├── .clang-format ├── test ├── python │ ├── onnx_importer │ │ ├── .gitignore │ │ ├── LeakyReLU.onnx │ │ ├── lit.local.cfg │ │ ├── BadName.onnx │ │ ├── import_onnx_tool.runlit │ │ ├── function_expansion │ │ │ ├── GreaterOrEqual.runlit.onnx │ │ │ ├── ReduceSumSquare_no_attrs.runlit.onnx │ │ │ ├── ReduceSumSquare_keepdims=0.runlit.onnx │ │ │ ├── ReduceSumSquare_keepdims=0.runlit │ │ │ └── ReduceSumSquare_no_attrs.runlit │ │ ├── BadName.runlit │ │ ├── _torch_mlir_config.py │ │ └── constants.py │ ├── lit.local.cfg │ └── fx_importer │ │ ├── v2.3 │ │ └── lit.local.cfg │ │ └── sparsity │ │ └── lit.local.cfg ├── CAPI │ ├── lit.local.cfg │ └── CMakeLists.txt ├── RefBackend │ └── lit.local.cfg ├── Conversion │ ├── TorchToStablehlo │ │ └── lit.local.cfg │ ├── TorchToTensor │ │ └── torch_to_tensor.mlir │ ├── TorchConversionToMLProgram │ │ ├── multiple_functions.mlir │ │ └── basic.mlir │ ├── TorchToTosa │ │ └── cast_fp32_to_fp16.mlir │ └── TorchToLinalg │ │ └── squeeze.mlir ├── Dialect │ ├── Torch │ │ ├── erase-module-initializer.mlir │ │ ├── decompose-complex-ops-legal.mlir │ │ ├── GlobalizeObjectGraph │ │ │ ├── module-uses-error.mlir │ │ │ ├── visibility.mlir │ │ │ ├── submodules.mlir │ │ │ └── error.mlir │ │ ├── verify-backend-contract-unimplemented-op.mlir │ │ ├── reduce-op-variants-error.mlir │ │ └── verify-backend-contract-error.mlir │ ├── TMTensor │ │ └── canonicalize.mlir │ └── TorchConversion │ │ └── unpack-quant-tensor.mlir ├── CMakeLists.txt └── lit.site.cfg.py.in ├── projects ├── pt1 │ ├── python │ │ ├── torch_mlir_e2e_test │ │ │ ├── __init__.py │ │ │ ├── tosa_backends │ │ │ │ └── __init__.py │ │ │ ├── stablehlo_backends │ │ │ │ └── __init__.py │ │ │ ├── linalg_on_tensors_backends │ │ │ │ └── __init__.py │ │ │ ├── CMakeLists.txt │ │ │ ├── configs │ │ │ │ ├── __init__.py │ │ │ │ ├── native_torch.py │ │ │ │ └── torchscript.py │ │ │ ├── utils.py │ │ │ └── test_suite │ │ │ │ └── custom_op_example.py │ │ ├── torch_mlir │ │ │ ├── csrc │ │ │ │ ├── reference_lazy_backend │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── gen_dummy_lib.py │ │ │ │ │ └── backend_impl.h │ │ │ │ ├── .clang-format │ │ │ │ └── jit_ir_importer │ │ │ │ │ ├── import_options_pybind.h │ │ │ │ │ ├── init_python_bindings.cpp │ │ │ │ │ ├── get_registered_ops.h │ │ │ │ │ ├── import_options_pybind.cpp │ │ │ │ │ ├── class_annotator_pybind.h │ │ │ │ │ └── CMakeLists.txt │ │ │ ├── jit_ir_importer │ │ │ │ ├── build_tools │ │ │ │ │ └── __init__.py │ │ │ │ ├── CMakeLists.txt │ │ │ │ └── __init__.py │ │ │ ├── _torch_mlir_custom_op_example │ │ │ │ ├── __init__.py │ │ │ │ ├── torch_mlir_custom_op_example.cpp │ │ │ │ └── CMakeLists.txt │ │ │ └── _version.py │ │ └── test │ │ │ ├── torchscript_e2e_test │ │ │ ├── README.md │ │ │ ├── basic.py │ │ │ ├── runtime_failure.py │ │ │ └── submodule.py │ │ │ ├── CMakeLists.txt │ │ │ ├── lazy_backend │ │ │ ├── run_test.py │ │ │ └── device_data_name.py │ │ │ └── compile_api │ │ │ ├── make_fx.py │ │ │ ├── backend_legal_ops.py │ │ │ ├── output_type_spec.py │ │ │ ├── already_traced.py │ │ │ └── already_scripted.py │ ├── examples │ │ ├── example-requirements.txt │ │ ├── torchscript_stablehlo_backend_resnet.py │ │ ├── torchscript_stablehlo_backend_tinybert.py │ │ └── torchscript_resnet18_all_output_types.py │ ├── test │ │ ├── python │ │ │ ├── lit.local.cfg │ │ │ ├── importer │ │ │ │ └── jit_ir │ │ │ │ │ ├── lit.local.cfg │ │ │ │ │ ├── ivalue_import │ │ │ │ │ ├── README.md │ │ │ │ │ ├── debug-module-name.py │ │ │ │ │ ├── methods-locations.py │ │ │ │ │ ├── object-identity.py │ │ │ │ │ ├── object-identity-error.py │ │ │ │ │ ├── strings.py │ │ │ │ │ ├── object-identity-error-submodule.py │ │ │ │ │ ├── tuple.py │ │ │ │ │ ├── list.py │ │ │ │ │ ├── submodules-select.py │ │ │ │ │ ├── annotations │ │ │ │ │ │ ├── arg-tensor-type-bound.py │ │ │ │ │ │ ├── export-error.py │ │ │ │ │ │ └── export.py │ │ │ │ │ ├── primitives.py │ │ │ │ │ └── methods-derefine.py │ │ │ │ │ ├── node_import │ │ │ │ │ ├── utils.py │ │ │ │ │ ├── types-none.py │ │ │ │ │ ├── types-bool.py │ │ │ │ │ ├── union.py │ │ │ │ │ ├── README.md │ │ │ │ │ ├── errors.py │ │ │ │ │ ├── list.py │ │ │ │ │ ├── elif.py │ │ │ │ │ ├── function-block-arg-adjustment.py │ │ │ │ │ ├── debug-info.py │ │ │ │ │ └── classes.py │ │ │ │ │ └── get_registered_ops.py │ │ │ ├── smoketest.py │ │ │ └── compile.py │ │ ├── CMakeLists.txt │ │ └── lit.site.cfg.py.in │ ├── tools │ │ └── e2e_test.sh │ └── CMakeLists.txt ├── ltc │ ├── CMakeLists.txt │ └── csrc │ │ └── base_lazy_backend │ │ ├── utils │ │ ├── jit_utils.h │ │ ├── tensor_utils.h │ │ ├── sys_utils.h │ │ ├── debug.h │ │ ├── jit_utils.cpp │ │ └── string_utils.h │ │ ├── tensor.h │ │ ├── ops │ │ ├── generic.cpp │ │ ├── unbind_int.h │ │ ├── ivalue.h │ │ ├── ivalue.cpp │ │ └── generic.h │ │ ├── tensor.cpp │ │ ├── README.md │ │ └── mlir_node_lowering.h ├── jit_ir_common │ ├── CMakeLists.txt │ └── csrc │ │ └── jit_ir_importer │ │ ├── CMakeLists.txt │ │ └── ivalue_importer.h └── onnx_c_importer │ ├── README.md │ └── CMakeLists.txt ├── pytorch-hash.txt ├── include ├── torch-mlir-dialects │ ├── CMakeLists.txt │ └── Dialect │ │ ├── CMakeLists.txt │ │ └── TMTensor │ │ ├── CMakeLists.txt │ │ ├── Transforms │ │ ├── CMakeLists.txt │ │ ├── Passes.td │ │ └── Passes.h │ │ └── IR │ │ ├── TMTensorDialect.h │ │ ├── TMTensorOpInterface.h │ │ └── ScalarLoopOpInterface.h ├── torch-mlir │ ├── Dialect │ │ ├── Torch │ │ │ ├── IR │ │ │ │ ├── .gitignore │ │ │ │ ├── CMakeLists.txt │ │ │ │ └── TorchDialect.h │ │ │ ├── CMakeLists.txt │ │ │ ├── Transforms │ │ │ │ └── CMakeLists.txt │ │ │ └── Utils │ │ │ │ └── SparsityUtils.h │ │ ├── CMakeLists.txt │ │ └── TorchConversion │ │ │ ├── CMakeLists.txt │ │ │ ├── Transforms │ │ │ └── CMakeLists.txt │ │ │ └── IR │ │ │ ├── CMakeLists.txt │ │ │ ├── TorchConversionDialect.h │ │ │ ├── TorchConversionOps.h │ │ │ └── TorchConversionBase.td │ ├── CMakeLists.txt │ ├── RefBackend │ │ ├── CMakeLists.txt │ │ └── Passes.h │ ├── Conversion │ │ ├── TorchOnnxToTorch │ │ │ ├── CMakeLists.txt │ │ │ ├── Passes.h │ │ │ └── Passes.td │ │ ├── CMakeLists.txt │ │ ├── Passes.h │ │ ├── TorchToSCF │ │ │ └── TorchToSCF.h │ │ ├── TorchToArith │ │ │ └── TorchToArith.h │ │ ├── TorchToTMTensor │ │ │ └── TorchToTMTensor.h │ │ ├── TorchToTensor │ │ │ └── TorchToTensor.h │ │ ├── TorchToLinalg │ │ │ └── TorchToLinalg.h │ │ ├── TorchConversionToMLProgram │ │ │ └── TorchConversionToMLProgram.h │ │ └── TorchToStablehlo │ │ │ └── TorchToStablehlo.h │ └── InitAll.h ├── CMakeLists.txt └── torch-mlir-c │ ├── Transforms.h │ ├── Dialects.h │ ├── Registration.h │ └── TorchOps.h ├── whl-requirements.txt ├── lib ├── Dialect │ ├── TMTensor │ │ ├── CMakeLists.txt │ │ ├── Transforms │ │ │ ├── CMakeLists.txt │ │ │ └── Passes.cpp │ │ └── IR │ │ │ ├── CMakeLists.txt │ │ │ ├── ScalarLoopOpInterface.cpp │ │ │ └── TMTensorDialect.cpp │ ├── TorchConversion │ │ ├── CMakeLists.txt │ │ ├── IR │ │ │ └── CMakeLists.txt │ │ └── Transforms │ │ │ └── CMakeLists.txt │ ├── Torch │ │ ├── CMakeLists.txt │ │ ├── Utils │ │ │ └── CMakeLists.txt │ │ ├── IR │ │ │ ├── CMakeLists.txt │ │ │ └── TorchOpsODSGenerated.cpp │ │ └── Transforms │ │ │ └── CMakeLists.txt │ └── CMakeLists.txt ├── Conversion │ ├── Utils │ │ └── CMakeLists.txt │ ├── TorchToArith │ │ └── CMakeLists.txt │ ├── TorchToTensor │ │ └── CMakeLists.txt │ ├── TorchToSCF │ │ └── CMakeLists.txt │ ├── TorchToTosa │ │ └── CMakeLists.txt │ ├── TorchToTMTensor │ │ └── CMakeLists.txt │ ├── TorchConversionToMLProgram │ │ └── CMakeLists.txt │ ├── TorchOnnxToTorch │ │ ├── CMakeLists.txt │ │ └── Passes.cpp │ ├── TorchToLinalg │ │ └── CMakeLists.txt │ ├── TorchToStablehlo │ │ ├── CMakeLists.txt │ │ ├── Utils.h │ │ └── Utils.cpp │ └── CMakeLists.txt ├── CAPI │ ├── CMakeLists.txt │ ├── Dialects.cpp │ ├── Transforms.cpp │ └── Registration.cpp ├── RefBackend │ └── CMakeLists.txt └── CMakeLists.txt ├── test-requirements.txt ├── tools ├── CMakeLists.txt ├── torch-mlir-opt │ └── CMakeLists.txt └── torch-mlir-lsp-server │ ├── CMakeLists.txt │ └── torch-mlir-lsp-server.cpp ├── docs └── images │ ├── architecture.png │ ├── roadmap_backend.png │ ├── ltc_architecture.png │ ├── roadmap_frontend.png │ ├── ltc_syncing_tensors.png │ ├── ltc_tracing_tensors.png │ ├── ltc_vendor_execution.png │ └── readme_architecture_diagram.png ├── requirements.txt ├── pytorch-requirements.txt ├── torchvision-requirements.txt ├── pyproject.toml ├── .github ├── dependabot.yml └── workflows │ ├── pre-commit-all.yml │ ├── pre-commit.yml │ └── merge-rollpytorch.yml ├── .gitmodules ├── utils └── bazel │ ├── BUILD.bazel │ ├── torch-mlir-overlay │ ├── .bazelignore │ └── test │ │ ├── Dialect │ │ └── BUILD.bazel │ │ ├── Conversion │ │ └── BUILD.bazel │ │ ├── RefBackend │ │ └── BUILD.bazel │ │ └── BUILD.bazel │ ├── docker │ └── run_docker.sh │ └── .bazelrc ├── python ├── torch_mlir │ ├── dialects │ │ ├── torch │ │ │ └── __init__.py │ │ └── TorchBinding.td │ ├── _mlir_libs │ │ └── _site_initialize_0.py │ └── tools │ │ └── opt │ │ └── __main__.py └── TorchMLIRModule.cpp ├── .git-blame-ignore-revs ├── CITATION.cff ├── .gitignore ├── .yamllint.yml └── .pre-commit-config.yaml /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | -------------------------------------------------------------------------------- /test/python/onnx_importer/.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/CAPI/lit.local.cfg: -------------------------------------------------------------------------------- 1 | config.suffixes.add('.c') 2 | -------------------------------------------------------------------------------- /pytorch-hash.txt: -------------------------------------------------------------------------------- 1 | 0dfcb1a118dd45c544a156e1d86566368e528e69 2 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/tosa_backends/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /include/torch-mlir-dialects/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Dialect) 2 | -------------------------------------------------------------------------------- /projects/ltc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(csrc/base_lazy_backend) 2 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/Torch/IR/.gitignore: -------------------------------------------------------------------------------- 1 | JITOperatorRegistryDump.txt 2 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /include/torch-mlir-dialects/Dialect/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(TMTensor) 2 | -------------------------------------------------------------------------------- /projects/jit_ir_common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(csrc/jit_ir_importer) 2 | -------------------------------------------------------------------------------- /whl-requirements.txt: -------------------------------------------------------------------------------- 1 | -f build-requirements.txt 2 | -f pytorch-requirements.txt 3 | -------------------------------------------------------------------------------- /lib/Dialect/TMTensor/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | dill 3 | multiprocess 4 | onnx==1.16.1 5 | mpmath==1.3.0 6 | -------------------------------------------------------------------------------- /include/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(torch-mlir) 2 | add_subdirectory(torch-mlir-dialects) 3 | -------------------------------------------------------------------------------- /lib/Dialect/TorchConversion/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | -------------------------------------------------------------------------------- /projects/pt1/examples/example-requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | transformers 3 | requests 4 | pillow 5 | -------------------------------------------------------------------------------- /tools/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(torch-mlir-lsp-server) 2 | add_subdirectory(torch-mlir-opt) 3 | -------------------------------------------------------------------------------- /docs/images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/docs/images/architecture.png -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Torch) 2 | add_subdirectory(TorchConversion) 3 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/Torch/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r pytorch-requirements.txt 2 | -r build-requirements.txt 3 | -r test-requirements.txt 4 | -------------------------------------------------------------------------------- /docs/images/roadmap_backend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/docs/images/roadmap_backend.png -------------------------------------------------------------------------------- /projects/pt1/test/python/lit.local.cfg: -------------------------------------------------------------------------------- 1 | if not config.enable_bindings_python: 2 | config.unsupported = True 3 | -------------------------------------------------------------------------------- /test/RefBackend/lit.local.cfg: -------------------------------------------------------------------------------- 1 | if not config.torch_mlir_enable_refbackend: 2 | config.unsupported = True 3 | -------------------------------------------------------------------------------- /docs/images/ltc_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/docs/images/ltc_architecture.png -------------------------------------------------------------------------------- /docs/images/roadmap_frontend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/docs/images/roadmap_frontend.png -------------------------------------------------------------------------------- /include/torch-mlir-dialects/Dialect/TMTensor/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/TorchConversion/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | -------------------------------------------------------------------------------- /lib/Dialect/Torch/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | add_subdirectory(Utils) 4 | -------------------------------------------------------------------------------- /test/Conversion/TorchToStablehlo/lit.local.cfg: -------------------------------------------------------------------------------- 1 | if not config.enable_stablehlo: 2 | config.unsupported = True 3 | -------------------------------------------------------------------------------- /docs/images/ltc_syncing_tensors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/docs/images/ltc_syncing_tensors.png -------------------------------------------------------------------------------- /docs/images/ltc_tracing_tensors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/docs/images/ltc_tracing_tensors.png -------------------------------------------------------------------------------- /docs/images/ltc_vendor_execution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/docs/images/ltc_vendor_execution.png -------------------------------------------------------------------------------- /lib/Dialect/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(TMTensor) 2 | add_subdirectory(Torch) 3 | add_subdirectory(TorchConversion) 4 | -------------------------------------------------------------------------------- /pytorch-requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/nightly/cpu/torch/ 2 | --pre 3 | torch==2.10.0.dev20251016 4 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/lit.local.cfg: -------------------------------------------------------------------------------- 1 | if not config.enable_jit_ir_importer: 2 | config.unsupported = True 3 | -------------------------------------------------------------------------------- /include/torch-mlir/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Conversion) 2 | add_subdirectory(Dialect) 3 | add_subdirectory(RefBackend) 4 | -------------------------------------------------------------------------------- /test/python/lit.local.cfg: -------------------------------------------------------------------------------- 1 | if not config.enable_bindings_python or "Windows" in config.host_os: 2 | config.unsupported = True 3 | -------------------------------------------------------------------------------- /test/python/onnx_importer/LeakyReLU.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/test/python/onnx_importer/LeakyReLU.onnx -------------------------------------------------------------------------------- /docs/images/readme_architecture_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/docs/images/readme_architecture_diagram.png -------------------------------------------------------------------------------- /torchvision-requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ 2 | --pre 3 | torchvision==0.25.0.dev20251016 4 | -------------------------------------------------------------------------------- /projects/pt1/python/test/torchscript_e2e_test/README.md: -------------------------------------------------------------------------------- 1 | This directory is for testing the e2e_test framework itself. 2 | It is not for holding e2e tests themselves!!! 3 | -------------------------------------------------------------------------------- /test/python/onnx_importer/lit.local.cfg: -------------------------------------------------------------------------------- 1 | try: 2 | import onnx 3 | except ModuleNotFoundError: 4 | print("Skipping onnx tests.. no onnx") 5 | config.unsupported = True 6 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | AlignAfterOpenBracket: AlwaysBreak # BlockIndent 3 | PointerAlignment: Left 4 | ReflowComments: false 5 | -------------------------------------------------------------------------------- /test/python/onnx_importer/BadName.onnx: -------------------------------------------------------------------------------- 1 |  2 | :tmain*6B&abz_.(1, 2)[$something, %anotherthing]Jb4 3 | &abz_.(1, 2)[$something, %anotherthing] 4 | 5 |  6 | B -------------------------------------------------------------------------------- /test/python/onnx_importer/import_onnx_tool.runlit: -------------------------------------------------------------------------------- 1 | # RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s 2 | 3 | # CHECK: torch.operator "onnx.LeakyRelu" 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 88 7 | target-version = ['py38'] 8 | -------------------------------------------------------------------------------- /test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit.onnx -------------------------------------------------------------------------------- /test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit.onnx -------------------------------------------------------------------------------- /test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llvm/torch-mlir/HEAD/test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit.onnx -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | groups: 8 | github-actions: 9 | patterns: 10 | - "*" 11 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "externals/llvm-project"] 2 | path = externals/llvm-project 3 | url = https://github.com/llvm/llvm-project.git 4 | [submodule "externals/stablehlo"] 5 | path = externals/stablehlo 6 | url = https://github.com/openxla/stablehlo.git 7 | -------------------------------------------------------------------------------- /include/torch-mlir/RefBackend/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls) 3 | add_public_tablegen_target(TorchMLIRRefBackendPassIncGen) 4 | 5 | #add_mlir_doc(Passes RefBackendPasses ./ -gen-pass-doc) 6 | -------------------------------------------------------------------------------- /test/CAPI/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_llvm_executable(torch-mlir-capi-torch-test torch.c) 2 | llvm_update_compile_flags(torch-mlir-capi-torch-test) 3 | target_link_libraries( 4 | torch-mlir-capi-torch-test 5 | PRIVATE 6 | MLIRCAPIIR 7 | TorchMLIRCAPI 8 | ) 9 | -------------------------------------------------------------------------------- /utils/bazel/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # This file is licensed under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | 5 | # Required to reference .bzl files in this package 6 | -------------------------------------------------------------------------------- /test/Dialect/Torch/erase-module-initializer.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-erase-module-initializer -split-input-file -verify-diagnostics %s | FileCheck %s 2 | 3 | // CHECK: module { 4 | // CHECK-NEXT: } 5 | torch.global_slot.module_initializer { 6 | torch.initialize.global_slots [ 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /test/python/fx_importer/v2.3/lit.local.cfg: -------------------------------------------------------------------------------- 1 | config.unsupported = True 2 | 3 | try: 4 | import torch 5 | if torch.__version__ >= "2.3.0" and "Windows" not in config.host_os: 6 | print("Enabling Torch v2.3+ tests") 7 | config.unsupported = False 8 | except ModuleNotFoundError: 9 | ... 10 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls) 3 | add_public_tablegen_target(TorchMLIRConversionTorchOnnxToTorchPassIncGen) 4 | add_mlir_doc(Passes TorchMLIRConversionTorchOnnxToTorchPasses ./ -gen-pass-doc) 5 | -------------------------------------------------------------------------------- /lib/Conversion/Utils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRConversionUtils 2 | Utils.cpp 3 | 4 | ADDITIONAL_HEADER_DIRS 5 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/Utils 6 | 7 | LINK_LIBS PUBLIC 8 | MLIRArithDialect 9 | MLIRLinalgDialect 10 | TorchMLIRTorchDialect 11 | ) 12 | -------------------------------------------------------------------------------- /test/python/onnx_importer/BadName.runlit: -------------------------------------------------------------------------------- 1 | # The original constant name : "abz_.(1, 2)[$something, %anotherthing]" 2 | 3 | # RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/BadName.onnx | FileCheck %s 4 | 5 | # CHECK: torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_abz_._1__2___something___anotherthing_> 6 | -------------------------------------------------------------------------------- /test/python/fx_importer/sparsity/lit.local.cfg: -------------------------------------------------------------------------------- 1 | config.unsupported = True 2 | 3 | try: 4 | import torch 5 | if "2.5.0" <= str(torch.__version__) and "Windows" not in config.host_os: 6 | print("Enabling sparsity propagation tests") 7 | config.unsupported = False 8 | 9 | except ModuleNotFoundError: 10 | ... 11 | -------------------------------------------------------------------------------- /utils/bazel/torch-mlir-overlay/.bazelignore: -------------------------------------------------------------------------------- 1 | # This file is licensed under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | 5 | # Skip the following directories when overlaying 6 | utils/bazel 7 | externals 8 | -------------------------------------------------------------------------------- /lib/Dialect/Torch/Utils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_dialect_library(TorchMLIRTorchUtils 2 | Utils.cpp 3 | SparsityUtils.cpp 4 | TorchUpstream.cpp 5 | 6 | ADDITIONAL_HEADER_DIRS 7 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Utils 8 | 9 | DEPENDS 10 | MLIRTorchOpsIncGen 11 | MLIRTorchTypesIncGen 12 | ) 13 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace torch { 4 | namespace jit { 5 | 6 | // Convert ScalarImplicit to IntImplicit or FloatImplicit. 7 | TORCH_API void ConvertScalarImplicit(std::shared_ptr &graph); 8 | 9 | } // namespace jit 10 | } // namespace torch 11 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | 3 | mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS}) 4 | 5 | add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen) 6 | 7 | add_mlir_doc(Passes TorchMLIRTorchConversionTransforms ./ -gen-pass-doc) 8 | -------------------------------------------------------------------------------- /include/torch-mlir-dialects/Dialect/TMTensor/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls) 3 | mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header) 4 | mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl) 5 | add_public_tablegen_target(TorchMLIRTMTensorTransformsPassesIncGen) 6 | -------------------------------------------------------------------------------- /utils/bazel/docker/run_docker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | docker build -f utils/bazel/docker/Dockerfile \ 4 | -t torch-mlir:dev \ 5 | . 6 | 7 | docker run -it \ 8 | -v "$(pwd)":"/opt/src/torch-mlir" \ 9 | -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ 10 | torch-mlir:dev 11 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/_torch_mlir_custom_op_example/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | # Register _torch_mlir_custom_op_example.identity as a side-effect of importing. 5 | current_dir = os.path.dirname(os.path.abspath(__file__)) 6 | lib = os.path.join(*[current_dir, "libtorch_mlir_custom_op_example.so"]) 7 | torch.ops.load_library(lib) 8 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(TorchOnnxToTorch) 2 | 3 | set(LLVM_TARGET_DEFINITIONS Passes.td) 4 | 5 | 6 | 7 | mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS}) 8 | 9 | add_public_tablegen_target(TorchMLIRConversionPassIncGen) 10 | 11 | add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc) 12 | -------------------------------------------------------------------------------- /projects/pt1/tools/e2e_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | 4 | src_dir="$(realpath "$(dirname "$0")"/..)" 5 | project_dir="$src_dir/../.." 6 | 7 | cd "$src_dir" 8 | 9 | # Ensure PYTHONPATH is set for export to child processes, even if empty. 10 | export PYTHONPATH=${PYTHONPATH-} 11 | source $project_dir/.env 12 | 13 | python -m e2e_testing.main "$@" 14 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/Torch/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls) 3 | mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header) 4 | mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl) 5 | add_public_tablegen_target(TorchMLIRTorchPassIncGen) 6 | 7 | add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc) 8 | -------------------------------------------------------------------------------- /python/torch_mlir/dialects/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | from .._torch_ops_gen import * 7 | from ..._mlir_libs._torchMlir import register_dialect 8 | -------------------------------------------------------------------------------- /projects/onnx_c_importer/README.md: -------------------------------------------------------------------------------- 1 | # ONNX C Importer 2 | 3 | This project provides a C++ implementation of the `onnx_importer.py`, which is 4 | the canonical source. It is provided as sample code for anyone who wishes to 5 | integrate it into their system. By design, it only depends on the ONNX API 6 | and the MLIR C API via the `mlir-c` headers. As such, it should be easy to 7 | build into any system that already has those things by adding the sources. 8 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToArith/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchToArith 2 | TorchToArith.cpp 3 | 4 | ADDITIONAL_HEADER_DIRS 5 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToArith 6 | 7 | DEPENDS 8 | TorchMLIRConversionPassIncGen 9 | 10 | LINK_LIBS PUBLIC 11 | MLIRIR 12 | MLIRPass 13 | MLIRFuncDialect 14 | TorchMLIRTorchDialect 15 | ) 16 | 17 | torch_mlir_target_includes(TorchMLIRTorchToArith) 18 | -------------------------------------------------------------------------------- /utils/bazel/torch-mlir-overlay/test/Dialect/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@llvm-project//llvm:lit_test.bzl", "lit_test") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | [ 6 | lit_test( 7 | name = "%s.test" % src, 8 | srcs = [src], 9 | data = [ 10 | "@torch-mlir//:torch-mlir-opt", 11 | "@torch-mlir//test:lit_data", 12 | ], 13 | ) 14 | for src in glob(["**/*.mlir"]) 15 | ] 16 | -------------------------------------------------------------------------------- /utils/bazel/torch-mlir-overlay/test/Conversion/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@llvm-project//llvm:lit_test.bzl", "lit_test") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | [ 6 | lit_test( 7 | name = "%s.test" % src, 8 | srcs = [src], 9 | data = [ 10 | "@torch-mlir//:torch-mlir-opt", 11 | "@torch-mlir//test:lit_data", 12 | ], 13 | ) 14 | for src in glob(["**/*.mlir"]) 15 | ] 16 | -------------------------------------------------------------------------------- /utils/bazel/torch-mlir-overlay/test/RefBackend/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@llvm-project//llvm:lit_test.bzl", "lit_test") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | [ 6 | lit_test( 7 | name = "%s.test" % src, 8 | srcs = [src], 9 | data = [ 10 | "@torch-mlir//:torch-mlir-opt", 11 | "@torch-mlir//test:lit_data", 12 | ], 13 | ) 14 | for src in glob(["**/*.mlir"]) 15 | ] 16 | -------------------------------------------------------------------------------- /projects/pt1/test/python/smoketest.py: -------------------------------------------------------------------------------- 1 | # RUN: %PYTHON %s 2 | 3 | import torch_mlir.ir 4 | from torch_mlir.dialects import torch 5 | 6 | with torch_mlir.ir.Context() as ctx: 7 | torch.register_dialect(ctx) 8 | with torch_mlir.ir.Location.unknown() as loc: 9 | module = torch_mlir.ir.Module.create(loc) 10 | with torch_mlir.ir.InsertionPoint.at_block_begin(module.body): 11 | n = torch.ConstantNoneOp() 12 | module.operation.print() 13 | -------------------------------------------------------------------------------- /test/Conversion/TorchToTensor/torch_to_tensor.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt <%s -convert-torch-to-tensor | FileCheck %s 2 | 3 | // CHECK-LABEL: func.func @test_shape 4 | func.func @test_shape(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3],si64> { 5 | // CHECK: %[[SHAPE:.+]] = arith.constant dense<[3, 4, 5]> : tensor<3xi64> 6 | %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3],si64> 7 | return %0 : !torch.vtensor<[3],si64> 8 | } 9 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # This file contains the list of commits to exclude from 'git blame'. 2 | # Such commits do not meaningfully contribute to git history, and include 3 | # large-scale mechanical changes like code formatting style changes. 4 | # 5 | # To set this file as the default ignore file for 'git blame', run: 6 | # ```shell 7 | # git config blame.ignoreRevsFile .git-blame-ignore-revs 8 | # ``` 9 | 10 | # Refresh clang-format 11 | 494089d53db4c183b3ba12e36f61ce1c7553984c 12 | -------------------------------------------------------------------------------- /lib/CAPI/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_public_c_api_library(TorchMLIRCAPI 2 | Dialects.cpp 3 | Registration.cpp 4 | TorchOps.cpp 5 | TorchTypes.cpp 6 | Transforms.cpp 7 | 8 | ADDITIONAL_HEADER_DIRS 9 | ${PROJECT_SOURCE_DIR}/include/torch-mlir-c/ 10 | 11 | ENABLE_AGGREGATION 12 | 13 | LINK_LIBS PUBLIC 14 | MLIRIR 15 | MLIRSupport 16 | TorchMLIRTorchDialect 17 | TorchMLIRInitAll 18 | TorchMLIRTorchPasses 19 | ) 20 | 21 | torch_mlir_target_includes(TorchMLIRCAPI) 22 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToTensor/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchToTensor 2 | TorchToTensor.cpp 3 | 4 | ADDITIONAL_HEADER_DIRS 5 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTensor 6 | 7 | DEPENDS 8 | TorchMLIRConversionPassIncGen 9 | 10 | LINK_LIBS PUBLIC 11 | MLIRIR 12 | MLIRPass 13 | MLIRTensorDialect 14 | TorchMLIRTorchDialect 15 | TorchMLIRConversionUtils 16 | ) 17 | 18 | torch_mlir_target_includes(TorchMLIRTorchToTensor) 19 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToSCF/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchToSCF 2 | TorchToSCF.cpp 3 | 4 | ADDITIONAL_HEADER_DIRS 5 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToSCF 6 | 7 | DEPENDS 8 | TorchMLIRConversionPassIncGen 9 | 10 | LINK_LIBS PUBLIC 11 | MLIRIR 12 | MLIRPass 13 | MLIRSCFDialect 14 | MLIRFuncDialect 15 | TorchMLIRTorchDialect 16 | TorchMLIRTorchConversionDialect 17 | ) 18 | 19 | torch_mlir_target_includes(TorchMLIRTorchToSCF) 20 | -------------------------------------------------------------------------------- /lib/Dialect/TorchConversion/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_dialect_library(TorchMLIRTorchConversionDialect 2 | TorchConversionDialect.cpp 3 | TorchConversionOps.cpp 4 | 5 | ADDITIONAL_HEADER_DIRS 6 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion 7 | 8 | DEPENDS 9 | MLIRTorchConversionOpsIncGen 10 | MLIRTorchTypesIncGen 11 | 12 | LINK_LIBS PUBLIC 13 | MLIRIR 14 | MLIRSupport 15 | MLIRSideEffectInterfaces 16 | ) 17 | 18 | torch_mlir_target_includes(TorchMLIRTorchConversionDialect) 19 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/README.md: -------------------------------------------------------------------------------- 1 | # ivalue_import 2 | 3 | Most of the tests in this directory test importing of TorchScript 4 | `torch::jit::Module`'s. 5 | 6 | Modules are just one of many types of c10::IValue's and recursively contain 7 | c10::IValue's. Thus, the work of importing TorchScript modules is mainly 8 | about importing the wide variety of possible c10::IValue's, hence the name 9 | of this directory and the corresponding code in ivalue_importer.cpp that it 10 | exercises. 11 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: Torch-MLIR 3 | message: >- 4 | If you use this software, please cite it using the 5 | metadata from this file. 6 | type: software 7 | authors: 8 | - name: LLVM 9 | repository-code: 'https://github.com/llvm/torch-mlir' 10 | abstract: >- 11 | The Torch-MLIR project aims to provide first class support 12 | from the PyTorch ecosystem to the MLIR ecosystem. 13 | keywords: 14 | - Compiler 15 | - PyTorch 16 | - MLIR 17 | license: 18 | - Apache-2.0 with LLVM Exceptions 19 | - BSD 20 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/_torch_mlir_custom_op_example/torch_mlir_custom_op_example.cpp: -------------------------------------------------------------------------------- 1 | // For writing an extension like this one, see: 2 | // https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html 3 | 4 | #include // One-stop header for PyTorch 5 | 6 | torch::Tensor identity(torch::Tensor t) { 7 | // Do literally nothing. 8 | return t; 9 | } 10 | 11 | TORCH_LIBRARY(_torch_mlir_custom_op_example, m) { 12 | m.def("identity(Tensor t) -> Tensor"); 13 | m.impl("identity", &identity); 14 | } 15 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/_version.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | from packaging import version 7 | import torch 8 | 9 | 10 | def torch_version_for_comparison(): 11 | # Ignore +cpu, +cu117m, etc. in comparisons 12 | return version.parse(torch.__version__.split("+", 1)[0]) 13 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToTosa/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchToTosa 2 | TorchToTosa.cpp 3 | TosaLegalizeUtils.cpp 4 | TosaLegalizeCommon.cpp 5 | 6 | ADDITIONAL_HEADER_DIRS 7 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa 8 | 9 | DEPENDS 10 | TorchMLIRConversionPassIncGen 11 | 12 | LINK_LIBS PUBLIC 13 | MLIRIR 14 | MLIRPass 15 | MLIRTosaDialect 16 | TorchMLIRConversionUtils 17 | TorchMLIRTorchDialect 18 | ) 19 | 20 | torch_mlir_target_includes(TorchMLIRTorchToTosa) 21 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-all.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit (all files on push) 2 | 3 | on: 4 | push: 5 | branches: [main, post-commit-test] 6 | 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-22.04 10 | steps: 11 | - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 12 | - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 13 | - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 14 | with: 15 | extra_args: --color=always --all-files 16 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToTMTensor/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchToTMTensor 2 | TorchToTMTensor.cpp 3 | 4 | ADDITIONAL_HEADER_DIRS 5 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTMTensor 6 | 7 | DEPENDS 8 | TorchMLIRConversionPassIncGen 9 | 10 | LINK_LIBS PUBLIC 11 | MLIRIR 12 | MLIRPass 13 | MLIRLinalgDialect 14 | MLIRMathDialect 15 | TorchMLIRTorchDialect 16 | TorchMLIRTMTensorDialect 17 | TorchMLIRTorchUtils 18 | ) 19 | 20 | torch_mlir_target_includes(TorchMLIRTorchToTMTensor) 21 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------- 2 | # Subdirectories 3 | #------------------------------------------------------------------------------- 4 | 5 | ## Declare the sources of the Python module. 6 | 7 | declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources.JitIRImporter 8 | ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" 9 | ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources 10 | SOURCES_GLOB 11 | jit_ir_importer/*.py 12 | ) 13 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/utils.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | # Helpers for the other tests. 6 | 7 | import torch 8 | from torch._C import CompilationUnit 9 | 10 | # RUN: %PYTHON %s 11 | 12 | 13 | # Import TorchScript IR string as ScriptFunction. 14 | def create_script_function(func_name, ts_ir_str, **kwargs): 15 | cu = CompilationUnit() 16 | return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str, **kwargs)) 17 | -------------------------------------------------------------------------------- /lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram 2 | TorchConversionToMLProgram.cpp 3 | 4 | ADDITIONAL_HEADER_DIRS 5 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram 6 | 7 | DEPENDS 8 | TorchMLIRConversionPassIncGen 9 | 10 | LINK_LIBS PUBLIC 11 | MLIRIR 12 | MLIRLinalgDialect 13 | MLIRMLProgramDialect 14 | MLIRMathDialect 15 | MLIRPass 16 | TorchMLIRTorchDialect 17 | ) 18 | 19 | torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram) 20 | -------------------------------------------------------------------------------- /lib/RefBackend/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(TorchMLIRRefBackend 2 | RefBackend.cpp 3 | 4 | ADDITIONAL_HEADER_DIRS 5 | ${PROJECT_SRC_DIR}/include/torch-mlir/RefBackend 6 | 7 | DEPENDS 8 | MLIRTorchTypesIncGen 9 | TorchMLIRRefBackendPassIncGen 10 | MLIRTorchConversionOpsIncGen 11 | 12 | LINK_COMPONENTS 13 | Core 14 | 15 | LINK_LIBS PUBLIC 16 | MLIRIR 17 | MLIRTransforms 18 | MLIRMathTransforms 19 | MLIRLinalgTransforms 20 | ) 21 | 22 | mlir_check_all_link_libraries(TorchMLIRRefBackend) 23 | torch_mlir_target_includes(TorchMLIRRefBackend) 24 | -------------------------------------------------------------------------------- /python/torch_mlir/_mlir_libs/_site_initialize_0.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # Multi-threading rarely helps the frontend and we are also running in contexts 7 | # where we want to run a lot of test parallelism (and nproc*nproc threads 8 | # puts a large load on the system and virtual memory). 9 | disable_multithreading = True 10 | -------------------------------------------------------------------------------- /tools/torch-mlir-opt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_llvm_executable(torch-mlir-opt torch-mlir-opt.cpp) 2 | 3 | install(TARGETS torch-mlir-opt 4 | EXPORT TorchMLIRTargets 5 | RUNTIME DESTINATION ${LLVM_TOOLS_INSTALL_DIR} 6 | COMPONENT torch-mlir-opt) 7 | 8 | set(dependency_libraries) 9 | if(TORCH_MLIR_ENABLE_STABLEHLO) 10 | list(APPEND dependency_libraries StablehloRegister) 11 | endif() 12 | 13 | target_link_libraries(torch-mlir-opt PRIVATE 14 | MLIROptLib 15 | MLIRTransforms 16 | TorchMLIRInitAll 17 | TorchMLIRTorchDialect 18 | TorchMLIRTorchPasses 19 | ${dependency_libraries} 20 | ) 21 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | declare_mlir_python_sources(TorchMLIRE2ETestPythonSources) 2 | 3 | declare_mlir_python_sources(TorchMLIRE2ETestPythonSources.Core 4 | ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" 5 | ADD_TO_PARENT TorchMLIRE2ETestPythonSources 6 | SOURCES_GLOB 7 | *.py 8 | ) 9 | 10 | add_mlir_python_modules(TorchMLIRE2ETestPythonModules 11 | ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir_e2e_test" 12 | INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir_e2e_test" 13 | DECLARED_SOURCES TorchMLIRE2ETestPythonSources 14 | ) 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | .cache/ 3 | .vscode 4 | .ccache 5 | .env 6 | *.code-workspace 7 | .ipynb_checkpoints 8 | *.venv/ 9 | mlir_venv/ 10 | externals/pytorch/ 11 | libtorch* 12 | 13 | /build/ 14 | .build-cache/ 15 | /setup_build/ 16 | __pycache__ 17 | *.pyc 18 | 19 | .pytype 20 | 21 | 22 | # Pip artifacts. 23 | *.egg-info 24 | *.whl 25 | /wheelhouse 26 | 27 | # Bazel 28 | bazel-* 29 | 30 | # Autogenerated files 31 | /projects/ltc/csrc/base_lazy_backend/generated 32 | 33 | #Docker builds 34 | build_oot/ 35 | docker_venv/ 36 | llvm-build/ 37 | 38 | # C++ build artifacts 39 | compile_commands.json 40 | -------------------------------------------------------------------------------- /lib/Dialect/TMTensor/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(TorchMLIRTMTensorPasses 2 | ConvertToLoops.cpp 3 | Bufferize.cpp 4 | Passes.cpp 5 | 6 | DEPENDS 7 | TorchMLIRTMTensorTransformsPassesIncGen 8 | 9 | LINK_LIBS PUBLIC 10 | TorchMLIRTMTensorDialect 11 | MLIRAffineDialect 12 | MLIRIR 13 | MLIRLinalgDialect 14 | MLIRLinalgTransforms 15 | MLIRMathDialect 16 | MLIRMemRefDialect 17 | MLIRPass 18 | MLIRSCFDialect 19 | MLIRFuncDialect 20 | MLIRSupport 21 | MLIRTensorDialect 22 | MLIRTransforms 23 | ) 24 | 25 | torch_mlir_target_includes(TorchMLIRTMTensorPasses) 26 | -------------------------------------------------------------------------------- /utils/bazel/.bazelrc: -------------------------------------------------------------------------------- 1 | # This file is licensed under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | 5 | build --action_env=CC=clang-16 6 | build --action_env=CXX=clang++-16 7 | build --cxxopt=-std=c++17 8 | build --host_cxxopt=-std=c++17 9 | build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 10 | build --cxxopt=-U__GXX_ABI_VERSION 11 | build --cxxopt=-D__GXX_ABI_VERSION=1011 12 | build --cxxopt=-DPYBIND11_COMPILER_TYPE=\"_gcc\" 13 | build --cxxopt=-DMLIR_PYTHON_PACKAGE_PREFIX=torch_mlir. 14 | -------------------------------------------------------------------------------- /test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s 2 | 3 | module { 4 | func.func private @f0() -> i64 5 | func.func private @f1() -> i64 6 | func.func private @f2() -> i64 7 | func.func private @f3() -> i64 8 | func.func private @f4() -> i64 9 | func.func private @f5() -> i64 10 | func.func private @f6() -> i64 11 | func.func private @f7() -> i64 12 | } 13 | 14 | // CHECK-NOT: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor 15 | // CHECK-NOT: @global_seed 16 | -------------------------------------------------------------------------------- /test/Dialect/Torch/decompose-complex-ops-legal.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=aten.softmax.int" -split-input-file %s | FileCheck %s 2 | 3 | // CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim 4 | func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> { 5 | %none = torch.constant.none 6 | %dim = torch.constant.int 1 7 | // CHECK: torch.aten.softmax.int 8 | %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32> 9 | return %ret : !torch.tensor<[2,3],f32> 10 | } 11 | -------------------------------------------------------------------------------- /lib/Conversion/TorchOnnxToTorch/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch 2 | ComMicrosoftDomain.cpp 3 | DefaultDomainAtoF.cpp 4 | DefaultDomainGtoP.cpp 5 | DefaultDomainQtoZ.cpp 6 | OnnxRecurrentLayerOpExpanders.cpp 7 | Passes.cpp 8 | Patterns.cpp 9 | TorchOnnxToTorch.cpp 10 | Utils.cpp 11 | 12 | ADDITIONAL_HEADER_DIRS 13 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch 14 | 15 | DEPENDS 16 | TorchMLIRConversionTorchOnnxToTorchPassIncGen 17 | 18 | LINK_LIBS PUBLIC 19 | MLIRIR 20 | MLIRPass 21 | TorchMLIRTorchDialect 22 | ) 23 | -------------------------------------------------------------------------------- /test/Dialect/Torch/GlobalizeObjectGraph/module-uses-error.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s 2 | 3 | torch.class_type @parent { 4 | torch.method "module_type_return", @module_type_return 5 | } 6 | 7 | func.func private @module_type_return(%arg0: !torch.nn.Module<"parent">) { 8 | // expected-error @+1 {{unsupported use of a torch.nn.Module. Expected only method calls or attribute get/set}} 9 | torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list> 10 | return 11 | } 12 | 13 | torch.nn_module {} : !torch.nn.Module<"parent"> 14 | -------------------------------------------------------------------------------- /projects/pt1/examples/torchscript_stablehlo_backend_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | from torch_mlir import torchscript 4 | 5 | model = models.resnet18(pretrained=True) 6 | model.eval() 7 | data = torch.randn(2, 3, 200, 200) 8 | out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir" 9 | 10 | module = torchscript.compile( 11 | model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False 12 | ) 13 | with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: 14 | outf.write(module.operation.get_asm()) 15 | 16 | print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}") 17 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | pre-commit: 8 | runs-on: ubuntu-22.04 9 | steps: 10 | - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 11 | with: 12 | # requites to grab the history of the PR 13 | fetch-depth: 0 14 | - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 15 | - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 16 | with: 17 | extra_args: --color=always --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} 18 | -------------------------------------------------------------------------------- /python/torch_mlir/dialects/TorchBinding.td: -------------------------------------------------------------------------------- 1 | //===-- TorchBinding.td - Torch dialect bindings -----------*- tablegen -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef PYTHON_BINDINGS_TORCH_OPS 11 | #define PYTHON_BINDINGS_TORCH_OPS 12 | 13 | include "torch-mlir/Dialect/Torch/IR/TorchOps.td" 14 | 15 | #endif // PYTHON_BINDINGS_TORCH_OPS 16 | -------------------------------------------------------------------------------- /test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file 2 | 3 | // CHECK: %{{.*}} = tosa.cast %{{.*}} : (tensor<1x32x220x220xf32>) -> tensor<1x32x220x220xf16> 4 | func.func @forward(%arg0: !torch.vtensor<[1,32,220,220],f32>) -> !torch.vtensor<[1,32,220,220],f16> { 5 | %int5 = torch.constant.int 5 6 | %false = torch.constant.bool false 7 | %none = torch.constant.none 8 | %out = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,32,220,220],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,32,220,220],f16> 9 | return %out : !torch.vtensor<[1,32,220,220],f16> 10 | } 11 | -------------------------------------------------------------------------------- /lib/Dialect/Torch/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(TorchMLIRTorchDialect 2 | TorchDialect.cpp 3 | TorchOps.cpp 4 | TorchOpsODSGenerated.cpp 5 | TorchTypes.cpp 6 | UtilsForODSGenerated.cpp 7 | 8 | ADDITIONAL_HEADER_DIRS 9 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch 10 | 11 | DEPENDS 12 | MLIRTorchOpsIncGen 13 | MLIRTorchTypesIncGen 14 | 15 | LINK_LIBS PUBLIC 16 | MLIRBytecodeOpInterface 17 | MLIRBytecodeReader 18 | MLIRBytecodeWriter 19 | MLIRFuncDialect 20 | MLIRIR 21 | MLIRSupport 22 | MLIRControlFlowInterfaces 23 | MLIRInferTypeOpInterface 24 | MLIRSideEffectInterfaces 25 | ) 26 | 27 | torch_mlir_target_includes(TorchMLIRTorchDialect) 28 | -------------------------------------------------------------------------------- /.yamllint.yml: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | extends: default 4 | 5 | rules: 6 | # These do not appear to be conventional in GitHub actions. 7 | document-end: 8 | present: false 9 | document-start: 10 | present: false 11 | # GitHub actions use "on" for triggers. 12 | truthy: disable 13 | # We have lots of long strings and command lines. 14 | line-length: disable 15 | comments: 16 | # Formatters may do this (e.g. Prettier does) and it seems like the most 17 | # trivial thing to get a failing check for. 18 | min-spaces-from-content: 1 19 | # This is not a useful check, especially when disabling entire blocks. 20 | comments-indentation: disable 21 | 22 | ignore: /third_party/* 23 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/TorchConversion/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS TorchConversionOps.td) 2 | mlir_tablegen(TorchConversionOps.h.inc -gen-op-decls) 3 | mlir_tablegen(TorchConversionOps.cpp.inc -gen-op-defs) 4 | mlir_tablegen(TorchConversionDialect.h.inc -gen-dialect-decls -dialect=torch_c) 5 | mlir_tablegen(TorchConversionDialect.cpp.inc -gen-dialect-defs -dialect=torch_c) 6 | add_public_tablegen_target(MLIRTorchConversionOpsIncGen) 7 | add_dependencies(mlir-headers MLIRTorchConversionOpsIncGen) 8 | 9 | add_mlir_doc(TorchConversionDialect TorchConversionDialect TorchConversion/ -gen-dialect-doc) 10 | add_mlir_doc(TorchConversionOps TorchConversionOps TorchConversion/ -gen-op-doc) 11 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToLinalg/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchToLinalg 2 | DataMovement.cpp 3 | IndirectDataMovement.cpp 4 | Linear.cpp 5 | Pooling.cpp 6 | Random.cpp 7 | Reduction.cpp 8 | TensorConstructors.cpp 9 | TensorScalarInterop.cpp 10 | TorchToLinalg.cpp 11 | Uncategorized.cpp 12 | Utils.cpp 13 | 14 | ADDITIONAL_HEADER_DIRS 15 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg 16 | 17 | DEPENDS 18 | TorchMLIRConversionPassIncGen 19 | 20 | LINK_LIBS PUBLIC 21 | MLIRIR 22 | MLIRPass 23 | MLIRLinalgDialect 24 | MLIRMathDialect 25 | TorchMLIRTorchDialect 26 | ) 27 | 28 | torch_mlir_target_includes(TorchMLIRTorchToLinalg) 29 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | from .lazy_tensor_core import LazyTensorCoreTestConfig 7 | from .native_torch import NativeTorchTestConfig 8 | from .onnx_backend import OnnxBackendTestConfig 9 | from .torchscript import TorchScriptTestConfig 10 | from .torchdynamo import TorchDynamoTestConfig 11 | from .jit_importer_backend import JITImporterTestConfig 12 | from .fx_importer_backend import FxImporterTestConfig 13 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/types-none.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import torch 6 | from torch_mlir.jit_ir_importer import ModuleBuilder 7 | 8 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 9 | 10 | mb = ModuleBuilder() 11 | 12 | 13 | # CHECK: @__torch__.returns_none 14 | @mb.import_function 15 | @torch.jit.script 16 | def returns_none(): 17 | # CHECK-NEXT: %[[NONE:.*]] = torch.constant.none 18 | # CHECK-NEXT: return %[[NONE]] 19 | pass 20 | 21 | 22 | assert isinstance(returns_none, torch.jit.ScriptFunction) 23 | mb.module.operation.print() 24 | print() 25 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import torch 6 | from torch_mlir.jit_ir_importer import ModuleBuilder 7 | 8 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 9 | 10 | mb = ModuleBuilder() 11 | 12 | 13 | # CHECK: @__torch__.returns_bool 14 | @mb.import_function 15 | @torch.jit.script 16 | def returns_bool(): 17 | # CHECK-NEXT: %[[T:.*]] = torch.constant.bool true 18 | # CHECK-NEXT: return %[[T]] 19 | return True 20 | 21 | 22 | assert isinstance(returns_bool, torch.jit.ScriptFunction) 23 | mb.module.operation.print() 24 | print() 25 | -------------------------------------------------------------------------------- /lib/Dialect/TMTensor/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(TorchMLIRTMTensorDialect 2 | TMTensorDialect.cpp 3 | TMTensorInterfaces.cpp 4 | TMTensorOps.cpp 5 | ScalarLoopOpInterface.cpp 6 | 7 | ADDITIONAL_HEADER_DIRS 8 | ${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include 9 | 10 | DEPENDS 11 | TorchMLIRTMTensorOpsIncGen 12 | 13 | LINK_LIBS PUBLIC 14 | MLIRAffineDialect 15 | MLIRDialectUtils 16 | MLIRIR 17 | MLIRLinalgDialect 18 | MLIRMathDialect 19 | MLIRMemRefDialect 20 | MLIRPass 21 | MLIRSideEffectInterfaces 22 | MLIRSupport 23 | MLIRSCFDialect 24 | MLIRFuncDialect 25 | MLIRTensorDialect 26 | MLIRViewLikeInterface 27 | ) 28 | 29 | torch_mlir_target_includes(TorchMLIRTMTensorDialect) 30 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | exclude: "GeneratedTorchOps\\.td|abstract_interp_lib_gen\\.py|\\.excalidraw|\\.ipynb" 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v3.2.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-ast 11 | - id: check-yaml 12 | - id: check-added-large-files 13 | - repo: https://github.com/psf/black 14 | rev: 24.4.2 15 | hooks: 16 | - id: black 17 | 18 | - repo: https://github.com/pre-commit/mirrors-clang-format 19 | rev: 'v18.1.4' 20 | hooks: 21 | - id: clang-format 22 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToStablehlo/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(TorchMLIRTorchToStablehlo 2 | TorchToStablehlo.cpp 3 | StablehloLegalizeUtils.cpp 4 | Basic.cpp 5 | GatherScatter.cpp 6 | Linear.cpp 7 | ViewLike.cpp 8 | Reduction.cpp 9 | Rng.cpp 10 | Pooling.cpp 11 | Uncategorized.cpp 12 | Utils.cpp 13 | 14 | ADDITIONAL_HEADER_DIRS 15 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo 16 | 17 | DEPENDS 18 | TorchMLIRConversionPassIncGen 19 | 20 | LINK_LIBS PUBLIC 21 | MLIRIR 22 | MLIRPass 23 | MLIRComplexDialect 24 | ChloOps 25 | StablehloOps 26 | TorchMLIRTorchDialect 27 | TorchMLIRConversionUtils 28 | ) 29 | 30 | torch_mlir_target_includes(TorchMLIRTorchToStablehlo) 31 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/union.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | from typing import Union 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | # CHECK-LABEL: func.func @__torch__.f( 15 | # CHECK-SAME: %{{.*}}: !torch.union) -> !torch.none { 16 | 17 | 18 | @mb.import_function 19 | @torch.jit.script 20 | def f(x: Union[int, float]): 21 | return 22 | 23 | 24 | assert isinstance(f, torch.jit.ScriptFunction) 25 | mb.module.operation.print() 26 | print() 27 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/README.md: -------------------------------------------------------------------------------- 1 | # node_import 2 | 3 | Most of the tests in this directory test the importing of TorchScript 4 | `torch::jit::Graph`'s. 5 | 6 | However, TorchScript graphs don't really correspond directly to anything on 7 | the MLIR side. They are a weird combination of a context, builder, and 8 | function and just holds a `torch::jit::Block`. It is `torch::jit::Node` 9 | and `torch::jit::Block` which form the recursive structure analogous to 10 | MLIR's operation/region/block. 11 | 12 | - `torch::jit::Node` == `mlir::Operation`, 13 | - `torch::jit::Block` == `mlir::Region` containing single `mlir::Block` 14 | 15 | Hence the name of this directory and the corresponding code in 16 | node_importer.h/cpp. 17 | -------------------------------------------------------------------------------- /lib/CAPI/Dialects.cpp: -------------------------------------------------------------------------------- 1 | //===- Dialects.cpp - C Interface for Dialects ----------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "torch-mlir-c/Dialects.h" 11 | 12 | #include "mlir/CAPI/Registration.h" 13 | #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" 14 | 15 | MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, 16 | mlir::torch::Torch::TorchDialect) 17 | -------------------------------------------------------------------------------- /projects/pt1/python/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | configure_lit_site_cfg( 2 | ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in 3 | ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py 4 | MAIN_CONFIG 5 | ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py 6 | ) 7 | 8 | set(TEST_DEPENDS 9 | FileCheck count not 10 | torch-mlir-opt 11 | TorchMLIRPythonModules 12 | ) 13 | 14 | add_lit_testsuite(check-torch-mlir-python "Running the torch-mlir Python regression tests" 15 | ${CMAKE_CURRENT_BINARY_DIR} 16 | DEPENDS ${TEST_DEPENDS} 17 | ) 18 | set_target_properties(check-torch-mlir-python PROPERTIES FOLDER "Tests") 19 | 20 | add_lit_testsuites(TORCH_MLIR_PYTHON ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TEST_DEPENDS}) 21 | add_dependencies(check-torch-mlir-all check-torch-mlir-python) 22 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/Torch/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS TorchOps.td) 2 | mlir_tablegen(TorchOps.h.inc -gen-op-decls) 3 | mlir_tablegen(TorchOps.cpp.inc -gen-op-defs) 4 | mlir_tablegen(TorchDialect.h.inc -gen-dialect-decls -dialect=torch) 5 | mlir_tablegen(TorchDialect.cpp.inc -gen-dialect-defs -dialect=torch) 6 | add_public_tablegen_target(MLIRTorchOpsIncGen) 7 | add_dependencies(mlir-headers MLIRTorchOpsIncGen) 8 | 9 | set(LLVM_TARGET_DEFINITIONS TorchTypes.td) 10 | mlir_tablegen(TorchTypes.h.inc -gen-typedef-decls) 11 | mlir_tablegen(TorchTypes.cpp.inc -gen-typedef-defs) 12 | add_public_tablegen_target(MLIRTorchTypesIncGen) 13 | 14 | add_mlir_doc(TorchDialect TorchDialect Torch/ -gen-dialect-doc) 15 | add_mlir_doc(TorchOps TorchOps Torch/ -gen-op-doc) 16 | -------------------------------------------------------------------------------- /test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s 2 | func.func @forward(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor { 3 | %none = torch.constant.none 4 | %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[3,5],f32> to !torch.vtensor<*,f32> 5 | %1 = torch.copy.to_tensor %0 : !torch.tensor<*,f32> 6 | // expected-error @+1 {{unsupported by backend contract: Unimplemented operator 'an.unimplemented.op'}} 7 | %2 = torch.operator "an.unimplemented.op"(%1, %1, %none) : (!torch.tensor<*,f32>, !torch.tensor<*,f32>, !torch.none) -> !torch.tensor 8 | %3 = torch.copy.to_vtensor %2 : !torch.vtensor 9 | return %3 : !torch.vtensor 10 | } 11 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/get_registered_ops.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | # RUN: %PYTHON %s | FileCheck %s 6 | 7 | from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops 8 | 9 | # This check is just for a built-in op that is unlikely to change (and is 10 | # otherwise insignificant). 11 | # CHECK: {'name': ('aten::mul', 'Tensor'), 'is_c10_op': True, 'is_vararg': False, 'is_varret': False, 'is_mutable': False, 'arguments': [{'name': 'self', 'type': 'Tensor', 'pytype': 'Tensor'}, {'name': 'other', 'type': 'Tensor', 'pytype': 'Tensor'}], 'returns': [{'name': '', 'type': 'Tensor', 'pytype': 'Tensor'}]} 12 | print("\n\n".join([repr(r) for r in get_registered_ops()])) 13 | -------------------------------------------------------------------------------- /test/python/onnx_importer/_torch_mlir_config.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s 7 | 8 | """This file exists so that the tests can find/configure torch_mlir. 9 | 10 | It allows the test file to be standalone and used verbatim in other 11 | projects (i.e. by just providing this file on the side). 12 | """ 13 | 14 | from torch_mlir import ir 15 | from torch_mlir.extras import onnx_importer 16 | 17 | 18 | def configure_context(context): 19 | from torch_mlir.dialects import torch as torch_d 20 | 21 | torch_d.register_dialect(context) 22 | -------------------------------------------------------------------------------- /lib/Conversion/TorchOnnxToTorch/Passes.cpp: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" 11 | 12 | namespace { 13 | #define GEN_PASS_REGISTRATION 14 | #include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" 15 | } // end namespace 16 | 17 | void mlir::torch::onnx_c::registerTorchOnnxToTorchPasses() { 18 | ::registerPasses(); 19 | } 20 | -------------------------------------------------------------------------------- /test/Dialect/Torch/GlobalizeObjectGraph/visibility.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s 2 | 3 | torch.class_type @c { 4 | // CHECK: torch.global_slot "private" @float : !torch.float 5 | torch.attr private "float" : !torch.float 6 | torch.method private "forward", @method 7 | } 8 | 9 | // CHECK: func.func private @forward() { 10 | func.func private @method(%arg0: !torch.nn.Module<"c">) { 11 | return 12 | } 13 | 14 | %c42 = torch.constant.float 42.0 15 | torch.nn_module { 16 | torch.slot "float", %c42 : !torch.float 17 | } : !torch.nn.Module<"c"> 18 | 19 | func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) { 20 | %0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float 21 | return 22 | } 23 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/Passes.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_PASSES_H 11 | #define TORCHMLIR_CONVERSION_PASSES_H 12 | 13 | namespace mlir { 14 | namespace torch { 15 | 16 | /// Registers all torch-mlir conversion passes. 17 | void registerConversionPasses(); 18 | 19 | } // namespace torch 20 | } // namespace mlir 21 | 22 | #endif // TORCHMLIR_CONVERSION_PASSES_H 23 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py: -------------------------------------------------------------------------------- 1 | # When LTC is disabled in Torch-MLIR build, we will generate a dummy module to 2 | # ensure that no import errors occur. 3 | 4 | import sys 5 | import os 6 | 7 | if __name__ == "__main__": 8 | path = sys.argv[1] # dummy script path 9 | file_name = sys.argv[2] # dummy script 10 | 11 | contents = """ 12 | # This file was automatically generated due to LTC being disabled in build. 13 | 14 | class LazyTensorCoreTestConfig: 15 | def __init__(self): 16 | assert False, "LTC is not enabled. Check the value of `TORCH_MLIR_ENABLE_LTC`" 17 | """ 18 | 19 | if not os.path.exists(path): 20 | os.makedirs(path) 21 | 22 | with open(os.path.join(path, file_name + ".py"), "w") as file: 23 | file.write(contents) 24 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H 11 | #define TORCHMLIR_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H 12 | 13 | #include "mlir/IR/Dialect.h" 14 | 15 | #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h.inc" 16 | 17 | #endif // TORCHMLIR_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H 18 | -------------------------------------------------------------------------------- /projects/pt1/python/test/lazy_backend/run_test.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: true 7 | 8 | 9 | def run_test(*args, XPASS=False, XFAIL=False): 10 | def _run_test(test): 11 | test_name = test.__name__ 12 | try: 13 | test() 14 | print(("X" if XPASS else "") + f"PASS - {test_name}") 15 | except Exception as e: 16 | print(("X" if XFAIL else "") + f"FAIL - {test_name}") 17 | print("Errors: ", e) 18 | print(flush=True) 19 | 20 | if len(args): 21 | _run_test(args[0]) 22 | else: 23 | return _run_test 24 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | # CHECK: module attributes {torch.debug_module_name = "TestModule"} 16 | class TestModule(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | 21 | test_module = TestModule() 22 | recursivescriptmodule = torch.jit.script(test_module) 23 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 24 | mb.import_module(recursivescriptmodule._c) 25 | mb.module.operation.print() 26 | -------------------------------------------------------------------------------- /include/torch-mlir-c/Transforms.h: -------------------------------------------------------------------------------- 1 | //===-- torch-mlir-c/Transforms.h - C API for torch passes --------*- C -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. 5 | // See https://llvm.org/LICENSE.txt for license information. 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 | // 8 | //===----------------------------------------------------------------------===// 9 | // 10 | // This header declares the registration and creation method for 11 | // transformation passes. 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | #ifndef TORCHMLIR_C_TRANSFORMS_H 16 | #define TORCHMLIR_C_TRANSFORMS_H 17 | 18 | #include "mlir-c/Support.h" 19 | 20 | #include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc" 21 | 22 | #endif // TORCHMLIR_C_TRANSFORMS_H 23 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/errors.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import enum 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | 11 | class Color(enum.Enum): 12 | RED = 1 13 | GREEN = 2 14 | 15 | 16 | # RUN: %PYTHON %s 17 | 18 | mb = ModuleBuilder() 19 | 20 | # To test errors, use a type that we don't support yet. 21 | try: 22 | 23 | @mb.import_function 24 | @torch.jit.script 25 | def import_class(x: Color): 26 | return x 27 | 28 | except Exception as e: 29 | # TODO: Once diagnostics are enabled, verify the actual error emitted. 30 | assert str(e) == "unsupported type in function schema: 'Enum<__torch__.Color>'" 31 | else: 32 | assert False, "Expected exception" 33 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.h: -------------------------------------------------------------------------------- 1 | //===- import_options_pybind.h ----------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H 11 | #define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H 12 | 13 | #include 14 | 15 | namespace torch_mlir { 16 | void initImportOptionsBindings(pybind11::module& m); 17 | } // namespace torch_mlir 18 | 19 | #endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H 20 | -------------------------------------------------------------------------------- /projects/pt1/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | message(STATUS "Building PyTorch1 compatibility project") 2 | 3 | if(TORCH_MLIR_ENABLE_LTC) 4 | set(ENV{TORCH_MLIR_ENABLE_LTC} 1) 5 | message(STATUS "LTC Backend build is enabled") 6 | else() 7 | set(ENV{TORCH_MLIR_ENABLE_LTC} 0) 8 | message(STATUS "LTC Backend build is disabled") 9 | endif() 10 | 11 | 12 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/python/torch_mlir/cmake/modules") 13 | 14 | ################################################################################ 15 | # Setup python. 16 | ################################################################################ 17 | 18 | if(MLIR_ENABLE_BINDINGS_PYTHON) 19 | add_dependencies(check-torch-mlir-all 20 | check-torch-mlir-pt1 21 | ) 22 | add_subdirectory(python) 23 | else() 24 | add_custom_target(TorchMLIRPythonModules) 25 | endif() 26 | 27 | add_subdirectory(test) 28 | -------------------------------------------------------------------------------- /projects/pt1/python/test/compile_api/make_fx.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | import functorch 9 | import torch 10 | 11 | from torch_mlir import torchscript 12 | 13 | 14 | def simple(x): 15 | return x * x 16 | 17 | 18 | example_input = torch.randn( 19 | 1, 20 | ) 21 | graph = functorch.make_fx(simple)( 22 | torch.randn( 23 | 1, 24 | ) 25 | ) 26 | 27 | # Simplest case: One example argument. 28 | print(torchscript.compile(graph, example_input)) 29 | # CHECK-LABEL: @forward 30 | # CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> 31 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H 11 | #define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Pass/Pass.h" 15 | 16 | namespace mlir { 17 | namespace torch { 18 | std::unique_ptr> createConvertTorchToSCFPass(); 19 | } 20 | } // namespace mlir 21 | 22 | #endif // TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H 23 | -------------------------------------------------------------------------------- /projects/pt1/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | llvm_canonicalize_cmake_booleans( 2 | MLIR_ENABLE_BINDINGS_PYTHON 3 | TORCH_MLIR_ENABLE_JIT_IR_IMPORTER 4 | ) 5 | 6 | configure_lit_site_cfg( 7 | ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in 8 | ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py 9 | MAIN_CONFIG 10 | ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py 11 | ) 12 | 13 | set(TORCH_MLIR_TEST_DEPENDS 14 | FileCheck count not 15 | TorchMLIRPythonModules 16 | torch-mlir-opt 17 | torch-mlir-capi-torch-test 18 | ) 19 | 20 | add_lit_testsuite(check-torch-mlir-pt1 "Running the torch-mlir PT1 regression tests" 21 | ${CMAKE_CURRENT_BINARY_DIR} 22 | DEPENDS ${TORCH_MLIR_TEST_DEPENDS} 23 | ) 24 | set_target_properties(check-torch-mlir-pt1 PROPERTIES FOLDER "Tests") 25 | 26 | add_lit_testsuites(TORCH_MLIR_PT1 ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS}) 27 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/list.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import torch 6 | from torch_mlir.jit_ir_importer import ModuleBuilder 7 | 8 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 9 | 10 | mb = ModuleBuilder() 11 | 12 | # CHECK-LABEL: func.func @__torch__.f( 13 | # CHECK-SAME: %[[T0:.*]]: !torch.tensor, 14 | # CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list { 15 | # CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list 16 | # CHECK: return %[[RET]] : !torch.list 17 | 18 | 19 | @mb.import_function 20 | @torch.jit.script 21 | def f(t0, t1): 22 | return [t0, t1] 23 | 24 | 25 | assert isinstance(f, torch.jit.ScriptFunction) 26 | mb.module.operation.print() 27 | print() 28 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchToArith/TorchToArith.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H 11 | #define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Pass/Pass.h" 15 | #include 16 | 17 | namespace mlir { 18 | namespace torch { 19 | std::unique_ptr> createConvertTorchToArithPass(); 20 | } 21 | } // namespace mlir 22 | 23 | #endif // TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H 24 | -------------------------------------------------------------------------------- /tools/torch-mlir-lsp-server/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # torch-mlir-lsp-server is always linked dynamically as we want to distribute the 2 | # binaries with the python packages for hacking/debugging. 3 | add_llvm_executable(torch-mlir-lsp-server torch-mlir-lsp-server.cpp) 4 | 5 | install(TARGETS torch-mlir-lsp-server 6 | EXPORT TorchMLIRTargets 7 | RUNTIME DESTINATION ${LLVM_TOOLS_INSTALL_DIR} 8 | COMPONENT torch-mlir-lsp-server) 9 | 10 | # get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) 11 | # get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) 12 | # get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) 13 | 14 | target_link_libraries(torch-mlir-lsp-server PRIVATE 15 | MLIRLspServerLib 16 | TorchMLIRInitAll 17 | 18 | # # TODO: Remove these in favor of interface deps. 19 | # ${dialect_libs} 20 | # ${conversion_libs} 21 | # ${extension_libs} 22 | ) 23 | 24 | mlir_check_all_link_libraries(torch-mlir-lsp-server) 25 | -------------------------------------------------------------------------------- /tools/torch-mlir-lsp-server/torch-mlir-lsp-server.cpp: -------------------------------------------------------------------------------- 1 | //===- torch-mlir-lsp-server.cpp - MLIR Language Server ---------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "mlir/IR/Dialect.h" 11 | #include "mlir/IR/MLIRContext.h" 12 | #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" 13 | #include "torch-mlir/InitAll.h" 14 | 15 | using namespace mlir; 16 | 17 | int main(int argc, char **argv) { 18 | DialectRegistry registry; 19 | mlir::torch::registerAllDialects(registry); 20 | mlir::torch::registerOptionalInputDialects(registry); 21 | return failed(MlirLspServerMain(argc, argv, registry)); 22 | } 23 | -------------------------------------------------------------------------------- /lib/CAPI/Transforms.cpp: -------------------------------------------------------------------------------- 1 | //===- CAPIPasses.cpp - C API for Transformations Passes ------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // 7 | //===----------------------------------------------------------------------===// 8 | 9 | #include "mlir/CAPI/Pass.h" 10 | #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" 11 | 12 | // Must include the declarations as they carry important visibility attributes. 13 | #include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc" 14 | 15 | using namespace mlir; 16 | using namespace mlir::torch; 17 | using namespace mlir::torch::Torch; 18 | 19 | #ifdef __cplusplus 20 | extern "C" { 21 | #endif 22 | 23 | #include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.cpp.inc" 24 | 25 | #ifdef __cplusplus 26 | } 27 | #endif 28 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-locations.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self, x, y): 20 | # CHECK-LABEL: torch.nn_module 21 | # CHECK: loc("{{.*}}methods-locations.py":[[@LINE+1]] 22 | return x * y 23 | 24 | 25 | test_module = TestModule() 26 | recursivescriptmodule = torch.jit.script(test_module) 27 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 28 | mb.import_module(recursivescriptmodule._c) 29 | mb.module.operation.print(enable_debug_info=True) 30 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/elif.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import torch 6 | from torch_mlir.jit_ir_importer import ModuleBuilder 7 | 8 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 9 | 10 | mb = ModuleBuilder() 11 | 12 | 13 | # CHECK-LABEL: @__torch__.f 14 | @mb.import_function 15 | @torch.jit.script 16 | def f(b: bool, i: int): 17 | # elif is modeled as a nested if, so we only need to do cursory checking. 18 | # CHECK: torch.prim.If {{.*}} { 19 | # CHECK: } else { 20 | # CHECK: torch.prim.If {{.*}} { 21 | # CHECK: } else { 22 | # CHECK: } 23 | # CHECK: } 24 | 25 | if b: 26 | return i + i 27 | elif i: 28 | return i + i * i 29 | else: 30 | return i * i 31 | 32 | 33 | assert isinstance(f, torch.jit.ScriptFunction) 34 | mb.module.operation.print() 35 | print() 36 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | llvm_canonicalize_cmake_booleans( 2 | MLIR_ENABLE_BINDINGS_PYTHON 3 | TORCH_MLIR_ENABLE_REFBACKEND 4 | TORCH_MLIR_ENABLE_STABLEHLO 5 | ) 6 | 7 | configure_lit_site_cfg( 8 | ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in 9 | ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py 10 | MAIN_CONFIG 11 | ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py 12 | ) 13 | 14 | set(TORCH_MLIR_TEST_DEPENDS 15 | FileCheck count not 16 | TorchMLIRPythonModules 17 | torch-mlir-opt 18 | torch-mlir-capi-torch-test 19 | ) 20 | 21 | add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests" 22 | ${CMAKE_CURRENT_BINARY_DIR} 23 | DEPENDS ${TORCH_MLIR_TEST_DEPENDS} 24 | ) 25 | set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests") 26 | 27 | add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS}) 28 | 29 | add_subdirectory(CAPI) 30 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H 11 | #define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Pass/Pass.h" 15 | 16 | namespace mlir { 17 | namespace torch { 18 | std::unique_ptr> createConvertTorchToTMTensorPass(); 19 | } 20 | } // namespace mlir 21 | 22 | #endif // TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H 23 | -------------------------------------------------------------------------------- /projects/pt1/python/test/compile_api/backend_legal_ops.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | import torch 9 | 10 | from torch_mlir import torchscript 11 | 12 | 13 | class AddmmModule(torch.nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x, y, z): 18 | return torch.ops.aten.addmm(x, y, z) 19 | 20 | 21 | example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)] 22 | 23 | print( 24 | torchscript.compile( 25 | AddmmModule(), 26 | example_args, 27 | output_type="torch", 28 | backend_legal_ops=["aten.addmm"], 29 | ) 30 | ) 31 | # CHECK-LABEL: @forward 32 | # CHECK: torch.aten.addmm 33 | -------------------------------------------------------------------------------- /projects/pt1/test/python/compile.py: -------------------------------------------------------------------------------- 1 | # RUN: %PYTHON %s 2>&1 | FileCheck %s 2 | 3 | import gc 4 | import sys 5 | import torch 6 | from torch_mlir import torchscript 7 | 8 | 9 | def run_test(f): 10 | print("TEST:", f.__name__, file=sys.stderr) 11 | f() 12 | gc.collect() 13 | 14 | 15 | class TinyModel(torch.nn.Module): 16 | def __init__(self): 17 | super(TinyModel, self).__init__() 18 | 19 | self.linear = torch.nn.Linear(20, 30) 20 | 21 | def forward(self, x): 22 | x = self.linear(x) 23 | return x 24 | 25 | 26 | # CHECK-LABEL: TEST: test_enable_ir_printing 27 | @run_test 28 | def test_enable_ir_printing(): 29 | torchscript.compile( 30 | TinyModel(), 31 | torch.ones(1, 3, 20, 20), 32 | output_type="linalg-on-tensors", 33 | enable_ir_printing=True, 34 | ) 35 | 36 | 37 | # CHECK: // -----// IR Dump After Inliner (inline) 38 | # CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} { 39 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H 11 | #define TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Pass/Pass.h" 15 | #include 16 | 17 | namespace mlir { 18 | namespace torch { 19 | std::unique_ptr> createConvertTorchToTensorPass(); 20 | } // namespace torch 21 | } // namespace mlir 22 | 23 | #endif // TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H 24 | -------------------------------------------------------------------------------- /include/torch-mlir/RefBackend/Passes.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_REFBACKEND_PASSES_H 11 | #define TORCHMLIR_REFBACKEND_PASSES_H 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Pass/Pass.h" 15 | #include "mlir/Pass/PassManager.h" 16 | 17 | namespace mlir { 18 | namespace torch { 19 | namespace RefBackend { 20 | 21 | /// Registers all RefBackend passes. 22 | void registerRefBackendPasses(); 23 | 24 | } // namespace RefBackend 25 | } // namespace torch 26 | } // namespace mlir 27 | 28 | #endif // TORCHMLIR_REFBACKEND_PASSES_H 29 | -------------------------------------------------------------------------------- /include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_ 11 | #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_ 12 | 13 | #include "mlir/IR/Dialect.h" 14 | #include "mlir/IR/OpDefinition.h" 15 | 16 | // clang-format off: must be included after all LLVM/MLIR headers 17 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h.inc" // IWYU pragma: keep 18 | // clang-format on 19 | 20 | #endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_ 21 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/jit_ir_importer/__init__.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # This is a trampoline module which loads the _torch_mlir native module 7 | # and binds names locally. It exists to allow for customization of behavior 8 | # prior to loading shared objects. 9 | 10 | import torch 11 | 12 | # Our native extension is not self-contained. It references libraries which 13 | # must come in via the above first. 14 | from .._mlir_libs._jit_ir_importer import * 15 | 16 | # Ensure that the torch dialect has been loaded as it registers passes 17 | # and other things the jit_ir_importer needs. 18 | from ..dialects import torch as _unused_torch_dialect 19 | 20 | __all__ = [ 21 | "debug_trace_to_stderr", 22 | "ModuleBuilder", 23 | "ClassAnnotator", 24 | ] 25 | -------------------------------------------------------------------------------- /include/torch-mlir-c/Dialects.h: -------------------------------------------------------------------------------- 1 | /*===-- torch-mlir-c/Dialects.h - Dialect functions --------------*- C -*-===*\ 2 | |* *| 3 | |* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| 4 | |* Exceptions. *| 5 | |* See https://llvm.org/LICENSE.txt for license information. *| 6 | |* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| 7 | |* *| 8 | \*===----------------------------------------------------------------------===*/ 9 | 10 | #ifndef TORCHMLIR_C_DIALECTS_H 11 | #define TORCHMLIR_C_DIALECTS_H 12 | 13 | #include "mlir-c/IR.h" 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Torch, torch); 20 | 21 | #ifdef __cplusplus 22 | } 23 | #endif 24 | 25 | #endif // TORCHMLIR_C_DIALECTS_H 26 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H 11 | #define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 | #include "mlir/Pass/Pass.h" 16 | #include 17 | 18 | namespace mlir { 19 | namespace torch { 20 | std::unique_ptr> createConvertTorchToLinalgPass(); 21 | } 22 | } // namespace mlir 23 | 24 | #endif // TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H 25 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToStablehlo/Utils.h: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "mlir/Transforms/DialectConversion.h" 11 | #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" 12 | 13 | namespace mlir { 14 | namespace torch { 15 | namespace torch_to_stablehlo { 16 | 17 | // Convert a scalar type to the corresponding builtin type in the 18 | // stablehlo backend. 19 | FailureOr 20 | getBackendTypeForScalarType(MLIRContext *context, 21 | torch_upstream::ScalarType dtypeInt); 22 | 23 | } // namespace torch_to_stablehlo 24 | } // namespace torch 25 | } // namespace mlir 26 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | # CHECK: %[[T:.*]] = torch.tensor.literal 19 | # CHECK: torch.nn_module { 20 | # CHECK: torch.slot "t1", %[[T]] 21 | # CHECK: torch.slot "t2", %[[T]] 22 | self.t1 = self.t2 = torch.tensor([10.0, 20.0]) 23 | 24 | 25 | test_module = TestModule() 26 | recursivescriptmodule = torch.jit.script(test_module) 27 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 28 | mb.import_module(recursivescriptmodule._c) 29 | mb.module.operation.print() 30 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: not %PYTHON %s 2>&1 | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | # CHECK: Unhandled tensor that shares storage with another tensor. 19 | # CHECK-NEXT: Found at path '.t2' from root object '__torch__.TestModule' 20 | self.t1 = torch.tensor([10.0, 20.0]) 21 | self.t2 = self.t1[0] 22 | 23 | 24 | test_module = TestModule() 25 | recursivescriptmodule = torch.jit.script(test_module) 26 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 27 | mb.import_module(recursivescriptmodule._c) 28 | mb.module.operation.print() 29 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/jit_ir_importer/init_python_bindings.cpp: -------------------------------------------------------------------------------- 1 | //===- python_bindings.cpp --------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | // This is the top-level entry point for the JIT IR -> MLIR importer. 11 | 12 | #include 13 | 14 | #include "class_annotator_pybind.h" 15 | #include "get_registered_ops.h" 16 | #include "import_options_pybind.h" 17 | #include "module_builder.h" 18 | 19 | using namespace torch_mlir; 20 | 21 | PYBIND11_MODULE(_jit_ir_importer, m) { 22 | ModuleBuilder::bind(m); 23 | initClassAnnotatorBindings(m); 24 | initGetRegisteredOpsBindings(m); 25 | initImportOptionsBindings(m); 26 | } 27 | -------------------------------------------------------------------------------- /projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_mlir import torchscript 3 | 4 | from transformers import BertForMaskedLM 5 | 6 | 7 | # Wrap the bert model to avoid multiple returns problem 8 | class BertTinyWrapper(torch.nn.Module): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | self.bert = BertForMaskedLM.from_pretrained( 12 | "prajjwal1/bert-tiny", return_dict=False 13 | ) 14 | 15 | def forward(self, data): 16 | return self.bert(data)[0] 17 | 18 | 19 | model = BertTinyWrapper() 20 | model.eval() 21 | data = torch.randint(30522, (2, 128)) 22 | out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" 23 | 24 | module = torchscript.compile( 25 | model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True 26 | ) 27 | with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: 28 | outf.write(module.operation.get_asm()) 29 | 30 | print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}") 31 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H 11 | #define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H 12 | 13 | #include "mlir/IR/BuiltinOps.h" 14 | #include "mlir/Pass/Pass.h" 15 | 16 | namespace mlir { 17 | namespace torch { 18 | std::unique_ptr> 19 | createConvertTorchConversionToMLProgramPass(); 20 | } 21 | } // namespace mlir 22 | 23 | #endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H 24 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.h: -------------------------------------------------------------------------------- 1 | //===- backend_impl.h -----------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #pragma once 11 | 12 | #include 13 | 14 | namespace at { 15 | // This function is defined in the codegenerated RegisterLazy.cpp file. 16 | TORCH_API void RegisterTorchMlirLazyNativeFunctions(); 17 | } // namespace at 18 | 19 | namespace torch { 20 | namespace lazy { 21 | 22 | torch::lazy::BackendImplInterface* GetReferenceLazyBackendImpl(); 23 | 24 | void InitReferenceLazyBackend(); 25 | 26 | ComputationPtr& GetLatestComputation(); 27 | 28 | } // namespace lazy 29 | } // namespace torch 30 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "torch/csrc/lazy/backend/backend_device.h" 4 | #include "torch/csrc/lazy/core/tensor.h" 5 | 6 | #include "../ops/device_data.h" 7 | 8 | namespace torch { 9 | namespace lazy { 10 | 11 | TORCH_API bool is_detach_copy(const torch::lazy::Node *); 12 | TORCH_API bool is_detach_copy(const torch::lazy::Value &); 13 | 14 | TORCH_API torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *); 15 | TORCH_API const torch::lazy::Node * 16 | extract_non_detach_copy_node(const torch::lazy::Node *); 17 | 18 | TORCH_API torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *); 19 | TORCH_API const torch::lazy::DeviceData * 20 | device_data_cast(const torch::lazy::Node *); 21 | TORCH_API torch::lazy::DeviceData * 22 | device_data_cast(const torch::lazy::Value &value); 23 | TORCH_API torch::lazy::DeviceData *device_data_cast( 24 | const at::Tensor &tensor, 25 | std::optional device = c10::nullopt); 26 | 27 | } // namespace lazy 28 | } // namespace torch 29 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/jit_ir_importer/get_registered_ops.h: -------------------------------------------------------------------------------- 1 | //===- get_registered_ops.h -------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | // 10 | // Listing of the JIT operator registry, for use in generating the `torch` 11 | // dialect. 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | #ifndef TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H 16 | #define TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H 17 | 18 | #include 19 | 20 | namespace torch_mlir { 21 | 22 | void initGetRegisteredOpsBindings(py::module& m); 23 | 24 | } // namespace torch_mlir 25 | 26 | #endif // TORCHMLIRJITIRIMPORTER_CSRC_GETREGISTEREDOPS_H 27 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/tensor.h: -------------------------------------------------------------------------------- 1 | //===- tensor.h -----------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #pragma once 11 | 12 | #include 13 | 14 | namespace torch { 15 | namespace lazy { 16 | 17 | // Ops like torch.ones/zeros etc. which produce new tensor as an output 18 | // should have explicit tensor functinoalization. Otherwise we can get 19 | // unfanctionalized primitives or in the worst case if we apply inplace 20 | // operations to unfunctionalized tensor it won't be captured in LTC graph. 21 | TORCH_API at::Tensor 22 | CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor); 23 | 24 | } // namespace lazy 25 | } // namespace torch 26 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace sys_util { 7 | 8 | template 9 | static T GetEnv(const std::string &name, const T &default_value = T(0)) { 10 | const char *env = std::getenv(name.c_str()); 11 | if (!env) { 12 | return default_value; 13 | } 14 | return T(std::atoi(env)); 15 | } 16 | 17 | [[maybe_unused]] static std::string 18 | GetEnvString(const std::string &name, const std::string &default_value) { 19 | const char *env = std::getenv(name.c_str()); 20 | if (!env) { 21 | return default_value; 22 | } 23 | return std::string(env); 24 | } 25 | 26 | [[maybe_unused]] static bool GetEnvBool(const char *name, bool defval) { 27 | const char *env = std::getenv(name); 28 | if (env == nullptr) { 29 | return defval; 30 | } 31 | if (std::strcmp(env, "true") == 0) { 32 | return true; 33 | } 34 | if (std::strcmp(env, "false") == 0) { 35 | return false; 36 | } 37 | return std::atoi(env) != 0; 38 | } 39 | 40 | } // namespace sys_util 41 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H 11 | #define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/IR/BuiltinOps.h" 15 | #include "mlir/Pass/Pass.h" 16 | #include 17 | 18 | namespace mlir::torch::onnx_c { 19 | 20 | std::unique_ptr> createTorchOnnxToTorchPass(); 21 | 22 | /// Registers all torch-mlir conversion passes. 23 | void registerTorchOnnxToTorchPasses(); 24 | 25 | } // namespace mlir::torch::onnx_c 26 | 27 | #endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H 28 | -------------------------------------------------------------------------------- /lib/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h" 11 | 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 | #include "mlir/Dialect/Utils/StaticValueUtils.h" 17 | #include "llvm/Support/Debug.h" 18 | 19 | #define DEBUG_TYPE "torch-mlir-tiled-op-interface" 20 | 21 | using namespace mlir; 22 | using namespace mlir::torch::TMTensor; 23 | 24 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp.inc" 25 | -------------------------------------------------------------------------------- /lib/Dialect/TorchConversion/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LinkedLibs 2 | MLIRFuncTransforms 3 | MLIRControlFlowTransforms 4 | MLIRIR 5 | MLIRLinalgTransforms 6 | MLIRMemRefTransforms 7 | MLIRPass 8 | MLIRTosaTransforms 9 | MLIRVectorTransforms 10 | TorchMLIRTorchConversionDialect 11 | TorchMLIRTorchDialect 12 | TorchMLIRTorchPasses 13 | TorchMLIRConversionPasses 14 | ) 15 | 16 | if(TORCH_MLIR_ENABLE_STABLEHLO) 17 | list(APPEND LinkedLibs 18 | StablehloOps 19 | StablehloPasses 20 | ) 21 | endif() 22 | 23 | add_mlir_library(TorchMLIRTorchConversionPasses 24 | BackendTypeConversion.cpp 25 | BackendTypeConversionPasses.cpp 26 | Passes.cpp 27 | ConvertCustomQuantOp.cpp 28 | UnpackQuantTensor.cpp 29 | VerifyLinalgOnTensorsBackendContract.cpp 30 | VerifyTosaBackendContract.cpp 31 | VerifyStablehloBackendContract.cpp 32 | 33 | ADDITIONAL_HEADER_DIRS 34 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms 35 | 36 | DEPENDS 37 | TorchMLIRTorchConversionPassIncGen 38 | 39 | LINK_LIBS PUBLIC 40 | ${LinkedLibs} 41 | ) 42 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/ops/generic.cpp: -------------------------------------------------------------------------------- 1 | //===- generic.cpp --------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | // This file is adapted from pytorch/pytorch 10 | // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/generic.cpp 11 | //===----------------------------------------------------------------------===// 12 | 13 | #include "generic.h" 14 | 15 | namespace torch { 16 | namespace lazy { 17 | 18 | Generic::Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs, 19 | hash_t hash_seed) 20 | : TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed), 21 | hash_seed_(hash_seed) {} 22 | 23 | } // namespace lazy 24 | } // namespace torch 25 | -------------------------------------------------------------------------------- /test/python/onnx_importer/constants.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | 5 | # RUN: %PYTHON %s %t.onnx 6 | # RUN: %PYTHON -m torch_mlir.tools.import_onnx %t.onnx > %t.mlir 7 | # RUN: FileCheck %s < %t.mlir 8 | 9 | import onnx 10 | from onnx.helper import make_graph, make_tensor, make_tensor_value_info 11 | 12 | graph = make_graph( 13 | name="graph", 14 | inputs=[], 15 | nodes=[], 16 | outputs=[], 17 | initializer=[ 18 | # CHECK{LITERAL}: torch.operator "onnx.Constant"() {torch.onnx.value = dense<[[true, false], [false, true]]> : tensor<2x2xi1>} : () -> !torch.vtensor<[2,2],i1> 19 | make_tensor( 20 | "bool_tensor", 21 | onnx.TensorProto.BOOL, 22 | dims=[2, 2], 23 | vals=[True, False, False, True], 24 | ) 25 | ], 26 | ) 27 | model = onnx.helper.make_model(graph) 28 | 29 | import sys 30 | 31 | out_file_path = sys.argv[1] 32 | onnx.save(model, out_file_path) 33 | -------------------------------------------------------------------------------- /projects/pt1/examples/torchscript_resnet18_all_output_types.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | import torch 7 | import torchvision 8 | 9 | from torch_mlir import torchscript 10 | 11 | resnet18 = torchvision.models.resnet18(pretrained=True) 12 | resnet18.eval() 13 | 14 | module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") 15 | print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10)) 16 | module = torchscript.compile( 17 | resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors" 18 | ) 19 | print( 20 | "LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10) 21 | ) 22 | # TODO: Debug why this is so slow. 23 | module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") 24 | print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10)) 25 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/tensor.cpp: -------------------------------------------------------------------------------- 1 | //===- tensor.cpp ---------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include 11 | 12 | #include "tensor.h" 13 | 14 | namespace torch { 15 | namespace lazy { 16 | 17 | at::Tensor 18 | CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor) { 19 | at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor); 20 | if (!c10::impl::tls_is_dispatch_key_excluded( 21 | c10::DispatchKey::Functionalize) && 22 | !at::functionalization::impl::isFunctionalTensor(tensor)) { 23 | return at::functionalization::impl::to_functional_tensor(tensor); 24 | } 25 | return tensor; 26 | } 27 | 28 | } // namespace lazy 29 | } // namespace torch 30 | -------------------------------------------------------------------------------- /test/Dialect/TMTensor/canonicalize.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -canonicalize -split-input-file %s | FileCheck %s 2 | 3 | // CHECK-LABEL: func.func @tensor.cast( 4 | func.func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> { 5 | %init = tensor.empty() : tensor<128xi32> 6 | %c0 = tensor.empty() : tensor 7 | 8 | %casted_arg0 = tensor.cast %arg0 : tensor<128xi32> to tensor 9 | %casted_init = tensor.cast %init : tensor<128xi32> to tensor 10 | // CHECK: tm_tensor.scan 11 | // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<128xi32>) 12 | // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<128xi32>, tensor) 13 | %0, %1 = tm_tensor.scan dimension(0) inclusive(true) 14 | ins(%casted_arg0 : tensor) 15 | outs(%casted_init, %c0: tensor, tensor) { 16 | ^bb0(%barg0 : i32, %barg1 : i32, %barg2 : i32): 17 | %sum = arith.addi %barg0, %barg1 : i32 18 | tm_tensor.yield %sum : i32 19 | } -> tensor, tensor 20 | 21 | %2 = tensor.cast %0: tensor to tensor<128xi32> 22 | 23 | return %2: tensor<128xi32> 24 | } 25 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/jit_ir_importer/import_options_pybind.cpp: -------------------------------------------------------------------------------- 1 | //===- import_options_pybind.cpp ------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "import_options_pybind.h" 11 | #include "jit_ir_importer/import_options.h" 12 | 13 | namespace py = pybind11; 14 | 15 | using namespace torch_mlir; 16 | 17 | void torch_mlir::initImportOptionsBindings(py::module& m) { 18 | py::class_(m, "ImportOptions") 19 | .def(py::init<>()) 20 | .def_readwrite( 21 | "assumeTensorsHaveValueSemantics", 22 | &ImportOptions::assumeTensorsHaveValueSemantics) 23 | .def_readwrite( 24 | "ignoreExistingTensorShapesAndDtypes", 25 | &ImportOptions::ignoreExistingTensorShapesAndDtypes); 26 | } 27 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | from torch_mlir.jit_ir_importer import ModuleBuilder 6 | 7 | from utils import create_script_function 8 | 9 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 10 | 11 | mb = ModuleBuilder() 12 | 13 | # CHECK-LABEL: func.func @__torch__.refined_block_arg( 14 | # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor { 15 | # CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32> 16 | # CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor 17 | # CHECK: return %[[RESULT]] : !torch.tensor 18 | mb.import_function( 19 | create_script_function( 20 | "__torch__.refined_block_arg", 21 | """ 22 | graph(%0 : Float(1, 384)): 23 | return (%0) 24 | """, 25 | ) 26 | ) 27 | 28 | mb.module.operation.print() 29 | print() 30 | -------------------------------------------------------------------------------- /lib/CAPI/Registration.cpp: -------------------------------------------------------------------------------- 1 | //===- Registration.cpp - C Interface for MLIR Registration ---------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "torch-mlir-c/Registration.h" 11 | 12 | #include "mlir/CAPI/IR.h" 13 | #include "mlir/Conversion/Passes.h" 14 | #include "mlir/Dialect/Linalg/Passes.h" 15 | #include "mlir/Transforms/Passes.h" 16 | #include "torch-mlir/InitAll.h" 17 | 18 | void torchMlirRegisterAllDialects(MlirContext context) { 19 | mlir::DialectRegistry registry; 20 | mlir::torch::registerAllDialects(registry); 21 | unwrap(context)->appendDialectRegistry(registry); 22 | // TODO: Don't eagerly load once D88162 is in and clients can do this. 23 | unwrap(context)->loadAllAvailableDialects(); 24 | } 25 | 26 | void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); } 27 | -------------------------------------------------------------------------------- /projects/pt1/python/test/compile_api/output_type_spec.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | import torch 9 | 10 | from torch_mlir import torchscript 11 | 12 | 13 | class TanhModule(torch.nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x): 18 | return torch.ops.aten.tanh(x) 19 | 20 | 21 | tanh_example_input = torch.ones(2, 3) 22 | 23 | print( 24 | torchscript.compile( 25 | TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH 26 | ) 27 | ) 28 | # CHECK-LABEL: @forward 29 | # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> 30 | print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch")) 31 | # CHECK-LABEL: @forward 32 | # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> 33 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/Torch/IR/TorchDialect.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H 11 | #define TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H 12 | 13 | #include "mlir/IR/Dialect.h" 14 | 15 | #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h.inc" 16 | 17 | namespace mlir { 18 | namespace torch { 19 | namespace Torch { 20 | 21 | /// Parse a type registered to this dialect. 22 | Type parseTorchDialectType(AsmParser &parser); 23 | 24 | /// Print a type registered to this dialect. 25 | void printTorchDialectType(Type type, AsmPrinter &printer); 26 | 27 | } // namespace Torch 28 | } // namespace torch 29 | } // namespace mlir 30 | 31 | #endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H 32 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | #ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H 10 | #define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H 11 | 12 | #include "mlir/IR/Attributes.h" 13 | #include "mlir/IR/Value.h" 14 | #include "mlir/Support/LogicalResult.h" 15 | 16 | namespace mlir { 17 | namespace torch { 18 | namespace Torch { 19 | 20 | // Create a new SparseTensorEncodingAttr based on the provided `attr`, but with 21 | // a new dense level inserted at `dim`. 22 | FailureOr getSparsityWithDenseLTAtDim(Attribute attr, Value dim); 23 | 24 | } // namespace Torch 25 | } // namespace torch 26 | } // namespace mlir 27 | 28 | #endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H 29 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/strings.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.s = "foo" 19 | 20 | 21 | # CHECK: torch.class_type @[[CLASSTYPE:.*]] { 22 | # TODO: Don't lose element type. 23 | # CHECK: torch.attr "s" : !torch.str 24 | # CHECK: } 25 | # CHECK: %[[STR:.*]] = torch.constant.str "foo" 26 | # CHECK: torch.nn_module { 27 | # CHECK: torch.slot "s", %[[STR]] : !torch.str 28 | # CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> 29 | 30 | 31 | test_module = TestModule() 32 | recursivescriptmodule = torch.jit.script(test_module) 33 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 34 | mb.import_module(recursivescriptmodule._c) 35 | mb.module.operation.print() 36 | -------------------------------------------------------------------------------- /include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterface.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_ 11 | #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_ 12 | 13 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 14 | #include "mlir/IR/Builders.h" 15 | #include "mlir/IR/BuiltinTypes.h" 16 | #include "mlir/IR/Operation.h" 17 | #include "mlir/Interfaces/ViewLikeInterface.h" 18 | #include "mlir/Support/LLVM.h" 19 | 20 | /// Include the ODS generated interface header files. 21 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc" 22 | 23 | #endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_ 24 | -------------------------------------------------------------------------------- /include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td: -------------------------------------------------------------------------------- 1 | //===-------------------------------------------------------*- tablegen -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCH_MLIR_DIALECT_TMTENSOR_PASSES 11 | #define TORCH_MLIR_DIALECT_TMTENSOR_PASSES 12 | 13 | include "mlir/Pass/PassBase.td" 14 | 15 | def TMTensorToLoops : 16 | Pass<"tm-tensor-to-loops", "func::FuncOp"> { 17 | let summary = "Convert TMTensor ops to loops and Linalg ops."; 18 | let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()"; 19 | } 20 | 21 | def TMTensorBufferize : Pass<"tm-tensor-bufferize", "func::FuncOp"> { 22 | let summary = "Bufferize the TMTensor dialect"; 23 | let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()"; 24 | } 25 | 26 | #endif // TORCH_MLIR_DIALECT_TMTENSOR_PASSES 27 | -------------------------------------------------------------------------------- /projects/jit_ir_common/csrc/jit_ir_importer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Static library with core functionality. 2 | # We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build) 3 | # For details, see: https://github.com/llvm/torch-mlir/runs/7919012376 4 | add_library(TorchMLIRJITIRImporter STATIC 5 | class_annotator.cpp 6 | function_importer.cpp 7 | node_importer.cpp 8 | ivalue_importer.cpp 9 | torch_to_mlir_utils.cpp 10 | ) 11 | message(STATUS "Linking TorchMLIRJITImporter with ${TORCH_LIBRARIES}") 12 | target_link_libraries(TorchMLIRJITIRImporter 13 | TorchMLIRAggregateCAPI 14 | ${TORCH_LIBRARIES} 15 | ) 16 | # Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...") 17 | target_include_directories(TorchMLIRJITIRImporter PUBLIC 18 | ${CMAKE_CURRENT_SOURCE_DIR}/.. 19 | ) 20 | set_target_properties(TorchMLIRJITIRImporter PROPERTIES 21 | LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" 22 | OUTPUT_NAME lib_jit_ir_importer 23 | PREFIX "" 24 | SUFFIX ".a" 25 | CXX_VISIBILITY_PRESET "default" 26 | COMPILE_FLAGS "${TORCH_CXXFLAGS}" 27 | ) 28 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/jit_ir_importer/class_annotator_pybind.h: -------------------------------------------------------------------------------- 1 | //===- class_annotator_pybind.h ---------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | // Includes Torch-specific pybind and associated helpers. 10 | // Depend on this for access to all Torch types (versus depending on pybind11 11 | // directly). 12 | //===----------------------------------------------------------------------===// 13 | 14 | #ifndef TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H 15 | #define TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H 16 | 17 | #include 18 | 19 | namespace py = pybind11; 20 | namespace torch_mlir { 21 | void initClassAnnotatorBindings(py::module& m); 22 | } // namespace torch_mlir 23 | 24 | #endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H 25 | -------------------------------------------------------------------------------- /include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_ 11 | #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_ 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Pass/Pass.h" 15 | 16 | namespace mlir { 17 | namespace torch { 18 | namespace TMTensor { 19 | 20 | std::unique_ptr> createTMTensorToLoopsPass(); 21 | std::unique_ptr> createTMTensorBufferizePass(); 22 | 23 | void registerPasses(); 24 | 25 | } // namespace TMTensor 26 | } // namespace torch 27 | } // namespace mlir 28 | 29 | #endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_ 30 | -------------------------------------------------------------------------------- /lib/Dialect/TMTensor/Transforms/Passes.cpp: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" 11 | 12 | #include "mlir/Pass/Pass.h" 13 | #include "mlir/Pass/PassRegistry.h" 14 | #include "mlir/Transforms/Passes.h" 15 | 16 | using namespace mlir; 17 | 18 | namespace mlir { 19 | namespace torch { 20 | namespace TMTensor { 21 | 22 | namespace detail { 23 | #define GEN_PASS_REGISTRATION 24 | #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: export 25 | } // namespace detail 26 | 27 | } // namespace TMTensor 28 | } // namespace torch 29 | } // namespace mlir 30 | 31 | void torch::TMTensor::registerPasses() { 32 | torch::TMTensor::detail::registerPasses(); 33 | } 34 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/utils.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | from torch_mlir.compiler_utils import TensorPlaceholder 7 | from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME 8 | 9 | 10 | def convert_annotations_to_placeholders(forward_method): 11 | """Converts the annotations on a forward method into tensor placeholders. 12 | 13 | These placeholders are suitable for being passed to `torchscript.compile`. 14 | """ 15 | annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) 16 | placeholders = [] 17 | # Skip the "self" annotation. 18 | for annotation in annotations[1:]: 19 | if not annotation[2]: 20 | raise ValueError( 21 | "Can only compile inputs annotated as having value semantics." 22 | ) 23 | placeholders.append(TensorPlaceholder(annotation[0], annotation[1])) 24 | return placeholders 25 | -------------------------------------------------------------------------------- /projects/pt1/test/lit.site.cfg.py.in: -------------------------------------------------------------------------------- 1 | @LIT_SITE_CFG_IN_HEADER@ 2 | 3 | import sys 4 | 5 | config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ 6 | config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@" 7 | config.torch_mlir_python_packages_dir = "@TORCH_MLIR_PYTHON_PACKAGES_DIR@" 8 | config.host_os = "@HOST_OS@" 9 | config.host_cxx = "@HOST_CXX@" 10 | config.host_arch = "@HOST_ARCH@" 11 | config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" 12 | config.llvm_src_root = "@LLVM_SOURCE_DIR@" 13 | config.llvm_obj_root = "@LLVM_BINARY_DIR@" 14 | config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" 15 | config.llvm_build_dir = "@CMAKE_BINARY_DIR@" 16 | config.llvm_lib_dir = "@LLVM_LIBS_DIR@" 17 | config.llvm_shlib_dir = "@SHLIBDIR@" 18 | config.llvm_shlib_ext = "@SHLIBEXT@" 19 | config.llvm_exe_ext = "@EXEEXT@" 20 | config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" 21 | config.python_executable = "@Python3_EXECUTABLE@" 22 | config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@ 23 | 24 | import lit.llvm 25 | lit.llvm.initialize(lit_config, config) 26 | 27 | # Let the main config do the real work. 28 | lit_config.load_config(config, "@TORCH_MLIR_SOURCE_DIR@/projects/pt1/test/lit.cfg.py") 29 | -------------------------------------------------------------------------------- /.github/workflows/merge-rollpytorch.yml: -------------------------------------------------------------------------------- 1 | # yamllint disable rule:line-length 2 | name: RollPyTorch Merge 3 | 4 | on: 5 | workflow_run: 6 | workflows: [Build and Test] 7 | types: [completed] 8 | branches: [rollpytorch] 9 | 10 | jobs: 11 | merge-pr: 12 | runs-on: ubuntu-22.04 13 | if: | 14 | github.repository == 'llvm/torch-mlir' && 15 | github.event.workflow_run.actor.login == 'stellaraccident' && 16 | github.event.workflow_run.conclusion == 'success' 17 | 18 | steps: 19 | # Fetch the repo first so that the gh command knows where to look for the PR 20 | - name: Fetch Repo 21 | uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 22 | with: 23 | token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} 24 | 25 | - name: Merge RollPyTorch PR 26 | run: | 27 | for pr_id in ${{ join(github.event.workflow_run.pull_requests.*.number, ' ') }} 28 | do 29 | echo "Merging PR: $pr_id" 30 | gh pr merge $pr_id --delete-branch --squash 31 | done 32 | shell: bash 33 | env: 34 | GH_TOKEN: ${{ secrets.ROLLPYTORCH_TOKEN1 }} 35 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td: -------------------------------------------------------------------------------- 1 | //===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES 11 | #define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES 12 | 13 | include "mlir/Pass/PassBase.td" 14 | 15 | def ConvertTorchOnnxToTorch : Pass<"convert-torch-onnx-to-torch", "func::FuncOp"> { 16 | let summary = "Converts ONNX custom ops in the torch dialect to native torch ops"; 17 | let description = [{ 18 | Converts equivalent ONNX custom ops to built-in equivalents. 19 | 20 | See the README for a detailed description of how this operates. 21 | }]; 22 | 23 | let constructor = "mlir::torch::onnx_c::createTorchOnnxToTorchPass()"; 24 | } 25 | 26 | #endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES 27 | -------------------------------------------------------------------------------- /projects/pt1/python/test/compile_api/already_traced.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | import torch 9 | from torch_mlir import torchscript 10 | 11 | 12 | class BasicModule(torch.nn.Module): 13 | def forward(self, x): 14 | return torch.ops.aten.sin(x) 15 | 16 | 17 | example_arg = torch.ones(2, 3) 18 | example_args = torchscript.ExampleArgs.get(example_arg) 19 | 20 | traced = torch.jit.trace(BasicModule(), example_arg) 21 | print(torchscript.compile(traced, example_args)) 22 | # CHECK: module 23 | # CHECK-DAG: func.func @forward 24 | 25 | traced = torch.jit.trace(BasicModule(), example_arg) 26 | try: 27 | # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. 28 | torchscript.compile( 29 | traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg) 30 | ) 31 | except Exception as e: 32 | print(e) 33 | -------------------------------------------------------------------------------- /test/lit.site.cfg.py.in: -------------------------------------------------------------------------------- 1 | @LIT_SITE_CFG_IN_HEADER@ 2 | 3 | import sys 4 | 5 | config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ 6 | config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@" 7 | config.torch_mlir_python_packages_dir = "@TORCH_MLIR_PYTHON_PACKAGES_DIR@" 8 | config.torch_mlir_enable_refbackend = @TORCH_MLIR_ENABLE_REFBACKEND@ 9 | config.host_os = "@HOST_OS@" 10 | config.host_cxx = "@HOST_CXX@" 11 | config.host_arch = "@HOST_ARCH@" 12 | config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" 13 | config.llvm_src_root = "@LLVM_SOURCE_DIR@" 14 | config.llvm_obj_root = "@LLVM_BINARY_DIR@" 15 | config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" 16 | config.llvm_build_dir = "@CMAKE_BINARY_DIR@" 17 | config.llvm_lib_dir = "@LLVM_LIBS_DIR@" 18 | config.llvm_shlib_dir = "@SHLIBDIR@" 19 | config.llvm_shlib_ext = "@SHLIBEXT@" 20 | config.llvm_exe_ext = "@EXEEXT@" 21 | config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" 22 | config.python_executable = "@Python3_EXECUTABLE@" 23 | config.enable_stablehlo = @TORCH_MLIR_ENABLE_STABLEHLO@ 24 | 25 | import lit.llvm 26 | lit.llvm.initialize(lit_config, config) 27 | 28 | # Let the main config do the real work. 29 | lit_config.load_config(config, "@TORCH_MLIR_SOURCE_DIR@/test/lit.cfg.py") 30 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H 11 | #define TORCHMLIR_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H 12 | 13 | #include "mlir/IR/BuiltinTypes.h" 14 | #include "mlir/IR/OpDefinition.h" 15 | #include "mlir/IR/OpImplementation.h" 16 | #include "mlir/Interfaces/CastInterfaces.h" 17 | #include "mlir/Interfaces/InferTypeOpInterface.h" 18 | #include "mlir/Interfaces/SideEffectInterfaces.h" 19 | #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" 20 | #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" 21 | 22 | #define GET_OP_CLASSES 23 | #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h.inc" 24 | 25 | #endif // TORCHMLIR_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H 26 | -------------------------------------------------------------------------------- /include/torch-mlir/InitAll.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCH_MLIR_INITALL_H 11 | #define TORCH_MLIR_INITALL_H 12 | 13 | #include "mlir/IR/Dialect.h" 14 | 15 | namespace mlir { 16 | namespace torch { 17 | 18 | // Registers all dialects that this project produces and any dependencies. 19 | void registerAllDialects(mlir::DialectRegistry ®istry); 20 | 21 | // Registers all necessary dialect extensions for this project 22 | void registerAllExtensions(mlir::DialectRegistry ®istry); 23 | 24 | // Registers dialects that may be needed to parse torch-mlir inputs and 25 | // test cases. 26 | void registerOptionalInputDialects(mlir::DialectRegistry ®istry); 27 | 28 | void registerAllPasses(); 29 | 30 | } // namespace torch 31 | } // namespace mlir 32 | 33 | #endif // TORCH_MLIR_INITALL_H 34 | -------------------------------------------------------------------------------- /lib/Dialect/TMTensor/IR/TMTensorDialect.cpp: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" 11 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" 12 | 13 | #include "mlir/IR/Attributes.h" 14 | #include "mlir/IR/DialectImplementation.h" 15 | #include "mlir/IR/OpDefinition.h" 16 | #include "mlir/IR/OpImplementation.h" 17 | #include "llvm/ADT/SmallVector.h" 18 | #include "llvm/Support/SourceMgr.h" 19 | 20 | using namespace mlir; 21 | using namespace mlir::torch::TMTensor; 22 | 23 | void TMTensorDialect::initialize() { 24 | #define GET_OP_LIST 25 | addOperations< 26 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc" 27 | >(); 28 | } 29 | 30 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.cpp.inc" 31 | -------------------------------------------------------------------------------- /lib/Dialect/Torch/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(TorchMLIRTorchPasses 2 | AdjustCallingConventions.cpp 3 | DecomposeComplexOps.cpp 4 | DropAbstractInterpCalculations.cpp 5 | EraseModuleInitializer.cpp 6 | FuseQuantizedOps.cpp 7 | Passes.cpp 8 | GlobalizeObjectGraph.cpp 9 | InlineGlobalSlots.cpp 10 | LowerToBackendContract.cpp 11 | MatchQuantizedOps.cpp 12 | MaximizeValueSemantics.cpp 13 | PrepareForGlobalizeObjectGraph.cpp 14 | RecomposeComplexOps.cpp 15 | ReduceOpVariants.cpp 16 | RefinePublicReturn.cpp 17 | ReifyShapeCalculations.cpp 18 | ReifyDtypeCalculations.cpp 19 | ReifyAbstractInterpCalculationsUtils.cpp 20 | RestructureNonConstantAxes.cpp 21 | ScalarizeShapes.cpp 22 | AbstractInterpLibrary.cpp 23 | SimplifyShapeCalculations.cpp 24 | SimplifyDtypeCalculations.cpp 25 | SimplifyAbstractInterpCalculationsUtils.cpp 26 | 27 | ADDITIONAL_HEADER_DIRS 28 | ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms 29 | 30 | DEPENDS 31 | TorchMLIRTorchPassIncGen 32 | 33 | LINK_LIBS PUBLIC 34 | MLIRIR 35 | MLIRPass 36 | MLIRTransforms 37 | TorchMLIRTorchDialect 38 | TorchMLIRTorchUtils 39 | ) 40 | 41 | torch_mlir_target_includes(TorchMLIRTorchPasses) 42 | -------------------------------------------------------------------------------- /projects/pt1/python/test/compile_api/already_scripted.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | import torch 9 | from torch_mlir import torchscript 10 | 11 | 12 | class BasicModule(torch.nn.Module): 13 | @torch.jit.export 14 | def sin(self, x): 15 | return torch.ops.aten.sin(x) 16 | 17 | 18 | example_args = torchscript.ExampleArgs() 19 | example_args.add_method("sin", torch.ones(2, 3)) 20 | 21 | scripted = torch.jit.script(BasicModule()) 22 | print(torchscript.compile(scripted, example_args)) 23 | # CHECK: module 24 | # CHECK-DAG: func.func @sin 25 | 26 | scripted = torch.jit.script(BasicModule()) 27 | try: 28 | # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. 29 | torchscript.compile( 30 | scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)) 31 | ) 32 | except Exception as e: 33 | print(e) 34 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/_torch_mlir_custom_op_example/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Setup PyTorch 2 | include(TorchMLIRPyTorch) 3 | TorchMLIRProbeForPyTorchInstall() 4 | find_package(Torch 1.8 REQUIRED) 5 | TorchMLIRConfigurePyTorch() 6 | 7 | # Python sources 8 | declare_mlir_python_sources(TorchMLIRPythonSources.CustomOp 9 | ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" 10 | ADD_TO_PARENT TorchMLIRPythonSources 11 | SOURCES_GLOB 12 | _torch_mlir_custom_op_example/__init__.py 13 | ) 14 | 15 | # C++ extension 16 | include_directories(BEFORE 17 | ${TORCH_INCLUDE_DIRS} 18 | ) 19 | add_library(torch_mlir_custom_op_example SHARED torch_mlir_custom_op_example.cpp) 20 | target_link_libraries(torch_mlir_custom_op_example 21 | ${TORCH_LIBRARIES} 22 | ) 23 | # Because the custom op library is a bit odd, we'd like it to stay with the 24 | # Python component in the build directory. 25 | set_target_properties(torch_mlir_custom_op_example PROPERTIES 26 | LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_torch_mlir_custom_op_example/" 27 | COMPILE_FLAGS "${TORCH_CXXFLAGS}" 28 | ) 29 | torch_mlir_python_target_compile_options(torch_mlir_custom_op_example) 30 | mlir_check_all_link_libraries(torch_mlir_custom_op_example) 31 | -------------------------------------------------------------------------------- /test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s 2 | 3 | // Check that linkage names consist of the dotted path from the root. 4 | 5 | // CHECK-LABEL: torch.global_slot.module_initializer { 6 | // CHECK: %[[FLOAT:.*]] = torch.constant.float 4.200000e+01 7 | // CHECK: torch.initialize.global_slots [ 8 | // CHECK: @m.float(%[[FLOAT]] : !torch.float) 9 | // CHECK: ] 10 | // CHECK: } 11 | // CHECK-LABEL: torch.global_slot @m.float : !torch.float 12 | 13 | 14 | torch.class_type @child { 15 | torch.attr "float" : !torch.float 16 | } 17 | torch.class_type @parent { 18 | torch.attr "m" : !torch.nn.Module<"child"> 19 | } 20 | 21 | %c42 = torch.constant.float 42.0 22 | %child = torch.nn_module { 23 | torch.slot "float", %c42 : !torch.float 24 | } : !torch.nn.Module<"child"> 25 | %parent = torch.nn_module { 26 | torch.slot "m", %child : !torch.nn.Module<"child"> 27 | } : !torch.nn.Module<"parent"> 28 | 29 | func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"child">) { 30 | %0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"child"> -> !torch.float 31 | return 32 | } 33 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import torch 6 | from torch_mlir.jit_ir_importer import ModuleBuilder 7 | 8 | # RUN: %PYTHON %s | FileCheck %s 9 | 10 | mb = ModuleBuilder() 11 | 12 | 13 | # CHECK-LABEL: func.func @__torch__.add3 14 | # Note that line-level debug information for parts unannotated in the Torch 15 | # graph are ascribed to the first op that carries source information. Presently 16 | # this includes naked constants, return and the function itself. This heuristic 17 | # likely needs to be improved and this test should be reworked when it is. 18 | @mb.import_function 19 | @torch.jit.script 20 | def add3(t0, t1, t2): 21 | # CHECK-DAG: torch.aten.add.Tensor {{.*}} loc("aten::add"({{.*}}debug-info.py":[[# @LINE + 1]] 22 | intermediate = t0 + t1 23 | # CHECK-DAG: torch.aten.mul.Tensor {{.*}} loc("aten::mul"({{.*}}debug-info.py":[[# @LINE + 1]] 24 | return intermediate * t2 25 | 26 | 27 | # Verify again with debug info present. Just checking that it makes it in there. 28 | mb.module.operation.print(enable_debug_info=True, use_local_scope=True) 29 | print() 30 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/utils/debug.h: -------------------------------------------------------------------------------- 1 | //===- debug.h ------------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #pragma once 11 | 12 | #include 13 | 14 | #include "sys_utils.h" 15 | 16 | #define PRINT_DEBUG(msg) \ 17 | std::cout << msg << " (" << __FILE__ << ":" << __LINE__ << ")" \ 18 | << std::endl; 19 | 20 | #define PRINT_FUNCTION() \ 21 | if (verbose_print_function) { \ 22 | std::cout << __PRETTY_FUNCTION__ << " (" << __FILE__ << ":" << __LINE__ \ 23 | << ")" << std::endl; \ 24 | } 25 | 26 | static const bool verbose_print_function = 27 | sys_util::GetEnvBool("VERBOSE_PRINT_FUNCTION", false); 28 | -------------------------------------------------------------------------------- /lib/Conversion/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(TorchOnnxToTorch) 2 | add_subdirectory(TorchToArith) 3 | add_subdirectory(TorchToLinalg) 4 | add_subdirectory(TorchToSCF) 5 | add_subdirectory(TorchToTensor) 6 | if(TORCH_MLIR_ENABLE_TOSA) 7 | add_subdirectory(TorchToTosa) 8 | endif() 9 | if(TORCH_MLIR_ENABLE_STABLEHLO) 10 | add_subdirectory(TorchToStablehlo) 11 | endif() 12 | add_subdirectory(TorchToTMTensor) 13 | add_subdirectory(TorchConversionToMLProgram) 14 | add_subdirectory(Utils) 15 | 16 | # TODO: Automate this with add_torch_mlir_conversion_library. 17 | set(linked_libs TorchMLIRTorchToArith 18 | TorchMLIRTorchToLinalg 19 | TorchMLIRTorchToSCF 20 | TorchMLIRTorchToTensor 21 | TorchMLIRTorchToTMTensor 22 | TorchMLIRTorchConversionToMLProgram 23 | TorchMLIRConversionUtils) 24 | if(TORCH_MLIR_ENABLE_STABLEHLO) 25 | list(APPEND linked_libs TorchMLIRTorchToStablehlo) 26 | endif() 27 | if(TORCH_MLIR_ENABLE_TOSA) 28 | list(APPEND linked_libs TorchMLIRTorchToTosa) 29 | endif() 30 | 31 | add_mlir_library(TorchMLIRConversionPasses 32 | Passes.cpp 33 | 34 | DEPENDS 35 | TorchMLIRConversionPassIncGen 36 | 37 | LINK_LIBS PUBLIC 38 | ${linked_libs} 39 | ) 40 | -------------------------------------------------------------------------------- /projects/pt1/python/test/lazy_backend/device_data_name.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | 9 | import torch 10 | import torch._lazy 11 | 12 | import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend 13 | 14 | from run_test import run_test 15 | 16 | lazy_backend._initialize() 17 | 18 | device = "lazy" 19 | 20 | 21 | # CHECK: 0 input tensors found 22 | # ----- 23 | # CHECK: PASS - test_no_device_data_name 24 | @run_test 25 | def test_no_device_data_name(): 26 | x = torch.tensor(1).to(device) 27 | y = torch.tensor(2).to(device) 28 | z = x + y 29 | torch._lazy.mark_step() 30 | 31 | 32 | # CHECK: Input tensor: input_x 33 | # CHECK: 1 input tensors found 34 | # ----- 35 | # CHECK: PASS - test_device_data_name 36 | @run_test 37 | def test_device_data_name(): 38 | x = torch.tensor(1).to(device) 39 | y = torch.tensor(2).to(device) 40 | 41 | lazy_backend.set_parameter_name(x, "input_x") 42 | 43 | z = x + y 44 | torch._lazy.mark_step() 45 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-error-submodule.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: not %PYTHON %s 2>&1 | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class Submodule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.t1 = torch.tensor([10.0, 20.0]) 19 | # Test a nontrivial recursive case of the diagnostic. 20 | # CHECK: Unhandled tensor that shares storage with another tensor. 21 | # CHECK-NEXT: Found at path '.m.t2' from root object '__torch__.TestModule' 22 | self.t2 = self.t1[0] 23 | 24 | 25 | class TestModule(torch.nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | self.m = Submodule() 29 | 30 | 31 | test_module = TestModule() 32 | recursivescriptmodule = torch.jit.script(test_module) 33 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 34 | mb.import_module(recursivescriptmodule._c) 35 | mb.module.operation.print() 36 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/node_import/classes.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch._C import CompilationUnit 9 | from torch_mlir.jit_ir_importer import ModuleBuilder 10 | 11 | import typing 12 | 13 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 14 | 15 | mb = ModuleBuilder() 16 | 17 | 18 | class BasicClass: 19 | def __init__(self, x: int): 20 | self.x = x 21 | 22 | 23 | # CHECK-LABEL: func.func @__torch__.prim_CreateObject( 24 | # CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.nn.Module<"__torch__.BasicClass"> { 25 | # CHECK: %[[OBJECT:.*]] = torch.prim.CreateObject !torch.nn.Module<"__torch__.BasicClass"> 26 | # CHECK: %[[NONE:.*]] = torch.prim.CallMethod %[[OBJECT]]["__init__"] (%[[ARG0]]) : !torch.nn.Module<"__torch__.BasicClass">, (!torch.int) -> !torch.none 27 | # CHECK: return %[[OBJECT]] : !torch.nn.Module<"__torch__.BasicClass"> 28 | @mb.import_function 29 | @torch.jit.script 30 | def prim_CreateObject(i: int): 31 | return BasicClass(i) 32 | 33 | 34 | mb.module.operation.print() 35 | print() 36 | -------------------------------------------------------------------------------- /python/torch_mlir/tools/opt/__main__.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | """Torch-MLIR modular optimizer driver 7 | 8 | Typically, when installed from a wheel, this can be invoked as: 9 | 10 | torch-mlir-opt [options] 11 | 12 | To see available passes, dialects, and options, run: 13 | 14 | torch-mlir-opt --help 15 | """ 16 | import os 17 | import platform 18 | import subprocess 19 | import sys 20 | 21 | from typing import Optional 22 | 23 | 24 | def _get_builtin_tool(exe_name: str) -> Optional[str]: 25 | if platform.system() == "Windows": 26 | exe_name = exe_name + ".exe" 27 | this_path = os.path.dirname(__file__) 28 | tool_path = os.path.join(this_path, "..", "..", "_mlir_libs", exe_name) 29 | return tool_path 30 | 31 | 32 | def main(args=None): 33 | if args is None: 34 | args = sys.argv[1:] 35 | exe = _get_builtin_tool("torch-mlir-opt") 36 | return subprocess.call(args=[exe] + args) 37 | 38 | 39 | if __name__ == "__main__": 40 | sys.exit(main()) 41 | -------------------------------------------------------------------------------- /lib/Conversion/TorchToStablehlo/Utils.cpp: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "./Utils.h" 11 | #include "torch-mlir/Dialect/Torch/Utils/Utils.h" 12 | 13 | using namespace mlir; 14 | using namespace torch; 15 | 16 | FailureOr torch_to_stablehlo::getBackendTypeForScalarType( 17 | MLIRContext *context, torch_upstream::ScalarType dtypeInt) { 18 | FailureOr maybeType = Torch::getTypeForScalarType( 19 | context, (torch_upstream::ScalarType)dtypeInt); 20 | if (failed(maybeType)) { 21 | return failure(); 22 | } 23 | Type type = *maybeType; 24 | // The stablehlo backend expects signed integers to be signless. 25 | if (type.isSignedInteger()) { 26 | type = IntegerType::get(context, type.getIntOrFloatBitWidth(), 27 | IntegerType::Signless); 28 | } 29 | return type; 30 | } 31 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/tuple.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.t = (1, 2) 19 | 20 | 21 | # CHECK: torch.class_type @[[CLASSTYPE:.*]] { 22 | # TODO: Don't lose element type. 23 | # CHECK: } 24 | # CHECK: %[[N1:.*]] = torch.constant.int 1 25 | # CHECK: %[[N2:.*]] = torch.constant.int 2 26 | # CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[N1]], %[[N2]] : !torch.int, !torch.int 27 | # CHECK: torch.nn_module { 28 | # CHECK: torch.slot "t", %[[TUPLE]] : !torch.tuple 29 | # CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> 30 | 31 | 32 | test_module = TestModule() 33 | recursivescriptmodule = torch.jit.script(test_module) 34 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 35 | mb.import_module(recursivescriptmodule._c) 36 | mb.module.operation.print() 37 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/list.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.l = [1, 2] 19 | 20 | 21 | # CHECK: torch.class_type @[[CLASSTYPE:.*]] { 22 | # CHECK: torch.attr "l" : !torch.list 23 | # CHECK: } 24 | # CHECK: %[[N1:.*]] = torch.constant.int 1 25 | # CHECK: %[[N2:.*]] = torch.constant.int 2 26 | # CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[N1]], %[[N2]] : (!torch.int, !torch.int) -> !torch.list 27 | # CHECK: torch.nn_module { 28 | # CHECK: torch.slot "l", %[[LIST]] : !torch.list 29 | # CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> 30 | 31 | 32 | test_module = TestModule() 33 | recursivescriptmodule = torch.jit.script(test_module) 34 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 35 | mb.import_module(recursivescriptmodule._c) 36 | mb.module.operation.print() 37 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/README.md: -------------------------------------------------------------------------------- 1 | # Torch-MLIR Lazy Tensor Core Backend 2 | 3 | ## Detailed Documentation 4 | 5 | Detailed documentation about the architecture of this LTC backend is available [here](../../../../docs/ltc_backend.md). 6 | 7 | ## Summary 8 | 9 | Contained within this directory are the components that implements the 10 | Torch-MLIR LTC backend. Note that the code style for LTC components is 11 | consistent with that of LTC itself, rather than the rest of Torch-MLIR. 12 | 13 | The components are subclasses of the backend API interface classes found under 14 | [torch/csrc/lazy/backend](https://github.com/pytorch/pytorch/tree/master/torch/csrc/lazy/backend). 15 | 16 | Importantly, the subclasses are still abstract classes. Pure virtual methods 17 | such as `Compile` were purposefully not overridden as Torch-MLIR does not know 18 | how to compile the model for the target hardware. 19 | 20 | The intent is that vendor hardware specific plugins will subclass the Torch-MLIR 21 | backend classes and override the remaining pure virtual functions to complete 22 | the backend. 23 | 24 | The Torch-MLIR LTC backend's job is to perform the lowering from ATen to MLIR. A 25 | hardware vendor's backend job is to take care of the actual compile and 26 | execution of the lowered MLIR. 27 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | import torch 7 | 8 | from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem 9 | 10 | 11 | class NativeTorchTestConfig(TestConfig): 12 | """TestConfig that just runs the torch.nn.Module without compiling""" 13 | 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def compile( 18 | self, program: torch.nn.Module, verbose: bool = False 19 | ) -> torch.nn.Module: 20 | return program 21 | 22 | def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: 23 | # TODO: Deepcopy the torch.nn.Module, so that if the program is 24 | # stateful then it does not mutate the original compiled program. 25 | result: Trace = [] 26 | for item in trace: 27 | output = getattr(artifact, item.symbol)(*item.inputs) 28 | result.append( 29 | TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) 30 | ) 31 | return result 32 | -------------------------------------------------------------------------------- /lib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | torch_mlir_enable_werror() 2 | 3 | add_subdirectory(CAPI) 4 | add_subdirectory(Conversion) 5 | add_subdirectory(Dialect) 6 | 7 | set(LinkedLibs 8 | MLIRComplexDialect 9 | MLIRFuncDialect 10 | MLIRFuncInlinerExtension 11 | MLIRIR 12 | MLIRMLProgramDialect 13 | MLIRMemRefDialect 14 | MLIRSCFDialect 15 | MLIRTensorDialect 16 | MLIRTensorInferTypeOpInterfaceImpl 17 | MLIRSupport 18 | 19 | # Dialects. 20 | TorchMLIRTMTensorDialect 21 | TorchMLIRTorchDialect 22 | TorchMLIRTorchConversionDialect 23 | 24 | # Dialect passes. 25 | TorchMLIRTMTensorPasses 26 | TorchMLIRTorchConversionPasses 27 | TorchMLIRTorchPasses 28 | 29 | # Conversion passes. 30 | TorchMLIRConversionPasses 31 | TorchMLIRTorchOnnxToTorch 32 | ) 33 | 34 | if(TORCH_MLIR_ENABLE_STABLEHLO) 35 | list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses) 36 | endif() 37 | 38 | if(TORCH_MLIR_ENABLE_TOSA) 39 | list(APPEND LinkedLibs MLIRTosaDialect) 40 | endif() 41 | 42 | if(TORCH_MLIR_ENABLE_REFBACKEND) 43 | add_subdirectory(RefBackend) 44 | list(APPEND LinkedLibs TorchMLIRRefBackend) 45 | endif() 46 | 47 | add_mlir_library(TorchMLIRInitAll 48 | InitAll.cpp 49 | 50 | LINK_LIBS PUBLIC 51 | ${LinkedLibs} 52 | ) 53 | 54 | torch_mlir_target_includes(TorchMLIRInitAll) 55 | -------------------------------------------------------------------------------- /include/torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_ 11 | #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_ 12 | 13 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 14 | #include "mlir/IR/Builders.h" 15 | #include "mlir/IR/BuiltinTypes.h" 16 | #include "mlir/IR/Operation.h" 17 | #include "mlir/Interfaces/ViewLikeInterface.h" 18 | #include "mlir/Support/LLVM.h" 19 | 20 | /// Include the ODS generated interface header files. 21 | #include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc" 22 | 23 | namespace mlir { 24 | namespace torch { 25 | namespace TMTensor {} // namespace TMTensor 26 | } // namespace torch 27 | } // namespace mlir 28 | 29 | #endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_ 30 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h: -------------------------------------------------------------------------------- 1 | //===- unbind_int.h ------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #pragma once 11 | 12 | #include "../mlir_node.h" 13 | 14 | namespace torch { 15 | namespace lazy { 16 | 17 | class UnbindCopyInt : public torch::lazy::TorchMlirNode { 18 | public: 19 | static torch::lazy::OpKind ClassOpKind() { 20 | return torch::lazy::OpKind(at::aten::unbind_copy); 21 | } 22 | 23 | UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim, 24 | std::vector &&shapes); 25 | 26 | std::string ToString() const override; 27 | 28 | bool CanBeReused(const torch::lazy::Value &self, const int64_t &dim) const; 29 | 30 | TorchMlirOpVector Lower(TorchMlirFunction function, 31 | TorchMlirLoweringContext *loctx) const override; 32 | 33 | int64_t dim; 34 | }; 35 | 36 | } // namespace lazy 37 | } // namespace torch 38 | -------------------------------------------------------------------------------- /python/TorchMLIRModule.cpp: -------------------------------------------------------------------------------- 1 | //===-- TorchBind.td - Torch dialect bind ------------------*- tablegen -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "mlir/Bindings/Python/PybindAdaptors.h" 11 | #include "torch-mlir-c/Dialects.h" 12 | #include "torch-mlir-c/Registration.h" 13 | 14 | namespace py = pybind11; 15 | 16 | PYBIND11_MODULE(_torchMlir, m) { 17 | torchMlirRegisterAllPasses(); 18 | 19 | m.doc() = "torch-mlir main python extension"; 20 | 21 | m.def( 22 | "register_dialect", 23 | [](MlirContext context, bool load) { 24 | MlirDialectHandle handle = mlirGetDialectHandle__torch__(); 25 | mlirDialectHandleRegisterDialect(handle, context); 26 | if (load) { 27 | mlirDialectHandleLoadDialect(handle, context); 28 | } 29 | }, 30 | py::arg("context"), py::arg("load") = true); 31 | 32 | m.def("get_int64_max", []() { return INT64_MAX; }); 33 | 34 | m.def("get_int64_min", []() { return INT64_MIN; }); 35 | } 36 | -------------------------------------------------------------------------------- /test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit: -------------------------------------------------------------------------------- 1 | # Test the expansion of ONNX operators that are functions, specifically the 2 | # propagation of attribute values from the call-site to nodes within the 3 | # expanded function. 4 | # 5 | # In this case, the model has a ReduceSumSquare node with the attribute 6 | # 'keepdims' set to 0, and the definition of this version of ReduceSumSquare 7 | # contains a ReduceSum node that references the value of 'keepdims', so we 8 | # expect to see this value propagated to the ReduceSum node in the expansion. 9 | # 10 | # This also tests that the absence of 'axes' (as an optional attribute with no 11 | # default value) is propagated in the same way. 12 | # 13 | # The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_do_not_keepdims_example/model.onnx 14 | 15 | # RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s 16 | # 17 | # CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example 18 | # CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}" 19 | # 20 | # CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}" 21 | # CHECK: %0 = torch.operator "onnx.Mul" 22 | # CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 0 : si64} 23 | -------------------------------------------------------------------------------- /test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit: -------------------------------------------------------------------------------- 1 | # Test the expansion of ONNX operators that are functions, specifically the 2 | # propagation of attribute values from the call-site to nodes within the 3 | # expanded function. 4 | # 5 | # In this case, the model has a ReduceSumSquare node with no attributes, but the 6 | # definition of this version of ReduceSumSquare contains a ReduceSum node that 7 | # references the value of 'keepdims', and the definition says its default value 8 | # is 1, so we expect to see this value propagated to the ReduceSum node in the 9 | # expansion. 10 | # 11 | # This also tests that the absence of 'axes' (as an optional attribute with no 12 | # default value) is propagated in the same way. 13 | # 14 | # The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_empty_set/model.onnx 15 | 16 | # RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s 17 | # 18 | # CHECK-LABEL: func.func @test_reduce_sum_square_empty_set 19 | # CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}" 20 | # 21 | # CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}" 22 | # CHECK: %0 = torch.operator "onnx.Mul" 23 | # CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 1 : si64} 24 | -------------------------------------------------------------------------------- /include/torch-mlir-c/Registration.h: -------------------------------------------------------------------------------- 1 | /*===-- torch-mlir-c/Registration.h - Registration functions -----*- C -*-===*\ 2 | |* *| 3 | |* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| 4 | |* Exceptions. *| 5 | |* See https://llvm.org/LICENSE.txt for license information. *| 6 | |* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| 7 | |* *| 8 | \*===----------------------------------------------------------------------===*/ 9 | 10 | #ifndef TORCHMLIR_C_REGISTRATION_H 11 | #define TORCHMLIR_C_REGISTRATION_H 12 | 13 | #include "mlir-c/IR.h" 14 | #include "mlir-c/Support.h" 15 | 16 | #ifdef __cplusplus 17 | extern "C" { 18 | #endif 19 | 20 | /** Registers all dialects with a context. 21 | * This is needed before creating IR for these Dialects. 22 | */ 23 | MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context); 24 | 25 | /** Registers all passes for symbolic access with the global registry. */ 26 | MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses(void); 27 | 28 | #ifdef __cplusplus 29 | } 30 | #endif 31 | 32 | #endif // TORCHMLIR_C_REGISTRATION_H 33 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/ops/ivalue.h: -------------------------------------------------------------------------------- 1 | //===- index.h ------------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #pragma once 11 | 12 | #include "../mlir_node.h" 13 | 14 | namespace torch { 15 | namespace lazy { 16 | 17 | // IValueConstant IR Node represents a `prim::Constant` constructed with IValue 18 | // parameter which is helpful in different usecases when we need custom 19 | // native ops lowering to torch-mlir IR nodes. 20 | class IValueConstant : public torch::lazy::TorchMlirNode { 21 | public: 22 | static torch::lazy::OpKind ClassOpKind() { 23 | return torch::lazy::OpKind(at::prim::Constant); 24 | } 25 | 26 | IValueConstant(const c10::IValue &value); 27 | 28 | std::string ToString() const override; 29 | 30 | TorchMlirOpVector Lower(TorchMlirFunction function, 31 | TorchMlirLoweringContext *loctx) const override; 32 | 33 | c10::IValue value; 34 | }; 35 | 36 | } // namespace lazy 37 | } // namespace torch 38 | -------------------------------------------------------------------------------- /include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h: -------------------------------------------------------------------------------- 1 | //===------------------------------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H 11 | #define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H 12 | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" 14 | #include "mlir/Pass/Pass.h" 15 | #include 16 | 17 | namespace mlir { 18 | namespace torch { 19 | 20 | #define GEN_PASS_DECL_CONVERTTORCHTOSTABLEHLO 21 | #include "torch-mlir/Conversion/Passes.h.inc" 22 | 23 | std::unique_ptr> 24 | createConvertTorchToStablehloPass(); 25 | 26 | // Convenience wrapper for users who want to pass options as individual 27 | // parameters 28 | std::unique_ptr> 29 | createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); 30 | 31 | } // namespace torch 32 | } // namespace mlir 33 | 34 | #endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H 35 | -------------------------------------------------------------------------------- /include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td: -------------------------------------------------------------------------------- 1 | //===-------------------------------------------------------*- tablegen -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHCONVERSION_BASE 11 | #define TORCHCONVERSION_BASE 12 | 13 | include "mlir/IR/OpBase.td" 14 | 15 | def TorchConversion_Dialect : Dialect { 16 | // `torch_conversion` is too verbose. 17 | let name = "torch_c"; 18 | let cppNamespace = "::mlir::torch::TorchConversion"; 19 | let description = [{ 20 | This dialect contains ops and transforms for converting from the Torch 21 | backend contract to the linalg-on-tensors backend contract. 22 | 23 | This mainly consists of converting ops and types from `torch` dialect 24 | to the mix of dialects of the linalg-on-tensors backend contract, such as 25 | tensor ops being converted linalg-on-tensors and `!torch.vtensor` being 26 | converted to the builtin `tensor` type. 27 | }]; 28 | 29 | let hasConstantMaterializer = 1; 30 | } 31 | 32 | #endif // TORCHCONVERSION_BASE 33 | -------------------------------------------------------------------------------- /projects/jit_ir_common/csrc/jit_ir_importer/ivalue_importer.h: -------------------------------------------------------------------------------- 1 | //===- ivalue_importer.h ----------------------------------------*- C++ -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef TORCHMLIRJITIRIMPORTER_CSRC_IVALUE_IMPORTER_H 11 | #define TORCHMLIRJITIRIMPORTER_CSRC_IVALUE_IMPORTER_H 12 | 13 | #include 14 | 15 | #include "class_annotator.h" 16 | #include "import_options.h" 17 | 18 | #include "mlir-c/IR.h" 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | namespace torch_mlir { 25 | 26 | /// Main entry-point for importing torch IValue's . 27 | /// Recursively imports `ivalue`, inserting operations at the end of `block`. 28 | MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context, 29 | ClassAnnotator &annotator, 30 | const ImportOptions &importOptions); 31 | 32 | } // namespace torch_mlir 33 | 34 | #endif // TORCHMLIRJITIRIMPORTER_CSRC_IVALUE_IMPORTER_H 35 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp: -------------------------------------------------------------------------------- 1 | //===- ivalue.cpp 2 | //----------------------------------------------------------===// 3 | // 4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 | // See https://llvm.org/LICENSE.txt for license information. 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 | // Also available under a BSD-style license. See LICENSE. 8 | // 9 | //===----------------------------------------------------------------------===// 10 | 11 | #include "ivalue.h" 12 | 13 | #include 14 | 15 | namespace torch { 16 | namespace lazy { 17 | 18 | IValueConstant::IValueConstant(const c10::IValue &value) 19 | : torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{}, 20 | std::vector{}, 21 | /* num_outputs */ 1, torch::lazy::MHash()), 22 | value(value) {} 23 | 24 | std::string IValueConstant::ToString() const { 25 | std::stringstream ss; 26 | ss << torch::lazy::TorchMlirNode::ToString(); 27 | return ss.str(); 28 | } 29 | 30 | TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function, 31 | TorchMlirLoweringContext *loctx) const { 32 | return {loctx->graph()->insertConstant(value)}; 33 | } 34 | 35 | } // namespace lazy 36 | } // namespace torch 37 | -------------------------------------------------------------------------------- /projects/onnx_c_importer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | message(STATUS "Enabling onnx_c_importer...") 2 | 3 | include(FetchContent) 4 | 5 | find_package(Protobuf REQUIRED CONFIG) 6 | 7 | FetchContent_Declare( 8 | onnx 9 | EXCLUDE_FROM_ALL 10 | GIT_REPOSITORY https://github.com/onnx/onnx.git 11 | GIT_TAG v1.16.1 12 | GIT_SHALLOW ON 13 | GIT_PROGRESS ON 14 | ) 15 | FetchContent_MakeAvailable(onnx) 16 | 17 | set(LLVM_REQUIRES_EH ON) 18 | set(LLVM_REQUIRES_RTTI ON) 19 | 20 | 21 | add_llvm_executable( 22 | torch-mlir-import-onnx 23 | PARTIAL_SOURCES_INTENDED 24 | 25 | import-onnx-main.cpp 26 | OnnxImporter.h 27 | OnnxImporter.cpp 28 | SimpleArgParser.hpp 29 | Dict.hpp 30 | Status.hpp 31 | onnx_extras.hpp 32 | ) 33 | 34 | set_target_properties(torch-mlir-import-onnx PROPERTIES CXX_STANDARD 20) 35 | 36 | # Supress compiler warnings from onnx headers 37 | check_cxx_compiler_flag(-Wno-c++98-compat-extra-semi 38 | CXX_SUPPORTS_NO_CXX98_COMPAT_EXTRA_SEMI_FLAG) 39 | if (CXX_SUPPORTS_CXX98_COMPAT_EXTRA_SEMI_FLAG) 40 | target_compile_options(torch-mlir-import-onnx PRIVATE 41 | "-Wno-c++98-compat-extra-semi") 42 | target_compile_options(onnx PRIVATE 43 | "-Wno-c++98-compat-extra-semi") 44 | endif() 45 | 46 | target_link_libraries( 47 | torch-mlir-import-onnx 48 | MLIRCAPIIR 49 | TorchMLIRCAPI 50 | onnx 51 | ) 52 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h: -------------------------------------------------------------------------------- 1 | //===- mlir_node_lowering.h -----------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | // This file is adapted from pytorch/pytorch 10 | // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.h 11 | //===----------------------------------------------------------------------===// 12 | 13 | #pragma once 14 | 15 | #include 16 | #include 17 | 18 | namespace torch { 19 | namespace lazy { 20 | 21 | typedef std::vector TorchMlirOpVector; 22 | typedef std::shared_ptr TorchMlirFunction; 23 | 24 | TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin( 25 | TorchMlirFunction function, c10::Symbol sym, 26 | const c10::ArrayRef result_shapes, 27 | const std::vector &arguments, 28 | const std::vector &kwarguments = {}); 29 | 30 | } // namespace lazy 31 | } // namespace torch 32 | -------------------------------------------------------------------------------- /include/torch-mlir-c/TorchOps.h: -------------------------------------------------------------------------------- 1 | //===-- torch-mlir-c/TorchOps.h - C API for torch ops -------------*- C -*-===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 | // Exceptions. 5 | // See https://llvm.org/LICENSE.txt for license information. 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 | // Also available under a BSD-style license. See LICENSE. 8 | // 9 | //===----------------------------------------------------------------------===// 10 | 11 | #ifndef TORCHMLIR_C_TORCHOPS_H 12 | #define TORCHMLIR_C_TORCHOPS_H 13 | 14 | #include "mlir-c/IR.h" 15 | #include "mlir-c/Support.h" 16 | 17 | #ifdef __cplusplus 18 | extern "C" { 19 | #endif 20 | 21 | //===----------------------------------------------------------------------===// 22 | // Utilities. 23 | //===----------------------------------------------------------------------===// 24 | 25 | /// Adjusts the static information in the type of `value` to `desiredType`. 26 | /// 27 | /// Returns null if such an adjustment is not possible. 28 | /// 29 | /// If `userAllowsRefinement` is true, then the original value will be returned 30 | /// if it is a subtype of `desiredType`. 31 | MLIR_CAPI_EXPORTED MlirValue torchMlirAdjustStaticInformation( 32 | MlirBlock block, MlirOperation insertBefore, MlirValue value, 33 | MlirType desiredType, bool userAllowsRefinement); 34 | 35 | #ifdef __cplusplus 36 | } 37 | #endif 38 | 39 | #endif // TORCHMLIR_C_TORCHOPS_H 40 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/submodules-select.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class Submodule(torch.nn.Module): 16 | def __init__(self, n): 17 | super().__init__() 18 | self.n = n 19 | 20 | def forward(self): 21 | return self.n 22 | 23 | 24 | class TestModule(torch.nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | self.s1 = Submodule(1) 28 | self.s2 = Submodule(2) 29 | 30 | # CHECK-LABEL: func.func private @{{.*}}TestModule.forward 31 | def forward(self, b: bool): 32 | # Modules with the same class can be selected between. 33 | # CHECK: %[[MOD:.*]] = torch.prim.If 34 | s = self.s1 if b else self.s2 35 | # CHECK: %[[N:.*]] = torch.prim.CallMethod %[[MOD]]["forward"] () 36 | # CHECK: return %[[N]] 37 | return s.forward() 38 | 39 | 40 | test_module = TestModule() 41 | recursivescriptmodule = torch.jit.script(test_module) 42 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 43 | mb.import_module(recursivescriptmodule._c) 44 | mb.module.operation.print() 45 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/test_suite/custom_op_example.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | import torch 7 | 8 | from torch_mlir_e2e_test.framework import TestUtils 9 | from torch_mlir_e2e_test.registry import register_test_case 10 | from torch_mlir_e2e_test.annotations import annotate_args, export 11 | 12 | # ============================================================================== 13 | 14 | # Custom operators must be registered with PyTorch before being used. 15 | # This is part of the test. 16 | # Note that once this library has been loaded, the side effects mutate 17 | # the PyTorch op registry permanently. 18 | import torch_mlir._torch_mlir_custom_op_example 19 | 20 | 21 | class CustomOpExampleModule(torch.nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | @export 26 | @annotate_args( 27 | [ 28 | None, 29 | ([-1, -1], torch.float32, True), 30 | ] 31 | ) 32 | def forward(self, a): 33 | return torch.ops._torch_mlir_custom_op_example.identity(a) 34 | 35 | 36 | @register_test_case(module_factory=lambda: CustomOpExampleModule()) 37 | def CustomOpExampleModule_basic(module, tu: TestUtils): 38 | module.forward(tu.rand(3, 4)) 39 | -------------------------------------------------------------------------------- /test/Conversion/TorchToLinalg/squeeze.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s 2 | 3 | // CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic 4 | func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} { 5 | // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor 6 | // CHECK: %[[C0:.*]] = torch.constant.int 0 7 | // CHECK: %[[C0_1:.*]] = arith.constant 0 : index 8 | // CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor 9 | // CHECK: %[[C1:.*]] = arith.constant 1 : index 10 | // CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index 11 | // CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1" 12 | // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor into tensor 13 | // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor -> !torch.vtensor<[?,?],f32> 14 | %int0 = torch.constant.int 0 15 | %1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> 16 | return %1 : !torch.vtensor<[?,?],f32> 17 | } 18 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | import copy 7 | from typing import Any 8 | 9 | import torch 10 | 11 | from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem 12 | 13 | 14 | class TorchScriptTestConfig(TestConfig): 15 | """TestConfig that runs the torch.nn.Module through TorchScript""" 16 | 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def compile( 21 | self, program: torch.nn.Module, verbose: bool = False 22 | ) -> torch.jit.ScriptModule: 23 | return torch.jit.script(program) 24 | 25 | def run(self, artifact: torch.jit.ScriptModule, trace: Trace) -> Trace: 26 | # TODO: Deepcopy the torch.jit.ScriptModule, so that if the program is 27 | # stateful then it does not mutate the original compiled program. 28 | 29 | result: Trace = [] 30 | for item in trace: 31 | attr = artifact 32 | for part in item.symbol.split("."): 33 | attr = getattr(attr, part) 34 | output = attr(*item.inputs) 35 | result.append( 36 | TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) 37 | ) 38 | return result 39 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp: -------------------------------------------------------------------------------- 1 | #include "jit_utils.h" 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace torch { 8 | namespace jit { 9 | 10 | void ConvertScalarImplicit(std::shared_ptr &graph) { 11 | DepthFirstGraphNodeIterator it(graph); 12 | for (auto *node = it.next(); node != nullptr; node = it.next()) { 13 | if (node->kind() != c10::aten::ScalarImplicit) { 14 | continue; 15 | } 16 | 17 | auto input = node->input(0); 18 | auto scalar_type = input->type()->cast()->scalarType(); 19 | TORCH_CHECK(scalar_type, "scalar type is not defined for input value"); 20 | 21 | NodeKind node_type; 22 | TypePtr output_type; 23 | if (c10::isIntegralType(*scalar_type, true)) { 24 | node_type = c10::aten::IntImplicit; 25 | output_type = IntType::get(); 26 | } else if (c10::isFloatingType(*scalar_type)) { 27 | node_type = c10::aten::FloatImplicit; 28 | output_type = FloatType::get(); 29 | } else { 30 | throw std::runtime_error("Expected isIntegralType or isFloatingType"); 31 | } 32 | 33 | Value *output = graph->create(node_type, {input}) 34 | ->insertBefore(node) 35 | ->output() 36 | ->setType(output_type); 37 | node->output()->replaceAllUsesWith(output); 38 | node->destroy(); 39 | } 40 | } 41 | 42 | } // namespace jit 43 | } // namespace torch 44 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/ops/generic.h: -------------------------------------------------------------------------------- 1 | //===- generic.h ----------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | // This file is adapted from pytorch/pytorch 10 | // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/generic.h 11 | //===----------------------------------------------------------------------===// 12 | 13 | #pragma once 14 | 15 | #include "../mlir_node.h" 16 | 17 | namespace torch { 18 | namespace lazy { 19 | 20 | // Generic IR Node implementation for nodes which can simply be described by a 21 | // specific OpKind and a lowering function. IR nodes carrying 22 | // metadata should not be using this class TORCH_API (and have the metadata 23 | // captured by the LowerFn), but they should instead create a dedicated IR node. 24 | // Doing the former would limit IR introspection. 25 | class TORCH_API Generic : public TorchMlirNode { 26 | public: 27 | Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs = 1, 28 | hash_t hash_seed = static_cast(0x5a2d296e9)); 29 | 30 | private: 31 | hash_t hash_seed_; 32 | }; 33 | 34 | } // namespace lazy 35 | } // namespace torch 36 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-tensor-type-bound.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self, a, b): 20 | return 21 | 22 | 23 | test_module = TestModule() 24 | recursivescriptmodule = torch.jit.script(test_module) 25 | 26 | annotator = ClassAnnotator() 27 | class_type = recursivescriptmodule._c._type() 28 | # CHECK: func.func private @__torch__.TestModule.forward( 29 | # CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">, 30 | # CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>}, 31 | # CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>} 32 | # CHECK-SAME: ) -> !torch.none 33 | annotator.annotateArgs( 34 | class_type, 35 | ["forward"], 36 | [ 37 | None, 38 | ((-1, 1024), torch.int8, True), 39 | ((), torch.float, True), 40 | ], 41 | ) 42 | 43 | # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 44 | mb.import_module(recursivescriptmodule._c, annotator) 45 | mb.module.operation.print() 46 | -------------------------------------------------------------------------------- /projects/ltc/csrc/base_lazy_backend/utils/string_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | template 8 | std::ostream &string_join(std::ostream &out, const std::vector &v, 9 | const std::string &delimiter) { 10 | size_t i = 0; 11 | for (const T &e : v) { 12 | if ((i++) > 0) { 13 | out << delimiter; 14 | } 15 | out << e; 16 | } 17 | return out; 18 | } 19 | 20 | template 21 | std::string string_join(const std::vector &v, const std::string &delimiter) { 22 | std::ostringstream joined; 23 | string_join(joined, v, delimiter); 24 | return joined.str(); 25 | } 26 | 27 | inline std::vector string_split(const std::string &str, 28 | const std::string &sep) { 29 | std::vector tokens; 30 | std::size_t pos1 = str.find_first_not_of(sep); 31 | while (pos1 != std::string::npos) { 32 | std::size_t pos2 = str.find_first_of(sep, pos1); 33 | if (pos2 == std::string::npos) { 34 | tokens.push_back(str.substr(pos1)); 35 | pos1 = pos2; 36 | } else { 37 | tokens.push_back(str.substr(pos1, pos2 - pos1)); 38 | pos1 = str.find_first_not_of(sep, pos2 + 1); 39 | } 40 | } 41 | return tokens; 42 | } 43 | 44 | /* 45 | * Returns true if str starts with prefix 46 | */ 47 | inline bool startswith(const std::string &str, const std::string &prefix) { 48 | return str.rfind(prefix, 0) == 0; 49 | } 50 | -------------------------------------------------------------------------------- /projects/pt1/python/test/torchscript_e2e_test/basic.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | import torch 9 | 10 | from torch_mlir_e2e_test.framework import run_tests, TestUtils 11 | from torch_mlir_e2e_test.reporting import report_results 12 | from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY 13 | from torch_mlir_e2e_test.configs import TorchScriptTestConfig 14 | 15 | 16 | class MmModule(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, lhs, rhs): 21 | return torch.mm(lhs, rhs) 22 | 23 | 24 | # TODO: Refine messages. 25 | # CHECK: PASS - "MmModule_basic" 26 | @register_test_case(module_factory=lambda: MmModule()) 27 | def MmModule_basic(module, tu: TestUtils): 28 | module.forward(tu.rand(4, 4), tu.rand(4, 4)) 29 | 30 | 31 | # CHECK: PASS - "MmModule_basic2" 32 | @register_test_case(module_factory=lambda: MmModule()) 33 | def MmModule_basic2(module, tu: TestUtils): 34 | module.forward(tu.rand(4, 4), tu.rand(4, 4)) 35 | 36 | 37 | def main(): 38 | config = TorchScriptTestConfig() 39 | results = run_tests(GLOBAL_TEST_REGISTRY, config) 40 | report_results(results, set(), verbose=True) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /test/Dialect/TorchConversion/unpack-quant-tensor.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-unpack-quant-tensor))' -split-input-file -verify-diagnostics | FileCheck %s 2 | 3 | // CHECK-LABEL: func @forward 4 | func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> { 5 | %q_rhs = torch.vtensor.literal(dense<[[57, 128, 249, 244], [7, 243, 27, 15], [1, 2, 159, 71], [159, 253, 160, 231], [248, 224, 191, 228], [96, 15, 158, 220], [240, 250, 47, 208], [127, 192, 239, 176]]> : tensor<8x4xui8>) : !torch.vtensor<[8,4],ui8> 6 | // CHECK: %[[C0:.*]] = torch.vtensor.literal(dense<{{\[\[}}9, 3, 0, 8, 9, 15, 4, 15], [7, 0, 3, 15, 11, 1, 15, 0], [1, 0, 2, 0, 15, 9, 7, 4], [15, 9, 13, 15, 0, 10, 7, 14], [8, 15, 0, 14, 15, 11, 4, 14], [0, 6, 15, 0, 14, 9, 12, 13], [0, 15, 10, 15, 15, 2, 0, 13], [15, 7, 0, 12, 15, 14, 0, 11]]> : tensor<8x8xui4>) : !torch.vtensor<[8,8],ui4> 7 | %scales = torch.vtensor.literal(dense<1.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16> 8 | %zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16> 9 | %bit_width = torch.constant.int 4 10 | %group_size = torch.constant.int 2 11 | %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16> 12 | return %output : !torch.vtensor<[1,1,8],f16> 13 | } 14 | -------------------------------------------------------------------------------- /test/Dialect/Torch/reduce-op-variants-error.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-reduce-op-variants -verify-diagnostics -split-input-file %s 2 | 3 | // ----- 4 | 5 | func.func @convert_to_value_semantic_tensors_list( %list: !torch.list) -> !torch.tensor { 6 | %int1 = torch.constant.int 1 7 | // expected-error@+1 {{failed to legalize operation 'torch.aten.cat' that was explicitly marked illegal}} 8 | %ret = torch.aten.cat %list, %int1 : !torch.list, !torch.int -> !torch.tensor 9 | return %ret : !torch.tensor 10 | } 11 | 12 | // ----- 13 | 14 | func.func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.optional, 15 | %t: !torch.tensor, 16 | %training: !torch.bool, 17 | %cudnn_enable: !torch.bool, 18 | %f : !torch.float) -> !torch.tensor { 19 | // expected-error@+1 {{failed to legalize operation 'torch.aten.batch_norm' that was explicitly marked illegal}} 20 | %ret = torch.aten.batch_norm %t, %tensor_optional, %tensor_optional, %tensor_optional, 21 | %tensor_optional, %training, %f, %f, %cudnn_enable: 22 | !torch.tensor, !torch.optional, !torch.optional, 23 | !torch.optional, !torch.optional, 24 | !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.tensor 25 | return %ret: !torch.tensor 26 | } 27 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/primitives.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.i = 3 19 | self.f = 42.5 20 | 21 | 22 | # CHECK: torch.class_type @[[CLASSTYPE:.*]] { 23 | # CHECK: torch.attr "training" : !torch.bool 24 | # CHECK: torch.attr "i" : !torch.int 25 | # CHECK: torch.attr "f" : !torch.float 26 | # CHECK: } 27 | # CHECK: %[[TRUE:.*]] = torch.constant.bool true 28 | # CHECK: %[[N3:.*]] = torch.constant.int 3 29 | # CHECK: %[[N42:.*]] = torch.constant.float 4.250000e+01 30 | # CHECK: %[[MODULE:.*]] = torch.nn_module { 31 | # Note: for some reason, Torch always adds a "training" property to all modules. 32 | # CHECK: torch.slot "training", %[[TRUE]] : !torch.bool 33 | # CHECK: torch.slot "i", %[[N3]] : !torch.int 34 | # CHECK: torch.slot "f", %[[N42]] : !torch.float 35 | # CHECK: } : !torch.nn.Module<"[[CLASSTYPE:.*]]"> 36 | 37 | 38 | test_module = TestModule() 39 | recursivescriptmodule = torch.jit.script(test_module) 40 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 41 | mb.import_module(recursivescriptmodule._c) 42 | mb.module.operation.print() 43 | -------------------------------------------------------------------------------- /test/Dialect/Torch/GlobalizeObjectGraph/error.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s 2 | 3 | torch.class_type @c1 {} 4 | torch.class_type @c2 {} 5 | 6 | // expected-note @+1 {{see other root module here}} 7 | torch.nn_module {} : !torch.nn.Module<"c1"> 8 | // expected-error @+1 {{found more than one root module (module that is not a child of any other module)}} 9 | torch.nn_module {} : !torch.nn.Module<"c2"> 10 | 11 | // ----- 12 | 13 | torch.class_type @child { 14 | torch.attr "float" : !torch.float 15 | } 16 | torch.class_type @parent { 17 | torch.attr "m" : !torch.nn.Module<"child"> 18 | torch.attr "m2" : !torch.nn.Module<"child"> 19 | 20 | } 21 | 22 | %c42 = torch.constant.float 42.0 23 | // expected-error @+1 {{reachable by multiple paths from root object: '.m' and '.m2'}} 24 | %child = torch.nn_module { 25 | torch.slot "float", %c42 : !torch.float 26 | } : !torch.nn.Module<"child"> 27 | %parent = torch.nn_module { 28 | torch.slot "m", %child : !torch.nn.Module<"child"> 29 | torch.slot "m2", %child : !torch.nn.Module<"child"> 30 | } : !torch.nn.Module<"parent"> 31 | 32 | func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"child">) { 33 | %0 = torch.prim.GetAttr %arg0["m"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child"> 34 | %1 = torch.prim.GetAttr %arg0["m2"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child"> 35 | %2 = torch.prim.GetAttr %arg1["float"] : !torch.nn.Module<"child"> -> !torch.float 36 | return 37 | } 38 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export-error.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self): 20 | return 21 | 22 | 23 | test_module = TestModule() 24 | recursivescriptmodule = torch.jit.script(test_module) 25 | 26 | annotator = ClassAnnotator() 27 | class_type = recursivescriptmodule._c._type() 28 | 29 | try: 30 | annotator.exportPath(class_type, ["a"]) 31 | except Exception as e: 32 | # CHECK: class '__torch__.TestModule' does not have a method or attribute called 'a' 33 | print(e) 34 | try: 35 | annotator.exportPath(class_type, []) 36 | except Exception as e: 37 | # CHECK: Empty exported path. Can only export a property of a class. 38 | print(e) 39 | 40 | try: 41 | annotator.exportPath(class_type, ["a", "b"]) 42 | except Exception as e: 43 | # This error is generated by PyTorch itself, so be a bit defensive about changes. 44 | # CHECK: __torch__.TestModule {{.*}} 'a' 45 | print(e) 46 | 47 | # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 48 | mb.import_module(recursivescriptmodule._c, annotator) 49 | mb.module.operation.print() 50 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/methods-derefine.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | # CHECK-LABEL: func.func private @__torch__.TestModule.forward( 20 | # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional { 21 | # CHECK: %[[NONE:.*]] = torch.constant.none 22 | # CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional 23 | # CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[SELF]]["callee"] (%[[DEREFINED]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.optional) -> !torch.optional 24 | # CHECK: return %[[RET]] : !torch.optional 25 | def forward(self): 26 | return self.callee(None) 27 | 28 | def callee(self, o: typing.Optional[int]): 29 | return o 30 | 31 | 32 | test_module = TestModule() 33 | recursivescriptmodule = torch.jit.script(test_module) 34 | # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 35 | mb.import_module(recursivescriptmodule._c) 36 | mb.module.operation.print() 37 | -------------------------------------------------------------------------------- /test/Conversion/TorchConversionToMLProgram/basic.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s 2 | 3 | // CHECK-LABEL: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor 4 | // CHECK-LABEL: func.func @f() -> i64 { 5 | // CHECK: %[[GLOBAL:.*]] = ml_program.global_load @global_seed : tensor 6 | // CHECK: %[[SEED:.*]] = tensor.extract %[[GLOBAL]][] : tensor 7 | // CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64 8 | // CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64 9 | // CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64 10 | // CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64 11 | // CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[GLOBAL]][] : tensor 12 | // CHECK: ml_program.global_store @global_seed = %[[INSERTED]] : tensor 13 | // CHECK: return %[[NEXT_SEED]] : i64 14 | module { 15 | func.func @f() -> i64 { 16 | %seed = torch_c.get_next_seed : () -> i64 17 | return %seed : i64 18 | } 19 | } 20 | 21 | // ----- 22 | 23 | module { 24 | func.func @no_seed_needed(%arg0: tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> { 25 | %0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> 26 | return %0 : !torch.vtensor<[2,3],f32> 27 | } 28 | } 29 | 30 | // CHECK-NOT: ml_program.global 31 | // CHECK-LABEL: @no_seed_needed 32 | // CHECK-NEXT: torch_c.from_builtin_tensor 33 | -------------------------------------------------------------------------------- /test/Dialect/Torch/verify-backend-contract-error.mlir: -------------------------------------------------------------------------------- 1 | // RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s 2 | 3 | func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { 4 | // expected-error @below {{unsupported by backend contract: tensor with unknown rank}} 5 | // expected-note @below {{this is likely due to a missing transfer function}} 6 | %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor 7 | return %t : !torch.vtensor 8 | } 9 | 10 | // ----- 11 | 12 | // expected-error @below {{invalid dtype 'i9'}} 13 | func.func @bad_element_type(%arg: !torch.vtensor<[?],i9>) -> !torch.vtensor<[?],i9> { 14 | return %arg : !torch.vtensor<[?],i9> 15 | } 16 | 17 | // ----- 18 | 19 | // expected-error @below {{unsupported by backend contract: non-value tensor type}} 20 | // expected-note @below {{this is likely due to a missing case in the MaximizeValueSemantics pass}} 21 | func.func @non_value_tensor(%arg0: !torch.tensor) -> !torch.tensor { 22 | return %arg0 : !torch.tensor 23 | } 24 | 25 | // ----- 26 | 27 | func.func @valid_tuple(%arg0: !torch.vtensor<[?],f32>) -> !torch.tuple> { 28 | %0 = torch.prim.TupleConstruct %arg0 : !torch.vtensor<[?],f32> -> !torch.tuple> 29 | return %0 : !torch.tuple> 30 | } 31 | 32 | // ----- 33 | 34 | func.func @valid_multiple_ret_values(%arg0: !torch.vtensor<[?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) { 35 | return %arg0, %arg0 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> 36 | } 37 | -------------------------------------------------------------------------------- /projects/pt1/python/torch_mlir/csrc/jit_ir_importer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Separate Pybind MODULE due to issues with a SHARED library. 2 | # https://github.com/llvm/torch-mlir/issues/1154 3 | add_library(TorchMLIRJITIRImporterPybind MODULE 4 | class_annotator_pybind.cpp 5 | get_registered_ops.cpp 6 | import_options_pybind.cpp 7 | init_python_bindings.cpp 8 | module_builder.cpp 9 | ) 10 | add_dependencies(TorchMLIRJITIRImporterPybind 11 | TorchMLIRJITIRImporter 12 | ) 13 | target_link_libraries(TorchMLIRJITIRImporterPybind 14 | ${TORCH_LIBRARIES} 15 | torch_python 16 | TorchMLIRJITIRImporter 17 | ) 18 | 19 | # On static Python builds, there may not be Python libraries to link against 20 | # (they will late bind at runtime from the executable). We have to condition 21 | # this because in that case it is set to NOTFOUND and CMake will consider 22 | # this an error. 23 | if(Python3_LIBRARIES) 24 | target_link_libraries(TorchMLIRJITIRImporterPybind 25 | ${Python3_LIBRARIES} 26 | ) 27 | endif() 28 | 29 | set_target_properties(TorchMLIRJITIRImporterPybind PROPERTIES 30 | LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" 31 | OUTPUT_NAME _jit_ir_importer 32 | PREFIX "${PYTHON_MODULE_PREFIX}" 33 | SUFFIX "${PYTHON_MODULE_EXTENSION}" 34 | CXX_VISIBILITY_PRESET "hidden" 35 | COMPILE_FLAGS "${TORCH_CXXFLAGS}" 36 | ) 37 | mlir_python_setup_extension_rpath(TorchMLIRJITIRImporterPybind) 38 | 39 | torch_mlir_python_target_compile_options(TorchMLIRJITIRImporterPybind) 40 | mlir_check_all_link_libraries(TorchMLIRJITIRImporterPybind) 41 | -------------------------------------------------------------------------------- /projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/export.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | # This file is licensed under a pytorch-style license 3 | # See LICENSE.pytorch for license information. 4 | 5 | import typing 6 | 7 | import torch 8 | from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder 9 | 10 | # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s 11 | 12 | mb = ModuleBuilder() 13 | 14 | 15 | class TestModule(torch.nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.exported = 1 19 | self.not_exported = 2 20 | 21 | def forward(self): 22 | return self.not_exported_method() 23 | 24 | def not_exported_method(self): 25 | return 26 | 27 | 28 | test_module = TestModule() 29 | recursivescriptmodule = torch.jit.script(test_module) 30 | 31 | annotator = ClassAnnotator() 32 | class_type = recursivescriptmodule._c._type() 33 | # CHECK-LABEL: torch.class_type @__torch__.TestModule { 34 | # CHECK: torch.attr "exported" : !torch.int 35 | # CHECK: torch.attr private "not_exported" : !torch.int 36 | # CHECK: torch.method "forward", @{{.*}} 37 | # CHECK: torch.method private "not_exported_method", @{{.*}} 38 | # CHECK: } 39 | annotator.exportNone(class_type) 40 | annotator.exportPath(class_type, ["exported"]) 41 | annotator.exportPath(class_type, ["forward"]) 42 | 43 | # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. 44 | mb.import_module(recursivescriptmodule._c, annotator) 45 | mb.module.operation.print() 46 | -------------------------------------------------------------------------------- /projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | import torch 9 | 10 | from torch_mlir_e2e_test.framework import run_tests, TestUtils 11 | from torch_mlir_e2e_test.reporting import report_results 12 | from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY 13 | from torch_mlir_e2e_test.configs import TorchScriptTestConfig 14 | 15 | 16 | class MmModule(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, t): 21 | # Input of `torch.tensor` only allows ints, floats, or bools while empty 22 | # list defaults to tensor type 23 | return torch.tensor([]) 24 | 25 | 26 | # CHECK: FAIL - "MmModule_basic" 27 | # CHECK: Runtime error: 28 | # Assume that the diagnostic from the TorchScript runtime will at least contain 29 | # the offending "return torch.tensor([])". 30 | # CHECK: return torch.tensor([]) 31 | @register_test_case(module_factory=lambda: MmModule()) 32 | def MmModule_basic(module, tu: TestUtils): 33 | module.forward(torch.ones([])) 34 | 35 | 36 | def main(): 37 | config = TorchScriptTestConfig() 38 | results = run_tests(GLOBAL_TEST_REGISTRY, config) 39 | report_results(results, set(), verbose=True) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /lib/Dialect/Torch/IR/TorchOpsODSGenerated.cpp: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | // See https://llvm.org/LICENSE.txt for license information. 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | // Also available under a BSD-style license. See LICENSE. 7 | // 8 | //===----------------------------------------------------------------------===// 9 | // 10 | // This file is meant to include the `TorchOps.cpp.inc` file and compile it 11 | // separately from the main TorchOps.cpp file. The .inc file takes a very long 12 | // time to compile, and slows down the iteration time on folders, 13 | // canonicalizations, parser/printers, etc. in the actual TorchOps.cpp file, so 14 | // it makes sense to isolate it and let the build system cache it. 15 | // 16 | //===----------------------------------------------------------------------===// 17 | 18 | #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" 19 | 20 | #include "UtilsForODSGenerated.h" 21 | #include "mlir/IR/Builders.h" 22 | #include "mlir/IR/BuiltinOps.h" 23 | #include "mlir/IR/PatternMatch.h" 24 | #include "mlir/IR/TypeUtilities.h" 25 | #include "mlir/Support/LLVM.h" 26 | #include "torch-mlir/Dialect/Torch/Utils/Utils.h" 27 | #include "llvm/ADT/BitVector.h" 28 | #include "llvm/ADT/StringMap.h" 29 | #include "llvm/Support/Casting.h" 30 | 31 | using namespace mlir; 32 | using namespace mlir::torch; 33 | using namespace mlir::torch::Torch; 34 | 35 | #define GET_OP_CLASSES 36 | #include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc" 37 | -------------------------------------------------------------------------------- /projects/pt1/python/test/torchscript_e2e_test/submodule.py: -------------------------------------------------------------------------------- 1 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | # Also available under a BSD-style license. See LICENSE. 5 | 6 | # RUN: %PYTHON %s | FileCheck %s 7 | 8 | import torch 9 | 10 | from torch_mlir_e2e_test.framework import run_tests, TestUtils 11 | from torch_mlir_e2e_test.reporting import report_results 12 | from torch_mlir_e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY 13 | from torch_mlir_e2e_test.configs import TorchScriptTestConfig 14 | 15 | 16 | class Submodule2(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, lhs, rhs): 21 | return torch.mm(lhs, rhs) 22 | 23 | 24 | class Submodule(torch.nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | self.m2 = Submodule2() 28 | 29 | 30 | class ModuleWithSubmodule(torch.nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | self.m = Submodule() 34 | 35 | 36 | # CHECK: PASS - "ModuleWithSubmodule_basic" 37 | @register_test_case(module_factory=lambda: ModuleWithSubmodule()) 38 | def ModuleWithSubmodule_basic(module, tu: TestUtils): 39 | module.m.m2.forward(tu.rand(4, 4), tu.rand(4, 4)) 40 | 41 | 42 | def main(): 43 | config = TorchScriptTestConfig() 44 | results = run_tests(GLOBAL_TEST_REGISTRY, config) 45 | report_results(results, set()) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /utils/bazel/torch-mlir-overlay/test/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # This file is licensed under the Apache License v2.0 with LLVM Exceptions. 2 | # See https://llvm.org/LICENSE.txt for license information. 3 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 | 5 | load("@bazel_skylib//rules:expand_template.bzl", "expand_template") 6 | load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") 7 | 8 | package( 9 | default_visibility = [ 10 | "//visibility:public", 11 | ], 12 | ) 13 | 14 | expand_template( 15 | name = "lit_site_cfg_py", 16 | testonly = True, 17 | out = "lit.site.cfg.py", 18 | substitutions = { 19 | "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", 20 | "@TORCH_MLIR_SOURCE_DIR@": package_path("@torch-mlir//:BUILD"), 21 | "\"@TORCH_MLIR_BINARY_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'torch-mlir')", 22 | "\"@LLVM_TOOLS_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', 'llvm')", 23 | # All disabled, but required to substituted because they are not in quotes. 24 | "@MLIR_ENABLE_BINDINGS_PYTHON@": "0", 25 | "@TORCH_MLIR_ENABLE_STABLEHLO@": "0", 26 | "@TORCH_MLIR_ENABLE_REFBACKEND@": "1", 27 | }, 28 | template = "lit.site.cfg.py.in", 29 | ) 30 | 31 | # Common data used by most lit tests. 32 | filegroup( 33 | name = "lit_data", 34 | testonly = True, 35 | data = [ 36 | "lit.cfg.py", 37 | "lit.site.cfg.py", 38 | "@llvm-project//llvm:FileCheck", 39 | "@llvm-project//llvm:count", 40 | "@llvm-project//llvm:not", 41 | ], 42 | ) 43 | --------------------------------------------------------------------------------