├── .clang-format ├── .github └── workflows │ └── wheels.yml ├── .gitignore ├── .gitmodules ├── CITATION.cff ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── include └── loop_tool │ ├── backend.h │ ├── compile.h │ ├── cpp.h │ ├── dynlib.h │ ├── error.h │ ├── hardware.h │ ├── ir.h │ ├── lazy.h │ ├── loop_tool.h │ ├── mutate.h │ ├── nn.h │ ├── serialization.h │ ├── smallvec.h │ ├── symbolic.h │ ├── tensor.h │ └── wasm.h ├── install_cuda.sh ├── javascript ├── index.html ├── loop_tool.mjs ├── lt.mjs.gz ├── lzstring.js ├── main.mjs ├── test.html ├── tutorial.html ├── viz.html ├── webcam.html └── webcam.mjs ├── python ├── __init__.py ├── nn.py └── ui.py ├── requirements.txt ├── setup.py ├── src ├── backends │ ├── README.md │ ├── cpu │ │ ├── cpp.cpp │ │ └── loop_nest.cpp │ ├── cuda │ │ ├── FindCUDAToolkit.cmake │ │ ├── README.md │ │ ├── cuda.cpp │ │ └── cuda_backend.h │ └── wasm │ │ ├── wasm.cpp │ │ └── wasm_runtime.cpp ├── core │ ├── backend.cpp │ ├── compile.cpp │ ├── hardware.cpp │ ├── ir.cpp │ ├── serialization.cpp │ ├── symbolic.cpp │ └── tensor.cpp └── frontends │ ├── javascript.cpp │ ├── lazy.cpp │ ├── mutate.cpp │ ├── nn.cpp │ └── python.cpp ├── test ├── bench.py ├── bench_lazy.py ├── cuda_test.cpp ├── loop_nest_test.cpp ├── test.cpp ├── test.mjs ├── test.py ├── test_backend.cpp ├── test_cpp.cpp ├── test_ir.cpp ├── test_lazy.cpp ├── test_lazy.py ├── test_ln.py ├── test_mutate.cpp ├── test_nn.cpp ├── test_ops.py ├── test_serialization.cpp ├── test_symbolic.cpp ├── test_ui.py ├── test_utils.h ├── test_views.py ├── utils.cpp ├── wasm_runtime_test.cpp └── wasm_test.cpp └── tutorial.ipynb /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: Google 4 | AccessModifierOffset: -1 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveMacros: false 7 | AlignConsecutiveAssignments: false 8 | AlignConsecutiveDeclarations: false 9 | AlignEscapedNewlines: Left 10 | AlignOperands: true 11 | AlignTrailingComments: true 12 | AllowAllArgumentsOnNextLine: true 13 | AllowAllConstructorInitializersOnNextLine: true 14 | AllowAllParametersOfDeclarationOnNextLine: true 15 | AllowShortBlocksOnASingleLine: Never 16 | AllowShortCaseLabelsOnASingleLine: false 17 | AllowShortFunctionsOnASingleLine: All 18 | AllowShortLambdasOnASingleLine: All 19 | AllowShortIfStatementsOnASingleLine: WithoutElse 20 | AllowShortLoopsOnASingleLine: true 21 | AlwaysBreakAfterDefinitionReturnType: None 22 | AlwaysBreakAfterReturnType: None 23 | AlwaysBreakBeforeMultilineStrings: true 24 | AlwaysBreakTemplateDeclarations: Yes 25 | BinPackArguments: true 26 | BinPackParameters: true 27 | BraceWrapping: 28 | AfterCaseLabel: false 29 | AfterClass: false 30 | AfterControlStatement: false 31 | AfterEnum: false 32 | AfterFunction: false 33 | AfterNamespace: false 34 | AfterObjCDeclaration: false 35 | AfterStruct: false 36 | AfterUnion: false 37 | AfterExternBlock: false 38 | BeforeCatch: false 39 | BeforeElse: false 40 | IndentBraces: false 41 | SplitEmptyFunction: true 42 | SplitEmptyRecord: true 43 | SplitEmptyNamespace: true 44 | BreakBeforeBinaryOperators: None 45 | BreakBeforeBraces: Attach 46 | BreakBeforeInheritanceComma: false 47 | BreakInheritanceList: BeforeColon 48 | BreakBeforeTernaryOperators: true 49 | BreakConstructorInitializersBeforeComma: false 50 | BreakConstructorInitializers: BeforeColon 51 | BreakAfterJavaFieldAnnotations: false 52 | BreakStringLiterals: true 53 | ColumnLimit: 80 54 | CommentPragmas: '^ IWYU pragma:' 55 | CompactNamespaces: false 56 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 57 | ConstructorInitializerIndentWidth: 4 58 | ContinuationIndentWidth: 4 59 | Cpp11BracedListStyle: true 60 | DeriveLineEnding: true 61 | DerivePointerAlignment: true 62 | DisableFormat: false 63 | ExperimentalAutoDetectBinPacking: false 64 | FixNamespaceComments: true 65 | ForEachMacros: 66 | - foreach 67 | - Q_FOREACH 68 | - BOOST_FOREACH 69 | IncludeBlocks: Regroup 70 | IncludeCategories: 71 | - Regex: '^' 72 | Priority: 2 73 | SortPriority: 0 74 | - Regex: '^<.*\.h>' 75 | Priority: 1 76 | SortPriority: 0 77 | - Regex: '^<.*' 78 | Priority: 2 79 | SortPriority: 0 80 | - Regex: '.*' 81 | Priority: 3 82 | SortPriority: 0 83 | IncludeIsMainRegex: '([-_](test|unittest))?$' 84 | IncludeIsMainSourceRegex: '' 85 | IndentCaseLabels: true 86 | IndentGotoLabels: true 87 | IndentPPDirectives: None 88 | IndentWidth: 2 89 | IndentWrappedFunctionNames: false 90 | JavaScriptQuotes: Leave 91 | JavaScriptWrapImports: true 92 | KeepEmptyLinesAtTheStartOfBlocks: false 93 | MacroBlockBegin: '' 94 | MacroBlockEnd: '' 95 | MaxEmptyLinesToKeep: 1 96 | NamespaceIndentation: None 97 | ObjCBinPackProtocolList: Never 98 | ObjCBlockIndentWidth: 2 99 | ObjCSpaceAfterProperty: false 100 | ObjCSpaceBeforeProtocolList: true 101 | PenaltyBreakAssignment: 2 102 | PenaltyBreakBeforeFirstCallParameter: 1 103 | PenaltyBreakComment: 300 104 | PenaltyBreakFirstLessLess: 120 105 | PenaltyBreakString: 1000 106 | PenaltyBreakTemplateDeclaration: 10 107 | PenaltyExcessCharacter: 1000000 108 | PenaltyReturnTypeOnItsOwnLine: 200 109 | PointerAlignment: Left 110 | RawStringFormats: 111 | - Language: Cpp 112 | Delimiters: 113 | - cc 114 | - CC 115 | - cpp 116 | - Cpp 117 | - CPP 118 | - 'c++' 119 | - 'C++' 120 | CanonicalDelimiter: '' 121 | BasedOnStyle: google 122 | - Language: TextProto 123 | Delimiters: 124 | - pb 125 | - PB 126 | - proto 127 | - PROTO 128 | EnclosingFunctions: 129 | - EqualsProto 130 | - EquivToProto 131 | - PARSE_PARTIAL_TEXT_PROTO 132 | - PARSE_TEST_PROTO 133 | - PARSE_TEXT_PROTO 134 | - ParseTextOrDie 135 | - ParseTextProtoOrDie 136 | CanonicalDelimiter: '' 137 | BasedOnStyle: google 138 | ReflowComments: true 139 | SortIncludes: true 140 | SortUsingDeclarations: true 141 | SpaceAfterCStyleCast: false 142 | SpaceAfterLogicalNot: false 143 | SpaceAfterTemplateKeyword: true 144 | SpaceBeforeAssignmentOperators: true 145 | SpaceBeforeCpp11BracedList: false 146 | SpaceBeforeCtorInitializerColon: true 147 | SpaceBeforeInheritanceColon: true 148 | SpaceBeforeParens: ControlStatements 149 | SpaceBeforeRangeBasedForLoopColon: true 150 | SpaceInEmptyBlock: false 151 | SpaceInEmptyParentheses: false 152 | SpacesBeforeTrailingComments: 2 153 | SpacesInAngles: false 154 | SpacesInConditionalStatement: false 155 | SpacesInContainerLiterals: true 156 | SpacesInCStyleCastParentheses: false 157 | SpacesInParentheses: false 158 | SpacesInSquareBrackets: false 159 | SpaceBeforeSquareBrackets: false 160 | Standard: Auto 161 | StatementMacros: 162 | - Q_UNUSED 163 | - QT_REQUIRE_VERSION 164 | TabWidth: 8 165 | UseCRLF: false 166 | UseTab: Never 167 | ... 168 | 169 | -------------------------------------------------------------------------------- /.github/workflows/wheels.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build_wheels: 7 | env: 8 | CIBW_ENVIRONMENT: "CMAKE_GENERATOR='Ninja' GNUMAKEFLAGS=-j4 CUDA_PATH=/usr/local/cuda" 9 | CIBW_BUILD_VERBOSITY: 3 10 | strategy: 11 | matrix: 12 | include: 13 | - runs-on: ubuntu-latest 14 | cibw-arch: manylinux_x86_64 15 | - runs-on: ubuntu-latest 16 | cibw-arch: manylinux_aarch64 17 | - runs-on: macos-latest 18 | cibw-arch: macosx_x86_64 19 | - runs-on: macos-latest 20 | cibw-arch: macosx_arm64 21 | name: Wheels • ${{ matrix.cibw-arch }} 22 | runs-on: ${{ matrix.runs-on }} 23 | 24 | steps: 25 | - uses: actions/checkout@v2 26 | with: 27 | submodules: true 28 | 29 | - uses: actions/setup-python@v2 30 | 31 | - name: Package source distribution 32 | if: runner.os == 'Linux' 33 | run: | 34 | python setup.py sdist -d wheelhouse --formats=gztar 35 | 36 | - name: Configure cibuildwheel 37 | shell: bash 38 | run: | 39 | CMAKE_OSX_ARCHITECTURES=${{ matrix.cibw-arch == 'macosx_x86_64' && 'x86_64' || matrix.cibw-arch == 'macosx_arm64' && 'arm64' || matrix.cibw-arch == 'macosx_universal2' && '"arm64;x86_64"' || '' }} 40 | echo "CIBW_SKIP=\"pp* cp36-*\"" >> $GITHUB_ENV 41 | echo "CIBW_ARCHS_LINUX=x86_64 aarch64" >> $GITHUB_ENV 42 | echo "CIBW_ARCHS_MACOS=x86_64 arm64" >> $GITHUB_ENV 43 | echo "CIBW_BUILD=*-${{ matrix.cibw-arch }}" >> $GITHUB_ENV 44 | echo "CIBW_ENVIRONMENT_MACOS=CMAKE_OSX_ARCHITECTURES=\"$CMAKE_OSX_ARCHITECTURES\"" >> $GITHUB_ENV 45 | 46 | - name: Set up QEMU 47 | if: runner.os == 'Linux' 48 | uses: docker/setup-qemu-action@v1 49 | with: 50 | platforms: all 51 | 52 | - name: Build wheels 53 | uses: pypa/cibuildwheel@v2.11.4 54 | env: 55 | CIBW_BEFORE_BUILD: pip install ninja 56 | # TODO re-enable when backend is cleaned up 57 | #CIBW_BEFORE_ALL_LINUX: sh install_cuda.sh 58 | 59 | - uses: actions/upload-artifact@v2 60 | with: 61 | path: ./wheelhouse/*.whl 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.sw? 2 | build/ 3 | dist/ 4 | loop_tool_py.egg-info/ 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "extern/wasmblr"] 2 | path = extern/wasmblr 3 | url = https://github.com/bwasti/wasmblr.git 4 | [submodule "extern/wasm-micro-runtime"] 5 | path = extern/wasm-micro-runtime 6 | url = https://github.com/bytecodealliance/wasm-micro-runtime.git 7 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Wasti 5 | given-names: Bram 6 | orcid: https://orcid.org/0000-0003-0562-2688 7 | title: "loop_tool" 8 | version: 0.0.1 9 | date-released: 2021-7-8 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(loop_tool) 2 | cmake_minimum_required(VERSION 3.8) 3 | set(CMAKE_CXX_STANDARD 17) 4 | 5 | set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src/) 6 | 7 | set(LIB_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include/) 8 | set(LIB_PRIVATE_INCLUDE ${SRC_DIR}) 9 | 10 | file(GLOB CORE_SRCS 11 | ${SRC_DIR}/core/*.cpp 12 | ) 13 | list(APPEND CORE_SRCS "${SRC_DIR}/frontends/lazy.cpp") 14 | list(APPEND CORE_SRCS "${SRC_DIR}/frontends/mutate.cpp") 15 | list(APPEND CORE_SRCS "${SRC_DIR}/frontends/nn.cpp") 16 | 17 | option(BUILD_WASM "Build a WebAssembly compilation target" ON) 18 | option(BUILD_WASM_RUNTIME "Build a WebAssembly runtime with WAMR" OFF) 19 | option(BUILD_ES6 "Build an ES6 target in emcc" OFF) 20 | 21 | if(DEFINED EMCC_DIR) 22 | set(BUILD_WASM ON) 23 | endif() 24 | 25 | if (BUILD_WASM_RUNTIME) 26 | set(BUILD_WASM ON) 27 | list(APPEND CORE_SRCS "${SRC_DIR}/backends/wasm/wasm_runtime.cpp") 28 | set(WAMR_BUILD_AOT 0) 29 | if (DEFINED LLVM_DIR) 30 | set(WAMR_BUILD_AOT 1) 31 | set(WAMR_BUILD_JIT 1) 32 | endif() 33 | set(WAMR_BUILD_LIBC_BUILTIN 0) 34 | set(WAMR_BUILD_LIBC_WASI 0) 35 | add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/extern/wasm-micro-runtime") 36 | endif() 37 | 38 | if (BUILD_WASM) 39 | list(APPEND CORE_SRCS "${SRC_DIR}/backends/wasm/wasm.cpp") 40 | set(LIB_INCLUDE ${LIB_INCLUDE} "${CMAKE_CURRENT_SOURCE_DIR}/extern/wasmblr") 41 | endif() 42 | 43 | option(BUILD_LOOP_NEST "Build loop_nest backend for fast contractions on CPU" OFF) 44 | if (BUILD_LOOP_NEST) 45 | list(APPEND CORE_SRCS "${SRC_DIR}/backends/cpu/loop_nest.cpp") 46 | set(DABUN_BUILD_APPS_FOR_ALL_SUPPORTED_VEX OFF) 47 | set(DABUN_BUILD_TESTS_FOR_ALL_ARCH_VEX OFF) 48 | add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/extern/loop_nest") 49 | set_target_properties(dabun PROPERTIES POSITION_INDEPENDENT_CODE ON) 50 | endif() 51 | 52 | set(CPU_SRCS ${SRC_DIR}/backends/cpu/cpp.cpp) 53 | 54 | set(WHOLE_ARCHIVE_START "-Wl,--whole-archive -Wl,--no-as-needed") 55 | set(WHOLE_ARCHIVE_END "-Wl,--no-whole-archive -Wl,--as-needed") 56 | if (APPLE) 57 | set(WHOLE_ARCHIVE_START "-Wl,-all_load") 58 | set(WHOLE_ARCHIVE_END "-Wl,-noall_load") 59 | endif() 60 | 61 | 62 | add_library(loop_tool SHARED ${CORE_SRCS} ${CPU_SRCS}) 63 | target_include_directories(loop_tool PRIVATE ${LIB_PRIVATE_INCLUDE}) 64 | target_include_directories(loop_tool PUBLIC 65 | "$" 66 | $ 67 | ) 68 | if (BUILD_LOOP_NEST) 69 | target_link_libraries(loop_tool PUBLIC dabun) 70 | endif() 71 | if (BUILD_WASM_RUNTIME) 72 | target_link_libraries(loop_tool PUBLIC iwasm_shared) 73 | target_include_directories(loop_tool PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/extern/wasm-micro-runtime/core/iwasm/include") 74 | endif() 75 | 76 | set(THREADS_PREFER_PTHREAD_FLAG ON) 77 | find_package(Threads REQUIRED) 78 | target_link_libraries(loop_tool PUBLIC ${CMAKE_DL_LIBS} PRIVATE Threads::Threads) 79 | 80 | list(APPEND CMAKE_MODULE_PATH "${SRC_DIR}/backends/cuda") 81 | if(DEFINED ENV{CUDA_PATH}) 82 | set(CUDAToolkit_ROOT $ENV{CUDA_PATH}) 83 | endif() 84 | find_package(CUDAToolkit) 85 | if (CUDAToolkit_FOUND) 86 | message("Found CUDA toolkit version ${CUDAToolkit_VERSION}") 87 | file(GLOB CUDA_SRCS ${SRC_DIR}/backends/cuda/*.cpp) 88 | 89 | add_library(loop_tool_cuda SHARED ${CUDA_SRCS}) 90 | target_include_directories(loop_tool_cuda PUBLIC ${LIB_INCLUDE} ${SRC_DIR}/backends/cuda ${CUDAToolkit_INCLUDE_DIRS}) 91 | #target_link_libraries(loop_tool_cuda CUDA::cudart_static CUDA::nvrtc) 92 | endif() 93 | 94 | option(BUILD_TESTS "Build all available tests" ON) 95 | if (BUILD_TESTS) 96 | 97 | set(TEST_DIR ${CMAKE_CURRENT_SOURCE_DIR}/test/) 98 | 99 | file(GLOB TEST_SRCS ${TEST_DIR}/test_*.cpp) 100 | list(APPEND TEST_SRCS "${TEST_DIR}/test.cpp") # main file 101 | add_library(loop_tool_test_utils "${TEST_DIR}/utils.cpp") 102 | target_include_directories(loop_tool_test_utils PUBLIC ${SRC_DIR}) 103 | 104 | if (BUILD_LOOP_NEST) 105 | list(APPEND TEST_SRCS "${TEST_DIR}/loop_nest_test.cpp") 106 | endif() 107 | 108 | if (BUILD_WASM) 109 | if (BUILD_LOOP_NEST) 110 | message(FATAL_ERROR "loop_nest doesn't have wasm backend yet, please rebuild with -DBUILD_LOOP_NEST=OFF or -DBUILD_WASM=OFF") 111 | endif() 112 | list(APPEND TEST_SRCS "${TEST_DIR}/wasm_test.cpp") 113 | endif() 114 | if (BUILD_WASM_RUNTIME) 115 | list(APPEND TEST_SRCS "${TEST_DIR}/wasm_runtime_test.cpp") 116 | endif() 117 | 118 | if (CUDAToolkit_FOUND) 119 | list(APPEND TEST_SRCS "${TEST_DIR}/cuda_test.cpp") 120 | endif() # CUDAToolkit_FOUND 121 | 122 | add_executable(loop_tool_test ${TEST_SRCS}) 123 | target_include_directories(loop_tool_test PUBLIC ${SRC_DIR}) 124 | target_link_libraries(loop_tool_test loop_tool_test_utils loop_tool) 125 | 126 | if (CUDAToolkit_FOUND) 127 | target_link_libraries(loop_tool_test 128 | ${WHOLE_ARCHIVE_START} 129 | loop_tool_cuda 130 | ${WHOLE_ARCHIVE_END} 131 | ) 132 | endif() # CUDAToolkit_FOUND 133 | 134 | endif() # BUILD_TESTS 135 | 136 | find_package(pybind11 CONFIG) 137 | if (pybind11_FOUND) 138 | message("Building python bindings...") 139 | file(GLOB PY_SRCS ${SRC_DIR}/frontends/python*.cpp) 140 | pybind11_add_module(loop_tool_py MODULE ${PY_SRCS}) 141 | target_include_directories(loop_tool_py PUBLIC ${LIB_INCLUDE}) 142 | if (CUDAToolkit_FOUND) 143 | target_compile_definitions(loop_tool_py PUBLIC ENABLE_CUDA) 144 | target_link_libraries(loop_tool_py PUBLIC 145 | -rdynamic 146 | ${WHOLE_ARCHIVE_START} 147 | loop_tool_cuda 148 | loop_tool 149 | ${WHOLE_ARCHIVE_END} 150 | ) 151 | else() 152 | target_link_libraries(loop_tool_py PUBLIC 153 | -rdynamic 154 | ${WHOLE_ARCHIVE_START} 155 | loop_tool 156 | ${WHOLE_ARCHIVE_END} 157 | ) 158 | endif() 159 | else() 160 | message("To build python bindings, pip install pybind11 and run `cmake .. -Dpybind11_DIR=$(python -c 'import pybind11;print(pybind11.get_cmake_dir())')`") 161 | endif() 162 | 163 | FUNCTION(PREPEND var prefix) 164 | SET(listVar "") 165 | FOREACH(f ${ARGN}) 166 | LIST(APPEND listVar "${prefix}/${f}") 167 | ENDFOREACH(f) 168 | SET(${var} "${listVar}" PARENT_SCOPE) 169 | ENDFUNCTION(PREPEND) 170 | 171 | if(DEFINED EMCC_DIR) 172 | SET(EMCC ${EMCC_DIR}/emcc) 173 | message("Using ${EMCC} to compile javascript bindings...") 174 | SET(EMCC_INCLUDE ${EMCC_DIR}/system/include) 175 | file(GLOB JS_SRCS ${SRC_DIR}/frontends/javascript*.cpp) 176 | SET(EMCC_FLAGS -s NO_DISABLE_EXCEPTION_CATCHING -s MODULARIZE -s SINGLE_FILE=1 -s "EXPORT_NAME=\"createMyModule\"" -s "TOTAL_MEMORY=268435456") 177 | set(EMCC_TARGET libloop_tool.js) 178 | if (BUILD_ES6) 179 | SET(EMCC_FLAGS ${EMCC_FLAGS} -s EXPORT_ES6=1) 180 | set(EMCC_TARGET libloop_tool.mjs) 181 | endif() 182 | if (CMAKE_BUILD_TYPE MATCHES Debug) 183 | message("Running a debug build for emcc...") 184 | if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 185 | SET(EMCC_FLAGS ${EMCC_FLAGS} -gsource-map) 186 | endif() 187 | SET(EMCC_FLAGS ${EMCC_FLAGS} -g -s ASSERTIONS=1 -s DEMANGLE_SUPPORT) 188 | else() 189 | SET(EMCC_FLAGS ${EMCC_FLAGS} -Oz) 190 | endif() 191 | PREPEND(INC "-I" ${LIB_INCLUDE} ${LIB_PRIVATE_INCLUDE}) 192 | add_custom_command(OUTPUT loop_tool_js_emcc 193 | COMMAND ${EMCC} -I${EMCC_INCLUDE} ${INC} ${CORE_SRCS} ${JS_SRCS} ${EMCC_FLAGS} -o ${EMCC_TARGET} --bind 194 | WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" 195 | DEPENDS "${CORE_SRCS}" 196 | COMMENT "Compiling with emcc" 197 | VERBATIM 198 | ) 199 | 200 | add_custom_target(loop_tool_js ALL DEPENDS loop_tool_js_emcc) 201 | endif() 202 | 203 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `loop_tool` 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 2 spaces for indentation rather than tabs 31 | * 80 character line length 32 | 33 | ## License 34 | By contributing to `loop_tool`, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `loop_tool` 2 | 3 | A tiny linear algebra code-generator and optimization toolkit. Try it out here: http://loop-tool.glitch.me 4 | 5 | 6 | ![](https://user-images.githubusercontent.com/4842908/174701426-5e479bad-821e-4395-ae1a-91c48776efd9.gif) 7 | 8 | Preprint on [arxiv](https://arxiv.org/abs/2205.00618) 9 | 10 | ## Installation 11 | 12 | ### C++: 13 | ```bash 14 | git clone https://github.com/facebookresearch/loop_tool.git 15 | mkdir build 16 | cd build 17 | cmake .. -DCMAKE_BUILD_TYPE=Release 18 | make -j$(nproc) 19 | ``` 20 | 21 | ### Python: 22 | ```bash 23 | pip install loop_tool 24 | ``` 25 | 26 | ### JavaScript: 27 | ```bash 28 | curl -O -L https://github.com/facebookresearch/loop_tool/raw/main/javascript/lt.mjs.gz 29 | gunzip lt.mjs.gz 30 | ``` 31 | 32 | ## Import 33 | ### C++: 34 | ```cpp 35 | #include 36 | namespace lt = loop_tool; 37 | ``` 38 | ### Python: 39 | ```python 40 | import loop_tool as lt 41 | ``` 42 | ### JavaScript: 43 | ```javascript 44 | import * as lt from './lt.mjs'; 45 | ``` 46 | 47 | ## Usage 48 | 49 | ### C++: 50 | ```cpp 51 | #include 52 | namespace lz = ::loop_tool::lazy; 53 | 54 | auto mm = [](lz::Tensor A, lz::Tensor B) { 55 | lz::Symbol M, N, K; 56 | auto C = A.as(M, K) * B.as(K, N); 57 | return C.sum(K); 58 | }; 59 | 60 | lz::Tensor A(128, 128); 61 | lz::Tensor B(128, 128); 62 | rand(A.data(), 128 * 128); 63 | rand(B.data(), 128 * 128); 64 | 65 | auto C = mm(A, B); 66 | std::cout << C.data()[0]; 67 | ``` 68 | 69 | ### Python: 70 | ```python 71 | import loop_tool as lt 72 | import numpy as np 73 | 74 | def mm(a, b): 75 | m, n, k = lt.symbols("m n k") 76 | return (a.to(m, k) * b.to(k, n)).sum(k) 77 | 78 | A_np = np.random.randn(128, 128) 79 | B_np = np.random.randn(128, 128) 80 | A = lt.Tensor(A_np) 81 | B = lt.Tensor(B_np) 82 | 83 | C = mm(A, B) 84 | print(C.numpy()[0]) 85 | ``` 86 | ### JavaScript: 87 | ```javascript 88 | import * as lt from './lt.mjs'; 89 | 90 | function mm(A, B) { 91 | const [m, n, k] = lt.symbols("m n k"); 92 | return A.to(m, k).mul(B.to(k, n)).sum(k); 93 | } 94 | 95 | const A = lt.tensor(128, 128); 96 | const B = lt.tensor(128, 128); 97 | A.set(new Float32Array(128 * 128, 1)); 98 | B.set(new Float32Array(128 * 128, 1)); 99 | 100 | const C = mm(A, B); 101 | console.log(await C.data()[0]); 102 | ``` 103 | 104 | ## Overview 105 | 106 | `loop_tool` is an experimental loop-based computation toolkit. 107 | Building on the fact that many useful operations (in linear algebra, neural networks, and media processing) 108 | can be written as highly optimized bounded loops, 109 | `loop_tool` is composed of two ideas: 110 | 111 | 1. A lazy symbolic frontend 112 | - Extension of typical eager interfaces (e.g. [Numpy](https://numpy.org) or earlier [PyTorch](https://pytorch.org)) 113 | - Symbolic shape deduction (including input shapes) 114 | - Transparent JIT compilation 115 | 2. A simple functional IR 116 | - Optimized through local node-level annotations 117 | - Lowered to various backends (currently C, WASM) 118 | 119 | Additionally, a curses-based UI is provided (in Python `lt.ui(tensor)`) to interactively optimize loop structures on the fly: 120 | 121 | https://user-images.githubusercontent.com/4842908/172877041-fd4b9ed8-7164-49f6-810f-b14f169e2ca9.mp4 122 | 123 | ## Arithmetic Support 124 | 125 | `loop_tool` is built on fixed-range loop iterations over quasi-linear index equations. Within these loops, the following operations are currently supported: 126 | 127 | - `copy` 128 | - `add` 129 | - `subtract` 130 | - `multiply` 131 | - `divide` 132 | - `min` 133 | - `max` 134 | - `log` 135 | - `exp` 136 | - `sqrt` 137 | - `abs` 138 | - `negate` 139 | - `reciprocal` 140 | 141 | It is straight-forward to add arithmetic support if needed. 142 | 143 | With the constraint-based quasi-linear indexing system, far more complex operations can easily be defined, such as 144 | 145 | - Padding 146 | - Concatenation 147 | - Matrix multiplication 148 | - Convolution (group, strided, dilated) 149 | - Pooling (max, average) 150 | 151 | ## License 152 | 153 | loop_tool is MIT licensed, as found in the LICENSE file. 154 | -------------------------------------------------------------------------------- /include/loop_tool/backend.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "ir.h" 14 | #include "tensor.h" 15 | 16 | namespace loop_tool { 17 | 18 | struct Compiled { 19 | virtual ~Compiled() {} 20 | virtual void run(const std::vector &memory, 21 | bool sync = true) const = 0; 22 | virtual std::string dump() const { 23 | return "[not implemented, override `std::string Compiled::dump() const`]"; 24 | } 25 | 26 | void operator()(const std::vector &tensors, bool sync = true) const; 27 | 28 | std::vector allocate(std::vector &sizes) const; 29 | 30 | template 31 | void run(Args const &...tensors) const { 32 | std::vector memory = {tensors.data.address...}; 33 | run(memory, sync); 34 | } 35 | 36 | template 37 | void operator()(Args const &...tensors) const { 38 | run(std::forward(tensors)...); 39 | } 40 | 41 | template 42 | void async(Args const &...tensors) const { 43 | run(std::forward(tensors)...); 44 | } 45 | 46 | std::unordered_map int_properties; 47 | std::unordered_map string_properties; 48 | int hardware_requirement = -1; 49 | std::string name; 50 | }; 51 | 52 | struct Backend { 53 | std::string name_; 54 | Backend(std::string name) : name_(name) {} 55 | virtual ~Backend(){}; 56 | 57 | const std::string &name() const { return name_; } 58 | 59 | virtual std::unique_ptr compile_impl(const LoopTree <) const = 0; 60 | virtual int hardware_requirement() const = 0; 61 | 62 | std::unique_ptr compile(const LoopTree <) const { 63 | auto compiled = compile_impl(lt); 64 | compiled->hardware_requirement = hardware_requirement(); 65 | compiled->name = name(); 66 | return compiled; 67 | } 68 | }; 69 | 70 | const std::unordered_map> &getBackends(); 71 | void registerBackend(std::shared_ptr backend); 72 | 73 | std::shared_ptr &getDefaultBackend(); 74 | void setDefaultBackend(std::string backend); 75 | 76 | struct ScopedBackend { 77 | std::string old_backend_name; 78 | ScopedBackend(std::string backend_name) { 79 | const auto &old_backend = getDefaultBackend(); 80 | old_backend_name = old_backend->name(); 81 | setDefaultBackend(backend_name); 82 | } 83 | ~ScopedBackend() { setDefaultBackend(old_backend_name); } 84 | }; 85 | 86 | struct RegisterBackend { 87 | RegisterBackend(std::shared_ptr backend) { 88 | registerBackend(backend); 89 | } 90 | }; 91 | 92 | void loadLibrary(std::string lib_name); 93 | 94 | } // namespace loop_tool 95 | -------------------------------------------------------------------------------- /include/loop_tool/compile.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "backend.h" 14 | #include "ir.h" 15 | 16 | #define MAX_DEPTH 8 17 | 18 | namespace loop_tool { 19 | 20 | using InnerFnType = 21 | std::function &, int[MAX_DEPTH])>; 22 | 23 | // Generates runnable code (there's also CodeGenerator, which generates text) 24 | class Compiler { 25 | public: 26 | struct Allocation { 27 | Allocation() = default; 28 | Allocation(int memory_idx, IR::NodeRef node_ref_) 29 | : mem_idx(memory_idx), node_ref(node_ref_) {} 30 | Allocation(int memory_idx, IR::NodeRef node_ref_, 31 | const std::vector &sizes_, LoopTree::TreeRef lca_) 32 | : mem_idx(memory_idx), sizes(sizes_), node_ref(node_ref_), lca(lca_) {} 33 | int mem_idx = -1; 34 | // scoped sizes 35 | std::vector sizes; 36 | std::vector strides; 37 | inline int64_t size(int start_idx = 0) const { 38 | int64_t s = 1; 39 | for (int i = start_idx; i < sizes.size(); ++i) { 40 | const auto &s_ = sizes.at(i); 41 | s *= std::max(s_, (int64_t)1); 42 | } 43 | return s; 44 | } 45 | IR::NodeRef node_ref = -1; 46 | LoopTree::TreeRef lca = -1; 47 | }; 48 | 49 | struct Access { 50 | Access(const Allocation &a) : alloc(a) {} 51 | Allocation alloc; 52 | // alloc (base vars) mapped to expr, max 53 | std::vector full_exprs; 54 | std::vector scoped_exprs; 55 | std::vector> bounds; 56 | }; 57 | 58 | // optionally always true, this is for cleanup 59 | mutable bool set_called = false; 60 | size_t count; 61 | LoopTree lt; 62 | std::unordered_map 63 | inner_sizes; // total size of inner loops over same var 64 | std::unordered_map allocations; 65 | std::unordered_map resolved_reads; 66 | std::unordered_map resolved_writes; 67 | std::unordered_map var_sizes; 68 | std::unordered_map var_to_sym; 69 | std::unordered_map> 71 | sym_to_var; 72 | 73 | Compiler(const LoopTree <_); 74 | virtual ~Compiler() = default; 75 | 76 | Allocation gen_alloc(IR::NodeRef node_ref) const; 77 | 78 | std::pair, std::vector> 79 | gen_index_equations(IR::NodeRef read_node_ref, IR::NodeRef write_node_ref, 80 | LoopTree::TreeRef ref) const; 81 | // given a node used at point "ref", generate access information 82 | Access gen_access(IR::NodeRef node, LoopTree::TreeRef ref) const; 83 | 84 | symbolic::Expr get_scoped_expr(const Access &access) const; 85 | std::unordered_map>, 87 | symbolic::Hash> 88 | get_symbol_strides(LoopTree::TreeRef ref, LoopTree::TreeRef root) const; 89 | 90 | std::function &memory, 91 | int indices[MAX_DEPTH])> 92 | gen_access_fn(const Access &access, LoopTree::TreeRef ref) const; 93 | std::vector> get_constraints( 94 | const Access &access) const; 95 | 96 | InnerFnType gen_reset(LoopTree::TreeRef ref) const; 97 | 98 | symbolic::Expr reify_sizes(const symbolic::Expr &expr) const; 99 | int64_t get_expr_max(const symbolic::Expr &) const; 100 | int64_t get_expr_min(const symbolic::Expr &) const; 101 | 102 | InnerFnType gen_mem_node(LoopTree::TreeRef ref) const; 103 | InnerFnType gen_binary_node(LoopTree::TreeRef ref) const; 104 | InnerFnType gen_unary_node(LoopTree::TreeRef ref) const; 105 | InnerFnType gen_node(LoopTree::TreeRef ref) const; 106 | 107 | InnerFnType gen_loop(LoopTree::TreeRef ref, 108 | std::unordered_map overrides) const; 109 | InnerFnType gen_backend_exec(LoopTree::TreeRef ref, 110 | std::unordered_map overrides, 111 | const std::string &backend) const; 112 | 113 | InnerFnType gen_exec( 114 | LoopTree::TreeRef ref = -1, 115 | std::unordered_map overrides = {}) const; 116 | virtual std::string gen_string() const; 117 | 118 | std::vector allocate() const; 119 | std::vector memory_sizes(bool include_io = false) const; 120 | }; 121 | 122 | struct CPUInterpretedBackend : public Backend { 123 | CPUInterpretedBackend() : Backend("cpu_interpreted") {} 124 | ~CPUInterpretedBackend() {} 125 | CPUInterpretedBackend(std::string name) : Backend(name) {} 126 | 127 | std::unique_ptr compile_impl(const LoopTree <) const override; 128 | int hardware_requirement() const override; 129 | }; 130 | 131 | } // namespace loop_tool 132 | -------------------------------------------------------------------------------- /include/loop_tool/cpp.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include "loop_tool/compile.h" 10 | 11 | namespace loop_tool { 12 | 13 | class CppCompiler : public Compiler { 14 | public: 15 | CppCompiler(const LoopTree& lt); 16 | inline std::string gen_string() const override { return gen_string_impl(); } 17 | 18 | private: 19 | std::string gen_string_impl( 20 | LoopTree::TreeRef ref = -1, 21 | std::unordered_map overrides = {}) const; 22 | bool is_input_output(IR::NodeRef nr) const; 23 | std::string gen_access_string(IR::NodeRef node_ref, 24 | LoopTree::TreeRef ref) const; 25 | std::string gen_reset_string(LoopTree::TreeRef ref) const; 26 | std::string gen_mem_node_string(LoopTree::TreeRef ref) const; 27 | std::string gen_compute_node_string(LoopTree::TreeRef ref) const; 28 | inline std::string gen_indent(LoopTree::TreeRef ref, int extra = 0) const { 29 | auto depth = ((ref == -1) ? 0 : lt.depth(ref) + 1); 30 | return std::string((depth + extra) * 2, ' '); 31 | } 32 | std::string gen_node_string(LoopTree::TreeRef ref) const; 33 | std::string gen_loop_string( 34 | LoopTree::TreeRef ref, 35 | std::unordered_map overrides) const; 36 | }; 37 | 38 | struct CppBackend : public Backend { 39 | CppBackend() : Backend("cpp") {} 40 | ~CppBackend() {} 41 | CppBackend(std::string name) : Backend(name) {} 42 | 43 | std::unique_ptr compile_impl(const LoopTree& lt) const override; 44 | int hardware_requirement() const override; 45 | }; 46 | 47 | } // namespace loop_tool 48 | -------------------------------------------------------------------------------- /include/loop_tool/dynlib.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | #include 9 | 10 | #include 11 | 12 | #include "error.h" 13 | 14 | namespace loop_tool { 15 | 16 | struct DynamicLibrary { 17 | DynamicLibrary(const char* name, bool expose_symbols = false) : name_(name) { 18 | int symbol_flag = expose_symbols ? RTLD_GLOBAL : RTLD_LOCAL; 19 | lib_ = dlopen(name, symbol_flag | RTLD_NOW); 20 | ASSERT(lib_) << "Couldn't load library " << name_ 21 | << " dlerror: " << dlerror(); 22 | } 23 | 24 | static bool exists(const char* name, bool expose_symbols = false) { 25 | int symbol_flag = expose_symbols ? RTLD_GLOBAL : RTLD_LOCAL; 26 | return !!dlopen(name, symbol_flag | RTLD_NOW); 27 | } 28 | 29 | inline void* sym(const char* sym_name) const { 30 | ASSERT(lib_) << "Library " << name_ << " not loaded for symbol " 31 | << sym_name; 32 | auto* symbol = dlsym(lib_, sym_name); 33 | ASSERT(symbol) << "Couldn't find " << sym_name << " in " << name_; 34 | return symbol; 35 | } 36 | 37 | template 38 | F sym(const char* sym_name) const { 39 | F fn; 40 | auto fnp = sym(sym_name); 41 | reinterpret_cast(fn) = fnp; 42 | return fn; 43 | } 44 | 45 | ~DynamicLibrary() { dlclose(lib_); } 46 | 47 | private: 48 | void* lib_ = nullptr; 49 | std::string name_; 50 | }; 51 | 52 | #define DYNLIB(lib, name) reinterpret_cast(lib->sym(#name)) 53 | 54 | } // namespace loop_tool 55 | -------------------------------------------------------------------------------- /include/loop_tool/error.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #define S1(x) #x 15 | #define S2(x) S1(x) 16 | #define LOCATION __FILE__ ":" S2(__LINE__) 17 | 18 | namespace loop_tool { 19 | 20 | struct NullStream { 21 | template 22 | NullStream &operator<<(T const &) { 23 | return *this; 24 | } 25 | }; 26 | 27 | struct StreamOut : public NullStream { 28 | std::stringstream ss; 29 | bool failure = false; 30 | 31 | StreamOut(bool pass, std::string location, std::string cond = "") 32 | : failure(!pass) { 33 | if (failure && cond.size()) { 34 | ss << "assertion: " << cond << " "; 35 | } 36 | ss << "failed @ " << location << " "; 37 | } 38 | 39 | template 40 | StreamOut &operator<<(const T &d) { 41 | if (failure) { 42 | ss << d; 43 | } 44 | return *this; 45 | } 46 | 47 | ~StreamOut() noexcept(false) { 48 | if (failure) { 49 | throw std::runtime_error(ss.str()); 50 | } 51 | } 52 | }; 53 | 54 | } // namespace loop_tool 55 | 56 | #ifdef NOEXCEPTIONS 57 | #define ASSERT(x) loop_tool::NullStream() 58 | #else 59 | #define ASSERT(x) \ 60 | if (!(x)) loop_tool::StreamOut(x, LOCATION, #x) 61 | #endif 62 | -------------------------------------------------------------------------------- /include/loop_tool/hardware.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "error.h" 16 | 17 | namespace loop_tool { 18 | 19 | struct Memory { 20 | int compatible = 0; 21 | void *address = 0; 22 | }; 23 | 24 | struct Hardware { 25 | std::string name_; 26 | int count_; 27 | int id_ = 0; // default for CPU 28 | 29 | Hardware(std::string name, int count) : name_(name), count_(count) {} 30 | 31 | void setId(int id) { 32 | id_ = id; 33 | ASSERT(id >= 0 && id < 32) << "Invalid ID for hardware: " << id; 34 | } 35 | 36 | // Allocation must be compatible with CPU 37 | virtual Memory alloc(size_t size) = 0; 38 | virtual void free(Memory &data) = 0; 39 | 40 | // TODO 41 | // virtual Memory copy(const Memory& data) = 0; 42 | // virtual Memory move(const Memory& data) = 0; 43 | 44 | bool compatible(const Memory &m) const { return m.compatible & (1 << id_); }; 45 | const std::string &name() const { return name_; } 46 | int id() const { return id_; } 47 | int count() const { return count_; } 48 | }; 49 | 50 | int availableCPUs(); 51 | 52 | struct CPUHardware : public Hardware { 53 | CPUHardware() : Hardware("cpu", availableCPUs()) {} 54 | 55 | Memory alloc(size_t size) override { return Memory{0x1, malloc(size)}; } 56 | 57 | void free(Memory &data) override { 58 | ::free(data.address); 59 | data.address = nullptr; 60 | data.compatible = 0; 61 | } 62 | 63 | static Hardware *create() { return new CPUHardware(); } 64 | }; 65 | 66 | const std::vector> &getHardware(); 67 | int getAvailableHardware(); 68 | 69 | int &getDefaultHardwareId(); 70 | void setDefaultHardwareId(int id); 71 | const std::shared_ptr &getDefaultHardware(); 72 | 73 | void registerHardware(std::shared_ptr hw); 74 | 75 | struct RegisterHardware { 76 | RegisterHardware(std::shared_ptr hw) { registerHardware(hw); } 77 | }; 78 | 79 | } // namespace loop_tool 80 | -------------------------------------------------------------------------------- /include/loop_tool/loop_tool.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include "backend.h" 10 | #include "compile.h" 11 | #include "hardware.h" 12 | #include "ir.h" 13 | #include "serialization.h" 14 | #include "tensor.h" 15 | // default frontends 16 | #include "lazy.h" 17 | #include "mutate.h" 18 | #include "nn.h" 19 | -------------------------------------------------------------------------------- /include/loop_tool/mutate.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | 11 | namespace loop_tool { 12 | 13 | IR split_node(const IR& ir, IR::NodeRef node_ref, 14 | std::vector injected); 15 | IR split_var(const IR& ir, IR::VarRef v); 16 | IR swap_vars(const IR& ir, IR::NodeRef node_ref, IR::VarRef a, IR::VarRef b); 17 | 18 | // split out a subtree at the ref 19 | LoopTree subtree(const LoopTree& lt, LoopTree::TreeRef ref, 20 | std::unordered_map node_map = {}, 21 | std::unordered_map var_map = {}); 22 | 23 | LoopTree split(const LoopTree& lt, LoopTree::TreeRef ref, int64_t size); 24 | // merges upward 25 | LoopTree merge(const LoopTree& lt, LoopTree::TreeRef ref); 26 | LoopTree copy_input(const LoopTree& lt, LoopTree::TreeRef ref, int idx); 27 | LoopTree delete_copy(const LoopTree& lt, LoopTree::TreeRef ref); 28 | // generic swap for addressable loops and nodes, may fail silently 29 | LoopTree try_swap(const LoopTree& lt, LoopTree::TreeRef a, LoopTree::TreeRef b); 30 | LoopTree swap_loops(const LoopTree& lt, LoopTree::TreeRef a, 31 | LoopTree::TreeRef b); 32 | LoopTree add_loop(const LoopTree& lt, LoopTree::TreeRef ref, 33 | LoopTree::TreeRef add); 34 | LoopTree remove_loop(const LoopTree& lt, LoopTree::TreeRef ref, 35 | LoopTree::TreeRef rem); 36 | LoopTree swap_nodes(const LoopTree& lt, LoopTree::TreeRef a, 37 | LoopTree::TreeRef b); 38 | LoopTree swap_vars(const LoopTree& lt, IR::NodeRef node_ref, IR::VarRef a, 39 | IR::VarRef b); 40 | LoopTree disable_reuse(const LoopTree& lt, LoopTree::TreeRef loop, 41 | IR::NodeRef n); 42 | LoopTree enable_reuse(const LoopTree& lt, LoopTree::TreeRef loop, 43 | IR::NodeRef n); 44 | LoopTree decrease_reuse(const LoopTree& lt, LoopTree::TreeRef ref); 45 | LoopTree increase_reuse(const LoopTree& lt, LoopTree::TreeRef ref); 46 | LoopTree::TreeRef next_ref(const LoopTree& lt, LoopTree::TreeRef ref); 47 | LoopTree::TreeRef previous_ref(const LoopTree& lt, LoopTree::TreeRef ref); 48 | 49 | LoopTree annotate(const LoopTree& lt, LoopTree::TreeRef ref, std::string annot); 50 | // map an old ref to a close new ref after mutation, return the new ref 51 | LoopTree::TreeRef map_ref(const LoopTree& new_lt, LoopTree::TreeRef old_ref, 52 | const LoopTree& old_lt); 53 | 54 | LoopTree maximize_reuse(const LoopTree& lt); 55 | LoopTree unroll_inner_loops(const LoopTree& lt, int32_t unroll_amount); 56 | 57 | // Informational functions 58 | int64_t flops(const LoopTree& lt); 59 | bool is_trivially_parallel(const LoopTree& lt, LoopTree::TreeRef ref); 60 | std::vector find(const IR& ir, Operation op); 61 | 62 | } // namespace loop_tool 63 | -------------------------------------------------------------------------------- /include/loop_tool/nn.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | 11 | namespace loop_tool { 12 | namespace nn { 13 | 14 | lazy::Tensor convolve(lazy::Tensor X, lazy::Tensor W, 15 | std::vector spatial_dims, 16 | std::vector window_dims, 17 | int stride = 1); 18 | lazy::Tensor maxpool(lazy::Tensor X, std::vector spatial_dims, 19 | int k, int stride = 1); 20 | 21 | } // namespace nn 22 | } // namespace loop_tool 23 | -------------------------------------------------------------------------------- /include/loop_tool/serialization.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | 11 | namespace loop_tool { 12 | 13 | std::string serialize(const IR& ir); 14 | IR deserialize(const std::string& str); 15 | 16 | } // namespace loop_tool 17 | -------------------------------------------------------------------------------- /include/loop_tool/smallvec.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | 11 | template 12 | class smallvec { 13 | private: 14 | static_assert(Nm > 0, "Smallvec only supports non-zero sizes"); 15 | 16 | public: 17 | using value_type = Tp; 18 | using pointer = value_type *; 19 | using const_pointer = value_type const *; 20 | using reference = value_type &; 21 | using const_reference = value_type const &; 22 | using iterator = value_type *; 23 | using const_iterator = value_type const *; 24 | using size_type = std::size_t; 25 | using difference_type = std::ptrdiff_t; 26 | using reverse_iterator = std::reverse_iterator; 27 | using const_reverse_iterator = std::reverse_iterator; 28 | 29 | private: 30 | size_type size_ = 0; 31 | Tp elements_[Nm]; 32 | 33 | public: 34 | constexpr pointer data() noexcept { return const_cast(elements_); } 35 | constexpr const_pointer data() const noexcept { 36 | return const_cast(elements_); 37 | } 38 | constexpr const_pointer cdata() const noexcept { 39 | return const_cast(elements_); 40 | } 41 | 42 | // clang-format off 43 | 44 | // Iterators 45 | constexpr iterator begin() noexcept { return iterator(data()); } 46 | constexpr const_iterator begin() const noexcept { return const_iterator(data()); } 47 | constexpr const_iterator cbegin() const noexcept { return const_iterator(data()); } 48 | constexpr iterator end() noexcept { return iterator(data() + size()); } 49 | constexpr const_iterator end() const noexcept { return const_iterator(data() + size()); } 50 | constexpr const_iterator cend() const noexcept { return const_iterator(data() + size()); } 51 | 52 | constexpr reverse_iterator rbegin() noexcept { return reverse_iterator(end()); } 53 | constexpr const_reverse_iterator rbegin() const noexcept { return const_reverse_iterator(end()); } 54 | constexpr const_reverse_iterator crbegin() const noexcept { return const_reverse_iterator(end()); } 55 | constexpr reverse_iterator rend() noexcept { return reverse_iterator(begin()); } 56 | constexpr const_reverse_iterator rend() const noexcept { return const_reverse_iterator(begin()); } 57 | constexpr const_reverse_iterator crend() const noexcept { return const_reverse_iterator(begin()); } 58 | 59 | // Capacity 60 | constexpr size_type size() const noexcept { return size_; } 61 | constexpr size_type max_size() const noexcept { return Nm; } 62 | [[nodiscard]] 63 | constexpr bool empty() const noexcept { return size() == 0; } 64 | 65 | // Element access 66 | constexpr reference operator[](size_type n) noexcept { return elements_[n]; } 67 | constexpr const_reference operator[](size_type n) const noexcept { return elements_[n]; } 68 | 69 | constexpr reference front() noexcept { return elements_[0]; } 70 | constexpr const_reference front() const noexcept { return elements_[0]; } 71 | 72 | constexpr reference back() noexcept { return elements_[size_ - 1]; } 73 | constexpr const_reference back() const noexcept { return elements_[size_ - 1]; } 74 | 75 | // clang-format on 76 | 77 | constexpr reference at(size_type n) { 78 | if (n >= size_) { 79 | throw std::out_of_range( 80 | "vec::at out of range"); // TODO(zi) provide size() and n; 81 | } 82 | return elements_[n]; 83 | } 84 | 85 | constexpr const_reference at(size_type n) const { 86 | if (n >= size_) { 87 | throw std::out_of_range( 88 | "vec::at out of range"); // TODO(zi) provide size() and n; 89 | } 90 | return elements_[n]; 91 | } 92 | 93 | constexpr void fill(value_type const &v) { std::fill_n(begin(), size(), v); } 94 | 95 | constexpr void swap(smallvec &other) { 96 | std::swap_ranges(begin(), end(), other.begin()); 97 | std::swap(size_, other.size_); 98 | } 99 | 100 | private: 101 | void destruct_elements() { 102 | for (size_type i = 0; i < size_; ++i) { 103 | (elements_ + i)->~Tp(); 104 | } 105 | } 106 | 107 | public: 108 | // Constructors 109 | constexpr smallvec() {} 110 | ~smallvec() { /*destruct_elements();*/ 111 | } 112 | 113 | constexpr smallvec(smallvec const &other) noexcept( 114 | std::is_nothrow_copy_constructible::value) 115 | : size_(other.size_) { 116 | for (size_type i = 0; i < size_; ++i) { 117 | new (elements_ + i) Tp(other.elements_[i]); 118 | } 119 | } 120 | 121 | constexpr smallvec &operator=(smallvec const &other) noexcept( 122 | std::is_nothrow_copy_constructible::value) { 123 | destruct_elements(); 124 | size_ = other.size_; 125 | for (size_type i = 0; i < size_; ++i) { 126 | new (elements_ + i) Tp(other.elements_[i]); 127 | } 128 | return *this; 129 | } 130 | 131 | constexpr smallvec(smallvec &&other) noexcept( 132 | std::is_nothrow_move_constructible::value) { 133 | destruct_elements(); 134 | size_ = std::exchange(other.size_, 0); 135 | for (size_type i = 0; i < size_; ++i) { 136 | new (elements_ + i) Tp(std::move(other.elements_[i])); 137 | } 138 | } 139 | 140 | // todo(zi) call operator= on elements when present in both this and other 141 | constexpr smallvec &operator=(smallvec &&other) noexcept( 142 | std::is_nothrow_move_constructible::value) { 143 | destruct_elements(); 144 | size_ = other.size_; 145 | for (size_type i = 0; i < size_; ++i) { 146 | new (elements_ + i) Tp(std::move(other.elements_[i])); 147 | } 148 | return *this; 149 | } 150 | 151 | constexpr void clear() noexcept { 152 | destruct_elements(); 153 | size_ = 0; 154 | } 155 | 156 | constexpr void push_back(Tp const &value) { 157 | if (size_ >= max_size()) { 158 | throw std::out_of_range("..."); // TODO(zi) provide size() and n; 159 | } 160 | 161 | new (elements_ + size_++) Tp(value); 162 | } 163 | 164 | constexpr void push_back(Tp &&value) { 165 | if (size_ >= max_size()) { 166 | throw std::out_of_range("..."); // TODO(zi) provide size() and n; 167 | } 168 | 169 | new (elements_ + size_++) Tp(std::move(value)); 170 | } 171 | 172 | template 173 | constexpr reference emplace_back(Args &&...args) { 174 | if (size_ >= max_size()) { 175 | throw std::out_of_range("..."); // TODO(zi) provide size() and n; 176 | } 177 | new (elements_ + size_++) Tp(std::forward(args)...); 178 | 179 | return this->operator[](size_ - 1); 180 | } 181 | 182 | constexpr void pop_back() { 183 | --size_; 184 | (elements_ + size_)->~Tp(); 185 | } 186 | 187 | constexpr void resize(size_type count) { 188 | if (count > max_size()) { 189 | throw std::out_of_range("..."); // TODO(zi) provide size() and n; 190 | } 191 | 192 | if (size_ > count) { 193 | for (size_type i = count; i < size_; ++i) { 194 | (elements_ + i)->~Tp(); 195 | } 196 | } else if (size_ < count) { 197 | for (size_type i = size_; i < count; ++i) { 198 | new (elements_ + i) Tp; 199 | } 200 | } 201 | size_ = count; 202 | } 203 | 204 | constexpr void resize(size_type count, value_type const &other) { 205 | if (count > max_size()) { 206 | throw std::out_of_range("..."); // TODO(zi) provide size() and n; 207 | } 208 | 209 | if (size_ > count) { 210 | for (size_type i = count; i < size_; ++i) { 211 | (elements_ + i)->~Tp(); 212 | } 213 | } else if (size_ < count) { 214 | for (size_type i = size_; i < count; ++i) { 215 | new (elements_ + i) Tp(other); 216 | } 217 | } 218 | size_ = count; 219 | } 220 | }; 221 | -------------------------------------------------------------------------------- /include/loop_tool/symbolic.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "error.h" 18 | #include "smallvec.h" 19 | 20 | namespace loop_tool { 21 | namespace symbolic { 22 | 23 | inline uint64_t hash(uint64_t x) { 24 | x += 1337; 25 | x = (x ^ (x >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); 26 | x = (x ^ (x >> 27)) * UINT64_C(0x94d049bb133111eb); 27 | x = x ^ (x >> 31); 28 | return x; 29 | } 30 | 31 | inline uint64_t hash_combine(uint64_t a, uint64_t b) { 32 | std::hash hasher; 33 | const uint64_t kMul = 0x9ddfea08eb382d69ULL; 34 | uint64_t x = (hasher(b) ^ a) * kMul; 35 | x ^= (x >> 47); 36 | uint64_t y = (a ^ x) * kMul; 37 | y ^= (y >> 47); 38 | return y * kMul; 39 | } 40 | 41 | template 42 | struct Hash { 43 | std::size_t operator()(const T& k) const { return k.hash(); } 44 | }; 45 | 46 | enum struct Op { 47 | // no inputs 48 | constant = 0, 49 | // unary 50 | negate, 51 | reciprocal, 52 | size, 53 | max, 54 | // binary 55 | add, 56 | multiply, 57 | divide, 58 | modulo 59 | }; 60 | 61 | struct Expr; 62 | 63 | struct Symbol { 64 | // TODO replace with smaller construct 65 | std::string name_; 66 | int32_t id_ = -1; 67 | Symbol() : id_(getNewId()), name_("X") {} 68 | Symbol(std::string name) : id_(getNewId()), name_(name) {} 69 | Symbol(const Symbol& s) : id_(s.id_), name_(s.name_) {} 70 | static const int32_t getNewId(); 71 | const int32_t id() const; 72 | size_t hash() const; 73 | bool operator==(const Symbol& s) const; 74 | bool operator!=(const Symbol& s) const; 75 | std::string name() const; 76 | 77 | operator Expr() const; 78 | Expr operator+(const Symbol& rhs) const; 79 | Expr operator*(const Symbol& rhs) const; 80 | Expr operator+(const Expr& rhs) const; 81 | Expr operator*(const Expr& rhs) const; 82 | }; 83 | 84 | struct Expr; 85 | 86 | struct ExprImpl { 87 | enum class Type { value, symbol, function } type_; 88 | Op op_ = Op::constant; 89 | int64_t val_; 90 | Symbol symbol_; 91 | smallvec, 2> args_; 92 | uint64_t hash_ = 0; 93 | uint64_t symbol_hash_ = 0; 94 | bool simplified_ = false; 95 | explicit ExprImpl(int64_t val) 96 | : type_(Type::value), val_(val), simplified_(true) { 97 | init(); 98 | } 99 | explicit ExprImpl(const Symbol& symbol) 100 | : type_(Type::symbol), symbol_(symbol), simplified_(true) { 101 | init(); 102 | } 103 | explicit ExprImpl(Op op, const Expr&, bool simplified = false); 104 | explicit ExprImpl(Op op, const Expr&, const Expr&, bool simplified = false); 105 | void init(); 106 | inline uint64_t hash(bool symbol_sensitive) { 107 | if (symbol_sensitive) { 108 | return symbol_hash_; 109 | } 110 | return hash_; 111 | } 112 | 113 | bool contains(const Symbol& s) const { 114 | switch (type_) { 115 | case Type::symbol: 116 | if (symbol_ == s) { 117 | return true; 118 | } 119 | return false; 120 | case Type::function: { 121 | for (const auto& arg : args_) { 122 | if (arg->contains(s)) { 123 | return true; 124 | } 125 | } 126 | } 127 | default: 128 | return false; 129 | } 130 | } 131 | }; 132 | 133 | struct Expr { 134 | std::shared_ptr impl_; 135 | using Type = ExprImpl::Type; 136 | 137 | explicit Expr() : impl_(std::make_shared(-1)) {} 138 | explicit Expr(std::shared_ptr impl) : impl_(impl) {} 139 | 140 | template 141 | explicit Expr(Args... args) 142 | : impl_(std::make_shared(std::forward(args)...)) {} 143 | 144 | inline smallvec args() const { 145 | smallvec out; 146 | for (const auto& impl : impl_->args_) { 147 | out.emplace_back(Expr(impl)); 148 | } 149 | return out; 150 | } 151 | 152 | Expr arg(int idx) const { return Expr(impl_->args_.at(idx)); } 153 | 154 | inline const smallvec, 2>& impl_args() const { 155 | return impl_->args_; 156 | } 157 | 158 | bool simplified() const { return impl_->simplified_; } 159 | 160 | inline Type type() const { return impl_->type_; } 161 | inline Op op() const { return impl_->op_; } 162 | inline int64_t value() const { 163 | if (type() != Type::value) { 164 | ASSERT(type() == Type::value) 165 | << "attempted to get real value from symbolic or unsimplified " 166 | "expression: " 167 | << dump(); 168 | } 169 | return impl_->val_; 170 | } 171 | 172 | inline const Symbol& symbol() const { 173 | if (type() != Type::symbol) { 174 | ASSERT(type() == Type::symbol) 175 | << "attempted to get symbol from value or unsimplified " 176 | "expression: " 177 | << dump(); 178 | } 179 | return impl_->symbol_; 180 | } 181 | 182 | Expr walk(std::function f) const; 183 | void visit(std::function f) const; 184 | Expr replace(Symbol A, Symbol B) const; 185 | Expr replace(Symbol A, Expr e) const; 186 | Expr replace(const Expr& e, Symbol B) const; 187 | Expr replace(const Expr& e, int64_t c) const; 188 | Expr replace(Symbol A, int64_t c) const; 189 | 190 | std::string dump(bool short_form = false, 191 | const std::unordered_map>& 192 | replacements = {}) const; 193 | 194 | inline uint64_t hash(bool symbol_sensitive = false) const { 195 | return impl_->hash(symbol_sensitive); 196 | } 197 | 198 | inline size_t contains(const Symbol& s) const { return impl_->contains(s); } 199 | std::vector symbols(bool include_sized = true) const; 200 | 201 | inline Expr operator+(const Expr& rhs) const { 202 | return Expr(Op::add, *this, rhs); 203 | } 204 | inline Expr operator*(const Expr& rhs) const { 205 | return Expr(Op::multiply, *this, rhs); 206 | } 207 | inline Expr operator-() const { return Expr(Op::negate, *this); } 208 | inline Expr operator-(const Expr& rhs) const { return *this + (-rhs); } 209 | inline Expr operator/(const Expr& rhs) const { 210 | return Expr(Op::divide, *this, rhs); 211 | } 212 | inline Expr operator%(const Expr& rhs) const { 213 | return Expr(Op::modulo, *this, rhs); 214 | } 215 | static Expr size(const Expr& arg) { 216 | ASSERT(arg.type() == Type::symbol); 217 | return Expr(Op::size, arg); 218 | } 219 | static Expr max(const Expr& lhs, const Expr& rhs) { 220 | return Expr(Op::max, lhs, rhs); 221 | } 222 | inline Expr reciprocal() const { 223 | if (type() == Type::value) { 224 | ASSERT(value() != 0) << "cannot calculate 1/0"; 225 | } 226 | return Expr(Op::reciprocal, *this); 227 | } 228 | bool operator!=(const Expr& rhs) const; 229 | bool operator==(const Expr& rhs) const; 230 | Expr simplify() const; 231 | bool can_evaluate() const; 232 | float evaluate() const; 233 | }; 234 | 235 | // This might seem generic, but it should be limited to either: 236 | // - Expr(Symbol) -> Expr 237 | // - Expr::size(Symbol) -> Expr 238 | using Constraint = std::pair; 239 | 240 | std::vector unify(std::vector constraints); 241 | bool can_isolate(const Constraint& c, const Symbol& sym); 242 | Constraint isolate(const Constraint& c, const Symbol& sym); 243 | 244 | std::vector evaluate( 245 | const std::vector& old_constraints); 246 | 247 | Expr differentiate(Expr, Symbol); 248 | // zero out every symbol 249 | Expr intercept(Expr); 250 | 251 | } // namespace symbolic 252 | } // namespace loop_tool 253 | -------------------------------------------------------------------------------- /include/loop_tool/tensor.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | #include 9 | #include 10 | 11 | #include "hardware.h" 12 | 13 | namespace loop_tool { 14 | 15 | struct Tensor { 16 | Tensor(size_t N, int hardware = 0); 17 | Tensor() = delete; 18 | Tensor(const Tensor &) = delete; 19 | Tensor(Tensor &&) = default; 20 | ~Tensor(); 21 | int hardware_id = -1; 22 | Memory data; 23 | size_t numel; 24 | }; 25 | 26 | } // namespace loop_tool 27 | -------------------------------------------------------------------------------- /include/loop_tool/wasm.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include "loop_tool/compile.h" 10 | #include "wasmblr.h" 11 | 12 | namespace loop_tool { 13 | 14 | class WebAssemblyCompiler : public Compiler { 15 | mutable std::shared_ptr cg; 16 | std::unordered_set stack_storage; 17 | std::unordered_set local_storage; 18 | std::unordered_map stack_vector_storage; 19 | std::unordered_map local_vector_storage; 20 | std::unordered_set vectorized_loops; 21 | mutable std::unordered_set stack_f32; 22 | mutable std::unordered_set stack_v128; 23 | mutable int32_t tmp_i32; 24 | mutable int32_t tmp_f32; 25 | mutable int32_t tmp_v128; 26 | mutable std::unordered_map> local_f32; 27 | mutable std::unordered_map> local_v128; 28 | mutable std::unordered_map iterators; 29 | mutable std::unordered_map memory_locations; 30 | 31 | public: 32 | WebAssemblyCompiler(const LoopTree& lt); 33 | 34 | int64_t get_unroll_offset( 35 | IR::NodeRef node_ref, LoopTree::TreeRef ref, 36 | const std::unordered_map& unrolls) const; 37 | int64_t get_unroll_offset( 38 | IR::NodeRef node_ref, LoopTree::TreeRef ref, LoopTree::TreeRef root, 39 | const symbolic::Expr& idx_expr, 40 | const std::unordered_map& unrolls) const; 41 | 42 | void push_expr_to_stack( 43 | const symbolic::Expr& idx_expr, 44 | std::unordered_map>, 46 | symbolic::Hash> 47 | sym_strides, 48 | 49 | std::unordered_map unrolls, 50 | int32_t base_stride) const; 51 | bool push_constraints_to_stack( 52 | IR::NodeRef node_ref, LoopTree::TreeRef ref, 53 | std::unordered_map unrolls) const; 54 | int32_t push_access_to_stack( 55 | IR::NodeRef node_ref, LoopTree::TreeRef ref, 56 | std::unordered_map unrolls) const; 57 | void push_float_to_stack( 58 | IR::NodeRef node_ref, LoopTree::TreeRef ref, 59 | std::unordered_map unrolls, 60 | bool force_memory_load = false) const; 61 | void push_vector_to_stack( 62 | IR::NodeRef node_ref, LoopTree::TreeRef ref, 63 | std::unordered_map unrolls, IR::VarRef dim, 64 | bool force_memory_load = false) const; 65 | void store_float_from_stack( 66 | IR::NodeRef node_ref, LoopTree::TreeRef ref, 67 | std::unordered_map unrolls) const; 68 | void store_vector_from_stack( 69 | IR::NodeRef node_ref, LoopTree::TreeRef ref, 70 | std::unordered_map unrolls, 71 | IR::VarRef dim) const; 72 | int32_t get_tmp_i32() const; 73 | int32_t get_tmp_f32() const; 74 | int32_t get_tmp_v128() const; 75 | 76 | private: 77 | bool should_store_stack(IR::NodeRef node_ref) const; 78 | IR::VarRef should_store_vectorized_dim(IR::NodeRef node_ref) const; 79 | 80 | public: 81 | bool needs_reset(IR::NodeRef node_ref) const; 82 | 83 | void emit_node(LoopTree::TreeRef ref, 84 | std::unordered_map unrolls) const; 85 | void emit_reset(LoopTree::TreeRef ref) const; 86 | void emit_loop(LoopTree::TreeRef ref, 87 | std::unordered_map overrides, 88 | std::unordered_map unrolls) const; 89 | void emit(LoopTree::TreeRef ref, 90 | std::unordered_map overrides, 91 | std::unordered_map unrolls) const; 92 | std::vector emit() const; 93 | inline bool is_local(IR::NodeRef node_ref) const { 94 | return local_storage.count(node_ref) || 95 | local_vector_storage.count(node_ref); 96 | } 97 | inline bool is_on_stack(IR::NodeRef node_ref) const { 98 | return stack_storage.count(node_ref) || 99 | stack_vector_storage.count(node_ref); 100 | } 101 | inline bool is_vector_stored(IR::NodeRef node_ref) const { 102 | return local_vector_storage.count(node_ref) || 103 | stack_vector_storage.count(node_ref); 104 | } 105 | IR::VarRef vector_storage_dim(IR::NodeRef node_ref) const { 106 | if (local_vector_storage.count(node_ref)) { 107 | return local_vector_storage.at(node_ref); 108 | } 109 | if (stack_vector_storage.count(node_ref)) { 110 | return stack_vector_storage.at(node_ref); 111 | } 112 | return -1; 113 | } 114 | inline bool is_broadcast(IR::NodeRef node_ref) const { 115 | const auto& vs = lt.ir.node(node_ref).vars(); 116 | auto var = vector_storage_dim(node_ref); 117 | if (var == -1) { 118 | return false; 119 | } 120 | return !(vs.size() && vs.back() == var); 121 | } 122 | bool should_vectorize(LoopTree::TreeRef ref) const; 123 | 124 | void emit_vectorized_node( 125 | LoopTree::TreeRef ref, 126 | std::unordered_map unrolls) const; 127 | void emit_vectorized_loop( 128 | LoopTree::TreeRef ref, std::unordered_map overrides, 129 | std::unordered_map unrolls) const; 130 | }; 131 | 132 | } // namespace loop_tool 133 | -------------------------------------------------------------------------------- /install_cuda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -vex 3 | uname -a 4 | if [ "$(uname -m)" != "x86_64" ]; then 5 | exit 0 6 | fi 7 | 8 | cat /proc/version 9 | 10 | apt-get -qq update 11 | apt-get -yqq install wget libxml2-dev 12 | wget -q https://developer.download.nvidia.com/compute/cuda/11.6.0/local_installers/cuda_11.6.0_510.39.01_linux.run 13 | sh cuda_11.6.0_510.39.01_linux.run --help 14 | sh cuda_11.6.0_510.39.01_linux.run --silent --toolkit 15 | cat /var/log/cuda-installer.log 16 | 17 | ls /usr/local/cuda 18 | -------------------------------------------------------------------------------- /javascript/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | loop_tool demo 5 | 22 | 23 | 24 | webcam demo 25 |
26 | matrix multiplication + wasm demo 27 |
28 | reference tested matmul for perf/debugging 29 | 30 | 31 | -------------------------------------------------------------------------------- /javascript/lt.mjs.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/loop_tool/df22dd2d29ed96c4349c8ce3c55797781e647446/javascript/lt.mjs.gz -------------------------------------------------------------------------------- /javascript/lzstring.js: -------------------------------------------------------------------------------- 1 | var LZString=function(){function o(o,r){if(!t[o]){t[o]={};for(var n=0;ne;e++){var s=r.charCodeAt(e);n[2*e]=s>>>8,n[2*e+1]=s%256}return n},decompressFromUint8Array:function(o){if(null===o||void 0===o)return i.decompress(o);for(var n=new Array(o.length/2),e=0,t=n.length;t>e;e++)n[e]=256*o[2*e]+o[2*e+1];var s=[];return n.forEach(function(o){s.push(r(o))}),i.decompress(s.join(""))},compressToEncodedURIComponent:function(o){return null==o?"":i._compress(o,6,function(o){return e.charAt(o)})},decompressFromEncodedURIComponent:function(r){return null==r?"":""==r?null:(r=r.replace(/ /g,"+"),i._decompress(r.length,32,function(n){return o(e,r.charAt(n))}))},compress:function(o){return i._compress(o,16,function(o){return r(o)})},_compress:function(o,r,n){if(null==o)return"";var e,t,i,s={},p={},u="",c="",a="",l=2,f=3,h=2,d=[],m=0,v=0;for(i=0;ie;e++)m<<=1,v==r-1?(v=0,d.push(n(m)),m=0):v++;for(t=a.charCodeAt(0),e=0;8>e;e++)m=m<<1|1&t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t>>=1}else{for(t=1,e=0;h>e;e++)m=m<<1|t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t=0;for(t=a.charCodeAt(0),e=0;16>e;e++)m=m<<1|1&t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t>>=1}l--,0==l&&(l=Math.pow(2,h),h++),delete p[a]}else for(t=s[a],e=0;h>e;e++)m=m<<1|1&t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t>>=1;l--,0==l&&(l=Math.pow(2,h),h++),s[c]=f++,a=String(u)}if(""!==a){if(Object.prototype.hasOwnProperty.call(p,a)){if(a.charCodeAt(0)<256){for(e=0;h>e;e++)m<<=1,v==r-1?(v=0,d.push(n(m)),m=0):v++;for(t=a.charCodeAt(0),e=0;8>e;e++)m=m<<1|1&t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t>>=1}else{for(t=1,e=0;h>e;e++)m=m<<1|t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t=0;for(t=a.charCodeAt(0),e=0;16>e;e++)m=m<<1|1&t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t>>=1}l--,0==l&&(l=Math.pow(2,h),h++),delete p[a]}else for(t=s[a],e=0;h>e;e++)m=m<<1|1&t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t>>=1;l--,0==l&&(l=Math.pow(2,h),h++)}for(t=2,e=0;h>e;e++)m=m<<1|1&t,v==r-1?(v=0,d.push(n(m)),m=0):v++,t>>=1;for(;;){if(m<<=1,v==r-1){d.push(n(m));break}v++}return d.join("")},decompress:function(o){return null==o?"":""==o?null:i._decompress(o.length,32768,function(r){return o.charCodeAt(r)})},_decompress:function(o,n,e){var t,i,s,p,u,c,a,l,f=[],h=4,d=4,m=3,v="",w=[],A={val:e(0),position:n,index:1};for(i=0;3>i;i+=1)f[i]=i;for(p=0,c=Math.pow(2,2),a=1;a!=c;)u=A.val&A.position,A.position>>=1,0==A.position&&(A.position=n,A.val=e(A.index++)),p|=(u>0?1:0)*a,a<<=1;switch(t=p){case 0:for(p=0,c=Math.pow(2,8),a=1;a!=c;)u=A.val&A.position,A.position>>=1,0==A.position&&(A.position=n,A.val=e(A.index++)),p|=(u>0?1:0)*a,a<<=1;l=r(p);break;case 1:for(p=0,c=Math.pow(2,16),a=1;a!=c;)u=A.val&A.position,A.position>>=1,0==A.position&&(A.position=n,A.val=e(A.index++)),p|=(u>0?1:0)*a,a<<=1;l=r(p);break;case 2:return""}for(f[3]=l,s=l,w.push(l);;){if(A.index>o)return"";for(p=0,c=Math.pow(2,m),a=1;a!=c;)u=A.val&A.position,A.position>>=1,0==A.position&&(A.position=n,A.val=e(A.index++)),p|=(u>0?1:0)*a,a<<=1;switch(l=p){case 0:for(p=0,c=Math.pow(2,8),a=1;a!=c;)u=A.val&A.position,A.position>>=1,0==A.position&&(A.position=n,A.val=e(A.index++)),p|=(u>0?1:0)*a,a<<=1;f[d++]=r(p),l=d-1,h--;break;case 1:for(p=0,c=Math.pow(2,16),a=1;a!=c;)u=A.val&A.position,A.position>>=1,0==A.position&&(A.position=n,A.val=e(A.index++)),p|=(u>0?1:0)*a,a<<=1;f[d++]=r(p),l=d-1,h--;break;case 2:return w.join("")}if(0==h&&(h=Math.pow(2,m),m++),f[l])v=f[l];else{if(l!==d)return null;v=s+s.charAt(0)}w.push(v),f[d++]=s+v.charAt(0),h--,s=v,0==h&&(h=Math.pow(2,m),m++)}}};return i}();"function"==typeof define&&define.amd?define(function(){return LZString}):"undefined"!=typeof module&&null!=module&&(module.exports=LZString); 2 | -------------------------------------------------------------------------------- /javascript/test.html: -------------------------------------------------------------------------------- 1 | 5 | 6 | 7 | 8 | 32 |
33 |
34 |

 35 |   
36 |
37 | 172 | -------------------------------------------------------------------------------- /javascript/tutorial.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | 14 | 15 | 77 | 78 | 113 | 114 | 115 | 116 |

117 | Intro 118 |

119 | 120 | Starting with a basic matmul, we do blah blah blah.... 121 |
122 |
123 |
124 | const SIZE = 128; 125 | const [m, n, k] = lt.symbols("m n k"); 126 | const x = lt.tensor(SIZE, SIZE).to(m, k); 127 | const y = lt.tensor(SIZE, SIZE).to(k, n); 128 | const z = x.mul(y).sum(k); 129 | z.load_loop_tree(`v:m_0 130 | v:k_2 131 | v:n_1 132 | n:2::0,1,:::0::::: 133 | n:2::1,2,:::0::::: 134 | n:7:0,1,:0,1,2,:::0:0;128;0,2;128;0,1;128;0,::,,,:: 135 | n:5:2,:0,2,:::0:0;128;0,2;128;0,1;128;0,::,,,:: 136 | n:1:3,:0,2,:::0:0;128;0,2;128;0,::,,:: 137 | i:0,1, 138 | o:4,`); 139 | z 140 |
141 | 142 |
143 | but we can always unroll! that lets us do blah blah.... 144 |
145 |
146 | 147 |
148 | const SIZE = 128; 149 | const [m, n, k] = lt.symbols("m n k"); 150 | const x = lt.tensor(SIZE, SIZE).to(m, k); 151 | const y = lt.tensor(SIZE, SIZE).to(k, n); 152 | const z = x.mul(y).sum(k); 153 | z.load_loop_tree(`v:m_83 154 | v:k_85 155 | v:n_84 156 | n:2::0,1,:::0::::: 157 | n:2::1,2,:::0::::: 158 | n:7:0,1,:0,1,2,:::0:0;32;0,2;32;0,1;128;0,0;4;0,2;4;0,::,,,unroll,unroll,:: 159 | n:5:2,:0,2,:::0:0;32;0,2;32;0,1;128;0,0;4;0,2;4;0,::,,,unroll,unroll,:: 160 | n:1:3,:0,2,:::0:0;32;0,2;32;0,0;4;0,2;4;0,::,,unroll,unroll,:: 161 | i:0,1, 162 | o:4,`); 163 | z 164 |
165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /javascript/viz.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | loop_tool demo 5 | 6 | 7 | 8 | 9 | 13 | 14 | 15 | 16 | 20 | 24 | 25 | 26 | 92 | 93 | 94 | 99 |
100 |
101 |
102 | 108 | 124 |
125 |
126 | [[source repo]] 132 | [[webcam demo]] 138 |

139 |       
140 |
141 |
142 | 143 | 144 | -------------------------------------------------------------------------------- /javascript/webcam.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | loop_tool live webcam demo 4 | 5 | 9 | 13 | 17 | 18 | 19 | 20 | 134 | 135 | 136 | 140 | 141 | 160 |
161 | 162 |
163 | 164 |
165 |
166 |
167 |
168 | 169 | 170 | 171 | 172 |
173 |
174 |
175 |
176 | 177 |
178 |
179 | 180 |
181 | 182 | 183 | -------------------------------------------------------------------------------- /javascript/webcam.mjs: -------------------------------------------------------------------------------- 1 | import * as lt from "./loop_tool.mjs"; 2 | import * as util from "./main.mjs"; 3 | window.lt = lt; 4 | 5 | let data_canvas = null; 6 | let data_ctx = null; 7 | let cur_hash = null; 8 | let opt_hash = null; 9 | let loop_editor = null; 10 | 11 | const blank_template = `const [h, w, c] = lt.symbols("h w c"); 12 | const V = lt.tensor(video_data.height, 13 | video_data.width, 14 | 4).to(h, w, c); 15 | V.set(video_data.data); 16 | 17 | await display(V);`; 18 | 19 | const color_template = `// manipulate colors 20 | 21 | const [h, w, c] = lt.symbols("h w c"); 22 | const V = lt.tensor(video_data.height, video_data.width, 4).to(h, w, c); 23 | V.set(video_data.data); 24 | 25 | // brighten but clip to max value of 240 26 | const B = V.mul(1.2).min(240); 27 | 28 | // make 40% more red 29 | const redden = lt.tensor(4).to(c); 30 | redden.set([1.4, 1, 1, 1]); 31 | const R = B.mul(redden); 32 | 33 | // contrast 34 | const C = R.sub(15).div(230).mul(255); 35 | 36 | await display(C);`; 37 | 38 | const edge_template = `// Edge detection kernel 39 | // https://en.wikipedia.org/wiki/Kernel_(image_processing) 40 | 41 | const [h, w, c] = lt.symbols("h w c"); 42 | const V = lt.tensor(video_data.height, 43 | video_data.width, 44 | 4).to(h, w, c); 45 | V.set(video_data.data); 46 | 47 | const [kh, kw] = lt.symbols("kh kw"); 48 | const W = lt.tensor(3, 3).to(kh, kw); 49 | W.set([0, -1, 0, 50 | -1, 4, -1, 51 | 0, -1, 0]); 52 | 53 | const X = lt.convolve(V, W, [h, w], [kh, kw]); 54 | const Y = X.sum(c).add(100); 55 | 56 | const alpha = lt.tensor(4).to(c); 57 | alpha.set([255, 255, 255, 0]); 58 | 59 | const Z = Y.max(alpha); 60 | await display(Z);`; 61 | 62 | const sobel_template = `// Sobel operator 63 | 64 | const [h, w, c] = lt.symbols("h w c"); 65 | const V = lt.tensor(video_data.height, 66 | video_data.width, 67 | 4).to(h, w, c); 68 | V.set(video_data.data); 69 | 70 | const [kh, kw] = lt.symbols("kh kw"); 71 | 72 | const Gx = lt.tensor(3, 3).to(kh, kw); 73 | Gx.set([ -1, 0, 1, 74 | -2, 0, 2, 75 | -1, 0, 1 ]) 76 | 77 | const Gy = lt.tensor(3, 3).to(kh, kw); 78 | Gy.set([ 1, 2, 1, 79 | 0, 0, 0, 80 | -1, -2, -1 ]) 81 | 82 | let Vx = lt.convolve(V, Gx, [h, w], [kh, kw]) 83 | .sum(c) 84 | .div(3); 85 | let Vy = lt.convolve(V, Gy, [h, w], [kh, kw]) 86 | .sum(c) 87 | .div(3); 88 | Vx = Vx.to(...Vy.symbolic_shape); 89 | 90 | const Vx2 = Vx.mul(Vx); 91 | const Vy2 = Vy.mul(Vy); 92 | const Y = Vx2.add(Vy2).sqrt(); 93 | 94 | const alpha = lt.tensor(4).to(c); 95 | alpha.set([255, 255, 255, 0]); 96 | 97 | const Z = Y.add(alpha); 98 | 99 | await display(Z);`; 100 | 101 | async function initWebcam() { 102 | const webcam_div = document.querySelector("#webcam"); 103 | try { 104 | const stream = await navigator.mediaDevices.getUserMedia({ 105 | audio: false, 106 | video: true, 107 | }); 108 | let { width, height, facingMode } = stream.getTracks()[0].getSettings(); 109 | console.log(stream.getTracks()[0].getSettings()); 110 | if (data_canvas === null) { 111 | data_canvas = document.createElement("canvas"); 112 | data_canvas.height = height; 113 | data_canvas.width = width; 114 | data_ctx = data_canvas.getContext("2d"); 115 | } 116 | 117 | const vid_elem = document.createElement("video"); 118 | vid_elem.setAttribute("muted", true); 119 | vid_elem.setAttribute("playsinline", true); 120 | vid_elem.srcObject = stream; 121 | vid_elem.addEventListener("play", function () { 122 | runLoop(); 123 | }); 124 | vid_elem.play(); 125 | 126 | const display_canvas = document.querySelector("#output_canvas"); 127 | display_canvas.height = height; 128 | display_canvas.width = width; 129 | 130 | webcam_div.appendChild(vid_elem); 131 | } catch (err) { 132 | webcam_div.textContent = err; 133 | } 134 | } 135 | 136 | function optimize(tensor) { 137 | if (tensor.hash !== opt_hash) { 138 | tensor.optimize(); 139 | opt_hash = tensor.hash; 140 | } 141 | } 142 | 143 | async function display(tensor) { 144 | if ( 145 | (cur_hash === null || cur_hash != tensor.hash) && 146 | !document.querySelector("#opt").classList.contains("hidden") 147 | ) { 148 | cur_hash = tensor.hash; 149 | loop_editor = new util.Editor(document.getElementById("opt"), tensor); 150 | loop_editor.render(); 151 | } 152 | const float_data = await tensor.data; 153 | const [h, w, c] = tensor.shape; 154 | const image_data = new ImageData(w, h); 155 | image_data.data.set(float_data); 156 | const canvas = document.querySelector("#output_canvas"); 157 | const ctx = canvas.getContext("2d"); 158 | ctx.putImageData(image_data, 0, 0); 159 | } 160 | 161 | async function loop() { 162 | if (!data_ctx) { 163 | return; 164 | } 165 | const video = document.querySelector("#webcam video"); 166 | if (!video) { 167 | return; 168 | } 169 | 170 | data_ctx.drawImage(video, 0, 0, data_canvas.width, data_canvas.height); 171 | const video_data = data_ctx.getImageData( 172 | 0, 173 | 0, 174 | data_canvas.width, 175 | data_canvas.height 176 | ); 177 | } 178 | 179 | async function runLoop() { 180 | while (true) { 181 | try { 182 | const t = performance.now(); 183 | if (document.querySelector("#opt").classList.contains("hidden")) { 184 | lt.clear_heap(); 185 | cur_hash = null; 186 | } 187 | await loop(); 188 | const d = performance.now() - t; 189 | document.querySelector("#stats").textContent = `${Math.round( 190 | 1e3 / d 191 | )} fps`; 192 | } catch (e) { 193 | if (Number.isInteger(e)) { 194 | e = lt.getExceptionMessage(e); 195 | } 196 | document.querySelector("#error_out").textContent = e; 197 | } 198 | await new Promise((r) => { 199 | requestAnimationFrame(r); 200 | }); 201 | } 202 | } 203 | 204 | function updateLoop(editor) { 205 | const fn_def = ` 206 | loop = null; 207 | loop = async function() { 208 | 209 | if (!data_ctx) { 210 | return; 211 | } 212 | const video = document.querySelector("#webcam video"); 213 | if (!video) { 214 | return; 215 | } 216 | 217 | data_ctx.drawImage(video, 0, 0, data_canvas.width, data_canvas.height); 218 | const video_data = data_ctx.getImageData(0, 0, data_canvas.width, data_canvas.height); 219 | 220 | ${editor.getValue()} 221 | 222 | } 223 | `; 224 | eval(fn_def); 225 | } 226 | 227 | function init() { 228 | let editor = CodeMirror.fromTextArea(document.querySelector("#codeeditor"), { 229 | lineNumbers: true, 230 | tabSize: 2, 231 | mode: "javascript", 232 | }); 233 | if (window.location.hash) { 234 | const s = window.location.hash.slice(1); 235 | const c = LZString.decompressFromBase64(decodeURIComponent(s)); 236 | editor.setValue(c); 237 | } else { 238 | editor.setValue(edge_template); 239 | } 240 | editor.on("change", function () { 241 | try { 242 | updateLoop(editor); 243 | window.location.hash = encodeURIComponent( 244 | LZString.compressToBase64(editor.getValue()) 245 | ); 246 | document.querySelector("#error_out").textContent = ""; 247 | } catch (e) { 248 | document.querySelector("#error_out").textContent = e; 249 | console.log(e); 250 | } 251 | }); 252 | 253 | document 254 | .querySelector('button[name="blank"]') 255 | .addEventListener("click", () => { 256 | editor.setValue(blank_template); 257 | }); 258 | document 259 | .querySelector('button[name="color"]') 260 | .addEventListener("click", () => { 261 | editor.setValue(color_template); 262 | }); 263 | document 264 | .querySelector('button[name="edge"]') 265 | .addEventListener("click", () => { 266 | editor.setValue(edge_template); 267 | }); 268 | document 269 | .querySelector('button[name="sobel"]') 270 | .addEventListener("click", () => { 271 | editor.setValue(sobel_template); 272 | }); 273 | 274 | updateLoop(editor); 275 | 276 | initWebcam(); 277 | window.addEventListener("keydown", (e) => { 278 | if (editor.hasFocus()) { 279 | return; 280 | } 281 | if (!loop_editor) { 282 | return; 283 | } 284 | e.preventDefault(); 285 | loop_editor.handle_keydown(e); 286 | loop_editor.render(); 287 | }); 288 | } 289 | 290 | export { init }; 291 | -------------------------------------------------------------------------------- /python/__init__.py: -------------------------------------------------------------------------------- 1 | from loop_tool_py import * 2 | from .ui import ui 3 | from . import nn 4 | 5 | def symbols(s): 6 | syms = [] 7 | for n in s.split(" "): 8 | syms.append(Symbol(n)) 9 | return syms 10 | 11 | class Backend(): 12 | def __init__(self, backend): 13 | self.old_backend = get_default_backend() 14 | self.backend = backend 15 | def __enter__(self): 16 | set_default_backend(self.backend) 17 | return self 18 | def __exit__(self, type, value, traceback): 19 | set_default_backend(self.old_backend) 20 | 21 | -------------------------------------------------------------------------------- /python/nn.py: -------------------------------------------------------------------------------- 1 | import loop_tool_py as lt 2 | 3 | const_map = {} 4 | 5 | 6 | def fill(constant, symbolic_shape): 7 | if constant in const_map: 8 | const = const_map[constant] 9 | else: 10 | const = lt.Tensor().set(constant) 11 | const_map[constant] = const 12 | return const 13 | 14 | 15 | def mean(X, dims): 16 | exprs = [(x, lt.Expr(0)) for x in dims] 17 | one = fill(1, dims).to(*dims, constraints=exprs) 18 | return X.sum(*dims) / one.sum(*dims) 19 | 20 | 21 | def sigmoid(T): 22 | shape = T.symbolic_shape 23 | one = fill(1, shape) 24 | return (one + (-T).exp()).reciprocal() 25 | 26 | 27 | def swish(T): 28 | shape = T.symbolic_shape 29 | return T * sigmoid(T) 30 | 31 | 32 | def relu(T): 33 | shape = T.symbolic_shape 34 | zero = fill(0, shape) 35 | return T.max(zero) 36 | 37 | 38 | def relu6(T): 39 | shape = T.symbolic_shape 40 | six = fill(6, shape) 41 | return relu(T) - relu(T - six) 42 | 43 | 44 | def hardswish(T): 45 | shape = T.symbolic_shape 46 | three = fill(3, shape) 47 | sixth = fill(1 / 6, shape) 48 | return T * relu6(T + three) * sixth 49 | 50 | 51 | def tanh(T): 52 | shape = T.symbolic_shape 53 | two = fill(2, shape) 54 | one = fill(1, shape) 55 | return two * sigmoid(two * T) - one 56 | 57 | 58 | def linear(X, W, bias=None): 59 | reduction_dims = set(X.symbolic_shape) & set(W.symbolic_shape) 60 | Y = (X * W).sum(*reduction_dims) 61 | if bias: 62 | Y = Y + bias 63 | return Y 64 | 65 | 66 | # pad(X, (s.K, 1), (s.X, (0, 1))) 67 | def pad(X, *args): 68 | for d, pad in args: 69 | if type(pad) is tuple: 70 | X = X.pad(d, *pad) 71 | else: 72 | X = X.pad(d, pad) 73 | return X 74 | 75 | 76 | def conv(X, W, spatial, window, stride=1, channel_reduce=True): 77 | assert len(spatial) == len(window) 78 | # output dimensions need new names 79 | new_spatial = [lt.Symbol(x.name + "o") for x in spatial] 80 | outer = [d for d in X.symbolic_shape if d not in spatial] 81 | exprs = [lt.Expr(stride) * x + k for x, k in zip(new_spatial, window)] 82 | X = X.to(*outer, *new_spatial, *window, constraints=zip(spatial, exprs)) 83 | 84 | # reduce over input channels and the windowed dims 85 | if channel_reduce: 86 | reduction_dims = (set(X.symbolic_shape) & set(W.symbolic_shape)) | set(window) 87 | else: 88 | reduction_dims = set(window) 89 | return (X * W).sum(*reduction_dims) 90 | 91 | 92 | def batch_norm(x, mean, var, weight, bias, eps=None): 93 | if eps == None: 94 | eps = lt.Tensor().set(1e-5) 95 | x = (x - mean) * weight 96 | return x / (var + eps).sqrt() + bias 97 | -------------------------------------------------------------------------------- /python/ui.py: -------------------------------------------------------------------------------- 1 | import loop_tool_py as lt 2 | import curses 3 | from curses import wrapper 4 | from curses.textpad import Textbox 5 | import time 6 | 7 | # use with vim [file] -c 'set updatetime=750 | set autoread | au CursorHold * checktime | call feedkeys("lh")' 8 | 9 | 10 | def hex_to_color(h): 11 | r, g, b = tuple(int(h[i : i + 2], 16) for i in (0, 2, 4)) 12 | return curses.color_pair(int(16 + r / 48 * 36 + g / 48 * 6 + b / 48)) 13 | 14 | 15 | def init_colors(): 16 | curses.start_color() 17 | curses.use_default_colors() 18 | for i in range(0, curses.COLORS): 19 | curses.init_pair(i + 1, i, -1) 20 | 21 | 22 | def get_versions(loop): 23 | versions = [] 24 | 25 | def f(r, depth): 26 | nonlocal versions 27 | if tree.is_loop(r) and (tree.loop(r) == loop): 28 | versions.append(r) 29 | 30 | tree.walk(f) 31 | return versions 32 | 33 | 34 | def benchmark(tensor, limit_ms=100): 35 | start = time.time() * 1000 36 | iters = 1 37 | t = 0 38 | while (t - start) < limit_ms: 39 | for i in range(iters): 40 | tensor.force_recompute() 41 | tensor.resolve() 42 | t = time.time() * 1000 43 | iters *= 2 44 | return 1000 * (iters - 1) / (t - start) 45 | 46 | 47 | def loop_version(tree, ref): 48 | if not tree.is_loop(ref): 49 | return None 50 | loop = tree.loop(ref) 51 | version = 0 52 | keep_scanning = True 53 | 54 | def f(r, depth): 55 | nonlocal keep_scanning 56 | nonlocal version 57 | if r == ref: 58 | keep_scanning = False 59 | if keep_scanning and tree.is_loop(r) and tree.loop(r) == loop: 60 | version += 1 61 | 62 | tree.walk(f) 63 | return (loop, version) 64 | 65 | 66 | def highlight(tree, drag): 67 | assert drag 68 | highlighted = None 69 | version = 0 70 | 71 | def find_loop(ref, depth): 72 | nonlocal highlighted 73 | nonlocal version 74 | if ( 75 | tree.is_loop(ref) 76 | and (tree.loop(ref) == drag[0] and version == drag[1]) 77 | and highlighted == None 78 | ): 79 | highlighted = ref 80 | if tree.is_loop(ref) and (tree.loop(ref) == drag[0]): 81 | version += 1 82 | 83 | tree.walk(find_loop) 84 | assert highlighted != None, ( 85 | f"found {version} versions but wanted {drag[1]}:\n" + tree.dump() 86 | ) 87 | return highlighted 88 | 89 | 90 | def gen_info(tree, highlighted, drag): 91 | s = "" 92 | if tree.is_loop(highlighted): 93 | if drag is not None: 94 | s += f"[dragging {tree.ir.dump_var(drag[0].var)} v{drag[1]}]" 95 | else: 96 | allocs = lt.Compiler(tree).allocations 97 | n = tree.ir_node(highlighted) 98 | if n in allocs: 99 | s += f"[size: {allocs[n].size}]" 100 | else: 101 | s += f"[allocs size {len(allocs)}]" 102 | return s 103 | 104 | 105 | def prompt(stdscr, pad, s): 106 | rows, cols = stdscr.getmaxyx() 107 | pad.addstr(0, 0, s + " " * (cols - len(s) - 1)) 108 | # , hex_to_color("00ff00")) 109 | stdscr.refresh() 110 | pad.refresh(0, 0, 0, 0, rows, cols) 111 | editwin = curses.newwin(1, 30, 0, len(s)) 112 | box = Textbox(editwin, insert_mode=True) 113 | 114 | def validate(x): 115 | if x == 10: 116 | x = 7 117 | if x == 127: 118 | x = curses.KEY_BACKSPACE 119 | return x 120 | 121 | box.edit(validate) 122 | message = box.gather() 123 | split_size = 0 124 | try: 125 | split_size = int(message) 126 | except: 127 | pass 128 | return split_size 129 | 130 | 131 | def ui_impl(stdscr, tensor, fn): 132 | tree = tensor.loop_tree 133 | trees = [tree] 134 | highlighted = tree.roots[0] 135 | drag = None 136 | rows, cols = stdscr.getmaxyx() 137 | stdscr.clear() 138 | curses.curs_set(0) 139 | tree_pad = curses.newpad(rows, cols) 140 | 141 | iters_sec = 0 142 | flops = 0 143 | reads = 0 144 | writes = 0 145 | 146 | def render(changed): 147 | nonlocal highlighted, iters_sec, flops, reads, writes 148 | highlighted = highlight(tree, drag) if drag else highlighted 149 | tree_pad.erase() 150 | i = 0 151 | info = gen_info(tree, highlighted, drag) 152 | tree_pad.addstr(i, 0, info) 153 | 154 | if changed: 155 | tensor.set(tree) 156 | trees.append(tree) 157 | if fn: 158 | with open(fn, "w") as f: 159 | f.write(tensor.code) 160 | _ = benchmark(tensor, 10) # warmup 161 | iters_sec = benchmark(tensor) 162 | flops = tree.flops() 163 | tree_pad.addstr( 164 | i, 165 | len(info) + 1, 166 | f"{flops * iters_sec / 1e9:.2f} GFlops, ({iters_sec:.2f} iters/sec, {flops} total flops)", 167 | ) 168 | 169 | def _render_ref(ref): 170 | if tree.is_loop(ref): 171 | loop = tree.loop(ref) 172 | v = tree.ir.dump_var(loop.var) 173 | r = f" r {loop.tail}" if loop.tail else "" 174 | return f"for {v} in {loop.size}{r}" 175 | return tree.dump(ref) 176 | 177 | def _r(ref, depth): 178 | nonlocal i 179 | i += 1 180 | tree_pad.addstr(i, depth, _render_ref(ref)) 181 | if ref == highlighted: 182 | tree_pad.chgat(i, 0, curses.A_REVERSE) 183 | 184 | tree.walk(_r) 185 | 186 | stdscr.refresh() 187 | tree_pad.refresh(0, 0, 0, 0, rows, cols) 188 | 189 | render(True) 190 | 191 | def update_tree(new_tree): 192 | nonlocal highlighted, tree, changed 193 | highlighted = new_tree.map_ref(highlighted, tree) 194 | tree = new_tree 195 | changed = True 196 | 197 | while True: 198 | key = stdscr.getkey() 199 | changed = False 200 | if key == "q": 201 | break 202 | elif key == "s": 203 | split_size = prompt(stdscr, tree_pad, "inner size? ") 204 | try: 205 | update_tree(tree.split(highlighted, split_size)) 206 | except: 207 | pass 208 | elif key == "u" and len(trees) > 1: 209 | trees = trees[:-1] 210 | update_tree(trees[-1]) 211 | elif key == "KEY_DOWN": 212 | if drag: 213 | update_tree(tree.try_swap(highlighted, tree.next_ref(highlighted))) 214 | else: 215 | n = tree.next_ref(highlighted) 216 | if n is not None: 217 | highlighted = n 218 | elif key == "KEY_UP": 219 | if drag: 220 | update_tree(tree.try_swap(highlighted, tree.previous_ref(highlighted))) 221 | else: 222 | p = tree.previous_ref(highlighted) 223 | if p is not None: 224 | highlighted = p 225 | elif key == "KEY_SR": # up + shift 226 | update_tree(tree.try_swap(highlighted, tree.previous_ref(highlighted))) 227 | elif key == "KEY_SF": # down + shift 228 | update_tree(tree.try_swap(highlighted, tree.next_ref(highlighted))) 229 | changed = True 230 | elif key in ("KEY_BACKSPACE", "\b", "\x7f"): 231 | update_tree(tree.merge(highlighted)) 232 | elif key == "\n": 233 | key = "ENTER" 234 | drag = None if drag else loop_version(tree, highlighted) 235 | render(changed) 236 | if key == "u": 237 | trees = trees[:-1] 238 | return tree 239 | 240 | 241 | def ui(T, path=""): 242 | T.set(wrapper(ui_impl, T, path)) 243 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pybind11 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2016 The Pybind Development Team, All rights reserved. 2 | 3 | # -*- coding: utf-8 -*- 4 | import os 5 | import sys 6 | import subprocess 7 | 8 | from setuptools import setup, Extension, find_packages 9 | from setuptools.command.build_ext import build_ext 10 | 11 | 12 | # Convert distutils Windows platform specifiers to CMake -A arguments 13 | PLAT_TO_CMAKE = { 14 | "win32": "Win32", 15 | "win-amd64": "x64", 16 | "win-arm32": "ARM", 17 | "win-arm64": "ARM64", 18 | } 19 | 20 | 21 | class CMakeExtension(Extension): 22 | def __init__(self, name, sourcedir=""): 23 | Extension.__init__(self, name, sources=[]) 24 | self.sourcedir = os.path.abspath(sourcedir) 25 | 26 | 27 | class CMakeBuild(build_ext): 28 | def build_extension(self, ext): 29 | import pybind11 30 | 31 | extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) 32 | 33 | # required for auto-detection of auxiliary "native" libs 34 | if not extdir.endswith(os.path.sep): 35 | extdir += os.path.sep 36 | 37 | cfg = "Debug" if self.debug else "Release" 38 | 39 | # CMake lets you override the generator - we need to check this. 40 | # Can be set with Conda-Build, for example. 41 | cmake_generator = os.environ.get("CMAKE_GENERATOR", "") 42 | 43 | # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON 44 | cmake_args = [ 45 | "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}".format(extdir), 46 | "-DPYTHON_EXECUTABLE={}".format(sys.executable), 47 | "-DCMAKE_BUILD_TYPE={}".format(cfg), # not used on MSVC, but no harm 48 | "-Dpybind11_DIR={}".format(pybind11.get_cmake_dir()), 49 | "-DCMAKE_BUILD_WITH_INSTALL_RPATH=TRUE", 50 | "-DCMAKE_INSTALL_RPATH={}".format( 51 | "@loader_path" if "darwin" in sys.platform else "$ORIGIN" 52 | ), 53 | "-DBUILD_TESTS=OFF", 54 | ] 55 | build_args = [] 56 | 57 | if self.compiler.compiler_type != "msvc": 58 | # Using Ninja-build since it a) is available as a wheel and b) 59 | # multithreads automatically. MSVC would require all variables be 60 | # exported for Ninja to pick it up, which is a little tricky to do. 61 | # Users can override the generator with CMAKE_GENERATOR in CMake 62 | # 3.15+. 63 | if not cmake_generator: 64 | cmake_args += ["-GNinja"] 65 | 66 | else: 67 | 68 | # Single config generators are handled "normally" 69 | single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) 70 | 71 | # CMake allows an arch-in-generator style for backward compatibility 72 | contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) 73 | 74 | # Specify the arch if using MSVC generator, but only if it doesn't 75 | # contain a backward-compatibility arch spec already in the 76 | # generator name. 77 | if not single_config and not contains_arch: 78 | cmake_args += ["-A", PLAT_TO_CMAKE[self.plat_name]] 79 | 80 | # Multi-config generators have a different way to specify configs 81 | if not single_config: 82 | cmake_args += [ 83 | "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir) 84 | ] 85 | build_args += ["--config", cfg] 86 | 87 | # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level 88 | # across all generators. 89 | if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: 90 | # self.parallel is a Python 3 only way to set parallel jobs by hand 91 | # using -j in the build_ext call, not supported by pip or PyPA-build. 92 | if hasattr(self, "parallel") and self.parallel: 93 | # CMake 3.12+ only. 94 | build_args += ["-j{}".format(self.parallel)] 95 | 96 | if not os.path.exists(self.build_temp): 97 | os.makedirs(self.build_temp) 98 | 99 | subprocess.check_call( 100 | ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp 101 | ) 102 | subprocess.check_call( 103 | ["cmake", "--build", "."] + build_args, cwd=self.build_temp 104 | ) 105 | 106 | 107 | setup( 108 | name="loop_tool", 109 | version="0.1.0", 110 | author="Bram Wasti", 111 | author_email="bwasti@fb.com", 112 | description="A lightweight IR for dense linear algebra", 113 | long_description="", # TODO 114 | url="https://github.com/facebookresearch/loop_tool", 115 | ext_modules=[CMakeExtension("loop_tool_py")], 116 | packages=["loop_tool"], 117 | package_dir={"loop_tool": "python"}, 118 | install_requires=[], 119 | setup_requires=["pybind11", "ninja"], 120 | include_package_data=True, 121 | exclude_package_data={'': ['test']}, 122 | cmdclass={"build_ext": CMakeBuild}, 123 | zip_safe=False, 124 | ) 125 | -------------------------------------------------------------------------------- /src/backends/cpu/loop_nest.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include "dabun/loop_nest.hpp" 8 | 9 | #include "dabun/arithmetic_operation.hpp" 10 | #include "dabun/code_generator/memory_resource.hpp" 11 | #include "dabun/isa.hpp" 12 | #include "loop_tool/backend.h" 13 | #include "loop_tool/mutate.h" 14 | 15 | using namespace loop_tool; 16 | using namespace symbolic; 17 | 18 | struct LoopNestCompiler : public Compiler { 19 | bool fma_nest = false; 20 | bool transpose_nest = false; 21 | LoopNestCompiler(const LoopTree& lt) : Compiler(lt) { 22 | fma_nest = is_fma_nest(); 23 | transpose_nest = is_transpose_nest(); 24 | } 25 | 26 | bool can_compile() const { return fma_nest || transpose_nest; } 27 | 28 | dabun::shared_aot_fn gen_exec() 29 | const { 30 | ASSERT(fma_nest); 31 | return compile_fma_nest(); 32 | } 33 | 34 | #define REQUIRE(x) \ 35 | { \ 36 | if (!(x)) { \ 37 | std::cerr << #x << " failed\n"; \ 38 | return false; \ 39 | } \ 40 | } 41 | bool is_fma_nest() const { 42 | auto reads = find(lt.ir, Operation::read); 43 | REQUIRE(reads.size() == 2); 44 | REQUIRE(lt.scheduled.count(reads.at(0)) == 0); 45 | REQUIRE(lt.scheduled.count(reads.at(1)) == 0); 46 | 47 | // find mul and add operations 48 | auto muls = find(lt.ir, Operation::multiply); 49 | REQUIRE(muls.size() == 1); 50 | auto mul_ref = muls.at(0); 51 | const auto& mul = lt.ir.node(mul_ref); 52 | REQUIRE(lt.scheduled.count(mul_ref) == 1); 53 | REQUIRE(mul.inputs().size() == 2); 54 | 55 | auto views = find(lt.ir, Operation::view); 56 | for (auto v : views) { 57 | REQUIRE(lt.ir.node(v).outputs().size() == 1); 58 | REQUIRE(lt.ir.node(v).outputs().at(0) == mul_ref); 59 | REQUIRE(lt.ir.node(v).inputs().size() == 1); 60 | } 61 | 62 | // 2 reads (unscheduled), 1 mul, 1 add, 1 write scheduled, optional views 63 | REQUIRE(lt.ir.nodes().size() == 2 + 1 + 1 + 1 + views.size()); 64 | 65 | auto adds = find(lt.ir, Operation::add); 66 | REQUIRE(adds.size() == 1); 67 | auto add_ref = adds.at(0); 68 | const auto& add = lt.ir.node(add_ref); 69 | REQUIRE(add.inputs().size() == 1); 70 | REQUIRE(lt.scheduled.count(add_ref) == 1); 71 | 72 | REQUIRE(add.inputs().size() == 1); 73 | REQUIRE(add.inputs().at(0) == mul_ref); 74 | 75 | auto mul_parent = lt.parent(lt.scheduled.at(mul_ref)); 76 | auto add_parent = lt.parent(lt.scheduled.at(add_ref)); 77 | REQUIRE(mul_parent == add_parent); 78 | 79 | auto writes = find(lt.ir, Operation::write); 80 | REQUIRE(writes.size() == 1); 81 | auto write_ref = writes.at(0); 82 | const auto& write = lt.ir.node(write_ref); 83 | REQUIRE(lt.scheduled.count(write_ref) == 1); 84 | 85 | return true; 86 | } 87 | 88 | bool is_transpose_nest() const { return false; } 89 | #undef REQUIRE 90 | 91 | dabun::shared_aot_fn 92 | compile_fma_nest() const { 93 | auto muls = find(lt.ir, Operation::multiply); 94 | const auto& mul = lt.ir.node(muls.at(0)); 95 | auto ref = lt.parent(lt.scheduled.at(muls.at(0))); 96 | std::vector> order; 97 | std::vector> sizes; 98 | for (auto v : lt.ir.vars()) { 99 | auto size = var_sizes.at(v); 100 | auto size_name = lt.ir.var(v).name(); 101 | sizes.emplace_back(size_name, (int)size); 102 | } 103 | 104 | while (ref != -1) { 105 | auto loop = lt.loop(ref); 106 | auto order_size = inner_sizes.at(ref); 107 | auto order_name = lt.ir.var(loop.var).name(); 108 | order.emplace(order.begin(), order_name, order_size); 109 | ref = lt.parent(ref); 110 | } 111 | 112 | auto reads = find(lt.ir, Operation::read); 113 | auto A_ref = mul.inputs().at(0); 114 | const auto& A = lt.ir.node(A_ref); 115 | auto B_ref = mul.inputs().at(1); 116 | const auto& B = lt.ir.node(B_ref); 117 | auto C_ref = find(lt.ir, Operation::write).at(0); 118 | const auto& C = lt.ir.node(C_ref); 119 | 120 | // for strides, get the Access when reading from A, B 121 | auto mul_ref = lt.scheduled.at(muls.at(0)); 122 | auto A_acc = gen_access(A_ref, mul_ref); 123 | auto A_idx = get_scoped_expr(A_acc); 124 | auto B_acc = gen_access(B_ref, mul_ref); 125 | auto B_idx = get_scoped_expr(B_acc); 126 | auto C_acc = gen_access(C_ref, lt.scheduled.at(C_ref)); 127 | auto C_idx = get_scoped_expr(C_acc); 128 | 129 | std::vector> A_strides; 130 | std::vector> B_strides; 131 | std::vector> C_strides; 132 | for (auto v : lt.ir.vars()) { 133 | auto v_name = lt.ir.var(v).name(); 134 | auto A_v = differentiate(A_idx, var_to_sym.at(v)).evaluate(); 135 | auto B_v = differentiate(B_idx, var_to_sym.at(v)).evaluate(); 136 | auto C_v = differentiate(C_idx, var_to_sym.at(v)).evaluate(); 137 | if (A_v) { 138 | A_strides.emplace_back(v_name, A_v); 139 | } 140 | if (B_v) { 141 | B_strides.emplace_back(v_name, B_v); 142 | } 143 | if (C_v) { 144 | C_strides.emplace_back(v_name, C_v); 145 | } 146 | } 147 | 148 | std::vector A_axes; 149 | std::vector B_axes; 150 | std::vector C_axes; 151 | for (auto v : lt.ir.node(A_ref).vars()) { 152 | A_axes.emplace_back(lt.ir.var(v).name()); 153 | } 154 | for (auto v : lt.ir.node(B_ref).vars()) { 155 | B_axes.emplace_back(lt.ir.var(v).name()); 156 | } 157 | for (auto v : lt.ir.node(C_ref).vars()) { 158 | C_axes.emplace_back(lt.ir.var(v).name()); 159 | } 160 | 161 | auto arg = dabun::LN_sizes(sizes) 162 | .C_axes(C_axes) 163 | .A_axes(A_axes) 164 | .B_axes(B_axes) 165 | .C_strides(C_strides) 166 | .A_strides(A_strides) 167 | .B_strides(B_strides) 168 | .append_loops(order); 169 | 170 | #if defined(__AVX512F__) 171 | #define VEX dabun::extension::avx512 172 | #elif defined(__aarch64__) || defined(__arm64__) 173 | #define VEX dabun::extension::neon 174 | #else // default to avx2 175 | #define VEX dabun::extension::avx2 176 | #endif 177 | 178 | return dabun::loop_nest_compiler(arg, dabun::fma).get_shared(); 179 | } 180 | 181 | void compile_transpose_nest() const { 182 | ASSERT(0); 183 | return; 184 | } 185 | }; 186 | 187 | struct LoopNestCompiled : public Compiled { 188 | dabun::shared_aot_fn fn; 189 | 190 | LoopNestCompiled() = delete; 191 | LoopNestCompiled(const LoopNestCompiled&) = delete; 192 | LoopNestCompiled(LoopNestCompiled&&) = delete; 193 | 194 | LoopNestCompiled(const LoopTree& lt) { 195 | LoopNestCompiler cc(lt); 196 | fn = std::move(cc.gen_exec()); 197 | } 198 | 199 | ~LoopNestCompiled() {} 200 | 201 | void run(const std::vector& memory, bool sync) const override { 202 | fn((float*)(memory[2]), (const float*)(memory[0]), 203 | (const float*)(memory[1]), 0); 204 | } 205 | 206 | std::string dump() const override { return ""; } 207 | }; 208 | 209 | struct LoopNestBackend : public Backend { 210 | LoopNestBackend() : Backend("loop_nest") { 211 | // static destruction order hack 212 | (void)dabun::memory_resource::default_resource(); 213 | } 214 | ~LoopNestBackend() {} 215 | 216 | std::unique_ptr compile_impl(const LoopTree& lt) const override { 217 | return std::make_unique(lt); 218 | } 219 | int hardware_requirement() const override { return 1 << 0; } 220 | }; 221 | 222 | static RegisterBackend loop_nest_backend_reg_( 223 | std::make_shared()); 224 | -------------------------------------------------------------------------------- /src/backends/cuda/README.md: -------------------------------------------------------------------------------- 1 | # DEPRECATED 2 | 3 | Warning, this folder is deprecated. 4 | -------------------------------------------------------------------------------- /src/backends/cuda/cuda_backend.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | #include "loop_tool/compile.h" 16 | #include "loop_tool/dynlib.h" 17 | #include "loop_tool/error.h" 18 | #include "loop_tool/hardware.h" 19 | #include "loop_tool/ir.h" 20 | 21 | namespace loop_tool { 22 | 23 | std::shared_ptr &cudaLib(); 24 | std::shared_ptr &cudaRuntimeLib(); 25 | std::shared_ptr &nvrtcLib(); 26 | 27 | #define CULIB(sym) DYNLIB(loop_tool::cudaLib(), sym) 28 | #define CURTLIB(sym) DYNLIB(loop_tool::cudaRuntimeLib(), sym) 29 | #define NVRTCLIB(sym) DYNLIB(loop_tool::nvrtcLib(), sym) 30 | 31 | struct CudaAux { 32 | // maps loops to the inner size of other threaded loops 33 | std::unordered_map threaded; 34 | std::unordered_set unrolled; 35 | int threads_per_warp; 36 | int threads_per_block; 37 | std::unordered_map alloc_threads; 38 | std::unordered_map syncs; 39 | std::unordered_map tail; // temporary 40 | }; 41 | 42 | CudaAux calc_cuda_aux(const LoopTree <, const Auxiliary &aux, 43 | const std::unordered_set &threaded); 44 | 45 | } // namespace loop_tool 46 | 47 | #define gpuErrchk(ans) \ 48 | { gpuAssert((ans), __FILE__, __LINE__); } 49 | inline void gpuAssert(cudaError_t code, const char *file, int line, 50 | bool abort = true) { 51 | if (code != cudaSuccess) { 52 | ASSERT(0) << CULIB(cudaGetErrorString)(code) << " " << file << ":" << line; 53 | } 54 | } 55 | 56 | #define CUDA_SAFE_CALL(x) \ 57 | do { \ 58 | CUresult result = x; \ 59 | const char *msg; \ 60 | CULIB(cuGetErrorName)(result, &msg); \ 61 | ASSERT(result == CUDA_SUCCESS) \ 62 | << "\nerror: " #x " failed with error " << msg << '\n'; \ 63 | } while (0) 64 | -------------------------------------------------------------------------------- /src/backends/wasm/wasm_runtime.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include "loop_tool/wasm.h" 8 | #include "wasm_c_api.h" 9 | 10 | using namespace loop_tool; 11 | 12 | struct WebAssemblyCompiled : public Compiled { 13 | std::vector emitted_wasm; 14 | wasm_engine_t* engine; 15 | wasm_store_t* store; 16 | wasm_instance_t* instance; 17 | wasm_extern_vec_t exports; 18 | wasm_memory_t* wasm_memory; 19 | wasm_func_t* fn; 20 | 21 | std::vector memory_size_map; 22 | std::vector memory_input_offset_map; 23 | std::vector memory_output_offset_map; 24 | 25 | WebAssemblyCompiled(const LoopTree& lt) { 26 | WebAssemblyCompiler wc(lt); 27 | emitted_wasm = wc.emit(); 28 | engine = wasm_engine_new(); 29 | store = wasm_store_new(engine); 30 | wasm_byte_vec_t binary; 31 | wasm_byte_vec_new_uninitialized(&binary, emitted_wasm.size()); 32 | memcpy(binary.data, emitted_wasm.data(), emitted_wasm.size()); 33 | wasm_module_t* m = wasm_module_new(store, &binary); 34 | ASSERT(m) << "Couldn't compile WebAssembly module"; 35 | wasm_byte_vec_delete(&binary); 36 | 37 | wasm_extern_vec_t imports = WASM_EMPTY_VEC; 38 | instance = 39 | wasm_instance_new_with_args(store, m, &imports, NULL, KILOBYTE(32), 0); 40 | ASSERT(instance) << "Couldn't instantiate WebAssembly module"; 41 | wasm_instance_exports(instance, &exports); 42 | wasm_memory = wasm_extern_as_memory(exports.data[0]); 43 | fn = wasm_extern_as_func(exports.data[1]); 44 | wasm_module_delete(m); 45 | 46 | // include input/output sizes 47 | const auto& inputs = lt.ir.inputs(); 48 | const auto& outputs = lt.ir.outputs(); 49 | auto all_mem_sizes = wc.memory_sizes(true); 50 | int64_t offset = 0; 51 | memory_input_offset_map.resize(inputs.size() + outputs.size()); 52 | memory_output_offset_map.resize(inputs.size() + outputs.size()); 53 | for (auto i = 0; i < inputs.size() + outputs.size(); ++i) { 54 | int64_t size = all_mem_sizes.at(i) * 4; 55 | memory_size_map.emplace_back(size); 56 | if (i < inputs.size()) { 57 | memory_input_offset_map[i] = offset; 58 | memory_output_offset_map[i] = -1; 59 | } else { 60 | memory_input_offset_map[i] = -1; 61 | memory_output_offset_map[i] = offset; 62 | } 63 | offset += size; 64 | } 65 | } 66 | 67 | void run(const std::vector& user_memory, bool sync) const override { 68 | wasm_val_vec_t args = WASM_EMPTY_VEC; 69 | wasm_val_vec_t results = WASM_EMPTY_VEC; 70 | // copy user memory to webassembly and back 71 | char* data = wasm_memory_data(wasm_memory); 72 | for (auto i = 0; i < user_memory.size(); ++i) { 73 | int64_t offset = memory_input_offset_map[i]; 74 | if (offset == -1) { 75 | continue; 76 | } 77 | void* ptr = &data[offset]; 78 | memcpy(ptr, user_memory[i], memory_size_map[i]); 79 | } 80 | wasm_func_call(fn, &args, &results); 81 | for (auto i = 0; i < user_memory.size(); ++i) { 82 | int64_t offset = memory_output_offset_map[i]; 83 | if (offset == -1) { 84 | continue; 85 | } 86 | void* ptr = &data[offset]; 87 | memcpy(user_memory[i], ptr, memory_size_map[i]); 88 | } 89 | } 90 | std::string dump() const override { return ""; } 91 | 92 | ~WebAssemblyCompiled() { 93 | wasm_extern_vec_delete(&exports); 94 | wasm_instance_delete(instance); 95 | wasm_store_delete(store); 96 | wasm_engine_delete(engine); 97 | } 98 | }; 99 | 100 | struct WebAssemblyBackend : public Backend { 101 | WebAssemblyBackend() : Backend("wasm") {} 102 | ~WebAssemblyBackend() {} 103 | WebAssemblyBackend(std::string name) : Backend(name) {} 104 | 105 | std::unique_ptr compile_impl(const LoopTree& lt) const override { 106 | return std::make_unique(lt); 107 | } 108 | int hardware_requirement() const override { return 1 << 0; } 109 | }; 110 | 111 | static RegisterBackend wasm_backend_reg_( 112 | std::make_shared()); 113 | -------------------------------------------------------------------------------- /src/core/backend.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include "loop_tool/backend.h" 8 | 9 | #include 10 | #include 11 | 12 | #include "loop_tool/dynlib.h" 13 | 14 | static std::mutex registration_mutex_; 15 | static std::vector loaded_libs; 16 | 17 | namespace loop_tool { 18 | 19 | void Compiled::operator()(const std::vector &tensors, 20 | bool sync) const { 21 | std::vector memory; 22 | for (const auto &t : tensors) { 23 | memory.emplace_back(t->data.address); 24 | } 25 | run(memory, sync); 26 | } 27 | 28 | std::vector Compiled::allocate(std::vector &sizes) const { 29 | std::vector memory(sizes.size()); 30 | for (auto i = 0; i < sizes.size(); ++i) { 31 | if (sizes[i] > 0) { 32 | memory[i] = calloc(sizes[i], sizeof(float)); 33 | } 34 | } 35 | return memory; 36 | } 37 | 38 | std::unordered_map> 39 | &getMutableBackends() { 40 | static std::unordered_map> backends_; 41 | return backends_; 42 | } 43 | 44 | const std::unordered_map> &getBackends() { 45 | return getMutableBackends(); 46 | } 47 | 48 | void registerBackend(std::shared_ptr backend) { 49 | std::lock_guard guard(registration_mutex_); 50 | getMutableBackends()[backend->name()] = backend; 51 | } 52 | 53 | std::shared_ptr &getDefaultBackend() { 54 | static std::shared_ptr default_backend_ = getBackends().at("cpp"); 55 | return default_backend_; 56 | } 57 | 58 | void setDefaultBackend(std::string backend) { 59 | ASSERT(getBackends().count(backend)) << "couldn't find backend " << backend; 60 | getDefaultBackend() = getBackends().at(backend); 61 | } 62 | 63 | void loadLibrary(std::string lib_name) { 64 | loaded_libs.emplace_back(lib_name.c_str(), true); 65 | } 66 | 67 | } // namespace loop_tool 68 | -------------------------------------------------------------------------------- /src/core/hardware.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include "loop_tool/hardware.h" 8 | 9 | #include 10 | #include 11 | 12 | static std::mutex registration_mutex_; 13 | 14 | namespace loop_tool { 15 | 16 | std::vector> &getMutableHardware() { 17 | // We want CPU to be first, so we don't use registration pattern 18 | static std::vector> hardware_ = { 19 | std::make_shared()}; 20 | return hardware_; 21 | } 22 | const std::vector> &getHardware() { 23 | return getMutableHardware(); 24 | } 25 | 26 | int getAvailableHardware() { 27 | int avail = 0; 28 | for (auto &hw : getHardware()) { 29 | if (hw->count()) { 30 | avail |= 1 << hw->id(); 31 | } 32 | } 33 | return avail; 34 | } 35 | 36 | void registerHardware(std::shared_ptr hw) { 37 | std::lock_guard guard(registration_mutex_); 38 | hw->setId(getHardware().size()); 39 | getMutableHardware().emplace_back(hw); 40 | } 41 | 42 | int availableCPUs() { 43 | // TODO 44 | return 1; 45 | } 46 | 47 | int &getDefaultHardwareId() { 48 | static int default_hardware_id_ = 0; 49 | return default_hardware_id_; 50 | } 51 | 52 | const std::shared_ptr &getDefaultHardware() { 53 | for (auto &hw : getHardware()) { 54 | if (hw->id() == getDefaultHardwareId()) { 55 | return hw; 56 | } 57 | } 58 | return getHardware().at(0); 59 | } 60 | 61 | void setDefaultHardwareId(int id) { getDefaultHardwareId() = id; } 62 | 63 | } // namespace loop_tool 64 | -------------------------------------------------------------------------------- /src/core/tensor.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include "loop_tool/tensor.h" 8 | 9 | #include 10 | 11 | using namespace loop_tool; 12 | 13 | Tensor::Tensor(size_t N, int hardware) : hardware_id(hardware) { 14 | data = getHardware().at(hardware_id)->alloc(N * sizeof(float)); 15 | numel = N; 16 | } 17 | 18 | Tensor::~Tensor() { getHardware().at(hardware_id)->free(data); } 19 | -------------------------------------------------------------------------------- /src/frontends/nn.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include "loop_tool/nn.h" 8 | 9 | namespace loop_tool { 10 | namespace nn { 11 | 12 | using Tensor = loop_tool::lazy::Tensor; 13 | using namespace loop_tool::symbolic; 14 | 15 | Tensor convolve(Tensor X, Tensor W, std::vector spatial_dims, 16 | std::vector window_dims, int stride) { 17 | ASSERT(spatial_dims.size() == window_dims.size()); 18 | std::vector new_spatial_dims; 19 | std::vector constraints; 20 | for (auto i = 0; i < spatial_dims.size(); ++i) { 21 | const auto& sp_dim = spatial_dims.at(i); 22 | const auto& w_dim = window_dims.at(i); 23 | Symbol new_dim(sp_dim.name() + "_x_" + w_dim.name()); 24 | new_spatial_dims.emplace_back(new_dim); 25 | const auto& idx_equation = Expr(new_dim) * Expr(stride) + Expr(w_dim); 26 | constraints.emplace_back(std::make_pair(Expr(sp_dim), idx_equation)); 27 | } 28 | 29 | std::vector batch_dims; 30 | std::vector reduction_dims = window_dims; 31 | auto W_dims = to_set(W.shape()); 32 | auto X_sp_dims = to_set(spatial_dims); 33 | for (auto sym : X.shape()) { 34 | if (W_dims.count(sym)) { 35 | reduction_dims.emplace_back(sym); 36 | } else if (!X_sp_dims.count(sym)) { 37 | batch_dims.emplace_back(sym); 38 | } 39 | } 40 | 41 | batch_dims.insert(batch_dims.end(), new_spatial_dims.begin(), 42 | new_spatial_dims.end()); 43 | batch_dims.insert(batch_dims.end(), window_dims.begin(), window_dims.end()); 44 | X = X.to(batch_dims, constraints); 45 | return (X * W).sum(reduction_dims); 46 | } 47 | 48 | Tensor maxpool(Tensor X, std::vector spatial_dims, int k, int stride) { 49 | std::vector new_spatial_dims; 50 | std::vector new_window_dims; 51 | std::vector constraints; 52 | for (const auto& sym : spatial_dims) { 53 | Symbol new_spatial(sym.name() + "_p"); 54 | Symbol new_window(sym.name() + "_k"); 55 | new_spatial_dims.emplace_back(new_spatial); 56 | new_window_dims.emplace_back(new_window); 57 | auto idx_equation = Expr(new_spatial) * Expr(stride) + Expr(new_window); 58 | constraints.emplace_back(std::make_pair(Expr(sym), idx_equation)); 59 | constraints.emplace_back(std::make_pair(Expr::size(new_window), Expr(k))); 60 | } 61 | std::vector batch_dims; 62 | auto spatial_dim_set = to_set(spatial_dims); 63 | for (const auto& sym : X.shape()) { 64 | if (spatial_dim_set.count(sym)) { 65 | continue; 66 | } 67 | batch_dims.emplace_back(sym); 68 | } 69 | 70 | batch_dims.insert(batch_dims.end(), new_spatial_dims.begin(), 71 | new_spatial_dims.end()); 72 | batch_dims.insert(batch_dims.end(), new_window_dims.begin(), 73 | new_window_dims.end()); 74 | X = X.to(batch_dims, constraints); 75 | return X.max(new_window_dims); 76 | } 77 | 78 | } // namespace nn 79 | } // namespace loop_tool 80 | -------------------------------------------------------------------------------- /test/bench.py: -------------------------------------------------------------------------------- 1 | import loop_tool_py as lt 2 | import random 3 | import numpy as np 4 | import math 5 | import time 6 | 7 | lt.set_default_hardware("cuda") 8 | 9 | 10 | def gen_pw_add(): 11 | ir = lt.IR() 12 | a = ir.create_var("a") 13 | r0 = ir.create_node(lt.read, [], [a]) 14 | r1 = ir.create_node(lt.read, [], [a]) 15 | add = ir.create_node(lt.add, [r0, r1], [a]) 16 | w = ir.create_node(lt.write, [add], [a]) 17 | ir.set_inputs([r0, r1]) 18 | ir.set_outputs([w]) 19 | return ir, a 20 | 21 | 22 | def test_pw(size, inner_size, vec_size): 23 | assert size >= (inner_size * vec_size) 24 | ir, v = gen_pw_add() # v = pointwise var 25 | size_map = {} 26 | size_map[v] = size 27 | for n in ir.nodes: 28 | outer = size // (inner_size * vec_size) 29 | outer_rem = size % (inner_size * vec_size) 30 | 31 | ir.set_order( 32 | n, [(v, (outer, outer_rem)), (v, (inner_size, 0)), (v, (vec_size, 0))] 33 | ) 34 | ir.disable_reuse(n, 2) 35 | loop_tree = lt.LoopTree(ir) 36 | A = lt.RawTensor(size) 37 | B = lt.RawTensor(size) 38 | C = lt.RawTensor(size) 39 | Ap = np.random.randn(size) 40 | Bp = np.random.randn(size) 41 | A.set(Ap) 42 | B.set(Bp) 43 | C_ref = Ap + Bp 44 | C.set(1337.0) 45 | parallel = set(loop_tree.children(loop_tree.roots[0])) 46 | print(loop_tree) 47 | c = lt.cuda(loop_tree, parallel) 48 | c([A, B, C]) 49 | C_test = C.to_numpy() 50 | max_diff = np.max(np.abs(C_test - C_ref)) 51 | mean_val = np.mean(np.abs(C_ref)) 52 | assert max_diff < 1e-3 * mean_val 53 | iters = 10000 54 | # warmup 55 | for i in range(50): 56 | c([A, B, C]) 57 | t = time.time() 58 | for i in range(iters - 1): 59 | c([A, B, C], False) 60 | c([A, B, C]) 61 | t_ = time.time() 62 | # print(loop_tree.dump(lambda x: "[threaded]" if x in parallel else "")) 63 | # print(c.code) 64 | # 2 read 1 write, 4 bytes per float 65 | bytes_moved = (2 + 1) * 4 * size * iters / (t_ - t) / 1e9 66 | pct = bytes_moved / c.bandwidth 67 | usec = (t_ - t) / iters * 1e6 68 | # print(f"peak: {c.bandwidth} GB/sec") 69 | print( 70 | f"{bytes_moved:.2f} GB/sec", 71 | f"({100 * pct:.2f}% of peak, {usec:.2f} usec per iter)", 72 | ) 73 | return ( 74 | bytes_moved, 75 | c.code, 76 | loop_tree.dump(lambda x: "// Threaded" if x in parallel else ""), 77 | ) 78 | 79 | 80 | s = 1024 * 1024 81 | best = 0 82 | code = "" 83 | loop_tree = "" 84 | inner_scale = 512 * 8 85 | for i in range(1, s // inner_scale): 86 | inner = i * inner_scale 87 | for vec_pow in range(0, 3): 88 | vec = 2 ** vec_pow 89 | inner = inner // vec 90 | b, c, l = test_pw(s, inner, vec) 91 | if b > best: 92 | best = b 93 | code = c 94 | loop_tree = l 95 | print(f"Best kernel found ({best:.2f} GB/sec):") 96 | print(loop_tree) 97 | print(code) 98 | -------------------------------------------------------------------------------- /test/bench_lazy.py: -------------------------------------------------------------------------------- 1 | import loop_tool_py as lt 2 | import numpy as np 3 | import time 4 | 5 | L = 1024 6 | if "cuda" in lt.backends(): 7 | lt.set_default_hardware("cuda") 8 | lt.set_default_backend("cuda") 9 | L *= 1024 10 | 11 | X = lt.Tensor(L) 12 | Y = lt.Tensor(L) 13 | X.set(np.random.randn(L)) 14 | Y.set(np.random.randn(L)) 15 | 16 | N = lt.Symbol("N") 17 | Z = X.to(N) + Y.to(N) 18 | 19 | assert np.allclose(Z.numpy(), X.numpy() + Y.numpy(), atol=0.0001, rtol=0.0001) 20 | 21 | 22 | def bench(loop_tree, warmup, iters): 23 | X = lt.Tensor(L) 24 | Y = lt.Tensor(L) 25 | X.set(np.random.randn(L)) 26 | Y.set(np.random.randn(L)) 27 | N = lt.Symbol("N") 28 | Z = X.to(N) + Y.to(N) 29 | Z.set(loop_tree) 30 | 31 | for i in range(warmup): 32 | Z = X.to(N) + Y.to(N) 33 | Z.resolve() 34 | t1 = time.time() 35 | for i in range(iters): 36 | Z = X.to(N) + Y.to(N) 37 | Z.resolve() 38 | t2 = time.time() 39 | print(f"{iters / (t2 - t1):.2f} iters/sec") 40 | 41 | 42 | def split(loop, parallel_size, inner_size): 43 | assert loop.tail == 0 44 | s = loop.size // (parallel_size * inner_size) 45 | t = loop.size % (parallel_size * inner_size) 46 | return [ 47 | (loop.var, (s, t)), 48 | (loop.var, (parallel_size, 0)), 49 | (loop.var, (inner_size, 0)), 50 | ] 51 | 52 | 53 | loop_tree = Z.loop_tree 54 | ir = loop_tree.ir 55 | 56 | for l in loop_tree.loops: 57 | if loop_tree.trivially_parallel(l): 58 | loop = loop_tree.loop(l) 59 | for n in ir.nodes: 60 | ir.set_order(n, split(loop, 128, 4)) 61 | ir.disable_reuse(n, 2) 62 | 63 | loop_tree = lt.LoopTree(ir) 64 | # parallelize the outermost loops 65 | loop_tree.annotate(loop_tree.loops[0], "parallel") 66 | loop_tree.annotate(loop_tree.loops[1], "parallel") 67 | 68 | Z.set(loop_tree) 69 | 70 | print(Z.loop_tree) 71 | bench(loop_tree, 10, 1000) 72 | -------------------------------------------------------------------------------- /test/loop_nest_test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | 9 | #include "test_utils.h" 10 | 11 | using namespace loop_tool::testing; 12 | 13 | TEST(LoopNestBackend) { 14 | loop_tool::ScopedBackend sb("loop_nest"); 15 | namespace lz = ::loop_tool::lazy; 16 | auto mm = [](lz::Tensor A, lz::Tensor B) { 17 | auto M = lz::Symbol("M"); 18 | auto N = lz::Symbol("N"); 19 | auto K = lz::Symbol("K"); 20 | auto C = A.as(M, K) * B.as(K, N); 21 | return C.sum(K); 22 | }; 23 | 24 | auto M = 16; 25 | auto N = 16; 26 | auto K = 16; 27 | 28 | lz::Tensor A(M, K); 29 | lz::Tensor B(K, N); 30 | for (auto i = 0; i < M * K; ++i) { 31 | A.data()[i] = 1; 32 | B.data()[i] = 2; 33 | } 34 | auto C = mm(A, B); 35 | auto d = C.data(); 36 | for (auto i = 0; i < M * N; ++i) { 37 | std::cerr << d[i] << " "; 38 | } 39 | std::cerr << "\n"; 40 | C.clear_cache(); 41 | } 42 | 43 | TEST(LoopNestMM) { 44 | loop_tool::ScopedBackend sb("loop_nest"); 45 | namespace lz = ::loop_tool::lazy; 46 | auto mm = [](lz::Tensor A, lz::Tensor B) { 47 | auto M = lz::Symbol("M"); 48 | auto N = lz::Symbol("N"); 49 | auto K = lz::Symbol("K"); 50 | auto C = A.as(M, K) * B.as(K, N); 51 | return C.sum(K); 52 | }; 53 | 54 | auto M = 16; 55 | auto N = 16; 56 | auto K = 16; 57 | 58 | lz::Tensor A(M, K); 59 | lz::Tensor B(K, N); 60 | rand(A.data(), M * K); 61 | rand(B.data(), K * N); 62 | 63 | auto C = mm(A, B); 64 | lz::Tensor C_ref(M * N); 65 | ref_mm(A.data(), B.data(), M, N, K, C_ref.data()); 66 | 67 | ASSERT(all_close(C_ref.data(), C.data(), M * N)); 68 | C.clear_cache(); 69 | } 70 | 71 | TEST(LoopNestConv) { 72 | loop_tool::ScopedBackend sb("loop_nest"); 73 | namespace lz = ::loop_tool::lazy; 74 | 75 | auto conv = [](lz::Tensor X, lz::Tensor w) { 76 | lz::Symbol N("N"), M("M"), C("C"), H("H"), Ho("Ho"), W("W"), Wo("Wo"), 77 | Kh("Kh"), Kw("Kw"); 78 | X = X.as(N, C, H, W); 79 | w = w.as(M, C, Kh, Kw); 80 | auto X_im2col = X.to({N, C, Ho, Kh, Wo, Kw}, lz::Constraint(H, Ho + Kh), 81 | lz::Constraint(W, Wo + Kw)); 82 | auto Y = (X_im2col * w).sum(Kh, Kw, C); 83 | return Y.transpose({N, M, Ho, Wo}); 84 | }; 85 | 86 | auto N = 4; 87 | auto M = 64; 88 | auto C = 64; 89 | auto HW = 8; 90 | auto K = 3; 91 | auto HWo = HW - K + 1; 92 | 93 | lz::Tensor A(N, C, HW, HW); 94 | lz::Tensor B(M, C, K, K); 95 | rand(A.data(), A.numel()); 96 | rand(B.data(), B.numel()); 97 | 98 | auto C_lt = conv(A, B); 99 | std::cerr << C_lt.numel() << " vs " << (N * M * HWo * HWo) << "\n"; 100 | ASSERT(C_lt.numel() == N * M * HWo * HWo); 101 | lz::Tensor C_ref(C_lt.numel()); 102 | ref_conv(A.data(), B.data(), N, M, C, HW, K, 103 | C_ref.data()); 104 | 105 | ASSERT(all_close(C_ref.data(), C_lt.data(), C_lt.numel())); 106 | } 107 | 108 | TEST(LoopNestEmbedded) { 109 | loop_tool::ScopedBackend sb("cpu_interpreted"); 110 | namespace lz = ::loop_tool::lazy; 111 | auto mm = [](lz::Tensor A, lz::Tensor B) { 112 | auto M = lz::Symbol("M"); 113 | auto N = lz::Symbol("N"); 114 | auto K = lz::Symbol("K"); 115 | auto C = A.as(M, K) * B.as(K, N); 116 | return C.sum(K); 117 | }; 118 | 119 | auto M = 16; 120 | auto N = 16; 121 | auto K = 16; 122 | 123 | lz::Tensor A(M, K); 124 | lz::Tensor B(K, N); 125 | rand(A.data(), M * K); 126 | rand(B.data(), K * N); 127 | 128 | auto C = mm(A, B); 129 | auto tree = C.loop_tree(); 130 | tree = annotate(tree, tree.roots[0], "[loop_nest]"); 131 | C.set(tree); 132 | std::cerr << "TRE IS " << tree.dump() << "\n"; 133 | lz::Tensor C_ref(M * N); 134 | ref_mm(A.data(), B.data(), M, N, K, C_ref.data()); 135 | 136 | ASSERT(all_close(C_ref.data(), C.data(), M * N)); 137 | C.clear_cache(); 138 | } 139 | -------------------------------------------------------------------------------- /test/test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | #include 9 | #include 10 | 11 | #include "loop_tool/loop_tool.h" 12 | #include "test_utils.h" 13 | 14 | using namespace loop_tool; 15 | using namespace loop_tool::testing; 16 | 17 | int main(int argc, char *argv[]) { RUN_TESTS(argc, argv); } 18 | -------------------------------------------------------------------------------- /test/test.mjs: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | import * as lt from '../javascript/loop_tool.mjs'; 8 | import * as fs from 'fs'; 9 | 10 | import { 11 | PerformanceObserver, 12 | performance 13 | } from 'perf_hooks'; 14 | 15 | function cmp(a, b) { 16 | if (a.length != b.length) { 17 | return false; 18 | } 19 | for (let i = 0; i < a.length; ++i) { 20 | if (Math.abs(a[i] - b[i]) > 0.001) { 21 | console.log(a[i], b[i], "at index", i); 22 | return false; 23 | } 24 | } 25 | return true; 26 | 27 | } 28 | 29 | function rand(array) { 30 | for (let i = 0; i < array.length; ++i) { 31 | array[i] = Math.random(); 32 | } 33 | } 34 | 35 | function mm(a, b, m, n, k) { 36 | const c = new Float32Array(m * n); 37 | for (let m_ = 0; m_ < m; ++m_) { 38 | for (let n_ = 0; n_ < n; ++n_) { 39 | for (let k_ = 0; k_ < k; ++k_) { 40 | c[m_ * n + n_] += a[m_ * k + k_] * b[k_ * n + n_]; 41 | } 42 | } 43 | } 44 | return c; 45 | } 46 | 47 | try { 48 | (async () => { 49 | let [m, n, k] = lt.symbols("M N K"); 50 | let a = new lt.Tensor(100, 200).to(m, k); 51 | let b = new lt.Tensor(200, 300).to(k, n); 52 | rand(a.buffer); 53 | rand(b.buffer); 54 | let c = a.mul(b).sum(k); 55 | let loop_tree = c.loop_tree; 56 | console.log(loop_tree.walk().length); 57 | for (let ref of loop_tree.walk()) { 58 | const d = loop_tree.depth(ref); 59 | if (loop_tree.is_loop(ref)) { 60 | const loop = loop_tree.loop(ref); 61 | const v = loop.v(); 62 | console.log(" ".repeat(d), "iter", loop_tree.var_name(v)); 63 | } else { 64 | const node = loop_tree.node(ref); 65 | console.log(" ".repeat(d), 'node'); 66 | } 67 | } 68 | })() 69 | } catch (e) { 70 | console.log(e); 71 | } 72 | 73 | (async () => { 74 | let n = new lt.Symbol("N"); 75 | let k = new lt.Symbol("K"); 76 | let no = new lt.Symbol("No"); 77 | let a = new lt.Tensor(10).to(n); 78 | rand(a.buffer); 79 | let b = new lt.Tensor(3).to(k); 80 | rand(b.buffer); 81 | a = a.to(no, k, [ 82 | [n.expr(), no.expr().add(k.expr())] 83 | ]); 84 | let c = a.mul(b).sum(k); 85 | console.log(c.shape); 86 | const loop_tree = c.loop_tree; 87 | for (let ref of loop_tree.walk()) { 88 | if (loop_tree.is_loop(ref)) { 89 | console.log(loop_tree.depth(ref)); 90 | } 91 | } 92 | let d = await c.data; 93 | console.log("data", d); 94 | })(); 95 | 96 | (async () => { 97 | let n = new lt.Symbol("N"); 98 | let a = new lt.Tensor(2).to(n); 99 | a.buffer[0] = 3; 100 | a.buffer[1] = 2; 101 | let b = new lt.Tensor(2).to(n); 102 | b.set(new Float32Array([4, 9])); 103 | let c = a.add(b); 104 | c = c.add(b); 105 | console.log(c.hash + '.wasm'); 106 | fs.writeFile(c.hash + '.wasm', c.wasm, _ => {}); 107 | //console.log(c.graphviz); 108 | let d = await c.data; 109 | console.log(d); 110 | }); 111 | 112 | (async () => { 113 | let n = new lt.Symbol("N"); 114 | const N = 10; 115 | let a = new lt.Tensor(N).to(n); 116 | let b = new lt.Tensor(N).to(n); 117 | rand(a.buffer); 118 | rand(b.buffer); 119 | let c = a.add(b); 120 | c = c.add(b); 121 | const loop_tree = c.loop_tree; 122 | let roots = loop_tree.children(-1); 123 | loop_tree.annotate(roots[0], "unroll"); 124 | console.log(loop_tree.dump()); 125 | c.set_loop_tree(loop_tree); 126 | console.log(c.hash + '.wasm'); 127 | fs.writeFile(c.hash + '.wasm', c.wasm, _ => {}); 128 | let d = await c.data; 129 | for (let i = 0; i < N; ++i) { 130 | if (Math.abs(d[i] - (a.buffer[i] + 2 * b.buffer[i])) > 0.001) { 131 | console.log("EROR", d[i]); 132 | } 133 | } 134 | console.log(d); 135 | })(); 136 | 137 | (async () => { 138 | let [m, n, k] = lt.symbols("M N K"); 139 | let a = new lt.Tensor(100, 200).to(m, k); 140 | let b = new lt.Tensor(200, 300).to(k, n); 141 | rand(a.buffer); 142 | rand(b.buffer); 143 | let c_ref = mm(a.buffer, b.buffer, 100, 300, 200); 144 | let c = a.mul(b).sum(k); 145 | let d = await c.data; 146 | //console.log(c.graphviz); 147 | console.log(c.shape, c.symbolic_shape); 148 | if (cmp(c_ref, d)) { 149 | console.log("results look good"); 150 | } else { 151 | console.log("ERROR!"); 152 | } 153 | })(); 154 | 155 | async function benchmark(fn, warmup = 100, iters = 10000) { 156 | for (let i = 0; i < warmup; ++i) { 157 | await fn(); 158 | } 159 | let t0 = performance.now(); 160 | for (let i = 0; i < iters; ++i) { 161 | await fn(); 162 | } 163 | let t1 = performance.now(); 164 | return 1e3 * iters / (t1 - t0); 165 | } 166 | 167 | (async () => { 168 | const fn_wrapped = async () => { 169 | let n = new lt.Symbol("N"); 170 | let a = new lt.Tensor(128 * 128).to(n); 171 | let b = new lt.Tensor(128 * 128).to(n); 172 | let c = a.add(b); 173 | let d = await c.data; 174 | } 175 | let n = new lt.Symbol("N"); 176 | let a = new lt.Tensor(128 * 128).to(n); 177 | let b = new lt.Tensor(128 * 128).to(n); 178 | let c = a.add(b); 179 | let [mem_map, fn] = await c.compile(); 180 | const fn_mem = async () => { 181 | for (let k of Object.keys(mem_map)) { 182 | if (k == c._id) { 183 | continue; 184 | } 185 | mem_map[k].fill(1); 186 | } 187 | fn(); 188 | } 189 | console.log(await benchmark(fn), "iters per second (pure fn)"); 190 | console.log(await benchmark(fn_mem), "iters per second (fn + fill inputs)"); 191 | console.log(await benchmark(fn_wrapped, 10, 100), "iters per second (wrapped)"); 192 | 193 | { 194 | let [m, n, k] = lt.symbols("M N K"); 195 | let a = new lt.Tensor(100, 200).to(m, k); 196 | let b = new lt.Tensor(200, 300).to(k, n); 197 | let c = a.mul(b).sum(k); 198 | let [mem_map, fn] = await c.compile(); 199 | let iter_sec = await benchmark(fn, 10, 100); 200 | console.log(iter_sec, "mm iters per second (pure fn)", `${100 * 200 * 300 * 2 * iter_sec / 1e9} gflops`); 201 | } 202 | 203 | })(); 204 | 205 | (async () => { 206 | let m = lt.symbol("m"); 207 | let a = lt.rand(128).to(m); 208 | let b = a.sum(m); 209 | const d = new Float32Array(1); 210 | d.set(await b.data); 211 | let tree = b.loop_tree; 212 | const roots = tree.children(-1); 213 | let new_tree = tree.annotate(roots[0], "unroll"); 214 | new_tree; 215 | b.set_loop_tree(new_tree); 216 | const e = new Float32Array(1); 217 | console.log(b.data); 218 | e.set(await b.data); 219 | console.log("diff", d, e); 220 | })(); -------------------------------------------------------------------------------- /test/test_backend.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | 9 | #include "test_utils.h" 10 | 11 | namespace lt = loop_tool; 12 | 13 | struct CustomCompiled : public lt::Compiled { 14 | void run(const std::vector &memory, bool sync) const override { 15 | std::cerr << "here!\n"; 16 | return; 17 | } 18 | }; 19 | 20 | struct CustomBackend : lt::Backend { 21 | CustomBackend() : lt::Backend("custom") {} 22 | 23 | std::unique_ptr compile_impl( 24 | const lt::LoopTree <) const override { 25 | return std::make_unique(); 26 | } 27 | 28 | int hardware_requirement() const override { 29 | return 0; // CPU 30 | } 31 | }; 32 | 33 | static lt::RegisterBackend custom_backend_reg_{ 34 | std::make_shared()}; 35 | 36 | TEST(CustomBackend) { 37 | // define 38 | lt::IR ir; 39 | auto a = ir.create_var("a"); 40 | auto b = ir.create_var("b"); 41 | auto r = ir.create_node(lt::Operation::read, {}, {a, b}); 42 | auto add = ir.create_node(lt::Operation::add, {r}, {}); 43 | auto w = ir.create_node(lt::Operation::write, {add}, {}); 44 | ir.set_inputs({r}); 45 | ir.set_outputs({w}); 46 | 47 | // schedule 48 | constexpr int N = 16; 49 | /* 50 | read and add nodes have the loop order: 51 | 52 | ``` 53 | for a in N: 54 | for b in N: 55 | read 56 | add 57 | ``` 58 | 59 | **/ 60 | ir.set_order(r, {{a, {N, 0}}, {b, {N, 0}}}); 61 | ir.set_order(add, {{a, {N, 0}}, {b, {N, 0}}}); 62 | // write can be executed without looping 63 | ir.set_order(w, {}); 64 | lt::LoopTree loop_tree(ir); 65 | 66 | std::cout << loop_tree.dump(); 67 | 68 | // compile and run 69 | auto compiled = lt::getBackends().at("custom")->compile(loop_tree); 70 | auto A = lt::Tensor(N * N); 71 | auto B = lt::Tensor(1); 72 | const auto &f = *compiled; 73 | f(A, B); 74 | f.async(A, B); 75 | } 76 | -------------------------------------------------------------------------------- /test/test_cpp.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #include "test_utils.h" 14 | 15 | TEST(CppFromLazy) { 16 | namespace lz = ::loop_tool::lazy; 17 | auto mm = [](lz::Tensor A, lz::Tensor B) { 18 | auto M = lz::Symbol("M"); 19 | auto N = lz::Symbol("N"); 20 | auto K = lz::Symbol("K"); 21 | auto C = A.as(M, K) * B.as(K, N); 22 | return C.sum(K); 23 | }; 24 | 25 | auto M = 16; 26 | auto N = 16; 27 | auto K = 16; 28 | 29 | lz::Tensor A(M, K); 30 | lz::Tensor B(K, N); 31 | auto C = mm(A, B); 32 | auto compiler = loop_tool::CppCompiler(C.loop_tree()); 33 | auto code = compiler.gen_string(); 34 | std::string fn_name = "fn_" + std::to_string(compiler.count); 35 | 36 | std::ofstream("/tmp/fn_impl.c") << code; 37 | std::system( 38 | "cc -Wall -Werror -fpic -shared -o /tmp/fn_impl.so /tmp/fn_impl.c"); // compile 39 | loop_tool::DynamicLibrary dll("/tmp/fn_impl.so"); 40 | auto fn = dll.sym(fn_name.c_str()); 41 | { 42 | float* A = (float*)calloc(sizeof(float), 16 * 16); 43 | float* B = (float*)calloc(sizeof(float), 16 * 16); 44 | float* C = (float*)calloc(sizeof(float), 16 * 16); 45 | float* C_ref = (float*)calloc(sizeof(float), 16 * 16); 46 | for (int64_t i = 0; i < 16 * 16; ++i) { 47 | A[i] = i * 3; 48 | B[i] = 100 - (i * 2); 49 | } 50 | for (int64_t i = 0; i < 16; ++i) { 51 | for (int64_t j = 0; j < 16; ++j) { 52 | for (int64_t k = 0; k < 16; ++k) { 53 | C_ref[i * 16 + j] += A[i * 16 + k] * B[k * 16 + j]; 54 | } 55 | } 56 | } 57 | void* tmp = malloc(sizeof(float) * 16); 58 | void* mem[5] = {A, B, C, 0, tmp}; 59 | fn(mem); 60 | for (int64_t i = 0; i < 16 * 16; ++i) { 61 | auto diff = std::abs(C[i] - C_ref[i]); 62 | ASSERT(diff < 0.01) << "difference of " << diff; 63 | } 64 | } 65 | } 66 | 67 | TEST(CppWithTail) { 68 | namespace lz = ::loop_tool::lazy; 69 | auto mm = [](lz::Tensor A, lz::Tensor B) { 70 | auto M = lz::Symbol("m"), N = lz::Symbol("n"), K = lz::Symbol("k"); 71 | auto C = A.as(M, K) * B.as(K, N); 72 | return C.sum(K); 73 | }; 74 | 75 | lz::Tensor A(16, 16); 76 | lz::Tensor B(16, 16); 77 | auto C = mm(A, B); 78 | auto lt = C.loop_tree(); 79 | std::cerr << '\n'; 80 | std::cerr << lt.dump(); 81 | std::cerr << '\n'; 82 | auto r = lt.children(lt.roots.at(0)).at(0); 83 | lt = loop_tool::split(lt, r, 10); 84 | 85 | auto a = lt.children(lt.children(lt.roots.at(0)).at(0)).at(0); 86 | auto b = lt.children(a).at(0); 87 | lt = loop_tool::swap_loops(lt, a, b); 88 | 89 | std::cerr << '\n'; 90 | std::cerr << lt.dump(); 91 | 92 | C.compile(); 93 | C.set(lt); 94 | auto compiler = loop_tool::CppCompiler(C.loop_tree()); 95 | auto code = compiler.gen_string(); 96 | std::string fn_name = "fn_" + std::to_string(compiler.count); 97 | std::cerr << code << "\n"; 98 | std::ofstream("/tmp/fn_impl.c") << code; 99 | std::system( 100 | "cc -Wall -Werror -fpic -shared -o /tmp/fn_impl.so /tmp/fn_impl.c"); // compile 101 | loop_tool::DynamicLibrary dll("/tmp/fn_impl.so"); 102 | auto fn = dll.sym(fn_name.c_str()); 103 | { 104 | float* A = (float*)calloc(sizeof(float), 16 * 16); 105 | float* B = (float*)calloc(sizeof(float), 16 * 16); 106 | float* C = (float*)calloc(sizeof(float), 16 * 16); 107 | float* C_ref = (float*)calloc(sizeof(float), 16 * 16); 108 | for (int64_t i = 0; i < 16 * 16; ++i) { 109 | A[i] = i * 3; 110 | B[i] = 100 - (i * 2); 111 | } 112 | for (int64_t i = 0; i < 16; ++i) { 113 | for (int64_t j = 0; j < 16; ++j) { 114 | for (int64_t k = 0; k < 16; ++k) { 115 | C_ref[i * 16 + j] += A[i * 16 + k] * B[k * 16 + j]; 116 | } 117 | } 118 | } 119 | void* tmp = malloc(sizeof(float) * 16); 120 | void* mem[5] = {A, B, C, 0, tmp}; 121 | fn(mem); 122 | for (int64_t i = 0; i < 16 * 16; ++i) { 123 | auto diff = std::abs(C[i] - C_ref[i]); 124 | ASSERT(diff < 0.01) << "difference of " << diff << " at " << i / 16 125 | << ", " << i % 16 << " (" << C[i] << " vs expected " 126 | << C_ref[i] << ")"; 127 | } 128 | } 129 | } 130 | 131 | TEST(CppView) { 132 | namespace lz = ::loop_tool::lazy; 133 | auto padded_conv = [](lz::Tensor X, lz::Tensor W) { 134 | auto N = lz::Symbol("n"), Np = lz::Symbol("np"); 135 | auto X_pad = X.as(N).pad(N, 1).as(Np); 136 | auto No = lz::Symbol("no"), K = lz::Symbol("k"); 137 | return (X_pad.to({No, K}, {{Np, No + K}}) * W.as(K)).sum(K); 138 | }; 139 | lz::Tensor A(16); 140 | lz::Tensor B(3); 141 | auto C = padded_conv(A, B); 142 | auto lt = C.loop_tree(); 143 | std::cerr << '\n'; 144 | std::cerr << lt.dump(); 145 | std::cerr << '\n'; 146 | auto compiler = loop_tool::CppCompiler(C.loop_tree()); 147 | auto code = compiler.gen_string(); 148 | std::string fn_name = "fn_" + std::to_string(compiler.count); 149 | 150 | std::cerr << code << "\n"; 151 | std::ofstream("/tmp/fn_impl.c") << code; 152 | std::system( 153 | "cc -g -O0 -Wall -Werror -fpic -shared -o /tmp/fn_impl.so " 154 | "/tmp/fn_impl.c"); // compile 155 | loop_tool::DynamicLibrary dll("/tmp/fn_impl.so"); 156 | auto fn = dll.sym(fn_name.c_str()); 157 | { 158 | float* A = (float*)calloc(sizeof(float), 16); 159 | float* B = (float*)calloc(sizeof(float), 3); 160 | float* C = (float*)calloc(sizeof(float), 16); 161 | float* C_ref = (float*)calloc(sizeof(float), 16); 162 | for (int64_t i = 0; i < 16; ++i) { 163 | A[i] = i * 3 + 1; 164 | } 165 | for (int64_t i = 0; i < 3; ++i) { 166 | B[i] = 1 - (i * 2); 167 | } 168 | for (int64_t i = 0; i < 16; ++i) { 169 | for (int64_t k = 0; k < 3; ++k) { 170 | if ((i + k - 1 >= 0) && (i + k - 1 < 16)) { 171 | C_ref[i] += A[i + k - 1] * B[k]; 172 | } 173 | } 174 | } 175 | void* tmp = malloc(sizeof(float) * 18); 176 | void* mem[5] = {A, B, C, tmp}; 177 | fn(mem); 178 | for (int64_t i = 0; i < 16; ++i) { 179 | auto diff = std::abs(C[i] - C_ref[i]); 180 | std::cerr << C[i] << " vs " << C_ref[i] << "\n"; 181 | ASSERT(diff < 0.01) << "difference of " << diff; 182 | } 183 | } 184 | } 185 | -------------------------------------------------------------------------------- /test/test_ir.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | 9 | #include "test_utils.h" 10 | 11 | using namespace loop_tool; 12 | using namespace loop_tool::testing; 13 | 14 | TEST(DotDump) { 15 | IR ir; 16 | constexpr int N = 16; 17 | auto a = ir.create_var("a"); 18 | auto b = ir.create_var("b"); 19 | auto r = ir.create_node(Operation::read, {}, {a, b}); 20 | auto add = ir.create_node(Operation::add, {r}, {}); 21 | auto w = ir.create_node(Operation::write, {add}, {}); 22 | ir.set_inputs({r}); 23 | ir.set_outputs({w}); 24 | std::cerr << LoopTree(ir).dump() << "\n"; 25 | std::cerr << dot(ir) << "\n"; 26 | ir = split_node(ir, add, {b}); 27 | std::cerr << " -- split -- \n"; 28 | std::cerr << LoopTree(ir).dump() << "\n"; 29 | std::cerr << dot(ir) << "\n"; 30 | } 31 | 32 | TEST(SetPriority) { 33 | IR ir; 34 | auto a = ir.create_var("a"); 35 | auto b = ir.create_var("b"); 36 | auto r0 = ir.create_node(Operation::read, {}, {a, b}); 37 | auto r1 = ir.create_node(Operation::read, {}, {a, b}); 38 | auto add = ir.create_node(Operation::add, {r0, r1}, {a, b}); 39 | auto w = ir.create_node(Operation::write, {add}, {a, b}); 40 | ir.set_inputs({r0, r1}); 41 | ir.set_priority(r1, 10); 42 | LoopTree lt(ir); 43 | std::cerr << "dumping:\n"; 44 | std::cerr << lt.dump(); 45 | } 46 | 47 | TEST(NegativeSizes) { 48 | IR ir; 49 | auto a = ir.create_var("a"); 50 | auto b = ir.create_var("b"); 51 | auto c = ir.create_var("c"); 52 | auto r0 = ir.create_node(Operation::read, {}, {a, b}); 53 | auto r1 = ir.create_node(Operation::read, {}, {b, c}); 54 | auto mul = ir.create_node(Operation::multiply, {r0, r1}, {a, b, c}); 55 | auto add = ir.create_node(Operation::add, {mul}, {a, c}); 56 | auto w = ir.create_node(Operation::write, {add}, {a, c}); 57 | ir.set_inputs({r0, r1}); 58 | ir.set_priority(r1, 10); 59 | ir.set_priority(r0, 100); 60 | ir.set_order(r1, {{b, {-1, 0}}, {c, {-1, 0}}}); 61 | LoopTree lt(ir); 62 | std::cerr << "dumping:\n"; 63 | std::cerr << lt.dump(); 64 | } 65 | 66 | TEST(BasicSchedule) { 67 | IR ir; 68 | constexpr int M = 16; 69 | constexpr int N = 16; 70 | constexpr int K = 16; 71 | auto m = ir.create_var("m"); 72 | auto n = ir.create_var("n"); 73 | auto k = ir.create_var("k"); 74 | 75 | auto r0 = ir.create_node(Operation::read, {}, {m, k}); 76 | auto r1 = ir.create_node(Operation::read, {}, {k, n}); 77 | 78 | auto mul = ir.create_node(Operation::multiply, {r1, r0}, {m, k, n}); 79 | auto add = ir.create_node(Operation::add, {mul}, {m, n}); 80 | 81 | auto w = ir.create_node(Operation::write, {add}, {m, n}); 82 | 83 | ir.set_order(r0, {{m, {M, 0}}, {k, {K, 0}}}); 84 | ir.set_order(r1, {{m, {M, 0}}, {n, {N, 0}}, {k, {K, 0}}}); 85 | ir.set_priority(r1, 10); 86 | ir.set_priority(r0, 0); 87 | ir.set_order(mul, {{m, {M, 0}}, {n, {N, 0}}, {k, {K, 0}}}); 88 | ir.set_order(add, {{m, {M, 0}}, {n, {N, 0}}, {k, {K, 0}}}); 89 | ir.set_order(w, {{m, {M, 0}}, {n, {N, 0}}}); 90 | ir.set_inputs({r0, r1}); 91 | ir.set_outputs({w}); 92 | LoopTree lt(ir); 93 | std::cerr << lt.dump(); 94 | float in0[M * K]; 95 | float in1[N * K]; 96 | float out[M * N]; 97 | rand(in0, M * K); 98 | rand(in1, N * K); 99 | auto cc = getDefaultBackend()->compile(lt); 100 | cc->run({in0, in1, out}); 101 | float out_ref[M * N]; 102 | ref_mm(in0, in1, M, N, K, out_ref); 103 | float max_diff = 0; 104 | for (auto i = 0; i < M * N; ++i) { 105 | max_diff = std::max(max_diff, std::abs((float)(out_ref[i] - out[i]))); 106 | ASSERT(max_diff < 0.01) 107 | << "diff is " << max_diff << " at index " << i << " (" << out[i] 108 | << " vs ref " << out_ref[i] << ")"; 109 | } 110 | std::cout << "max diff " << max_diff << "\n"; 111 | } 112 | 113 | TEST(NodeSplit) { 114 | IR ir; 115 | constexpr int N = 16; 116 | auto a = ir.create_var("a"); 117 | auto b = ir.create_var("b"); 118 | auto r = ir.create_node(Operation::read, {}, {a, b}); 119 | auto add = ir.create_node(Operation::add, {r}, {}); 120 | auto w = ir.create_node(Operation::write, {add}, {}); 121 | ir.set_inputs({r}); 122 | ir.set_outputs({w}); 123 | ir = split_node(ir, add, {b}); 124 | std::cout << dot(ir) << "\n"; 125 | 126 | for (auto n : ir.nodes()) { 127 | std::vector> sched; 128 | for (auto v : ir.loop_vars(n)) { 129 | sched.emplace_back(std::pair{v, {N, 0}}); 130 | } 131 | ir.set_order(n, sched); 132 | } 133 | 134 | auto lt = LoopTree(ir); 135 | lt.walk([&](LoopTree::TreeRef ref, int) { 136 | if (is_trivially_parallel(lt, ref)) { 137 | annotate(lt, ref, "parallel"); 138 | } 139 | }); 140 | std::cout << lt.dump() << "\n"; 141 | 142 | auto cc = getDefaultBackend()->compile(lt); 143 | std::vector input(N * N); 144 | float ref = 0; 145 | for (auto i = 0; i < N * N; ++i) { 146 | input[i] = i * 3; 147 | ref += i * 3; 148 | } 149 | std::vector output(1); 150 | cc->run({input.data(), output.data()}, true); 151 | std::cout << "sum of vals from 0 to " << (N * N - 1) << " is " << output[0] 152 | << "\n"; 153 | ASSERT(std::abs(ref - output[0]) < 0.01) 154 | << "expected " << ref << " but got " << output[0]; 155 | } 156 | 157 | TEST(BasicInterpreter) { 158 | IR ir; 159 | constexpr int N = 405; 160 | auto a = ir.create_var("a"); 161 | auto r = ir.create_node(Operation::read, {}, {a}); 162 | auto add = ir.create_node(Operation::add, {r}, {a}); 163 | auto w = ir.create_node(Operation::write, {add}, {a}); 164 | ir.set_inputs({r}); 165 | ir.set_outputs({w}); 166 | 167 | for (auto n : ir.nodes()) { 168 | std::vector> sched; 169 | // for (auto v : ir.loop_vars(n)) { 170 | sched.emplace_back(std::pair{a, {10, 15}}); 171 | sched.emplace_back(std::pair{a, {4, 3}}); 172 | sched.emplace_back(std::pair{a, {4, 1}}); 173 | sched.emplace_back(std::pair{a, {2, 0}}); 174 | //} 175 | ir.set_order(n, sched); 176 | } 177 | 178 | auto lt = LoopTree(ir); 179 | std::cout << lt.dump() << "\n"; 180 | auto cc = getDefaultBackend()->compile(lt); 181 | std::vector input(N); 182 | for (auto i = 0; i < N; ++i) { 183 | input[i] = i * 3; 184 | } 185 | std::vector output(N); 186 | // cc->run({input.data(), output.data()}, true); 187 | } 188 | -------------------------------------------------------------------------------- /test/test_lazy.py: -------------------------------------------------------------------------------- 1 | import loop_tool_py as lt 2 | import numpy as np 3 | import time 4 | 5 | 6 | backend = "cpu" 7 | if "cuda" in lt.backends(): 8 | backend = "cuda" 9 | lt.set_default_hardware("cuda") 10 | lt.set_default_backend("cuda") 11 | 12 | 13 | m, n, k = 8, 8, 8 14 | A = lt.Tensor(m, k).set(np.random.randn(m, k)) 15 | B = lt.Tensor(k, n).set(np.random.randn(k, n)) 16 | 17 | 18 | def mm(A, B): 19 | N = lt.Symbol("N") 20 | M = lt.Symbol("M") 21 | K = lt.Symbol("K") 22 | C = A.to(M, K) * B.to(K, N) 23 | return C.sum(K) 24 | 25 | 26 | C = mm(A, B) 27 | if backend == "cuda": 28 | print(C.compiled.code) 29 | C_ref = A.numpy() @ B.numpy() 30 | assert np.allclose(C.numpy(), C_ref, atol=0.0001, rtol=0.0001) 31 | 32 | 33 | def conv(X, W): 34 | with lt.SymbolGenerator() as s: 35 | return (X[s.No + s.K] * W.to(s.K)).sum(s.K) 36 | 37 | 38 | X = lt.Tensor(128).set(np.random.randn(128)) 39 | W = lt.Tensor(3).set(np.random.randn(3)) 40 | 41 | Y = conv(X, W) 42 | Y_ref = np.correlate(X.numpy(), W.numpy(), mode="valid") 43 | assert np.allclose(Y.numpy(), Y_ref, atol=0.0001, rtol=0.0001) 44 | 45 | # unbound sizes can be inferred 46 | 47 | N = lt.Symbol("N") 48 | K = lt.Symbol("K") 49 | 50 | X = lt.Tensor(N) 51 | W0 = lt.Tensor(3) 52 | W1 = lt.Tensor(8) 53 | 54 | Y = conv(X, W0) 55 | Z = (Y.to(K) * W1.to(K)).sum(K) 56 | 57 | Z.unify() 58 | assert X.shape[0] == 10 59 | Z.compile() 60 | 61 | # we can override the schedule 62 | def schedule(ir): 63 | # print(ir) 64 | return ir 65 | 66 | 67 | Z.set(schedule(Z.ir)) 68 | print(Z.loop_tree) 69 | 70 | W0.set(np.ones(3)) 71 | W1.set(np.ones(8)) 72 | X.set(np.ones(10)) 73 | 74 | assert Z.numpy() == 24 75 | 76 | L = 1024 * 128 77 | 78 | X = lt.Tensor(L) 79 | Y = lt.Tensor(L) 80 | X.set(np.random.randn(L)) 81 | Y.set(np.random.randn(L)) 82 | 83 | N = lt.Symbol("N") 84 | Z = X.to(N) + Y.to(N) 85 | print(Z.loop_tree) 86 | 87 | assert np.allclose(Z.numpy(), X.numpy() + Y.numpy(), atol=0.0001, rtol=0.0001) 88 | 89 | 90 | def bench(loop_tree, warmup, iters): 91 | X = lt.Tensor(L) 92 | Y = lt.Tensor(L) 93 | X.set(np.random.randn(L)) 94 | Y.set(np.random.randn(L)) 95 | N = lt.Symbol("N") 96 | Z = X.to(N) + Y.to(N) 97 | Z.set(loop_tree) 98 | 99 | for i in range(warmup): 100 | Z = X.to(N) + Y.to(N) 101 | Z.resolve() 102 | t1 = time.time() 103 | for i in range(iters): 104 | Z = X.to(N) + Y.to(N) 105 | Z.resolve() 106 | t2 = time.time() 107 | print(f"{iters / (t2 - t1):.2f} iters/sec") 108 | 109 | 110 | print(Z.loop_tree) 111 | bench(Z.loop_tree, 10, 1000) 112 | 113 | 114 | def split(loop, inner_size): 115 | assert loop.tail == 0 116 | s = loop.size // inner_size 117 | t = loop.size % inner_size 118 | return [(loop.var, (s, t)), (loop.var, (inner_size, 0))] 119 | 120 | 121 | loop_tree = Z.loop_tree 122 | ir = loop_tree.ir 123 | 124 | for l in loop_tree.loops: 125 | if loop_tree.trivially_parallel(l): 126 | loop = loop_tree.loop(l) 127 | for n in ir.nodes: 128 | ir.set_order(n, split(loop, L // 2)) 129 | 130 | loop_tree = lt.LoopTree(ir) 131 | # parallelize the outermost loop 132 | loop_tree.annotate(loop_tree.loops[0], "parallel") 133 | 134 | Z.set(loop_tree) 135 | 136 | print(Z.loop_tree) 137 | bench(loop_tree, 10, 1000) 138 | -------------------------------------------------------------------------------- /test/test_ln.py: -------------------------------------------------------------------------------- 1 | import loop_tool as lt 2 | import numpy as np 3 | 4 | M = 128 * 2 5 | N = 128 * 2 6 | K = 128 * 2 7 | 8 | A_np = np.random.randn(M, K) 9 | B_np = np.random.randn(K, N) 10 | 11 | def mm(a, b): 12 | m, n, k = lt.symbols("m n k") 13 | return (a.to(m, k) * b.to(k, n)).sum(k) 14 | 15 | A = lt.Tensor(A_np) 16 | B = lt.Tensor(B_np) 17 | 18 | with lt.Backend("loop_nest"): 19 | C = mm(A, B) 20 | lt.ui(C) 21 | tree = C.loop_tree 22 | 23 | #tree = tree.annotate(tree.loops[0], "[loop_nest]") 24 | #print(hash(C)) 25 | #print(C.numpy()) 26 | #print(C.numpy()) 27 | C.set(tree) 28 | print(C.code) 29 | #C = mm(A, B) 30 | #C.clear_cache() 31 | #print(C.numpy()) 32 | -------------------------------------------------------------------------------- /test/test_mutate.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | #include 9 | 10 | #include "test_utils.h" 11 | 12 | TEST(MutateSplit) { 13 | namespace lz = ::loop_tool::lazy; 14 | auto mm = [](lz::Tensor A, lz::Tensor B) { 15 | auto M = lz::Symbol("m"), N = lz::Symbol("n"), K = lz::Symbol("k"); 16 | auto C = A.as(M, K) * B.as(K, N); 17 | return C.sum(K); 18 | }; 19 | 20 | lz::Tensor A(16, 16); 21 | lz::Tensor B(16, 17); 22 | auto C = mm(A, B); 23 | auto lt = C.loop_tree(); 24 | std::cerr << "presplit:\n"; 25 | std::cerr << lt.dump(); 26 | std::cerr << '\n'; 27 | auto r = lt.children(lt.children(lt.roots.at(0)).at(0)).at(0); 28 | lt = split(lt, r, 10); 29 | std::cerr << '\n'; 30 | std::cerr << lt.dump(); 31 | } 32 | 33 | TEST(MutateMerge) { 34 | namespace lz = ::loop_tool::lazy; 35 | auto N = lz::Symbol("n"); 36 | lz::Tensor A(16); 37 | lz::Tensor B(16); 38 | auto C = A + B; 39 | auto lt = C.loop_tree(); 40 | std::cerr << "presplit:\n"; 41 | std::cerr << lt.dump(); 42 | std::cerr << '\n'; 43 | auto r = lt.children(lt.roots.at(0)).at(0); 44 | lt = split(lt, r, 10); 45 | std::cerr << '\n'; 46 | std::cerr << lt.dump(); 47 | auto c = lt.children(lt.children(lt.roots.at(0)).at(0)).at(0); 48 | lt = merge(lt, c); 49 | std::cerr << '\n'; 50 | std::cerr << "postmerge:\n"; 51 | std::cerr << lt.dump(); 52 | } 53 | 54 | TEST(MutateSwap) { 55 | namespace lz = ::loop_tool::lazy; 56 | auto mm = [](lz::Tensor A, lz::Tensor B) { 57 | auto M = lz::Symbol("m"), N = lz::Symbol("n"), K = lz::Symbol("k"); 58 | auto C = A.as(M, K) * B.as(K, N); 59 | return C.sum(K); 60 | }; 61 | 62 | lz::Tensor A(16, 16); 63 | lz::Tensor B(16, 17); 64 | auto C = mm(A, B); 65 | auto lt = C.loop_tree(); 66 | std::cerr << '\n'; 67 | std::cerr << lt.dump(); 68 | std::cerr << '\n'; 69 | auto r = lt.children(lt.roots.at(0)).at(0); 70 | lt = split(lt, r, 10); 71 | 72 | auto a = lt.children(lt.children(lt.roots.at(0)).at(0)).at(0); 73 | auto b = lt.children(a).at(0); 74 | lt = swap_loops(lt, a, b); 75 | 76 | std::cerr << '\n'; 77 | std::cerr << lt.dump(); 78 | 79 | C.compile(); 80 | C.set(lt); 81 | std::cerr << C.code(); 82 | } 83 | 84 | TEST(MutateSubTree) { 85 | namespace lz = ::loop_tool::lazy; 86 | auto mm = [](lz::Tensor A, lz::Tensor B) { 87 | auto M = lz::Symbol("m"), N = lz::Symbol("n"), K = lz::Symbol("k"); 88 | auto C = A.as(M, K) * B.as(K, N); 89 | return C.sum(K); 90 | }; 91 | 92 | lz::Tensor A(16, 16); 93 | lz::Tensor B(16, 16); 94 | lz::Tensor C(16, 16); 95 | lz::Tensor D(16, 16); 96 | lz::Tensor E(16, 16); 97 | auto F = mm(A, B); 98 | auto G = mm(C, D); 99 | auto H = mm(E, F); 100 | auto I = mm(G, H); 101 | auto lt = I.loop_tree(); 102 | std::cerr << "old loop_tree" << lt.dump() << "\n"; 103 | lt = subtree(lt, lt.roots[1]); 104 | std::cerr << "new loop_tree: " << lt.dump() << "\n"; 105 | } 106 | -------------------------------------------------------------------------------- /test/test_nn.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include "test_utils.h" 14 | 15 | namespace lt = loop_tool; 16 | using namespace loop_tool::lazy; 17 | using namespace loop_tool::symbolic; 18 | using namespace loop_tool::testing; 19 | 20 | TEST(MNIST) { 21 | loop_tool::ScopedBackend sb("wasm"); 22 | auto conv = [](Tensor X, Tensor W, Tensor B, int stride, int padding) { 23 | auto inp_shape = X.shape(); 24 | if (padding > 0) { 25 | // H 26 | X = X.pad(inp_shape[1], padding); 27 | // W 28 | X = X.pad(inp_shape[2], padding); 29 | } 30 | auto w_shape = W.shape(); 31 | auto inp_padded_shape = X.shape(); 32 | auto oc = w_shape[0]; 33 | auto ic = w_shape[1]; 34 | auto kh = w_shape[2]; 35 | auto kw = w_shape[3]; 36 | auto ih = inp_padded_shape[1]; 37 | auto iw = inp_padded_shape[2]; 38 | X = X.as(ic, ih, iw); 39 | auto Y = lt::nn::convolve(X, W, {ih, iw}, {kh, kw}, stride); 40 | Y = Y.transpose({2, 0, 1}); 41 | if (B.shape().size()) { 42 | Y = Y + B.as(Y.shape().at(0)); 43 | } 44 | return Y; 45 | }; 46 | 47 | auto maxp = [](Tensor X, int k, int stride) { 48 | auto s = X.shape(); 49 | return lt::nn::maxpool(X, {s[1], s[2]}, k, stride); 50 | }; 51 | for (auto i = 0; i < 500; ++i) { 52 | Tensor X(1, 28, 28); 53 | Tensor W0(16, 1, 5, 5); 54 | Tensor B0(16); 55 | Tensor W1(32, 16, 5, 5); 56 | Tensor B1(32); 57 | X = conv(X, W0, B0, 1, 2); 58 | X = maxp(X, 2, 2); 59 | X = conv(X, W1, B1, 1, 2); 60 | X = maxp(X, 2, 2); 61 | //(void)X.sizes()[0]; 62 | X.compile(); 63 | X.clear_cache(); 64 | std::cerr << "hash is " << X.hash() << "\n"; 65 | } 66 | }; 67 | -------------------------------------------------------------------------------- /test/test_serialization.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | 9 | #include "test_utils.h" 10 | 11 | namespace lt = loop_tool; 12 | using namespace loop_tool::testing; 13 | 14 | TEST(SerializationBasic) { 15 | namespace lz = ::loop_tool::lazy; 16 | auto N = lz::Symbol("N"); 17 | auto size = 137; 18 | lz::Tensor A(size); 19 | lz::Tensor B(size); 20 | auto C = A.as(N) + B.as(N); 21 | rand(A.data(), size); 22 | rand(B.data(), size); 23 | const auto& ir = C.ir(); 24 | std::cerr << dot(ir) << "\n"; 25 | auto s_ir = lt::serialize(ir); 26 | std::cerr << s_ir << "\n"; 27 | const auto& ir_d = lt::deserialize(s_ir); 28 | C.set(ir_d); 29 | std::cerr << dot(ir_d) << "\n"; 30 | float max_diff = 0; 31 | for (auto i = 0; i < size; ++i) { 32 | auto ref = A.data()[i] + B.data()[i]; 33 | auto diff = std::abs(C.data()[i] - ref); 34 | max_diff = std::max(max_diff, diff); 35 | } 36 | ASSERT(max_diff < 0.01) << "got diff of " << max_diff; 37 | } 38 | 39 | TEST(SerializationScheduled) { 40 | namespace lz = ::loop_tool::lazy; 41 | auto N = lz::Symbol("N"); 42 | auto size = 138; 43 | lz::Tensor A(size); 44 | lz::Tensor B(size); 45 | auto C = A.as(N) + B.as(N); 46 | rand(A.data(), size); 47 | rand(B.data(), size); 48 | auto ir = C.ir(); 49 | auto v = ir.vars().at(0); 50 | for (auto n : ir.nodes()) { 51 | switch (ir.node(n).op()) { 52 | case lt::Operation::read: 53 | case lt::Operation::write: 54 | continue; 55 | default: 56 | break; 57 | } 58 | ir.set_order(n, {{v, {27, 3}}, {v, {5, 0}}}); 59 | ir.annotate_loop(n, 1, "unroll"); 60 | ir.disable_reuse(n, 1); 61 | } 62 | auto dot_before = dot(ir); 63 | auto s_ir = lt::serialize(ir); 64 | const auto& ir_d = lt::deserialize(s_ir); 65 | auto dot_after = dot(ir_d); 66 | ASSERT(dot_before == dot_after); 67 | C.set(ir_d); 68 | float max_diff = 0; 69 | for (auto i = 0; i < size; ++i) { 70 | auto ref = A.data()[i] + B.data()[i]; 71 | auto diff = std::abs(C.data()[i] - ref); 72 | max_diff = std::max(max_diff, diff); 73 | } 74 | ASSERT(max_diff < 0.01) << "got diff of " << max_diff; 75 | } 76 | 77 | TEST(SerializationDeletedNodes) { 78 | namespace lz = ::loop_tool::lazy; 79 | auto N = lz::Symbol("N"); 80 | auto size = 132; 81 | lz::Tensor A(size); 82 | lz::Tensor B(size); 83 | auto C = A.as(N) + B.as(N); 84 | rand(A.data(), size); 85 | rand(B.data(), size); 86 | auto ir = C.ir(); 87 | lt::LoopTree tree(ir); 88 | { 89 | auto c = tree.children(tree.roots[0]); 90 | auto write = c[1]; 91 | tree = copy_input(tree, write, 0); 92 | } 93 | { 94 | auto c = tree.children(tree.roots[0]); 95 | auto add = c[0]; 96 | tree = copy_input(tree, add, 0); 97 | } 98 | { 99 | auto c = tree.children(tree.roots[0]); 100 | auto copy = c[c.size() - 2]; 101 | tree = delete_copy(tree, copy); 102 | } 103 | { 104 | auto c = tree.children(tree.roots[0]); 105 | auto copy = c[0]; 106 | tree = delete_copy(tree, copy); 107 | } 108 | C.set(tree); 109 | ir = C.ir(); 110 | ir.reify_deletions(); 111 | auto dot_before = dot(ir); 112 | std::cerr << dot_before << "\n"; 113 | auto s_ir = lt::serialize(ir); 114 | const auto& ir_d = lt::deserialize(s_ir); 115 | auto dot_after = dot(ir_d); 116 | ASSERT(dot_before == dot_after); 117 | C.set(ir_d); 118 | float max_diff = 0; 119 | for (auto i = 0; i < size; ++i) { 120 | auto ref = A.data()[i] + B.data()[i]; 121 | auto diff = std::abs(C.data()[i] - ref); 122 | max_diff = std::max(max_diff, diff); 123 | } 124 | ASSERT(max_diff < 0.01) << "got diff of " << max_diff; 125 | } 126 | 127 | TEST(SerializationConv) { 128 | namespace lz = ::loop_tool::lazy; 129 | lz::Symbol N("N"), N_o("N_o"), K("K"); 130 | lz::Tensor A(N); 131 | lz::Tensor W(K); 132 | lz::Tensor X = A.to({N_o, K}, lz::Constraint(N, lz::Expr(2) * N_o + K)); 133 | auto Y = (X * W).sum(K); 134 | Y.bind(nullptr, {8}); // we can infer the size of A from this 135 | W.bind(nullptr, {3}); 136 | float A_data[17] = {0}; 137 | A.bind(A_data, {17}); 138 | for (auto i = 0; i < 10; ++i) { 139 | A.data()[i] = 1; 140 | } 141 | for (auto i = 0; i < 3; ++i) { 142 | W.data()[i] = 1; 143 | } 144 | 145 | auto dot_before = dot(Y.ir()); 146 | std::cerr << dot_before; 147 | auto ir = Y.ir(); 148 | auto s = lt::serialize(ir); 149 | std::cerr << s << "\n"; 150 | auto ir_d = lt::deserialize(s); 151 | auto dot_after = dot(ir_d); 152 | std::cerr << dot_after << "\n"; 153 | ASSERT(dot_before == dot_after); 154 | Y.set(ir_d); 155 | ASSERT(Y.data()[3] == 3); 156 | } 157 | -------------------------------------------------------------------------------- /test/test_symbolic.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | 9 | #include "test_utils.h" 10 | 11 | using namespace loop_tool; 12 | using namespace loop_tool::testing; 13 | 14 | TEST(SymbolicReplace) { 15 | namespace lz = loop_tool::lazy; 16 | lz::Symbol A("A"); 17 | lz::Symbol B("B"); 18 | lz::Expr C = A + B; 19 | std::cerr << C.dump() << "\n"; 20 | auto D = C.replace(A, 4); 21 | std::cerr << D.dump() << "\n"; 22 | D = D.replace(B, 4); 23 | auto E = D.simplify(); 24 | std::cerr << E.dump() << "\n"; 25 | } 26 | 27 | TEST(SymbolicBasic) { 28 | namespace lz = loop_tool::lazy; 29 | std::vector constraints; 30 | lz::Symbol A("A"); 31 | lz::Symbol B("B"); 32 | lz::Symbol C("C"); 33 | lz::Symbol D("D"); 34 | constraints.emplace_back( 35 | std::make_pair(lz::Expr::size(C), lz::Expr::size(B) * lz::Expr(9))); 36 | constraints.emplace_back(std::make_pair(lz::Expr::size(A), lz::Expr(8))); 37 | constraints.emplace_back( 38 | std::make_pair(lz::Expr::size(B), lz::Expr::size(A) + lz::Expr(2))); 39 | constraints.emplace_back( 40 | std::make_pair(lz::Expr::size(D), lz::Expr::size(A) + lz::Expr::size(C))); 41 | 42 | auto out = unify(constraints); 43 | for (auto p : out) { 44 | std::cout << p.first.dump() << " = " << p.second.dump() << "\n"; 45 | if (p.first == lz::Expr::size(D)) { 46 | ASSERT(p.second.value() == 98); 47 | } 48 | } 49 | } 50 | 51 | TEST(SymbolicUnbound) { 52 | namespace lz = loop_tool::lazy; 53 | std::vector constraints; 54 | lz::Symbol A("A"); 55 | lz::Symbol B("B"); 56 | lz::Symbol C("C"); 57 | lz::Symbol D("D"); 58 | constraints.emplace_back( 59 | std::make_pair(lz::Expr::size(C), lz::Expr(B) * lz::Expr(9))); 60 | constraints.emplace_back( 61 | std::make_pair(lz::Expr::size(B), lz::Expr(A) + lz::Expr(2))); 62 | constraints.emplace_back( 63 | std::make_pair(lz::Expr::size(D), lz::Expr(A) + lz::Expr(C))); 64 | 65 | auto out = unify(constraints); 66 | } 67 | 68 | TEST(SymbolicDerivative) { 69 | namespace lz = loop_tool::lazy; 70 | lz::Symbol N("N"), N_o("N_o"), K("K"); 71 | { 72 | auto d = loop_tool::symbolic::differentiate(N_o + K, N_o); 73 | ASSERT(d == lz::Expr(1)); 74 | } 75 | { 76 | auto d = loop_tool::symbolic::differentiate(N_o + K, N); 77 | ASSERT(d == lz::Expr(0)); 78 | } 79 | { 80 | auto d = loop_tool::symbolic::differentiate(lz::Expr(2) * N_o + K, N_o); 81 | ASSERT(d == lz::Expr(2)); 82 | } 83 | { 84 | auto d = 85 | loop_tool::symbolic::differentiate(lz::Expr(2) * N_o + K * N_o, N_o); 86 | ASSERT(d == lz::Expr(2) + K) << "found " << d.dump(); 87 | } 88 | { 89 | auto d = loop_tool::symbolic::differentiate(N + lz::Expr(2) * N_o + K * N_o, 90 | N_o); 91 | ASSERT(d == lz::Expr(2) + K) << "found " << d.dump(); 92 | } 93 | { 94 | auto d = loop_tool::symbolic::differentiate( 95 | N * (lz::Expr(2) * N_o + K * N_o) + N_o, N_o); 96 | ASSERT(d == N * (lz::Expr(2) + K) + lz::Expr(1)) << "found " << d.dump(); 97 | } 98 | } 99 | 100 | TEST(SymbolicPaddedConv) { 101 | namespace lz = loop_tool::lazy; 102 | auto xi = lz::Symbol("xi"); 103 | auto xp = lz::Symbol("xp"); 104 | auto xo = lz::Symbol("xo"); 105 | auto k = lz::Symbol("k"); 106 | std::vector constraints = { 107 | {xi + lz::Expr(k) / lz::Expr(2), xp}, // pad left 108 | {lz::Expr::size(xp), 109 | lz::Expr::size(xi) + lz::Expr::size(k) / lz::Expr(2)}, // pad right 110 | {xp, xo + k}, // conv 111 | {lz::Expr::size(xo), lz::Expr(100)}, 112 | {lz::Expr::size(k), lz::Expr(3)}, 113 | //{lz::Expr::size(xp), lz::Expr(102)}, 114 | {lz::Expr::size(xi), lz::Expr(100)}}; 115 | auto out = unify(constraints); 116 | } 117 | 118 | TEST(SymbolicConcat) { 119 | namespace lz = loop_tool::lazy; 120 | auto k = lz::Symbol("k"); 121 | auto j = lz::Symbol("j"); 122 | auto kj = lz::Symbol("kj"); 123 | std::vector constraints = { 124 | {kj, k}, 125 | {kj, j + lz::Expr::size(k)}, 126 | {lz::Expr::size(k), lz::Expr(10)}, 127 | {lz::Expr::size(j), lz::Expr(7)}, 128 | }; 129 | auto out = unify(constraints); 130 | for (auto& p : out) { 131 | if (p.first == lz::Expr::size(kj)) { 132 | ASSERT(p.second == lz::Expr(17)) 133 | << "found kj size to be " << p.second.dump(); 134 | } 135 | std::cerr << p.first.dump() << ": " << p.second.dump() << "\n"; 136 | } 137 | } 138 | 139 | TEST(SymbolicCanonicalization) {} 140 | 141 | TEST(SymbolicNewImpl) { 142 | using namespace loop_tool::symbolic; 143 | const auto& s = Symbol("X"); 144 | auto e0 = Expr(s); 145 | auto e1 = Expr(1234LL); 146 | auto e2 = e0 + e1; 147 | std::cerr << e2.args().size() << "\n"; 148 | std::cerr << e2.contains(s) << "\n"; 149 | std::cerr << e2.dump() << "\n"; 150 | std::cerr << e2.hash() << "\n"; 151 | std::cerr << e2.hash(true) << "\n"; 152 | e2.visit([&](const Expr& e) { std::cerr << e.dump() << "\n"; }); 153 | std::cerr << "NUM SYM " << e2.symbols().size() << "\n"; 154 | 155 | constexpr int N = 1; 156 | { 157 | auto start = std::chrono::steady_clock::now(); 158 | auto e = Expr(s); 159 | for (auto i = 0; i < N; ++i) { 160 | e = e + Expr(1234LL); 161 | e = Expr(1234LL) + e; 162 | } 163 | e = e.walk([](const Expr& e) { 164 | if (e.type() == Expr::Type::symbol) { 165 | return Expr(0); 166 | } 167 | return e; 168 | }); 169 | std::cerr << e.contains(s) << "\n"; 170 | auto end = std::chrono::steady_clock::now(); 171 | std::chrono::duration diff = end - start; 172 | std::cerr << "expr_: " << (1e6 * diff.count()) << "us\n"; 173 | } 174 | { 175 | auto start = std::chrono::steady_clock::now(); 176 | auto e = Expr(s); 177 | for (auto i = 0; i < N; ++i) { 178 | e = e + Expr(1234LL); 179 | e = Expr(1234LL) + e; 180 | } 181 | e = e.walk([](const Expr& e) { 182 | if (e.type() == Expr::Type::symbol) { 183 | return Expr(0LL); 184 | } 185 | return e; 186 | }); 187 | std::cerr << e.contains(s) << "\n"; 188 | auto end = std::chrono::steady_clock::now(); 189 | std::chrono::duration diff = end - start; 190 | std::cerr << "expr: " << (1e6 * diff.count()) << "us\n"; 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /test/test_ui.py: -------------------------------------------------------------------------------- 1 | import loop_tool as lt 2 | 3 | # import loop_tool_py.ui as ui 4 | import numpy as np 5 | 6 | 7 | def mm(A, B): 8 | s = lt.SymbolGenerator() 9 | C = A.to(s.m, s.k) * B.to(s.k, s.n) 10 | return C.sum(s.k) 11 | 12 | 13 | m, n, k = 128, 128, 128 # 8, 16, 128 14 | A = lt.Tensor(m, k).set(np.random.randn(m, k)) 15 | B = lt.Tensor(k, n).set(np.random.randn(k, n)) 16 | 17 | s = lt.SymbolGenerator() 18 | C = mm(A, B).to(s.m, s.n).sum(s.m) # * A.to(s.m, s.k) 19 | 20 | 21 | def conv(X, W): 22 | s = lt.SymbolGenerator() 23 | X = X.pad(X.symbolic_shape[1], 1) 24 | return (X[s.B, s.No + s.K] * W.to(s.B, s.K)).sum(s.K) 25 | 26 | 27 | X = lt.Tensor(256, 128).set(np.random.randn(256, 128)) 28 | W = lt.Tensor(256, 3).set(np.random.randn(256, 3)) 29 | 30 | C = conv(X, W) 31 | 32 | A = lt.Tensor(m, k).set(np.random.randn(m, k)) 33 | B = lt.Tensor(m, k).set(np.random.randn(m, k)) 34 | C = mm(A, B) 35 | 36 | lt.ui(C, "/tmp/woo.c") 37 | 38 | print(C.code) 39 | -------------------------------------------------------------------------------- /test/test_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace loop_tool { 8 | namespace testing { 9 | 10 | struct Test { 11 | std::string file; 12 | std::string name; 13 | std::function fn; 14 | Test(std::string file_, std::string name_, std::function fn_) 15 | : file(file_), name(name_), fn(fn_) {} 16 | void operator()() const { fn(); } 17 | }; 18 | 19 | std::vector &getTestRegistry(); 20 | 21 | struct AddTest { 22 | AddTest(std::string file, std::string name, std::function fn); 23 | }; 24 | 25 | void runner(int argc, char *argv[]); 26 | void rand(float *data, int N); 27 | void ref_mm(const float *A, const float *B, int M, int N, int K, float *C, 28 | float alpha = 0); 29 | 30 | // input: NCHW weight: MCKhKw out: NM(H-K+2)(W-K+2) 31 | void ref_conv(const float *X, const float *W, int N, int M, int C, int HW, 32 | int K, float *Y); 33 | 34 | bool all_close(const float *A, const float *B, size_t N, float eps = 0.001); 35 | 36 | } // namespace testing 37 | } // namespace loop_tool 38 | 39 | #define TEST(name) \ 40 | void _loop_tool_test_##name(); \ 41 | static loop_tool::testing::AddTest _loop_tool_test_add_##name( \ 42 | __FILE__, #name, _loop_tool_test_##name); \ 43 | void _loop_tool_test_##name() 44 | 45 | #define RUN_TESTS(argc, argv) loop_tool::testing::runner(argc, argv); 46 | -------------------------------------------------------------------------------- /test/test_views.py: -------------------------------------------------------------------------------- 1 | import loop_tool_py as lt 2 | import numpy as np 3 | 4 | 5 | def test_pad(): 6 | m, n = 128, 16 7 | 8 | base_np = np.random.randn(m, n) 9 | padded_np = np.pad(base_np, [(0,), (3,)]) 10 | 11 | base_lt = lt.Tensor(m, n).set(base_np) 12 | padded_lt = base_lt.pad(base_lt.symbolic_shape[1], 3) 13 | 14 | assert np.allclose(padded_lt.numpy(), padded_np) 15 | 16 | 17 | def test_concat(): 18 | m, n, k = 128, 16, 5 19 | 20 | A_np = np.random.randn(m, n) 21 | B_np = np.random.randn(m, k) 22 | C_np = np.concatenate((A_np, B_np), axis=1) 23 | 24 | A_lt = lt.Tensor(m, n).set(A_np) 25 | B_lt = lt.Tensor(m, k).set(B_np) 26 | 27 | with lt.SymbolGenerator() as s: 28 | C_lt = A_lt.to(s.m, s.n) | B_lt.to(s.m, s.k) 29 | 30 | assert np.allclose(C_lt.numpy(), C_np) 31 | 32 | 33 | def test_2d_conv(): 34 | import torch 35 | import torch.nn.functional as F 36 | 37 | def conv2d(X, W): 38 | s = lt.SymbolGenerator() 39 | X = X[s.C, s.H + s.Kh, s.W + s.Kw] 40 | W = W.to(s.Co, s.C, s.Kh, s.Kw) 41 | return (X * W).sum(s.C, s.Kh, s.Kw).transpose(s.Co, s.H, s.W) 42 | 43 | ci = 16 44 | co = 16 45 | x = 8 46 | k = 3 47 | X_np = np.random.randn(ci, x, x) 48 | W_np = np.random.randn(co, ci, k, k) 49 | Y_np = F.conv2d(torch.tensor(X_np).unsqueeze(0), torch.tensor(W_np)).numpy() 50 | 51 | X_lt = lt.Tensor(ci, x, x).set(X_np) 52 | W_lt = lt.Tensor(co, ci, k, k).set(W_np) 53 | Y_lt = conv2d(X_lt, W_lt) 54 | 55 | assert np.allclose(Y_lt.numpy(), Y_np, rtol=0.001, atol=0.001) 56 | 57 | 58 | def test_padded_2d_conv(): 59 | import torch 60 | import torch.nn.functional as F 61 | 62 | def conv2d(X, W): 63 | s = lt.SymbolGenerator() 64 | X = X.to(s.c, s.h, s.w).pad(s.h, 1).pad(s.w, 1) 65 | X = X[s.C, s.H + s.Kh, s.W + s.Kw] 66 | W = W.to(s.Co, s.C, s.Kh, s.Kw) 67 | return (X * W).sum(s.C, s.Kh, s.Kw).transpose(s.Co, s.H, s.W) 68 | 69 | ci = 16 70 | co = 16 71 | x = 8 72 | k = 3 73 | X_np = np.random.randn(ci, x, x) 74 | W_np = np.random.randn(co, ci, k, k) 75 | Y_np = F.conv2d( 76 | torch.tensor(X_np).unsqueeze(0), torch.tensor(W_np), padding=1 77 | ).numpy() 78 | 79 | X_lt = lt.Tensor(ci, x, x).set(X_np) 80 | W_lt = lt.Tensor(co, ci, k, k).set(W_np) 81 | Y_lt = conv2d(X_lt, W_lt) 82 | print(Y_lt.loop_tree) 83 | 84 | assert np.allclose(Y_lt.numpy(), Y_np, rtol=0.001, atol=0.001) 85 | 86 | 87 | def test_many_pad(): 88 | import string 89 | 90 | N = 5 91 | a = lt.Tensor(lt.Symbol("A")).set(np.random.randn(N)) 92 | X = a.symbolic_shape[0] 93 | Y = lt.Symbol("B") 94 | b = a.to( 95 | Y, constraints=[(Y, X + lt.Expr(1)), (lt.Size(Y), lt.Size(X) + lt.Expr(2))] 96 | ) 97 | for i in range(10): 98 | X = b.symbolic_shape[0] 99 | Y = lt.Symbol(string.ascii_uppercase[i + 2]) 100 | b = b.to( 101 | Y, constraints=[(Y, X + lt.Expr(1)), (lt.Size(Y), lt.Size(X) + lt.Expr(2))] 102 | ) 103 | ir = b.ir 104 | vs = ir.vars 105 | print(vs) 106 | for n in ir.nodes: 107 | if "write" in ir.dump(n): 108 | print(ir.dump(n)) 109 | ir.set_order(n, [(vs[0], (N, 0))]) 110 | # b.set(ir) 111 | print(b.loop_tree) 112 | print(b.code) 113 | print(b.numpy()) 114 | 115 | 116 | def test_many_conv(): 117 | import string 118 | 119 | N = 11 120 | # a = lt.Tensor(lt.Symbol("A")).set(np.arange(N)) 121 | a = lt.Tensor(lt.Symbol("A")).set(np.ones(N)) 122 | X = a.symbolic_shape[0] 123 | Y = lt.Symbol("B") 124 | Z = lt.Symbol("C") 125 | cur_syms = [Y, Z] 126 | b = a.to(*cur_syms, constraints=[(X, Y + Z), (lt.Size(Z), lt.Expr(3))]) 127 | for i in range(2): 128 | X = b.symbolic_shape[0] 129 | Y = lt.Symbol(string.ascii_uppercase[i * 2 + 3]) 130 | Z = lt.Symbol(string.ascii_uppercase[i * 2 + 4]) 131 | cur_syms = [Y, Z] + cur_syms[1:] 132 | # print("setting size of ", Y, Z, b.shape) 133 | b = b.to(*cur_syms, constraints=[(X, Y + Z), (lt.Size(Z), lt.Expr(3))]) 134 | print(b.ir) 135 | print(b.loop_tree) 136 | print(b.shape) 137 | print(b.numpy()) 138 | print(b.code) 139 | 140 | 141 | # test_pad() 142 | # test_concat() 143 | # test_2d_conv() 144 | # test_padded_2d_conv() 145 | test_many_pad() 146 | # test_many_conv() 147 | -------------------------------------------------------------------------------- /test/utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "test_utils.h" 10 | 11 | namespace loop_tool { 12 | namespace testing { 13 | 14 | AddTest::AddTest(std::string file, std::string name, std::function fn) { 15 | getTestRegistry().emplace_back(file, name, fn); 16 | } 17 | 18 | std::vector& getTestRegistry() { 19 | static std::vector tests_; 20 | return tests_; 21 | } 22 | 23 | #define ANSI_GREEN "\033[32m" 24 | #define ANSI_RED "\033[31m" 25 | #define ANSI_RESET "\033[39m" 26 | 27 | void runner(int argc, char* argv[]) { 28 | // unsigned int microseconds = 1000; 29 | // usleep(microseconds); 30 | bool verbose = false; 31 | bool strict = false; 32 | std::regex filter(".*"); 33 | for (auto i = 0; i < argc; ++i) { 34 | auto arg = std::string(argv[i]); 35 | if (arg == "--verbose" || arg == "-v") { 36 | verbose = true; 37 | } 38 | if (arg == "--strict" || arg == "-f") { 39 | strict = true; 40 | } 41 | if (arg == "-fv" || arg == "-vf") { 42 | strict = true; 43 | verbose = true; 44 | } 45 | if (arg == "--filter") { 46 | if (argc <= i + 1) { 47 | std::cerr << "no argument found for --filter\n"; 48 | return; 49 | } 50 | arg = argv[++i]; 51 | filter = std::regex(arg); 52 | } 53 | } 54 | std::stringstream stdout_buffer; 55 | std::stringstream stderr_buffer; 56 | std::streambuf* old_stdout; 57 | std::streambuf* old_stderr; 58 | auto hide_output = [&]() { 59 | stdout_buffer.str(""); 60 | stderr_buffer.str(""); 61 | if (!verbose) { 62 | old_stdout = std::cout.rdbuf(stdout_buffer.rdbuf()); 63 | old_stderr = std::cerr.rdbuf(stderr_buffer.rdbuf()); 64 | } 65 | }; 66 | auto restore_output = [&]() { 67 | if (!verbose) { 68 | std::cout.rdbuf(old_stdout); 69 | std::cerr.rdbuf(old_stderr); 70 | } 71 | }; 72 | 73 | auto tests = getTestRegistry(); 74 | std::sort(tests.begin(), tests.end(), [](const Test& a, const Test& b) { 75 | return a.file.compare(b.file) < 0; 76 | }); 77 | tests.erase(std::remove_if(tests.begin(), tests.end(), 78 | [&](const Test& test) { 79 | std::string q = test.file + " " + test.name; 80 | return !std::regex_search(q, filter); 81 | }), 82 | tests.end()); 83 | std::string curr_file = ""; 84 | size_t passed = 0; 85 | 86 | for (const auto& test : tests) { 87 | if (test.file != curr_file) { 88 | std::cout << "running tests in " << test.file << "\n"; 89 | ; 90 | curr_file = test.file; 91 | } 92 | 93 | std::cout << " - " << test.name << " ... "; 94 | 95 | try { 96 | hide_output(); 97 | test(); 98 | restore_output(); 99 | std::cout << ANSI_GREEN << "passed" << ANSI_RESET << ".\n"; 100 | passed++; 101 | } catch (const std::exception& e) { 102 | restore_output(); 103 | std::cout << ANSI_RED << "failed" << ANSI_RESET << ".\n"; 104 | if (strict) { 105 | throw; 106 | } else { 107 | if (stdout_buffer.str().size()) { 108 | std::cout << "==== stdout for failed test \"" << test.name 109 | << "\" ====\n"; 110 | std::cout << stdout_buffer.str(); 111 | std::cout << "\n[ run tests with -f flag to throw ]\n"; 112 | } 113 | if (stderr_buffer.str().size()) { 114 | std::cerr << "==== stderr for failed test \"" << test.name 115 | << "\" ====\n"; 116 | std::cerr << stderr_buffer.str(); 117 | std::cerr << "\n[ run tests with -f flag to throw ]\n"; 118 | } 119 | } 120 | } 121 | } 122 | std::cout << "[" << passed << "/" << tests.size() << " tests passed]\n"; 123 | } 124 | 125 | void rand(float* data, int N) { 126 | std::random_device rd; 127 | std::mt19937 e2(rd()); 128 | std::normal_distribution<> dist(2, 2); 129 | for (auto i = 0; i < N; ++i) { 130 | data[i] = dist(e2); 131 | } 132 | } 133 | 134 | // assumes LCA=K, LCB=N 135 | void ref_mm(const float* A, const float* B, int M, int N, int K, float* C, 136 | float alpha) { 137 | for (auto n = 0; n < N; ++n) { 138 | for (auto m = 0; m < M; ++m) { 139 | float tmp = 0; 140 | for (auto k = 0; k < K; ++k) { 141 | tmp += A[m * K + k] * B[k * N + n]; 142 | } 143 | C[m * N + n] = (alpha ? (alpha * C[m * N + n]) : 0) + tmp; 144 | } 145 | } 146 | } 147 | 148 | void ref_conv(const float* X, const float* W, int N, int M, int C, int HW, 149 | int K, float* Y) { 150 | const auto HWO = HW - K + 1; 151 | for (auto i = 0; i < N * M * HWO * HWO; ++i) { 152 | Y[i] = 0; 153 | } 154 | for (auto n = 0; n < N; ++n) { 155 | for (auto m = 0; m < M; ++m) { 156 | for (auto c = 0; c < C; ++c) { 157 | for (auto h = 0; h < HWO; ++h) { 158 | for (auto w = 0; w < HWO; ++w) { 159 | for (auto kh = 0; kh < K; ++kh) { 160 | for (auto kw = 0; kw < K; ++kw) { 161 | Y[n * M * HWO * HWO + m * HWO * HWO + h * HWO + w] += 162 | X[(n)*HW * HW * C + (c)*HW * HW + (h + kh) * HW + 163 | (w + kw)] * 164 | W[m * C * K * K + c * K * K + kh * K + kw]; 165 | } 166 | } 167 | } 168 | } 169 | } 170 | } 171 | } 172 | } 173 | 174 | bool all_close(const float* A, const float* B, size_t N, float eps) { 175 | float max_diff = 0; 176 | float min_val = std::numeric_limits::max(); 177 | for (size_t i = 0; i < N; ++i) { 178 | max_diff = std::max(std::abs(A[i] - B[i]), max_diff); 179 | min_val = std::min(std::abs(A[i]), min_val); 180 | min_val = std::min(std::abs(B[i]), min_val); 181 | } 182 | std::cerr << "max diff " << max_diff << " vs min val " << min_val 183 | << " (eps: " << eps << ")\n"; 184 | return max_diff < std::max(eps * min_val, eps); 185 | } 186 | 187 | } // namespace testing 188 | } // namespace loop_tool 189 | -------------------------------------------------------------------------------- /test/wasm_runtime_test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include 8 | 9 | #include "test_utils.h" 10 | 11 | using namespace loop_tool::testing; 12 | 13 | TEST(WasmBackend) { 14 | loop_tool::ScopedBackend sb("wasm"); 15 | namespace lz = ::loop_tool::lazy; 16 | auto mm = [](lz::Tensor A, lz::Tensor B) { 17 | auto M = lz::Symbol("M"); 18 | auto N = lz::Symbol("N"); 19 | auto K = lz::Symbol("K"); 20 | auto C = A.as(M, K) * B.as(K, N); 21 | return C.sum(K); 22 | }; 23 | 24 | auto M = 16; 25 | auto N = 16; 26 | auto K = 16; 27 | 28 | lz::Tensor A(M, K); 29 | lz::Tensor B(K, N); 30 | for (auto i = 0; i < M * K; ++i) { 31 | A.data()[i] = 1; 32 | B.data()[i] = 2; 33 | } 34 | auto C = mm(A, B); 35 | auto d = C.data(); 36 | for (auto i = 0; i < M * N; ++i) { 37 | std::cerr << d[i] << " "; 38 | } 39 | std::cerr << "\n"; 40 | C.clear_cache(); 41 | } 42 | 43 | TEST(WasmMM) { 44 | loop_tool::ScopedBackend sb("wasm"); 45 | namespace lz = ::loop_tool::lazy; 46 | auto mm = [](lz::Tensor A, lz::Tensor B) { 47 | auto M = lz::Symbol("M"); 48 | auto N = lz::Symbol("N"); 49 | auto K = lz::Symbol("K"); 50 | auto C = A.as(M, K) * B.as(K, N); 51 | return C.sum(K); 52 | }; 53 | 54 | auto M = 16; 55 | auto N = 16; 56 | auto K = 16; 57 | 58 | lz::Tensor A(M, K); 59 | lz::Tensor B(K, N); 60 | rand(A.data(), M * K); 61 | rand(B.data(), K * N); 62 | 63 | auto C = mm(A, B); 64 | lz::Tensor C_ref(M * N); 65 | ref_mm(A.data(), B.data(), M, N, K, C_ref.data()); 66 | 67 | ASSERT(all_close(C_ref.data(), C.data(), M * N)); 68 | C.clear_cache(); 69 | } 70 | 71 | TEST(WasmConv) { 72 | loop_tool::ScopedBackend sb("wasm"); 73 | namespace lz = ::loop_tool::lazy; 74 | 75 | auto conv = [](lz::Tensor X, lz::Tensor w) { 76 | lz::Symbol N("N"), M("M"), C("C"), H("H"), Ho("Ho"), W("W"), Wo("Wo"), 77 | Kh("Kh"), Kw("Kw"); 78 | X = X.as(N, C, H, W); 79 | w = w.as(M, C, Kh, Kw); 80 | auto X_im2col = X.to({N, C, Ho, Kh, Wo, Kw}, lz::Constraint(H, Ho + Kh), 81 | lz::Constraint(W, Wo + Kw)); 82 | auto Y = (X_im2col * w).sum(Kh, Kw, C); 83 | return Y.transpose({N, M, Ho, Wo}); 84 | }; 85 | 86 | auto N = 4; 87 | auto M = 64; 88 | auto C = 64; 89 | auto HW = 8; 90 | auto K = 3; 91 | auto HWo = HW - K + 1; 92 | 93 | lz::Tensor A(N, C, HW, HW); 94 | lz::Tensor B(M, C, K, K); 95 | rand(A.data(), A.numel()); 96 | rand(B.data(), B.numel()); 97 | 98 | auto C_lt = conv(A, B); 99 | std::cerr << C_lt.numel() << " vs " << (N * M * HWo * HWo) << "\n"; 100 | ASSERT(C_lt.numel() == N * M * HWo * HWo); 101 | lz::Tensor C_ref(C_lt.numel()); 102 | ref_conv(A.data(), B.data(), N, M, C, HW, K, 103 | C_ref.data()); 104 | 105 | ASSERT(all_close(C_ref.data(), C_lt.data(), C_lt.numel())); 106 | } 107 | 108 | TEST(WasmConcat1D) { 109 | loop_tool::ScopedBackend sb("wasm"); 110 | namespace lz = ::loop_tool::lazy; 111 | lz::Symbol N("N"), M("M"), NM("NM"); 112 | lz::Tensor A(N); 113 | lz::Tensor B(M); 114 | A.bind(nullptr, {8}); 115 | B.bind(nullptr, {5}); 116 | auto A_ = A.to({NM}, lz::Constraint(NM, N), 117 | lz::Constraint(lz::Expr::size(NM), 118 | lz::Expr::size(N) + lz::Expr::size(M))); 119 | auto B_ = B.to({NM}, lz::Constraint(NM, M + lz::Expr::size(N))); 120 | auto C = A_ + B_; 121 | std::cerr << C.loop_tree().dump() << "\n"; 122 | std::cerr << C.code() << "\n"; 123 | std::cerr << "shape:\n"; 124 | for (auto s : C.shape()) { 125 | std::cerr << s.name() << "\n"; 126 | } 127 | ASSERT(C.size(0) == 13) << "size is " << C.size(0); 128 | for (auto i = 0; i < 8; ++i) { 129 | A.data()[i] = i; 130 | } 131 | for (auto i = 0; i < 5; ++i) { 132 | B.data()[i] = i; 133 | } 134 | for (auto i = 0; i < 13; ++i) { 135 | std::cerr << "C[" << i << "]: " << C.data()[i] << "\n"; 136 | } 137 | // ASSERT(C.data()[2] == 2); 138 | ASSERT(C.data()[10] == 2); 139 | auto D = A | B; 140 | ASSERT(D.data()[10] == C.data()[10]); 141 | C.clear_cache(); 142 | D.clear_cache(); 143 | } 144 | 145 | TEST(WasmConcat2D) { 146 | loop_tool::ScopedBackend sb("wasm"); 147 | namespace lz = ::loop_tool::lazy; 148 | int64_t batch = 2; 149 | lz::Symbol N("N"), M0("M0"), M1("M1"), M("M"); 150 | lz::Tensor A(N, M0); 151 | lz::Tensor B(N, M1); 152 | auto C = A | B; // different dimensions are concatenated 153 | A.bind(nullptr, {batch, 5}); 154 | B.bind(nullptr, {batch, 3}); 155 | ASSERT(C.shape()[0] == N); 156 | ASSERT(C.size(1) == 8); 157 | for (auto i = 0; i < batch * 5; ++i) { 158 | A.data()[i] = 11; // i; 159 | } 160 | for (auto i = 0; i < batch * 3; ++i) { 161 | B.data()[i] = 7; // i; 162 | } 163 | std::cerr << loop_tool::dot(C.ir()) << "\n"; 164 | std::cerr << C.code() << "\n"; 165 | std::cerr << "checking " << C.data()[0] << "\n"; 166 | for (auto i = 0; i < batch * 8; ++i) { 167 | std::cerr << C.data()[i] << "\n"; 168 | } 169 | ASSERT(C.data()[6] == 7); 170 | ASSERT(C.data()[8] == 11); 171 | C.clear_cache(); 172 | } 173 | 174 | TEST(WasmPad) { 175 | loop_tool::ScopedBackend sb("wasm"); 176 | namespace lz = ::loop_tool::lazy; 177 | lz::Symbol N("N"), Np("Np"); 178 | lz::Tensor X(N); 179 | // pads both sizes by 1 180 | lz::Tensor X_pad = 181 | X.to({Np}, lz::Constraint(Np, N + lz::Expr(1)), 182 | lz::Constraint(lz::Expr::size(Np), lz::Expr::size(N) + lz::Expr(2))); 183 | X.bind(nullptr, {5}); 184 | for (auto i = 0; i < 5; ++i) { 185 | X.data()[i] = i; 186 | } 187 | ASSERT(X_pad.size(0) == 7); 188 | for (auto i = 0; i < 7; ++i) { 189 | std::cerr << "XPAD " << i << ": " << X_pad.data()[i] << "\n"; 190 | } 191 | ASSERT(X_pad.data()[2] == 1); 192 | ASSERT(X_pad.data()[6] == 0); 193 | X_pad.clear_cache(); 194 | } 195 | 196 | TEST(WasmPaddedConv) { 197 | loop_tool::ScopedBackend sb("wasm"); 198 | namespace lz = ::loop_tool::lazy; 199 | lz::Symbol N("N"), Np("Np"), K("K"), No("No"); 200 | lz::Tensor X(N); 201 | lz::Tensor W(K); 202 | auto paddedX = 203 | X.to({Np}, lz::Constraint(Np, N + lz::Expr(1)), 204 | lz::Constraint(lz::Expr::size(Np), lz::Expr::size(N) + lz::Expr(2))); 205 | 206 | // implicit constraint -> Np = size(N) + size(K) - 1 207 | auto expandedX = paddedX.to({No, K}, lz::Constraint(Np, No + K)); 208 | ASSERT(expandedX.shape().size() == 2); 209 | // ASSERT(expandedX.shape().at(0) == N); 210 | ASSERT(expandedX.shape().at(1) == K); 211 | auto Y = (expandedX * W).sum(K); 212 | X.bind(nullptr, {5}); 213 | W.bind(nullptr, {3}); 214 | for (auto i = 0; i < 5; ++i) { 215 | X.data()[i] = 1; 216 | } 217 | for (auto i = 0; i < 3; ++i) { 218 | W.data()[i] = 1; 219 | } 220 | ASSERT(Y.size(0) == 5); 221 | Y.data(); 222 | ASSERT(Y.data()[0] == 2); 223 | ASSERT(Y.data()[2] == 3); 224 | } 225 | -------------------------------------------------------------------------------- /test/wasm_test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | */ 7 | #include "loop_tool/wasm.h" 8 | 9 | #include 10 | 11 | #include 12 | 13 | #include "test_utils.h" 14 | 15 | TEST(WasmBasic) { 16 | namespace lz = ::loop_tool::lazy; 17 | auto mm = [](lz::Tensor A, lz::Tensor B) { 18 | auto M = lz::Symbol("M"); 19 | auto N = lz::Symbol("N"); 20 | auto K = lz::Symbol("K"); 21 | auto C = A.as(M, K) * B.as(K, N); 22 | return C.sum(K); 23 | }; 24 | 25 | auto M = 16; 26 | auto N = 16; 27 | auto K = 16; 28 | 29 | lz::Tensor A(M, K); 30 | lz::Tensor B(K, N); 31 | auto C = mm(A, B); 32 | auto wc = loop_tool::WebAssemblyCompiler(C.loop_tree()); 33 | auto bytes = wc.emit(); 34 | std::ofstream wasm("out.wasm", std::ios::binary); 35 | wasm.write((char*)bytes.data(), bytes.size()); 36 | } 37 | 38 | TEST(WasmUnroll) { 39 | namespace lz = ::loop_tool::lazy; 40 | auto N = 4; 41 | auto n = lz::Symbol("N"); 42 | 43 | lz::Tensor A(N); 44 | lz::Tensor B(N); 45 | auto C = A.as(n) + B.as(n); 46 | auto lt = C.loop_tree(); 47 | 48 | lt = split(lt, lt.roots.at(0), 2); 49 | lt = 50 | disable_reuse(lt, lt.children(lt.roots.at(0)).at(0), lt.ir.nodes().at(2)); 51 | lt = loop_tool::annotate(lt, lt.roots.at(0), "unroll"); 52 | std::cerr << "\n" << lt.dump(); 53 | auto wc = loop_tool::WebAssemblyCompiler(lt); 54 | auto bytes = wc.emit(); 55 | std::ofstream wasm("out.wasm", std::ios::binary); 56 | wasm.write((char*)bytes.data(), bytes.size()); 57 | } 58 | 59 | TEST(WasmVectorize) { 60 | namespace lz = ::loop_tool::lazy; 61 | auto N = 4; 62 | auto n = lz::Symbol("N"); 63 | 64 | lz::Tensor A(N); 65 | lz::Tensor B(N); 66 | auto C = A.as(n) + B.as(n); 67 | auto lt = C.loop_tree(); 68 | 69 | lt = loop_tool::annotate(lt, lt.roots.at(0), "vectorize"); 70 | std::cerr << "\n" << lt.dump(); 71 | auto wc = loop_tool::WebAssemblyCompiler(lt); 72 | auto bytes = wc.emit(); 73 | std::ofstream wasm("out.wasm", std::ios::binary); 74 | wasm.write((char*)bytes.data(), bytes.size()); 75 | } 76 | 77 | TEST(WasmVectorizeCopy) { 78 | namespace lz = ::loop_tool::lazy; 79 | auto N = 4; 80 | auto n = lz::Symbol("N"); 81 | 82 | lz::Tensor A(N); 83 | lz::Tensor B(N); 84 | lz::Tensor C(N); 85 | auto D = A.as(n) * B.as(n) + C.as(n); 86 | auto lt = D.loop_tree(); 87 | 88 | auto ref = lt.children(lt.roots.at(0)).at(1); 89 | lt = loop_tool::copy_input(lt, ref, 1); 90 | lt = loop_tool::annotate(lt, lt.roots.at(0), "vectorize"); 91 | std::cerr << "\n" << lt.dump(); 92 | auto wc = loop_tool::WebAssemblyCompiler(lt); 93 | auto bytes = wc.emit(); 94 | std::ofstream wasm("out.wasm", std::ios::binary); 95 | wasm.write((char*)bytes.data(), bytes.size()); 96 | } 97 | 98 | TEST(WasmVectorizeWithTail) { 99 | namespace lz = ::loop_tool::lazy; 100 | auto mm = [](lz::Tensor A, lz::Tensor B) { 101 | auto M = lz::Symbol("M"); 102 | auto N = lz::Symbol("N"); 103 | auto K = lz::Symbol("K"); 104 | auto C = A.as(M, K) * B.as(K, N); 105 | return C.sum(K); 106 | }; 107 | 108 | auto M = 5; 109 | auto N = 5; 110 | auto K = 5; 111 | 112 | lz::Tensor A(M, K); 113 | lz::Tensor B(K, N); 114 | auto C = mm(A, B); 115 | 116 | auto lt = C.loop_tree(); 117 | lt = split(lt, lt.children(lt.roots.at(0)).at(1), 4); 118 | lt = annotate(lt, lt.children(lt.children(lt.roots.at(0)).at(1)).at(0), 119 | "vectorize"); 120 | lt = split(lt, lt.children(lt.children(lt.roots.at(0)).at(0)).at(0), 4); 121 | lt = annotate( 122 | lt, 123 | lt.children(lt.children(lt.children(lt.roots.at(0)).at(0)).at(0)).at(0), 124 | "vectorize"); 125 | C.set(lt); 126 | 127 | auto wc = loop_tool::WebAssemblyCompiler(C.loop_tree()); 128 | auto bytes = wc.emit(); 129 | std::ofstream wasm("out.wasm", std::ios::binary); 130 | wasm.write((char*)bytes.data(), bytes.size()); 131 | } 132 | 133 | TEST(WasmUnrollVectorizeWithTail) { 134 | namespace lz = ::loop_tool::lazy; 135 | auto mm = [](lz::Tensor A, lz::Tensor B) { 136 | auto M = lz::Symbol("M"); 137 | auto N = lz::Symbol("N"); 138 | auto K = lz::Symbol("K"); 139 | auto C = A.as(M, K) * B.as(K, N); 140 | return C.sum(K); 141 | }; 142 | 143 | auto M = 2; 144 | auto N = 5; 145 | auto K = 5; 146 | 147 | lz::Tensor A(M, K); 148 | lz::Tensor B(K, N); 149 | auto C = mm(A, B); 150 | 151 | auto lt = C.loop_tree(); 152 | lt = split(lt, lt.children(lt.roots.at(0)).at(1), 4); 153 | lt = annotate(lt, lt.children(lt.roots.at(0)).at(1), "unroll"); 154 | lt = annotate(lt, lt.children(lt.children(lt.roots.at(0)).at(1)).at(0), 155 | "vectorize"); 156 | lt = split(lt, lt.children(lt.children(lt.roots.at(0)).at(0)).at(0), 4); 157 | lt = annotate(lt, lt.children(lt.children(lt.roots.at(0)).at(0)).at(0), 158 | "unroll"); 159 | lt = annotate( 160 | lt, 161 | lt.children(lt.children(lt.children(lt.roots.at(0)).at(0)).at(0)).at(0), 162 | "vectorize"); 163 | C.set(lt); 164 | 165 | auto wc = loop_tool::WebAssemblyCompiler(C.loop_tree()); 166 | auto bytes = wc.emit(); 167 | std::ofstream wasm("out.wasm", std::ios::binary); 168 | wasm.write((char*)bytes.data(), bytes.size()); 169 | } 170 | --------------------------------------------------------------------------------