├── .gitignore ├── LICENSE ├── Makefile ├── README.md └── src ├── arithmetic ├── Deinterleave.cpp ├── Deinterleave.h ├── ExprUsesVar.h ├── Interval.cpp ├── Interval.h ├── ModulusRemainder.cpp ├── ModulusRemainder.h ├── Scope.h ├── Simplify.cpp ├── Simplify.h ├── Substitute.cpp └── Substitute.h ├── base ├── Debug.cpp ├── Debug.h ├── Error.cpp ├── Error.h ├── Float16.h ├── Float16Opt.cpp ├── RoundingMode.h ├── Type.cpp ├── Type.h ├── TypeBase.h ├── Util.cpp └── Util.h ├── ir ├── Expr.h ├── FunctionBase.h ├── IR.cpp ├── IR.h ├── IREquality.cpp ├── IREquality.h ├── IRMutator.cpp ├── IRMutator.h ├── IROperator.cpp ├── IROperator.h ├── IRPrinter.cpp ├── IRPrinter.h ├── IRVisitor.cpp ├── IRVisitor.h └── Range.h └── tvm └── node ├── container.h ├── ir_functor.h ├── memory.h ├── node.cpp ├── node.h └── node_base.h /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | *.smod 19 | 20 | # Compiled Static libraries 21 | *.lai 22 | *.la 23 | *.a 24 | *.lib 25 | 26 | # Executables 27 | *.exe 28 | *.out 29 | *.app 30 | *~ 31 | build 32 | 33 | # Miscellany 34 | tags 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 HalideIR contributors 2 | 3 | Copyright (c) 2012-2014 MIT CSAIL, Google Inc., and other contributors 4 | 5 | HalideIR is derived from the Halide project. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | DMLC_CORE_PATH ?= ../dmlc-core 2 | DLPACK_INCLUDE_PATH ?= ../dlpack/include 3 | 4 | LDFLAGS = -pthread -lm 5 | CFLAGS = -std=c++11 -Wall -O2\ 6 | -Iinclude -I${DMLC_CORE_PATH}/include -I../include -I${DLPACK_INCLUDE_PATH} -Isrc -fPIC -fvisibility=hidden 7 | 8 | ifdef no_rtti 9 | CFLAGS += -fno-rtti 10 | endif 11 | 12 | # specify tensor path 13 | .PHONY: clean all test doc 14 | 15 | CCSUFFIX=cpp 16 | 17 | all: lib/libHalideIR.a lib/libHalideIR.so 18 | SRC = $(wildcard src/*.$(CCSUFFIX) src/*/*.$(CCSUFFIX) src/*/*/*.$(CCSUFFIX)) 19 | ALL_OBJ = $(patsubst src/%.$(CCSUFFIX), build/%.o, $(SRC)) 20 | ALL_DEP = $(ALL_OBJ) 21 | 22 | 23 | build/%.o: src/%.$(CCSUFFIX) 24 | @mkdir -p $(@D) 25 | $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d 26 | $(CXX) -c $(CFLAGS) -c $< -o $@ 27 | 28 | 29 | lib/libHalideIR.a: $(ALL_DEP) 30 | @mkdir -p $(@D) 31 | ar crv $@ $(filter %.o, $?) 32 | 33 | lib/libHalideIR.so: $(ALL_DEP) 34 | @mkdir -p $(@D) 35 | $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) 36 | 37 | 38 | lint: 39 | python2 dmlc-core/scripts/lint.py tvm cpp include src/tvm 40 | 41 | doc: 42 | doxygen docs/Doxyfile 43 | 44 | clean: 45 | $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d 46 | 47 | -include build/*.d 48 | -include build/*/*.d 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HalideIR: Symbolic Arithmetic IR Module 2 | 3 | HalideIR is a base module for building symbolic expression and arithmetic simplification 4 | for building new DSLs. It is isolated and refactored from part of Halide project (credit should go to the original authors). 5 | It was used in earlier versions of the tvm project. 6 | 7 | Note that some portions of the TVM compiler outside of this folder are also adapted from Halide codebase, 8 | where we needed similar logic (e.g loop vectorization). These are commented where they occur. 9 | 10 | ## Motivation 11 | 12 | We build this component during the development of TVM project. 13 | Symbolic expression data structure and simplification libary 14 | is essential to for such compiler stack project. 15 | 16 | Unfortunately there is no standalone module that fits our need at the moment. 17 | We find that the IR and simplification module of Halide project fits such purposes nicely. 18 | So we isolated and refactor the coresponding module into this repo. 19 | 20 | The major goal is minimum and interpolatable with more front-end languages. 21 | HalideIR is used in TVM project. Here are the few major improvements. 22 | 23 | - An isolated dependency free symbolic IR and simplification (no LLVM dependency). 24 | - The project modularized into logical components. 25 | - All IR structure are serializable and publically accessible from front-end language(e.g. python) 26 | - This supports ease of development and proptyping in python 27 | - Think of re-usablity when adding new DSL structure. 28 | - A runtime dispatching mechanism is introduced to allow ease of add new IR node. 29 | - Simplified variable defintiion rule: change from match by string to match by pointer 30 | - This ensures each variable to have single define location (like SSA) 31 | 32 | Besides these changes, we also re-factored the code to resolve some of the issues in the codebase. 33 | Some of these changes addresses the pain point raised in Halide project. 34 | The detailed change are listed in the later part of the project. 35 | 36 | ## Project Structure 37 | Based on code from Halide(release_2017_05_03). 38 | The code is componetized into four logical components 39 | 40 | - tvm: TVM container code for interpolation, basic data structures. 41 | - base: base utilities and type. 42 | - ir: The IR data structure 43 | - arithmetic: Arithemetic simplification. 44 | 45 | ### Code Style 46 | We keep old files in old code style, but use Google C style in the newly added files. 47 | 48 | ### List of Changes 49 | - Replace implmentation of Float16.cpp with a simpler version that does not depend on LLVM 50 | - IntrusivePtr 51 | - Remove IntrusivePtr, change everything to base on std::shared_ptr 52 | - See tvm/node.h 53 | - Add support for IR reflection via NodeRef.VisitAttrs 54 | - All the IR is constructable from python via TVM 55 | - This enables quick proptyping and debuging from python 56 | - Call 57 | - Remove Parameter, BufferPtr from Variable, Call, Load 58 | - This simplifies dependency, we also find it is cleaner to simply use Variable 59 | for Parameter. 60 | - AssertStmt 61 | - Add body field to AssertStmt, which represents the scope where the assert condition holds 62 | - This makes it easier to do Visitor pattern that take benefit of the scope assert information. 63 | - AttrStmt 64 | - This is a new Stmt that can be used to annotate attribute of certain things 65 | (e.g. content type of buffer). 66 | This removes the need of string matching for properties and hacks around string matching in Let. 67 | - We use this extensively to add new hints in the environment scope. 68 | - FunctionContent 69 | - Make FunctionBaseNode abstract, to replace FunctionContent 70 | - Provide, Realize, Prefetch, ProducerConsumer, Call 71 | - Remove use of string name matching of function 72 | - When place where function is needed, use FunctionRef 73 | - FunctionRef is uniqued matched by internal pointer. 74 | - Variable, Store, Allocate, Free, For, Load, Let, LetStmt 75 | - Remove use of string name matching of Variable 76 | - When place where variable is needed, use VarExpr 77 | - VarExpr is uniqued matched by internal pointer. 78 | - Variable 79 | - Rename Variable.name -> Variable.name_hint, to avoid confusion. 80 | - Variable.name_hint is not used to uniquely identify the variable. 81 | - By default, each Variable should be only defined once 82 | - This is in analog to SSA, and can make ir pass simplier(remove need of scoping) 83 | - Provide, Realize, Prefetch 84 | - Change Provide and Realize to associate with value_index. 85 | - Original Provide/Realize with multiple values can be represented by several Provide and Realize chained together 86 | - This allows outputs of a Function to have different shapes. 87 | - Make every field the IR reflectable and accessible from python 88 | - std::vector<> - > Array<>(tvm/container.h) 89 | - struct Range -> Range(Range.h) 90 | - Range 91 | - Remove constructor Range(min, extent) to Range::make_with_min_extent(min, extent) 92 | - The original constructor could make new user confuse since a typical way is 93 | range(begin, end) in both c++ and python 94 | - Simplify Visitor 95 | - Remove use of string name mapping in the Scope 96 | - Remove Scope in Substitute to check name conflicts because we use variable pointer matching 97 | - Add Expr&/Stmt& to visit interface in the IRVisitor and IRMutator 98 | - Add a more flexible RTTI and dynamic dispatch support, see src/tvm/node.h 99 | - IRFunctor allows plugin of new IR Node more easily, without chaning interface 100 | -------------------------------------------------------------------------------- /src/arithmetic/Deinterleave.cpp: -------------------------------------------------------------------------------- 1 | #include "Deinterleave.h" 2 | #include "base/Debug.h" 3 | #include "ir/IRMutator.h" 4 | #include "ir/IROperator.h" 5 | #include "ir/IREquality.h" 6 | #include "ir/IRPrinter.h" 7 | #include "ModulusRemainder.h" 8 | #include "Scope.h" 9 | #include "Simplify.h" 10 | 11 | namespace HalideIR { 12 | namespace Internal { 13 | 14 | using std::pair; 15 | 16 | class Deinterleaver : public IRMutator { 17 | public: 18 | int starting_lane; 19 | int new_lanes; 20 | int lane_stride; 21 | 22 | Deinterleaver() {} 23 | 24 | private: 25 | Scope internal; 26 | 27 | using IRMutator::visit; 28 | 29 | void visit(const Broadcast *op, const Expr &self) { 30 | if (new_lanes == 1) { 31 | expr = op->value; 32 | } else { 33 | expr = Broadcast::make(op->value, new_lanes); 34 | } 35 | } 36 | 37 | void visit(const Load *op, const Expr &self) { 38 | if (op->type.is_scalar()) { 39 | expr = self; 40 | } else { 41 | Type t = op->type.with_lanes(new_lanes); 42 | expr = Load::make(t, op->buffer_var, mutate(op->index), mutate(op->predicate)); 43 | } 44 | } 45 | 46 | void visit(const Ramp *op, const Expr &self) { 47 | expr = op->base + starting_lane * op->stride; 48 | internal_assert(expr.type() == op->base.type()); 49 | if (new_lanes > 1) { 50 | expr = Ramp::make(expr, op->stride * lane_stride, new_lanes); 51 | } 52 | } 53 | 54 | void visit(const Variable *op, const Expr &self) { 55 | if (op->type.is_scalar()) { 56 | expr = self; 57 | } else { 58 | if (internal.contains(op)) { 59 | expr = internal.get(op); 60 | } else { 61 | // Uh-oh, we don't know how to deinterleave this vector expression 62 | // Make llvm do it 63 | Array indices; 64 | for (int i = 0; i < new_lanes; i++) { 65 | indices.push_back(IntImm::make(Int(32), starting_lane + lane_stride * i)); 66 | } 67 | expr = Shuffle::make({self}, indices); 68 | } 69 | } 70 | } 71 | 72 | void visit(const Cast *op, const Expr &self) { 73 | if (op->type.is_scalar()) { 74 | expr = self; 75 | } else { 76 | Type t = op->type.with_lanes(new_lanes); 77 | expr = Cast::make(t, mutate(op->value)); 78 | } 79 | } 80 | 81 | void visit(const Call *op, const Expr &self) { 82 | Type t = op->type.with_lanes(new_lanes); 83 | 84 | // Don't mutate scalars 85 | if (op->type.is_scalar()) { 86 | expr = self; 87 | } else if (op->is_intrinsic(Call::glsl_texture_load)) { 88 | // glsl_texture_load returns a result. Deinterleave by 89 | // wrapping the call in a shuffle_vector 90 | Array indices; 91 | for (int i = 0; i < new_lanes; i++) { 92 | indices.push_back(i*lane_stride + starting_lane); 93 | } 94 | expr = Shuffle::make({self}, indices); 95 | } else { 96 | 97 | // Vector calls are always parallel across the lanes, so we 98 | // can just deinterleave the args. 99 | 100 | // Beware of other intrinsics for which this is not true! 101 | // Currently there's only interleave_vectors and 102 | // shuffle_vector. 103 | 104 | std::vector args(op->args.size()); 105 | for (size_t i = 0; i < args.size(); i++) { 106 | args[i] = mutate(op->args[i]); 107 | } 108 | 109 | expr = Call::make(t, op->name, args, op->call_type, 110 | op->func, op->value_index); 111 | } 112 | } 113 | 114 | void visit(const Let *op, const Expr& self) { 115 | if (op->type.is_vector()) { 116 | Expr new_value = mutate(op->value); 117 | Type new_type = new_value.type(); 118 | VarExpr new_var = Variable::make(new_type, "t"); 119 | internal.push(op->var.get(), new_var); 120 | Expr body = mutate(op->body); 121 | internal.pop(op->var.get()); 122 | 123 | // Define the new name. 124 | expr = Let::make(new_var, new_value, body); 125 | 126 | // Someone might still use the old name. 127 | expr = Let::make(op->var, op->value, expr); 128 | } else { 129 | IRMutator::visit(op, self); 130 | } 131 | } 132 | 133 | void visit(const Shuffle *op, const Expr &self) { 134 | if (op->is_interleave()) { 135 | internal_assert(starting_lane >= 0 && starting_lane < lane_stride); 136 | if ((int)op->vectors.size() == lane_stride) { 137 | expr = op->vectors[starting_lane]; 138 | } else if ((int)op->vectors.size() % lane_stride == 0) { 139 | // Pick up every lane-stride vector. 140 | std::vector new_vectors(op->vectors.size() / lane_stride); 141 | for (size_t i = 0; i < new_vectors.size(); i++) { 142 | new_vectors[i] = op->vectors[i*lane_stride + starting_lane]; 143 | } 144 | expr = Shuffle::make_interleave(new_vectors); 145 | } else { 146 | // Interleave some vectors then deinterleave by some other factor... 147 | // Brute force! 148 | Array indices; 149 | for (int i = 0; i < new_lanes; i++) { 150 | indices.push_back(IntImm::make(Int(32), i*lane_stride + starting_lane)); 151 | } 152 | expr = Shuffle::make({self}, indices); 153 | } 154 | } else { 155 | // Extract every nth numeric arg to the shuffle. 156 | Array indices; 157 | for (int i = 0; i < new_lanes; i++) { 158 | int idx = i * lane_stride + starting_lane; 159 | indices.push_back(op->indices[idx]); 160 | } 161 | expr = Shuffle::make({self}, indices); 162 | } 163 | } 164 | }; 165 | 166 | Expr extract_odd_lanes(Expr e) { 167 | internal_assert(e.type().lanes() % 2 == 0); 168 | Deinterleaver d; 169 | d.starting_lane = 1; 170 | d.lane_stride = 2; 171 | d.new_lanes = e.type().lanes()/2; 172 | e = d.mutate(e); 173 | return simplify(e); 174 | } 175 | 176 | Expr extract_even_lanes(Expr e) { 177 | internal_assert(e.type().lanes() % 2 == 0); 178 | Deinterleaver d; 179 | d.starting_lane = 0; 180 | d.lane_stride = 2; 181 | d.new_lanes = (e.type().lanes()+1)/2; 182 | e = d.mutate(e); 183 | return simplify(e); 184 | } 185 | 186 | Expr extract_lane(Expr e, int lane) { 187 | Deinterleaver d; 188 | d.starting_lane = lane; 189 | d.lane_stride = e.type().lanes(); 190 | d.new_lanes = 1; 191 | e = d.mutate(e); 192 | return simplify(e); 193 | } 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/arithmetic/Deinterleave.h: -------------------------------------------------------------------------------- 1 | #ifndef DEINTERLEAVE_H 2 | #define DEINTERLEAVE_H 3 | 4 | /** \file 5 | * 6 | * Defines methods for splitting up a vector into the even lanes and 7 | * the odd lanes. Useful for optimizing expressions such as select(x % 8 | * 2, f(x/2), g(x/2)) 9 | */ 10 | 11 | #include "ir/IR.h" 12 | 13 | namespace HalideIR { 14 | namespace Internal { 15 | 16 | /** Extract the odd-numbered lanes in a vector */ 17 | EXPORT Expr extract_odd_lanes(Expr a); 18 | 19 | /** Extract the even-numbered lanes in a vector */ 20 | EXPORT Expr extract_even_lanes(Expr a); 21 | 22 | /** Extract the nth lane of a vector */ 23 | EXPORT Expr extract_lane(Expr vec, int lane); 24 | 25 | } 26 | } 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /src/arithmetic/ExprUsesVar.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_EXPR_USES_VAR_H 2 | #define HALIDEIR_EXPR_USES_VAR_H 3 | 4 | /** \file 5 | * Defines a method to determine if an expression depends on some variables. 6 | */ 7 | 8 | #include "ir/IR.h" 9 | #include "ir/IRVisitor.h" 10 | #include "Scope.h" 11 | 12 | namespace HalideIR { 13 | namespace Internal { 14 | 15 | template 16 | class ExprUsesVars : public IRGraphVisitor { 17 | using IRGraphVisitor::visit; 18 | 19 | const Scope &vars; 20 | Scope scope; 21 | 22 | void visit(const Variable *v, const Expr&) { 23 | if (vars.contains(v)) { 24 | result = true; 25 | } else if (scope.contains(v)) { 26 | include(scope.get(v)); 27 | } 28 | } 29 | public: 30 | ExprUsesVars(const Scope &v, const Scope *s = nullptr) : vars(v), result(false) { 31 | scope.set_containing_scope(s); 32 | } 33 | bool result; 34 | }; 35 | 36 | /** Test if a statement or expression references the given variable. */ 37 | template 38 | inline bool stmt_or_expr_uses_var(StmtOrExpr e, const Variable* v) { 39 | Scope s; 40 | s.push(v, 0); 41 | ExprUsesVars uses(s); 42 | e.accept(&uses); 43 | return uses.result; 44 | } 45 | 46 | /** Test if a statement or expression references any of the variables 47 | * in a scope, additionally considering variables bound to Expr's in 48 | * the scope provided in the final argument. 49 | */ 50 | template 51 | inline bool stmt_or_expr_uses_vars(StmtOrExpr e, const Scope &v, 52 | const Scope &s = Scope::empty_scope()) { 53 | ExprUsesVars uses(v, &s); 54 | e.accept(&uses); 55 | return uses.result; 56 | } 57 | 58 | /** Test if an expression references the given variable. */ 59 | inline bool expr_uses_var(Expr e, const Variable* v) { 60 | return stmt_or_expr_uses_var(e, v); 61 | } 62 | 63 | /** Test if a statement references the given variable. */ 64 | inline bool stmt_uses_var(Stmt s, const Variable* v) { 65 | return stmt_or_expr_uses_var(s, v); 66 | } 67 | 68 | /** Test if an expression references any of the variables in a scope, 69 | * additionally considering variables bound to Expr's in the scope 70 | * provided in the final argument. 71 | */ 72 | template 73 | inline bool expr_uses_vars(Expr e, const Scope &v, 74 | const Scope &s = Scope::empty_scope()) { 75 | return stmt_or_expr_uses_vars(e, v, s); 76 | } 77 | 78 | /** Test if a statement references any of the variables in a scope, 79 | * additionally considering variables bound to Expr's in the scope 80 | * provided in the final argument. 81 | */ 82 | template 83 | inline bool stmt_uses_vars(Stmt e, const Scope &v, 84 | const Scope &s = Scope::empty_scope()) { 85 | return stmt_or_expr_uses_vars(e, v, s); 86 | } 87 | 88 | } 89 | } 90 | 91 | #endif 92 | -------------------------------------------------------------------------------- /src/arithmetic/Interval.cpp: -------------------------------------------------------------------------------- 1 | #include "Interval.h" 2 | #include "ir/IROperator.h" 3 | #include "ir/IREquality.h" 4 | #include "Simplify.h" 5 | 6 | namespace HalideIR { 7 | namespace Internal { 8 | 9 | // This is called repeatedly by bounds inference and the solver to 10 | // build large expressions, so we want to simplify eagerly to avoid 11 | // monster expressions. 12 | Expr Interval::make_max(Expr a, Expr b) { 13 | if (a.same_as(b)) return a; 14 | 15 | // Deal with infinities 16 | if (a.same_as(Interval::pos_inf)) return a; 17 | if (b.same_as(Interval::pos_inf)) return b; 18 | if (a.same_as(Interval::neg_inf)) return b; 19 | if (b.same_as(Interval::neg_inf)) return a; 20 | 21 | // Deep equality 22 | if (equal(a, b)) return a; 23 | 24 | // Constant fold 25 | const int64_t *ia = as_const_int(a); 26 | const int64_t *ib = as_const_int(b); 27 | const uint64_t *ua = as_const_uint(a); 28 | const uint64_t *ub = as_const_uint(b); 29 | const double *fa = as_const_float(a); 30 | const double *fb = as_const_float(b); 31 | if (ia && ib) return (*ia > *ib) ? a : b; 32 | if (ua && ub) return (*ua > *ub) ? a : b; 33 | if (fa && fb) return (*fa > *fb) ? a : b; 34 | 35 | // Balance trees to the left, with constants pushed rightwards 36 | const Max *ma = a.as(); 37 | const Max *mb = b.as(); 38 | if (mb && !ma && !(is_const(mb->a) && is_const(mb->b))) { 39 | std::swap(ma, mb); 40 | std::swap(a, b); 41 | } 42 | if (ma && is_const(ma->b) && is_const(b)) { 43 | return Interval::make_max(ma->a, Interval::make_max(ma->b, b)); 44 | } 45 | if (ma && (ma->a.same_as(b) || ma->b.same_as(b))) { 46 | // b is already represented in a 47 | return a; 48 | } 49 | 50 | return Max::make(a, b); 51 | } 52 | 53 | Expr Interval::make_min(Expr a, Expr b) { 54 | if (a.same_as(b)) return a; 55 | 56 | // Deal with infinities 57 | if (a.same_as(Interval::pos_inf)) return b; 58 | if (b.same_as(Interval::pos_inf)) return a; 59 | if (a.same_as(Interval::neg_inf)) return a; 60 | if (b.same_as(Interval::neg_inf)) return b; 61 | 62 | // Deep equality 63 | if (equal(a, b)) return a; 64 | 65 | // Constant fold 66 | const int64_t *ia = as_const_int(a); 67 | const int64_t *ib = as_const_int(b); 68 | const uint64_t *ua = as_const_uint(a); 69 | const uint64_t *ub = as_const_uint(b); 70 | const double *fa = as_const_float(a); 71 | const double *fb = as_const_float(b); 72 | if (ia && ib) return (*ia > *ib) ? b : a; 73 | if (ua && ub) return (*ua > *ub) ? b : a; 74 | if (fa && fb) return (*fa > *fb) ? b : a; 75 | 76 | // Balance trees to the left, with constants pushed rightwards 77 | const Min *ma = a.as(); 78 | const Min *mb = b.as(); 79 | if (mb && !ma && !(is_const(mb->a) && is_const(mb->b))) { 80 | std::swap(ma, mb); 81 | std::swap(a, b); 82 | } 83 | if (ma && is_const(ma->b) && is_const(b)) { 84 | return Interval::make_min(ma->a, Interval::make_min(ma->b, b)); 85 | } 86 | if (ma && (ma->a.same_as(b) || ma->b.same_as(b))) { 87 | // b is already represented in a 88 | return a; 89 | } 90 | 91 | return Min::make(a, b); 92 | } 93 | 94 | void Interval::include(const Interval &i) { 95 | max = Interval::make_max(max, i.max); 96 | min = Interval::make_min(min, i.min); 97 | } 98 | 99 | void Interval::include(Expr e) { 100 | max = Interval::make_max(max, e); 101 | min = Interval::make_min(min, e); 102 | } 103 | 104 | Interval Interval::make_union(const Interval &a, const Interval &b) { 105 | Interval result = a; 106 | result.include(b); 107 | return result; 108 | } 109 | 110 | Interval Interval::make_intersection(const Interval &a, const Interval &b) { 111 | auto min = Interval::make_max(a.min, b.min); 112 | auto max = Interval::make_min(a.max, b.max); 113 | if (min.same_as(Interval::pos_inf) || max.same_as(Interval::neg_inf)) { 114 | return Interval::nothing(); 115 | } 116 | else if (min.type() == max.type() && 117 | (min.type().is_int() || min.type().is_uint()) && 118 | can_prove(min > max)) { 119 | return Interval::nothing(); 120 | } 121 | else { 122 | return Interval(min, max); 123 | } 124 | } 125 | 126 | // Use Handle types for positive and negative infinity, to prevent 127 | // accidentally doing arithmetic on them. 128 | Expr Interval::pos_inf = Variable::make(Handle(), "pos_inf"); 129 | Expr Interval::neg_inf = Variable::make(Handle(), "neg_inf"); 130 | 131 | 132 | namespace { 133 | void check(Interval result, Interval expected, int line) { 134 | internal_assert(equal(result.min, expected.min) && 135 | equal(result.max, expected.max)) 136 | << "Interval test on line " << line << " failed\n" 137 | << " Expected [" << expected.min << ", " << expected.max << "]\n" 138 | << " Got [" << result.min << ", " << result.max << "]\n"; 139 | } 140 | } 141 | 142 | void interval_test() { 143 | Interval e = Interval::everything(); 144 | Interval n = Interval::nothing(); 145 | Expr x = Variable::make(Int(32), "x"); 146 | Interval xp{x, Interval::pos_inf}; 147 | Interval xn{Interval::neg_inf, x}; 148 | Interval xx{x, x}; 149 | 150 | internal_assert(e.is_everything()); 151 | internal_assert(!e.has_upper_bound()); 152 | internal_assert(!e.has_lower_bound()); 153 | internal_assert(!e.is_empty()); 154 | internal_assert(!e.is_bounded()); 155 | internal_assert(!e.is_single_point()); 156 | 157 | internal_assert(!n.is_everything()); 158 | internal_assert(!n.has_upper_bound()); 159 | internal_assert(!n.has_lower_bound()); 160 | internal_assert(n.is_empty()); 161 | internal_assert(!n.is_bounded()); 162 | internal_assert(!n.is_single_point()); 163 | 164 | internal_assert(!xp.is_everything()); 165 | internal_assert(!xp.has_upper_bound()); 166 | internal_assert(xp.has_lower_bound()); 167 | internal_assert(!xp.is_empty()); 168 | internal_assert(!xp.is_bounded()); 169 | internal_assert(!xp.is_single_point()); 170 | 171 | internal_assert(!xn.is_everything()); 172 | internal_assert(xn.has_upper_bound()); 173 | internal_assert(!xn.has_lower_bound()); 174 | internal_assert(!xn.is_empty()); 175 | internal_assert(!xn.is_bounded()); 176 | internal_assert(!xn.is_single_point()); 177 | 178 | internal_assert(!xx.is_everything()); 179 | internal_assert(xx.has_upper_bound()); 180 | internal_assert(xx.has_lower_bound()); 181 | internal_assert(!xx.is_empty()); 182 | internal_assert(xx.is_bounded()); 183 | internal_assert(xx.is_single_point()); 184 | 185 | check(Interval::make_union(xp, xn), e, __LINE__); 186 | check(Interval::make_union(e, xn), e, __LINE__); 187 | check(Interval::make_union(xn, e), e, __LINE__); 188 | check(Interval::make_union(xn, n), xn, __LINE__); 189 | check(Interval::make_union(n, xp), xp, __LINE__); 190 | check(Interval::make_union(xp, xp), xp, __LINE__); 191 | 192 | check(Interval::make_intersection(xp, xn), Interval::single_point(x), __LINE__); 193 | check(Interval::make_intersection(e, xn), xn, __LINE__); 194 | check(Interval::make_intersection(xn, e), xn, __LINE__); 195 | check(Interval::make_intersection(xn, n), n, __LINE__); 196 | check(Interval::make_intersection(n, xp), n, __LINE__); 197 | check(Interval::make_intersection(xp, xp), xp, __LINE__); 198 | 199 | check(Interval::make_union({3, Interval::pos_inf}, {5, Interval::pos_inf}), {3, Interval::pos_inf}, __LINE__); 200 | check(Interval::make_intersection({3, Interval::pos_inf}, {5, Interval::pos_inf}), {5, Interval::pos_inf}, __LINE__); 201 | 202 | check(Interval::make_union({Interval::neg_inf, 3}, {Interval::neg_inf, 5}), {Interval::neg_inf, 5}, __LINE__); 203 | check(Interval::make_intersection({Interval::neg_inf, 3}, {Interval::neg_inf, 5}), {Interval::neg_inf, 3}, __LINE__); 204 | 205 | check(Interval::make_union({3, 4}, {9, 10}), {3, 10}, __LINE__); 206 | check(Interval::make_intersection({3, 4}, {9, 10}), {9, 4}, __LINE__); 207 | 208 | check(Interval::make_union({3, 9}, {4, 10}), {3, 10}, __LINE__); 209 | check(Interval::make_intersection({3, 9}, {4, 10}), {4, 9}, __LINE__); 210 | 211 | std::cout << "Interval test passed" << std::endl; 212 | } 213 | 214 | 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /src/arithmetic/Interval.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_INTERVAL_H 2 | #define HALIDEIR_INTERVAL_H 3 | 4 | /** \file 5 | * Defines the Interval class 6 | */ 7 | 8 | #include "ir/Expr.h" 9 | 10 | namespace HalideIR { 11 | namespace Internal { 12 | 13 | /** A class to represent ranges of Exprs. Can be unbounded above or below. */ 14 | struct Interval { 15 | 16 | /** Exprs to represent positive and negative infinity */ 17 | static Expr pos_inf, neg_inf; 18 | 19 | /** The lower and upper bound of the interval. They are included 20 | * in the interval. */ 21 | Expr min, max; 22 | 23 | /** A default-constructed Interval is everything */ 24 | Interval() : min(neg_inf), max(pos_inf) {} 25 | 26 | /** Construct an interval from a lower and upper bound. */ 27 | Interval(Expr min, Expr max) : min(min), max(max) { 28 | internal_assert(min.defined() && max.defined()); 29 | } 30 | 31 | /** The interval representing everything. */ 32 | static Interval everything() {return Interval(neg_inf, pos_inf);} 33 | 34 | /** The interval representing nothing. */ 35 | static Interval nothing() {return Interval(pos_inf, neg_inf);} 36 | 37 | /** Construct an interval representing a single point */ 38 | static Interval single_point(Expr e) {return Interval(e, e);} 39 | 40 | /** Is the interval the empty set */ 41 | bool is_empty() const {return min.same_as(pos_inf) || max.same_as(neg_inf);} 42 | 43 | /** Is the interval the entire range */ 44 | bool is_everything() const {return min.same_as(neg_inf) && max.same_as(pos_inf);} 45 | 46 | /** Is the interval just a single value (min == max) */ 47 | bool is_single_point() const {return min.same_as(max);} 48 | 49 | /** Is the interval a particular single value */ 50 | bool is_single_point(Expr e) const {return min.same_as(e) && max.same_as(e);} 51 | 52 | /** Does the interval have a finite least upper bound */ 53 | bool has_upper_bound() const {return !max.same_as(pos_inf) && !is_empty();} 54 | 55 | /** Does the interval have a finite greatest lower bound */ 56 | bool has_lower_bound() const {return !min.same_as(neg_inf) && !is_empty();} 57 | 58 | /** Does the interval have a finite upper and lower bound */ 59 | bool is_bounded() const {return has_upper_bound() && has_lower_bound();} 60 | 61 | /** Is the interval the same as another interval */ 62 | bool same_as(const Interval &other) {return min.same_as(other.min) && max.same_as(other.max);} 63 | 64 | /** Expand the interval to include another Interval */ 65 | EXPORT void include(const Interval &i); 66 | 67 | /** Expand the interval to include an Expr */ 68 | EXPORT void include(Expr e); 69 | 70 | /** Construct the smallest interval containing two intervals. */ 71 | EXPORT static Interval make_union(const Interval &a, const Interval &b); 72 | 73 | /** Construct the largest interval contained within two intervals. */ 74 | EXPORT static Interval make_intersection(const Interval &a, const Interval &b); 75 | 76 | /** An eagerly-simplifying max of two Exprs that respects infinities. */ 77 | EXPORT static Expr make_max(Expr a, Expr b); 78 | 79 | /** An eagerly-simplifying min of two Exprs that respects infinities. */ 80 | EXPORT static Expr make_min(Expr a, Expr b); 81 | 82 | }; 83 | 84 | EXPORT void interval_test(); 85 | 86 | } 87 | } 88 | 89 | #endif 90 | -------------------------------------------------------------------------------- /src/arithmetic/ModulusRemainder.cpp: -------------------------------------------------------------------------------- 1 | #include "ModulusRemainder.h" 2 | #include "Simplify.h" 3 | #include "ir/IROperator.h" 4 | #include "ir/IRPrinter.h" 5 | #include "ir/IR.h" 6 | 7 | // This file is largely a port of parts of src/analysis.ml 8 | namespace HalideIR { 9 | namespace Internal { 10 | 11 | class ComputeModulusRemainder : public IRVisitor { 12 | public: 13 | ModulusRemainder analyze(Expr e); 14 | 15 | int modulus, remainder; 16 | Scope scope; 17 | 18 | ComputeModulusRemainder(const Scope *s) { 19 | scope.set_containing_scope(s); 20 | } 21 | 22 | void visit(const IntImm *, const Expr &); 23 | void visit(const UIntImm *, const Expr &); 24 | void visit(const FloatImm *, const Expr &); 25 | void visit(const StringImm *, const Expr &); 26 | void visit(const Cast *, const Expr &); 27 | void visit(const Variable *, const Expr &); 28 | void visit(const Add *, const Expr &); 29 | void visit(const Sub *, const Expr &); 30 | void visit(const Mul *, const Expr &); 31 | void visit(const Div *, const Expr &); 32 | void visit(const Mod *, const Expr &); 33 | void visit(const Min *, const Expr &); 34 | void visit(const Max *, const Expr &); 35 | void visit(const EQ *, const Expr &); 36 | void visit(const NE *, const Expr &); 37 | void visit(const LT *, const Expr &); 38 | void visit(const LE *, const Expr &); 39 | void visit(const GT *, const Expr &); 40 | void visit(const GE *, const Expr &); 41 | void visit(const And *, const Expr &); 42 | void visit(const Or *, const Expr &); 43 | void visit(const Not *, const Expr &); 44 | void visit(const Select *, const Expr &); 45 | void visit(const Load *, const Expr &); 46 | void visit(const Ramp *, const Expr &); 47 | void visit(const Broadcast *, const Expr &); 48 | void visit(const Call *, const Expr &); 49 | void visit(const Let *, const Expr &); 50 | void visit(const LetStmt *, const Stmt &); 51 | void visit(const AssertStmt *, const Stmt &); 52 | void visit(const ProducerConsumer *, const Stmt &); 53 | void visit(const For *, const Stmt &); 54 | void visit(const Store *, const Stmt &); 55 | void visit(const Provide *, const Stmt &); 56 | void visit(const Allocate *, const Stmt &); 57 | void visit(const Free *, const Stmt &); 58 | void visit(const Realize *, const Stmt &); 59 | void visit(const Prefetch *, const Stmt &); 60 | void visit(const Block *, const Stmt &); 61 | void visit(const IfThenElse *, const Stmt &); 62 | void visit(const Evaluate *, const Stmt &); 63 | void visit(const Shuffle *, const Expr &); 64 | }; 65 | 66 | ModulusRemainder modulus_remainder(Expr e) { 67 | ComputeModulusRemainder mr(nullptr); 68 | return mr.analyze(e); 69 | } 70 | 71 | ModulusRemainder modulus_remainder(Expr e, const Scope &scope) { 72 | ComputeModulusRemainder mr(&scope); 73 | return mr.analyze(e); 74 | } 75 | 76 | 77 | 78 | bool reduce_expr_modulo(Expr expr, int modulus, int *remainder) { 79 | ModulusRemainder result = modulus_remainder(expr); 80 | 81 | /* As an example: If we asked for expr mod 8, and the analysis 82 | * said that expr = 16*k + 13, then because 16 % 8 == 0, the 83 | * result is 13 % 8 == 5. But if the analysis says that expr = 84 | * 6*k + 3, then expr mod 8 could be 1, 3, 5, or 7, so we just 85 | * return false. 86 | */ 87 | 88 | if (result.modulus % modulus == 0) { 89 | *remainder = result.remainder % modulus; 90 | return true; 91 | } else { 92 | return false; 93 | } 94 | } 95 | bool reduce_expr_modulo(Expr expr, int modulus, int *remainder, const Scope &scope) { 96 | ModulusRemainder result = modulus_remainder(expr, scope); 97 | 98 | if (result.modulus % modulus == 0) { 99 | *remainder = result.remainder % modulus; 100 | return true; 101 | } else { 102 | return false; 103 | } 104 | } 105 | 106 | ModulusRemainder ComputeModulusRemainder::analyze(Expr e) { 107 | e.accept(this); 108 | return ModulusRemainder(modulus, remainder); 109 | } 110 | 111 | namespace { 112 | void check(Expr e, int m, int r) { 113 | ModulusRemainder result = modulus_remainder(e); 114 | if (result.modulus != m || result.remainder != r) { 115 | std::cerr << "Test failed for modulus_remainder:\n"; 116 | std::cerr << "Expression: " << e << "\n"; 117 | std::cerr << "Correct modulus, remainder = " << m << ", " << r << "\n"; 118 | std::cerr << "Computed modulus, remainder = " 119 | << result.modulus << ", " 120 | << result.remainder << "\n"; 121 | exit(-1); 122 | } 123 | } 124 | } 125 | 126 | void modulus_remainder_test() { 127 | VarExpr x = Variable::make(Int(32), "x"); 128 | VarExpr y = Variable::make(Int(32), "y"); 129 | 130 | check((30*x + 3) + (40*y + 2), 10, 5); 131 | check((6*x + 3) * (4*y + 1), 2, 1); 132 | check(max(30*x - 24, 40*y + 31), 5, 1); 133 | check(10*x - 33*y, 1, 0); 134 | check(10*x - 35*y, 5, 0); 135 | check(123, 0, 123); 136 | check(Let::make(y, x*3 + 4, y*3 + 4), 9, 7); 137 | 138 | std::cout << "modulus_remainder test passed\n"; 139 | } 140 | 141 | 142 | void ComputeModulusRemainder::visit(const IntImm *op, const Expr &) { 143 | // Equal to op->value modulo anything. We'll use zero as the 144 | // modulus to mark this special case. We'd better be able to 145 | // handle zero in the rest of the code... 146 | remainder = op->value; 147 | modulus = 0; 148 | } 149 | 150 | void ComputeModulusRemainder::visit(const UIntImm *op, const Expr &) { 151 | internal_error << "modulus_remainder of uint\n"; 152 | } 153 | 154 | void ComputeModulusRemainder::visit(const FloatImm *, const Expr &) { 155 | internal_error << "modulus_remainder of float\n"; 156 | } 157 | 158 | void ComputeModulusRemainder::visit(const StringImm *, const Expr &) { 159 | internal_error << "modulus_remainder of string\n"; 160 | } 161 | 162 | void ComputeModulusRemainder::visit(const Cast *, const Expr &) { 163 | modulus = 1; 164 | remainder = 0; 165 | } 166 | 167 | void ComputeModulusRemainder::visit(const Variable *op, const Expr &) { 168 | if (scope.contains(op)) { 169 | ModulusRemainder mod_rem = scope.get(op); 170 | modulus = mod_rem.modulus; 171 | remainder = mod_rem.remainder; 172 | } else { 173 | modulus = 1; 174 | remainder = 0; 175 | } 176 | } 177 | 178 | int gcd(int a, int b) { 179 | if (a < b) std::swap(a, b); 180 | while (b != 0) { 181 | int64_t tmp = b; 182 | b = a % b; 183 | a = tmp; 184 | } 185 | return a; 186 | } 187 | 188 | int lcm(int a, int b) { 189 | return (a*b)/gcd(a, b); 190 | } 191 | 192 | int mod(int a, int m) { 193 | if (m == 0) return a; 194 | return mod_imp(a, m); 195 | } 196 | 197 | void ComputeModulusRemainder::visit(const Add *op, const Expr &) { 198 | ModulusRemainder a = analyze(op->a); 199 | ModulusRemainder b = analyze(op->b); 200 | modulus = gcd(a.modulus, b.modulus); 201 | remainder = mod(a.remainder + b.remainder, modulus); 202 | } 203 | 204 | void ComputeModulusRemainder::visit(const Sub *op, const Expr &) { 205 | ModulusRemainder a = analyze(op->a); 206 | ModulusRemainder b = analyze(op->b); 207 | modulus = gcd(a.modulus, b.modulus); 208 | remainder = mod(a.remainder - b.remainder, modulus); 209 | } 210 | 211 | void ComputeModulusRemainder::visit(const Mul *op, const Expr &) { 212 | ModulusRemainder a = analyze(op->a); 213 | ModulusRemainder b = analyze(op->b); 214 | 215 | if (a.modulus == 0) { 216 | // a is constant 217 | modulus = a.remainder * b.modulus; 218 | remainder = a.remainder * b.remainder; 219 | } else if (b.modulus == 0) { 220 | // b is constant 221 | modulus = b.remainder * a.modulus; 222 | remainder = a.remainder * b.remainder; 223 | } else if (a.remainder == 0 && b.remainder == 0) { 224 | // multiple times multiple 225 | modulus = a.modulus * b.modulus; 226 | remainder = 0; 227 | } else if (a.remainder == 0) { 228 | modulus = a.modulus * gcd(b.modulus, b.remainder); 229 | remainder = 0; 230 | } else if (b.remainder == 0) { 231 | modulus = b.modulus * gcd(a.modulus, a.remainder); 232 | remainder = 0; 233 | } else { 234 | // All our tricks failed. Convert them to the same modulus and multiply 235 | modulus = gcd(a.modulus, b.modulus); 236 | a.remainder = mod(a.remainder * b.remainder, modulus); 237 | } 238 | } 239 | 240 | void ComputeModulusRemainder::visit(const Div *, const Expr &) { 241 | // We might be able to say something about this if the numerator 242 | // modulus is provably a multiple of a constant denominator, but 243 | // in this case we should have simplified away the division. 244 | remainder = 0; 245 | modulus = 1; 246 | } 247 | 248 | namespace { 249 | ModulusRemainder unify_alternatives(ModulusRemainder a, ModulusRemainder b) { 250 | // We don't know if we're going to get a or b, so we'd better find 251 | // a single modulus remainder that works for both. 252 | 253 | // For example: 254 | // max(30*_ + 13, 40*_ + 27) -> 255 | // max(10*_ + 3, 10*_ + 7) -> 256 | // max(2*_ + 1, 2*_ + 1) -> 257 | // 2*_ + 1 258 | 259 | // Reduce them to the same modulus and the same remainder 260 | int modulus = gcd(a.modulus, b.modulus); 261 | int64_t diff = (int64_t)a.remainder - (int64_t)b.remainder; 262 | if (!Int(32).can_represent(diff)) { 263 | // The difference overflows. 264 | return ModulusRemainder(0, 1); 265 | } 266 | if (diff < 0) diff = -diff; 267 | modulus = gcd((int)diff, modulus); 268 | 269 | int ra = mod(a.remainder, modulus); 270 | 271 | internal_assert(ra == mod(b.remainder, modulus)) 272 | << "There's a bug inside ModulusRemainder in unify_alternatives:\n" 273 | << "a.modulus = " << a.modulus << "\n" 274 | << "a.remainder = " << a.remainder << "\n" 275 | << "b.modulus = " << b.modulus << "\n" 276 | << "b.remainder = " << b.remainder << "\n" 277 | << "diff = " << diff << "\n" 278 | << "unified modulus = " << modulus << "\n" 279 | << "unified remainder = " << ra << "\n"; 280 | 281 | 282 | return ModulusRemainder(modulus, ra); 283 | } 284 | } 285 | 286 | void ComputeModulusRemainder::visit(const Mod *op, const Expr &) { 287 | // We can treat x mod y as x + z*y, where we know nothing about z. 288 | // (ax + b) + z (cx + d) -> 289 | // ax + b + zcx + dz -> 290 | // gcd(a, c, d) * w + b 291 | 292 | // E.g: 293 | // (8x + 5) mod (6x + 2) -> 294 | // (8x + 5) + z (6x + 2) -> 295 | // (8x + 6zx + 2x) + 5 -> 296 | // 2(4x + 3zx + x) + 5 -> 297 | // 2w + 1 298 | ModulusRemainder a = analyze(op->a); 299 | ModulusRemainder b = analyze(op->b); 300 | modulus = gcd(a.modulus, b.modulus); 301 | modulus = gcd(modulus, b.remainder); 302 | remainder = mod(a.remainder, modulus); 303 | } 304 | 305 | void ComputeModulusRemainder::visit(const Min *op, const Expr &) { 306 | ModulusRemainder r = unify_alternatives(analyze(op->a), analyze(op->b)); 307 | modulus = r.modulus; 308 | remainder = r.remainder; 309 | } 310 | 311 | void ComputeModulusRemainder::visit(const Max *op, const Expr &) { 312 | ModulusRemainder r = unify_alternatives(analyze(op->a), analyze(op->b)); 313 | modulus = r.modulus; 314 | remainder = r.remainder; 315 | } 316 | 317 | void ComputeModulusRemainder::visit(const EQ *, const Expr &) { 318 | internal_assert(false) << "modulus_remainder of bool\n"; 319 | } 320 | 321 | void ComputeModulusRemainder::visit(const NE *, const Expr &) { 322 | internal_assert(false) << "modulus_remainder of bool\n"; 323 | } 324 | 325 | void ComputeModulusRemainder::visit(const LT *, const Expr &) { 326 | internal_assert(false) << "modulus_remainder of bool\n"; 327 | } 328 | 329 | void ComputeModulusRemainder::visit(const LE *, const Expr &) { 330 | internal_assert(false) << "modulus_remainder of bool\n"; 331 | } 332 | 333 | void ComputeModulusRemainder::visit(const GT *, const Expr &) { 334 | internal_assert(false) << "modulus_remainder of bool\n"; 335 | } 336 | 337 | void ComputeModulusRemainder::visit(const GE *, const Expr &) { 338 | internal_assert(false) << "modulus_remainder of bool\n"; 339 | } 340 | 341 | void ComputeModulusRemainder::visit(const And *, const Expr &) { 342 | internal_assert(false) << "modulus_remainder of bool\n"; 343 | } 344 | 345 | void ComputeModulusRemainder::visit(const Or *, const Expr &) { 346 | internal_assert(false) << "modulus_remainder of bool\n"; 347 | } 348 | 349 | void ComputeModulusRemainder::visit(const Not *, const Expr &) { 350 | internal_assert(false) << "modulus_remainder of bool\n"; 351 | } 352 | 353 | void ComputeModulusRemainder::visit(const Select *op, const Expr &) { 354 | ModulusRemainder r = unify_alternatives(analyze(op->true_value), 355 | analyze(op->false_value)); 356 | modulus = r.modulus; 357 | remainder = r.remainder; 358 | } 359 | 360 | void ComputeModulusRemainder::visit(const Load *, const Expr &) { 361 | modulus = 1; 362 | remainder = 0; 363 | } 364 | 365 | void ComputeModulusRemainder::visit(const Ramp *, const Expr &) { 366 | internal_assert(false) << "modulus_remainder of vector\n"; 367 | } 368 | 369 | void ComputeModulusRemainder::visit(const Broadcast *, const Expr &) { 370 | internal_assert(false) << "modulus_remainder of vector\n"; 371 | } 372 | 373 | void ComputeModulusRemainder::visit(const Call *, const Expr &) { 374 | modulus = 1; 375 | remainder = 0; 376 | } 377 | 378 | void ComputeModulusRemainder::visit(const Let *op, const Expr &) { 379 | bool value_interesting = op->value.type().is_int(); 380 | 381 | if (value_interesting) { 382 | ModulusRemainder val = analyze(op->value); 383 | scope.push(op->var.get(), val); 384 | } 385 | ModulusRemainder val = analyze(op->body); 386 | if (value_interesting) { 387 | scope.pop(op->var.get()); 388 | } 389 | modulus = val.modulus; 390 | remainder = val.remainder; 391 | } 392 | 393 | void ComputeModulusRemainder::visit(const Shuffle *op, const Expr&) { 394 | // It's possible that scalar expressions are extracting a lane of a vector - don't fail in this case, but stop 395 | internal_assert(op->indices.size() == 1) << "modulus_remainder of vector\n"; 396 | modulus = 1; 397 | remainder = 0; 398 | } 399 | 400 | void ComputeModulusRemainder::visit(const LetStmt *, const Stmt &) { 401 | internal_assert(false) << "modulus_remainder of statement\n"; 402 | } 403 | 404 | void ComputeModulusRemainder::visit(const AssertStmt *, const Stmt &) { 405 | internal_assert(false) << "modulus_remainder of statement\n"; 406 | } 407 | 408 | void ComputeModulusRemainder::visit(const ProducerConsumer *, const Stmt &) { 409 | internal_assert(false) << "modulus_remainder of statement\n"; 410 | } 411 | 412 | void ComputeModulusRemainder::visit(const For *, const Stmt &) { 413 | internal_assert(false) << "modulus_remainder of statement\n"; 414 | } 415 | 416 | void ComputeModulusRemainder::visit(const Store *, const Stmt &) { 417 | internal_assert(false) << "modulus_remainder of statement\n"; 418 | } 419 | 420 | void ComputeModulusRemainder::visit(const Provide *, const Stmt &) { 421 | internal_assert(false) << "modulus_remainder of statement\n"; 422 | } 423 | 424 | void ComputeModulusRemainder::visit(const Allocate *, const Stmt &) { 425 | internal_assert(false) << "modulus_remainder of statement\n"; 426 | } 427 | 428 | void ComputeModulusRemainder::visit(const Realize *, const Stmt &) { 429 | internal_assert(false) << "modulus_remainder of statement\n"; 430 | } 431 | 432 | void ComputeModulusRemainder::visit(const Prefetch *, const Stmt &) { 433 | internal_assert(false) << "modulus_remainder of statement\n"; 434 | } 435 | 436 | void ComputeModulusRemainder::visit(const Block *, const Stmt &) { 437 | internal_assert(false) << "modulus_remainder of statement\n"; 438 | } 439 | 440 | void ComputeModulusRemainder::visit(const Free *, const Stmt &) { 441 | internal_assert(false) << "modulus_remainder of statement\n"; 442 | } 443 | 444 | void ComputeModulusRemainder::visit(const IfThenElse *, const Stmt &) { 445 | internal_assert(false) << "modulus_remainder of statement\n"; 446 | } 447 | 448 | void ComputeModulusRemainder::visit(const Evaluate *, const Stmt &) { 449 | internal_assert(false) << "modulus_remainder of statement\n"; 450 | } 451 | 452 | } 453 | } 454 | -------------------------------------------------------------------------------- /src/arithmetic/ModulusRemainder.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_MODULUS_REMAINDER_H 2 | #define HALIDEIR_MODULUS_REMAINDER_H 3 | 4 | /** \file 5 | * Routines for statically determining what expressions are divisible by. 6 | */ 7 | 8 | #include "Scope.h" 9 | 10 | namespace HalideIR { 11 | namespace Internal { 12 | 13 | /** The result of modulus_remainder analysis */ 14 | struct ModulusRemainder { 15 | ModulusRemainder() : modulus(0), remainder(0) {} 16 | ModulusRemainder(int m, int r) : modulus(m), remainder(r) {} 17 | int modulus, remainder; 18 | }; 19 | 20 | /** For things like alignment analysis, often it's helpful to know 21 | * if an integer expression is some multiple of a constant plus 22 | * some other constant. For example, it is straight-forward to 23 | * deduce that ((10*x + 2)*(6*y - 3) - 1) is congruent to five 24 | * modulo six. 25 | * 26 | * We get the most information when the modulus is large. E.g. if 27 | * something is congruent to 208 modulo 384, then we also know it's 28 | * congruent to 0 mod 8, and we can possibly use it as an index for an 29 | * aligned load. If all else fails, we can just say that an integer is 30 | * congruent to zero modulo one. 31 | */ 32 | EXPORT ModulusRemainder modulus_remainder(Expr e); 33 | 34 | /** If we have alignment information about external variables, we can 35 | * let the analysis know about that using this version of 36 | * modulus_remainder: */ 37 | EXPORT ModulusRemainder modulus_remainder(Expr e, const Scope &scope); 38 | 39 | /** Reduce an expression modulo some integer. Returns true and assigns 40 | * to remainder if an answer could be found. */ 41 | ///@{ 42 | EXPORT bool reduce_expr_modulo(Expr e, int modulus, int *remainder); 43 | EXPORT bool reduce_expr_modulo(Expr e, int modulus, int *remainder, const Scope &scope); 44 | ///@} 45 | 46 | EXPORT void modulus_remainder_test(); 47 | 48 | /** The greatest common divisor of two integers */ 49 | EXPORT int gcd(int, int); 50 | 51 | /** The least common multiple of two integers */ 52 | EXPORT int lcm(int, int); 53 | 54 | } 55 | } 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /src/arithmetic/Scope.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_SCOPE_H 2 | #define HALIDEIR_SCOPE_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "base/Util.h" 11 | #include "base/Debug.h" 12 | #include "base/Error.h" 13 | #include "ir/IR.h" 14 | 15 | /** \file 16 | * Defines the Scope class, which is used for keeping track of names in a scope while traversing IR 17 | */ 18 | 19 | namespace HalideIR { 20 | namespace Internal { 21 | 22 | /** A stack which can store one item very efficiently. Using this 23 | * instead of std::stack speeds up Scope substantially. */ 24 | template 25 | class SmallStack { 26 | private: 27 | T _top; 28 | std::vector _rest; 29 | bool _empty; 30 | 31 | public: 32 | SmallStack() : _empty(true) {} 33 | 34 | void pop() { 35 | if (_rest.empty()) { 36 | _empty = true; 37 | _top = T(); 38 | } else { 39 | _top = _rest.back(); 40 | _rest.pop_back(); 41 | } 42 | } 43 | 44 | void push(const T &t) { 45 | if (_empty) { 46 | _empty = false; 47 | } else { 48 | _rest.push_back(_top); 49 | } 50 | _top = t; 51 | } 52 | 53 | T top() const { 54 | return _top; 55 | } 56 | 57 | T &top_ref() { 58 | return _top; 59 | } 60 | 61 | const T &top_ref() const { 62 | return _top; 63 | } 64 | 65 | bool empty() const { 66 | return _empty; 67 | } 68 | }; 69 | 70 | /** A common pattern when traversing Halide IR is that you need to 71 | * keep track of stuff when you find a Let or a LetStmt, and that it 72 | * should hide previous values with the same name until you leave the 73 | * Let or LetStmt nodes This class helps with that. */ 74 | template 75 | class Scope { 76 | private: 77 | std::map> table; 78 | 79 | // Copying a scope object copies a large table full of strings and 80 | // stacks. Bad idea. 81 | Scope(const Scope &); 82 | Scope &operator=(const Scope &); 83 | 84 | const Scope *containing_scope; 85 | 86 | 87 | public: 88 | Scope() : containing_scope(nullptr) {} 89 | 90 | /** Set the parent scope. If lookups fail in this scope, they 91 | * check the containing scope before returning an error. Caller is 92 | * responsible for managing the memory of the containing scope. */ 93 | void set_containing_scope(const Scope *s) { 94 | containing_scope = s; 95 | } 96 | 97 | /** A const ref to an empty scope. Useful for default function 98 | * arguments, which would otherwise require a copy constructor 99 | * (with llvm in c++98 mode) */ 100 | static const Scope &empty_scope() { 101 | static Scope *_empty_scope = new Scope(); 102 | return *_empty_scope; 103 | } 104 | 105 | /** Retrieve the value referred to by a name */ 106 | T get(const Variable* var) const { 107 | typename std::map>::const_iterator iter = table.find(var); 108 | if (iter == table.end() || iter->second.empty()) { 109 | if (containing_scope) { 110 | return containing_scope->get(var); 111 | } else { 112 | internal_error << "Symbol '" << var->name_hint << "' not found\n"; 113 | } 114 | } 115 | return iter->second.top(); 116 | } 117 | 118 | /** Return a reference to an entry. Does not consider the containing scope. */ 119 | T &ref(const Variable* var) { 120 | typename std::map>::iterator iter = table.find(var); 121 | if (iter == table.end() || iter->second.empty()) { 122 | internal_error << "Symbol '" << var->name_hint << "' not found\n"; 123 | } 124 | return iter->second.top_ref(); 125 | } 126 | 127 | /** Tests if a name is in scope */ 128 | bool contains(const Variable* var) const { 129 | typename std::map>::const_iterator iter = table.find(var); 130 | if (iter == table.end() || iter->second.empty()) { 131 | if (containing_scope) { 132 | return containing_scope->contains(var); 133 | } else { 134 | return false; 135 | } 136 | } 137 | return true; 138 | } 139 | 140 | /** Add a new (name, value) pair to the current scope. Hide old 141 | * values that have this name until we pop this name. 142 | */ 143 | void push(const Variable* var, const T &value) { 144 | table[var].push(value); 145 | } 146 | 147 | /** A name goes out of scope. Restore whatever its old value 148 | * was (or remove it entirely if there was nothing else of the 149 | * same name in an outer scope) */ 150 | void pop(const Variable* var) { 151 | typename std::map>::iterator iter = table.find(var); 152 | internal_assert(iter != table.end()) << "Name not in symbol table: " << var->name_hint << "\n"; 153 | iter->second.pop(); 154 | if (iter->second.empty()) { 155 | table.erase(iter); 156 | } 157 | } 158 | 159 | /** Iterate through the scope. Does not capture any containing scope. */ 160 | class const_iterator { 161 | typename std::map>::const_iterator iter; 162 | public: 163 | explicit const_iterator(const typename std::map>::const_iterator &i) : 164 | iter(i) { 165 | } 166 | 167 | const_iterator() {} 168 | 169 | bool operator!=(const const_iterator &other) { 170 | return iter != other.iter; 171 | } 172 | 173 | void operator++() { 174 | ++iter; 175 | } 176 | 177 | const Variable* var() { 178 | return iter->first; 179 | } 180 | 181 | const SmallStack &stack() { 182 | return iter->second; 183 | } 184 | 185 | const T &value() { 186 | return iter->second.top_ref(); 187 | } 188 | }; 189 | 190 | const_iterator cbegin() const { 191 | return const_iterator(table.begin()); 192 | } 193 | 194 | const_iterator cend() const { 195 | return const_iterator(table.end()); 196 | } 197 | 198 | class iterator { 199 | typename std::map>::iterator iter; 200 | public: 201 | explicit iterator(typename std::map>::iterator i) : 202 | iter(i) { 203 | } 204 | 205 | iterator() {} 206 | 207 | bool operator!=(const iterator &other) { 208 | return iter != other.iter; 209 | } 210 | 211 | void operator++() { 212 | ++iter; 213 | } 214 | 215 | const Variable* var() { 216 | return iter->first; 217 | } 218 | 219 | SmallStack &stack() { 220 | return iter->second; 221 | } 222 | 223 | T &value() { 224 | return iter->second.top_ref(); 225 | } 226 | }; 227 | 228 | iterator begin() { 229 | return iterator(table.begin()); 230 | } 231 | 232 | iterator end() { 233 | return iterator(table.end()); 234 | } 235 | 236 | void swap(Scope &other) { 237 | table.swap(other.table); 238 | std::swap(containing_scope, other.containing_scope); 239 | } 240 | }; 241 | 242 | template 243 | std::ostream &operator<<(std::ostream &stream, const Scope& s) { 244 | stream << "{\n"; 245 | typename Scope::const_iterator iter; 246 | for (iter = s.cbegin(); iter != s.cend(); ++iter) { 247 | stream << " " << iter.var()->name_hint << "\n"; 248 | } 249 | stream << "}"; 250 | return stream; 251 | } 252 | 253 | } 254 | } 255 | 256 | #endif 257 | -------------------------------------------------------------------------------- /src/arithmetic/Simplify.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_SIMPLIFY_H 2 | #define HALIDEIR_SIMPLIFY_H 3 | 4 | /** \file 5 | * Methods for simplifying halide statements and expressions 6 | */ 7 | 8 | #include 9 | 10 | #include "ir/IR.h" 11 | #include "Interval.h" 12 | #include "ModulusRemainder.h" 13 | 14 | namespace HalideIR { 15 | namespace Internal { 16 | 17 | /** Perform a a wide range of simplifications to expressions and 18 | * statements, including constant folding, substituting in trivial 19 | * values, arithmetic rearranging, etc. Simplifies across let 20 | * statements, so must not be called on stmts with dangling or 21 | * repeated variable names. 22 | */ 23 | // @{ 24 | EXPORT Stmt simplify(Stmt, bool simplify_lets = true, 25 | const Scope &bounds = Scope::empty_scope(), 26 | const Scope &alignment = Scope::empty_scope()); 27 | EXPORT Expr simplify(Expr, bool simplify_lets = true, 28 | const Scope &bounds = Scope::empty_scope(), 29 | const Scope &alignment = Scope::empty_scope()); 30 | // @} 31 | 32 | /** A common use of the simplifier is to prove boolean expressions are 33 | * true at compile time. Equivalent to is_one(simplify(e)) */ 34 | EXPORT bool can_prove(Expr e); 35 | 36 | /** Simplify expressions found in a statement, but don't simplify 37 | * across different statements. This is safe to perform at an earlier 38 | * stage in lowering than full simplification of a stmt. */ 39 | EXPORT Stmt simplify_exprs(Stmt); 40 | 41 | /** Implementations of division and mod that are specific to Halide. 42 | * Use these implementations; do not use native C division or mod to 43 | * simplify Halide expressions. Halide division and modulo satisify 44 | * the Euclidean definition of division for integers a and b: 45 | * 46 | /code 47 | (a/b)*b + a%b = a 48 | 0 <= a%b < |b| 49 | /endcode 50 | * 51 | */ 52 | // @{ 53 | template 54 | inline T mod_imp(T a, T b) { 55 | Type t = type_of(); 56 | if (t.is_int()) { 57 | T r = a % b; 58 | r = r + (r < 0 ? (T)std::abs((int64_t)b) : 0); 59 | return r; 60 | } else { 61 | return a % b; 62 | } 63 | } 64 | 65 | template 66 | inline T div_imp(T a, T b) { 67 | Type t = type_of(); 68 | if (t.is_int()) { 69 | int64_t q = a / b; 70 | int64_t r = a - q * b; 71 | int64_t bs = b >> (t.bits() - 1); 72 | int64_t rs = r >> (t.bits() - 1); 73 | return (T) (q - (rs & bs) + (rs & ~bs)); 74 | } else { 75 | return a / b; 76 | } 77 | } 78 | // @} 79 | 80 | // Special cases for float, double. 81 | template<> inline float mod_imp(float a, float b) { 82 | float f = a - b * (floorf(a / b)); 83 | // The remainder has the same sign as b. 84 | return f; 85 | } 86 | template<> inline double mod_imp(double a, double b) { 87 | double f = a - b * (std::floor(a / b)); 88 | return f; 89 | } 90 | 91 | template<> inline float div_imp(float a, float b) { 92 | return a/b; 93 | } 94 | template<> inline double div_imp(double a, double b) { 95 | return a/b; 96 | } 97 | 98 | 99 | EXPORT void simplify_test(); 100 | 101 | } 102 | } 103 | 104 | #endif 105 | -------------------------------------------------------------------------------- /src/arithmetic/Substitute.cpp: -------------------------------------------------------------------------------- 1 | #include "Substitute.h" 2 | #include "Scope.h" 3 | #include "ir/IRMutator.h" 4 | #include "ir/IREquality.h" 5 | 6 | namespace HalideIR { 7 | namespace Internal { 8 | 9 | using std::map; 10 | using std::string; 11 | 12 | class Substitute : public IRMutator { 13 | /* We don't need a Scope to check if variable inside let statements has 14 | same name as the first argument because we use variable pointer to 15 | match. */ 16 | const map &replace; 17 | 18 | Expr find_replacement(const Variable* s) { 19 | map::const_iterator iter = replace.find(s); 20 | if (iter != replace.end()) { 21 | return iter->second; 22 | } else { 23 | return Expr(); 24 | } 25 | } 26 | 27 | public: 28 | Substitute(const map &m) : replace(m) {} 29 | 30 | using IRMutator::visit; 31 | 32 | void visit(const Variable *v, const Expr &e) { 33 | Expr r = find_replacement(v); 34 | if (r.defined()) { 35 | expr = r; 36 | } else { 37 | expr = e; 38 | } 39 | } 40 | 41 | void visit(const Let *op, const Expr &e) { 42 | Expr new_value = mutate(op->value); 43 | Expr new_body = mutate(op->body); 44 | 45 | if (new_value.same_as(op->value) && 46 | new_body.same_as(op->body)) { 47 | expr = e; 48 | } else { 49 | expr = Let::make(op->var, new_value, new_body); 50 | } 51 | } 52 | 53 | void visit(const LetStmt *op, const Stmt &s) { 54 | Expr new_value = mutate(op->value); 55 | Stmt new_body = mutate(op->body); 56 | 57 | if (new_value.same_as(op->value) && 58 | new_body.same_as(op->body)) { 59 | stmt = s; 60 | } else { 61 | stmt = LetStmt::make(op->var, new_value, new_body); 62 | } 63 | } 64 | 65 | void visit(const For *op, const Stmt &s) { 66 | Expr new_min = mutate(op->min); 67 | Expr new_extent = mutate(op->extent); 68 | Stmt new_body = mutate(op->body); 69 | 70 | if (new_min.same_as(op->min) && 71 | new_extent.same_as(op->extent) && 72 | new_body.same_as(op->body)) { 73 | stmt = s; 74 | } else { 75 | stmt = For::make(op->loop_var, new_min, new_extent, op->for_type, op->device_api, new_body); 76 | } 77 | } 78 | 79 | }; 80 | 81 | Expr substitute(const Variable* var, Expr replacement, Expr expr) { 82 | map m; 83 | m[var] = replacement; 84 | Substitute s(m); 85 | return s.mutate(expr); 86 | } 87 | 88 | Stmt substitute(const Variable* var, Expr replacement, Stmt stmt) { 89 | map m; 90 | m[var] = replacement; 91 | Substitute s(m); 92 | return s.mutate(stmt); 93 | } 94 | 95 | Expr substitute(const map &m, Expr expr) { 96 | Substitute s(m); 97 | return s.mutate(expr); 98 | } 99 | 100 | Stmt substitute(const map &m, Stmt stmt) { 101 | Substitute s(m); 102 | return s.mutate(stmt); 103 | } 104 | 105 | 106 | class SubstituteExpr : public IRMutator { 107 | public: 108 | Expr find, replacement; 109 | 110 | using IRMutator::mutate; 111 | 112 | Expr mutate(Expr e) { 113 | if (equal(e, find)) { 114 | return replacement; 115 | } else { 116 | return IRMutator::mutate(e); 117 | } 118 | } 119 | }; 120 | 121 | Expr substitute(Expr find, Expr replacement, Expr expr) { 122 | SubstituteExpr s; 123 | s.find = find; 124 | s.replacement = replacement; 125 | return s.mutate(expr); 126 | } 127 | 128 | Stmt substitute(Expr find, Expr replacement, Stmt stmt) { 129 | SubstituteExpr s; 130 | s.find = find; 131 | s.replacement = replacement; 132 | return s.mutate(stmt); 133 | } 134 | 135 | /** Substitute an expr for a var in a graph. */ 136 | class GraphSubstitute : public IRGraphMutator { 137 | const Variable* var; 138 | Expr value; 139 | 140 | using IRGraphMutator::visit; 141 | 142 | void visit(const Variable *op, const Expr &e) { 143 | if (op == var) { 144 | expr = value; 145 | } else { 146 | expr = e; 147 | } 148 | } 149 | 150 | public: 151 | 152 | GraphSubstitute(const Variable* var, Expr value) : var(var), value(value) {} 153 | }; 154 | 155 | /** Substitute an Expr for another Expr in a graph. Unlike substitute, 156 | * this only checks for shallow equality. */ 157 | class GraphSubstituteExpr : public IRGraphMutator { 158 | Expr find, replace; 159 | public: 160 | 161 | using IRGraphMutator::mutate; 162 | 163 | Expr mutate(Expr e) { 164 | if (e.same_as(find)) return replace; 165 | return IRGraphMutator::mutate(e); 166 | } 167 | 168 | GraphSubstituteExpr(Expr find, Expr replace) : find(find), replace(replace) {} 169 | }; 170 | 171 | Expr graph_substitute(const Variable* var, Expr replacement, Expr expr) { 172 | return GraphSubstitute(var, replacement).mutate(expr); 173 | } 174 | 175 | Stmt graph_substitute(const Variable* var, Expr replacement, Stmt stmt) { 176 | return GraphSubstitute(var, replacement).mutate(stmt); 177 | } 178 | 179 | Expr graph_substitute(Expr find, Expr replacement, Expr expr) { 180 | return GraphSubstituteExpr(find, replacement).mutate(expr); 181 | } 182 | 183 | Stmt graph_substitute(Expr find, Expr replacement, Stmt stmt) { 184 | return GraphSubstituteExpr(find, replacement).mutate(stmt); 185 | } 186 | 187 | class SubstituteInAllLets : public IRGraphMutator { 188 | 189 | using IRGraphMutator::visit; 190 | 191 | void visit(const Let *op, const Expr &) { 192 | Expr value = mutate(op->value); 193 | Expr body = mutate(op->body); 194 | expr = graph_substitute(op->var, value, body); 195 | } 196 | }; 197 | 198 | Expr substitute_in_all_lets(Expr expr) { 199 | return SubstituteInAllLets().mutate(expr); 200 | } 201 | 202 | Stmt substitute_in_all_lets(Stmt stmt) { 203 | return SubstituteInAllLets().mutate(stmt); 204 | } 205 | 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /src/arithmetic/Substitute.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_SUBSTITUTE_H 2 | #define HALIDEIR_SUBSTITUTE_H 3 | 4 | /** \file 5 | * 6 | * Defines methods for substituting out variables in expressions and 7 | * statements. */ 8 | 9 | #include 10 | 11 | #include "ir/IR.h" 12 | 13 | namespace HalideIR { 14 | namespace Internal { 15 | 16 | /** Substitute variables with the given pointer with the replacement 17 | * expression within expr. */ 18 | EXPORT Expr substitute(const Variable* var, Expr replacement, Expr expr); 19 | 20 | /** Substitute variables with the given pointer with the replacement 21 | * expression within stmt. */ 22 | EXPORT Stmt substitute(const Variable* var, Expr replacement, Stmt stmt); 23 | 24 | EXPORT inline Expr substitute(const VarExpr& var, Expr replacement, Expr expr) { 25 | return substitute(var.get(), replacement, expr); 26 | } 27 | 28 | EXPORT inline Stmt substitute(const VarExpr& var, Expr replacement, Stmt stmt) { 29 | return substitute(var.get(), replacement, stmt); 30 | } 31 | 32 | /** Substitute variables with pointers in the map. */ 33 | // @{ 34 | EXPORT Expr substitute(const std::map &replacements, Expr expr); 35 | EXPORT Stmt substitute(const std::map &replacements, Stmt stmt); 36 | // @} 37 | 38 | /** Substitute expressions for other expressions. */ 39 | // @{ 40 | EXPORT Expr substitute(Expr find, Expr replacement, Expr expr); 41 | EXPORT Stmt substitute(Expr find, Expr replacement, Stmt stmt); 42 | // @} 43 | 44 | /** Substitutions where the IR may be a general graph (and not just a 45 | * DAG). */ 46 | // @{ 47 | Expr graph_substitute(const Variable* var, Expr replacement, Expr expr); 48 | Stmt graph_substitute(const Variable* var, Expr replacement, Stmt stmt); 49 | Expr graph_substitute(Expr find, Expr replacement, Expr expr); 50 | Stmt graph_substitute(Expr find, Expr replacement, Stmt stmt); 51 | // @} 52 | 53 | /** Substitute in all let Exprs in a piece of IR. Doesn't substitute 54 | * in let stmts, as this may change the meaning of the IR (e.g. by 55 | * moving a load after a store). Produces graphs of IR, so don't use 56 | * non-graph-aware visitors or mutators on it until you've CSE'd the 57 | * result. */ 58 | // @{ 59 | Expr substitute_in_all_lets(Expr expr); 60 | Stmt substitute_in_all_lets(Stmt stmt); 61 | // @} 62 | 63 | } 64 | } 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /src/base/Debug.cpp: -------------------------------------------------------------------------------- 1 | #include "Debug.h" 2 | 3 | namespace HalideIR { 4 | namespace Internal { 5 | 6 | int debug::debug_level() { 7 | return 0; 8 | } 9 | 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/base/Debug.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_DEBUG_H 2 | #define HALIDEIR_DEBUG_H 3 | 4 | /** \file 5 | * Defines functions for debug logging during code generation. 6 | */ 7 | 8 | #include 9 | #include 10 | #include 11 | #include "Util.h" 12 | 13 | namespace HalideIR { 14 | 15 | struct Expr; 16 | struct Type; 17 | // Forward declare some things from IRPrinter, which we can't include yet. 18 | EXPORT std::ostream &operator<<(std::ostream &stream, const Expr &); 19 | EXPORT std::ostream &operator<<(std::ostream &stream, const Type &); 20 | 21 | class Module; 22 | EXPORT std::ostream &operator<<(std::ostream &stream, const Module &); 23 | 24 | namespace Internal { 25 | 26 | struct Stmt; 27 | EXPORT std::ostream &operator<<(std::ostream &stream, const Stmt &); 28 | 29 | struct LoweredFunc; 30 | EXPORT std::ostream &operator << (std::ostream &, const LoweredFunc &); 31 | 32 | /** For optional debugging during codegen, use the debug class as 33 | * follows: 34 | * 35 | \code 36 | debug(verbosity) << "The expression is " << expr << std::endl; 37 | \endcode 38 | * 39 | * verbosity of 0 always prints, 1 should print after every major 40 | * stage, 2 should be used for more detail, and 3 should be used for 41 | * tracing everything that occurs. The verbosity with which to print 42 | * is determined by the value of the environment variable 43 | * HL_DEBUG_CODEGEN 44 | */ 45 | 46 | class debug { 47 | const bool logging; 48 | 49 | public: 50 | debug(int verbosity) : logging(verbosity <= debug_level()) {} 51 | 52 | template 53 | debug &operator<<(T&& x) { 54 | if (logging) { 55 | std::cerr << std::forward(x); 56 | } 57 | return *this; 58 | } 59 | 60 | EXPORT static int debug_level(); 61 | }; 62 | 63 | } 64 | } 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /src/base/Error.cpp: -------------------------------------------------------------------------------- 1 | #include "Error.h" 2 | 3 | namespace HalideIR { 4 | 5 | namespace { 6 | 7 | CompileTimeErrorReporter* custom_error_reporter = nullptr; 8 | 9 | } // namespace 10 | 11 | void set_custom_compile_time_error_reporter(CompileTimeErrorReporter* error_reporter) { 12 | custom_error_reporter = error_reporter; 13 | } 14 | 15 | bool exceptions_enabled() { 16 | #ifdef WITH_EXCEPTIONS 17 | return true; 18 | #else 19 | return false; 20 | #endif 21 | } 22 | 23 | Error::Error(const std::string &msg) : std::runtime_error(msg) { 24 | } 25 | 26 | CompileError::CompileError(const std::string &msg) : Error(msg) { 27 | } 28 | 29 | RuntimeError::RuntimeError(const std::string &msg) : Error(msg) { 30 | } 31 | 32 | InternalError::InternalError(const std::string &msg) : Error(msg) { 33 | } 34 | 35 | 36 | namespace Internal { 37 | 38 | // Force the classes to exist, even if exceptions are off 39 | namespace { 40 | CompileError _compile_error(""); 41 | RuntimeError _runtime_error(""); 42 | InternalError _internal_error(""); 43 | } 44 | 45 | ErrorReport::ErrorReport(const char *file, int line, const char *condition_string, int flags) : flags(flags) { 46 | 47 | const std::string &source_loc = ""; 48 | 49 | if (flags & User) { 50 | // Only mention where inside of libHalide the error tripped if we have debug level > 0 51 | debug(1) << "User error triggered at " << file << ":" << line << "\n"; 52 | if (condition_string) { 53 | debug(1) << "Condition failed: " << condition_string << "\n"; 54 | } 55 | if (flags & Warning) { 56 | msg << "Warning"; 57 | } else { 58 | msg << "Error"; 59 | } 60 | if (source_loc.empty()) { 61 | msg << ":\n"; 62 | } else { 63 | msg << " at " << source_loc << ":\n"; 64 | } 65 | 66 | } else { 67 | msg << "Internal "; 68 | if (flags & Warning) { 69 | msg << "warning"; 70 | } else { 71 | msg << "error"; 72 | } 73 | msg << " at " << file << ":" << line; 74 | if (!source_loc.empty()) { 75 | msg << " triggered by user code at " << source_loc << ":\n"; 76 | } else { 77 | msg << "\n"; 78 | } 79 | if (condition_string) { 80 | msg << "Condition failed: " << condition_string << "\n"; 81 | } 82 | } 83 | } 84 | 85 | ErrorReport::~ErrorReport() 86 | #if __cplusplus >= 201100 || _MSC_VER >= 1900 87 | noexcept(false) 88 | #endif 89 | { 90 | if (!msg.str().empty() && msg.str().back() != '\n') { 91 | msg << '\n'; 92 | } 93 | 94 | if (custom_error_reporter != nullptr) { 95 | if (flags & Warning) { 96 | custom_error_reporter->warning(msg.str().c_str()); 97 | return; 98 | } else { 99 | custom_error_reporter->error(msg.str().c_str()); 100 | // error() should not have returned to us, but just in case 101 | // it does, make sure we don't continue. 102 | abort(); 103 | } 104 | } 105 | 106 | // TODO: Add an option to error out on warnings too 107 | if (flags & Warning) { 108 | std::cerr << msg.str(); 109 | return; 110 | } 111 | 112 | #ifdef WITH_EXCEPTIONS 113 | if (std::uncaught_exception()) { 114 | // This should never happen - evaluating one of the arguments 115 | // to the error message would have to throw an 116 | // exception. Nonetheless, in case it does, preserve the 117 | // exception already in flight and suppress this one. 118 | return; 119 | } else if (flags & Runtime) { 120 | RuntimeError err(msg.str()); 121 | throw err; 122 | } else if (flags & User) { 123 | CompileError err(msg.str()); 124 | throw err; 125 | } else { 126 | InternalError err(msg.str()); 127 | throw err; 128 | } 129 | #else 130 | std::cerr << msg.str(); 131 | abort(); 132 | #endif 133 | } 134 | } 135 | 136 | } 137 | -------------------------------------------------------------------------------- /src/base/Error.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_ERROR_H 2 | #define HALIDEIR_ERROR_H 3 | 4 | #include 5 | #include 6 | 7 | #include "Debug.h" 8 | #include "TypeBase.h" 9 | 10 | #include 11 | 12 | namespace HalideIR { 13 | 14 | /** Query whether Halide was compiled with exceptions. */ 15 | EXPORT bool exceptions_enabled(); 16 | 17 | /** A base class for Halide errors. */ 18 | struct Error : public std::runtime_error { 19 | // Give each class a non-inlined constructor so that the type 20 | // doesn't get separately instantiated in each compilation unit. 21 | EXPORT Error(const std::string &msg); 22 | }; 23 | 24 | /** An error that occurs while running a JIT-compiled Halide pipeline. */ 25 | struct RuntimeError : public Error { 26 | EXPORT RuntimeError(const std::string &msg); 27 | }; 28 | 29 | /** An error that occurs while compiling a Halide pipeline that Halide 30 | * attributes to a user error. */ 31 | struct CompileError : public Error { 32 | EXPORT CompileError(const std::string &msg); 33 | }; 34 | 35 | /** An error that occurs while compiling a Halide pipeline that Halide 36 | * attributes to an internal compiler bug, or to an invalid use of 37 | * Halide's internals. */ 38 | struct InternalError : public Error { 39 | EXPORT InternalError(const std::string &msg); 40 | }; 41 | 42 | /** CompileTimeErrorReporter is used at compile time (*not* runtime) when 43 | * an error or warning is generated by Halide. Note that error() is called 44 | * a fatal error has occurred, and returning to Halide may cause a crash; 45 | * implementations of CompileTimeErrorReporter::error() should never return. 46 | * (Implementations of CompileTimeErrorReporter::warning() may return but 47 | * may also abort(), exit(), etc.) 48 | */ 49 | class CompileTimeErrorReporter { 50 | public: 51 | virtual ~CompileTimeErrorReporter() {} 52 | virtual void warning(const char* msg) = 0; 53 | virtual void error(const char* msg) = 0; 54 | }; 55 | 56 | /** The default error reporter logs to stderr, then throws an exception 57 | * (if WITH_EXCEPTIONS) or calls abort (if not). This allows customization 58 | * of that behavior if a more gentle response to error reporting is desired. 59 | * Note that error_reporter is expected to remain valid across all Halide usage; 60 | * it is up to the caller to ensure that this is the case (and to do any 61 | * cleanup necessary). 62 | */ 63 | EXPORT void set_custom_compile_time_error_reporter(CompileTimeErrorReporter* error_reporter); 64 | 65 | namespace Internal { 66 | 67 | struct ErrorReport { 68 | enum { 69 | User = 0x0001, 70 | Warning = 0x0002, 71 | Runtime = 0x0004 72 | }; 73 | 74 | std::ostringstream msg; 75 | const int flags; 76 | 77 | EXPORT ErrorReport(const char *f, int l, const char *cs, int flags); 78 | 79 | // Just a trick used to convert RValue into LValue 80 | HALIDEIR_ALWAYS_INLINE ErrorReport& ref() { return *this; } 81 | 82 | template 83 | ErrorReport &operator<<(const T &x) { 84 | msg << x; 85 | return *this; 86 | } 87 | 88 | /** When you're done using << on the object, and let it fall out of 89 | * scope, this errors out, or throws an exception if they are 90 | * enabled. This is a little dangerous because the destructor will 91 | * also be called if there's an exception in flight due to an 92 | * error in one of the arguments passed to operator<<. We handle 93 | * this by only actually throwing if there isn't an exception in 94 | * flight already. 95 | */ 96 | #if __cplusplus >= 201100 || _MSC_VER >= 1900 97 | EXPORT ~ErrorReport() noexcept(false); 98 | #else 99 | EXPORT ~ErrorReport(); 100 | #endif 101 | }; 102 | 103 | // This uses operator precedence as a trick to avoid argument evaluation if 104 | // an assertion is true: it is intended to be used as part of the 105 | // _halideir_internal_assertion macro, to coerce the result of the stream 106 | // expression to void (to match the condition-is-false case). 107 | class Voidifier { 108 | public: 109 | HALIDEIR_ALWAYS_INLINE Voidifier() {} 110 | // This has to be an operator with a precedence lower than << but 111 | // higher than ?: 112 | HALIDEIR_ALWAYS_INLINE void operator&(ErrorReport&) {} 113 | }; 114 | 115 | /** 116 | * _halideir_internal_assertion is used to implement our assertion macros 117 | * in such a way that the messages output for the assertion are only 118 | * evaluated if the assertion's value is false. 119 | * 120 | * Note that this macro intentionally has no parens internally; in actual 121 | * use, the implicit grouping will end up being 122 | * 123 | * condition ? (void) : (Voidifier() & (ErrorReport << arg1 << arg2 ... << argN)) 124 | * 125 | * This (regrettably) requires a macro to work, but has the highly desirable 126 | * effect that all assertion parameters are totally skipped (not ever evaluated) 127 | * when the assertion is true. 128 | */ 129 | #define _halideir_internal_assertion(condition, flags) \ 130 | (condition) \ 131 | ? (void)0 \ 132 | : ::HalideIR::Internal::Voidifier() & \ 133 | ::HalideIR::Internal::ErrorReport(__FILE__, __LINE__, #condition, flags).ref() 134 | 135 | 136 | #define internal_error HalideIR::Internal::ErrorReport(__FILE__, __LINE__, nullptr, 0) 137 | #define user_error HalideIR::Internal::ErrorReport(__FILE__, __LINE__, nullptr, HalideIR::Internal::ErrorReport::User) 138 | #define user_warning HalideIR::Internal::ErrorReport(__FILE__, __LINE__, nullptr, HalideIR::Internal::ErrorReport::User | HalideIR::Internal::ErrorReport::Warning) 139 | #define halideir_runtime_error HalideIR::Internal::ErrorReport(__FILE__, __LINE__, nullptr, HalideIR::Internal::ErrorReport::User | HalideIR::Internal::ErrorReport::Runtime) 140 | 141 | // #define internal_assert(c) _halideir_internal_assertion(c, 0) 142 | // #define user_assert(c) _halideir_internal_assertion(c, HalideIR::Internal::ErrorReport::User) 143 | 144 | #define internal_assert CHECK 145 | #define user_assert CHECK 146 | 147 | // The nicely named versions get cleaned up at the end of Halide.h, 148 | // but user code might want to do halide-style user_asserts (e.g. the 149 | // Extern macros introduce calls to user_assert), so for that purpose 150 | // we define an equivalent macro that can be used outside of Halide.h 151 | #define _halideir_user_assert(c) _halideir_internal_assertion(c, HalideIR::Internal::ErrorReport::User) 152 | 153 | // N.B. Any function that might throw a user_assert or user_error may 154 | // not be inlined into the user's code, or the line number will be 155 | // misattributed to Halide.h. Either make such functions internal to 156 | // libHalide, or mark them as NO_INLINE. 157 | 158 | } 159 | 160 | } 161 | 162 | #endif 163 | -------------------------------------------------------------------------------- /src/base/Float16.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_FLOAT16_H 2 | #define HALIDEIR_FLOAT16_H 3 | 4 | #include 5 | #include 6 | #include "./TypeBase.h" 7 | #include "./RoundingMode.h" 8 | #include "./Util.h" 9 | 10 | namespace HalideIR { 11 | 12 | /** Class that provides a type that implements half precision 13 | * floating point (IEEE754 2008 binary16) in software. 14 | * 15 | * This type is enforced to be 16-bits wide and maintains no state 16 | * other than the raw IEEE754 binary16 bits so that it can passed 17 | * to code that checks a type's size and used for buffer_t allocation. 18 | * */ 19 | struct float16_t { 20 | // NOTE: Do not use virtual methods here 21 | // it will change the size of this data type. 22 | 23 | /// \name Constructors 24 | /// @{ 25 | 26 | /** Construct from a float using a particular rounding mode. 27 | * A warning will be emitted if the result cannot be represented exactly 28 | * and error will be raised if the conversion results in overflow. 29 | * 30 | * \param value the input float 31 | * \param roundingMode The rounding mode to use 32 | * 33 | */ 34 | EXPORT explicit float16_t(float value, RoundingMode roundingMode=RoundingMode::ToNearestTiesToEven); 35 | 36 | /** Construct from a double using a particular rounding mode. 37 | * A warning will be emitted if the result cannot be represented exactly 38 | * and error will be raised if the conversion results in overflow. 39 | * 40 | * \param value the input double 41 | * \param roundingMode The rounding mode to use 42 | * 43 | */ 44 | EXPORT explicit float16_t(double value, RoundingMode roundingMode=RoundingMode::ToNearestTiesToEven); 45 | 46 | /** Construct by parsing a string using a particular rounding mode. 47 | * A warning will be emitted if the result cannot be represented exactly 48 | * and error will be raised if the conversion results in overflow. 49 | * 50 | * \param stringRepr the input string. The string maybe in C99 hex format 51 | * (e.g. ``-0x1.000p-1``) or in a decimal (e.g.``-0.5``) format. 52 | * 53 | * \param roundingMode The rounding mode to use 54 | * 55 | */ 56 | EXPORT explicit float16_t(const char *stringRepr, RoundingMode roundingMode=RoundingMode::ToNearestTiesToEven); 57 | 58 | /** Construct a float16_t with the bits initialised to 0. This represents 59 | * positive zero.*/ 60 | EXPORT float16_t(); 61 | 62 | /// @} 63 | 64 | // Use explicit to avoid accidently raising the precision 65 | /** Cast to float */ 66 | EXPORT explicit operator float() const; 67 | /** Cast to double */ 68 | EXPORT explicit operator double() const; 69 | 70 | // Be explicit about how the copy constructor is expected to behave 71 | EXPORT float16_t(const float16_t&) = default; 72 | 73 | // Be explicit about how assignment is expected to behave 74 | EXPORT float16_t& operator=(const float16_t&) = default; 75 | 76 | /** \name Convenience "constructors" 77 | */ 78 | /**@{*/ 79 | 80 | /** Get a new float16_t that represents zero 81 | * \param positive if true then returns positive zero otherwise returns 82 | * negative zero. 83 | */ 84 | EXPORT static float16_t make_zero(bool positive); 85 | 86 | /** Get a new float16_t that represents infinity 87 | * \param positive if true then returns positive infinity otherwise returns 88 | * negative infinity. 89 | */ 90 | EXPORT static float16_t make_infinity(bool positive); 91 | 92 | /** Get a new float16_t that represents NaN (not a number) */ 93 | EXPORT static float16_t make_nan(); 94 | 95 | /** Get a new float16_t with the given raw bits 96 | * 97 | * \param bits The bits conformant to IEEE754 binary16 98 | */ 99 | EXPORT static float16_t make_from_bits(uint16_t bits); 100 | 101 | /** Get a new float16_t from a signed integer. 102 | * It is not provided as a constructor to avoid call ambiguity 103 | * */ 104 | EXPORT static float16_t make_from_signed_int(int64_t value, RoundingMode roundingMode=RoundingMode::ToNearestTiesToEven); 105 | /**@}*/ 106 | 107 | /**\name Arithmetic operators 108 | * These compute the result of an arithmetic operation 109 | * using a particular ``roundingMode`` and return a new float16_t 110 | * representing the result. 111 | * 112 | * Exceptions are ignored. 113 | */ 114 | /**@{*/ 115 | /** add */ 116 | EXPORT float16_t add(float16_t rhs, RoundingMode roundingMode) const; 117 | /** subtract */ 118 | EXPORT float16_t subtract(float16_t rhs, RoundingMode roundingMode) const; 119 | /** multiply */ 120 | EXPORT float16_t multiply(float16_t rhs, RoundingMode roundingMode) const; 121 | /** divide */ 122 | EXPORT float16_t divide(float16_t denominator, RoundingMode roundingMode) const; 123 | /** IEEE-754 2008 5.3.1 General operations - remainder **/ 124 | EXPORT float16_t remainder(float16_t denominator) const; 125 | /** C fmod() */ 126 | EXPORT float16_t mod(float16_t denominator, RoundingMode roudingMode) const; 127 | /**@}*/ 128 | 129 | 130 | /** Return a new float16_t with a negated sign bit*/ 131 | EXPORT float16_t operator-() const; 132 | 133 | /** \name Overloaded arithmetic operators for convenience 134 | * These operators assume RoundingMode::ToNearestTiesToEven rounding 135 | */ 136 | /**@{*/ 137 | EXPORT float16_t operator+(float16_t rhs) const; 138 | EXPORT float16_t operator-(float16_t rhs) const; 139 | EXPORT float16_t operator*(float16_t rhs) const; 140 | EXPORT float16_t operator/(float16_t rhs) const; 141 | /**@}*/ 142 | 143 | /** \name Comparison operators */ 144 | /**@{*/ 145 | /** Equality */ 146 | EXPORT bool operator==(float16_t rhs) const; 147 | /** Not equal */ 148 | EXPORT bool operator!=(float16_t rhs) const { return !(*this == rhs); } 149 | /** Greater than */ 150 | EXPORT bool operator>(float16_t rhs) const; 151 | /** Less than */ 152 | EXPORT bool operator<(float16_t rhs) const; 153 | /** Greater than or equal to*/ 154 | EXPORT bool operator>=(float16_t rhs) const { return (*this > rhs) || (*this == rhs); } 155 | /** Less than or equal to*/ 156 | EXPORT bool operator<=(float16_t rhs) const { return (*this < rhs) || (*this == rhs); } 157 | /** \return true if and only if the float16_t and ``rhs`` are not ordered. E.g. 158 | * NaN and a normalised number 159 | */ 160 | EXPORT bool are_unordered(float16_t rhs) const; 161 | /**@}*/ 162 | 163 | /** \name String output methods */ 164 | /**@{*/ 165 | /** Return a string in the C99 hex format (e.g.\ ``-0x1.000p-1``) that 166 | * represents this float16_t precisely. 167 | */ 168 | EXPORT std::string to_hex_string() const; 169 | /** Returns a string in a decimal scientific notation (e.g.\ ``-5.0E-1``) 170 | * that represents the closest decimal value to this float16_t precise to 171 | * the number of significant digits requested. 172 | * 173 | * \param significantDigits The number of significant digits to use. If 174 | * set to ``0`` then string returned will have enough precision to 175 | * construct the same float16_t when using 176 | * RoundingMode::ToNearestTiesToEven 177 | */ 178 | EXPORT std::string to_decimal_string(unsigned int significantDigits = 0) const; 179 | /**@}*/ 180 | 181 | /** \name Properties */ 182 | /*@{*/ 183 | EXPORT bool is_nan() const; 184 | EXPORT bool is_infinity() const; 185 | EXPORT bool is_negative() const; 186 | EXPORT bool is_zero() const; 187 | /*@}*/ 188 | 189 | /** Returns the bits that represent this float16_t. 190 | * 191 | * An alternative method to access the bits is to cast a pointer 192 | * to this instance as a pointer to a uint16_t. 193 | **/ 194 | EXPORT uint16_t to_bits() const; 195 | 196 | private: 197 | // The raw bits. 198 | // This must be the **ONLY** data member so that 199 | // this data type is 16-bits wide. 200 | uint16_t data; 201 | }; 202 | } // namespace HalideIR 203 | 204 | template<> 205 | HALIDEIR_ALWAYS_INLINE halideir_type_t halideir_type_of() { 206 | return halideir_type_t(halideir_type_float, 16); 207 | } 208 | 209 | #endif 210 | -------------------------------------------------------------------------------- /src/base/Float16Opt.cpp: -------------------------------------------------------------------------------- 1 | #include "Float16.h" 2 | #include "Error.h" 3 | 4 | #include 5 | #include 6 | 7 | using namespace HalideIR; 8 | 9 | namespace HalideIR { 10 | 11 | // An optional implementation of float16_t 12 | // Float16 conversion op that removes LLVM dep, 13 | // so things can be invariant from LLVM until codegen 14 | //5A 15 | // The float16_t here is not accurate for arithmetic(uses float) 16 | // But can be used as a good storage type. 17 | 18 | namespace { 19 | 20 | union Bits { 21 | float f; 22 | int32_t si; 23 | uint32_t ui; 24 | }; 25 | 26 | static int const shift = 13; 27 | static int const shiftSign = 16; 28 | 29 | static int32_t const infN = 0x7F800000; // flt32 infinity 30 | static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32 31 | static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 32 | static int32_t const signN = 0x80000000; // flt32 sign bit 33 | 34 | static int32_t const infC = infN >> shift; 35 | static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 36 | static int32_t const maxC = maxN >> shift; 37 | static int32_t const minC = minN >> shift; 38 | static int32_t const signC = signN >> shiftSign; // flt16 sign bit 39 | 40 | static int32_t const mulN = 0x52000000; // (1 << 23) / minN 41 | static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) 42 | 43 | static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted 44 | static int32_t const norC = 0x00400; // min flt32 normal down shifted 45 | 46 | static int32_t const maxD = infC - maxC - 1; 47 | static int32_t const minD = minC - subC - 1; 48 | 49 | inline uint16_t float2half(const float& value) { 50 | Bits v, s; 51 | v.f = value; 52 | uint32_t sign = v.si & signN; 53 | v.si ^= sign; 54 | sign >>= shiftSign; // logical shift 55 | s.si = mulN; 56 | s.si = static_cast(s.f * v.f); // correct subnormals 57 | v.si ^= (s.si ^ v.si) & -(minN > v.si); 58 | v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); 59 | v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); 60 | v.ui >>= shift; // logical shift 61 | v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); 62 | v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); 63 | return v.ui | sign; 64 | } 65 | 66 | inline float half2float(const uint16_t& value) { 67 | Bits v; 68 | v.ui = value; 69 | int32_t sign = v.si & signC; 70 | v.si ^= sign; 71 | sign <<= shiftSign; 72 | v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); 73 | v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); 74 | Bits s; 75 | s.si = mulC; 76 | s.f *= v.si; 77 | int32_t mask = -(norC > v.si); 78 | v.si <<= shift; 79 | v.si ^= (s.si ^ v.si) & mask; 80 | v.si |= sign; 81 | return v.f; 82 | } 83 | } // namespace 84 | 85 | // The static_asserts checking the size is to make sure 86 | // float16_t can be used as a 16-bits wide POD type. 87 | float16_t::float16_t(float value, RoundingMode roundingMode) { 88 | static_assert(sizeof(float16_t) == 2, "float16_t is wrong size"); 89 | this->data = float2half(value); 90 | } 91 | 92 | float16_t::float16_t(double value, RoundingMode roundingMode) { 93 | static_assert(sizeof(float16_t) == 2, "float16_t is wrong size"); 94 | this->data = float2half(static_cast(value)); 95 | } 96 | 97 | float16_t::float16_t(const char *stringRepr, RoundingMode roundingMode) { 98 | static_assert(sizeof(float16_t) == 2, "float16_t is wrong size"); 99 | std::memcpy(&data, stringRepr, 2); 100 | } 101 | 102 | float16_t::float16_t() { 103 | static_assert(sizeof(float16_t) == 2, "float16_t is wrong size"); 104 | this->data = 0; 105 | } 106 | 107 | 108 | float16_t::operator float() const { 109 | return half2float(data); 110 | } 111 | 112 | float16_t::operator double() const { 113 | return half2float(data); 114 | } 115 | 116 | float16_t float16_t::make_zero(bool positive) { 117 | return float16_t(0.0f, RoundingMode::TowardZero); 118 | } 119 | 120 | float16_t float16_t::make_infinity(bool positive) { 121 | return float16_t(std::numeric_limits::infinity(), RoundingMode::TowardZero); 122 | } 123 | 124 | float16_t float16_t::make_nan() { 125 | return float16_t(std::nan(""), RoundingMode::TowardZero); 126 | } 127 | 128 | float16_t float16_t::add(float16_t rhs, RoundingMode roundingMode) const { 129 | return float16_t(half2float(data) + half2float(rhs.data), roundingMode); 130 | } 131 | 132 | float16_t float16_t::subtract(float16_t rhs, RoundingMode roundingMode) const { 133 | return float16_t(half2float(data) - half2float(rhs.data), roundingMode); 134 | } 135 | 136 | float16_t float16_t::multiply(float16_t rhs, RoundingMode roundingMode) const { 137 | return float16_t(half2float(data) * half2float(rhs.data), roundingMode); 138 | } 139 | 140 | float16_t float16_t::divide(float16_t rhs, RoundingMode roundingMode) const { 141 | return float16_t(half2float(data) / half2float(rhs.data), roundingMode); 142 | } 143 | 144 | float16_t float16_t::operator-() const { 145 | return float16_t(-half2float(data), RoundingMode::TowardZero); 146 | } 147 | 148 | float16_t float16_t::operator+(float16_t rhs) const { 149 | return this->add(rhs, RoundingMode::ToNearestTiesToEven); 150 | } 151 | 152 | float16_t float16_t::operator-(float16_t rhs) const { 153 | return this->subtract(rhs, RoundingMode::ToNearestTiesToEven); 154 | } 155 | 156 | float16_t float16_t::operator*(float16_t rhs) const { 157 | return this->multiply(rhs, RoundingMode::ToNearestTiesToEven); 158 | } 159 | 160 | float16_t float16_t::operator/(float16_t rhs) const { 161 | return this->divide(rhs, RoundingMode::ToNearestTiesToEven); 162 | } 163 | 164 | bool float16_t::operator==(float16_t rhs) const { 165 | return half2float(data) == half2float(rhs.data); 166 | } 167 | 168 | bool float16_t::operator>(float16_t rhs) const { 169 | internal_assert(!this->are_unordered(rhs)) << "Cannot compare unorderable values\n"; 170 | return half2float(data) > half2float(rhs.data); 171 | } 172 | 173 | bool float16_t::operator<(float16_t rhs) const { 174 | internal_assert(!this->are_unordered(rhs)) << "Cannot compare unorderable values\n"; 175 | return half2float(data) < half2float(rhs.data); 176 | } 177 | 178 | bool float16_t::are_unordered(float16_t rhs) const { 179 | return std::isunordered(half2float(data), half2float(rhs.data)); 180 | } 181 | 182 | std::string float16_t::to_decimal_string(unsigned int significantDigits) const { 183 | return std::to_string(half2float(data)); 184 | } 185 | 186 | bool float16_t::is_nan() const { 187 | return std::isnan(half2float(data)); 188 | } 189 | 190 | bool float16_t::is_infinity() const { 191 | return std::isinf(half2float(data)); 192 | } 193 | 194 | bool float16_t::is_negative() const { 195 | return half2float(data) < 0; 196 | } 197 | 198 | bool float16_t::is_zero() const { 199 | return half2float(data) == 0; 200 | } 201 | 202 | uint16_t float16_t::to_bits() const { 203 | return this->data; 204 | } 205 | 206 | } // namespace halide 207 | -------------------------------------------------------------------------------- /src/base/RoundingMode.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_ROUNDING_MODE_H 2 | #define HALIDEIR_ROUNDING_MODE_H 3 | namespace HalideIR { 4 | 5 | /** Rounding modes (IEEE754 2008 4.3 Rounding-direction attributes) */ 6 | enum class RoundingMode { 7 | TowardZero, ///< Round towards zero (IEEE754 2008 4.3.2) 8 | ToNearestTiesToEven, ///< Round to nearest, when there is a tie pick even integral significand (IEEE754 2008 4.3.1) 9 | ToNearestTiesToAway, ///< Round to nearest, when there is a tie pick value furthest away from zero (IEEE754 2008 4.3.1) 10 | TowardPositiveInfinity, ///< Round towards positive infinity (IEEE754 2008 4.3.2) 11 | TowardNegativeInfinity ///< Round towards negative infinity (IEEE754 2008 4.3.2) 12 | }; 13 | 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /src/base/Type.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // TODO(tqchen): remove recursive dep on IR? 4 | #include "ir/IR.h" 5 | 6 | namespace HalideIR { 7 | 8 | using std::ostringstream; 9 | 10 | namespace { 11 | uint64_t max_uint(int bits) { 12 | uint64_t max_val = 0xffffffffffffffffULL; 13 | return max_val >> (64 - bits); 14 | } 15 | 16 | int64_t max_int(int bits) { 17 | int64_t max_val = 0x7fffffffffffffffLL; 18 | return max_val >> (64 - bits); 19 | } 20 | 21 | int64_t min_int(int bits) { 22 | return -max_int(bits) - 1; 23 | } 24 | 25 | } 26 | 27 | /** Return an expression which is the maximum value of this type */ 28 | HalideIR::Expr Type::max() const { 29 | if (is_vector()) { 30 | return Internal::Broadcast::make(element_of().max(), lanes()); 31 | } else if (is_int()) { 32 | return Internal::IntImm::make(*this, max_int(bits())); 33 | } else if (is_uint()) { 34 | return Internal::UIntImm::make(*this, max_uint(bits())); 35 | } else { 36 | internal_assert(is_float()); 37 | if (bits() == 16) { 38 | return Internal::FloatImm::make(*this, 65504.0); 39 | } else if (bits() == 32) { 40 | return Internal::FloatImm::make(*this, FLT_MAX); 41 | } else if (bits() == 64) { 42 | return Internal::FloatImm::make(*this, DBL_MAX); 43 | } else { 44 | internal_error 45 | << "Unknown float type: " << (*this) << "\n"; 46 | return 0; 47 | } 48 | } 49 | } 50 | 51 | /** Return an expression which is the minimum value of this type */ 52 | HalideIR::Expr Type::min() const { 53 | if (is_vector()) { 54 | return Internal::Broadcast::make(element_of().min(), lanes()); 55 | } else if (is_int()) { 56 | return Internal::IntImm::make(*this, min_int(bits())); 57 | } else if (is_uint()) { 58 | return Internal::UIntImm::make(*this, 0); 59 | } else { 60 | internal_assert(is_float()); 61 | if (bits() == 16) { 62 | return Internal::FloatImm::make(*this, -65504.0); 63 | } else if (bits() == 32) { 64 | return Internal::FloatImm::make(*this, -FLT_MAX); 65 | } else if (bits() == 64) { 66 | return Internal::FloatImm::make(*this, -DBL_MAX); 67 | } else { 68 | internal_error 69 | << "Unknown float type: " << (*this) << "\n"; 70 | return 0; 71 | } 72 | } 73 | } 74 | 75 | bool Type::is_max(int64_t x) const { 76 | return x > 0 && is_max((uint64_t)x); 77 | } 78 | 79 | bool Type::is_max(uint64_t x) const { 80 | if (is_int()) { 81 | return x == (uint64_t)max_int(bits()); 82 | } else if (is_uint()) { 83 | return x == max_uint(bits()); 84 | } else { 85 | return false; 86 | } 87 | } 88 | 89 | bool Type::is_min(int64_t x) const { 90 | if (is_int()) { 91 | return x == min_int(bits()); 92 | } else if (is_uint()) { 93 | return x == 0; 94 | } else { 95 | return false; 96 | } 97 | } 98 | 99 | bool Type::is_min(uint64_t x) const { 100 | return false; 101 | } 102 | 103 | bool Type::can_represent(Type other) const { 104 | if (lanes() != other.lanes()) return false; 105 | if (is_int()) { 106 | return ((other.is_int() && other.bits() <= bits()) || 107 | (other.is_uint() && other.bits() < bits())); 108 | } else if (is_uint()) { 109 | return other.is_uint() && other.bits() <= bits(); 110 | } else if (is_float()) { 111 | return ((other.is_float() && other.bits() <= bits()) || 112 | (bits() == 64 && other.bits() <= 32) || 113 | (bits() == 32 && other.bits() <= 16)); 114 | } else { 115 | return false; 116 | } 117 | } 118 | 119 | bool Type::can_represent(int64_t x) const { 120 | if (is_int()) { 121 | return x >= min_int(bits()) && x <= max_int(bits()); 122 | } else if (is_uint()) { 123 | return x >= 0 && (uint64_t)x <= max_uint(bits()); 124 | } else if (is_float()) { 125 | switch (bits()) { 126 | case 16: 127 | return (int64_t)(float)(float16_t)(float)x == x; 128 | case 32: 129 | return (int64_t)(float)x == x; 130 | case 64: 131 | return (int64_t)(double)x == x; 132 | default: 133 | return false; 134 | } 135 | } else { 136 | return false; 137 | } 138 | } 139 | 140 | bool Type::can_represent(uint64_t x) const { 141 | if (is_int()) { 142 | return x <= (uint64_t)(max_int(bits())); 143 | } else if (is_uint()) { 144 | return x <= max_uint(bits()); 145 | } else if (is_float()) { 146 | switch (bits()) { 147 | case 16: 148 | return (uint64_t)(float)(float16_t)(float)x == x; 149 | case 32: 150 | return (uint64_t)(float)x == x; 151 | case 64: 152 | return (uint64_t)(double)x == x; 153 | default: 154 | return false; 155 | } 156 | } else { 157 | return false; 158 | } 159 | } 160 | 161 | bool Type::can_represent(double x) const { 162 | if (is_int()) { 163 | int64_t i = x; 164 | return (x >= min_int(bits())) && (x <= max_int(bits())) && (x == (double)i); 165 | } else if (is_uint()) { 166 | uint64_t u = x; 167 | return (x >= 0) && (x <= max_uint(bits())) && (x == (double)u); 168 | } else if (is_float()) { 169 | switch (bits()) { 170 | case 16: 171 | return (double)(float16_t)x == x; 172 | case 32: 173 | return (double)(float)x == x; 174 | case 64: 175 | return true; 176 | default: 177 | return false; 178 | } 179 | } else { 180 | return false; 181 | } 182 | } 183 | 184 | bool Type::same_handle_type(const Type &other) const { 185 | const halideir_handle_cplusplus_type *first = handle_type; 186 | const halideir_handle_cplusplus_type *second = other.handle_type; 187 | 188 | if (first == second) { 189 | return true; 190 | } 191 | 192 | if (first == nullptr) { 193 | first = halideir_handle_traits::type_info(); 194 | } 195 | if (second == nullptr) { 196 | second = halideir_handle_traits::type_info(); 197 | } 198 | 199 | return first->inner_name == second->inner_name && 200 | first->namespaces == second->namespaces && 201 | first->enclosing_types == second->enclosing_types && 202 | first->cpp_type_modifiers == second->cpp_type_modifiers && 203 | first->reference_type == second->reference_type; 204 | } 205 | 206 | } 207 | -------------------------------------------------------------------------------- /src/base/TypeBase.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_TYPEBASE_H 2 | #define HALIDEIR_TYPEBASE_H 3 | 4 | // type handling code stripped from Halide runtime 5 | 6 | #include 7 | #include 8 | #include 9 | // Forward declare type to allow naming typed handles. 10 | // See Type.h for documentation. 11 | template struct halideir_handle_traits; 12 | 13 | /** Types in the halide type system. They can be ints, unsigned ints, 14 | * or floats (of various bit-widths), or a handle (which is always 64-bits). 15 | * Note that the int/uint/float values do not imply a specific bit width 16 | * (the bit width is expected to be encoded in a separate value). 17 | */ 18 | typedef enum halideir_type_code_t 19 | #if __cplusplus >= 201103L 20 | : uint8_t 21 | #endif 22 | { 23 | halideir_type_int = 0, //!< signed integers 24 | halideir_type_uint = 1, //!< unsigned integers 25 | halideir_type_float = 2, //!< floating point numbers 26 | halideir_type_handle = 3 //!< opaque pointer type (void *) 27 | } halideir_type_code_t; 28 | 29 | // Note that while __attribute__ can go before or after the declaration, 30 | // __declspec apparently is only allowed before. 31 | #ifndef HALIDEIR_ATTRIBUTE_ALIGN 32 | #ifdef _MSC_VER 33 | #define HALIDEIR_ATTRIBUTE_ALIGN(x) __declspec(align(x)) 34 | #else 35 | #define HALIDEIR_ATTRIBUTE_ALIGN(x) __attribute__((aligned(x))) 36 | #endif 37 | #endif 38 | 39 | /** A runtime tag for a type in the halide type system. Can be ints, 40 | * unsigned ints, or floats of various bit-widths (the 'bits' 41 | * field). Can also be vectors of the same (by setting the 'lanes' 42 | * field to something larger than one). This struct should be 43 | * exactly 32-bits in size. */ 44 | struct halideir_type_t { 45 | /** The basic type code: signed integer, unsigned integer, or floating point. */ 46 | #if __cplusplus >= 201103L 47 | HALIDEIR_ATTRIBUTE_ALIGN(1) halideir_type_code_t code; // halideir_type_code_t 48 | #else 49 | HALIDEIR_ATTRIBUTE_ALIGN(1) uint8_t code; // halideir_type_code_t 50 | #endif 51 | 52 | /** The number of bits of precision of a single scalar value of this type. */ 53 | HALIDEIR_ATTRIBUTE_ALIGN(1) uint8_t bits; 54 | 55 | /** How many elements in a vector. This is 1 for scalar types. */ 56 | HALIDEIR_ATTRIBUTE_ALIGN(2) uint16_t lanes; 57 | 58 | #ifdef __cplusplus 59 | /** Construct a runtime representation of a Halide type from: 60 | * code: The fundamental type from an enum. 61 | * bits: The bit size of one element. 62 | * lanes: The number of vector elements in the type. */ 63 | halideir_type_t(halideir_type_code_t code, uint8_t bits, uint16_t lanes = 1) 64 | : code(code), bits(bits), lanes(lanes) { 65 | } 66 | 67 | /** Default constructor is required e.g. to declare halideir_trace_event 68 | * instances. */ 69 | halideir_type_t() : code((halideir_type_code_t)0), bits(0), lanes(0) {} 70 | 71 | /** Compare two types for equality. */ 72 | bool operator==(const halideir_type_t &other) const { 73 | return (code == other.code && 74 | bits == other.bits && 75 | lanes == other.lanes); 76 | } 77 | 78 | /** Size in bytes for a single element, even if width is not 1, of this type. */ 79 | size_t bytes() const { return (bits + 7) / 8; } 80 | #endif 81 | }; 82 | 83 | namespace { 84 | 85 | template 86 | struct halideir_type_of_helper; 87 | 88 | template 89 | struct halideir_type_of_helper { 90 | operator halideir_type_t() { 91 | return halideir_type_t(halideir_type_handle, 64); 92 | } 93 | }; 94 | 95 | template 96 | struct halideir_type_of_helper { 97 | operator halideir_type_t() { 98 | return halideir_type_t(halideir_type_handle, 64); 99 | } 100 | }; 101 | 102 | // Halide runtime does not require C++11 103 | #if __cplusplus > 199711L 104 | template 105 | struct halideir_type_of_helper { 106 | operator halideir_type_t() { 107 | return halideir_type_t(halideir_type_handle, 64); 108 | } 109 | }; 110 | #endif 111 | 112 | template<> 113 | struct halideir_type_of_helper { 114 | operator halideir_type_t() { return halideir_type_t(halideir_type_float, 32); } 115 | }; 116 | 117 | template<> 118 | struct halideir_type_of_helper { 119 | operator halideir_type_t() { return halideir_type_t(halideir_type_float, 64); } 120 | }; 121 | 122 | template<> 123 | struct halideir_type_of_helper { 124 | operator halideir_type_t() { return halideir_type_t(halideir_type_uint, 8); } 125 | }; 126 | 127 | template<> 128 | struct halideir_type_of_helper { 129 | operator halideir_type_t() { return halideir_type_t(halideir_type_uint, 16); } 130 | }; 131 | 132 | template<> 133 | struct halideir_type_of_helper { 134 | operator halideir_type_t() { return halideir_type_t(halideir_type_uint, 32); } 135 | }; 136 | 137 | template<> 138 | struct halideir_type_of_helper { 139 | operator halideir_type_t() { return halideir_type_t(halideir_type_uint, 64); } 140 | }; 141 | 142 | template<> 143 | struct halideir_type_of_helper { 144 | operator halideir_type_t() { return halideir_type_t(halideir_type_int, 8); } 145 | }; 146 | 147 | template<> 148 | struct halideir_type_of_helper { 149 | operator halideir_type_t() { return halideir_type_t(halideir_type_int, 16); } 150 | }; 151 | 152 | template<> 153 | struct halideir_type_of_helper { 154 | operator halideir_type_t() { return halideir_type_t(halideir_type_int, 32); } 155 | }; 156 | 157 | template<> 158 | struct halideir_type_of_helper { 159 | operator halideir_type_t() { return halideir_type_t(halideir_type_int, 64); } 160 | }; 161 | 162 | template<> 163 | struct halideir_type_of_helper { 164 | operator halideir_type_t() { return halideir_type_t(halideir_type_uint, 1); } 165 | }; 166 | 167 | } 168 | 169 | /** Construct the halide equivalent of a C type */ 170 | template halideir_type_t halideir_type_of() { 171 | return halideir_type_of_helper(); 172 | } 173 | 174 | // it is not necessary, and may produce warnings for some build configurations. 175 | #ifdef _MSC_VER 176 | #define HALIDEIR_ALWAYS_INLINE __forceinline 177 | #else 178 | #define HALIDEIR_ALWAYS_INLINE __attribute__((always_inline)) inline 179 | #endif 180 | 181 | #endif // HALIDEIR_HALIDERUNTIME_H 182 | -------------------------------------------------------------------------------- /src/base/Util.cpp: -------------------------------------------------------------------------------- 1 | #include "./Util.h" 2 | #include "./Debug.h" 3 | #include "./Error.h" 4 | #include 5 | #include 6 | 7 | namespace HalideIR { 8 | namespace Internal { 9 | std::vector split_string(const std::string &source, const std::string &delim) { 10 | std::vector elements; 11 | size_t start = 0; 12 | size_t found = 0; 13 | while ((found = source.find(delim, start)) != std::string::npos) { 14 | elements.push_back(source.substr(start, found - start)); 15 | start = found + delim.size(); 16 | } 17 | 18 | // If start is exactly source.size(), the last thing in source is a 19 | // delimiter, in which case we want to add an empty string to elements. 20 | if (start <= source.size()) { 21 | elements.push_back(source.substr(start, std::string::npos)); 22 | } 23 | return elements; 24 | } 25 | 26 | std::string extract_namespaces(const std::string &name, std::vector &namespaces) { 27 | namespaces = split_string(name, "::"); 28 | std::string result = namespaces.back(); 29 | namespaces.pop_back(); 30 | return result; 31 | } 32 | 33 | bool add_would_overflow(int bits, int64_t a, int64_t b) { 34 | int64_t max_val = 0x7fffffffffffffffLL >> (64 - bits); 35 | int64_t min_val = -max_val - 1; 36 | return 37 | ((b > 0 && a > max_val - b) || // (a + b) > max_val, rewritten to avoid overflow 38 | (b < 0 && a < min_val - b)); // (a + b) < min_val, rewritten to avoid overflow 39 | } 40 | 41 | bool sub_would_overflow(int bits, int64_t a, int64_t b) { 42 | int64_t max_val = 0x7fffffffffffffffLL >> (64 - bits); 43 | int64_t min_val = -max_val - 1; 44 | return 45 | ((b < 0 && a > max_val + b) || // (a - b) > max_val, rewritten to avoid overflow 46 | (b > 0 && a < min_val + b)); // (a - b) < min_val, rewritten to avoid overflow 47 | } 48 | 49 | bool mul_would_overflow(int bits, int64_t a, int64_t b) { 50 | int64_t max_val = 0x7fffffffffffffffLL >> (64 - bits); 51 | int64_t min_val = -max_val - 1; 52 | if (a == 0) { 53 | return false; 54 | } else if (a == -1) { 55 | return b == min_val; 56 | } else { 57 | // Do the multiplication as a uint64, for which overflow is 58 | // well defined, then cast the bits back to int64 to get 59 | // multiplication modulo 2^64. 60 | int64_t ab = (int64_t)((uint64_t)a)*((uint64_t)b); 61 | // The first two clauses catch overflow mod 2^bits, assuming 62 | // no 64-bit overflow occurs, and the third clause catches 63 | // 64-bit overflow. 64 | return ab < min_val || ab > max_val || (ab / a != b); 65 | } 66 | } 67 | 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/base/Util.h: -------------------------------------------------------------------------------- 1 | // Always use assert, even if llvm-config defines NDEBUG 2 | #ifdef NDEBUG 3 | #undef NDEBUG 4 | #include 5 | #define NDEBUG 6 | #else 7 | #include 8 | #endif 9 | 10 | #ifndef HALIDEIR_UTIL_H 11 | #define HALIDEIR_UTIL_H 12 | 13 | /** \file 14 | * Various utility functions used internally Halide. */ 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | // by default, the symbol EXPORT does nothing. In windows dll builds we can define it to __declspec(dllexport) 23 | #if defined(_WIN32) 24 | #ifdef Halide_EXPORTS 25 | #define EXPORT __declspec(dllexport) 26 | #else 27 | #define EXPORT __declspec(dllimport) 28 | #endif 29 | #else 30 | #define EXPORT __attribute__((visibility("default"))) 31 | #endif 32 | 33 | // If we're in user code, we don't want certain functions to be inlined. 34 | #if defined(COMPILING_HALIDE) || defined(BUILDING_PYTHON) 35 | #define NO_INLINE 36 | #else 37 | #ifdef _WIN32 38 | #define NO_INLINE __declspec(noinline) 39 | #else 40 | #define NO_INLINE __attribute__((noinline)) 41 | #endif 42 | #endif 43 | 44 | // On windows, Halide needs a larger stack than the default MSVC provides 45 | #ifdef _MSC_VER 46 | #pragma comment(linker, "/STACK:8388608,1048576") 47 | #endif 48 | 49 | namespace HalideIR { 50 | namespace Internal { 51 | 52 | /** An aggressive form of reinterpret cast used for correct type-punning. */ 53 | template 54 | DstType reinterpret_bits(const SrcType &src) { 55 | static_assert(sizeof(SrcType) == sizeof(DstType), "Types must be same size"); 56 | DstType dst; 57 | memcpy(&dst, &src, sizeof(SrcType)); 58 | return dst; 59 | } 60 | 61 | /** Perform a left fold of a vector. Returns a default-constructed 62 | * vector element if the vector is empty. Similar to std::accumulate 63 | * but with a less clunky syntax. */ 64 | template 65 | T fold_left(const std::vector &vec, Fn f) { 66 | T result; 67 | if (vec.empty()) { 68 | return result; 69 | } 70 | result = vec[0]; 71 | for (size_t i = 1; i < vec.size(); i++) { 72 | result = f(result, vec[i]); 73 | } 74 | return result; 75 | } 76 | 77 | /** Returns a right fold of a vector. Returns a default-constructed 78 | * vector element if the vector is empty. */ 79 | template 80 | T fold_right(const std::vector &vec, Fn f) { 81 | T result; 82 | if (vec.empty()) { 83 | return result; 84 | } 85 | result = vec.back(); 86 | for (size_t i = vec.size()-1; i > 0; i--) { 87 | result = f(vec[i-1], result); 88 | } 89 | return result; 90 | } 91 | 92 | template 93 | inline NO_INLINE void collect_paired_args(std::vector> &collected_args, 94 | const T3 &a1, const T4 &a2) { 95 | collected_args.push_back(std::pair(a1, a2)); 96 | } 97 | 98 | template 99 | inline NO_INLINE void collect_paired_args(std::vector> &collected_args, 100 | const T3 &a1, const T4 &a2, Args&&... args) { 101 | collected_args.push_back(std::pair(a1, a2)); 102 | collect_paired_args(collected_args, std::forward(args)...); 103 | } 104 | 105 | template 106 | struct meta_and : std::true_type {}; 107 | 108 | template 109 | struct meta_and : std::integral_constant::value> {}; 110 | 111 | template 112 | struct meta_or : std::false_type {}; 113 | 114 | template 115 | struct meta_or : std::integral_constant::value> {}; 116 | 117 | template 118 | struct all_are_convertible : meta_and...> {}; 119 | 120 | /** Returns base name and fills in namespaces, outermost one first in vector. */ 121 | EXPORT std::string extract_namespaces(const std::string &name, std::vector &namespaces); 122 | 123 | 124 | /** Routines to test if math would overflow for signed integers with 125 | * the given number of bits. */ 126 | // @{ 127 | bool add_would_overflow(int bits, int64_t a, int64_t b); 128 | bool sub_would_overflow(int bits, int64_t a, int64_t b); 129 | bool mul_would_overflow(int bits, int64_t a, int64_t b); 130 | // @} 131 | 132 | // Wrappers for some C++14-isms that are useful and trivially implementable 133 | // in C++11; these are defined in the HalideIR::Internal namespace. If we 134 | // are compiling under C++14 or later, we just use the standard implementations 135 | // rather than our own. 136 | #if __cplusplus >= 201402L 137 | 138 | // C++14: Use the standard implementations 139 | using std::integer_sequence; 140 | using std::make_integer_sequence; 141 | using std::index_sequence; 142 | using std::make_index_sequence; 143 | 144 | #else 145 | 146 | // C++11: std::integer_sequence (etc) is standard in C++14 but not C++11, but 147 | // is easily written in C++11. This is a simple version that could 148 | // probably be improved. 149 | 150 | template 151 | struct integer_sequence { 152 | static constexpr size_t size() { return sizeof...(Ints); } 153 | }; 154 | 155 | template 156 | struct next_integer_sequence; 157 | 158 | template 159 | struct next_integer_sequence> { 160 | using type = integer_sequence; 161 | }; 162 | 163 | template 164 | struct make_integer_sequence_helper { 165 | using type = typename next_integer_sequence< 166 | typename make_integer_sequence_helper::type 167 | >::type; 168 | }; 169 | 170 | template 171 | struct make_integer_sequence_helper { 172 | using type = integer_sequence; 173 | }; 174 | 175 | template 176 | using make_integer_sequence = typename make_integer_sequence_helper::type; 177 | 178 | template 179 | using index_sequence = integer_sequence; 180 | 181 | template 182 | using make_index_sequence = make_integer_sequence; 183 | 184 | #endif 185 | 186 | } // namespace Internal 187 | } // namespace HalideIR 188 | 189 | #endif 190 | -------------------------------------------------------------------------------- /src/ir/Expr.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_EXPR_H 2 | #define HALIDEIR_EXPR_H 3 | 4 | /** \file 5 | * Base classes for Halide expressions (\ref HalideIR::Expr) and statements (\ref HalideIR::Internal::Stmt) 6 | */ 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | #include "base/Debug.h" 16 | #include "base/Error.h" 17 | #include "base/Float16.h" 18 | #include "base/Type.h" 19 | #include "base/Util.h" 20 | 21 | 22 | namespace HalideIR { 23 | using tvm::Node; 24 | using tvm::NodeRef; 25 | using tvm::Array; 26 | using tvm::NodePtr; 27 | using tvm::make_node; 28 | 29 | namespace IR { 30 | using tvm::AttrVisitor; 31 | } // namespace IR 32 | 33 | namespace Internal { 34 | 35 | struct Variable; 36 | class IRVisitor; 37 | 38 | /** All our IR node types get unique IDs for the purposes of RTTI */ 39 | enum class IRNodeType : int { 40 | IntImm, 41 | UIntImm, 42 | FloatImm, 43 | StringImm, 44 | Cast, 45 | Variable, 46 | Add, 47 | Sub, 48 | Mul, 49 | Div, 50 | Mod, 51 | Min, 52 | Max, 53 | EQ, 54 | NE, 55 | LT, 56 | LE, 57 | GT, 58 | GE, 59 | And, 60 | Or, 61 | Not, 62 | Select, 63 | Load, 64 | Ramp, 65 | Broadcast, 66 | Call, 67 | Let, 68 | LetStmt, 69 | AssertStmt, 70 | ProducerConsumer, 71 | For, 72 | Store, 73 | Provide, 74 | Allocate, 75 | Free, 76 | Realize, 77 | Block, 78 | IfThenElse, 79 | Evaluate, 80 | Shuffle, 81 | Prefetch, 82 | AttrStmt, 83 | ExtensionExpr 84 | }; 85 | 86 | /** The abstract base classes for a node in the Halide IR. */ 87 | struct IRNode : public Node { 88 | /** Each IR node subclass should return some unique pointer. We 89 | * can compare these pointers to do runtime type 90 | * identification. We don't compile with rtti because that 91 | * injects run-time type identification stuff everywhere (and 92 | * often breaks when linking external libraries compiled 93 | * without it), and we only want it for IR nodes. */ 94 | virtual IRNodeType type_info() const = 0; 95 | }; 96 | 97 | /** IR nodes are split into expressions and statements. These are 98 | similar to expressions and statements in C - expressions 99 | represent some value and have some type (e.g. x + 3), and 100 | statements are side-effecting pieces of code that do not 101 | represent a value (e.g. assert(x > 3)) */ 102 | 103 | /** A base class for statement nodes. They have no properties or 104 | methods beyond base IR nodes for now */ 105 | struct BaseStmtNode : public IRNode { 106 | /** We use the visitor pattern to traverse IR nodes throughout the 107 | * compiler, so we have a virtual accept method which accepts 108 | * visitors. 109 | */ 110 | virtual void accept(IRVisitor *v, const Stmt &s) const = 0; 111 | // friendly type message 112 | static constexpr const char* _type_key = "Stmt"; 113 | 114 | TVM_DECLARE_BASE_NODE_INFO(BaseStmtNode, Node); 115 | }; 116 | 117 | /** A base class for expression nodes. They all contain their types 118 | * (e.g. Int(32), Float(32)) */ 119 | struct BaseExprNode : public IRNode { 120 | Type type; 121 | /** We use the visitor pattern to traverse IR nodes throughout the 122 | * compiler, so we have a virtual accept method which accepts 123 | * visitors. 124 | */ 125 | virtual void accept(IRVisitor *v, const Expr &e) const = 0; 126 | // friendly type message 127 | static constexpr const char* _type_key = "Expr"; 128 | 129 | TVM_DECLARE_BASE_NODE_INFO(BaseExprNode, Node); 130 | }; 131 | 132 | /** We use the "curiously recurring template pattern" to avoid 133 | duplicated code in the IR Nodes. These classes live between the 134 | abstract base classes and the actual IR Nodes in the 135 | inheritance hierarchy. It provides an implementation of the 136 | accept function necessary for the visitor pattern to work, and 137 | a concrete instantiation of a unique IRNodeType per class. */ 138 | template 139 | struct ExprNode : public BaseExprNode { 140 | EXPORT void accept(IRVisitor *v, const Expr &e) const; 141 | IRNodeType type_info() const final {return T::_type_info;} 142 | 143 | TVM_DECLARE_NODE_TYPE_INFO(T, BaseExprNode); 144 | }; 145 | 146 | template 147 | struct StmtNode : public BaseStmtNode { 148 | EXPORT void accept(IRVisitor *v, const Stmt &s) const; 149 | IRNodeType type_info() const final {return T::_type_info;} 150 | 151 | TVM_DECLARE_NODE_TYPE_INFO(T, BaseStmtNode); 152 | }; 153 | 154 | /** IR nodes are passed around opaque handles to them. This is a 155 | base class for those handles. It manages the reference count, 156 | and dispatches visitors. */ 157 | struct IRHandle : public NodeRef { 158 | IRHandle() {} 159 | IRHandle(NodePtr p) : NodeRef(p) {} 160 | 161 | /** return internal content as IRNode */ 162 | inline const IRNode* get() const { 163 | return static_cast(node_.get()); 164 | } 165 | /** return internal content as IRNode */ 166 | inline const IRNode* operator->() const { 167 | return static_cast(node_.get()); 168 | } 169 | }; 170 | } // namespace Internal 171 | 172 | /** A fragment of Halide syntax. It's implemented as reference-counted 173 | * handle to a concrete expression node, but it's immutable, so you 174 | * can treat it as a value type. */ 175 | struct Expr : public Internal::IRHandle { 176 | /** Make an undefined expression */ 177 | Expr() : Internal::IRHandle() {} 178 | 179 | /** Make an expression from a concrete expression node pointer (e.g. Add) */ 180 | explicit Expr(NodePtr n) : IRHandle(n) {} 181 | 182 | /** Make an expression representing numeric constants of various types. */ 183 | // @{ 184 | EXPORT explicit Expr(int8_t x); 185 | EXPORT explicit Expr(int16_t x); 186 | EXPORT Expr(int32_t x); 187 | EXPORT explicit Expr(int64_t x); 188 | EXPORT explicit Expr(uint8_t x); 189 | EXPORT explicit Expr(uint16_t x); 190 | EXPORT explicit Expr(uint32_t x); 191 | EXPORT explicit Expr(uint64_t x); 192 | EXPORT Expr(float16_t x); 193 | EXPORT Expr(float x); 194 | EXPORT explicit Expr(double x); 195 | // @} 196 | 197 | /** Make an expression representing a const string (i.e. a StringImm) */ 198 | // Ree 199 | EXPORT Expr(const std::string &s); 200 | 201 | /** Dispatch to the correct visitor method for this node. E.g. if 202 | * this node is actually an Add node, then this will call 203 | * IRVisitor::visit(const Add *) */ 204 | inline void accept(Internal::IRVisitor *v) const { 205 | static_cast(node_.get())->accept(v, *this); 206 | } 207 | 208 | /** Get the type of this expression node */ 209 | Type type() const { 210 | return (static_cast(node_.get()))->type; 211 | } 212 | /*! \brief type indicate the container type */ 213 | using ContainerType = Internal::BaseExprNode; 214 | }; 215 | 216 | /** This lets you use an Expr as a key in a map of the form 217 | * map */ 218 | struct ExprCompare { 219 | bool operator()(const Expr& a, const Expr& b) const { 220 | return a.get() < b.get(); 221 | } 222 | }; 223 | 224 | /** This lets you use an Expr as a key in a unordered_map of the form 225 | * unordered_map */ 226 | struct ExprHash { 227 | size_t operator()(const Expr& a) const { 228 | return a.hash(); 229 | } 230 | }; 231 | 232 | /** This lets you use an Expr as a key in a unordered_map of the form 233 | * unordered_map */ 234 | struct ExprEqual { 235 | bool operator()(const Expr& a, const Expr& b) const { 236 | return a.get() == b.get(); 237 | } 238 | }; 239 | 240 | /** 241 | * A subclass of Expr that only refers to a Variable 242 | * 243 | * Avoid use the Var to confuse with Halide's Var in high level DSL. 244 | */ 245 | struct VarExpr : public Expr { 246 | VarExpr() : Expr() { } 247 | explicit VarExpr(NodePtr n) : Expr(n) {} 248 | /** 249 | * constructor from variable 250 | * Choose first have name then type, with default int32 251 | * because most VarExpr are used as looping variable. 252 | */ 253 | EXPORT explicit VarExpr(const std::string &name_hint, Type t = Int(32)); 254 | /** return internal content as Variable */ 255 | inline const Internal::Variable* get() const; 256 | /** return internal variable pointer */ 257 | inline const Internal::Variable* operator->() const; 258 | }; 259 | 260 | /** An enum describing a type of device API. Used by schedules, and in 261 | * the For loop IR node. */ 262 | enum class DeviceAPI : int { 263 | None = 0, /// Used to denote for loops that run on the same device as the containing code. 264 | Host, 265 | Default_GPU, 266 | CUDA, 267 | OpenCL, 268 | GLSL, 269 | OpenGLCompute, 270 | Metal, 271 | Hexagon 272 | }; 273 | 274 | namespace Internal { 275 | 276 | /** An enum describing a type of loop traversal. Used in schedules, 277 | * and in the For loop IR node. */ 278 | enum class ForType : int { 279 | Serial = 0, 280 | Parallel = 1, 281 | Vectorized = 2, 282 | Unrolled = 3 283 | }; 284 | 285 | 286 | /** A reference-counted handle to a statement node. */ 287 | struct Stmt : public IRHandle { 288 | Stmt() : IRHandle() {} 289 | Stmt(NodePtr n) : IRHandle(n) {} 290 | 291 | /** Dispatch to the correct visitor method for this node. E.g. if 292 | * this node is actually an Add node, then this will call 293 | * IRVisitor::visit(const Add *) */ 294 | inline void accept(Internal::IRVisitor *v) const { 295 | static_cast(node_.get())->accept(v, *this); 296 | } 297 | /*! \brief type indicate the container type */ 298 | using ContainerType = Internal::BaseStmtNode; 299 | }; 300 | 301 | 302 | } // namespace Internal 303 | } // namespace HalideIR 304 | 305 | namespace HalideIR { 306 | namespace IR { 307 | using ::HalideIR::Expr; 308 | using Internal::Stmt; 309 | } // namespace IR 310 | } // namespace Stmt 311 | 312 | namespace std { 313 | template <> 314 | struct hash<::HalideIR::Expr> { 315 | std::size_t operator()(const ::HalideIR::Expr& k) const { 316 | return k.hash(); 317 | } 318 | }; 319 | template <> 320 | struct hash<::HalideIR::Internal::Stmt> { 321 | std::size_t operator()(const ::HalideIR::Internal::Stmt& k) const { 322 | return k.hash(); 323 | } 324 | }; 325 | } 326 | #endif 327 | -------------------------------------------------------------------------------- /src/ir/FunctionBase.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2016 by Contributors 3 | * \file FunctionBase.h 4 | * \brief The function reference data structure to hold the function without defining it. 5 | * 6 | * This is used because Function is a high level object that can contain schedule, 7 | * which could have many variations. Removing FunctionContent dep from IR makes IR minimum. 8 | */ 9 | #ifndef HALIDEIR_IR_FUNCTION_BASE_H_ 10 | #define HALIDEIR_IR_FUNCTION_BASE_H_ 11 | 12 | #include 13 | #include "Expr.h" 14 | 15 | namespace HalideIR { 16 | namespace IR { 17 | 18 | // Internal node container of Range 19 | class FunctionBaseNode; 20 | 21 | /*! \brief reference to a function */ 22 | class FunctionRef : public NodeRef { 23 | public: 24 | /*! \brief constructor */ 25 | FunctionRef() {} 26 | FunctionRef(NodePtr n) : NodeRef(n) {} 27 | /*! 28 | * \brief access the internal node container 29 | * \return the pointer to the internal node container 30 | */ 31 | inline const FunctionBaseNode* operator->() const; 32 | }; 33 | 34 | /*! \brief range over one dimension */ 35 | class FunctionBaseNode : public Node { 36 | public: 37 | /*! \return the name of the function */ 38 | virtual const std::string& func_name() const = 0; 39 | /*! \return the number of outputs of this function */ 40 | virtual int num_outputs() const = 0; 41 | }; 42 | 43 | // implements of inline functions 44 | inline const FunctionBaseNode* FunctionRef::operator->() const { 45 | return static_cast(node_.get()); 46 | } 47 | 48 | } // namespace IR 49 | } // namespace HalideIR 50 | 51 | #endif // HALIDEIR_IR_FUNCTION_BASE_H_ 52 | -------------------------------------------------------------------------------- /src/ir/IREquality.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_IR_EQUALITY_H 2 | #define HALIDEIR_IR_EQUALITY_H 3 | 4 | /** \file 5 | * Methods to test Exprs and Stmts for equality of value 6 | */ 7 | 8 | #include "IR.h" 9 | 10 | namespace HalideIR { 11 | namespace Internal { 12 | 13 | /** A compare struct suitable for use in std::map and std::set that 14 | * computes a lexical ordering on IR nodes. */ 15 | struct IRDeepCompare { 16 | EXPORT bool operator()(const Expr &a, const Expr &b) const; 17 | EXPORT bool operator()(const Stmt &a, const Stmt &b) const; 18 | }; 19 | 20 | /** Lossily track known equal exprs with a cache. On collision, the 21 | * old pair is evicted. Used below by ExprWithCompareCache. */ 22 | class IRCompareCache { 23 | private: 24 | struct Entry { 25 | Expr a, b; 26 | }; 27 | 28 | int bits; 29 | 30 | uint32_t hash(const Expr &a, const Expr &b) const { 31 | // Note this hash is symmetric in a and b, so that a 32 | // comparison in a and b hashes to the same bucket as 33 | // a comparison on b and a. 34 | uint64_t pa = (uint64_t)(a.get()); 35 | uint64_t pb = (uint64_t)(b.get()); 36 | uint64_t mix = (pa + pb) + (pa ^ pb); 37 | mix ^= (mix >> bits); 38 | mix ^= (mix >> (bits*2)); 39 | uint32_t bottom = mix & ((1 << bits) - 1); 40 | return bottom; 41 | } 42 | 43 | std::vector entries; 44 | 45 | public: 46 | void insert(const Expr &a, const Expr &b) { 47 | uint32_t h = hash(a, b); 48 | entries[h].a = a; 49 | entries[h].b = b; 50 | } 51 | 52 | bool contains(const Expr &a, const Expr &b) const { 53 | uint32_t h = hash(a, b); 54 | const Entry &e = entries[h]; 55 | return ((a.same_as(e.a) && b.same_as(e.b)) || 56 | (a.same_as(e.b) && b.same_as(e.a))); 57 | } 58 | 59 | void clear() { 60 | for (size_t i = 0; i < entries.size(); i++) { 61 | entries[i].a = Expr(); 62 | entries[i].b = Expr(); 63 | } 64 | } 65 | 66 | IRCompareCache() {} 67 | IRCompareCache(int b) : bits(b), entries(static_cast(1) << bits) {} 68 | }; 69 | 70 | /** A wrapper about Exprs so that they can be deeply compared with a 71 | * cache for known-equal subexpressions. Useful for unsanitized Exprs 72 | * coming in from the front-end, which may be horrible graphs with 73 | * sub-expressions that are equal by value but not by identity. This 74 | * isn't a comparison object like IRDeepCompare above, because libc++ 75 | * requires that comparison objects be stateless (and constructs a new 76 | * one for each comparison!), so they can't have a cache associated 77 | * with them. However, by sneakily making the cache a mutable member 78 | * of the objects being compared, we can dodge this issue. 79 | * 80 | * Clunky example usage: 81 | * 82 | \code 83 | Expr a, b, c, query; 84 | std::set s; 85 | IRCompareCache cache(8); 86 | s.insert(ExprWithCompareCache(a, &cache)); 87 | s.insert(ExprWithCompareCache(b, &cache)); 88 | s.insert(ExprWithCompareCache(c, &cache)); 89 | if (m.contains(ExprWithCompareCache(query, &cache))) {...} 90 | \endcode 91 | * 92 | */ 93 | struct ExprWithCompareCache { 94 | Expr expr; 95 | mutable IRCompareCache *cache; 96 | 97 | ExprWithCompareCache() : cache(nullptr) {} 98 | ExprWithCompareCache(const Expr &e, IRCompareCache *c) : expr(e), cache(c) {} 99 | 100 | /** The comparison uses (and updates) the cache */ 101 | EXPORT bool operator<(const ExprWithCompareCache &other) const; 102 | }; 103 | 104 | /** Compare IR nodes for equality of value. Traverses entire IR 105 | * tree. For equality of reference, use Expr::same_as. If you're 106 | * comparing non-CSE'd Exprs, use graph_equal, which is safe for nasty 107 | * graphs of IR nodes. */ 108 | // @{ 109 | EXPORT bool equal(const Expr &a, const Expr &b); 110 | EXPORT bool equal(const Stmt &a, const Stmt &b); 111 | EXPORT bool graph_equal(const Expr &a, const Expr &b); 112 | EXPORT bool graph_equal(const Stmt &a, const Stmt &b); 113 | // @} 114 | 115 | 116 | 117 | EXPORT void ir_equality_test(); 118 | 119 | } 120 | } 121 | 122 | #endif 123 | -------------------------------------------------------------------------------- /src/ir/IRMutator.cpp: -------------------------------------------------------------------------------- 1 | #include "IRMutator.h" 2 | 3 | namespace HalideIR { 4 | namespace Internal { 5 | 6 | using std::vector; 7 | 8 | Expr IRMutator::mutate(Expr e) { 9 | if (e.defined()) { 10 | e.accept(this); 11 | } else { 12 | expr = Expr(); 13 | } 14 | stmt = Stmt(); 15 | return expr; 16 | } 17 | 18 | Stmt IRMutator::mutate(Stmt s) { 19 | if (s.defined()) { 20 | s.accept(this); 21 | } else { 22 | stmt = Stmt(); 23 | } 24 | expr = Expr(); 25 | return stmt; 26 | } 27 | 28 | void IRMutator::visit(const IntImm *op, const Expr &e) { expr = e; } 29 | void IRMutator::visit(const UIntImm *op, const Expr &e) { expr = e; } 30 | void IRMutator::visit(const FloatImm *op, const Expr &e) { expr = e; } 31 | void IRMutator::visit(const StringImm *op, const Expr &e) { expr = e; } 32 | void IRMutator::visit(const Variable *op, const Expr &e) { expr = e; } 33 | 34 | void IRMutator::visit(const Cast *op, const Expr &e) { 35 | Expr value = mutate(op->value); 36 | if (value.same_as(op->value)) { 37 | expr = e; 38 | } else { 39 | expr = Cast::make(op->type, value); 40 | } 41 | } 42 | 43 | // use macro to access private function. 44 | #define MUTATE_BINARY_OP(op, e, T) \ 45 | Expr a = mutate(op->a); \ 46 | Expr b = mutate(op->b); \ 47 | if (a.same_as(op->a) && \ 48 | b.same_as(op->b)) { \ 49 | expr = e; \ 50 | } else { \ 51 | expr = T::make(a, b); \ 52 | } \ 53 | 54 | void IRMutator::visit(const Add *op, const Expr &e) { 55 | MUTATE_BINARY_OP(op, e, Add); 56 | } 57 | void IRMutator::visit(const Sub *op, const Expr &e) { 58 | MUTATE_BINARY_OP(op, e, Sub); 59 | } 60 | void IRMutator::visit(const Mul *op, const Expr &e) { 61 | MUTATE_BINARY_OP(op, e, Mul); 62 | } 63 | void IRMutator::visit(const Div *op, const Expr &e) { 64 | MUTATE_BINARY_OP(op, e, Div); 65 | } 66 | void IRMutator::visit(const Mod *op, const Expr &e) { 67 | MUTATE_BINARY_OP(op, e, Mod); 68 | } 69 | void IRMutator::visit(const Min *op, const Expr &e) { 70 | MUTATE_BINARY_OP(op, e, Min); 71 | } 72 | void IRMutator::visit(const Max *op, const Expr &e) { 73 | MUTATE_BINARY_OP(op, e, Max); 74 | } 75 | void IRMutator::visit(const EQ *op, const Expr &e) { 76 | MUTATE_BINARY_OP(op, e, EQ); 77 | } 78 | void IRMutator::visit(const NE *op, const Expr &e) { 79 | MUTATE_BINARY_OP(op, e, NE); 80 | } 81 | void IRMutator::visit(const LT *op, const Expr &e) { 82 | MUTATE_BINARY_OP(op, e, LT); 83 | } 84 | void IRMutator::visit(const LE *op, const Expr &e) { 85 | MUTATE_BINARY_OP(op, e, LE); 86 | } 87 | void IRMutator::visit(const GT *op, const Expr &e) { 88 | MUTATE_BINARY_OP(op, e, GT); 89 | } 90 | void IRMutator::visit(const GE *op, const Expr &e) { 91 | MUTATE_BINARY_OP(op, e, GE); 92 | } 93 | void IRMutator::visit(const And *op, const Expr &e) { 94 | MUTATE_BINARY_OP(op, e, And); 95 | } 96 | void IRMutator::visit(const Or *op, const Expr &e) { 97 | MUTATE_BINARY_OP(op, e, Or); 98 | } 99 | 100 | void IRMutator::visit(const Not *op, const Expr &e) { 101 | Expr a = mutate(op->a); 102 | if (a.same_as(op->a)) { 103 | expr = e; 104 | } else { 105 | expr = Not::make(a); 106 | } 107 | } 108 | 109 | void IRMutator::visit(const Select *op, const Expr &e) { 110 | Expr cond = mutate(op->condition); 111 | Expr t = mutate(op->true_value); 112 | Expr f = mutate(op->false_value); 113 | if (cond.same_as(op->condition) && 114 | t.same_as(op->true_value) && 115 | f.same_as(op->false_value)) { 116 | expr = e; 117 | } else { 118 | expr = Select::make(cond, t, f); 119 | } 120 | } 121 | 122 | void IRMutator::visit(const Load *op, const Expr &e) { 123 | Expr index = mutate(op->index); 124 | Expr predicate = mutate(op->predicate); 125 | if (predicate.same_as(op->predicate) && index.same_as(op->index)) { 126 | expr = e; 127 | } else { 128 | expr = Load::make(op->type, op->buffer_var, index, predicate); 129 | } 130 | } 131 | 132 | void IRMutator::visit(const Ramp *op, const Expr &e) { 133 | Expr base = mutate(op->base); 134 | Expr stride = mutate(op->stride); 135 | if (base.same_as(op->base) && 136 | stride.same_as(op->stride)) { 137 | expr = e; 138 | } else { 139 | expr = Ramp::make(base, stride, op->lanes); 140 | } 141 | } 142 | 143 | void IRMutator::visit(const Broadcast *op, const Expr &e) { 144 | Expr value = mutate(op->value); 145 | if (value.same_as(op->value)) { 146 | expr = e; 147 | } else { 148 | expr = Broadcast::make(value, op->lanes); 149 | } 150 | } 151 | 152 | void IRMutator::visit(const Call *op, const Expr &e) { 153 | vector new_args(op->args.size()); 154 | bool changed = false; 155 | 156 | // Mutate the args 157 | for (size_t i = 0; i < op->args.size(); i++) { 158 | Expr old_arg = op->args[i]; 159 | Expr new_arg = mutate(old_arg); 160 | if (!new_arg.same_as(old_arg)) changed = true; 161 | new_args[i] = new_arg; 162 | } 163 | 164 | if (!changed) { 165 | expr = e; 166 | } else { 167 | expr = Call::make(op->type, op->name, new_args, op->call_type, 168 | op->func, op->value_index); 169 | } 170 | } 171 | 172 | void IRMutator::visit(const Let *op, const Expr &e) { 173 | Expr value = mutate(op->value); 174 | Expr body = mutate(op->body); 175 | if (value.same_as(op->value) && 176 | body.same_as(op->body)) { 177 | expr = e; 178 | } else { 179 | expr = Let::make(op->var, value, body); 180 | } 181 | } 182 | 183 | void IRMutator::visit(const LetStmt *op, const Stmt &s) { 184 | Expr value = mutate(op->value); 185 | Stmt body = mutate(op->body); 186 | if (value.same_as(op->value) && 187 | body.same_as(op->body)) { 188 | stmt = s; 189 | } else { 190 | stmt = LetStmt::make(op->var, value, body); 191 | } 192 | } 193 | 194 | void IRMutator::visit(const AttrStmt *op, const Stmt &s) { 195 | Expr value = mutate(op->value); 196 | Stmt body = mutate(op->body); 197 | if (value.same_as(op->value) && 198 | body.same_as(op->body)) { 199 | stmt = s; 200 | } else { 201 | stmt = AttrStmt::make(op->node, op->attr_key, value, body); 202 | } 203 | } 204 | 205 | void IRMutator::visit(const AssertStmt *op, const Stmt &s) { 206 | Expr condition = mutate(op->condition); 207 | Expr message = mutate(op->message); 208 | Stmt body = mutate(op->body); 209 | 210 | if (condition.same_as(op->condition) && 211 | message.same_as(op->message) && 212 | body.same_as(op->body)) { 213 | stmt = s; 214 | } else { 215 | stmt = AssertStmt::make(condition, message, body); 216 | } 217 | } 218 | 219 | void IRMutator::visit(const ProducerConsumer *op, const Stmt &s) { 220 | Stmt body = mutate(op->body); 221 | if (body.same_as(op->body)) { 222 | stmt = s; 223 | } else { 224 | stmt = ProducerConsumer::make(op->func, op->is_producer, body); 225 | } 226 | } 227 | 228 | void IRMutator::visit(const For *op, const Stmt &s) { 229 | Expr min = mutate(op->min); 230 | Expr extent = mutate(op->extent); 231 | Stmt body = mutate(op->body); 232 | if (min.same_as(op->min) && 233 | extent.same_as(op->extent) && 234 | body.same_as(op->body)) { 235 | stmt = s; 236 | } else { 237 | stmt = For::make( 238 | op->loop_var, min, extent, op->for_type, op->device_api, body); 239 | } 240 | } 241 | 242 | void IRMutator::visit(const Store *op, const Stmt &s) { 243 | Expr value = mutate(op->value); 244 | Expr index = mutate(op->index); 245 | Expr predicate = mutate(op->predicate); 246 | if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) { 247 | stmt = s; 248 | } else { 249 | stmt = Store::make(op->buffer_var, value, index, predicate); 250 | } 251 | } 252 | 253 | void IRMutator::visit(const Provide *op, const Stmt &s) { 254 | vector new_args(op->args.size()); 255 | 256 | bool changed = false; 257 | 258 | // Mutate the args 259 | for (size_t i = 0; i < op->args.size(); i++) { 260 | Expr old_arg = op->args[i]; 261 | Expr new_arg = mutate(old_arg); 262 | if (!new_arg.same_as(old_arg)) changed = true; 263 | new_args[i] = new_arg; 264 | } 265 | Expr old_value = op->value; 266 | Expr new_value = mutate(old_value); 267 | if (!new_value.same_as(old_value)) changed = true; 268 | 269 | if (!changed) { 270 | stmt = s; 271 | } else { 272 | stmt = Provide::make(op->func, op->value_index, new_value, new_args); 273 | } 274 | } 275 | 276 | void IRMutator::visit(const Allocate *op, const Stmt &s) { 277 | std::vector new_extents; 278 | bool all_extents_unmodified = true; 279 | for (size_t i = 0; i < op->extents.size(); i++) { 280 | new_extents.push_back(mutate(op->extents[i])); 281 | all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); 282 | } 283 | Stmt body = mutate(op->body); 284 | Expr condition = mutate(op->condition); 285 | Expr new_expr; 286 | if (op->new_expr.defined()) { 287 | new_expr = mutate(op->new_expr); 288 | } 289 | if (all_extents_unmodified && 290 | body.same_as(op->body) && 291 | condition.same_as(op->condition) && 292 | new_expr.same_as(op->new_expr)) { 293 | stmt = s; 294 | } else { 295 | stmt = Allocate::make(op->buffer_var, op->type, new_extents, condition, body, new_expr, op->free_function); 296 | } 297 | } 298 | 299 | void IRMutator::visit(const Free *op, const Stmt &s) { 300 | stmt = s; 301 | } 302 | 303 | void IRMutator::visit(const Realize *op, const Stmt &s) { 304 | Region new_bounds; 305 | bool bounds_changed = false; 306 | 307 | // Mutate the bounds 308 | for (size_t i = 0; i < op->bounds.size(); i++) { 309 | Expr old_min = op->bounds[i]->min; 310 | Expr old_extent = op->bounds[i]->extent; 311 | Expr new_min = mutate(old_min); 312 | Expr new_extent = mutate(old_extent); 313 | if (!new_min.same_as(old_min)) bounds_changed = true; 314 | if (!new_extent.same_as(old_extent)) bounds_changed = true; 315 | new_bounds.push_back( 316 | Range::make_by_min_extent(new_min, new_extent)); 317 | } 318 | 319 | Stmt body = mutate(op->body); 320 | Expr condition = mutate(op->condition); 321 | if (!bounds_changed && 322 | body.same_as(op->body) && 323 | condition.same_as(op->condition)) { 324 | stmt = s; 325 | } else { 326 | stmt = Realize::make(op->func, op->value_index, 327 | op->type, new_bounds, 328 | condition, body); 329 | } 330 | } 331 | 332 | void IRMutator::visit(const Prefetch *op, const Stmt &s) { 333 | Region new_bounds; 334 | bool bounds_changed = false; 335 | 336 | // Mutate the bounds 337 | for (size_t i = 0; i < op->bounds.size(); i++) { 338 | Expr old_min = op->bounds[i]->min; 339 | Expr old_extent = op->bounds[i]->extent; 340 | Expr new_min = mutate(old_min); 341 | Expr new_extent = mutate(old_extent); 342 | if (!new_min.same_as(old_min)) bounds_changed = true; 343 | if (!new_extent.same_as(old_extent)) bounds_changed = true; 344 | new_bounds.push_back( 345 | Range::make_by_min_extent(new_min, new_extent)); 346 | } 347 | 348 | if (!bounds_changed) { 349 | stmt = s; 350 | } else { 351 | stmt = Prefetch::make(op->func, op->value_index, 352 | op->type, new_bounds); 353 | } 354 | } 355 | 356 | void IRMutator::visit(const Block *op, const Stmt &s) { 357 | Stmt first = mutate(op->first); 358 | Stmt rest = mutate(op->rest); 359 | if (first.same_as(op->first) && 360 | rest.same_as(op->rest)) { 361 | stmt = s; 362 | } else { 363 | stmt = Block::make(first, rest); 364 | } 365 | } 366 | 367 | void IRMutator::visit(const IfThenElse *op, const Stmt &s) { 368 | Expr condition = mutate(op->condition); 369 | Stmt then_case = mutate(op->then_case); 370 | Stmt else_case = mutate(op->else_case); 371 | if (condition.same_as(op->condition) && 372 | then_case.same_as(op->then_case) && 373 | else_case.same_as(op->else_case)) { 374 | stmt = s; 375 | } else { 376 | stmt = IfThenElse::make(condition, then_case, else_case); 377 | } 378 | } 379 | 380 | void IRMutator::visit(const Evaluate *op, const Stmt &s) { 381 | Expr v = mutate(op->value); 382 | if (v.same_as(op->value)) { 383 | stmt = s; 384 | } else { 385 | stmt = Evaluate::make(v); 386 | } 387 | } 388 | 389 | void IRMutator::visit(const Shuffle *op, const Expr& e) { 390 | Array new_vectors; 391 | bool changed = false; 392 | 393 | for (size_t i = 0; i < op->vectors.size(); i++) { 394 | Expr old_vector = op->vectors[i]; 395 | Expr new_vector = mutate(old_vector); 396 | if (!new_vector.same_as(old_vector)) changed = true; 397 | new_vectors.push_back(new_vector); 398 | } 399 | 400 | if (!changed) { 401 | expr = e; 402 | } else { 403 | expr = Shuffle::make(new_vectors, op->indices); 404 | } 405 | } 406 | 407 | Stmt IRGraphMutator::mutate(Stmt s) { 408 | auto iter = stmt_replacements.find(s); 409 | if (iter != stmt_replacements.end()) { 410 | return iter->second; 411 | } 412 | Stmt new_s = IRMutator::mutate(s); 413 | stmt_replacements[s] = new_s; 414 | return new_s; 415 | } 416 | 417 | Expr IRGraphMutator::mutate(Expr e) { 418 | auto iter = expr_replacements.find(e); 419 | if (iter != expr_replacements.end()) { 420 | return iter->second; 421 | } 422 | Expr new_e = IRMutator::mutate(e); 423 | expr_replacements[e] = new_e; 424 | return new_e; 425 | } 426 | 427 | } 428 | } 429 | -------------------------------------------------------------------------------- /src/ir/IRMutator.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_IR_MUTATOR_H 2 | #define HALIDEIR_IR_MUTATOR_H 3 | 4 | /** \file 5 | * Defines a base class for passes over the IR that modify it 6 | */ 7 | 8 | #include "IRVisitor.h" 9 | #include 10 | 11 | namespace HalideIR { 12 | namespace Internal { 13 | 14 | /** A base class for passes over the IR which modify it 15 | * (e.g. replacing a variable with a value (Substitute.h), or 16 | * constant-folding). 17 | * 18 | * Your mutate should override the visit methods you care about. Return 19 | * the new expression by assigning to expr or stmt. The default ones 20 | * recursively mutate their children. To mutate sub-expressions and 21 | * sub-statements you should the mutate method, which will dispatch to 22 | * the appropriate visit method and then return the value of expr or 23 | * stmt after the call to visit. 24 | */ 25 | class IRMutator : public IRVisitor { 26 | public: 27 | 28 | /** This is the main interface for using a mutator. Also call 29 | * these in your subclass to mutate sub-expressions and 30 | * sub-statements. 31 | */ 32 | virtual EXPORT Expr mutate(Expr expr); 33 | virtual EXPORT Stmt mutate(Stmt stmt); 34 | 35 | protected: 36 | /** visit methods that take Exprs assign to this to return their 37 | * new value */ 38 | Expr expr; 39 | 40 | /** visit methods that take Stmts assign to this to return their 41 | * new value */ 42 | Stmt stmt; 43 | 44 | protected: 45 | EXPORT virtual void visit(const IntImm *, const Expr &); 46 | EXPORT virtual void visit(const UIntImm *, const Expr &); 47 | EXPORT virtual void visit(const FloatImm *, const Expr &); 48 | EXPORT virtual void visit(const StringImm *, const Expr &); 49 | EXPORT virtual void visit(const Cast *, const Expr &); 50 | EXPORT virtual void visit(const Variable *, const Expr &); 51 | EXPORT virtual void visit(const Add *, const Expr &); 52 | EXPORT virtual void visit(const Sub *, const Expr &); 53 | EXPORT virtual void visit(const Mul *, const Expr &); 54 | EXPORT virtual void visit(const Div *, const Expr &); 55 | EXPORT virtual void visit(const Mod *, const Expr &); 56 | EXPORT virtual void visit(const Min *, const Expr &); 57 | EXPORT virtual void visit(const Max *, const Expr &); 58 | EXPORT virtual void visit(const EQ *, const Expr &); 59 | EXPORT virtual void visit(const NE *, const Expr &); 60 | EXPORT virtual void visit(const LT *, const Expr &); 61 | EXPORT virtual void visit(const LE *, const Expr &); 62 | EXPORT virtual void visit(const GT *, const Expr &); 63 | EXPORT virtual void visit(const GE *, const Expr &); 64 | EXPORT virtual void visit(const And *, const Expr &); 65 | EXPORT virtual void visit(const Or *, const Expr &); 66 | EXPORT virtual void visit(const Not *, const Expr &); 67 | EXPORT virtual void visit(const Select *, const Expr &); 68 | EXPORT virtual void visit(const Load *, const Expr &); 69 | EXPORT virtual void visit(const Ramp *, const Expr &); 70 | EXPORT virtual void visit(const Broadcast *, const Expr &); 71 | EXPORT virtual void visit(const Call *, const Expr &); 72 | EXPORT virtual void visit(const Let *, const Expr &); 73 | EXPORT virtual void visit(const LetStmt *, const Stmt &); 74 | EXPORT virtual void visit(const AttrStmt *, const Stmt &); 75 | EXPORT virtual void visit(const AssertStmt *, const Stmt &); 76 | EXPORT virtual void visit(const ProducerConsumer *, const Stmt &); 77 | EXPORT virtual void visit(const For *, const Stmt &); 78 | EXPORT virtual void visit(const Store *, const Stmt &); 79 | EXPORT virtual void visit(const Provide *, const Stmt &); 80 | EXPORT virtual void visit(const Allocate *, const Stmt &); 81 | EXPORT virtual void visit(const Free *, const Stmt &); 82 | EXPORT virtual void visit(const Realize *, const Stmt &); 83 | EXPORT virtual void visit(const Prefetch *, const Stmt &); 84 | EXPORT virtual void visit(const Block *, const Stmt &); 85 | EXPORT virtual void visit(const IfThenElse *, const Stmt &); 86 | EXPORT virtual void visit(const Evaluate *, const Stmt &); 87 | EXPORT virtual void visit(const Shuffle *, const Expr &); 88 | }; 89 | 90 | 91 | /** A mutator that caches and reapplies previously-done mutations, so 92 | * that it can handle graphs of IR that have not had CSE done to 93 | * them. */ 94 | class IRGraphMutator : public IRMutator { 95 | protected: 96 | std::unordered_map expr_replacements; 97 | std::unordered_map stmt_replacements; 98 | 99 | public: 100 | EXPORT Stmt mutate(Stmt s); 101 | EXPORT Expr mutate(Expr e); 102 | }; 103 | 104 | 105 | } 106 | } 107 | 108 | #endif 109 | -------------------------------------------------------------------------------- /src/ir/IRPrinter.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_IR_PRINTER_H 2 | #define HALIDEIR_IR_PRINTER_H 3 | 4 | /** \file 5 | * This header file defines operators that let you dump a Halide 6 | * expression, statement, or type directly into an output stream 7 | * in a human readable form. 8 | * E.g: 9 | \code 10 | Expr foo = ... 11 | std::cout << "Foo is " << foo << std::endl; 12 | \endcode 13 | * 14 | * These operators are implemented using \ref HalideIR::Internal::IRPrinter 15 | */ 16 | 17 | #include 18 | #include "./IR.h" 19 | #include "./IRVisitor.h" 20 | 21 | namespace HalideIR { 22 | 23 | /** Emit an expression on an output stream (such as std::cout) in a 24 | * human-readable form */ 25 | EXPORT std::ostream &operator<<(std::ostream &stream, const Expr &); 26 | 27 | /** Emit a halide type on an output stream (such as std::cout) in a 28 | * human-readable form */ 29 | EXPORT std::ostream &operator<<(std::ostream &stream, const Type &); 30 | 31 | 32 | /** Emit a halide device api type in a human readable form */ 33 | EXPORT std::ostream &operator<<(std::ostream &stream, const DeviceAPI &); 34 | 35 | namespace Internal { 36 | 37 | /** Emit a halide statement on an output stream (such as std::cout) in 38 | * a human-readable form */ 39 | EXPORT std::ostream &operator<<(std::ostream &stream, const Stmt &); 40 | 41 | /** Emit a halide for loop type (vectorized, serial, etc) in a human 42 | * readable form */ 43 | EXPORT std::ostream &operator<<(std::ostream &stream, const ForType &); 44 | 45 | /** 46 | * An IRVisitor that emits IR to the given output stream in a human 47 | * readable form. Can be subclassed if you want to modify the way in 48 | * which it prints. 49 | * 50 | * IRPrinter is re-implemeneted using IRFunctor, as a demonstration 51 | * example on how Visitor based printing can be adopted to IRFunctor. 52 | * 53 | */ 54 | class IRPrinter { 55 | public: 56 | /** Construct an IRPrinter pointed at a given output stream 57 | * (e.g. std::cout, or a std::ofstream) */ 58 | EXPORT IRPrinter(std::ostream &); 59 | 60 | /** emit an expression on the output stream */ 61 | EXPORT void print(const NodeRef&); 62 | 63 | EXPORT static void test(); 64 | 65 | /** The stream we're outputting on */ 66 | std::ostream &stream; 67 | 68 | /** The current indentation level, useful for pretty-printing 69 | * statements */ 70 | int indent; 71 | 72 | /** Emit spaces according to the current indentation level */ 73 | void do_indent(); 74 | 75 | using FType = tvm::IRFunctor; 76 | 77 | EXPORT static FType& vtable(); 78 | }; 79 | } 80 | } 81 | 82 | #endif 83 | -------------------------------------------------------------------------------- /src/ir/IRVisitor.cpp: -------------------------------------------------------------------------------- 1 | #include "IRVisitor.h" 2 | 3 | namespace HalideIR { 4 | namespace Internal { 5 | 6 | IRVisitor::~IRVisitor() { 7 | } 8 | 9 | void IRVisitor::visit(const IntImm *, const Expr &) { 10 | } 11 | 12 | void IRVisitor::visit(const UIntImm *, const Expr &) { 13 | } 14 | 15 | void IRVisitor::visit(const FloatImm *, const Expr &) { 16 | } 17 | 18 | void IRVisitor::visit(const StringImm *, const Expr &) { 19 | } 20 | 21 | void IRVisitor::visit(const Cast *op, const Expr &) { 22 | op->value.accept(this); 23 | } 24 | 25 | void IRVisitor::visit(const Variable *, const Expr &) { 26 | } 27 | 28 | void IRVisitor::visit(const Add *op, const Expr &) { 29 | op->a.accept(this); 30 | op->b.accept(this); 31 | } 32 | 33 | void IRVisitor::visit(const Sub *op, const Expr &) { 34 | op->a.accept(this); 35 | op->b.accept(this); 36 | } 37 | 38 | void IRVisitor::visit(const Mul *op, const Expr &) { 39 | op->a.accept(this); 40 | op->b.accept(this); 41 | } 42 | 43 | void IRVisitor::visit(const Div *op, const Expr &) { 44 | op->a.accept(this); 45 | op->b.accept(this); 46 | } 47 | 48 | void IRVisitor::visit(const Mod *op, const Expr &) { 49 | op->a.accept(this); 50 | op->b.accept(this); 51 | } 52 | 53 | void IRVisitor::visit(const Min *op, const Expr &) { 54 | op->a.accept(this); 55 | op->b.accept(this); 56 | } 57 | 58 | void IRVisitor::visit(const Max *op, const Expr &) { 59 | op->a.accept(this); 60 | op->b.accept(this); 61 | } 62 | 63 | void IRVisitor::visit(const EQ *op, const Expr &) { 64 | op->a.accept(this); 65 | op->b.accept(this); 66 | } 67 | 68 | void IRVisitor::visit(const NE *op, const Expr &) { 69 | op->a.accept(this); 70 | op->b.accept(this); 71 | } 72 | 73 | void IRVisitor::visit(const LT *op, const Expr &) { 74 | op->a.accept(this); 75 | op->b.accept(this); 76 | } 77 | 78 | void IRVisitor::visit(const LE *op, const Expr &) { 79 | op->a.accept(this); 80 | op->b.accept(this); 81 | } 82 | 83 | void IRVisitor::visit(const GT *op, const Expr &) { 84 | op->a.accept(this); 85 | op->b.accept(this); 86 | } 87 | 88 | void IRVisitor::visit(const GE *op, const Expr &) { 89 | op->a.accept(this); 90 | op->b.accept(this); 91 | } 92 | 93 | void IRVisitor::visit(const And *op, const Expr &) { 94 | op->a.accept(this); 95 | op->b.accept(this); 96 | } 97 | 98 | void IRVisitor::visit(const Or *op, const Expr &) { 99 | op->a.accept(this); 100 | op->b.accept(this); 101 | } 102 | 103 | void IRVisitor::visit(const Not *op, const Expr &) { 104 | op->a.accept(this); 105 | } 106 | 107 | void IRVisitor::visit(const Select *op, const Expr &) { 108 | op->condition.accept(this); 109 | op->true_value.accept(this); 110 | op->false_value.accept(this); 111 | } 112 | 113 | void IRVisitor::visit(const Load *op, const Expr &) { 114 | op->index.accept(this); 115 | op->predicate.accept(this); 116 | } 117 | 118 | void IRVisitor::visit(const Ramp *op, const Expr &) { 119 | op->base.accept(this); 120 | op->stride.accept(this); 121 | } 122 | 123 | void IRVisitor::visit(const Broadcast *op, const Expr &) { 124 | op->value.accept(this); 125 | } 126 | 127 | void IRVisitor::visit(const Call *op, const Expr &) { 128 | for (size_t i = 0; i < op->args.size(); i++) { 129 | op->args[i].accept(this); 130 | } 131 | 132 | // removed: Consider extern call args 133 | } 134 | 135 | void IRVisitor::visit(const Let *op, const Expr &) { 136 | op->value.accept(this); 137 | op->body.accept(this); 138 | } 139 | 140 | void IRVisitor::visit(const LetStmt *op, const Stmt &) { 141 | op->value.accept(this); 142 | op->body.accept(this); 143 | } 144 | 145 | void IRVisitor::visit(const AttrStmt *op, const Stmt &) { 146 | op->value.accept(this); 147 | op->body.accept(this); 148 | } 149 | 150 | void IRVisitor::visit(const AssertStmt *op, const Stmt &) { 151 | op->condition.accept(this); 152 | op->message.accept(this); 153 | op->body.accept(this); 154 | } 155 | 156 | void IRVisitor::visit(const ProducerConsumer *op, const Stmt &) { 157 | op->body.accept(this); 158 | } 159 | 160 | void IRVisitor::visit(const For *op, const Stmt &) { 161 | op->min.accept(this); 162 | op->extent.accept(this); 163 | op->body.accept(this); 164 | } 165 | 166 | void IRVisitor::visit(const Store *op, const Stmt &) { 167 | op->value.accept(this); 168 | op->index.accept(this); 169 | op->predicate.accept(this); 170 | } 171 | 172 | void IRVisitor::visit(const Provide *op, const Stmt &) { 173 | op->value.accept(this); 174 | for (size_t i = 0; i < op->args.size(); i++) { 175 | op->args[i].accept(this); 176 | } 177 | } 178 | 179 | void IRVisitor::visit(const Allocate *op, const Stmt &) { 180 | for (size_t i = 0; i < op->extents.size(); i++) { 181 | op->extents[i].accept(this); 182 | } 183 | op->condition.accept(this); 184 | if (op->new_expr.defined()) { 185 | op->new_expr.accept(this); 186 | } 187 | op->body.accept(this); 188 | } 189 | 190 | void IRVisitor::visit(const Free *op, const Stmt &) { 191 | } 192 | 193 | void IRVisitor::visit(const Realize *op, const Stmt &) { 194 | for (size_t i = 0; i < op->bounds.size(); i++) { 195 | op->bounds[i]->min.accept(this); 196 | op->bounds[i]->extent.accept(this); 197 | } 198 | op->condition.accept(this); 199 | op->body.accept(this); 200 | } 201 | 202 | void IRVisitor::visit(const Prefetch *op, const Stmt &) { 203 | for (size_t i = 0; i < op->bounds.size(); i++) { 204 | op->bounds[i]->min.accept(this); 205 | op->bounds[i]->extent.accept(this); 206 | } 207 | } 208 | 209 | void IRVisitor::visit(const Block *op, const Stmt &) { 210 | op->first.accept(this); 211 | if (op->rest.defined()) { 212 | op->rest.accept(this); 213 | } 214 | } 215 | 216 | void IRVisitor::visit(const IfThenElse *op, const Stmt &) { 217 | op->condition.accept(this); 218 | op->then_case.accept(this); 219 | if (op->else_case.defined()) { 220 | op->else_case.accept(this); 221 | } 222 | } 223 | 224 | void IRVisitor::visit(const Evaluate *op, const Stmt &) { 225 | op->value.accept(this); 226 | } 227 | 228 | void IRVisitor::visit(const Shuffle *op, const Expr &) { 229 | for (Expr i : op->vectors) { 230 | i.accept(this); 231 | } 232 | } 233 | 234 | void IRGraphVisitor::include(const Expr &e) { 235 | if (visited.count(e.get())) { 236 | return; 237 | } else { 238 | visited.insert(e.get()); 239 | e.accept(this); 240 | return; 241 | } 242 | } 243 | 244 | void IRGraphVisitor::include(const Stmt &s) { 245 | if (visited.count(s.get())) { 246 | return; 247 | } else { 248 | visited.insert(s.get()); 249 | s.accept(this); 250 | return; 251 | } 252 | } 253 | 254 | void IRGraphVisitor::visit(const IntImm *, const Expr &) { 255 | } 256 | 257 | void IRGraphVisitor::visit(const UIntImm *, const Expr &) { 258 | } 259 | 260 | void IRGraphVisitor::visit(const FloatImm *, const Expr &) { 261 | } 262 | 263 | void IRGraphVisitor::visit(const StringImm *, const Expr &) { 264 | } 265 | 266 | void IRGraphVisitor::visit(const Cast *op, const Expr &) { 267 | include(op->value); 268 | } 269 | 270 | void IRGraphVisitor::visit(const Variable *op, const Expr &) { 271 | } 272 | 273 | void IRGraphVisitor::visit(const Add *op, const Expr &) { 274 | include(op->a); 275 | include(op->b); 276 | } 277 | 278 | void IRGraphVisitor::visit(const Sub *op, const Expr &) { 279 | include(op->a); 280 | include(op->b); 281 | } 282 | 283 | void IRGraphVisitor::visit(const Mul *op, const Expr &) { 284 | include(op->a); 285 | include(op->b); 286 | } 287 | 288 | void IRGraphVisitor::visit(const Div *op, const Expr &) { 289 | include(op->a); 290 | include(op->b); 291 | } 292 | 293 | void IRGraphVisitor::visit(const Mod *op, const Expr &) { 294 | include(op->a); 295 | include(op->b); 296 | } 297 | 298 | void IRGraphVisitor::visit(const Min *op, const Expr &) { 299 | include(op->a); 300 | include(op->b); 301 | } 302 | 303 | void IRGraphVisitor::visit(const Max *op, const Expr &) { 304 | include(op->a); 305 | include(op->b); 306 | } 307 | 308 | void IRGraphVisitor::visit(const EQ *op, const Expr &) { 309 | include(op->a); 310 | include(op->b); 311 | } 312 | 313 | void IRGraphVisitor::visit(const NE *op, const Expr &) { 314 | include(op->a); 315 | include(op->b); 316 | } 317 | 318 | void IRGraphVisitor::visit(const LT *op, const Expr &) { 319 | include(op->a); 320 | include(op->b); 321 | } 322 | 323 | void IRGraphVisitor::visit(const LE *op, const Expr &) { 324 | include(op->a); 325 | include(op->b); 326 | } 327 | 328 | void IRGraphVisitor::visit(const GT *op, const Expr &) { 329 | include(op->a); 330 | include(op->b); 331 | } 332 | 333 | void IRGraphVisitor::visit(const GE *op, const Expr &) { 334 | include(op->a); 335 | include(op->b); 336 | } 337 | 338 | void IRGraphVisitor::visit(const And *op, const Expr &) { 339 | include(op->a); 340 | include(op->b); 341 | } 342 | 343 | void IRGraphVisitor::visit(const Or *op, const Expr &) { 344 | include(op->a); 345 | include(op->b); 346 | } 347 | 348 | void IRGraphVisitor::visit(const Not *op, const Expr &) { 349 | include(op->a); 350 | } 351 | 352 | void IRGraphVisitor::visit(const Select *op, const Expr &) { 353 | include(op->condition); 354 | include(op->true_value); 355 | include(op->false_value); 356 | } 357 | 358 | void IRGraphVisitor::visit(const Load *op, const Expr &) { 359 | include(op->index); 360 | include(op->predicate); 361 | } 362 | 363 | void IRGraphVisitor::visit(const Ramp *op, const Expr &) { 364 | include(op->base); 365 | include(op->stride); 366 | } 367 | 368 | void IRGraphVisitor::visit(const Broadcast *op, const Expr &) { 369 | include(op->value); 370 | } 371 | 372 | void IRGraphVisitor::visit(const Call *op, const Expr &) { 373 | for (size_t i = 0; i < op->args.size(); i++) { 374 | include(op->args[i]); 375 | } 376 | } 377 | 378 | void IRGraphVisitor::visit(const Let *op, const Expr &) { 379 | include(op->value); 380 | include(op->body); 381 | } 382 | 383 | void IRGraphVisitor::visit(const LetStmt *op, const Stmt &) { 384 | include(op->value); 385 | include(op->body); 386 | } 387 | 388 | void IRGraphVisitor::visit(const AssertStmt *op, const Stmt &) { 389 | include(op->condition); 390 | include(op->message); 391 | include(op->body); 392 | } 393 | 394 | void IRGraphVisitor::visit(const ProducerConsumer *op, const Stmt &) { 395 | include(op->body); 396 | } 397 | 398 | void IRGraphVisitor::visit(const For *op, const Stmt &) { 399 | include(op->min); 400 | include(op->extent); 401 | include(op->body); 402 | } 403 | 404 | void IRGraphVisitor::visit(const Store *op, const Stmt &) { 405 | include(op->value); 406 | include(op->index); 407 | include(op->predicate); 408 | } 409 | 410 | void IRGraphVisitor::visit(const Provide *op, const Stmt &) { 411 | include(op->value); 412 | for (size_t i = 0; i < op->args.size(); i++) { 413 | include(op->args[i]); 414 | } 415 | } 416 | 417 | void IRGraphVisitor::visit(const Allocate *op, const Stmt &) { 418 | for (size_t i = 0; i < op->extents.size(); i++) { 419 | include(op->extents[i]); 420 | } 421 | include(op->condition); 422 | if (op->new_expr.defined()) { 423 | include(op->new_expr); 424 | } 425 | include(op->body); 426 | } 427 | 428 | void IRGraphVisitor::visit(const Free *op, const Stmt &) { 429 | } 430 | 431 | void IRGraphVisitor::visit(const Realize *op, const Stmt &) { 432 | for (size_t i = 0; i < op->bounds.size(); i++) { 433 | include(op->bounds[i]->min); 434 | include(op->bounds[i]->extent); 435 | } 436 | include(op->condition); 437 | include(op->body); 438 | } 439 | 440 | void IRGraphVisitor::visit(const Prefetch *op, const Stmt &) { 441 | for (size_t i = 0; i < op->bounds.size(); i++) { 442 | include(op->bounds[i]->min); 443 | include(op->bounds[i]->extent); 444 | } 445 | } 446 | 447 | void IRGraphVisitor::visit(const Block *op, const Stmt &) { 448 | include(op->first); 449 | if (op->rest.defined()) include(op->rest); 450 | } 451 | 452 | void IRGraphVisitor::visit(const IfThenElse *op, const Stmt &) { 453 | include(op->condition); 454 | include(op->then_case); 455 | if (op->else_case.defined()) { 456 | include(op->else_case); 457 | } 458 | } 459 | 460 | void IRGraphVisitor::visit(const Evaluate *op, const Stmt &) { 461 | include(op->value); 462 | } 463 | 464 | void IRGraphVisitor::visit(const Shuffle *op, const Expr &) { 465 | for (Expr i : op->vectors) { 466 | include(i); 467 | } 468 | } 469 | 470 | } 471 | } 472 | -------------------------------------------------------------------------------- /src/ir/IRVisitor.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDEIR_IR_VISITOR_H 2 | #define HALIDEIR_IR_VISITOR_H 3 | 4 | #include "IR.h" 5 | #include "base/Util.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | /** \file 12 | * Defines the base class for things that recursively walk over the IR 13 | */ 14 | 15 | namespace HalideIR { 16 | namespace Internal { 17 | 18 | /** A base class for algorithms that need to recursively walk over the 19 | * IR. The default implementations just recursively walk over the 20 | * children. Override the ones you care about. 21 | */ 22 | class IRVisitor { 23 | public: 24 | EXPORT virtual ~IRVisitor(); 25 | EXPORT virtual void visit(const IntImm *, const Expr &); 26 | EXPORT virtual void visit(const UIntImm *, const Expr &); 27 | EXPORT virtual void visit(const FloatImm *, const Expr &); 28 | EXPORT virtual void visit(const StringImm *, const Expr &); 29 | EXPORT virtual void visit(const Cast *, const Expr &); 30 | EXPORT virtual void visit(const Variable *, const Expr &); 31 | EXPORT virtual void visit(const Add *, const Expr &); 32 | EXPORT virtual void visit(const Sub *, const Expr &); 33 | EXPORT virtual void visit(const Mul *, const Expr &); 34 | EXPORT virtual void visit(const Div *, const Expr &); 35 | EXPORT virtual void visit(const Mod *, const Expr &); 36 | EXPORT virtual void visit(const Min *, const Expr &); 37 | EXPORT virtual void visit(const Max *, const Expr &); 38 | EXPORT virtual void visit(const EQ *, const Expr &); 39 | EXPORT virtual void visit(const NE *, const Expr &); 40 | EXPORT virtual void visit(const LT *, const Expr &); 41 | EXPORT virtual void visit(const LE *, const Expr &); 42 | EXPORT virtual void visit(const GT *, const Expr &); 43 | EXPORT virtual void visit(const GE *, const Expr &); 44 | EXPORT virtual void visit(const And *, const Expr &); 45 | EXPORT virtual void visit(const Or *, const Expr &); 46 | EXPORT virtual void visit(const Not *, const Expr &); 47 | EXPORT virtual void visit(const Select *, const Expr &); 48 | EXPORT virtual void visit(const Load *, const Expr &); 49 | EXPORT virtual void visit(const Ramp *, const Expr &); 50 | EXPORT virtual void visit(const Broadcast *, const Expr &); 51 | EXPORT virtual void visit(const Call *, const Expr &); 52 | EXPORT virtual void visit(const Let *, const Expr &); 53 | EXPORT virtual void visit(const Shuffle *, const Expr &); 54 | EXPORT virtual void visit(const LetStmt *, const Stmt &); 55 | EXPORT virtual void visit(const AttrStmt *, const Stmt &); 56 | EXPORT virtual void visit(const AssertStmt *, const Stmt &); 57 | EXPORT virtual void visit(const ProducerConsumer *, const Stmt &); 58 | EXPORT virtual void visit(const For *, const Stmt &); 59 | EXPORT virtual void visit(const Store *, const Stmt &); 60 | EXPORT virtual void visit(const Provide *, const Stmt &); 61 | EXPORT virtual void visit(const Allocate *, const Stmt &); 62 | EXPORT virtual void visit(const Free *, const Stmt &); 63 | EXPORT virtual void visit(const Realize *, const Stmt &); 64 | EXPORT virtual void visit(const Prefetch *, const Stmt &); 65 | EXPORT virtual void visit(const Block *, const Stmt &); 66 | EXPORT virtual void visit(const IfThenElse *, const Stmt &); 67 | EXPORT virtual void visit(const Evaluate *, const Stmt &); 68 | }; 69 | 70 | /** A base class for algorithms that walk recursively over the IR 71 | * without visiting the same node twice. This is for passes that are 72 | * capable of interpreting the IR as a DAG instead of a tree. */ 73 | class IRGraphVisitor : public IRVisitor { 74 | protected: 75 | /** By default these methods add the node to the visited set, and 76 | * return whether or not it was already there. If it wasn't there, 77 | * it delegates to the appropriate visit method. You can override 78 | * them if you like. */ 79 | // @{ 80 | EXPORT virtual void include(const Expr &); 81 | EXPORT virtual void include(const Stmt &); 82 | // @} 83 | 84 | /** The nodes visited so far */ 85 | std::set visited; 86 | 87 | public: 88 | 89 | /** These methods should call 'include' on the children to only 90 | * visit them if they haven't been visited already. */ 91 | // @{ 92 | EXPORT virtual void visit(const IntImm *, const Expr &); 93 | EXPORT virtual void visit(const UIntImm *, const Expr &); 94 | EXPORT virtual void visit(const FloatImm *, const Expr &); 95 | EXPORT virtual void visit(const StringImm *, const Expr &); 96 | EXPORT virtual void visit(const Cast *, const Expr &); 97 | EXPORT virtual void visit(const Variable *, const Expr &); 98 | EXPORT virtual void visit(const Add *, const Expr &); 99 | EXPORT virtual void visit(const Sub *, const Expr &); 100 | EXPORT virtual void visit(const Mul *, const Expr &); 101 | EXPORT virtual void visit(const Div *, const Expr &); 102 | EXPORT virtual void visit(const Mod *, const Expr &); 103 | EXPORT virtual void visit(const Min *, const Expr &); 104 | EXPORT virtual void visit(const Max *, const Expr &); 105 | EXPORT virtual void visit(const EQ *, const Expr &); 106 | EXPORT virtual void visit(const NE *, const Expr &); 107 | EXPORT virtual void visit(const LT *, const Expr &); 108 | EXPORT virtual void visit(const LE *, const Expr &); 109 | EXPORT virtual void visit(const GT *, const Expr &); 110 | EXPORT virtual void visit(const GE *, const Expr &); 111 | EXPORT virtual void visit(const And *, const Expr &); 112 | EXPORT virtual void visit(const Or *, const Expr &); 113 | EXPORT virtual void visit(const Not *, const Expr &); 114 | EXPORT virtual void visit(const Select *, const Expr &); 115 | EXPORT virtual void visit(const Load *, const Expr &); 116 | EXPORT virtual void visit(const Ramp *, const Expr &); 117 | EXPORT virtual void visit(const Broadcast *, const Expr &); 118 | EXPORT virtual void visit(const Call *, const Expr &); 119 | EXPORT virtual void visit(const Let *, const Expr &); 120 | EXPORT virtual void visit(const Shuffle *, const Expr &); 121 | EXPORT virtual void visit(const LetStmt *, const Stmt &); 122 | EXPORT virtual void visit(const AssertStmt *, const Stmt &); 123 | EXPORT virtual void visit(const ProducerConsumer *, const Stmt &); 124 | EXPORT virtual void visit(const For *, const Stmt &); 125 | EXPORT virtual void visit(const Store *, const Stmt &); 126 | EXPORT virtual void visit(const Provide *, const Stmt &); 127 | EXPORT virtual void visit(const Allocate *, const Stmt &); 128 | EXPORT virtual void visit(const Free *, const Stmt &); 129 | EXPORT virtual void visit(const Realize *, const Stmt &); 130 | EXPORT virtual void visit(const Prefetch *, const Stmt &); 131 | EXPORT virtual void visit(const Block *, const Stmt &); 132 | EXPORT virtual void visit(const IfThenElse *, const Stmt &); 133 | EXPORT virtual void visit(const Evaluate *, const Stmt &); 134 | // @} 135 | }; 136 | 137 | } 138 | } 139 | 140 | #endif 141 | -------------------------------------------------------------------------------- /src/ir/Range.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2016 by Contributors 3 | * \file Range.h 4 | * \brief The Range data structure 5 | */ 6 | #ifndef HALIDEIR_IR_RANGE_H_ 7 | #define HALIDEIR_IR_RANGE_H_ 8 | 9 | #include 10 | #include "Expr.h" 11 | 12 | namespace HalideIR { 13 | namespace IR { 14 | 15 | // Internal node container of Range 16 | class RangeNode; 17 | 18 | /*! \brief Node range */ 19 | class Range : public NodeRef { 20 | public: 21 | /*! \brief constructor */ 22 | Range() {} 23 | Range(NodePtr n) : NodeRef(n) {} 24 | /*! 25 | * \brief access the internal node container 26 | * \return the pointer to the internal node container 27 | */ 28 | inline const RangeNode* operator->() const; 29 | /*! \brief specify container node */ 30 | using ContainerType = RangeNode; 31 | /*! 32 | * \brief construct a new range with min and extent 33 | * The corresponding constructor is removed, 34 | * because that is counter convention of tradition meaning 35 | * of range(begin, end) 36 | * 37 | * \param min The minimum range. 38 | * \param extent The extent of the range. 39 | */ 40 | static inline Range make_by_min_extent(Expr min, Expr extent); 41 | }; 42 | 43 | /*! \brief range over one dimension */ 44 | class RangeNode : public Node { 45 | public: 46 | /*! \brief beginning of the node */ 47 | Expr min; 48 | /*! \brief the extend of range */ 49 | Expr extent; 50 | /*! \brief constructor */ 51 | RangeNode() {} 52 | RangeNode(Expr min, Expr extent) : min(min), extent(extent) {} 53 | 54 | void VisitAttrs(IR::AttrVisitor* v) final { 55 | v->Visit("min", &min); 56 | v->Visit("extent", &extent); 57 | } 58 | 59 | static constexpr const char* _type_key = "Range"; 60 | TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node); 61 | }; 62 | 63 | // implements of inline functions 64 | inline const RangeNode* Range::operator->() const { 65 | return static_cast(node_.get()); 66 | } 67 | 68 | inline Range Range::make_by_min_extent(Expr min, Expr extent) { 69 | internal_assert(min.type() == extent.type()) 70 | << "Region min and extent must have same type\n"; 71 | NodePtr n = make_node(); 72 | n->min = min; 73 | n->extent = extent; 74 | return Range(n); 75 | } 76 | 77 | // overload print function 78 | inline std::ostream& operator<<(std::ostream &os, const Range& r) { // NOLINT(*) 79 | os << "Range(min=" << r->min << ", extent=" << r->extent <<')'; 80 | return os; 81 | } 82 | 83 | } // namespace IR 84 | } // namespace HalideIR 85 | 86 | #endif // HALIDEIR_IR_H_ 87 | -------------------------------------------------------------------------------- /src/tvm/node/ir_functor.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2018 by Contributors 3 | * \file tvm/node/ir_functor.h 4 | * \brief Defines the IRFunctor data structures. 5 | */ 6 | #ifndef HALIDEIR_TVM_NODE_IR_FUNCTOR_H_ 7 | #define HALIDEIR_TVM_NODE_IR_FUNCTOR_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include "node.h" 16 | 17 | namespace tvm { 18 | /*! 19 | * \brief A dynamical dispatched functor on NodeRef in the first argument. 20 | * 21 | * \code 22 | * IRFunctor tostr; 23 | * tostr.set_dispatch([](const Add* op, std::string prefix) { 24 | * return prefix + "Add"; 25 | * }); 26 | * tostr.set_dispatch([](const IntImm* op) { 27 | * return prefix + "IntImm" 28 | * }); 29 | * 30 | * Expr x = make_const(1); 31 | * Expr y = x + x; 32 | * // dispatch to IntImm, outputs "MyIntImm" 33 | * LOG(INFO) << tostr(x, "My"); 34 | * // dispatch to IntImm, outputs "MyAdd" 35 | * LOG(INFO) << tostr(y, "My"); 36 | * \endcode 37 | * 38 | * \tparam FType function signiture 39 | * This type if only defined for FType with function signiture 40 | */ 41 | template 42 | class IRFunctor; 43 | 44 | template 45 | class IRFunctor { 46 | private: 47 | using Function = std::function; 48 | using TSelf = IRFunctor; 49 | /*! \brief internal function table */ 50 | std::vector func_; 51 | 52 | public: 53 | /*! \brief the result type of this functor */ 54 | using result_type = R; 55 | /*! 56 | * \brief Whether the functor can dispatch the corresponding Node 57 | * \param n The node to be dispatched 58 | * \return Whether dispatching function is registered for n's type. 59 | */ 60 | inline bool can_dispatch(const NodeRef& n) const { 61 | uint32_t type_index = n.type_index(); 62 | return type_index < func_.size() && func_[type_index] != nullptr; 63 | } 64 | /*! 65 | * \brief invoke the functor , dispatch on type of n 66 | * \param n The Node argument 67 | * \param args The additional arguments 68 | * \return The result. 69 | */ 70 | inline R operator()(const NodeRef& n, Args... args) const { 71 | uint32_t type_index = n.type_index(); 72 | CHECK(type_index < func_.size() && 73 | func_[type_index] != nullptr) 74 | << "IRFunctor calls un-registered function on type " 75 | << Node::TypeIndex2Key(type_index); 76 | return func_[type_index](n, std::forward(args)...); 77 | } 78 | /*! 79 | * \brief set the dispacher for type TNode 80 | * \param f The function to be set. 81 | * \tparam TNode the type of Node to be dispatched. 82 | * \return reference to self. 83 | */ 84 | template 85 | inline TSelf& set_dispatch(Function f) { // NOLINT(*) 86 | uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); 87 | if (func_.size() <= tindex) { 88 | func_.resize(tindex + 1, nullptr); 89 | } 90 | CHECK(func_[tindex] == nullptr) 91 | << "Dispatch for " << Node::TypeIndex2Key(tindex) 92 | << " is already set"; 93 | func_[tindex] = f; 94 | return *this; 95 | } 96 | /*! 97 | * \brief set the dispacher for type TNode 98 | * This allows f to used detailed const Node pointer to replace NodeRef 99 | * 100 | * \param f The function to be set. 101 | * \tparam TNode the type of Node to be dispatched. 102 | * \return reference to self. 103 | */ 104 | template 105 | inline TSelf& set_dispatch(std::function f) { // NOLINT(*) 106 | Function fun = [f](const NodeRef& n, Args... args) { 107 | return f(static_cast(n.node_.get()), 108 | std::forward(args)...); 109 | }; 110 | return this->set_dispatch(fun); 111 | } 112 | /*! 113 | * \brief unset the dispacher for type TNode 114 | * 115 | * \tparam TNode the type of Node to be dispatched. 116 | * \return reference to self. 117 | */ 118 | template 119 | inline TSelf& clear_dispatch() { // NOLINT(*) 120 | uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); 121 | CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; 122 | func_[tindex] = nullptr; 123 | return *this; 124 | } 125 | }; 126 | 127 | #if defined(__GNUC__) 128 | #define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) 129 | #else 130 | #define TVM_ATTRIBUTE_UNUSED 131 | #endif 132 | 133 | /*! \brief helper macro to generate string concat */ 134 | #define TVM_STR_CONCAT_(__x, __y) __x##__y 135 | #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) 136 | 137 | #define TVM_REGISTER_VAR_DEF(ClsName) \ 138 | static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName 139 | 140 | /*! 141 | * \brief Useful macro to set IRFunctor dispatch in a global static field. 142 | * 143 | * \code 144 | * // Use IRFunctor to implement IRPrinter similar to Visitor Pattern. 145 | * // vtable allows easy patch in of new Node types, without changing 146 | * // interface of IRPrinter. 147 | * 148 | * class IRPrinter { 149 | * public: 150 | * std::ostream& stream; 151 | * // the dispatch function. 152 | * void print(Expr e) { 153 | * const static FType& f = *vtable(); 154 | * f(e, this); 155 | * } 156 | * 157 | * using FType = IRFunctor; 158 | * // function to return global function table 159 | * static FType& vtable(); 160 | * }; 161 | * 162 | * // in cpp/cc file 163 | * IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*) 164 | * static FType inst; return inst; 165 | * } 166 | * 167 | * TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) 168 | * .set_dispatch([](const Add* n, IRPrinter* p) { 169 | * p->print(n->a); 170 | * p->stream << '+' 171 | * p->print(n->b); 172 | * }); 173 | * 174 | * 175 | * \endcode 176 | * 177 | * \param ClsName The name of the class 178 | * \param FField The static function that returns a singleton of IRFunctor. 179 | */ 180 | #define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ 181 | TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ 182 | ClsName::FField() 183 | 184 | /*! 185 | * \brief A container for a list of callbacks. All callbacks are invoked when 186 | * the object is destructed. 187 | */ 188 | class IRFunctorCleanList { 189 | public: 190 | ~IRFunctorCleanList() { 191 | for (auto &f : clean_items) { 192 | f(); 193 | } 194 | } 195 | 196 | void append(std::function func) { 197 | clean_items.push_back(func); 198 | } 199 | 200 | private: 201 | std::vector< std::function > clean_items; 202 | }; 203 | 204 | /*! 205 | * \brief A wrapper around IRFunctor that will record calls to set_dispatch 206 | * and make a corresponding call to clear_dispatch when the last copy of 207 | * the IRFunctorStaticRegistry is destructed. When assigned to a static variable, 208 | * this can be used by NNVM and other libraries to unregister callbacks when 209 | * the library is unloaded. This prevents crashes when the underlying IRFunctor 210 | * is destructed as it will no longer contain std::function instances allocated 211 | * by a library that has been unloaded. 212 | */ 213 | template 214 | class IRFunctorStaticRegistry; 215 | 216 | template 217 | class IRFunctorStaticRegistry { 218 | private: 219 | IRFunctor *irf_; 220 | std::shared_ptr free_list; 221 | 222 | using TSelf = IRFunctorStaticRegistry; 223 | 224 | public: 225 | IRFunctorStaticRegistry(IRFunctor *irf) { 226 | irf_ = irf; 227 | free_list = std::make_shared(); 228 | } 229 | 230 | template 231 | inline TSelf& set_dispatch(std::function f) { // NOLINT(*) 232 | irf_->template set_dispatch(f); 233 | auto irf_copy = irf_; 234 | free_list.get()->append([irf_copy] { 235 | irf_copy->template clear_dispatch(); 236 | }); 237 | return *this; 238 | } 239 | }; 240 | 241 | /*! 242 | * \brief Helper function for constructing an IRFunctorStaticRegistry. This allows 243 | * the compiler to deduce the template types. 244 | */ 245 | template 246 | IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( 247 | IRFunctor *irf) { 248 | return IRFunctorStaticRegistry(irf); 249 | } 250 | 251 | #define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ 252 | static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName 253 | 254 | /*! 255 | * \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry. 256 | * Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of 257 | * TVM_STATIC_IR_FUNCTOR. 258 | */ 259 | #define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \ 260 | TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ 261 | MakeIRFunctorStaticRegistry(&ClsName::FField()) 262 | 263 | } // namespace tvm 264 | #endif // HALIDEIR_TVM_NODE_IR_FUNCTOR_H_ 265 | -------------------------------------------------------------------------------- /src/tvm/node/memory.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2018 by Contributors 3 | * \file tvm/node/memory.h 4 | * \brief Node memory management. 5 | */ 6 | #ifndef HALIDEIR_TVM_NODE_MEMORY_H_ 7 | #define HALIDEIR_TVM_NODE_MEMORY_H_ 8 | 9 | #include "node.h" 10 | 11 | namespace tvm { 12 | /*! 13 | * \brief Allocate a node object. 14 | * \param args arguments to the constructor. 15 | * \tparam T the node type. 16 | * \return The NodePtr to the allocated object. 17 | */ 18 | template 19 | inline NodePtr make_node(Args&&... args); 20 | 21 | // Detail implementations after this 22 | // 23 | // The current design allows swapping the 24 | // allocator pattern when necessary. 25 | // 26 | // Possible future allocator optimizations: 27 | // - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) 28 | // - Thread-local object pools: one pool per size and alignment requirement. 29 | // - Can specialize by type of object to give the specific allocator to each object. 30 | // 31 | template 32 | class SimpleNodeAllocator { 33 | public: 34 | template 35 | static T* New(Args&&... args) { 36 | return new T(std::forward(args)...); 37 | } 38 | static NodeBase::FDeleter Deleter() { 39 | return Deleter_; 40 | } 41 | 42 | private: 43 | static void Deleter_(NodeBase* ptr) { 44 | delete static_cast(ptr); 45 | } 46 | }; 47 | 48 | template 49 | inline NodePtr make_node(Args&&... args) { 50 | using Allocator = SimpleNodeAllocator; 51 | static_assert(std::is_base_of::value, 52 | "make_node can only be used to create NodeBase"); 53 | T* node = Allocator::New(std::forward(args)...); 54 | node->deleter_ = Allocator::Deleter(); 55 | return NodePtr(node); 56 | } 57 | 58 | } // namespace tvm 59 | #endif // HALIDEIR_TVM_NODE_MEMORY_H_ 60 | -------------------------------------------------------------------------------- /src/tvm/node/node.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2018 by Contributors 3 | * Implementation of IR Node API 4 | * \file node.cc 5 | */ 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace tvm { 13 | 14 | namespace { 15 | // single manager of operator information. 16 | struct TypeManager { 17 | // mutex to avoid registration from multiple threads. 18 | // recursive is needed for trigger(which calls UpdateAttrMap) 19 | std::mutex mutex; 20 | std::atomic type_counter{0}; 21 | std::unordered_map key2index; 22 | std::vector index2key; 23 | // get singleton of the 24 | static TypeManager* Global() { 25 | static TypeManager inst; 26 | return &inst; 27 | } 28 | }; 29 | } // namespace 30 | 31 | EXPORT const bool Node::_DerivedFrom(uint32_t tid) const { 32 | static uint32_t tindex = TypeKey2Index(Node::_type_key); 33 | return tid == tindex; 34 | } 35 | 36 | // this is slow, usually caller always hold the result in a static variable. 37 | EXPORT uint32_t Node::TypeKey2Index(const char* key) { 38 | TypeManager *t = TypeManager::Global(); 39 | std::lock_guard lock(t->mutex); 40 | std::string skey = key; 41 | auto it = t->key2index.find(skey); 42 | if (it != t->key2index.end()) { 43 | return it->second; 44 | } 45 | uint32_t tid = ++(t->type_counter); 46 | t->key2index[skey] = tid; 47 | t->index2key.push_back(skey); 48 | return tid; 49 | } 50 | 51 | EXPORT const char* Node::TypeIndex2Key(uint32_t index) { 52 | TypeManager *t = TypeManager::Global(); 53 | std::lock_guard lock(t->mutex); 54 | internal_assert(index != 0); 55 | return t->index2key.at(index - 1).c_str(); 56 | } 57 | 58 | } // namespace tvm 59 | -------------------------------------------------------------------------------- /src/tvm/node/node.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2018 by Contributors 3 | * \file tvm/node/node.h 4 | * \brief Node system data structure. 5 | */ 6 | #ifndef HALIDEIR_TVM_NODE_NODE_H_ 7 | #define HALIDEIR_TVM_NODE_NODE_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | #include "base/Type.h" 13 | #include "node_base.h" 14 | 15 | namespace tvm { 16 | using HalideIR::Type; 17 | // forward declaration 18 | class Node; 19 | class NodeRef; 20 | 21 | namespace runtime { 22 | // forward declaration 23 | class NDArray; 24 | // forward declaration 25 | class Object; 26 | } // namespace runtime 27 | 28 | /*! 29 | * \brief Visitor class to each node content. 30 | * The content is going to be called for each field. 31 | */ 32 | class EXPORT AttrVisitor { 33 | public: 34 | //! \cond Doxygen_Suppress 35 | virtual ~AttrVisitor() = default; 36 | virtual void Visit(const char* key, double* value) = 0; 37 | virtual void Visit(const char* key, int64_t* value) = 0; 38 | virtual void Visit(const char* key, uint64_t* value) = 0; 39 | virtual void Visit(const char* key, int* value) = 0; 40 | virtual void Visit(const char* key, bool* value) = 0; 41 | virtual void Visit(const char* key, std::string* value) = 0; 42 | virtual void Visit(const char* key, void** value) = 0; 43 | virtual void Visit(const char* key, Type* value) = 0; 44 | virtual void Visit(const char* key, NodeRef* value) = 0; 45 | virtual void Visit(const char* key, runtime::NDArray* value) = 0; 46 | virtual void Visit(const char* key, runtime::Object* value) = 0; 47 | template::value>::type> 49 | void Visit(const char* key, ENum* ptr) { 50 | static_assert(std::is_same::type>::value, 51 | "declare enum to be enum int to use visitor"); 52 | this->Visit(key, reinterpret_cast(ptr)); 53 | } 54 | //! \endcond 55 | }; 56 | 57 | /*! 58 | * \brief base class of node container in DSL AST. 59 | * All object's internal is stored as std::shared_ptr 60 | */ 61 | class EXPORT Node : public NodeBase { 62 | public: 63 | /*! \brief virtual destructor */ 64 | virtual ~Node() {} 65 | /*! \return The unique type key of the node */ 66 | virtual const char* type_key() const = 0; 67 | /*! 68 | * \brief Apply visitor to each field of the Node 69 | * Visitor could mutate the content of the node. 70 | * override if Node contains attribute fields. 71 | * \param visitor The visitor 72 | */ 73 | virtual void VisitAttrs(AttrVisitor* visitor) {} 74 | /*! \return the type index of the node */ 75 | virtual const uint32_t type_index() const = 0; 76 | /*! 77 | * \brief Whether this node derives from node with type_index=tid. 78 | * Implemented by TVM_DECLARE_NODE_TYPE_INFO 79 | * 80 | * \param tid The type index. 81 | * \return the check result. 82 | */ 83 | virtual const bool _DerivedFrom(uint32_t tid) const; 84 | /*! 85 | * \brief get a runtime unique type index given a type key 86 | * \param type_key Type key of a type. 87 | * \return the corresponding type index. 88 | */ 89 | static uint32_t TypeKey2Index(const char* type_key); 90 | /*! 91 | * \brief get type key from type index. 92 | * \param index The type index 93 | * \return the corresponding type key. 94 | */ 95 | static const char* TypeIndex2Key(uint32_t index); 96 | /*! 97 | * \return whether the type is derived from 98 | */ 99 | template 100 | inline bool derived_from() const; 101 | /*! 102 | * \return whether the node is of type T 103 | * \tparam The type to be checked. 104 | */ 105 | template 106 | inline bool is_type() const; 107 | /*! 108 | * \brief Get a NodePtr that holds reference to this Node. 109 | * \return the NodePtr 110 | */ 111 | inline NodePtr GetNodePtr() const; 112 | // node ref can see this 113 | friend class NodeRef; 114 | static constexpr const char* _type_key = "Node"; 115 | }; 116 | 117 | /*! \brief Base class of all node reference object */ 118 | class NodeRef { 119 | public: 120 | /*! \brief type indicate the container type */ 121 | using ContainerType = Node; 122 | /*! 123 | * \brief Comparator 124 | * \param other Another node ref. 125 | * \return the compare result. 126 | */ 127 | inline bool operator==(const NodeRef& other) const; 128 | /*! 129 | * \brief Comparator 130 | * \param other Another node ref. 131 | * \return the compare result. 132 | */ 133 | inline bool same_as(const NodeRef& other) const; 134 | /*! 135 | * \brief Comparator 136 | * \param other Another node ref. 137 | * \return the compare result. 138 | */ 139 | inline bool operator<(const NodeRef& other) const; 140 | /*! 141 | * \brief Comparator 142 | * \param other Another node ref. 143 | * \return the compare result. 144 | */ 145 | inline bool operator!=(const NodeRef& other) const; 146 | /*! \return the hash function for NodeRef */ 147 | inline size_t hash() const; 148 | /*! \return whether the expression is null */ 149 | inline bool defined() const; 150 | /*! \return the internal type index of IRNode */ 151 | inline uint32_t type_index() const; 152 | /*! \return the internal node pointer */ 153 | inline const Node* get() const; 154 | /*! \return the internal node pointer */ 155 | inline const Node* operator->() const; 156 | /*! 157 | * \brief Downcast this ir node to its actual type (e.g. Add, or 158 | * Select). This returns nullptr if the node is not of the requested 159 | * type. Example usage: 160 | * 161 | * if (const Add *add = node->as()) { 162 | * // This is an add node 163 | * } 164 | * \tparam T the target type, must be subtype of IRNode 165 | */ 166 | template 167 | inline const T *as() const; 168 | /*! 169 | * \brief A more powerful version of as that also works with 170 | * intermediate base types. 171 | * \tparam T the target type, must be subtype of IRNode 172 | */ 173 | template 174 | inline const T *as_derived() const; 175 | /*! \brief default constructor */ 176 | NodeRef() = default; 177 | explicit NodeRef(NodePtr node) : node_(node) {} 178 | /*! \brief the internal node object, do not touch */ 179 | NodePtr node_; 180 | }; 181 | 182 | /*! 183 | * \brief Get a reference type from a Node ptr type 184 | * 185 | * It is always important to get a reference type 186 | * if we want to return a value as reference or keep 187 | * the node alive beyond the scope of the function. 188 | * 189 | * \param ptr The node pointer 190 | * \tparam RefType The reference type 191 | * \tparam NodeType The node type 192 | * \return The corresponding RefType 193 | */ 194 | template 195 | inline RefType GetRef(const NodeType* ptr); 196 | 197 | /*! 198 | * \brief Downcast a base reference type to a more specific type. 199 | * 200 | * \param ref The inptut reference 201 | * \return The corresponding SubRef. 202 | * \tparam SubRef The target specific reference type. 203 | * \tparam BaseRef the current reference type. 204 | */ 205 | template 206 | inline SubRef Downcast(BaseRef ref); 207 | 208 | /*! 209 | * \brief helper macro to declare type information in a base node. 210 | */ 211 | #define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ 212 | const bool _DerivedFrom(uint32_t tid) const override { \ 213 | static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ 214 | if (tidx == tid) return true; \ 215 | return Parent::_DerivedFrom(tid); \ 216 | } 217 | 218 | /*! 219 | * \brief helper macro to declare type information in a terminal node 220 | */ 221 | #define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ 222 | const char* type_key() const final { \ 223 | return TypeName::_type_key; \ 224 | } \ 225 | const uint32_t type_index() const final { \ 226 | static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ 227 | return tidx; \ 228 | } \ 229 | const bool _DerivedFrom(uint32_t tid) const final { \ 230 | static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ 231 | if (tidx == tid) return true; \ 232 | return Parent::_DerivedFrom(tid); \ 233 | } 234 | 235 | // implementations of inline functions after this 236 | template 237 | inline bool Node::derived_from() const { 238 | // use static field so query only happens once. 239 | static uint32_t type_id = Node::TypeKey2Index(T::_type_key); 240 | return this->_DerivedFrom(type_id); 241 | } 242 | 243 | 244 | template 245 | inline bool Node::is_type() const { 246 | // use static field so query only happens once. 247 | static uint32_t type_id = Node::TypeKey2Index(T::_type_key); 248 | return type_id == this->type_index(); 249 | } 250 | 251 | 252 | inline NodePtr Node::GetNodePtr() const { 253 | return NodePtr(const_cast(this)); 254 | } 255 | 256 | template 257 | inline RefType GetRef(const NodeType* ptr) { 258 | static_assert(std::is_base_of::value, 259 | "Can only cast to the ref of same container type"); 260 | return RefType(ptr->GetNodePtr()); 261 | } 262 | 263 | template 264 | inline SubRef Downcast(BaseRef ref) { 265 | CHECK(ref->template is_type() || 266 | ref->template derived_from()) 267 | << "Downcast from " << ref->type_key() << " to " 268 | << SubRef::ContainerType::_type_key << " failed."; 269 | return SubRef(std::move(ref.node_)); 270 | } 271 | 272 | inline const Node* NodeRef::get() const { 273 | return node_.get(); 274 | } 275 | 276 | inline const Node* NodeRef::operator->() const { 277 | return node_.get(); 278 | } 279 | 280 | inline bool NodeRef::defined() const { 281 | return node_.get() != nullptr; 282 | } 283 | 284 | inline bool NodeRef::operator==(const NodeRef& other) const { 285 | return node_.get() == other.node_.get(); 286 | } 287 | 288 | inline bool NodeRef::same_as(const NodeRef& other) const { 289 | return node_.get() == other.node_.get(); 290 | } 291 | 292 | inline bool NodeRef::operator<(const NodeRef& other) const { 293 | return node_.get() < other.node_.get(); 294 | } 295 | 296 | inline bool NodeRef::operator!=(const NodeRef& other) const { 297 | return node_.get() != other.node_.get(); 298 | } 299 | 300 | inline size_t NodeRef::hash() const { 301 | return std::hash()(node_.get()); 302 | } 303 | 304 | inline uint32_t NodeRef::type_index() const { 305 | CHECK(node_.get() != nullptr) 306 | << "null type"; 307 | return get()->type_index(); 308 | } 309 | 310 | template 311 | inline const T* NodeRef::as() const { 312 | const Node* ptr = static_cast(get()); 313 | if (ptr && ptr->is_type()) { 314 | return static_cast(ptr); 315 | } 316 | return nullptr; 317 | } 318 | 319 | template 320 | inline const T* NodeRef::as_derived() const { 321 | const Node* ptr = static_cast(get()); 322 | if (ptr && (ptr->is_type() || ptr->derived_from())) { 323 | return static_cast(ptr); 324 | } 325 | return nullptr; 326 | } 327 | 328 | /*! \brief The hash function for nodes */ 329 | struct NodeHash { 330 | size_t operator()(const NodeRef& a) const { 331 | return a.hash(); 332 | } 333 | }; 334 | 335 | /*! \brief The equal comparator for nodes */ 336 | struct NodeEqual { 337 | bool operator()(const NodeRef& a, const NodeRef& b) const { 338 | return a.get() == b.get(); 339 | } 340 | }; 341 | } // namespace tvm 342 | #endif // HALIDEIR_TVM_NODE_NODE_H_ 343 | -------------------------------------------------------------------------------- /src/tvm/node/node_base.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2018 by Contributors 3 | * \file tvm/runtime/node_base.h 4 | * \brief Base data structure for Node. 5 | * 6 | * \note Node is not a runtime feature. 7 | * This file only exposes the signature of NodePtr for PackedFunc. 8 | */ 9 | #ifndef TVM_RUNTIME_NODE_BASE_H_ 10 | #define TVM_RUNTIME_NODE_BASE_H_ 11 | 12 | #include 13 | #include 14 | 15 | namespace tvm { 16 | 17 | // forward declarations 18 | template 19 | class NodePtr; 20 | class Node; 21 | class NodeRef; 22 | 23 | /*! 24 | * \brief Base class of Node for runtime destructor purposes. 25 | * 26 | * Node is a reference counted object which is used to construct AST. 27 | * Each node is backed by a custom deleter, which deletes the object. 28 | * Do not call create raw Node pointer, always use tvm::make_node. 29 | * 30 | * \note In most cases, please inheritate tvm::Node. 31 | * \sa Node, NodePtr, make_node 32 | */ 33 | class NodeBase { 34 | public: 35 | /*! 36 | * \brief type of NodeBase deleter 37 | * \param self pointer to the NodeBase. 38 | */ 39 | typedef void (*FDeleter)(NodeBase* self); 40 | 41 | protected: 42 | // default constructor and copy constructor 43 | NodeBase() {} 44 | // override the copy and assign constructors to do nothing. 45 | // This is to make sure only contents, but not deleter and ref_counter 46 | // are copied when a child class copies itself. 47 | NodeBase(const NodeBase& other) { // NOLINT(*) 48 | } 49 | NodeBase(NodeBase&& other) { // NOLINT(*) 50 | } 51 | NodeBase& operator=(const NodeBase& other) { //NOLINT(*) 52 | return *this; 53 | } 54 | NodeBase& operator=(NodeBase&& other) { //NOLINT(*) 55 | return *this; 56 | } 57 | 58 | private: 59 | /*! \brief Internal reference counter */ 60 | std::atomic ref_counter_{0}; 61 | /*! 62 | * \brief deleter of this object to enable customized allocation. 63 | * If the deleter is nullptr, no deletion will be performed. 64 | * The creator of the Node must always set the deleter field properly. 65 | */ 66 | FDeleter deleter_ = nullptr; 67 | // reference counting functions 68 | void IncRef() { 69 | ref_counter_.fetch_add(1, std::memory_order_relaxed); 70 | } 71 | void DecRef() { 72 | if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { 73 | std::atomic_thread_fence(std::memory_order_acquire); 74 | if (this->deleter_ != nullptr) { 75 | (*this->deleter_)(this); 76 | } 77 | } 78 | } 79 | int use_count() const { 80 | return ref_counter_.load(std::memory_order_relaxed); 81 | } 82 | // friend declaration 83 | template 84 | friend class NodePtr; 85 | template 86 | friend NodePtr make_node(Args&&...); 87 | }; 88 | 89 | /*! 90 | * \brief Smart pointer for Node containers, 91 | * must be subclass of NodeBase 92 | * \tparam T the content data type. 93 | */ 94 | template 95 | class NodePtr { 96 | public: 97 | /*! \brief default constructor */ 98 | NodePtr() {} 99 | /*! \brief default constructor */ 100 | NodePtr(std::nullptr_t) {} // NOLINT(*) 101 | /*! 102 | * \brief copy constructor 103 | * \param other The value to be moved 104 | */ 105 | NodePtr(const NodePtr& other) // NOLINT(*) 106 | : NodePtr(other.data_) { 107 | } 108 | /*! 109 | * \brief copy constructor 110 | * \param other The value to be moved 111 | */ 112 | template 113 | NodePtr(const NodePtr& other) // NOLINT(*) 114 | : NodePtr(other.data_) { 115 | static_assert(std::is_base_of::value, 116 | "can only assign of child class NodePtr to parent"); 117 | } 118 | /*! 119 | * \brief move constructor 120 | * \param other The value to be moved 121 | */ 122 | NodePtr(NodePtr&& other) // NOLINT(*) 123 | : data_(other.data_) { 124 | other.data_ = nullptr; 125 | } 126 | /*! 127 | * \brief move constructor 128 | * \param other The value to be moved 129 | */ 130 | template 131 | NodePtr(NodePtr&& other) // NOLINT(*) 132 | : data_(other.data_) { 133 | static_assert(std::is_base_of::value, 134 | "can only assign of child class NodePtr to parent"); 135 | other.data_ = nullptr; 136 | } 137 | /*! \brief destructor */ 138 | ~NodePtr() { 139 | this->reset(); 140 | } 141 | /*! 142 | * \brief Swap this array with another NDArray 143 | * \param other The other NDArray 144 | */ 145 | void swap(NodePtr& other) { // NOLINT(*) 146 | std::swap(data_, other.data_); 147 | } 148 | /*! 149 | * \return Get the content of the pointer 150 | */ 151 | T* get() const { 152 | return static_cast(data_); 153 | } 154 | /*! 155 | * \return The pointer 156 | */ 157 | T* operator->() const { 158 | return get(); 159 | } 160 | /*! 161 | * \return The reference 162 | */ 163 | T& operator*() const { // NOLINT(*) 164 | return *get(); 165 | } 166 | /*! 167 | * \brief copy assignmemt 168 | * \param other The value to be assigned. 169 | * \return reference to self. 170 | */ 171 | NodePtr& operator=(const NodePtr& other) { // NOLINT(*) 172 | // takes in plane operator to enable copy elison. 173 | // copy-and-swap idiom 174 | NodePtr(other).swap(*this); // NOLINT(*) 175 | return *this; 176 | } 177 | /*! 178 | * \brief move assignmemt 179 | * \param other The value to be assigned. 180 | * \return reference to self. 181 | */ 182 | NodePtr& operator=(NodePtr&& other) { // NOLINT(*) 183 | // copy-and-swap idiom 184 | NodePtr(std::move(other)).swap(*this); // NOLINT(*) 185 | return *this; 186 | } 187 | /*! \brief reset the content of ptr to be nullptr */ 188 | void reset() { 189 | if (data_ != nullptr) { 190 | data_->DecRef(); 191 | data_ = nullptr; 192 | } 193 | } 194 | /*! \return The use count of the ptr, for debug purposes */ 195 | int use_count() const { 196 | return data_ != nullptr ? data_->use_count() : 0; 197 | } 198 | /*! \return whether the reference is unique */ 199 | bool unique() const { 200 | return data_ != nullptr && data_->use_count() == 1; 201 | } 202 | /*! \return Whether two NodePtr do not equals each other */ 203 | bool operator==(const NodePtr& other) const { 204 | return data_ == other.data_; 205 | } 206 | /*! \return Whether two NodePtr equals each other */ 207 | bool operator!=(const NodePtr& other) const { 208 | return data_ != other.data_; 209 | } 210 | /*! \return Whether the pointer is nullptr */ 211 | bool operator==(std::nullptr_t null) const { 212 | return data_ == nullptr; 213 | } 214 | /*! \return Whether the pointer is not nullptr */ 215 | bool operator!=(std::nullptr_t null) const { 216 | return data_ != nullptr; 217 | } 218 | 219 | private: 220 | /*! \brief internal pointer field */ 221 | NodeBase* data_{nullptr}; 222 | /*! 223 | * \brief constructor from NodeBase 224 | * \param data The node base pointer 225 | */ 226 | explicit NodePtr(NodeBase* data) 227 | : data_(data) { 228 | if (data != nullptr) { 229 | data_->IncRef(); 230 | } 231 | } 232 | // friend declaration 233 | friend class Node; 234 | template 235 | friend class NodePtr; 236 | template 237 | friend NodePtr make_node(Args&&...); 238 | }; 239 | } // namespace tvm 240 | 241 | #endif // TVM_RUNTIME_NODE_BASE_H_ 242 | --------------------------------------------------------------------------------