├── .github └── workflows │ └── ccpp.yml ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README-old.md ├── README.md ├── include ├── IR.h ├── IRMutator.h ├── IRPrinter.h ├── IRVisitor.h ├── arith.h ├── debug.h └── type.h ├── project1 ├── .gitignore ├── CMakeLists.txt ├── cases │ ├── case1.json │ ├── case10.json │ ├── case2.json │ ├── case3.json │ ├── case4.json │ ├── case5.json │ ├── case6.json │ ├── case7.json │ ├── case8.json │ ├── case9.json │ └── example.json ├── clean │ └── clean.cc ├── kernels │ ├── kernel_case1.cc │ ├── kernel_case10.cc │ ├── kernel_case2.cc │ ├── kernel_case3.cc │ ├── kernel_case4.cc │ ├── kernel_case5.cc │ ├── kernel_case6.cc │ ├── kernel_case7.cc │ ├── kernel_case8.cc │ ├── kernel_case9.cc │ └── kernel_example.cc ├── run.cc ├── run.h └── solution │ └── example_solution.cc ├── project2 ├── .gitignore ├── CMakeLists.txt ├── cases │ ├── case1.json │ ├── case10.json │ ├── case2.json │ ├── case3.json │ ├── case4.json │ ├── case5.json │ ├── case6.json │ ├── case7.json │ ├── case8.json │ └── case9.json ├── clean │ └── clean2.cc ├── kernels │ ├── grad_case1.cc │ ├── grad_case10.cc │ ├── grad_case2.cc │ ├── grad_case3.cc │ ├── grad_case4.cc │ ├── grad_case5.cc │ ├── grad_case6.cc │ ├── grad_case7.cc │ ├── grad_case8.cc │ └── grad_case9.cc ├── run2.cc ├── run2.h └── solution │ └── solution2.cc ├── src ├── IR.cc ├── IRMutator.cc ├── IRPrinter.cc └── IRVisitor.cc ├── test ├── CMakeLists.txt ├── conv2d.cc ├── gemm.cc └── ir_mutator.cc ├── 编译大作业-第一部分.md ├── 编译大作业-第一部分.pdf ├── 编译大作业-第二部分.md └── 编译大作业-第二部分.pdf /.github/workflows/ccpp.yml: -------------------------------------------------------------------------------- 1 | name: C/C++ CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: submodule 17 | run: git submodule init && git submodule update 18 | - name: cmake 19 | run: mkdir build && cd build && cmake --version && cmake .. 20 | - name: make 21 | run: cd build && make -j 4 22 | - name: test 23 | run: cd build/test && ./gemm && ./conv2d 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | .DS_Store 3 | .vscode -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/jsoncpp"] 2 | path = 3rdparty/jsoncpp 3 | url = https://github.com/open-source-parsers/jsoncpp.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(boost C CXX) 3 | set(LIB_NAME boost) 4 | 5 | # include directories 6 | include_directories(${CMAKE_INCLUDE_PATH}) 7 | include_directories("include") 8 | 9 | # initial variables 10 | set(BOOST_LINKER_LIBS "") 11 | set(BOOST_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) 12 | 13 | # Generic compilation options 14 | if(MSVC) 15 | add_definitions(-DWIN32_LEAN_AND_MEAN) 16 | add_definitions(-D_CRT_SECURE_NO_WARNINGS) 17 | add_definitions(-D_SCL_SECURE_NO_WARNINGS) 18 | add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE) 19 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") 21 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj") 22 | if(USE_MSVC_MT) 23 | foreach(flag_var 24 | CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE 25 | CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) 26 | if(${flag_var} MATCHES "/MD") 27 | string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") 28 | endif(${flag_var} MATCHES "/MD") 29 | endforeach(flag_var) 30 | endif() 31 | else(MSVC) 32 | if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") 33 | message("Build in Debug mode") 34 | set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}") 35 | set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}") 36 | else() 37 | set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") 38 | set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC ${CMAKE_CXX_FLAGS}") 39 | if (HIDE_PRIVATE_SYMBOLS) 40 | message(STATUS "Hide private symbols...") 41 | set(CMAKE_C_FLAGS "-fvisibility=hidden ${CMAKE_C_FLAGS}") 42 | set(CMAKE_CXX_FLAGS "-fvisibility=hidden ${CMAKE_CXX_FLAGS}") 43 | endif(HIDE_PRIVATE_SYMBOLS) 44 | endif () 45 | if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND 46 | CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) 47 | set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}") 48 | endif() 49 | endif(MSVC) 50 | 51 | 52 | # Source file 53 | file(GLOB COMPILER_SRCS 54 | src/*.cc 55 | ) 56 | 57 | file(GLOB COMPILER_INCLUDES 58 | include/*.h 59 | ) 60 | 61 | 62 | if(NOT MSVC) 63 | include(CheckCXXCompilerFlag) 64 | check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) 65 | message(STATUS "Build with c++11") 66 | set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS}") 67 | endif() 68 | 69 | add_library(${LIB_NAME} SHARED ${COMPILER_SRCS} ${COMPILER_INCLUDES}) 70 | 71 | target_link_libraries(${LIB_NAME} ${BOOST_LINKER_LIBS} ${BOOST_RUNTIME_LINKER_LIBS}) 72 | 73 | if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") 74 | set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL") 75 | # Note: 'target_link_options' with 'PRIVATE' keyword would be cleaner 76 | # but it's not available until CMake 3.13. Switch to 'target_link_options' 77 | # once minimum CMake version is bumped up to 3.13 or above. 78 | target_link_libraries(${LIB_NAME} ${HIDE_SYMBOLS_LINKER_FLAGS}) 79 | endif() 80 | 81 | add_subdirectory(test) 82 | add_subdirectory(project1) 83 | add_subdirectory(project2) 84 | 85 | # JSON targets. 86 | add_subdirectory("3rdparty/jsoncpp" 87 | ${CMAKE_CURRENT_BINARY_DIR}/jsoncpp-build 88 | EXCLUDE_FROM_ALL) 89 | 90 | target_link_libraries(${LIB_NAME} jsoncpp_lib) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Size Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-old.md: -------------------------------------------------------------------------------- 1 | ![Build](https://github.com/pku-compiler-design-spring/CompilerProject-2020Spring-Part1/workflows/C/C++%20CI/badge.svg?branch=master) 2 | 3 | ## Code Generation Compiler 4 | 5 | This project is designed for undergraduate students who are taking Compiler Design courses in spring. 6 | 7 | > Author: Size Zheng 8 | 9 | > Email: zhengsz@pku.edu.cn 10 | 11 | ### BUG report And Bonus 12 | 13 | __Format:__ [date] "message" by **reporters** [bonus] 14 | 15 | 1. [2020-4-14] "In run.cc case 4,5 golden array shape bug" by **Ye Yuan, Anjiang Wei, Yuyue Wang, Chenyang Yang** [+1] 16 | 2. [2020-4-15] "In document, input BNF bug in AList" by **Jing Mai, Can Su, Zixuan Ling** [+1] 17 | 3. [2020-4-16] "In CMakeLists.txt, link target to library" by **Chenqian Wang, Jiaqi Zhang, Wenqi Wang** [+1] 18 | 19 | ### 1. Overview 20 | In this project, we provide several useful IR nodes and corresponding IRVisitor and IRMutator. The concept behind these structs are well studied in [Halide](https://github.com/halide/Halide) and [TVM](https://github.com/apache/incubator-tvm). Here we invent some new IR nodes and re-implement the Visitor and Mutator for them. 21 | 22 | The purpose of this project is to help students to better understand how to build a IR system and implement a simple code generate tool. 23 | 24 | The IR infrastructure of this project contains four levels: 25 | 26 | ``` 27 | Program 28 | Group 29 | Stmt 30 | Expr 31 | ``` 32 | The first level `Program` is not explicitly implemented. 33 | Each level of IR has several different type of nodes: 34 | ``` 35 | Group: Kernel 36 | Stmt: LoopNest, IfThenElse, Move 37 | Expr: IntImm, 38 | UIntImm, 39 | FloatImm, 40 | StringImm, 41 | Unary, 42 | Binary, 43 | Select, 44 | Compare, 45 | Call, 46 | Var, 47 | Cast, 48 | Ramp, 49 | Index, 50 | Dom 51 | ``` 52 | 53 | Use these IR nodes we can potentially represent many kinds of programs. 54 | 55 | ### 2. Build 56 | ```sh 57 | mkdir build 58 | cd build 59 | cmake .. 60 | make -j 4 61 | ``` 62 | 63 | ### 3. Example 64 | In `test` directory, thre are two examples of `gemm` and `conv2d`, they are good examples of how to represent computations by our IR infrastructure. If you run them: 65 | ```sh 66 | cd build/test 67 | ./gemm 68 | ./conv2d 69 | ``` 70 | You can see the results are very similar to C programs, however, the printed strings are just intermediate representation, you can't run the printed strings for now. We hope you can improve current system to print exactly C/C++ programs and compile them using C/C++ compilers 71 | 72 | 73 | ### 4. Tasks 74 | 1. Please read the source code base throughly, you need to understand every parts of it. 75 | 2. You need to implment you C/C++ code genreation. Hints: learn how the IRPrinter works, imitate it and try to write a new IRVisitor which can print C/C++ source codes. 76 | 2. Go to `project1` directory, you will find many json files in `case` directory. The are inputs to your questions. For example, `example.json` contains: 77 | ```json 78 | { 79 | "name": "kernel_example", 80 | "ins": ["B", "C"], 81 | "outs": ["A"], 82 | "data_type": "float", 83 | "kernel": "A<32, 16>[i, j] = C<32, 16>[i, j] * B<32, 16>[i, j];" 84 | } 85 | ``` 86 | It means you need to generate a `.cc` file which implements the computation of `A<32, 16>[i, j] = C<32, 16>[i, j] * B<32, 16>[i, j];`. Put the computation in a function named `kernel_example`, whose inputs are `B` and `C`, and output is `A`, the data type is `float`. In the expression, we can see `A` has shape of [32, 16], and also `B` and `C`. So the function's signature is 87 | ```c 88 | void kernel_example(float (&B)[32][16], float (&C)[32][16], float (&A)[32][16]) 89 | ``` 90 | Please try to generate C/C++ source files for these json files and put them under directory `kernels`. 91 | 92 | 4. Your code genration application source files should be placed in `solution` directory. (But your code genration passes can be put in outer directories such as `include` and `src`) 93 | 94 | 95 | ### 5. Notice 96 | 1. We present a silly solution in `solution` directory, please do not follow such silly manner. The example is just used to tell you how our framework works. 97 | 2. All the source files you put in `solution` directory should only contain one `main` function, as we will compile all the source files in `solution` directory into one executable file. 98 | 3. Please be careful and do not delete important files, which may break down the system. 99 | 4. If you want to test your designs, just enter the `build` directory, run `make -j 4`, you will see the binaries in `build` directory, there are sub-directories such as `project1`, your executable files should be placed there automatically. 100 | 5. You are not supposed to modify `run.h` and `run.cc`. These files will be changed to another version which contains the full 10 test cases, so any modification is meaningless. 101 | 6. If you are confused about what kinds of C/C++ code you are supposed to generate, see `solution/example_solution.cc`. 102 | 103 | ### 6. Judge 104 | 1. We provide auto-test file, after building the project, enter `build/project1`, and run `./test1`, you can see the results. 105 | 2. We only show you 6 cases and 4 cases are hidden. The TAs will test all the 10 files and decide scores according to how many cases you can pass. Don't be worried, the hidden cases are no more complex than the open cases. If you can handle the open cases, you should pass all hidden cases. 106 | 3. Do not copy the codes from others, we will do the check! Any intends to break this rule will result a 0 score to you. 107 | 108 | 109 | ### 7. How it works? 110 | When you build the project, we will actually build four parts: 111 | - the files in `include` and `src` 112 | - the files in `test` 113 | - the files in `project1` are compiled to one executable 114 | - the files in `project1/solution` are compiled to one executable 115 | 116 | And we will automatically clean files under `kernels/*.cc`, so you can't expect to modify them manually. 117 | 118 | Then we will call the executable from `project1/solution` automatically, which is expected to generate all the functions and put them in `kernels/*.cc`. 119 | 120 | At last, we will run `./test1` manually to see your results and decide your scores according to the results. 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Build](https://github.com/pku-compiler-design-spring/CompilerProject-2020Spring/workflows/C/C++%20CI/badge.svg?branch=master) 2 | 3 | # 编译大作业 4 | ## 第二部分——自动求导的编译器 5 | 6 | 7 | ### 1. 前言 8 | 在第一部分的作业中,我们做的事情是根据输入的表达式生成C/C++代码,并且在10个例子上测试正确性(6个公开,4个隐藏)。此时,每位同学手头都应该有一个可用的代码生成器了。回忆我们做这个project的初衷,我们想要做一个面向当前重要应用——深度学习——的代码生成工具,利用我们编译课上学习的知识完成这一任务。在第一部分中,我们体会了词法分析、语法分析、中间表示形式(IR Node)、语法树构建、语法树遍历(通过IRVisitor)和代码生成,并且还可能用到了少数SDD, SDT中的知识。我们第二次project将继续这个方向,利用编译技术做更多有趣的功能,这一次,我们的重点将放在语法树的变换(一个个pass)上来,对于变换的设计可能会用到课本上更多的知识(但不一定是严格局限课本的例子,同学们可以根据实际情况活用)。这一次project可能对于一些同学来说比较困难,希望通过小组合作,大家都能掌握这个过程中需要的知识和技术。 9 | 10 | 11 | ### 2. 问题描述 12 | 13 | #### 2.1 传统的深度学习框架求导 14 | 自动求导是深度学习中当前必不可少的功能(依赖于梯度优化的算法都摆脱不了求导过程)。在深度学习框架中(如Tensorflow, PyTorch),自动求导都是由框架完成的,它们的方法论是,首先形成计算图,然后根据链式法则构建计算梯度的图。举一个例子,一个简单的计算过程为: 15 | ```py 16 | X is a tensor of shape [4, 3, 28, 28] 17 | T is a label of shape [4, 8 * 28 * 28] 18 | Y1 = Conv2d(X, kernel=(8, 3, 3, 3), padding=1, stride=1) # result shape is [4, 8, 28, 28] 19 | Y2 = flatten(Y1) # result shape is [4, 8 * 28 * 28] 20 | loss = mse_loss(Y2, T) # loss is scalar 21 | ``` 22 | 如果想要求出对于X的导数(虽然常见情况是对于网络参数求导,而不是网络输入,但这里只是做一个例子),就要从loss开始求,首先loss对于自己的导数是1,然后求Y2的导数,框架发现Y2用来计算loss时,使用的时mse_loss函数,于是找到了mse_loss函数的导函数grad_mse_loss,用于计算Y2的导数;接着对于Y1,框架又发现Y2是通过flatten函数求出来的,于是找到了flatten函数的导函数grad_flatten,利用这个导函数求出对于Y1的导数,继续向上求X的导数,框架又发现了Conv2d层,于是找到了对应的卷积求导函数,用于求X的导数。可以看到,这个过程除了链式法则,框架还在不断地识别正向传播时使用的函数/层的名字,然后在自己的函数库里寻找对应的导函数,框架知道应该找哪个导函数,都是依赖于编写框架的人了解这些知识,然后在框架的库里准备好需要的函数们。 23 | 这是一种传统的求导方式,它的粒度是算子(加减乘除也算是算子),而在我们这次project中,我们将使用编译技术自动地根据前向算子计算定义生成其反向计算导数的函数,整个过程,不需要特意知道算子的名称,只需要看到数学表达式即可。这样做的一个优势是,深度学习应用的可扩展性将被加强,当有人希望自己设计一个算子时,他/她不再需要自己推导导函数的定义,然后自己实现出来再注册到框架里使用,而是只需要提供一个正向传播的计算表达式,就可以得到对应的导函数计算定义。 24 | 25 | #### 2.2 问题定义 26 | __现在我们开始进行问题描述:__ 27 | 对于一个给定的表达式$Output = expr(Input_1, Input_2, ..., Input_n)$($Output, Input_i$是张量或标量, $expr()$表示用其参数构造一个表达式),我们如果已知了最终loss对于$Output$的导数$dOutput = \frac{\partial loss}{\partial Output}$,我们想知道loss对于某个输入的导函数是什么,也就是求$dInput_i = \frac{\partial loss}{\partial Input_i}$的问题,**我们要求如下:** 28 | - 分析出来的求导表达式是一个或多个赋值语句形式,每个语句左侧的下标索引上不能有加减乘除等运算,也就是不能出现`A[i+1] = B[i]`的形式。 29 | - 必须通过对输入表达式的编译分析过程,综合出求导表达式的内容,并生成代码,不能通过判断case的名字直接得出求导表达式(这样就和传统框架一样了),也不能用打表法直接打印出字符串 30 | 31 | #### 2.3一个例子 32 | 为了帮助理解,我们给一个例子: 33 | ```py 34 | C[i, j] = A[i, k] * B[k, j] 35 | ``` 36 | 基于第一次project的知识,我们知道这个式子表达了一个矩阵乘法。 37 | 现在已知了某个$loss$对于$C$的导数$dC$(是个张量,大小与$C$的大小相同,注意这里的$dC$是个名字,不是算符),假设想要求$dA$,那么根据求导的数学方法得到 38 | $$dA[i, k] = \frac{\partial loss}{\partial A[i, k]} = \sum_{j}{\frac{\partial loss}{\partial C[i, j]} \times \frac{\partial C[i, j]}{\partial A[i, k]}} = dC[i, j] \times B[k, j]$$ 39 | 所以可以得到对于$A$的导数计算式为 40 | ```py 41 | dA[i, k] = dC[i, j] * B[k, j] 42 | ``` 43 | 翻译为C代码就是 44 | ```c 45 | for (int i = 0; i < M; ++i) { 46 | for (int k = 0; k < K; ++k) { 47 | dA[i][k] = 0.0; 48 | for (int j = 0; j < N; ++j) { 49 | dA[i][k] += dC[i][j] * B[k][j]; 50 | } 51 | } 52 | } 53 | ``` 54 | 对于$B$也可以类似写出求导的式子: 55 | ```py 56 | dB[k, j] = dC[i, j] * A[i, k] 57 | ``` 58 | 59 | ### 3. Project输入与输出 60 | 基于上一次project的测试法,我们这次仍然给出10个例子,与第一次project的测试case不同的是,我们给出的json文件中多了一个"grad_to"的键值,这个键值的信息是对哪个/哪些输入(可能一个或多个输入)进行求导。 61 | 比如看case1的json文件: 62 | ```json 63 | { 64 | "name": "grad_case1", 65 | "ins": ["A", "B"], 66 | "outs": ["C"], 67 | "data_type": "float", 68 | "kernel": "C<4, 16>[i, j] = A<4, 16>[i, j] * B<4, 16>[i, j] + 1.0;", 69 | "grad_to": ["A"] 70 | } 71 | ``` 72 | 这里指明了对于$A$进行求导,所以得到的式子应该是 73 | ```py 74 | dA<4, 16>[i, j] = dC<4, 16>[i, j] * B<4, 16>[i, j] 75 | ``` 76 | 这里的$dC$符号就是C的导数,我们认为所有的输出的导数张量都是已知,命名规则都是原来的名字前面加个$d$,此外,我们只考虑正向传播表达式有且仅有一个输出的情况。 77 | 同学们读如json文件,分析正向表达式后,根据编译技术分析出反向传播表达式,然后对这个反向传播的表达式生成C/C++代码,放在kernels/目录对应的文件内。每次cmake这个project时,都会先自动运行solution下的代码,然后运行run2.cc,run2.cc里有测试逻辑,会测试同学们生成的反向传播代码的正确性。 78 | 79 | 80 | #### 3.1 测试例子 81 | 这次一共10个测试例子,全部是公开的,公开理由为: 82 | - 自动求导本身蕴含NP问题,只有下标的变换满足一定条件(如线性)才是易解的,即使是易解的,其求解细节也比较复杂,所以给出具体的10个例子,同学们不必花费太多精力担心输入不可预测性。 83 | - 隐藏例子测试法(第一次project)是防止同学们通过打表法做题,只给出trivial的解决方案(比如直接输出字符串),这个问题可以通过设计审查方法来杜绝(审查法在后面介绍) 84 | - 并非所有同学都有求导的先验知识,所以给出所有例子方便同学们掌握问题,更好地完成任务 85 | 86 | 这10个例子都是紧贴实际深度学习应用的,涵盖的实际应用包括: 87 | 1. element-wise的乘法 88 | 2. 矩阵乘法 89 | 3. dense MTTKRP 90 | 4. 二维普通卷积 91 | 5. 转置 92 | 6. flatten 93 | 7. broadcast 94 | 8. blur 95 | 96 | 考虑到并非所有同学都接触过求导方法,我们提供了ground truth。每个测试例子在run2.cc里都会有一个对应地测试函数,在函数体以及注释里,都可以获取正确求导的结果。比如对case1,test_case1函数为 97 | ```c 98 | bool test_case1(std::mt19937 &gen, std::uniform_real_distribution &dis) { 99 | // "C<4, 16>[i, j] = A<4, 16>[i, j] * B<4, 16>[i, j] + 1.0;" 100 | // "dA<4, 16>[i, j] = dC<4, 16>[i, j] * B<4, 16>[i, j];" 101 | float B[4][16] = {{0}}; 102 | float dA[4][16] = {{0}}; 103 | float dC[4][16] = {{0}}; 104 | float golden[4][16] = {{0}}; 105 | // initialize 106 | for (int i = 0; i < 4; ++i) { 107 | for (int j = 0; j < 16; ++j) { 108 | B[i][j] = dis(gen); 109 | dC[i][j] = dis(gen); 110 | } 111 | } 112 | // compute golden 113 | for (int i = 0; i < 4; ++i) { 114 | for (int j = 0; j < 16; ++j) { 115 | golden[i][j] = dC[i][j] * B[i][j]; 116 | } 117 | } 118 | try { 119 | grad_case1(B, dC, dA); 120 | } catch (...) { 121 | std::cout << "Failed because of runtime error\n"; 122 | return false; 123 | } 124 | 125 | // check 126 | for (int i = 0; i < 4; ++i) { 127 | for (int j = 0; j < 16; ++j) { 128 | if (std::abs(golden[i][j] - dA[i][j]) >= 1e-5) { 129 | std::cout << "Wrong answer\n"; 130 | return false; 131 | } 132 | } 133 | } 134 | // correct 135 | return true; 136 | } 137 | ``` 138 | 可以看注释,或者golden的记算方法,来学习正确的求导结果。 139 | 140 | ### 4. 评分与要求 141 | #### 4.1 关键日期 142 | project2 开始:2020年5月16日晚23:59 143 | project2 截至:2020年6月21日晚23:59 144 | **不接受补交,请及时提交文件,并检查是否提交成功** 145 | #### 4.2 毕业班政策及组队 146 | 毕业班同学有两个选择: 147 | 1. 正常按时完成project2并计分。 148 | 2. 不做project2,使用期末成绩折合20%作为project2的分数。 149 | 150 | 如果选择了第二种,只需要不提交project2即可,助教会自动认为选择了使用期末成绩折合的方式。 151 | 考虑到有毕业班同学之前和非毕业班的同学组队。第二次project允许重新组队,请计划不做project2的毕业班同学不要再组队。新的组队信息在6月1日23:59前发送至compiler2020spring@163.com。没有变更的小组不用发邮件。 152 | #### 4.3 提交途径与要求 153 | ##### 4.3.1 途径 154 | 以小组为单位提交。 155 | 提交代码途径:发送github**链接**到邮箱compiler2020spring@163.com 156 | 157 | ##### 4.3.2 要求 158 | 1. 必须包含一个pdf版本的报告在project2目录下,报告内容必须涵盖小组分工,自动求导技术设计,实现流程,实验结果。其余内容可根据个人爱好添加。 159 | 2. 不可以更改/拷贝run2.h, run2.cc, clean2.cc的内容,其余内容均可自由改动 160 | 3. 不要从stdin读取内容,请从json文件读取输入 161 | 4. 使用编译器版本需要兼容C++11标准(gcc 4.8.5以上应该都满足) 162 | 163 | 发送github链接前,请一定保证在提交截止日期后代码仓是public的,这样助教有权限下载代码(但也要注意不要提前public了,以防有人抄代码)。助教的测试命令为: 164 | ```sh 165 | git clone --recursive <提交的github链接> CompilerProject 166 | cd CompilerProject 167 | mkdir build 168 | cd build 169 | cmake .. 170 | make -j 4 171 | cd project2 172 | ./test2 173 | ``` 174 | 175 | #### 4.4 评分法 176 | 本次project占总成绩20%,这20分中5分来自pdf报告,15分来自提交代码。为了保证每个小组成员都要给出有效贡献,pdf报告中的分工将被参考到最终评分中,同时,如果小组成员举报某一成员并未做出任何贡献,一经查实(查实方法为审核github提交贡献量),将**不予该成员给分**,请小组内部紧密合作。 177 | 178 | pdf评分标准: 179 | - 包含小组分工,自动求导技术设计(2分) 180 | - 包含实现流程,实验结果内容(1分) 181 | - 通过一个具体例子解释所设计的求导技术的可行性和正确性(1分) 182 | - 总结使用到的编译知识,讲解如何实现(1分) 183 | 184 | 代码评分标准: 185 | - 根据通过的case数目计算分数(真实分数),具体按照下表 186 | 187 | | 通过case数目 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 188 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | 189 | | 得分 | 0 | 2.25 | 4.5 | 6.75 | 9 | 10.5 | 12 | 12.75 | 13.5 | 14.25 | 15 | 190 | 191 | - 在运行./test2后会自动打出分数,符合上表的定义 192 | 193 | #### 4.5 审查法 194 | 审查代码是为了防止同学作弊,作弊的定义包含: 195 | - 拷贝或修改run2.h/run2.cc/clean2.cc的内容 196 | - 任何两组的代码重合度过高甚至完全一致 197 | - 完全使用第三方项目解决问题 198 | - 报告内容与实际实现不一致 199 | 200 | 可能使用的审查法包括 201 | 1. 全面审查法,3名助教平均每人审查若干支队伍的代码 202 | 2. 抽样审查法,随机挑选若干队伍进行代码审查 203 | 3. 输入测试法,助教随机修改输入case的json文件,如果仍然能输出有意义的代码到kernels/下(助教会看),说明没有作弊(不测试代码正确性,只考察能生成代码) 204 | 4. 结合报告内容审查法,检查报告中的技术和实现的技术一致性。 205 | 5. 重点审查法,对于得分高于12分的组进行全面代码审查。 206 | 207 | 具体审查方法不公开,请同学们自觉使用编译课所学知识解决问题,我们的给分策略是非线性,通过4个例子的组都可以得到9分以上(60%的分数),过6个例子就可以得到12分(80%的分数),宗旨是鼓励大家侧重于自己设计编译器,而不是唯分数论。 208 | 209 | ### 5. 参考文献与代码 210 | 1. Halide的一个自动求导工作 211 | https://people.csail.mit.edu/tzumao/gradient_halide/gradient_halide.pdf 212 | 2. Halide自动求导代码 213 | https://github.com/halide/Halide/blob/master/src/Derivative.cpp 214 | 3. TVM自动求导代码 215 | https://github.com/apache/incubator-tvm/pull/2498 216 | 4. 二维卷积反向传播推导 217 | https://zhuanlan.zhihu.com/p/61898234 218 | 5. 线性下标变换下求导方法 219 | https://arxiv.org/abs/1711.01348 220 | 221 | ### 6. 讨论 222 | Project可能潜在的bug可以在微信群、github issue上提出,有价值的issue可以为全组加分,每个bug加1分 223 | 另外,鼓励小组内部协作与讨论,也鼓励适当的小组间交流,交流方式为github issue或微信群,助教也会参与讨论,解答一些技术问题。 224 | 225 | ### 附录 226 | #### 1. IRMutator的使用 227 | IRMutator的功能是遍历IR,并且在遍历到每个节点的时候,返回一个新的IR节点。默认的IRMutator行为是返回和先前一摸一样的新节点。实际使用时,可以通过继承IRMutator,并重载特定的visit函数来定制对于IRMutator的遍历和修改行为。所有通过IRMutator对于AST的修改,都是创造新的AST,所以不会影响原来的AST的内容。 228 | 229 | 在test/目录下,ir_mutator.cc文件中展示了一个简单的定制Mutator的过程: 230 | ```c 231 | class MyMutator : public IRMutator { 232 | public: 233 | Expr visit(Ref op) override { 234 | if (op->name == "A") { 235 | return Var::make(op->type(), "modified_A", op->args, op->shape); 236 | } 237 | return IRMutator::visit(op); 238 | } 239 | }; 240 | ``` 241 | 利用这个Mutator,可以把表达式里名字为"A"的Var节点更改为名字为"modified_A"的Var节点。 242 | ```c 243 | MyMutator mutator; 244 | kernel = mutator.mutate(kernel); 245 | ``` 246 | 更改后的kernel,打印出来是这样的: 247 | ```py 248 | simple_gemm(modified_A<1024, 256>, B<256, 512>, C<1024, 512>) { 249 | for i in dom[((int32_t <1>) 0), ((int32_t <1>) 1024)){ 250 | for j in dom[((int32_t <1>) 0), ((int32_t <1>) 512)){ 251 | for k in dom[((int32_t <1>) 0), ((int32_t <1>) 256)){ 252 | C[i, j] = C[i, j] + modified_A[i, k] * B[k, j] 253 | } 254 | } 255 | } 256 | } 257 | ``` 258 | 可以看到名字的确改了。这样,我们可以利用IRMutator实现很多不同的pass。 -------------------------------------------------------------------------------- /include/IR.h: -------------------------------------------------------------------------------- 1 | /* 2 | * MIT License 3 | * 4 | * Copyright (c) 2020 Size Zheng 5 | 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy 7 | * of this software and associated documentation files (the "Software"), to deal 8 | * in the Software without restriction, including without limitation the rights 9 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | * copies of the Software, and to permit persons to whom the Software is 11 | * furnished to do so, subject to the following conditions: 12 | 13 | * The above copyright notice and this permission notice shall be included in all 14 | * copies or substantial portions of the Software. 15 | 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | * SOFTWARE. 23 | */ 24 | 25 | #ifndef BOOST_IR_H 26 | #define BOOST_IR_H 27 | 28 | #include 29 | #include 30 | 31 | #include "type.h" 32 | #include "arith.h" 33 | #include "debug.h" 34 | 35 | namespace Boost { 36 | 37 | namespace Internal { 38 | 39 | /** 40 | * This class is inspired by Halide IntrusivePtr 41 | * The difference is that we use std::shared_ptr 42 | */ 43 | template 44 | class Ref { 45 | 46 | protected: 47 | std::shared_ptr ptr = nullptr; 48 | 49 | public: 50 | Ref() {} 51 | 52 | Ref(Ref &other) : ptr(other.ptr) {} 53 | 54 | Ref(Ref &&other) : ptr(std::move(other.ptr)) {} 55 | 56 | /** 57 | * allow constructing from sub-class 58 | */ 59 | template::value>::type* = nullptr> 60 | Ref(Ref &other) : ptr(other.real_ptr()) {} 61 | 62 | template::value>::type* = nullptr> 63 | Ref(Ref &&other) : ptr(std::move(other.real_ptr())) {} 64 | 65 | /** 66 | * allow constructing from shared_ptr of sub-class 67 | */ 68 | template::value>::type* = nullptr> 69 | Ref(std::shared_ptr _ptr) : ptr(_ptr) {} 70 | 71 | bool defined() const { return ptr != nullptr; } 72 | 73 | T *get() const { return ptr.get(); } 74 | 75 | /** 76 | * have to expose inner shared_ptr, required by constructor 77 | */ 78 | void set_ptr(std::shared_ptr other) { 79 | this->ptr = other; 80 | } 81 | 82 | std::shared_ptr real_ptr() const { 83 | return this->ptr; 84 | } 85 | 86 | T &operator*() const { return *ptr; } 87 | 88 | T *operator->() const { return ptr.operator->(); } 89 | 90 | Ref &operator=(Ref &b) { 91 | this->ptr = b.ptr; 92 | return *this; 93 | } 94 | 95 | Ref &operator=(Ref &&b) { 96 | swap(this->ptr, b.ptr); 97 | return *this; 98 | } 99 | 100 | bool operator<(Ref &b) const { 101 | /* Don't directly compare shared_ptr, for C++20 removes operator< */ 102 | return this->get() < b.get(); 103 | } 104 | }; 105 | 106 | /** 107 | * different type of IRNodes 108 | */ 109 | enum class IRNodeType : short { 110 | // Groups 111 | Kernel, 112 | // Stmts 113 | LoopNest, 114 | IfThenElse, 115 | Move, 116 | // Exprs 117 | Unary, 118 | Binary, 119 | Select, 120 | Compare, 121 | Call, 122 | Var, 123 | Cast, 124 | Ramp, 125 | Index, 126 | IntImm, 127 | UIntImm, 128 | FloatImm, 129 | StringImm, 130 | Dom 131 | }; 132 | 133 | 134 | /** 135 | * forward declaration 136 | */ 137 | class IRVisitor; 138 | class IRMutator; 139 | class Expr; 140 | class Stmt; 141 | class Group; 142 | 143 | 144 | /** 145 | * this is the base class of all IR nodes 146 | */ 147 | class IRNode { 148 | public: 149 | IRNode(const IRNodeType _type) : _node_type(_type) {} 150 | 151 | IRNodeType node_type() const { 152 | return this->_node_type; 153 | } 154 | 155 | /** 156 | * for IRVisitor 157 | */ 158 | virtual void visit_node(IRVisitor *visitor) const = 0; 159 | 160 | private: 161 | /** 162 | * indicate the concrete type of this IR node 163 | */ 164 | IRNodeType _node_type; 165 | }; 166 | 167 | 168 | /** 169 | * base node of expression 170 | */ 171 | class ExprNode : public IRNode { 172 | private: 173 | Type type_; 174 | public: 175 | ExprNode(Type _type, const IRNodeType node_type) : IRNode(node_type), type_(_type) {} 176 | 177 | virtual ~ExprNode() = default; 178 | 179 | virtual Expr mutate_expr(IRMutator *mutator) const = 0; 180 | 181 | Type type() const { 182 | return type_; 183 | } 184 | }; 185 | 186 | 187 | /** 188 | * base node of statement 189 | */ 190 | class StmtNode : public IRNode { 191 | private: 192 | 193 | public: 194 | StmtNode(IRNodeType _type) : IRNode(_type) {} 195 | 196 | virtual ~StmtNode() = default; 197 | 198 | virtual Stmt mutate_stmt(IRMutator *mutator) const = 0; 199 | }; 200 | 201 | 202 | /** 203 | * base node of group 204 | */ 205 | class GroupNode : public IRNode { 206 | private: 207 | 208 | public: 209 | GroupNode(IRNodeType _type) : IRNode(_type) {} 210 | 211 | virtual ~GroupNode() = default; 212 | 213 | virtual Group mutate_group(IRMutator *mutator) const = 0; 214 | }; 215 | 216 | 217 | /** 218 | * inherited from Halide 219 | */ 220 | class IntImm : public ExprNode, public std::enable_shared_from_this { 221 | private: 222 | int64_t value_; 223 | public: 224 | IntImm(Type _type, const int64_t _value) : ExprNode(_type, IRNodeType::IntImm), value_(_value) {} 225 | 226 | /** 227 | * May need consider bits 228 | */ 229 | int64_t value() const { 230 | return value_; 231 | } 232 | 233 | Expr mutate_expr(IRMutator *mutator) const; 234 | void visit_node(IRVisitor *visitor) const; 235 | 236 | static Ref make(Type t, const int64_t _value) { 237 | return std::make_shared(t, _value); 238 | } 239 | 240 | static const IRNodeType node_type_ = IRNodeType::IntImm; 241 | }; 242 | 243 | 244 | /** 245 | * inherited from Halide 246 | */ 247 | class UIntImm : public ExprNode, public std::enable_shared_from_this { 248 | private: 249 | uint64_t value_; 250 | public: 251 | UIntImm(Type _type, const uint64_t _value) : ExprNode(_type, IRNodeType::UIntImm), value_(_value) {} 252 | 253 | /** 254 | * May need consider bits 255 | */ 256 | uint64_t value() const { 257 | return value_; 258 | } 259 | 260 | Expr mutate_expr(IRMutator *mutator) const; 261 | void visit_node(IRVisitor *visitor) const; 262 | 263 | static Ref make(Type t, const uint64_t _value) { 264 | return std::make_shared(t, _value); 265 | } 266 | 267 | static const IRNodeType node_type_ = IRNodeType::UIntImm; 268 | }; 269 | 270 | 271 | /** 272 | * inherited from Halide 273 | */ 274 | class FloatImm : public ExprNode, public std::enable_shared_from_this { 275 | private: 276 | double value_; 277 | public: 278 | FloatImm(Type _type, const double _value) : ExprNode(_type, IRNodeType::FloatImm), value_(_value) {} 279 | 280 | /** 281 | * May need consider bits 282 | */ 283 | double value() const { 284 | return value_; 285 | } 286 | 287 | Expr mutate_expr(IRMutator *mutator) const; 288 | void visit_node(IRVisitor *visitor) const; 289 | 290 | static Ref make(Type t, const double _value) { 291 | return std::make_shared(t, _value); 292 | } 293 | 294 | static const IRNodeType node_type_ = IRNodeType::FloatImm; 295 | }; 296 | 297 | 298 | /** 299 | * inherited from Halide 300 | */ 301 | class StringImm : public ExprNode, public std::enable_shared_from_this { 302 | private: 303 | std::string value_; 304 | public: 305 | StringImm(Type _type, const std::string _value) : 306 | ExprNode(_type, IRNodeType::StringImm), value_(_value) {} 307 | 308 | std::string value() const { 309 | return value_; 310 | } 311 | 312 | Expr mutate_expr(IRMutator *mutator) const; 313 | void visit_node(IRVisitor *visitor) const; 314 | 315 | static Ref make(Type t, const std::string _value) { 316 | return std::make_shared(t, _value); 317 | } 318 | 319 | static const IRNodeType node_type_ = IRNodeType::StringImm; 320 | }; 321 | 322 | 323 | /** 324 | * a reference to expression 325 | */ 326 | class Expr : public Ref { 327 | public: 328 | Expr() : Ref() {} 329 | 330 | Expr(const Expr &other) : Ref(other.real_ptr()) {} 331 | 332 | Expr(const Expr &&other) : Ref(other.real_ptr()) {} 333 | 334 | template::value>::type* = nullptr> 336 | Expr(Ref &other) : Ref(other) {} 337 | 338 | template::value>::type* = nullptr> 340 | Expr(Ref &&other) : Ref(std::move(other)) {} 341 | 342 | template::value>::type* = nullptr> 344 | Expr(std::shared_ptr _ptr) : Ref(_ptr) {} 345 | 346 | /** 347 | * convenient constructors 348 | */ 349 | explicit Expr(bool value) : 350 | Ref(UIntImm::make(Type::uint_scalar(1), static_cast(value))) {} 351 | 352 | explicit Expr(uint8_t value) : 353 | Ref(UIntImm::make(Type::uint_scalar(8), static_cast(value))) {} 354 | 355 | explicit Expr(uint16_t value) : 356 | Ref(UIntImm::make(Type::uint_scalar(16), static_cast(value))) {} 357 | 358 | explicit Expr(uint32_t value) : 359 | Ref(UIntImm::make(Type::uint_scalar(32), static_cast(value))) {} 360 | 361 | explicit Expr(uint64_t value) : 362 | Ref(UIntImm::make(Type::uint_scalar(64), value)) {} 363 | 364 | explicit Expr(int8_t value) : 365 | Ref(IntImm::make(Type::int_scalar(8), static_cast(value))) {} 366 | 367 | explicit Expr(int16_t value) : 368 | Ref(IntImm::make(Type::int_scalar(16), static_cast(value))) {} 369 | 370 | Expr(int value) : 371 | Ref(IntImm::make(Type::int_scalar(32), static_cast(value))) {} 372 | 373 | explicit Expr(int64_t value) : 374 | Ref(IntImm::make(Type::int_scalar(64), value)) {} 375 | 376 | explicit Expr(float value) : 377 | Ref(FloatImm::make(Type::float_scalar(32), static_cast(value))) {} 378 | 379 | Expr(double value) : 380 | Ref(FloatImm::make(Type::float_scalar(64), value)) {} 381 | 382 | Expr &operator=(const Expr &other) { 383 | this->set_ptr(other.real_ptr()); 384 | return *this; 385 | } 386 | 387 | IRNodeType node_type() const { 388 | return this->get()->node_type(); 389 | } 390 | 391 | Type type() const { 392 | return this->get()->type(); 393 | } 394 | 395 | void visit_expr(IRVisitor *visitor) const { 396 | return this->get()->visit_node(visitor); 397 | } 398 | 399 | Expr mutate_expr(IRMutator *mutator) const { 400 | return this->get()->mutate_expr(mutator); 401 | } 402 | 403 | /** 404 | * cast to other type of reference 405 | */ 406 | template 407 | std::shared_ptr as() const { 408 | if (this->node_type() == T::node_type_) { 409 | return std::static_pointer_cast(this->real_ptr()); 410 | } 411 | return nullptr; 412 | } 413 | }; 414 | 415 | 416 | /** 417 | * a reference to statement 418 | */ 419 | class Stmt : public Ref { 420 | public: 421 | Stmt() : Ref() {} 422 | 423 | Stmt(const Stmt &other) : Ref(other.real_ptr()) {} 424 | 425 | Stmt(const Stmt &&other) : Ref(other.real_ptr()) {} 426 | 427 | template::value>::type* = nullptr> 428 | Stmt(Ref &other) : Ref(other) {} 429 | 430 | template::value>::type* = nullptr> 431 | Stmt(Ref &&other) : Ref(std::move(other)) {} 432 | 433 | template::value>::type* = nullptr> 434 | Stmt(std::shared_ptr _ptr) : Ref(_ptr) {} 435 | 436 | Stmt(std::shared_ptr _ptr) : Ref(_ptr) {} 437 | 438 | Stmt &operator=(const Stmt &other) { 439 | this->set_ptr(other.real_ptr()); 440 | return *this; 441 | } 442 | 443 | IRNodeType node_type() const { 444 | return this->get()->node_type(); 445 | } 446 | 447 | void visit_stmt(IRVisitor *visitor) const { 448 | return this->get()->visit_node(visitor); 449 | } 450 | 451 | Stmt mutate_stmt(IRMutator *mutator) const { 452 | return this->get()->mutate_stmt(mutator); 453 | } 454 | 455 | /** 456 | * cast to other type of reference 457 | */ 458 | template 459 | std::shared_ptr as() const { 460 | if (this->node_type() == T::node_type_) { 461 | return std::static_pointer_cast(this->real_ptr()); 462 | } 463 | return nullptr; 464 | } 465 | }; 466 | 467 | 468 | /** 469 | * a reference to group 470 | */ 471 | class Group : public Ref { 472 | public: 473 | Group() : Ref() {} 474 | 475 | Group(const Group &other) : Ref(other.real_ptr()) {} 476 | 477 | Group(const Group &&other) : Ref(other.real_ptr()) {} 478 | 479 | template::value>::type* = nullptr> 480 | Group(Ref &other) : Ref(other) {} 481 | 482 | template::value>::type* = nullptr> 483 | Group(Ref &&other) : Ref(std::move(other)) {} 484 | 485 | template::value>::type* = nullptr> 486 | Group(std::shared_ptr _ptr) : Ref(_ptr) {} 487 | 488 | Group(std::shared_ptr _ptr) : Ref(_ptr) {} 489 | 490 | Group &operator=(const Group &other) { 491 | this->set_ptr(other.real_ptr()); 492 | return *this; 493 | } 494 | 495 | IRNodeType node_type() const { 496 | return this->get()->node_type(); 497 | } 498 | 499 | void visit_group(IRVisitor *visitor) const { 500 | return this->get()->visit_node(visitor); 501 | } 502 | 503 | Group mutate_group(IRMutator *mutator) const { 504 | return this->get()->mutate_group(mutator); 505 | } 506 | 507 | /** 508 | * cast to other type of reference 509 | */ 510 | template 511 | std::shared_ptr as() const { 512 | if (this->node_type() == T::node_type_) { 513 | return std::static_pointer_cast(this->real_ptr()); 514 | } 515 | return nullptr; 516 | } 517 | }; 518 | 519 | 520 | enum class UnaryOpType : uint8_t { 521 | Neg, /* negate */ 522 | Not /* logic not */ 523 | }; 524 | 525 | 526 | /** 527 | * unary operation 528 | */ 529 | class Unary : public ExprNode, public std::enable_shared_from_this { 530 | public: 531 | UnaryOpType op_type; 532 | Expr a; 533 | 534 | Unary(Type _type, UnaryOpType _op_type, Expr _a) : ExprNode(_type, IRNodeType::Unary), 535 | op_type(_op_type), a(_a) {} 536 | 537 | Expr mutate_expr(IRMutator *mutator) const; 538 | void visit_node(IRVisitor *visitor) const; 539 | 540 | static Expr make(Type t, UnaryOpType _op_type, Expr _a) { 541 | return std::make_shared(t, _op_type, _a); 542 | } 543 | 544 | static const IRNodeType node_type_ = IRNodeType::Unary; 545 | }; 546 | 547 | 548 | enum class BinaryOpType : uint8_t { 549 | Add, 550 | Sub, 551 | Mul, 552 | Div, 553 | Mod, 554 | And, 555 | Or, 556 | }; 557 | 558 | 559 | /** 560 | * binary operation 561 | */ 562 | class Binary : public ExprNode, public std::enable_shared_from_this { 563 | public: 564 | BinaryOpType op_type; 565 | Expr a, b; 566 | 567 | Binary(Type _type, BinaryOpType _op_type, Expr _a, Expr _b) : ExprNode(_type, IRNodeType::Binary), 568 | op_type(_op_type), a(_a), b(_b) {} 569 | 570 | Expr mutate_expr(IRMutator *mutator) const; 571 | void visit_node(IRVisitor *visitor) const; 572 | 573 | static Expr make(Type t, BinaryOpType _op_type, Expr _a, Expr _b) { 574 | return std::make_shared(t, _op_type, _a, _b); 575 | } 576 | 577 | static const IRNodeType node_type_ = IRNodeType::Binary; 578 | }; 579 | 580 | 581 | enum class CompareOpType : uint8_t { 582 | LT, 583 | LE, 584 | EQ, 585 | NE, 586 | GE, 587 | GT 588 | }; 589 | 590 | 591 | /** 592 | * compare op <, <=, =, !=, >=, > 593 | */ 594 | class Compare : public ExprNode, public std::enable_shared_from_this { 595 | public: 596 | CompareOpType op_type; 597 | Expr a, b; 598 | 599 | Compare(Type _type, CompareOpType _op_type, Expr _a, Expr _b) : ExprNode(_type, IRNodeType::Compare), 600 | op_type(_op_type), a(_a), b(_b) {} 601 | 602 | Expr mutate_expr(IRMutator *mutator) const; 603 | void visit_node(IRVisitor *visitor) const; 604 | 605 | static Expr make(Type t, CompareOpType _op_type, Expr _a, Expr _b) { 606 | return std::make_shared(t, _op_type, _a, _b); 607 | } 608 | 609 | static const IRNodeType node_type_ = IRNodeType::Compare; 610 | }; 611 | 612 | 613 | /** 614 | * select op: cond? true_value : false_value 615 | */ 616 | class Select : public ExprNode, public std::enable_shared_from_this