├── CMakeLists.txt ├── README.md ├── README.pdf ├── bert_graphs ├── bert_layout_nw.onnx ├── bert_sequential_nw.onnx ├── bert_subst_nw.onnx └── bert_xflow_nw.onnx ├── cmake ├── FindCUDA.cmake └── SetCUDA.cmake ├── config.cmake ├── examples ├── bert.py ├── eval_groups.py ├── eval_joint.py ├── nasnet_a.py ├── nasrnn.py ├── resnet50.py └── resnext50.py ├── graph_subst.pb ├── include └── xflow │ ├── cuda_helper.h │ └── ops.h ├── python ├── sample.py ├── setup.py ├── test.py └── xflow │ ├── __init__.py │ └── _cython │ ├── CCore.pxd │ └── core.pyx ├── rules_pb2.py ├── src ├── core │ ├── activation.cc │ ├── batchnorm.cc │ ├── concat.cc │ ├── constant.cc │ ├── conv2d.cc │ ├── element.cc │ ├── enlarge.cc │ ├── graph_to_trt.cc │ ├── matmul.cc │ ├── merge_gconv.cc │ ├── mul.cc │ ├── noop.cc │ ├── ops.cc │ ├── pool2d.cc │ ├── reshape.cc │ ├── rules.pb.cc │ ├── rules.pb.h │ ├── rules.proto │ ├── split.cc │ ├── substitution.cc │ ├── substitution.h │ └── transpose.cc ├── cudnn │ ├── activation_cudnn.cu │ ├── batchnorm_cudnn.cu │ ├── concat_cudnn.cu │ ├── constant_kernel.cu │ ├── conv2d_cudnn.cu │ ├── cuda_helper.cu │ ├── element_cudnn.cu │ ├── enlarge_cudnn.cu │ ├── matmul_cudnn.cu │ ├── merge_gconv_cudnn.cu │ ├── mul_cudnn.cu │ ├── ops_cudnn.cu │ ├── pool2d_cudnn.cu │ ├── reshape_cudnn.cu │ └── transpose_cudnn.cu └── generator │ ├── compile.sh │ ├── generator.cc │ ├── rules.pb.cc │ └── rules.pb.h ├── tensorflow_py ├── bert.py ├── nasnet_a.py ├── nasrnn.py ├── resnet50.py ├── resnext50.py └── shared_functions.py ├── validate_axioms.py └── verify.py /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.2) 2 | project(xflow LANGUAGES CXX CUDA) 3 | 4 | #required packages 5 | find_package(Protobuf REQUIRED) 6 | 7 | if (EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) 8 | include (${CMAKE_CURRENT_BINARY_DIR}/config.cmake) 9 | else() 10 | if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/config.cmake) 11 | include(${CMAKE_CURRENT_SOURCE_DIR}/config.cmake) 12 | endif() 13 | endif() 14 | 15 | #include directories 16 | include_directories("include") 17 | include_directories("src/core") 18 | 19 | #initial variables 20 | set(XF_LIBS "") 21 | set(XF_LINK_LIBS ${CMAKE_DL_LIBS}) 22 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 23 | 24 | file(GLOB_RECURSE XF_SRCS 25 | src/core/*.cc 26 | ) 27 | 28 | file(GLOB_RECURSE XF_CUDA_SRCS 29 | src/cudnn/*.cu 30 | ) 31 | 32 | list(APPEND XF_SRCS ${XF_CUDA_SRCS}) 33 | 34 | #Generic compilation options 35 | include(CheckCXXCompilerFlag) 36 | check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) 37 | if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") 38 | message("Build in Debug mode") 39 | set(CMAKE_CUDA_FLAGS "-O0 -g -arch compute_70 -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}") 40 | set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}") 41 | else() 42 | set(CMAKE_CUDA_FLAGS "-O2 -arch compute_70 -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}") 43 | set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC ${CMAKE_CXX_FLAGS}") 44 | endif() 45 | 46 | #set CUDA 47 | include(cmake/FindCUDA.cmake) 48 | find_cuda(${USE_CUDA}) 49 | if (CUDA_FOUND) 50 | include_directories(${CUDA_INCLUDE_DIRS}) 51 | message(STATUS "CUDA_INCLUDE_DIR=" ${CUDA_INCLUDE_DIRS}) 52 | add_definitions(-DUSE_CUDNN) 53 | list(APPEND XF_LINKER_LIBS ${CUDA_CUDART_LIBRARY}) 54 | list(APPEND XF_LINKER_LIBS ${CUDA_CUDA_LIBRARY}) 55 | list(APPEND XF_LINKER_LIBS ${CUDA_CUDNN_LIBRARY}) 56 | list(APPEND XF_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY}) 57 | elseif(CUDA_FOUND) 58 | message(FATAL_ERROR "Cannot find CUDA, USE_CUDA=" ${USE_CUDA}) 59 | endif(CUDA_FOUND) 60 | 61 | #set ProtocolBuffer 62 | message(STATUS "PROTOBUF=" ${PROTOBUF_LIBRARY}) 63 | list(APPEND XF_LINKER_LIBS ${PROTOBUF_LIBRARY}) 64 | 65 | add_library(xf_runtime SHARED ${XF_SRCS}) 66 | 67 | set_target_properties(xf_runtime 68 | PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 69 | 70 | target_compile_features(xf_runtime PUBLIC cxx_std_11) 71 | 72 | target_link_libraries(xf_runtime ${XF_LINKER_LIBS}) 73 | 74 | target_include_directories(xf_runtime 75 | PUBLIC ${PROJECT_SOURCE_DIR}/include) 76 | 77 | set_target_properties(xf_runtime 78 | PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 79 | #install library 80 | install(TARGETS xf_runtime 81 | LIBRARY DESTINATION lib) 82 | 83 | install (DIRECTORY ${PROJECT_SOURCE_DIR}/include 84 | DESTINATION .) 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Artifacts for SOSP'19 paper Optimizing Deep Learning Computation with Automatic Generation of Graph Substitutions 2 | -------------------------------------------------------------------------------- /README.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiazhihao/sosp19ae/6946461166180d2da719cd1b19ebe37ca97b2b5f/README.pdf -------------------------------------------------------------------------------- /bert_graphs/bert_layout_nw.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiazhihao/sosp19ae/6946461166180d2da719cd1b19ebe37ca97b2b5f/bert_graphs/bert_layout_nw.onnx -------------------------------------------------------------------------------- /bert_graphs/bert_sequential_nw.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiazhihao/sosp19ae/6946461166180d2da719cd1b19ebe37ca97b2b5f/bert_graphs/bert_sequential_nw.onnx -------------------------------------------------------------------------------- /bert_graphs/bert_subst_nw.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiazhihao/sosp19ae/6946461166180d2da719cd1b19ebe37ca97b2b5f/bert_graphs/bert_subst_nw.onnx -------------------------------------------------------------------------------- /bert_graphs/bert_xflow_nw.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiazhihao/sosp19ae/6946461166180d2da719cd1b19ebe37ca97b2b5f/bert_graphs/bert_xflow_nw.onnx -------------------------------------------------------------------------------- /cmake/FindCUDA.cmake: -------------------------------------------------------------------------------- 1 | ####################################################### 2 | # Enhanced version of find CUDA. 3 | # 4 | # Usage: 5 | # find_cuda(${USE_CUDA}) 6 | # 7 | # - When USE_CUDA=ON, use auto search 8 | # - When USE_CUDA=/path/to/cuda-path, use the cuda path 9 | # 10 | # Provide variables: 11 | # 12 | # - CUDA_FOUND 13 | # - CUDA_INCLUDE_DIRS 14 | # - CUDA_TOOLKIT_ROOT_DIR 15 | # - CUDA_CUDA_LIBRARY 16 | # - CUDA_CUDART_LIBRARY 17 | # - CUDA_NVRTC_LIBRARY 18 | # - CUDA_CUDNN_LIBRARY 19 | # - CUDA_CUBLAS_LIBRARY 20 | # 21 | macro(find_cuda use_cuda) 22 | set(__use_cuda ${use_cuda}) 23 | if(__use_cuda STREQUAL "ON") 24 | find_package(CUDA QUIET) 25 | elseif(IS_DIRECTORY ${__use_cuda}) 26 | set(CUDA_TOOLKIT_ROOT_DIR ${__use_cuda}) 27 | message(STATUS "Custom CUDA_PATH=" ${CUDA_TOOLKIT_ROOT_DIR}) 28 | set(CUDA_INCLUDE_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/include) 29 | set(CUDA_FOUND TRUE) 30 | if(MSVC) 31 | find_library(CUDA_CUDART_LIBRARY cudart 32 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 33 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32) 34 | else(MSVC) 35 | find_library(CUDA_CUDART_LIBRARY cudart 36 | ${CUDA_TOOLKIT_ROOT_DIR}/lib64 37 | ${CUDA_TOOLKIT_ROOT_DIR}/lib) 38 | endif(MSVC) 39 | endif() 40 | 41 | # additional libraries 42 | if(CUDA_FOUND) 43 | if(MSVC) 44 | find_library(CUDA_CUDA_LIBRARY cuda 45 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 46 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32) 47 | find_library(CUDA_NVRTC_LIBRARY nvrtc 48 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 49 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32) 50 | find_library(CUDA_CUDNN_LIBRARY cudnn 51 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 52 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32) 53 | find_library(CUDA_CUBLAS_LIBRARY cublas 54 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 55 | ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32) 56 | else(MSVC) 57 | find_library(_CUDA_CUDA_LIBRARY cuda 58 | PATHS ${CUDA_TOOLKIT_ROOT_DIR} 59 | PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs 60 | NO_DEFAULT_PATH) 61 | if(_CUDA_CUDA_LIBRARY) 62 | set(CUDA_CUDA_LIBRARY ${_CUDA_CUDA_LIBRARY}) 63 | endif() 64 | find_library(CUDA_NVRTC_LIBRARY nvrtc 65 | PATHS ${CUDA_TOOLKIT_ROOT_DIR} 66 | PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu 67 | NO_DEFAULT_PATH) 68 | find_library(CUDA_CUDNN_LIBRARY cudnn 69 | ${CUDA_TOOLKIT_ROOT_DIR}/lib64 70 | ${CUDA_TOOLKIT_ROOT_DIR}/lib) 71 | find_library(CUDA_CUBLAS_LIBRARY cublas 72 | ${CUDA_TOOLKIT_ROOT_DIR}/lib64 73 | ${CUDA_TOOLKIT_ROOT_DIR}/lib) 74 | endif(MSVC) 75 | message(STATUS "Found CUDA_TOOLKIT_ROOT_DIR=" ${CUDA_TOOLKIT_ROOT_DIR}) 76 | message(STATUS "Found CUDA_CUDA_LIBRARY=" ${CUDA_CUDA_LIBRARY}) 77 | message(STATUS "Found CUDA_CUDART_LIBRARY=" ${CUDA_CUDART_LIBRARY}) 78 | message(STATUS "Found CUDA_NVRTC_LIBRARY=" ${CUDA_NVRTC_LIBRARY}) 79 | message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY}) 80 | message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY}) 81 | endif(CUDA_FOUND) 82 | endmacro(find_cuda) 83 | -------------------------------------------------------------------------------- /cmake/SetCUDA.cmake: -------------------------------------------------------------------------------- 1 | # CUDA Module 2 | find_cuda(${USE_CUDA}) 3 | 4 | if(CUDA_FOUND) 5 | # always set the includedir when cuda is available 6 | # avoid global retrigger of cmake 7 | include_directories(${CUDA_INCLUDE_DIRS}) 8 | endif(CUDA_FOUND) 9 | 10 | if(USE_CUDA) 11 | if(NOT CUDA_FOUND) 12 | message(FATAL_ERROR "Cannot find CUDA, USE_CUDA=" ${USE_CUDA}) 13 | endif() 14 | message(STATUS "Build with CUDA support") 15 | file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc) 16 | list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS}) 17 | list(APPEND COMPILER_SRCS src/codegen/opt/build_cuda_on.cc) 18 | 19 | list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIBRARY}) 20 | list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY}) 21 | list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY}) 22 | list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_NVRTC_LIBRARY}) 23 | 24 | if(USE_CUDNN) 25 | message(STATUS "Build with cuDNN support") 26 | file(GLOB CONTRIB_CUDNN_SRCS src/contrib/cudnn/*.cc) 27 | list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_SRCS}) 28 | list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY}) 29 | endif(USE_CUDNN) 30 | 31 | if(USE_CUBLAS) 32 | message(STATUS "Build with cuBLAS support") 33 | file(GLOB CONTRIB_CUBLAS_SRCS src/contrib/cublas/*.cc) 34 | list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS}) 35 | list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY}) 36 | endif(USE_CUBLAS) 37 | 38 | else(USE_CUDA) 39 | list(APPEND COMPILER_SRCS src/codegen/opt/build_cuda_off.cc) 40 | endif(USE_CUDA) 41 | -------------------------------------------------------------------------------- /config.cmake: -------------------------------------------------------------------------------- 1 | set(USE_CUDA ON) 2 | set(USE_CUDNN ON) 3 | -------------------------------------------------------------------------------- /examples/bert.py: -------------------------------------------------------------------------------- 1 | import xflow as xf 2 | 3 | seq_length = 64 4 | hidden_dims = 1024 5 | 6 | def attention(graph, input, heads): 7 | d_model = input.dim(1) 8 | d_k = d_model // heads 9 | assert input.dim(1) % heads == 0 10 | weights = list() 11 | for i in range(3): 12 | weights.append(graph.new_weight(dims=(d_model, d_model))) 13 | # compute query, key, value tensors 14 | q = graph.matmul(input, weights[0]) 15 | k = graph.matmul(input, weights[1]) 16 | v = graph.matmul(input, weights[2]) 17 | # reshape query, key, value to multiple heads 18 | q = graph.reshape(q, shape=(64,16,64)) 19 | k = graph.reshape(k, shape=(64,16,64)) 20 | v = graph.reshape(v, shape=(64,16,64)) 21 | # transpose query, key, value for batched matmul 22 | q = graph.transpose(q, perm=(1,0,2), shuffle=True) 23 | k = graph.transpose(k, perm=(1,0,2), shuffle=True) 24 | v = graph.transpose(v, perm=(1,0,2), shuffle=True) 25 | # perform matrix multiplications 26 | logits = graph.matmul(q, k) 27 | output = graph.matmul(logits, v) 28 | # transpose the output back 29 | output = graph.transpose(output,perm=(1,0,2), shuffle=True) 30 | output = graph.reshape(output, shape=(64, 1024)) 31 | 32 | # a final linear layer 33 | linear = graph.new_weight(dims=(d_model, d_model)) 34 | output = graph.matmul(input, linear) 35 | return output 36 | 37 | graph = xf.new_graph() 38 | input = graph.new_input(dims=(seq_length, hidden_dims)) 39 | input = graph.relu(input) 40 | t = input 41 | for i in range(8): 42 | t = attention(graph, t, 16) 43 | 44 | new_graph = xf.optimize(graph, alpha=1.0, budget=100) 45 | -------------------------------------------------------------------------------- /examples/eval_groups.py: -------------------------------------------------------------------------------- 1 | import xflow 2 | 3 | graph = xflow.new_graph() 4 | input = graph.new_input(dims=(1,512,28,28)) 5 | input = graph.maxpool2d(input=input, kernels=(1,1), strides=(1,1), padding="SAME") 6 | # Printing the performance of different multi-branch convolutions 7 | graph.print_measurements() 8 | i = 1 9 | while i <= 32: 10 | print("Num. Convs Per Grop = {}".format(i)) 11 | weight = graph.new_weight(dims=(512,512//i,3,3)) 12 | t = graph.conv2d(input=input,weight=weight,strides=(1,1),padding="SAME", activation="RELU") 13 | i *= 2 14 | 15 | #weight1 = graph.new_weight(dims=(256,8,3,3)) 16 | #t1 = graph.conv2d(input=input,weight=weight1,strides=(1,1), padding="SAME", activation="RELU") 17 | #weight2 = graph.new_weight(dims=(256,16,3,3)) 18 | #t2 = graph.conv2d(input=input,weight=weight2,strides=(1,1), padding="SAME", activation="RELU") 19 | #weight3 = graph.new_weight(dims=(256,32,3,3)) 20 | #t3 = graph.conv2d(input=input,weight=weight3,strides=(1,1), padding="SAME", activation="RELU") 21 | #weight4 = graph.new_weight(dims=(256,64,3,3)) 22 | #t4 = graph.conv2d(input=input,weight=weight4,strides=(1,1), padding="SAME", activation="RELU") 23 | #weight5 = graph.new_weight(dims=(256,128,3,3)) 24 | #t5 = graph.conv2d(input=input,weight=weight5,strides=(1,1), padding="SAME", activation="RELU") 25 | #weight6 = graph.new_weight(dims=(256,256,3,3)) 26 | #t6 = graph.conv2d(input=input,weight=weight6,strides=(1,1), padding="SAME", activation="RELU") 27 | 28 | -------------------------------------------------------------------------------- /examples/eval_joint.py: -------------------------------------------------------------------------------- 1 | import xflow 2 | import onnx 3 | 4 | # 1. evaluate the performance by just considering substitution optimizations 5 | print("Measuring the performance of graph substitution optimizations (average of 1000 runs)") 6 | graph = xflow.load('bert_graphs/bert_subst_nw.onnx') 7 | print("XFlow: end-to-end inference time = {}ms".format(graph.run_time())) 8 | print() 9 | 10 | #2. evaluate the performance by just performing data layout optimizations 11 | print("Measuring the performance of data layout optimizations") 12 | graph = xflow.load('bert_graphs/bert_layout_nw.onnx') 13 | print("XFlow: end-to-end inference time = {}ms".format(graph.run_time())) 14 | print() 15 | 16 | #3. evaluate the performance by sequential optimizations 17 | print("Measuring the performance of sequential optimizations") 18 | graph = xflow.load('bert_graphs/bert_sequential_nw.onnx') 19 | print("XFlow: end-to-end inference time = {}ms".format(graph.run_time())) 20 | print() 21 | 22 | #4. evaluate the performance by joint optimizations 23 | print("Measuring the performance of joint optimizations") 24 | graph = xflow.load('bert_graphs/bert_xflow_nw.onnx') 25 | print("XFlow: end-to-end inference time = {}ms".format(graph.run_time())) 26 | print() 27 | 28 | -------------------------------------------------------------------------------- /examples/nasnet_a.py: -------------------------------------------------------------------------------- 1 | import xflow as xf 2 | 3 | def squeeze(graph, out_channels, input): 4 | weight = graph.new_weight(dims=(out_channels, input.dim(1), 1, 1)) 5 | return graph.conv2d(input=input, weight=weight, 6 | strides=(1, 1), padding="SAME", 7 | activation="RELU") 8 | 9 | def fit(graph, current, input): 10 | if input.dim(2) == current.dim(2): 11 | return squeeze(graph, current.dim(1), input) 12 | else: 13 | weight = graph.new_weight(dims=(current.dim(1), input.dim(1), 3, 3)) 14 | return graph.conv2d(input=input, weight=weight, strides=(2, 2), padding="SAME", activation="RELU") 15 | 16 | def seperable_conv(graph, input, out_channels, kernels, strides, padding, activation = "NONE"): 17 | assert input.dim(1) % out_channels == 0, "input.dim(1)={}, out_channels={}".format(input.dim(1), out_channels) 18 | weight1 = graph.new_weight(dims=(out_channels, input.dim(1) // out_channels, kernels[0], kernels[1])) 19 | t = graph.conv2d(input=input, weight=weight1, strides=strides, padding=padding) 20 | weight2 = graph.new_weight(dims=(out_channels, t.dim(1), 1, 1)) 21 | return graph.conv2d(input=t, weight=weight2, strides=(1, 1), padding="SAME", activation=activation) 22 | 23 | def normal_cell(graph, prev, cur, out_channels): 24 | cur = squeeze(graph, out_channels, cur) 25 | prev = fit(graph, cur, prev) 26 | ts = list() 27 | ts.append(seperable_conv(graph, input=cur, out_channels=out_channels, 28 | kernels=(3,3), strides=(1,1), padding="SAME")) 29 | ts.append(cur) 30 | ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, 31 | kernels=(3,3), strides=(1,1), padding="SAME")) 32 | ts.append(seperable_conv(graph, input=cur, out_channels=out_channels, 33 | kernels=(3,3), strides=(1,1), padding="SAME")) 34 | ts.append(graph.avgpool2d(input=cur, kernels=(3,3), strides=(1,1), padding="SAME")) 35 | ts.append(prev) 36 | ts.append(graph.avgpool2d(input=prev, kernels=(3,3), strides=(1,1), padding="SAME")) 37 | ts.append(graph.avgpool2d(input=prev, kernels=(3,3), strides=(1,1), padding="SAME")) 38 | ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, 39 | kernels=(3,3), strides=(1,1), padding="SAME")) 40 | ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, 41 | kernels=(3,3), strides=(1,1), padding="SAME")) 42 | assert len(ts) == 10, "Expected 10 tensors, got {}".format(len(ts)) 43 | outputs = list() 44 | for i in range(5): 45 | outputs.append(graph.add(ts[2*i], ts[2*i+1])) 46 | return graph.concat(1, outputs) 47 | 48 | def reduction_cell(graph, prev, cur, out_channels): 49 | cur = squeeze(graph, out_channels, cur) 50 | prev = fit(graph, cur, prev) 51 | ts = list() 52 | outputs = list() 53 | ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, 54 | kernels=(7,7), strides=(2,2), padding="SAME")) 55 | ts.append(seperable_conv(graph, input=cur, out_channels=out_channels, 56 | kernels=(5,5), strides=(2,2), padding="SAME")) 57 | outputs.append(graph.add(ts[0], ts[1])) 58 | ts.append(graph.maxpool2d(input=cur, kernels=(3,3), strides=(2,2), padding="SAME")) 59 | ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, 60 | kernels=(7,7), strides=(2,2), padding="SAME")) 61 | outputs.append(graph.add(ts[2], ts[3])) 62 | ts.append(graph.avgpool2d(input=cur, kernels=(3,3), strides=(2,2), padding="SAME")) 63 | ts.append(seperable_conv(graph, input=prev, out_channels=out_channels, 64 | kernels=(5,5), strides=(2,2), padding="SAME")) 65 | outputs.append(graph.add(ts[4], ts[5])) 66 | ts.append(graph.maxpool2d(input=cur, kernels=(3,3), strides=(2,2), padding="SAME")) 67 | ts.append(seperable_conv(graph, input=outputs[0], out_channels=out_channels, 68 | kernels=(3,3), strides=(1,1), padding="SAME")) 69 | outputs.append(graph.add(ts[6], ts[7])) 70 | ts.append(graph.avgpool2d(input=outputs[0], kernels=(3,3), strides=(1,1), padding="SAME")) 71 | ts.append(outputs[1]) 72 | outputs.append(graph.add(ts[8], ts[9])) 73 | return graph.concat(1, outputs) 74 | 75 | graph = xf.new_graph() 76 | input = graph.new_input(dims = (1, 128, 56, 56)) 77 | input = graph.maxpool2d(input=input, kernels=(1,1), strides=(1,1), padding="SAME") 78 | out_channels = 128 79 | for i in range(3): 80 | prev = input 81 | cur = input 82 | for j in range(5): 83 | t = normal_cell(graph, prev, cur, out_channels) 84 | prev = cur 85 | cur = t 86 | out_channels *= 2 87 | input = reduction_cell(graph, prev, cur, out_channels) 88 | new_graph = xf.optimize(graph, alpha=1.0, budget=100) 89 | #onnx_model = xf.export_onnx(new_graph) 90 | #onnx.checker.check_model(onnx_model) 91 | #onnx.save(onnx_model, "nasneta_xflow.onnx") 92 | -------------------------------------------------------------------------------- /examples/nasrnn.py: -------------------------------------------------------------------------------- 1 | import xflow as xf 2 | 3 | hidden_size = 512 4 | length = 5 5 | 6 | def combine(graph, x, h): 7 | w1 = graph.new_weight(dims=(hidden_size, x.dim(1))) 8 | w2 = graph.new_weight(dims=(hidden_size, h.dim(1))) 9 | return graph.add(graph.matmul(x, w1), graph.matmul(h, w2)) 10 | 11 | def nas_node(graph, input, x): 12 | t = list() 13 | for i in range(8): 14 | t.append(combine(graph, x, input)) 15 | midt = list() 16 | midt.append(graph.add(graph.relu(t[0]), graph.sigmoid(t[3]))) 17 | midt.append(graph.add(graph.sigmoid(t[1]), graph.tanh(t[2]))) 18 | midt.append(graph.mul(graph.sigmoid(t[4]), graph.tanh(t[5]))) 19 | midt.append(graph.mul(graph.sigmoid(t[6]), graph.relu(t[7]))) 20 | midt.append(graph.add(graph.sigmoid(midt[1]), graph.tanh(midt[2]))) 21 | midt.append(graph.mul(graph.tanh(midt[0]), graph.tanh(midt[3]))) 22 | midt.append(graph.mul(graph.tanh(midt[4]), graph.tanh(midt[5]))) 23 | return graph.tanh(midt[6]) 24 | 25 | graph = xf.new_graph() 26 | xs = list() 27 | for i in range(length): 28 | xs.append(graph.new_input(dims=(1, hidden_size))) 29 | state = graph.new_weight(dims=(1, hidden_size)) 30 | for i in range(length): 31 | state = nas_node(graph, state, xs[i]) 32 | new_graph = xf.optimize(graph, alpha=1.0, budget=100) 33 | -------------------------------------------------------------------------------- /examples/resnet50.py: -------------------------------------------------------------------------------- 1 | import xflow as xf 2 | 3 | def resnet_block(graph, input, strides, out_channels): 4 | w1 = graph.new_weight(dims=(out_channels,input.dim(1),1,1)) 5 | t = graph.conv2d(input=input, weight=w1, 6 | strides=(1,1), padding="SAME", 7 | activation="RELU") 8 | w2 = graph.new_weight(dims=(out_channels,t.dim(1),3,3)) 9 | t = graph.conv2d(input=t, weight=w2, 10 | strides=strides, padding="SAME", 11 | activation="RELU") 12 | w3 = graph.new_weight(dims=(4*out_channels,t.dim(1),1,1)) 13 | t = graph.conv2d(input=t, weight=w3, 14 | strides=(1,1), padding="SAME") 15 | if (strides[0]>1) or (input.dim(1) != out_channels*4): 16 | w4 = graph.new_weight(dims=(out_channels*4,input.dim(1),1,1)) 17 | input=graph.conv2d(input=input, weight=w4, 18 | strides=strides, padding="SAME", 19 | activation="RELU") 20 | return graph.relu(graph.add(input, t)) 21 | 22 | graph = xf.new_graph() 23 | input = graph.new_input(dims=(1,64,56,56)) 24 | t = input 25 | for i in range(3): 26 | t = resnet_block(graph, t, (1,1), 64) 27 | strides = (2,2) 28 | for i in range(4): 29 | t = resnet_block(graph, t, strides, 128) 30 | strides = (1,1) 31 | strides = (2,2) 32 | for i in range(6): 33 | t = resnet_block(graph, t, strides, 256) 34 | strides = (1,1) 35 | strides = (2,2) 36 | for i in range(3): 37 | t = resnet_block(graph, t, strides, 512) 38 | strides = (1,1) 39 | 40 | new_graph = xf.optimize(graph, alpha=1.0, budget=1000) 41 | #onnx_model = xf.export_onnx(new_graph) 42 | 43 | -------------------------------------------------------------------------------- /examples/resnext50.py: -------------------------------------------------------------------------------- 1 | import xflow as xf 2 | 3 | def resnext_block(graph, input, strides, out_channels, groups): 4 | w1 = graph.new_weight(dims=(out_channels,input.dim(1),1,1)) 5 | t = graph.conv2d(input=input, weight=w1, 6 | strides=(1,1), padding="SAME", 7 | activation="RELU") 8 | w2 = graph.new_weight(dims=(out_channels,t.dim(1)//groups,3,3)) 9 | t = graph.conv2d(input=t, weight=w2, 10 | strides=strides, padding="SAME", 11 | activation="RELU") 12 | w3 = graph.new_weight(dims=(2*out_channels,t.dim(1),1,1)) 13 | t = graph.conv2d(input=t, weight=w3, 14 | strides=(1,1), padding="SAME") 15 | if (strides[0]>1) or (input.dim(1) != out_channels*2): 16 | w4 = graph.new_weight(dims=(out_channels*2,input.dim(1),1,1)) 17 | input=graph.conv2d(input=input, weight=w4, 18 | strides=strides, padding="SAME", 19 | activation="RELU") 20 | return graph.relu(graph.add(input, t)) 21 | 22 | graph = xf.new_graph() 23 | input = graph.new_input(dims=(1,64,56,56)) 24 | t = input 25 | for i in range(3): 26 | t = resnext_block(graph, t, (1,1), 128, 64) 27 | strides = (2,2) 28 | for i in range(4): 29 | t = resnext_block(graph, t, strides, 256, 64) 30 | strides = (1,1) 31 | strides = (2,2) 32 | for i in range(6): 33 | t = resnext_block(graph, t, strides, 512, 64) 34 | strides = (1,1) 35 | strides = (2,2) 36 | for i in range(3): 37 | t = resnext_block(graph, t, strides, 1024, 64) 38 | strides = (1,1) 39 | 40 | new_graph = xf.optimize(graph, alpha=1.0, budget=100) 41 | #onnx_model = xf.export_onnx(new_graph) 42 | #onnx.checker.check_model(onnx_model) 43 | #onnx.save(onnx_model, "resnext50_xflow.onnx") 44 | -------------------------------------------------------------------------------- /graph_subst.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiazhihao/sosp19ae/6946461166180d2da719cd1b19ebe37ca97b2b5f/graph_subst.pb -------------------------------------------------------------------------------- /include/xflow/cuda_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_HELPER_H_ 2 | #define _CUDA_HELPER_H_ 3 | 4 | #include 5 | #include 6 | #include "xflow/ops.h" 7 | #include 8 | 9 | #define FatalError(s) do { \ 10 | std::stringstream _where, _message; \ 11 | _where << __FILE__ << ':' << __LINE__; \ 12 | _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ 13 | std::cerr << _message.str() << "\nAborting...\n"; \ 14 | exit(1); \ 15 | } while(0) 16 | 17 | #define checkCUDNN(status) do { \ 18 | std::stringstream _error; \ 19 | if (status != CUDNN_STATUS_SUCCESS) { \ 20 | _error << "CUDNN failure: " << cudnnGetErrorString(status); \ 21 | FatalError(_error.str()); \ 22 | } \ 23 | } while(0) 24 | 25 | #define checkCUDA(status) do { \ 26 | std::stringstream _error; \ 27 | if (status != 0) { \ 28 | _error << "Cuda failure: " << status; \ 29 | FatalError(_error.str()); \ 30 | } \ 31 | } while(0) 32 | 33 | // CUDA: grid stride looping 34 | #define CUDA_KERNEL_LOOP(i, n) \ 35 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) 36 | 37 | const int CUDA_NUM_THREADS = 1024; 38 | const int BLOCK_SIZE_LIMIT = 32768; 39 | 40 | // CUDA: number of blocks for threads. 41 | inline int GET_BLOCKS(const int N) 42 | { 43 | int ret = (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 44 | return (ret > BLOCK_SIZE_LIMIT) ? BLOCK_SIZE_LIMIT : ret; 45 | } 46 | 47 | void helperSetTensorDescriptor(const XFlow::Tensor& tensor, 48 | cudnnTensorDescriptor_t tensorDesc); 49 | 50 | __global__ 51 | void assign_kernel(float* ptr, int size, float value); 52 | 53 | cudnnActivationMode_t get_activation_mode(XFlow::ActiMode activation); 54 | #endif 55 | -------------------------------------------------------------------------------- /python/sample.py: -------------------------------------------------------------------------------- 1 | import xflow 2 | import onnx 3 | 4 | graph = xflow.new_graph() 5 | input = graph.new_input(dims=(1,256,28,28)) 6 | input = graph.maxpool2d(input=input, kernels=(1,1), strides=(1,1), padding="SAME") 7 | weight1 = graph.new_weight(dims=(256,8,3,3)) 8 | #weight2 = graph.new_weight(dims=(256,16,3,3)) 9 | #weight3 = graph.new_weight(dims=(256,32,3,3)) 10 | #weight4 = graph.new_weight(dims=(256,64,3,3)) 11 | #weight5 = graph.new_weight(dims=(256,128,3,3)) 12 | #weight6 = graph.new_weight(dims=(256,256,3,3)) 13 | t1 = graph.conv2d(input=input,weight=weight1,strides=(1,1), padding="SAME", activation="RELU") 14 | #t2 = graph.conv2d(input=input,weight=weight2,strides=(1,1), padding="SAME", activation="RELU") 15 | #t3 = graph.conv2d(input=input,weight=weight3,strides=(1,1), padding="SAME", activation="RELU") 16 | #t4 = graph.conv2d(input=input,weight=weight4,strides=(1,1), padding="SAME", activation="RELU") 17 | #t5 = graph.conv2d(input=input,weight=weight5,strides=(1,1), padding="SAME", activation="RELU") 18 | #t6 = graph.conv2d(input=input,weight=weight6,strides=(1,1), padding="SAME", activation="RELU") 19 | 20 | new_graph = xflow.optimize(graph, alpha=1.0, budget=100) 21 | onnx_model = xflow.export_onnx(new_graph) 22 | onnx.checker.check_model(onnx_model) 23 | onnx.save(onnx_model, "/home/ubuntu/ONNXModel/inception_v2/model_xflow.onnx") 24 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Stanford 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import os 16 | import sys 17 | import sysconfig 18 | from setuptools import find_packages 19 | 20 | # need to use distutils.core for correct placement of cython dll 21 | if "--inplace" in sys.argv: 22 | from distutils.core import setup 23 | from distutils.extension import Extension 24 | else: 25 | from setuptools import setup 26 | from setuptools.extension import Extension 27 | 28 | def config_cython(): 29 | sys_cflags = sysconfig.get_config_var("CFLAGS") 30 | try: 31 | from Cython.Build import cythonize 32 | ret = [] 33 | path = "xflow/_cython" 34 | for fn in os.listdir(path): 35 | if not fn.endswith(".pyx"): 36 | continue 37 | ret.append(Extension( 38 | "xflow.%s" % fn[:-4], 39 | ["%s/%s" % (path, fn)], 40 | include_dirs=["../include", "/usr/local/cuda/include"], 41 | libraries=["xf_runtime"], 42 | extra_compile_args=["-DUSE_CUDNN", "-std=c++11"], 43 | extra_link_args=[], 44 | language="c++")) 45 | return cythonize(ret, compiler_directives={"language_level" : 3}) 46 | except ImportError: 47 | print("WARNING: cython is not installed!!!") 48 | return [] 49 | 50 | setup_args = {} 51 | 52 | #if not os.getenv('CONDA_BUILD'): 53 | # curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) 54 | # for i, path in enumerate(LIB_LIST): 55 | # LIB_LIST[i] = os.path.relpath(path, curr_path) 56 | # setup_args = { 57 | # "include_package_data": True, 58 | # "data_files": [('xflow', LIB_LIST)] 59 | # } 60 | 61 | setup(name='xflow', 62 | #version=__version__, 63 | description="XFlow: A DNN Computation Graph Superoptimizer", 64 | zip_safe=False, 65 | install_requires=[], 66 | packages=find_packages(), 67 | url='https://github.com/jiazhihao/xflow', 68 | ext_modules=config_cython(), 69 | #**setup_args, 70 | ) 71 | 72 | -------------------------------------------------------------------------------- /python/test.py: -------------------------------------------------------------------------------- 1 | import xflow 2 | import onnx 3 | 4 | graph = xflow.load("/home/ubuntu/squeezenet1.1.onnx") 5 | #graph = xflow.load("/home/ubuntu/resnext-101.onnx") 6 | #graph = xflow.load("/home/ubuntu/ONNXModel/inception_v2/model.onnx") 7 | new_graph = xflow.optimize(graph, alpha = 1.0, budget = 100) 8 | onnx_model = xflow.export_onnx(new_graph) 9 | onnx.checker.check_model(onnx_model) 10 | onnx.save(onnx_model, "/home/ubuntu/ONNXModel/inception_v2/model_xflow.onnx") 11 | -------------------------------------------------------------------------------- /python/xflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | import onnx 3 | from onnx import helper, TensorProto, numpy_helper 4 | 5 | def _check_output(xf_output, onnx_output): 6 | # TODO: check output match 7 | return True 8 | 9 | def _parse_attribute(attributes): 10 | atts = dict() 11 | for att in attributes: 12 | if att.type == onnx.AttributeProto.INT: 13 | atts[att.name] = att.i 14 | elif att.type == onnx.AttributeProto.INTS: 15 | atts[att.name] = att.ints 16 | elif att.type == onnx.AttributeProto.FLOAT: 17 | atts[att.name] = att.f 18 | elif att.type == onnx.AttributeProto.STRING: 19 | atts[att.name] = att.s 20 | else: 21 | assert False, "Unsupported Attribute Type: {}".format(att.type) 22 | return atts 23 | 24 | def _get_inputs(op, tensors): 25 | inputs = list() 26 | for i in op.input: 27 | assert i in tensors, "Input tensor not found" 28 | inputs.append(tensors[i]) 29 | return inputs 30 | 31 | def _add(op, graph, tensors, initializer): 32 | inputs = _get_inputs(op, tensors) 33 | outputs = graph.add(inputs[0], inputs[1]); 34 | return outputs; 35 | 36 | def _batchnorm(op, graph, tensors, initializer): 37 | inputs = _get_inputs(op, tensors) 38 | attrs = _parse_attribute(op.attribute) 39 | outputs = graph.batchnorm(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4]) 40 | return outputs 41 | 42 | def _concat(op, graph, tensors, initializer): 43 | inputs = _get_inputs(op, tensors) 44 | attrs = _parse_attribute(op.attribute) 45 | axis = attrs["axis"] 46 | outputs = graph.concat(axis, inputs) 47 | return outputs 48 | 49 | def _conv2d(op, graph, tensors, initializer): 50 | inputs = _get_inputs(op, tensors) 51 | attrs = _parse_attribute(op.attribute) 52 | if "group" not in attrs: 53 | group = 1 # default 1 54 | else: 55 | group = attrs["group"] 56 | # Note that we always think conv1x1 has SAME padding 57 | # This will allow fusing enlarged convs 58 | if sum(attrs["pads"]) == 0 and sum(attrs['kernel_shape']) > 2: 59 | pads = "VALID" 60 | else: 61 | pads = "SAME" 62 | strides = attrs["strides"] 63 | outputs = graph.conv2d(input=inputs[0], weight=inputs[1], strides=strides, padding=pads) 64 | return outputs 65 | 66 | def _dropout(op, graph, tensors, initializer): 67 | inputs = _get_inputs(op, tensors) 68 | assert len(inputs) == 1, "Dropout requires exactly one input" 69 | attrs = _parse_attribute(op.attribute) 70 | rate = attrs["ratio"] 71 | outputs = graph.dropout(inputs[0], rate) 72 | return outputs 73 | 74 | def _matmul(op, graph, tensors, initializer): 75 | inputs = _get_inputs(op, tensors) 76 | assert len(inputs) == 2, "Matmul requires exactly two inputs" 77 | outputs = graph.matmul(inputs[0], inputs[1]) 78 | return outputs 79 | 80 | def _pad(op, graph, tensors, initializer): 81 | inputs = _get_inputs(op, tensors) 82 | attrs = _parse_attribute(op.attribute) 83 | # Currently treat pad as a no op 84 | assert sum(attrs['pads']) == 0 85 | return inputs 86 | 87 | def _maxpool2d(op, graph, tensors, initializer): 88 | assert len(op.input) == 1, "MaxPool2D requires exactly one input" 89 | assert op.input[0] in tensors 90 | attrs = _parse_attribute(op.attribute) 91 | kernels = attrs["kernel_shape"] 92 | strides = attrs["strides"] 93 | if sum(attrs["pads"]) == 0: 94 | pads = "VALID" 95 | else: 96 | pads = "SAME" 97 | outputs = graph.maxpool2d(input=tensors[op.input[0]], kernels=kernels, strides=strides, padding=pads) 98 | return outputs 99 | 100 | def _avgpool2d(op, graph, tensors, initializer): 101 | assert len(op.input) == 1, "MaxPool2D requires exactly one input" 102 | assert op.input[0] in tensors 103 | attrs = _parse_attribute(op.attribute) 104 | kernels = attrs["kernel_shape"] 105 | strides = attrs["strides"] 106 | if sum(attrs["pads"]) == 0: 107 | pads = "VALID" 108 | else: 109 | pads = "SAME" 110 | outputs = graph.avgpool2d(input=tensors[op.input[0]], kernels=kernels, strides=strides, padding=pads) 111 | return outputs 112 | 113 | def _reshape(op, graph, tensors, initializer): 114 | inputs = _get_inputs(op, tensors) 115 | assert len(inputs) == 2 116 | for data in initializer: 117 | if data.name == op.input[1]: 118 | shape = list() 119 | for dim in data.int64_data: 120 | shape.append(dim) 121 | outputs = graph.reshape(inputs[0], tuple(shape)) 122 | return outputs 123 | 124 | def _relu(op, graph, tensors, initializer): 125 | assert len(op.input) == 1, "Relu requires exactly one input" 126 | assert op.input[0] in tensors 127 | attrs = _parse_attribute(op.attribute) 128 | outputs = graph.relu(tensors[op.input[0]]) 129 | return outputs 130 | 131 | def _split(op, graph, tensors, initializer): 132 | assert len(op.input) == 1, "Split requires exactly one input" 133 | assert op.input[0] in tensors 134 | attrs = _parse_attribute(op.attribute) 135 | axis = attrs["axis"] 136 | split_ints = attrs["split"] 137 | split_list = list() 138 | for i in split_ints: 139 | split_list.append(i) 140 | outputs = graph.split(tensors[op.input[0]], axis, split_list) 141 | return outputs 142 | 143 | def _transpose(op, graph, tensors, initializer): 144 | assert len(op.input) == 1, "Transpose requires exactly one input" 145 | assert op.input[0] in tensors 146 | attrs = _parse_attribute(op.attribute) 147 | perm_ints = attrs["perm"] 148 | perm = list() 149 | for i in perm_ints: 150 | perm.append(i) 151 | outputs = graph.transpose(tensors[op.input[0]], tuple(perm), shuffle=True) 152 | return outputs 153 | 154 | # Add all supported operators 155 | xf_operators = dict() 156 | xf_operators['Add'] = _add 157 | xf_operators['BatchNormalization'] = _batchnorm 158 | xf_operators['Concat'] = _concat 159 | xf_operators['Conv'] = _conv2d 160 | xf_operators['Dropout'] = _dropout 161 | xf_operators['Pad'] = _pad 162 | xf_operators['Reshape'] = _reshape 163 | xf_operators['Relu'] = _relu 164 | xf_operators['Matmul'] = _matmul 165 | xf_operators['MaxPool'] = _maxpool2d 166 | xf_operators['AveragePool'] = _avgpool2d 167 | xf_operators['Split'] = _split 168 | xf_operators['Transpose'] = _transpose 169 | 170 | def new_graph(print_measurements = False): 171 | graph = core.PyGraph() 172 | if print_measurements: 173 | graph.print_measurements() 174 | return graph 175 | 176 | def load(filename): 177 | ''' 178 | Load a onnx file and return a Graph 179 | 180 | @params 181 | filename is a string containing a file name 182 | 183 | @return 184 | Loaded in-memory Graph 185 | ''' 186 | graph = core.PyGraph() 187 | model = onnx.load(filename) 188 | tensors = dict() 189 | for t in model.graph.input: 190 | dims = list() 191 | for d in t.type.tensor_type.shape.dim: 192 | dims.append(d.dim_value) 193 | if "data" in t.name: 194 | tensors[t.name] = graph.new_input(dims=tuple(dims)) 195 | else: 196 | weight_data = None 197 | for weight in model.graph.initializer: 198 | if (weight.name == t.name): 199 | weight_data = numpy_helper.to_array(weight) 200 | #assert(weight_data is not None) 201 | tensors[t.name] = graph.new_weight(dims=tuple(dims), data = weight_data) 202 | 203 | for op in model.graph.node: 204 | if op.op_type in xf_operators: 205 | outputs = xf_operators[op.op_type](op, graph, tensors, model.graph.initializer) 206 | if not isinstance(outputs, list): 207 | outputs = [outputs] 208 | assert len(outputs) == len(op.output), "Number of output tensors mismatch" 209 | for i in range(len(outputs)): 210 | assert _check_output(outputs[i], op.output[i]) 211 | tensors[op.output[i]] = outputs[i] 212 | else: 213 | assert False, "Unsupported ONNX operator: {}".format(op.op_type) 214 | return graph 215 | 216 | input_weight_names = dict() 217 | input_weight_names['Conv'] = ['input', 'weight', 'bias'] 218 | input_weight_names['Matmul'] = ['input', 'weight'] 219 | input_weight_names['Reshpe'] = ['input', 'shape'] 220 | 221 | operator_attrs = dict() 222 | operator_attrs['Add'] = [] 223 | operator_attrs['AveragePool'] = ['kernel_shape', 'pads', 'strides'] 224 | operator_attrs['Concat'] = ['axis'] 225 | operator_attrs['Conv'] = ['group', 'kernel_shape', 'pads', 'strides'] 226 | operator_attrs['Dropout'] = [] 227 | operator_attrs['Matmul'] = [] 228 | operator_attrs['MaxPool'] = ['kernel_shape', 'pads', 'strides'] 229 | operator_attrs['Split'] = ['axis', 'split'] 230 | operator_attrs['Relu'] = [] 231 | operator_attrs['Reshape'] = [] 232 | operator_attrs['Transpose'] = ['perm'] 233 | 234 | def _input_tensor_name(graph, inedge, op): 235 | intype = graph.get_operator_type(inedge['srcOp']) 236 | if intype == "Input": 237 | return "data" 238 | elif intype == "Weight": 239 | mytype = graph.get_operator_type(op) 240 | return "{}{}_{}".format(mytype, op['guid'], input_weight_names[mytype][inedge['dstIdx']]) 241 | else: 242 | return _output_tensor_name(graph, inedge['srcOp'], inedge['srcIdx']) 243 | 244 | def _output_tensor_name(graph, op, idx): 245 | type = graph.get_operator_type(op) 246 | return "{}{}_fwd{}".format(type, op['guid'], idx) 247 | 248 | def _add_node_attribute(graph, node, op, optype): 249 | for key in operator_attrs[optype]: 250 | val = graph.get_operator_attr(op, key) 251 | attr = helper.make_attribute(key, val) 252 | node.attribute.append(attr) 253 | 254 | def export_onnx(graph): 255 | ''' 256 | Export a XFlow graph to an ONNX graph 257 | 258 | @params 259 | graph is a XFlow graph 260 | 261 | @return 262 | A in-memory ONNX graph 263 | ''' 264 | opList = graph.get_operator_list() 265 | graph_nodes = list() 266 | graph_inputs = list() 267 | graph_initializers = list() 268 | graph_outputs = list() 269 | output_guids = dict() 270 | for op in opList: 271 | mytype = graph.get_operator_type(op) 272 | inedges = graph.get_input_edges(op) 273 | print("op.guid={} mytype={} inedges={}".format(op['guid'], mytype, len(inedges))) 274 | inputs = list() 275 | for e in inedges: 276 | intype = graph.get_operator_type(e['srcOp']) 277 | inputs.append(_input_tensor_name(graph, e, op)) 278 | output_guids.pop((e['srcOp']['guid'], e['srcIdx']), None) 279 | if intype == 'Input' or intype == 'Weight': 280 | graph_inputs.append(helper.make_tensor_value_info(_input_tensor_name(graph, e, op), 281 | TensorProto.FLOAT, graph.get_input_dims(op, e['dstIdx']))) 282 | if intype == 'Weight': 283 | graph_initializers.append(helper.make_tensor(_input_tensor_name(graph, e, op), 284 | TensorProto.FLOAT, graph.get_input_dims(op, e['dstIdx']), 285 | graph.get_weight_value(e['srcOp']))) 286 | 287 | # add a second input for Reshape 288 | if mytype == 'Reshape': 289 | inputs.append('Reshape_attr{}'.format(op['guid'])) 290 | shape = graph.get_output_dims(op, 0) 291 | graph_inputs.append(helper.make_tensor_value_info('Reshape_attr{}'.format(op['guid']), TensorProto.INT64, [len(shape)])) 292 | graph_initializers.append(helper.make_tensor('Reshape_attr{}'.format(op['guid']), TensorProto.INT64, [len(shape)], shape)) 293 | outputs = list() 294 | for i in range(graph.get_num_outputs(op)): 295 | outputs.append(_output_tensor_name(graph, op, i)) 296 | output_guids[(op['guid'], i)] = op 297 | node = helper.make_node(mytype, inputs, outputs, '{}{}'.format(mytype, op['guid'])) 298 | _add_node_attribute(graph, node, op, mytype) 299 | graph_nodes.append(node) 300 | for guid, idx in output_guids: 301 | op = output_guids[(guid, idx)] 302 | graph_outputs.append(helper.make_tensor_value_info(_output_tensor_name(graph, op, idx), 303 | TensorProto.FLOAT, graph.get_output_dims(op, idx))) 304 | onnx_graph = helper.make_graph(graph_nodes, 'main', graph_inputs, graph_outputs, graph_initializers) 305 | onnx_model = helper.make_model(onnx_graph, producer_name='XFlow Optimized Model') 306 | return onnx_model 307 | 308 | def optimize(graph, alpha = 1.0, budget = 1000): 309 | return graph.optimize(alpha, budget) 310 | -------------------------------------------------------------------------------- /python/xflow/_cython/CCore.pxd: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Stanford 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | #ccore.pxd 17 | 18 | from libcpp.memory cimport shared_ptr 19 | from libcpp.vector cimport vector 20 | from libcpp cimport bool 21 | 22 | cdef extern from "xflow/ops.h" namespace "XFlow": 23 | # This must be consistent with include/xflow/ops.h 24 | cdef enum OpType: 25 | OP_INPUT 26 | OP_WEIGHT 27 | OP_ANY 28 | OP_CONV2D 29 | OP_DROPOUT 30 | OP_LINEAR 31 | OP_POOL2D_MAX 32 | OP_POOL2D_AVG 33 | OP_RELU 34 | OP_SIGMOID 35 | OP_TANH 36 | OP_BATCHNORM 37 | OP_CONCAT 38 | OP_SPLIT 39 | OP_RESHAPE 40 | OP_TRANSPOSE 41 | # RNN operators 42 | OP_EW_ADD 43 | OP_EW_MUL 44 | OP_MATMUL 45 | OP_SCALARMUL 46 | OP_ENLARGE 47 | OP_MERGE_GCONV 48 | OP_CONSTANT_IMM, 49 | OP_CONSTANT_ICONV, 50 | OP_CONSTANT_ONE, 51 | OP_CONSTANT_POOL, 52 | 53 | # This must be consistent with include/xflow/ops.h 54 | cdef enum PMParameter: 55 | PM_OP_TYPE 56 | PM_NUM_INPUTS 57 | PM_NUM_OUTPUTS 58 | PM_GROUP 59 | PM_KERNEL_H 60 | PM_KERNEL_W 61 | PM_STRIDE_H 62 | PM_STRIDE_W 63 | PM_PAD 64 | PM_ACTI 65 | PM_NUMDIM 66 | PM_AXIS 67 | PM_PERM 68 | PM_OUTSHUFFLE 69 | PM_MERGE_GCONV_COUNT 70 | 71 | # This must be consistent with include/xflow/ops.h 72 | cdef enum ActiMode: 73 | AC_MODE_NONE 74 | AC_MODE_SIGMOID 75 | AC_MODE_RELU 76 | AC_MODE_TANH 77 | 78 | # This must be consistent with include/xflow/ops.h 79 | cdef enum PaddingMode: 80 | PD_MODE_SAME 81 | PD_MODE_VALID 82 | 83 | # This must be consistent with include/xflow/ops.h 84 | cdef enum ConstantMode: 85 | CN_MODE_IDENTITY 86 | CN_MODE_ZEROS 87 | CN_MODE_ONES 88 | CN_MODE_ONES_SCALED_L1 89 | CN_MODE_ONES_SCALED_L2 90 | CN_MODE_ONES_SCALED_ALL 91 | 92 | cdef cppclass Model: 93 | Model() 94 | 95 | # ctypedef struct SplitInfo: 96 | # int num 97 | # int pos[MAX_NUM_SPLITS] 98 | # 99 | # ctypedef cppclass OpBase: 100 | # pass 101 | 102 | ctypedef struct Op: 103 | size_t guid 104 | pass 105 | 106 | ctypedef struct Edge: 107 | Op srcOp 108 | Op dstOp 109 | int srcIdx 110 | int dstIdx 111 | 112 | ctypedef struct Tensor: 113 | int numDim 114 | int dim[4] 115 | int stride[4] # NOTE: this must be consistent with the C++ header 116 | pass 117 | # int idx 118 | # Op op 119 | # void* ptr 120 | # SplitInfo split[MAX_DIM] 121 | 122 | ctypedef Tensor* TensorHandle 123 | 124 | cdef cppclass Graph: 125 | Graph() 126 | TensorHandle batchnorm(const TensorHandle input, 127 | const TensorHandle scale, 128 | const TensorHandle bias, 129 | const TensorHandle mean, 130 | const TensorHandle var) 131 | TensorHandle concat(int axis, int n, 132 | const TensorHandle* inputs) 133 | TensorHandle conv2d(const TensorHandle input, 134 | const TensorHandle weight, 135 | int strideH, int strideW, 136 | PaddingMode _padding, 137 | ActiMode _activation) 138 | TensorHandle dropout(const TensorHandle input) 139 | TensorHandle element(OpType type, 140 | const TensorHandle x, 141 | const TensorHandle y) 142 | TensorHandle pool2d_max(const TensorHandle input, 143 | int kernelH, int kernelW, 144 | int strideH, int strideW, 145 | PaddingMode padding, 146 | ActiMode activation) 147 | TensorHandle pool2d_avg(const TensorHandle input, 148 | int kernelH, int kernelW, 149 | int strideH, int strideW, 150 | PaddingMode padding, 151 | ActiMode activation) 152 | TensorHandle matmul(const TensorHandle input, 153 | const TensorHandle weight, 154 | ActiMode activation) 155 | TensorHandle reshape(const TensorHandle input, 156 | const vector[int] shape) 157 | TensorHandle relu(const TensorHandle input, 158 | bool _inplace) 159 | TensorHandle sigmoid(const TensorHandle input, 160 | bool _inplace) 161 | TensorHandle tanh(const TensorHandle input, 162 | bool _inplace) 163 | TensorHandle transpose(const TensorHandle input, 164 | const vector[int] perm, 165 | bool shuffle) 166 | TensorHandle new_input(int ndim, const int* dims) 167 | TensorHandle new_weight(int ndim, const int* dims, const float* data) 168 | Graph* optimize(float alpha, int budget) 169 | int get_operator_list(Op* ops, size_t maxNumOps) 170 | int get_input_edges(Edge* edges, size_t guid) 171 | OpType get_operator_type(size_t guid) 172 | int get_operator_int_attr(size_t guid, PMParameter attr) 173 | int get_num_outputs(size_t guid) 174 | int get_input_dims(size_t guid, int* dims, int idx) 175 | void get_weight_value(size_t guid, float* data) 176 | int get_split_lens(size_t guid, int* lens) 177 | int get_output_dims(size_t guid, int* dims, int idx) 178 | void print_measurements() 179 | float run() 180 | -------------------------------------------------------------------------------- /rules_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: rules.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='rules.proto', 20 | package='GraphSubst', 21 | syntax='proto2', 22 | serialized_pb=_b('\n\x0brules.proto\x12\nGraphSubst\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x02(\x05\x12\r\n\x05value\x18\x02 \x02(\x05\"$\n\x06Tensor\x12\x0c\n\x04opId\x18\x01 \x02(\x05\x12\x0c\n\x04tsId\x18\x02 \x02(\x05\"`\n\x08Operator\x12\x0c\n\x04type\x18\x01 \x02(\x05\x12!\n\x05input\x18\x02 \x03(\x0b\x32\x12.GraphSubst.Tensor\x12#\n\x04para\x18\x03 \x03(\x0b\x32\x15.GraphSubst.Parameter\"O\n\tMapOutput\x12\x0f\n\x07srcOpId\x18\x01 \x02(\x05\x12\x0f\n\x07\x64stOpId\x18\x02 \x02(\x05\x12\x0f\n\x07srcTsId\x18\x03 \x02(\x05\x12\x0f\n\x07\x64stTsId\x18\x04 \x02(\x05\"}\n\x04Rule\x12#\n\x05srcOp\x18\x01 \x03(\x0b\x32\x14.GraphSubst.Operator\x12#\n\x05\x64stOp\x18\x02 \x03(\x0b\x32\x14.GraphSubst.Operator\x12+\n\x0cmappedOutput\x18\x03 \x03(\x0b\x32\x15.GraphSubst.MapOutput\"0\n\x0eRuleCollection\x12\x1e\n\x04rule\x18\x01 \x03(\x0b\x32\x10.GraphSubst.Rule') 23 | ) 24 | 25 | 26 | 27 | 28 | _PARAMETER = _descriptor.Descriptor( 29 | name='Parameter', 30 | full_name='GraphSubst.Parameter', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='key', full_name='GraphSubst.Parameter.key', index=0, 37 | number=1, type=5, cpp_type=1, label=2, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='value', full_name='GraphSubst.Parameter.value', index=1, 44 | number=2, type=5, cpp_type=1, label=2, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto2', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=27, 62 | serialized_end=66, 63 | ) 64 | 65 | 66 | _TENSOR = _descriptor.Descriptor( 67 | name='Tensor', 68 | full_name='GraphSubst.Tensor', 69 | filename=None, 70 | file=DESCRIPTOR, 71 | containing_type=None, 72 | fields=[ 73 | _descriptor.FieldDescriptor( 74 | name='opId', full_name='GraphSubst.Tensor.opId', index=0, 75 | number=1, type=5, cpp_type=1, label=2, 76 | has_default_value=False, default_value=0, 77 | message_type=None, enum_type=None, containing_type=None, 78 | is_extension=False, extension_scope=None, 79 | options=None), 80 | _descriptor.FieldDescriptor( 81 | name='tsId', full_name='GraphSubst.Tensor.tsId', index=1, 82 | number=2, type=5, cpp_type=1, label=2, 83 | has_default_value=False, default_value=0, 84 | message_type=None, enum_type=None, containing_type=None, 85 | is_extension=False, extension_scope=None, 86 | options=None), 87 | ], 88 | extensions=[ 89 | ], 90 | nested_types=[], 91 | enum_types=[ 92 | ], 93 | options=None, 94 | is_extendable=False, 95 | syntax='proto2', 96 | extension_ranges=[], 97 | oneofs=[ 98 | ], 99 | serialized_start=68, 100 | serialized_end=104, 101 | ) 102 | 103 | 104 | _OPERATOR = _descriptor.Descriptor( 105 | name='Operator', 106 | full_name='GraphSubst.Operator', 107 | filename=None, 108 | file=DESCRIPTOR, 109 | containing_type=None, 110 | fields=[ 111 | _descriptor.FieldDescriptor( 112 | name='type', full_name='GraphSubst.Operator.type', index=0, 113 | number=1, type=5, cpp_type=1, label=2, 114 | has_default_value=False, default_value=0, 115 | message_type=None, enum_type=None, containing_type=None, 116 | is_extension=False, extension_scope=None, 117 | options=None), 118 | _descriptor.FieldDescriptor( 119 | name='input', full_name='GraphSubst.Operator.input', index=1, 120 | number=2, type=11, cpp_type=10, label=3, 121 | has_default_value=False, default_value=[], 122 | message_type=None, enum_type=None, containing_type=None, 123 | is_extension=False, extension_scope=None, 124 | options=None), 125 | _descriptor.FieldDescriptor( 126 | name='para', full_name='GraphSubst.Operator.para', index=2, 127 | number=3, type=11, cpp_type=10, label=3, 128 | has_default_value=False, default_value=[], 129 | message_type=None, enum_type=None, containing_type=None, 130 | is_extension=False, extension_scope=None, 131 | options=None), 132 | ], 133 | extensions=[ 134 | ], 135 | nested_types=[], 136 | enum_types=[ 137 | ], 138 | options=None, 139 | is_extendable=False, 140 | syntax='proto2', 141 | extension_ranges=[], 142 | oneofs=[ 143 | ], 144 | serialized_start=106, 145 | serialized_end=202, 146 | ) 147 | 148 | 149 | _MAPOUTPUT = _descriptor.Descriptor( 150 | name='MapOutput', 151 | full_name='GraphSubst.MapOutput', 152 | filename=None, 153 | file=DESCRIPTOR, 154 | containing_type=None, 155 | fields=[ 156 | _descriptor.FieldDescriptor( 157 | name='srcOpId', full_name='GraphSubst.MapOutput.srcOpId', index=0, 158 | number=1, type=5, cpp_type=1, label=2, 159 | has_default_value=False, default_value=0, 160 | message_type=None, enum_type=None, containing_type=None, 161 | is_extension=False, extension_scope=None, 162 | options=None), 163 | _descriptor.FieldDescriptor( 164 | name='dstOpId', full_name='GraphSubst.MapOutput.dstOpId', index=1, 165 | number=2, type=5, cpp_type=1, label=2, 166 | has_default_value=False, default_value=0, 167 | message_type=None, enum_type=None, containing_type=None, 168 | is_extension=False, extension_scope=None, 169 | options=None), 170 | _descriptor.FieldDescriptor( 171 | name='srcTsId', full_name='GraphSubst.MapOutput.srcTsId', index=2, 172 | number=3, type=5, cpp_type=1, label=2, 173 | has_default_value=False, default_value=0, 174 | message_type=None, enum_type=None, containing_type=None, 175 | is_extension=False, extension_scope=None, 176 | options=None), 177 | _descriptor.FieldDescriptor( 178 | name='dstTsId', full_name='GraphSubst.MapOutput.dstTsId', index=3, 179 | number=4, type=5, cpp_type=1, label=2, 180 | has_default_value=False, default_value=0, 181 | message_type=None, enum_type=None, containing_type=None, 182 | is_extension=False, extension_scope=None, 183 | options=None), 184 | ], 185 | extensions=[ 186 | ], 187 | nested_types=[], 188 | enum_types=[ 189 | ], 190 | options=None, 191 | is_extendable=False, 192 | syntax='proto2', 193 | extension_ranges=[], 194 | oneofs=[ 195 | ], 196 | serialized_start=204, 197 | serialized_end=283, 198 | ) 199 | 200 | 201 | _RULE = _descriptor.Descriptor( 202 | name='Rule', 203 | full_name='GraphSubst.Rule', 204 | filename=None, 205 | file=DESCRIPTOR, 206 | containing_type=None, 207 | fields=[ 208 | _descriptor.FieldDescriptor( 209 | name='srcOp', full_name='GraphSubst.Rule.srcOp', index=0, 210 | number=1, type=11, cpp_type=10, label=3, 211 | has_default_value=False, default_value=[], 212 | message_type=None, enum_type=None, containing_type=None, 213 | is_extension=False, extension_scope=None, 214 | options=None), 215 | _descriptor.FieldDescriptor( 216 | name='dstOp', full_name='GraphSubst.Rule.dstOp', index=1, 217 | number=2, type=11, cpp_type=10, label=3, 218 | has_default_value=False, default_value=[], 219 | message_type=None, enum_type=None, containing_type=None, 220 | is_extension=False, extension_scope=None, 221 | options=None), 222 | _descriptor.FieldDescriptor( 223 | name='mappedOutput', full_name='GraphSubst.Rule.mappedOutput', index=2, 224 | number=3, type=11, cpp_type=10, label=3, 225 | has_default_value=False, default_value=[], 226 | message_type=None, enum_type=None, containing_type=None, 227 | is_extension=False, extension_scope=None, 228 | options=None), 229 | ], 230 | extensions=[ 231 | ], 232 | nested_types=[], 233 | enum_types=[ 234 | ], 235 | options=None, 236 | is_extendable=False, 237 | syntax='proto2', 238 | extension_ranges=[], 239 | oneofs=[ 240 | ], 241 | serialized_start=285, 242 | serialized_end=410, 243 | ) 244 | 245 | 246 | _RULECOLLECTION = _descriptor.Descriptor( 247 | name='RuleCollection', 248 | full_name='GraphSubst.RuleCollection', 249 | filename=None, 250 | file=DESCRIPTOR, 251 | containing_type=None, 252 | fields=[ 253 | _descriptor.FieldDescriptor( 254 | name='rule', full_name='GraphSubst.RuleCollection.rule', index=0, 255 | number=1, type=11, cpp_type=10, label=3, 256 | has_default_value=False, default_value=[], 257 | message_type=None, enum_type=None, containing_type=None, 258 | is_extension=False, extension_scope=None, 259 | options=None), 260 | ], 261 | extensions=[ 262 | ], 263 | nested_types=[], 264 | enum_types=[ 265 | ], 266 | options=None, 267 | is_extendable=False, 268 | syntax='proto2', 269 | extension_ranges=[], 270 | oneofs=[ 271 | ], 272 | serialized_start=412, 273 | serialized_end=460, 274 | ) 275 | 276 | _OPERATOR.fields_by_name['input'].message_type = _TENSOR 277 | _OPERATOR.fields_by_name['para'].message_type = _PARAMETER 278 | _RULE.fields_by_name['srcOp'].message_type = _OPERATOR 279 | _RULE.fields_by_name['dstOp'].message_type = _OPERATOR 280 | _RULE.fields_by_name['mappedOutput'].message_type = _MAPOUTPUT 281 | _RULECOLLECTION.fields_by_name['rule'].message_type = _RULE 282 | DESCRIPTOR.message_types_by_name['Parameter'] = _PARAMETER 283 | DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR 284 | DESCRIPTOR.message_types_by_name['Operator'] = _OPERATOR 285 | DESCRIPTOR.message_types_by_name['MapOutput'] = _MAPOUTPUT 286 | DESCRIPTOR.message_types_by_name['Rule'] = _RULE 287 | DESCRIPTOR.message_types_by_name['RuleCollection'] = _RULECOLLECTION 288 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 289 | 290 | Parameter = _reflection.GeneratedProtocolMessageType('Parameter', (_message.Message,), dict( 291 | DESCRIPTOR = _PARAMETER, 292 | __module__ = 'rules_pb2' 293 | # @@protoc_insertion_point(class_scope:GraphSubst.Parameter) 294 | )) 295 | _sym_db.RegisterMessage(Parameter) 296 | 297 | Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), dict( 298 | DESCRIPTOR = _TENSOR, 299 | __module__ = 'rules_pb2' 300 | # @@protoc_insertion_point(class_scope:GraphSubst.Tensor) 301 | )) 302 | _sym_db.RegisterMessage(Tensor) 303 | 304 | Operator = _reflection.GeneratedProtocolMessageType('Operator', (_message.Message,), dict( 305 | DESCRIPTOR = _OPERATOR, 306 | __module__ = 'rules_pb2' 307 | # @@protoc_insertion_point(class_scope:GraphSubst.Operator) 308 | )) 309 | _sym_db.RegisterMessage(Operator) 310 | 311 | MapOutput = _reflection.GeneratedProtocolMessageType('MapOutput', (_message.Message,), dict( 312 | DESCRIPTOR = _MAPOUTPUT, 313 | __module__ = 'rules_pb2' 314 | # @@protoc_insertion_point(class_scope:GraphSubst.MapOutput) 315 | )) 316 | _sym_db.RegisterMessage(MapOutput) 317 | 318 | Rule = _reflection.GeneratedProtocolMessageType('Rule', (_message.Message,), dict( 319 | DESCRIPTOR = _RULE, 320 | __module__ = 'rules_pb2' 321 | # @@protoc_insertion_point(class_scope:GraphSubst.Rule) 322 | )) 323 | _sym_db.RegisterMessage(Rule) 324 | 325 | RuleCollection = _reflection.GeneratedProtocolMessageType('RuleCollection', (_message.Message,), dict( 326 | DESCRIPTOR = _RULECOLLECTION, 327 | __module__ = 'rules_pb2' 328 | # @@protoc_insertion_point(class_scope:GraphSubst.RuleCollection) 329 | )) 330 | _sym_db.RegisterMessage(RuleCollection) 331 | 332 | 333 | # @@protoc_insertion_point(module_scope) 334 | -------------------------------------------------------------------------------- /src/core/activation.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::relu(const TensorHandle _input, bool _inPlace) 20 | { 21 | Op op = model->get_or_create_activation(*_input, OP_RELU, _inPlace); 22 | assert(op != Op::INVALID_OP); 23 | add_edge(_input->op, op, _input->idx, 0); 24 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 25 | t->op = op; 26 | return t; 27 | } 28 | 29 | TensorHandle Graph::sigmoid(const TensorHandle _input, bool _inPlace) 30 | { 31 | Op op = model->get_or_create_activation(*_input, OP_SIGMOID, _inPlace); 32 | assert(op != Op::INVALID_OP); 33 | add_edge(_input->op, op, _input->idx, 0); 34 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 35 | t->op = op; 36 | return t; 37 | } 38 | 39 | TensorHandle Graph::tanh(const TensorHandle _input, bool _inPlace) 40 | { 41 | Op op = model->get_or_create_activation(*_input, OP_TANH, _inPlace); 42 | assert(op != Op::INVALID_OP); 43 | add_edge(_input->op, op, _input->idx, 0); 44 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 45 | t->op = op; 46 | return t; 47 | } 48 | 49 | Op Model::get_or_create_activation(Tensor _input, OpType _type, bool _inPlace) 50 | { 51 | // keys are (inputN, inputC, inputH, inputW, _type, _inPlace) 52 | ActivationKey key(_input, _type, _inPlace); 53 | Activation* actOp; 54 | if (activation.find(key) != activation.end()) { 55 | actOp = activation[key]; 56 | } else { 57 | actOp = new Activation(this, _input, _type, _inPlace); 58 | measure_activation_cost(actOp); 59 | activation[key] = actOp; 60 | } 61 | Op ret; 62 | ret.guid = global_unique_id ++; 63 | ret.ptr = actOp; 64 | return ret; 65 | } 66 | 67 | Activation::Activation(Model* _model, Tensor _input, OpType _type, bool _inPlace) 68 | : OpBase(_input, _model, _type), inPlace(_inPlace) 69 | { 70 | numOutputs = 1; 71 | outputs[0] = _input; 72 | outputs[0].idx = 0; 73 | } 74 | 75 | Activation::~Activation(void) 76 | { 77 | } 78 | 79 | bool Activation::get_parameter(PMParameter para, int* value) 80 | { 81 | return OpBase::get_parameter(para, value); 82 | } 83 | 84 | void Activation::collect_costs(float& exe_time, float& flops, 85 | float& mem_acc, int& num_kernels) 86 | { 87 | int outputSize = 1, inputSize = 1; 88 | for (int i = 0; i < outputs[0].numDim; i++) 89 | outputSize *= outputs[0].dim[i]; 90 | for (int i = 0; i < inputs[0].numDim; i++) 91 | inputSize *= inputs[0].dim[i]; 92 | // cost metrics 93 | exe_time += runtime; 94 | if (type == OP_RELU) 95 | flops += 0; // relu does not involve flops 96 | else 97 | flops += outputSize; 98 | mem_acc += inputSize; 99 | num_kernels += 1; 100 | printf(" cost[Activation]: mode(%d) cost(%.4lf) total_cost(%.4lf)\n", 101 | type, runtime, exe_time); 102 | } 103 | 104 | // Key ordering: type, inPlace, _input 105 | ActivationKey::ActivationKey(Tensor _input, OpType _type, bool _inPlace) 106 | { 107 | int idx = 0; 108 | keys[idx++] = _type; 109 | keys[idx++] = (int)(_inPlace); 110 | _input.serialize(keys, idx); 111 | while (idx < KEY_LENGTH) 112 | keys[idx++] = 0; 113 | assert(idx == KEY_LENGTH); 114 | } 115 | 116 | -------------------------------------------------------------------------------- /src/core/batchnorm.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::batchnorm(const TensorHandle _input, 20 | const TensorHandle _scale, 21 | const TensorHandle _bias, 22 | const TensorHandle _mean, 23 | const TensorHandle _var) 24 | { 25 | Op op = model->get_or_create_batchnorm(*_input, *_scale, *_bias, 26 | *_mean, *_var); 27 | add_edge(_input->op, op, _input->idx, 0); 28 | add_edge(_scale->op, op, _scale->idx, 1); 29 | add_edge(_bias->op, op, _bias->idx, 2); 30 | add_edge(_mean->op, op, _mean->idx, 3); 31 | add_edge(_var->op, op, _var->idx, 4); 32 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 33 | t->op = op; 34 | return t; 35 | } 36 | 37 | Op Model::get_or_create_batchnorm(Tensor _input, 38 | Tensor _scale, 39 | Tensor _bias, 40 | Tensor _mean, 41 | Tensor _var) 42 | { 43 | // key is (inputN, inputC, inputH, inputW) 44 | BatchNormKey key(_input); 45 | BatchNorm* bnOp; 46 | if(batchnorm.find(key) != batchnorm.end()) { 47 | bnOp = batchnorm[key]; 48 | } else { 49 | bnOp = new BatchNorm(this, _input, _scale, _bias, _mean, _var); 50 | measure_batchnorm_cost(bnOp); 51 | batchnorm[key] = bnOp; 52 | } 53 | Op ret; 54 | ret.guid = global_unique_id ++; 55 | ret.ptr = bnOp; 56 | return ret; 57 | } 58 | 59 | BatchNorm::BatchNorm(Model* _model, Tensor _input, Tensor _scale, 60 | Tensor _bias, Tensor _mean, Tensor _var) 61 | : OpBase(_input, _scale, _bias, _mean, _var, _model, OP_BATCHNORM) 62 | { 63 | assert(_input.numDim == 4); 64 | numOutputs = 1; 65 | outputs[0] = _input; 66 | outputs[0].idx = 0; 67 | } 68 | 69 | BatchNorm::~BatchNorm(void) 70 | {} 71 | 72 | bool BatchNorm::get_parameter(PMParameter para, int* value) 73 | { 74 | return OpBase::get_parameter(para, value); 75 | } 76 | 77 | void BatchNorm::collect_costs(float& exe_time, float& flops, 78 | float& mem_acc, int& num_kernels) 79 | { 80 | int outputSize = 1, inputSize = 1; 81 | for (int i = 0; i < outputs[0].numDim; i++) 82 | outputSize *= outputs[0].dim[i]; 83 | for (int i = 0; i < inputs[0].numDim; i++) 84 | inputSize *= inputs[0].dim[i]; 85 | // cost metrics 86 | exe_time += runtime; 87 | flops += outputSize * 2; 88 | mem_acc += inputSize; 89 | num_kernels += 1; 90 | } 91 | 92 | // key is (_input) 93 | BatchNormKey::BatchNormKey(Tensor _input) 94 | { 95 | int idx = 0; 96 | _input.serialize(keys, idx); 97 | while (idx < KEY_LENGTH) 98 | keys[idx++] = 0; 99 | assert(KEY_LENGTH == idx); 100 | } 101 | -------------------------------------------------------------------------------- /src/core/concat.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::concat(int axis, int n, const TensorHandle* _inputs) 20 | { 21 | Tensor inputTensors[MAX_NUM_INPUTS]; 22 | for (int i = 0; i < n; i++) { 23 | inputTensors[i] = *_inputs[i]; 24 | } 25 | bool needCopy[MAX_NUM_INPUTS]; 26 | for (int i = 0; i < n; i++) 27 | needCopy[i] = true; 28 | Op op = model->get_or_create_concat(axis, n, inputTensors, needCopy); 29 | // Assert op must be valid 30 | assert (op != Op::INVALID_OP); 31 | for (int i = 0; i < n; i++) 32 | add_edge(_inputs[i]->op, op, _inputs[i]->idx, i); 33 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 34 | t->op = op; 35 | return t; 36 | } 37 | 38 | Op Model::get_or_create_concat(int axis, int n, Tensor* _inputs, bool* _needCopy) 39 | { 40 | // key ordering is: 41 | // axis, n, bitmask(needCopy) 42 | // inputs[0].dim[0], ..., inputs[0].dim[axis-1], 43 | // inputs[0].dim[axis+1], ..., inputs[0].dim[nDims - 1] 44 | // inputs[0].dim[axis], ..., inputs[n-1].dim[axis] 45 | // Check validness 46 | for (int i = 0; i < n; i++) { 47 | if (_inputs[i].numDim != _inputs[0].numDim) { 48 | return Op::INVALID_OP; 49 | } 50 | for (int j = 0; j < _inputs[0].numDim; j++) 51 | if ((j != axis) && (_inputs[i].dim[j] != _inputs[0].dim[j])) { 52 | return Op::INVALID_OP; 53 | } 54 | } 55 | ConcatKey key(axis, n, _inputs, _needCopy); 56 | Concat* concatOp; 57 | if (concat.find(key) != concat.end()) { 58 | concatOp = concat[key]; 59 | } else { 60 | concatOp = new Concat(this, axis, n, _inputs, _needCopy); 61 | measure_concat_cost(concatOp); 62 | concat[key] = concatOp; 63 | } 64 | Op ret; 65 | ret.guid = global_unique_id ++; 66 | ret.ptr = concatOp; 67 | return ret; 68 | } 69 | 70 | Concat::Concat(Model* _model, int _axis, int n, Tensor* _inputs, bool* _needCopy) 71 | : OpBase(n, _inputs, _model, OP_CONCAT), axis(_axis) 72 | { 73 | //for (int i = 0; i < n; i++) { 74 | // printf(" concat2[%d]:", i); 75 | // for (int j = 0; j < _inputs[i].numDim; j++) 76 | // printf("%d, ", _inputs[i].dim[j]); 77 | // printf("\n"); 78 | //} 79 | assert(n <= MAX_NUM_INPUTS); 80 | for (int i = 0; i < n; i++) 81 | needCopy[i] = _needCopy[i]; 82 | numOutputs = 1; 83 | outputs[0].numDim = inputs[0].numDim; 84 | for (int i = 0; i < outputs[0].numDim; i++) 85 | outputs[0].dim[i] = inputs[0].dim[i]; 86 | for (int i = 0; i < outputs[0].numDim; i++) 87 | if (i != axis) { 88 | outputs[0].split[i] = inputs[0].split[i]; 89 | for (int j = 1; j < n; j++) 90 | outputs[0].split[i].combine(inputs[j].split[i]); 91 | } 92 | outputs[0].split[axis] = inputs[0].split[axis]; 93 | for (int i = 1; i < n; i++) { 94 | outputs[0].split[axis].merge(outputs[0].dim[axis], inputs[i].split[axis]); 95 | outputs[0].dim[axis] += inputs[i].dim[axis]; 96 | } 97 | for (int i = outputs[0].numDim-1; i >= 0; i--) { 98 | if (i == outputs[0].numDim-1) 99 | outputs[0].stride[i] = 1; 100 | else 101 | outputs[0].stride[i] = outputs[0].stride[i+1] * outputs[0].dim[i+1]; 102 | } 103 | outputs[0].idx = 0; 104 | } 105 | 106 | Concat::~Concat(void) 107 | {} 108 | 109 | bool Concat::get_parameter(PMParameter para, int* value) 110 | { 111 | switch (para) { 112 | case PM_AXIS: 113 | *value = axis; 114 | return true; 115 | default: 116 | return OpBase::get_parameter(para, value); 117 | } 118 | } 119 | 120 | void Concat::collect_costs(float& exe_time, float& flops, 121 | float& mem_acc, int& num_kernels) 122 | { 123 | for (int i = 0; i < numInputs; i++) 124 | if (needCopy[i]) { 125 | int inputSize = 1; 126 | for (int j = 0; j < inputs[i].numDim; j++) 127 | inputSize *= inputs[i].dim[j]; 128 | mem_acc += inputSize; 129 | } 130 | // cost metrics 131 | exe_time += runtime; 132 | flops += 0; 133 | num_kernels += 1; 134 | printf(" cost[Concat]: numInputs(%d) cost(%.4lf) total_cost(%.4lf)\n", 135 | numInputs, runtime, exe_time); 136 | } 137 | 138 | int bitmask(int n, bool* bits) 139 | { 140 | int ret = 0; 141 | for (int i = 0; i < n; i++) 142 | ret = bits[i] ? ret * 2 + 1 : ret * 2; 143 | return ret; 144 | } 145 | 146 | // key ordering is: axis, n, bitmask(needCopy), inputs[0], ..., inputs[n-1] 147 | // 148 | // 149 | // axis, n, bitmask(needCopy), inputs[0], inputs[n-1] 150 | // inputs[0].dim[0], ..., inputs[0].dim[axis-1], 151 | // inputs[0].dim[axis+1], ..., inputs[0].dim[nDims - 1] 152 | // inputs[0].dim[axis], ..., inputs[n-1].dim[axis] 153 | ConcatKey::ConcatKey(int axis, int n, Tensor* _inputs, bool* _needCopy) 154 | { 155 | int idx = 0; 156 | keys[idx++] = axis; 157 | keys[idx++] = n; 158 | keys[idx++] = bitmask(n, _needCopy); 159 | for (int i = 0; i < n; i++) 160 | _inputs[i].serialize(keys, idx); 161 | while (idx < KEY_LENGTH) 162 | keys[idx++] = 0; 163 | assert(idx == KEY_LENGTH); 164 | #ifdef DEADCODE 165 | assert(_inputs[0].numDim + n + 2 <= KEY_LENGTH); 166 | int idx = 0; 167 | keys[idx++] = axis; 168 | keys[idx++] = n; 169 | keys[idx++] = bitmask(n, _needCopy); 170 | for (int i = 0; i < _inputs[0].numDim; i++) 171 | if (i != axis) 172 | keys[idx++] = _inputs[0].dim[i]; 173 | for (int i = 0; i < n; i++) 174 | keys[idx++] = _inputs[i].dim[axis]; 175 | while (idx < KEY_LENGTH) 176 | keys[idx++] = 0; 177 | assert(idx == KEY_LENGTH); 178 | #endif 179 | } 180 | 181 | -------------------------------------------------------------------------------- /src/core/constant.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::constant(int ndim, int* dims, OpType type) 20 | { 21 | Op op = model->get_or_create_constant(ndim, dims, type); 22 | // NOTE that constant do not have any inputs 23 | // we need to manually add op to the inedges 24 | assert(inEdges.find(op) == inEdges.end()); 25 | inEdges[op]; 26 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 27 | t->op = op; 28 | return t; 29 | } 30 | 31 | Op Model::get_or_create_constant(int ndim, int* dims, OpType _type) 32 | { 33 | ConstantKey key(ndim, dims, _type); 34 | Constant* constantOp; 35 | if (constant.find(key) != constant.end()) { 36 | constantOp = constant[key]; 37 | } else { 38 | constantOp = new Constant(this, ndim, dims, _type); 39 | constantOp->runtime = 0.0f; 40 | constant[key] = constantOp; 41 | } 42 | Op ret; 43 | ret.guid = global_unique_id ++; 44 | ret.ptr = constantOp; 45 | return ret; 46 | } 47 | 48 | Constant::Constant(Model* _model, int ndim, int* dims, OpType _type) 49 | : OpBase(_model, _type) 50 | { 51 | numOutputs = 1; 52 | outputs[0].numDim = ndim; 53 | for (int i = 0; i < ndim; i++) 54 | outputs[0].dim[i] = dims[i]; 55 | outputs[0].stride[ndim-1] = 1; 56 | for (int i = ndim-2; i >= 0; i--) 57 | outputs[0].stride[i] = outputs[0].stride[i+1] * outputs[0].dim[i+1]; 58 | // Set SplitInfo 59 | for (int i = 0; i < ndim; i++) 60 | outputs[0].split[i] = SplitInfo::NO_SPLIT; 61 | outputs[0].idx = 0; 62 | } 63 | 64 | Constant::~Constant(void) 65 | {} 66 | 67 | bool Constant::get_parameter(PMParameter para, int* value) 68 | { 69 | return OpBase::get_parameter(para, value); 70 | } 71 | 72 | void Constant::collect_costs(float& exe_time, float& flops, 73 | float& mem_acc, int& num_kernels) 74 | { 75 | // TODO; implement 76 | assert(false); 77 | } 78 | 79 | ConstantKey::ConstantKey(int ndim, int* dims, OpType type) 80 | { 81 | int idx = 0; 82 | keys[idx++] = ndim; 83 | for (int i = 0; i < ndim; i++) 84 | keys[idx++] = dims[i]; 85 | keys[idx++] = type; 86 | while (idx < KEY_LENGTH) 87 | keys[idx++] = 0; 88 | assert(KEY_LENGTH == idx); 89 | } 90 | 91 | -------------------------------------------------------------------------------- /src/core/conv2d.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::group_conv2d(int groups, 20 | const TensorHandle _input, 21 | int _outputC, 22 | int _kernelH, int _kernelW, 23 | int _strideH, int _strideW, 24 | PaddingMode _padding, 25 | ActiMode _activation) 26 | { 27 | assert(_input->dim[1] % groups == 0); 28 | assert(_outputC % groups == 0); 29 | int dims[4] = {_outputC, _input->dim[1] / groups, _kernelH, _kernelW}; 30 | int total = dims[0] * dims[1] * dims[2] * dims[3]; 31 | // Randomly initialize weights 32 | DATATYPE* data = (DATATYPE*) malloc(total * sizeof(DATATYPE)); 33 | for (int i = 0; i < total; i++) 34 | data[i] = (DATATYPE)std::rand() / RAND_MAX; 35 | TensorHandle weight = new_weight(4, dims, data); 36 | free(data); 37 | /* 38 | weight.numDim = 4; 39 | weight.dim[0] = _outputC; 40 | weight.dim[1] = _input.dim[1] / groups; 41 | weight.dim[2] = _kernelH; 42 | weight.dim[3] = _kernelW; 43 | weight.stride[3] = 1; 44 | weight.stride[2] = weight.stride[3] * weight.dim[3]; 45 | weight.stride[1] = weight.stride[2] * weight.dim[2]; 46 | weight.stride[0] = weight.stride[1] * weight.dim[1]; 47 | weight.op.guid = GUID_WEIGHT; 48 | weight.op.ptr = NULL; 49 | weight.idx = 0; 50 | weight = noop(weight); 51 | */ 52 | return conv2d(_input, weight, _strideH, _strideW, _padding, _activation); 53 | } 54 | 55 | 56 | TensorHandle Graph::conv2d(const TensorHandle _input, 57 | int _outputC, 58 | int _kernelH, int _kernelW, 59 | int _strideH, int _strideW, 60 | PaddingMode _padding, 61 | ActiMode _activation) 62 | { 63 | const int dims[4] = {_outputC, _input->dim[1], _kernelH, _kernelW}; 64 | int total = dims[0] * dims[1] * dims[2] * dims[3]; 65 | // Randomly initialize weights 66 | DATATYPE* data = (DATATYPE*) malloc(total * sizeof(DATATYPE)); 67 | for (int i = 0; i < total; i++) 68 | data[i] = (DATATYPE)std::rand() / RAND_MAX; 69 | TensorHandle weight = new_weight(4, dims, data); 70 | free(data); 71 | /* 72 | weight.numDim = 4; 73 | weight.dim[0] = _outputC; 74 | weight.dim[1] = _input.dim[1]; 75 | weight.dim[2] = _kernelH; 76 | weight.dim[3] = _kernelW; 77 | weight.stride[3] = 1; 78 | weight.stride[2] = weight.stride[3] * weight.dim[3]; 79 | weight.stride[1] = weight.stride[2] * weight.dim[2]; 80 | weight.stride[0] = weight.stride[1] * weight.dim[1]; 81 | weight.op.guid = GUID_WEIGHT; 82 | weight.op.ptr = NULL; 83 | weight.idx = 0; 84 | weight = noop(weight); 85 | */ 86 | return conv2d(_input, weight, _strideH, _strideW, 87 | _padding, _activation); 88 | } 89 | 90 | /* 91 | Tensor Graph::conv2d(Tensor _input, Tensor _weight, 92 | int _strideH, int _strideW, 93 | PaddingMode _padding, 94 | ActiMode _activation) 95 | { 96 | Op op = model->get_or_create_conv2d(_input, _weight, _strideH, _strideW, 97 | _padding, _activation); 98 | add_edge(_input.op, op, _input.idx, 0); 99 | add_edge(_weight.op, op, _weight.idx, 1); 100 | Tensor t = op.ptr->outputs[0]; 101 | t.op = op; 102 | return t; 103 | } 104 | */ 105 | 106 | TensorHandle Graph::conv2d(const TensorHandle _input, 107 | const TensorHandle _weight, 108 | int _strideH, int _strideW, 109 | PaddingMode _padding, 110 | ActiMode _activation) 111 | { 112 | Op op = model->get_or_create_conv2d(*_input, *_weight, _strideH, _strideW, 113 | _padding, _activation); 114 | assert(op != Op::INVALID_OP); 115 | add_edge(_input->op, op, _input->idx, 0); 116 | add_edge(_weight->op, op, _weight->idx, 1); 117 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 118 | t->op = op; 119 | return t; 120 | } 121 | 122 | Op Model::get_or_create_conv2d(Tensor _input, Tensor _weight, 123 | int _strideH, int _strideW, 124 | PaddingMode _padding, 125 | ActiMode _activation) 126 | { 127 | if (_input.dim[1] % _weight.dim[1] != 0) 128 | return Op::INVALID_OP; 129 | // key is (inputN, inputC, inputH, inputW, outputC, kernelH, kernelW, 130 | // strideH, strideW, padding, activation) 131 | Conv2DKey key(_input, _weight, _strideH, _strideW, _padding, _activation); 132 | Conv2D* convOp; 133 | if (conv2d.find(key) != conv2d.end()) { 134 | convOp = conv2d[key]; 135 | } else { 136 | convOp = new Conv2D(this, _input, _weight, _strideH, _strideW, 137 | _padding, _activation); 138 | measure_conv2d_cost(convOp); 139 | conv2d[key] = convOp; 140 | } 141 | Op ret; 142 | ret.guid = global_unique_id ++; 143 | ret.ptr = convOp; 144 | return ret; 145 | } 146 | 147 | Conv2D::Conv2D(Model* _model, Tensor _input, Tensor _weight, 148 | int _strideH, int _strideW, 149 | PaddingMode _padding, 150 | ActiMode _activation) 151 | : OpBase(_input, _weight, _model, OP_CONV2D), 152 | strideH(_strideH), strideW(_strideW), 153 | padding(_padding), activation(_activation) 154 | { 155 | assert(_input.numDim == 4); 156 | assert(_weight.numDim == 4); 157 | //assert(_input.dim[1] == _weight.dim[1]); 158 | assert(_input.dim[1] % _weight.dim[1] == 0); 159 | int groups = _input.dim[1] / _weight.dim[1]; 160 | assert(_weight.dim[0] % groups == 0); 161 | //printf("k(%d %d) pad(%d %d) stride(%d %d)\n", 162 | // kernelH, kernelW, padH, padW, strideH, strideW); 163 | int inputH = _input.dim[2]; 164 | int inputW = _input.dim[3]; 165 | int kernelH = _weight.dim[2]; 166 | int kernelW = _weight.dim[3]; 167 | int outputH, outputW; 168 | switch (padding) 169 | { 170 | case PD_MODE_SAME: 171 | outputH = (inputH + strideH - 1) / strideH; 172 | outputW = (inputW + strideW - 1) / strideW; 173 | break; 174 | case PD_MODE_VALID: 175 | outputH = (inputH - kernelH) / strideH + 1; 176 | outputW = (inputW - kernelW) / strideW + 1; 177 | break; 178 | default: 179 | assert(false); 180 | } 181 | //int outputH = 1 + (inputH + 2 * padH - kernelH) / strideH; 182 | //int outputW = 1 + (inputW + 2 * padW - kernelW) / strideW; 183 | // Set dims and strides 184 | numOutputs = 1; 185 | outputs[0].numDim = 4; 186 | outputs[0].dim[0] = _input.dim[0]; 187 | outputs[0].dim[1] = _weight.dim[0]; 188 | outputs[0].dim[2] = outputH; 189 | outputs[0].dim[3] = outputW; 190 | outputs[0].stride[3] = 1; 191 | outputs[0].stride[2] = outputs[0].stride[3] * outputs[0].dim[3]; 192 | outputs[0].stride[1] = outputs[0].stride[2] * outputs[0].dim[2]; 193 | outputs[0].stride[0] = outputs[0].stride[1] * outputs[0].dim[1]; 194 | // Set SplitInfo 195 | outputs[0].split[0] = _input.split[0]; 196 | outputs[0].split[1] = _weight.split[0]; 197 | outputs[0].split[2] = _input.split[2]; 198 | outputs[0].split[3] = _input.split[3]; 199 | // Assume we cannot split the H and W dimension, 200 | // otherwise we need to extend Conv2DKey to include their SplitInfo 201 | assert(outputs[0].split[2] == SplitInfo::NO_SPLIT); 202 | assert(outputs[0].split[3] == SplitInfo::NO_SPLIT); 203 | outputs[0].idx = 0; 204 | } 205 | 206 | Conv2D::~Conv2D(void) 207 | {} 208 | 209 | bool Conv2D::get_parameter(PMParameter para, int* value) 210 | { 211 | switch (para) { 212 | case PM_GROUP: 213 | { 214 | int inputC = inputs[0].dim[1]; 215 | int weightC = inputs[1].dim[1]; 216 | assert(inputC % weightC == 0); 217 | *value = inputC / weightC; 218 | return true; 219 | } 220 | case PM_KERNEL_H: 221 | *value = inputs[1].dim[2]; 222 | return true; 223 | case PM_KERNEL_W: 224 | *value = inputs[1].dim[3]; 225 | return true; 226 | case PM_STRIDE_H: 227 | *value = strideH; 228 | return true; 229 | case PM_STRIDE_W: 230 | *value = strideW; 231 | return true; 232 | case PM_PAD: 233 | *value = padding; 234 | return true; 235 | case PM_ACTI: 236 | *value = (int) activation; 237 | return true; 238 | default: 239 | return OpBase::get_parameter(para, value); 240 | } 241 | } 242 | 243 | void Conv2D::get_padding(int* padH, int* padW) { 244 | int inputH = inputs[0].dim[2]; 245 | int inputW = inputs[0].dim[3]; 246 | int kernelH = inputs[1].dim[2]; 247 | int kernelW = inputs[1].dim[3]; 248 | // Reference: https://www.tensorflow.org/api_guides/python/nn#Convolution 249 | switch (padding) { 250 | case PD_MODE_SAME: 251 | int totalPadH, totalPadW; 252 | if (inputH % strideH == 0) 253 | totalPadH = max(kernelH - strideH, 0); 254 | else 255 | totalPadH = max(kernelH - (inputH % strideH), 0); 256 | if (inputW % strideW == 0) 257 | totalPadW = max(kernelW - strideW, 0); 258 | else 259 | totalPadW = max(kernelW - (inputW % strideW), 0); 260 | // assert same padding on both sides 261 | *padH = (totalPadH + 1) / 2; 262 | *padW = (totalPadW + 1) / 2; 263 | break; 264 | case PD_MODE_VALID: 265 | *padH = 0; 266 | *padW = 0; 267 | break; 268 | default: 269 | assert(false); 270 | } 271 | } 272 | 273 | void Conv2D::collect_costs(float& exe_time, float& flops, 274 | float& mem_acc, int& num_kernels) 275 | { 276 | size_t outputSize = outputs[0].volume() * sizeof(DATATYPE); 277 | size_t inputSize = inputs[0].volume() * sizeof(DATATYPE); 278 | size_t weightSize = inputs[1].volume() * sizeof(DATATYPE); 279 | // cost metrics 280 | exe_time += runtime; 281 | int kernelH = inputs[1].dim[2]; 282 | int kernelW = inputs[1].dim[3]; 283 | int inputC = inputs[1].dim[1]; 284 | flops += outputSize * (kernelH * kernelW * inputC + 1); 285 | if (activation != AC_MODE_NONE) 286 | flops += outputSize; 287 | mem_acc += inputSize + outputSize + weightSize; 288 | num_kernels += 1; 289 | printf(" cost[Conv2D]: i(%d %d %d %d) w(%d %d %d %d) s(%d %d) p(%d) cost(%.4lf) total_cost(%.4lf)\n", 290 | inputs[0].dim[0], inputs[0].dim[1], inputs[0].dim[2], inputs[0].dim[3], 291 | inputs[1].dim[0], inputs[1].dim[1], inputs[1].dim[2], inputs[1].dim[3], 292 | strideH, strideW, padding, runtime, exe_time); 293 | } 294 | 295 | // keys are (inputN, inputC, inputH, inputW, outputC, kernelH, kernelW, 296 | // strideH, strideW, padding, acitvation, 297 | // input.split[0], weight.split[0]) 298 | Conv2DKey::Conv2DKey(Tensor _input, Tensor _weight, 299 | int _strideH, int _strideW, 300 | PaddingMode _padding, 301 | ActiMode _activation) 302 | { 303 | assert(_input.dim[1] % _weight.dim[1] == 0); 304 | int groups = _input.dim[1] / _weight.dim[1]; 305 | assert(_weight.dim[0] % groups == 0); 306 | int idx = 0; 307 | keys[idx++] = _strideH; 308 | keys[idx++] = _strideW; 309 | keys[idx++] = _padding; 310 | keys[idx++] = _activation; 311 | _input.serialize(keys, idx); 312 | _weight.serialize(keys, idx); 313 | while (idx < KEY_LENGTH) 314 | keys[idx++] = 0; 315 | assert(KEY_LENGTH == idx); 316 | } 317 | 318 | -------------------------------------------------------------------------------- /src/core/element.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::element(OpType type, 20 | const TensorHandle t1, 21 | const TensorHandle t2) 22 | { 23 | assert(t1->numDim == t2->numDim); 24 | for (int i = 0; i < t1->numDim; i++) 25 | assert(t1->dim[i] == t2->dim[i]); 26 | Op op = model->get_or_create_element(type, *t1, *t2); 27 | add_edge(t1->op, op, t1->idx, 0); 28 | add_edge(t2->op, op, t2->idx, 1); 29 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 30 | t->op = op; 31 | return t; 32 | } 33 | 34 | Op Model::get_or_create_element(OpType type, 35 | Tensor t1, Tensor t2) 36 | { 37 | // key is (inputN, inputC, inputH, inputW, type) 38 | ElementKey key(t1, type); 39 | Element* eleOp; 40 | if (element.find(key) != element.end()) { 41 | eleOp = element[key]; 42 | } else { 43 | eleOp = new Element(this, type, t1, t2); 44 | measure_element_cost(eleOp); 45 | element[key] = eleOp; 46 | } 47 | Op ret; 48 | ret.guid = global_unique_id ++; 49 | ret.ptr = eleOp; 50 | return ret; 51 | } 52 | 53 | Element::Element(Model* _model, OpType _type, 54 | Tensor _t1, Tensor _t2) 55 | : OpBase(_t1, _t2, _model, _type) 56 | { 57 | numOutputs = 1; 58 | outputs[0] = _t1; 59 | for (int i = 0; i < outputs[0].numDim; i++) 60 | outputs[0].split[i].combine(_t2.split[i]); 61 | outputs[0].idx = 0; 62 | } 63 | 64 | Element::~Element(void) 65 | {} 66 | 67 | bool Element::get_parameter(PMParameter para, int* value) 68 | { 69 | return OpBase::get_parameter(para, value); 70 | } 71 | 72 | void Element::collect_costs(float& exe_time, float& flops, 73 | float& mem_acc, int& num_kernels) 74 | { 75 | int outputSize = 1, inputSize = 1; 76 | for (int i = 0; i < outputs[0].numDim; i++) 77 | outputSize *= outputs[0].dim[i]; 78 | for (int i = 0; i < inputs[0].numDim; i++) 79 | inputSize *= inputs[0].dim[i]; 80 | // cost metrics 81 | exe_time += runtime; 82 | flops += outputSize; 83 | mem_acc += inputSize * 2; 84 | num_kernels += 1; 85 | printf(" cost[Element]: cost(%.4lf) total_cost(%.4lf)\n", runtime, exe_time); 86 | } 87 | 88 | // Key ordering: type, input 89 | ElementKey::ElementKey(Tensor input, OpType type) 90 | { 91 | int idx = 0; 92 | keys[idx++] = type; 93 | input.serialize(keys, idx); 94 | while (idx < KEY_LENGTH) 95 | keys[idx++] = 0; 96 | assert(idx == KEY_LENGTH); 97 | } 98 | -------------------------------------------------------------------------------- /src/core/enlarge.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | // Enlarge the third and forth dimension of _w1 to the same size as _w2 20 | TensorHandle Graph::enlarge(const TensorHandle _w1, 21 | const TensorHandle _w2) 22 | { 23 | // Currently the weight being enlarged must be 4D: 24 | // Cout, Cin, KerelH, KernelW 25 | assert(_w1->numDim == 4); 26 | assert(_w2->numDim == 4); 27 | assert(_w1->dim[2] <= _w2->dim[2]); 28 | assert(_w1->dim[3] <= _w2->dim[3]); 29 | Op op = model->get_or_create_enlarge(*_w1, *_w2); 30 | assert(op != Op::INVALID_OP); 31 | add_edge(_w1->op, op, _w1->idx, 0); 32 | add_edge(_w2->op, op, _w2->idx, 1); 33 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 34 | t->op = op; 35 | return t; 36 | } 37 | 38 | Op Model::get_or_create_enlarge(Tensor _w1, Tensor _w2) 39 | { 40 | // Check 1: w1 and w2 must both have 4D 41 | if (_w1.numDim != 4 || _w2.numDim != 4) 42 | return Op::INVALID_OP; 43 | // Check 2: w1 is smaller than w2 44 | if (_w1.dim[2] > _w2.dim[2] || _w1.dim[3] > _w2.dim[3]) 45 | return Op::INVALID_OP; 46 | EnlargeKey key(_w1, _w2); 47 | Enlarge* enlargeOp; 48 | if (enlarge.find(key) != enlarge.end()) { 49 | enlargeOp = enlarge[key]; 50 | } else { 51 | enlargeOp = new Enlarge(this, _w1, _w2); 52 | measure_enlarge_cost(enlargeOp); 53 | enlarge[key] = enlargeOp; 54 | } 55 | Op ret; 56 | ret.guid = global_unique_id ++; 57 | ret.ptr = enlargeOp; 58 | return ret; 59 | } 60 | 61 | Enlarge::Enlarge(Model* _model, Tensor _w1, Tensor _w2) 62 | : OpBase(_w1, _w2, _model, OP_ENLARGE) 63 | { 64 | assert(_w1.numDim == 4); 65 | assert(_w2.numDim == 4); 66 | assert(_w1.dim[2] <= _w2.dim[2]); 67 | assert(_w1.dim[3] <= _w2.dim[3]); 68 | numOutputs = 1; 69 | outputs[0].numDim = _w1.numDim; 70 | outputs[0].dim[0] = _w1.dim[0]; 71 | outputs[0].dim[1] = _w1.dim[1]; 72 | outputs[0].dim[2] = _w2.dim[2]; 73 | outputs[0].dim[3] = _w2.dim[3]; 74 | outputs[0].stride[3] = 1; 75 | outputs[0].stride[2] = outputs[0].stride[3] * outputs[0].dim[3]; 76 | outputs[0].stride[1] = outputs[0].stride[2] * outputs[0].dim[2]; 77 | outputs[0].stride[0] = outputs[0].stride[1] * outputs[0].dim[1]; 78 | // Set SplitInfo 79 | outputs[0].split[0] = _w1.split[0]; 80 | outputs[0].split[1] = _w1.split[1]; 81 | outputs[0].idx = 0; 82 | } 83 | 84 | Enlarge::~Enlarge(void) 85 | {} 86 | 87 | bool Enlarge::get_parameter(PMParameter para, int* value) 88 | { 89 | switch (para) { 90 | //case PM_KERNEL_H: 91 | // *value = kernelH; 92 | // return true; 93 | //case PM_KERNEL_W: 94 | // *value = kernelW; 95 | // return true; 96 | default: 97 | return OpBase::get_parameter(para, value); 98 | } 99 | } 100 | 101 | void Enlarge::collect_costs(float& exe_time, float& flops, 102 | float& mem_acc, int& num_kernels) 103 | { 104 | int outputSize = outputs[0].volume(); 105 | int inputSize = inputs[0].volume(); 106 | exe_time += runtime; 107 | flops += outputSize; 108 | mem_acc += inputSize + outputSize; 109 | num_kernels += 1; 110 | } 111 | 112 | // keys are (kernelH, kernelW, _weight) 113 | EnlargeKey::EnlargeKey(Tensor _w1, Tensor _w2) 114 | { 115 | assert(_w1.numDim == 4); 116 | int idx = 0; 117 | _w1.serialize(keys, idx); 118 | _w2.serialize(keys, idx); 119 | while (idx < KEY_LENGTH) 120 | keys[idx++] = 0; 121 | assert(idx == KEY_LENGTH); 122 | } 123 | -------------------------------------------------------------------------------- /src/core/matmul.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::fc(const TensorHandle _input, 20 | int _outputC, 21 | ActiMode acti) 22 | { 23 | assert(_input->numDim == 2); 24 | const int dims[2] = {_outputC, _input->dim[1]}; 25 | int total = dims[0] * dims[1]; 26 | // Randomly initialize weights 27 | DATATYPE* data = (DATATYPE*) malloc(total * sizeof(DATATYPE)); 28 | for (int i = 0; i < total; i++) 29 | data[i] = (DATATYPE)std::rand() / RAND_MAX; 30 | TensorHandle weight = new_weight(2, dims, data); 31 | free(data); 32 | return matmul(_input, weight, acti); 33 | } 34 | 35 | TensorHandle Graph::matmul(const TensorHandle _input, 36 | const TensorHandle _weight, 37 | ActiMode acti) 38 | { 39 | Op op = model->get_or_create_matmul(*_input, *_weight, acti); 40 | assert(op != Op::INVALID_OP); 41 | add_edge(_input->op, op, _input->idx, 0); 42 | add_edge(_weight->op, op, _weight->idx, 1); 43 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 44 | t->op = op; 45 | return t; 46 | } 47 | 48 | Op Model::get_or_create_matmul(Tensor _input, Tensor _weight, 49 | ActiMode _acti) 50 | { 51 | if (_input.numDim != _weight.numDim) 52 | return Op::INVALID_OP; 53 | for (int i = 0; i < _input.numDim - 2; i++) 54 | if (_input.dim[i] != _weight.dim[i]) 55 | return Op::INVALID_OP; 56 | if (_input.dim[_input.numDim-1] != _weight.dim[_weight.numDim-2]) 57 | return Op::INVALID_OP; 58 | // key is (inputX, inputN, inputC, outputC, acti) 59 | MatmulKey key(_input, _weight, _acti); 60 | Matmul* matmulOp; 61 | if (matmul.find(key) != matmul.end()) { 62 | matmulOp = matmul[key]; 63 | } else { 64 | matmulOp = new Matmul(this, _input, _weight, _acti); 65 | measure_matmul_cost(matmulOp); 66 | matmul[key] = matmulOp; 67 | } 68 | Op ret; 69 | ret.guid = global_unique_id ++; 70 | ret.ptr = matmulOp; 71 | return ret; 72 | } 73 | 74 | Matmul::Matmul(Model* _model, Tensor _input, Tensor _weight, ActiMode _activation) 75 | : OpBase(_input, _weight, _model, OP_MATMUL), activation(_activation) 76 | { 77 | int numDim = _input.numDim; 78 | assert(numDim == _weight.numDim); 79 | for (int i = 0; i < numDim - 2; i++) 80 | assert(_input.dim[i] == _weight.dim[i]); 81 | assert(_input.dim[numDim-1] == _weight.dim[numDim-2]); 82 | numOutputs = 1; 83 | // set dims and strides 84 | outputs[0].numDim = numDim; 85 | for (int i = 0; i < numDim-1; i++) 86 | outputs[0].dim[i] = _input.dim[i]; 87 | outputs[0].dim[numDim-1] = _weight.dim[numDim-1]; 88 | outputs[0].stride[numDim-2] = 1; 89 | outputs[0].stride[numDim-1] = outputs[0].dim[numDim-2]; 90 | int size = outputs[0].dim[numDim-2] * outputs[0].dim[numDim-1]; 91 | for (int i = numDim-3; i >= 0; i--) { 92 | outputs[0].stride[i] = size; 93 | size *= outputs[0].dim[i]; 94 | } 95 | assert(size == outputs[0].volume()); 96 | // set SplitInfo 97 | for (int i = 0; i < numDim-2; i++) { 98 | if (_input.split[i] == _weight.split[i]) 99 | outputs[0].split[i] = _input.split[i]; 100 | else 101 | outputs[0].split[i] = SplitInfo::NO_SPLIT; 102 | } 103 | outputs[0].split[numDim-2] = _input.split[numDim-2]; 104 | outputs[0].split[numDim-1] = _weight.split[numDim-1]; 105 | outputs[0].idx = 0; 106 | } 107 | 108 | Matmul::~Matmul(void) 109 | {} 110 | 111 | bool Matmul::get_parameter(PMParameter para, int* value) 112 | { 113 | switch (para) { 114 | case PM_ACTI: 115 | *value = (int) activation; 116 | return true; 117 | default: 118 | return OpBase::get_parameter(para, value); 119 | } 120 | } 121 | 122 | void Matmul::collect_costs(float& exe_time, float& flops, 123 | float& mem_acc, int& num_kernels) 124 | { 125 | int outputSize = 1, inputSize = 1; 126 | for (int i = 0; i < outputs[0].numDim; i++) 127 | outputSize *= outputs[0].dim[i]; 128 | for (int i = 0; i < inputs[0].numDim; i++) 129 | inputSize *= inputs[0].dim[i]; 130 | // cost metrics 131 | exe_time += runtime; 132 | assert(inputs[0].numDim == inputs[1].numDim); 133 | flops += outputSize * inputs[0].dim[inputs[0].numDim-1]; 134 | mem_acc += inputSize; 135 | num_kernels += 1; 136 | printf(" cost[Matmul]: %s %s cost(%.4lf) total_cost(%.4lf)\n", 137 | inputs[0].to_string("input").c_str(), 138 | inputs[1].to_string("weight").c_str(), 139 | runtime, exe_time); 140 | } 141 | 142 | // key is (inputN, inputC, outputC, acti) 143 | MatmulKey::MatmulKey(Tensor _input, Tensor _weight, ActiMode _mode) 144 | { 145 | assert(_input.numDim == _weight.numDim); 146 | int idx = 0; 147 | keys[idx++] = (int)(_mode); 148 | _input.serialize(keys, idx); 149 | _weight.serialize(keys, idx); 150 | while (idx < KEY_LENGTH) 151 | keys[idx++] = 0; 152 | assert(idx == KEY_LENGTH); 153 | } 154 | 155 | -------------------------------------------------------------------------------- /src/core/merge_gconv.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | // Merge multiple group convs to a single group conv 20 | TensorHandle Graph::merge_gconv(const TensorHandle _weight, 21 | int count) 22 | { 23 | // Currently the weight being merged must be 4D: 24 | // Count, Cin, KernelH, KernelW 25 | assert(_weight->numDim == 4); 26 | Op op = model->get_or_create_merge_gconv(*_weight, count); 27 | assert(op != Op::INVALID_OP); 28 | add_edge(_weight->op, op, _weight->idx, 0); 29 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 30 | t->op = op; 31 | return t; 32 | } 33 | 34 | Op Model::get_or_create_merge_gconv(const Tensor& _weight, 35 | int count) 36 | { 37 | // Check 1: weight must have 4D 38 | if (_weight.numDim != 4) 39 | return Op::INVALID_OP; 40 | // new group number must be an integer 41 | //if (_input.dim[1] % (_weight.dim[1] * count) != 0) 42 | //return Op::INVALID_OP; 43 | MergeGConvKey key(_weight, count); 44 | MergeGConv* mergeOp; 45 | if (merge_gconv.find(key) != merge_gconv.end()) { 46 | mergeOp = merge_gconv[key]; 47 | } else { 48 | mergeOp = new MergeGConv(this, _weight, count); 49 | mergeOp->runtime = 0.0f; 50 | merge_gconv[key] = mergeOp; 51 | } 52 | Op ret; 53 | ret.guid = global_unique_id ++; 54 | ret.ptr = mergeOp; 55 | return ret; 56 | } 57 | 58 | MergeGConv::MergeGConv(Model* _model, 59 | const Tensor& _weight, 60 | int _count) 61 | : OpBase(_weight, _model, OP_MERGE_GCONV), count(_count) 62 | { 63 | assert(_weight.numDim == 4); 64 | numOutputs = 1; 65 | outputs[0].numDim = _weight.numDim; 66 | outputs[0].dim[0] = _weight.dim[0]; 67 | outputs[0].dim[1] = _weight.dim[1] * count; 68 | outputs[0].dim[2] = _weight.dim[2]; 69 | outputs[0].dim[3] = _weight.dim[3]; 70 | outputs[0].stride[3] = 1; 71 | outputs[0].stride[2] = outputs[0].stride[3] * outputs[0].dim[3]; 72 | outputs[0].stride[1] = outputs[0].stride[2] * outputs[0].dim[2]; 73 | outputs[0].stride[0] = outputs[0].stride[1] * outputs[0].dim[1]; 74 | // Set SplitInfo 75 | outputs[0].split[0] = _weight.split[0]; 76 | outputs[0].split[1] = SplitInfo::NO_SPLIT; 77 | outputs[0].split[2] = _weight.split[2]; 78 | outputs[0].split[3] = _weight.split[3]; 79 | outputs[0].idx = 0; 80 | // assume that group number is an integer 81 | } 82 | 83 | MergeGConv::~MergeGConv(void) 84 | {} 85 | 86 | bool MergeGConv::get_parameter(PMParameter para, int* value) 87 | { 88 | switch (para) { 89 | case PM_MERGE_GCONV_COUNT: 90 | *value = count; 91 | return true; 92 | default: 93 | return OpBase::get_parameter(para, value); 94 | } 95 | } 96 | 97 | void MergeGConv::collect_costs(float& exe_time, float& flops, 98 | float& mem_acc, int& num_kernels) 99 | { 100 | int outputSize = outputs[0].volume(); 101 | int inputSize = inputs[0].volume(); 102 | exe_time += runtime; 103 | flops += outputSize; 104 | mem_acc += inputSize + outputSize; 105 | num_kernels += 1; 106 | } 107 | 108 | // keys are (count, _weight) 109 | MergeGConvKey::MergeGConvKey(const Tensor& _weight, 110 | int count) 111 | { 112 | assert(_weight.numDim == 4); 113 | int idx = 0; 114 | keys[idx++] = count; 115 | _weight.serialize(keys, idx); 116 | while (idx < KEY_LENGTH) 117 | keys[idx++] = 0; 118 | assert(idx == KEY_LENGTH); 119 | } 120 | -------------------------------------------------------------------------------- /src/core/mul.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::mul(const TensorHandle x, 20 | const TensorHandle y) 21 | { 22 | Op op = model->get_or_create_mul(*x, *y); 23 | add_edge(x->op, op, x->idx, 0); 24 | add_edge(y->op, op, y->idx, 1); 25 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 26 | t->op = op; 27 | return t; 28 | } 29 | 30 | Op Model::get_or_create_mul(const Tensor& x, 31 | const Tensor& y) 32 | { 33 | MulKey key(x, y); 34 | Mul* mulOp; 35 | if (mul.find(key) != mul.end()) { 36 | mulOp = mul[key]; 37 | } else { 38 | mulOp = new Mul(this, x, y); 39 | measure_mul_cost(mulOp); 40 | mul[key] = mulOp; 41 | } 42 | Op ret; 43 | ret.guid = global_unique_id ++; 44 | ret.ptr = mulOp; 45 | return ret; 46 | } 47 | 48 | Mul::Mul(Model* _model, const Tensor& x, const Tensor& y) 49 | : OpBase(x, y, _model, OP_MUL) 50 | { 51 | // TODO: support broadcast 52 | // Currently assume _y.numDim = 0 53 | int numDim = x.numDim; 54 | assert(y.numDim == 0); 55 | for (int i = 0; i < y.numDim; i++) 56 | assert(x.dim[i] == y.dim[i]); 57 | numOutputs = 1; 58 | outputs[0].numDim = numDim; 59 | for (int i = 0; i < numDim-1; i++) { 60 | outputs[0].dim[i] = x.dim[i]; 61 | outputs[0].stride[i] = x.stride[i]; 62 | outputs[0].split[i] = x.split[i]; 63 | } 64 | outputs[0].idx = 0; 65 | } 66 | 67 | Mul::~Mul(void) 68 | {} 69 | 70 | bool Mul::get_parameter(PMParameter para, int* value) 71 | { 72 | return OpBase::get_parameter(para, value); 73 | } 74 | 75 | void Mul::collect_costs(float& exe_time, float& flops, 76 | float& mem_acc, int& num_kernels) 77 | { 78 | // TODO: to be implemented 79 | assert(false); 80 | } 81 | 82 | MulKey::MulKey(const Tensor& _x, const Tensor& _y) 83 | { 84 | int idx = 0; 85 | _x.serialize(keys, idx); 86 | _y.serialize(keys, idx); 87 | while (idx < KEY_LENGTH) 88 | keys[idx++] = 0; 89 | assert(idx == KEY_LENGTH); 90 | } 91 | 92 | -------------------------------------------------------------------------------- /src/core/noop.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::input_wrapper(const TensorHandle _input) 20 | { 21 | // Always create new operator for input 22 | Op op = model->create_input(*_input, OP_INPUT); 23 | add_edge(_input->op, op, _input->idx, 0); 24 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 25 | t->op = op; 26 | return t; 27 | } 28 | 29 | TensorHandle Graph::weight_wrapper(const TensorHandle _weight) 30 | { 31 | // Always create new operator for weight 32 | Op op = model->create_weight(*_weight, OP_WEIGHT); 33 | add_edge(_weight->op, op, _weight->idx, 0); 34 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 35 | t->op = op; 36 | return t; 37 | } 38 | 39 | // TODO: we ignore dropout rate for inference 40 | TensorHandle Graph::dropout(const TensorHandle _input) 41 | { 42 | Op op = model->get_or_create_noop(*_input, OP_DROPOUT); 43 | add_edge(_input->op, op, _input->idx, 0); 44 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 45 | t->op = op; 46 | return t; 47 | } 48 | 49 | Op Model::create_input(Tensor _input, OpType _type) 50 | { 51 | assert(_type == OP_INPUT); 52 | Op ret; 53 | ret.ptr = new NoOp(this, _input, _type); 54 | ret.guid = global_unique_id ++; 55 | return ret; 56 | } 57 | 58 | Op Model::create_weight(Tensor _weight, OpType _type) 59 | { 60 | assert(_type == OP_WEIGHT); 61 | assert(_weight.data_ptr != NULL); 62 | Op ret; 63 | ret.ptr = new NoOp(this, _weight, _type); 64 | ret.guid = global_unique_id ++; 65 | return ret; 66 | } 67 | 68 | 69 | Op Model::get_or_create_noop(Tensor _input, OpType _type) 70 | { 71 | assert(_type == OP_DROPOUT); 72 | // key is (_type, _input) 73 | NoopKey key(_input, _type); 74 | NoOp* noOp; 75 | if (noop.find(key) != noop.end()) { 76 | noOp = noop[key]; 77 | } else { 78 | noOp = new NoOp(this, _input, _type); 79 | noOp->runtime = 0.0f; 80 | noop[key] = noOp; 81 | } 82 | Op ret; 83 | ret.guid = global_unique_id ++; 84 | ret.ptr = noOp; 85 | return ret; 86 | } 87 | 88 | NoOp::NoOp(Model* _model, Tensor _input, OpType type) 89 | : OpBase(_input, _model, type) 90 | { 91 | numOutputs = 1; 92 | outputs[0] = _input; 93 | outputs[0].idx = 0; 94 | } 95 | 96 | NoOp::~NoOp(void) 97 | {} 98 | 99 | bool NoOp::get_parameter(PMParameter para, int* value) 100 | { 101 | switch (para) { 102 | case PM_OP_TYPE: 103 | *value = (int) type; 104 | return true; 105 | case PM_NUM_INPUTS: 106 | *value = numInputs; 107 | return true; 108 | case PM_NUM_OUTPUTS: 109 | *value = numOutputs; 110 | return true; 111 | default: 112 | return false; 113 | } 114 | } 115 | 116 | void NoOp::map(void) 117 | {} 118 | 119 | void NoOp::unmap(void) 120 | {} 121 | 122 | void NoOp::forward(bool block) 123 | {} 124 | 125 | void NoOp::collect_costs(float& exe_time, float& flops, 126 | float& mem_acc, int& num_kernels) 127 | { 128 | // cost metrics 129 | exe_time += 0; 130 | flops += 0; 131 | mem_acc += 0; 132 | num_kernels += 0; 133 | } 134 | 135 | // key ordering: _type, input 136 | NoopKey::NoopKey(Tensor input, OpType _type) 137 | { 138 | int idx = 0; 139 | keys[idx++] = _type; 140 | input.serialize(keys, idx); 141 | while (idx < KEY_LENGTH) 142 | keys[idx++] = 0; 143 | assert(idx == KEY_LENGTH); 144 | } 145 | 146 | -------------------------------------------------------------------------------- /src/core/pool2d.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::pool2d_max(const TensorHandle _input, 20 | int _kernelH, int _kernelW, 21 | int _strideH, int _strideW, 22 | PaddingMode _padding, 23 | ActiMode _activation) 24 | { 25 | int num = _input->dim[1] * _kernelH * _kernelW; 26 | DATATYPE* data_ptr = (DATATYPE*) malloc(num * sizeof(DATATYPE)); 27 | for (int i = 0; i < num; i++) 28 | data_ptr[i] = 1.0f / (_kernelH * _kernelW); 29 | const int dims[4] = {_input->dim[1], 1, _kernelH, _kernelW}; 30 | TensorHandle weight = new_weight(4, dims, data_ptr); 31 | /* 32 | weight.numDim = 4; 33 | weight.dim[0] = _input.dim[1]; 34 | weight.dim[1] = 1; 35 | weight.dim[2] = _kernelH; 36 | weight.dim[3] = _kernelW; 37 | weight.stride[3] = 1; 38 | weight.stride[2] = weight.stride[3] * weight.dim[3]; 39 | weight.stride[1] = weight.stride[2] * weight.dim[2]; 40 | weight.stride[0] = weight.stride[1] * weight.dim[1]; 41 | weight.op.guid = GUID_WEIGHT; 42 | weight.op.ptr = NULL; 43 | weight.idx = 0; 44 | weight = noop(weight); 45 | */ 46 | Op op = model->get_or_create_pool2d( 47 | *_input, *weight, OP_POOL2D_MAX, _kernelH, _kernelW, 48 | _strideH, _strideW, _padding, _activation); 49 | add_edge(_input->op, op, _input->idx, 0); 50 | add_edge(weight->op, op, weight->idx, 1); 51 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 52 | t->op = op; 53 | return t; 54 | } 55 | 56 | TensorHandle Graph::pool2d_avg(const TensorHandle _input, 57 | int _kernelH, int _kernelW, 58 | int _strideH, int _strideW, 59 | PaddingMode _padding, 60 | ActiMode _activation) 61 | { 62 | int num = _input->dim[1] * _kernelH * _kernelW; 63 | DATATYPE* data_ptr = (DATATYPE*) malloc(num * sizeof(DATATYPE)); 64 | for (int i = 0; i < num; i++) 65 | data_ptr[i] = 1.0f / (_kernelH * _kernelW); 66 | const int dims[4] = {_input->dim[1], 1, _kernelH, _kernelW}; 67 | TensorHandle weight = new_weight(4, dims, data_ptr); 68 | /* 69 | weight.numDim = 4; 70 | weight.dim[0] = _input.dim[1]; 71 | weight.dim[1] = 1; 72 | weight.dim[2] = _kernelH; 73 | weight.dim[3] = _kernelW; 74 | weight.stride[3] = 1; 75 | weight.stride[2] = weight.stride[3] * weight.dim[3]; 76 | weight.stride[1] = weight.stride[2] * weight.dim[2]; 77 | weight.stride[0] = weight.stride[1] * weight.dim[1]; 78 | weight.op.guid = GUID_WEIGHT; 79 | weight.op.ptr = NULL; 80 | weight.idx = 0; 81 | weight = noop(weight); 82 | */ 83 | Op op = model->get_or_create_pool2d( 84 | *_input, *weight, OP_POOL2D_AVG, _kernelH, _kernelW, 85 | _strideH, _strideW, _padding, _activation); 86 | add_edge(_input->op, op, _input->idx, 0); 87 | add_edge(weight->op, op, weight->idx, 1); 88 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 89 | t->op = op; 90 | return t; 91 | } 92 | 93 | Op Model::get_or_create_pool2d(Tensor _input, Tensor _weight, 94 | OpType _type, 95 | int _kernelH, int _kernelW, 96 | int _strideH, int _strideW, 97 | PaddingMode _padding, 98 | ActiMode _activation) 99 | 100 | { 101 | // keys are (inputN, inputC, inputH, inputW, kernelH, kernelW, 102 | // strideH, strideW, padding, activation, _type) 103 | Pool2DKey key(_input, _type, _kernelH, _kernelW, _strideH, _strideW, 104 | _padding, _activation); 105 | Pool2D* poolOp; 106 | if (pool2d.find(key) != pool2d.end()) { 107 | poolOp = pool2d[key]; 108 | } else { 109 | poolOp = new Pool2D(this, _input, _weight, _type, _kernelH, _kernelW, 110 | _strideH, _strideW, _padding, _activation); 111 | measure_pool2d_cost(poolOp); 112 | pool2d[key] = poolOp; 113 | } 114 | Op ret; 115 | ret.guid = global_unique_id ++; 116 | ret.ptr = poolOp; 117 | return ret; 118 | } 119 | 120 | Pool2D::Pool2D(Model* _model, Tensor _input, 121 | Tensor _weight, OpType _type, 122 | int _kernelH, int _kernelW, 123 | int _strideH, int _strideW, 124 | PaddingMode _padding, 125 | ActiMode _activation) 126 | : OpBase(_input, _weight, _model, _type), 127 | kernelH(_kernelH), kernelW(_kernelW), 128 | strideH(_strideH), strideW(_strideW), 129 | padding(_padding), activation(_activation) 130 | { 131 | assert(type == OP_POOL2D_MAX || type == OP_POOL2D_AVG); 132 | assert(_input.numDim == 4); 133 | int inputC = _input.dim[1]; 134 | int inputH = _input.dim[2]; 135 | int inputW = _input.dim[3]; 136 | int outputH, outputW; 137 | switch (padding) 138 | { 139 | case PD_MODE_SAME: 140 | outputH = (inputH + strideH - 1) / strideH; 141 | outputW = (inputW + strideW - 1) / strideW; 142 | break; 143 | case PD_MODE_VALID: 144 | outputH = (inputH - kernelH) / strideH + 1; 145 | outputW = (inputW - kernelW) / strideW + 1; 146 | break; 147 | default: 148 | assert(false); 149 | } 150 | //int outputH = 1 + (inputH + 2 * padH - kernelH) / strideH; 151 | //int outputW = 1 + (inputW + 2 * padW - kernelW) / strideW; 152 | //printf("k(%d %d) padding(%d) s(%d %d) o(%d %d)\n", 153 | // kernelH, kernelW, padding, strideH, strideW, outputH, outputW); 154 | numOutputs = 1; 155 | outputs[0].numDim = 4; 156 | outputs[0].dim[0] = _input.dim[0]; 157 | outputs[0].dim[1] = _input.dim[1]; 158 | outputs[0].dim[2] = outputH; 159 | outputs[0].dim[3] = outputW; 160 | // Set strides 161 | outputs[0].stride[3] = 1; 162 | outputs[0].stride[2] = outputs[0].dim[3] * outputs[0].stride[3]; 163 | outputs[0].stride[1] = outputs[0].dim[2] * outputs[0].stride[2]; 164 | outputs[0].stride[0] = outputs[0].dim[1] * outputs[0].stride[1]; 165 | // Set SplitInfo 166 | outputs[0].split[0] = _input.split[0]; 167 | outputs[0].split[1] = _input.split[1]; 168 | outputs[0].split[2] = SplitInfo::NO_SPLIT; 169 | outputs[0].split[3] = SplitInfo::NO_SPLIT; 170 | outputs[0].idx = 0; 171 | } 172 | 173 | Pool2D::~Pool2D(void) 174 | { 175 | } 176 | 177 | bool Pool2D::get_parameter(PMParameter para, int* value) 178 | { 179 | switch (para) { 180 | case PM_KERNEL_H: 181 | *value = kernelH; 182 | return true; 183 | case PM_KERNEL_W: 184 | *value = kernelW; 185 | return true; 186 | case PM_STRIDE_H: 187 | *value = strideH; 188 | return true; 189 | case PM_STRIDE_W: 190 | *value = strideW; 191 | return true; 192 | case PM_PAD: 193 | *value = padding; 194 | return true; 195 | case PM_ACTI: 196 | *value = activation; 197 | return true; 198 | default: 199 | return OpBase::get_parameter(para, value); 200 | } 201 | } 202 | 203 | void Pool2D::get_padding(int* padH, int* padW) { 204 | int inputH = inputs[0].dim[2]; 205 | int inputW = inputs[0].dim[3]; 206 | // TODO eliminate duplicated code with conv2d version 207 | // Reference: https://www.tensorflow.org/api_guides/python/nn#Convolution 208 | switch (padding) { 209 | case PD_MODE_SAME: 210 | int totalPadH, totalPadW; 211 | if (inputH % strideH == 0) 212 | totalPadH = max(kernelH - strideH, 0); 213 | else 214 | totalPadH = max(kernelH - (inputH % strideH), 0); 215 | if (inputW % strideW == 0) 216 | totalPadW = max(kernelW - strideW, 0); 217 | else 218 | totalPadW = max(kernelW - (inputW % strideW), 0); 219 | // assert same padding on both sides 220 | *padH = (totalPadH + 1) / 2; 221 | *padW = (totalPadW + 1) / 2; 222 | break; 223 | case PD_MODE_VALID: 224 | *padH = 0; 225 | *padW = 0; 226 | break; 227 | default: 228 | assert(false); 229 | } 230 | } 231 | 232 | void Pool2D::collect_costs(float& exe_time, float& flops, 233 | float& mem_acc, int& num_kernels) 234 | { 235 | int outputSize = 1, inputSize = 1; 236 | for (int i = 0; i < outputs[0].numDim; i++) 237 | outputSize *= outputs[0].dim[i]; 238 | for (int i = 0; i < inputs[0].numDim; i++) 239 | inputSize *= inputs[0].dim[i]; 240 | // cost metrics 241 | exe_time += runtime; 242 | flops += outputSize * kernelH * kernelW; 243 | mem_acc += inputSize; 244 | num_kernels += 1; 245 | printf(" cost[Pool2D]: i(%d %d %d %d) k(%d %d) s(%d %d) cost(%.4lf) total_cost(%.4lf)\n", 246 | inputs[0].dim[0], inputs[0].dim[1], inputs[0].dim[2], inputs[0].dim[3], 247 | kernelH, kernelW, strideH, strideW, runtime, exe_time); 248 | } 249 | 250 | // keys are (kernelH, kernelW, strideH, strideW, padding, activation, _type, 251 | // input) 252 | Pool2DKey::Pool2DKey(Tensor _input, OpType _type, 253 | int _kernelH, int _kernelW, int _strideH, int _strideW, 254 | PaddingMode _padding, 255 | ActiMode _activation) 256 | { 257 | int idx = 0; 258 | keys[idx++] = _kernelH; 259 | keys[idx++] = _kernelW; 260 | keys[idx++] = _strideH; 261 | keys[idx++] = _strideW; 262 | keys[idx++] = _padding; 263 | keys[idx++] = _activation; 264 | keys[idx++] = _type; 265 | _input.serialize(keys, idx); 266 | while (idx < KEY_LENGTH) 267 | keys[idx++] = 0; 268 | assert(KEY_LENGTH == idx); 269 | } 270 | 271 | -------------------------------------------------------------------------------- /src/core/reshape.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | TensorHandle Graph::reshape(const TensorHandle _input, 20 | const std::vector& _shape) 21 | { 22 | std::vector myshape = _shape; 23 | // replace zeros with input dims 24 | for (size_t i = 0; i < myshape.size(); i++) 25 | if (myshape[i] == 0) 26 | myshape[i] = _input->dim[i]; 27 | int input_size = _input->volume(); 28 | // replace -1 with actual size 29 | for (size_t i = 0; i < myshape.size(); i++) 30 | if (myshape[i] != -1) { 31 | assert(input_size % myshape[i] == 0); 32 | input_size /= myshape[i]; 33 | } 34 | for (size_t i = 0; i < myshape.size(); i++) 35 | if (myshape[i] == -1) { 36 | myshape[i] = input_size; 37 | input_size = 1; 38 | } 39 | assert(input_size == 1); 40 | Op op = model->get_or_create_reshape(*_input, myshape); 41 | add_edge(_input->op, op, _input->idx, 0); 42 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 43 | t->op = op; 44 | return t; 45 | } 46 | 47 | Op Model::get_or_create_reshape(Tensor _input, 48 | const std::vector& _shape) 49 | { 50 | ReshapeKey key(_input, _shape); 51 | Reshape* reshapeOp; 52 | if (reshape.find(key) != reshape.end()) { 53 | reshapeOp = reshape[key]; 54 | } else { 55 | reshapeOp = new Reshape(this, _input, _shape); 56 | measure_reshape_cost(reshapeOp); 57 | reshape[key] = reshapeOp; 58 | } 59 | Op ret; 60 | ret.guid = global_unique_id ++; 61 | ret.ptr = reshapeOp; 62 | return ret; 63 | } 64 | 65 | Reshape::Reshape(Model* _model, Tensor _input, 66 | const std::vector& _shape) 67 | 68 | : OpBase(_input, _model, OP_RESHAPE) 69 | { 70 | int size = 1; 71 | // set dims and strides 72 | numOutputs = 1; 73 | outputs[0].numDim = _shape.size(); 74 | for (int i = _shape.size() - 1; i >= 0; i--) { 75 | outputs[0].dim[i] = _shape[i]; 76 | outputs[0].stride[i] = size; 77 | size *= _shape[i]; 78 | outputs[0].split[i] = SplitInfo::NO_SPLIT; 79 | } 80 | assert(_input.volume() == size); 81 | outputs[0].idx = 0; 82 | } 83 | 84 | Reshape::~Reshape(void) 85 | {} 86 | 87 | bool Reshape::get_parameter(PMParameter para, int* value) 88 | { 89 | return OpBase::get_parameter(para, value); 90 | } 91 | 92 | void Reshape::collect_costs(float& exe_time, float& flops, 93 | float& mem_acc, int& num_kernels) 94 | { 95 | } 96 | 97 | ReshapeKey::ReshapeKey(Tensor _input, const std::vector& shape) 98 | { 99 | int idx = 0; 100 | keys[idx++] = shape.size(); 101 | for (size_t i = 0; i < shape.size(); i++) 102 | keys[idx++] = shape[i]; 103 | _input.serialize(keys, idx); 104 | while (idx < KEY_LENGTH) 105 | keys[idx++] = 0; 106 | assert(idx == KEY_LENGTH); 107 | } 108 | 109 | -------------------------------------------------------------------------------- /src/core/rules.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package GraphSubst; 4 | 5 | message Parameter { 6 | required int32 key = 1; 7 | required int32 value = 2; 8 | } 9 | 10 | message Tensor { 11 | required int32 opId = 1; 12 | required int32 tsId = 2; 13 | } 14 | 15 | message Operator { 16 | required int32 type = 1; 17 | repeated Tensor input = 2; 18 | repeated Parameter para = 3; 19 | } 20 | 21 | message MapOutput { 22 | required int32 srcOpId = 1; 23 | required int32 dstOpId = 2; 24 | required int32 srcTsId = 3; 25 | required int32 dstTsId = 4; 26 | } 27 | 28 | message Rule { 29 | repeated Operator srcOp = 1; 30 | repeated Operator dstOp = 2; 31 | repeated MapOutput mappedOutput = 3; 32 | } 33 | 34 | message RuleCollection { 35 | repeated Rule rule = 1; 36 | } 37 | -------------------------------------------------------------------------------- /src/core/split.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | void Graph::split(Tensor _input, int axis, int _num, 20 | const int* _sizes, Tensor* outputs) 21 | { 22 | int n = _num, sizes[MAX_NUM_OUTPUTS]; 23 | for (int i = 0; i < n; i++) 24 | sizes[i] = _sizes[i]; 25 | Op op = model->get_or_create_split(_input, axis, n, sizes); 26 | add_edge(_input.op, op, _input.idx, 0); 27 | for (int i = 0; i < n; i++) { 28 | outputs[i] = op.ptr->outputs[i]; 29 | outputs[i].op = op; 30 | } 31 | } 32 | 33 | /* 34 | void Graph::split(Tensor _input, int axis, int _num, Tensor* outputs) 35 | { 36 | int sizes[MAX_NUM_OUTPUTS]; 37 | SplitInfo parent = _input.split[axis], left, right; 38 | int curPos, oldPos = _input.dim[axis]; 39 | for (int i = _num - 1; i >= 0; i--) { 40 | parent.divide(left, right, curPos); 41 | sizes[i] = oldPos - curPos; 42 | oldPos = curPos; 43 | parent = left; 44 | } 45 | Graph::split(_input, axis, _num, sizes, outputs); 46 | } 47 | */ 48 | 49 | void Graph::split(Tensor _input, int axis, int size1, int size2, Tensor* outputs) 50 | { 51 | int sizes[2]; 52 | sizes[0] = size1; 53 | sizes[1] = size2; 54 | Graph::split(_input, axis, 2, sizes, outputs); 55 | } 56 | 57 | Op Model::get_or_create_split(Tensor _input, int axis, int n, int* sizes) 58 | { 59 | // key ordering is: 60 | // axis, n, inputs[0].dim[0], ..., inputs[0].dim[axis-1], 61 | // inputs[0].dim[axis+1], ..., inputs[0].dim[nDims - 1] 62 | // sizes[0], ..., sizes[n-1] 63 | SplitKey key(_input, axis, n, sizes); 64 | Split* splitOp; 65 | if (split.find(key) != split.end()) { 66 | splitOp = split[key]; 67 | } else { 68 | splitOp = new Split(this, _input, axis, n, sizes); 69 | measure_split_cost(splitOp); 70 | split[key] = splitOp; 71 | } 72 | Op ret; 73 | ret.guid = global_unique_id ++; 74 | ret.ptr = splitOp; 75 | return ret; 76 | } 77 | 78 | Op Model::get_or_create_split(Tensor _input, int axis, int n) 79 | { 80 | int sizes[MAX_NUM_OUTPUTS]; 81 | SplitInfo parent = _input.split[axis], left, right; 82 | int curPos, oldPos = _input.dim[axis]; 83 | for (int i = n - 1; i > 0; i--) { 84 | parent.divide(left, right, curPos); 85 | sizes[i] = oldPos - curPos; 86 | oldPos = curPos; 87 | parent = left; 88 | } 89 | sizes[0] = oldPos; 90 | Op ret = get_or_create_split(_input, axis, n, sizes); 91 | return ret; 92 | } 93 | 94 | Split::Split(Model* _model, Tensor _input, int _axis, int n, int* _sizes) 95 | : OpBase(_input, model, OP_SPLIT), axis(_axis) 96 | { 97 | assert(n <= MAX_NUM_OUTPUTS); 98 | numOutputs = n; 99 | for (int i = 0; i < n; i++) 100 | sizes[i] = _sizes[i]; 101 | SplitInfo parent = inputs[0].split[axis], left, right; 102 | int oldPos = inputs[0].dim[axis], curPos; 103 | bool misMatch = false; 104 | for (int i = n - 1; i >= 0; i--) { 105 | outputs[i].numDim = inputs[0].numDim; 106 | for (int j = 0; j < inputs[0].numDim; j++) 107 | if (j != axis) { 108 | outputs[i].dim[j] = inputs[0].dim[j]; 109 | outputs[i].stride[j] = inputs[0].stride[j]; 110 | outputs[i].split[j] = inputs[0].split[j]; 111 | } else { 112 | outputs[i].dim[j] = _sizes[i]; 113 | outputs[i].stride[j] = inputs[0].stride[j]; 114 | if (i > 0) { 115 | parent.divide(left, right, curPos); 116 | } else { 117 | curPos = 0; 118 | right = parent; 119 | } 120 | if (oldPos - curPos == _sizes[i]) 121 | outputs[i].split[j] = right; 122 | else { 123 | misMatch = true; 124 | outputs[i].split[j] = SplitInfo::NO_SPLIT; 125 | } 126 | oldPos = curPos; 127 | parent = left; 128 | } 129 | } 130 | if (misMatch) { 131 | // Clear split info if mismatch 132 | for (int i = n - 1; i >= 0; i--) 133 | outputs[i].split[axis] = SplitInfo::NO_SPLIT; 134 | } 135 | } 136 | 137 | Split::~Split(void) 138 | {} 139 | 140 | bool Split::get_parameter(PMParameter para, int* value) 141 | { 142 | switch (para) { 143 | case PM_AXIS: 144 | *value = axis; 145 | return true; 146 | default: 147 | return OpBase::get_parameter(para, value); 148 | } 149 | } 150 | 151 | void Split::map(void) 152 | { 153 | size_t offset = 0; 154 | for (int i = 0; i < numOutputs; i++) { 155 | outputs[i].data_ptr = (DATATYPE*)inputs[0].data_ptr + offset; 156 | offset += outputs[i].dim[axis] * inputs[0].stride[axis]; 157 | } 158 | } 159 | 160 | void Split::unmap(void) 161 | {} 162 | 163 | void Split::forward(bool block) 164 | {} 165 | 166 | void Split::collect_costs(float& exe_time, float& flops, 167 | float& mem_acc, int& num_kernels) 168 | { 169 | // cost metrics 170 | exe_time += 0; 171 | flops += 0; 172 | mem_acc += 0; 173 | num_kernels += 0; 174 | printf(" cost[Split]: numOutputs(%d) cost(%.4lf) total_cost(%.4lf)\n", 175 | numOutputs, 0.0f, exe_time); 176 | } 177 | 178 | void Model::measure_split_cost(Split* split) 179 | { 180 | // We assume split cost is zero 181 | split->runtime = 0; 182 | if (print_cost) 183 | printf(" measure[split]: cost(%.4lf)\n", split->runtime); 184 | } 185 | 186 | // key ordering is: 187 | // axis, n, sizes[0], ..., sizes[n-1], input 188 | SplitKey::SplitKey(Tensor input, int axis, int n, int* sizes) 189 | { 190 | int idx = 0; 191 | keys[idx++] = axis; 192 | keys[idx++] = n; 193 | for (int i = 0; i < n; i++) 194 | keys[idx++] = sizes[i]; 195 | input.serialize(keys, idx); 196 | while (idx < KEY_LENGTH) 197 | keys[idx++] = 0; 198 | assert(idx == KEY_LENGTH); 199 | } 200 | -------------------------------------------------------------------------------- /src/core/substitution.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #ifndef _SUBSTITUTION_H_ 17 | #define _SUBSTITUTION_H_ 18 | #include "xflow/ops.h" 19 | #include "rules.pb.h" 20 | #include 21 | namespace XFlow { 22 | 23 | enum Compare { 24 | COMPARE_EQ, 25 | COMPARE_NE, 26 | COMPARE_LT, 27 | COMPARE_LE, 28 | COMPARE_GT, 29 | COMPARE_GE, 30 | }; 31 | 32 | struct PMConstraint { 33 | PMConstraint(Compare comp, PMParameter para, int value); 34 | Compare comp; 35 | PMParameter para; 36 | int value; 37 | }; 38 | 39 | struct TNConstraint { 40 | TNConstraint(Compare comp, TNParameter para, DIMParameter dim, int value); 41 | TNConstraint(Compare comp, TNParameter para1, DIMParameter dim1, 42 | TNParameter para2, DIMParameter dim2); 43 | bool singlePara; 44 | Compare comp; 45 | TNParameter para1, para2; 46 | DIMParameter dim1, dim2; 47 | int value; 48 | }; 49 | 50 | class OpX; 51 | class GraphXfer; 52 | struct TensorX { 53 | TensorX(void): op(NULL), idx(0) {} 54 | TensorX(OpX* _op, int _idx): op(_op), idx(_idx) {} 55 | Tensor to_tensor(const GraphXfer* xfer) const; 56 | OpX* op; 57 | int idx; 58 | }; 59 | 60 | struct TensorXCompare { 61 | bool operator()(const TensorX& a, const TensorX& b) const { 62 | if (a.op != b.op) return a.op < b.op; 63 | return a.idx < b.idx; 64 | }; 65 | }; 66 | 67 | class OpX { 68 | public: 69 | OpX(OpType _type, TensorX input0, int numOutputs = 1); 70 | OpX(OpType _type, TensorX input0, TensorX input1); 71 | OpX(OpType _type, int n, TensorX* ins); 72 | bool add_pm_constraint(Compare comp, PMParameter para, int value); 73 | bool add_input_constraint(Compare, TNParameter, DIMParameter, int); 74 | bool add_input_constraint(Compare, TNParameter, DIMParameter, TNParameter, DIMParameter); 75 | bool get_pm_constraint(PMParameter para, int& value) const; 76 | public: 77 | OpType type; 78 | Op mapOp; 79 | std::vector inputs, outputs; 80 | std::vector pmConstraints; 81 | std::vector tnConstraints; 82 | }; 83 | 84 | class DstOp; 85 | class SrcOp { 86 | public: 87 | SrcOp(OpType _type); 88 | bool add_constraint(Compare comp, PMParameter para, int value); 89 | bool match(Op op); 90 | public: 91 | std::vector constraints; 92 | OpType type; 93 | Op mapOp; 94 | DstOp *mapInput, *mapOutput; 95 | }; 96 | 97 | class DstOp { 98 | public: 99 | DstOp(OpType _type); 100 | DstOp(OpType _type, const SrcOp* op); 101 | DstOp(OpType _type, const SrcOp* op1, const SrcOp* op2); 102 | virtual Op create_operator(Model* model) = 0; 103 | public: 104 | OpType type; 105 | Op mapOp; 106 | SrcOp *mapInput, *mapOutput; 107 | SrcOp *srcOps[MAX_NUM_INPUTS]; 108 | }; 109 | 110 | template 111 | struct SubEdge { 112 | SubEdge(OpType* _srcOp, OpType* _dstOp, int _srcIdx, int _dstIdx) 113 | : srcOp(_srcOp), dstOp(_dstOp), srcIdx(_srcIdx), dstIdx(_dstIdx) {} 114 | int srcIdx, dstIdx; 115 | OpType *srcOp, *dstOp; 116 | }; 117 | 118 | template 119 | struct SubEdgeCompare { 120 | bool operator()(const SubEdge& a, const SubEdge& b) const { 121 | if (a.srcOp != b.srcOp) return a.srcOp < b.srcOp; 122 | if (a.dstOp != b.dstOp) return a.dstOp < b.dstOp; 123 | if (a.srcIdx != b.srcIdx) return a.srcIdx < b.srcIdx; 124 | if (a.dstIdx != b.dstIdx) return a.dstIdx < b.dstIdx; 125 | return false; 126 | }; 127 | }; 128 | 129 | class GraphCompare { 130 | public: 131 | bool operator() (Graph* lhs, Graph* rhs) { 132 | return lhs->total_cost() > rhs->total_cost(); 133 | } 134 | }; 135 | 136 | class GraphXfer { 137 | public: 138 | GraphXfer(Model* _model); 139 | static void load_graph_xfer_from_pb_file(Model* model, 140 | std::vector& xfers, 141 | std::string filename); 142 | TensorX new_tensor(void); 143 | bool can_match(OpX* srcOp, Op op, Graph* graph); 144 | void match(OpX* srcOp, Op op, Graph* graph); 145 | void unmatch(OpX* srcOp, Op op, Graph* graph); 146 | void create_operator_from_pb(const GraphSubst::Operator& pbOp, 147 | std::map& mappedInputs, 148 | bool isSrcOp = true); 149 | OpX* create_activation(TensorX input, OpType type, bool isSrcOp = true); 150 | OpX* create_conv2d(TensorX input, TensorX weight, 151 | //int kernelH, int kernelW, 152 | int strideH, int strideW, 153 | PaddingMode padding, 154 | ActiMode activation, 155 | bool isSrcOp = true); 156 | OpX* create_element(TensorX input0, TensorX input1, 157 | OpType type, bool isSrcOp = true); 158 | OpX* create_pool2d_avg(TensorX input, TensorX weight, 159 | //int kernelH, int kernelW, 160 | int strideH, int strideW, 161 | PaddingMode padding, 162 | ActiMode activation, 163 | bool isSrcOp = true); 164 | OpX* create_matmul(TensorX input, TensorX weight, 165 | ActiMode activation, bool isSrcOp = true); 166 | OpX* create_mul(TensorX x, TensorX y, bool isSrcOp = true); 167 | OpX* create_transpose(TensorX input, int numDim, int* perm, int shuffle); 168 | OpX* create_enlarge(TensorX w1, TensorX w2, bool isSrcOp = true); 169 | OpX* create_merge_gconv(TensorX w, int count, bool isSrcOp = true); 170 | OpX* create_concat(int axis, int numDim, TensorX in1, TensorX in2, bool isSrcOp = true); 171 | OpX* create_concat(int axis, int numDim, int n, TensorX* ins, bool isSrcOp = true); 172 | OpX* create_split(TensorX input, int axis, int n, bool isSrcOp = true); 173 | void add_src_op(SrcOp* op); 174 | void add_dst_op(DstOp* op); 175 | void add_src_edge(SrcOp* src, SrcOp* tgt, int srcIdx = 0, int dstIdx = 0); 176 | void add_dst_edge(DstOp* src, DstOp* tgt, int srcIdx = 0, int dstIdx = 0); 177 | bool add_constraint(Compare comp, SrcOp* src, PMParameter srcPara, 178 | SrcOp* tgt, PMParameter dstPara); 179 | bool map_input(SrcOp* src, DstOp* dst); 180 | bool map_output(SrcOp* src, DstOp* dst); 181 | bool map_output(TensorX src, TensorX dst); 182 | void run(int depth, Graph* graph, 183 | std::priority_queue, GraphCompare>&, 184 | std::set&, float threshold, int maxNumOps); 185 | Graph* create_new_graph(Graph* graph); 186 | bool create_new_operator(const OpX* opx, Op& op); 187 | 188 | // built-in substitutions 189 | static GraphXfer* create_conv_relu(Model* model, int strideH, int strideW, PaddingMode padding); 190 | static GraphXfer* create_enlarge_merge_convs(Model* model, ActiMode activation); 191 | static GraphXfer* create_merge_group_convs(Model* model, int strideH, int strideW, ActiMode activation); 192 | public: 193 | Model* model; 194 | int tensorId; 195 | //std::vector constraints; 196 | //std::map, SubEdgeCompare > > srcInEdges, srcOutEdges; 197 | //std::map, SubEdgeCompare > > dstInEdges, dstOutEdges; 198 | std::map mappedOps; 199 | std::multimap > mappedInputs; 200 | std::map mappedOutputs; 201 | std::vector srcOps; 202 | std::vector dstOps; 203 | }; 204 | 205 | } // namespace XFlow 206 | #endif 207 | -------------------------------------------------------------------------------- /src/core/transpose.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | using namespace XFlow; 18 | 19 | int permutation_to_index(const std::vector& perm) 20 | { 21 | // check perm 22 | for (size_t i = 0; i < perm.size(); i++) { 23 | assert(perm[i] >= 0 && perm[i] < perm.size()); 24 | for (size_t j = i + 1; j < perm.size(); j++) 25 | assert(perm[i] != perm[j]); 26 | } 27 | int idx = 0; 28 | for (size_t i = 0; i < perm.size(); i++) 29 | idx = idx * perm.size() + perm[i]; 30 | return idx; 31 | } 32 | 33 | TensorHandle Graph::transpose(const TensorHandle _input, 34 | const std::vector& perm, 35 | bool _shuffle) 36 | { 37 | Op op = model->get_or_create_transpose(*_input, perm, _shuffle); 38 | assert(op != Op::INVALID_OP); 39 | add_edge(_input->op, op, _input->idx, 0); 40 | TensorHandle t = new Tensor(op.ptr->outputs[0]); 41 | t->op = op; 42 | return t; 43 | } 44 | 45 | Op Model::get_or_create_transpose(Tensor _input, int permIdx, 46 | bool _shuffle) 47 | { 48 | int ndim = _input.numDim; 49 | std::vector permVec; 50 | int permArray[MAX_DIM]; 51 | for (int i = ndim - 1; i >= 0; i--) { 52 | permArray[i] = permIdx % ndim; 53 | permIdx = permIdx / ndim; 54 | } 55 | if (permIdx != 0) { 56 | return Op::INVALID_OP; 57 | } 58 | for (int i = 0; i < ndim; i++) 59 | for (int j = i + 1; j < ndim; j++) 60 | if (permArray[i] != permArray[j]) { 61 | return Op::INVALID_OP; 62 | } 63 | for (int i = 0; i < ndim; i++) 64 | permVec.push_back(permArray[i]); 65 | return get_or_create_transpose(_input, permVec, _shuffle); 66 | } 67 | 68 | Op Model::get_or_create_transpose(Tensor _input, 69 | const std::vector& perm, 70 | bool _shuffle) 71 | { 72 | TransposeKey key(_input, perm, _shuffle); 73 | Transpose* transposeOp; 74 | if (transpose.find(key) != transpose.end()) { 75 | transposeOp = transpose[key]; 76 | } else { 77 | transposeOp = new Transpose(this, _input, perm, _shuffle); 78 | measure_transpose_cost(transposeOp); 79 | transpose[key] = transposeOp; 80 | } 81 | Op ret; 82 | ret.guid = global_unique_id ++; 83 | ret.ptr = transposeOp; 84 | return ret; 85 | } 86 | 87 | Transpose::Transpose(Model* _model, Tensor _input, 88 | const std::vector& _perm, 89 | bool _shuffle) 90 | : OpBase(_input, _model, OP_TRANSPOSE), shuffle(_shuffle) 91 | { 92 | assert(shuffle); 93 | permIdx = permutation_to_index(_perm); 94 | assert(_input.numDim == _perm.size()); 95 | numOutputs = 1; 96 | // set dims and strides 97 | outputs[0].numDim = _input.numDim; 98 | for (int i = 0; i < _perm.size(); i++) { 99 | outputs[0].dim[i] = _input.dim[_perm[i]]; 100 | outputs[0].split[i] = _input.split[_perm[i]]; 101 | } 102 | if (shuffle) { 103 | int size = 1; 104 | for (int i = _perm.size() - 1; i >= 0; i--) { 105 | outputs[0].stride[i] = size; 106 | size *= outputs[0].dim[i]; 107 | } 108 | assert(size == outputs[0].volume()); 109 | } else { 110 | for (int i = 0; i < _perm.size(); i++) 111 | outputs[0].stride[i] = _input.stride[_perm[i]]; 112 | } 113 | outputs[0].idx = 0; 114 | } 115 | 116 | Transpose::~Transpose(void) 117 | {} 118 | 119 | bool Transpose::get_parameter(PMParameter para, int* value) 120 | { 121 | switch (para) { 122 | case PM_NUMDIM: 123 | *value = outputs[0].numDim; 124 | return true; 125 | case PM_PERM: 126 | *value = permIdx; 127 | return true; 128 | case PM_OUTSHUFFLE: 129 | *value = (int) shuffle; 130 | return true; 131 | default: 132 | return OpBase::get_parameter(para, value); 133 | } 134 | } 135 | 136 | void Transpose::collect_costs(float& exe_time, float& flops, 137 | float& mem_acc, int& num_kernels) 138 | { 139 | if (shuffle) { 140 | exe_time += runtime; 141 | flops += outputs[0].volume(); 142 | mem_acc += outputs[0].volume(); 143 | num_kernels += 1; 144 | } 145 | } 146 | 147 | TransposeKey::TransposeKey(Tensor _input, 148 | const std::vector& perm, 149 | bool _shuffle) 150 | { 151 | int idx = 0; 152 | keys[idx++] = permutation_to_index(perm); 153 | keys[idx++] = (int) _shuffle; 154 | _input.serialize(keys, idx); 155 | while (idx < KEY_LENGTH) 156 | keys[idx++] = 0; 157 | assert(idx == KEY_LENGTH); 158 | } 159 | 160 | -------------------------------------------------------------------------------- /src/cudnn/activation_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Activation::map(void) 21 | { 22 | // create descriptors 23 | checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); 24 | helperSetTensorDescriptor(inputs[0], inputTensor); 25 | checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); 26 | cudnnActivationMode_t mode; 27 | switch (type) { 28 | case OP_RELU: 29 | mode = CUDNN_ACTIVATION_RELU; 30 | break; 31 | case OP_SIGMOID: 32 | mode = CUDNN_ACTIVATION_SIGMOID; 33 | break; 34 | case OP_TANH: 35 | mode = CUDNN_ACTIVATION_TANH; 36 | break; 37 | default: 38 | assert(false); 39 | } 40 | checkCUDNN(cudnnSetActivationDescriptor(actiDesc, mode, 41 | CUDNN_NOT_PROPAGATE_NAN, 0.0)); 42 | if (!inPlace) { 43 | size_t outputSize = sizeof(DATATYPE); 44 | for (int i = 0; i < inputs[0].numDim; i++) 45 | outputSize *= inputs[0].dim[i]; 46 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 47 | } else { 48 | outputs[0].data_ptr = inputs[0].data_ptr; 49 | } 50 | } 51 | 52 | void Activation::unmap(void) 53 | { 54 | checkCUDNN(cudnnDestroyTensorDescriptor(inputTensor)); 55 | checkCUDNN(cudnnDestroyActivationDescriptor(actiDesc)); 56 | if (!inPlace) { 57 | checkCUDA(cudaFree(outputs[0].data_ptr)); 58 | } 59 | } 60 | 61 | void Activation::forward(bool block) 62 | { 63 | const float alpha = 1.0f; 64 | const float beta = 0.0f; 65 | checkCUDNN(cudnnActivationForward(model->dnn, actiDesc, 66 | &alpha, inputTensor, inputs[0].data_ptr, 67 | &beta, inputTensor, outputs[0].data_ptr)); 68 | if (block) 69 | checkCUDA(cudaDeviceSynchronize()); 70 | } 71 | 72 | void Model::measure_activation_cost(Activation* act) 73 | { 74 | const float alpha = 1.0f; 75 | const float beta = 0.0f; 76 | helperSetTensorDescriptor(act->inputs[0], inputTensor); 77 | cudnnActivationMode_t mode; 78 | switch (act->type) { 79 | case OP_RELU: 80 | mode = CUDNN_ACTIVATION_RELU; 81 | break; 82 | case OP_SIGMOID: 83 | mode = CUDNN_ACTIVATION_SIGMOID; 84 | break; 85 | case OP_TANH: 86 | mode = CUDNN_ACTIVATION_TANH; 87 | break; 88 | default: 89 | assert(false); 90 | } 91 | checkCUDNN(cudnnSetActivationDescriptor(actiDesc, mode, 92 | CUDNN_NOT_PROPAGATE_NAN, 0.0)); 93 | checkCUDA(cudaDeviceSynchronize()); 94 | checkCUDA(cudaEventRecord(startEvent)); 95 | for (int i = 0; i < REPEAT_TIMES; i++) { 96 | if (act->inPlace) { 97 | checkCUDNN(cudnnActivationForward(dnn, actiDesc, 98 | &alpha, inputTensor, inputPtr, 99 | &beta, inputTensor, inputPtr)); 100 | } else { 101 | checkCUDNN(cudnnActivationForward(dnn, actiDesc, 102 | &alpha, inputTensor, inputPtr, 103 | &beta, inputTensor, outputPtr)); 104 | } 105 | } 106 | checkCUDA(cudaEventRecord(endEvent)); 107 | checkCUDA(cudaEventSynchronize(endEvent)); 108 | float milliseconds; 109 | cudaEventElapsedTime(&milliseconds, startEvent, endEvent); 110 | act->runtime = milliseconds / REPEAT_TIMES; 111 | if (print_cost) 112 | printf(" measure[Activation]: i(%d %d %d %d) type(%d) cost(%.4lf)\n", 113 | act->inputs[0].dim[0], act->inputs[0].dim[1], act->inputs[0].dim[2], 114 | act->inputs[0].dim[3], act->type, act->runtime); 115 | } 116 | 117 | -------------------------------------------------------------------------------- /src/cudnn/batchnorm_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void BatchNorm::map(void) 21 | { 22 | assert(inputs[0].numDim == 4); 23 | // create descriptors 24 | checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); 25 | checkCUDNN(cudnnCreateTensorDescriptor(&biasTensor)); 26 | checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); 27 | int inputN = inputs[0].dim[0]; 28 | int inputC = inputs[0].dim[1]; 29 | int inputH = inputs[0].dim[2]; 30 | int inputW = inputs[0].dim[3]; 31 | checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, CUDNN_TENSOR_NCHW, 32 | CUDNN_DATA_FLOAT, inputN, inputC, inputH, inputW)); 33 | checkCUDNN(cudnnSetTensor4dDescriptor(outputTensor, CUDNN_TENSOR_NCHW, 34 | CUDNN_DATA_FLOAT, inputN, inputC, inputH, inputW)); 35 | checkCUDNN(cudnnSetTensor4dDescriptor(biasTensor, CUDNN_TENSOR_NCHW, 36 | CUDNN_DATA_FLOAT, 1, inputC, 1, 1)); 37 | #ifdef DO_TRAINING 38 | checkCUDA(cudaMalloc(&runningMean, sizeof(DATATYPE) * inputC)); 39 | checkCUDA(cudaMalloc(&runningVar, sizeof(DATATYPE) * inputC)); 40 | checkCUDA(cudaMalloc(&saveMean, sizeof(DATATYPE) * inputC)); 41 | checkCUDA(cudaMalloc(&saveVar, sizeof(DATATYPE) * inputC)); 42 | checkCUDA(cudaMalloc(&biasPtr, sizeof(DATATYPE) * inputC)); 43 | checkCUDA(cudaMalloc(&scalePtr, sizeof(DATATYPE) * inputC)); 44 | initialize scale to ones and bias to zeros 45 | assign_kernel<<>>( 46 | scalePtr, inputC, 1.0f); 47 | assign_kernel<<>>( 48 | biasPtr, inputC, 0.0f); 49 | assign_kernel<<>>( 50 | runningMean, inputC, 0.0f); 51 | assign_kernel<<>>( 52 | runningVar, inputC, 0.0f); 53 | #endif 54 | size_t outputSize = sizeof(DATATYPE) * outputs[0].volume(); 55 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 56 | } 57 | 58 | void BatchNorm::unmap(void) 59 | { 60 | checkCUDNN(cudnnDestroyTensorDescriptor(inputTensor)); 61 | checkCUDNN(cudnnDestroyTensorDescriptor(biasTensor)); 62 | checkCUDNN(cudnnDestroyTensorDescriptor(outputTensor)); 63 | #ifdef DO_TRAINING 64 | checkCUDA(cudaFree(runningMean)); 65 | checkCUDA(cudaFree(runningVar)); 66 | checkCUDA(cudaFree(saveMean)); 67 | checkCUDA(cudaFree(saveVar)); 68 | checkCUDA(cudaFree(biasPtr)); 69 | checkCUDA(cudaFree(scalePtr)); 70 | checkCUDA(cudaFree(outputs[0].data_ptr)); 71 | #endif 72 | } 73 | 74 | void BatchNorm::forward(bool block) 75 | { 76 | const float alpha = 1.0f; 77 | const float beta = 0.0f; 78 | cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; 79 | int inputC = inputs[0].dim[1]; 80 | #ifdef DO_TRAINING 81 | if (model->isTraining) { 82 | assign_kernel<<>>( 83 | runningMean, inputC, 0.0f); 84 | assign_kernel<<>>( 85 | runningVar, inputC, 0.0f); 86 | checkCUDNN(cudnnBatchNormalizationForwardTraining( 87 | model->dnn, mode, &alpha, &beta, inputTensor, inputs[0].data_ptr, 88 | outputTensor, outputs[0].data_ptr, biasTensor, scalePtr, biasPtr, 89 | 1.0, runningMean, runningVar, CUDNN_BN_MIN_EPSILON, saveMean, saveVar)); 90 | } else { 91 | #endif 92 | checkCUDNN(cudnnBatchNormalizationForwardInference( 93 | model->dnn, mode, &alpha, &beta, inputTensor, inputs[0].data_ptr, 94 | outputTensor, outputs[0].data_ptr, biasTensor, inputs[1].data_ptr, inputs[2].data_ptr, 95 | inputs[3].data_ptr, inputs[4].data_ptr, CUDNN_BN_MIN_EPSILON)); 96 | #ifdef DO_TRAINING 97 | } 98 | #endif 99 | if (block) 100 | checkCUDA(cudaDeviceSynchronize()); 101 | } 102 | 103 | void Model::measure_batchnorm_cost(BatchNorm* bn) 104 | { 105 | const float alpha = 1.0f; 106 | const float beta = 0.0f; 107 | int inputC = bn->inputs[0].dim[1]; 108 | int inputH = bn->inputs[0].dim[2]; 109 | int inputW = bn->inputs[0].dim[3]; 110 | cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; 111 | checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, CUDNN_TENSOR_NCHW, 112 | CUDNN_DATA_FLOAT, BATCH_SIZE, inputC, inputH, inputW)); 113 | checkCUDNN(cudnnSetTensor4dDescriptor(outputTensor, CUDNN_TENSOR_NCHW, 114 | CUDNN_DATA_FLOAT, BATCH_SIZE, inputC, inputH, inputW)); 115 | checkCUDNN(cudnnSetTensor4dDescriptor(biasTensor, CUDNN_TENSOR_NCHW, 116 | CUDNN_DATA_FLOAT, 1, inputC, 1, 1)); 117 | #ifdef DO_TRAINING 118 | assign_kernel<<>>( 119 | scalePtr, inputC, 0.5f); 120 | assign_kernel<<>>( 121 | biasPtr, inputC, 0.5f); 122 | #endif 123 | checkCUDA(cudaDeviceSynchronize()); 124 | checkCUDA(cudaEventRecord(startEvent)); 125 | for (int i = 0; i < REPEAT_TIMES; i++) { 126 | #ifdef DO_TRAINING 127 | if (isTraining) { 128 | assign_kernel<<>>( 129 | runningMean, inputC, 0.0f); 130 | assign_kernel<<>>( 131 | runningVar, inputC, 0.0f); 132 | checkCUDNN(cudnnBatchNormalizationForwardTraining( 133 | dnn, mode, &alpha, &beta, inputTensor, inputPtr, 134 | outputTensor, outputPtr, biasTensor, scalePtr, biasPtr, 135 | 1.0, runningMean, runningVar, CUDNN_BN_MIN_EPSILON, 136 | saveMean, saveVar)); 137 | } else { 138 | #endif 139 | checkCUDNN(cudnnBatchNormalizationForwardInference( 140 | dnn, mode, &alpha, &beta, inputTensor, inputPtr, 141 | outputTensor, outputPtr, biasTensor, scalePtr, biasPtr, 142 | runningMean, runningVar, CUDNN_BN_MIN_EPSILON)); 143 | #ifdef DO_TRAINING 144 | } 145 | #endif 146 | } 147 | checkCUDA(cudaEventRecord(endEvent)); 148 | checkCUDA(cudaEventSynchronize(endEvent)); 149 | float milliseconds; 150 | cudaEventElapsedTime(&milliseconds, startEvent, endEvent); 151 | bn->runtime = milliseconds / REPEAT_TIMES; 152 | printf("measure[BatchNorm]: i(%d %d %d %d) cost(%.4lf)\n", 153 | BATCH_SIZE, bn->inputs[0].dim[1], bn->inputs[0].dim[2], 154 | bn->inputs[0].dim[3], bn->runtime); 155 | } 156 | 157 | -------------------------------------------------------------------------------- /src/cudnn/concat_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Concat::map(void) 21 | { 22 | size_t outputSize = sizeof(DATATYPE) * outputs[0].volume(); 23 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 24 | } 25 | 26 | void Concat::unmap(void) 27 | { 28 | checkCUDA(cudaFree(outputs[0].data_ptr)); 29 | } 30 | 31 | __global__ 32 | void assign_with_stride(DATATYPE* dst, 33 | const DATATYPE* src, 34 | int num_blocks, 35 | int dst_blk_size, 36 | int src_blk_size) 37 | { 38 | assert(src_blk_size <= dst_blk_size); 39 | CUDA_KERNEL_LOOP(i, num_blocks * src_blk_size) 40 | { 41 | int blk_idx = i / src_blk_size; 42 | int blk_offset = i % src_blk_size; 43 | int src_offset = blk_idx * src_blk_size + blk_offset; 44 | int dst_offset = blk_idx * dst_blk_size + blk_offset; 45 | dst[dst_offset] = src[src_offset]; 46 | } 47 | } 48 | 49 | void Concat::forward(bool block) 50 | { 51 | int offset = 0; 52 | for (int i = 0; i < numInputs; i++) 53 | if (needCopy[i]) { 54 | int dst_blk_size = 1, src_blk_size = 1, num_blocks = 1; 55 | for (int j = inputs[i].numDim-1; j >= 0; j--) 56 | if (j >= axis) { 57 | dst_blk_size *= outputs[0].dim[j]; 58 | src_blk_size *= inputs[i].dim[j]; 59 | } else { 60 | num_blocks *= outputs[0].dim[j]; 61 | } 62 | assert(inputs[i].data_ptr != NULL); 63 | assign_with_stride<<>>( 64 | ((DATATYPE*)outputs[0].data_ptr) + offset, (DATATYPE*)inputs[i].data_ptr, 65 | num_blocks, dst_blk_size, src_blk_size); 66 | offset += src_blk_size; 67 | } 68 | if (block) 69 | checkCUDA(cudaDeviceSynchronize()); 70 | //FIXME 71 | //DATATYPE* print_vals = (DATATYPE*) malloc(outputs[0].volume() * sizeof(DATATYPE)); 72 | //checkCUDA(cudaMemcpy(print_vals, outputs[0].data_ptr, outputs[0].volume() * sizeof(DATATYPE), cudaMemcpyDefault)); 73 | //for (int i = 0; i < outputs[0].volume(); i++) 74 | // printf("output[%d]: %.4lf\n", i, print_vals[i]); 75 | //for (int i = 0; i < numInputs; i++) { 76 | // checkCUDA(cudaMemcpy(print_vals, inputs[i].data_ptr, inputs[i].volume() * sizeof(DATATYPE), cudaMemcpyDefault)); 77 | // printf("concat_forward: inputs[%d].ptr=%p\n", i, inputs[i].data_ptr); 78 | // for (int j = 0; j < inputs[i].volume(); j++) 79 | // printf("input[%d][%d]: %.4lf\n", i, j, print_vals[j]); 80 | //} 81 | } 82 | 83 | void Model::measure_concat_cost(Concat* concat) 84 | { 85 | checkCUDA(cudaDeviceSynchronize()); 86 | checkCUDA(cudaEventRecord(startEvent)); 87 | for (int i = 0; i < REPEAT_TIMES; i++) { 88 | int offset = 0; 89 | // TODO: remove needCopy and should not include operators 90 | // that can be preproceed 91 | for (int j = 0; j < concat->numInputs; j++) 92 | if (concat->needCopy[j]) { 93 | int dst_blk_size = 1, src_blk_size = 1, num_blocks = 1; 94 | for (int d = concat->inputs[j].numDim-1; d >= 0; d--) 95 | if (d >= concat->axis) { 96 | dst_blk_size *= concat->outputs[0].dim[d]; 97 | src_blk_size *= concat->inputs[j].dim[d]; 98 | } else { 99 | num_blocks *= concat->outputs[0].dim[d]; 100 | } 101 | assign_with_stride<<>>( 102 | ((DATATYPE*)outputPtr) + offset, (DATATYPE*)inputPtr, 103 | num_blocks, dst_blk_size, src_blk_size); 104 | offset += src_blk_size; 105 | } 106 | } 107 | checkCUDA(cudaEventRecord(endEvent)); 108 | checkCUDA(cudaEventSynchronize(endEvent)); 109 | float milliseconds; 110 | cudaEventElapsedTime(&milliseconds, startEvent, endEvent); 111 | concat->runtime = milliseconds / REPEAT_TIMES; 112 | if (print_cost) 113 | printf(" measure[Concat]: cost(%.4lf)\n", concat->runtime); 114 | } 115 | 116 | -------------------------------------------------------------------------------- /src/cudnn/constant_kernel.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Constant::map(void) 21 | { 22 | size_t outputSize = sizeof(DATATYPE) * outputs[0].volume(); 23 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 24 | } 25 | 26 | void Constant::unmap(void) 27 | { 28 | checkCUDA(cudaFree(outputs[0].data_ptr)); 29 | } 30 | 31 | void Constant::forward(bool block) 32 | { 33 | if (block) 34 | checkCUDA(cudaDeviceSynchronize()); 35 | } 36 | 37 | -------------------------------------------------------------------------------- /src/cudnn/conv2d_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Conv2D::map(void) 21 | { 22 | // create descriptors 23 | checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); 24 | checkCUDNN(cudnnCreateTensorDescriptor(&biasTensor)); 25 | checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); 26 | checkCUDNN(cudnnCreateFilterDescriptor(&filterDesc)); 27 | checkCUDNN(cudnnCreateConvolutionDescriptor(&convDesc)); 28 | int inputN = inputs[0].dim[0]; 29 | int inputC = inputs[0].dim[1]; 30 | int inputH = inputs[0].dim[2]; 31 | int inputW = inputs[0].dim[3]; 32 | int outputC = inputs[1].dim[0]; 33 | int groups = inputs[0].dim[1] / inputs[1].dim[1]; 34 | int padH, padW; 35 | get_padding(&padH, &padW); 36 | // set descriptors 37 | checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, CUDNN_TENSOR_NCHW, 38 | CUDNN_DATA_FLOAT, inputN, inputC, inputH, inputW)); 39 | checkCUDNN(cudnnSetTensor4dDescriptor(biasTensor, CUDNN_TENSOR_NCHW, 40 | CUDNN_DATA_FLOAT, 1, outputC, 1, 1)); 41 | checkCUDNN(cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_FLOAT, 42 | CUDNN_TENSOR_NCHW, inputs[1].dim[0], inputs[1].dim[1], 43 | inputs[1].dim[2], inputs[1].dim[3])); 44 | checkCUDNN(cudnnSetConvolution2dDescriptor(convDesc, padH, padW, 45 | strideH, strideW, 1/*dilationH*/, 1/*dilationW*/, 46 | CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); 47 | checkCUDNN(cudnnSetConvolutionMathType(convDesc, CUDNN_TENSOR_OP_MATH)); 48 | if (groups != 1) { 49 | checkCUDNN(cudnnSetConvolutionGroupCount(convDesc, groups)); 50 | } 51 | int n, c, h, w; 52 | checkCUDNN(cudnnGetConvolution2dForwardOutputDim(convDesc, 53 | inputTensor, filterDesc, &n, &c, &h, &w)); 54 | assert(n == inputN); 55 | assert(c == outputC); 56 | assert(outputs[0].dim[2] == h); 57 | assert(outputs[0].dim[3] == w); 58 | checkCUDNN(cudnnSetTensor4dDescriptor(outputTensor, CUDNN_TENSOR_NCHW, 59 | CUDNN_DATA_FLOAT, n, c, h, w)); 60 | if (activation != AC_MODE_NONE) { 61 | checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); 62 | cudnnActivationMode_t mode = get_activation_mode(activation); 63 | checkCUDNN(cudnnSetActivationDescriptor(actiDesc, mode, 64 | CUDNN_NOT_PROPAGATE_NAN, 0.0)); 65 | } 66 | // allocate tensors 67 | size_t outputSize = sizeof(DATATYPE) * n * c * h * w; 68 | size_t biasSize = sizeof(DATATYPE) * outputC; 69 | checkCUDA(cudaMalloc(&biasPtr, biasSize)); 70 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 71 | } 72 | 73 | void Conv2D::unmap(void) 74 | { 75 | checkCUDNN(cudnnDestroyTensorDescriptor(inputTensor)); 76 | checkCUDNN(cudnnDestroyTensorDescriptor(biasTensor)); 77 | checkCUDNN(cudnnDestroyTensorDescriptor(outputTensor)); 78 | checkCUDNN(cudnnDestroyFilterDescriptor(filterDesc)); 79 | checkCUDNN(cudnnDestroyConvolutionDescriptor(convDesc)); 80 | if (activation != AC_MODE_NONE) { 81 | checkCUDNN(cudnnDestroyActivationDescriptor(actiDesc)); 82 | } 83 | // free tensors 84 | checkCUDA(cudaFree(outputs[0].data_ptr)); 85 | checkCUDA(cudaFree(biasPtr)); 86 | } 87 | 88 | void Conv2D::forward(bool block) 89 | { 90 | const float alpha = 1.0f; 91 | const float beta = 0.0f; 92 | if (activation != AC_MODE_NONE) { 93 | checkCUDNN(cudnnConvolutionBiasActivationForward( 94 | model->dnn, &alpha, inputTensor, inputs[0].data_ptr, filterDesc, inputs[1].data_ptr, 95 | convDesc, fwdAlgo, model->workSpace, model->workSpaceSize, 96 | &beta, outputTensor, outputs[0].data_ptr, biasTensor, biasPtr, actiDesc, 97 | outputTensor, outputs[0].data_ptr)); 98 | } else { 99 | checkCUDNN(cudnnConvolutionForward( 100 | model->dnn, &alpha, inputTensor, inputs[0].data_ptr, filterDesc, inputs[1].data_ptr, 101 | convDesc, fwdAlgo, model->workSpace, model->workSpaceSize, 102 | &beta, outputTensor, outputs[0].data_ptr)); 103 | checkCUDNN(cudnnAddTensor(model->dnn, &alpha, biasTensor, biasPtr, 104 | &alpha, outputTensor, outputs[0].data_ptr)); 105 | } 106 | if (block) 107 | checkCUDA(cudaDeviceSynchronize()); 108 | } 109 | 110 | void Model::measure_conv2d_cost(Conv2D* conv) 111 | { 112 | const float alpha = 1.0f; 113 | const float beta = 0.0f; 114 | int inputN = conv->inputs[0].dim[0]; 115 | int inputC = conv->inputs[0].dim[1]; 116 | int inputH = conv->inputs[0].dim[2]; 117 | int inputW = conv->inputs[0].dim[3]; 118 | int kernelH = conv->inputs[1].dim[2]; 119 | int kernelW = conv->inputs[1].dim[3]; 120 | int outputC = conv->outputs[0].dim[1]; 121 | int outputH = conv->outputs[0].dim[2]; 122 | int outputW = conv->outputs[0].dim[3]; 123 | int groups = conv->inputs[0].dim[1] / conv->inputs[1].dim[1]; 124 | int padH, padW; 125 | // Reference: https://www.tensorflow.org/api_guides/python/nn#Convolution 126 | switch (conv->padding) { 127 | case PD_MODE_SAME: 128 | int totalPadH, totalPadW; 129 | if (inputH % conv->strideH == 0) 130 | totalPadH = max(kernelH - conv->strideH, 0); 131 | else 132 | totalPadH = max(kernelH - (inputH % conv->strideH), 0); 133 | if (inputW % conv->strideW == 0) 134 | totalPadW = max(kernelW - conv->strideW, 0); 135 | else 136 | totalPadW = max(kernelW - (inputW % conv->strideW), 0); 137 | // assert same padding on both sides 138 | padH = (totalPadH + 1) / 2; 139 | padW = (totalPadW + 1) / 2; 140 | break; 141 | case PD_MODE_VALID: 142 | padH = 0; 143 | padW = 0; 144 | break; 145 | default: 146 | assert(false); 147 | } 148 | checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, CUDNN_TENSOR_NCHW, 149 | CUDNN_DATA_FLOAT, inputN, inputC, inputH, inputW)); 150 | checkCUDNN(cudnnSetTensor4dDescriptor(biasTensor, CUDNN_TENSOR_NCHW, 151 | CUDNN_DATA_FLOAT, 1, outputC, 1, 1)); 152 | checkCUDNN(cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_FLOAT, 153 | CUDNN_TENSOR_NCHW, conv->inputs[1].dim[0], conv->inputs[1].dim[1], 154 | conv->inputs[1].dim[2], conv->inputs[1].dim[3])); 155 | checkCUDNN(cudnnSetConvolution2dDescriptor(convDesc, padH, padW, 156 | conv->strideH, conv->strideW, 1/*dilationH*/, 1/*dilationW*/, 157 | CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)); 158 | checkCUDNN(cudnnSetConvolutionMathType(convDesc, CUDNN_TENSOR_OP_MATH)); 159 | checkCUDNN(cudnnSetConvolutionGroupCount(convDesc, groups)); 160 | checkCUDNN(cudnnSetActivationDescriptor(actiDesc, CUDNN_ACTIVATION_RELU, 161 | CUDNN_NOT_PROPAGATE_NAN, 0.0)); 162 | int n, c, h, w; 163 | checkCUDNN(cudnnGetConvolution2dForwardOutputDim(convDesc, 164 | inputTensor, filterDesc, &n, &c, &h, &w)); 165 | assert(n == inputN); 166 | assert(c == outputC); 167 | assert(outputH == h); 168 | assert(outputW == w); 169 | checkCUDNN(cudnnSetTensor4dDescriptor(outputTensor, CUDNN_TENSOR_NCHW, 170 | CUDNN_DATA_FLOAT, n, c, h, w)); 171 | size_t inputSize = sizeof(DATATYPE) * inputN * inputC * inputH * inputW; 172 | size_t filterSize = sizeof(DATATYPE) * inputC * outputC 173 | * kernelH * kernelW; 174 | size_t outputSize = sizeof(DATATYPE) * n * c * h * w; 175 | assert(inputSize < MAX_TENSOR_SIZE); 176 | assert(filterSize < MAX_TENSOR_SIZE); 177 | assert(outputSize < MAX_TENSOR_SIZE); 178 | 179 | const int reqAlgCnt = 8; 180 | int cnt = 0; 181 | cudnnConvolutionFwdAlgoPerf_t perfResults[reqAlgCnt]; 182 | checkCUDNN(cudnnFindConvolutionForwardAlgorithmEx( 183 | dnn, inputTensor, inputPtr, filterDesc, filterPtr, convDesc, 184 | outputTensor, outputPtr, reqAlgCnt, &cnt, perfResults, 185 | workSpace, workSpaceSize)); 186 | assert(cnt > 0); 187 | checkCUDNN(perfResults[0].status); 188 | //for (int i = 0; i < cnt; i++) { 189 | //printf("fwdAlgo(%d) time(%.2lfms) space(%dMB)\n", perfResults[i].algo, 190 | // perfResults[i].time, perfResults[i].memory / 1024 / 1024); 191 | //} 192 | conv->fwdAlgo = perfResults[0].algo; 193 | 194 | checkCUDA(cudaDeviceSynchronize()); 195 | for (int i = 0; i < WARMUP_TIMES + REPEAT_TIMES; i++) { 196 | if (i == WARMUP_TIMES) { 197 | checkCUDA(cudaEventRecord(startEvent)); 198 | } 199 | if (conv->activation != AC_MODE_NONE) { 200 | checkCUDNN(cudnnConvolutionBiasActivationForward( 201 | dnn, &alpha, inputTensor, inputPtr, filterDesc, filterPtr, 202 | convDesc, conv->fwdAlgo, workSpace, workSpaceSize, 203 | &beta, outputTensor, outputPtr, biasTensor, biasPtr, actiDesc, 204 | outputTensor, outputPtr)); 205 | } else { 206 | checkCUDNN(cudnnConvolutionForward( 207 | dnn, &alpha, inputTensor, inputPtr, filterDesc, filterPtr, 208 | convDesc, conv->fwdAlgo, workSpace, workSpaceSize, 209 | &beta, outputTensor, outputPtr)); 210 | checkCUDNN(cudnnAddTensor(dnn, &alpha, biasTensor, biasPtr, 211 | &alpha, outputTensor, outputPtr)); 212 | } 213 | } 214 | checkCUDA(cudaEventRecord(endEvent)); 215 | checkCUDA(cudaEventSynchronize(endEvent)); 216 | float milliseconds; 217 | cudaEventElapsedTime(&milliseconds, startEvent, endEvent); 218 | conv->runtime = milliseconds / REPEAT_TIMES; 219 | if (print_cost) 220 | printf(" measure[Conv2D]: i(%d %d %d %d) w(%d %d %d %d) s(%d %d) p(%d %d) cost(%.4lf)\n", 221 | conv->inputs[0].dim[0], conv->inputs[0].dim[1], conv->inputs[0].dim[2], conv->inputs[0].dim[3], 222 | conv->inputs[1].dim[0], conv->inputs[1].dim[1], conv->inputs[1].dim[2], conv->inputs[1].dim[3], 223 | conv->strideH, conv->strideW, padH, padW, conv->runtime); 224 | } 225 | 226 | -------------------------------------------------------------------------------- /src/cudnn/cuda_helper.cu: -------------------------------------------------------------------------------- 1 | #include "xflow/cuda_helper.h" 2 | using namespace XFlow; 3 | 4 | __global__ 5 | void assign_kernel(float* ptr, int size, float value) 6 | { 7 | CUDA_KERNEL_LOOP(i, size) 8 | { 9 | ptr[i] = value; 10 | } 11 | } 12 | 13 | cudnnActivationMode_t get_activation_mode(ActiMode activation) 14 | { 15 | switch (activation) { 16 | case AC_MODE_SIGMOID: 17 | return CUDNN_ACTIVATION_SIGMOID; 18 | case AC_MODE_RELU: 19 | return CUDNN_ACTIVATION_RELU; 20 | case AC_MODE_TANH: 21 | return CUDNN_ACTIVATION_TANH; 22 | default: 23 | assert(false); 24 | } 25 | // return RELU as default 26 | return CUDNN_ACTIVATION_RELU; 27 | } 28 | 29 | void helperSetTensorDescriptor(const Tensor& tensor, 30 | cudnnTensorDescriptor_t tensorDesc) 31 | { 32 | switch(tensor.numDim) { 33 | case 1: 34 | { 35 | int dims[] = {tensor.dim[0], 1, 1}; 36 | int strides[] = {tensor.stride[0], 1, 1}; 37 | checkCUDNN(cudnnSetTensorNdDescriptor(tensorDesc, CUDNN_DATA_FLOAT, 38 | 3, dims, strides)); 39 | break; 40 | } 41 | case 2: 42 | { 43 | int dims[] = {tensor.dim[0], tensor.dim[1], 1, 1}; 44 | int strides[] = {tensor.stride[0], tensor.stride[1], 1, 1}; 45 | checkCUDNN(cudnnSetTensorNdDescriptor(tensorDesc, CUDNN_DATA_FLOAT, 46 | 4, dims, strides)); 47 | break; 48 | } 49 | default: 50 | { 51 | assert(tensor.numDim >= 3); 52 | checkCUDNN(cudnnSetTensorNdDescriptor(tensorDesc, CUDNN_DATA_FLOAT, 53 | tensor.numDim, tensor.dim, tensor.stride)); 54 | } 55 | } 56 | } 57 | 58 | -------------------------------------------------------------------------------- /src/cudnn/element_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Element::map(void) 21 | { 22 | // create descriptors 23 | checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); 24 | checkCUDNN(cudnnCreateOpTensorDescriptor(&opDesc)); 25 | // set descriptors 26 | helperSetTensorDescriptor(inputs[0], inputTensor); 27 | 28 | cudnnOpTensorOp_t opType; 29 | switch (type) { 30 | case OP_EW_ADD: 31 | opType = CUDNN_OP_TENSOR_ADD; 32 | break; 33 | case OP_EW_MUL: 34 | opType = CUDNN_OP_TENSOR_MUL; 35 | break; 36 | default: 37 | assert(false); 38 | } 39 | checkCUDNN(cudnnSetOpTensorDescriptor(opDesc, opType, CUDNN_DATA_FLOAT, 40 | CUDNN_NOT_PROPAGATE_NAN)); 41 | // allocate tensors 42 | size_t outputSize = sizeof(DATATYPE); 43 | for (int i = 0; i < outputs[0].numDim; i++) 44 | outputSize *= outputs[0].dim[i]; 45 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 46 | } 47 | 48 | void Element::unmap(void) 49 | { 50 | checkCUDNN(cudnnDestroyTensorDescriptor(inputTensor)); 51 | checkCUDNN(cudnnDestroyOpTensorDescriptor(opDesc)); 52 | checkCUDA(cudaFree(outputs[0].data_ptr)); 53 | } 54 | 55 | void Element::forward(bool block) 56 | { 57 | const float alpha = 1.0f; 58 | const float beta = 0.0f; 59 | checkCUDNN(cudnnOpTensor(model->dnn, opDesc, &alpha, inputTensor, inputs[0].data_ptr, 60 | &alpha, inputTensor, inputs[1].data_ptr, &beta, inputTensor, outputs[0].data_ptr)); 61 | if (block) 62 | checkCUDA(cudaDeviceSynchronize()); 63 | } 64 | 65 | void Model::measure_element_cost(Element* ele) 66 | { 67 | const float alpha = 1.0f; 68 | const float beta = 0.0f; 69 | helperSetTensorDescriptor(ele->inputs[0], inputTensor); 70 | //int inputN = ele->inputs[0].dim[0]; 71 | //int inputC = max(ele->inputs[0].dim[1], 1); 72 | //int inputH = max(ele->inputs[0].dim[2], 1); 73 | //int inputW = max(ele->inputs[0].dim[3], 1); 74 | //checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, CUDNN_TENSOR_NCHW, 75 | // CUDNN_DATA_FLOAT, inputN, inputC, inputH, inputW)); 76 | 77 | cudnnOpTensorOp_t opType; 78 | switch (ele->type) { 79 | case OP_EW_ADD: 80 | opType = CUDNN_OP_TENSOR_ADD; 81 | break; 82 | case OP_EW_MUL: 83 | opType = CUDNN_OP_TENSOR_MUL; 84 | break; 85 | default: 86 | assert(false); 87 | } 88 | checkCUDNN(cudnnSetOpTensorDescriptor(opDesc, opType, CUDNN_DATA_FLOAT, 89 | CUDNN_NOT_PROPAGATE_NAN)); 90 | 91 | checkCUDA(cudaDeviceSynchronize()); 92 | checkCUDA(cudaEventRecord(startEvent)); 93 | for (int i = 0; i < REPEAT_TIMES; i++) { 94 | checkCUDNN(cudnnOpTensor(dnn, opDesc, &alpha, inputTensor, inputPtr, 95 | &alpha, inputTensor, filterPtr, &beta, inputTensor, outputPtr)); 96 | } 97 | checkCUDA(cudaEventRecord(endEvent)); 98 | checkCUDA(cudaEventSynchronize(endEvent)); 99 | float milliseconds; 100 | cudaEventElapsedTime(&milliseconds, startEvent, endEvent); 101 | ele->runtime = milliseconds / REPEAT_TIMES; 102 | if (print_cost) 103 | printf(" measure[Element]: i(%d %d %d %d) type(%d) cost(%.4lf)\n", 104 | ele->inputs[0].dim[0], ele->inputs[0].dim[1], ele->inputs[0].dim[2], 105 | ele->inputs[0].dim[3], ele->type, ele->runtime); 106 | } 107 | 108 | -------------------------------------------------------------------------------- /src/cudnn/enlarge_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | __global__ 21 | void enlarge_kernel(DATATYPE* dst_ptr, 22 | const DATATYPE* src_ptr, 23 | int volume, 24 | int dst_h, 25 | int dst_w, 26 | int src_h, 27 | int src_w) 28 | { 29 | int off_h = (dst_h - src_h) / 2; 30 | int off_w = (dst_w - src_w) / 2; 31 | CUDA_KERNEL_LOOP(i, volume) 32 | { 33 | int h = (i % (dst_h * dst_w)) / dst_w - off_h; 34 | int w = (i % (dst_h * dst_w)) % dst_w - off_w; 35 | if ((h < 0) || (h >= src_h) || (w < 0) || (w >= src_w)) 36 | dst_ptr[i] = 0.0f; 37 | else { 38 | int offset = (i / (dst_h * dst_w)) * (src_h * src_w) + h * src_w + w; 39 | dst_ptr[i] = src_ptr[offset]; 40 | } 41 | } 42 | } 43 | 44 | void Enlarge::map(void) 45 | { 46 | size_t outputSize = sizeof(DATATYPE) * outputs[0].volume(); 47 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 48 | } 49 | 50 | void Enlarge::unmap(void) 51 | { 52 | checkCUDA(cudaFree(outputs[0].data_ptr)); 53 | } 54 | 55 | void Enlarge::forward(bool block) 56 | { 57 | enlarge_kernel<<>>( 58 | (DATATYPE*)outputs[0].data_ptr, (DATATYPE*)inputs[0].data_ptr, outputs[0].volume(), 59 | outputs[0].dim[2], outputs[0].dim[3], inputs[0].dim[2], inputs[0].dim[3]); 60 | if (block) 61 | checkCUDA(cudaDeviceSynchronize()); 62 | } 63 | 64 | void Model::measure_enlarge_cost(Enlarge* enl) 65 | { 66 | enl->runtime = 0.0f; 67 | } 68 | -------------------------------------------------------------------------------- /src/cudnn/matmul_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Matmul::map(void) 21 | { 22 | // create descriptors 23 | checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); 24 | helperSetTensorDescriptor(outputs[0], outputTensor); 25 | if (activation != AC_MODE_NONE) { 26 | cudnnActivationMode_t mode; 27 | switch (activation) { 28 | case AC_MODE_SIGMOID: 29 | mode = CUDNN_ACTIVATION_SIGMOID; 30 | break; 31 | case AC_MODE_RELU: 32 | mode = CUDNN_ACTIVATION_RELU; 33 | break; 34 | case AC_MODE_TANH: 35 | mode = CUDNN_ACTIVATION_TANH; 36 | break; 37 | default: 38 | assert(false); 39 | } 40 | checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); 41 | checkCUDNN(cudnnSetActivationDescriptor(actiDesc, mode, 42 | CUDNN_NOT_PROPAGATE_NAN, 0.0)); 43 | } 44 | // allocate tensors 45 | size_t outputSize = sizeof(DATATYPE) * outputs[0].volume(); 46 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 47 | } 48 | 49 | void Matmul::unmap(void) 50 | { 51 | checkCUDNN(cudnnDestroyTensorDescriptor(outputTensor)); 52 | if (activation != AC_MODE_NONE) { 53 | checkCUDNN(cudnnDestroyActivationDescriptor(actiDesc)); 54 | } 55 | checkCUDA(cudaFree(outputs[0].data_ptr)); 56 | } 57 | 58 | void Matmul::forward(bool block) 59 | { 60 | const float alpha = 1.0f; 61 | const float beta = 0.0f; 62 | int numDim = outputs[0].numDim; 63 | int m = inputs[0].dim[numDim-2]; 64 | int n = inputs[1].dim[numDim-1]; 65 | int k = inputs[0].dim[numDim-1]; 66 | cublasOperation_t transA, transB; 67 | int lda, ldb, ldc; 68 | if (inputs[0].stride[numDim-2] == 1) { 69 | transA = CUBLAS_OP_N; 70 | lda = inputs[0].stride[numDim-1]; 71 | } else { 72 | assert(inputs[0].stride[numDim-1] == 1); 73 | transA = CUBLAS_OP_T; 74 | lda = inputs[0].stride[numDim-2]; 75 | } 76 | if (inputs[1].stride[numDim-2] == 1) { 77 | transB = CUBLAS_OP_N; 78 | ldb = inputs[1].stride[numDim-1]; 79 | } else { 80 | assert(inputs[1].stride[numDim-1] == 1); 81 | transB = CUBLAS_OP_T; 82 | ldb = inputs[1].stride[numDim-2]; 83 | } 84 | ldc = outputs[0].stride[numDim-1]; 85 | if (numDim == 2) { 86 | // Normal 2D Matmul 87 | checkCUDA(cublasSgemm(model->blas, transA, transB, 88 | m, n, k, &alpha, (float*)inputs[0].data_ptr, lda, 89 | (float*)inputs[1].data_ptr, ldb, &beta, (float*)outputs[0].data_ptr, ldc)); 90 | } else { 91 | // Batched Matmul 92 | int strideA = inputs[0].stride[numDim-3]; 93 | int strideB = inputs[1].stride[numDim-3]; 94 | int strideC = outputs[0].stride[numDim-3]; 95 | int batch = 1; 96 | for (int i = 0; i < numDim-2; i++) 97 | batch *= outputs[0].dim[i]; 98 | checkCUDA(cublasSgemmStridedBatched(model->blas, transA, transB, 99 | m, n, k, &alpha, (float*)inputs[0].data_ptr, lda, strideA, 100 | (float*)inputs[1].data_ptr, ldb, strideB, 101 | &beta, (float*)outputs[0].data_ptr, ldc, strideC, batch)); 102 | } 103 | if (activation != AC_MODE_NONE) 104 | checkCUDNN(cudnnActivationForward(model->dnn, actiDesc, 105 | &alpha, outputTensor, outputs[0].data_ptr, 106 | &beta, outputTensor, outputs[0].data_ptr)); 107 | if (block) 108 | checkCUDA(cudaDeviceSynchronize()); 109 | } 110 | 111 | void Model::measure_matmul_cost(Matmul* mm) 112 | { 113 | const float alpha = 1.0f; 114 | const float beta = 0.0f; 115 | int numDim = mm->outputs[0].numDim; 116 | int m = mm->inputs[0].dim[numDim-2]; 117 | int n = mm->inputs[1].dim[numDim-1]; 118 | int k = mm->inputs[0].dim[numDim-1]; 119 | cublasOperation_t transA, transB; 120 | int lda, ldb, ldc; 121 | if (mm->inputs[0].stride[numDim-2] == 1) { 122 | transA = CUBLAS_OP_N; 123 | lda = mm->inputs[0].stride[numDim-1]; 124 | } else { 125 | assert(mm->inputs[0].stride[numDim-1] == 1); 126 | transA = CUBLAS_OP_T; 127 | lda = mm->inputs[0].stride[numDim-2]; 128 | } 129 | if (mm->inputs[1].stride[numDim-2] == 1) { 130 | transB = CUBLAS_OP_N; 131 | ldb = mm->inputs[1].stride[numDim-1]; 132 | } else { 133 | assert(mm->inputs[1].stride[numDim-1] == 1); 134 | transB = CUBLAS_OP_T; 135 | ldb = mm->inputs[1].stride[numDim-2]; 136 | } 137 | ldc = mm->outputs[0].stride[numDim-1]; 138 | 139 | if (mm->activation != AC_MODE_NONE) { 140 | cudnnActivationMode_t mode; 141 | switch (mm->activation) { 142 | case AC_MODE_SIGMOID: 143 | mode = CUDNN_ACTIVATION_SIGMOID; 144 | break; 145 | case AC_MODE_RELU: 146 | mode = CUDNN_ACTIVATION_RELU; 147 | break; 148 | case AC_MODE_TANH: 149 | mode = CUDNN_ACTIVATION_TANH; 150 | break; 151 | default: 152 | assert(false); 153 | } 154 | checkCUDNN(cudnnSetActivationDescriptor(actiDesc, mode, 155 | CUDNN_NOT_PROPAGATE_NAN, 0.0)); 156 | } 157 | helperSetTensorDescriptor(mm->outputs[0], outputTensor); 158 | 159 | checkCUDA(cudaDeviceSynchronize()); 160 | for (int i = 0; i < WARMUP_TIMES + REPEAT_TIMES; i++) { 161 | if (i == WARMUP_TIMES) 162 | checkCUDA(cudaEventRecord(startEvent)); 163 | if (numDim == 2) { 164 | // Normal 2D Matmul 165 | checkCUDA(cublasSgemm(blas, transA, transB, 166 | m, n, k, &alpha, inputPtr, lda, 167 | filterPtr, ldb, &beta, outputPtr, ldc)); 168 | } else { 169 | // Batched Matmul 170 | int strideA = mm->inputs[0].stride[numDim-3]; 171 | int strideB = mm->inputs[1].stride[numDim-3]; 172 | int strideC = mm->outputs[0].stride[numDim-3]; 173 | int batch = 1; 174 | for (int i = 0; i < numDim-2; i++) 175 | batch *= mm->outputs[0].dim[i]; 176 | checkCUDA(cublasSgemmStridedBatched(blas, transA, transB, 177 | m, n, k, &alpha, inputPtr, lda, strideA, 178 | filterPtr, ldb, strideB, 179 | &beta, outputPtr, ldc, strideC, batch)); 180 | } 181 | if (mm->activation != AC_MODE_NONE) 182 | checkCUDNN(cudnnActivationForward(dnn, actiDesc, 183 | &alpha, outputTensor, outputPtr, 184 | &beta, outputTensor, outputPtr)); 185 | } 186 | checkCUDA(cudaEventRecord(endEvent)); 187 | checkCUDA(cudaEventSynchronize(endEvent)); 188 | float milliseconds; 189 | cudaEventElapsedTime(&milliseconds, startEvent, endEvent); 190 | mm->runtime = milliseconds / REPEAT_TIMES; 191 | if (print_cost) 192 | printf(" measure[Matmul]: %s %s acti(%d) cost(%.4lf)\n", 193 | mm->inputs[0].to_string("input").c_str(), 194 | mm->inputs[1].to_string("weight").c_str(), 195 | mm->activation, mm->runtime); 196 | } 197 | 198 | -------------------------------------------------------------------------------- /src/cudnn/merge_gconv_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void MergeGConv::map(void) 21 | { 22 | size_t outputSize = sizeof(DATATYPE) * outputs[0].volume(); 23 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 24 | } 25 | 26 | void MergeGConv::unmap(void) 27 | { 28 | checkCUDA(cudaFree(outputs[0].data_ptr)); 29 | } 30 | 31 | void MergeGConv::forward(bool block) 32 | { 33 | //merge_gconv_kernel<<>>( 34 | // (DATATYPE*)outputs[0].data_ptr, (DATATYPE*)inputs[0].data_ptr, 35 | 36 | if (block) 37 | checkCUDA(cudaDeviceSynchronize()); 38 | } 39 | -------------------------------------------------------------------------------- /src/cudnn/mul_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Mul::map(void) 21 | { 22 | // allocate tensors 23 | size_t outputSize = sizeof(DATATYPE) * outputs[0].volume(); 24 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 25 | } 26 | 27 | void Mul::unmap(void) 28 | { 29 | checkCUDA(cudaFree(outputs[0].data_ptr)); 30 | } 31 | 32 | void Mul::forward(bool block) 33 | { 34 | assert(false); 35 | if (block) 36 | checkCUDA(cudaDeviceSynchronize()); 37 | } 38 | 39 | void Model::measure_mul_cost(Mul* m) 40 | { 41 | assert(false); 42 | } 43 | -------------------------------------------------------------------------------- /src/cudnn/ops_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | Model::Model() 21 | : isTraining(false), print_cost(false) 22 | { 23 | //int* a = (int*) malloc(sizeof(int) * 8); 24 | checkCUDA(cudaSetDevice(0)); 25 | checkCUDNN(cudnnCreate(&dnn)); 26 | checkCUDA(cublasCreate(&blas)); 27 | workSpaceSize = WORK_SPACE_SIZE; 28 | global_unique_id = 100; 29 | checkCUDA(cudaMalloc(&workSpace, workSpaceSize)); 30 | // printf("handle.workSpace = 0x%x\n", workSpace); 31 | // create all descriptors 32 | checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); 33 | checkCUDNN(cudnnCreateTensorDescriptor(&biasTensor)); 34 | checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); 35 | checkCUDNN(cudnnCreateTensorDescriptor(&scaleTensor)); 36 | checkCUDNN(cudnnCreateFilterDescriptor(&filterDesc)); 37 | checkCUDNN(cudnnCreateConvolutionDescriptor(&convDesc)); 38 | checkCUDNN(cudnnCreatePoolingDescriptor(&poolDesc)); 39 | checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); 40 | checkCUDNN(cudnnCreateOpTensorDescriptor(&opDesc)); 41 | // allocate tensors for measuring performance 42 | checkCUDA(cudaMalloc(&inputPtr, MAX_TENSOR_SIZE)); 43 | checkCUDA(cudaMalloc(&biasPtr, MAX_TENSOR_SIZE)); 44 | checkCUDA(cudaMalloc(&outputPtr, MAX_TENSOR_SIZE)); 45 | checkCUDA(cudaMalloc(&filterPtr, MAX_TENSOR_SIZE)); 46 | // create tensors for batch norm 47 | checkCUDA(cudaMalloc(&scalePtr, MAX_TENSOR_SIZE)); 48 | checkCUDA(cudaMalloc(&runningMean, MAX_TENSOR_SIZE)); 49 | checkCUDA(cudaMalloc(&runningVar, MAX_TENSOR_SIZE)); 50 | checkCUDA(cudaMalloc(&saveMean, MAX_TENSOR_SIZE)); 51 | checkCUDA(cudaMalloc(&saveVar, MAX_TENSOR_SIZE)); 52 | // create cuda events 53 | checkCUDA(cudaEventCreate(&startEvent)); 54 | checkCUDA(cudaEventCreate(&endEvent)); 55 | } 56 | 57 | float Model::measure_oplist_runtime(const std::vector& opBaseList) 58 | { 59 | const int num_runs = 100; 60 | // warmup 61 | for (int times = 0; times < num_runs; times++) 62 | for (int i = 0; i < opBaseList.size(); i++) 63 | opBaseList[i]->forward(); 64 | // measure runtime 65 | // checkCUDA(cudaDeviceSynchronize()); 66 | checkCUDA(cudaEventRecord(startEvent)); 67 | for (int times = 0; times < num_runs; times++) { 68 | for (int i = 0; i < opBaseList.size(); i++) 69 | opBaseList[i]->forward(); 70 | } 71 | checkCUDA(cudaEventRecord(endEvent)); 72 | checkCUDA(cudaEventSynchronize(endEvent)); 73 | float milliseconds; 74 | cudaEventElapsedTime(&milliseconds, startEvent, endEvent); 75 | return milliseconds / num_runs; 76 | } 77 | 78 | void* Model::allocate_memory(size_t size, const DATATYPE* data_initial) 79 | { 80 | void* ptr; 81 | checkCUDA(cudaMalloc(&ptr, size)); 82 | if (data_initial != NULL) { 83 | checkCUDA(cudaMemcpy(ptr, data_initial, size, cudaMemcpyDefault)); 84 | } 85 | return ptr; 86 | } 87 | 88 | bool Model::copy_memory(DATATYPE* dst, const DATATYPE* src, size_t size) 89 | { 90 | checkCUDA(cudaMemcpy(dst, src, size, cudaMemcpyDefault)); 91 | return true; 92 | } 93 | -------------------------------------------------------------------------------- /src/cudnn/pool2d_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Pool2D::map(void) 21 | { 22 | // create descriptors 23 | checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); 24 | checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); 25 | checkCUDNN(cudnnCreatePoolingDescriptor(&poolDesc)); 26 | int inputC = inputs[0].dim[1]; 27 | int inputH = inputs[0].dim[2]; 28 | int inputW = inputs[0].dim[3]; 29 | int padH, padW; 30 | get_padding(&padH, &padW); 31 | // set descriptors 32 | checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, CUDNN_TENSOR_NCHW, 33 | CUDNN_DATA_FLOAT, BATCH_SIZE, inputC, inputH, inputW)); 34 | cudnnPoolingMode_t mode; 35 | if (type == OP_POOL2D_MAX) 36 | mode = CUDNN_POOLING_MAX; 37 | else if (type == OP_POOL2D_AVG) 38 | mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; 39 | checkCUDNN(cudnnSetPooling2dDescriptor(poolDesc, mode, CUDNN_PROPAGATE_NAN, 40 | kernelH, kernelW, padH, padW, strideH, strideW)); 41 | int n, c, h, w; 42 | checkCUDNN(cudnnGetPooling2dForwardOutputDim(poolDesc, 43 | inputTensor, &n, &c, &h, &w)); 44 | assert(n == BATCH_SIZE); 45 | assert(c == inputC); 46 | assert(outputs[0].dim[2] == h); 47 | assert(outputs[0].dim[3] == w); 48 | checkCUDNN(cudnnSetTensor4dDescriptor(outputTensor, CUDNN_TENSOR_NCHW, 49 | CUDNN_DATA_FLOAT, n, c, h, w)); 50 | if (activation != AC_MODE_NONE) { 51 | checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); 52 | cudnnActivationMode_t mode = get_activation_mode(activation); 53 | checkCUDNN(cudnnSetActivationDescriptor(actiDesc, mode, 54 | CUDNN_PROPAGATE_NAN, 0.0)); 55 | } 56 | // allocate tensors 57 | size_t outputSize = sizeof(DATATYPE) * n * c * h * w; 58 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 59 | } 60 | 61 | void Pool2D::unmap(void) 62 | { 63 | checkCUDNN(cudnnDestroyTensorDescriptor(inputTensor)); 64 | checkCUDNN(cudnnDestroyTensorDescriptor(outputTensor)); 65 | checkCUDNN(cudnnDestroyPoolingDescriptor(poolDesc)); 66 | if (activation != AC_MODE_NONE) { 67 | checkCUDNN(cudnnDestroyActivationDescriptor(actiDesc)); 68 | } 69 | // free tensors 70 | checkCUDA(cudaFree(outputs[0].data_ptr)); 71 | } 72 | 73 | void Pool2D::forward(bool block) 74 | { 75 | const float alpha = 1.0f; 76 | const float beta = 0.0f; 77 | checkCUDNN(cudnnPoolingForward(model->dnn, poolDesc, 78 | &alpha, inputTensor, inputs[0].data_ptr, 79 | &beta, outputTensor, outputs[0].data_ptr)); 80 | if (activation != AC_MODE_NONE) { 81 | checkCUDNN(cudnnActivationForward(model->dnn, actiDesc, 82 | &alpha, outputTensor, outputs[0].data_ptr, 83 | &beta, outputTensor, outputs[0].data_ptr)); 84 | } 85 | if (block) 86 | checkCUDA(cudaDeviceSynchronize()); 87 | } 88 | 89 | void Model::measure_pool2d_cost(Pool2D* pool) 90 | { 91 | const float alpha = 1.0f; 92 | const float beta = 0.0f; 93 | int inputC = pool->inputs[0].dim[1]; 94 | int inputH = pool->inputs[0].dim[2]; 95 | int inputW = pool->inputs[0].dim[3]; 96 | int outputH = pool->outputs[0].dim[2]; 97 | int outputW = pool->outputs[0].dim[3]; 98 | int padH, padW; 99 | // Reference: https://www.tensorflow.org/api_guides/python/nn#Convolution 100 | switch (pool->padding) { 101 | case PD_MODE_SAME: 102 | int totalPadH, totalPadW; 103 | if (inputH % pool->strideH == 0) 104 | totalPadH = max(pool->kernelH - pool->strideH, 0); 105 | else 106 | totalPadH = max(pool->kernelH - (inputH % pool->strideH), 0); 107 | if (inputW % pool->strideW == 0) 108 | totalPadW = max(pool->kernelW - pool->strideW, 0); 109 | else 110 | totalPadW = max(pool->kernelW - (inputW % pool->strideW), 0); 111 | // assert same padding on both sides 112 | padH = (totalPadH + 1) / 2; 113 | padW = (totalPadW + 1)/ 2; 114 | break; 115 | case PD_MODE_VALID: 116 | padH = 0; 117 | padW = 0; 118 | break; 119 | default: 120 | assert(false); 121 | } 122 | checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, CUDNN_TENSOR_NCHW, 123 | CUDNN_DATA_FLOAT, BATCH_SIZE, inputC, inputH, inputW)); 124 | cudnnPoolingMode_t mode; 125 | if (pool->type == OP_POOL2D_MAX) 126 | mode = CUDNN_POOLING_MAX; 127 | else if (pool->type == OP_POOL2D_AVG) 128 | mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; 129 | checkCUDNN(cudnnSetPooling2dDescriptor(poolDesc, mode, 130 | CUDNN_PROPAGATE_NAN, pool->kernelH, pool->kernelW, padH, padW, 131 | pool->strideH, pool->strideW)); 132 | checkCUDNN(cudnnSetActivationDescriptor(actiDesc, CUDNN_ACTIVATION_RELU, 133 | CUDNN_NOT_PROPAGATE_NAN, 0.0)); 134 | int n, c, h, w; 135 | checkCUDNN(cudnnGetPooling2dForwardOutputDim(poolDesc, 136 | inputTensor, &n, &c, &h, &w)); 137 | assert(n == BATCH_SIZE); 138 | assert(c == inputC); 139 | assert(outputH == h); 140 | assert(outputW == w); 141 | checkCUDNN(cudnnSetTensor4dDescriptor(outputTensor, CUDNN_TENSOR_NCHW, 142 | CUDNN_DATA_FLOAT, n, c, h, w)); 143 | size_t inputSize = sizeof(DATATYPE) * BATCH_SIZE * inputC * inputH * inputW; 144 | size_t outputSize = sizeof(DATATYPE) * BATCH_SIZE * inputC * outputH * outputW; 145 | assert(inputSize < MAX_TENSOR_SIZE); 146 | assert(outputSize < MAX_TENSOR_SIZE); 147 | checkCUDA(cudaDeviceSynchronize()); 148 | for (int i = 0; i < WARMUP_TIMES + REPEAT_TIMES; i++) { 149 | if (i == WARMUP_TIMES) { 150 | checkCUDA(cudaEventRecord(startEvent)); 151 | } 152 | checkCUDNN(cudnnPoolingForward(dnn, poolDesc, 153 | &alpha, inputTensor, inputPtr, 154 | &beta, outputTensor, outputPtr)); 155 | if (pool->activation != AC_MODE_NONE) { 156 | checkCUDNN(cudnnActivationForward(dnn, actiDesc, 157 | &alpha, outputTensor, outputPtr, 158 | &beta, outputTensor, outputPtr)); 159 | } 160 | } 161 | checkCUDA(cudaEventRecord(endEvent)); 162 | checkCUDA(cudaEventSynchronize(endEvent)); 163 | float milliseconds; 164 | cudaEventElapsedTime(&milliseconds, startEvent, endEvent); 165 | pool->runtime = milliseconds / REPEAT_TIMES; 166 | if (print_cost) 167 | printf(" measure[Pool2D]: i(%d %d %d %d) k(%d %d) s(%d %d) p(%d %d) cost(%.4lf)\n", 168 | BATCH_SIZE, inputC, inputH, inputW, pool->kernelH, pool->kernelW, 169 | pool->strideH, pool->strideW, padH, padW, pool->runtime); 170 | } 171 | 172 | -------------------------------------------------------------------------------- /src/cudnn/reshape_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Reshape::map(void) 21 | { 22 | // allocate tensors 23 | size_t outputSize = sizeof(DATATYPE) * outputs[0].volume(); 24 | checkCUDA(cudaMalloc(&outputs[0].data_ptr, outputSize)); 25 | } 26 | 27 | void Reshape::unmap(void) 28 | { 29 | checkCUDA(cudaFree(outputs[0].data_ptr)); 30 | } 31 | 32 | void Reshape::forward(bool block) 33 | { 34 | if (block) 35 | checkCUDA(cudaDeviceSynchronize()); 36 | } 37 | 38 | void Model::measure_reshape_cost(Reshape* reshape) 39 | { 40 | // FIXME: assume the cost is zero for now 41 | reshape->runtime = 0; 42 | } 43 | -------------------------------------------------------------------------------- /src/cudnn/transpose_cudnn.cu: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Stanford 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include "xflow/ops.h" 17 | #include "xflow/cuda_helper.h" 18 | using namespace XFlow; 19 | 20 | void Transpose::map(void) 21 | { 22 | //TODO: for now the output and input share the same instance 23 | outputs[0].data_ptr = inputs[0].data_ptr; 24 | } 25 | 26 | void Transpose::unmap(void) 27 | { 28 | } 29 | 30 | void Transpose::forward(bool block) 31 | { 32 | if (block) 33 | checkCUDA(cudaDeviceSynchronize()); 34 | } 35 | 36 | void Model::measure_transpose_cost(Transpose* transpose) 37 | { 38 | // Transpose requires no kernel launch 39 | transpose->runtime = 0; 40 | } 41 | -------------------------------------------------------------------------------- /src/generator/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | g++ generator.cc rules.pb.cc -o generator -I/usr/local/include -L/usr/local/lib -lprotobuf -std=c++11 -pthread -O2 3 | -------------------------------------------------------------------------------- /tensorflow_py/bert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import numpy as np 4 | import time 5 | from shared_functions import make_matmul 6 | 7 | def attention(input, heads): 8 | d_model = input.shape[1].value 9 | q = make_matmul(input, d_model) 10 | k = make_matmul(input, d_model) 11 | v = make_matmul(input, d_model) 12 | # reshape query, key, value 13 | q = tf.reshape(q, shape=(64,16,64)) 14 | k = tf.reshape(k, shape=(64,16,64)) 15 | v = tf.reshape(v, shape=(64,16,64)) 16 | # transpose q, k, v for batched matmul 17 | q = tf.transpose(q, perm=(1,0,2)) 18 | k = tf.transpose(k, perm=(1,0,2)) 19 | v = tf.transpose(v, perm=(1,0,2)) 20 | logits = tf.matmul(q, k) 21 | output = tf.matmul(logits, v) 22 | # transpose the output back 23 | output = tf.transpose(output, perm=(1,0,2)) 24 | output = tf.reshape(output, shape=(64, 1024)) 25 | # a final linear layer 26 | output = make_matmul(tf.nn.relu(make_matmul(input, 4*d_model)), d_model) 27 | return output 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--xla", help="Whether to run with TensorFlowXLA optimizations", action="store_true") 31 | parser.add_argument("--print_tensorboard", help="Name of folder to output the tensorboard information") 32 | parser.add_argument("--iterations", help="How many iterations to average for timing (default 5000)", type=int, default=1000) 33 | parser.add_argument("--discard_iter", help="How many iterations to not time during warm up (default 1000)", type=int, default=1000) 34 | args = parser.parse_args() 35 | 36 | input = tf.placeholder(tf.float32, shape=(64,1024)) 37 | input_dictionary = {} 38 | input_dictionary[input] = np.random.random_sample((64, 1024)) 39 | t = input 40 | for i in range(12): 41 | t = attention(t, 16) 42 | 43 | output_nodes = [t] 44 | 45 | config = tf.ConfigProto() 46 | if (args.xla): 47 | print("Measuring inference performance with XLA ON") 48 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 49 | else: 50 | print("Measuring inference performance with XLA OFF") 51 | print(config.graph_options.optimizer_options.global_jit_level) 52 | 53 | with tf.Session(config=config) as sess: 54 | if (args.print_tensorboard): 55 | writer = tf.summary.FileWriter(args.print_tensorboard, sess.graph) 56 | times = [] 57 | for i in range(args.discard_iter + args.iterations): 58 | t0 = time.time() 59 | sess.run(output_nodes, input_dictionary) 60 | t1 = time.time() 61 | times.append(t1 - t0) 62 | total = 0 63 | for i in range(args.discard_iter, len(times)): 64 | total += times[i] 65 | avg = total / (args.iterations) * 1000.0 66 | print("Average inference time of the last " + str(args.iterations) + " iterations: " + str(avg) + " ms") 67 | -------------------------------------------------------------------------------- /tensorflow_py/nasnet_a.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import numpy as np 4 | import time 5 | from shared_functions import make_activation, make_conv2d, make_seperable_conv2d, make_avgpool2d, make_maxpool2d 6 | 7 | def squeeze(out_channels, input): 8 | return make_conv2d(input_tensor=input, filter_shape=(1,1,input.shape[1].value,out_channels), strides=(1,1,1,1), padding="SAME", actimode="RELU", name="squeeze") 9 | 10 | def fit(current, input): 11 | if (input.shape[2].value == current.shape[2].value): 12 | return squeeze(current.shape[1].value, input) 13 | else: 14 | return make_conv2d(input_tensor=input, filter_shape=(3,3,input.shape[1].value,current.shape[1].value), strides=(1,1,2,2), padding="SAME", actimode="RELU", name="fit") 15 | 16 | def normal_cell(prev, cur, out_channels): 17 | cur = squeeze(out_channels, cur) 18 | prev = fit(cur, prev) 19 | ts = list() 20 | ts.append(make_seperable_conv2d(input_tensor=cur, out_channels=out_channels, kernels=(3,3), strides=(1,1,1,1), padding="SAME")) 21 | ts.append(cur) 22 | ts.append(make_seperable_conv2d(input_tensor=prev, out_channels=out_channels, kernels=(3,3), strides=(1,1,1,1), padding="SAME")) 23 | ts.append(make_seperable_conv2d(input_tensor=cur, out_channels=out_channels, kernels=(3,3), strides=(1,1,1,1), padding="SAME")) 24 | ts.append(make_avgpool2d(input_tensor=cur, kernels=(1,1,3,3), strides=(1,1,1,1), padding="SAME")) 25 | ts.append(prev) 26 | ts.append(make_avgpool2d(input_tensor=prev, kernels=(1,1,3,3), strides=(1,1,1,1), padding="SAME")) 27 | ts.append(make_avgpool2d(input_tensor=prev, kernels=(1,1,3,3), strides=(1,1,1,1), padding="SAME")) 28 | ts.append(make_seperable_conv2d(input_tensor=prev, out_channels=out_channels, kernels=(3,3), strides=(1,1,1,1), padding="SAME")) 29 | ts.append(make_seperable_conv2d(input_tensor=prev, out_channels=out_channels, kernels=(3,3), strides=(1,1,1,1), padding="SAME")) 30 | assert len(ts) == 10 31 | outputs=list() 32 | for i in range(5): 33 | outputs.append(tf.add(ts[2*i], ts[2*i+1])) 34 | return tf.concat(outputs, axis=1, name="concat1") 35 | 36 | def reduction_cell(prev, cur, out_channels): 37 | cur = squeeze(out_channels, cur) 38 | prev = fit(cur, prev) 39 | ts = list() 40 | outputs = list() 41 | ts.append(make_seperable_conv2d(input_tensor=prev, out_channels=out_channels, kernels=(7,7), strides=(1,1,2,2), padding="SAME")) 42 | ts.append(make_seperable_conv2d(input_tensor=cur, out_channels=out_channels, kernels=(5,5), strides=(1,1,2,2), padding="SAME")) 43 | outputs.append(tf.add(ts[0], ts[1])) 44 | ts.append(make_maxpool2d(input_tensor=cur, kernels=(1,1,3,3), strides=(1,1,2,2), padding="SAME")) 45 | ts.append(make_seperable_conv2d(input_tensor=prev, out_channels=out_channels, kernels=(7,7), strides=(1,1,2,2), padding="SAME")) 46 | outputs.append(tf.add(ts[2], ts[3])) 47 | ts.append(make_avgpool2d(input_tensor=cur, kernels=(1,1,3,3), strides=(1,1,2,2), padding="SAME")) 48 | ts.append(make_seperable_conv2d(input_tensor=prev, out_channels=out_channels, kernels=(5,5), strides=(1,1,2,2), padding="SAME")) 49 | outputs.append(tf.add(ts[4], ts[5])) 50 | ts.append(make_maxpool2d(input_tensor=cur, kernels=(1,1,3,3), strides=(1,1,2,2), padding="SAME")) 51 | ts.append(make_seperable_conv2d(input_tensor=outputs[0], out_channels=out_channels, kernels=(3,3), strides=(1,1,1,1), padding="SAME")) 52 | outputs.append(tf.add(ts[6], ts[7])) 53 | ts.append(make_avgpool2d(input_tensor=outputs[0], kernels=(1,1,3,3), strides=(1,1,1,1), padding="SAME")) 54 | ts.append(outputs[1]) 55 | outputs.append(tf.add(ts[8], ts[9])) 56 | return tf.concat(outputs, axis=1, name="concat2") 57 | 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--xla", help="Whether to run with TensorFlowXLA optimizations", action="store_true") 60 | parser.add_argument("--print_tensorboard", help="Name of folder to output the tensorboard information") 61 | parser.add_argument("--iterations", help="How many iterations to average for timing (default 5000)", type=int, default=1000) 62 | parser.add_argument("--discard_iter", help="How many iterations to not time during warm up (default 1000)", type=int, default=1000) 63 | args = parser.parse_args() 64 | 65 | input0 = tf.placeholder(tf.float32, shape=(1,128,56,56)) 66 | input = input0 67 | out_channels = 128 68 | for i in range(3): 69 | if i > 0: 70 | input = reduction_cell(prev, cur, out_channels) 71 | prev = input 72 | cur = input 73 | for j in range(10): 74 | t = normal_cell(prev, cur, out_channels) 75 | prev = cur 76 | cur = t 77 | out_channels *= 2 78 | 79 | config = tf.ConfigProto() 80 | if (args.xla): 81 | print("Measuring inference performance with XLA ON") 82 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 83 | else: 84 | print("Measuring inference performance with XLA OFF") 85 | print(config.graph_options.optimizer_options.global_jit_level) 86 | 87 | output_nodes = [t] 88 | input_dictionary = {} 89 | input_dictionary[input0] = np.random.random_sample((1,128,56,56)) 90 | 91 | with tf.Session(config=config) as sess: 92 | if (args.print_tensorboard): 93 | writer = tf.summary.FileWriter(args.print_tensorboard, sess.graph) 94 | times = [] 95 | for i in range(args.discard_iter + args.iterations): 96 | t0 = time.time() 97 | sess.run(output_nodes, input_dictionary) 98 | t1 = time.time() 99 | times.append(t1 - t0) 100 | total = 0 101 | for i in range(args.discard_iter, len(times)): 102 | total += times[i] 103 | avg = total / (args.iterations) * 1000.0 104 | print("Average inference time of the last " + str(args.iterations) + " iterations: " + str(avg) + " ms") 105 | -------------------------------------------------------------------------------- /tensorflow_py/nasrnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import numpy as np 4 | import time 5 | from shared_functions import make_matmul 6 | 7 | hidden_size = 512 8 | length = 5 9 | 10 | def combine(x, h): 11 | w1 = make_matmul(x, hidden_size) 12 | w2 = make_matmul(h, hidden_size) 13 | return tf.add(tf.nn.relu(w1), tf.nn.relu(w2)) 14 | 15 | def nas_node(input, x): 16 | t = list() 17 | for i in range(8): 18 | t.append(combine(x, input)) 19 | midt = list() 20 | midt.append(tf.add(tf.nn.relu(t[0]), tf.nn.sigmoid(t[3]))) 21 | midt.append(tf.add(tf.nn.sigmoid(t[1]), tf.nn.tanh(t[2]))) 22 | midt.append(tf.multiply(tf.nn.sigmoid(t[4]), tf.nn.tanh(t[5]))) 23 | midt.append(tf.multiply(tf.nn.sigmoid(t[6]), tf.nn.relu(t[7]))) 24 | midt.append(tf.add(tf.nn.sigmoid(midt[1]), tf.nn.tanh(midt[2]))) 25 | midt.append(tf.multiply(tf.nn.tanh(midt[0]), tf.nn.tanh(midt[3]))) 26 | midt.append(tf.multiply(tf.nn.tanh(midt[4]), tf.nn.tanh(midt[5]))) 27 | return tf.nn.tanh(midt[6]) 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--xla", help="Whether to run with TensorFlowXLA optimizations", action="store_true") 31 | parser.add_argument("--print_tensorboard", help="Name of folder to output the tensorboard information") 32 | parser.add_argument("--iterations", help="How many iterations to average for timing (default 5000)", type=int, default=1000) 33 | parser.add_argument("--discard_iter", help="How many iterations to not time during warm up (default 1000)", type=int, default=1000) 34 | args = parser.parse_args() 35 | 36 | input_dictionary = {} 37 | xs = list() 38 | output_nodes = [] 39 | for i in range(length): 40 | xs.append(tf.placeholder(tf.float32, shape=(1, hidden_size))) 41 | input_dictionary[xs[i]] = np.random.random_sample((1, hidden_size)) 42 | state = tf.constant(np.random.random_sample((1, hidden_size)), dtype=tf.float32) 43 | for i in range(length): 44 | state = nas_node(state, xs[i]) 45 | output_nodes.append(state) 46 | 47 | config = tf.ConfigProto() 48 | if (args.xla): 49 | print("Measuring inference performance with XLA ON") 50 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 51 | else: 52 | print("Measuring inference performance with XLA OFF") 53 | print(config.graph_options.optimizer_options.global_jit_level) 54 | 55 | with tf.Session(config=config) as sess: 56 | if (args.print_tensorboard): 57 | writer = tf.summary.FileWriter(args.print_tensorboard, sess.graph) 58 | times = [] 59 | for i in range(args.discard_iter + args.iterations): 60 | t0 = time.time() 61 | sess.run(output_nodes, input_dictionary) 62 | t1 = time.time() 63 | times.append(t1 - t0) 64 | total = 0 65 | for i in range(args.discard_iter, len(times)): 66 | total += times[i] 67 | avg = total / (args.iterations) * 1000.0 68 | print("Average inference time of the last " + str(args.iterations) + " iterations: " + str(avg) + " ms") 69 | -------------------------------------------------------------------------------- /tensorflow_py/resnet50.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import numpy as np 4 | import time 5 | from shared_functions import make_activation, make_conv2d 6 | 7 | def resnet_block(input, strides, out_channels, name): 8 | t = make_conv2d(input_tensor=input, filter_shape=(1,1,input.shape[1].value,out_channels), strides=(1,1,1,1), padding="SAME", actimode="RELU", name=name+"_conv1") 9 | t = make_conv2d(input_tensor=t, filter_shape=(3,3,out_channels,out_channels), strides=strides, padding="SAME", actimode="RELU", name=name+"_conv2") 10 | t = make_conv2d(input_tensor=t, filter_shape=(1,1,out_channels,out_channels*4), strides=(1,1,1,1), padding="SAME", actimode="NONE", name=name+"_conv3") 11 | if (strides[2]>1) or (input.shape[1].value != out_channels * 4): 12 | input = make_conv2d(input_tensor=input, filter_shape=(1,1,input.shape[1].value,out_channels*4), strides=strides, padding="SAME", actimode="RELU", name=name+"_conv4") 13 | return tf.nn.relu(tf.add(input, t)) 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--xla", help="Whether to run with TensorFlowXLA optimizations", action="store_true") 17 | parser.add_argument("--print_tensorboard", help="Name of folder to output the tensorboard information") 18 | parser.add_argument("--iterations", help="How many iterations to average for timing (default 5000)", type=int, default=1000) 19 | parser.add_argument("--discard_iter", help="How many iterations to not time during warm up (default 1000)", type=int, default=1000) 20 | args = parser.parse_args() 21 | 22 | input = tf.placeholder(tf.float32, shape=(1,64,56,56)) 23 | t = input 24 | for i in range(3): 25 | t = resnet_block(t, (1,1,1,1), 64, "resnet_block_1_{}".format(i)) 26 | strides=(1,1,2,2) 27 | for i in range(4): 28 | t = resnet_block(t, strides, 128, "resnet_block_2_{}".format(i)) 29 | strides=(1,1,1,1) 30 | strides=(1,1,2,2) 31 | for i in range(6): 32 | t = resnet_block(t, strides, 256, "resnet_block_3_{}".format(i)) 33 | strides=(1,1,1,1) 34 | strides=(1,1,2,2) 35 | for i in range(3): 36 | t = resnet_block(t, strides, 512, "resnet_block_4_{}".format(i)) 37 | strides=(1,1,1,1) 38 | 39 | config = tf.ConfigProto() 40 | if (args.xla): 41 | print("Measuring inference performance with XLA ON") 42 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 43 | else: 44 | print("Measuring inference performance with XLA OFF") 45 | print(config.graph_options.optimizer_options.global_jit_level) 46 | 47 | output_nodes = [t] 48 | input_dictionary = {} 49 | input_dictionary[input] = np.random.random_sample((1,64,56,56)) 50 | 51 | with tf.Session(config=config) as sess: 52 | if (args.print_tensorboard): 53 | writer = tf.summary.FileWriter(args.print_tensorboard, sess.graph) 54 | times = [] 55 | for i in range(args.discard_iter + args.iterations): 56 | t0 = time.time() 57 | sess.run(output_nodes, input_dictionary) 58 | t1 = time.time() 59 | times.append(t1 - t0) 60 | total = 0 61 | for i in range(args.discard_iter, len(times)): 62 | total += times[i] 63 | avg = total / (args.iterations) * 1000.0 64 | print("Average inference time of the last " + str(args.iterations) + " iterations: " + str(avg) + " ms") 65 | -------------------------------------------------------------------------------- /tensorflow_py/resnext50.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import numpy as np 4 | import time 5 | from shared_functions import make_activation, make_conv2d 6 | 7 | def resnext_block(input, strides, out_channels, groups, name): 8 | t = make_conv2d(input_tensor=input, filter_shape=(1,1,input.shape[1].value,out_channels), strides=(1,1,1,1), padding="SAME", actimode="RELU", name=name+"_conv1") 9 | t = tf.split(t, groups, axis=1, name=name+"_split") 10 | assert(len(t) == groups) 11 | for i in range(groups): 12 | t[i] = make_conv2d(input_tensor=t[i], filter_shape=(3,3,t[i].shape[1].value,out_channels//groups), strides=strides, padding="SAME", actimode="RELU", name=name+"_conv2_".format(i)) 13 | output = tf.concat(t, axis=1, name=name+"_concat") 14 | t = make_conv2d(input_tensor=output, filter_shape=(1,1,output.shape[1].value,2*out_channels), strides=(1,1,1,1), padding="SAME", actimode="NONE", name=name+"_conv3") 15 | if (strides[2]>1) or (input.shape[1].value != out_channels*2): 16 | input = make_conv2d(input_tensor=input, filter_shape=(1,1,input.shape[1].value,out_channels*2), strides=strides, padding="SAME", actimode="RELU", name=name+"_conv4") 17 | return tf.nn.relu(tf.add(input, t)) 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--xla", help="Whether to run with TensorFlowXLA optimizations", action="store_true") 21 | parser.add_argument("--print_tensorboard", help="Name of folder to output the tensorboard information") 22 | parser.add_argument("--iterations", help="How many iterations to average for timing (default 5000)", type=int, default=1000) 23 | parser.add_argument("--discard_iter", help="How many iterations to not time during warm up (default 1000)", type=int, default=1000) 24 | args = parser.parse_args() 25 | 26 | input = tf.placeholder(tf.float32, shape=(1,64,56,56)) 27 | t = input 28 | for i in range(3): 29 | t = resnext_block(t, (1,1,1,1), 128, 32, "resnet_block_1_{}".format(i)) 30 | strides=(1,1,2,2) 31 | for i in range(4): 32 | t = resnext_block(t, strides, 256, 32, "resnet_block_2_{}".format(i)) 33 | strides=(1,1,1,1) 34 | strides=(1,1,2,2) 35 | for i in range(6): 36 | t = resnext_block(t, strides, 512, 32, "resnet_block_3_{}".format(i)) 37 | strides=(1,1,1,1) 38 | strides=(1,1,2,2) 39 | for i in range(3): 40 | t = resnext_block(t, strides, 1024, 32, "resnet_block_4_{}".format(i)) 41 | strides=(1,1,1,1) 42 | 43 | config = tf.ConfigProto() 44 | if (args.xla): 45 | print("Measuring inference performance with XLA ON") 46 | config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 47 | else: 48 | print("Measuring inference performance with XLA OFF") 49 | print(config.graph_options.optimizer_options.global_jit_level) 50 | 51 | output_nodes = [t] 52 | input_dictionary = {} 53 | input_dictionary[input] = np.random.random_sample((1,64,56,56)) 54 | 55 | with tf.Session(config=config) as sess: 56 | if (args.print_tensorboard): 57 | writer = tf.summary.FileWriter(args.print_tensorboard, sess.graph) 58 | times = [] 59 | for i in range(args.discard_iter + args.iterations): 60 | t0 = time.time() 61 | sess.run(output_nodes, input_dictionary) 62 | t1 = time.time() 63 | times.append(t1 - t0) 64 | total = 0 65 | for i in range(args.discard_iter, len(times)): 66 | total += times[i] 67 | avg = total / (args.iterations) * 1000.0 68 | print("Average inference time of the last " + str(args.iterations) + " iterations: " + str(avg) + " ms") 69 | -------------------------------------------------------------------------------- /tensorflow_py/shared_functions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def make_activation(input, actimode, name): 5 | if actimode == "NONE": 6 | return input 7 | elif actimode == "RELU": 8 | relu_name = name + "_relu" 9 | relu = tf.nn.relu(input, name=relu_name) 10 | return relu 11 | elif actimode == "SIGMOID": 12 | sigmoid_name = name + "_sigmoid" 13 | sigmoid = tf.nn.sigmoid(input, name=sigmoid_name) 14 | return sigmoid 15 | elif actimode == "TANH": 16 | tanh_name = name + "_tanh" 17 | tanh = tf.nn.tanh(input, name=tanh_name) 18 | return tanh 19 | else: 20 | print("Unknown Actimode") 21 | assert(0) 22 | 23 | def make_conv2d(input_tensor, filter_shape, strides, padding, actimode, name): 24 | weights_name = name + "_weights" 25 | conv_name = name + "_conv2d" 26 | weights = tf.constant(np.random.random_sample(filter_shape), name=weights_name, dtype=tf.float32) 27 | conv2d = tf.nn.conv2d(input_tensor, weights, strides, padding, data_format="NCHW", name=conv_name) 28 | return make_activation(conv2d, actimode, name) 29 | 30 | def make_seperable_conv2d(input_tensor, out_channels, kernels, strides, padding, actimode="NONE", name="seperable_conv2d"): 31 | depthwise_filter_shape=(kernels[0],kernels[1],input_tensor.shape[1].value,1) 32 | pointwise_filter_shape=(1,1,input_tensor.shape[1].value,out_channels) 33 | dp_filter = tf.constant(np.random.random_sample(depthwise_filter_shape), name=name+"_dp_filter", dtype=tf.float32) 34 | pw_filter = tf.constant(np.random.random_sample(pointwise_filter_shape), name=name+"_pw_filter", dtype=tf.float32) 35 | conv2d = tf.nn.separable_conv2d(input_tensor, dp_filter, pw_filter, strides, padding, data_format="NCHW", name=name) 36 | return make_activation(conv2d, actimode, name) 37 | 38 | def make_avgpool2d(input_tensor, kernels, strides, padding): 39 | return tf.nn.avg_pool(input_tensor, kernels, strides, padding, data_format="NCHW") 40 | 41 | def make_maxpool2d(input_tensor, kernels, strides, padding): 42 | return tf.nn.max_pool(input_tensor, kernels, strides, padding, data_format="NCHW") 43 | 44 | def make_matmul(input_tensor, out_channels): 45 | weight_shape = (input_tensor.shape[1].value, out_channels) 46 | weight = tf.constant(np.random.random_sample(weight_shape), dtype=tf.float32) 47 | return tf.matmul(input_tensor, weight) 48 | --------------------------------------------------------------------------------