├── .clang-format ├── .clangd ├── .github └── workflows │ ├── ci.yml │ └── release.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake ├── CPM.cmake └── Utils │ ├── AddGoogleTest.cmake │ ├── CxxUtils.cmake │ └── GitVersion.cmake ├── cpp ├── c_api.cc ├── printer.cc ├── registry.h ├── structure.cc └── sym │ ├── analyzer_canonical_simplify.cc │ ├── analyzer_canonical_simplify.h │ ├── analyzer_const_int_bound.cc │ ├── analyzer_const_int_bound.h │ ├── analyzer_impl.h │ ├── analyzer_interval_set.cc │ ├── analyzer_interval_set.h │ ├── analyzer_modular_set.cc │ ├── analyzer_modular_set.h │ ├── analyzer_rewrite_simplify.cc │ ├── analyzer_rewrite_simplify.h │ ├── analyzer_transitive_comparisons.cc │ ├── analyzer_transitive_comparisons.h │ ├── sym.cc │ ├── text_format.cc │ └── utils.h ├── include └── mlc │ ├── base │ ├── all.h │ ├── alloc.h │ ├── any.h │ ├── base_traits.h │ ├── lib.h │ ├── optional.h │ ├── ref.h │ ├── traits_device.h │ ├── traits_dtype.h │ ├── traits_object.h │ ├── traits_scalar.h │ ├── traits_str.h │ └── utils.h │ ├── c_api.h │ ├── core │ ├── all.h │ ├── dict.h │ ├── dict_base.h │ ├── error.h │ ├── func.h │ ├── func_details.h │ ├── list.h │ ├── list_base.h │ ├── object.h │ ├── object_path.h │ ├── opaque.h │ ├── reflection.h │ ├── str.h │ ├── tensor.h │ ├── typing.h │ ├── utils.h │ └── visitor.h │ ├── printer │ ├── all.h │ ├── ast.h │ └── ir_printer.h │ └── sym │ ├── all.h │ ├── analyzer.h │ ├── expr.h │ ├── expr_functor.h │ ├── op.h │ └── pattern_match.h ├── pyproject.toml ├── python └── mlc │ ├── __init__.py │ ├── _cython │ ├── .gitignore │ ├── __init__.py │ ├── base.py │ └── core.pyx │ ├── cc │ ├── __init__.py │ ├── compiler.py │ ├── jit.py │ └── loader.py │ ├── config.py │ ├── core │ ├── __init__.py │ ├── dep_graph.py │ ├── device.py │ ├── dict.py │ ├── dtype.py │ ├── error.py │ ├── func.py │ ├── list.py │ ├── object.py │ ├── object_path.py │ ├── opaque.py │ ├── tensor.py │ └── typing.py │ ├── dataclasses │ ├── __init__.py │ ├── c_class.py │ ├── py_class.py │ └── utils.py │ ├── parser │ ├── __init__.py │ ├── diagnostic.py │ ├── env.py │ └── parser.py │ ├── printer │ ├── __init__.py │ ├── ast.py │ ├── cprint.py │ └── ir_printer.py │ ├── sym │ ├── __init__.py │ ├── _internal.py │ ├── analyzer.py │ ├── expr.py │ └── op.py │ └── testing │ ├── __init__.py │ ├── dataclasses.py │ └── toy_ir │ ├── __init__.py │ ├── ir.py │ ├── ir_builder.py │ └── parser.py ├── scripts ├── cpp_tests.bat ├── cpp_tests.sh ├── setup_manylinux2014.sh └── show_wheel_content.py └── tests ├── cpp ├── CMakeLists.txt ├── common.h ├── test_base_any.cc ├── test_base_optional.cc ├── test_base_ref.cc ├── test_base_type_info.cc ├── test_core_dict.cc ├── test_core_func.cc ├── test_core_list.cc ├── test_core_str.cc ├── test_core_udict.cc ├── test_core_udict_legacy.cc ├── test_core_ulist.cc ├── test_core_ulist_legacy.cc └── test_sym_pattern.cc └── python ├── test_cc.py ├── test_cli_config.py ├── test_core_dep_graph.py ├── test_core_device.py ├── test_core_dict.py ├── test_core_dtype.py ├── test_core_func.py ├── test_core_json.py ├── test_core_list.py ├── test_core_object.py ├── test_core_object_path.py ├── test_core_opaque.py ├── test_core_tensor.py ├── test_core_typing.py ├── test_cython_traceback.py ├── test_dataclasses_copy.py ├── test_dataclasses_fields.py ├── test_dataclasses_prototype.py ├── test_dataclasses_py_class.py ├── test_dataclasses_serialize.py ├── test_dataclasses_structure.py ├── test_parser_toy_ir_parser.py ├── test_printer_ast.py ├── test_printer_ir_printer.py ├── test_sym_analyzer_canonical_simplify.py ├── test_sym_analyzer_const_int_bound.py ├── test_sym_analyzer_interval_set.py ├── test_sym_analyzer_modular_set.py ├── test_sym_analyzer_rewrite_simplify.py ├── test_sym_analyzer_simplify.py └── test_sym_expr.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | ColumnLimit: 120 3 | -------------------------------------------------------------------------------- /.clangd: -------------------------------------------------------------------------------- 1 | CompileFlags: 2 | CompilationDatabase: build-vscode/ 3 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | env: 5 | CIBW_BUILD_VERBOSITY: 3 6 | CIBW_TEST_REQUIRES: "pytest torch jsonpickle" 7 | CIBW_TEST_COMMAND: "pytest -svv --durations=20 {project}/tests/python/" 8 | CIBW_ENVIRONMENT: "MLC_SHOW_CPP_STACKTRACES=1" 9 | CIBW_REPAIR_WHEEL_COMMAND_LINUX: > 10 | auditwheel repair -w {dest_dir} {wheel} && 11 | pipx run abi3audit --strict --report {wheel} 12 | CIBW_REPAIR_WHEEL_COMMAND_MACOS: > 13 | delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel} && 14 | pipx run abi3audit --strict --report {wheel} 15 | CIBW_REPAIR_WHEEL_COMMAND_WINDOWS: > 16 | pipx run delvewheel repair -w {dest_dir} {wheel} && 17 | pipx run abi3audit --strict --report {wheel} 18 | MLC_CIBW_VERSION: "2.22.0" 19 | MLC_PYTHON_VERSION: "3.9" 20 | MLC_CIBW_WIN_BUILD: "cp39-win_amd64" 21 | MLC_CIBW_MAC_BUILD: "cp39-macosx_arm64" 22 | MLC_CIBW_LINUX_BUILD: "cp313-manylinux_x86_64" 23 | 24 | jobs: 25 | pre-commit: 26 | runs-on: ubuntu-latest 27 | steps: 28 | - uses: actions/checkout@v4 29 | - uses: actions/setup-python@v5 30 | with: 31 | python-version: ${{ env.MLC_PYTHON_VERSION }} 32 | - uses: pre-commit/action@v3.0.1 33 | - uses: ytanikin/pr-conventional-commits@1.4.0 34 | if: github.event_name == 'pull_request' 35 | with: 36 | task_types: '["feat", "fix", "ci", "chore", "test"]' 37 | add_label: 'false' 38 | windows: 39 | name: Windows 40 | runs-on: windows-latest 41 | needs: pre-commit 42 | steps: 43 | - uses: actions/checkout@v4 44 | with: 45 | submodules: "recursive" 46 | - uses: actions/setup-python@v5 47 | with: 48 | python-version: ${{ env.MLC_PYTHON_VERSION }} 49 | - name: Install cibuildwheel 50 | run: python -m pip install cibuildwheel=="${{ env.MLC_CIBW_VERSION }}" 51 | - name: Build wheels 52 | run: python -m cibuildwheel --output-dir wheelhouse 53 | env: 54 | CIBW_BEFORE_ALL: ".\\scripts\\cpp_tests.bat" 55 | CIBW_BUILD: ${{ env.MLC_CIBW_WIN_BUILD }} 56 | - name: Show package contents 57 | run: python scripts/show_wheel_content.py wheelhouse 58 | macos: 59 | name: MacOS 60 | runs-on: macos-latest 61 | needs: pre-commit 62 | steps: 63 | - uses: actions/checkout@v4 64 | with: 65 | submodules: "recursive" 66 | - uses: actions/setup-python@v5 67 | with: 68 | python-version: ${{ env.MLC_PYTHON_VERSION }} 69 | - name: Install cibuildwheel 70 | run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }} 71 | - name: Build wheels 72 | run: python -m cibuildwheel --output-dir wheelhouse 73 | env: 74 | CIBW_BEFORE_ALL: "./scripts/cpp_tests.sh" 75 | CIBW_BUILD: ${{ env.MLC_CIBW_MAC_BUILD }} 76 | - name: Show package contents 77 | run: python scripts/show_wheel_content.py wheelhouse 78 | linux: 79 | name: Linux 80 | runs-on: ubuntu-latest 81 | needs: pre-commit 82 | steps: 83 | - uses: actions/checkout@v4 84 | with: 85 | submodules: "recursive" 86 | - uses: actions/setup-python@v5 87 | with: 88 | python-version: ${{ env.MLC_PYTHON_VERSION }} 89 | - name: Install cibuildwheel 90 | run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }} 91 | - name: Build wheels 92 | run: python -m cibuildwheel --output-dir wheelhouse 93 | env: 94 | CIBW_BEFORE_ALL: "./scripts/setup_manylinux2014.sh && ./scripts/cpp_tests.sh" 95 | CIBW_BUILD: ${{ env.MLC_CIBW_LINUX_BUILD }} 96 | - name: Show package contents 97 | run: python scripts/show_wheel_content.py wheelhouse 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.dSYM 3 | build 4 | build-cpp-tests 5 | python/mlc/_version.py 6 | wheelhouse/ 7 | uv.lock 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/dlpack"] 2 | path = 3rdparty/dlpack 3 | url = https://github.com/dmlc/dlpack.git 4 | [submodule "3rdparty/googletest"] 5 | path = 3rdparty/googletest 6 | url = https://github.com/google/googletest.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | default_install_hook_types: 4 | - pre-commit 5 | - commit-msg 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v5.0.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: mixed-line-ending 12 | - id: end-of-file-fixer 13 | - id: check-yaml 14 | - id: check-toml 15 | - id: check-added-large-files 16 | - repo: https://github.com/astral-sh/ruff-pre-commit 17 | rev: v0.9.0 18 | hooks: 19 | - id: ruff 20 | types_or: [python, pyi, jupyter] 21 | args: [--fix] 22 | - id: ruff-format 23 | types_or: [python, pyi, jupyter] 24 | - repo: https://github.com/pre-commit/mirrors-mypy 25 | rev: "v1.14.1" 26 | hooks: 27 | - id: mypy 28 | additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "torch", "jsonpickle"] 29 | args: [--show-error-codes] 30 | - repo: https://github.com/pre-commit/mirrors-clang-format 31 | rev: "v19.1.6" 32 | hooks: 33 | - id: clang-format 34 | - repo: https://github.com/MarcoGorelli/cython-lint 35 | rev: v0.16.6 36 | hooks: 37 | - id: cython-lint 38 | - id: double-quote-cython-strings 39 | - repo: https://github.com/scop/pre-commit-shfmt 40 | rev: v3.10.0-2 41 | hooks: 42 | - id: shfmt 43 | - repo: https://github.com/shellcheck-py/shellcheck-py 44 | rev: v0.10.0.1 45 | hooks: 46 | - id: shellcheck 47 | # - repo: https://github.com/cheshirekow/cmake-format-precommit 48 | # rev: v0.6.10 49 | # hooks: 50 | # - id: cmake-format 51 | # - id: cmake-lint 52 | - repo: https://github.com/compilerla/conventional-pre-commit 53 | rev: v4.0.0 54 | hooks: 55 | - id: conventional-pre-commit 56 | stages: [commit-msg] 57 | args: [feat, fix, ci, chore, test] 58 | -------------------------------------------------------------------------------- /cmake/CPM.cmake: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: MIT 2 | # 3 | # SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors 4 | 5 | set(CPM_DOWNLOAD_VERSION 0.40.8) 6 | set(CPM_HASH_SUM "78ba32abdf798bc616bab7c73aac32a17bbd7b06ad9e26a6add69de8f3ae4791") 7 | 8 | if(CPM_SOURCE_CACHE) 9 | set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") 10 | elseif(DEFINED ENV{CPM_SOURCE_CACHE}) 11 | set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") 12 | else() 13 | set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") 14 | endif() 15 | 16 | # Expand relative path. This is important if the provided path contains a tilde (~) 17 | get_filename_component(CPM_DOWNLOAD_LOCATION ${CPM_DOWNLOAD_LOCATION} ABSOLUTE) 18 | 19 | file(DOWNLOAD 20 | https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake 21 | ${CPM_DOWNLOAD_LOCATION} EXPECTED_HASH SHA256=${CPM_HASH_SUM} 22 | ) 23 | 24 | include(${CPM_DOWNLOAD_LOCATION}) 25 | -------------------------------------------------------------------------------- /cmake/Utils/AddGoogleTest.cmake: -------------------------------------------------------------------------------- 1 | set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE) 2 | set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) 3 | set(BUILD_GTEST ON CACHE BOOL "" FORCE) 4 | set(GOOGLETEST_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/googletest) 5 | 6 | add_subdirectory(${GOOGLETEST_ROOT}) 7 | include(GoogleTest) 8 | 9 | set_target_properties(gtest gtest_main 10 | PROPERTIES 11 | EXPORT_COMPILE_COMMANDS OFF 12 | EXCLUDE_FROM_ALL ON 13 | FOLDER 3rdparty 14 | ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" 15 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" 16 | RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" 17 | ) 18 | 19 | # Install gtest and gtest_main 20 | install(TARGETS gtest gtest_main DESTINATION "lib/") 21 | 22 | # Mark advanced variables 23 | mark_as_advanced( 24 | BUILD_GMOCK BUILD_GTEST BUILD_SHARED_LIBS 25 | gmock_build_tests gtest_build_samples gtest_build_tests 26 | gtest_disable_pthreads gtest_force_shared_crt gtest_hide_internal_symbols 27 | ) 28 | 29 | macro(add_googletest target_name) 30 | add_test( 31 | NAME ${target_name} 32 | COMMAND ${target_name} 33 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} 34 | ) 35 | target_link_libraries(${target_name} PRIVATE gtest_main) 36 | gtest_discover_tests(${target_name} 37 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 38 | DISCOVERY_MODE PRE_TEST 39 | PROPERTIES 40 | VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" 41 | ) 42 | set_target_properties(${target_name} PROPERTIES FOLDER tests) 43 | endmacro() 44 | -------------------------------------------------------------------------------- /cmake/Utils/CxxUtils.cmake: -------------------------------------------------------------------------------- 1 | if(APPLE) 2 | find_program(DSYMUTIL_PROGRAM dsymutil) 3 | mark_as_advanced(DSYMUTIL_PROGRAM) 4 | endif() 5 | 6 | function(add_cxx_warning target_name) 7 | # GNU, Clang, or AppleClang 8 | if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") 9 | target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic") 10 | return() 11 | endif() 12 | # MSVC 13 | if(MSVC) 14 | target_compile_options(${target_name} PRIVATE "/W4" "/WX") 15 | return() 16 | endif() 17 | message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") 18 | endfunction() 19 | 20 | function(add_debug_symbol_apple _target _directory) 21 | if(APPLE) 22 | add_custom_command(TARGET ${_target} POST_BUILD 23 | COMMAND ${DSYMUTIL_PROGRAM} ARGS $ 24 | COMMENT "Running dsymutil" VERBATIM 25 | ) 26 | install(FILES $.dSYM DESTINATION ${_directory}) 27 | endif(APPLE) 28 | endfunction() 29 | 30 | function(add_sanitizer_address target_name) 31 | if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") 32 | include(CheckCXXCompilerFlag) 33 | set(_saved_CRF ${CMAKE_REQUIRED_FLAGS}) 34 | set(CMAKE_REQUIRED_FLAGS "-fsanitize=address") 35 | check_cxx_source_compiles("int main() { return 0; }" COMPILER_SUPPORTS_ASAN) 36 | set(CMAKE_REQUIRED_FLAGS ${_saved_CRF}) 37 | get_target_property(_saved_type ${target_name} TYPE) 38 | if(${_saved_type} STREQUAL "INTERFACE_LIBRARY") 39 | set(_saved_type INTERFACE) 40 | else() 41 | set(_saved_type PRIVATE) 42 | endif() 43 | target_link_options(${target_name} ${_saved_type} "-fsanitize=address") 44 | target_compile_options(${target_name} ${_saved_type} "-fsanitize=address" "-fno-omit-frame-pointer" "-g") 45 | return() 46 | endif() 47 | endfunction() 48 | -------------------------------------------------------------------------------- /cmake/Utils/GitVersion.cmake: -------------------------------------------------------------------------------- 1 | find_package(Git) 2 | 3 | if (GIT_EXECUTABLE) 4 | execute_process( 5 | COMMAND ${GIT_EXECUTABLE} describe --tags --dirty --match "v*" 6 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} 7 | OUTPUT_VARIABLE _GIT_OUTPUT 8 | RESULT_VARIABLE _GIT_ERROR 9 | OUTPUT_STRIP_TRAILING_WHITESPACE 10 | ) 11 | if (NOT _GIT_ERROR) 12 | string(REGEX REPLACE "^v" "" MLC_VERSION_RAW "${_GIT_OUTPUT}") 13 | endif() 14 | else() 15 | message(ERROR "Git not found, cannot determine version. Falling back to 0.0.0-0-unknown.") 16 | endif() 17 | 18 | if(NOT DEFINED MLC_VERSION_RAW) 19 | set(MLC_VERSION_RAW 0.0.0-0-unknown) 20 | message(WARNING "Failed to determine MLC_VERSION_RAW from Git tags. Using default version \"${MLC_VERSION_RAW}\".") 21 | endif() 22 | 23 | string(REGEX MATCH "^([0-9]+)\\.([0-9]+)\\.([0-9]+)(-([0-9]+)-([a-z0-9]+))?" MLC_VERSION_MATCH ${MLC_VERSION_RAW}) 24 | set(MLC_VERSION_MAJOR ${CMAKE_MATCH_1}) 25 | set(MLC_VERSION_MINOR ${CMAKE_MATCH_2}) 26 | set(MLC_VERSION_PATCH ${CMAKE_MATCH_3}) 27 | set(MLC_VERSION_COMMIT_NUM ${CMAKE_MATCH_5}) 28 | set(MLC_VERSION_COMMIT_SHA ${CMAKE_MATCH_6}) 29 | 30 | if (NOT MLC_VERSION_COMMIT_NUM) 31 | set(MLC_VERSION_GIT ${MLC_VERSION_MAJOR}.${MLC_VERSION_MINOR}.${MLC_VERSION_PATCH}) 32 | else() 33 | # Increment `MLC_VERSION_PATCH` by 1 34 | math(EXPR MLC_VERSION_PATCH "${MLC_VERSION_PATCH}+1") 35 | set(MLC_VERSION_GIT ${MLC_VERSION_MAJOR}.${MLC_VERSION_MINOR}.${MLC_VERSION_PATCH}.dev${MLC_VERSION_COMMIT_NUM}+${MLC_VERSION_COMMIT_SHA}) 36 | endif() 37 | -------------------------------------------------------------------------------- /cpp/sym/analyzer_canonical_simplify.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_SYM_ANALYZER_CANONICAL_SIMPLIFY_H_ 2 | #define MLC_SYM_ANALYZER_CANONICAL_SIMPLIFY_H_ 3 | 4 | #include "./utils.h" 5 | #include 6 | 7 | namespace mlc { 8 | namespace sym { 9 | 10 | enum class DivMode { 11 | kTruncDiv = 0, 12 | kFloorDiv = 1, 13 | }; 14 | 15 | class CanonicalSimplifier { 16 | public: 17 | Expr operator()(const Expr &expr); 18 | void Update(const Var &var, const Expr &new_expr, bool allow_override = false); 19 | 20 | explicit CanonicalSimplifier(AnalyzerObj::Impl *parent); 21 | ~CanonicalSimplifier(); 22 | struct Impl; 23 | std::unique_ptr impl_; 24 | }; 25 | 26 | struct SplitExpr; 27 | struct SumExpr; 28 | 29 | struct SplitExprObj : public ExprObj { 30 | MLC_DEF_DYN_TYPE(MLC_EXPORTS, SplitExprObj, ExprObj, "mlc.sym.SplitExpr"); 31 | 32 | Expr index; 33 | int64_t lower_factor; // = 1 34 | int64_t upper_factor; // = kPosInf 35 | int64_t scale; // = 1 36 | DivMode div_mode; // = DivMode::kTruncDiv 37 | 38 | explicit SplitExprObj(DLDataType dtype, Expr index, int64_t lower_factor, int64_t upper_factor, int64_t scale, 39 | DivMode div_mode) 40 | : ExprObj(dtype), index(index), lower_factor(lower_factor), upper_factor(upper_factor), scale(scale), 41 | div_mode(div_mode) {} 42 | 43 | void Verify() const { 44 | if (!(upper_factor == kPosInf || upper_factor % lower_factor == 0)) { 45 | MLC_THROW(InternalError) << "Failed verification"; 46 | } 47 | } 48 | std::string __str__() const { 49 | std::ostringstream os; 50 | os << "SplitExpr(index=" << this->index // 51 | << ", lower_factor=" << this->lower_factor // 52 | << ", upper_factor=" << this->upper_factor // 53 | << ", scale=" << this->scale // 54 | << ", div_mode=" << (this->div_mode == DivMode::kTruncDiv ? "kTruncDiv" : "kFloorDiv") << ")"; 55 | return os.str(); 56 | } 57 | Expr NormalizeWithScale(int64_t sscale) const; 58 | Expr Normalize() const { return NormalizeWithScale(1); } 59 | void MulToSelf(int64_t s) { this->scale *= s; } 60 | bool CanPushCastToChildren(DLDataType dtype, AnalyzerObj::Impl *analyzer) const; 61 | void PushCastToChildren(DLDataType dtype); 62 | inline bool IndexEqual(const SplitExpr &other) const; 63 | inline bool DivModeCompatibleTo(DivMode mode) const; 64 | }; 65 | 66 | struct SplitExpr : public Expr { 67 | MLC_DEF_OBJ_REF(MLC_EXPORTS, SplitExpr, SplitExprObj, Expr) // 68 | .MemFn("__str__", &SplitExprObj::__str__); 69 | MLC_DEF_OBJ_REF_COW_(); 70 | explicit SplitExpr(DLDataType dtype, Expr index, int64_t lower_factor = 1, int64_t upper_factor = kPosInf, 71 | int64_t scale = 1, DivMode div_mode = DivMode::kTruncDiv) 72 | : SplitExpr(SplitExpr::New(dtype, index, lower_factor, upper_factor, scale, div_mode)) {} 73 | }; 74 | 75 | struct SumExprObj : public ExprObj { 76 | MLC_DEF_DYN_TYPE(MLC_EXPORTS, SumExprObj, ExprObj, "mlc.sym.SumExpr"); 77 | 78 | std::vector args; 79 | int64_t base{0}; 80 | explicit SumExprObj(DLDataType dtype, std::vector args, int64_t base) : args(std::move(args)), base(base) { 81 | this->dtype = dtype; 82 | } 83 | explicit SumExprObj(DLDataType dtype) : args(), base(0) { this->dtype = dtype; } 84 | 85 | bool IsZero() const { return base == 0 && args.size() == 0; } 86 | Expr Normalize() const; 87 | bool DivisibleBy(int64_t scale); 88 | void MulToSelf(int64_t scale); 89 | void DivideBy(int64_t scale); 90 | void AddToSelf(int64_t value) { this->base += value; } 91 | void AddToSelf(SplitExpr other, int64_t scale); 92 | void AddToSelf(const SumExpr &other, int64_t scale); 93 | bool CanPushCastToChildren(DLDataType dtype, AnalyzerObj::Impl *analyzer) const; 94 | void PushCastToChildren(DLDataType dtype); 95 | std::string __str__() const { 96 | std::ostringstream os; 97 | os << "SumExpr(base=" << this->base << ", args=["; 98 | bool is_first = true; 99 | for (const auto &arg : this->args) { 100 | if (!is_first) { 101 | os << ", "; 102 | } else { 103 | is_first = false; 104 | } 105 | os << arg->__str__(); 106 | } 107 | os << "])"; 108 | return os.str(); 109 | } 110 | 111 | private: 112 | static std::vector SimplifySplitExprs(std::vector args); 113 | static Expr Normalize_(DLDataType dtype, const std::vector &args, int64_t base); 114 | }; 115 | 116 | struct SumExpr : public Expr { 117 | SumExpr(DLDataType dtype) : SumExpr(SumExpr::New(dtype)) {} 118 | SumExpr(DLDataType dtype, std::vector args, int64_t base) : SumExpr(SumExpr::New(dtype, args, base)) {} 119 | MLC_DEF_OBJ_REF(MLC_EXPORTS, SumExpr, SumExprObj, Expr) // 120 | .MemFn("__str__", &SumExprObj::__str__); 121 | MLC_DEF_OBJ_REF_COW_(); 122 | }; 123 | 124 | } // namespace sym 125 | } // namespace mlc 126 | 127 | #endif // MLC_SYM_ANALYZER_CANONICAL_SIMPLIFY_H_ 128 | -------------------------------------------------------------------------------- /cpp/sym/analyzer_const_int_bound.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_SYM_ANALYZER_CONST_INT_BOUND_H_ 2 | #define MLC_SYM_ANALYZER_CONST_INT_BOUND_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace mlc { 8 | namespace sym { 9 | 10 | struct ConstIntBoundObj { 11 | MLCAny _mlc_header; 12 | int64_t min_value; 13 | int64_t max_value; 14 | explicit ConstIntBoundObj(int64_t min_value, int64_t max_value) 15 | : _mlc_header{}, min_value(min_value), max_value(max_value) {} 16 | 17 | std::string __str__() const { 18 | std::ostringstream oss; 19 | oss << "ConstIntBound[" << min_value << ", " << max_value << "]"; 20 | return oss.str(); 21 | } 22 | 23 | static constexpr int64_t kPosInf = std::numeric_limits::max(); 24 | static constexpr int64_t kNegInf = -kPosInf; 25 | 26 | MLC_DEF_DYN_TYPE(MLC_EXPORTS, ConstIntBoundObj, ::mlc::Object, "mlc.sym.ConstIntBound"); 27 | }; // struct ConstIntBoundObj 28 | 29 | struct ConstIntBound : public ::mlc::ObjectRef { 30 | MLC_DEF_OBJ_REF(MLC_EXPORTS, ConstIntBound, ConstIntBoundObj, ::mlc::ObjectRef) 31 | .Field("min_value", &ConstIntBoundObj::min_value) 32 | .Field("max_value", &ConstIntBoundObj::max_value) 33 | .MemFn("__str__", &ConstIntBoundObj::__str__) 34 | .StaticFn("__init__", ::mlc::InitOf); 35 | MLC_DEF_OBJ_REF_FWD_NEW(ConstIntBound) 36 | }; // struct ConstIntBound 37 | 38 | struct ConstIntBoundAnalyzer { 39 | using BoundMapType = Dict; 40 | ConstIntBound operator()(const Expr &expr) const; 41 | ConstIntBound operator()(const Expr &expr, BoundMapType *bound); 42 | void Update(const Var &var, const ConstIntBound &info, bool allow_override = false); 43 | void Bind(const Var &var, const Range &range, bool allow_override = false); 44 | 45 | explicit ConstIntBoundAnalyzer(AnalyzerObj::Impl *parent); 46 | ~ConstIntBoundAnalyzer(); 47 | std::function EnterConstraint(const Expr &constraint); 48 | struct Impl; 49 | std::unique_ptr impl_; 50 | }; 51 | 52 | } // namespace sym 53 | } // namespace mlc 54 | 55 | #endif // MLC_SYM_ANALYZER_CONST_INT_BOUND_H_ 56 | -------------------------------------------------------------------------------- /cpp/sym/analyzer_interval_set.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_SYM_ANALYZER_INTERVAL_SET_H_ 2 | #define MLC_SYM_ANALYZER_INTERVAL_SET_H_ 3 | 4 | #include 5 | 6 | namespace mlc { 7 | namespace sym { 8 | 9 | struct IntervalSet; 10 | 11 | struct IntervalSetObj { 12 | MLCAny _mlc_header; 13 | Expr min_value; 14 | Expr max_value; 15 | explicit IntervalSetObj(Expr min_value, Expr max_value) : _mlc_header{}, min_value(min_value), max_value(max_value) {} 16 | MLC_DEF_DYN_TYPE(MLC_EXPORTS, IntervalSetObj, ::mlc::Object, "mlc.sym.IntervalSet"); 17 | 18 | bool HasUpperBound() const; 19 | bool HasLowerBound() const; 20 | bool IsSinglePoint() const; 21 | bool IsEmpty() const; 22 | bool IsEverything() const; 23 | std::string __str__() const { 24 | std::ostringstream os; 25 | os << "IntervalSet[" << min_value << ", " << max_value << "]"; 26 | return os.str(); 27 | } 28 | IntervalSet Union(IntervalSetObj *b, AnalyzerObj::Impl *analyzer) const; 29 | IntervalSet Intersect(IntervalSetObj *b, AnalyzerObj::Impl *analyzer) const; 30 | }; 31 | 32 | struct IntervalSet : public ObjectRef { 33 | MLC_DEF_OBJ_REF(MLC_EXPORTS, IntervalSet, IntervalSetObj, ObjectRef) 34 | .Field("min_value", &IntervalSetObj::min_value) 35 | .Field("max_value", &IntervalSetObj::max_value) 36 | .MemFn("__str__", &IntervalSetObj::__str__) 37 | .StaticFn("__init__", ::mlc::InitOf); 38 | explicit IntervalSet(Expr min_value, Expr max_value) : IntervalSet(IntervalSet::New(min_value, max_value)) {} 39 | static IntervalSet Nothing() { return IntervalSet::Empty(); } 40 | static IntervalSet SinglePoint(Expr value) { return IntervalSet(value, value); } 41 | static IntervalSet Everything(); 42 | static IntervalSet Empty(); 43 | static IntervalSet FromRange(const Range &range); 44 | static IntervalSet Interval(Expr min, Expr max); 45 | static IntervalSet Intersect(const List &sets, AnalyzerObj::Impl *analyzer); 46 | }; 47 | 48 | struct IntervalSetAnalyzer { 49 | IntervalSet operator()(const Expr &expr, const Dict &dom_map); 50 | IntervalSet operator()(const Expr &expr); 51 | void Update(const Var &var, const IntervalSet &new_interval_set, bool allow_override = false); 52 | void Bind(const Var &var, const Range &new_range, bool allow_override = false); 53 | std::function EnterConstraint(const Expr &constraint); 54 | explicit IntervalSetAnalyzer(AnalyzerObj::Impl *parent); 55 | ~IntervalSetAnalyzer(); 56 | struct Impl; 57 | std::unique_ptr impl_; 58 | }; 59 | 60 | } // namespace sym 61 | } // namespace mlc 62 | 63 | #endif // MLC_SYM_ANALYZER_INTERVAL_SET_H_ 64 | -------------------------------------------------------------------------------- /cpp/sym/analyzer_modular_set.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_SYM_ANALYZER_MODULAR_SET_H_ 2 | #define MLC_SYM_ANALYZER_MODULAR_SET_H_ 3 | 4 | #include 5 | 6 | namespace mlc { 7 | namespace sym { 8 | 9 | struct ModularSetObj { 10 | MLCAny _mlc_header; 11 | int64_t coeff; 12 | int64_t base; 13 | explicit ModularSetObj(int64_t coeff, int64_t base) : _mlc_header{}, coeff(coeff), base(base) {} 14 | 15 | std::string __str__() const { 16 | std::ostringstream oss; 17 | oss << "ModularSet(coeff=" << coeff << ", base=" << base << ")"; 18 | return oss.str(); 19 | } 20 | 21 | MLC_DEF_DYN_TYPE(MLC_EXPORTS, ModularSetObj, ::mlc::Object, "mlc.sym.ModularSet"); 22 | }; // struct ModularSetObj 23 | 24 | struct ModularSet : public ::mlc::ObjectRef { 25 | MLC_DEF_OBJ_REF(MLC_EXPORTS, ModularSet, ModularSetObj, ::mlc::ObjectRef) 26 | .Field("coeff", &ModularSetObj::coeff) 27 | .Field("base", &ModularSetObj::base) 28 | .MemFn("__str__", &ModularSetObj::__str__) 29 | .StaticFn("__init__", ::mlc::InitOf); 30 | MLC_DEF_OBJ_REF_FWD_NEW(ModularSet) 31 | }; // struct ModularSet 32 | 33 | class ModularSetAnalyzer { 34 | public: 35 | ModularSet operator()(const Expr &expr); 36 | void Update(const Var &var, const ModularSet &info, bool allow_override = false); 37 | 38 | explicit ModularSetAnalyzer(AnalyzerObj::Impl *parent); 39 | ~ModularSetAnalyzer(); 40 | std::function EnterConstraint(const Expr &constraint); 41 | struct Impl; 42 | std::unique_ptr impl_; 43 | }; 44 | 45 | } // namespace sym 46 | } // namespace mlc 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /cpp/sym/analyzer_transitive_comparisons.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_SYM_ANALYZER_TRANSITIVE_COMPARISONS_H_ 2 | #define MLC_SYM_ANALYZER_TRANSITIVE_COMPARISONS_H_ 3 | 4 | #include "./utils.h" 5 | #include 6 | 7 | namespace mlc { 8 | namespace sym { 9 | 10 | class TransitiveComparisonAnalyzer { 11 | public: 12 | CompareResult TryCompare(const Expr &lhs, const Expr &rhs, bool propagate_inequalities = true); 13 | void Bind(const Var &var, const Expr &expr, bool allow_override = false); 14 | void Bind(const Var &var, const Range &range, bool allow_override = false); 15 | std::function EnterConstraint(const Expr &constraint); 16 | 17 | explicit TransitiveComparisonAnalyzer(AnalyzerObj::Impl *analyzer); 18 | ~TransitiveComparisonAnalyzer(); 19 | struct Impl; 20 | std::unique_ptr impl_; 21 | }; 22 | 23 | } // namespace sym 24 | } // namespace mlc 25 | 26 | #endif // MLC_SYM_ANALYZER_TRANSITIVE_COMPARISONS_H_ 27 | -------------------------------------------------------------------------------- /include/mlc/base/alloc.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_BASE_ALLOC_H_ 2 | #define MLC_BASE_ALLOC_H_ 3 | 4 | #include "./utils.h" 5 | #include 6 | #include 7 | 8 | namespace mlc { 9 | 10 | template struct DefaultObjectAllocator { 11 | using Storage = typename std::aligned_storage::type; 12 | 13 | template >> 14 | MLC_INLINE_NO_MSVC static T *New(Args &&...args) { 15 | Storage *data = new Storage; 16 | try { 17 | new (data) T(std::forward(args)...); 18 | } catch (...) { 19 | delete data; 20 | throw; 21 | } 22 | T *ret = reinterpret_cast(data); 23 | ret->_mlc_header.type_index = T::_type_index; 24 | ret->_mlc_header.ref_cnt = 0; 25 | ret->_mlc_header.v.deleter = DefaultObjectAllocator::Deleter; 26 | return ret; 27 | } 28 | 29 | template >> 30 | MLC_INLINE_NO_MSVC static T *NewWithPad(size_t pad_size, Args &&...args) { 31 | size_t num_storages = (sizeof(T) + pad_size * sizeof(PadType) + sizeof(Storage) - 1) / sizeof(Storage); 32 | Storage *data = new Storage[num_storages]; 33 | try { 34 | new (data) T(std::forward(args)...); 35 | } catch (...) { 36 | delete[] data; 37 | throw; 38 | } 39 | T *ret = reinterpret_cast(data); 40 | ret->_mlc_header.type_index = T::_type_index; 41 | ret->_mlc_header.ref_cnt = 0; 42 | ret->_mlc_header.v.deleter = DefaultObjectAllocator::DeleterArray; 43 | return ret; 44 | } 45 | 46 | static void Deleter(void *objptr) { 47 | T *tptr = static_cast(objptr); 48 | tptr->T::~T(); 49 | delete reinterpret_cast(tptr); 50 | } 51 | 52 | static void DeleterArray(void *objptr) { 53 | T *tptr = static_cast(objptr); 54 | tptr->T::~T(); 55 | delete[] reinterpret_cast(tptr); 56 | } 57 | }; 58 | 59 | template struct PODAllocator; 60 | 61 | #define MLC_DEF_POD_ALLOCATOR(Type, TypeIndex, Field) \ 62 | template <> struct PODAllocator { \ 63 | MLC_INLINE_NO_MSVC static MLCAny *New(Type data) { \ 64 | MLCBoxedPOD *ret = new MLCBoxedPOD; \ 65 | ret->_mlc_header.type_index = static_cast(TypeIndex); \ 66 | ret->_mlc_header.ref_cnt = 0; \ 67 | ret->_mlc_header.v.deleter = PODAllocator::Deleter; \ 68 | ret->data.v_int64 = 0; \ 69 | ret->data.Field = data; \ 70 | return reinterpret_cast(ret); \ 71 | } \ 72 | static void Deleter(void *objptr) { delete static_cast(objptr); } \ 73 | } 74 | 75 | MLC_DEF_POD_ALLOCATOR(bool, MLCTypeIndex::kMLCBool, v_bool); 76 | MLC_DEF_POD_ALLOCATOR(int64_t, MLCTypeIndex::kMLCInt, v_int64); 77 | MLC_DEF_POD_ALLOCATOR(double, MLCTypeIndex::kMLCFloat, v_float64); 78 | MLC_DEF_POD_ALLOCATOR(DLDataType, MLCTypeIndex::kMLCDataType, v_dtype); 79 | MLC_DEF_POD_ALLOCATOR(DLDevice, MLCTypeIndex::kMLCDevice, v_device); 80 | MLC_DEF_POD_ALLOCATOR(::mlc::base::VoidPtr, MLCTypeIndex::kMLCPtr, v_ptr); 81 | 82 | #undef MLC_DEF_POD_ALLOCATOR 83 | 84 | MLC_INLINE mlc::Object *AllocExternObject(int32_t type_index, int32_t num_bytes) { 85 | MLCAny *ptr = reinterpret_cast(std::malloc(num_bytes)); 86 | std::memset(ptr, 0, num_bytes); 87 | ptr->type_index = type_index; 88 | ptr->ref_cnt = 0; 89 | ptr->v.deleter = MLCExtObjDelete; 90 | return reinterpret_cast(ptr); 91 | } 92 | 93 | } // namespace mlc 94 | 95 | #endif // MLC_BASE_ALLOC_H_ 96 | -------------------------------------------------------------------------------- /include/mlc/base/lib.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_BASE_LIB_H_ 2 | #define MLC_BASE_LIB_H_ 3 | 4 | #include "./utils.h" 5 | 6 | namespace mlc { 7 | 8 | struct VTable { 9 | VTable(const VTable &) = delete; 10 | VTable &operator=(const VTable &) = delete; 11 | VTable(VTable &&other) noexcept : self(other.self) { other.self = nullptr; } 12 | VTable &operator=(VTable &&other) noexcept { 13 | this->Swap(other); 14 | return *this; 15 | } 16 | ~VTable() { MLC_CHECK_ERR(::MLCVTableDelete(self)); } 17 | 18 | template R operator()(Args... args) const; 19 | template VTable &Set(Func func); 20 | 21 | private: 22 | friend struct Lib; 23 | VTable(MLCVTableHandle self) : self(self) {} 24 | void Swap(VTable &other) { std::swap(self, other.self); } 25 | MLCVTableHandle self; 26 | }; 27 | 28 | struct Lib { 29 | static int32_t FuncSetGlobal(const char *name, FuncObj *func, bool allow_override = false); 30 | static FuncObj *FuncGetGlobal(const char *name, bool allow_missing = false); 31 | static ::mlc::Str CxxStr(AnyView obj); 32 | static ::mlc::Str Str(AnyView obj); 33 | static int64_t StructuralHash(AnyView obj); 34 | static bool StructuralEqual(AnyView a, AnyView b, bool bind_free_vars = true, bool assert_mode = false); 35 | static Any IRPrint(AnyView obj, AnyView printer, AnyView path); 36 | static const char *DeviceTypeToStr(int32_t device_type); 37 | static int32_t DeviceTypeFromStr(const char *source); 38 | static void DeviceTypeRegister(const char *name); 39 | static const char *DataTypeCodeToStr(int32_t dtype_code); 40 | static DLDataType DataTypeFromStr(const char *source); 41 | static void DataTypeRegister(const char *name, int32_t dtype_bits); 42 | 43 | static FuncObj *_init(int32_t type_index) { return VTableGetFunc(init, type_index, "__init__"); } 44 | static VTable MakeVTable(const char *name) { 45 | MLCVTableHandle vtable = nullptr; 46 | MLC_CHECK_ERR(::MLCVTableCreate(_lib, name, &vtable)); 47 | return VTable(vtable); 48 | } 49 | MLC_INLINE static MLCTypeInfo *GetTypeInfo(int32_t type_index) { 50 | MLCTypeInfo *type_info = nullptr; 51 | MLC_CHECK_ERR(::MLCTypeIndex2Info(_lib, type_index, &type_info)); 52 | return type_info; 53 | } 54 | MLC_INLINE static MLCTypeInfo *GetTypeInfo(const char *type_key) { 55 | MLCTypeInfo *type_info = nullptr; 56 | MLC_CHECK_ERR(::MLCTypeKey2Info(_lib, type_key, &type_info)); 57 | return type_info; 58 | } 59 | MLC_INLINE static const char *GetTypeKey(int32_t type_index) { 60 | if (MLCTypeInfo *type_info = GetTypeInfo(type_index)) { 61 | return type_info->type_key; 62 | } 63 | return "(undefined)"; 64 | } 65 | MLC_INLINE static const char *GetTypeKey(const MLCAny *self) { 66 | if (self == nullptr) { 67 | return "None"; 68 | } else if (MLCTypeInfo *type_info = GetTypeInfo(self->type_index)) { 69 | return type_info->type_key; 70 | } 71 | return "(undefined)"; 72 | } 73 | MLC_INLINE static int32_t GetTypeIndex(const char *type_key) { 74 | if (MLCTypeInfo *type_info = GetTypeInfo(type_key)) { 75 | return type_info->type_index; 76 | } 77 | MLC_THROW(TypeError) << "Cannot find type with key: " << type_key; 78 | MLC_UNREACHABLE(); 79 | } 80 | MLC_INLINE static MLCTypeInfo *TypeRegister(int32_t parent_type_index, int32_t type_index, const char *type_key) { 81 | MLCTypeInfo *info = nullptr; 82 | MLC_CHECK_ERR(::MLCTypeRegister(_lib, parent_type_index, type_key, type_index, &info)); 83 | return info; 84 | } 85 | 86 | private: 87 | static FuncObj *VTableGetFunc(MLCVTableHandle vtable, int32_t type_index, const char *vtable_name) { 88 | MLCAny func{}; 89 | MLC_CHECK_ERR(::MLCVTableGetFunc(vtable, type_index, true, &func)); 90 | if (!::mlc::base::IsTypeIndexPOD(func.type_index)) { 91 | ::mlc::base::DecRef(func.v.v_obj); 92 | } 93 | FuncObj *ret = reinterpret_cast(func.v.v_obj); 94 | if (func.type_index == kMLCNone) { 95 | MLC_THROW(TypeError) << "Function `" << vtable_name << "` for type: " << GetTypeKey(type_index) 96 | << " is not defined in the vtable"; 97 | } else if (func.type_index != kMLCFunc) { 98 | MLC_THROW(TypeError) << "Function `" << vtable_name << "` for type: " << GetTypeKey(type_index) 99 | << " is not callable. Its type is " << GetTypeKey(func.type_index); 100 | } 101 | return ret; 102 | } 103 | static MLCVTableHandle VTableGetGlobal(const char *name) { 104 | MLCVTableHandle ret = nullptr; 105 | MLC_CHECK_ERR(::MLCVTableGetGlobal(_lib, name, &ret)); 106 | return ret; 107 | } 108 | static MLC_SYMBOL_HIDE inline MLCTypeTableHandle _lib = []() { 109 | MLCTypeTableHandle ret = nullptr; 110 | MLC_CHECK_ERR(::MLCHandleGetGlobal(&ret)); 111 | return ret; 112 | }(); 113 | static MLC_SYMBOL_HIDE inline MLCVTableHandle cxx_str = VTableGetGlobal("__cxx_str__"); 114 | static MLC_SYMBOL_HIDE inline MLCVTableHandle str = VTableGetGlobal("__str__"); 115 | static MLC_SYMBOL_HIDE inline MLCVTableHandle ir_print = VTableGetGlobal("__ir_print__"); 116 | static MLC_SYMBOL_HIDE inline MLCVTableHandle init = VTableGetGlobal("__init__"); 117 | }; 118 | 119 | } // namespace mlc 120 | #endif // MLC_BASE_LIB_H_ 121 | -------------------------------------------------------------------------------- /include/mlc/base/traits_device.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_BASE_TRAITS_DEVICE_H_ 2 | #define MLC_BASE_TRAITS_DEVICE_H_ 3 | 4 | #include "./lib.h" 5 | #include "./utils.h" 6 | 7 | namespace mlc { 8 | namespace base { 9 | 10 | DLDevice DeviceFromStr(const std::string &source); 11 | 12 | inline bool DeviceEqual(DLDevice a, DLDevice b) { return a.device_type == b.device_type && a.device_id == b.device_id; } 13 | inline const char *DeviceType2Str(int32_t device_type) { return ::mlc::Lib::DeviceTypeToStr(device_type); } 14 | 15 | template <> struct TypeTraits { 16 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCDevice); 17 | static constexpr const char *type_str = "Device"; 18 | 19 | MLC_INLINE static void TypeToAny(DLDevice src, MLCAny *ret) { 20 | ret->type_index = static_cast(MLCTypeIndex::kMLCDevice); 21 | ret->v.v_device = src; 22 | } 23 | 24 | MLC_INLINE static DLDevice AnyToTypeOwned(const MLCAny *v) { 25 | MLCTypeIndex ty = static_cast(v->type_index); 26 | if (ty == MLCTypeIndex::kMLCDevice) { 27 | return v->v.v_device; 28 | } 29 | if (ty == MLCTypeIndex::kMLCRawStr) { 30 | return DeviceFromStr(v->v.v_str); 31 | } 32 | if (ty == MLCTypeIndex::kMLCStr) { 33 | return DeviceFromStr(reinterpret_cast(v->v.v_obj)->data); 34 | } 35 | throw TemporaryTypeError(); 36 | } 37 | 38 | MLC_INLINE static DLDevice AnyToTypeUnowned(const MLCAny *v) { return AnyToTypeOwned(v); } 39 | 40 | MLC_INLINE static std::string __str__(DLDevice device) { 41 | std::ostringstream os; 42 | os << DeviceType2Str(static_cast(device.device_type)) << ":" << device.device_id; 43 | return os.str(); 44 | } 45 | }; 46 | 47 | inline DLDevice DeviceFromStr(const std::string &source) { 48 | constexpr int64_t i32_max = 2147483647; 49 | int32_t device_type; 50 | int64_t device_id = 0; 51 | try { 52 | if (size_t c_pos = source.rfind(':'); c_pos != std::string::npos) { 53 | device_type = ::mlc::Lib::DeviceTypeFromStr(source.substr(0, c_pos).c_str()); 54 | device_id = StrToInt(source, c_pos + 1); 55 | } else { 56 | device_type = ::mlc::Lib::DeviceTypeFromStr(source.c_str()); 57 | device_id = 0; 58 | } 59 | if (device_type < 0 || device_id < 0 || device_id > i32_max) { 60 | throw std::runtime_error(""); // Going to catch it below 61 | } 62 | return DLDevice{static_cast(device_type), static_cast(device_id)}; 63 | } catch (...) { 64 | } 65 | MLC_THROW(ValueError) << "Cannot convert to `Device` from string: " << source; 66 | MLC_UNREACHABLE(); 67 | } 68 | 69 | inline std::string DeviceToStr(DLDevice device) { return TypeTraits::__str__(device); } 70 | 71 | } // namespace base 72 | } // namespace mlc 73 | 74 | #endif // MLC_BASE_TRAITS_DEVICE_H_ 75 | -------------------------------------------------------------------------------- /include/mlc/base/traits_dtype.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_BASE_TRAITS_DTYPE_H_ 2 | #define MLC_BASE_TRAITS_DTYPE_H_ 3 | 4 | #include "./lib.h" 5 | #include "./utils.h" 6 | 7 | namespace mlc { 8 | namespace base { 9 | 10 | inline const char *DataTypeCode2Str(int32_t type_code) { return ::mlc::Lib::DataTypeCodeToStr(type_code); } 11 | 12 | template <> struct TypeTraits { 13 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCDataType); 14 | static constexpr const char *type_str = "dtype"; 15 | 16 | MLC_INLINE static void TypeToAny(DLDataType src, MLCAny *ret) { 17 | ret->type_index = static_cast(MLCTypeIndex::kMLCDataType); 18 | ret->v.v_dtype = src; 19 | } 20 | 21 | MLC_INLINE static DLDataType AnyToTypeOwned(const MLCAny *v) { 22 | MLCTypeIndex ty = static_cast(v->type_index); 23 | if (ty == MLCTypeIndex::kMLCDataType) { 24 | return v->v.v_dtype; 25 | } 26 | if (ty == MLCTypeIndex::kMLCRawStr) { 27 | return ::mlc::Lib::DataTypeFromStr(v->v.v_str); 28 | } 29 | if (ty == MLCTypeIndex::kMLCStr) { 30 | return ::mlc::Lib::DataTypeFromStr(reinterpret_cast(v->v.v_obj)->data); 31 | } 32 | throw TemporaryTypeError(); 33 | } 34 | 35 | MLC_INLINE static DLDataType AnyToTypeUnowned(const MLCAny *v) { return AnyToTypeOwned(v); } 36 | 37 | MLC_INLINE static std::string __str__(DLDataType dtype) { 38 | int32_t code = static_cast(dtype.code); 39 | int32_t bits = static_cast(dtype.bits); 40 | int32_t lanes = static_cast(dtype.lanes); 41 | if (code == kDLUInt && bits == 1 && lanes == 1) { 42 | return "bool"; 43 | } 44 | if (code == kDLOpaqueHandle && bits == 0 && lanes == 0) { 45 | return "void"; 46 | } 47 | std::ostringstream os; 48 | os << DataTypeCode2Str(code); 49 | if (code < kMLCExtension_DLDataTypeCode_Begin) { 50 | // for `code >= kMLCExtension_DLDataTypeCode_Begin`, the `bits` is already encoded in `code` 51 | os << bits; 52 | } 53 | if (lanes != 1) { 54 | os << "x" << lanes; 55 | } 56 | return os.str(); 57 | } 58 | }; 59 | 60 | struct DType { 61 | static DLDataType Int(int bits, int lanes = 1) { 62 | DLDataType dtype; 63 | dtype.code = kDLInt; 64 | dtype.bits = static_cast(bits); 65 | dtype.lanes = static_cast(lanes); 66 | return dtype; 67 | } 68 | static DLDataType UInt(int bits, int lanes = 1) { 69 | DLDataType dtype; 70 | dtype.code = kDLUInt; 71 | dtype.bits = static_cast(bits); 72 | dtype.lanes = static_cast(lanes); 73 | return dtype; 74 | } 75 | static DLDataType Float(int bits, int lanes = 1) { 76 | DLDataType dtype; 77 | dtype.code = kDLFloat; 78 | dtype.bits = static_cast(bits); 79 | dtype.lanes = static_cast(lanes); 80 | return dtype; 81 | } 82 | static DLDataType Bool(int lanes = 1) { 83 | DLDataType dtype; 84 | dtype.code = kDLUInt; 85 | dtype.bits = 1; 86 | dtype.lanes = static_cast(lanes); 87 | return dtype; 88 | } 89 | static DLDataType Void() { 90 | DLDataType dtype; 91 | dtype.code = kDLOpaqueHandle; 92 | dtype.bits = 0; 93 | dtype.lanes = 0; 94 | return dtype; 95 | } 96 | static bool Equal(DLDataType a, DLDataType b) { return a.code == b.code && a.bits == b.bits && a.lanes == b.lanes; } 97 | static bool IsIntOrUIntOrBool(DLDataType dtype) { return dtype.code == kDLInt || dtype.code == kDLUInt; } 98 | static bool IsBool(DLDataType dtype) { return dtype.code == kDLUInt && dtype.bits == 1; } 99 | static bool IsFloat(DLDataType dtype) { 100 | // TODO: handle fp8 101 | return dtype.code == kDLFloat || dtype.code == kDLBfloat; 102 | } 103 | static std::string Str(DLDataType dtype) { return TypeTraits::__str__(dtype); } 104 | static int32_t Size(DLDataType dtype) { 105 | int32_t bits = static_cast(dtype.bits); 106 | int32_t lanes = static_cast(dtype.lanes); 107 | return ((bits + 7) / 8) * lanes; 108 | } 109 | }; 110 | 111 | } // namespace base 112 | } // namespace mlc 113 | 114 | #endif // MLC_BASE_TRAITS_DTYPE_H_ 115 | -------------------------------------------------------------------------------- /include/mlc/base/traits_object.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_BASE_TRAITS_OBJECT_H_ 2 | #define MLC_BASE_TRAITS_OBJECT_H_ 3 | 4 | #include "./lib.h" 5 | #include "./utils.h" 6 | 7 | namespace mlc { 8 | namespace base { 9 | 10 | template struct ObjPtrTraitsDefault { 11 | MLC_INLINE static void TypeToAny(T *src, MLCAny *ret) { 12 | if (src == nullptr) { 13 | ret->type_index = static_cast(MLCTypeIndex::kMLCNone); 14 | ret->v.v_obj = nullptr; 15 | } else { 16 | ret->type_index = src->_mlc_header.type_index; 17 | ret->v.v_obj = const_cast(reinterpret_cast(src)); 18 | } 19 | } 20 | MLC_INLINE static T *AnyToTypeUnowned(const MLCAny *v) { 21 | if (::mlc::base::IsTypeIndexNone(v->type_index)) { 22 | return nullptr; 23 | } 24 | if (!::mlc::base::IsTypeIndexPOD(v->type_index) && ::mlc::base::IsInstanceOf(v)) { 25 | return reinterpret_cast(v->v.v_obj); 26 | } 27 | throw TemporaryTypeError(); 28 | } 29 | MLC_INLINE static T *AnyToTypeOwned(const MLCAny *v) { return AnyToTypeUnowned(v); } 30 | MLC_INLINE static T *AnyToTypeWithStorage(const MLCAny *v, Any *storage) = delete; 31 | }; 32 | 33 | template struct TypeTraits && !IsTemplate>> { 34 | MLC_INLINE static void TypeToAny(T *src, MLCAny *ret) { ObjPtrTraitsDefault::TypeToAny(src, ret); } 35 | MLC_INLINE static T *AnyToTypeUnowned(const MLCAny *v) { return ObjPtrTraitsDefault::AnyToTypeUnowned(v); } 36 | MLC_INLINE static T *AnyToTypeOwned(const MLCAny *v) { return ObjPtrTraitsDefault::AnyToTypeOwned(v); } 37 | }; 38 | 39 | template struct TypeTraits *> { 40 | using T = ListObj; 41 | MLC_INLINE static void TypeToAny(T *src, MLCAny *ret) { ObjPtrTraitsDefault::TypeToAny(src, ret); } 42 | MLC_INLINE static T *AnyToTypeOwned(const MLCAny *v) { return AnyToTypeUnowned(v); } 43 | MLC_INLINE static T *AnyToTypeUnowned(const MLCAny *v); 44 | }; 45 | template struct TypeTraits *> { 46 | using T = DictObj; 47 | MLC_INLINE static void TypeToAny(T *src, MLCAny *ret) { ObjPtrTraitsDefault::TypeToAny(src, ret); } 48 | MLC_INLINE static T *AnyToTypeOwned(const MLCAny *v) { return AnyToTypeUnowned(v); } 49 | MLC_INLINE static T *AnyToTypeUnowned(const MLCAny *v); 50 | }; 51 | 52 | template MLC_INLINE bool IsInstanceOf(const MLCAny *self) { 53 | if constexpr (std::is_same_v || std::is_base_of_v) { 54 | return true; 55 | } 56 | // Special case: `DerivedType` is exactly the underlying type of `type_index` 57 | if (self == nullptr) { 58 | return false; 59 | } 60 | int32_t type_index = self->type_index; 61 | if (type_index == DerivedType::_type_index) { 62 | return true; 63 | } 64 | // Given an index `i = DerivedType::_type_index`, 65 | // and the underlying type of `T = *this`, we wanted to check if 66 | // `T::_type_ancestors[i] == DerivedType::_type_index`. 67 | // 68 | // There are 3 ways to reflect `T` out of `this->type_index`: 69 | // (Case 1) Use `SelfType` as a surrogate type if `SelfType::_type_ancestors` 70 | // is long enough, whose length is reflected by `SelfType::_type_depth`. 71 | if constexpr (SelfType::_type_depth > DerivedType::_type_depth) { 72 | return SelfType::_type_ancestors[DerivedType::_type_depth] == DerivedType::_type_index; 73 | } 74 | if constexpr (SelfType::_type_depth == DerivedType::_type_depth) { 75 | return SelfType::_type_index == DerivedType::_type_index; 76 | } 77 | // (Case 2) If `type_index` falls in static object section 78 | if (::mlc::base::IsTypeIndexPOD(type_index)) { 79 | return false; 80 | } 81 | // (Case 3) Look up the type table for type hierarchy via `type_index`. 82 | if (MLCTypeInfo *info = Lib::GetTypeInfo(type_index)) { 83 | return info->type_depth > DerivedType::_type_depth && 84 | info->type_ancestors[DerivedType::_type_depth] == DerivedType::_type_index; 85 | } 86 | MLC_THROW(InternalError) << "Undefined type index: " << type_index; 87 | MLC_UNREACHABLE(); 88 | } 89 | } // namespace base 90 | } // namespace mlc 91 | 92 | #endif // MLC_BASE_TRAITS_OBJECT_H_ 93 | -------------------------------------------------------------------------------- /include/mlc/base/traits_scalar.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_BASE_TRAITS_SCALAR_H_ 2 | #define MLC_BASE_TRAITS_SCALAR_H_ 3 | 4 | #include "./utils.h" 5 | #include 6 | #include 7 | 8 | namespace mlc { 9 | namespace base { 10 | 11 | template <> struct TypeTraits { 12 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCBool); 13 | static constexpr const char *type_str = "bool"; 14 | 15 | MLC_INLINE static void TypeToAny(bool src, MLCAny *ret) { 16 | ret->type_index = static_cast(MLCTypeIndex::kMLCBool); 17 | ret->v.v_int64 = static_cast(src); 18 | } 19 | MLC_INLINE static bool AnyToTypeOwned(const MLCAny *v) { 20 | MLCTypeIndex ty = static_cast(v->type_index); 21 | if (ty == MLCTypeIndex::kMLCBool) { 22 | return v->v.v_bool; 23 | } 24 | throw TemporaryTypeError(); 25 | } 26 | MLC_INLINE static bool AnyToTypeUnowned(const MLCAny *v) { return AnyToTypeOwned(v); } 27 | MLC_INLINE static std::string __str__(bool src) { return src ? "True" : "False"; } 28 | }; 29 | 30 | template struct TypeTraits>> { 31 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCInt); 32 | static constexpr const char *type_str = "int"; 33 | 34 | MLC_INLINE static void TypeToAny(Int src, MLCAny *ret) { 35 | ret->type_index = static_cast(MLCTypeIndex::kMLCInt); 36 | ret->v.v_int64 = static_cast(src); 37 | } 38 | MLC_INLINE static Int AnyToTypeOwned(const MLCAny *v) { 39 | MLCTypeIndex ty = static_cast(v->type_index); 40 | if (ty == MLCTypeIndex::kMLCInt) { 41 | return static_cast(v->v.v_int64); 42 | } 43 | throw TemporaryTypeError(); 44 | } 45 | MLC_INLINE static Int AnyToTypeUnowned(const MLCAny *v) { return AnyToTypeOwned(v); } 46 | MLC_INLINE static std::string __str__(Int src) { return std::to_string(src); } 47 | }; 48 | 49 | template struct TypeTraits>> { 50 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCFloat); 51 | static constexpr const char *type_str = "float"; 52 | 53 | MLC_INLINE static void TypeToAny(Float src, MLCAny *ret) { 54 | ret->type_index = static_cast(MLCTypeIndex::kMLCFloat); 55 | ret->v.v_float64 = src; 56 | } 57 | MLC_INLINE static Float AnyToTypeOwned(const MLCAny *v) { 58 | MLCTypeIndex ty = static_cast(v->type_index); 59 | if (ty == MLCTypeIndex::kMLCFloat) { 60 | return static_cast(v->v.v_float64); 61 | } else if (ty == MLCTypeIndex::kMLCInt) { 62 | return static_cast(v->v.v_int64); 63 | } 64 | throw TemporaryTypeError(); 65 | } 66 | MLC_INLINE static Float AnyToTypeUnowned(const MLCAny *v) { return AnyToTypeOwned(v); } 67 | MLC_INLINE static std::string __str__(Float src) { return std::to_string(src); } 68 | }; 69 | 70 | template <> struct TypeTraits { 71 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCPtr); 72 | static constexpr const char *type_str = "Ptr"; 73 | 74 | MLC_INLINE static void TypeToAny(void *src, MLCAny *ret) { 75 | ret->type_index = 76 | (src == nullptr) ? static_cast(MLCTypeIndex::kMLCNone) : static_cast(MLCTypeIndex::kMLCPtr); 77 | ret->v.v_ptr = src; 78 | } 79 | MLC_INLINE static void *AnyToTypeOwned(const MLCAny *v) { 80 | MLCTypeIndex ty = static_cast(v->type_index); 81 | if (ty == MLCTypeIndex::kMLCPtr || ty == MLCTypeIndex::kMLCRawStr || ty == MLCTypeIndex::kMLCNone) { 82 | return v->v.v_ptr; 83 | } 84 | throw TemporaryTypeError(); 85 | } 86 | MLC_INLINE static void *AnyToTypeUnowned(const MLCAny *v) { return AnyToTypeOwned(v); } 87 | MLC_INLINE static std::string __str__(void *src) { 88 | if (src == nullptr) { 89 | return "None"; 90 | } else { 91 | std::ostringstream oss; 92 | oss << "0x" << std::setfill('0') << std::setw(12) << std::hex << (uintptr_t)(src); 93 | return oss.str(); 94 | } 95 | } 96 | }; 97 | 98 | template <> struct TypeTraits : public TypeTraits { 99 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCNone); 100 | static constexpr const char *type_str = "None"; 101 | }; 102 | 103 | } // namespace base 104 | } // namespace mlc 105 | 106 | #endif // MLC_BASE_TRAITS_SCALAR_H_ 107 | -------------------------------------------------------------------------------- /include/mlc/base/traits_str.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_BASE_TRAITS_STR_H_ 2 | #define MLC_BASE_TRAITS_STR_H_ 3 | 4 | #include "./utils.h" 5 | 6 | namespace mlc { 7 | namespace base { 8 | 9 | template <> struct TypeTraits { 10 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCRawStr); 11 | static constexpr const char *type_str = "char *"; 12 | 13 | static void TypeToAny(const char *src, MLCAny *ret) { 14 | ret->type_index = static_cast(MLCTypeIndex::kMLCRawStr); 15 | ret->v.v_str = src; 16 | } 17 | static const char *AnyToTypeOwned(const MLCAny *v) { 18 | MLCTypeIndex ty = static_cast(v->type_index); 19 | if (ty == MLCTypeIndex::kMLCRawStr) { 20 | return v->v.v_str; 21 | } 22 | if (ty == MLCTypeIndex::kMLCStr) { 23 | return reinterpret_cast(v->v.v_obj)->data; 24 | } 25 | throw TemporaryTypeError(); 26 | } 27 | static const char *AnyToTypeUnowned(const MLCAny *v) { return AnyToTypeOwned(v); } 28 | static std::string __str__(const char *src) { return '"' + std::string(src) + '"'; } 29 | }; 30 | 31 | template <> struct TypeTraits { 32 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCRawStr); 33 | static constexpr const char *type_str = "char *"; 34 | 35 | static void TypeToAny(char *src, MLCAny *ret) { return TypeTraits::TypeToAny(src, ret); } 36 | static char *AnyToTypeOwned(const MLCAny *v) { 37 | return const_cast(TypeTraits::AnyToTypeOwned(v)); 38 | } 39 | static const char *AnyToTypeUnowned(const MLCAny *v) { return AnyToTypeOwned(v); } 40 | static std::string __str__(const char *src) { return '"' + std::string(src) + '"'; } 41 | }; 42 | 43 | template <> struct TypeTraits { 44 | static constexpr int32_t type_index = static_cast(MLCTypeIndex::kMLCRawStr); 45 | static constexpr const char *type_str = "char *"; 46 | 47 | static void TypeToAny(const std::string &src, MLCAny *ret) { 48 | return TypeTraits::TypeToAny(src.data(), ret); 49 | } 50 | static std::string AnyToTypeOwned(const MLCAny *v) { return TypeTraits::AnyToTypeOwned(v); } 51 | static std::string AnyToTypeUnowned(const MLCAny *v) { return TypeTraits::AnyToTypeUnowned(v); } 52 | static std::string __str__(const char *src) { return '"' + std::string(src) + '"'; } 53 | }; 54 | 55 | template struct TypeTraits : public TypeTraits {}; 56 | 57 | } // namespace base 58 | } // namespace mlc 59 | 60 | #endif // MLC_BASE_TRAITS_STR_H_ 61 | -------------------------------------------------------------------------------- /include/mlc/core/error.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_CORE_ERROR_H_ 2 | #define MLC_CORE_ERROR_H_ 3 | 4 | #include "./object.h" 5 | #include 6 | #include 7 | 8 | namespace mlc { 9 | 10 | struct ErrorObj : public MLCError { 11 | struct Allocator; 12 | std::string __str__() const { return this->ByteArray(); } 13 | const char *ByteArray() const { return reinterpret_cast(this + 1); } 14 | char *ByteArray() { return reinterpret_cast(this + 1); } 15 | const char *kind() const { return MLCError::kind; } 16 | 17 | explicit ErrorObj(const char *kind, MLCByteArray message, MLCByteArray traceback) { 18 | // Assumption: 19 | // 1) `message` ends with no '\0' 20 | // 2) `traceback` ends with exactly 1 '\0' 21 | MLCError::kind = kind; 22 | char *byte_array = this->ByteArray(); 23 | std::memcpy(byte_array, message.bytes, message.num_bytes); 24 | byte_array[message.num_bytes] = '\0'; 25 | byte_array += message.num_bytes + 1; 26 | std::memcpy(byte_array, traceback.bytes, traceback.num_bytes); 27 | byte_array[traceback.num_bytes] = '\0'; 28 | } 29 | 30 | explicit ErrorObj(const char *kind, int64_t num_bytes, const char *bytes) { 31 | MLCError::kind = kind; 32 | char *byte_array = this->ByteArray(); 33 | std::memcpy(byte_array, bytes, num_bytes); 34 | byte_array[num_bytes] = '\0'; 35 | } 36 | 37 | inline Ref AppendWith(MLCByteArray traceback) const; 38 | 39 | void GetInfo(std::vector *ret) const { 40 | ret->clear(); 41 | const char *bytes = this->ByteArray(); 42 | while (*bytes != '\0') { 43 | ret->push_back(bytes); 44 | bytes += std::strlen(bytes) + 1; 45 | } 46 | } 47 | 48 | void FormatExc(std::ostream &os) const { 49 | std::vector info; 50 | this->GetInfo(&info); 51 | os << "Traceback (most recent call last):" << std::endl; 52 | int frame_id = 0; 53 | for (size_t i = 1; i < info.size(); i += 3) { 54 | const char *filename = info[i]; 55 | const char *funcname = info[i + 2]; 56 | os << " [" << ++frame_id << "] File \"" << filename << "\", line " << info[i + 1] << ", in " << funcname 57 | << std::endl; 58 | } 59 | os << this->kind() << ": " << info[0] << std::endl; 60 | } 61 | 62 | MLC_DEF_STATIC_TYPE(MLC_EXPORTS, ErrorObj, Object, MLCTypeIndex::kMLCError, "object.Error"); 63 | }; 64 | 65 | struct ErrorObj::Allocator { 66 | MLC_INLINE static ErrorObj *New(const char *kind, MLCByteArray message, MLCByteArray traceback) { 67 | return ::mlc::DefaultObjectAllocator::NewWithPad(message.num_bytes + traceback.num_bytes + 2, kind, 68 | message, traceback); 69 | } 70 | MLC_INLINE static ErrorObj *New(const char *kind, int64_t num_bytes, const char *bytes) { 71 | return ::mlc::DefaultObjectAllocator::NewWithPad(num_bytes + 1, kind, num_bytes, bytes); 72 | } 73 | }; 74 | 75 | struct Error : public ObjectRef { 76 | MLC_DEF_OBJ_REF(MLC_EXPORTS, Error, ErrorObj, ObjectRef) 77 | .Field("kind", &MLCError::kind, /*frozen=*/true) 78 | .MemFn("__str__", &ErrorObj::__str__); 79 | explicit Error(const char *kind, MLCByteArray message, MLCByteArray traceback) 80 | : Error(Error::New(kind, message, traceback)) {} 81 | explicit Error(const char *kind, int64_t num_bytes, const char *bytes) : Error(Error::New(kind, num_bytes, bytes)) {} 82 | }; 83 | 84 | inline Ref ErrorObj::AppendWith(MLCByteArray traceback) const { 85 | MLCByteArray self; 86 | self.bytes = this->ByteArray(); 87 | self.num_bytes = [&]() -> int64_t { 88 | const char *end_bytes = this->ByteArray(); 89 | while (*end_bytes != '\0') { 90 | end_bytes += std::strlen(end_bytes) + 1; 91 | } 92 | return end_bytes - self.bytes - 1; 93 | }(); 94 | return Ref::New(MLCError::kind, self, traceback); 95 | } 96 | 97 | inline const char *Exception::what() const noexcept(true) { 98 | if (data_.get() == nullptr) { 99 | return "mlc::ffi::Exception: Unspecified"; 100 | } 101 | return Obj()->ByteArray(); 102 | } 103 | 104 | inline Exception::Exception(Ref data) : data_(data.get()) {} 105 | 106 | inline void Exception::FormatExc(std::ostream &os) const { 107 | if (data_.get()) { 108 | Obj()->FormatExc(os); 109 | } else { 110 | os << "mlc.Exception: Unspecified"; 111 | } 112 | } 113 | } // namespace mlc 114 | 115 | namespace mlc { 116 | namespace base { 117 | [[noreturn]] inline void MLCThrowError(const char *kind, MLCByteArray message, MLCByteArray traceback) noexcept(false) { 118 | throw Exception(Ref::New(kind, message, traceback)); 119 | } 120 | inline Any MLCCreateError(const char *kind, const std::string &message, MLCByteArray traceback) { 121 | return Ref::New(kind, MLCByteArray{static_cast(message.size()), message.data()}, traceback); 122 | } 123 | } // namespace base 124 | } // namespace mlc 125 | 126 | #endif // MLC_CORE_ERROR_H_ 127 | -------------------------------------------------------------------------------- /include/mlc/core/func.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_CORE_FUNC_H_ 2 | #define MLC_CORE_FUNC_H_ 3 | 4 | #include "./error.h" 5 | #include "./str.h" 6 | #include "./utils.h" 7 | #include 8 | #include 9 | 10 | #define MLC_REGISTER_FUNC(name) [[maybe_unused]] static auto MLC_UNIQUE_ID() = ::mlc::core::FuncRegistryHelper(name) 11 | 12 | namespace mlc { 13 | struct FuncObj : public MLCFunc { 14 | using Call = void(const FuncObj *, int32_t, const AnyView *, Any *); 15 | using SafeCall = int32_t(const FuncObj *, int32_t, const AnyView *, Any *); 16 | struct Allocator; 17 | 18 | template inline Any operator()(Args &&...args) const { 19 | constexpr size_t N = sizeof...(Args); 20 | AnyViewArray stack_args; 21 | Any ret; 22 | stack_args.Fill(std::forward(args)...); 23 | ::mlc::base::FuncCall(this, N, stack_args.v, &ret); 24 | return ret; 25 | } 26 | 27 | static Ref FromForeign(void *self, MLCDeleterType deleter, MLCFuncSafeCallType safe_call); 28 | 29 | static int32_t SafeCallImpl(const FuncObj *self, int32_t num_args, const AnyView *args, Any *ret) { 30 | MLC_SAFE_CALL_BEGIN(); 31 | self->call(self, num_args, args, ret); 32 | MLC_SAFE_CALL_END(ret); 33 | } 34 | 35 | MLC_INLINE FuncObj(Call *f) : MLCFunc() { 36 | this->MLCFunc::call = reinterpret_cast(f); 37 | this->MLCFunc::safe_call = reinterpret_cast(FuncObj::SafeCallImpl); 38 | } 39 | 40 | MLC_DEF_STATIC_TYPE(MLC_EXPORTS, FuncObj, Object, MLCTypeIndex::kMLCFunc, "object.Func"); 41 | }; 42 | 43 | struct FuncObj::Allocator { 44 | public: 45 | template >> 46 | MLC_INLINE static FuncObj *New(FuncType func); 47 | }; 48 | 49 | struct Func : public ObjectRef { 50 | template MLC_INLINE Any operator()(Args &&...args) const { 51 | return get()->operator()(std::forward(args)...); 52 | } 53 | template >> 54 | Func(FuncType func) : Func(FuncObj::Allocator::New(std::move(func))) {} 55 | static FuncObj *GetGlobal(const char *name, bool allow_missing = false) { 56 | return Lib::FuncGetGlobal(name, allow_missing); 57 | } 58 | MLC_DEF_OBJ_REF(MLC_EXPORTS, Func, FuncObj, ObjectRef).MemFn("__str__", ::mlc::core::StringifyOpaque); 59 | }; 60 | } // namespace mlc 61 | 62 | namespace mlc { 63 | namespace core { 64 | template struct FuncImpl : public FuncObj { 65 | using TSelf = FuncImpl; 66 | using Allocator = ::mlc::DefaultObjectAllocator; 67 | MLC_INLINE FuncImpl(FuncType func, FuncObj::Call *f) : FuncObj(f), func_(std::forward(func)) {} 68 | mutable std::decay_t func_; 69 | }; 70 | 71 | struct FuncRegistryHelper { 72 | explicit FuncRegistryHelper(const char *name) : name(name) {} 73 | template FuncRegistryHelper &set_body(FuncType func, bool allow_override = false) { 74 | Ref f = Ref::New(std::forward(func)); 75 | Lib::FuncSetGlobal(name, f.get(), allow_override); 76 | return *this; 77 | } 78 | const char *name; 79 | }; 80 | 81 | } // namespace core 82 | } // namespace mlc 83 | 84 | namespace mlc { 85 | namespace base { 86 | inline void FuncCallCheckError(int32_t err_code, MLCAny *ret) noexcept(false) { 87 | Any err; 88 | if (ret != nullptr) { 89 | err = static_cast(*ret); 90 | } else { 91 | static_cast(err) = ::MLCGetLastError(); 92 | } 93 | if (err_code == -1) { // string errors 94 | MLC_THROW(InternalError) << "Error: " << err; 95 | } else if (err_code == -2) { // error objects 96 | throw Exception(err.operator Ref()->AppendWith(MLC_TRACEBACK_HERE())); 97 | } else { // error code 98 | MLC_THROW(InternalError) << "Error code: " << err_code; 99 | } 100 | MLC_UNREACHABLE(); 101 | } 102 | inline void FuncCall(const void *self, int32_t num_args, const MLCAny *args, MLCAny *ret) { 103 | const MLCFunc *func = static_cast(self); 104 | if (func->call && reinterpret_cast(func->safe_call) == reinterpret_cast(FuncObj::SafeCallImpl)) { 105 | func->call(func, num_args, args, ret); 106 | } else if (int32_t err_code = func->safe_call(func, num_args, args, ret)) { 107 | FuncCallCheckError(err_code, ret); 108 | } 109 | } 110 | template inline auto GetGlobalFuncCall(const char *name) { 111 | using ArrayType = std::array; 112 | FuncObj *func = Func::GetGlobal(name); 113 | return [func](ArrayType &&args) { 114 | Any ret; 115 | ::mlc::base::FuncCall(func, num_args, std::move(args).data(), &ret); 116 | return ret; 117 | }; 118 | } 119 | template inline Any CallableToAny(Callable &&callable) { 120 | return Ref::New(std::forward(callable)); 121 | } 122 | } // namespace base 123 | } // namespace mlc 124 | 125 | #endif // MLC_CORE_FUNC_H_ 126 | -------------------------------------------------------------------------------- /include/mlc/core/opaque.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_CORE_OPAQUE_H_ 2 | #define MLC_CORE_OPAQUE_H_ 3 | 4 | #include "./object.h" 5 | #include "./typing.h" 6 | #include 7 | 8 | namespace mlc { 9 | 10 | struct OpaqueObj : public MLCOpaque { 11 | explicit OpaqueObj(void *handle_, void *handle_deleter_, const char *opaque_type_name_) : MLCOpaque{} { 12 | this->handle = handle_; 13 | this->handle_deleter = reinterpret_cast(handle_deleter_); 14 | this->opaque_type_name = [opaque_type_name_]() { 15 | char *ret = new char[std::strlen(opaque_type_name_) + 1]; 16 | std::memcpy(ret, opaque_type_name_, std::strlen(opaque_type_name_) + 1); 17 | return ret; 18 | }(); 19 | } 20 | ~OpaqueObj() { 21 | delete[] this->opaque_type_name; 22 | this->handle_deleter(this->handle); 23 | } 24 | std::string __str__() const { return "opaque_type_name) + "`>"; } 25 | MLC_DEF_STATIC_TYPE(MLC_EXPORTS, OpaqueObj, Object, MLCTypeIndex::kMLCOpaque, "mlc.core.Opaque"); 26 | }; 27 | 28 | struct Opaque : public ObjectRef { 29 | MLC_DEF_OBJ_REF(MLC_EXPORTS, Opaque, OpaqueObj, ObjectRef) 30 | .StaticFn("__init__", InitOf) 31 | .MemFn("__str__", &OpaqueObj::__str__) 32 | .Field("handle", &OpaqueObj::handle, /*frozen=*/true) 33 | ._Field("handle_deleter", offsetof(MLCOpaque, handle_deleter), sizeof(MLCOpaque::handle_deleter), true, 34 | ::mlc::core::ParseType()) 35 | .Field("opaque_type_name", &OpaqueObj::opaque_type_name, /*frozen=*/true); 36 | explicit Opaque(void *handle, void *handle_deleter, const char *opaque_type_name) 37 | : Opaque(Opaque::New(handle, handle_deleter, opaque_type_name)) {} 38 | }; 39 | 40 | } // namespace mlc 41 | 42 | #endif // MLC_CORE_OPAQUE_H_ 43 | -------------------------------------------------------------------------------- /include/mlc/core/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_CORE_UTILS_H_ 2 | #define MLC_CORE_UTILS_H_ 3 | 4 | #include 5 | #include 6 | 7 | #define MLC_SAFE_CALL_BEGIN() \ 8 | try { \ 9 | (void)0 10 | 11 | #define MLC_SAFE_CALL_END(err_ret) \ 12 | return 0; \ 13 | } \ 14 | catch (::mlc::Exception & err) { \ 15 | *(err_ret) = std::move(err.data_); \ 16 | return -2; \ 17 | } \ 18 | catch (std::exception & err) { \ 19 | *(err_ret) = err.what(); \ 20 | return -1; \ 21 | } \ 22 | MLC_UNREACHABLE() 23 | 24 | namespace mlc { 25 | namespace core { 26 | 27 | /********** Section 1. Nested type checking *********/ 28 | 29 | struct NestedTypeError : public std::runtime_error { 30 | explicit NestedTypeError(const char *msg) : std::runtime_error(msg) {} 31 | 32 | struct Frame { 33 | std::string expected_type; 34 | std::vector indices; 35 | }; 36 | 37 | NestedTypeError &NewFrame(std::string expected_type) { 38 | frames.push_back(Frame{expected_type, {}}); 39 | return *this; 40 | } 41 | 42 | NestedTypeError &NewIndex(AnyView index) { 43 | frames.back().indices.push_back(index); 44 | return *this; 45 | } 46 | 47 | void Format(std::ostream &os, std::string overall_expected) const { 48 | int32_t num_frames = static_cast(frames.size()); 49 | if (num_frames == 1) { 50 | os << "Let input be `A: " << overall_expected // 51 | << "`. Type mismatch on `A"; 52 | for (auto rit = frames[0].indices.rbegin(); rit != frames[0].indices.rend(); ++rit) { 53 | os << "[" << *rit << "]"; 54 | } 55 | os << "`: " << this->what(); 56 | return; 57 | } 58 | int32_t last_var = num_frames; 59 | os << "Let input be `A_0: " << overall_expected << "`"; 60 | for (int32_t frame_id = num_frames - 1; frame_id >= 0; --frame_id) { 61 | const Frame &frame = frames[frame_id]; 62 | if (frame_id == 0 && frame.indices.empty()) { 63 | last_var = num_frames - 1; 64 | break; 65 | } 66 | os << ", `A_" << (num_frames - frame_id) // 67 | << ": " << frame.expected_type // 68 | << (frame_id == 0 ? " := A_" : " in A_") << (num_frames - frame_id - 1); 69 | for (auto rit = frame.indices.rbegin(); rit != frame.indices.rend(); ++rit) { 70 | os << "[" << *rit << "]"; 71 | } 72 | if (frame_id > 0) { 73 | os << ".keys()"; 74 | } 75 | os << "`"; 76 | } 77 | os << ". Type mismatch on `A_" << last_var << "`: " << this->what(); 78 | } 79 | 80 | std::vector frames; 81 | }; 82 | 83 | template struct NestedTypeCheck { 84 | MLC_INLINE_NO_MSVC static void Run(const MLCAny &any); 85 | }; 86 | template struct NestedTypeCheck> { 87 | MLC_INLINE_NO_MSVC static void Run(const MLCAny &any); 88 | }; 89 | template struct NestedTypeCheck> { 90 | MLC_INLINE_NO_MSVC static void Run(const MLCAny &any); 91 | }; 92 | 93 | } // namespace core 94 | } // namespace mlc 95 | 96 | #endif // MLC_CORE_UTILS_H_ 97 | -------------------------------------------------------------------------------- /include/mlc/printer/all.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_PRINTER_ALL_H_ 2 | #define MLC_PRINTER_ALL_H_ 3 | 4 | #include "./ast.h" // IWYU pragma: export 5 | #include "./ir_printer.h" // IWYU pragma: export 6 | 7 | namespace mlc { 8 | namespace printer { 9 | 10 | inline Expr ExprObj::Attr(mlc::Str name) const { 11 | return ::mlc::printer::Attr(::mlc::List<::mlc::core::ObjectPath>{}, Expr(this), name); 12 | } 13 | 14 | inline Expr ExprObj::Index(mlc::List<::mlc::printer::Expr> idx) const { 15 | return ::mlc::printer::Index(::mlc::List<::mlc::core::ObjectPath>{}, Expr(this), idx); 16 | } 17 | 18 | inline Expr ExprObj::Call(mlc::List<::mlc::printer::Expr> args) const { 19 | return ::mlc::printer::Call(::mlc::List<::mlc::core::ObjectPath>{}, Expr(this), args, mlc::List<::mlc::Str>{}, 20 | mlc::List<::mlc::printer::Expr>{}); 21 | } 22 | 23 | inline Expr ExprObj::CallKw(mlc::List<::mlc::printer::Expr> args, mlc::List<::mlc::Str> kwargs_keys, 24 | mlc::List<::mlc::printer::Expr> kwargs_values) const { 25 | return ::mlc::printer::Call(::mlc::List<::mlc::core::ObjectPath>{}, Expr(this), args, kwargs_keys, kwargs_values); 26 | } 27 | 28 | } // namespace printer 29 | } // namespace mlc 30 | 31 | #endif // MLC_PRINTER_ALL_H_ 32 | -------------------------------------------------------------------------------- /include/mlc/sym/all.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_SYM_ALL_H_ 2 | #define MLC_SYM_ALL_H_ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #endif // MLC_SYM_ALL_H_ 9 | -------------------------------------------------------------------------------- /include/mlc/sym/analyzer.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_SYM_ANALYZER_H_ 2 | #define MLC_SYM_ANALYZER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace mlc { 9 | namespace sym { 10 | 11 | struct AnalyzerObj { 12 | MLCAny _mlc_header; 13 | MLC_DEF_DYN_TYPE(MLC_SYM_EXPORTS, AnalyzerObj, ::mlc::Object, "mlc.sym.Analyzer"); 14 | 15 | struct Impl; 16 | struct Testing; 17 | enum class ProofStrength : int { 18 | kDefault = 0, 19 | kSymbolicBound = 1, 20 | }; 21 | MLC_API AnalyzerObj(); 22 | MLC_API ~AnalyzerObj(); 23 | MLC_API void MarkGlobalNonNegValue(const Expr &value); 24 | MLC_API void Bind(const Var &var, const Expr &expr, bool allow_override = false); 25 | MLC_API void Bind(const Var &var, const Range &range, bool allow_override = false); 26 | MLC_API void Bind(const Dict &variables, bool allow_override = false); 27 | MLC_API bool CanProveGreaterEqual(const Expr &expr, int64_t lower_bound); 28 | MLC_API bool CanProveLess(const Expr &expr, int64_t upper_bound); 29 | MLC_API bool CanProveEqual(const Expr &lhs, const Expr &rhs); 30 | MLC_API bool CanProveLessEqualThanSymbolicShapeValue(const Expr &lhs, const Expr &shape); 31 | MLC_API bool CanProve(const Expr &cond, ProofStrength strength = ProofStrength::kDefault); 32 | MLC_API Expr Simplify(const Expr &expr, int steps = 2); 33 | 34 | private: 35 | friend struct IRMutatorWithAnalyzer; 36 | std::unique_ptr impl_; 37 | }; 38 | 39 | struct Analyzer : public ObjectRef { 40 | Analyzer() : ObjectRef(Analyzer::New()) {} 41 | MLC_DEF_OBJ_REF(MLC_SYM_EXPORTS, Analyzer, AnalyzerObj, ObjectRef) 42 | .StaticFn("__init__", InitOf) 43 | .MemFn("mark_global_non_neg_value", &AnalyzerObj::MarkGlobalNonNegValue) 44 | .MemFn("_bind_range", [](AnalyzerObj *self, const Var &var, Range range, 45 | bool allow_override) { self->Bind(var, range, allow_override); }) 46 | .MemFn("_bind_expr", [](AnalyzerObj *self, const Var &var, Expr expr, 47 | bool allow_override) { self->Bind(var, expr, allow_override); }) 48 | .MemFn("can_prove_greater_equal", &AnalyzerObj::CanProveGreaterEqual) 49 | .MemFn("can_prove_less", &AnalyzerObj::CanProveLess) 50 | .MemFn("can_prove_equal", &AnalyzerObj::CanProveEqual) 51 | .MemFn("can_prove_less_equal_than_symbolic_shape_value", &AnalyzerObj::CanProveLessEqualThanSymbolicShapeValue) 52 | .MemFn("can_prove", 53 | [](AnalyzerObj *self, const Expr &cond, int32_t strength) { 54 | return self->CanProve(cond, static_cast(strength)); 55 | }) 56 | .MemFn("simplify", &AnalyzerObj::Simplify); 57 | }; 58 | 59 | struct IRMutatorWithAnalyzer : public ExprMutator { 60 | explicit IRMutatorWithAnalyzer(AnalyzerObj *analyzer) : analyzer_(analyzer->impl_.get()) {} 61 | explicit IRMutatorWithAnalyzer(AnalyzerObj::Impl *analyzer) : analyzer_(analyzer) {} 62 | using ExprMutator::VisitExpr_; 63 | MLC_API Expr VisitExpr_(const LetObj *op) override; 64 | MLC_API Expr VisitExpr_(const SelectObj *op) override; 65 | MLC_API Expr VisitExpr_(const CallObj *op) override; 66 | 67 | protected: 68 | AnalyzerObj::Impl *analyzer_; 69 | }; 70 | 71 | } // namespace sym 72 | } // namespace mlc 73 | 74 | #endif // MLC_SYM_ANALYZER_H_ 75 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mlc-python" 3 | dynamic = ["version"] 4 | dependencies = [ 5 | 'numpy >= 1.22', 6 | 'ml-dtypes >= 0.1', 7 | 'Pygments>=2.4.0', 8 | 'colorama', 9 | 'setuptools ; platform_system == "Windows"', 10 | ] 11 | description = "Python-first Development for AI Compilers" 12 | requires-python = ">=3.9" 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3 :: Only", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Intended Audience :: Science/Research", 18 | ] 19 | keywords = [] 20 | readme = "README.md" 21 | license = { file = "LICENSE" } 22 | authors = [{ name = "MLC Authors", email = "junrushao@apache.org" }] 23 | 24 | [project.scripts] 25 | "mlc.config" = "mlc.config:main" 26 | 27 | [project.optional-dependencies] 28 | tests = ['pytest', 'torch', 'jsonpickle'] 29 | dev = [ 30 | "cython>=3.1", 31 | "pre-commit", 32 | "pytest", 33 | "pipx", 34 | "ipdb", 35 | "ruff", 36 | "mypy", 37 | "torch", 38 | "jsonpickle", 39 | ] 40 | 41 | [build-system] 42 | requires = ["scikit-build-core>=0.9.8", "cython>=3.1", "setuptools-scm"] 43 | build-backend = "scikit_build_core.build" 44 | 45 | [tool.setuptools_scm] 46 | version_file = "python/mlc/_version.py" 47 | write_to = "python/mlc/_version.py" 48 | 49 | [tool.scikit-build] 50 | metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" 51 | build.targets = ["mlc_py", "mlc"] 52 | build.verbose = true 53 | cmake.build-type = "RelWithDebInfo" 54 | logging.level = "DEBUG" 55 | wheel.license-files = [] 56 | wheel.install-dir = "mlc" 57 | wheel.py-api = "cp39" 58 | install.strip = false 59 | build-dir = "build-wheels/{wheel_tag}-{build_type}" 60 | sdist.include = ["python/mlc/_version.py"] 61 | 62 | [tool.scikit-build.wheel.packages] 63 | "mlc" = "python/mlc" 64 | 65 | [tool.scikit-build.cmake.define] 66 | MLC_BUILD_PY = "ON" 67 | MLC_BUILD_STATIC = "ON" 68 | MLC_BUILD_TESTS = "OFF" 69 | 70 | [[tool.scikit-build.overrides]] 71 | if.env.MLC_RELEASE = true 72 | inherit.cmake.define = "append" 73 | cmake.define.MLC_BUILD_STATIC = "OFF" 74 | 75 | [tool.ruff] 76 | line-length = 100 77 | indent-width = 4 78 | target-version = "py39" 79 | include = ["pyproject.toml", "python/**/*.py", "tests/python/**/*.py"] 80 | select = [ 81 | "UP", # pyupgrade, https://docs.astral.sh/ruff/rules/#pyupgrade-up 82 | "PL", # pylint, https://docs.astral.sh/ruff/rules/#pylint-pl 83 | "I", # isort, https://docs.astral.sh/ruff/rules/#isort-i 84 | "RUF", # ruff, https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf 85 | "NPY", # numpy, https://docs.astral.sh/ruff/rules/#numpy-specific-rules-npy 86 | "F", # pyflakes, https://docs.astral.sh/ruff/rules/#pyflakes-f 87 | "ANN", # flake8-annotations, https://docs.astral.sh/ruff/rules/#flake8-annotations-ann 88 | "PTH", # flake8-use-pathlib, https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth 89 | # "D", # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d 90 | ] 91 | ignore = [ 92 | "PLR2004", # pylint: magic-value-comparison 93 | "ANN401", # flake8-annotations: any-type 94 | ] 95 | fixable = ["ALL"] 96 | unfixable = [] 97 | 98 | [tool.ruff.lint.per-file-ignores] 99 | "__init__.py" = ["F401"] 100 | 101 | [tool.ruff.format] 102 | quote-style = "double" 103 | indent-style = "space" 104 | skip-magic-trailing-comma = false 105 | line-ending = "auto" 106 | docstring-code-format = false 107 | docstring-code-line-length = "dynamic" 108 | 109 | [tool.cython-lint] 110 | max-line-length = 120 111 | 112 | [tool.mypy] 113 | mypy_path = "./python" 114 | -------------------------------------------------------------------------------- /python/mlc/__init__.py: -------------------------------------------------------------------------------- 1 | from . import _cython, cc, dataclasses, parser, printer 2 | from ._cython import Ptr, Str 3 | from .core import ( 4 | DataType, 5 | Device, 6 | Dict, 7 | Error, 8 | Func, 9 | List, 10 | Object, 11 | ObjectPath, 12 | Opaque, 13 | Tensor, 14 | build_info, 15 | dep_graph, 16 | json_loads, 17 | typing, 18 | ) 19 | from .core.dep_graph import DepGraph, DepNode 20 | from .dataclasses import PyClass, c_class, py_class 21 | 22 | try: 23 | from ._version import __version__, __version_tuple__ # type: ignore[import-not-found] 24 | except ImportError: 25 | __version__ = version = "0.0.0.dev0" 26 | __version_tuple__ = version_tuple = (0, 0, 0, "dev0", "ga1bb7a5c.d20230415") 27 | -------------------------------------------------------------------------------- /python/mlc/_cython/.gitignore: -------------------------------------------------------------------------------- 1 | *.c 2 | *.so 3 | *.html 4 | -------------------------------------------------------------------------------- /python/mlc/_cython/__init__.py: -------------------------------------------------------------------------------- 1 | import ctypes as _ctypes 2 | import pathlib as _pathlib 3 | 4 | from . import core as _core # type: ignore[import-not-found] 5 | from .base import ( 6 | DSO_SUFFIX, 7 | MISSING, 8 | SYSTEM, 9 | DataTypeCode, 10 | DeviceType, 11 | DLDataType, 12 | DLDevice, 13 | Field, 14 | MetaNoSlots, 15 | MLCAny, 16 | MLCHeader, 17 | MLCObjPtr, 18 | Ptr, 19 | TypeField, 20 | TypeInfo, 21 | TypeMethod, 22 | attach_field, 23 | attach_method, 24 | c_class_core, 25 | device_normalize, 26 | dtype_normalize, 27 | ) 28 | from .core import ( # type: ignore[import-not-found] 29 | container_to_py, 30 | cxx_stacktrace_enabled, 31 | device_as_pair, 32 | dtype_as_triple, 33 | dtype_from_triple, 34 | error_get_info, 35 | error_pycode_fake, 36 | func_call, 37 | func_get, 38 | func_init, 39 | func_register, 40 | make_mlc_init, 41 | opaque_init, 42 | register_opauqe_type, 43 | str_c2py, 44 | str_py2c, 45 | tensor_byte_offset, 46 | tensor_data, 47 | tensor_device, 48 | tensor_dtype, 49 | tensor_init, 50 | tensor_ndim, 51 | tensor_shape, 52 | tensor_strides, 53 | tensor_to_dlpack, 54 | toggle_cxx_stacktrace, 55 | type_add_method, 56 | type_cast, 57 | type_create, 58 | type_create_instance, 59 | type_field_get_accessor, 60 | type_index2cached_py_type_info, 61 | type_index2type_methods, 62 | type_key2py_type_info, 63 | type_register_fields, 64 | type_register_structure, 65 | type_table, 66 | ) 67 | 68 | LIB: _ctypes.CDLL = _core.LIB 69 | LIB_PATH: _pathlib.Path = _core.LIB_PATH 70 | Str: type[str] = _core.Str 71 | 72 | 73 | class PyAny(_core.PyAny, metaclass=MetaNoSlots): 74 | __slots__ = () 75 | -------------------------------------------------------------------------------- /python/mlc/cc/__init__.py: -------------------------------------------------------------------------------- 1 | from .compiler import create_shared 2 | from .jit import jit_load 3 | from .loader import load_dso 4 | -------------------------------------------------------------------------------- /python/mlc/cc/jit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import tempfile 4 | from collections.abc import Mapping, Sequence 5 | from pathlib import Path 6 | from types import MappingProxyType 7 | 8 | from mlc._cython import DSO_SUFFIX, SYSTEM 9 | 10 | from .compiler import DEFAULT_OPTIONS, create_shared 11 | from .loader import load_dso 12 | 13 | 14 | def jit_load( 15 | sources: str | Path | Sequence[Path | str], 16 | options: Mapping[str, Sequence[str]] = MappingProxyType(DEFAULT_OPTIONS), 17 | ) -> None: 18 | if isinstance(options, Sequence): 19 | options: dict[str, Sequence[str]] = {SYSTEM: sources} # type: ignore[no-redef] 20 | with tempfile.TemporaryDirectory() as temp_dir_str: 21 | output = Path(temp_dir_str) / f"jit{DSO_SUFFIX}" 22 | create_shared( 23 | sources=sources, 24 | output=output, 25 | options=options, 26 | ) 27 | load_dso(output) 28 | -------------------------------------------------------------------------------- /python/mlc/cc/loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import shutil 4 | from pathlib import Path 5 | 6 | from mlc._cython import SYSTEM 7 | from mlc.core import Func 8 | 9 | _C_load_dso = Func.get("mlc.ffi.LoadDSO") 10 | 11 | 12 | def load_dso(path: Path) -> None: 13 | filename: Path | None = None 14 | if SYSTEM == "Windows": 15 | filename = Path.cwd() / "main.dll" 16 | shutil.copy(str(path), str(filename)) 17 | path = filename 18 | return _C_load_dso(str(path)) 19 | -------------------------------------------------------------------------------- /python/mlc/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import platform 4 | import shutil 5 | import warnings 6 | from pathlib import Path 7 | 8 | from mlc._cython import LIB_PATH, SYSTEM 9 | 10 | 11 | def includedir() -> tuple[Path, ...]: 12 | path = LIB_PATH.parent / ".." / ".." / "include" 13 | path = path.resolve() 14 | return (path,) 15 | 16 | 17 | def libdir() -> Path: 18 | return LIB_PATH.parent.resolve() 19 | 20 | 21 | def probe_vcvarsall() -> Path: 22 | for path in probe_msvc(): 23 | cur = path 24 | while cur.parent != cur: 25 | if cur.name == "VC": 26 | break 27 | cur = cur.parent 28 | else: 29 | continue 30 | vcvarsall = cur / "Auxiliary" / "Build" / "vcvarsall.bat" 31 | if vcvarsall.exists(): 32 | return vcvarsall.resolve() 33 | raise RuntimeError("vcvarsall.bat not found") 34 | 35 | 36 | def probe_msvc() -> tuple[Path, ...]: 37 | import setuptools # type: ignore[import-not-found,import-untyped] 38 | 39 | results = [] 40 | if (path := shutil.which("cl.exe", mode=os.X_OK)) is not None: 41 | results.append(Path(path).resolve()) 42 | 43 | try: 44 | vctools = setuptools.msvc.EnvironmentInfo(platform.machine()).VCTools 45 | except Exception: 46 | pass 47 | else: 48 | for vctool in vctools: 49 | if (cl_exe := Path(vctool) / "cl.exe").exists(): 50 | results.append(cl_exe.resolve()) 51 | return tuple(dict.fromkeys(results)) 52 | 53 | 54 | def probe_compiler() -> tuple[Path, ...]: 55 | results = [] 56 | if compiler := os.environ.get("CXX") or os.environ.get("CC"): 57 | results.append(Path(compiler).resolve()) 58 | if SYSTEM == "Windows": 59 | return probe_msvc() 60 | else: 61 | for compiler in ["g++", "gcc", "clang++", "clang", "c++", "cc"]: 62 | if (path := shutil.which(compiler, mode=os.X_OK)) is not None: 63 | results.append(Path(path).resolve()) 64 | if not results: 65 | warnings.warn("No compiler found. Set environment variable `CXX` to override") 66 | return tuple(dict.fromkeys(results)) 67 | 68 | 69 | def display_build_info() -> None: 70 | from mlc.core import Func 71 | 72 | info = Func.get("mlc.core.BuildInfo")() 73 | for k in sorted(info.keys()): 74 | v = info[k] 75 | print(f"{k}: {v}") 76 | 77 | 78 | def main() -> None: 79 | parser = argparse.ArgumentParser(description="MLC Config Tool") 80 | parser.add_argument("--build-info", action="store_true", help="Print build information") 81 | parser.add_argument("--includedir", action="store_true", help="Print the include directory") 82 | parser.add_argument("--libdir", action="store_true", help="Print the library directory") 83 | parser.add_argument("--probe-compiler", action="store_true", help="Probe the compiler") 84 | parser.add_argument("--probe-msvc", action="store_true", help="Probe MSVC") 85 | parser.add_argument("--probe-vcvarsall", action="store_true", help="Probe vcvarsall.bat") 86 | 87 | def _tuple_path_to_str(paths: tuple[Path, ...]) -> str: 88 | return ";".join(str(path) for path in paths) 89 | 90 | args = parser.parse_args() 91 | has_action = False 92 | if args.build_info: 93 | has_action = True 94 | display_build_info() 95 | if args.includedir: 96 | has_action = True 97 | print(_tuple_path_to_str(includedir())) 98 | if args.libdir: 99 | has_action = True 100 | print(libdir()) 101 | if args.probe_compiler: 102 | has_action = True 103 | print(_tuple_path_to_str(probe_compiler())) 104 | if args.probe_msvc: 105 | has_action = True 106 | print(_tuple_path_to_str(probe_msvc())) 107 | if args.probe_vcvarsall: 108 | has_action = True 109 | print(probe_vcvarsall()) 110 | if not has_action: 111 | parser.print_help() 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /python/mlc/core/__init__.py: -------------------------------------------------------------------------------- 1 | from . import typing 2 | from .device import Device 3 | from .dict import Dict 4 | from .dtype import DataType 5 | from .error import Error 6 | from .func import Func, build_info, json_loads 7 | from .list import List 8 | from .object import Object 9 | from .object_path import ObjectPath 10 | from .opaque import Opaque 11 | from .tensor import Tensor 12 | -------------------------------------------------------------------------------- /python/mlc/core/dep_graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable, Generator 4 | from typing import Any 5 | 6 | from mlc.core import Func, Object 7 | from mlc.dataclasses import c_class 8 | 9 | Stmt = Any 10 | Var = Any 11 | 12 | 13 | @c_class("mlc.core.DepNode") 14 | class DepNode(Object): 15 | stmt: Stmt 16 | input_vars: list[Var] 17 | output_vars: list[Var] 18 | _prev: Object 19 | _next: Object 20 | 21 | @property 22 | def prev(self) -> DepNode | None: 23 | return self._prev 24 | 25 | @property 26 | def next(self) -> DepNode | None: 27 | return self._next 28 | 29 | 30 | @c_class("mlc.core.DepGraph", init=False) 31 | class DepGraph(Object): 32 | _stmt_to_inputs: Func 33 | _stmt_to_outputs: Func 34 | _stmt_to_node: dict 35 | _var_to_producer: dict 36 | _var_to_consumers: dict 37 | _head: DepNode 38 | 39 | @staticmethod 40 | def from_stmts( 41 | input_vars: list[Var], 42 | stmts: list[Stmt], 43 | stmt_to_inputs: Callable[[Stmt], list[Var]], 44 | stmt_to_outputs: Callable[[Stmt], list[Var]], 45 | ) -> DepGraph: 46 | return DepGraph._C( 47 | b"_init_from_stmts", 48 | input_vars, 49 | stmts, 50 | stmt_to_inputs, 51 | stmt_to_outputs, 52 | ) 53 | 54 | def clear(self) -> None: 55 | DepGraph._C(b"clear", self) 56 | 57 | def create_node(self, stmt: Stmt) -> DepNode: 58 | return DepGraph._C(b"create_node", self, stmt) 59 | 60 | def get_node_from_stmt(self, stmt: Stmt) -> DepNode: 61 | return DepGraph._C(b"get_node_from_stmt", self, stmt) 62 | 63 | def insert_before(self, anchor: DepNode, node: DepNode) -> None: 64 | DepGraph._C(b"insert_before", self, anchor, node) 65 | 66 | def insert_after(self, anchor: DepNode, node: DepNode) -> None: 67 | DepGraph._C(b"insert_after", self, anchor, node) 68 | 69 | def erase_node(self, to_erase: DepNode) -> None: 70 | DepGraph._C(b"erase_node", self, to_erase) 71 | 72 | def replace(self, old_node: DepNode, new_node: DepNode) -> None: 73 | DepGraph._C(b"replace", self, old_node, new_node) 74 | 75 | def get_node_producers(self, node: DepNode) -> list[DepNode]: 76 | return DepGraph._C(b"get_node_producers", self, node) 77 | 78 | def get_node_consumers(self, node: DepNode) -> list[DepNode]: 79 | return DepGraph._C(b"get_node_consumers", self, node) 80 | 81 | def get_var_producer(self, v: Var) -> DepNode: 82 | return DepGraph._C(b"get_var_producer", self, v) 83 | 84 | def get_var_consumers(self, v: Var) -> list[DepNode]: 85 | return DepGraph._C(b"get_var_consumers", self, v) 86 | 87 | @property 88 | def nodes(self) -> Generator[DepNode, None, None]: 89 | node: DepNode | None = self._head 90 | while node is not None: 91 | yield node 92 | node = node.next 93 | -------------------------------------------------------------------------------- /python/mlc/core/device.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from mlc._cython import DeviceType, PyAny, c_class_core, device_as_pair, device_normalize 6 | 7 | if TYPE_CHECKING: 8 | import torch 9 | 10 | 11 | @c_class_core("Device") 12 | class Device(PyAny): 13 | def __init__(self, device: str | Device | torch.device) -> None: 14 | self._mlc_init(device_normalize(device)) 15 | 16 | @property 17 | def _device_pair(self) -> tuple[int, int]: 18 | return device_as_pair(self) 19 | 20 | @property 21 | def device_type(self) -> DeviceType | int: 22 | code = self._device_pair[0] 23 | try: 24 | return DeviceType(code) 25 | except ValueError: 26 | return code 27 | 28 | @property 29 | def device_id(self) -> int: 30 | return self._device_pair[1] 31 | 32 | def __eq__(self, other: object) -> bool: 33 | return isinstance(other, Device) and self._device_pair == other._device_pair 34 | 35 | def __ne__(self, other: object) -> bool: 36 | return isinstance(other, Device) and self._dtype_triple != other._dtype_triple 37 | 38 | def __hash__(self) -> int: 39 | return hash((Device, *self._device_pair)) 40 | 41 | def torch(self) -> torch.device: 42 | import torch 43 | 44 | return torch.device(str(self)) 45 | 46 | @staticmethod 47 | def register(name: str) -> int: 48 | from .func import Func 49 | 50 | return Func.get("mlc.base.DeviceTypeRegister")(name) 51 | -------------------------------------------------------------------------------- /python/mlc/core/dtype.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import numpy as np 6 | 7 | from mlc._cython import ( 8 | DataTypeCode, 9 | PyAny, 10 | c_class_core, 11 | dtype_as_triple, 12 | dtype_from_triple, 13 | dtype_normalize, 14 | ) 15 | 16 | if TYPE_CHECKING: 17 | import torch 18 | 19 | 20 | @c_class_core("dtype") 21 | class DataType(PyAny): 22 | def __init__(self, dtype: str | np.dtype | type[torch.dtype] | DataType) -> None: 23 | self._mlc_init(dtype_normalize(dtype)) 24 | 25 | @property 26 | def _dtype_triple(self) -> tuple[int, int, int]: 27 | return dtype_as_triple(self) 28 | 29 | @staticmethod 30 | def from_triple(code: DataTypeCode | int, bits: int, lanes: int) -> DataType: 31 | if isinstance(code, DataTypeCode): 32 | code = code.value 33 | return dtype_from_triple(code, bits, lanes) 34 | 35 | @property 36 | def code(self) -> DataTypeCode | int: 37 | code = self._dtype_triple[0] 38 | try: 39 | return DataTypeCode(code) 40 | except ValueError: 41 | return code 42 | 43 | @property 44 | def bits(self) -> int: 45 | return self._dtype_triple[1] 46 | 47 | @property 48 | def lanes(self) -> int: 49 | return self._dtype_triple[2] 50 | 51 | def __eq__(self, other: object) -> bool: 52 | if isinstance(other, str): 53 | other = DataType(other) 54 | return isinstance(other, DataType) and self._dtype_triple == other._dtype_triple 55 | 56 | def __ne__(self, other: object) -> bool: 57 | if isinstance(other, str): 58 | other = DataType(other) 59 | return isinstance(other, DataType) and self._dtype_triple != other._dtype_triple 60 | 61 | def __hash__(self) -> int: 62 | return hash((DataType, *self._dtype_triple)) 63 | 64 | def torch(self) -> torch.dtype: 65 | import torch 66 | 67 | if (ret := getattr(torch, str(self), None)) is not None: 68 | if isinstance(ret, torch.dtype): 69 | return ret 70 | raise ValueError(f"Cannot convert to `torch.dtype` from: {self}") 71 | 72 | def numpy(self) -> np.dtype: 73 | return np.dtype(str(self)) 74 | 75 | @staticmethod 76 | def register(name: str, bits: int) -> int: 77 | from .func import Func 78 | 79 | return Func.get("mlc.base.DataTypeRegister")(name, bits) 80 | -------------------------------------------------------------------------------- /python/mlc/core/error.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | from typing import TextIO 5 | 6 | from mlc._cython import c_class_core, error_get_info 7 | 8 | from .object import Object 9 | 10 | 11 | @c_class_core("object.Error") 12 | class Error(Object): 13 | kind: str 14 | 15 | @property 16 | def _info(self) -> list[str]: 17 | return error_get_info(self) 18 | 19 | def print_exc(self, file: TextIO | None) -> None: 20 | if not file: 21 | file = sys.stdout 22 | print("Traceback (most recent call last):", file=file) 23 | info = self._info 24 | msg, info = info[0], info[1:] 25 | frame_id = 0 26 | while info: 27 | frame_id += 1 28 | filename, lineno, funcname, info = info[0], int(info[1]), info[2], info[3:] 29 | print(f' [{frame_id}] File "{filename}", line {lineno}, in {funcname}', file=file) 30 | print(f"{self.kind}: {msg}", file=file) 31 | -------------------------------------------------------------------------------- /python/mlc/core/func.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from typing import Any, TypeVar 5 | 6 | from mlc._cython import ( 7 | c_class_core, 8 | cxx_stacktrace_enabled, 9 | func_call, 10 | func_get, 11 | func_init, 12 | func_register, 13 | ) 14 | 15 | from .object import Object 16 | 17 | _CallableType = TypeVar("_CallableType", bound=Callable) 18 | 19 | 20 | @c_class_core("object.Func") 21 | class Func(Object): 22 | def __init__(self, func: Callable) -> None: 23 | assert callable(func), "func must be callable" 24 | func_init(self, func) 25 | 26 | def __call__(self, *args: Any) -> Any: 27 | if cxx_stacktrace_enabled(): 28 | return func_call(self, args) 29 | else: 30 | try: 31 | return func_call(self, args) 32 | except Exception as e: 33 | raise e.with_traceback(None) 34 | 35 | @staticmethod 36 | def get(name: str, allow_missing: bool = False) -> Func: 37 | ret = func_get(name) 38 | if (not allow_missing) and (ret is None): 39 | raise ValueError(f"Can't find global function: {name}") 40 | return ret 41 | 42 | @staticmethod 43 | def register( 44 | name: str, 45 | allow_override: bool = False, 46 | ) -> Callable[[_CallableType], _CallableType]: 47 | def decorator(func: _CallableType) -> _CallableType: 48 | func_register(name, allow_override, func) 49 | return func 50 | 51 | return decorator 52 | 53 | 54 | def json_loads(s: str) -> Any: 55 | return _json_loads(s) 56 | 57 | 58 | def build_info() -> dict[str, Any]: 59 | return _build_info() 60 | 61 | 62 | _json_loads = Func.get("mlc.core.JSONLoads") 63 | _build_info = Func.get("mlc.core.BuildInfo") 64 | -------------------------------------------------------------------------------- /python/mlc/core/list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABCMeta 4 | from collections.abc import Iterable, Iterator, Sequence 5 | from typing import Any, TypeVar, overload 6 | 7 | from mlc._cython import MetaNoSlots, Ptr, c_class_core, container_to_py 8 | 9 | from .object import Object 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | class ListMeta(MetaNoSlots, ABCMeta): ... 15 | 16 | 17 | @c_class_core("object.List") 18 | class List(Object, Sequence[T], metaclass=ListMeta): 19 | capacity: int 20 | size: int 21 | _frozen: int 22 | data: Ptr 23 | 24 | def __init__(self, iterable: Iterable[T] = ()) -> None: 25 | self._mlc_init(*iterable) 26 | 27 | def __len__(self) -> int: 28 | return self.size 29 | 30 | def freeze(self) -> None: 31 | if self._frozen == 0: 32 | self._frozen = 1 33 | 34 | @property 35 | def frozen(self) -> bool: 36 | return self._frozen == 1 37 | 38 | @overload 39 | def __getitem__(self, i: int) -> T: ... 40 | 41 | @overload 42 | def __getitem__(self, i: slice) -> Sequence[T]: ... 43 | 44 | def __getitem__(self, i: int | slice) -> T | Sequence[T]: 45 | if isinstance(i, int): 46 | i = _normalize_index(i, len(self)) 47 | return List._C(b"__iter_at__", self, i) 48 | elif isinstance(i, slice): 49 | # Implement slicing 50 | start, stop, step = i.indices(len(self)) 51 | return List([self[i] for i in range(start, stop, step)]) 52 | else: 53 | raise TypeError(f"list indices must be integers or slices, not {type(i).__name__}") 54 | 55 | def __setitem__(self, index: int, value: T) -> None: 56 | if self._frozen: 57 | raise RuntimeError("Cannot modify a frozen list") 58 | length = len(self) 59 | if not -length <= index < length: 60 | raise IndexError(f"list assignment index out of range: {index}") 61 | if index < 0: 62 | index += length 63 | List._C(b"__setitem__", self, index, value) 64 | 65 | def __iter__(self) -> Iterator[T]: 66 | return iter(self[i] for i in range(len(self))) 67 | 68 | def __add__(self, other: Sequence[T]) -> List[T]: 69 | if not isinstance(other, (list, tuple, List)): 70 | return NotImplemented 71 | result = List(self) 72 | result.extend(other) 73 | return result 74 | 75 | def __radd__(self, other: Sequence[T]) -> List[T]: 76 | if not isinstance(other, (list, tuple)): 77 | return NotImplemented 78 | result = List(other) 79 | result.extend(self) 80 | return result 81 | 82 | def insert(self, i: int, x: T) -> None: 83 | if self._frozen: 84 | raise RuntimeError("Cannot modify a frozen list") 85 | i = _normalize_index(i, len(self) + 1) 86 | return List._C(b"_insert", self, i, x) 87 | 88 | def append(self, x: T) -> None: 89 | if self._frozen: 90 | raise RuntimeError("Cannot modify a frozen list") 91 | return List._C(b"_append", self, x) 92 | 93 | def pop(self, i: int = -1) -> T: 94 | if self._frozen: 95 | raise RuntimeError("Cannot modify a frozen list") 96 | i = _normalize_index(i, len(self)) 97 | return List._C(b"_pop", self, i) 98 | 99 | def clear(self) -> None: 100 | if self._frozen: 101 | raise RuntimeError("Cannot modify a frozen list") 102 | return List._C(b"_clear", self) 103 | 104 | def extend(self, iterable: Iterable[T]) -> None: 105 | if self._frozen: 106 | raise RuntimeError("Cannot modify a frozen list") 107 | return List._C(b"_extend", self, *iterable) 108 | 109 | def __eq__(self, other: Any) -> bool: 110 | if isinstance(other, List) and self._mlc_address == other._mlc_address: 111 | return True 112 | if not isinstance(other, (list, tuple, List)): 113 | return False 114 | if len(self) != len(other): 115 | return False 116 | return all(a == b for a, b in zip(self, other)) 117 | 118 | def __ne__(self, other: Any) -> bool: 119 | return not (self == other) 120 | 121 | def __delitem__(self, i: int) -> None: 122 | if self._frozen: 123 | raise RuntimeError("Cannot modify a frozen list") 124 | self.pop(i) 125 | 126 | def py(self) -> list[T]: 127 | return container_to_py(self) 128 | 129 | 130 | def _normalize_index(i: int, length: int) -> int: 131 | if not -length <= i < length: 132 | raise IndexError(f"list index out of range: {i}") 133 | if i < 0: 134 | i += length 135 | return i 136 | -------------------------------------------------------------------------------- /python/mlc/core/object.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | from collections.abc import Callable 5 | 6 | from mlc._cython import PyAny, TypeInfo, c_class_core 7 | 8 | 9 | @c_class_core("object.Object") 10 | class Object(PyAny): 11 | def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: 12 | def init() -> None: 13 | self._mlc_init() 14 | 15 | init(*args, **kwargs) 16 | 17 | @property 18 | def id_(self) -> int: 19 | return self._mlc_address 20 | 21 | def is_(self, other: Object) -> bool: 22 | return isinstance(other, Object) and self._mlc_address == other._mlc_address 23 | 24 | def json( 25 | self, 26 | fn_opaque_serialize: Callable[[list[typing.Any]], str] | None = None, 27 | ) -> str: 28 | return super()._mlc_json(fn_opaque_serialize) 29 | 30 | @staticmethod 31 | def from_json( 32 | json_str: str, 33 | fn_opaque_deserialize: Callable[[str], list[typing.Any]] | None = None, 34 | ) -> Object: 35 | return PyAny._mlc_from_json(json_str, fn_opaque_deserialize) # type: ignore[attr-defined] 36 | 37 | def eq_s( 38 | self, 39 | other: Object, 40 | *, 41 | bind_free_vars: bool = True, 42 | assert_mode: bool = False, 43 | ) -> bool: 44 | return PyAny._mlc_eq_s(self, other, bind_free_vars, assert_mode) # type: ignore[attr-defined] 45 | 46 | def eq_s_fail_reason( 47 | self, 48 | other: Object, 49 | *, 50 | bind_free_vars: bool = True, 51 | ) -> tuple[bool, str]: 52 | return PyAny._mlc_eq_s_fail_reason(self, other, bind_free_vars) 53 | 54 | def hash_s(self) -> int: 55 | return PyAny._mlc_hash_s(self) # type: ignore[attr-defined] 56 | 57 | def eq_ptr(self, other: typing.Any) -> bool: 58 | return isinstance(other, Object) and self._mlc_address == other._mlc_address 59 | 60 | def __copy__(self: Object) -> Object: 61 | return PyAny._mlc_copy_shallow(self) # type: ignore[attr-defined] 62 | 63 | def __deepcopy__(self: Object, memo: dict[int, Object] | None) -> Object: 64 | return PyAny._mlc_copy_deep(self) 65 | 66 | def __replace__(self: Object, /, **changes: typing.Any) -> Object: 67 | unpacked: list[typing.Any] = [self] 68 | for key, value in changes.items(): 69 | unpacked.append(key) 70 | unpacked.append(value) 71 | return PyAny._mlc_copy_replace(*unpacked) 72 | 73 | def __hash__(self) -> int: 74 | return hash((type(self), self._mlc_address)) 75 | 76 | def __eq__(self, other: typing.Any) -> bool: 77 | return self.eq_ptr(other) 78 | 79 | def __ne__(self, other: typing.Any) -> bool: 80 | return not self == other 81 | 82 | def _mlc_setattr(self, name: str, value: typing.Any) -> None: 83 | type_info: TypeInfo = type(self)._mlc_type_info 84 | for field in type_info.fields: 85 | if field.name == name: 86 | if field.setter is None: 87 | raise AttributeError(f"Attribute `{name}` missing setter") 88 | field.setter(self, value) 89 | return 90 | raise AttributeError(f"Attribute `{name}` not found in `{type(self)}`") 91 | 92 | def _mlc_getattr(self, name: str) -> typing.Any: 93 | type_info: TypeInfo = type(self)._mlc_type_info 94 | for field in type_info.fields: 95 | if field.name == name: 96 | if field.getter is None: 97 | raise AttributeError(f"Attribute `{name}` missing getter") 98 | return field.getter(self) 99 | raise AttributeError(f"Attribute `{name}` not found in `{type(self)}`") 100 | 101 | def swap(self, other: typing.Any) -> None: 102 | if type(self) == type(other): 103 | self._mlc_swap(other) 104 | else: 105 | raise TypeError(f"Cannot different types: `{type(self)}` and `{type(other)}`") 106 | -------------------------------------------------------------------------------- /python/mlc/core/object_path.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from mlc._cython import c_class_core 6 | 7 | from .object import Object 8 | 9 | 10 | @c_class_core("mlc.core.ObjectPath") 11 | class ObjectPath(Object): 12 | kind: int 13 | key: Any 14 | prev: Object | None 15 | length: int 16 | 17 | @staticmethod 18 | def root() -> ObjectPath: 19 | return ObjectPath._C(b"root") 20 | 21 | def with_field(self, field: str) -> ObjectPath: 22 | return ObjectPath._C(b"with_field", self, field) 23 | 24 | def with_list_index(self, index: int) -> ObjectPath: 25 | return ObjectPath._C(b"with_list_index", self, index) 26 | 27 | def with_dict_key(self, key: Any) -> ObjectPath: 28 | return ObjectPath._C(b"with_dict_key", self, key) 29 | 30 | def equal(self, other: ObjectPath) -> bool: 31 | return ObjectPath._C(b"equal", self, other) 32 | 33 | def get_prefix(self, length: int) -> ObjectPath: 34 | return ObjectPath._C(b"get_prefix", self, length) 35 | 36 | def is_prefix_of(self, other: ObjectPath) -> bool: 37 | return ObjectPath._C(b"is_prefix_of", self, other) 38 | 39 | def __getitem__(self, key: str | int) -> ObjectPath: 40 | if isinstance(key, int): 41 | return self.with_list_index(key) 42 | elif isinstance(key, str): 43 | return self.with_field(key) 44 | else: 45 | raise TypeError( 46 | f"Unsupported key type: {type(key)}. Please explicitly use " 47 | "`with_field`, `with_list_index`, or `with_dict_key` methods." 48 | ) 49 | -------------------------------------------------------------------------------- /python/mlc/core/opaque.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | from collections.abc import Callable 5 | from typing import Any, Literal 6 | 7 | from mlc._cython import Ptr, c_class_core, func_register, opaque_init, register_opauqe_type 8 | 9 | from .object import Object 10 | 11 | 12 | @c_class_core("mlc.core.Opaque") 13 | class Opaque(Object): 14 | handle: Ptr 15 | 16 | def __init__(self, instance: Any) -> None: 17 | opaque_init(self, instance) 18 | 19 | @staticmethod 20 | def register( 21 | ty: type, 22 | eq_s: Callable | Literal["default"] | None = "default", 23 | hash_s: Callable | Literal["default"] | None = "default", 24 | deepcopy: Callable | Literal["default"] | None = "default", 25 | ) -> None: 26 | register_opauqe_type(ty) 27 | name = ty.__module__ + "." + ty.__name__ 28 | 29 | if isinstance(eq_s, str) and eq_s == "default": 30 | func_register(f"Opaque.eq_s.{name}", False, lambda a, b: a == b) 31 | elif callable(eq_s): 32 | func_register(f"Opaque.eq_s.{name}", False, eq_s) 33 | else: 34 | assert eq_s is None, "eq_s must be a callable, a literal 'default', or None" 35 | 36 | if isinstance(hash_s, str) and hash_s == "default": 37 | func_register(f"Opaque.hash_s.{name}", False, lambda a: hash(a)) 38 | elif callable(hash_s): 39 | func_register(f"Opaque.hash_s.{name}", False, hash_s) 40 | else: 41 | assert hash_s is None, "hash_s must be a callable, a literal 'default', or None" 42 | 43 | if isinstance(deepcopy, str) and deepcopy == "default": 44 | func_register(f"Opaque.deepcopy.{name}", False, lambda a: copy.deepcopy(a)) 45 | elif callable(deepcopy): 46 | func_register(f"Opaque.deepcopy.{name}", False, deepcopy) 47 | else: 48 | assert deepcopy is None, "deepcopy must be a callable, a literal 'default', or None" 49 | 50 | 51 | def _default_serialize(opaques: list[Any]) -> str: 52 | import jsonpickle # type: ignore[import-untyped] 53 | 54 | return jsonpickle.dumps(list(opaques)) 55 | 56 | 57 | def _default_deserialize(json_str: str) -> list[Any]: 58 | import jsonpickle # type: ignore[import-untyped] 59 | 60 | return jsonpickle.loads(json_str) 61 | 62 | 63 | func_register("mlc.Opaque.default.serialize", False, _default_serialize) 64 | func_register("mlc.Opaque.default.deserialize", False, _default_deserialize) 65 | -------------------------------------------------------------------------------- /python/mlc/core/tensor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any 4 | 5 | import numpy as np 6 | 7 | from mlc._cython import ( 8 | Ptr, 9 | c_class_core, 10 | tensor_byte_offset, 11 | tensor_data, 12 | tensor_device, 13 | tensor_dtype, 14 | tensor_init, 15 | tensor_ndim, 16 | tensor_shape, 17 | tensor_strides, 18 | tensor_to_dlpack, 19 | ) 20 | 21 | from .func import Func 22 | from .object import Object 23 | 24 | if TYPE_CHECKING: 25 | import torch 26 | 27 | from mlc.core import DataType, Device 28 | 29 | 30 | @c_class_core("mlc.core.Tensor") 31 | class Tensor(Object): 32 | def __init__(self, tensor: Any) -> None: 33 | tensor_init(self, tensor) 34 | 35 | @property 36 | def data(self) -> Ptr: 37 | return tensor_data(self) 38 | 39 | @property 40 | def device(self) -> Device: 41 | return tensor_device(self) 42 | 43 | @property 44 | def dtype(self) -> DataType: 45 | return tensor_dtype(self) 46 | 47 | @property 48 | def ndim(self) -> int: 49 | return tensor_ndim(self) 50 | 51 | @property 52 | def shape(self) -> tuple[int, ...]: 53 | return tensor_shape(self) 54 | 55 | @property 56 | def strides(self) -> tuple[int, ...] | None: 57 | return tensor_strides(self) 58 | 59 | @property 60 | def byte_offset(self) -> int: 61 | return tensor_byte_offset(self) 62 | 63 | def base64(self) -> str: 64 | return TensorToBase64(self) 65 | 66 | @staticmethod 67 | def from_base64(base64: str) -> Tensor: 68 | return TensorFromBase64(base64) 69 | 70 | def __dlpack__(self, stream: Any = None) -> Any: 71 | return tensor_to_dlpack(self) 72 | 73 | def __dlpack_device__(self) -> tuple[int, int]: 74 | return self.device._device_pair 75 | 76 | def numpy(self) -> np.ndarray: 77 | return np.from_dlpack(self) 78 | 79 | def torch(self) -> torch.Tensor: 80 | import torch 81 | 82 | return torch.from_dlpack(self) 83 | 84 | 85 | TensorToBase64 = Func.get("mlc.core.TensorToBase64") 86 | TensorFromBase64 = Func.get("mlc.core.TensorFromBase64") 87 | -------------------------------------------------------------------------------- /python/mlc/core/typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import ctypes 4 | import sys 5 | import types 6 | import typing 7 | 8 | from mlc._cython import DLDataType, DLDevice, MLCAny, MLCObjPtr, Ptr, c_class_core, type_cast 9 | 10 | from .object import Object 11 | 12 | if sys.version_info >= (3, 10): 13 | UnionType = types.UnionType 14 | else: 15 | UnionType = None 16 | 17 | 18 | @c_class_core("mlc.core.typing.Type") 19 | class Type(Object): 20 | def args(self) -> tuple[Type, ...]: 21 | raise NotImplementedError 22 | 23 | def cast(self, obj: typing.Any) -> typing.Any: 24 | return type_cast(self, obj) 25 | 26 | def _ctype(self) -> typing.Any: 27 | raise NotImplementedError 28 | 29 | def cxx_str(self) -> str: 30 | return self._C(b"__cxx_str__", self) 31 | 32 | 33 | @c_class_core("mlc.core.typing.AnyType") 34 | class AnyType(Type): 35 | def __init__(self) -> None: 36 | self._mlc_init() 37 | 38 | def args(self) -> tuple: 39 | return () 40 | 41 | def _ctype(self) -> typing.Any: 42 | return MLCAny 43 | 44 | 45 | @c_class_core("mlc.core.typing.AtomicType") 46 | class AtomicType(Type): 47 | type_index: int 48 | 49 | def __init__(self, type_index: int) -> None: 50 | self._mlc_init(type_index) 51 | 52 | def args(self) -> tuple: 53 | return () 54 | 55 | def _ctype(self) -> typing.Any: 56 | type_index = self.type_index 57 | if type_index >= _kMLCStaticObjectBegin: 58 | return Ptr 59 | if (ret := _TYPE_INDEX_TO_CTYPES.get(type_index)) is not None: 60 | return ret 61 | raise ValueError(f"Unsupported type index: {type_index}") 62 | 63 | 64 | @c_class_core("mlc.core.typing.PtrType") 65 | class PtrType(Type): 66 | ty: Type 67 | 68 | def __init__(self, ty: Type) -> None: 69 | self._mlc_init(ty) 70 | 71 | def args(self) -> tuple: 72 | return (self.ty,) 73 | 74 | def _ctype(self) -> typing.Any: 75 | return MLCObjPtr 76 | 77 | 78 | @c_class_core("mlc.core.typing.Optional") 79 | class Optional(Type): 80 | ty: Type 81 | 82 | def __init__(self, ty: Type) -> None: 83 | self._mlc_init(ty) 84 | 85 | def args(self) -> tuple: 86 | return (self.ty,) 87 | 88 | def _ctype(self) -> typing.Any: 89 | return MLCObjPtr 90 | 91 | 92 | @c_class_core("mlc.core.typing.List") 93 | class List(Type): 94 | ty: Type 95 | 96 | def __init__(self, ty: Type) -> None: 97 | self._mlc_init(ty) 98 | 99 | def args(self) -> tuple: 100 | return (self.ty,) 101 | 102 | def _ctype(self) -> typing.Any: 103 | return MLCObjPtr 104 | 105 | 106 | @c_class_core("mlc.core.typing.Dict") 107 | class Dict(Type): 108 | ty_k: Type 109 | ty_v: Type 110 | 111 | def __init__(self, ty_k: Type, ty_v: Type) -> None: 112 | self._mlc_init(ty_k, ty_v) 113 | 114 | def args(self) -> tuple: 115 | return (self.ty_k, self.ty_v) 116 | 117 | def _ctype(self) -> typing.Any: 118 | return MLCObjPtr 119 | 120 | 121 | def from_py(ann: type) -> Type: 122 | if (ty := _PY_TYPE_TO_MLC_TYPE.get(ann)) is not None: 123 | return ty 124 | elif (type_info := getattr(ann, "_mlc_type_info", None)) is not None: 125 | return AtomicType(type_info.type_index) 126 | elif (origin := typing.get_origin(ann)) is not None: 127 | from mlc.core import Dict as MLCDict 128 | from mlc.core import List as MLCList 129 | 130 | args = typing.get_args(ann) 131 | if (origin is list) or (origin is MLCList): 132 | if len(args) == 1: 133 | return List(from_py(args[0])) 134 | raise ValueError(f"Unsupported type: {ann}") 135 | elif (origin is dict) or (origin is MLCDict): 136 | if len(args) == 2: 137 | return Dict(from_py(args[0]), from_py(args[1])) 138 | raise ValueError(f"Unsupported type: {ann}") 139 | elif origin is tuple: 140 | raise ValueError("Unsupported type: `tuple`. Use `list` instead.") 141 | elif (origin is UnionType) or (origin is typing.Union): 142 | if len(args) == 2: 143 | if args[1] is type(None): 144 | return Optional(from_py(args[0])) 145 | if args[0] is type(None): 146 | return Optional(from_py(args[1])) 147 | raise ValueError(f"Unsupported type: {ann}") 148 | raise ValueError(f"Unsupported type: {ann}") 149 | 150 | 151 | _kMLCBool = 1 152 | _kMLCInt = 2 153 | _kMLCFloat = 3 154 | _kMLCPtr = 4 155 | _kMLCDataType = 5 156 | _kMLCDevice = 6 157 | _kMLCRawStr = 7 158 | _kMLCStaticObjectBegin = 1000 159 | _kMLCStr = 1005 160 | _Any = AnyType() 161 | _BOOL = AtomicType(_kMLCBool) 162 | _INT = AtomicType(_kMLCInt) 163 | _FLOAT = AtomicType(_kMLCFloat) 164 | _PTR = AtomicType(_kMLCPtr) 165 | _STR = AtomicType(_kMLCStr) 166 | _UList = List(_Any) 167 | _UDict = Dict(_Any, _Any) 168 | _TYPE_INDEX_TO_CTYPES = { 169 | _kMLCBool: ctypes.c_bool, 170 | _kMLCInt: ctypes.c_int64, 171 | _kMLCFloat: ctypes.c_double, 172 | _kMLCPtr: Ptr, 173 | _kMLCDataType: DLDataType, 174 | _kMLCDevice: DLDevice, 175 | _kMLCRawStr: Ptr, 176 | } 177 | _PY_TYPE_TO_MLC_TYPE: dict[typing.Any, Type] = { 178 | bool: _BOOL, 179 | int: _INT, 180 | float: _FLOAT, 181 | Ptr: _PTR, 182 | str: _STR, 183 | typing.Any: _Any, 184 | Ellipsis: _Any, 185 | list: _UList, 186 | dict: _UDict, 187 | typing.List: _UList, # noqa: UP006 188 | typing.Dict: _UDict, # noqa: UP006 189 | } 190 | -------------------------------------------------------------------------------- /python/mlc/dataclasses/__init__.py: -------------------------------------------------------------------------------- 1 | from .c_class import c_class 2 | from .py_class import PyClass, py_class 3 | from .utils import ( 4 | Structure, 5 | add_vtable_method, 6 | field, 7 | prototype, 8 | replace, 9 | stringify, 10 | vtable_method, 11 | ) 12 | -------------------------------------------------------------------------------- /python/mlc/dataclasses/c_class.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import typing 3 | import warnings 4 | from collections.abc import Callable 5 | 6 | from mlc._cython import ( 7 | TypeInfo, 8 | TypeMethod, 9 | attach_field, 10 | attach_method, 11 | type_index2type_methods, 12 | type_key2py_type_info, 13 | ) 14 | from mlc.core import typing as mlc_typing 15 | 16 | from .utils import ( 17 | add_vtable_methods_for_type_cls, 18 | get_parent_type, 19 | inspect_dataclass_fields, 20 | method_init, 21 | prototype, 22 | ) 23 | 24 | ClsType = typing.TypeVar("ClsType") 25 | 26 | 27 | def c_class( 28 | type_key: str, 29 | init: bool = True, 30 | ) -> Callable[[type[ClsType]], type[ClsType]]: 31 | def decorator(super_type_cls: type[ClsType]) -> type[ClsType]: 32 | @functools.wraps(super_type_cls, updated=()) 33 | class type_cls(super_type_cls): # type: ignore[valid-type,misc] 34 | __slots__ = () 35 | 36 | # Step 1. Retrieve `type_info` from registry 37 | type_info: TypeInfo = type_key2py_type_info(type_key) 38 | parent_type_info: TypeInfo = get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined] 39 | 40 | if type_info.type_cls is not None: 41 | raise ValueError(f"Type is already registered: {type_key}") 42 | _, d_fields = inspect_dataclass_fields(type_key, type_cls, parent_type_info, frozen=False) 43 | type_info.type_cls = type_cls 44 | type_info.d_fields = tuple(d_fields) 45 | 46 | # Step 2. Check if all fields are exposed as type annotations 47 | _check_c_class(super_type_cls, type_info) 48 | 49 | # Step 3. Attach fields 50 | setattr(type_cls, "_mlc_type_info", type_info) 51 | for field in type_info.fields: 52 | attach_field( 53 | cls=type_cls, 54 | name=field.name, 55 | getter=field.getter, 56 | setter=field.setter, 57 | frozen=field.frozen, 58 | ) 59 | 60 | # Step 4. Attach methods 61 | if init: 62 | attach_method( 63 | parent_cls=super_type_cls, 64 | cls=type_cls, 65 | name="__init__", 66 | method=method_init(super_type_cls, d_fields), 67 | check_exists=True, 68 | ) 69 | add_vtable_methods_for_type_cls(super_type_cls, type_index=type_info.type_index) 70 | return type_cls 71 | 72 | return decorator 73 | 74 | 75 | def _check_c_class( 76 | type_cls: type[ClsType], 77 | type_info: TypeInfo, 78 | ) -> None: 79 | type_hints = typing.get_type_hints(type_cls) 80 | warned: bool = False 81 | for field in type_info.fields: 82 | if field.name in type_hints: 83 | c_ty_str = field.ty.__str__() 84 | py_ty = type_hints.pop(field.name) 85 | py_ty_str = mlc_typing.from_py(py_ty).__str__() 86 | if c_ty_str != py_ty_str and not (c_ty_str == "char*" and py_ty_str == "str"): 87 | warnings.warn( 88 | f"Type mismatch on `{type_cls.__module__}.{type_cls.__qualname__}.{field.name}`: " 89 | f"Expected `{c_ty_str}`, but got `{py_ty_str}`." 90 | ) 91 | warned = True 92 | else: 93 | warnings.warn( 94 | f"Attribute not found: `{type_cls.__module__}.{type_cls.__qualname__}.{field.name}`. " 95 | f"Add `{field.name}: {field.ty}` to class definition." 96 | ) 97 | warned = True 98 | if type_hints: 99 | extra_attrs = ", ".join(str(k) for k in type_hints) 100 | warnings.warn(f"Unused attributes in class definition: {extra_attrs}") 101 | warned = True 102 | method: TypeMethod 103 | for method in type_index2type_methods(type_info.type_index): 104 | if method.name.startswith("_"): 105 | continue 106 | func = getattr(type_cls, method.name, None) 107 | if func is None: 108 | warnings.warn( 109 | f"Method not found `{type_cls.__module__}.{type_cls.__qualname__}.{method.name}`. " 110 | ) 111 | warned = True 112 | elif not callable(func): 113 | warnings.warn( 114 | f"Attribute `{type_cls.__module__}.{type_cls.__qualname__}.{method.name}` is not a method. " 115 | ) 116 | warned = True 117 | if warned: 118 | warnings.warn( 119 | f"One or multiple warnings in `{type_cls.__module__}.{type_cls.__qualname__}`. Its prototype is:\n" 120 | + prototype(type_info, lang="py") 121 | ) 122 | -------------------------------------------------------------------------------- /python/mlc/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .diagnostic import DiagnosticError 2 | from .env import Env, Span, check_decorator 3 | from .parser import Frame, Parser 4 | -------------------------------------------------------------------------------- /python/mlc/parser/diagnostic.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | from types import TracebackType 6 | from typing import Any, Literal 7 | 8 | import colorama # type: ignore[import-untyped] 9 | 10 | from .env import Span 11 | 12 | colorama.init(autoreset=True) 13 | 14 | 15 | class DiagnosticError(Exception): 16 | pass 17 | 18 | 19 | def raise_diagnostic_error( 20 | source: str, 21 | source_name: str, 22 | span: Span, 23 | err: Exception | str, 24 | ) -> None: 25 | diag_err = DiagnosticError("Diagnostics were emitted, please check rendered error message.") 26 | if isinstance(err, Exception): 27 | msg = type(err).__name__ + ": " + [i for i in str(err).split("\n") if i][-1] 28 | diag_err.with_traceback(err.__traceback__) 29 | else: 30 | msg = str(err) 31 | print( 32 | _render_at( 33 | source=source, 34 | source_name=source_name, 35 | span=span, 36 | message=msg, 37 | level="error", 38 | ), 39 | file=sys.stderr, 40 | ) 41 | if isinstance(err, Exception): 42 | raise diag_err from err 43 | else: 44 | raise diag_err 45 | 46 | 47 | def _render_at( 48 | source: str, 49 | source_name: str, 50 | span: Span, 51 | message: str, 52 | level: Literal["warning", "error", "bug", "note"] = "error", 53 | ) -> str: 54 | lines = source.splitlines() 55 | row_st = max(1, span.row_st) 56 | row_ed = min(len(lines), span.row_ed) 57 | # If no valid rows, just return the bare message. 58 | if row_st > row_ed: 59 | return message 60 | # Map the "level" to a color and label (similar to rang::fg usage in C++). 61 | color, diag_type = { 62 | "warning": (colorama.Fore.YELLOW, "warning"), 63 | "error": (colorama.Fore.RED, "error"), 64 | "bug": (colorama.Fore.BLUE, "bug"), 65 | "note": (colorama.Fore.RESET, "note"), 66 | "help": (colorama.Fore.RESET, "help"), 67 | }.get(level, (colorama.Fore.RED, "error")) 68 | 69 | # Prepare lines of output 70 | out_lines = [ 71 | f"{colorama.Style.BRIGHT}{color}{diag_type}{colorama.Style.RESET_ALL}: {message}", 72 | f"{colorama.Fore.BLUE} --> {colorama.Style.RESET_ALL}{source_name}:{row_st}:{span.col_st}", 73 | ] 74 | left_margin_width = len(str(row_ed)) 75 | for row_idx in range(row_st, row_ed + 1): 76 | line_text = lines[row_idx - 1] # zero-based 77 | line_label = str(row_idx).rjust(left_margin_width) 78 | # Step 1. the actual source line 79 | out_lines.append(f"{line_label} | {line_text}") 80 | # Step 2. the marker line: 81 | marker = [" "] * len(line_text) 82 | # For the first line... 83 | if row_idx == row_st and row_idx == row_ed: 84 | # Case 1. Single-line: highlight col_st..col_ed 85 | c_start = max(1, span.col_st) 86 | c_end = min(len(line_text), span.col_ed) 87 | marker[c_start:c_end] = "^" * (c_end - c_start) 88 | elif row_idx == row_st: 89 | # Case 2. The first line in a multi-line highlight 90 | c_start = max(1, span.col_st) 91 | marker[c_start:] = "^" * (len(line_text) - c_start) 92 | elif row_idx == row_ed: 93 | # Case 3. The last line in a multi-line highlight 94 | c_end = min(len(line_text), span.col_ed) 95 | marker[:c_end] = "^" * c_end 96 | else: 97 | # Case 4. A line in the middle of row_st..row_ed => highlight entire line 98 | marker = ["^"] * len(line_text) 99 | out_lines.append(f"{' ' * (left_margin_width)} | {''.join(marker)}") 100 | return "\n".join(out_lines) 101 | 102 | 103 | def excepthook( 104 | exctype: type[BaseException], 105 | value: BaseException, 106 | traceback: TracebackType | None, 107 | ) -> Any: 108 | should_hide_backtrace = os.environ.get("MLC_BACKTRACE", None) is None 109 | if exctype is DiagnosticError and should_hide_backtrace: 110 | print("note: run with `MLC_BACKTRACE=1` environment variable to display a backtrace.") 111 | return 112 | sys_excepthook(exctype, value, traceback) 113 | 114 | 115 | sys_excepthook = sys.excepthook 116 | sys.excepthook = excepthook 117 | -------------------------------------------------------------------------------- /python/mlc/printer/__init__.py: -------------------------------------------------------------------------------- 1 | from mlc.core import ObjectPath 2 | 3 | from . import ast 4 | from .ast import PrinterConfig 5 | from .ir_printer import ( 6 | Bool, 7 | DefaultFrame, 8 | Float, 9 | Int, 10 | IRPrinter, 11 | None_, 12 | Str, 13 | print_python, 14 | to_python, 15 | ) 16 | -------------------------------------------------------------------------------- /python/mlc/printer/cprint.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | import typing 5 | 6 | if typing.TYPE_CHECKING: 7 | import pygments # type: ignore[import-untyped] 8 | 9 | 10 | def cprint(printable: str, style: str | None = None) -> None: 11 | """Print Python code with Pygments highlight. 12 | 13 | Parameters 14 | ---------- 15 | style : str, optional 16 | 17 | Pygmentize printing style, auto-detected if None. 18 | 19 | Notes 20 | ----- 21 | 22 | The style parameter follows the Pygments style names or Style objects. Three 23 | built-in styles are extended: "light", "dark" and "ansi". By default, "light" 24 | will be used for notebook environment and terminal style will be "ansi" for 25 | better style consistency. As an fallback when the optional Pygment library is 26 | not installed, plain text will be printed with a one-time warning to suggest 27 | installing the Pygment library. Other Pygment styles can be found in 28 | https://pygments.org/styles/ 29 | """ 30 | is_in_notebook = "ipykernel" in sys.modules # in notebook env (support html display). 31 | 32 | pygment_style = _get_pygments_style(style, is_in_notebook) 33 | 34 | if pygment_style is None: 35 | print(printable) 36 | return 37 | 38 | # pylint: disable=import-outside-toplevel 39 | from pygments import highlight # type: ignore[import-untyped] 40 | from pygments.formatters import ( # type: ignore[import-untyped] 41 | HtmlFormatter, 42 | Terminal256Formatter, 43 | ) 44 | from pygments.lexers.python import Python3Lexer # type: ignore[import-untyped] 45 | 46 | if is_in_notebook: 47 | from IPython import display # type: ignore[import-not-found] 48 | 49 | formatter = HtmlFormatter(style=pygment_style) 50 | formatter.noclasses = True # inline styles 51 | html = highlight(printable, Python3Lexer(), formatter) 52 | display.display(display.HTML(html)) 53 | else: 54 | print(highlight(printable, Python3Lexer(), Terminal256Formatter(style=pygment_style))) 55 | 56 | 57 | def _get_pygments_style( 58 | style: str | None, 59 | is_in_notebook: bool, 60 | ) -> pygments.style.Style | str | None: 61 | from pygments.style import Style # type: ignore[import-untyped] 62 | from pygments.token import ( # type: ignore[import-untyped] 63 | Comment, 64 | Keyword, 65 | Name, 66 | Number, 67 | Operator, 68 | String, 69 | ) 70 | 71 | class JupyterLight(Style): 72 | """A Jupyter-Notebook-like Pygments style configuration (aka. "light")""" 73 | 74 | background_color = "" 75 | styles: typing.ClassVar = { 76 | Keyword: "bold #008000", 77 | Keyword.Type: "nobold #008000", 78 | Name.Function: "#0000FF", 79 | Name.Class: "bold #0000FF", 80 | Name.Decorator: "#AA22FF", 81 | String: "#BA2121", 82 | Number: "#008000", 83 | Operator: "bold #AA22FF", 84 | Operator.Word: "bold #008000", 85 | Comment: "italic #007979", 86 | } 87 | 88 | class VSCDark(Style): 89 | """A VSCode-Dark-like Pygments style configuration (aka. "dark")""" 90 | 91 | background_color = "" 92 | styles: typing.ClassVar = { 93 | Keyword: "bold #c586c0", 94 | Keyword.Type: "#82aaff", 95 | Keyword.Namespace: "#4ec9b0", 96 | Name.Class: "bold #569cd6", 97 | Name.Function: "bold #dcdcaa", 98 | Name.Decorator: "italic #fe4ef3", 99 | String: "#ce9178", 100 | Number: "#b5cea8", 101 | Operator: "#bbbbbb", 102 | Operator.Word: "#569cd6", 103 | Comment: "italic #6a9956", 104 | } 105 | 106 | class AnsiTerminalDefault(Style): 107 | """The default style for terminal display with ANSI colors (aka. "ansi")""" 108 | 109 | background_color = "" 110 | styles: typing.ClassVar = { 111 | Keyword: "bold ansigreen", 112 | Keyword.Type: "nobold ansigreen", 113 | Name.Class: "bold ansiblue", 114 | Name.Function: "bold ansiblue", 115 | Name.Decorator: "italic ansibrightmagenta", 116 | String: "ansiyellow", 117 | Number: "ansibrightgreen", 118 | Operator: "bold ansimagenta", 119 | Operator.Word: "bold ansigreen", 120 | Comment: "italic ansibrightblack", 121 | } 122 | 123 | if style == "light": 124 | return JupyterLight 125 | elif style == "dark": 126 | return VSCDark 127 | elif style == "ansi": 128 | return AnsiTerminalDefault 129 | if style is not None: 130 | return style 131 | if is_in_notebook: 132 | return JupyterLight 133 | return AnsiTerminalDefault 134 | -------------------------------------------------------------------------------- /python/mlc/printer/ir_printer.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from collections.abc import Generator 3 | from typing import Any, Optional, TypeVar, Union 4 | 5 | import mlc.dataclasses as mlcd 6 | from mlc.core import Func, Object, ObjectPath 7 | 8 | from .ast import Expr, Literal, Node, PrinterConfig, Stmt 9 | from .cprint import cprint 10 | 11 | 12 | def Int(value: int, source_paths: Optional[list[ObjectPath]] = None) -> Literal: 13 | if source_paths is None: 14 | source_paths = [] 15 | return Literal(value, source_paths=source_paths) 16 | 17 | 18 | def Float(value: float, source_paths: Optional[list[ObjectPath]] = None) -> Literal: 19 | if source_paths is None: 20 | source_paths = [] 21 | return Literal(value, source_paths=source_paths) 22 | 23 | 24 | def Str(value: str, source_paths: Optional[list[ObjectPath]] = None) -> Literal: 25 | if source_paths is None: 26 | source_paths = [] 27 | return Literal(value, source_paths=source_paths) 28 | 29 | 30 | def Bool(value: bool, source_paths: Optional[list[ObjectPath]] = None) -> Literal: 31 | if source_paths is None: 32 | source_paths = [] 33 | return Literal(value, source_paths=source_paths) 34 | 35 | 36 | def None_(source_paths: Optional[list[ObjectPath]] = None) -> Literal: 37 | if source_paths is None: 38 | source_paths = [] 39 | return Literal(None, source_paths=source_paths) 40 | 41 | 42 | @mlcd.c_class("mlc.printer.VarInfo") 43 | class VarInfo(Object): 44 | name: Optional[str] 45 | creator: Func 46 | 47 | 48 | FrameType = TypeVar("FrameType") 49 | 50 | 51 | @mlcd.c_class("mlc.printer.DefaultFrame", init=False) 52 | class DefaultFrame(Object): 53 | stmts: list[Stmt] 54 | 55 | def __init__(self, stmts: Optional[list[Stmt]] = None) -> None: 56 | if stmts is None: 57 | stmts = [] 58 | self._mlc_init(stmts) 59 | 60 | 61 | @mlcd.c_class("mlc.printer.IRPrinter", init=False) 62 | class IRPrinter(Object): 63 | cfg: PrinterConfig 64 | obj2info: dict[Any, VarInfo] 65 | defined_names: dict[str, int] 66 | frames: list[Any] 67 | frame_vars: dict[Any, Any] 68 | 69 | def __init__(self, cfg: Optional[PrinterConfig] = None) -> None: 70 | if cfg is None: 71 | cfg = PrinterConfig() 72 | self._mlc_init(cfg, {}, {}, [], {}) 73 | 74 | def var_is_defined(self, obj: Any) -> bool: 75 | return bool(IRPrinter._C(b"var_is_defined", self, obj)) 76 | 77 | def var_def( 78 | self, 79 | name: str, 80 | obj: Any, 81 | frame: Optional[Any] = None, 82 | ) -> None: 83 | return IRPrinter._C(b"var_def", self, name, obj, frame) 84 | 85 | def var_def_no_name( 86 | self, 87 | creator: Func, 88 | obj: Any, 89 | frame: Optional[Any] = None, 90 | ) -> None: 91 | IRPrinter._C(b"var_def_no_name", self, creator, obj, frame) 92 | 93 | def var_remove(self, obj: Any) -> None: 94 | IRPrinter._C(b"var_remove", self, obj) 95 | 96 | def var_get(self, obj: Any) -> Optional[Expr]: 97 | return IRPrinter._C(b"var_get", self, obj) 98 | 99 | def frame_push(self, frame: Any) -> None: 100 | IRPrinter._C(b"frame_push", self, frame) 101 | 102 | def frame_pop(self) -> None: 103 | IRPrinter._C(b"frame_pop", self) 104 | 105 | def __call__(self, obj: Union[Node, int, str, bool, float, None], path: ObjectPath) -> Node: 106 | return IRPrinter._C(b"__call__", self, obj, path) 107 | 108 | @contextlib.contextmanager 109 | def with_frame(self, frame: FrameType) -> Generator[FrameType, None, None]: 110 | self.frame_push(frame) 111 | try: 112 | yield frame 113 | finally: 114 | self.frame_pop() 115 | 116 | 117 | _C_ToPython = Func.get("mlc.printer.ToPython") 118 | 119 | 120 | def to_python(obj: Any, cfg: Optional[PrinterConfig] = None) -> str: 121 | if cfg is None: 122 | cfg = PrinterConfig() 123 | return _C_ToPython(obj, cfg) 124 | 125 | 126 | def print_python( 127 | obj: Any, 128 | cfg: Optional[PrinterConfig] = None, 129 | style: Optional[str] = None, 130 | ) -> None: 131 | cprint(to_python(obj, cfg=cfg), style=style) 132 | -------------------------------------------------------------------------------- /python/mlc/sym/__init__.py: -------------------------------------------------------------------------------- 1 | from .analyzer import Analyzer 2 | from .expr import ( 3 | EQ, 4 | GE, 5 | GT, 6 | LE, 7 | LT, 8 | NE, 9 | Add, 10 | And, 11 | BoolImm, 12 | Broadcast, 13 | Call, 14 | Cast, 15 | Div, 16 | Expr, 17 | FloatImm, 18 | FloorDiv, 19 | FloorMod, 20 | IntImm, 21 | Let, 22 | Max, 23 | Min, 24 | Mod, 25 | Mul, 26 | Not, 27 | Op, 28 | Or, 29 | Ramp, 30 | Range, 31 | Select, 32 | ShapeVar, 33 | Shuffle, 34 | Sub, 35 | Var, 36 | const, 37 | ) 38 | from .op import ( 39 | abs, 40 | broadcast, 41 | cast, 42 | equal, 43 | floordiv, 44 | floormod, 45 | greater, 46 | greater_equal, 47 | if_then_else, 48 | less, 49 | less_equal, 50 | let, 51 | logical_and, 52 | logical_not, 53 | logical_or, 54 | max, 55 | max_value, 56 | min, 57 | min_value, 58 | not_equal, 59 | ramp, 60 | select, 61 | truncdiv, 62 | truncmod, 63 | ) 64 | -------------------------------------------------------------------------------- /python/mlc/sym/_internal.py: -------------------------------------------------------------------------------- 1 | """System internals that are not part of the public API. Testing-only.""" 2 | 3 | from __future__ import annotations 4 | 5 | import contextlib 6 | from collections.abc import Generator 7 | 8 | import mlc.dataclasses as mlcd 9 | from mlc.core import Func, Object 10 | 11 | from .analyzer import Analyzer 12 | from .expr import Expr, Var, const 13 | 14 | 15 | @mlcd.c_class("mlc.sym.ConstIntBound") 16 | class ConstIntBound(Object): 17 | min_value: int 18 | max_value: int 19 | 20 | 21 | @mlcd.c_class("mlc.sym.IntervalSet", init=False) 22 | class IntervalSet(Object): 23 | min_value: Expr 24 | max_value: Expr 25 | 26 | def __init__(self, min_value: Expr | int, max_value: Expr | int) -> None: 27 | if isinstance(min_value, int) and isinstance(max_value, int): 28 | min_value = const("int32", min_value) 29 | max_value = const("int32", max_value) 30 | if isinstance(min_value, int): 31 | assert isinstance(max_value, Expr) 32 | min_value = const(max_value.dtype, min_value) 33 | if isinstance(max_value, int): 34 | assert isinstance(min_value, Expr) 35 | max_value = const(min_value.dtype, max_value) 36 | self._mlc_init(min_value, max_value) 37 | 38 | 39 | @mlcd.c_class("mlc.sym.ModularSet") 40 | class ModularSet(Object): 41 | coeff: int 42 | base: int 43 | 44 | 45 | def const_int_bound(analyzer: Analyzer, expr: Expr) -> ConstIntBound: 46 | return _Analyzer_ConstIntBound(analyzer, expr) 47 | 48 | 49 | def modular_set(analyzer: Analyzer, expr: Expr) -> ModularSet: 50 | return _Analyzer_ModularSet(analyzer, expr) 51 | 52 | 53 | def rewrite_simplify(analyzer: Analyzer, expr: Expr) -> Expr: 54 | return _Analyzer_RewriteSimplify(analyzer, expr) 55 | 56 | 57 | def canonical_simplify(analyzer: Analyzer, expr: Expr) -> Expr: 58 | return _Analyzer_CanonicalSimplify(analyzer, expr) 59 | 60 | 61 | def interval_set( 62 | analyzer: Analyzer, 63 | expr: Expr, 64 | dom_map: dict[Var, IntervalSet], 65 | ) -> IntervalSet: 66 | return _Analyzer_IntervalSet(analyzer, expr, dom_map) 67 | 68 | 69 | def const_int_bound_update( 70 | analyzer: Analyzer, 71 | var: Var, 72 | info: ConstIntBound, 73 | allow_override: bool = False, 74 | ) -> None: 75 | return _Analyzer_ConstIntBoundUpdate(analyzer, var, info, allow_override) 76 | 77 | 78 | def get_enabled_extensions(analyzer: Analyzer) -> int: 79 | return _Analyzer_GetEnabledExtensions(analyzer) 80 | 81 | 82 | def set_enabled_extensions(analyzer: Analyzer, flags: int) -> None: 83 | return _Analyzer_SetEnabledExtensions(analyzer, flags) 84 | 85 | 86 | @contextlib.contextmanager 87 | def enter_constraint(analyzer: Analyzer, constraint: Expr | None) -> Generator[None, None, None]: 88 | if constraint is None: 89 | yield 90 | return 91 | exit_constraint: Func = _Analyzer_EnterConstraint(analyzer, constraint) 92 | try: 93 | yield 94 | finally: 95 | exit_constraint() 96 | 97 | 98 | _Analyzer_ConstIntBound = Func.get("mlc.sym._internal.Analyzer.ConstIntBound") 99 | _Analyzer_ModularSet = Func.get("mlc.sym._internal.Analyzer.ModularSet") 100 | _Analyzer_RewriteSimplify = Func.get("mlc.sym._internal.Analyzer.RewriteSimplify") 101 | _Analyzer_CanonicalSimplify = Func.get("mlc.sym._internal.Analyzer.CanonicalSimplify") 102 | _Analyzer_IntervalSet = Func.get("mlc.sym._internal.Analyzer.IntervalSet") 103 | _Analyzer_ConstIntBoundUpdate = Func.get("mlc.sym._internal.Analyzer.ConstIntBoundUpdate") 104 | _Analyzer_GetEnabledExtensions = Func.get("mlc.sym._internal.Analyzer.GetEnabledExtensions") 105 | _Analyzer_SetEnabledExtensions = Func.get("mlc.sym._internal.Analyzer.SetEnabledExtensions") 106 | _Analyzer_EnterConstraint = Func.get("mlc.sym._internal.Analyzer.EnterConstraint") 107 | -------------------------------------------------------------------------------- /python/mlc/sym/analyzer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Literal 4 | 5 | import mlc.dataclasses as mlcd 6 | from mlc.core import Object 7 | 8 | if TYPE_CHECKING: 9 | from .expr import Expr, Range, Var 10 | 11 | 12 | @mlcd.c_class("mlc.sym.Analyzer") 13 | class Analyzer(Object): 14 | def mark_global_non_neg_value(self, v: Expr) -> None: 15 | Analyzer._C(b"_mark_global_non_neg_value", self, v) 16 | 17 | def bind( 18 | self, 19 | v: Var, 20 | bound: Range | Expr | int | float, 21 | allow_override: bool = False, 22 | ) -> None: 23 | from .expr import Expr, Range, const 24 | 25 | if isinstance(bound, Range): 26 | Analyzer._C(b"_bind_range", self, v, bound, allow_override) 27 | elif isinstance(bound, Expr): 28 | Analyzer._C(b"_bind_expr", self, v, bound, allow_override) 29 | elif isinstance(bound, (int, float)): 30 | Analyzer._C(b"_bind_expr", self, v, const(v.dtype, bound), allow_override) 31 | else: 32 | raise TypeError(f"Unsupported type for bound: {type(bound)}") 33 | 34 | def can_prove_greater_equal(self, a: Expr, b: int) -> bool: 35 | assert isinstance(b, int) 36 | return Analyzer._C(b"can_prove_greater_equal", self, a, b) 37 | 38 | def can_prove_less(self, a: Expr, b: int) -> bool: 39 | assert isinstance(b, int) 40 | return Analyzer._C(b"can_prove_less", self, a, b) 41 | 42 | def can_prove_equal(self, a: Expr, b: Expr | int) -> bool: 43 | from .expr import Expr, const 44 | 45 | assert isinstance(a, Expr) 46 | if isinstance(b, int): 47 | b = const(a.dtype, b) 48 | return Analyzer._C(b"can_prove_equal", self, a, b) 49 | 50 | def can_prove_less_equal_than_symbolic_shape_value(self, a: Expr, b: Expr) -> bool: 51 | return Analyzer._C(b"can_prove_less_equal_than_symbolic_shape_value", self, a, b) 52 | 53 | def can_prove( 54 | self, 55 | cond: Expr, 56 | *, 57 | strength: Literal["default", "symbolic_bound"] = "default", 58 | ) -> bool: 59 | return Analyzer._C(b"can_prove", self, cond, _STRENGTH[strength]) 60 | 61 | def simplify(self, expr: Expr, *, steps: int = 2) -> Expr: 62 | return Analyzer._C(b"simplify", self, expr, steps) 63 | 64 | 65 | _STRENGTH = { 66 | "default": 0, 67 | "symbolic_bound": 1, 68 | } 69 | -------------------------------------------------------------------------------- /python/mlc/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlc-ai/mlc-python/0bb1ac18e9980013c96325f7e9c2d30c78835a91/python/mlc/testing/__init__.py -------------------------------------------------------------------------------- /python/mlc/testing/dataclasses.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import mlc 4 | 5 | 6 | @mlc.c_class("mlc.testing.c_class") 7 | class CClassForTest(mlc.Object): 8 | bool_: bool 9 | i8: int 10 | i16: int 11 | i32: int 12 | i64: int 13 | f32: float 14 | f64: float 15 | raw_ptr: mlc.Ptr 16 | dtype: mlc.DataType 17 | device: mlc.Device 18 | any: Any 19 | func: mlc.Func 20 | ulist: list[Any] 21 | udict: dict 22 | str_: str 23 | str_readonly: str 24 | ### 25 | list_any: list[Any] 26 | list_list_int: list[list[int]] 27 | dict_any_any: dict[Any, Any] 28 | dict_str_any: dict[str, Any] 29 | dict_any_str: dict[Any, str] 30 | dict_str_list_int: dict[str, list[int]] 31 | ### 32 | opt_bool: Optional[bool] 33 | opt_i64: Optional[int] 34 | opt_f64: Optional[float] 35 | opt_raw_ptr: Optional[mlc.Ptr] 36 | opt_dtype: Optional[mlc.DataType] 37 | opt_device: Optional[mlc.Device] 38 | opt_func: Optional[mlc.Func] 39 | opt_ulist: Optional[list] 40 | opt_udict: Optional[dict[Any, Any]] 41 | opt_str: Optional[str] 42 | ### 43 | opt_list_any: Optional[list[Any]] 44 | opt_list_list_int: Optional[list[list[int]]] 45 | opt_dict_any_any: Optional[dict] 46 | opt_dict_str_any: Optional[dict[str, Any]] 47 | opt_dict_any_str: Optional[dict[Any, str]] 48 | opt_dict_str_list_int: Optional[dict[str, list[int]]] 49 | 50 | def i64_plus_one(self) -> int: 51 | return type(self)._C(b"i64_plus_one", self) 52 | 53 | 54 | @mlc.py_class("mlc.testing.py_class") 55 | class PyClassForTest(mlc.PyClass): 56 | bool_: bool 57 | i8: int # `py_class` doesn't support `int8`, it will effectively be `int64_t` 58 | i16: int # `py_class` doesn't support `int16`, it will effectively be `int64_t` 59 | i32: int # `py_class` doesn't support `int32`, it will effectively be `int64_t` 60 | i64: int 61 | f32: float # `py_class` doesn't support `float32`, it will effectively be `float64` 62 | f64: float 63 | raw_ptr: mlc.Ptr 64 | dtype: mlc.DataType 65 | device: mlc.Device 66 | any: Any 67 | func: mlc.Func 68 | ulist: list[Any] 69 | udict: dict 70 | str_: str 71 | str_readonly: str 72 | ### 73 | list_any: list[Any] 74 | list_list_int: list[list[int]] 75 | dict_any_any: dict[Any, Any] 76 | dict_str_any: dict[str, Any] 77 | dict_any_str: dict[Any, str] 78 | dict_str_list_int: dict[str, list[int]] 79 | ### 80 | opt_bool: Optional[bool] 81 | opt_i64: Optional[int] 82 | opt_f64: Optional[float] 83 | opt_raw_ptr: Optional[mlc.Ptr] 84 | opt_dtype: Optional[mlc.DataType] 85 | opt_device: Optional[mlc.Device] 86 | opt_func: Optional[mlc.Func] 87 | opt_ulist: Optional[list] 88 | opt_udict: Optional[dict[Any, Any]] 89 | opt_str: Optional[str] 90 | ### 91 | opt_list_any: Optional[list[Any]] 92 | opt_list_list_int: Optional[list[list[int]]] 93 | opt_dict_any_any: Optional[dict] 94 | opt_dict_str_any: Optional[dict[str, Any]] 95 | opt_dict_any_str: Optional[dict[Any, str]] 96 | opt_dict_str_list_int: Optional[dict[str, list[int]]] 97 | 98 | def i64_plus_one(self) -> int: 99 | return self.i64 + 1 100 | 101 | 102 | def visit_fields(obj: mlc.Object) -> list[tuple[str, str, Any]]: 103 | types, names, values = _C_VisitFields(obj) 104 | return list(zip(types, names, values)) 105 | 106 | 107 | def field_get(obj: mlc.Object, name: str) -> Any: 108 | return _C_FieldGet(obj, name) 109 | 110 | 111 | def field_set(obj: mlc.Object, name: str, value: Any) -> None: 112 | _C_FieldSet(obj, name, value) 113 | 114 | 115 | _C_VisitFields = mlc.Func.get("mlc.testing.VisitFields") 116 | _C_FieldGet = mlc.Func.get("mlc.testing.FieldGet") 117 | _C_FieldSet = mlc.Func.get("mlc.testing.FieldSet") 118 | -------------------------------------------------------------------------------- /python/mlc/testing/toy_ir/__init__.py: -------------------------------------------------------------------------------- 1 | from .ir import Add, Assign, Expr, Func, Node, Stmt, Var 2 | from .ir_builder import FunctionFrame, IRBuilder 3 | from .parser import Parser, parse_func 4 | -------------------------------------------------------------------------------- /python/mlc/testing/toy_ir/ir.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import mlc.dataclasses as mlcd 4 | import mlc.printer as mlcp 5 | import mlc.printer.ast as mlt 6 | 7 | 8 | @mlcd.py_class 9 | class Node(mlcd.PyClass): ... 10 | 11 | 12 | @mlcd.py_class 13 | class Expr(Node): ... 14 | 15 | 16 | @mlcd.py_class 17 | class Stmt(Node): ... 18 | 19 | 20 | @mlcd.py_class(structure="var") 21 | class Var(Expr): 22 | name: str = mlcd.field(structure=None) 23 | 24 | def __add__(self, other: Var) -> Add: 25 | return Add(lhs=self, rhs=other) 26 | 27 | def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: 28 | if not printer.var_is_defined(obj=self): 29 | printer.var_def(self.name, obj=self) 30 | ret = printer.var_get(obj=self) 31 | assert ret is not None 32 | return ret 33 | 34 | 35 | @mlcd.py_class(structure="nobind") 36 | class Add(Expr): 37 | lhs: Expr 38 | rhs: Expr 39 | 40 | def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: 41 | lhs: mlt.Expr = printer(self.lhs, path=path["lhs"]) 42 | rhs: mlt.Expr = printer(self.rhs, path=path["rhs"]) 43 | return lhs + rhs 44 | 45 | 46 | @mlcd.py_class(structure="bind") 47 | class Assign(Stmt): 48 | rhs: Expr 49 | lhs: Var = mlcd.field(structure="bind") 50 | 51 | def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: 52 | rhs: mlt.Expr = printer(self.rhs, path=path["rhs"]) 53 | printer.var_def(self.lhs.name, obj=self.lhs) 54 | lhs: mlt.Expr = printer(self.lhs, path=path["lhs"]) 55 | return mlt.Assign(lhs=lhs, rhs=rhs) 56 | 57 | 58 | @mlcd.py_class(structure="bind") 59 | class Func(Node): 60 | name: str = mlcd.field(structure=None) 61 | args: list[Var] = mlcd.field(structure="bind") 62 | stmts: list[Stmt] 63 | ret: Var 64 | 65 | def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: 66 | with printer.with_frame(mlcp.DefaultFrame()): 67 | for arg in self.args: 68 | printer.var_def(arg.name, obj=arg) 69 | args: list[mlt.Expr] = [ 70 | printer(arg, path=path["args"][i]) for i, arg in enumerate(self.args) 71 | ] 72 | stmts: list[mlt.Expr] = [ 73 | printer(stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts) 74 | ] 75 | ret_stmt = mlt.Return(printer(self.ret, path=path["ret"])) 76 | return mlt.Function( 77 | name=mlt.Id(self.name), 78 | args=[mlt.Assign(lhs=arg, rhs=None) for arg in args], 79 | decorators=[], 80 | return_type=None, 81 | body=[*stmts, ret_stmt], 82 | ) 83 | -------------------------------------------------------------------------------- /python/mlc/testing/toy_ir/ir_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from types import TracebackType 4 | from typing import Any, ClassVar 5 | 6 | from .ir import Func, Stmt, Var 7 | 8 | 9 | class IRBuilder: 10 | _ctx: ClassVar[list[IRBuilder]] = [] 11 | frames: list[Any] 12 | result: Any 13 | 14 | def __init__(self) -> None: 15 | self.frames = [] 16 | self.result = None 17 | 18 | def __enter__(self) -> IRBuilder: 19 | IRBuilder._ctx.append(self) 20 | return self 21 | 22 | def __exit__( 23 | self, 24 | exc_type: type[BaseException], 25 | exc_value: BaseException, 26 | traceback: TracebackType, 27 | ) -> None: 28 | IRBuilder._ctx.pop() 29 | 30 | @staticmethod 31 | def get() -> IRBuilder: 32 | return IRBuilder._ctx[-1] 33 | 34 | 35 | class FunctionFrame: 36 | name: str 37 | args: list[Var] 38 | stmts: list[Stmt] 39 | ret: Var | None 40 | 41 | def __init__(self, name: str) -> None: 42 | self.name = name 43 | self.args = [] 44 | self.stmts = [] 45 | self.ret = None 46 | 47 | def add_arg(self, arg: Var) -> Var: 48 | self.args.append(arg) 49 | return arg 50 | 51 | def __enter__(self) -> FunctionFrame: 52 | IRBuilder.get().frames.append(self) 53 | return self 54 | 55 | def __exit__( 56 | self, 57 | exc_type: type[BaseException], 58 | exc_value: BaseException, 59 | traceback: TracebackType, 60 | ) -> None: 61 | frame = IRBuilder.get().frames.pop() 62 | assert frame is self 63 | if exc_type is None: 64 | IRBuilder.get().result = Func( 65 | name=self.name, 66 | args=frame.args, 67 | stmts=frame.stmts, 68 | ret=frame.ret, 69 | ) 70 | -------------------------------------------------------------------------------- /python/mlc/testing/toy_ir/parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import ast 4 | from typing import Callable 5 | 6 | import mlc.parser as mlcs 7 | 8 | from .ir import Assign, Node, Var 9 | from .ir_builder import FunctionFrame, IRBuilder 10 | 11 | 12 | class Parser(ast.NodeVisitor): 13 | base: mlcs.Parser 14 | 15 | def __init__(self, env: mlcs.Env) -> None: 16 | self.base = mlcs.Parser(env, include_builtins=True, extra_vars=None) 17 | 18 | def visit_Assign(self, node: ast.Assign) -> None: 19 | if len(node.targets) != 1: 20 | self.base.report_error(node, "Continuous assignment is not supported") 21 | (target,) = node.targets 22 | if not isinstance(target, ast.Name): 23 | self.base.report_error(target, "Invalid assignment target") 24 | assert isinstance(target, ast.Name) 25 | value = self.base.eval_assign( 26 | target=target, 27 | source=self.base.eval_expr(node.value), 28 | )[target.id] 29 | var = Var(name=target.id) 30 | self.base.var_def(name=target.id, value=var) 31 | IRBuilder.get().frames[-1].stmts.append( 32 | Assign( 33 | lhs=var, 34 | rhs=value, 35 | ) 36 | ) 37 | 38 | def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 39 | with mlcs.Frame().scope(self.base): 40 | with FunctionFrame(node.name) as frame: 41 | for node_arg in node.args.args: 42 | self.base.var_def( 43 | name=node_arg.arg, 44 | value=frame.add_arg(Var(name=node_arg.arg)), 45 | ) 46 | for node_stmt in node.body: 47 | self.visit(node_stmt) 48 | 49 | def visit_Return(self, node: ast.Return) -> None: 50 | if not isinstance(node.value, ast.Name): 51 | self.base.report_error(node, "Return statement must return a single variable") 52 | assert isinstance(node.value, ast.Name) 53 | frame: FunctionFrame = IRBuilder.get().frames[-1] 54 | frame.ret = self.base.eval_expr(node.value) 55 | 56 | 57 | def parse_func(source: Callable) -> Node: 58 | env = mlcs.Env.from_function(source) 59 | parser = Parser(env) 60 | node = ast.parse( 61 | env.source, 62 | filename=env.source_name, 63 | ) 64 | with IRBuilder() as ib: 65 | parser.visit(node) 66 | return ib.result 67 | -------------------------------------------------------------------------------- /scripts/cpp_tests.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal enabledelayedexpansion 3 | 4 | set BUILD_TYPE=RelWithDebInfo 5 | set BUILD_DIR=build-cpp-tests/ 6 | set GTEST_COLOR=1 7 | 8 | if exist %BUILD_DIR% rmdir /s /q %BUILD_DIR% 9 | cmake -S . -B %BUILD_DIR% ^ 10 | -DMLC_BUILD_TESTS=ON ^ 11 | -DCMAKE_BUILD_TYPE=%BUILD_TYPE% ^ 12 | -DCMAKE_EXPORT_COMPILE_COMMANDS=ON 13 | if errorlevel 1 goto :error 14 | 15 | cmake --build %BUILD_DIR% ^ 16 | --config %BUILD_TYPE% ^ 17 | --target mlc_tests ^ 18 | -j %NUMBER_OF_PROCESSORS% ^ 19 | -- -verbosity:detailed 20 | if errorlevel 1 goto :error 21 | 22 | ctest -V -C %BUILD_TYPE% --test-dir %BUILD_DIR% --output-on-failure 23 | if errorlevel 1 goto :error 24 | 25 | rmdir /s /q %BUILD_DIR% 26 | goto :eof 27 | 28 | :error 29 | echo Script failed with error #%errorlevel%. 30 | exit /b %errorlevel% 31 | :eof 32 | 33 | endlocal 34 | -------------------------------------------------------------------------------- /scripts/cpp_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | BUILD_TYPE=RelWithDebInfo 5 | BUILD_DIR=build-cpp-tests/ 6 | if [[ "$(uname)" == "Darwin" ]]; then 7 | NUM_PROCS=$(sysctl -n hw.ncpu) 8 | else 9 | NUM_PROCS=$(nproc) 10 | fi 11 | 12 | cmake -S . -B ${BUILD_DIR} \ 13 | -DMLC_BUILD_TESTS=ON \ 14 | -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ 15 | -DCMAKE_EXPORT_COMPILE_COMMANDS=ON 16 | cmake --build ${BUILD_DIR} \ 17 | --config ${BUILD_TYPE} \ 18 | --target mlc_tests \ 19 | -j "${NUM_PROCS}" 20 | GTEST_COLOR=1 ctest -V -C ${BUILD_TYPE} --test-dir ${BUILD_DIR} --output-on-failure 21 | -------------------------------------------------------------------------------- /scripts/setup_manylinux2014.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo 4 | sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo 5 | yum clean all 6 | yum list installed 7 | yum install -y devtoolset-10-libasan-devel devtoolset-10-liblsan-devel devtoolset-10-libtsan-devel devtoolset-10-libubsan-devel 8 | -------------------------------------------------------------------------------- /scripts/show_wheel_content.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | from pathlib import Path 4 | from zipfile import ZipFile 5 | 6 | 7 | def main() -> None: 8 | directory = Path(sys.argv[1]) 9 | # for each file in the directory that ends with `.whl` 10 | for file in directory.iterdir(): 11 | if file.suffix == ".whl": 12 | # print the name of the file 13 | print(f"============= Wheel: {file} =============") 14 | # open the file as a zip file 15 | with ZipFile(file) as z: 16 | # print the names of the files in the zip file 17 | print(*z.namelist(), sep="\n") 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /tests/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB _test_sources "${CMAKE_CURRENT_SOURCE_DIR}/test*.cc") 2 | add_executable(mlc_tests ${_test_sources}) 3 | set_target_properties( 4 | mlc_tests PROPERTIES 5 | POSITION_INDEPENDENT_CODE ON 6 | CXX_STANDARD 17 7 | CXX_EXTENSIONS OFF 8 | CXX_STANDARD_REQUIRED ON 9 | MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL" 10 | ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" 11 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" 12 | RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" 13 | ) 14 | target_compile_definitions(mlc_tests PRIVATE MLC_CPPTESTS_EXPORTS) 15 | if(MSVC) 16 | add_custom_command( 17 | TARGET mlc_tests 18 | POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $ $ 19 | COMMAND_EXPAND_LISTS 20 | ) 21 | endif() 22 | add_cxx_warning(mlc_tests) 23 | add_sanitizer_address(mlc_tests) 24 | add_sanitizer_address(mlc-shared) 25 | target_include_directories(mlc_tests PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/dlpack/include") 26 | target_link_libraries(mlc_tests PRIVATE mlc-shared) 27 | add_googletest(mlc_tests) 28 | -------------------------------------------------------------------------------- /tests/cpp/common.h: -------------------------------------------------------------------------------- 1 | #ifndef MLC_CPPTESTS_EXPORTS 2 | #define MLC_CPPTESTS_EXPORTS 0 3 | #else 4 | #undef MLC_CPPTESTS_EXPORTS 5 | #define MLC_CPPTESTS_EXPORTS 1 6 | #endif 7 | -------------------------------------------------------------------------------- /tests/cpp/test_core_str.cc: -------------------------------------------------------------------------------- 1 | #include "./common.h" 2 | #include 3 | #include 4 | 5 | namespace { 6 | using namespace mlc; 7 | 8 | TEST(StrObj, DefaultConstructor) { 9 | Ref str = Ref::New(""); 10 | EXPECT_EQ(str->size(), 0); 11 | EXPECT_STREQ(str->c_str(), ""); 12 | } 13 | 14 | TEST(StrObj, ConstructorFromNullptr) { 15 | try { 16 | Ref::New(static_cast(nullptr)); 17 | FAIL() << "No exception thrown"; 18 | } catch (Exception &ex) { 19 | EXPECT_STREQ(ex.what(), "Cannot create StrObj from nullptr"); 20 | } 21 | } 22 | 23 | TEST(StrObj, ConstructorFromCStyleCharPointer) { 24 | const char *c_str = "Hello, World!"; 25 | Ref str = Ref::New(c_str); 26 | EXPECT_EQ(str->size(), strlen(c_str)); 27 | EXPECT_STREQ(str->c_str(), c_str); 28 | } 29 | 30 | TEST(StrObj, ConstructorFromCStyleCharArray) { 31 | constexpr int64_t len = 23; 32 | char c_str[len] = "Hello, World!"; 33 | Ref str = Ref::New(c_str); 34 | #ifndef _MSC_VER 35 | EXPECT_EQ(str->size(), len - 1); 36 | #endif 37 | EXPECT_STREQ(str->c_str(), c_str); 38 | } 39 | 40 | TEST(StrObj, ConstructorFromStdString) { 41 | std::string std_str = "Hello, World!"; 42 | Ref str = Ref::New(std_str); 43 | EXPECT_EQ(str->size(), std_str.size()); 44 | EXPECT_STREQ(str->c_str(), std_str.c_str()); 45 | } 46 | 47 | TEST(StrObj, ConstructorFromStdStringRValue) { 48 | std::string std_str = "Hello, World!"; 49 | Ref str = Ref::New(std::move(std_str)); 50 | EXPECT_EQ(str->size(), 13); 51 | EXPECT_STREQ(str->c_str(), "Hello, World!"); 52 | } 53 | 54 | TEST(Str, DefaultConstructor) { 55 | Str str(Null); 56 | EXPECT_EQ(str.get(), nullptr); 57 | } 58 | 59 | TEST(Str, ConstructorFromNullptr) { 60 | const char *c_str = nullptr; 61 | try { 62 | Str str(c_str); 63 | FAIL() << "No exception thrown"; 64 | } catch (Exception &ex) { 65 | EXPECT_STREQ(ex.what(), "Cannot create StrObj from nullptr"); 66 | } 67 | } 68 | 69 | TEST(Str, ConstructorFromCStyleCharPointer) { 70 | const char *c_str = "Hello, World!"; 71 | Str str(c_str); 72 | EXPECT_EQ(str.size(), strlen(c_str)); 73 | EXPECT_STREQ(str.c_str(), c_str); 74 | } 75 | 76 | TEST(Str, ConstructorFromCStyleCharArray) { 77 | constexpr int64_t len = 23; 78 | char c_str[len] = "Hello, World!"; 79 | Str str(c_str); 80 | #ifndef _MSC_VER 81 | EXPECT_EQ(str.size(), len - 1); 82 | #endif 83 | EXPECT_STREQ(str.c_str(), c_str); 84 | } 85 | 86 | TEST(Str, ConstructorFromStdString) { 87 | std::string std_str = "Hello, World!"; 88 | Str str(std_str); 89 | EXPECT_EQ(str.size(), std_str.size()); 90 | EXPECT_STREQ(str.c_str(), std_str.c_str()); 91 | } 92 | 93 | TEST(Str, ConstructorFromStdStringRValue) { 94 | std::string std_str = "Hello, World!"; 95 | Str str(std::move(std_str)); 96 | EXPECT_EQ(str.size(), 13); 97 | EXPECT_STREQ(str.c_str(), "Hello, World!"); 98 | } 99 | 100 | TEST(Str, CopyConstructor) { 101 | Str str1("Hello, World!"); 102 | Str str2(str1); 103 | EXPECT_EQ(str1.get(), str2.get()); 104 | } 105 | 106 | TEST(Str, MoveConstructor) { 107 | Str str1("Hello, World!"); 108 | Str str2(std::move(str1)); 109 | EXPECT_EQ(str1.get(), nullptr); 110 | EXPECT_EQ(str2.size(), 13); 111 | EXPECT_STREQ(str2.c_str(), "Hello, World!"); 112 | } 113 | 114 | TEST(Str, CopyAssignment) { 115 | Str str1("Hello, World!"); 116 | Str str2("Test"); 117 | str2 = str1; 118 | EXPECT_EQ(str1.get(), str2.get()); 119 | } 120 | 121 | TEST(Str, MoveAssignment) { 122 | Str str1("Hello, World!"); 123 | Str str2("Test"); 124 | StrObj *str_ptr = str1.get(); 125 | str2 = std::move(str1); 126 | EXPECT_EQ(str1.get(), nullptr); 127 | EXPECT_EQ(str2.get(), str_ptr); 128 | } 129 | 130 | TEST(Str, Comparison) { 131 | Str str1("Hello"); 132 | Str str2("World"); 133 | Str str3("Hello"); 134 | std::string std_str1 = "Hello"; 135 | std::string std_str2 = "World"; 136 | 137 | EXPECT_TRUE(str1 < str2); 138 | EXPECT_FALSE(str1 > str2); 139 | EXPECT_TRUE(str1 <= str3); 140 | EXPECT_TRUE(str1 >= str3); 141 | EXPECT_TRUE(str1 == str3); 142 | EXPECT_TRUE(str1 != str2); 143 | 144 | EXPECT_TRUE(str1 < "World"); 145 | EXPECT_TRUE("World" > str1); 146 | EXPECT_TRUE(str1 <= "Hello"); 147 | EXPECT_TRUE("Hello" >= str1); 148 | EXPECT_TRUE(str1 == "Hello"); 149 | EXPECT_TRUE(str1 != "World"); 150 | 151 | EXPECT_TRUE(str1 < std_str2); 152 | EXPECT_TRUE(std_str2 > str1); 153 | EXPECT_TRUE(str1 <= std_str1); 154 | EXPECT_TRUE(std_str1 >= str1); 155 | EXPECT_TRUE(str1 == std_str1); 156 | EXPECT_TRUE(str1 != std_str2); 157 | } 158 | 159 | TEST(Str, StreamOperator) { 160 | Str str("Hello, World!"); 161 | std::ostringstream oss; 162 | oss << str; 163 | EXPECT_EQ(oss.str(), "Hello, World!"); 164 | } 165 | 166 | } // namespace 167 | -------------------------------------------------------------------------------- /tests/cpp/test_core_udict_legacy.cc: -------------------------------------------------------------------------------- 1 | #include "./common.h" 2 | #include 3 | #include 4 | #include 5 | 6 | namespace { 7 | 8 | using namespace mlc; 9 | using mlc::base::AnyEqual; 10 | using mlc::base::DType; 11 | 12 | TEST(Legacy_UDict_Construtor, Default) { 13 | UDict dict; 14 | ASSERT_EQ(dict->size(), 0); 15 | MLCDict *dict_ptr = reinterpret_cast(dict.get()); 16 | EXPECT_EQ(dict_ptr->_mlc_header.type_index, static_cast(MLCTypeIndex::kMLCDict)); 17 | EXPECT_EQ(dict_ptr->_mlc_header.ref_cnt, 1); 18 | EXPECT_NE(dict_ptr->_mlc_header.v.deleter, nullptr); 19 | EXPECT_EQ(dict_ptr->size, 0); 20 | EXPECT_EQ(dict_ptr->capacity, 0); 21 | } 22 | 23 | TEST(Legacy_UDict_Construtor, InitializerList) { 24 | UDict dict{{"key1", 1}, {"key2", "value2"}, {3, 4}}; 25 | EXPECT_EQ(dict->size(), 3); 26 | EXPECT_EQ(int(dict["key1"]), 1); 27 | EXPECT_EQ(dict["key2"].operator std::string(), "value2"); 28 | EXPECT_EQ(int(dict[3]), 4); 29 | 30 | bool found[3] = {false, false, false}; 31 | for (const auto &kv : *dict) { 32 | if (AnyEqual(kv.first, Any("key1"))) { 33 | found[0] = true; 34 | EXPECT_EQ(int(kv.second), 1); 35 | } else if (AnyEqual(kv.first, Any("key2"))) { 36 | found[1] = true; 37 | EXPECT_EQ(kv.second.operator std::string(), "value2"); 38 | } else if (AnyEqual(kv.first, Any(3))) { 39 | found[2] = true; 40 | EXPECT_EQ(int(kv.second), 4); 41 | } else { 42 | FAIL() << "Unexpected key: " << kv.first; 43 | } 44 | } 45 | EXPECT_TRUE(found[0]); 46 | EXPECT_TRUE(found[1]); 47 | EXPECT_TRUE(found[2]); 48 | } 49 | 50 | TEST(Legacy_UDict_Insert, New) { 51 | int64_t integer = 100; 52 | double fp = 1.0; 53 | std::string str = "Hi"; 54 | DLDataType dtype{kDLInt, 32, 1}; 55 | DLDevice device{kDLCPU, 0}; 56 | Ref obj = Ref::New(); 57 | Ref null_obj{nullptr}; 58 | UDict dict{{integer, fp}, {str, dtype}, {null_obj, 0}}; 59 | dict[device] = null_obj; 60 | EXPECT_EQ(dict->size(), 4); 61 | EXPECT_DOUBLE_EQ(double(dict[integer]), fp); 62 | EXPECT_PRED2(DType::Equal, DLDataType(dict[str]), dtype); 63 | EXPECT_EQ(int(dict[null_obj]), 0); 64 | EXPECT_EQ((Object *)(dict[device]), nullptr); 65 | } 66 | 67 | TEST(Legacy_UDict_Insert, Override) { 68 | UDict dict{{"key1", 1}, {"key2", "value2"}, {3, 4}}; 69 | EXPECT_EQ(dict->size(), 3); 70 | dict["key1"] = 2; 71 | dict["key2"] = "new_value"; 72 | dict[3] = 5; 73 | EXPECT_EQ(dict->size(), 3); 74 | EXPECT_EQ(int(dict["key1"]), 2); 75 | EXPECT_EQ(dict["key2"].operator std::string(), "new_value"); 76 | EXPECT_EQ(int(dict[3]), 5); 77 | } 78 | 79 | TEST(Legacy_UDict_At, Found) { 80 | int64_t integer = 100; 81 | double fp = 1.0; 82 | std::string str = "Hi"; 83 | DLDataType dtype{kDLInt, 32, 1}; 84 | Ref obj = Ref::New(); 85 | Ref null_obj{nullptr}; 86 | UDict dict{{integer, fp}, {str, dtype}, {null_obj, 0}}; 87 | EXPECT_DOUBLE_EQ(double(dict->at(integer)), fp); 88 | EXPECT_PRED2(DType::Equal, DLDataType(dict->at(str)), dtype); 89 | EXPECT_EQ(int(dict->at(null_obj)), 0); 90 | } 91 | 92 | TEST(Legacy_UDict_At, NotFound) { 93 | UDict dict{{"key1", 1}, {"key2", "value2"}, {3, 4}}; 94 | try { 95 | dict->at("key3"); 96 | FAIL() << "Expected Exception"; 97 | } catch (const Exception &) { 98 | } 99 | } 100 | 101 | TEST(Legacy_UDict_ReHash, POD) { 102 | UDict dict; 103 | for (int j = 0; j < 1000; ++j) { 104 | dict[j] = j; 105 | } 106 | EXPECT_EQ(dict->size(), 1000); 107 | std::unordered_set keys; 108 | for (auto &kv : *dict) { 109 | int64_t key = kv.first; 110 | int64_t value = kv.second; 111 | EXPECT_EQ(key, value); 112 | EXPECT_FALSE(keys.count(key)); 113 | EXPECT_EQ(key, value); 114 | EXPECT_TRUE(0 <= key && key < 1000); 115 | } 116 | EXPECT_EQ(dict->size(), 1000); 117 | } 118 | 119 | TEST(Legacy_UDict_ReHash, Object) { 120 | std::vector> objs; 121 | std::unordered_map obj_map; 122 | for (int j = 0; j < 1000; ++j) { 123 | objs.push_back(Ref::New()); 124 | obj_map[objs[j].get()] = j; 125 | } 126 | UDict dict; 127 | for (int j = 0; j < 1000; ++j) { 128 | dict[objs[j]] = j; 129 | } 130 | EXPECT_EQ(dict->size(), 1000); 131 | std::unordered_set keys; 132 | for (auto &kv : *dict) { 133 | Ref key = kv.first; 134 | int64_t value = kv.second; 135 | keys.insert(key.get()); 136 | EXPECT_EQ(value, obj_map[key.get()]); 137 | } 138 | EXPECT_EQ(dict->size(), 1000); 139 | } 140 | 141 | TEST(Legacy_UDict_Erase, POD) { 142 | UDict dict; 143 | for (int j = 0; j < 1000; ++j) { 144 | dict[j] = j; 145 | } 146 | EXPECT_EQ(dict->size(), 1000); 147 | for (int j = 0; j < 1000; ++j) { 148 | dict->erase(j); 149 | EXPECT_EQ(dict->size(), 1000 - j - 1); 150 | } 151 | for (int j = 0; j < 1000; ++j) { 152 | dict[j] = j; 153 | EXPECT_EQ(dict->size(), j + 1); 154 | } 155 | } 156 | 157 | TEST(Legacy_UDict_Erase, Object) { 158 | std::vector> objs; 159 | std::unordered_map obj_map; 160 | for (int j = 0; j < 1000; ++j) { 161 | objs.push_back(Ref::New()); 162 | obj_map[objs[j].get()] = j; 163 | } 164 | UDict dict; 165 | for (int j = 0; j < 1000; ++j) { 166 | dict[objs[j]] = j; 167 | } 168 | EXPECT_EQ(dict->size(), 1000); 169 | for (int j = 0; j < 1000; ++j) { 170 | dict->erase(objs[j]); 171 | EXPECT_EQ(dict->size(), 1000 - j - 1); 172 | } 173 | for (int j = 0; j < 1000; ++j) { 174 | dict[objs[j]] = j; 175 | EXPECT_EQ(dict->size(), j + 1); 176 | } 177 | } 178 | 179 | } // namespace 180 | -------------------------------------------------------------------------------- /tests/cpp/test_sym_pattern.cc: -------------------------------------------------------------------------------- 1 | #include "./common.h" 2 | #include 3 | #include 4 | 5 | namespace { 6 | 7 | using namespace mlc::sym; 8 | using mlc::base::DType; 9 | 10 | TEST(Pattern, Basic_1) { 11 | Var y("y", DType::Int(64)); 12 | PVar px, py, pz; 13 | auto r = 1 + (y + 1); 14 | ASSERT_FALSE((px + (px + px)).Match(r)); 15 | ASSERT_FALSE((px + (py + py)).Match(r)); 16 | ASSERT_TRUE((px + (py + pz)).Match(r)); 17 | } 18 | 19 | TEST(Pattern, Basic_2) { 20 | Var y("y", DType::Int(64)); 21 | PVar px, py, pz; 22 | Expr r = 1 + (y + 1); 23 | ASSERT_TRUE((px + (py + px)).Match(r)); 24 | Expr rr = (px + py).Eval(); 25 | ASSERT_TRUE(ExprDeepEqual()(rr, 1 + y)); 26 | ASSERT_TRUE(ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y)); 27 | } 28 | 29 | TEST(Pattern, Basic_3) { 30 | Var x("x", DType::Int(64)); 31 | Var y("y", DType::Int(64)); 32 | Var z("z", DType::Int(64)); 33 | PVar px, py, pz; 34 | ASSERT_TRUE((px + max(py, px)).Match((x + 1) + max(y, (x + 1)))); 35 | ASSERT_TRUE(ExprDeepEqual()(px.Eval(), x + 1)); 36 | ASSERT_TRUE(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); 37 | ASSERT_TRUE((px + min(py, px)).Match(z + min(y, z))); 38 | ASSERT_TRUE((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2))); 39 | ASSERT_TRUE((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2))); 40 | ASSERT_TRUE((px - floormod(py, px * 2)).Match(x - floormod(2, x * 2))); 41 | } 42 | 43 | TEST(Pattern, Logical) { 44 | Var x("x", DType::Int(64)); 45 | Var y("y", DType::Int(64)); 46 | Var z("z", DType::Int(64)); 47 | PVar px, py, pz; 48 | ASSERT_TRUE((px == pz).Match(x == 1)); 49 | ASSERT_TRUE((px != pz).Match(x != 1)); 50 | ASSERT_TRUE((px > py).Match(x > y)); 51 | ASSERT_TRUE((px < py).Match(x < y)); 52 | ASSERT_TRUE((px <= py).Match(x <= y)); 53 | ASSERT_TRUE((px >= py).Match(x >= y)); 54 | ASSERT_TRUE((px >= py && px < pz).Match(x >= y && x < z)); 55 | ASSERT_TRUE((!(px > py || px != py)).Match(!(x > y || x != y))); 56 | } 57 | 58 | TEST(Pattern, Select) { 59 | Var x("x", DType::Int(64)); 60 | Var y("y", DType::Int(64)); 61 | Var z("z", DType::Int(64)); 62 | PVar px, py, pz; 63 | { 64 | ASSERT_TRUE(select(px >= pz, py, py + pz).Match(Select((x + 1) >= 1, y, y + 1))); 65 | ASSERT_TRUE(ExprDeepEqual()(px.Eval(), x + 1)); 66 | } 67 | { 68 | ASSERT_TRUE(select(px > pz, py, py + pz).Match(Select(x > 1, y, y + 1))); 69 | ASSERT_EQ(pz.Eval().as()->value, 1); 70 | } 71 | ASSERT_TRUE(!select(px > pz, py, py + pz).Match(Select(x > 2, y, y + 1))); 72 | ASSERT_TRUE(!select(px > pz, py, py).Match(Select(x > 2, y, y + 1))); 73 | { 74 | ASSERT_TRUE(select(px, py, pz).Match(Select(x > 2, y, y + 1))); 75 | ASSERT_TRUE(ExprDeepEqual()(pz.Eval(), y + 1)); 76 | } 77 | } 78 | 79 | TEST(Pattern, BitIntrinsics) { 80 | Var x("x", DType::Int(64)); 81 | Var y("y", DType::Int(64)); 82 | Var z("z", DType::Int(64)); 83 | PVar px, py, pz; 84 | ASSERT_TRUE((px << py).Match(x << 1)); 85 | ASSERT_TRUE((px >> py).Match(x >> 1)); 86 | ASSERT_TRUE((px & py).Match(x & 1)); 87 | ASSERT_TRUE((px | py).Match(x | 1)); 88 | ASSERT_TRUE((px ^ py).Match(x ^ 1)); 89 | ASSERT_TRUE((~px).Match(~x)); 90 | ASSERT_TRUE((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); 91 | } 92 | 93 | TEST(Pattern, IntImm) { 94 | Var tx("tx", DType::Int(64)); 95 | Var ty("ty", DType::Int(64)); 96 | PVar c; 97 | PVar v; 98 | { 99 | // We can match integer and Var, both of which are 100 | // special case container of Expr 101 | ASSERT_TRUE((v * c).Match(tx * 3)); 102 | ASSERT_EQ(c.Eval()->value, 3); 103 | ASSERT_TRUE((v * 3).Match(tx * 3)); 104 | } 105 | // cannot match c to ty 106 | ASSERT_TRUE(!(v * c).Match(tx * ty)); 107 | // cannot match tx + 1 to v 108 | ASSERT_TRUE(!(v * c).Match((tx + 1) * 3)); 109 | } 110 | 111 | TEST(Pattern, IfThenElse) { 112 | Var x("x", DType::Int(64)); 113 | Var y("y", DType::Int(64)); 114 | Var z("z", DType::Int(64)); 115 | PVar px, py, pz; 116 | ASSERT_TRUE(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1))); 117 | ASSERT_EQ(pz.Eval()->as()->value, 1); 118 | } 119 | 120 | TEST(Pattern, Ramp) { 121 | Var x("x", DType::Int(64)); 122 | PVar px; 123 | PVar lanes; 124 | ASSERT_TRUE(ramp(px, PConst(Expr::Int64(1)), lanes).Match(Ramp(x, Expr::Int64(1), 10))); 125 | ASSERT_TRUE(lanes.Eval() == 10); 126 | ASSERT_TRUE(!ramp(px, PConst(Expr::Int64(1)), lanes).Match(Ramp(x, Expr::Int64(2), 10))); 127 | } 128 | 129 | TEST(Pattern, Broadcast) { 130 | Var x("x", DType::Int(64)); 131 | PVar px, py; 132 | PVar lanes; 133 | ASSERT_TRUE(broadcast(px, lanes).Match(Broadcast(x, 10))); 134 | ASSERT_TRUE(lanes.Eval() == 10); 135 | ASSERT_TRUE(broadcast(px * py, lanes).Match(Broadcast(x * 10, 10))); 136 | } 137 | 138 | } // namespace 139 | -------------------------------------------------------------------------------- /tests/python/test_cc.py: -------------------------------------------------------------------------------- 1 | import mlc 2 | import mlc.dataclasses as mlcd 3 | import pytest 4 | from mlc._cython import SYSTEM 5 | 6 | 7 | @pytest.mark.xfail( 8 | condition=SYSTEM == "Windows", 9 | reason="`vcvarsall.bat` not found for some reason", 10 | ) 11 | def test_jit_load() -> None: 12 | mlc.cc.jit_load(""" 13 | #define MLC_JIT_EXPORTS 1 14 | #include 15 | #include 16 | 17 | struct MyObj : public mlc::Object { 18 | mlc::Str x; 19 | int32_t y; 20 | MyObj(mlc::Str x, int y) : x(x), y(y) {} 21 | int32_t YPlusOne() const { return y + 1; } 22 | MLC_DEF_DYN_TYPE(MLC_JIT_EXPORTS, MyObj, Object, "mlc.MyObj"); 23 | }; 24 | 25 | struct MyObjRef : public mlc::ObjectRef { 26 | MLC_DEF_OBJ_REF(MLC_JIT_EXPORTS, MyObjRef, MyObj, mlc::ObjectRef) 27 | .Field("x", &MyObj::x) 28 | .Field("y", &MyObj::y, /*frozen=*/true) 29 | .StaticFn("__init__", mlc::InitOf) 30 | .MemFn("YPlusOne", &MyObj::YPlusOne); 31 | MLC_DEF_OBJ_REF_FWD_NEW(MyObjRef) 32 | }; 33 | """) 34 | 35 | @mlcd.c_class("mlc.MyObj") 36 | class MyObj(mlc.Object): 37 | x: str 38 | y: int 39 | 40 | def YPlusOne(self) -> int: 41 | return type(self)._C(b"YPlusOne", self) 42 | 43 | obj = MyObj("hello", 42) 44 | assert obj.x == "hello" 45 | assert obj.y == 42 46 | assert obj.YPlusOne() == 43 47 | 48 | obj.x = "world" 49 | assert obj.x == "world" 50 | with pytest.raises(TypeError): 51 | obj.x = 42 # type: ignore[assignment] 52 | with pytest.raises(AttributeError): 53 | obj.y = 42 54 | del obj 55 | -------------------------------------------------------------------------------- /tests/python/test_cli_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mlc import config as cfg 3 | from mlc._cython import SYSTEM 4 | 5 | 6 | def test_includedir() -> None: 7 | (include_dir,) = cfg.includedir() 8 | assert include_dir.exists() 9 | assert (include_dir / "mlc").exists() 10 | assert (include_dir / "dlpack").exists() 11 | assert (include_dir / "mlc" / "backtrace").exists() 12 | 13 | 14 | def test_libdir() -> None: 15 | libdir = cfg.libdir() 16 | assert libdir.exists() 17 | if SYSTEM == "Windows": 18 | assert (libdir / "libmlc.dll").exists() 19 | assert (libdir / "libmlc-static.lib").exists() 20 | elif SYSTEM == "Darwin": 21 | assert (libdir / "libmlc.dylib").exists() 22 | assert (libdir / "libmlc-static.a").exists() 23 | else: 24 | assert (libdir / "libmlc.so").exists() 25 | assert (libdir / "libmlc-static.a").exists() 26 | 27 | 28 | @pytest.mark.xfail( 29 | condition=SYSTEM == "Windows", 30 | reason="`vcvarsall.bat` not found for some reason", 31 | ) 32 | def test_probe_compiler() -> None: 33 | compilers = cfg.probe_compiler() 34 | for compiler in compilers: 35 | assert compiler.exists() 36 | if SYSTEM == "Windows": 37 | print("Compilers found: ", ";".join(str(i) for i in compilers)) 38 | print("vcvarsall.bat found: ", cfg.probe_vcvarsall()) 39 | -------------------------------------------------------------------------------- /tests/python/test_core_device.py: -------------------------------------------------------------------------------- 1 | import mlc 2 | import pytest 3 | import torch 4 | from mlc import Device 5 | 6 | 7 | @pytest.mark.parametrize("x", ["cpu", "cpu:1", "cuda:0", "mps:1"]) 8 | def test_device_init_str(x: str) -> None: 9 | func = mlc.Func.get("mlc.testing.cxx_device") 10 | y = func(Device(x)) 11 | z = func(x) 12 | if ":" not in x: 13 | x += ":0" 14 | assert isinstance(y, Device) and y == Device(x) and str(y) == x 15 | assert isinstance(z, Device) and z == Device(x) and str(z) == x 16 | 17 | 18 | @pytest.mark.parametrize("x", ["unk"]) 19 | def test_device_init_fail(x: str) -> None: 20 | func = mlc.Func.get("mlc.testing.cxx_device") 21 | try: 22 | func(x) 23 | except ValueError as e: 24 | assert str(e) == f"Cannot convert to `Device` from string: {x}" 25 | else: 26 | assert False 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "x, y", 31 | [ 32 | ("meta:0", torch.device("meta")), 33 | ("cpu:0", torch.device("cpu")), 34 | ("cpu:0", torch.device("cpu:0")), 35 | ("cuda:0", torch.device("cuda")), 36 | ("cuda:1", torch.device("cuda:1")), 37 | ("mps:2", torch.device("mps:2")), 38 | ], 39 | ) 40 | def test_device_from_torch(x: str, y: torch.device) -> None: 41 | assert x == str(Device(y)) 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "x, y", 46 | [ 47 | (torch.device("meta:0"), Device("meta")), 48 | (torch.device("cpu:0"), Device("cpu")), 49 | (torch.device("cpu:0"), Device("cpu:0")), 50 | (torch.device("cuda:0"), Device("cuda")), 51 | (torch.device("cuda:1"), Device("cuda:1")), 52 | (torch.device("mps:2"), Device("mps:2")), 53 | ], 54 | ) 55 | def test_device_to_torch(x: torch.device, y: Device) -> None: 56 | assert x == y.torch() 57 | 58 | 59 | def test_device_register() -> None: 60 | code = Device.register("my_device") 61 | device = Device("my_device:10") 62 | assert device.device_type == code and device.device_id == 10 63 | assert str(device) == "my_device:10" 64 | -------------------------------------------------------------------------------- /tests/python/test_core_dict.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import pytest 4 | from mlc import Dict 5 | 6 | 7 | def test_dict_init_len() -> None: 8 | a = Dict({i: i * i for i in range(1, 5)}) 9 | assert len(a) == 4 10 | assert sorted(list(a)) == [1, 2, 3, 4] 11 | assert a[1] == 1 12 | assert a[2] == 4 13 | assert a[3] == 9 14 | assert a[4] == 16 15 | 16 | 17 | def test_dict_iter() -> None: 18 | a = Dict({i: i * i for i in range(1, 5)}) 19 | assert sorted(a) == [1, 2, 3, 4] 20 | assert sorted(a.keys()) == [1, 2, 3, 4] 21 | assert sorted(a.values()) == [1, 4, 9, 16] 22 | assert sorted(a.items()) == [(1, 1), (2, 4), (3, 9), (4, 16)] 23 | 24 | 25 | def test_dict_set_item() -> None: 26 | a = Dict({i: i * i for i in range(1, 5)}) 27 | a[1] = -1 28 | a[2] = -4 29 | assert a[1] == -1 30 | assert a[2] == -4 31 | assert dict(a) == {1: -1, 2: -4, 3: 9, 4: 16} 32 | 33 | 34 | def test_dict_del_item() -> None: 35 | a = Dict({i: i * i for i in range(1, 5)}) 36 | del a[1] 37 | del a[2] 38 | assert dict(a) == {3: 9, 4: 16} 39 | 40 | 41 | def test_dict_pop() -> None: 42 | a = Dict({i: i * i for i in range(1, 5)}) 43 | assert a.pop(1) == 1 44 | assert dict(a) == {2: 4, 3: 9, 4: 16} 45 | assert a.pop(2) == 4 46 | assert dict(a) == {3: 9, 4: 16} 47 | with pytest.raises(KeyError): 48 | a.pop(5) 49 | assert a.pop("not found", "default") == "default" # type: ignore[arg-type] 50 | 51 | 52 | def test_dict_key_error() -> None: 53 | a = Dict({i: i * i for i in range(1, 5)}) 54 | try: 55 | a[5] 56 | except KeyError as e: 57 | assert str(e) == "'5'" 58 | else: 59 | assert False 60 | 61 | 62 | def test_dict_get() -> None: 63 | a = Dict({i: i * i for i in range(1, 5)}) 64 | assert a.get(1) == 1 65 | assert a.get(2) == 4 66 | assert a.get(3) == 9 67 | assert a.get(4) == 16 68 | assert a.get(5) is None 69 | 70 | 71 | def test_dict_str() -> None: 72 | a = Dict({i: i * i for i in range(1, 5)}) 73 | assert "{" + ", ".join(sorted(str(a)[1:-1].split(", "))) + "}" == "{1: 1, 2: 4, 3: 9, 4: 16}" 74 | 75 | 76 | def test_dict_setdefault() -> None: 77 | a = Dict({i: i * i for i in range(1, 5)}) 78 | assert a.setdefault(1, -1) == 1 79 | assert a.setdefault(2, -4) == 4 80 | assert a.setdefault(5, 25) == 25 81 | assert dict(a) == {1: 1, 2: 4, 3: 9, 4: 16, 5: 25} 82 | 83 | 84 | def test_dict_eq() -> None: 85 | a = Dict({i: i * i for i in range(1, 5)}) 86 | b = {i: i * i for i in range(1, 5)} 87 | assert a == b 88 | assert b == a 89 | assert a == Dict(b) 90 | assert Dict(b) == a 91 | assert not a.eq_ptr(Dict(b)) 92 | assert not Dict(b).eq_ptr(a) 93 | assert a == a # noqa: PLR0124 94 | assert a.eq_ptr(a) 95 | 96 | 97 | def test_dict_ne_0() -> None: 98 | a = Dict({i: i * i for i in range(1, 5)}) 99 | b = {i: i * i for i in range(1, 6)} 100 | assert a != b 101 | assert b != a 102 | 103 | 104 | def test_dict_ne_1() -> None: 105 | a = Dict({i: i * i for i in range(1, 6)}) 106 | b = {i: i * i for i in range(1, 5)} 107 | assert a != b 108 | assert b != a 109 | 110 | 111 | def test_dict_to_py_0() -> None: 112 | a = Dict({i: i * i for i in range(1, 5)}).py() 113 | assert isinstance(a, dict) 114 | assert len(a) == 4 115 | assert isinstance(a[1], int) and a[1] == 1 116 | assert isinstance(a[2], int) and a[2] == 4 117 | assert isinstance(a[3], int) and a[3] == 9 118 | assert isinstance(a[4], int) and a[4] == 16 119 | 120 | 121 | def test_dict_to_py_1() -> None: 122 | a = Dict( 123 | { 124 | "a": { 125 | "b": [2], 126 | "c": 3.0, 127 | }, 128 | 1: "one", 129 | None: "e", 130 | } 131 | ).py() 132 | assert len(a) == 3 and set(a.keys()) == {"a", 1, None} 133 | assert isinstance(a["a"], dict) and len(a["a"]) == 2 134 | assert isinstance(a["a"]["b"], list) and len(a["a"]["b"]) == 1 and a["a"]["b"][0] == 2 135 | assert isinstance(a["a"]["c"], float) and a["a"]["c"] == 3.0 136 | assert isinstance(a[1], str) and a[1] == "one" 137 | assert isinstance(a[None], str) and a[None] == "e" 138 | 139 | 140 | @pytest.mark.parametrize( 141 | "callable", 142 | [ 143 | lambda a: a.__setitem__(0, 0), 144 | lambda a: a.__delitem__(1), 145 | lambda a: a.pop(2), 146 | lambda a: a.clear(), 147 | lambda a: a.setdefault(1, 5), 148 | ], 149 | ) 150 | def test_dict_freeze(callable: Callable[[Dict], None]) -> None: 151 | a = Dict({i: i * i for i in range(1, 5)}) 152 | assert a.frozen == False 153 | a.freeze() 154 | assert a.frozen == True 155 | with pytest.raises(RuntimeError) as e: 156 | callable(a) 157 | assert str(e.value) == "Cannot modify a frozen dict" 158 | -------------------------------------------------------------------------------- /tests/python/test_core_dtype.py: -------------------------------------------------------------------------------- 1 | import ml_dtypes 2 | import mlc 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from mlc import DataType 7 | 8 | 9 | @pytest.mark.parametrize("x", ["int32", "int32x3", "float8"]) 10 | def test_dtype_init_str(x: str) -> None: 11 | func = mlc.Func.get("mlc.testing.cxx_dtype") 12 | y = func(DataType(x)) 13 | z = func(x) 14 | assert isinstance(y, DataType) and y == DataType(x) and str(y) == x 15 | assert isinstance(z, DataType) and z == DataType(x) and str(z) == x 16 | 17 | 18 | @pytest.mark.parametrize("x", ["int", "int32xx3", "float821"]) 19 | def test_dtype_init_fail(x: str) -> None: 20 | func = mlc.Func.get("mlc.testing.cxx_dtype") 21 | try: 22 | print(func(x)) 23 | except ValueError as e: 24 | assert str(e) == f"Cannot convert to `dtype` from string: {x}" 25 | else: 26 | assert False 27 | 28 | 29 | @pytest.mark.parametrize("x", [DataType("int32"), DataType("int32x3"), DataType("float8")]) 30 | def test_dtype_init_dtype(x: DataType) -> None: 31 | y = DataType(x) 32 | assert x == y 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "x, y", 37 | [ 38 | (torch.int32, "int32"), 39 | (torch.float16, "float16"), 40 | (torch.float8_e4m3fn, "float8_e4m3fn"), 41 | (torch.float8_e4m3fnuz, "float8_e4m3fnuz"), 42 | (torch.float8_e5m2, "float8_e5m2"), 43 | (torch.float8_e5m2fnuz, "float8_e5m2fnuz"), 44 | ], 45 | ) 46 | def test_dtype_init_torch(x: str, y: DataType) -> None: 47 | assert y == str(DataType(x)) 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "x, y", 52 | [ 53 | (ml_dtypes.int2, "int2"), 54 | (ml_dtypes.int4, "int4"), 55 | (ml_dtypes.uint2, "uint2"), 56 | (ml_dtypes.uint4, "uint4"), 57 | (ml_dtypes.float8_e3m4, "float8_e3m4"), 58 | (ml_dtypes.float8_e4m3, "float8_e4m3"), 59 | (ml_dtypes.float8_e4m3b11fnuz, "float8_e4m3b11fnuz"), 60 | (ml_dtypes.float8_e4m3fn, "float8_e4m3fn"), 61 | (ml_dtypes.float8_e4m3fnuz, "float8_e4m3fnuz"), 62 | (ml_dtypes.float8_e5m2, "float8_e5m2"), 63 | (ml_dtypes.float8_e5m2fnuz, "float8_e5m2fnuz"), 64 | (ml_dtypes.float8_e8m0fnu, "float8_e8m0fnu"), 65 | (ml_dtypes.float4_e2m1fn, "float4_e2m1fn"), 66 | (ml_dtypes.float6_e2m3fn, "float6_e2m3fn"), 67 | (ml_dtypes.float6_e3m2fn, "float6_e3m2fn"), 68 | ], 69 | ) 70 | def test_dtype_init_ml_dtypes(x: type, y: str) -> None: 71 | assert y == str(DataType(x)) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "x, y", 76 | [ 77 | ("int32", torch.int32), 78 | ("float16", torch.float16), 79 | ("float8_e4m3fn", torch.float8_e4m3fn), 80 | ("float8_e4m3fnuz", torch.float8_e4m3fnuz), 81 | ("float8_e5m2", torch.float8_e5m2), 82 | ("float8_e5m2fnuz", torch.float8_e5m2fnuz), 83 | ], 84 | ) 85 | def test_dtype_to_torch(x: str, y: torch.dtype) -> None: 86 | assert DataType(x).torch() == y 87 | 88 | 89 | @pytest.mark.parametrize( 90 | "x, y", 91 | [ 92 | ("int2", ml_dtypes.int2), 93 | ("int4", ml_dtypes.int4), 94 | ("uint2", ml_dtypes.uint2), 95 | ("uint4", ml_dtypes.uint4), 96 | ("int32", np.int32), 97 | ("float16", np.float16), 98 | ("float8_e3m4", ml_dtypes.float8_e3m4), 99 | ("float8_e4m3", ml_dtypes.float8_e4m3), 100 | ("float8_e4m3b11fnuz", ml_dtypes.float8_e4m3b11fnuz), 101 | ("float8_e4m3fn", ml_dtypes.float8_e4m3fn), 102 | ("float8_e4m3fnuz", ml_dtypes.float8_e4m3fnuz), 103 | ("float8_e5m2", ml_dtypes.float8_e5m2), 104 | ("float8_e5m2fnuz", ml_dtypes.float8_e5m2fnuz), 105 | ("float8_e8m0fnu", ml_dtypes.float8_e8m0fnu), 106 | ("float4_e2m1fn", ml_dtypes.float4_e2m1fn), 107 | ("float6_e2m3fn", ml_dtypes.float6_e2m3fn), 108 | ("float6_e3m2fn", ml_dtypes.float6_e3m2fn), 109 | ], 110 | ) 111 | def test_dtype_to_numpy(x: str, y: type) -> None: 112 | assert DataType(x).numpy() == np.dtype(y) 113 | 114 | 115 | def test_dtype_register() -> None: 116 | code = DataType.register("float8_custom", bits=8) 117 | dtype = DataType("float8_custom") 118 | assert dtype.code == code and dtype.bits == 8 and dtype.lanes == 1 119 | assert str(dtype) == "float8_custom" 120 | 121 | 122 | def test_dtype_init_from_triple() -> None: 123 | i32 = DataType("int32") 124 | i32x2 = DataType.from_triple(i32.code, i32.bits, lanes=2) 125 | assert i32x2 == "int32x2" 126 | -------------------------------------------------------------------------------- /tests/python/test_core_func.py: -------------------------------------------------------------------------------- 1 | import mlc 2 | import pytest 3 | from mlc import Func 4 | 5 | 6 | def test_cxx_none() -> None: 7 | func = mlc.Func.get("mlc.testing.cxx_none") 8 | assert func() is None 9 | 10 | 11 | def test_cxx_nullptr() -> None: 12 | func = mlc.Func.get("mlc.testing.cxx_null") 13 | assert func() is None 14 | 15 | 16 | @pytest.mark.parametrize("x", [-1, 0, 1]) 17 | def test_cxx_int(x: int) -> None: 18 | func = mlc.Func.get("mlc.testing.cxx_int") 19 | y = func(x) 20 | assert isinstance(y, int) and y == x 21 | 22 | 23 | @pytest.mark.parametrize("x", [True, False]) 24 | def test_cxx_bool(x: int) -> None: 25 | func = mlc.Func.get("mlc.testing.cxx_bool") 26 | y = func(x) 27 | assert isinstance(y, bool) and y == x 28 | 29 | 30 | @pytest.mark.parametrize("x", [-1, 0, 1, 2.0, -2.0]) 31 | def test_cxx_float(x: float) -> None: 32 | func = mlc.Func.get("mlc.testing.cxx_float") 33 | y = func(x) 34 | assert isinstance(y, float) and y == x 35 | 36 | 37 | @pytest.mark.parametrize("x", [0x0, 0xDEADBEEF]) 38 | def test_cxx_ptr(x: int) -> None: 39 | import ctypes 40 | 41 | func = mlc.Func.get("mlc.testing.cxx_ptr") 42 | y = func(ctypes.c_void_p(x)) 43 | if x == 0: 44 | assert y is None 45 | else: 46 | assert isinstance(y, ctypes.c_void_p) and y.value == x 47 | 48 | 49 | def test_func_init() -> None: 50 | func = Func(lambda x: x + 1) 51 | assert func(1) == 2 52 | assert str(func).startswith("object.Func@0x") 53 | -------------------------------------------------------------------------------- /tests/python/test_core_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import mlc 4 | 5 | 6 | def test_json_loads_bool() -> None: 7 | src = json.dumps([True, False]) 8 | result = mlc.json_loads(src) 9 | assert isinstance(result, mlc.List) and len(result) == 2 10 | assert result[0] == True 11 | assert result[1] == False 12 | 13 | 14 | def test_json_loads_none() -> None: 15 | src = json.dumps([None]) 16 | result = mlc.json_loads(src) 17 | assert isinstance(result, mlc.List) and len(result) == 1 18 | assert result[0] is None 19 | -------------------------------------------------------------------------------- /tests/python/test_core_object.py: -------------------------------------------------------------------------------- 1 | import mlc 2 | 3 | 4 | def test_object_init() -> None: 5 | obj = mlc.Object() 6 | assert str(obj).startswith("object.Object@0x") 7 | 8 | 9 | def test_object_swap() -> None: 10 | a = mlc.Object() 11 | b = mlc.Object() 12 | a_addr = a._mlc_address 13 | b_addr = b._mlc_address 14 | a.swap(b) 15 | assert a._mlc_address == b_addr 16 | assert b._mlc_address == a_addr 17 | -------------------------------------------------------------------------------- /tests/python/test_core_object_path.py: -------------------------------------------------------------------------------- 1 | import mlc 2 | 3 | 4 | def test_root() -> None: 5 | root = mlc.ObjectPath.root() 6 | assert root.kind == -1 7 | assert str(root) == "{root}" 8 | 9 | 10 | def test_with_field_0() -> None: 11 | obj = mlc.ObjectPath.root().with_field("field") 12 | assert obj.kind == 0 13 | assert str(obj) == "{root}.field" 14 | 15 | 16 | def test_with_field_1() -> None: 17 | obj = mlc.ObjectPath.root()["field"] 18 | assert obj.kind == 0 19 | assert str(obj) == "{root}.field" 20 | 21 | 22 | def test_with_list_index() -> None: 23 | obj = mlc.ObjectPath.root().with_list_index(1) 24 | assert obj.kind == 1 25 | assert str(obj) == "{root}[1]" 26 | 27 | 28 | def test_with_list_index_1() -> None: 29 | obj = mlc.ObjectPath.root()[1] 30 | assert obj.kind == 1 31 | assert str(obj) == "{root}[1]" 32 | 33 | 34 | def test_with_dict_key() -> None: 35 | obj = mlc.ObjectPath.root().with_dict_key("key") 36 | assert obj.kind == 2 37 | assert str(obj) == '{root}["key"]' 38 | -------------------------------------------------------------------------------- /tests/python/test_core_opaque.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from typing import Any 4 | 5 | import mlc 6 | import pytest 7 | 8 | 9 | class MyTypeNotRegistered: 10 | def __init__(self, a: int) -> None: 11 | self.a = a 12 | 13 | 14 | class MyType: 15 | def __init__(self, a: int) -> None: 16 | self.a = a 17 | 18 | def __call__(self, x: int) -> int: 19 | return x + self.a 20 | 21 | def __eq__(self, other: Any) -> bool: 22 | return isinstance(self, MyType) and isinstance(other, MyType) and self.a == other.a 23 | 24 | def __hash__(self) -> int: 25 | assert isinstance(self, MyType) 26 | return hash((MyType, self.a)) 27 | 28 | 29 | mlc.Opaque.register(MyType) 30 | 31 | 32 | @mlc.dataclasses.py_class(structure="bind") 33 | class Wrapper(mlc.dataclasses.PyClass): 34 | field: Any = mlc.dataclasses.field(structure="nobind") 35 | 36 | 37 | def test_opaque_init() -> None: 38 | a = MyType(a=10) 39 | opaque = mlc.Opaque(a) 40 | assert str(opaque) == "" 41 | 42 | 43 | def test_opaque_init_error() -> None: 44 | a = MyTypeNotRegistered(a=10) 45 | with pytest.raises(TypeError) as e: 46 | mlc.Opaque(a) 47 | assert ( 48 | str(e.value) 49 | == "MLC does not recognize type: . " 50 | "If it's intentional, please register using `mlc.Opaque.register(type)`." 51 | ) 52 | 53 | 54 | def test_opaque_ffi() -> None: 55 | func = mlc.Func.get("mlc.testing.cxx_obj") 56 | a = func(MyType(a=10)) 57 | assert isinstance(a, MyType) 58 | assert a.a == 10 59 | 60 | 61 | def test_opaque_ffi_error() -> None: 62 | func = mlc.Func.get("mlc.testing.cxx_obj") 63 | with pytest.raises(TypeError) as e: 64 | func(MyTypeNotRegistered(a=10)) 65 | assert ( 66 | str(e.value) 67 | == "MLC does not recognize type: " 68 | ) 69 | 70 | 71 | def test_opaque_dataclass() -> None: 72 | a = MyType(a=10) 73 | wrapper = Wrapper(field=a) 74 | assert isinstance(wrapper.field, MyType) 75 | assert wrapper.field.a == 10 76 | 77 | 78 | def test_opaque_dataclass_eq_s() -> None: 79 | a1 = Wrapper(field=MyType(a=10)) 80 | a2 = Wrapper(field=MyType(a=10)) 81 | a1.eq_s(a2, assert_mode=True) 82 | 83 | 84 | def test_opaque_dataclass_eq_s_fail() -> None: 85 | a1 = Wrapper(field=MyType(a=10)) 86 | a2 = Wrapper(field=MyType(a=20)) 87 | with pytest.raises(ValueError) as exc_info: 88 | a1.eq_s(a2, assert_mode=True) 89 | assert str(exc_info.value).startswith("Structural equality check failed at {root}.field") 90 | 91 | 92 | def test_opaque_dataclass_hash_s() -> None: 93 | a1 = Wrapper(field=MyType(a=10)) 94 | assert isinstance(a1.hash_s(), int) 95 | 96 | 97 | def test_opaque_serialize() -> None: 98 | obj_1 = Wrapper(field=MyType(a=10)) 99 | json_str = obj_1.json() 100 | js = json.loads(json_str) 101 | assert js["opaques"] == '[{"py/object": "test_core_opaque.MyType", "a": 10}]' 102 | assert js["values"] == [[0, 0], [1, 0]] 103 | assert js["type_keys"] == ["mlc.core.Opaque", "test_core_opaque.Wrapper"] 104 | obj_2 = Wrapper.from_json(json_str) 105 | assert isinstance(obj_2.field, MyType) 106 | assert obj_2.field.a == 10 107 | 108 | 109 | def test_opaque_serialize_with_alias() -> None: 110 | a1 = MyType(a=10) 111 | a2 = MyType(a=20) 112 | a3 = MyType(a=30) 113 | obj_1 = Wrapper(field=[a1, a2, a3, a3, a2, a1]) 114 | obj_2 = Wrapper.from_json(obj_1.json()) 115 | assert obj_2.field[0] is obj_2.field[5] 116 | assert obj_2.field[1] is obj_2.field[4] 117 | assert obj_2.field[2] is obj_2.field[3] 118 | assert obj_2.field[0].a == 10 119 | assert obj_2.field[1].a == 20 120 | assert obj_2.field[2].a == 30 121 | assert obj_2.field[3].a == 30 122 | assert obj_2.field[4].a == 20 123 | assert obj_2.field[5].a == 10 124 | 125 | 126 | def test_opaque_deepcopy() -> None: 127 | a = MyType(a=10) 128 | obj_1 = Wrapper(field=a) 129 | obj_2 = copy.deepcopy(obj_1) 130 | assert isinstance(obj_2.field, MyType) 131 | assert obj_2.field.a == 10 132 | assert obj_1 is not obj_2 133 | assert obj_1.field is not obj_2.field 134 | -------------------------------------------------------------------------------- /tests/python/test_core_tensor.py: -------------------------------------------------------------------------------- 1 | import mlc 2 | import numpy as np 3 | import pytest 4 | import torch 5 | 6 | 7 | @pytest.fixture 8 | def cxx_func() -> mlc.Func: 9 | return mlc.Func.get("mlc.testing.cxx_obj") 10 | 11 | 12 | def test_tensor_from_numpy(cxx_func: mlc.Func) -> None: 13 | a = np.arange(24, dtype=np.int16).reshape(2, 3, 4) 14 | 15 | b = cxx_func(a) 16 | assert b.dtype == mlc.DataType("int16") 17 | assert b.device == mlc.Device("cpu") 18 | assert b.shape == (2, 3, 4) 19 | assert b.strides is None 20 | assert b.byte_offset == 0 21 | assert str(b) == "" 22 | 23 | b = mlc.Tensor(a) 24 | assert b.dtype == mlc.DataType("int16") 25 | assert b.device == mlc.Device("cpu") 26 | assert b.shape == (2, 3, 4) 27 | assert b.strides is None 28 | assert b.byte_offset == 0 29 | assert str(b) == "" 30 | 31 | assert np.array_equal(a, b.numpy()) 32 | 33 | 34 | def test_opaque_from_torch(cxx_func: mlc.Func) -> None: 35 | a = torch.arange(24, dtype=torch.int16).reshape(2, 3, 4) 36 | 37 | b = cxx_func(a) 38 | assert b.dtype == mlc.DataType("int16") 39 | assert b.device == mlc.Device("cpu") 40 | assert b.shape == (2, 3, 4) 41 | assert b.strides is None 42 | assert b.byte_offset == 0 43 | assert str(b) == "" 44 | 45 | b = mlc.Tensor(a) 46 | assert b.dtype == mlc.DataType("int16") 47 | assert b.device == mlc.Device("cpu") 48 | assert b.shape == (2, 3, 4) 49 | assert b.strides is None 50 | assert b.byte_offset == 0 51 | assert str(b) == "" 52 | 53 | assert torch.equal(a, b.torch()) 54 | 55 | 56 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") 57 | def test_opaque_from_torch_cuda(cxx_func: mlc.Func) -> None: 58 | a = torch.arange(24, dtype=torch.int16).reshape(2, 3, 4).cuda() 59 | 60 | b = cxx_func(a) 61 | assert b.dtype == mlc.DataType("int16") 62 | assert b.device == mlc.Device("cuda") 63 | assert b.shape == (2, 3, 4) 64 | assert b.strides is None 65 | assert b.byte_offset == 0 66 | assert str(b) == "" 67 | 68 | b = mlc.Tensor(a) 69 | assert b.dtype == mlc.DataType("int16") 70 | assert b.device == mlc.Device("cuda") 71 | assert b.shape == (2, 3, 4) 72 | assert b.strides is None 73 | assert b.byte_offset == 0 74 | assert str(b) == "" 75 | 76 | assert torch.equal(a, b.torch()) 77 | 78 | 79 | def test_tensor_base64_int16() -> None: 80 | a = mlc.Tensor(np.arange(24, dtype=np.int16).reshape(2, 3, 4)) 81 | assert ( 82 | a.base64() 83 | == "P6G0lvBAXt0DAAAAABABAAIAAAAAAAAAAwAAAAAAAAAEAAAAAAAAAAAAAQACAAMABAAFAAYABwAIAAkACgALAAwADQAOAA8AEAARABIAEwAUABUAFgAXAA==" 84 | ) 85 | b = mlc.Tensor.from_base64(a.base64()) 86 | assert a.ndim == b.ndim 87 | assert a.shape == b.shape 88 | assert a.dtype == b.dtype 89 | assert a.device == b.device 90 | assert a.strides == b.strides 91 | assert a.byte_offset == b.byte_offset 92 | assert a.base64() == b.base64() 93 | 94 | assert np.array_equal(a.numpy(), b.numpy()) 95 | assert torch.equal(a.torch(), b.torch()) 96 | 97 | 98 | def test_tensor_base64_float16() -> None: 99 | a = mlc.Tensor(np.array([3.0, 10.0, 20.0, 30.0, 35.50], dtype=np.float16)) 100 | assert a.base64() == "P6G0lvBAXt0BAAAAAhABAAUAAAAAAAAAAEIASQBNgE9wUA==" 101 | b = mlc.Tensor.from_base64(a.base64()) 102 | assert a.ndim == b.ndim 103 | assert a.shape == b.shape 104 | assert a.dtype == b.dtype 105 | assert a.device == b.device 106 | assert a.strides == b.strides 107 | assert a.byte_offset == b.byte_offset 108 | assert a.base64() == b.base64() 109 | 110 | assert np.array_equal(a.numpy(), b.numpy()) 111 | assert torch.equal(a.torch(), b.torch()) 112 | 113 | 114 | def test_torch_strides() -> None: 115 | a = torch.empty(4, 1, 6, 1, 10, dtype=torch.int16) 116 | a = torch.from_dlpack(torch.to_dlpack(a)) 117 | b = mlc.Tensor(a) 118 | assert b.strides is None 119 | 120 | 121 | def test_tensor_serialize() -> None: 122 | a = mlc.Tensor(np.arange(24, dtype=np.int16).reshape(2, 3, 4)) 123 | a_json = mlc.List([a, a]).json() 124 | b = mlc.List.from_json(a_json) 125 | assert isinstance(b, mlc.List) 126 | assert len(b) == 2 127 | assert isinstance(b[0], mlc.Tensor) 128 | assert isinstance(b[1], mlc.Tensor) 129 | assert b[0].eq_ptr(b[1]) 130 | assert np.array_equal(a.numpy(), b[0].numpy()) 131 | -------------------------------------------------------------------------------- /tests/python/test_cython_traceback.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from io import StringIO 3 | 4 | import mlc 5 | import pytest 6 | 7 | 8 | def test_throw_exception_from_c() -> None: 9 | func = mlc.Func.get("mlc.testing.throw_exception_from_c") 10 | with pytest.raises(ValueError) as exc_info: 11 | func() 12 | 13 | msg = traceback.format_exception(exc_info.type, exc_info.value, exc_info.tb) 14 | msg = "".join(msg).strip().splitlines() 15 | assert "Traceback (most recent call last)" in msg[0] 16 | assert "in test_throw_exception_from_c" in msg[1] 17 | assert "ValueError: This is an error message" in msg[-1] 18 | assert "c_api.cc" in msg[-3] 19 | 20 | 21 | def test_throw_exception_from_c_empty() -> None: 22 | func = mlc.Func.get("mlc.testing.throw_exception_from_c_empty") 23 | with pytest.raises(ValueError) as exc_info: 24 | func() 25 | 26 | msg = traceback.format_exception(exc_info.type, exc_info.value, exc_info.tb) 27 | msg = "".join(msg).strip().splitlines() 28 | assert "Traceback (most recent call last)" in msg[0] 29 | assert "in test_throw_exception_from_c_empty" in msg[1] 30 | assert "ValueError" == msg[-1].strip() 31 | 32 | 33 | def test_throw_exception_from_ffi() -> None: 34 | def throw_ValueError() -> None: 35 | raise ValueError("This is a ValueError") 36 | 37 | err: mlc.Error = mlc.Func.get("mlc.testing.throw_exception_from_ffi")(throw_ValueError) 38 | assert err.kind == "ValueError" 39 | assert str(err) == "This is a ValueError" 40 | io = StringIO() 41 | err.print_exc(file=io) 42 | msg = io.getvalue().rstrip().splitlines() 43 | assert "Traceback (most recent call last)" in msg[0] 44 | assert "ValueError: This is a ValueError" in msg[-1] 45 | assert msg[1].endswith(", in throw_ValueError") 46 | 47 | 48 | def test_throw_exception_from_ffi_in_c() -> None: 49 | def throw_ValueError() -> None: 50 | def _inner() -> None: 51 | raise ValueError("This is a ValueError") 52 | 53 | _inner() 54 | 55 | with pytest.raises(ValueError) as exc_info: 56 | mlc.Func.get("mlc.testing.throw_exception_from_ffi_in_c")(throw_ValueError) 57 | 58 | msg = traceback.format_exception(exc_info.type, exc_info.value, exc_info.tb) 59 | msg = "".join(msg).strip().splitlines() 60 | assert "Traceback (most recent call last)" in msg[0] 61 | assert "in test_throw_exception_from_ffi_in_c" in msg[1] 62 | assert "ValueError: This is a ValueError" in msg[-1] 63 | idx_c_api_tests = next(i for i, line in enumerate(msg) if "c_api.cc" in line) 64 | idx_handle_error = next(i for i, line in enumerate(msg) if "_func_safe_call_impl" in line) 65 | assert idx_c_api_tests < idx_handle_error 66 | 67 | 68 | def test_throw_NotImplementedError_from_ffi_in_c() -> None: 69 | def throw_ValueError() -> None: 70 | def _inner() -> None: 71 | raise NotImplementedError 72 | 73 | _inner() 74 | 75 | with pytest.raises(NotImplementedError): 76 | mlc.Func.get("mlc.testing.throw_exception_from_ffi_in_c")(throw_ValueError) 77 | -------------------------------------------------------------------------------- /tests/python/test_dataclasses_serialize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from typing import Any, Optional 4 | 5 | import mlc 6 | 7 | 8 | @mlc.dataclasses.py_class("mlc.testing.serialize", init=False) 9 | class ObjTest(mlc.PyClass): 10 | a: int 11 | b: float 12 | c: str 13 | d: bool 14 | 15 | def __init__(self, b: float, c: str, a: int, d: bool) -> None: 16 | self.a = a 17 | self.b = b 18 | self.c = c 19 | self.d = d 20 | 21 | 22 | @mlc.dataclasses.py_class("mlc.testing.serialize_opt") 23 | class ObjTestOpt(mlc.PyClass): 24 | a: Optional[int] 25 | b: Optional[float] 26 | c: Optional[str] 27 | d: Optional[bool] 28 | 29 | 30 | @mlc.dataclasses.py_class("mlc.testing.AnyContainer") 31 | class AnyContainer(mlc.PyClass): 32 | field: Any 33 | 34 | 35 | def test_json() -> None: 36 | obj = ObjTest(a=1, b=2.0, c="3", d=True) 37 | obj_json = obj.json() 38 | obj_json_dict = json.loads(obj_json) 39 | assert obj_json_dict["type_keys"] == ["mlc.testing.serialize", "int"] 40 | assert obj_json_dict["values"] == ["3", [0, [1, 1], 2.0, 0, True]] 41 | obj_from_json: ObjTest = ObjTest.from_json(obj_json) 42 | assert obj.a == obj_from_json.a 43 | assert obj.b == obj_from_json.b 44 | assert obj.c == obj_from_json.c 45 | assert obj.d == obj_from_json.d 46 | 47 | 48 | def test_pickle() -> None: 49 | obj = ObjTest(a=1, b=2.0, c="3", d=True) 50 | obj_pickle = pickle.dumps(obj) 51 | obj_from_pickle = pickle.loads(obj_pickle) 52 | assert obj.a == obj_from_pickle.a 53 | assert obj.b == obj_from_pickle.b 54 | assert obj.c == obj_from_pickle.c 55 | assert obj.d == obj_from_pickle.d 56 | 57 | 58 | def test_json_opt_0() -> None: 59 | obj = ObjTestOpt(1, 2.0, "3", True) 60 | obj_json = obj.json() 61 | obj_json_dict = json.loads(obj_json) 62 | assert obj_json_dict["type_keys"] == ["mlc.testing.serialize_opt", "int"] 63 | assert obj_json_dict["values"] == [ 64 | "3", # values[0] = literal: "3" 65 | [ 66 | # values[1].type_index = 0 ===> ObjTestOpt 67 | 0, 68 | # values[1].0 69 | [ 70 | 1, # type_index = 1 ===> int 71 | 1, # value = 1 72 | ], 73 | # values[1].1 = literal: 2.0 74 | 2.0, 75 | # values[1].2 = values[0] 76 | 0, 77 | # values[1].3 = literal: True 78 | True, 79 | ], 80 | ] 81 | obj_from_json: ObjTestOpt = ObjTestOpt.from_json(obj_json) 82 | assert obj.a == obj_from_json.a 83 | assert obj.b == obj_from_json.b 84 | assert obj.c == obj_from_json.c 85 | assert obj.d == obj_from_json.d 86 | 87 | 88 | def test_json_opt_1() -> None: 89 | obj = ObjTestOpt(None, None, None, None) 90 | obj_json = obj.json() 91 | obj_json_dict = json.loads(obj_json) 92 | assert obj_json_dict["type_keys"] == ["mlc.testing.serialize_opt"] 93 | assert obj_json_dict["values"] == [[0, None, None, None, None]] 94 | obj_from_json: ObjTestOpt = ObjTestOpt.from_json(obj_json) 95 | assert obj.a == obj_from_json.a 96 | assert obj.b == obj_from_json.b 97 | assert obj.c == obj_from_json.c 98 | assert obj.d == obj_from_json.d 99 | 100 | 101 | def test_json_dag() -> None: 102 | lst = mlc.List([1, 2.0, "3", True]) 103 | dct = mlc.Dict({"a": 1, "b": 2.0, "c": "3", "d": True, "v": lst}) 104 | big_lst = mlc.List([lst, dct, lst, dct]) 105 | obj_1 = AnyContainer([big_lst, big_lst]) 106 | obj_2: AnyContainer = AnyContainer.from_json(obj_1.json()) 107 | assert obj_2.field[0].is_(obj_2.field[1]) 108 | assert obj_2.field[0] == big_lst 109 | big_lst = obj_2.field[0] 110 | assert big_lst[0].is_(big_lst[2]) # type: ignore[attr-defined] 111 | assert big_lst[1].is_(big_lst[3]) # type: ignore[attr-defined] 112 | assert big_lst[0] == lst 113 | assert big_lst[1] == dct 114 | lst, dct = big_lst[:2] # type: ignore[assignment] 115 | assert dct["v"].is_(lst) # type: ignore[attr-defined] 116 | -------------------------------------------------------------------------------- /tests/python/test_parser_toy_ir_parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from mlc.testing import toy_ir 4 | from mlc.testing.toy_ir import Add, Assign, Func, Var 5 | 6 | 7 | def test_parse_func() -> None: 8 | def source_code(a, b, c): # noqa: ANN001, ANN202 9 | d = a + b 10 | e = d + c 11 | return e 12 | 13 | def _expected() -> Func: 14 | a = Var(name="_a") 15 | b = Var(name="_b") 16 | c = Var(name="_c") 17 | d = Var(name="_d") 18 | e = Var(name="_e") 19 | stmts = [ 20 | Assign(lhs=d, rhs=Add(a, b)), 21 | Assign(lhs=e, rhs=Add(d, c)), 22 | ] 23 | f = Func(name="_f", args=[a, b, c], stmts=stmts, ret=e) 24 | return f 25 | 26 | result = toy_ir.parse_func(source_code) 27 | expected = _expected() 28 | result.eq_s(expected, assert_mode=True) 29 | -------------------------------------------------------------------------------- /tests/python/test_printer_ir_printer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import mlc.printer as mlcp 4 | import pytest 5 | from mlc.printer import ObjectPath 6 | from mlc.testing.toy_ir import Add, Assign, Func, Var 7 | 8 | 9 | def test_var_print() -> None: 10 | a = Var(name="a") 11 | assert mlcp.to_python(a) == "a" 12 | 13 | 14 | def test_var_print_name_normalize() -> None: 15 | a = Var(name="a/0/b") 16 | assert mlcp.to_python(a) == "a_0_b" 17 | assert mlcp.to_python(a) == "a_0_b" 18 | 19 | 20 | def test_add_print() -> None: 21 | a = Var(name="a") 22 | b = Var(name="b") 23 | c = Add(lhs=a, rhs=b) 24 | assert mlcp.to_python(c) == "a + b" 25 | 26 | 27 | def test_assign_print() -> None: 28 | a = Var(name="a") 29 | b = Var(name="b") 30 | c = Assign(lhs=a, rhs=b) 31 | assert mlcp.to_python(c) == "a = b" 32 | 33 | 34 | def test_func_print() -> None: 35 | a = Var(name="a") 36 | b = Var(name="b") 37 | c = Var(name="c") 38 | d = Var(name="d") 39 | e = Var(name="e") 40 | stmts = [ 41 | Assign(lhs=d, rhs=Add(a, b)), 42 | Assign(lhs=e, rhs=Add(d, c)), 43 | ] 44 | f = Func(name="f", args=[a, b, c], stmts=stmts, ret=e) 45 | assert ( 46 | mlcp.to_python(f) 47 | == """ 48 | def f(a, b, c): 49 | d = a + b 50 | e = d + c 51 | return e 52 | """.strip() 53 | ) 54 | 55 | 56 | def test_print_none() -> None: 57 | printer = mlcp.IRPrinter() 58 | path = mlcp.ObjectPath.root() 59 | node = printer(None, path) 60 | assert node.to_python() == "None" 61 | 62 | 63 | def test_print_int() -> None: 64 | printer = mlcp.IRPrinter() 65 | path = mlcp.ObjectPath.root() 66 | node = printer(42, path) 67 | assert node.to_python() == "42" 68 | 69 | 70 | def test_print_str() -> None: 71 | printer = mlcp.IRPrinter() 72 | path = mlcp.ObjectPath.root() 73 | node = printer("hey", path) 74 | assert node.to_python() == '"hey"' 75 | 76 | 77 | def test_print_bool() -> None: 78 | printer = mlcp.IRPrinter() 79 | path = mlcp.ObjectPath.root() 80 | node = printer(True, path) 81 | assert node.to_python() == "True" 82 | 83 | 84 | def test_duplicated_vars() -> None: 85 | a = Var(name="a") 86 | b = Var(name="a") 87 | stmts = [ 88 | Assign(lhs=b, rhs=Add(a, a)), 89 | ] 90 | f = Func( 91 | name="f", 92 | args=[a], 93 | stmts=stmts, 94 | ret=b, 95 | ) 96 | assert ( 97 | mlcp.to_python(f) 98 | == """ 99 | def f(a): 100 | a_1 = a + a 101 | return a_1 102 | """.strip() 103 | ) 104 | assert re.fullmatch( 105 | r"^def f\(a\):\n" 106 | r" a_0x[0-9A-Fa-f]+ = a \+ a\n" 107 | r" return a_0x[0-9A-Fa-f]+$", 108 | mlcp.to_python(f, mlcp.PrinterConfig(print_addr_on_dup_var=True)), 109 | ) 110 | 111 | 112 | @pytest.mark.parametrize( 113 | "path, expected", 114 | [ 115 | ( 116 | ObjectPath.root()["args"][0], 117 | """ 118 | def f(a, b): 119 | ^ 120 | c = a + b 121 | return c 122 | """, 123 | ), 124 | ( 125 | ObjectPath.root()["args"][1], 126 | """ 127 | def f(a, b): 128 | ^ 129 | c = a + b 130 | return c 131 | """, 132 | ), 133 | ( 134 | ObjectPath.root()["stmts"][0], 135 | """ 136 | def f(a, b): 137 | c = a + b 138 | ^^^^^^^^^ 139 | return c 140 | """, 141 | ), 142 | ( 143 | ObjectPath.root()["stmts"][0], 144 | """ 145 | def f(a, b): 146 | c = a + b 147 | ^^^^^^^^^ 148 | return c 149 | """, 150 | ), 151 | ( 152 | ObjectPath.root()["stmts"][0]["lhs"], 153 | """ 154 | def f(a, b): 155 | c = a + b 156 | ^ 157 | return c 158 | """, 159 | ), 160 | ( 161 | ObjectPath.root()["stmts"][0]["rhs"], 162 | """ 163 | def f(a, b): 164 | c = a + b 165 | ^^^^^ 166 | return c 167 | """, 168 | ), 169 | ( 170 | ObjectPath.root()["stmts"][0]["rhs"]["lhs"], 171 | """ 172 | def f(a, b): 173 | c = a + b 174 | ^ 175 | return c 176 | """, 177 | ), 178 | ( 179 | ObjectPath.root()["stmts"][0]["rhs"]["rhs"], 180 | """ 181 | def f(a, b): 182 | c = a + b 183 | ^ 184 | return c 185 | """, 186 | ), 187 | ( 188 | ObjectPath.root()["ret"], 189 | """ 190 | def f(a, b): 191 | c = a + b 192 | return c 193 | ^ 194 | """, 195 | ), 196 | ], 197 | ) 198 | def test_print_underscore(path: ObjectPath, expected: str) -> None: 199 | a = Var(name="a") 200 | b = Var(name="b") 201 | c = Var(name="c") 202 | f = Func( 203 | name="f", 204 | args=[a, b], 205 | stmts=[ 206 | Assign(lhs=c, rhs=Add(a, b)), 207 | ], 208 | ret=c, 209 | ) 210 | actual = mlcp.to_python( 211 | f, 212 | mlcp.PrinterConfig(path_to_underline=[path]), 213 | ) 214 | assert actual.strip() == expected.strip() 215 | -------------------------------------------------------------------------------- /tests/python/test_sym_analyzer_simplify.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from collections.abc import Mapping 3 | from types import MappingProxyType 4 | from typing import Literal 5 | 6 | import pytest 7 | from mlc import sym as S 8 | 9 | 10 | def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: 11 | if "param" in metafunc.fixturenames: 12 | if test_cases := getattr(metafunc.cls, "param", None): 13 | metafunc.parametrize("param", test_cases) 14 | 15 | 16 | @pytest.fixture 17 | def analyzer() -> S.Analyzer: 18 | return S.Analyzer() 19 | 20 | 21 | def test_index_flatten(analyzer: S.Analyzer) -> None: 22 | i0 = S.Var("i0", "int64") 23 | i1 = S.Var("i1", "int64") 24 | analyzer.bind(i0, S.Range.from_const("int64", 0, 8)) 25 | analyzer.bind(i1, S.Range.from_const("int64", 0, 3)) 26 | 27 | i_flattened = i0 * 3 + i1 28 | before = (i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4 29 | expected_after = i_flattened 30 | after = analyzer.simplify(before) 31 | expected_after.eq_s(after, assert_mode=True) 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "dtype", 36 | ( 37 | "uint8", 38 | "uint16", 39 | "uint32", 40 | "uint64", 41 | "int8", 42 | "int16", 43 | "int32", 44 | "int64", 45 | "float16", 46 | "float32", 47 | "float64", 48 | ), 49 | ) 50 | def test_can_prove_self_identity(analyzer: S.Analyzer, dtype: str) -> None: 51 | n = S.Var("n", dtype) 52 | assert analyzer.can_prove(n == n) # type: ignore[arg-type] # noqa: PLR0124 53 | assert analyzer.can_prove_equal(n, n) 54 | 55 | 56 | class TestSymbolicCompare: 57 | @dataclasses.dataclass 58 | class Param: 59 | expr: S.Expr 60 | expected: bool = True 61 | strength: Literal["default", "symbolic_bound"] = "symbolic_bound" 62 | bounds: Mapping[S.Var, S.Range] = dataclasses.field(default_factory=dict) 63 | 64 | i0 = S.Var("i0", "int64") 65 | i1 = S.Var("i1", "int64") 66 | n = S.ShapeVar("n", "int64") 67 | m = S.ShapeVar("m", "int64") 68 | bounds = MappingProxyType( 69 | { 70 | i0: S.Range(0, (n + 31) // 32), 71 | i1: S.Range.from_const("int64", 0, 32), 72 | } 73 | ) 74 | 75 | param = ( 76 | Param( 77 | i0 * 32 + i1 < (n + 31) // 32 * 32, 78 | strength="default", 79 | expected=False, 80 | bounds=bounds, 81 | ), 82 | Param(i0 * 32 + i1 < (n + 31) // 32 * 32, bounds=bounds), 83 | Param(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, bounds=bounds), 84 | Param(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, bounds=bounds), 85 | Param((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, bounds=bounds), 86 | Param((n + 31) // 32 * 32 >= i0 * 32 + i1, bounds=bounds), 87 | ) 88 | 89 | @staticmethod 90 | def test_body(analyzer: S.Analyzer, param: Param) -> None: 91 | if param.bounds: 92 | for var, bound in param.bounds.items(): 93 | analyzer.bind(var, bound) 94 | assert analyzer.can_prove(param.expr, strength=param.strength) == param.expected 95 | -------------------------------------------------------------------------------- /tests/python/test_sym_expr.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections.abc import Callable 3 | 4 | import pytest 5 | from mlc import sym as S 6 | 7 | 8 | def test_op() -> None: 9 | op = S.Op.get("mlc.S.if_then_else") 10 | assert str(op) == 'S.Op("mlc.S.if_then_else")' 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "x, expected", 15 | [ 16 | [S.Var("x", "int64"), 'S.int64("x")'], 17 | [S.Var("y", "int32"), 'S.int32("y")'], 18 | ], 19 | ids=itertools.count(), 20 | ) 21 | def test_var(x: S.Var, expected: str) -> None: 22 | assert str(x) == expected 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "x, expected", 27 | [ 28 | [S.ShapeVar("x", "int64"), 'S.int64("x", size_var=True)'], 29 | [S.ShapeVar("y", "int32"), 'S.int32("y", size_var=True)'], 30 | ], 31 | ids=itertools.count(), 32 | ) 33 | def test_size_var(x: S.Var, expected: str) -> None: 34 | assert str(x) == expected 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "x, expected", 39 | [ 40 | [S.IntImm(-1, "int64"), "-1"], 41 | [S.IntImm(1, "int64"), "1"], 42 | [S.IntImm(65536, "int32"), "65536"], # TODO: detect overflow 43 | ], 44 | ids=itertools.count(), 45 | ) 46 | def test_int_imm(x: S.IntImm, expected: str) -> None: 47 | assert str(x) == expected 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "x, expected", 52 | [ 53 | [S.BoolImm(True), "True"], 54 | [S.BoolImm(False), "False"], 55 | ], 56 | ids=itertools.count(), 57 | ) 58 | def test_bool_imm(x: S.IntImm, expected: str) -> None: 59 | assert str(x) == expected 60 | 61 | 62 | @pytest.mark.parametrize( 63 | "x, expected", 64 | [ 65 | [S.FloatImm(0, "float64"), "0.0"], 66 | [S.FloatImm(1, "float64"), "1.0"], 67 | [S.FloatImm(-1, "float64"), "-1.0"], 68 | ], 69 | ids=itertools.count(), 70 | ) 71 | def test_float_imm(x: S.FloatImm, expected: str) -> None: 72 | assert str(x) == expected 73 | 74 | 75 | def test_cast() -> None: 76 | x = S.Var("x", "int64") 77 | y = x.cast("int32") 78 | assert str(y) == 'x.cast("int32")' 79 | 80 | 81 | @pytest.mark.parametrize( 82 | "op, expected", 83 | [ 84 | [lambda x, y: x + y, "x + y"], 85 | [lambda x, y: x - y, "x - y"], 86 | [lambda x, y: x * y, "x * y"], 87 | [lambda x, y: S.truncdiv(x, y), "S.truncdiv(x, y)"], 88 | [lambda x, y: S.truncmod(x, y), "S.truncmod(x, y)"], 89 | [lambda x, y: x // y, "x // y"], 90 | [lambda x, y: x % y, "x % y"], 91 | [lambda x, y: S.min(x, y), "S.min(x, y)"], 92 | [lambda x, y: S.max(x, y), "S.max(x, y)"], 93 | ], 94 | ids=itertools.count(), 95 | ) 96 | def test_arith_binary(op: Callable[[S.Expr, S.Expr], S.Expr], expected: str) -> None: 97 | x = S.Var("x", "int64") 98 | y = S.Var("y", "int64") 99 | z = op(x, y) 100 | assert str(z) == expected 101 | 102 | 103 | @pytest.mark.parametrize( 104 | "op, expected", 105 | [ 106 | [lambda x, y: x == y, "x == y"], 107 | [lambda x, y: x != y, "x != y"], 108 | [lambda x, y: x >= y, "x >= y"], 109 | [lambda x, y: x <= y, "x <= y"], 110 | [lambda x, y: x > y, "x > y"], 111 | [lambda x, y: x < y, "x < y"], 112 | ], 113 | ids=itertools.count(), 114 | ) 115 | def test_cmp(op: Callable[[S.Expr, S.Expr], S.Expr], expected: str) -> None: 116 | x = S.Var("x", "int64") 117 | y = S.Var("y", "int64") 118 | z = op(x, y) 119 | assert z.dtype == "bool" 120 | assert str(z) == expected 121 | 122 | 123 | def test_logical_and() -> None: 124 | x = S.Var("x", "bool") 125 | y = S.Var("y", "bool") 126 | z = S.logical_and(x, y) 127 | assert str(z) == "x and y" 128 | 129 | 130 | def test_logical_or() -> None: 131 | x = S.Var("x", "bool") 132 | y = S.Var("y", "bool") 133 | z = S.logical_or(x, y) 134 | assert str(z) == "x or y" 135 | 136 | 137 | def test_logical_not() -> None: 138 | x = S.Var("x", "bool") 139 | y = S.logical_not(x) 140 | assert str(y) == "not x" 141 | 142 | 143 | def test_select() -> None: 144 | cond = S.Var("cond", "bool") 145 | true_value = S.Var("true_value", "int64") 146 | false_value = S.Var("false_value", "int64") 147 | z = S.select(cond, true_value, false_value) 148 | assert str(z) == "S.select(cond, true_value, false_value)" 149 | 150 | 151 | def test_let() -> None: 152 | x = S.Var("x", "int64") 153 | y = S.Var("y", "int64") 154 | z = S.let( 155 | var=y, 156 | value=x + 1, 157 | body=x + y, 158 | ) 159 | assert ( 160 | str(z).strip() 161 | == """ 162 | y = x + 1 163 | x + y 164 | """.strip() 165 | ) 166 | 167 | 168 | def test_ramp() -> None: 169 | tx = S.Var("tx", "int64") 170 | ty = S.Var("ty", "int64") 171 | tz = S.Var("tz", "int64") 172 | x = S.ramp( 173 | ty * 512 + tz * 256 + tx * 8 + 2048, 174 | stride=S.IntImm(1, "int64"), 175 | lanes=8, 176 | ) 177 | assert str(x) == "S.ramp(ty * 512 + tz * 256 + tx * 8 + 2048, 1, 8)" 178 | assert x.dtype == "int64x8" 179 | 180 | 181 | def test_broadcast() -> None: 182 | value = S.Var("value", "int64") 183 | x = S.broadcast(value, lanes=8) 184 | assert str(x) == "S.broadcast(value, 8)" 185 | assert x.dtype == "int64x8" 186 | 187 | 188 | def test_range() -> None: 189 | min = S.Var("min", "int64") 190 | extent = S.Var("extent", "int64") 191 | x = S.Range(min, extent) 192 | assert str(x) == "S.Range(min, extent)" 193 | --------------------------------------------------------------------------------