├── .github ├── Dockerfile ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── actions │ ├── build-base-image │ │ └── action.yaml │ ├── build-docs │ │ └── action.yaml │ ├── build-wheel │ │ └── action.yaml │ └── run-tests │ │ └── action.yaml ├── base-image │ └── Dockerfile ├── requirements-ci.txt ├── scripts │ ├── __init__.py │ ├── bench │ │ └── __init__.py │ ├── build_wheel.sh │ ├── cuda_cleanup.py │ ├── db_utils.py │ ├── set_gpu_types.sh │ ├── set_test_matrix.py │ ├── start_instances.py │ ├── stop_instances.py │ └── upload_results.py └── workflows │ ├── check-pr-title.yaml │ ├── launch.yaml │ ├── lint.yaml │ ├── nightly.yaml │ ├── publish-centml-pypi.yaml │ ├── regression.yaml │ ├── release.yaml │ ├── sync.yaml │ ├── tests.yaml │ └── upload-dev-wheel.yaml ├── .gitignore ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── apps └── compile_server │ ├── Dockerfile │ ├── README.md │ ├── app.py │ ├── requirements.txt │ ├── resources │ ├── __init__.py │ ├── auth.py │ ├── compilation.py │ ├── compile_worker.py │ ├── status.py │ ├── user.py │ └── utils.py │ ├── run.py │ └── run.sh ├── config.cmake ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ ├── custom.css │ ├── favicon.svg │ ├── img │ │ ├── resolve-example-conv2d.svg │ │ ├── resolve-example-matmul.svg │ │ └── subgraph-rewrite-example.svg │ └── logo.svg │ ├── clear-hidet-cache.py │ ├── conf.py │ ├── developer-guides │ └── contributing.rst │ ├── genindex.rst │ ├── getting-started │ ├── build-from-source.rst │ └── install.rst │ ├── hidet-script │ ├── examples │ │ └── index.rst │ ├── index.rst │ └── reference │ │ ├── 1-type-system.rst │ │ ├── 2-expression.rst │ │ ├── 3-statement.rst │ │ ├── 4-function.rst │ │ ├── 5-module.rst │ │ ├── 6-cuda-specific.rst │ │ ├── 7-cpu-specific.rst │ │ └── index.rst │ ├── how-to-guides │ └── add-new-operator │ │ └── index.rst │ ├── index.rst │ ├── notes │ └── operator-cache.rst │ ├── python_api │ ├── cuda.rst │ ├── data_types.rst │ ├── drivers.rst │ ├── ffi │ │ └── index.rst │ ├── graph │ │ ├── frontend │ │ │ ├── index.rst │ │ │ ├── onnx.rst │ │ │ └── torch.rst │ │ ├── index.rst │ │ └── transforms │ │ │ ├── index.rst │ │ │ ├── resolve_variant.rst │ │ │ └── subgraph_rewrite.rst │ ├── index.rst │ ├── ir │ │ ├── compute.rst │ │ ├── expr.rst │ │ ├── func.rst │ │ ├── index.rst │ │ ├── stmt.rst │ │ ├── task.rst │ │ └── type.rst │ ├── ops │ │ └── index.rst │ ├── option.rst │ ├── root.rst │ ├── runtime │ │ └── index.rst │ ├── tensor.rst │ ├── testing │ │ └── index.rst │ └── utils │ │ └── index.rst │ └── references.bib ├── examples ├── README.md ├── distributed │ └── test.py └── quantization │ ├── gpt2.py │ └── gpt2_performance.py ├── gallery ├── README.rst ├── developer-guides │ ├── README.rst │ ├── add-new-operator-compute-definition.py │ ├── add-new-operator-rule-based.py │ ├── add-new-operator-template-based.py │ ├── add-operator-resolve-rule.py │ ├── add-subgraph-rewrite-rule.py │ ├── add-torch-operator-mapping.py │ └── hidet-script-dynamic-kernel.py ├── getting-started │ ├── README.rst │ └── quick-start.py ├── hidet-script │ ├── 0-hello-world.py │ ├── 1-scalar-addition.py │ ├── 2-vector-addition.py │ ├── 3-kernel-functions.py │ ├── 4-naive-matmul.py │ ├── 5-efficient-matmul.py │ └── README.rst ├── how-to-guides │ ├── README.rst │ └── visualize-flow-graph.py └── tutorials │ ├── README.rst │ ├── optimize-onnx-model.py.backup │ └── optimize-pytorch-model.py.backup ├── include └── hidet │ ├── runtime.h │ └── runtime │ ├── callbacks.h │ ├── common.h │ ├── context.h │ ├── cpu │ ├── bfloat16.h │ ├── complex.h │ ├── context.h │ ├── float16.h │ ├── float32.h │ └── vector_types.h │ ├── cuda │ ├── complex.h │ ├── context.h │ ├── cublas.h │ ├── cuda.h │ ├── cudnn.h │ ├── float8_e4m3.h │ └── float8_e5m2.h │ ├── hip │ ├── context.h │ └── f16_utils.h │ ├── int_fastdiv.h │ ├── logging.h │ ├── memory_planner.h │ ├── symbols.h │ └── torch │ └── stream.h ├── pyproject.toml ├── python └── hidet │ ├── __init__.py │ ├── apps │ └── compile_server │ │ ├── __init__.py │ │ ├── auth.py │ │ ├── compilation.py │ │ ├── core.py │ │ └── user.py │ ├── backend │ ├── __init__.py │ ├── build.py │ └── codegen.py │ ├── cli │ ├── __init__.py │ ├── bench │ │ ├── __init__.py │ │ ├── bench.py │ │ ├── bench_all.py │ │ ├── bench_common.py │ │ ├── model.py │ │ ├── nlp │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ └── nlp_model.py │ │ └── vision │ │ │ ├── __init__.py │ │ │ ├── inception_v3.py │ │ │ ├── mobilenet_v2.py │ │ │ ├── resnet.py │ │ │ ├── resnext.py │ │ │ └── vision_model.py │ ├── cache │ │ ├── __init__.py │ │ ├── clear.py │ │ ├── entry.py │ │ ├── status.py │ │ └── utils.py │ └── main.py │ ├── cuda │ ├── __init__.py │ ├── capability.py │ ├── cublas │ │ ├── __init__.py │ │ ├── ffi.py │ │ ├── kernels.py │ │ └── utils.py │ ├── cudnn │ │ ├── __init__.py │ │ ├── benchmark.py │ │ ├── ffi.py │ │ ├── kernels.py │ │ └── utils.py │ ├── device.py │ ├── event.py │ ├── graph.py │ ├── memory.py │ ├── nccl │ │ ├── __init__.py │ │ ├── comm.py │ │ ├── ffi.py │ │ └── libinfo.py │ └── stream.py │ ├── distributed │ ├── __init__.py │ ├── distributed.py │ ├── group.py │ └── store.py │ ├── drivers │ ├── __init__.py │ ├── build_graph.py │ ├── build_module.py │ ├── build_task.py │ └── utils.py │ ├── ffi │ ├── __init__.py │ ├── array.py │ ├── callbacks.py │ ├── convert.py │ ├── crt.py │ ├── ffi.py │ ├── runtime_api.py │ ├── shared_lib.py │ └── utils.py │ ├── graph │ ├── __init__.py │ ├── common.py │ ├── flow_graph.py │ ├── frontend │ │ ├── __init__.py │ │ ├── onnx │ │ │ ├── __init__.py │ │ │ ├── availability.py │ │ │ ├── onnx.py │ │ │ └── utils.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── availability.py │ │ │ ├── dynamo_backends.py │ │ │ ├── dynamo_config.py │ │ │ ├── flow_graph_cache.py │ │ │ ├── interpreter.py │ │ │ ├── register_functions.py │ │ │ ├── register_methods.py │ │ │ ├── register_modules.py │ │ │ ├── registry.py │ │ │ └── utils.py │ ├── graph_utils │ │ ├── __init__.py │ │ ├── functors.py │ │ └── instruments │ │ │ ├── __init__.py │ │ │ ├── benchmark_instrument.py │ │ │ └── debug_instrument.py │ ├── impl │ │ ├── __init__.py │ │ ├── dlpack.py │ │ └── graph_impl.py │ ├── nn │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── attention.py │ │ ├── container.py │ │ ├── convolutions.py │ │ ├── identity.py │ │ ├── linear.py │ │ ├── module.py │ │ ├── norms.py │ │ ├── poolings.py │ │ └── transforms.py │ ├── operator.py │ ├── ops │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── arithmetic.py │ │ ├── arithmetic_resolve.py │ │ ├── attention │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ └── attention_mask.py │ │ ├── compare.py │ │ ├── complex.py │ │ ├── conv1d │ │ │ ├── __init__.py │ │ │ ├── conv1d.py │ │ │ ├── conv1d_gemm.py │ │ │ ├── resolve.py │ │ │ └── utils.py │ │ ├── conv1d_transpose │ │ │ ├── __init__.py │ │ │ └── conv1d_transpose.py │ │ ├── conv2d │ │ │ ├── __init__.py │ │ │ ├── conv2d.py │ │ │ ├── conv2d_gemm.py │ │ │ ├── conv2d_winograd.py │ │ │ ├── resolve.py │ │ │ └── utils.py │ │ ├── conv2d_transpose │ │ │ ├── __init__.py │ │ │ ├── conv2d_transpose.py │ │ │ ├── conv2d_transpose_gemm.py │ │ │ └── resolve.py │ │ ├── conv3d │ │ │ ├── __init__.py │ │ │ ├── conv3d.py │ │ │ ├── conv3d_gemm.py │ │ │ ├── resolve.py │ │ │ └── utils.py │ │ ├── conv3d_transpose │ │ │ ├── __init__.py │ │ │ └── conv3d_transpose.py │ │ ├── create.py │ │ ├── cumulative.py │ │ ├── distributed.py │ │ ├── embedding_bag.py │ │ ├── fusion │ │ │ ├── __init__.py │ │ │ ├── apply_prologue_epilogue.py │ │ │ └── fused_operator.py │ │ ├── image.py │ │ ├── linear.py │ │ ├── matmul │ │ │ ├── __init__.py │ │ │ ├── cuda_batch_matmul.py │ │ │ ├── hip_batch_matmul.py │ │ │ ├── matmul.py │ │ │ ├── matmul_cublas.py │ │ │ ├── matmul_f16.py │ │ │ ├── matmul_f16_cute.py │ │ │ ├── matmul_f16_cute_experimental.py │ │ │ ├── matmul_f16_sm90.py │ │ │ ├── matmul_f32_x86.py │ │ │ ├── matmul_f8.py │ │ │ └── resolve.py │ │ ├── normalize │ │ │ ├── __init__.py │ │ │ ├── layers.py │ │ │ ├── lp.py │ │ │ ├── norm.py │ │ │ └── resolve.py │ │ ├── opaque.py │ │ ├── pool.py │ │ ├── quant │ │ │ ├── __init__.py │ │ │ ├── matmul.py │ │ │ ├── matmul_f16_i8.py │ │ │ ├── matmul_f16_i8_atomic.py │ │ │ ├── resolve.py │ │ │ └── symmetric.py │ │ ├── reduce │ │ │ ├── __init__.py │ │ │ ├── reduce.py │ │ │ └── resolve.py │ │ ├── scatter.py │ │ ├── softmax.py │ │ ├── special.py │ │ ├── transfer.py │ │ ├── transform.py │ │ ├── transpose2d.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── schedule_utils.py │ │ │ └── tensor_utils.py │ ├── tensor.py │ └── transforms │ │ ├── __init__.py │ │ ├── base.py │ │ ├── conv_channel_last.py │ │ ├── eliminate_barrier.py │ │ ├── fuse_operator.py │ │ ├── graph_patterns │ │ ├── __init__.py │ │ ├── arithmetic_patterns.py │ │ ├── attn_patterns.py │ │ ├── base.py │ │ ├── conv2d_patterns.py │ │ ├── matmul_patterns.py │ │ ├── quant │ │ │ ├── __init__.py │ │ │ ├── embedding.py │ │ │ └── linear.py │ │ ├── register_all_patterns.py │ │ └── transform_patterns.py │ │ ├── instruments │ │ ├── __init__.py │ │ ├── base.py │ │ ├── convert_flowgraph_to_vgpu.py │ │ ├── profile_instrument.py │ │ └── save_graph_instrument.py │ │ ├── resolve_variant.py │ │ ├── selective_quantize.py │ │ ├── subgraph_rewrite.py │ │ └── utils.py │ ├── hip │ ├── __init__.py │ ├── capability.py │ ├── device.py │ ├── event.py │ ├── ffi.py │ ├── graph.py │ ├── memory.py │ └── stream.py │ ├── ir │ ├── __init__.py │ ├── analyzers │ │ ├── __init__.py │ │ └── bound_analyzer.py │ ├── builders │ │ ├── __init__.py │ │ ├── func_builder.py │ │ └── stmt_builder.py │ ├── compute │ │ ├── __init__.py │ │ ├── cops │ │ │ ├── __init__.py │ │ │ ├── matmul.py │ │ │ ├── pad.py │ │ │ └── reduce.py │ │ ├── primitives.py │ │ └── reduce_operations.py │ ├── cute │ │ ├── __init__.py │ │ ├── algorithm.py │ │ ├── collective │ │ │ ├── __init__.py │ │ │ └── copy.py │ │ ├── contexts.py │ │ ├── expr.py │ │ ├── int_tuple.py │ │ ├── layout.py │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── arithmetic.py │ │ │ ├── copy.py │ │ │ ├── misc.py │ │ │ ├── mma.py │ │ │ ├── partition.py │ │ │ ├── rearrange.py │ │ │ ├── reduce.py │ │ │ ├── subtensor.py │ │ │ ├── tensor.py │ │ │ └── tiled_tensor_view.py │ │ ├── swizzle.py │ │ ├── type.py │ │ └── typing.py │ ├── dialects │ │ ├── __init__.py │ │ └── pattern.py │ ├── dtypes │ │ ├── __init__.py │ │ ├── boolean.py │ │ ├── complex.py │ │ ├── floats.py │ │ ├── integer.py │ │ ├── integer_subbyte.py │ │ ├── promotion.py │ │ ├── utils.py │ │ └── vector.py │ ├── expr.py │ ├── func.py │ ├── functors │ │ ├── __init__.py │ │ ├── base_functor.py │ │ ├── compute_functor.py │ │ ├── cute_functor.py │ │ ├── expr_functor.py │ │ ├── ir_functor.py │ │ ├── layout_functor.py │ │ ├── mapping_functor.py │ │ ├── module_functor.py │ │ ├── stmt_functor.py │ │ └── type_functor.py │ ├── layout.py │ ├── library │ │ ├── __init__.py │ │ ├── cuda │ │ │ ├── __init__.py │ │ │ ├── cublas │ │ │ │ ├── __init__.py │ │ │ │ ├── kernels.py │ │ │ │ └── regs.py │ │ │ └── matmul │ │ │ │ ├── __init__.py │ │ │ │ └── simt.py │ │ ├── tune.py │ │ └── utils.py │ ├── mapping.py │ ├── module.py │ ├── node.py │ ├── polinomial.py │ ├── primitives │ │ ├── __init__.py │ │ ├── complex.py │ │ ├── cpu │ │ │ ├── __init__.py │ │ │ ├── atomic.py │ │ │ ├── avx.py │ │ │ ├── avx_helper.py │ │ │ └── math │ │ │ │ ├── __init__.py │ │ │ │ ├── bfloat16.py │ │ │ │ ├── float16.py │ │ │ │ ├── float32.py │ │ │ │ ├── float64.py │ │ │ │ ├── int32.py │ │ │ │ └── int64.py │ │ ├── cuda │ │ │ ├── __init__.py │ │ │ ├── atomic.py │ │ │ ├── barrier.py │ │ │ ├── cluster.py │ │ │ ├── copy_tma.py │ │ │ ├── cp_async.py │ │ │ ├── cvt.py │ │ │ ├── cvta.py │ │ │ ├── errchk.py │ │ │ ├── fastintdiv.py │ │ │ ├── funcs.py │ │ │ ├── half.py │ │ │ ├── ldst.py │ │ │ ├── lop3.py │ │ │ ├── math │ │ │ │ ├── __init__.py │ │ │ │ ├── bfloat16.py │ │ │ │ ├── complex128.py │ │ │ │ ├── complex64.py │ │ │ │ ├── float16.py │ │ │ │ ├── float16x2.py │ │ │ │ ├── float32.py │ │ │ │ ├── float64.py │ │ │ │ ├── float8e4m3.py │ │ │ │ ├── float8e5m2.py │ │ │ │ ├── int32.py │ │ │ │ └── int64.py │ │ │ ├── memcpy.py │ │ │ ├── mma.py │ │ │ ├── mutex.py │ │ │ ├── nccl.py │ │ │ ├── prmt.py │ │ │ ├── setmaxnreg.py │ │ │ ├── shfl.py │ │ │ ├── smem.py │ │ │ ├── sync.py │ │ │ ├── tensor_map.py │ │ │ ├── time.py │ │ │ ├── vars.py │ │ │ ├── wgmma.py │ │ │ └── wmma.py │ │ ├── debug.py │ │ ├── func.py │ │ ├── hip │ │ │ ├── __init__.py │ │ │ ├── buffer_addr.py │ │ │ ├── errchk.py │ │ │ ├── lds_sync.py │ │ │ ├── math │ │ │ │ ├── __init__.py │ │ │ │ ├── float16.py │ │ │ │ ├── float32.py │ │ │ │ └── int32.py │ │ │ ├── mfma.py │ │ │ └── vars.py │ │ ├── math.py │ │ ├── runtime.py │ │ └── vars.py │ ├── schedulers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── cpu │ │ │ ├── __init__.py │ │ │ └── scheduler.py │ │ └── cuda │ │ │ ├── __init__.py │ │ │ └── scheduler.py │ ├── stmt.py │ ├── target.py │ ├── task.py │ ├── tools │ │ ├── __init__.py │ │ ├── free_var_collector.py │ │ ├── hasher.py │ │ ├── ir_dumper.py │ │ ├── printer.py │ │ ├── renamer.py │ │ ├── rewriter.py │ │ ├── simplifier.py │ │ ├── type_infer.py │ │ └── util_functors.py │ ├── type.py │ └── utils │ │ ├── __init__.py │ │ ├── broadcast_utils.py │ │ ├── call_graph.py │ │ ├── expr_utils.py │ │ ├── hash_sum.py │ │ ├── index_transform.py │ │ └── type_utils.py │ ├── lang │ ├── __init__.py │ ├── attrs.py │ ├── attrs │ │ ├── __init__.py │ │ ├── cuda.py │ │ └── hip.py │ ├── constructs │ │ ├── __init__.py │ │ ├── context.py │ │ ├── declare.py │ │ ├── loops.py │ │ └── meta.py │ ├── cpu.py │ ├── cuda │ │ ├── __init__.py │ │ └── contexts.py │ ├── hip.py │ ├── layout.py │ ├── mapping.py │ ├── runtime.py │ ├── script.py │ ├── transpiler.py │ └── types.py │ ├── libinfo.py │ ├── logging.py │ ├── option.py │ ├── runtime │ ├── __init__.py │ ├── compiled_graph.py │ ├── compiled_module.py │ ├── compiled_task.py │ ├── device.py │ ├── storage.py │ └── utils │ │ └── dispatch_table.py │ ├── testing │ ├── __init__.py │ ├── capture_stdout.py │ ├── models │ │ ├── __init__.py │ │ ├── gemma.py │ │ ├── gpt2.py │ │ ├── llama.py │ │ └── resnet.py │ ├── onnx_models.py │ ├── onnx_utils.py │ ├── torch_utils.py │ └── utils.py │ ├── transforms │ ├── __init__.py │ ├── add_explicit_cast.py │ ├── add_hints.py │ ├── annotate_header_and_libs.py │ ├── attach_hash_to_signature.py │ ├── base.py │ ├── check_launch_configuration.py │ ├── convert_div_to_fastintdiv.py │ ├── cute │ │ ├── __init__.py │ │ ├── analysis │ │ │ ├── __init__.py │ │ │ └── tensor_alias_analysis.py │ │ ├── cuda │ │ │ ├── __init__.py │ │ │ ├── annotate_mbarrier.py │ │ │ ├── cost_model.py │ │ │ ├── instantiate_auto_annotation.py │ │ │ ├── instruction_selection.py │ │ │ ├── lower_cute_dialect.py │ │ │ ├── lower_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── arithmetic.py │ │ │ │ ├── collective.py │ │ │ │ ├── copy.py │ │ │ │ ├── mbarrier.py │ │ │ │ ├── misc.py │ │ │ │ ├── mma.py │ │ │ │ ├── partition.py │ │ │ │ ├── rearrange.py │ │ │ │ ├── reduce.py │ │ │ │ ├── registry.py │ │ │ │ ├── subtensor.py │ │ │ │ └── tensor.py │ │ │ ├── resolve_bank_conflict.py │ │ │ ├── shared_memory_allocation.py │ │ │ ├── tma_fallback_copy.py │ │ │ ├── tma_layout_utils.py │ │ │ └── vectorize_elementwise.py │ │ └── generic │ │ │ ├── __init__.py │ │ │ ├── canonicalize.py │ │ │ ├── canonicalize_arithmetic_expression.py │ │ │ └── deadcode_elimination.py │ ├── declare_to_let.py │ ├── expand_let_expr.py │ ├── expand_repeat.py │ ├── explicit_unroll.py │ ├── flatten_tensor_index.py │ ├── flatten_tensor_slice.py │ ├── generate_launch_func.py │ ├── import_primitive_functions.py │ ├── inline_function.py │ ├── inline_let_stmt.py │ ├── instantiate_symbols.py │ ├── instruments │ │ ├── __init__.py │ │ ├── base.py │ │ ├── profile_instrument.py │ │ └── save_ir_instrument.py │ ├── lower_integer_subbyte.py │ ├── lower_protect_access.py │ ├── lower_special_cast.py │ ├── lower_task_mapping.py │ ├── normalize_const_tensor.py │ ├── propagate_launch_bound.py │ ├── resolve_generic_primitive_function.py │ ├── rule_based_simplifier.py │ ├── simplify_addition_chain.py │ ├── simplify_stmt.py │ ├── spatial_simplification.py │ ├── task_mapping_bound_check.py │ └── unify_global_objects.py │ ├── utils │ ├── __init__.py │ ├── benchmark │ │ ├── __init__.py │ │ ├── bench.py │ │ ├── gpu_freq.py │ │ └── utils.py │ ├── cache_utils.py │ ├── counters.py │ ├── cuda_sanitizer.py │ ├── dataclass.py │ ├── doc.py │ ├── exiting.py │ ├── fault_handler.py │ ├── files.py │ ├── folder_lock.py │ ├── gc.py │ ├── git_utils.py │ ├── multiprocess.py │ ├── namer.py │ ├── ncu_utils.py │ ├── net_utils.py │ ├── netron.py │ ├── nsys_utils.py │ ├── omniperf_utils.py │ ├── ort_utils.py │ ├── overrides.py │ ├── py.py │ ├── stack_limit.py │ ├── structure.py │ ├── torch_utils.py │ ├── trace_utils.py │ └── transformers_utils.py │ └── version.py ├── scripts ├── bench │ ├── README.md │ ├── benchmark.py │ ├── requirements.txt │ └── run.py ├── lint │ ├── .clang-format │ ├── _format.py │ ├── add_copyright.py │ ├── format.sh │ ├── lint.sh │ └── pylintrc ├── nightly-builds │ ├── README.md │ ├── add-crontab-record.sh │ └── update-nightly.sh ├── regression │ ├── __init__.py │ ├── email_sender.py │ ├── model_performance.py │ ├── op_performance.py │ ├── regression_data.json │ ├── requirements.txt │ ├── result_entry.py │ └── run.py └── wheel │ ├── build_wheel.sh │ ├── build_wheel_manylinux_2_28_x86_64.sh │ ├── current_version.py │ ├── dockerfiles │ └── manylinux_2_28_x86_64 │ │ └── Dockerfile │ ├── update_nightly.sh │ └── upload_wheel_to_pypi.sh ├── setup.py ├── src └── hidet │ ├── empty.cpp │ └── runtime │ ├── callbacks.cpp │ ├── cpu │ └── context.cpp │ ├── cuda │ ├── context.cpp │ ├── cublas.cpp │ ├── cuda.cpp │ ├── cudnn.cpp │ └── utils.h │ ├── hip │ └── context.cpp │ ├── int_fastdiv.cpp │ ├── logging.cpp │ └── symbols.cpp └── tests ├── README.md ├── benchmarks ├── bench_dynamic.py ├── bench_op.py ├── bench_op_torch_api.py ├── bench_task.py ├── bench_transformer.py ├── bench_transformer_comptime.py ├── bench_vision.py ├── run_configs.json ├── run_configs_full.json └── run_tests.py ├── conftest.py ├── cuda ├── test_cublas.py ├── test_cuda_graph.py └── test_cudnn.py ├── cute_fusion ├── fusion_bench_utils.py ├── test_epilogue_fusion.py ├── test_issue436.py ├── test_matmul_bias_cast.py ├── test_matmul_cast_dynamic.py ├── test_matmul_residual_bias.py ├── test_matmul_residual_permute.py ├── test_matmul_riskfuel.py └── test_matmul_standalone.py ├── cute_kernel ├── gemm_quant_nbit.py ├── quant_utils.py ├── quantized_linear.py ├── test_attention_bwd.py ├── test_cute_attention.py ├── test_flash_decoding.py ├── test_gemm_quant_nbit.py ├── test_gemv.py ├── test_hopper_gemm.py ├── test_mlp.py ├── test_moe_align.py ├── test_quant_linear.py ├── test_rmsnorm.py └── test_weight_quantization.py ├── cute_ops ├── __init__.py ├── test_arithmetic2.py ├── test_collective.py ├── test_ldst.py ├── test_mma.py ├── test_rearrange.py └── test_reduce.py ├── distributed ├── test_file_store.py ├── test_op.py ├── test_runtime.py ├── test_tcp_store.py └── utils.py ├── flowgraph └── test_graph_visualization.py ├── frontends ├── onnx │ └── test_onnx_slice.py └── torch │ ├── models │ ├── test_torch_bert.py │ ├── test_torch_densenet121.py │ ├── test_torch_pegasus.py │ └── test_torch_resnet50.py │ ├── subgraphs │ ├── test_torch_graph_flatten.py │ ├── test_torch_graph_pad.py │ └── test_torch_graph_to.py │ ├── test_flow_graph_cache.py │ ├── test_torch_activation.py │ ├── test_torch_arithmetic.py │ ├── test_torch_conv1d.py │ ├── test_torch_conv1d_transpose.py │ ├── test_torch_conv2d.py │ ├── test_torch_conv2d_transpose.py │ ├── test_torch_conv3d.py │ ├── test_torch_conv3d_transpose.py │ ├── test_torch_creation.py │ ├── test_torch_dyn_shape.py │ ├── test_torch_fxgraph.py │ ├── test_torch_image.py │ ├── test_torch_inplace.py │ ├── test_torch_interoperability.py │ ├── test_torch_mix_cuda_cpu.py │ ├── test_torch_mul.py │ ├── test_torch_norm.py │ ├── test_torch_pooling.py │ ├── test_torch_reduce.py │ ├── test_torch_sdpa.py │ ├── test_torch_split.py │ ├── test_torch_stream.py │ └── test_torch_to.py ├── hip └── test_hip_graph.py ├── ir ├── dialects │ └── test_pattern.py ├── dtypes │ ├── test_fp8e4m3.py │ └── test_fp8e5m2.py ├── functors │ └── test_persistence.py ├── parser │ └── test_parser.py ├── primitives │ ├── cuda │ │ ├── test_barrier.py │ │ ├── test_cluster.py │ │ ├── test_copy_tma.py │ │ ├── test_exp2.py │ │ ├── test_half.py │ │ ├── test_lop3.py │ │ ├── test_mma.py │ │ ├── test_prmt.py │ │ └── test_wgmma.py │ └── hip │ │ ├── test_buffer_addr.py │ │ └── test_mfma.py ├── test_expr_const_fold.py ├── test_int_subbyte.py ├── test_layout.py ├── test_primitives.py └── test_symbol_var.py ├── minimal └── test_add.py ├── models ├── test_gemma.py ├── test_gpt2.py ├── test_llama.py ├── test_llama_graph.py └── test_quant_gpt2.py ├── multiprocessing ├── lazy_init_sample.py └── test_lazy_initialization.py ├── operators_core ├── test_activation.py ├── test_arithmetic.py ├── test_attention.py ├── test_compare.py ├── test_complex.py ├── test_conv1d_transpose.py ├── test_create.py ├── test_fusion.py ├── test_identity.py ├── test_inplace_operator.py ├── test_matmul.py ├── test_norm.py ├── test_opaque.py ├── test_operator.py ├── test_quantization.py ├── test_reduce.py ├── test_softmax.py ├── test_symmetric_quant.py ├── test_tensor.py ├── test_transform.py └── test_tri.py ├── operators_vision ├── test_conv1d.py ├── test_conv2d.py ├── test_conv2d_transpose.py ├── test_conv3d.py ├── test_conv3d_transpose.py ├── test_image.py └── test_pool.py ├── runtime ├── test_dispatch_table.py └── test_try_catch.py ├── script ├── test_assignment.py ├── test_comprehension.py ├── test_constant.py ├── test_context.py ├── test_cpu_kernel.py ├── test_for_loops.py ├── test_func_import.py ├── test_lambda.py ├── test_meta.py ├── test_parallel.py ├── test_return_type.py └── test_unroll.py ├── transforms ├── test_graph_rewrites.py ├── test_rule_based_simplifier.py └── test_simplify_addition_chain.py ├── unit_tests ├── check_import_time.py ├── test_compiled_model.py ├── test_dynamic_shape.py ├── test_frontend_onnx.py ├── test_import_time.py ├── test_save_lower_ir.py └── test_vllm_hidet_compile.py └── utils ├── benchmark └── example.py ├── test_cuda_sanitizer.py ├── test_ncu_utils.py └── test_nsys_utils.py /.github/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:25.01-py3 2 | ADD ./hidet /workspace/hidet 3 | ADD ./models /workspace/models 4 | WORKDIR /workspace 5 | RUN pip install -r hidet/.github/requirements-ci.txt && \ 6 | bash hidet/scripts/wheel/build_wheel.sh && \ 7 | WHEEL=$(find hidet/scripts/wheel/built_wheel -maxdepth 1 -name '*.whl') && \ 8 | pip install --force-reinstall $WHEEL[dev] && \ 9 | pip install -e models && \ 10 | hidet cache clear --all 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug] " 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior. A small and reproducible script would be very helpful. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Enviroment** 20 | - OS: [e.g. Ubuntu 22.04] 21 | - GPU: [e.g. RTX 3090] 22 | - Others: [e.g. NVIDIA GPU Driver 525.85.12] 23 | 24 | **Additional context** 25 | Add any other context about the problem here. 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/actions/build-wheel/action.yaml: -------------------------------------------------------------------------------- 1 | name: 'Build wheel' 2 | description: 'Build a wheel with given version label' 3 | outputs: 4 | wheel_path: 5 | description: 'the path to the generated wheel' 6 | wheel_name: 7 | description: 'the full name of the wheel file' 8 | runs: 9 | using: 'docker' 10 | image: '../../../scripts/wheel/dockerfiles/manylinux_2_28_x86_64/Dockerfile' 11 | args: 12 | - "bash" 13 | - "./.github/scripts/build_wheel.sh" 14 | -------------------------------------------------------------------------------- /.github/actions/run-tests/action.yaml: -------------------------------------------------------------------------------- 1 | name: 'Run Tests' 2 | description: 'Runs test suite on a specified GPU' 3 | inputs: 4 | path: 5 | description: "The path to the tests" 6 | required: true 7 | runs: 8 | using: "composite" 9 | steps: 10 | - name: Run tests 11 | shell: bash 12 | run: | 13 | (while true; do echo "Heartbeat: $(date)"; sleep 30; done) & 14 | echo_loop_pid=$! 15 | 16 | # Ensure the heartbeat is stopped on exit, even if pytest fails. 17 | trap "kill $echo_loop_pid" EXIT 18 | 19 | rm -rf ~/.config/hidet 20 | nice -n 10 python -m pytest -v --durations=20 --clear-cache ${{ inputs.path }} 21 | 22 | # Fix of https://github.com/CentML/hidet/issues/928 23 | python .github/scripts/cuda_cleanup.py 24 | -------------------------------------------------------------------------------- /.github/requirements-ci.txt: -------------------------------------------------------------------------------- 1 | mysql-connector-python -------------------------------------------------------------------------------- /.github/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidet-org/hidet/599f628d30e235ade0a680297597fb84a0d7a54e/.github/scripts/__init__.py -------------------------------------------------------------------------------- /.github/scripts/bench/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidet-org/hidet/599f628d30e235ade0a680297597fb84a0d7a54e/.github/scripts/bench/__init__.py -------------------------------------------------------------------------------- /.github/scripts/build_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # stop immediately if a command exits with a non-zero status. 5 | set -e 6 | 7 | # print the executed commands 8 | set -x 9 | 10 | # use ./scripts/wheel/build_wheel.sh to build the wheel 11 | bash scripts/wheel/build_wheel.sh 12 | 13 | echo $(pwd) 14 | WHEEL=$(find scripts/wheel/built_wheel -maxdepth 1 -name '*.whl') 15 | WHEEL_FILENAME=$(basename "$WHEEL") 16 | 17 | echo "wheel_path=./scripts/wheel/built_wheel" >> "$GITHUB_OUTPUT" 18 | echo "wheel_name=${WHEEL_FILENAME}" >> "$GITHUB_OUTPUT" 19 | -------------------------------------------------------------------------------- /.github/scripts/cuda_cleanup.py: -------------------------------------------------------------------------------- 1 | # Script to ensure proper CUDA cleanup before process termination 2 | import torch 3 | import time 4 | import gc 5 | import sys 6 | 7 | print("=== Starting CUDA cleanup and synchronization ===") 8 | sys.stdout.flush() # Ensure output is visible in logs 9 | 10 | # Force garbage collection first 11 | gc.collect() 12 | 13 | # Explicitly empty CUDA cache 14 | torch.cuda.empty_cache() 15 | 16 | # Synchronize all CUDA operations 17 | print("Synchronizing CUDA devices...") 18 | for i in range(torch.cuda.device_count()): 19 | torch.cuda.synchronize(i) 20 | 21 | # Small delay to allow driver operations to complete 22 | time.sleep(2) 23 | 24 | # Print memory stats for debugging 25 | print(f"CUDA memory allocated: {torch.cuda.memory_allocated()} bytes") 26 | print(f"CUDA memory reserved: {torch.cuda.memory_reserved()} bytes") 27 | print("=== CUDA cleanup complete ===") 28 | sys.stdout.flush() -------------------------------------------------------------------------------- /.github/scripts/db_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mysql.connector 3 | 4 | def get_db_conn(): 5 | conn = mysql.connector.connect( 6 | host=os.environ.get('CI_DB_HOSTNAME'), 7 | user=os.environ.get('CI_DB_USERNAME'), 8 | password=os.environ.get('CI_DB_PASSWORD'), 9 | port=os.environ.get('CI_DB_PORT'), 10 | database='hidet_ci' 11 | ) 12 | return conn 13 | -------------------------------------------------------------------------------- /.github/scripts/set_gpu_types.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # expects $INPUT_GPU_L4, $INPUT_GPU_H100, $INPUT_GPU_A10, $INPUT_GPU_A100 4 | echo "Selecting GPU types based on input or github event" 5 | echo "Triggered by event: $GITHUB_EVENT_NAME" 6 | if [ "$GITHUB_EVENT_NAME" == "workflow_dispatch" ]; then 7 | echo "Using manual dispatch inputs" 8 | GPU_L4="$INPUT_GPU_L4" 9 | GPU_H100="$INPUT_GPU_H100" 10 | GPU_A10="$INPUT_GPU_A10" 11 | GPU_A100="$INPUT_GPU_A100" 12 | elif [ "$GITHUB_EVENT_NAME" == "pull_request" ]; then 13 | echo "Using pull request defaults" 14 | GPU_L4="true" 15 | GPU_H100="true" 16 | GPU_A10="false" 17 | GPU_A100="false" 18 | elif [ "$GITHUB_EVENT_NAME" == "push" ]; then 19 | echo "Using push defaults" 20 | GPU_L4="true" 21 | GPU_H100="true" 22 | GPU_A10="false" 23 | GPU_A100="false" 24 | elif [ "$GITHUB_EVENT_NAME" == "schedule" ]; then 25 | echo "Using scheduled run defaults" 26 | GPU_L4="false" 27 | GPU_H100="false" 28 | GPU_A10="true" 29 | GPU_A100="true" 30 | else 31 | echo "Unknown event type. Exiting." 32 | exit 1 33 | fi 34 | 35 | # Write the parameters to GITHUB_OUTPUT so they can be used in later steps 36 | echo "gpu_l4=$GPU_L4" >> $GITHUB_OUTPUT 37 | echo "gpu_h100=$GPU_H100" >> $GITHUB_OUTPUT 38 | echo "gpu_a10=$GPU_A10" >> $GITHUB_OUTPUT 39 | echo "gpu_a100=$GPU_A100" >> $GITHUB_OUTPUT -------------------------------------------------------------------------------- /.github/scripts/set_test_matrix.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sets the strategy matrix for the functional ci tests. 3 | This mimics the discovery strategy used by pytest for files inside the tests/ folder 4 | and shards them based on the top level parent folders. 5 | 6 | Expects to be executed in a GHA envirionment, with GITHUB_OUTPUT context available. 7 | """ 8 | import glob 9 | import json 10 | import os 11 | from pathlib import Path 12 | 13 | patterns = ('test_*.py', '*_test.py') # the tuple of file types 14 | files_matched = [] 15 | for pattern in patterns: 16 | files_matched.extend(glob.glob(f"tests/**/{pattern}", recursive=True)) 17 | 18 | testing_paths = [] 19 | for path in files_matched: 20 | current_path = Path(path) 21 | testing_paths.append("/".join(current_path.parts[:2])) 22 | 23 | include = [] 24 | 25 | for path in list(set(testing_paths)): 26 | include.append({ 27 | "path": path 28 | }) 29 | 30 | matrix = { 31 | "include": include 32 | } 33 | 34 | matrix_str = json.dumps(matrix) 35 | name = 'matrix' 36 | value = matrix_str 37 | with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: 38 | print(f'{name}={value}', file=fh) -------------------------------------------------------------------------------- /.github/workflows/nightly.yaml: -------------------------------------------------------------------------------- 1 | name: Nightly Workflow 2 | 3 | permissions: 4 | contents: read 5 | 6 | on: 7 | schedule: 8 | - cron: '0 0 * * *' # Run every day at midnight 9 | workflow_dispatch: 10 | inputs: 11 | runner_group: 12 | type: choice 13 | options: 14 | - arc-l4 15 | - arc-a10 16 | - arc-h100 17 | description: "Runner group to run tests. (arc-l4, arc-a10, arc-h100)" 18 | required: false 19 | default: arc-l4 20 | 21 | jobs: 22 | build-docs: 23 | timeout-minutes: 120 24 | runs-on: 25 | group: ${{ inputs.runner_group || 'arc-l4' }} 26 | container: 27 | image: us-east4-docker.pkg.dev/github-workflow-runners/hidet-base-ci/hidet-base-ci:latest 28 | options: --gpus all 29 | steps: 30 | - name: Checkout source 31 | uses: actions/checkout@v4 32 | 33 | - name: Build and deploy nightly docs 34 | uses: ./.github/actions/build-docs 35 | with: 36 | docs_deploy_token: ${{ secrets.WEBSITE_REPO_TOKEN }} 37 | version: "nightly" 38 | update_docs: true 39 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include python/hidet/lib *.so 2 | recursive-include python/hidet/include *.h 3 | -------------------------------------------------------------------------------- /apps/compile_server/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.6.2-devel-ubuntu22.04 2 | 3 | COPY ./run.py /app/run.py 4 | COPY ./requirements.txt /app/requirements.txt 5 | WORKDIR /app 6 | 7 | ENV TZ=America/Toronto 8 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 9 | 10 | RUN apt-get update && apt-get install -y \ 11 | python3-dev \ 12 | python3-pip \ 13 | python3-setuptools \ 14 | vim \ 15 | wget \ 16 | git \ 17 | && rm -rf /var/lib/apt/lists/* \ 18 | && ln -s /usr/bin/python3 /usr/bin/python \ 19 | && python -m pip install --upgrade pip \ 20 | && python -m pip install filelock requests gunicorn flask cmake \ 21 | && python -m pip install -r ./requirements.txt 22 | 23 | EXPOSE 3281 24 | 25 | CMD ["python", "run.py"] 26 | -------------------------------------------------------------------------------- /apps/compile_server/README.md: -------------------------------------------------------------------------------- 1 | # Hidet Compilation Server 2 | 3 | ## Usage 4 | 5 | ### Setup the Compilation Server 6 | 7 | ```bash 8 | $ # clone the hidet repository to the server 9 | $ git clone https://github.com/hidet-org/hidet.git 10 | $ cd hidet/apps/compile_server 11 | $ # build the docker image and run it 12 | $ bash run.sh 13 | $ # Now, the compilation server is listening on port 3281 14 | ``` 15 | 16 | ### Setup on the Client Side 17 | 18 | ```python 19 | import hidet 20 | 21 | # config the ip address and port of the server 22 | hidet.option.compile_server.addr('x.x.x.x') 23 | hidet.option.compile_server.port(3281) 24 | 25 | # the username and password of the user, please change it to your own 26 | hidet.option.compile_server.username('username') 27 | hidet.option.compile_server.password('password') 28 | 29 | # the repository to use, by default, the main branch of hidet-org/hidet will be used 30 | hidet.option.compile_server.repo('https://github.com/hidet-org/hidet', 'main') 31 | 32 | # enable the compile server 33 | hidet.option.compile_server.enable() 34 | ``` 35 | -------------------------------------------------------------------------------- /apps/compile_server/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from flask import Flask, send_from_directory 4 | from flask_jwt_extended import JWTManager 5 | from flask_restful import Api 6 | 7 | from resources import CompilationResource, AuthResource, UserResource 8 | from resources.utils import validate_path 9 | 10 | app = Flask(__name__) 11 | app.config['JWT_SECRET_KEY'] = 'jwt-secret-string' 12 | api = Api(app) 13 | jwt = JWTManager(app) 14 | 15 | api.add_resource(CompilationResource, '/compile') 16 | api.add_resource(AuthResource, '/auth') 17 | api.add_resource(UserResource, '/user') 18 | 19 | 20 | @app.route('/download/') 21 | def download(filename): 22 | results_dir = os.path.join(os.getcwd(), 'results') 23 | path = os.path.join(results_dir, filename) 24 | if not validate_path(path, results_dir): 25 | return 'Invalid file path', 400 26 | if os.path.exists(path): 27 | return send_from_directory(results_dir, filename, as_attachment=True) 28 | else: 29 | return 'File not found', 404 30 | 31 | 32 | if __name__ == '__main__': 33 | app.run(debug=False, port=3281) 34 | -------------------------------------------------------------------------------- /apps/compile_server/requirements.txt: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Necessary packages 3 | ################################################################################ 4 | gitpython 5 | numpy 6 | 7 | # used for query available memory 8 | psutil 9 | 10 | # used for print table 11 | tabulate 12 | 13 | # python tests 14 | pytest 15 | 16 | # show progress bar 17 | tqdm 18 | 19 | # used to annotate the scope of events in host process, which can be visualized 20 | # in Nsight System. 21 | nvtx 22 | 23 | # for print python ast node 24 | astunparse 25 | 26 | # for command line interface 27 | click 28 | 29 | # for package version check 30 | packaging 31 | 32 | # for cuda runtime api and runtime compilation api 33 | cuda-python 34 | 35 | # for filestore 36 | filelock 37 | 38 | requests 39 | flask 40 | flask-restful 41 | flask-jwt-extended 42 | gunicorn 43 | GitPython 44 | 45 | # for configuration 46 | tomlkit 47 | 48 | # for parser 49 | lark 50 | 51 | # for performance measurements 52 | scipy 53 | 54 | # for torch runtime api dependency 55 | torch>=2.3.0 56 | 57 | # for hip-python 58 | hip-python-fork 59 | -------------------------------------------------------------------------------- /apps/compile_server/resources/__init__.py: -------------------------------------------------------------------------------- 1 | from .user import UserResource 2 | from .auth import AuthResource 3 | from .compilation import CompilationResource 4 | -------------------------------------------------------------------------------- /apps/compile_server/resources/auth.py: -------------------------------------------------------------------------------- 1 | from flask import request 2 | from flask_jwt_extended import create_access_token 3 | from flask_restful import Resource 4 | 5 | from .user import users 6 | 7 | 8 | class AuthResource(Resource): 9 | def post(self): 10 | username = request.json.get('username') 11 | password = request.json.get('password') 12 | 13 | if not isinstance(username, str) or not isinstance(password, str): 14 | return {'message': 'Invalid credentials'}, 401 15 | 16 | # Authenticate the user 17 | if username in users: 18 | if users[username] == password: 19 | # Generate an access token 20 | access_token = create_access_token(identity=username, expires_delta=False) 21 | return {'access_token': access_token} 22 | else: 23 | return {'message': 'Invalid credentials'}, 401 24 | else: 25 | return {'message': 'Invalid credentials'}, 401 26 | -------------------------------------------------------------------------------- /apps/compile_server/resources/status.py: -------------------------------------------------------------------------------- 1 | import time 2 | from flask import request 3 | from flask_jwt_extended import create_access_token 4 | from flask_restful import Resource 5 | from filelock import FileLock 6 | 7 | 8 | with FileLock('last_compile_timestamp.txt.lock'): 9 | with open('last_compile_timestamp.txt', 'r') as f: 10 | f.write(str(time.time())) 11 | 12 | 13 | class StatusResource(Resource): 14 | def get(self): 15 | query: str = request.json()['query'] 16 | if query == 'last_compile_timestamp': 17 | with FileLock('last_compile_timestamp.txt.lock'): 18 | with open('last_compile_timestamp.txt', 'r') as f: 19 | return {'timestamp': f.read()} 20 | else: 21 | return {'message': 'Invalid query'}, 400 22 | -------------------------------------------------------------------------------- /apps/compile_server/resources/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def validate_path(path, base): 4 | """ 5 | Validate the path to prevent directory traversal attacks. 6 | 7 | This function checks if the path is inside the base directory. 8 | 9 | Parameters 10 | ---------- 11 | path: str or List[str] 12 | The path to be validated. 13 | base: str 14 | The base directory to check against. 15 | 16 | Returns 17 | ------- 18 | bool 19 | True if the path is valid, False otherwise. 20 | """ 21 | if isinstance(path, list): 22 | # If path is a list, check each path in the list 23 | for p in path: 24 | if not validate_path(p, base): 25 | return False 26 | return True 27 | elif isinstance(path, str): 28 | # Normalize the paths 29 | path = os.path.realpath(path) 30 | base = os.path.realpath(base) 31 | 32 | # Check if the path is inside the base directory 33 | return os.path.commonpath([path, base]) == base 34 | else: 35 | raise TypeError('Invalid type for path: {}'.format(type(path))) 36 | -------------------------------------------------------------------------------- /apps/compile_server/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker build -t compile_server . 4 | docker run -p 3281:3281 --rm compile_server 5 | -------------------------------------------------------------------------------- /config.cmake: -------------------------------------------------------------------------------- 1 | # Set build type 2 | # - Debug 3 | # - Release 4 | set(HIDET_BUILD_TYPE Release) -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | clean: 16 | @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 17 | rm -rf ./source/gallery 18 | 19 | .PHONY: help Makefile 20 | 21 | # Catch-all target: route all unknown targets to Sphinx using the new 22 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 23 | %: Makefile 24 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | jinja2 2 | sphinx 3 | sphinx-gallery 4 | sphinx-copybutton 5 | autodocsumm 6 | sphinx-book-theme==1.0.1 7 | matplotlib 8 | sphinxcontrib-bibtex 9 | sphinxcontrib-googleanalytics 10 | -------------------------------------------------------------------------------- /docs/source/_static/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 12 | 14 | 17 | Hi 27 | 28 | 29 | -------------------------------------------------------------------------------- /docs/source/clear-hidet-cache.py: -------------------------------------------------------------------------------- 1 | import conf 2 | 3 | conf.hidet.utils.clear_op_cache() 4 | -------------------------------------------------------------------------------- /docs/source/developer-guides/contributing.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | To contribute to this project, please fork the hidet repository and create a pull request. 5 | Before submitting a pull request, please make sure that your code is properly formatted and that it passes the tests. 6 | 7 | **Install editable dev version of hidet** 8 | 9 | .. code-block:: bash 10 | 11 | $ git clone hidet-org/hidet 12 | $ cd hidet 13 | $ pip install -e .[dev] 14 | 15 | **Format & Lint** Run the following scripts to format and lint the code: 16 | 17 | .. code-block:: bash 18 | 19 | $ bash scripts/lint/format.sh 20 | $ bash scripts/lint/lint.sh 21 | 22 | **Tests** To run the tests, run the following script: 23 | 24 | .. code-block:: bash 25 | 26 | $ # use --clear-cache to clear the operator cache if you changed the following sub-modules 27 | $ # - hidet.ir 28 | $ # - hidet.backend 29 | $ pytest tests 30 | 31 | -------------------------------------------------------------------------------- /docs/source/genindex.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ===== -------------------------------------------------------------------------------- /docs/source/getting-started/build-from-source.rst: -------------------------------------------------------------------------------- 1 | Build from source 2 | ------------------- 3 | .. _Build-from-source: 4 | 5 | If you want to contribute to Hidet, or you encountered any problem directly installing hidet via pip, it is better to install 6 | hidet from source. 7 | 8 | Clone the code 9 | ~~~~~~~~~~~~~~ 10 | 11 | First clone the repository to local: 12 | 13 | .. code-block:: console 14 | 15 | $ git clone https://github.com/hidet-org/hidet 16 | $ cd hidet # enter the hidet directory 17 | 18 | Install the Hidet Python package 19 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | Next we will install the Python package of Hidet via pip. The following command will install the package in the develop 22 | mode, in which the modification of the source code will be immediately reflected in the installed package. If you want to 23 | install the package in the normal mode, use 'pip install .' instead. 24 | 25 | .. code-block:: console 26 | 27 | $ pip install -e . 28 | 29 | Validation 30 | ~~~~~~~~~~ 31 | 32 | To make sure we have successfully installed hidet, run the following command in a new shell: 33 | 34 | .. code-block:: console 35 | 36 | $ python -c "import hidet" 37 | 38 | If no error reports, then hidet has been successfully installed on your computer. 39 | -------------------------------------------------------------------------------- /docs/source/getting-started/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Run the following command to install ``hidet`` package via python pip: 5 | 6 | .. code-block:: console 7 | 8 | $ pip install hidet 9 | 10 | To verify the installation, run the following command: 11 | 12 | .. code-block:: console 13 | 14 | $ python -c "import hidet" 15 | 16 | .. tip:: 17 | 18 | You can also install the nightly version of ``hidet`` package via python pip: 19 | 20 | .. code-block:: console 21 | 22 | $ pip install --pre --extra-index-url https://download.hidet.org/whl hidet 23 | 24 | 25 | If you want, you can also :doc:`build from source `. 26 | 27 | .. toctree:: 28 | :hidden: 29 | 30 | build-from-source 31 | -------------------------------------------------------------------------------- /docs/source/hidet-script/examples/index.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | This section contains a collection of examples that demonstrate how to use Hidet Script to write kernel programs. Each 5 | example is a self-contained hidet script program that can be run directly. 6 | 7 | .. _hidet script examples: 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Hidet Script Examples 12 | 13 | ../../gallery/hidet-script/0-hello-world 14 | ../../gallery/hidet-script/1-scalar-addition 15 | ../../gallery/hidet-script/2-vector-addition 16 | ../../gallery/hidet-script/3-kernel-functions 17 | ../../gallery/hidet-script/4-naive-matmul 18 | ../../gallery/hidet-script/5-efficient-matmul 19 | -------------------------------------------------------------------------------- /docs/source/hidet-script/index.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | Hidet Script is a domain specific language (DSL) for writing tensor programs directly in python. 5 | The users can write the tensor programs with python's syntax with some constrains and extensions. 6 | A transpiler is used to translate the python abstract syntax tree (AST) to Hidet's tensor program IR. 7 | Then, the translated tensor programs in Hidet IR is optimized and compiled to the target binary for execution. 8 | The tensor program writer works in the python environment in the whole process. 9 | 10 | 11 | To get started, please refer to the :ref:`hidet script examples`. 12 | -------------------------------------------------------------------------------- /docs/source/hidet-script/reference/4-function.rst: -------------------------------------------------------------------------------- 1 | Function 2 | ======== 3 | 4 | Function kinds 5 | -------------- 6 | 7 | A function can be one of the following kinds: 8 | 9 | - ``public``: a public function can be invoked in python directly 10 | - ``cuda_kernel``: a cuda kernel function 11 | - ``cuda_internal``: a cuda device function that can only be invoked by cuda kernel/device functions 12 | - ``cpu_kernel``: a cpu kernel function 13 | - ``cpu_internal``: a cpu function that will be used by other cpu functions 14 | 15 | Only the ``public`` functions will be exposed to python. For the modules that defines a kernel function 16 | (i.e., ``cuda_kernel`` or ``cpu_kernel``), and there is not a ``public`` function named ``launch``, then hidet 17 | will automatically create a ``public`` function named ``launch`` that will launch the kernel function. 18 | -------------------------------------------------------------------------------- /docs/source/hidet-script/reference/5-module.rst: -------------------------------------------------------------------------------- 1 | Module 2 | ====== 3 | 4 | Script module 5 | ------------- 6 | 7 | A script module is a collections of hidet script functions and global variables. It serves as a compilation unit 8 | of hidet. We can use ``hidet.script_module()`` to create a script module. The created script module can be used as 9 | a python context manager like 10 | 11 | .. code-block:: 12 | 13 | import hidet 14 | from hidet.lang import attrs 15 | from hidet.lang.types import f32 16 | 17 | with hidet.script_module() as script_module: 18 | # define global variables like 19 | script_module.define_global_var(name='global_var', var_type=f32) 20 | ... 21 | 22 | # define functions like 23 | @hidet.script 24 | def foo(): 25 | attrs.func_kind = 'public' # the function kind is mandatory 26 | ... 27 | 28 | # we can define multiple functions in the script module and call each other 29 | 30 | # we can build the script module to get a CompiledModule (hidet.runtime.CompiledModule) 31 | # that can be invoked in python directly 32 | module = script_module.build() 33 | -------------------------------------------------------------------------------- /docs/source/hidet-script/reference/7-cpu-specific.rst: -------------------------------------------------------------------------------- 1 | CPU Specifics 2 | ============= 3 | 4 | Primitive functions 5 | ------------------- 6 | 7 | Hidet provides primitives to use the avx instructions in modern cpu. They includes 8 | 9 | - ``avx_f32x4_load(...)``: vectorized load 4 f32 values from memory 10 | - ``avx_f32x4_store(...)``: vectorized store 4 f32 values to memory 11 | - ``avx_f32x4_fmadd(...)``: vectorized fused multiply-add operation 12 | - ``avx_f32x4_setzero(...)``: get the zero initialized vector 13 | - ``avx_f32x4_broadcast(...)``: broadcast a scalar to a vector 14 | 15 | There are also corresponding ``f32x8`` primitives. 16 | 17 | Multi-threading 18 | --------------- 19 | 20 | Hidet relies on the OpenMP to support multi-threading. To use the multi-threading, please specify the 21 | ``p`` attribute of the ``hidet.lang.grid`` or ``hidet.lang.mapping.repeat`` functions. 22 | -------------------------------------------------------------------------------- /docs/source/hidet-script/reference/index.rst: -------------------------------------------------------------------------------- 1 | Reference 2 | ========= 3 | 4 | As other programming languages, Hidet Script has its type system, expressions, statements, functions, and modules. 5 | 6 | Each module is a compilation unit, and it contains a collection of functions and global variables. Each function 7 | executes a series of statements. The function can define variables and manipulate them. Each variable has its data 8 | type. 9 | 10 | The details of the type system, expressions, statements, functions, and modules are described in the following sections. 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | :caption: Hidet Script Examples 15 | 16 | 1-type-system 17 | 2-expression 18 | 3-statement 19 | 4-function 20 | 5-module 21 | 6-cuda-specific 22 | 7-cpu-specific 23 | -------------------------------------------------------------------------------- /docs/source/how-to-guides/add-new-operator/index.rst: -------------------------------------------------------------------------------- 1 | Add New Operator 2 | ================ 3 | 4 | 5 | 6 | Hidet is designed to be extensible. It is easy to add new operators to Hidet. There are two ways to add and schedule 7 | an operator. 8 | 9 | 1. **Rule-based Scheduling** 10 | Define the mathematical computation of the operator, and Hidet will automatically schedule the computation into 11 | parallel tensor program with Hidet's rule-based scheduler. 12 | 2. **Template-based Scheduling** 13 | Besides the computation, user can also give the concrete implementation of the operator to achieve better performance 14 | for complex operators. 15 | 16 | .. toctree:: 17 | :maxdepth: 1 18 | :caption: Define Computation 19 | 20 | ../../gallery/developer-guides/add-new-operator-compute-definition 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | :caption: Two Scheduling Methods 25 | 26 | ../../gallery/developer-guides/add-new-operator-rule-based 27 | ../../gallery/developer-guides/add-new-operator-template-based 28 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Hidet's Documentation 2 | ================================ 3 | 4 | Hidet is an open-source DNN inference framework, it features 5 | 6 | - **Ease of Use**: Support end to end inference for PyTorch and ONNX models. 7 | - **High Performance**: Graph-level optimizations and operator-level kernel tuning. 8 | - **Extensibility**: Easy to add new operators, and fusion patterns. 9 | - **Python Oriented**: All core components are written in Python. 10 | 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | :caption: Getting Started 15 | 16 | getting-started/install 17 | gallery/getting-started/quick-start 18 | 19 | .. toctree:: 20 | :maxdepth: 1 21 | :caption: Hidet Script 22 | 23 | hidet-script/index 24 | hidet-script/examples/index 25 | hidet-script/reference/index 26 | 27 | 28 | .. toctree:: 29 | :maxdepth: 1 30 | :caption: Developer Guide 31 | 32 | gallery/developer-guides/add-torch-operator-mapping 33 | how-to-guides/add-new-operator/index 34 | gallery/developer-guides/add-operator-resolve-rule 35 | gallery/developer-guides/add-subgraph-rewrite-rule 36 | developer-guides/contributing.rst 37 | 38 | .. toctree:: 39 | :maxdepth: 1 40 | :caption: Notes 41 | 42 | notes/operator-cache 43 | gallery/how-to-guides/visualize-flow-graph 44 | 45 | .. toctree:: 46 | :maxdepth: 1 47 | :caption: Reference 48 | 49 | python_api/index.rst 50 | genindex 51 | -------------------------------------------------------------------------------- /docs/source/python_api/data_types.rst: -------------------------------------------------------------------------------- 1 | hidet.dtypes 2 | ============ 3 | 4 | Hidet supports the following primitive data types, which can be used as the ``dtype`` parameter of functions like 5 | :func:`hidet.zeros` and :func:`hidet.ones`:. 6 | 7 | .. data:: hidet.uint8 8 | .. data:: hidet.uint16 9 | .. data:: hidet.uint32 10 | .. data:: hidet.uint64 11 | .. data:: hidet.int8 12 | .. data:: hidet.int16 13 | .. data:: hidet.int32 14 | .. data:: hidet.int64 15 | .. data:: hidet.float16 16 | .. data:: hidet.float32 17 | .. data:: hidet.float64 18 | .. data:: hidet.bfloat16 19 | .. data:: hidet.tfloat32 20 | .. data:: hidet.boolean 21 | -------------------------------------------------------------------------------- /docs/source/python_api/drivers.rst: -------------------------------------------------------------------------------- 1 | hidet.drivers 2 | ------------- 3 | 4 | .. automodule:: hidet.drivers 5 | :members: 6 | :imported-members: 7 | :autosummary: 8 | -------------------------------------------------------------------------------- /docs/source/python_api/ffi/index.rst: -------------------------------------------------------------------------------- 1 | hidet.ffi 2 | --------- 3 | 4 | .. automodule:: hidet.ffi 5 | :members: 6 | :imported-members: 7 | :autosummary: 8 | 9 | -------------------------------------------------------------------------------- /docs/source/python_api/graph/frontend/index.rst: -------------------------------------------------------------------------------- 1 | hidet.graph.frontend 2 | ==================== 3 | 4 | .. toctree:: 5 | :caption: Submodules 6 | 7 | onnx 8 | torch 9 | -------------------------------------------------------------------------------- /docs/source/python_api/graph/frontend/onnx.rst: -------------------------------------------------------------------------------- 1 | hidet.graph.frontend.onnx 2 | ------------------------- 3 | 4 | .. autofunction:: hidet.graph.frontend.from_onnx 5 | 6 | -------------------------------------------------------------------------------- /docs/source/python_api/graph/frontend/torch.rst: -------------------------------------------------------------------------------- 1 | hidet.graph.frontend.torch 2 | -------------------------- 3 | 4 | .. autofunction:: hidet.graph.frontend.from_torch 5 | 6 | .. autoclass:: hidet.graph.frontend.torch.DynamoConfig 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/python_api/graph/index.rst: -------------------------------------------------------------------------------- 1 | hidet.graph 2 | =========== 3 | 4 | .. toctree:: 5 | :caption: Submodules 6 | 7 | frontend/index 8 | transforms/index 9 | 10 | 11 | .. automodule:: hidet.graph 12 | :members: 13 | :imported-members: 14 | :special-members: __call__ 15 | :exclude-members: Tensor 16 | :autosummary: 17 | -------------------------------------------------------------------------------- /docs/source/python_api/graph/transforms/index.rst: -------------------------------------------------------------------------------- 1 | hidet.graph.transforms 2 | ====================== 3 | 4 | .. toctree:: 5 | :caption: Transforms 6 | 7 | subgraph_rewrite 8 | resolve_variant 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/python_api/graph/transforms/resolve_variant.rst: -------------------------------------------------------------------------------- 1 | Resolve Operator Pass 2 | ===================== 3 | 4 | .. autoclass:: hidet.graph.transforms.resolve_variant.ResolveRule 5 | :members: 6 | 7 | .. autofunction:: hidet.graph.transforms.resolve_variant.register_resolve_rule 8 | -------------------------------------------------------------------------------- /docs/source/python_api/graph/transforms/subgraph_rewrite.rst: -------------------------------------------------------------------------------- 1 | Sub-graph Rewrite Pass 2 | ---------------------- 3 | 4 | 5 | .. autoclass:: hidet.graph.transforms.subgraph_rewrite.TensorPattern 6 | :members: 7 | 8 | .. autoclass:: hidet.graph.transforms.subgraph_rewrite.OperatorPattern 9 | :members: 10 | 11 | .. autoclass:: hidet.graph.transforms.subgraph_rewrite.SubgraphRewriteRule 12 | :members: 13 | 14 | .. autofunction:: hidet.graph.transforms.subgraph_rewrite.register_rewrite_rule 15 | -------------------------------------------------------------------------------- /docs/source/python_api/index.rst: -------------------------------------------------------------------------------- 1 | Python API 2 | ========== 3 | 4 | .. note:: 5 | 6 | We are actively working on adding more api documentation. 7 | 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Python API 12 | 13 | root 14 | option 15 | cuda 16 | tensor 17 | data_types 18 | drivers 19 | ir/index 20 | ops/index 21 | graph/index 22 | runtime/index 23 | ffi/index 24 | utils/index 25 | testing/index 26 | -------------------------------------------------------------------------------- /docs/source/python_api/ir/compute.rst: -------------------------------------------------------------------------------- 1 | hidet.ir.compute 2 | ================ 3 | 4 | .. tip:: 5 | 6 | Please refer to :doc:`here ` for how to use these 7 | compute primitives to define a computation task. 8 | 9 | 10 | .. automodule:: hidet.ir.compute 11 | :members: 12 | :imported-members: 13 | :autosummary: 14 | -------------------------------------------------------------------------------- /docs/source/python_api/ir/expr.rst: -------------------------------------------------------------------------------- 1 | hidet.ir.expr 2 | ============= 3 | 4 | .. automodule:: hidet.ir.expr 5 | :members: 6 | :autosummary: 7 | -------------------------------------------------------------------------------- /docs/source/python_api/ir/func.rst: -------------------------------------------------------------------------------- 1 | hidet.ir.func 2 | ============= 3 | 4 | .. automodule:: hidet.ir.func 5 | :members: 6 | :autosummary: 7 | -------------------------------------------------------------------------------- /docs/source/python_api/ir/index.rst: -------------------------------------------------------------------------------- 1 | hidet.ir 2 | ======== 3 | 4 | .. toctree:: 5 | :caption: Submodules 6 | 7 | type 8 | expr 9 | stmt 10 | func 11 | compute 12 | task 13 | 14 | -------------------------------------------------------------------------------- /docs/source/python_api/ir/stmt.rst: -------------------------------------------------------------------------------- 1 | hidet.ir.stmt 2 | ============= 3 | 4 | .. automodule:: hidet.ir.stmt 5 | :members: 6 | :autosummary: 7 | -------------------------------------------------------------------------------- /docs/source/python_api/ir/task.rst: -------------------------------------------------------------------------------- 1 | hidet.ir.task 2 | ============= 3 | 4 | .. automodule:: hidet.ir.task 5 | :members: 6 | :autosummary: 7 | -------------------------------------------------------------------------------- /docs/source/python_api/ir/type.rst: -------------------------------------------------------------------------------- 1 | hidet.ir.type 2 | ============= 3 | 4 | .. autoclass:: hidet.ir.type.DataType 5 | 6 | .. autofunction:: hidet.ir.type.data_type 7 | -------------------------------------------------------------------------------- /docs/source/python_api/ops/index.rst: -------------------------------------------------------------------------------- 1 | hidet.ops 2 | ========= 3 | 4 | .. todo:: 5 | 6 | We are still working on the documentation of operators. 7 | 8 | .. automodule:: hidet.ops 9 | :members: 10 | :undoc-members: 11 | :imported-members: 12 | :autosummary: 13 | -------------------------------------------------------------------------------- /docs/source/python_api/option.rst: -------------------------------------------------------------------------------- 1 | hidet.option 2 | ------------ 3 | 4 | .. automodule:: hidet.option 5 | :members: 6 | :autosummary: 7 | :member-order: groupwise 8 | -------------------------------------------------------------------------------- /docs/source/python_api/root.rst: -------------------------------------------------------------------------------- 1 | hidet 2 | ----- 3 | 4 | 5 | .. automodule:: hidet 6 | :members: 7 | :exclude-members: FlowGraph, Tensor, Operator, Task 8 | :imported-members: 9 | :autosummary: 10 | -------------------------------------------------------------------------------- /docs/source/python_api/runtime/index.rst: -------------------------------------------------------------------------------- 1 | hidet.runtime 2 | ============= 3 | 4 | .. automodule:: hidet.runtime 5 | :members: 6 | :imported-members: 7 | :exclude-members: CudaGraph 8 | :autosummary: 9 | -------------------------------------------------------------------------------- /docs/source/python_api/tensor.rst: -------------------------------------------------------------------------------- 1 | hidet.Tensor 2 | ============ 3 | 4 | 5 | .. autoclass:: hidet.Tensor 6 | 7 | .. autoattribute:: shape 8 | .. autoattribute:: dtype 9 | .. autoattribute:: device 10 | .. autoattribute:: size 11 | .. autoattribute:: nbytes 12 | .. autoattribute:: storage 13 | .. autoattribute:: trace 14 | .. autoattribute:: op 15 | .. autoattribute:: layout 16 | .. automethod:: tolist 17 | .. automethod:: to_device 18 | .. automethod:: astype 19 | .. automethod:: cpu 20 | .. automethod:: cuda 21 | .. automethod:: copy 22 | .. automethod:: cpu_async 23 | .. automethod:: cuda_async 24 | .. automethod:: copy_async 25 | .. automethod:: detach 26 | .. automethod:: numpy 27 | .. automethod:: torch 28 | .. automethod:: to 29 | .. automethod:: item 30 | .. automethod:: signature 31 | .. automethod:: is_symbolic 32 | .. automethod:: contiguous 33 | .. automethod:: reshape 34 | .. automethod:: squeeze 35 | .. automethod:: unsqueeze 36 | .. automethod:: rearrange 37 | .. automethod:: sum 38 | .. automethod:: mean 39 | -------------------------------------------------------------------------------- /docs/source/python_api/testing/index.rst: -------------------------------------------------------------------------------- 1 | hidet.testing 2 | ============= 3 | 4 | .. automodule:: hidet.testing 5 | :members: 6 | :imported-members: 7 | :autosummary: 8 | -------------------------------------------------------------------------------- /docs/source/python_api/utils/index.rst: -------------------------------------------------------------------------------- 1 | hidet.utils 2 | ============= 3 | 4 | .. automodule:: hidet.utils 5 | :members: 6 | :imported-members: 7 | :autosummary: 8 | -------------------------------------------------------------------------------- /docs/source/references.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{taso, 2 | title={TASO: optimizing deep learning computation with automatic generation of graph substitutions}, 3 | author={Jia, Zhihao and Padon, Oded and Thomas, James and Warszawski, Todd and Zaharia, Matei and Aiken, Alex}, 4 | booktitle={Proceedings of the 27th ACM Symposium on Operating Systems Principles}, 5 | pages={47--62}, 6 | year={2019} 7 | } 8 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidet-org/hidet/599f628d30e235ade0a680297597fb84a0d7a54e/examples/README.md -------------------------------------------------------------------------------- /gallery/README.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ===== 3 | -------------------------------------------------------------------------------- /gallery/developer-guides/README.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidet-org/hidet/599f628d30e235ade0a680297597fb84a0d7a54e/gallery/developer-guides/README.rst -------------------------------------------------------------------------------- /gallery/getting-started/README.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ===== -------------------------------------------------------------------------------- /gallery/hidet-script/2-vector-addition.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vector Addition 3 | =============== 4 | """ 5 | 6 | # %% 7 | # In this example, we will show you how to write a program that adds two float32 vectors in hidet script. 8 | import hidet 9 | from hidet.lang import attrs 10 | from hidet.lang.types import f32 11 | 12 | hidet.option.cache_dir('./outs/cache') 13 | 14 | # %% 15 | # In the script function, we annotate the data type of parameter ``a``, ``b``, and ``c`` as ``f32[3]``, which means 16 | # a 3-element vector of 32-bit floating point numbers. In general, we can use ``dtype[shape]`` to define a tensor with 17 | # given shape and data type. For example, ``f32[3, 4]`` is a 3x4 float32 matrix, and ``int32[3, 4, 5]`` is a 3x4x5 int32 18 | # tensor. 19 | # 20 | # We can use ``for i in range(extent)`` to iterate over a range, where ``extent`` is the extent of the loop. 21 | with hidet.script_module() as script_module: 22 | 23 | @hidet.script 24 | def launch(a: f32[3], b: f32[3], c: f32[3]): 25 | attrs.func_kind = 'public' 26 | 27 | for i in range(10): 28 | c[i] = a[i] + b[i] 29 | 30 | 31 | module = script_module.build() 32 | 33 | # %% 34 | # Create the input and output tensors (on cpu, with f32 data type by default): 35 | a = hidet.randn([3]) 36 | b = hidet.randn([3]) 37 | c = hidet.empty([3]) 38 | 39 | # %% 40 | # Call the compiled module with the input and output tensors 41 | module(a, b, c) 42 | print(a) 43 | print(b) 44 | print(c) 45 | -------------------------------------------------------------------------------- /gallery/hidet-script/README.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ===== -------------------------------------------------------------------------------- /gallery/how-to-guides/README.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ===== -------------------------------------------------------------------------------- /gallery/tutorials/README.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= -------------------------------------------------------------------------------- /include/hidet/runtime.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | -------------------------------------------------------------------------------- /include/hidet/runtime/callbacks.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #include 13 | #include 14 | 15 | DLL void register_callback(const char *name, void *func_ptr); 16 | 17 | DLL uint64_t allocate_cuda_storage(uint64_t nbytes); 18 | 19 | DLL void free_cuda_storage(uint64_t ptr); 20 | 21 | DLL uint64_t allocate_cpu_storage(uint64_t nbytes); 22 | 23 | DLL void free_cpu_storage(uint64_t ptr); 24 | 25 | DLL void cuda_memset(uint64_t ptr, int value, uint64_t nbytes); 26 | 27 | DLL uint64_t allocate_hip_storage(uint64_t nbytes); 28 | 29 | DLL void free_hip_storage(uint64_t ptr); 30 | 31 | DLL void hip_memset(uint64_t ptr, int value, uint64_t nbytes); 32 | 33 | DLL uint64_t get_torch_stream(); 34 | -------------------------------------------------------------------------------- /include/hidet/runtime/common.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | #include 14 | #include 15 | 16 | #ifndef DLL 17 | #define DLL extern "C" __attribute__((visibility("default"))) 18 | #endif 19 | -------------------------------------------------------------------------------- /include/hidet/runtime/context.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | 14 | #include 15 | #include 16 | 17 | struct Workspace { 18 | void *base; 19 | size_t allocated_nbytes; 20 | Workspace() { 21 | base = nullptr; 22 | allocated_nbytes = 0; 23 | } 24 | }; 25 | 26 | struct BaseContext { 27 | /* The clean workspace. The buffer only contains zero values. */ 28 | Workspace clean_workspace; 29 | /* The dirty workspace. The buffer contains arbitrary values. */ 30 | Workspace dirty_workspace; 31 | }; 32 | -------------------------------------------------------------------------------- /include/hidet/runtime/cpu/context.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | 14 | #include 15 | 16 | struct CpuContext: BaseContext { 17 | static CpuContext *global(); 18 | }; 19 | 20 | /** 21 | * Request a workspace. 22 | */ 23 | DLL void *request_cpu_workspace(size_t nbytes, bool require_clean); 24 | -------------------------------------------------------------------------------- /include/hidet/runtime/cpu/float32.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #include 13 | 14 | static inline float rsqrtf(float x) { 15 | return 1.0f / sqrtf(x); 16 | } 17 | -------------------------------------------------------------------------------- /include/hidet/runtime/cpu/vector_types.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | using float2 = __m64; 4 | using float4 = __m128; 5 | using float8 = __m256; 6 | -------------------------------------------------------------------------------- /include/hidet/runtime/cuda/cudnn.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | #define HIDET_CUDNN_MAX_GPUS 32 14 | 15 | #include 16 | 17 | struct cudnnContext; 18 | typedef struct cudnnContext *cudnnHandle_t; 19 | 20 | typedef void *cudnnBackendDescriptor_t; 21 | 22 | /* Legacy API */ 23 | struct cudnnTensorStruct; 24 | struct cudnnFilterStruct; 25 | struct cudnnConvolutionStruct; 26 | 27 | typedef struct cudnnTensorStruct *cudnnTensorDescriptor_t; 28 | typedef struct cudnnFilterStruct *cudnnFilterDescriptor_t; 29 | typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t; 30 | 31 | struct CudnnContext { 32 | cudnnHandle_t handles[HIDET_CUDNN_MAX_GPUS]; 33 | static CudnnContext *global(); 34 | static cudnnHandle_t current_handle(); 35 | }; 36 | 37 | DLL void hidet_cudnn_set_library_path(const char *path); 38 | -------------------------------------------------------------------------------- /include/hidet/runtime/hip/context.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | 17 | struct HipContext: BaseContext { 18 | /* The hip stream the kernels will be launched on. */ 19 | void *stream = nullptr; 20 | 21 | /** 22 | * Get the instance of hip context. 23 | */ 24 | static HipContext *global(); 25 | }; 26 | 27 | /** 28 | * Set the hip stream of hip context. 29 | */ 30 | DLL void set_hip_stream(void *stream); 31 | 32 | /** 33 | * Get the hip stream of hip context. 34 | */ 35 | DLL void *get_hip_stream(); 36 | 37 | /** 38 | * Request a workspace. 39 | */ 40 | DLL void *request_hip_workspace(size_t nbytes, bool require_clean); 41 | -------------------------------------------------------------------------------- /include/hidet/runtime/hip/f16_utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | struct half4 { 5 | half x, y, z, w; 6 | 7 | __device__ half4() {} 8 | __device__ half4(half x, half y, half z, half w) : x(x), y(y), z(z), w(w) {} 9 | }; 10 | 11 | struct half8 { 12 | half x, y, z, w, a, b, c, d; 13 | 14 | __device__ half8() {} 15 | __device__ half8(half x, half y, half z, half w, half a, half b, half c, half d) 16 | : x(x), y(y), z(z), w(w), a(a), b(b), c(c), d(d) {} 17 | }; 18 | -------------------------------------------------------------------------------- /include/hidet/runtime/int_fastdiv.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | 14 | #include 15 | 16 | #ifdef __CUDA_ARCH__ 17 | #define HOST_DEVICE __host__ __device__ 18 | #else 19 | #define HOST_DEVICE 20 | #endif 21 | 22 | HOST_DEVICE void calculate_magic_numbers(int d, int &m, int &s, int &as); 23 | -------------------------------------------------------------------------------- /include/hidet/runtime/symbols.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | DLL void reset_symbol_table(); 20 | 21 | DLL int32_t get_symbol_value(const char *symbol_name); 22 | 23 | DLL void set_symbol_value(const char *symbol_name, int32_t value); 24 | 25 | DLL void *get_ptr_symbol_value(const char *symbol_name); 26 | 27 | DLL void set_ptr_symbol_value(const char *symbol_name, void *value); 28 | -------------------------------------------------------------------------------- /include/hidet/runtime/torch/stream.h: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #pragma once 13 | #include 14 | 15 | DLL void *hidet_get_current_torch_stream(); 16 | -------------------------------------------------------------------------------- /python/hidet/apps/compile_server/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .compilation import remote_build 13 | from .core import init_api 14 | -------------------------------------------------------------------------------- /python/hidet/apps/compile_server/auth.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import requests 13 | from .core import api_url 14 | 15 | 16 | def get_access_token(username, password): 17 | try: 18 | response = requests.post(api_url('auth'), json={'username': username, 'password': password}) 19 | except requests.exceptions.ConnectionError: 20 | raise RuntimeError('Can not connect to compiler server {}'.format(api_url(''))) from None 21 | 22 | if response.status_code != 200: 23 | raise RuntimeError('Failed to get access token: {}'.format(response.json()['message'])) 24 | return response.json()['access_token'] 25 | -------------------------------------------------------------------------------- /python/hidet/apps/compile_server/core.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import Optional 13 | import hidet 14 | 15 | _api_url: Optional[str] = None 16 | _access_token: Optional[str] = None 17 | 18 | 19 | def init_api(): 20 | global _api_url, _access_token 21 | from .auth import get_access_token 22 | 23 | _api_url = 'http://{}:{}'.format( 24 | hidet.option.get_option('compile_server.addr'), hidet.option.get_option('compile_server.port') 25 | ) 26 | username = hidet.option.get_option('compile_server.username') 27 | password = hidet.option.get_option('compile_server.password') 28 | _access_token = get_access_token(username, password) 29 | 30 | 31 | def api_url(resource: str): 32 | if _api_url is None: 33 | init_api() 34 | return f'{_api_url}/{resource}' 35 | 36 | 37 | def access_token(): 38 | if _access_token is None: 39 | init_api() 40 | return _access_token 41 | -------------------------------------------------------------------------------- /python/hidet/apps/compile_server/user.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import requests 13 | from .core import access_token, api_url 14 | 15 | 16 | def add_user(username, password): 17 | response = requests.post( 18 | api_url('user'), 19 | json={'username': username, 'password': password}, 20 | headers={'Authorization': f'Bearer {access_token()}'}, 21 | ) 22 | if response.status_code != 201: 23 | print('Error: ', response.json()['message']) 24 | return 25 | print(response.json()['message']) 26 | 27 | 28 | def delete_user(username): 29 | response = requests.delete( 30 | api_url('user'), json={'username': username}, headers={'Authorization': f'Bearer {access_token()}'} 31 | ) 32 | if response.status_code != 200: 33 | print('Error: ', response.json()['message']) 34 | return 35 | print(response.json()['message']) 36 | -------------------------------------------------------------------------------- /python/hidet/backend/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .codegen import codegen 13 | from .build import compile_source 14 | -------------------------------------------------------------------------------- /python/hidet/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import bench 13 | -------------------------------------------------------------------------------- /python/hidet/cli/bench/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.graph.frontend.torch import availability as torch_availability 13 | from .bench import hidet_bench_group 14 | 15 | if not torch_availability.dynamo_available(): 16 | raise RuntimeError( 17 | 'PyTorch version is less than 2.0. Please upgrade PyTorch to 2.0 or higher to enable torch dynamo' 18 | 'which is required by the benchmark scripts.' 19 | ) 20 | -------------------------------------------------------------------------------- /python/hidet/cli/bench/bench_all.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import click 13 | from tabulate import tabulate 14 | from hidet.cli.bench.model import BenchModel, all_registered_models 15 | 16 | 17 | @click.command(name='all') 18 | def bench_all(): 19 | header = BenchModel.headers() 20 | result = [bench_model.benchmark() for bench_model in all_registered_models] 21 | 22 | click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')) 23 | -------------------------------------------------------------------------------- /python/hidet/cli/bench/bench_common.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import click 13 | from tabulate import tabulate 14 | from hidet.cli.bench.model import BenchModel, commonly_used_models 15 | 16 | 17 | @click.command(name='common') 18 | def bench_common(): 19 | header = BenchModel.headers() 20 | result = [bench_model.benchmark() for bench_model in commonly_used_models] 21 | 22 | click.echo(tabulate(result, headers=header, tablefmt='github', floatfmt='.3f', numalign='right', stralign='left')) 23 | -------------------------------------------------------------------------------- /python/hidet/cli/bench/nlp/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .models import bench_nlp 13 | -------------------------------------------------------------------------------- /python/hidet/cli/bench/vision/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .inception_v3 import bench_inception_v3 13 | from .mobilenet_v2 import bench_mobilenet_v2 14 | from .resnet import bench_resnet 15 | from .resnext import bench_resnext 16 | -------------------------------------------------------------------------------- /python/hidet/cli/bench/vision/vision_model.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.cli.bench.model import BenchModel 13 | 14 | 15 | class VisionModel(BenchModel): 16 | def __init__(self, model_name: str, batch_size, channels: int, height: int, width: int): 17 | self.model_name = model_name 18 | self.batch_size = batch_size 19 | self.channels = channels 20 | self.height = height 21 | self.width = width 22 | 23 | def __str__(self): 24 | return self.model_name 25 | 26 | def model(self): 27 | import torch 28 | 29 | return torch.hub.load('pytorch/vision:v0.6.0', self.model_name, pretrained=True, verbose=False) 30 | 31 | def example_inputs(self): 32 | import torch 33 | 34 | args = (torch.randn(self.batch_size, self.channels, self.height, self.width),) 35 | kwargs = {} 36 | return args, kwargs 37 | -------------------------------------------------------------------------------- /python/hidet/cli/cache/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .entry import hidet_cache_group 13 | -------------------------------------------------------------------------------- /python/hidet/cli/cache/entry.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import click 13 | from .status import hidet_cache_status 14 | from .clear import hidet_cache_clear 15 | 16 | 17 | @click.group(name='cache', help='Manage hidet cache.') 18 | def hidet_cache_group(): 19 | pass 20 | 21 | 22 | for command in [hidet_cache_status, hidet_cache_clear]: 23 | assert isinstance(command, click.Command) 24 | hidet_cache_group.add_command(command) 25 | -------------------------------------------------------------------------------- /python/hidet/cli/cache/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import os 13 | 14 | 15 | def get_size(path: str) -> int: 16 | if not os.path.exists(path): 17 | return 0 18 | if os.path.isfile(path): 19 | return os.path.getsize(path) 20 | size = 0 21 | for entry in os.scandir(path): 22 | size += get_size(entry.path) 23 | return size 24 | 25 | 26 | def nbytes2str(nbytes: int) -> str: 27 | for uint in ['B', 'KiB', 'MiB', 'GiB', 'TiB']: 28 | if nbytes < 128: 29 | if isinstance(nbytes, int): 30 | return '{} {}'.format(nbytes, uint) 31 | else: 32 | return '{:.2f} {}'.format(nbytes, uint) 33 | nbytes /= 1024 34 | return '{:.2f} PiB'.format(nbytes) 35 | -------------------------------------------------------------------------------- /python/hidet/cli/main.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import click 13 | from hidet.cli.bench import hidet_bench_group 14 | from hidet.cli.cache import hidet_cache_group 15 | from hidet.utils import initialize 16 | 17 | 18 | @click.group(name='hidet') 19 | def main(): 20 | pass 21 | 22 | 23 | @initialize() 24 | def register_commands(): 25 | for group in [hidet_bench_group, hidet_cache_group]: 26 | assert isinstance(group, click.Command) 27 | main.add_command(group) 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /python/hidet/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .capability import capability 13 | from .device import available, device_count, synchronize, compute_capability, properties, profiler_start, profiler_stop 14 | from .device import cudaDeviceProp, set_device, current_device, device, is_available 15 | from .stream import Stream, ExternalStream, stream, default_stream, current_stream 16 | from .memory import malloc, free, malloc_async, free_async, malloc_host, free_host, memcpy_peer, memcpy_peer_async 17 | from .memory import memcpy, memcpy_async, memset, memset_async, memory_info 18 | from .event import Event 19 | 20 | from . import cublas 21 | from . import cudnn 22 | -------------------------------------------------------------------------------- /python/hidet/cuda/cublas/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .ffi import cublasComputeType, cudaDataType 13 | from .kernels import gemm, strided_gemm, batched_gemm 14 | -------------------------------------------------------------------------------- /python/hidet/cuda/cudnn/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .ffi import cudnnDataType 13 | from .kernels import conv2d, conv2d_gemm 14 | -------------------------------------------------------------------------------- /python/hidet/cuda/nccl/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .ffi import nccl_available, nccl_version, nccl_library_filename, group_start, group_end 13 | from .comm import ( 14 | create_comm, 15 | NcclUniqueId, 16 | NcclDataType, 17 | NcclRedOp, 18 | comms_to_array, 19 | create_unique_id, 20 | dtype_to_nccl, 21 | NcclCommunicator, 22 | str_to_nccl_op, 23 | NCCL_SPLIT_NOCOLOR, 24 | ) 25 | -------------------------------------------------------------------------------- /python/hidet/cuda/nccl/libinfo.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import os 13 | 14 | 15 | def _get_nccl_dirs(): 16 | import site 17 | 18 | res = [os.path.join(root, 'nvidia', 'nccl') for root in site.getsitepackages()] 19 | res += [os.path.join(site.getusersitepackages(), 'nvidia', 'nccl')] 20 | return res 21 | 22 | 23 | def get_nccl_include_dirs(): 24 | return [os.path.join(root, 'include') for root in _get_nccl_dirs()] 25 | 26 | 27 | def get_nccl_library_search_dirs(): 28 | return [os.path.join(root, 'lib') for root in _get_nccl_dirs()] 29 | -------------------------------------------------------------------------------- /python/hidet/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from .distributed import ( 14 | init_process_group, 15 | all_reduce, 16 | broadcast, 17 | reduce, 18 | all_gather, 19 | all_gather_into_tensor, 20 | gather, 21 | scatter, 22 | reduce_scatter, 23 | reduce_scatter_tensor, 24 | barrier, 25 | send, 26 | recv, 27 | is_initialized, 28 | get_default_group, 29 | ) 30 | from .group import set_nccl_comms 31 | from .store import FileStore, TCPStore 32 | -------------------------------------------------------------------------------- /python/hidet/drivers/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .build_module import build_ir_module, build_ir_module_batch 13 | from .build_task import build_task, build_task_batch 14 | from .build_graph import build_flow_graph 15 | -------------------------------------------------------------------------------- /python/hidet/ffi/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .ffi import _LIB 13 | from .runtime_api import runtime_api 14 | from .shared_lib import SharedLibrary 15 | 16 | from . import callbacks 17 | from . import convert 18 | from . import crt 19 | from . import array 20 | from . import utils 21 | -------------------------------------------------------------------------------- /python/hidet/ffi/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | 14 | def ctypes_func_pointer(ctypes_func) -> int: 15 | import ctypes 16 | 17 | return ctypes.cast(ctypes_func, ctypes.c_void_p).value 18 | -------------------------------------------------------------------------------- /python/hidet/graph/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import tensor 13 | from . import operator 14 | from . import nn 15 | from . import ops 16 | from . import flow_graph 17 | from . import frontend 18 | 19 | from .tensor import Tensor 20 | from .operator import Operator 21 | from .flow_graph import FlowGraph 22 | from .transforms import GraphPass, PassContext, GraphPassInstrument 23 | from .flow_graph import GraphForwardContext, GraphForwardInstrument 24 | from .nn import Module 25 | from .graph_utils.instruments import GraphForwardBenchmarkInstrument, GraphForwardDebugInstrument 26 | 27 | from .tensor import asarray, randn, empty, zeros, ones, symbol, randint, randn_like, empty_like, zeros_like, ones_like 28 | from .tensor import symbol_like, full, full_like 29 | from .tensor import from_numpy, from_dlpack, from_torch 30 | from .flow_graph import trace_from, load_graph, save_graph, forward_context 31 | from .transforms import optimize, quant 32 | -------------------------------------------------------------------------------- /python/hidet/graph/common.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | def normalize(v, num=2): 13 | if isinstance(v, (list, tuple)): 14 | return v 15 | else: 16 | return [v for _ in range(num)] 17 | -------------------------------------------------------------------------------- /python/hidet/graph/frontend/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import onnx 13 | from . import torch 14 | 15 | if onnx.available(): 16 | from .onnx import from_onnx 17 | 18 | if torch.available(): 19 | from .torch import from_torch 20 | -------------------------------------------------------------------------------- /python/hidet/graph/frontend/onnx/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .availability import available 13 | from . import utils 14 | 15 | if available(): 16 | from .onnx import from_onnx 17 | -------------------------------------------------------------------------------- /python/hidet/graph/frontend/onnx/availability.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | try: 13 | import onnx # pylint: disable=unused-import 14 | except ImportError: 15 | onnx = None 16 | _available = False 17 | else: 18 | _available = True 19 | 20 | 21 | def available(): 22 | """ 23 | Check if onnx is installed. 24 | 25 | Returns 26 | ------- 27 | ret: bool 28 | True if onnx is installed. 29 | """ 30 | return _available 31 | -------------------------------------------------------------------------------- /python/hidet/graph/graph_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import functors 13 | from . import instruments 14 | -------------------------------------------------------------------------------- /python/hidet/graph/graph_utils/instruments/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .benchmark_instrument import GraphForwardBenchmarkInstrument 13 | from .debug_instrument import GraphForwardDebugInstrument 14 | -------------------------------------------------------------------------------- /python/hidet/graph/impl/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | -------------------------------------------------------------------------------- /python/hidet/graph/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import module 13 | from . import container 14 | 15 | from .module import Module 16 | from .identity import Identity 17 | from .container import Sequential, ModuleList 18 | from .attention import CrossAttention 19 | from .activations import Relu, Gelu, Geglu, Tanh 20 | from .convolutions import Conv2d 21 | from .linear import Linear, LinearTransposed 22 | from .norms import BatchNorm2d, LayerNorm, GroupNorm 23 | from .poolings import MaxPool2d, AvgPool2d, AdaptiveAvgPool2d 24 | from .transforms import Embedding 25 | -------------------------------------------------------------------------------- /python/hidet/graph/nn/activations.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.graph import ops 13 | from hidet.graph.nn.linear import Linear 14 | from hidet.graph.nn.module import Module 15 | 16 | 17 | class Relu(Module): 18 | def forward(self, x): 19 | return ops.relu(x) 20 | 21 | 22 | class Gelu(Module): 23 | def forward(self, x): 24 | return ops.gelu(x) 25 | 26 | 27 | class Geglu(Module): 28 | def __init__(self, dim_in: int, dim_out: int, bias: bool = True): 29 | super().__init__() 30 | self.proj = Linear(dim_in, dim_out * 2, bias=bias) 31 | 32 | def forward(self, x): 33 | x = self.proj(x) 34 | hidden_states, gate = ops.split(x, 2, axis=2) 35 | return hidden_states * ops.gelu(gate) 36 | 37 | 38 | class Tanh(Module): 39 | def forward(self, x): 40 | return ops.tanh(x) 41 | -------------------------------------------------------------------------------- /python/hidet/graph/nn/identity.py: -------------------------------------------------------------------------------- 1 | from hidet.graph.nn.module import Module 2 | 3 | 4 | class Identity(Module): 5 | """ 6 | Identity function. 7 | 8 | Used as a dummy for replacing modules (e.g. remove a layer in module list 9 | but need to keep indices in container to match torch model) 10 | """ 11 | 12 | def __init__(self, *args, **kwargs): 13 | super().__init__() 14 | 15 | self.args = args 16 | self.kwargs = kwargs 17 | 18 | def forward(self, x): 19 | return x 20 | -------------------------------------------------------------------------------- /python/hidet/graph/nn/transforms.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.graph import ops 13 | from hidet.graph.nn.module import Module 14 | from hidet.graph.tensor import Tensor, empty 15 | 16 | 17 | class Embedding(Module): 18 | def __init__(self, num_embeddings: int, embedding_dim: int): 19 | super().__init__() 20 | self.num_embeddings = num_embeddings 21 | self.embedding_dim = embedding_dim 22 | self.weight = empty(shape=[num_embeddings, embedding_dim], dtype='float32') 23 | 24 | def forward(self, indices: Tensor) -> Tensor: 25 | return ops.take(self.weight, indices, axis=0) 26 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .attention import attention 13 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .conv1d import conv1d 13 | from .conv1d import Conv1dOp 14 | from .conv1d_gemm import conv1d_gemm 15 | 16 | from . import resolve 17 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/conv1d/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import List, Sequence 13 | from hidet.ir.expr import is_constant 14 | from ..utils import normalize_stride 15 | 16 | 17 | def infer_conv1d_shape( 18 | x_shape: Sequence[int], w_shape: Sequence[int], stride: int, groups: int, dilation: int 19 | ) -> List[int]: 20 | n, c, d = x_shape 21 | oc, gc, kd = w_shape 22 | (sx,) = normalize_stride(stride, dim=1) 23 | dilx = dilation 24 | if is_constant(c) and gc * groups != c: 25 | msg = 'Conv2d: x has {} input channels, w has {} group channels, and groups={}'.format(c, gc, groups) 26 | raise ValueError(msg) 27 | if oc % groups != 0: 28 | msg = 'Conv2d expects out_channels % groups == 0, got out_channels {} and groups {}'.format(oc, groups) 29 | raise ValueError(msg) 30 | p = (d - dilx * (kd - 1) - 1) // sx + 1 31 | return [n, oc, p] 32 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/conv1d_transpose/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .conv1d_transpose import conv1d_transpose 13 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/conv2d/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .conv2d import conv2d, conv2d_channel_last 13 | from .conv2d import Conv2dOp, Conv2dChannelLastOp 14 | from .conv2d_winograd import conv2d_winograd, conv2d_winograd_image_transform, conv2d_winograd_filter_transform 15 | from .conv2d_winograd import conv2d_winograd_inverse_transform 16 | from .conv2d_winograd import Conv2dWinogradInverseTransformOp, Conv2dWinogradFilterTransformOp 17 | from .conv2d_winograd import Conv2dWinogradImageTransformOp 18 | from .conv2d_gemm import ( 19 | conv2d_gemm, 20 | conv2d_gemm_fp16, 21 | conv2d_gemm_fp16_channel_last, 22 | conv2d_gemm_image_transform, 23 | conv2d_gemm_filter_transform, 24 | ) 25 | from .conv2d_gemm import conv2d_gemm_inverse_transform 26 | from .conv2d_gemm import Conv2dGemmImageTransformOp 27 | 28 | 29 | from . import resolve 30 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/conv2d_transpose/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .conv2d_transpose import conv2d_transpose, Conv2dTransposeOp 13 | from .conv2d_transpose_gemm import conv2d_transpose_gemm 14 | 15 | from . import resolve 16 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/conv2d_transpose/resolve.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import List, Optional 13 | from hidet.graph.flow_graph import Tensor 14 | from hidet.graph.transforms import ResolveRule, register_resolve_rule 15 | from hidet.graph import ops 16 | 17 | from .conv2d_transpose import Conv2dTransposeOp 18 | 19 | 20 | @register_resolve_rule(Conv2dTransposeOp) 21 | class Conv2dTransposeResolveRule(ResolveRule): 22 | def resolve(self, op: Conv2dTransposeOp) -> Optional[List[Tensor]]: 23 | attrs = op.attrs 24 | data, weight = op.inputs 25 | stride = attrs['stride'] 26 | padding = attrs['padding'] 27 | groups = attrs['groups'] 28 | output_padding = attrs['output_padding'] 29 | out = ops.conv2d_transpose_gemm(data, weight, stride, padding, groups, output_padding) 30 | return [out] 31 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/conv3d/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .conv3d import conv3d 13 | from .conv3d import Conv3dOp 14 | from .conv3d_gemm import conv3d_gemm, conv3d_gemm_image_transform, conv3d_gemm_filter_transform 15 | from .conv3d_gemm import conv3d_gemm_inverse_transform 16 | from .conv3d_gemm import Conv3dGemmImageTransformOp 17 | 18 | 19 | from . import resolve 20 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/conv3d_transpose/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .conv3d_transpose import conv3d_transpose 13 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/fusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .fused_operator import fused_operator 13 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/matmul/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .matmul import matmul, MatmulOp, MatmulTask, matmul_nt 13 | from .cuda_batch_matmul import cuda_batch_matmul, CudaBatchMatmulOp, BatchMatmulTask 14 | from .matmul_cublas import matmul_cublas 15 | from . import resolve 16 | 17 | 18 | from .matmul_f32_x86 import Matmulx86Op, MatmulF32Taskx86 19 | from .matmul_f32_x86 import matmul_x86 20 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/normalize/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .layers import batch_norm_infer, layer_norm, instance_norm, group_norm 13 | from .lp import lp_norm 14 | from . import resolve 15 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/quant/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .symmetric import symmetric_quantize, symmetric_dequantize 13 | from .matmul import symmetric_quant_matmul 14 | from .matmul_f16_i8 import symmetric_quant_matmul_f16_i8 15 | from .matmul_f16_i8_atomic import symmetric_quant_matmul_atomic_f16_i8 16 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/reduce/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .reduce import ReduceBaseOp, ReduceTask 13 | from .reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any 14 | from .reduce import ReduceSumOp, ReduceMeanOp 15 | from . import resolve 16 | -------------------------------------------------------------------------------- /python/hidet/graph/ops/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.ir.utils.broadcast_utils import * 13 | 14 | from .tensor_utils import * 15 | -------------------------------------------------------------------------------- /python/hidet/graph/transforms/graph_patterns/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .base import TensorPattern, OperatorPattern, SubgraphRewriteRule, MatchDict, Usage, graph_pattern_match 13 | from .base import register_rewrite_rule, op_pattern, registered_rewrite_rules, deregister_rewrite_rule 14 | from .base import clear_registered_rewrite_rules 15 | -------------------------------------------------------------------------------- /python/hidet/graph/transforms/graph_patterns/register_all_patterns.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # pylint: disable=unused-import 13 | from .arithmetic_patterns import arithmetic_patterns 14 | from .transform_patterns import transform_patterns 15 | from .attn_patterns import attn_patterns 16 | from .conv2d_patterns import conv2d_patterns 17 | from .matmul_patterns import matmul_patterns 18 | -------------------------------------------------------------------------------- /python/hidet/graph/transforms/instruments/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .base import GraphPassInstrument 13 | from .profile_instrument import ProfileInstrument 14 | from .save_graph_instrument import SaveGraphInstrument 15 | from .convert_flowgraph_to_vgpu import ConvertGraphToVGPU 16 | -------------------------------------------------------------------------------- /python/hidet/graph/transforms/selective_quantize.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import List 13 | 14 | from hidet.graph.flow_graph import FlowGraph 15 | from hidet.graph.transforms import GraphPass 16 | from .graph_patterns import SubgraphRewriteRule 17 | from .subgraph_rewrite import SubgraphRewritePass 18 | 19 | 20 | class SelectiveQuantizePass(GraphPass): 21 | def process_graph(self, graph: FlowGraph) -> FlowGraph: 22 | rewrite_patterns: List[SubgraphRewriteRule] = self.current_context().configs['quantize_patterns'] 23 | graph = SubgraphRewritePass(rewrite_patterns)(graph) 24 | return graph 25 | 26 | 27 | def selective_quantize_pass() -> GraphPass: 28 | return SelectiveQuantizePass() 29 | -------------------------------------------------------------------------------- /python/hidet/graph/transforms/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.graph.operator import Operator 13 | 14 | 15 | def is_barrier(op: Operator): 16 | from hidet.graph.ops.special import BarrierOp 17 | 18 | return isinstance(op, BarrierOp) 19 | -------------------------------------------------------------------------------- /python/hidet/hip/__init__.py: -------------------------------------------------------------------------------- 1 | from .device import available, device_count, synchronize, compute_capability, properties 2 | from .device import set_device, current_device, device 3 | from .memory import malloc, free, malloc_async, free_async, malloc_host, free_host 4 | from .memory import memcpy, memcpy_async, memset, memset_async, memory_info 5 | from .stream import Stream, ExternalStream, stream, default_stream, current_stream 6 | from .event import Event 7 | from .capability import capability 8 | -------------------------------------------------------------------------------- /python/hidet/hip/ffi.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidet-org/hidet/599f628d30e235ade0a680297597fb84a0d7a54e/python/hidet/hip/ffi.py -------------------------------------------------------------------------------- /python/hidet/ir/analyzers/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import bound_analyzer 13 | 14 | from .bound_analyzer import BoundInfo, BoundAnalyzer, infer_bound, normalize_launch_dims 15 | -------------------------------------------------------------------------------- /python/hidet/ir/builders/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import func_builder 13 | from . import stmt_builder 14 | 15 | from .func_builder import FunctionBuilder 16 | from .stmt_builder import StmtBuilder 17 | -------------------------------------------------------------------------------- /python/hidet/ir/compute/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .reduce_operations import ReduceOperation 13 | from .primitives import ComputeNode 14 | from .primitives import TensorNode, ScalarNode 15 | from .primitives import ScalarInput, TensorInput 16 | from .primitives import GridCompute, ReduceCompute, ArgReduceCompute, ReduceType 17 | from .primitives import scalar_input, tensor_input, compute, reduce, arg_reduce 18 | -------------------------------------------------------------------------------- /python/hidet/ir/compute/cops/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .matmul import matmul 13 | from .pad import pad 14 | from .reduce import reduce 15 | -------------------------------------------------------------------------------- /python/hidet/ir/compute/cops/pad.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import List 13 | from hidet.ir.expr import if_then_else, logical_and, convert 14 | from hidet.ir.compute.primitives import TensorNode, compute 15 | 16 | 17 | def pad(data: TensorNode, pads: List[int], value: float): 18 | shape = data.shape 19 | rank = len(shape) 20 | assert rank * 2 == len(pads) 21 | out_shape = [a + b + c for a, b, c in zip(pads[:rank], shape, pads[rank:])] 22 | 23 | value = convert(value, dtype=data.type.dtype.name) 24 | 25 | def fmap(*indices): 26 | indices = [idx - beg for idx, beg in zip(indices, pads[:rank])] 27 | cond = logical_and(*[logical_and(0 <= idx, idx < shape[i]) for i, idx in enumerate(indices)]) 28 | return if_then_else(cond, data[indices], value) 29 | 30 | out = compute('out', shape=out_shape, fcompute=fmap) 31 | return out 32 | -------------------------------------------------------------------------------- /python/hidet/ir/cute/collective/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from .copy import CollectiveStore, collective_store 14 | -------------------------------------------------------------------------------- /python/hidet/ir/dialects/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | -------------------------------------------------------------------------------- /python/hidet/ir/functors/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .base_functor import BaseFunctor, BaseVisitor, BaseRewriter 13 | from .type_functor import TypeFunctor, TypeVisitor, TypeRewriter 14 | from .mapping_functor import MappingFunctor, MappingVisitor, MappingRewriter 15 | from .layout_functor import LayoutFunctor, LayoutVisitor, LayoutRewriter 16 | from .expr_functor import ExprFunctor, ExprVisitor, ExprRewriter 17 | from .stmt_functor import StmtFunctor, StmtVisitor, StmtRewriter 18 | from .compute_functor import ComputeFunctor, ComputeVisitor, ComputeRewriter 19 | from .module_functor import ModuleFunctor, ModuleVisitor, ModuleRewriter 20 | from .ir_functor import IRFunctor, IRVisitor, IRRewriter 21 | -------------------------------------------------------------------------------- /python/hidet/ir/library/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import cuda 13 | from .cuda import cublas 14 | -------------------------------------------------------------------------------- /python/hidet/ir/library/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .matmul import matmul_simt 13 | from . import cublas 14 | -------------------------------------------------------------------------------- /python/hidet/ir/library/cuda/cublas/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.cuda.cublas.utils import as_type_code 13 | from hidet.cuda.cublas.kernels import cublasComputeType, cudaDataType 14 | from .kernels import gemm, strided_gemm, batched_gemm 15 | from . import regs as _regs # register functions 16 | -------------------------------------------------------------------------------- /python/hidet/ir/library/cuda/matmul/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .simt import matmul_simt 13 | -------------------------------------------------------------------------------- /python/hidet/ir/library/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.ir.tools import infer_type 13 | from hidet.ir.type import TensorType, TensorPointerType 14 | from hidet.ir.expr import Expr 15 | 16 | 17 | def get_tensor_type(expr: Expr) -> TensorType: 18 | expr_type = infer_type(expr) 19 | if isinstance(expr_type, TensorType): 20 | return expr_type 21 | elif isinstance(expr_type, TensorPointerType): 22 | return expr_type.tensor_type 23 | else: 24 | raise TypeError('Can not infer the expr type to get a tensor type, got {}'.format(expr_type)) 25 | -------------------------------------------------------------------------------- /python/hidet/ir/node.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | 14 | class Node: 15 | def __str__(self): 16 | from hidet.ir.tools.printer import astext # pylint: disable=import-outside-toplevel 17 | 18 | return astext(self) 19 | 20 | def __repr__(self): 21 | return str(self) 22 | 23 | def __int__(self): 24 | return None 25 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/cpu/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import math 13 | 14 | from .avx import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store, avx_f32x4_setzero 15 | from .avx import avx_f32x8_broadcast, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_store, avx_f32x8_setzero 16 | from .avx import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc 17 | from .avx import avx_f32x8_store_aligned, avx_f32x8_load_aligned 18 | from .avx import avx_f32x4_store_aligned, avx_f32x4_load_aligned 19 | from .avx import ( 20 | avx_f32x8_unpackhi, 21 | avx_f32x8_unpacklo, 22 | avx_f32x8_shuffle, 23 | avx_f32x8_cast_f32x4, 24 | avx_f32x8_insert_f32x4, 25 | avx_f32x8_permute2f32x4, 26 | ) 27 | 28 | from .atomic import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor 29 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/cpu/math/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import float32 13 | from . import float64 14 | from . import int32 15 | from . import int64 16 | from . import bfloat16 17 | from . import float16 18 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/cuda/errchk.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.ir.stmt import BlackBoxStmt 13 | 14 | 15 | def check_cuda_error(): 16 | stmt = BlackBoxStmt( 17 | r'''{cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << ''' 18 | r'''cudaGetErrorString(err) << "\n";}''' 19 | ) 20 | return stmt 21 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/cuda/lop3.py: -------------------------------------------------------------------------------- 1 | from hidet.ir.expr import Expr, cast 2 | from hidet.ir.stmt import asm 3 | from hidet.ir.dtypes import uint32 4 | 5 | 6 | def lop3(d: Expr, a: Expr, b: Expr, c: Expr, *, imm_lut: int): 7 | """ 8 | Perform a logical operation on three 32-bit values and store the result in `d`. 9 | 10 | The logical operation is determined by the immediate value `imm_lut`. 11 | 12 | See the PTX ISA documentation for the `lop3` instruction for more information: 13 | https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3 14 | 15 | Parameters 16 | ---------- 17 | d: Expr 18 | The pointer to the 32-bit result. 19 | a: Expr 20 | The first 32-bit operand. 21 | b: Expr 22 | The second 32-bit operand. 23 | c: Expr 24 | The third 32-bit operand. 25 | imm_lut: int 26 | The immediate value that determines the logical operation. Given logical operation `f(a, b, c)`, the 27 | immediate value `imm_lut` should be set to `f(0xF0, 0xCC, 0xAA)` to indicate the logical operation. 28 | """ 29 | assert 0 <= imm_lut <= 255 30 | 31 | return asm( 32 | 'lop3.b32 %0, %1, %2, %3, {};'.format(imm_lut), 33 | outputs=[cast(d, ~uint32)[0]], 34 | inputs=[a, b, c, imm_lut], 35 | is_volatile=True, 36 | ) 37 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/cuda/math/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import float16 13 | from . import bfloat16 14 | from . import float8e4m3 15 | from . import float8e5m2 16 | from . import float32 17 | from . import float64 18 | from . import int64 19 | from . import int32 20 | from . import complex64 21 | from . import complex128 22 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/hip/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from . import math 14 | from . import mfma 15 | from . import buffer_addr 16 | from . import lds_sync 17 | 18 | from .errchk import check_hip_error 19 | from .vars import threadIdx, blockIdx, blockDim, gridDim 20 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/hip/errchk.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.ir.stmt import BlackBoxStmt 13 | 14 | 15 | def check_hip_error(): 16 | stmt = BlackBoxStmt( 17 | r'''{hipError_t err = hipGetLastError(); if (err != hipSuccess) LOG(ERROR) << "HIP error: " << ''' 18 | r'''hipGetErrorString(err) << "\n";}''' 19 | ) 20 | return stmt 21 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/hip/lds_sync.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, 2 | # Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | from hidet.utils import initialize 15 | from hidet.ir.stmt import AsmStmt 16 | from hidet.ir.builders import FunctionBuilder 17 | from hidet.ir.primitives.func import register_primitive_function 18 | from hidet.ir.primitives.func import call_primitive_func 19 | 20 | 21 | @initialize() 22 | def register_lds_sync(): 23 | with FunctionBuilder("hip_lds_sync", kind='hip_internal') as fb: 24 | fb += AsmStmt("s_waitcnt lgkmcnt(0)", is_volatile=True) 25 | fb += AsmStmt("s_barrier", is_volatile=True) 26 | 27 | register_primitive_function(name="hip_lds_sync", func_or_type=fb.func) 28 | 29 | 30 | def lds_sync(): 31 | return call_primitive_func("hip_lds_sync", []) 32 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/hip/math/__init__.py: -------------------------------------------------------------------------------- 1 | from . import float32 2 | from . import float16 3 | from . import int32 4 | -------------------------------------------------------------------------------- /python/hidet/ir/primitives/vars.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import Dict 13 | 14 | from hidet.ir.expr import Var 15 | from hidet.ir.type import DataType 16 | 17 | 18 | registered_primitive_variables: Dict[str, Var] = {} 19 | 20 | 21 | def register_primitive_variable(name: str, dtype: DataType): 22 | if name in registered_primitive_variables: 23 | raise KeyError('Primitive variable {} has already registered.'.format(name)) 24 | var = Var(hint=None, type=dtype, name=name) 25 | registered_primitive_variables[name] = var 26 | return var 27 | 28 | 29 | def lookup_primitive_variable(name: str) -> Var: 30 | if name not in registered_primitive_variables: 31 | raise KeyError('Primitive variable {} has not registered.'.format(name)) 32 | return registered_primitive_variables[name] 33 | -------------------------------------------------------------------------------- /python/hidet/ir/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .cpu import CpuAutoScheduler 13 | from .cuda import GpuAutoScheduler 14 | -------------------------------------------------------------------------------- /python/hidet/ir/schedulers/cpu/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .scheduler import CpuAutoScheduler 13 | -------------------------------------------------------------------------------- /python/hidet/ir/schedulers/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .scheduler import GpuAutoScheduler 13 | -------------------------------------------------------------------------------- /python/hidet/ir/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .type_infer import infer_type, TypeInfer 13 | from .util_functors import collect 14 | from .rewriter import rewrite, MapBasedRewriter 15 | from .free_var_collector import collect_free_vars 16 | from .printer import IRPrinter, astext 17 | from .simplifier import simplify, simplify_to_int 18 | from .hasher import ExprHash 19 | from .renamer import rename_funcs 20 | 21 | # from .ir_dumper import astext2, parse 22 | -------------------------------------------------------------------------------- /python/hidet/ir/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import call_graph 13 | from . import hash_sum 14 | from . import index_transform 15 | from . import expr_utils 16 | 17 | from .index_transform import index_serialize, index_deserialize 18 | from .expr_utils import as_expr 19 | from .broadcast_utils import can_broadcast, can_mutually_broadcast, broadcast_shape, broadcast_shapes, broadcast_indices 20 | -------------------------------------------------------------------------------- /python/hidet/ir/utils/expr_utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import Union 13 | from hidet.ir.expr import Expr, constant 14 | 15 | 16 | def as_expr(e: Union[int, float, bool, Expr]): 17 | if isinstance(e, Expr): 18 | return e 19 | elif isinstance(e, int): 20 | return constant(value=e, const_type='int32') 21 | elif isinstance(e, float): 22 | return constant(value=e, const_type='float32') 23 | elif isinstance(e, bool): 24 | return constant(value=e, const_type='bool') 25 | else: 26 | raise ValueError('Cannot convert {} to hidet.ir.Expr.'.format(e)) 27 | -------------------------------------------------------------------------------- /python/hidet/lang/attrs.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | -------------------------------------------------------------------------------- /python/hidet/lang/attrs/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # pylint: disable=pointless-string-statement 13 | from typing import Optional 14 | from hidet.ir.expr import Var 15 | from hidet.lang.attrs import cuda 16 | from hidet.lang.attrs import hip 17 | 18 | 19 | """ 20 | general attributes 21 | """ 22 | # The label of the scope 23 | label: Optional[str] = None 24 | 25 | """ 26 | function attributes 27 | """ 28 | # The name of the function. The default hidet function name is the name of wrapped python function. 29 | # Please set this attribute if we want to have a different name 30 | func_name: Optional[str] = None 31 | 32 | # The kind of this function. Candidates: 'cuda_kernel', 'cuda_internal', 'cpu_kernel', 'packed_func' 33 | func_kind: Optional[str] = None 34 | 35 | 36 | # If the func_kind == packed_func, then this attribute should be set to the var to function to be packed. 37 | # packed_func: Optional[Var] = None 38 | -------------------------------------------------------------------------------- /python/hidet/lang/attrs/cuda.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import Union, Tuple 13 | from hidet.ir.expr import Expr 14 | 15 | 16 | Int = Union[Expr, int] 17 | Dim3 = Union[Int, Tuple[Int, Int], Tuple[Int, Int, Int]] 18 | 19 | # The grid dimension of a cuda kernel, specifying the number of thread blocks 20 | grid_dim: Dim3 = 1 21 | 22 | # The optional cluster dimension of a cuda kernel, specifying the number of thread blocks per cluster 23 | cluster_dim: Dim3 = 1 24 | 25 | # The block dimension of a cuda kernel, specifying the number of threads per block 26 | block_dim: Dim3 = 1 27 | 28 | # A hint to nvcc compiler the minimal number of thread blocks should be executed on 29 | # the same streaming processor (SM). This attribute will influence the register allocation 30 | # strategy adopted by nvcc. 31 | min_blocks: int = 1 32 | 33 | # The size of dynamic shared memory allocated to the cuda kernel. 34 | dynamic_smem_bytes: Int = 0 35 | -------------------------------------------------------------------------------- /python/hidet/lang/attrs/hip.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import Union, Tuple 13 | from hidet.ir.expr import Expr 14 | 15 | 16 | Int = Union[Expr, int] 17 | Dim3 = Union[Int, Tuple[Int, Int], Tuple[Int, Int, Int]] 18 | 19 | # The grid dimension of a hip kernel, specifying the number of thread blocks 20 | grid_dim: Dim3 = 1 21 | 22 | # The block dimension of a hip kernel, specifying the number of threads per block 23 | block_dim: Dim3 = 1 24 | 25 | # The dynamic shared memory size of a hip kernel, in bytes 26 | dynamic_smem_bytes: int = 0 27 | -------------------------------------------------------------------------------- /python/hidet/lang/constructs/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import loops 13 | from . import declare 14 | from . import context 15 | -------------------------------------------------------------------------------- /python/hidet/lang/constructs/context.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import Optional, Any 13 | 14 | from hidet.ir.stmt import Stmt 15 | 16 | 17 | class HidetContext: 18 | """ 19 | Custom context manager used in Hidet Script to support the syntax of `with ... as ...`. 20 | 21 | with HidetContext() as bind_value: 22 | body(bind_value) 23 | 24 | with be transformed to 25 | 26 | post_process(body(bind_value)) 27 | """ 28 | 29 | def bind_value(self) -> Optional[Any]: 30 | return None 31 | 32 | def post_process(self, body: Stmt) -> Stmt: 33 | raise NotImplementedError() 34 | -------------------------------------------------------------------------------- /python/hidet/lang/constructs/meta.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from typing import Sequence, List, Type, Union 13 | import builtins 14 | from hidet.ir.type import BaseType 15 | 16 | 17 | class HidetMetaLoopIterable: 18 | def __init__(self, iterable): 19 | self.iterable = iterable 20 | 21 | def __iter__(self): 22 | return iter(self.iterable) 23 | 24 | 25 | class HidetMetaParamTypeList: 26 | def __init__(self, arg_types: Sequence[BaseType]): 27 | self.arg_types: List[BaseType] = list(arg_types) 28 | 29 | 30 | def range(extent: int): 31 | return HidetMetaLoopIterable(builtins.range(extent)) 32 | 33 | 34 | def each(iterable): 35 | return HidetMetaLoopIterable(iterable) 36 | 37 | 38 | def types(arg_types: Sequence[Union[BaseType, Type[Union[int, float, bool]]]]): 39 | return HidetMetaParamTypeList(arg_types) 40 | -------------------------------------------------------------------------------- /python/hidet/lang/layout.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # pylint: disable=unused-import 13 | from hidet.ir.layout import row_major, column_major, local_layout, strided_layout, DataLayout 14 | -------------------------------------------------------------------------------- /python/hidet/lang/mapping.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # pylint: disable=unused-import 13 | from hidet.ir.mapping import TaskMapping 14 | from hidet.ir.mapping import row_repeat as repeat 15 | from hidet.ir.mapping import row_spatial as spatial 16 | from hidet.ir.mapping import auto_map 17 | 18 | 19 | def chain(*task_mappings) -> TaskMapping: 20 | assert len(task_mappings) > 0 21 | composed = task_mappings[0] 22 | for mapping in task_mappings[1:]: 23 | composed = composed * mapping 24 | return composed 25 | -------------------------------------------------------------------------------- /python/hidet/lang/runtime.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # pylint: disable=unused-import 13 | from hidet.ir.primitives.runtime import get_cuda_stream, request_cuda_workspace, request_cpu_workspace 14 | -------------------------------------------------------------------------------- /python/hidet/lang/types.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # pylint: disable=unused-import 13 | from hidet.ir.dtypes import i8, i16, i32, i64, u8, u16, u32, u64, f16, f32, f64, bf16, tf32, i4, u4, i2, u2, i1, u1 14 | from hidet.ir.dtypes import int8, int16, int32, int64, uint8, uint32, uint64, float16, float32, float64, bfloat16 15 | from hidet.ir.dtypes import tfloat32 16 | from hidet.ir.dtypes import f16x2, float16x2 17 | from hidet.ir.dtypes import float8_e4m3, float8_e5m2, f8e4m3, f8e5m2 18 | 19 | from hidet.ir.type import void_p, void, byte_p, tensor_pointer_type 20 | 21 | from hidet.lang.constructs.declare import register_tensor, shared_tensor, tensor_pointer, tensor, DeclareScope 22 | -------------------------------------------------------------------------------- /python/hidet/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.ffi.ffi import BackendException 13 | 14 | from . import storage 15 | from . import compiled_module 16 | from . import compiled_task 17 | from . import compiled_graph 18 | 19 | from .storage import Storage 20 | from .compiled_module import CompiledModule, CompiledFunction, load_compiled_module 21 | from .compiled_task import CompiledTask, load_compiled_task 22 | from .compiled_graph import CompiledGraph, save_compiled_graph, load_compiled_graph 23 | -------------------------------------------------------------------------------- /python/hidet/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.utils.benchmark import benchmark_func 13 | 14 | from . import models 15 | from . import utils 16 | from .utils import check_unary, check_unary_dynamic, check_binary, check_binary_dynamic 17 | from .utils import check_ternary, check_torch_unary 18 | from .utils import ( 19 | check_torch_binary, 20 | check_torch_binary_with_inputs, 21 | check_torch_binary_dynamic, 22 | check_torch_ternary, 23 | assert_torch_allclose, 24 | ) 25 | from .torch_utils import device_to_torch 26 | -------------------------------------------------------------------------------- /python/hidet/testing/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import gpt2 13 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 14 | -------------------------------------------------------------------------------- /python/hidet/transforms/cute/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | -------------------------------------------------------------------------------- /python/hidet/transforms/cute/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from .tensor_alias_analysis import TensorAliasAnalysis, TensorInfo, tensor_info 14 | -------------------------------------------------------------------------------- /python/hidet/transforms/cute/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | -------------------------------------------------------------------------------- /python/hidet/transforms/cute/cuda/lower_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .registry import OpEmitter, Buffer, register_impl, emit_op, request_smem_nbytes, TmaTensor, tma_tensor 13 | from . import registry 14 | from . import tensor 15 | from . import partition 16 | from . import copy 17 | from . import rearrange 18 | from . import collective 19 | from . import arithmetic 20 | from . import mma 21 | from . import subtensor 22 | from . import reduce 23 | from . import misc 24 | from . import mbarrier 25 | -------------------------------------------------------------------------------- /python/hidet/transforms/cute/generic/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | -------------------------------------------------------------------------------- /python/hidet/transforms/instruments/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .base import PassInstrument 13 | from .profile_instrument import ProfileInstrument 14 | from .save_ir_instrument import SaveIRInstrument 15 | -------------------------------------------------------------------------------- /python/hidet/transforms/instruments/base.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from hidet.ir.module import IRModule 13 | 14 | 15 | class PassInstrument: 16 | def before_all_passes(self, ir_module: IRModule): 17 | pass 18 | 19 | def before_pass(self, pass_name: str, ir_module: IRModule): 20 | pass 21 | 22 | def after_pass(self, pass_name: str, ir_module: IRModule): 23 | pass 24 | 25 | def after_all_passes(self, ir_module: IRModule): 26 | pass 27 | -------------------------------------------------------------------------------- /python/hidet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from . import doc 13 | from . import namer 14 | from . import py 15 | from . import netron 16 | from . import transformers_utils 17 | from . import structure 18 | from . import stack_limit 19 | from . import fault_handler 20 | 21 | from .py import prod, Timer, repeat_until_converge, COLORS, get_next_file_index, factorize, HidetProfiler, TableBuilder 22 | from .py import same_list, strict_zip, index_of, initialize, gcd, lcm, error_tolerance, green, red, cyan, bold, blue 23 | from .py import str_indent, unique, assert_close, cdiv 24 | from .structure import DirectedGraph 25 | from .cache_utils import cache_dir, cache_file, clear_op_cache, clear_cache_dir 26 | from .net_utils import download 27 | from .files import copy_tree_ignore_existing 28 | from .gc import gc_disabled 29 | -------------------------------------------------------------------------------- /python/hidet/utils/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from .bench import Bench, BenchData, do_bench, benchmark_func 13 | from .gpu_freq import GPUSetFrequencyForBenchmarking 14 | -------------------------------------------------------------------------------- /python/hidet/utils/cache_utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import os 13 | import shutil 14 | import hidet.option 15 | 16 | 17 | def cache_dir(*items: str) -> str: 18 | root = hidet.option.get_cache_dir() 19 | ret = os.path.abspath(os.path.join(root, *items)) 20 | os.makedirs(ret, exist_ok=True) 21 | return ret 22 | 23 | 24 | def cache_file(*items: str) -> str: 25 | root_dir = cache_dir('./') 26 | ret = os.path.abspath(os.path.join(root_dir, *items)) 27 | os.makedirs(os.path.dirname(ret), exist_ok=True) 28 | return ret 29 | 30 | 31 | def clear_cache_dir(*items: str): 32 | root = hidet.option.get_cache_dir() 33 | dir_to_clear = os.path.abspath(os.path.join(root, *items)) 34 | print('Clearing hidet cache dir: {}'.format(dir_to_clear)) 35 | shutil.rmtree(dir_to_clear, ignore_errors=True) 36 | 37 | 38 | def clear_op_cache(): 39 | clear_cache_dir('ops') 40 | -------------------------------------------------------------------------------- /python/hidet/utils/counters.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import collections 13 | from typing import Counter, DefaultDict 14 | 15 | counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter) 16 | -------------------------------------------------------------------------------- /python/hidet/utils/fault_handler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Enable python faulthandler to print the python traceback when a segfault occurs. 3 | 4 | See: https://docs.python.org/3/library/faulthandler.html 5 | """ 6 | 7 | import faulthandler 8 | 9 | faulthandler.enable() 10 | -------------------------------------------------------------------------------- /python/hidet/utils/folder_lock.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import os 13 | from filelock import FileLock 14 | 15 | 16 | class FolderLock(FileLock): 17 | """ 18 | A context manager for file-based locking using flock. 19 | Ensures that only one process can acquire the lock at a time. 20 | 21 | Parameters 22 | ---------- 23 | lock_file_path : str 24 | Path to the lock file. 25 | """ 26 | 27 | def __init__(self, lock_dir): 28 | self.lock_file_path = os.path.join(lock_dir, ".lock") 29 | super().__init__(self.lock_file_path) 30 | -------------------------------------------------------------------------------- /python/hidet/utils/gc.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import gc 13 | from contextlib import contextmanager 14 | 15 | 16 | @contextmanager 17 | def gc_disabled(): 18 | """ 19 | A context manager to disable garbage collection 20 | and ensure it is re-enabled afterward. 21 | """ 22 | was_enabled = gc.isenabled() 23 | gc.disable() # Disable garbage collection 24 | try: 25 | yield 26 | finally: 27 | if was_enabled: 28 | gc.enable() # Re-enable garbage collection if it was originally enabled 29 | -------------------------------------------------------------------------------- /python/hidet/utils/overrides.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | def set_module(module): 13 | """ 14 | Set the module of the given class/function to the given module. 15 | """ 16 | 17 | def decorator(obj): 18 | if module is not None: 19 | obj.__module__ = module 20 | return obj 21 | 22 | return decorator 23 | -------------------------------------------------------------------------------- /python/hidet/version.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from importlib.metadata import version 13 | 14 | __version__ = version("hidet") 15 | -------------------------------------------------------------------------------- /scripts/bench/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark Server 2 | 3 | This script is used to benchmark the performance of hidet, repeatedly running and send github comment to an issue. 4 | 5 | ## Usage 6 | 7 | ```bash 8 | # Install github cli 9 | sudo apt install -y gh 10 | # Install dependency 11 | pip install -r scripts/bench/requirements.txt 12 | # clone the repo 13 | git clone git@github.com:hidet-org/hidet 14 | # cd into the repo 15 | cd hidet 16 | # run the daemon script, you can specify the issue to send the report to 17 | # is the issue number, e.g. 135 18 | # is the time to run the benchmark everyday in format HH:MM, e.g. 03:00 19 | python scripts/bench/run.py [--issue ] [--schedule-time ] 20 | ``` 21 | -------------------------------------------------------------------------------- /scripts/bench/requirements.txt: -------------------------------------------------------------------------------- 1 | schedule 2 | importlib_metadata 3 | -------------------------------------------------------------------------------- /scripts/lint/.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: Google 3 | IndentWidth: 4 4 | --- 5 | Language: Cpp 6 | ColumnLimit: 120 7 | 8 | # Right-aligned pointers/references, don't automatically determine on a 9 | # per-file basis 10 | DerivePointerAlignment: false 11 | PointerAlignment: Right 12 | ReferenceAlignment: Pointer 13 | 14 | # Grouping of imports: standard imports first, don't group 15 | IncludeBlocks: Merge 16 | IncludeCategories: 17 | - Regex: '^<.*\.h>' 18 | Priority: 2 19 | - Regex: '^<.*>' 20 | Priority: 1 21 | - Regex: '^".*"' 22 | Priority: 3 23 | 24 | # Misc 25 | AllowShortFunctionsOnASingleLine: Inline 26 | SpaceAfterTemplateKeyword: false 27 | SpaceBeforeInheritanceColon: false -------------------------------------------------------------------------------- /scripts/lint/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # work in the same directory of this script 4 | SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 5 | cd $SCRIPT_DIR/../.. 6 | 7 | # run pylint 8 | python -m pylint --rcfile ./scripts/lint/pylintrc -j $(nproc) ./python/hidet 9 | -------------------------------------------------------------------------------- /scripts/nightly-builds/README.md: -------------------------------------------------------------------------------- 1 | # Nightly built wheels server 2 | 3 | This folder contains the scripts that are used to serve the downloading the nightly-built wheels. 4 | 5 | 6 | The two scripts in this folder 7 | - `add-crobtab-record.sh`: add a crontab record to the system. The crontab record will let the system launch the `update-nightly.sh` script at 0:00 every day. 8 | - `update-nightly.sh`: this script will pull the latest commit from the main branch of `hidet-org/hidet` repo and build a wheel with versions like `0.4.1.dev20240721` and put it to the `whl/hidet` subdirectory of our wheel server for the users downloading. 9 | 10 | Setup steps: 11 | 1. Launch a web server. 12 | 2. Run the `add-crontab-record.sh`. 13 | -------------------------------------------------------------------------------- /scripts/nightly-builds/add-crontab-record.sh: -------------------------------------------------------------------------------- 1 | cp ./update-nightly.sh /home/ubuntu 2 | sudo crontab -l > ./crontab-temp 3 | sudo echo "0 0 * * * /home/ubuntu/update-nightly.sh" >> ./crontab-temp 4 | sudo crontab ./crontab-temp 5 | -------------------------------------------------------------------------------- /scripts/nightly-builds/update-nightly.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # stop immediately if a command exits with a non-zero status. 4 | set -e 5 | 6 | # download and extract the latest commit hidet 7 | wget https://github.com/hidet-org/hidet/archive/refs/heads/main.zip -O hidet.zip 8 | rm -rf hidet-main 9 | unzip hidet.zip 10 | cd hidet-main 11 | 12 | # build the manylinux1 wheel 13 | VERSION=`python3 scripts/wheel/current_version.py --nightly` 14 | bash scripts/wheel/build_wheel_manylinux1.sh $VERSION 15 | 16 | # copy the wheel to the deploy directory 17 | if [ $# -eq 1 ] 18 | then 19 | DEPLOY_DIR=$1 20 | else 21 | DEPLOY_DIR=/var/www/download.hidet.org/html/whl/hidet 22 | fi 23 | WHEEL_PATH=`ls scripts/wheel/built_wheel/*.whl` 24 | echo "Copying $WHEEL_PATH to $DEPLOY_DIR" 25 | cp scripts/wheel/built_wheel/*.whl $DEPLOY_DIR 26 | 27 | # remove the old wheels before 7 days ago in DEPLOY_DIR 28 | find $DEPLOY_DIR -type f -mtime +7 -name '*.whl' -execdir echo -- 'Removing old wheel {}' \; 29 | find $DEPLOY_DIR -type f -mtime +7 -name '*.whl' -execdir rm -- '{}' \; 30 | 31 | # clean up 32 | cd .. 33 | rm -f hidet.zip 34 | rm -rf hidet-main 35 | -------------------------------------------------------------------------------- /scripts/regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidet-org/hidet/599f628d30e235ade0a680297597fb84a0d7a54e/scripts/regression/__init__.py -------------------------------------------------------------------------------- /scripts/regression/email_sender.py: -------------------------------------------------------------------------------- 1 | import smtplib 2 | from email.mime.text import MIMEText 3 | import getpass 4 | 5 | 6 | class EmailSender: 7 | def __init__(self) -> None: 8 | self.recipients = input('Enter comma separated recipient email addresses:\n') 9 | self.recipients = self.recipients.replace(' ', '').split(',') 10 | self.sender = input('Enter your Gmail address:\n') 11 | self.password = getpass.getpass(prompt='Enter your 16-digit Google app password: ') 12 | 13 | def send_email(self, body): 14 | subject = 'Hidet Performance Regression' 15 | msg = MIMEText(body) 16 | msg['Subject'] = subject 17 | msg['From'] = self.sender 18 | msg['To'] = ', '.join(self.recipients) 19 | with smtplib.SMTP_SSL('smtp.gmail.com', 465) as smtp_server: 20 | smtp_server.login(self.sender, self.password) 21 | smtp_server.sendmail(self.sender, self.recipients, msg.as_string()) 22 | print("Results sent to", msg['To']) 23 | 24 | -------------------------------------------------------------------------------- /scripts/regression/requirements.txt: -------------------------------------------------------------------------------- 1 | schedule 2 | importlib_metadata 3 | opencv-python 4 | pandas 5 | ultralytics 6 | -------------------------------------------------------------------------------- /scripts/wheel/build_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Build a wheel: 4 | # 5 | # $ bash scripts/build_wheel.sh 6 | # 7 | # would generate a .whl file in the scripts directory. 8 | # 9 | 10 | set -e # exit immediately if a command exits with a non-zero status. 11 | 12 | ############################################################################### 13 | # This script builds a wheel for the current platform and Python version. 14 | ############################################################################### 15 | 16 | 17 | # work in the same directory of this script 18 | CURRENT_SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 19 | ROOT_DIR=$(cd -- "$CURRENT_SCRIPT_DIR/../.." &> /dev/null && pwd) 20 | cd $CURRENT_SCRIPT_DIR 21 | 22 | # clean the build cache 23 | rm -rf ${ROOT_DIR}/build 24 | rm -rf ${CURRENT_SCRIPT_DIR}/built_wheel 25 | 26 | # build wheel 27 | mkdir -p built_wheel; 28 | cd built_wheel; pip3 wheel --no-deps $ROOT_DIR; cd .. 29 | -------------------------------------------------------------------------------- /scripts/wheel/build_wheel_manylinux_2_28_x86_64.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # work in the same directory of this script 4 | SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 5 | cd $SCRIPT_DIR 6 | 7 | # get the root directory of the project 8 | HIDET_DIR=$(cd -- "$SCRIPT_DIR/../.." &> /dev/null && pwd) 9 | echo $HIDET_DIR 10 | 11 | # build the docker image 12 | ls ${SCRIPT_DIR}/dockerfiles/manylinux_2_28_x86_64/ 13 | docker build -t hidet-manylinux_2_28_x86_64-build ${SCRIPT_DIR}/dockerfiles/manylinux_2_28_x86_64/ 14 | 15 | # run the docker image 16 | docker run --rm -u $(id -u):$(id -g) -v $HIDET_DIR:/io hidet-manylinux_2_28_x86_64-build bash /io/scripts/wheel/build_wheel.sh 17 | -------------------------------------------------------------------------------- /scripts/wheel/dockerfiles/manylinux_2_28_x86_64/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM quay.io/pypa/manylinux_2_28_x86_64 as base 2 | 3 | # Install pip 4 | RUN python3 -m ensurepip --upgrade 5 | 6 | RUN pip3 install wheel -------------------------------------------------------------------------------- /scripts/wheel/update_nightly.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # stop immediately if a command exits with a non-zero status. 4 | set -e 5 | 6 | # download and extract the latest commit hidet 7 | wget https://github.com/hidet-org/hidet/archive/refs/heads/main.zip -O hidet.zip 8 | rm -rf hidet-main 9 | unzip hidet.zip 10 | cd hidet-main 11 | 12 | # build the manylinux1 wheel 13 | VERSION=`python3 scripts/wheel/current_version.py --nightly` 14 | bash scripts/wheel/build_wheel_manylinux1.sh $VERSION 15 | 16 | # copy the wheel to the deploy directory 17 | if [ $# -eq 1 ] 18 | then 19 | DEPLOY_DIR=$1 20 | else 21 | DEPLOY_DIR=/var/www/html/whl/hidet 22 | fi 23 | WHEEL_PATH=`ls scripts/wheel/built_wheel/*.whl` 24 | echo "Copying $WHEEL_PATH to $DEPLOY_DIR" 25 | cp scripts/wheel/built_wheel/*.whl $DEPLOY_DIR 26 | 27 | # remove the old wheels before 7 days ago in DEPLOY_DIR 28 | find $DEPLOY_DIR -type f -mtime +7 -name '*.whl' -execdir echo -- 'Removing old wheel {}' \; 29 | find $DEPLOY_DIR -type f -mtime +7 -name '*.whl' -execdir rm -- '{}' \; 30 | 31 | # clean up 32 | cd .. 33 | rm -f hidet.zip 34 | rm -rf hidet-main 35 | -------------------------------------------------------------------------------- /scripts/wheel/upload_wheel_to_pypi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WHEEL=$1 4 | 5 | set -e # exit immediately if a command exits with a non-zero status. 6 | 7 | # twine upload --repository testpypi $WHEEL 8 | twine upload $WHEEL 9 | 10 | -------------------------------------------------------------------------------- /src/hidet/empty.cpp: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | -------------------------------------------------------------------------------- /src/hidet/runtime/cuda/utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | inline T get_symbol(void *lib, const char *name) { 6 | T ret = (T)dlsym(lib, name); 7 | if (ret == nullptr) { 8 | LOG(FATAL) << "Failed to load symbol: " << std::endl << " " << dlerror(); 9 | } 10 | return ret; 11 | } 12 | -------------------------------------------------------------------------------- /src/hidet/runtime/logging.cpp: -------------------------------------------------------------------------------- 1 | // Licensed under the Apache License, Version 2.0 (the "License"); 2 | // you may not use this file except in compliance with the License. 3 | // You may obtain a copy of the License at 4 | // 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // 7 | // Unless required by applicable law or agreed to in writing, software 8 | // distributed under the License is distributed on an "AS IS" BASIS, 9 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | // See the License for the specific language governing permissions and 11 | // limitations under the License. 12 | #include 13 | 14 | ErrorState *ErrorState::global() { 15 | static thread_local ErrorState instance; 16 | return &instance; 17 | } 18 | 19 | DLL void hidet_set_last_error(const char *msg) { 20 | ErrorState *state = ErrorState::global(); 21 | if (state->has_error) { 22 | fprintf(stderr, "Warning: hidet error state has been override: %s\n", state->error_msg.c_str()); 23 | } 24 | state->has_error = true; 25 | state->error_msg = msg; 26 | } 27 | 28 | DLL const char *hidet_get_last_error() { 29 | ErrorState *state = ErrorState::global(); 30 | if (state->has_error) { 31 | state->has_error = false; 32 | return state->error_msg.c_str(); 33 | } else { 34 | return nullptr; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /tests/benchmarks/bench_task.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hidet 3 | 4 | from hidet.runtime.compiled_task import CompiledTask 5 | from hidet.drivers import build_task 6 | from hidet.testing.torch_utils import bench_model 7 | from hidet.testing.utils import init_hidet 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(prog='Benchmark task') 12 | parser.add_argument('--task', type=str, default='./task.pickle', help='Path to dump of task') 13 | parser.add_argument('--cache', type=str, default='', help='') 14 | args = parser.parse_args() 15 | task_path, cache = args.task, args.cache 16 | 17 | init_hidet(cache=cache) 18 | 19 | task: hidet.Task = hidet.load_task(task_path) 20 | inputs = task.dummy_arguments('cuda') 21 | compiled_task: CompiledTask = build_task(task, target='cuda') 22 | 23 | # For dynamic shapes should set their value 24 | for tensor in task.params: 25 | for dim in tensor.shape: 26 | if isinstance(dim, hidet.ir.expr.SymbolVar): 27 | hidet.ffi.runtime_api.set_symbol_value(dim.name, 2) 28 | 29 | out = compiled_task(*inputs) 30 | 31 | lat = bench_model(compiled_task, inputs) 32 | 33 | print(lat) 34 | -------------------------------------------------------------------------------- /tests/cute_fusion/fusion_bench_utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import torch 13 | 14 | 15 | def bench(functor, args): 16 | """ 17 | A small benchmark function for fusion tests. 18 | """ 19 | warmup_iters = 10 20 | bench_iters = 100 21 | 22 | for _ in range(warmup_iters): 23 | functor(*args) 24 | 25 | latencies = [] 26 | start = torch.cuda.Event(enable_timing=True) 27 | end = torch.cuda.Event(enable_timing=True) 28 | 29 | start.record() 30 | 31 | for _ in range(bench_iters): 32 | functor(*args) 33 | 34 | end.record() 35 | end.synchronize() 36 | latencies.append(start.elapsed_time(end) / bench_iters) 37 | 38 | mean = sum(latencies) / len(latencies) 39 | min_lat = min(latencies) 40 | max_lat = max(latencies) 41 | 42 | return mean, min_lat, max_lat 43 | -------------------------------------------------------------------------------- /tests/cute_ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hidet-org/hidet/599f628d30e235ade0a680297597fb84a0d7a54e/tests/cute_ops/__init__.py -------------------------------------------------------------------------------- /tests/frontends/onnx/test_onnx_slice.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import pytest 13 | import torch 14 | 15 | from hidet.testing.onnx_utils import check_onnx_and_hidet 16 | 17 | 18 | class SliceModule(torch.nn.Module): 19 | def __init__(self, indices): 20 | super().__init__() 21 | self.indices = indices 22 | 23 | def forward(self, x): 24 | return x[self.indices] 25 | 26 | 27 | @pytest.mark.parametrize('shape,indices', [((100,), slice(2, None))]) 28 | def test_slice(shape, indices, device): 29 | check_onnx_and_hidet(SliceModule(indices), [torch.randn(shape)], device=device) 30 | -------------------------------------------------------------------------------- /tests/frontends/torch/models/test_torch_densenet121.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import torch 13 | import pytest 14 | from hidet.testing.torch_utils import check_module 15 | 16 | 17 | @pytest.mark.slow 18 | @pytest.mark.parametrize('shape', [[1, 3, 224, 224]]) 19 | def test_densenet121(shape, device): 20 | model = torch.hub.load('pytorch/vision:v0.6.0', 'densenet121', pretrained=True).eval().to(torch.float16) 21 | x = torch.randn(*shape).to(torch.float16) * 0.1796 + 0.5491 22 | check_module(model, [x], atol=4e-2, rtol=4e-2, dynamic=False, device=device) 23 | 24 | 25 | if __name__ == '__main__': 26 | pytest.main([__file__]) 27 | -------------------------------------------------------------------------------- /tests/frontends/torch/test_torch_conv1d.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import pytest 13 | import torch 14 | from hidet.testing.torch_utils import check_module 15 | import torch.backends.cudnn as cudnn 16 | 17 | 18 | @pytest.mark.parametrize('in_shape,w_shape,stride,padding,groups', [[[1, 3, 224], [42, 3, 7], 2, 1, 1]]) 19 | @pytest.mark.parametrize('dtype', [torch.float32]) 20 | def test_conv1d(in_shape, w_shape, stride, padding, groups, dtype, device): 21 | check_module( 22 | model=torch.nn.Conv1d( 23 | in_channels=in_shape[1], 24 | out_channels=w_shape[0], 25 | kernel_size=w_shape[2:], 26 | stride=stride, 27 | padding=padding, 28 | groups=groups, 29 | ), 30 | args=[torch.randn(in_shape, dtype=dtype)], 31 | atol=2e-3, 32 | device=device, 33 | ) 34 | 35 | 36 | if __name__ == '__main__': 37 | pytest.main([__file__]) 38 | -------------------------------------------------------------------------------- /tests/frontends/torch/test_torch_creation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from hidet.testing.torch_utils import FunctionalModule, check_module 5 | 6 | 7 | @pytest.mark.parametrize('shape', [(32), (32, 32)]) 8 | @pytest.mark.parametrize('dtype', [torch.float32, torch.float16]) 9 | def test_empty_like(shape, dtype, device): 10 | check_module(FunctionalModule(op=torch.empty_like), [torch.randn(shape, dtype=dtype)], device=device) 11 | 12 | 13 | if __name__ == '__main__': 14 | pytest.main([__file__]) 15 | -------------------------------------------------------------------------------- /tests/frontends/torch/test_torch_mix_cuda_cpu.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import pytest 13 | import hidet 14 | import hidet.testing 15 | import torch 16 | from torch import nn 17 | 18 | 19 | class CopyTensorModule(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.w_cpu = torch.zeros(1, device='cpu') 23 | self.w_cuda = torch.zeros(1, device='cuda') 24 | 25 | def forward(self, x: torch.Tensor): 26 | return ((x.cpu() + self.w_cpu).cuda() + self.w_cuda).cpu().cuda() 27 | 28 | 29 | def test_torch_mix_cuda_cpu(device): 30 | if device != 'cuda': 31 | pytest.skip('TODO: support hip backend') 32 | model = CopyTensorModule() 33 | x = torch.randn(3, 4, device='cpu') 34 | y = model(x) 35 | 36 | model_opt = torch.compile(model, backend='hidet', mode=None) 37 | y1 = model_opt(x) 38 | 39 | torch.testing.assert_close(y, y1, rtol=0.0, atol=0.0) 40 | -------------------------------------------------------------------------------- /tests/frontends/torch/test_torch_mul.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # Define constants 7 | SHAPE_SIZE = 512 8 | 9 | 10 | class MultiplyModel(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, x): 15 | # Multiply input tensor by 3 16 | return x * 3 17 | 18 | 19 | @torch.inference_mode 20 | @pytest.mark.requires_cuda 21 | def test_mul(): 22 | model = MultiplyModel().cuda().to(dtype=torch.bfloat16) 23 | model_opt = torch.compile(model, backend='hidet', mode='default') 24 | 25 | # First run with batch size 32 26 | x1 = torch.randn(32, SHAPE_SIZE, device="cuda", dtype=torch.bfloat16) 27 | torch._dynamo.mark_dynamic(x1, 0) 28 | output_ref1 = model(x1) 29 | output_opt1 = model_opt(x1) 30 | 31 | assert torch.allclose(output_ref1, output_opt1, rtol=1e-2, atol=1e-2) 32 | 33 | # Second run with a different batch size 34 | x2 = torch.randn(64, SHAPE_SIZE, device="cuda", dtype=torch.bfloat16) 35 | torch._dynamo.mark_dynamic(x2, 0) 36 | output_ref2 = model(x2) 37 | output_opt2 = model_opt(x2) 38 | 39 | assert torch.allclose(output_ref2, output_opt2, rtol=1e-2, atol=1e-2) 40 | 41 | 42 | if __name__ == '__main__': 43 | # For manual testing 44 | test_mul() 45 | -------------------------------------------------------------------------------- /tests/ir/functors/test_persistence.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | -------------------------------------------------------------------------------- /tests/ir/primitives/cuda/test_lop3.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import hidet 4 | 5 | 6 | @pytest.mark.requires_cuda 7 | def test_lop3(): 8 | from hidet.lang import attrs, script 9 | from hidet.lang.types import uint32 10 | from hidet.ir.primitives.cuda import lop3 11 | 12 | with hidet.script_module() as script_module: 13 | 14 | @script 15 | def kernel(d_ptr: ~uint32, a: uint32, b: uint32, c: uint32): 16 | attrs.func_kind = 'cuda_kernel' 17 | attrs.cuda.grid_dim = 1 18 | attrs.cuda.block_dim = 32 19 | 20 | lop3(d_ptr, a, b, c, imm_lut=(0xF0 & 0xCC) | 0xAA) 21 | 22 | func = script_module.build() 23 | 24 | d = torch.empty([1], dtype=torch.int32, device='cuda') 25 | func(d, 0xFFFFFFFF, 0x00FF00FF, 0x0E00EE00) 26 | assert d[0] == 0x0EFFEEFF 27 | -------------------------------------------------------------------------------- /tests/ir/primitives/cuda/test_prmt.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import hidet 4 | 5 | 6 | @pytest.mark.requires_cuda 7 | def test_prmt(): 8 | from hidet.lang import attrs, script 9 | from hidet.lang.types import uint32 10 | from hidet.ir.primitives.cuda import prmt 11 | 12 | with hidet.script_module() as script_module: 13 | 14 | @script 15 | def kernel(d_ptr: ~uint32, a: uint32, b: uint32, c: uint32): 16 | attrs.func_kind = 'cuda_kernel' 17 | attrs.cuda.grid_dim = 1 18 | attrs.cuda.block_dim = 32 19 | 20 | prmt(d=d_ptr, a=a, b=b, c=c) 21 | 22 | func = script_module.build() 23 | 24 | d_int32 = torch.empty([1], dtype=torch.int32, device='cuda') 25 | func(d_int32, 0x00000201, 0x00000064, 0x4140) 26 | d_int32 = d_int32.item() 27 | assert d_int32 == 0x64026401 28 | -------------------------------------------------------------------------------- /tests/minimal/test_add.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import pytest 13 | import numpy as np 14 | import hidet 15 | from hidet.testing import device_to_torch 16 | 17 | 18 | def test_add(device): 19 | torch_device = device_to_torch(device) 20 | a = hidet.randn([10], device=torch_device) 21 | b = hidet.randn([10], device=torch_device) 22 | c = a + b 23 | c_np = a.cpu().numpy() + b.cpu().numpy() 24 | np.testing.assert_allclose(actual=c.cpu().numpy(), desired=c_np, atol=1e-5, rtol=1e-5) 25 | 26 | 27 | if __name__ == '__main__': 28 | pytest.main([__file__]) 29 | -------------------------------------------------------------------------------- /tests/models/test_gemma.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from transformers import AutoTokenizer 5 | 6 | import hidet 7 | from hidet.testing.models.gemma import GemmaForCausalLM 8 | 9 | 10 | @pytest.mark.skip(reason="This test requires access to the Gemma model on Hugging Face") 11 | def test_gemma(): 12 | os.environ["TOKENIZERS_PARALLELISM"] = 'false' 13 | model = GemmaForCausalLM().cuda().build() 14 | 15 | tok = AutoTokenizer.from_pretrained("google/gemma-2b") 16 | text = "Since the beginning of time" 17 | 18 | prompt = tok.encode(text, return_tensors="pt").cuda() 19 | prompt = hidet.from_torch(prompt) 20 | output = model.generate(prompt, num_tokens=15)[0].torch() 21 | ans = tok.decode(output) 22 | assert ans == "Since the beginning of time, people have been fascinated by the idea of" 23 | -------------------------------------------------------------------------------- /tests/multiprocessing/test_lazy_initialization.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import pytest 13 | import os 14 | import sys 15 | import subprocess 16 | 17 | 18 | @pytest.mark.requires_cuda 19 | def test_lazy_initialization(): 20 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 21 | python_path = sys.executable 22 | cmd = [python_path, os.path.join(cur_dir, 'lazy_init_sample.py')] 23 | subprocess.run(cmd, check=True) 24 | 25 | 26 | if __name__ == '__main__': 27 | pytest.main([__file__]) 28 | -------------------------------------------------------------------------------- /tests/operators_core/test_create.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import pytest 13 | import hidet 14 | 15 | 16 | @pytest.mark.parametrize('n', [3, 7]) 17 | @pytest.mark.parametrize('m', [3, 7]) 18 | @pytest.mark.parametrize('k', [-1, 1]) 19 | def test_tri(n, m, k): 20 | import numpy as np 21 | 22 | a = hidet.ops.tri(n, m, k) 23 | b = np.tri(n, m, k) 24 | assert np.allclose(a.numpy(), b) 25 | -------------------------------------------------------------------------------- /tests/operators_core/test_operator.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import hidet 13 | import pytest 14 | import torch 15 | import numpy as np 16 | 17 | 18 | @pytest.mark.requires_cuda 19 | def test_profile_config(): 20 | a = hidet.randn([1, 10, 10], device='cuda') 21 | b = hidet.randn([1, 10, 10], device='cuda') 22 | hidet.option.search_space(1) 23 | hidet.option.bench_config(1, 1, 1) 24 | c = hidet.ops.cuda_batch_matmul(a, b) 25 | hidet.option.search_space(0) 26 | 27 | 28 | if __name__ == '__main__': 29 | pytest.main(__file__) 30 | -------------------------------------------------------------------------------- /tests/operators_core/test_symmetric_quant.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import pytest 13 | import numpy as np 14 | 15 | import hidet 16 | from hidet import ops 17 | 18 | 19 | @pytest.mark.parametrize("w", [32, 64]) 20 | @pytest.mark.parametrize("dtype", ['int8', 'int16']) 21 | @pytest.mark.parametrize("dims", [[-1], [0]]) 22 | def test_symmetric_quant(w, dtype, dims): 23 | w = hidet.randn((w, w), dtype='float32') 24 | wq, scale = ops.symmetric_quantize(w, dtype, dims) 25 | w1 = ops.symmetric_dequantize(wq, scale, dims) 26 | assert np.allclose(w.numpy(), w1.numpy(), atol=1e-1, rtol=1e-1) 27 | 28 | 29 | if __name__ == '__main__': 30 | pytest.main([__file__]) 31 | -------------------------------------------------------------------------------- /tests/operators_core/test_tri.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | import numpy as np 13 | import pytest 14 | from hidet.graph import ops 15 | from hidet.testing import check_unary 16 | 17 | 18 | @pytest.mark.parametrize("op", ["triu", "tril"]) 19 | @pytest.mark.parametrize("shape, diagonal", [[(3, 4), -1], [(4, 3), 2], [(3, 1, 4, 5), 0], [(3, 3), -99]]) 20 | def test_tri(op, shape, diagonal, device): 21 | numpy_op, hidet_op = getattr(np, op), getattr(ops, op) 22 | check_unary( 23 | shape, 24 | numpy_op=lambda x: numpy_op(x, diagonal), 25 | hidet_op=lambda x: hidet_op(x, diagonal), 26 | dtype="float32", 27 | atol=1e-6, 28 | rtol=1e-6, 29 | device=device, 30 | ) 31 | -------------------------------------------------------------------------------- /tests/runtime/test_try_catch.py: -------------------------------------------------------------------------------- 1 | import hidet 2 | from hidet.ir.stmt import BlackBoxStmt 3 | from hidet.lang import attrs 4 | from hidet.runtime import BackendException 5 | 6 | 7 | def test_catch_runtime_exception(): 8 | 9 | with hidet.script_module() as script_module: 10 | 11 | @hidet.script 12 | def launch(): 13 | attrs.func_kind = 'public' 14 | 15 | BlackBoxStmt('throw HidetException("This is a runtime exception.");') 16 | 17 | func = script_module.build() 18 | print(func.source()) 19 | """ 20 | DLL void hidet_launch() { 21 | try { 22 | throw HidetException("This is a runtime exception."); 23 | } catch (HidetException &e) { 24 | hidet_set_last_error(e.what()); 25 | return ; 26 | } 27 | } 28 | """ 29 | try: 30 | func() 31 | except BackendException as e: 32 | print('Caught a runtime exception: ', e) 33 | else: 34 | raise AssertionError('Should have raised a runtime exception') 35 | -------------------------------------------------------------------------------- /tests/script/test_assignment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import hidet 4 | from hidet.lang import script, attrs, asm, deref 5 | from hidet.lang.types import int32 6 | 7 | 8 | def test_hidet_script_assign_list(): 9 | with hidet.script_module() as script_module: 10 | 11 | @script 12 | def cuda_load(addr: ~int32, v0: ~int32): 13 | attrs.func_kind = "cuda_internal" 14 | attrs.func_name = 'cuda_load' 15 | template = "ld.b32 %0, [%1];" 16 | outputs = [deref(v0)] 17 | asm(template, outputs=outputs, inputs=[addr], is_volatile=True) 18 | 19 | @script 20 | def kernel(src: ~int32, dst: ~int32): 21 | attrs.func_kind = "cuda_kernel" 22 | attrs.func_name = 'kernel' 23 | attrs.cuda.grid_dim = 1 24 | attrs.cuda.block_dim = 1 25 | 26 | reg_int32: int32 # define a variable without initialization 27 | cuda_load(src, ~reg_int32) # load the value from the global address 28 | dst[0] = reg_int32 29 | 30 | built = script_module.build() 31 | 32 | src = torch.asarray([123], dtype=torch.int32, device='cuda') 33 | dst = torch.zeros(1, dtype=torch.int32, device='cuda') 34 | built(src, dst) 35 | torch.testing.assert_close(src, dst) 36 | -------------------------------------------------------------------------------- /tests/script/test_lambda.py: -------------------------------------------------------------------------------- 1 | import hidet 2 | from hidet.lang import script, attrs, printf 3 | 4 | 5 | def test_hidet_script_lambda(): 6 | def print_grid(m: int, n: int, f_grid): 7 | from hidet.ir.builders import StmtBuilder 8 | 9 | sb = StmtBuilder() 10 | with sb.for_range(m) as i: 11 | with sb.for_range(n) as j: 12 | sb += printf("%2d ", f_grid(i, j)) 13 | sb += printf("\n") 14 | return sb.finish() 15 | 16 | with hidet.script_module() as script_module: 17 | 18 | @script 19 | def launch(): 20 | attrs.func_kind = 'public' 21 | print_grid(9, 9, lambda i, j: (i + 1) * (j + 1)) 22 | 23 | built = script_module.build() 24 | built() 25 | -------------------------------------------------------------------------------- /tests/unit_tests/check_import_time.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | # first import torch so that we only profile the time used to import hidet itself (and its other dependencies) 4 | # torch takes more than 0.5 seconds when I add this test, and it is not ignorable. 5 | import torch 6 | 7 | t1 = time.time_ns() 8 | import hidet 9 | 10 | t2 = time.time_ns() 11 | 12 | import_time = (t2 - t1) / 1e9 13 | print('Import hidet takes: {:.3f} seconds'.format(import_time)) 14 | assert import_time < 2.0 # make sure hidet could be imported within (1 seconds + torch's import time) 15 | -------------------------------------------------------------------------------- /tests/unit_tests/test_import_time.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | 5 | 6 | def test_import_time(): 7 | """ 8 | Make sure hidet can be imported within a given time limit. 9 | """ 10 | script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'check_import_time.py') 11 | subprocess.check_call([sys.executable, script_path]) 12 | -------------------------------------------------------------------------------- /tests/utils/test_cuda_sanitizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import hidet 4 | from hidet.utils.cuda_sanitizer import sanitizer_run, sanitizer_get_path 5 | 6 | 7 | def func(b): 8 | a = hidet.empty([1000], device='cuda') 9 | a + b 10 | 11 | 12 | @pytest.mark.skipif(not os.path.exists(sanitizer_get_path()), reason='CUDA Compute Sanitizer is not available.') 13 | def test_nsys_run(): 14 | sanitizer_run(func, b=1) 15 | -------------------------------------------------------------------------------- /tests/utils/test_ncu_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import hidet 4 | from hidet.utils.ncu_utils import ncu_run, ncu_get_path 5 | 6 | 7 | def func(b): 8 | a = hidet.empty([1000], device='cuda') 9 | a + b 10 | 11 | 12 | @pytest.mark.skipif(not os.path.exists(ncu_get_path()), reason='Nsight Compute is not available.') 13 | @pytest.mark.skip( 14 | reason='Skip due to the ci error: The user does not have permission to access NVIDIA GPU Performance Counters on the target device 0' 15 | ) 16 | def test_nsys_run(): 17 | report = ncu_run(func, b=1) 18 | # we can visualize the profiling result by calling the `visualize` method. 19 | # do not test this part as it will open the nsight compute ui and waiting for the user to close it. 20 | # report.visualize() 21 | -------------------------------------------------------------------------------- /tests/utils/test_nsys_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import hidet 4 | from hidet.utils.nsys_utils import nsys_run, nsys_get_path 5 | 6 | 7 | def func(b): 8 | a = hidet.empty([1000], device='cuda') 9 | a + b 10 | 11 | 12 | @pytest.mark.skipif(not os.path.exists(nsys_get_path()), reason='Nsight System is not available.') 13 | @pytest.mark.skip( 14 | reason='Skip due to the ci error: The user does not have permission to access NVIDIA GPU Performance Counters on the target device 0' 15 | ) 16 | def test_nsys_run(): 17 | report = nsys_run(func, b=1) 18 | # we can visualize the profiling result by calling the `visualize` method. 19 | # do not test this part as it will open the nsight system ui and waiting for the user to close it. 20 | # report.visualize() 21 | --------------------------------------------------------------------------------