├── .ahub └── tcchecker-tca │ └── config.yaml ├── .coveragerc ├── .github ├── actionlint.yaml ├── actions │ └── download-release-asset │ │ ├── action.yml │ │ └── entrypoint.sh └── workflows │ ├── check-pr.yaml │ ├── check-workflows.yaml │ ├── publish-nightly-package.yaml │ └── publish-official-package.yaml ├── .gitignore ├── .lintrunner.toml ├── .pylintrc ├── LICENSE ├── README.md ├── ccex ├── docs ├── design.md ├── requirements.md └── system_test.md ├── infra ├── command │ ├── build │ ├── configure │ ├── coverage │ ├── format │ ├── install │ └── test ├── dependency │ ├── torch_dev.txt │ └── torchvision_dev.txt ├── scripts │ ├── build.sh │ ├── coverage.sh │ ├── install.sh │ ├── install_requirements.txt │ ├── test_configure.sh │ └── test_run.py └── style │ ├── format.sh │ ├── install.sh │ └── requirements.txt ├── mypy.ini ├── requirements-lintrunner.txt ├── setup.py ├── test ├── README.md ├── __init__.py ├── dump_exported_program.py ├── dump_pt2_model.py ├── modules │ ├── README.md │ ├── __init__.py │ ├── base.py │ ├── model │ │ ├── Bert │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── BitNet_b1_58 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── DeepSeekR1DistillQwen1_5B │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── EfficientFormerL1 │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── EfficientNetV2S │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── Florence2 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── GPT2 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── Gemma3 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── InceptionV3 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── Llama │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── LlamaAttentionWithKVCache │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── LlamaDecoderLayerTRIV │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── LlamaWithGQA │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── LlamaWithKVCache │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── Mamba │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── MambaMixer │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── MobileNetV2 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── MobileNetV3S │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── ResNet18 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── SmolVLM_connector │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── SmolVLM_text_model │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── SmolVLM_vision_model │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── TinyLlama │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ ├── TinyLlamaWithFusedRMSNorm │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── requirements.txt │ │ └── __init__.py │ ├── net │ │ ├── ConvEmbed.py │ │ ├── KVCache.py │ │ ├── README.md │ │ ├── RMSNorm.py │ │ ├── RoPE.py │ │ ├── SDPA.py │ │ ├── __init__.py │ │ ├── mlp.py │ │ └── mlp_dyn.py │ └── op │ │ ├── README.md │ │ ├── __init__.py │ │ ├── abs.py │ │ ├── add.py │ │ ├── addmm.py │ │ ├── alias_copy.py │ │ ├── any.py │ │ ├── arange.py │ │ ├── argmax.py │ │ ├── avg_pool2d.py │ │ ├── bmm.py │ │ ├── cat.py │ │ ├── clamp.py │ │ ├── clone.py │ │ ├── constant_pad_nd.py │ │ ├── conv1d.py │ │ ├── conv2d.py │ │ ├── conv_transpose2d.py │ │ ├── copy.py │ │ ├── cos.py │ │ ├── cumsum.py │ │ ├── depthwise_conv2d.py │ │ ├── detach.py │ │ ├── div.py │ │ ├── embedding.py │ │ ├── eq.py │ │ ├── exp.py │ │ ├── expand_copy.py │ │ ├── full.py │ │ ├── full_like.py │ │ ├── ge.py │ │ ├── gelu.py │ │ ├── gt.py │ │ ├── hardtanh.py │ │ ├── index.py │ │ ├── index_select.py │ │ ├── instance_norm.py │ │ ├── interpolate.py │ │ ├── le.py │ │ ├── leaky_relu.py │ │ ├── linear.py │ │ ├── log.py │ │ ├── log1p.py │ │ ├── logical_and.py │ │ ├── logical_not.py │ │ ├── lt.py │ │ ├── max_dim.py │ │ ├── max_pool2d.py │ │ ├── maximum.py │ │ ├── mean.py │ │ ├── minimum.py │ │ ├── mm.py │ │ ├── mul.py │ │ ├── native_batch_norm.py │ │ ├── native_group_norm.py │ │ ├── native_layer_norm.py │ │ ├── ne.py │ │ ├── neg.py │ │ ├── permute.py │ │ ├── pow.py │ │ ├── prelu.py │ │ ├── reciprocal.py │ │ ├── relu.py │ │ ├── relu6.py │ │ ├── repeat.py │ │ ├── reshape.py │ │ ├── round.py │ │ ├── rsqrt.py │ │ ├── scalar_tensor.py │ │ ├── select.py │ │ ├── select_copy.py │ │ ├── sigmoid.py │ │ ├── sin.py │ │ ├── slice_copy.py │ │ ├── slice_scatter.py │ │ ├── softmax.py │ │ ├── split_with_sizes.py │ │ ├── sqrt.py │ │ ├── squeeze.py │ │ ├── sub.py │ │ ├── sum.py │ │ ├── tanh.py │ │ ├── to.py │ │ ├── to_dim_order_copy.py │ │ ├── unsqueeze.py │ │ ├── view.py │ │ └── where.py ├── performance │ ├── README.md │ ├── __init__.py │ ├── benchmark_perf.py │ ├── requirements.txt │ └── utils.py ├── pt2_to_circle_test │ ├── README.md │ ├── __init__.py │ ├── builder.py │ ├── test_model.py │ ├── test_net.py │ ├── test_op.py │ └── test_pt2_to_circle.py ├── pt2_to_qcircle_test │ ├── __init__.py │ ├── builder.py │ └── test_op.py ├── quantization │ ├── __init__.py │ ├── algorithm │ │ ├── __init__.py │ │ ├── test_fpi_gptq.py │ │ ├── test_gptq.py │ │ └── test_smooth_quant.py │ ├── evaluation │ │ ├── __init__.py │ │ └── test_evaluation.py │ ├── pass │ │ ├── __init__.py │ │ ├── test_convert_layout_op_to_reshape.py │ │ ├── test_fold_quant_ops.py │ │ ├── test_insert_quantize_on_dtype_mismatch.py │ │ ├── test_propagate_qparam_backward.py │ │ ├── test_propagate_quant_param.py │ │ └── test_remove_weight_dequant_op.py │ ├── test_quantizer_registry.py │ └── wrapq │ │ ├── __init__.py │ │ ├── observers │ │ ├── __init__.py │ │ ├── test_affine_base.py │ │ ├── test_base.py │ │ ├── test_ema.py │ │ ├── test_identity.py │ │ ├── test_minmax.py │ │ └── test_mx.py │ │ ├── test_dtype.py │ │ ├── test_mode.py │ │ ├── test_qscheme.py │ │ ├── test_quant_config.py │ │ ├── utils │ │ ├── __init__.py │ │ ├── test_introspection.py │ │ ├── test_metrics.py │ │ └── test_reduce_utils.py │ │ └── wrappers │ │ ├── __init__.py │ │ ├── fairseq │ │ ├── __init__.py │ │ ├── test_decoder_export_single_step.py │ │ ├── test_quant_decoder.py │ │ ├── test_quant_decoder_layer.py │ │ ├── test_quant_encoder.py │ │ ├── test_quant_encoder_layer.py │ │ └── test_quant_mha.py │ │ ├── llama │ │ ├── __init__.py │ │ ├── test_quant_attn.py │ │ ├── test_quant_decoder_layer.py │ │ └── test_quant_mlp.py │ │ ├── nn │ │ ├── __init__.py │ │ ├── test_quant_layernorm.py │ │ ├── test_quant_linear.py │ │ └── test_quant_silu.py │ │ ├── test_ptq_wrapper.py │ │ ├── test_quant_elementwise.py │ │ ├── test_quant_module_base.py │ │ └── test_registry.py ├── requirements.txt ├── requirements_pre.txt ├── unit_test │ ├── __init__.py │ ├── pass_test │ │ ├── __init__.py │ │ ├── test_cast_clamp_mixed_type_args.py │ │ ├── test_convert_conv1d_to_conv2d.py │ │ ├── test_convert_expand_to_slice_cat.py │ │ ├── test_decompose_fake_quantize.py │ │ ├── test_decompose_fake_quantize_tensor_qparam.py │ │ ├── test_decompose_grouped_conv2d.py │ │ ├── test_fuse_leading_unsqueeze_reshape.py │ │ ├── test_fuse_redundant_reshape_to_mean.py │ │ ├── test_legalize_causal_mask_value.py │ │ ├── test_lower_pow2_to_mul.py │ │ ├── test_lower_to_slice.py │ │ ├── test_merge_consecutive_cat.py │ │ ├── test_remove_nop.py │ │ ├── test_remove_redundant_expand.py │ │ ├── test_remove_redundant_permute.py │ │ ├── test_remove_redundant_reshape.py │ │ ├── test_remove_redundant_slice.py │ │ ├── test_remove_redundant_to_copy.py │ │ └── test_segment_index_select.py │ ├── quantization_test │ │ ├── __init__.py │ │ ├── test_adaptive_avg_pool2d.py │ │ ├── test_circle_executor.py │ │ ├── test_evaluate.py │ │ └── test_metric.py │ ├── serialize_test │ │ ├── __init__.py │ │ ├── operator │ │ │ ├── __init__.py │ │ │ ├── fixture.py │ │ │ └── test_utils.py │ │ ├── test_circle_graph.py │ │ ├── test_circle_mapping.py │ │ └── test_pack.py │ └── utils_test │ │ ├── __init__.py │ │ ├── test_convert.py │ │ ├── test_diff_graph.py │ │ ├── test_enforce_type.py │ │ ├── test_infer.py │ │ ├── test_mx.py │ │ ├── test_record_input.py │ │ ├── test_register_custom_op.py │ │ ├── test_run_bash_cmd.py │ │ ├── test_serialize.py │ │ ├── test_signature.py │ │ └── test_utils.py └── utils │ ├── base_builders.py │ ├── helper.py │ ├── infer.py │ ├── pass_value_test.py │ ├── runtime.py │ └── tag.py ├── tico ├── __init__.py ├── config │ ├── __init__.py │ ├── base.py │ ├── factory.py │ └── v1.py ├── experimental │ └── __init__.py ├── interpreter │ ├── __init__.py │ ├── infer.py │ └── interpreter.py ├── passes │ ├── __init__.py │ ├── cast_aten_where_arg_type.py │ ├── cast_clamp_mixed_type_args.py │ ├── cast_mixed_type_args.py │ ├── const_prop_pass.py │ ├── convert_conv1d_to_conv2d.py │ ├── convert_expand_to_slice_cat.py │ ├── convert_layout_op_to_reshape.py │ ├── convert_matmul_to_linear.py │ ├── convert_repeat_to_expand_copy.py │ ├── convert_to_relu6.py │ ├── decompose_addmm.py │ ├── decompose_batch_norm.py │ ├── decompose_fake_quantize.py │ ├── decompose_fake_quantize_tensor_qparams.py │ ├── decompose_group_norm.py │ ├── decompose_grouped_conv2d.py │ ├── decompose_slice_scatter.py │ ├── extract_dtype_kwargs.py │ ├── fill_meta_val.py │ ├── fuse_leading_unsqueeze_reshape.py │ ├── fuse_redundant_reshape_to_mean.py │ ├── legalize_causal_mask_value.py │ ├── legalize_predefined_layout_operators.py │ ├── lower_pow2_to_mul.py │ ├── lower_to_resize_nearest_neighbor.py │ ├── lower_to_slice.py │ ├── merge_consecutive_cat.py │ ├── ops.py │ ├── remove_nop.py │ ├── remove_redundant_assert_nodes.py │ ├── remove_redundant_expand.py │ ├── remove_redundant_permute.py │ ├── remove_redundant_reshape.py │ ├── remove_redundant_slice.py │ ├── remove_redundant_to_copy.py │ ├── restore_linear.py │ └── segment_index_select.py ├── pt2_to_circle.py ├── quantization │ ├── README.md │ ├── __init__.py │ ├── algorithm │ │ ├── README.md │ │ ├── __init__.py │ │ ├── fpi_gptq │ │ │ ├── __init__.py │ │ │ ├── fpi_gptq.py │ │ │ └── quantizer.py │ │ ├── gptq │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── gptq.py │ │ │ ├── quant.py │ │ │ ├── quantizer.py │ │ │ └── utils.py │ │ ├── pt2e │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── annotation │ │ │ │ ├── __init__.py │ │ │ │ ├── annotator.py │ │ │ │ ├── config.py │ │ │ │ ├── op │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── adaptive_avg_pool2d.py │ │ │ │ │ ├── add.py │ │ │ │ │ ├── conv2d.py │ │ │ │ │ ├── div.py │ │ │ │ │ ├── linear.py │ │ │ │ │ ├── mean.py │ │ │ │ │ ├── mul.py │ │ │ │ │ ├── relu6.py │ │ │ │ │ ├── rsqrt.py │ │ │ │ │ └── sub.py │ │ │ │ ├── spec.py │ │ │ │ └── utils.py │ │ │ ├── quantizer.py │ │ │ ├── transformation │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ └── convert_scalars_to_attrs.py │ │ │ └── utils.py │ │ └── smoothquant │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── observer.py │ │ │ ├── quantizer.py │ │ │ ├── smooth_quant.py │ │ │ └── smooth_quant.txt │ ├── config │ │ ├── __init__.py │ │ ├── base.py │ │ ├── fpi_gptq.py │ │ ├── gptq.py │ │ ├── pt2e.py │ │ ├── ptq.py │ │ └── smoothquant.py │ ├── evaluation │ │ ├── README.md │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── evaluate.py │ │ ├── executor │ │ │ ├── __init__.py │ │ │ ├── backend_executor.py │ │ │ ├── circle_executor.py │ │ │ └── triv24_executor.py │ │ ├── metric.py │ │ └── utils.py │ ├── passes │ │ ├── __init__.py │ │ ├── fold_quant_ops.py │ │ ├── insert_quantize_on_dtype_mismatch.py │ │ ├── propagate_qparam_backward.py │ │ ├── propagate_qparam_forward.py │ │ ├── quantize_bias.py │ │ └── remove_weight_dequant_op.py │ ├── public_interface.py │ ├── quantizer.py │ ├── quantizer_registry.py │ └── wrapq │ │ ├── README.md │ │ ├── __init__.py │ │ ├── dtypes.py │ │ ├── examples │ │ ├── __init__.py │ │ ├── compare_ppl.py │ │ ├── debug_quant_outputs.py │ │ ├── quantize_linear.py │ │ ├── quantize_llama_attn.py │ │ ├── quantize_llama_decoder_layer.py │ │ ├── quantize_llama_mlp.py │ │ └── quantize_with_gptq.py │ │ ├── mode.py │ │ ├── observers │ │ ├── __init__.py │ │ ├── affine_base.py │ │ ├── base.py │ │ ├── ema.py │ │ ├── identity.py │ │ ├── minmax.py │ │ └── mx.py │ │ ├── qscheme.py │ │ ├── quantizer.py │ │ ├── utils │ │ ├── __init__.py │ │ ├── introspection.py │ │ ├── metrics.py │ │ └── reduce_utils.py │ │ └── wrappers │ │ ├── __init__.py │ │ ├── fairseq │ │ ├── __init__.py │ │ ├── decoder_export_single_step.py │ │ ├── quant_decoder.py │ │ ├── quant_decoder_layer.py │ │ ├── quant_encoder.py │ │ ├── quant_encoder_layer.py │ │ └── quant_mha.py │ │ ├── llama │ │ ├── __init__.py │ │ ├── quant_attn.py │ │ ├── quant_decoder_layer.py │ │ └── quant_mlp.py │ │ ├── nn │ │ ├── __init__.py │ │ ├── quant_layernorm.py │ │ ├── quant_linear.py │ │ └── quant_silu.py │ │ ├── ptq_wrapper.py │ │ ├── quant_elementwise.py │ │ ├── quant_module_base.py │ │ └── registry.py ├── serialize │ ├── __init__.py │ ├── circle_graph.py │ ├── circle_mapping.py │ ├── circle_serializer.py │ ├── operators │ │ ├── __init__.py │ │ ├── adapters │ │ │ ├── __init__.py │ │ │ └── llama_rmsnorm.py │ │ ├── hashable_opcode.py │ │ ├── node_visitor.py │ │ ├── op_abs.py │ │ ├── op_add.py │ │ ├── op_alias_copy.py │ │ ├── op_any.py │ │ ├── op_arange_start_step.py │ │ ├── op_argmax.py │ │ ├── op_avg_pool2d.py │ │ ├── op_bmm.py │ │ ├── op_cat.py │ │ ├── op_clamp.py │ │ ├── op_clone.py │ │ ├── op_constant_pad_nd.py │ │ ├── op_conv2d.py │ │ ├── op_copy.py │ │ ├── op_cos.py │ │ ├── op_cumsum.py │ │ ├── op_depthwise_conv2d.py │ │ ├── op_dequantize_per_channel.py │ │ ├── op_dequantize_per_tensor.py │ │ ├── op_div.py │ │ ├── op_embedding.py │ │ ├── op_eq.py │ │ ├── op_exp.py │ │ ├── op_expand.py │ │ ├── op_full.py │ │ ├── op_full_like.py │ │ ├── op_ge.py │ │ ├── op_gelu.py │ │ ├── op_gt.py │ │ ├── op_index.py │ │ ├── op_index_select.py │ │ ├── op_instance_norm.py │ │ ├── op_le.py │ │ ├── op_leaky_relu.py │ │ ├── op_linear.py │ │ ├── op_log.py │ │ ├── op_log1p.py │ │ ├── op_logical_and.py │ │ ├── op_logical_not.py │ │ ├── op_lt.py │ │ ├── op_max_dim.py │ │ ├── op_max_pool2d_with_indices.py │ │ ├── op_maximum.py │ │ ├── op_mean.py │ │ ├── op_minimum.py │ │ ├── op_mm.py │ │ ├── op_mul.py │ │ ├── op_ne.py │ │ ├── op_neg.py │ │ ├── op_permute.py │ │ ├── op_pow.py │ │ ├── op_prelu.py │ │ ├── op_quantize_per_tensor.py │ │ ├── op_reciprocal.py │ │ ├── op_relu.py │ │ ├── op_relu6.py │ │ ├── op_repeat.py │ │ ├── op_reshape.py │ │ ├── op_resize_nearest_neighbor.py │ │ ├── op_rmsnorm.py │ │ ├── op_round.py │ │ ├── op_rsqrt.py │ │ ├── op_scalar_tensor.py │ │ ├── op_select_copy.py │ │ ├── op_sigmoid.py │ │ ├── op_sin.py │ │ ├── op_slice.py │ │ ├── op_softmax.py │ │ ├── op_split_with_sizes.py │ │ ├── op_sqrt.py │ │ ├── op_squeeze.py │ │ ├── op_sub.py │ │ ├── op_sum.py │ │ ├── op_tanh.py │ │ ├── op_to_copy.py │ │ ├── op_transpose_conv.py │ │ ├── op_unsqueeze.py │ │ ├── op_view.py │ │ ├── op_where.py │ │ └── utils.py │ ├── pack.py │ └── quant_param.py └── utils │ ├── __init__.py │ ├── convert.py │ ├── define.py │ ├── diff_graph.py │ ├── dtype.py │ ├── errors.py │ ├── graph.py │ ├── installed_packages.py │ ├── logging.py │ ├── model.py │ ├── mx │ ├── __init__.py │ ├── elemwise_ops.py │ ├── formats.py │ └── mx_ops.py │ ├── padding.py │ ├── passes.py │ ├── pytree_utils.py │ ├── record_input.py │ ├── register_custom_op.py │ ├── serialize.py │ ├── signature.py │ ├── torch_compat.py │ ├── trace_decorators.py │ ├── utils.py │ └── validate_args_kwargs.py └── version.py /.ahub/tcchecker-tca/config.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | test: 3 | - name: TICO_TCA 4 | testCaseLanguage: PYTHON 5 | testFW: PYUNIT 6 | testCaseFolder: 7 | - test/ 8 | 9 | testFile: 10 | - extension: py 11 | starts: 12 | - test_ 13 | 14 | testCase: 15 | - condition: 16 | - functionName: 17 | starts: 18 | - test_ 19 | 20 | positiveTestCase: 21 | - condition: 22 | - inverse: negativeTestCase 23 | 24 | negativeTestCase: 25 | - condition: 26 | - testName: 27 | ends: 28 | - _neg 29 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | include = 3 | tico/* 4 | omit = 5 | ; From external sources 6 | tico/utils/mx/* 7 | tico/quantization/algorithm/gptq/* 8 | 9 | [report] 10 | exclude_lines = 11 | if TYPE_CHECKING: 12 | 13 | -------------------------------------------------------------------------------- /.github/actionlint.yaml: -------------------------------------------------------------------------------- 1 | self-hosted-runner: 2 | labels: 3 | - tico-linux 4 | -------------------------------------------------------------------------------- /.github/actions/download-release-asset/action.yml: -------------------------------------------------------------------------------- 1 | name: Download release asset 2 | description: Download a specific file from a GitHub release 3 | 4 | inputs: 5 | owner: 6 | description: GitHub owner of the repository 7 | required: true 8 | repo: 9 | description: GitHub repository name 10 | required: true 11 | tag: 12 | description: GitHub release tag (e.g., v1.2.3) 13 | required: true 14 | filename: 15 | description: File name to download from the release 16 | required: true 17 | 18 | outputs: 19 | filename: 20 | description: The downloaded file name 21 | value: ${{ steps.download-step.outputs.filename }} 22 | 23 | runs: 24 | using: composite 25 | steps: 26 | - shell: bash 27 | id: download-step 28 | run: | 29 | cd .github/actions/download-release-asset/ 30 | chmod +x entrypoint.sh 31 | ./entrypoint.sh \ 32 | "${{ inputs.owner }}" \ 33 | "${{ inputs.repo }}" \ 34 | "${{ inputs.tag }}" \ 35 | "${{ inputs.filename }}" 36 | -------------------------------------------------------------------------------- /.github/actions/download-release-asset/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | echo "[download-release-asset] Start" 6 | 7 | # Input parameters 8 | OWNER=$1 9 | REPO=$2 10 | TAG=$3 11 | FILENAME=$4 12 | 13 | echo "[download-release-asset] Target filename: $FILENAME" 14 | 15 | API_URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/tags/${TAG}" 16 | DOWNLOAD_URL=$(curl -s $API_URL | jq -r ".assets[] | select(.name == \"$FILENAME\") | .browser_download_url") 17 | 18 | if [ -z "$DOWNLOAD_URL" ]; then 19 | echo "[download-release-asset] ERROR: file not found in release assets: $FILENAME" 20 | exit 1 21 | fi 22 | 23 | echo "[download-release-asset] Downloading from: $DOWNLOAD_URL" 24 | curl -L -o "$FILENAME" "$DOWNLOAD_URL" 25 | echo "filename=$FILENAME" >> "$GITHUB_OUTPUT" 26 | 27 | echo "[download-release-asset] End" 28 | -------------------------------------------------------------------------------- /.github/workflows/check-workflows.yaml: -------------------------------------------------------------------------------- 1 | name: Check Workflows 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | # WARNING: PRs from forked repo are not allowed. 7 | - '.github/workflows/**' 8 | 9 | jobs: 10 | actionlint: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Run actionlint 16 | run: | 17 | bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/main/scripts/download-actionlint.bash) 1.7.7 18 | ./actionlint -color -verbose 19 | shell: bash 20 | -------------------------------------------------------------------------------- /.github/workflows/publish-nightly-package.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Nightly Package 2 | 3 | on: 4 | schedule: 5 | # 05:00 AM (KST) Mon-Fri 6 | - cron: "00 20 * * 0-4" 7 | 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build-and-upload: 12 | runs-on: ubuntu-22.04 13 | strategy: 14 | matrix: 15 | ubuntu_version: [22.04] 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: "Build package" 20 | run: | 21 | NIGHTLY_VERSION=$(date '+%y%m%d') 22 | export NIGHTLY_VERSION 23 | echo "NIGHTLY_VERSION=$NIGHTLY_VERSION" >> "$GITHUB_ENV" 24 | ./ccex build 25 | 26 | - name: "Upload artifact" 27 | uses: actions/upload-artifact@v4 28 | with: 29 | name: "${{ matrix.ubuntu_version }}_wheel" 30 | path: "./dist/" 31 | 32 | publish-to-pypi: 33 | needs: 34 | - build-and-upload 35 | runs-on: ubuntu-22.04 36 | strategy: 37 | matrix: 38 | ubuntu_version: [22.04] 39 | 40 | environment: 41 | name: pypi 42 | url: https://pypi.org/p/TICO 43 | 44 | permissions: 45 | id-token: write # IMPORTANT: mandatory for trusted publishing 46 | 47 | steps: 48 | - name: "Download all the dists" 49 | uses: actions/download-artifact@v4 50 | with: 51 | name: "${{ matrix.ubuntu_version }}_wheel" 52 | path: "./dist/" 53 | 54 | - name: "Publish distribution to PyPI" 55 | uses: pypa/gh-action-pypi-publish@release/v1 56 | -------------------------------------------------------------------------------- /.github/workflows/publish-official-package.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Official Package 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | ref_name: 7 | description: 'Git reference (branch or tag) to build and publish' 8 | required: true 9 | type: string 10 | 11 | jobs: 12 | build-and-upload: 13 | runs-on: ubuntu-22.04 14 | 15 | steps: 16 | - name: Checkout the specified ref 17 | uses: actions/checkout@v4 18 | with: 19 | ref: ${{ inputs.ref_name }} 20 | 21 | # Verify version consistency between tag and version.py 22 | - name: Check version.py matches ref_name 23 | run: | 24 | FILE_VERSION=$(python -c 'exec(open("version.py").read()); print(VERSION)') 25 | INPUT_REF="${{ inputs.ref_name }}" 26 | TAG_VERSION="${INPUT_REF#v}" # Strip leading 'v' from ref_name 27 | 28 | echo "VERSION in version.py: $FILE_VERSION" 29 | echo "VERSION from ref_name: $TAG_VERSION" 30 | 31 | if [ "$FILE_VERSION" != "$TAG_VERSION" ]; then 32 | echo "::error::VERSION in version.py ($FILE_VERSION) does not match ref_name ($INPUT_REF)" 33 | exit 1 34 | fi 35 | 36 | - name: "Build package" 37 | run: | 38 | ./ccex build 39 | 40 | - name: "Upload artifact" 41 | uses: actions/upload-artifact@v4 42 | with: 43 | name: "wheel" 44 | path: "./dist/" 45 | 46 | publish-to-pypi: 47 | needs: 48 | - build-and-upload 49 | runs-on: ubuntu-22.04 50 | 51 | environment: 52 | name: pypi 53 | url: https://pypi.org/p/TICO 54 | 55 | permissions: 56 | id-token: write # IMPORTANT: mandatory for trusted publishing 57 | 58 | steps: 59 | - name: "Download all the dists" 60 | uses: actions/download-artifact@v4 61 | with: 62 | name: "wheel" 63 | path: "./dist/" 64 | 65 | - name: "Publish distribution to PyPI" 66 | uses: pypa/gh-action-pypi-publish@release/v1 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # Distribution 6 | build/ 7 | dist/ 8 | *.egg-info/ 9 | .eggs/ 10 | 11 | # Virtual environment 12 | .venv 13 | test/.venv 14 | infra/style/.venv 15 | 16 | # Model 17 | *.circle 18 | *.pt2 19 | 20 | # Reports for test results 21 | test/reports 22 | .coverage 23 | -------------------------------------------------------------------------------- /.lintrunner.toml: -------------------------------------------------------------------------------- 1 | merge_base_with = "origin/main" 2 | 3 | [[linter]] 4 | code = 'PYLINT' 5 | include_patterns = ['**/*.py'] 6 | exclude_patterns = [ 7 | 'submodules/**', 8 | '**/submodules/**', 9 | 'tico/utils/mx/**', 10 | ] 11 | command = [ 12 | 'lintrunner_adapters', 13 | 'run', 14 | 'pylint_linter', 15 | '--', 16 | '--rcfile=.pylintrc', 17 | '--jobs=0', 18 | '@{{PATHSFILE}}' 19 | ] 20 | init_command = [ 21 | 'lintrunner_adapters', 22 | 'run', 23 | 'pip_init', 24 | '--dry-run={{DRYRUN}}', 25 | '--requirement=requirements-lintrunner.txt', 26 | ] 27 | 28 | # Black + usort 29 | [[linter]] 30 | code = 'UFMT' 31 | include_patterns = [ 32 | '**/*.py', 33 | '**/*.pyi', 34 | ] 35 | exclude_patterns = [ 36 | 'submodules/**', 37 | '**/submodules/**', 38 | 'tico/utils/mx/**', 39 | ] 40 | command = [ 41 | 'lintrunner_adapters', 42 | 'run', 43 | 'ufmt_linter', 44 | '--', 45 | '@{{PATHSFILE}}' 46 | ] 47 | init_command = [ 48 | 'lintrunner_adapters', 49 | 'run', 50 | 'pip_init', 51 | '--dry-run={{DRYRUN}}', 52 | '--no-black-binary', 53 | '--requirement=requirements-lintrunner.txt', 54 | ] 55 | is_formatter = true 56 | 57 | [[linter]] 58 | code = 'MYPY' 59 | include_patterns = [ 60 | '**/*.py', 61 | '**/*.pyi', 62 | ] 63 | exclude_patterns = [ 64 | 'submodules/**', 65 | '**/submodules/**', 66 | 'old/**', 67 | 'tico/utils/mx/**', 68 | ] 69 | command = [ 70 | 'lintrunner_adapters', 71 | 'run', 72 | 'mypy_linter', 73 | '--config=mypy.ini', 74 | '--show-notes', 75 | '--show-disable', 76 | '--', 77 | '@{{PATHSFILE}}', 78 | ] 79 | init_command = [ 80 | 'lintrunner_adapters', 81 | 'run', 82 | 'pip_init', 83 | '--dry-run={{DRYRUN}}', 84 | '--requirement=requirements-lintrunner.txt', 85 | ] 86 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable= 3 | W, C, R, 4 | # Docstring 5 | C0114, # missing-module-docstring 6 | C0115, # missing-class-docstring 7 | C0116, # missing-function-docstring 8 | # Naming 9 | C0103, # invalid-name 10 | # Length 11 | C0301, # line-too-long 12 | # Imports 13 | E0401, # import-error 14 | W0611, # unused-import 15 | # Misc 16 | E1120, # no-value-for-parameter 17 | E1102, # not-callable 18 | W0107, # unnecessary-pass 19 | W0511, # fixme 20 | W1309, # f-string-without-interpolation 21 | -------------------------------------------------------------------------------- /ccex: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CCEX_COMMAND_RPATH="infra/command" 4 | CCEX_PROJECT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 5 | 6 | function Usage() 7 | { 8 | echo "Usage: $0 [COMMAND] ..." 9 | echo "Command:" 10 | for file in "$CCEX_COMMAND_RPATH"/*; 11 | do 12 | echo " $(basename "$file")" 13 | done 14 | } 15 | 16 | COMMAND=$1; shift 17 | 18 | if [[ -z "${COMMAND}" ]]; then 19 | Usage 20 | exit 255 21 | fi 22 | 23 | COMMAND_FILE="${CCEX_PROJECT_PATH}/${CCEX_COMMAND_RPATH}/${COMMAND}" 24 | 25 | if [[ ! -f "${COMMAND_FILE}" ]]; then 26 | echo "ERROR: '${COMMAND}' is not supported" 27 | exit 255 28 | fi 29 | 30 | export CCEX_PROJECT_PATH 31 | 32 | source "${COMMAND_FILE}" "$@" 33 | -------------------------------------------------------------------------------- /infra/command/build: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is called by `ccex` 3 | # 4 | # [EXPORTED VARIABLES] 5 | # - CCEX_PROJECT_PATH 6 | # 7 | # [WHAT IT DOES] 8 | # - Build a package 9 | 10 | CCEX_SCRIPTS_PATH="${CCEX_PROJECT_PATH}/infra/scripts" 11 | 12 | source ${CCEX_SCRIPTS_PATH}/build.sh "$@" 13 | -------------------------------------------------------------------------------- /infra/command/configure: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is called by `ccex` 3 | # 4 | # [EXPORTED VARIABLES] 5 | # - CCEX_PROJECT_PATH 6 | # 7 | # [WHAT IT DOES] 8 | # - Install packages for formatting 9 | # - Create pyhton virtual env for testing 10 | 11 | CCEX_SCRIPTS_PATH="${CCEX_PROJECT_PATH}/infra/scripts" 12 | CCEX_STYLE_PATH="${CCEX_PROJECT_PATH}/infra/style" 13 | 14 | function Usage() 15 | { 16 | echo "Usage: $0 configure [COMMAND] ..." 17 | echo "Command:" 18 | echo " all (default) prepare for both format/style checkers and testing environment" 19 | echo " format prepare for format/style checkers only" 20 | echo " test prepare for testing environment only" 21 | echo " --help, -h show this help message" 22 | } 23 | 24 | function Formatting() 25 | { 26 | bash ${CCEX_STYLE_PATH}/install.sh "$@" 27 | } 28 | 29 | function Testing() 30 | { 31 | bash ${CCEX_SCRIPTS_PATH}/test_configure.sh "$@" 32 | } 33 | 34 | _RUN_FORMAT=1 35 | _RUN_TEST=1 36 | 37 | COMMAND="$1" 38 | if [[ -z "${COMMAND}" ]]; then 39 | COMMAND="all" 40 | fi 41 | 42 | case $COMMAND in 43 | -h|--help ) 44 | Usage 45 | exit 0 46 | ;; 47 | format ) 48 | _RUN_FORMAT=1 49 | _RUN_TEST=0 50 | shift 51 | ;; 52 | test ) 53 | _RUN_FORMAT=0 54 | _RUN_TEST=1 55 | shift 56 | ;; 57 | all ) 58 | _RUN_FORMAT=1 59 | _RUN_TEST=1 60 | shift 61 | ;; 62 | '--'*[a-z] ) 63 | # skip for options (ex. --dev) 64 | ;; 65 | *) 66 | echo "[ERROR] Unknown parameter passed: $COMMAND"; 67 | Usage 68 | exit 255 69 | ;; 70 | esac 71 | 72 | if [ ${_RUN_FORMAT} -eq 1 ]; then 73 | echo "Prepare format/style checkers..." 74 | Formatting "$@" 75 | if [ $? -ne 0 ]; then 76 | echo "[ERROR] Failed to install formatters." 77 | exit 255 78 | fi 79 | fi 80 | 81 | if [ ${_RUN_TEST} -eq 1 ]; then 82 | echo "Prepare testing environment..." 83 | Testing "$@" 84 | if [ $? -ne 0 ]; then 85 | echo "[ERROR] Failed to install test dependencies." 86 | exit 255 87 | fi 88 | fi 89 | 90 | -------------------------------------------------------------------------------- /infra/command/coverage: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is called by `ccex` 3 | # 4 | # [EXPORTED VARIABLES] 5 | # - CCEX_PROJECT_PATH 6 | # 7 | # [WHAT IT DOES] 8 | # - Parse the test coverage 9 | # - Make a report for the test coverage 10 | 11 | CCEX_SCRIPTS_PATH="${CCEX_PROJECT_PATH}/infra/scripts" 12 | 13 | source ${CCEX_SCRIPTS_PATH}/coverage.sh "$@" 14 | -------------------------------------------------------------------------------- /infra/command/format: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is called by `ccex` 3 | # 4 | # [EXPORTED VARIABLES] 5 | # - CCEX_PROJECT_PATH 6 | # 7 | # [WHAT IT DOES] 8 | # - Do formatting 9 | 10 | CCEX_STYLE_PATH="${CCEX_PROJECT_PATH}/infra/style" 11 | 12 | source ${CCEX_STYLE_PATH}/format.sh "$@" 13 | -------------------------------------------------------------------------------- /infra/command/install: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is called by `ccex` 3 | # 4 | # [EXPORTED VARIABLES] 5 | # - CCEX_PROJECT_PATH 6 | # 7 | # [WHAT IT DOES] 8 | # - Install built `tico` package 9 | # - Install runtime dependencies 10 | 11 | CCEX_SCRIPTS_PATH="${CCEX_PROJECT_PATH}/infra/scripts" 12 | 13 | source ${CCEX_SCRIPTS_PATH}/install.sh "$@" 14 | -------------------------------------------------------------------------------- /infra/command/test: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is called by `ccex` 3 | # 4 | # [EXPORTED VARIABLES] 5 | # - CCEX_PROJECT_PATH 6 | # 7 | # [WHAT IT DOES] 8 | # - Run test 9 | 10 | CCEX_SCRIPTS_PATH="${CCEX_PROJECT_PATH}/infra/scripts" 11 | 12 | python3 ${CCEX_SCRIPTS_PATH}/test_run.py "$@" 13 | -------------------------------------------------------------------------------- /infra/dependency/torch_dev.txt: -------------------------------------------------------------------------------- 1 | torch==2.10.0.dev20251012+cpu 2 | -------------------------------------------------------------------------------- /infra/dependency/torchvision_dev.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.25.0.dev20251012+cpu 2 | -------------------------------------------------------------------------------- /infra/scripts/build.sh: -------------------------------------------------------------------------------- 1 | python3 setup.py bdist_wheel --dist-dir dist 2 | -------------------------------------------------------------------------------- /infra/scripts/coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This script is called by `ccex` 18 | # 19 | # [EXPORTED VARIABLES] 20 | # - CCEX_PROJECT_PATH 21 | 22 | TEST_DIR="${CCEX_PROJECT_PATH}/test" 23 | COVERAGE_REPORT_DIR="${TEST_DIR}/reports/cov" 24 | 25 | # The test should be run at the project root or test directory 26 | pushd ${CCEX_PROJECT_PATH} > /dev/null 27 | 28 | command_args="$@" 29 | 30 | COVERAGE_EXIST=$(pip list | grep -w coverage) 31 | if [ -z "${COVERAGE_EXIST}" ] > /dev/null; then 32 | echo "'coverage' does not exist." 33 | echo "Run python3 -m pip install coverage==7.6.1" 34 | exit 1 35 | fi 36 | 37 | coverage run -m unittest discover -s ${TEST_DIR} -v 2>&1 38 | 39 | if [ $# -eq 0 ]; then 40 | coverage report -i -m 41 | else 42 | OPTION=$1; shift 43 | if [[ "${OPTION}" != '-f' ]]; then 44 | echo "${OPTION} is not supported" 45 | else 46 | if [ ! -d ${COVERAGE_REPORT_DIR} ] ; then 47 | mkdir -p ${COVERAGE_REPORT_DIR} 48 | fi 49 | 50 | FORMAT=$1; shift 51 | if [[ "${FORMAT}" == 'txt' ]]; then 52 | coverage report -i -m > ${COVERAGE_REPORT_DIR}/coverage.txt 53 | elif [[ "${FORMAT}" == 'xml' ]]; then 54 | coverage xml -i -o ${COVERAGE_REPORT_DIR}/coverage.xml 55 | else 56 | echo "Unknown format: ${FORMAT}" 57 | echo "Following formats are supported: txt, xml" 58 | fi 59 | fi 60 | fi 61 | -------------------------------------------------------------------------------- /infra/scripts/install_requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools==78.1.1 2 | wheel==0.43.0 3 | 4 | circle_schema 5 | 6 | cffi==1.17.1 7 | packaging==25.0 8 | pyyaml==6.0.2 9 | tqdm 10 | -------------------------------------------------------------------------------- /infra/style/format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Immediately exit if any command has a non-zero exit status 4 | set -e 5 | 6 | APPLY_PATCH_OPTION="-a" 7 | while [[ "$#" -gt 0 ]]; do 8 | case $1 in 9 | --no-apply-patches) 10 | APPLY_PATCH_OPTION="" 11 | shift 12 | ;; 13 | -d|--diff-only) 14 | CHECK_DIFF_ONLY="1" 15 | shift 16 | ;; 17 | *) echo "[ERROR] Unknown parameter passed: $1"; exit 255 ;; 18 | esac 19 | done 20 | 21 | GIT_ROOT="$(git rev-parse --show-toplevel)" 22 | STYLE_DIR="${GIT_ROOT}/infra/style" 23 | 24 | if [[ "${CHECK_DIFF_ONLY}" = "1" ]]; then 25 | MAIN_EXIST=$(git rev-parse --verify main) 26 | CURRENT_BRANCH=$(git branch | grep \* | cut -d ' ' -f2-) 27 | DIFF_COMMITS=`git log --graph --oneline main..HEAD | wc -l` 28 | if [[ -z "${MAIN_EXIST}" ]]; then 29 | echo "Cannot find main branch" 30 | exit 1 31 | elif [[ "${CURRENT_BRANCH}" = "main" ]]; then 32 | echo "Current branch is main" 33 | exit 1 34 | else 35 | # Gather diff from HEAD 36 | FILES_TO_CHECK=$(git diff --name-only --diff-filter=d HEAD~${DIFF_COMMITS}) 37 | 38 | # Remove links 39 | # Git file mode 40 | # 120000: symbolic link 41 | # 160000: git link 42 | # Reference: https://github.com/git/git/blob/cd42415/Documentation/technical/index-format.txt#L72-L81 43 | FILES_TO_CHECK=$(git ls-files -c -s --exclude-standard ${FILES_TO_CHECK[@]} | egrep -v '^1[26]0000' | cut -f2) 44 | fi 45 | 46 | lintrunner --force-color $APPLY_PATCH_OPTION --config "${GIT_ROOT}/.lintrunner.toml" $FILES_TO_CHECK 47 | else 48 | lintrunner --force-color --all-files $APPLY_PATCH_OPTION --config "${GIT_ROOT}/.lintrunner.toml" 49 | fi 50 | -------------------------------------------------------------------------------- /infra/style/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GIT_ROOT="$(git rev-parse --show-toplevel)" 4 | STYLE_DIR="${GIT_ROOT}/infra/style" 5 | 6 | REQ_PKG="requirements.txt" 7 | python3 -m pip install -r "${STYLE_DIR}/${REQ_PKG}" 8 | 9 | lintrunner init --config ${GIT_ROOT}/.lintrunner.toml 10 | -------------------------------------------------------------------------------- /infra/style/requirements.txt: -------------------------------------------------------------------------------- 1 | lintrunner==0.11.0 2 | lintrunner-adapters==0.11.0 3 | types-PyYAML 4 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | show_error_codes = True 3 | show_column_numbers = True 4 | ignore_missing_imports = True 5 | check_untyped_defs = True 6 | -------------------------------------------------------------------------------- /requirements-lintrunner.txt: -------------------------------------------------------------------------------- 1 | # Lintrunner itself 2 | lintrunner==0.11.0 3 | lintrunner-adapters==0.11.0 4 | 5 | # Linter 6 | pylint==3.3.1 7 | astroid==3.3.10 # for pylint 8 | 9 | # Formatter 10 | black==22.12.0 11 | ufmt==2.0.1 12 | usort==1.0.5 13 | 14 | # Type checker 15 | mypy==1.9.0 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """setup.py for TICO""" 2 | 3 | import os 4 | import re 5 | from pathlib import Path 6 | 7 | from setuptools import find_packages, setup 8 | from version import VERSION 9 | 10 | ############################################################ 11 | ### Set dev(nightly) version ### 12 | ############################################################ 13 | if nightly_release_version := os.environ.get("NIGHTLY_VERSION"): 14 | VERSION = f"{VERSION}.dev{nightly_release_version}" 15 | 16 | ############################################################ 17 | ### Update __version__ in __init__.py ### 18 | ############################################################ 19 | with open("tico/__init__.py", "r") as init_file: 20 | init_file_update = re.sub( 21 | "__version__ =.*", f'__version__ = "{VERSION}"', init_file.read() 22 | ) 23 | with open("tico/__init__.py", "w") as init_file: 24 | init_file.write(init_file_update) 25 | 26 | ############################################################ 27 | ### Prepare long_description ### 28 | ############################################################ 29 | readme = Path(__file__).with_name("README.md").read_text(encoding="utf-8") 30 | 31 | ############################################################ 32 | ### Run setup ### 33 | ############################################################ 34 | setup( 35 | name="tico", 36 | python_requires=">=3.10.0", 37 | version=VERSION, 38 | description="Convert exported Torch module to circle", 39 | long_description=readme, 40 | long_description_content_type="text/markdown", 41 | license_files=("LICENSE",), 42 | packages=find_packages(include=["tico*"]), 43 | entry_points={"console_scripts": ["pt2-to-circle = tico.pt2_to_circle:main"]}, 44 | install_requires=["circle-schema", "packaging", "cffi", "torch", "pyyaml", "tqdm"], 45 | ) 46 | -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | # test module 2 | 3 | This is a module for TICO unittest. 4 | 5 | ## How to debug using vscode? 6 | 7 | 1. Add below configuration to `.vscode/launch.json`. 8 | 2. Make a breakpoint on the line you want. 9 | 3. Run the configuration with `Run and Debug`. 10 | 11 | **TIP** Install this project in editable mode (pip install -e) for interactive test. 12 | 13 | ```json 14 | { 15 | "version": "0.2.0", 16 | "configurations": [ 17 | { 18 | "name": "Debug unit test (pt2_to_circle)", 19 | "type": "debugpy", 20 | "request": "launch", 21 | "module": "unittest", 22 | "args": ["test.pt2_to_circle_test.test_pt2_to_circle"], 23 | "cwd": "${workspaceFolder}", 24 | "console": "integratedTerminal", 25 | "justMyCode": false, 26 | }, 27 | ] 28 | } 29 | ``` -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/dump_pt2_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import importlib 17 | import inspect 18 | 19 | import torch 20 | 21 | 22 | def main() -> None: 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | "-m", 26 | "--module_name", 27 | required=True, 28 | help="provide a model name.", 29 | ) 30 | parser.add_argument( 31 | "-o", 32 | "--output", 33 | required=True, 34 | help="provide an output .pt2 model name.", 35 | ) 36 | 37 | args = parser.parse_args() 38 | 39 | path = args.module_name.split(".") 40 | module_name = (".").join(path[:-1]) 41 | model_name = path[-1] 42 | 43 | module = importlib.import_module(module_name) 44 | models = inspect.getmembers(module, inspect.isclass) 45 | 46 | if all(name != model_name for name, _ in models): 47 | raise RuntimeError("Invalid module name") 48 | for name, cls in models: 49 | if name == model_name: 50 | model = cls() 51 | example_inputs = model.get_example_inputs() 52 | 53 | pt2_module = torch.export.export(model, example_inputs) 54 | 55 | torch.export.save(pt2_module, args.output) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /test/modules/README.md: -------------------------------------------------------------------------------- 1 | Test target module directory. 2 | It can be accessed as a module as `test.modules.*` 3 | 4 | ### How to use `target` decorator? 5 | 6 | By default, all torch.nn.Module subclasses in the python script are assumed to be test targets. 7 | If `@target` decorator is set for any torch.nn.Module subclass in the file, only those modules will be tested. 8 | 9 | #### Example 1. net.Llama2 10 | 11 | `@target` decorator is used. 12 | 13 | ```py 14 | class TransformerBlock(nn.Module): # <--- Not tested 15 | def __init__(self, layer_id: int, args: ModelArgs): 16 | super().__init__() 17 | self.use_kv_cache = args.use_kv_cache 18 | self.n_heads = args.n_heads 19 | 20 | @tag.target 21 | class Transformer(nn.Module): # <--- Tested 22 | def __init__(self, params: ModelArgs = ModelArgs()): 23 | super().__init__() 24 | ``` 25 | 26 | #### Example 2. 27 | 28 | `@target` decorator is not used for any module in the file. All torch.nn.Module subclasses are test targets. 29 | 30 | ```py 31 | class SimpleSqueeze(torch.nn.Module): # <--- Tested 32 | def __init__(self): 33 | super().__init__() 34 | 35 | def forward(self, x): 36 | z = torch.squeeze(x) 37 | return z 38 | 39 | def get_example_inputs(self): 40 | torch.manual_seed(1234) 41 | return (torch.randn(2, 1, 2, 1, 2),) 42 | 43 | 44 | class SimpleSqueezeWithDims(torch.nn.Module): # <--- Tested 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x): 49 | z = torch.squeeze(x, dim=(1, 3)) 50 | return z 51 | 52 | def get_example_inputs(self): 53 | torch.manual_seed(1234) 54 | return (torch.randn(2, 1, 2, 1, 2),) 55 | ``` -------------------------------------------------------------------------------- /test/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class ExampleInputMixin: 7 | def get_example_inputs(self) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 8 | raise NotImplementedError("Must implement get_example_inputs") 9 | 10 | 11 | class DynamicShapesMixin: 12 | def get_dynamic_shapes(self) -> Dict[str, Tuple[int, ...]]: # type: ignore[empty-body] 13 | pass 14 | 15 | 16 | class TestModuleBase(nn.Module, ExampleInputMixin, DynamicShapesMixin): 17 | pass 18 | -------------------------------------------------------------------------------- /test/modules/model/Bert/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/Bert/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import BertConfig, BertModel 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class Bert(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = BertModel(config=BertConfig()).to("cpu") 25 | self.rtol = 1e-4 26 | self.atol = 1e-4 27 | 28 | def forward(self, x): 29 | return self.model(x) 30 | 31 | def get_example_inputs(self): 32 | # >>> tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", legacy=True, from_slow=True) 33 | # >>> tokenizer.encode("Hello .") 34 | # [101, 7592, 1026, 1055, 1028, 1012, 102] 35 | return ( 36 | torch.Tensor([[101, 7592, 1026, 1055, 1028, 1012, 102]]).to( 37 | dtype=torch.int32 38 | ), 39 | ), {} 40 | -------------------------------------------------------------------------------- /test/modules/model/Bert/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.53.0 2 | -------------------------------------------------------------------------------- /test/modules/model/BitNet_b1_58/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/BitNet_b1_58/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.55.4 2 | accelerate 3 | -------------------------------------------------------------------------------- /test/modules/model/DeepSeekR1DistillQwen1_5B/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/DeepSeekR1DistillQwen1_5B/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.1 2 | -------------------------------------------------------------------------------- /test/modules/model/EfficientFormerL1/model.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | from test.modules.base import TestModuleBase 4 | 5 | 6 | class EfficientFormerL1(TestModuleBase): 7 | def __init__(self): 8 | super().__init__() 9 | self.model: timm.models.efficientformer.EfficientFormer = ( 10 | timm.create_model("efficientformer_l1", pretrained=True).to("cpu").eval() 11 | ) 12 | self.rtol = 1e-4 13 | self.atol = 1e-4 14 | 15 | def forward(self, x: torch.Tensor): 16 | return self.model.forward(x) 17 | 18 | def get_example_inputs(self): 19 | torch.manual_seed(1) 20 | return (torch.randn(1, 3, 224, 224),), {} 21 | -------------------------------------------------------------------------------- /test/modules/model/EfficientFormerL1/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | timm 3 | -------------------------------------------------------------------------------- /test/modules/model/EfficientNetV2S/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/EfficientNetV2S/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torchvision.models.efficientnet import efficientnet_v2_s 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class EfficientNet(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = efficientnet_v2_s(pretrained=True).to("cpu") 25 | self.rtol = 1e-4 26 | self.atol = 1e-4 27 | 28 | def forward(self, x): 29 | return self.model(x) 30 | 31 | def get_example_inputs(self): 32 | torch.manual_seed(1) 33 | return (torch.randn(1, 3, 16, 16),), {} 34 | -------------------------------------------------------------------------------- /test/modules/model/EfficientNetV2S/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | -------------------------------------------------------------------------------- /test/modules/model/Florence2/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/Florence2/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.49.0 2 | einops 3 | timm 4 | -------------------------------------------------------------------------------- /test/modules/model/GPT2/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/GPT2/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import GPT2Config, GPT2LMHeadModel 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class GPT2(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained("gpt2")).to( 25 | "cpu" 26 | ) 27 | 28 | def forward(self, x): 29 | return self.model(x) 30 | 31 | def get_example_inputs(self): 32 | # >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") 33 | # >>> tokenizer("Hello world")["input_ids"] 34 | # [15496, 995] 35 | input_ids = torch.Tensor([15496, 995]).to(dtype=torch.int32) 36 | return (input_ids,), {} 37 | -------------------------------------------------------------------------------- /test/modules/model/GPT2/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.1 2 | transformers==4.52.4 3 | -------------------------------------------------------------------------------- /test/modules/model/Gemma3/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/Gemma3/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.56.1 2 | -------------------------------------------------------------------------------- /test/modules/model/InceptionV3/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/InceptionV3/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torchvision.models.inception import inception_v3, Inception_V3_Weights 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class InceptionV3(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = inception_v3( 25 | weights=Inception_V3_Weights.DEFAULT, transform_input=False 26 | ).to("cpu") 27 | 28 | def forward(self, x): 29 | return self.model(x) 30 | 31 | def get_example_inputs(self): 32 | torch.manual_seed(1) 33 | return (torch.randn(1, 3, 299, 299),), {} 34 | -------------------------------------------------------------------------------- /test/modules/model/InceptionV3/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | -------------------------------------------------------------------------------- /test/modules/model/Llama/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/Llama/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import LlamaConfig, LlamaModel 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class Llama(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = LlamaModel( 25 | config=LlamaConfig( 26 | hidden_size=512, 27 | num_hidden_layers=8, 28 | num_attention_heads=8, 29 | ) 30 | ).to("cpu") 31 | self.rtol = 1e-4 32 | self.atol = 1e-4 33 | 34 | def forward(self, x): 35 | return self.model(x) 36 | 37 | def get_example_inputs(self): 38 | # >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True) 39 | # >>> tokenizer.encode("Hello .") # 869 is '▁.' 40 | # [1, 15043, 29871, 1, 869] 41 | return (torch.Tensor([[1, 15043, 29871, 1, 869]]).to(dtype=torch.int32),), {} 42 | -------------------------------------------------------------------------------- /test/modules/model/Llama/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.1 2 | transformers==4.52.4 3 | -------------------------------------------------------------------------------- /test/modules/model/LlamaAttentionWithKVCache/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/LlamaAttentionWithKVCache/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.1 2 | transformers==4.49.0 3 | -------------------------------------------------------------------------------- /test/modules/model/LlamaDecoderLayerTRIV/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/LlamaDecoderLayerTRIV/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==2.2.6 2 | transformers==4.52.4 3 | -------------------------------------------------------------------------------- /test/modules/model/LlamaWithGQA/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/LlamaWithGQA/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import LlamaConfig, LlamaModel 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class LlamaWithGQA(TestModuleBase): 22 | """ 23 | Llama model with Group Query Attention. 24 | """ 25 | 26 | def __init__(self): 27 | super().__init__() 28 | self.model = LlamaModel( 29 | config=LlamaConfig( 30 | hidden_size=512, 31 | num_hidden_layers=8, 32 | num_attention_heads=16, 33 | num_key_value_heads=4, 34 | ) 35 | ).to("cpu") 36 | self.rtol = 1e-4 37 | self.atol = 1e-4 38 | 39 | def forward(self, x): 40 | return self.model(x) 41 | 42 | def get_example_inputs(self): 43 | # >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True) 44 | # >>> tokenizer.encode("Hello .") # 869 is '▁.' 45 | # [1, 15043, 29871, 1, 869] 46 | return (torch.Tensor([[1, 15043, 29871, 1, 869]]).to(dtype=torch.int32),), {} 47 | -------------------------------------------------------------------------------- /test/modules/model/LlamaWithGQA/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.1 2 | transformers==4.52.4 3 | -------------------------------------------------------------------------------- /test/modules/model/LlamaWithKVCache/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/LlamaWithKVCache/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.1 2 | transformers==4.51.3 3 | -------------------------------------------------------------------------------- /test/modules/model/Mamba/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/Mamba/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import AutoTokenizer, MambaConfig, MambaForCausalLM 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class Mamba(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | # Prevent the result from being nan 26 | # WARNING This removes non-determinism only partially. 27 | torch.use_deterministic_algorithms(True) 28 | 29 | config = MambaConfig(use_cache=False) 30 | self.model = MambaForCausalLM.from_pretrained( 31 | "state-spaces/mamba-130m-hf", config=config 32 | ).to("cpu") 33 | 34 | self.rtol = 1e-3 35 | self.atol = 1e-3 36 | 37 | def forward(self, x): 38 | return self.model(x) 39 | 40 | def get_example_inputs(self): 41 | tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") 42 | input_ids = tokenizer("What is your name?", return_tensors="pt")[ 43 | "input_ids" 44 | ].to("cpu") 45 | return (input_ids,), {} 46 | -------------------------------------------------------------------------------- /test/modules/model/Mamba/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.0 -------------------------------------------------------------------------------- /test/modules/model/MambaMixer/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/MambaMixer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import MambaConfig, MambaForCausalLM 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class MambaMixer(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | config = MambaConfig() 25 | self.model = MambaForCausalLM.from_pretrained( 26 | "state-spaces/mamba-130m-hf", config=config 27 | ) 28 | 29 | self.rtol = 1e-2 30 | self.atol = 1e-2 31 | 32 | def forward(self, x): 33 | return self.model.backbone.layers[0].mixer(x) 34 | 35 | def get_example_inputs(self): 36 | # Let's fix the seed for now. 37 | # WHY? 1~5 among 4000+ exceeds the error rate with other seeds. 38 | # TODO Find way to increase accuracy 39 | torch.manual_seed(5) 40 | hidden_state = torch.randn(1, 6, 768) 41 | return (hidden_state,), {} 42 | -------------------------------------------------------------------------------- /test/modules/model/MambaMixer/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.0 -------------------------------------------------------------------------------- /test/modules/model/MobileNetV2/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/MobileNetV2/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torchvision.models import mobilenet_v2, MobileNet_V2_Weights 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class MobileNetV2(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).to("cpu") 25 | 26 | def forward(self, x): 27 | return self.model(x) 28 | 29 | def get_example_inputs(self): 30 | torch.manual_seed(1) 31 | return (torch.randn(1, 3, 224, 224),), {} 32 | -------------------------------------------------------------------------------- /test/modules/model/MobileNetV2/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | -------------------------------------------------------------------------------- /test/modules/model/MobileNetV3S/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/MobileNetV3S/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torchvision.models import mobilenet_v3_small 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class MobileNetV3S(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = mobilenet_v3_small(pretrained=True).to("cpu") 25 | 26 | def forward(self, x): 27 | return self.model(x) 28 | 29 | def get_example_inputs(self): 30 | torch.manual_seed(1) 31 | return (torch.randn(1, 3, 224, 224),), {} 32 | -------------------------------------------------------------------------------- /test/modules/model/MobileNetV3S/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | -------------------------------------------------------------------------------- /test/modules/model/ResNet18/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/ResNet18/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torchvision.models.resnet import resnet18, ResNet18_Weights 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class ResNet18(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = resnet18(weights=ResNet18_Weights.DEFAULT).to("cpu") 25 | 26 | def forward(self, x): 27 | return self.model(x) 28 | 29 | def get_example_inputs(self): 30 | torch.manual_seed(1) 31 | return (torch.randn(1, 3, 16, 16),), {} 32 | -------------------------------------------------------------------------------- /test/modules/model/ResNet18/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | -------------------------------------------------------------------------------- /test/modules/model/SmolVLM_connector/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/SmolVLM_connector/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import AutoModelForImageTextToText 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class SmolVLM_connector(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.rtol = 1e-4 25 | self.atol = 1e-4 26 | self.model = AutoModelForImageTextToText.from_pretrained( 27 | "HuggingFaceTB/SmolVLM-256M-Instruct" 28 | ).model.connector.to("cpu") 29 | 30 | def forward(self, *x): 31 | return self.model(*x) 32 | 33 | def get_example_inputs(self): 34 | image_hidden_states = torch.randn(26, 1024, 768) 35 | return (image_hidden_states,), {} 36 | -------------------------------------------------------------------------------- /test/modules/model/SmolVLM_connector/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.1 2 | -------------------------------------------------------------------------------- /test/modules/model/SmolVLM_text_model/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/SmolVLM_text_model/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import AutoModelForImageTextToText 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class SmolVLM_text_model(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.rtol = 1e-2 25 | self.atol = 1e-2 26 | self.model = AutoModelForImageTextToText.from_pretrained( 27 | "HuggingFaceTB/SmolVLM-256M-Instruct" 28 | ).model.text_model.to("cpu") 29 | 30 | def forward(self, *args, **kwargs): 31 | 32 | return self.model(*args, **kwargs) 33 | 34 | def get_example_inputs(self): 35 | kwargs = { 36 | "inputs_embeds": torch.randn(1, 1739, 576), 37 | "attention_mask": torch.ones(1, 1739, dtype=torch.int32), 38 | "use_cache": False, 39 | "output_attentions": False, 40 | "output_hidden_states": False, 41 | "return_dict": True, 42 | } 43 | return (), kwargs 44 | -------------------------------------------------------------------------------- /test/modules/model/SmolVLM_text_model/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.1 2 | -------------------------------------------------------------------------------- /test/modules/model/SmolVLM_vision_model/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/SmolVLM_vision_model/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.1 2 | -------------------------------------------------------------------------------- /test/modules/model/TinyLlama/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/TinyLlama/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from transformers import AutoModelForCausalLM 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | class TinyLlama(TestModuleBase): 22 | def __init__(self): 23 | super().__init__() 24 | self.model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0").to( 25 | "cpu" 26 | ) 27 | self.rtol = 1e-4 28 | self.atol = 1e-4 29 | 30 | def forward(self, x): 31 | return self.model(x) 32 | 33 | def get_example_inputs(self): 34 | # >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True) 35 | # >>> tokenizer.encode("Hello .") # 869 is '▁.' 36 | # [1, 15043, 29871, 1, 869] 37 | return (torch.Tensor([[1, 15043, 29871, 1, 869]]).to(dtype=torch.int32),), {} 38 | -------------------------------------------------------------------------------- /test/modules/model/TinyLlama/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.52.4 2 | -------------------------------------------------------------------------------- /test/modules/model/TinyLlamaWithFusedRMSNorm/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/model/TinyLlamaWithFusedRMSNorm/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from tico.serialize.operators.adapters.llama_rmsnorm import patched_llama_rmsnorm 18 | from tico.utils.pytree_utils import register_dynamic_cache 19 | 20 | from transformers import AutoModelForCausalLM 21 | 22 | from test.modules.base import TestModuleBase 23 | 24 | 25 | class TinyLlamaWithFusedRMSNorm(TestModuleBase): 26 | def __init__(self): 27 | super().__init__() 28 | with patched_llama_rmsnorm(): 29 | self.model = AutoModelForCausalLM.from_pretrained( 30 | "Maykeye/TinyLLama-v0" 31 | ).to("cpu") 32 | self.rtol = 1e-4 33 | self.atol = 1e-4 34 | register_dynamic_cache() 35 | 36 | def forward(self, x): 37 | return self.model(x) 38 | 39 | def get_example_inputs(self): 40 | # >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True) 41 | # >>> tokenizer.encode("Hello .") # 869 is '▁.' 42 | # [1, 15043, 29871, 1, 869] 43 | return (torch.Tensor([[1, 15043, 29871, 1, 869]]).to(dtype=torch.int32),), {} 44 | -------------------------------------------------------------------------------- /test/modules/model/TinyLlamaWithFusedRMSNorm/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.52.4 2 | -------------------------------------------------------------------------------- /test/modules/model/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/net/README.md: -------------------------------------------------------------------------------- 1 | It can be accessed as a module as `test.modules.net.*` 2 | -------------------------------------------------------------------------------- /test/modules/net/RMSNorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 2 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | 21 | B = 4 22 | SEQ_LEN = 8 23 | DIM = 16 24 | 25 | 26 | class LlamaRMSNorm(TestModuleBase): 27 | def __init__(self, hidden_size=DIM, eps=1e-6): 28 | """ 29 | LlamaRMSNorm is equivalent to T5LayerNorm 30 | """ 31 | super().__init__() 32 | self.weight = torch.nn.Parameter(torch.ones(hidden_size)) 33 | self.variance_epsilon = eps 34 | 35 | def forward(self, hidden_states: torch.Tensor): 36 | input_dtype = hidden_states.dtype 37 | hidden_states = hidden_states.to(torch.float32) 38 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 39 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 40 | return self.weight * hidden_states.to(input_dtype) 41 | 42 | def get_example_inputs(self): 43 | # (B, Seq_Len, Dim) 44 | return (torch.randn(B, SEQ_LEN, DIM),), {} 45 | -------------------------------------------------------------------------------- /test/modules/net/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/net/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | B = 4 21 | SEQ_LEN = 8 22 | DIM = 16 23 | INTERMEDATE = 64 24 | 25 | 26 | class MLP(TestModuleBase): 27 | def __init__(self): 28 | super().__init__() 29 | self.gate_proj = torch.nn.Linear(DIM, INTERMEDATE, bias=False) 30 | self.up_proj = torch.nn.Linear(DIM, INTERMEDATE, bias=False) 31 | self.down_proj = torch.nn.Linear(INTERMEDATE, DIM, bias=False) 32 | 33 | def forward(self, x): 34 | down_proj = self.down_proj( 35 | torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x) 36 | ) 37 | return down_proj 38 | 39 | def get_example_inputs(self): 40 | # (B, Seq_Len, Dim) 41 | return (torch.randn(B, SEQ_LEN, DIM),), {} 42 | -------------------------------------------------------------------------------- /test/modules/net/mlp_dyn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch.export import Dim 17 | 18 | from test.modules.base import TestModuleBase 19 | 20 | from test.utils.tag import use_onert 21 | 22 | B = 4 23 | SEQ_LEN = 8 24 | DIM = 16 25 | INTERMEDATE = 64 26 | 27 | 28 | @use_onert 29 | class MLP_DynamicShape(TestModuleBase): 30 | def __init__(self): 31 | super().__init__() 32 | self.gate_proj = torch.nn.Linear(DIM, INTERMEDATE, bias=False) 33 | self.up_proj = torch.nn.Linear(DIM, INTERMEDATE, bias=False) 34 | self.down_proj = torch.nn.Linear(INTERMEDATE, DIM, bias=False) 35 | 36 | def forward(self, x): 37 | down_proj = self.down_proj( 38 | torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x) 39 | ) 40 | return down_proj 41 | 42 | def get_example_inputs(self): 43 | return (torch.randn(B, SEQ_LEN, DIM),), {} 44 | 45 | def get_dynamic_shapes(self): 46 | batch = Dim("batch", min=1, max=128) 47 | dynamic_shapes = { 48 | "x": {0: batch}, 49 | } 50 | 51 | return dynamic_shapes 52 | -------------------------------------------------------------------------------- /test/modules/op/README.md: -------------------------------------------------------------------------------- 1 | It can be accessed as a module as `test.modules.op.*` 2 | 3 | 4 | ### Test target instantiation strategy: test-all, skip-some 5 | 6 | Use `@skip` decorator to mark which torch.nn.Module to skip. 7 | 8 | ```py 9 | # test.modules.op.add 10 | 11 | class SimpleAdd(torch.nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, x, y): 16 | z = x + y 17 | z = z + x 18 | z = z + x 19 | z = z + z 20 | return (z,) 21 | 22 | def get_example_inputs(self): 23 | return (torch.ones(1), torch.ones(1)) 24 | 25 | @skip(reason="Too large!") 26 | class VeryLargeSimpleAdd(torch.nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def forward(self, x, y): 31 | z = x + y 32 | return (z,) 33 | 34 | def get_example_inputs(self): 35 | return (torch.ones(99999999), torch.ones(99999999)) 36 | ``` 37 | -------------------------------------------------------------------------------- /test/modules/op/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/modules/op/abs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleAbs(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, input): 25 | return torch.abs(input) 26 | 27 | def get_example_inputs(self): 28 | return (torch.randn(3, 3),), {} 29 | 30 | 31 | class SimpleAbsWithNone(TestModuleBase): 32 | def __init__(self): 33 | super().__init__() 34 | 35 | def forward(self, input): 36 | return torch.abs(input), None 37 | 38 | def get_example_inputs(self): 39 | return (torch.randn(3, 3),), {} 40 | -------------------------------------------------------------------------------- /test/modules/op/alias_copy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleAliasCopy(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | result = torch.alias_copy(x) 26 | # Intentionally insert `add` op as there's no op after removing no-op. 27 | result += 2.0 28 | return result 29 | 30 | def get_example_inputs(self): 31 | return (torch.randn(3, 3),), {} 32 | 33 | 34 | class SimpleAliasCopyWithConstantTensor(TestModuleBase): 35 | def __init__(self): 36 | super().__init__() 37 | 38 | def forward(self, x): 39 | result = torch.alias_copy(x) 40 | # Intentionally insert `add` op as there's no op after removing no-op. 41 | result += 2.0 42 | return result 43 | 44 | def get_example_inputs(self): 45 | return (torch.Tensor([1.0, 2.0, 3.0]),), {} 46 | -------------------------------------------------------------------------------- /test/modules/op/argmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleArgMax(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.argmax(x, dim=0) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | torch.manual_seed(1234) 30 | return (torch.randn(3, 3),), {} 31 | 32 | 33 | class SimpleArgMaxWithNegativeDim(TestModuleBase): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | def forward(self, x): 38 | z = torch.argmax(x, dim=-1) 39 | return z 40 | 41 | def get_example_inputs(self): 42 | torch.manual_seed(1234) 43 | return (torch.randn(3, 3),), {} 44 | 45 | 46 | class SimpleArgMaxWithRankThreeTensor(TestModuleBase): 47 | def __init__(self): 48 | super().__init__() 49 | 50 | def forward(self, x): 51 | z = torch.argmax(x, dim=0) 52 | return z 53 | 54 | def get_example_inputs(self): 55 | torch.manual_seed(1234) 56 | return (torch.randn(3, 3, 3),), {} 57 | -------------------------------------------------------------------------------- /test/modules/op/cat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleCatWithDim(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, tensors): 25 | z = torch.cat(tensors=tensors, dim=1) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return ((torch.zeros(3, 3), torch.ones(3, 3)),), {} 30 | 31 | 32 | class SimpleCatDefault(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, tensors): 37 | z = torch.cat(tensors) 38 | return z 39 | 40 | def get_example_inputs(self): 41 | return ((torch.zeros(3), torch.ones(2)),), {} 42 | 43 | 44 | class SimpleCatThreeTensors(TestModuleBase): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, tensors): 49 | z = torch.cat(tensors=tensors, dim=2) 50 | return z 51 | 52 | def get_example_inputs(self): 53 | return ( 54 | ( 55 | torch.zeros(3, 3, 2), 56 | torch.ones(3, 3, 1), 57 | torch.ones(3, 3, 1), 58 | ), 59 | ), {} 60 | -------------------------------------------------------------------------------- /test/modules/op/copy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleCopy(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, dst, src): 25 | dst.copy_(src) 26 | return dst 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(5, 5), torch.randn(5, 5)), {} 30 | 31 | 32 | class SimpleCopyWithBroadcastTo(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, dst, src): 37 | dst.copy_(src) 38 | return dst 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(5, 5), torch.randn(1, 5)), {} 42 | -------------------------------------------------------------------------------- /test/modules/op/cos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleCos(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, tensor): 25 | result = torch.cos(tensor) 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/detach.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleDetach(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = x.detach() 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.zeros(3, 3),), {} 30 | 31 | 32 | class SimpleDetachConst(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x): 37 | x = torch.ones(3, 3) + x 38 | z = x.detach() 39 | return z 40 | 41 | def get_example_inputs(self): 42 | return (torch.randn(3, 3),), {} 43 | -------------------------------------------------------------------------------- /test/modules/op/exp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleExp(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, tensor): 25 | result = torch.exp(tensor) 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/expand_copy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleExpandCopy(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.expand_copy(x, size=(1, 3, 4)) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | torch.manual_seed(1) 30 | return (torch.randn(1, 4),), {} 31 | 32 | 33 | class SimpleExpandCopyMinusDim(TestModuleBase): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | def forward(self, x): 38 | z = torch.expand_copy(x, size=(3, -1)) 39 | return z 40 | 41 | def get_example_inputs(self): 42 | torch.manual_seed(1) 43 | return (torch.randn(1, 4),), {} 44 | -------------------------------------------------------------------------------- /test/modules/op/full.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleFull(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | y = torch.full((2, 3), 1.0) 26 | z = x + y 27 | return z 28 | 29 | def get_example_inputs(self): 30 | return (torch.randn(2, 3),), {} 31 | 32 | 33 | class SimpleFullBool(TestModuleBase): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | def forward(self, x): 38 | y = torch.full((2, 3), True) 39 | z = x + y 40 | return z 41 | 42 | def get_example_inputs(self): 43 | return (torch.randn(2, 3),), {} 44 | -------------------------------------------------------------------------------- /test/modules/op/ge.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleGeWithScalarFloat(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x, y): 25 | result = x >= y 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(1, 3), 2.0), {} 30 | 31 | 32 | class SimpleGeWithScalarInt(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x, y): 37 | result = x >= y 38 | return result 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(1, 3), 2), {} 42 | 43 | 44 | class SimpleGeWithTensor(TestModuleBase): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x, y): 49 | result = x >= y 50 | return result 51 | 52 | def get_example_inputs(self): 53 | return (torch.randn(2, 3), torch.randn(2, 3)), {} 54 | 55 | 56 | class SimpleGeWithDifferentTypeTensor(TestModuleBase): 57 | def __init__(self): 58 | super().__init__() 59 | 60 | def forward(self, x, y): 61 | result = x >= y 62 | return result 63 | 64 | def get_example_inputs(self): 65 | return (torch.randn(2, 3), torch.randn(2, 3).to(torch.int64)), {} 66 | -------------------------------------------------------------------------------- /test/modules/op/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleGelu(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | self.gelu = torch.nn.GELU() 24 | 25 | def forward(self, tensor): 26 | result = self.gelu(tensor) 27 | return result 28 | 29 | def get_example_inputs(self): 30 | return (torch.randn(3, 3),), {} 31 | 32 | 33 | class GeluWithApproximate(TestModuleBase): 34 | def __init__(self): 35 | super().__init__() 36 | self.gelu = torch.nn.GELU(approximate="tanh") 37 | 38 | def forward(self, tensor): 39 | result = self.gelu(tensor) 40 | return result 41 | 42 | def get_example_inputs(self): 43 | return (torch.randn(3, 3),), {} 44 | -------------------------------------------------------------------------------- /test/modules/op/gt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleGtWithScalarFloat(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x, y): 25 | result = x > y 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(1, 3), 2.0), {} 30 | 31 | 32 | class SimpleGtWithScalarInt(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x, y): 37 | result = x > y 38 | return result 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(1, 3), 2), {} 42 | 43 | 44 | class SimpleGtWithTensor(TestModuleBase): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x, y): 49 | result = x > y 50 | return result 51 | 52 | def get_example_inputs(self): 53 | return (torch.randn(2, 3), torch.randn(2, 3)), {} 54 | 55 | 56 | class SimpleGtWithDifferentTypeTensor(TestModuleBase): 57 | def __init__(self): 58 | super().__init__() 59 | 60 | def forward(self, x, y): 61 | result = x > y 62 | return result 63 | 64 | def get_example_inputs(self): 65 | return (torch.randn(2, 3), torch.randn(2, 3).to(torch.int64)), {} 66 | -------------------------------------------------------------------------------- /test/modules/op/hardtanh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleHardtanhFrom0to6(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | x = torch.nn.functional.hardtanh(x, 0.0, 6.0) 26 | return x 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(1, 3),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/interpolate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class InterpolateDouble(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | return torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 26 | 27 | def get_example_inputs(self): 28 | return (torch.randn(1, 2, 3, 4),), {} 29 | 30 | 31 | class InterpolateThreeTimes(TestModuleBase): 32 | def __init__(self): 33 | super().__init__() 34 | 35 | def forward(self, x): 36 | return torch.nn.functional.interpolate(x, scale_factor=3.0, mode="nearest") 37 | 38 | def get_example_inputs(self): 39 | return (torch.randn(1, 2, 3, 4),), {} 40 | 41 | 42 | class InterpolateOnePointFive(TestModuleBase): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | def forward(self, x): 47 | return torch.nn.functional.interpolate(x, scale_factor=1.5, mode="nearest") 48 | 49 | def get_example_inputs(self): 50 | return (torch.randn(1, 3, 6, 6),), {} 51 | -------------------------------------------------------------------------------- /test/modules/op/le.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleLeWithScalarFloat(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x, y): 25 | result = x <= y 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(1, 3), 2.0), {} 30 | 31 | 32 | class SimpleLeWithScalarInt(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x, y): 37 | result = x <= y 38 | return result 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(1, 3), 2), {} 42 | 43 | 44 | class SimpleLeWithTensor(TestModuleBase): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x, y): 49 | result = x <= y 50 | return result 51 | 52 | def get_example_inputs(self): 53 | return (torch.randn(2, 3), torch.randn(2, 3)), {} 54 | 55 | 56 | class SimpleLeWithDifferentTypeTensor(TestModuleBase): 57 | def __init__(self): 58 | super().__init__() 59 | 60 | def forward(self, x, y): 61 | result = x <= y 62 | return result 63 | 64 | def get_example_inputs(self): 65 | return (torch.randn(2, 3), torch.randn(2, 3).to(torch.int64)), {} 66 | -------------------------------------------------------------------------------- /test/modules/op/leaky_relu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleLeakyRelu(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | self.leaky_relu = torch.nn.LeakyReLU(0.1) 24 | 25 | def forward(self, tensor): 26 | result = self.leaky_relu(tensor) 27 | return result 28 | 29 | def get_example_inputs(self): 30 | return (torch.randn(3, 3),), {} 31 | -------------------------------------------------------------------------------- /test/modules/op/log.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleLog(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, input): 25 | return torch.log(input) 26 | 27 | def get_example_inputs(self): 28 | return (torch.randn(3, 3),), {} 29 | -------------------------------------------------------------------------------- /test/modules/op/log1p.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleLog1p(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, input): 25 | return torch.log1p(input) 26 | 27 | def get_example_inputs(self): 28 | return (torch.randn(3, 3),), {} 29 | -------------------------------------------------------------------------------- /test/modules/op/logical_and.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleLogicalAnd(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x, y): 25 | z = torch.logical_and(x, y) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | # generate random bool tensor 30 | lhs = torch.randn((3, 5)) < 0.5 31 | rhs = torch.randn((3, 5)) < 0.5 32 | return ( 33 | lhs, 34 | rhs, 35 | ), {} 36 | -------------------------------------------------------------------------------- /test/modules/op/logical_not.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleLogicalNot(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.logical_not(x) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | # generate random bool tensor 30 | tensor = torch.randn((3, 5)) < 0.5 31 | return (tensor,), {} 32 | -------------------------------------------------------------------------------- /test/modules/op/lt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleLt(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x, y): 25 | z = torch.lt(x, y) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3), torch.randn(3, 3)), {} 30 | 31 | 32 | class SimpleLtWithAngleBracket(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x, y): 37 | z = x < y 38 | return z 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(3, 3), torch.randn(3, 3)), {} 42 | -------------------------------------------------------------------------------- /test/modules/op/max_dim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleMaxDim(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x: torch.Tensor): 25 | max = x.max(dim=0) 26 | return max[0] # max: tuple (max, max_indices) 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(2, 3),), {} 30 | 31 | 32 | class MaxDimKeepDim(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x: torch.Tensor): 37 | max = x.max(dim=0, keepdim=True) 38 | return max[0] # max: tuple (max, max_indices) 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(2, 3),), {} 42 | -------------------------------------------------------------------------------- /test/modules/op/mul.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleMulWithTensor(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x, y): 25 | z = x * y 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3), torch.randn(3, 3)), {} 30 | 31 | 32 | class SimpleMulWithScalar(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x, y): 37 | z = x * y 38 | return z 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(3, 3), 5), {} 42 | 43 | 44 | class MulWithBuiltinFloat(TestModuleBase): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x, y): 49 | z = x * y 50 | return z 51 | 52 | def get_example_inputs(self): 53 | return (torch.ones(1), 2.0), {} 54 | 55 | 56 | class MulWithBuiltinInt(TestModuleBase): 57 | def __init__(self): 58 | super().__init__() 59 | 60 | def forward(self, x, y): 61 | z = x * y 62 | return z 63 | 64 | def get_example_inputs(self): 65 | return (torch.ones(1).to(torch.int64), 2), {} 66 | -------------------------------------------------------------------------------- /test/modules/op/neg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleNeg(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.neg(x) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/prelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimplePReLU(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | self.prelu = torch.nn.PReLU() 24 | 25 | def forward(self, x): 26 | result = self.prelu(x) 27 | return result 28 | 29 | def get_example_inputs(self): 30 | return (torch.randn(1, 2, 3, 3),), {} 31 | -------------------------------------------------------------------------------- /test/modules/op/reciprocal.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleReciprocal2(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, input_): 25 | result = torch.reciprocal(input_) 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3),), {} 30 | 31 | 32 | class SimpleReciprocalOperator(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, input_): 37 | result = 1 / input_ 38 | return result 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(3, 3),), {} 42 | -------------------------------------------------------------------------------- /test/modules/op/relu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleRelu(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | result = torch.relu(x) 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 5),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/relu6.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleReLU6(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.nn.functional.relu6(x) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3) * 7,), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/round.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | from test.utils import tag 20 | 21 | 22 | @tag.use_onert 23 | class SimpleRound(TestModuleBase): 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def forward(self, x): 28 | result = torch.round(x) 29 | return result 30 | 31 | def get_example_inputs(self): 32 | return (torch.randn(3, 5),), {} 33 | -------------------------------------------------------------------------------- /test/modules/op/rsqrt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleRsqrt(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.rsqrt(x) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | torch.manual_seed(1) 30 | return (torch.randn(3, 3),), {} 31 | -------------------------------------------------------------------------------- /test/modules/op/scalar_tensor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleScalarTensor(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | y = torch.scalar_tensor(2.0) 26 | z = x + y 27 | return z 28 | 29 | def get_example_inputs(self): 30 | return (torch.randn(3),), {} 31 | 32 | 33 | class SimpleScalarTensorInt(TestModuleBase): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | def forward(self, x): 38 | y = torch.scalar_tensor(2) 39 | z = x + y 40 | return z 41 | 42 | def get_example_inputs(self): 43 | return (torch.randn(3),), {} 44 | 45 | 46 | class SimpleScalarTensorBool(TestModuleBase): 47 | def __init__(self): 48 | super().__init__() 49 | 50 | def forward(self, x): 51 | y = torch.scalar_tensor(True) 52 | z = x + y 53 | return z 54 | 55 | def get_example_inputs(self): 56 | return (torch.randn(3),), {} 57 | -------------------------------------------------------------------------------- /test/modules/op/select.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleSelect(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x: torch.Tensor): 25 | dim = 0 26 | idx = 1 27 | # Equivalent to `tensor[idx]` 28 | selected_x = torch.select(x, dim, idx) 29 | return selected_x 30 | 31 | def get_example_inputs(self): 32 | return (torch.randn(4),), {} 33 | 34 | 35 | class SimpleSelect2(TestModuleBase): 36 | def __init__(self): 37 | super().__init__() 38 | 39 | def forward(self, x: torch.Tensor): 40 | dim = 2 41 | idx = 1 42 | # Equivalent to `tensor[:,:,idx]` 43 | selected_x = torch.select(x, dim, idx) 44 | return selected_x 45 | 46 | def get_example_inputs(self): 47 | return (torch.randn(2, 3, 4),), {} 48 | 49 | 50 | class SimpleConstIndex(TestModuleBase): 51 | def __init__(self): 52 | super().__init__() 53 | 54 | def forward(self, discrete_A): 55 | i = 1 56 | ssm_state = discrete_A[:, :, i, :] 57 | return ssm_state 58 | 59 | def get_example_inputs(self): 60 | return (torch.rand(1, 32, 6, 16),), {} 61 | -------------------------------------------------------------------------------- /test/modules/op/select_copy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleSelectCopy(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x: torch.Tensor): 25 | dim = 0 26 | idx = 1 27 | # Equivalent to `tensor[idx]` 28 | copy_x = torch.select_copy(x, dim, idx) 29 | return copy_x 30 | 31 | def get_example_inputs(self): 32 | return (torch.randn(4),), {} 33 | 34 | 35 | class SimpleSelectCopy2(TestModuleBase): 36 | def __init__(self): 37 | super().__init__() 38 | 39 | def forward(self, x: torch.Tensor): 40 | dim = 2 41 | idx = 1 42 | # Equivalent to `tensor[:,:,idx]` 43 | copy_x = torch.select_copy(x, dim, idx) 44 | return copy_x 45 | 46 | def get_example_inputs(self): 47 | return (torch.randn(2, 3, 4),), {} 48 | -------------------------------------------------------------------------------- /test/modules/op/sigmoid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleSigmoid(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.sigmoid(x) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | torch.manual_seed(1234) 30 | return (torch.randn(3, 3),), {} 31 | -------------------------------------------------------------------------------- /test/modules/op/sin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleSin(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, tensor): 25 | result = torch.sin(tensor) 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleSoftMax(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch._softmax(x, dim=2, half_to_float=False) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | torch.manual_seed(1) 30 | return (torch.randn(4, 4, 3),), {} 31 | 32 | 33 | class SimpleSoftMaxDimMinus(TestModuleBase): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | def forward(self, x): 38 | z = torch._softmax(x, dim=-1, half_to_float=False) 39 | return z 40 | 41 | def get_example_inputs(self): 42 | torch.manual_seed(1) 43 | return (torch.randn(4, 4, 3),), {} 44 | 45 | 46 | class SimpleSafeSoftMax(TestModuleBase): 47 | def __init__(self): 48 | super().__init__() 49 | 50 | def forward(self, x): 51 | z = torch.ops.aten._safe_softmax(x, dim=2) 52 | return z 53 | 54 | def get_example_inputs(self): 55 | torch.manual_seed(1) 56 | return (torch.randn(4, 4, 3),), {} 57 | -------------------------------------------------------------------------------- /test/modules/op/sqrt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleSqrt(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.sqrt(x) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/squeeze.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleSqueeze(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.squeeze(x) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | torch.manual_seed(1234) 30 | return (torch.randn(2, 1, 2, 1, 2),), {} 31 | 32 | 33 | class SimpleSqueezeWithDims(TestModuleBase): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | def forward(self, x): 38 | z = torch.squeeze(x, dim=(1, 3)) 39 | return z 40 | 41 | def get_example_inputs(self): 42 | torch.manual_seed(1234) 43 | return (torch.randn(2, 1, 2, 1, 2),), {} 44 | 45 | 46 | class SimpleSqueezeWithSingleDim(TestModuleBase): 47 | def __init__(self): 48 | super().__init__() 49 | 50 | def forward(self, x): 51 | z = torch.squeeze(x, dim=(1,)) 52 | return z 53 | 54 | def get_example_inputs(self): 55 | torch.manual_seed(1234) 56 | return (torch.randn(2, 1, 2, 1, 2),), {} 57 | -------------------------------------------------------------------------------- /test/modules/op/sub.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleSub(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, input_, other): 25 | return torch.sub(input_, other) 26 | 27 | def get_example_inputs(self): 28 | return (torch.randn(2, 3), torch.randn(2, 3)), {} 29 | 30 | 31 | class SubWithOut(TestModuleBase): 32 | def __init__(self): 33 | super().__init__() 34 | 35 | def forward(self, input_, other): 36 | out = torch.empty_like(input_) 37 | torch.sub(input_, other, out=out) 38 | return out 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(2, 3), torch.randn(2, 3)), {} 42 | 43 | 44 | class SubWithBuiltinFloat(TestModuleBase): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x, y): 49 | z = x - y 50 | return z 51 | 52 | def get_example_inputs(self): 53 | return (torch.ones(1), 2.0), {} 54 | 55 | 56 | class SubWithBuiltinInt(TestModuleBase): 57 | def __init__(self): 58 | super().__init__() 59 | 60 | def forward(self, x, y): 61 | z = x - y 62 | return z 63 | 64 | def get_example_inputs(self): 65 | return (torch.ones(1).to(torch.int64), 2), {} 66 | -------------------------------------------------------------------------------- /test/modules/op/tanh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleTanh(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, tensor): 25 | result = torch.tanh(tensor) 26 | return result 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(3, 3),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/to_dim_order_copy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleToF32I32(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = x.to(dtype=torch.int32) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(2, 3, dtype=torch.float32),), {} 30 | 31 | 32 | class SimpleToI32F32(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x): 37 | z = x.to(dtype=torch.float32) 38 | return z 39 | 40 | def get_example_inputs(self): 41 | return (torch.ones(2, 3, dtype=torch.int32),), {} 42 | -------------------------------------------------------------------------------- /test/modules/op/unsqueeze.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleUnsqueeze(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = torch.unsqueeze(x, dim=0) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(4, 4, 3),), {} 30 | -------------------------------------------------------------------------------- /test/modules/op/view.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from test.modules.base import TestModuleBase 18 | 19 | 20 | class SimpleView(TestModuleBase): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | z = x.view((4, 10)) 26 | return z 27 | 28 | def get_example_inputs(self): 29 | return (torch.randn(2, 4, 5),), {} 30 | 31 | 32 | class SimpleViewFirstDimMinus(TestModuleBase): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self, x): 37 | z = x.view([-1, 10]) 38 | return z 39 | 40 | def get_example_inputs(self): 41 | return (torch.randn(2, 4, 5),), {} 42 | 43 | 44 | class SimpleViewLastDimMinus(TestModuleBase): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x): 49 | z = x.view([5, -1]) 50 | return z 51 | 52 | def get_example_inputs(self): 53 | return (torch.randn(2, 4, 5),), {} 54 | -------------------------------------------------------------------------------- /test/performance/README.md: -------------------------------------------------------------------------------- 1 | # Performance Benchmarks 2 | 3 | This directory contains benchmark scripts that verify the performance requirements 4 | specified in **docs/system_test.md**: 5 | 6 | * **RNF‑1 – Conversion Speed** 7 | * **RNF‑2 – Model Size** 8 | 9 | Both requirements can be tested with: 10 | 11 | ```bash 12 | python3 -m test.performance.benchmark_perf 13 | ``` 14 | 15 | The test uses baseline models (`Llama-3.2-1B` and `Llama-3.2-3B`). 16 | 17 | Feel free to adjust thresholds or models as needed. 18 | -------------------------------------------------------------------------------- /test/performance/__init__.py: -------------------------------------------------------------------------------- 1 | # Performance test package for TICO 2 | -------------------------------------------------------------------------------- /test/performance/benchmark_perf.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | 3 | from tico import convert 4 | 5 | from test.performance.utils import ( 6 | load_model, 7 | measure_time, 8 | size_of_bytes, 9 | temp_state_dict_size, 10 | ) 11 | 12 | 13 | def run_benchmark(model_name, speed_threshold, size_threshold): 14 | print(f"Start performance test with {model_name}") 15 | 16 | model, inputs = load_model(model_name) 17 | 18 | # Conversion speed benchmark 19 | timings = measure_time(convert, model, inputs) 20 | mean_time = statistics.mean(timings) * model.num_hidden_layers() # type: ignore[operator] 21 | print( 22 | f"Mean conversion time (Single decoder layer * num_hidden_layers): {mean_time:.2f}s (threshold {speed_threshold}s)" 23 | ) 24 | if mean_time > speed_threshold: 25 | raise AssertionError( 26 | f"Conversion too slow: {mean_time:.2f}s exceeds threshold {speed_threshold}s" 27 | ) 28 | 29 | # Model size benchmark 30 | circle_model = convert(model, inputs) 31 | circle_size = size_of_bytes(circle_model.circle_binary) 32 | state_dict_size = temp_state_dict_size(model) 33 | print(f"Circle size: {circle_size} bytes") 34 | print(f"State dict size: {state_dict_size} bytes") 35 | print(f"Circle / State dict ratio: {circle_size / state_dict_size}") 36 | if circle_size > state_dict_size * size_threshold: 37 | raise AssertionError( 38 | f"Circle size {circle_size} is increased by {(size_threshold - 1)*100}% compared to state_dict size {state_dict_size}" 39 | ) 40 | 41 | 42 | if __name__ == "__main__": 43 | models = ["Llama-3.2-1B", "Llama-3.2-3B"] 44 | 45 | # Conversion speed thresholds (seconds) 46 | speed_thresholds = [60, 180] 47 | 48 | # Model size increase thresholds (1.01 = 1% increase) 49 | size_thresholds = [1.01, 1.01] 50 | 51 | for model, speed_threshold, size_threshold in zip( 52 | models, speed_thresholds, size_thresholds 53 | ): 54 | run_benchmark(model, speed_threshold, size_threshold) 55 | -------------------------------------------------------------------------------- /test/performance/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.53.0 2 | -------------------------------------------------------------------------------- /test/pt2_to_circle_test/README.md: -------------------------------------------------------------------------------- 1 | # pt2_to_circle_test 2 | 3 | `pt2_to_circle_test` validates the circle model which generated from pt2 file. 4 | 5 | The test proceeds as follows 6 | 7 | 1. Generate `pt2` model from `torch` model and export it to `.pt2` format. 8 | 2. Load `pt2` model from the exported file, and convert it to the `circle` model using `pt2-to-circle`. 9 | 3. Validate converted `circle` model using `circle2circle` (validity of shape/dtype). 10 | 4. Execute and compare the results from `torch` model and `circle` model. The results must be the same. 11 | -------------------------------------------------------------------------------- /test/pt2_to_circle_test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | 18 | # NOTE load_tests is `unittest` protocol to find tests dynamically 19 | # https://docs.python.org/ko/3.13/library/unittest.html#load-tests-protocol 20 | def load_tests(loader, standard_tests, pattern): 21 | # top level directory cached on loader instance 22 | this_dir = os.path.dirname(__file__) 23 | 24 | # Add test files to be found by `unittest` 25 | # WHY? Not to include other files by mistake and to make it clear which files are being tested 26 | for testfile in ["test_net.py", "test_op.py"]: 27 | package_tests = loader.discover(start_dir=this_dir, pattern=testfile) 28 | standard_tests.addTests(package_tests) 29 | 30 | return standard_tests 31 | -------------------------------------------------------------------------------- /test/pt2_to_circle_test/test_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from test.pt2_to_circle_test.builder import NormalTestDictBuilder 16 | from test.utils.helper import declare_unittests 17 | 18 | # NOTE Thie file's name must start with `test_` to be found by unittest 19 | 20 | 21 | declare_unittests(globals(), "test.modules.net", NormalTestDictBuilder) 22 | -------------------------------------------------------------------------------- /test/pt2_to_circle_test/test_op.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from test.pt2_to_circle_test.builder import NormalTestDictBuilder 16 | from test.utils.helper import declare_unittests 17 | 18 | # NOTE Thie file's name must start with `test_` to be found by unittest 19 | 20 | 21 | declare_unittests(globals(), "test.modules.op", NormalTestDictBuilder) 22 | -------------------------------------------------------------------------------- /test/pt2_to_qcircle_test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | 18 | # NOTE load_tests is `unittest` protocol to find tests dynamically 19 | # https://docs.python.org/ko/3.13/library/unittest.html#load-tests-protocol 20 | def load_tests(loader, standard_tests, pattern): 21 | # top level directory cached on loader instance 22 | this_dir = os.path.dirname(__file__) 23 | 24 | # Add test files to be found by `unittest` 25 | # WHY? Not to include other files by mistake and to make it clear which files are being tested 26 | for testfile in ["test_op.py"]: 27 | package_tests = loader.discover(start_dir=this_dir, pattern=testfile) 28 | standard_tests.addTests(package_tests) 29 | 30 | return standard_tests 31 | -------------------------------------------------------------------------------- /test/pt2_to_qcircle_test/test_op.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from test.pt2_to_qcircle_test.builder import QuantizationTestDictBuilder 16 | from test.utils.helper import declare_unittests 17 | 18 | # NOTE Thie file's name must start with `test_` to be found by unittest 19 | 20 | 21 | declare_unittests(globals(), "test.modules.op", QuantizationTestDictBuilder) 22 | -------------------------------------------------------------------------------- /test/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/pass/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/wrapq/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/wrapq/observers/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/wrapq/test_dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from tico.quantization.wrapq.dtypes import DType, INT8, UINT4 18 | 19 | 20 | class TestDType(unittest.TestCase): 21 | def test_presets(self): 22 | self.assertEqual(INT8.bits, 8) 23 | self.assertTrue(INT8.signed) 24 | self.assertEqual(UINT4.bits, 4) 25 | self.assertFalse(UINT4.signed) 26 | 27 | def test_range_signed(self): 28 | dt = DType.int(6) # 6-bit signed 29 | self.assertEqual(dt.qmin, -32) 30 | self.assertEqual(dt.qmax, 31) 31 | 32 | def test_range_unsigned(self): 33 | dt = DType.uint(5) # 5-bit unsigned 34 | self.assertEqual(dt.qmin, 0) 35 | self.assertEqual(dt.qmax, 31) 36 | 37 | def test_str(self): 38 | self.assertEqual(str(DType.int(3)), "int3") 39 | self.assertEqual(str(DType.uint(7)), "uint7") 40 | -------------------------------------------------------------------------------- /test/quantization/wrapq/test_mode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Unit-tests for the `Mode` enumeration that governs the wrapper 17 | finite-state-machine. 18 | 19 | Checks 20 | ------ 21 | 1. Enumeration contains exactly the three expected members. 22 | 2. Natural ordering follows definition order: NO_QUANT < CALIB < QUANT 23 | (the `auto()` values are monotonically increasing). 24 | 3. The `str(...)` representation returns the lower-case name 25 | ("no_quant", "calib", "quant"). 26 | """ 27 | 28 | import unittest 29 | 30 | from tico.quantization.wrapq.mode import Mode 31 | 32 | 33 | class TestModeEnum(unittest.TestCase): 34 | def test_member_names(self): 35 | self.assertEqual( 36 | list(Mode.__members__.keys()), 37 | ["NO_QUANT", "CALIB", "QUANT"], 38 | msg="Mode enum must contain exactly NO_QUANT, CALIB, QUANT in that order", 39 | ) 40 | 41 | def test_ordering(self): 42 | # auto() assigns consecutive integers starting from 1 43 | self.assertLess(Mode.NO_QUANT.value, Mode.CALIB.value) 44 | self.assertLess(Mode.CALIB.value, Mode.QUANT.value) 45 | 46 | def test_str_representation(self): 47 | self.assertEqual(str(Mode.NO_QUANT), "no_quant") 48 | self.assertEqual(str(Mode.CALIB), "calib") 49 | self.assertEqual(str(Mode.QUANT), "quant") 50 | -------------------------------------------------------------------------------- /test/quantization/wrapq/test_qscheme.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from tico.quantization.wrapq.qscheme import QScheme 18 | 19 | 20 | class TestQScheme(unittest.TestCase): 21 | def test_enum_members(self): 22 | self.assertIn(QScheme.PER_TENSOR_ASYMM, QScheme) 23 | self.assertIn(QScheme.PER_TENSOR_SYMM, QScheme) 24 | self.assertIn(QScheme.PER_CHANNEL_ASYMM, QScheme) 25 | self.assertIn(QScheme.PER_CHANNEL_SYMM, QScheme) 26 | 27 | def test_is_per_channel(self): 28 | self.assertFalse(QScheme.PER_TENSOR_ASYMM.is_per_channel()) 29 | self.assertTrue(QScheme.PER_CHANNEL_ASYMM.is_per_channel()) 30 | 31 | def test_is_symmetric(self): 32 | self.assertTrue(QScheme.PER_TENSOR_SYMM.is_symmetric()) 33 | self.assertFalse(QScheme.PER_CHANNEL_ASYMM.is_symmetric()) 34 | 35 | def test_str(self): 36 | self.assertEqual(str(QScheme.PER_TENSOR_ASYMM), "per_tensor_asymm") 37 | self.assertEqual(str(QScheme.PER_CHANNEL_SYMM), "per_channel_symm") 38 | -------------------------------------------------------------------------------- /test/quantization/wrapq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/wrapq/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/wrapq/wrappers/fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/wrapq/wrappers/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/quantization/wrapq/wrappers/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/requirements.txt: -------------------------------------------------------------------------------- 1 | # TEST DEPENDENCIES TO BE INSTALLED 2 | # DO NOT REMOVE THIS FILE 3 | -------------------------------------------------------------------------------- /test/requirements_pre.txt: -------------------------------------------------------------------------------- 1 | # TEST DEPENDENCIES TO BE INSTALLED WITH `pip install --pre` 2 | onert==0.2.0.dev250922 3 | -------------------------------------------------------------------------------- /test/unit_test/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/unit_test/pass_test/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/unit_test/pass_test/test_lower_pow2_to_mul.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from tico.passes.lower_pow2_to_mul import LowerPow2ToMul 17 | 18 | from test.utils.helper import num_of_ops 19 | from test.utils.pass_value_test import SinglePassValueTest 20 | 21 | 22 | class Pow2Net(torch.nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | 26 | def forward(self, x): 27 | z = x.pow(2) 28 | return z 29 | 30 | def get_example_inputs(self): 31 | return (torch.randn(3, 4),), {} 32 | 33 | 34 | class LowerPow2ToMulTest(SinglePassValueTest): 35 | def test_pass(self): 36 | self.setup(Pow2Net()) 37 | self.assertEqual( 38 | num_of_ops(self.exported_program(), [torch.ops.aten.pow.Tensor_Scalar]), 1 39 | ) 40 | 41 | self.run_value_test(LowerPow2ToMul()) 42 | self.assertEqual( 43 | num_of_ops(self.exported_program(), [torch.ops.aten.pow.Tensor_Scalar]), 0 44 | ) 45 | -------------------------------------------------------------------------------- /test/unit_test/pass_test/test_remove_nop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tico.passes.ops import aten 16 | from tico.passes.remove_nop import RemoveNop 17 | 18 | from test.modules.op.clone import SimpleCloneWithMemoryFormatContiguous 19 | 20 | from test.modules.op.detach import SimpleDetach 21 | from test.utils.helper import num_of_ops 22 | from test.utils.pass_value_test import SinglePassValueTest 23 | 24 | 25 | class TestRemoveDetach(SinglePassValueTest): 26 | def test_pass(self): 27 | self.setup(SimpleDetach()) 28 | self.assertEqual(num_of_ops(self.exported_program(), aten.detach), 1) 29 | 30 | self.run_value_test(RemoveNop()) 31 | self.assertEqual(num_of_ops(self.exported_program(), aten.detach), 0) 32 | 33 | 34 | class TestCloneContiguous(SinglePassValueTest): 35 | def test_pass(self): 36 | self.setup(SimpleCloneWithMemoryFormatContiguous()) 37 | self.assertEqual(num_of_ops(self.exported_program(), aten.clone), 1) 38 | 39 | self.run_value_test(RemoveNop()) 40 | self.assertEqual(num_of_ops(self.exported_program(), aten.clone), 0) 41 | -------------------------------------------------------------------------------- /test/unit_test/pass_test/test_remove_redundant_slice.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | import torch 18 | from tico.passes import ops 19 | from tico.passes.remove_redundant_slice import RemoveRedundantSlice 20 | from tico.utils.torch_compat import export_produces_slice 21 | 22 | from test.utils.helper import num_of_ops 23 | from test.utils.pass_value_test import SinglePassValueTest 24 | 25 | 26 | class RedundantSliceNet(torch.nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def forward(self, x): 31 | z = x[0, :] 32 | return z 33 | 34 | def get_example_inputs(self): 35 | return (torch.randn(1, 4),), {} 36 | 37 | 38 | class RemoveRedundantSliceTest(SinglePassValueTest): 39 | @unittest.skipUnless( 40 | export_produces_slice(), 41 | "Skip when torch doesn't produce redundant slices.", 42 | ) 43 | def test_pass(self): 44 | self.setup(RedundantSliceNet()) 45 | self.assertEqual(num_of_ops(self.exported_program(), ops.aten.slice), 1) 46 | 47 | self.run_value_test(RemoveRedundantSlice()) 48 | self.assertEqual(num_of_ops(self.exported_program(), ops.aten.slice), 0) 49 | -------------------------------------------------------------------------------- /test/unit_test/quantization_test/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/unit_test/serialize_test/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/unit_test/serialize_test/operator/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/unit_test/serialize_test/test_circle_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | import torch 18 | from tico.serialize.circle_graph import CircleModel, CircleSubgraph, is_const 19 | 20 | 21 | class CircleGraphTest(unittest.TestCase): 22 | def test_is_const(self): 23 | self.assertTrue(is_const(1)) 24 | self.assertTrue(is_const(1.1)) 25 | self.assertTrue(is_const([1, 1])) 26 | self.assertTrue(is_const([0.1, 0.1])) 27 | self.assertTrue(is_const(torch.tensor(1))) 28 | self.assertTrue(is_const([torch.tensor(1)])) 29 | self.assertTrue(is_const([torch.tensor(1), 1])) 30 | self.assertTrue(is_const([torch.tensor([1, 1])])) 31 | 32 | def test_duplicate_names(self): 33 | mod = CircleModel() 34 | g = CircleSubgraph(mod) 35 | g.add_tensor_from_scratch( 36 | prefix="name", shape=[1, 2, 3], shape_signature=None, dtype=0 37 | ) 38 | g.add_tensor_from_scratch( 39 | prefix="name", shape=[1, 2, 3], shape_signature=None, dtype=0 40 | ) 41 | 42 | self.assertTrue(g.has_tensor("name")) 43 | # This result depends on the naming rule of _gen_unique_name_with_prefix 44 | # Change this if the rule changes 45 | self.assertTrue(g.has_tensor("name_0")) 46 | -------------------------------------------------------------------------------- /test/unit_test/serialize_test/test_circle_mapping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from tico.serialize.circle_mapping import validate_circle_shape 18 | 19 | 20 | class CircleSerializeTest(unittest.TestCase): 21 | def test_validate_circle_shape(self): 22 | # static shape 23 | validate_circle_shape(shape=[1, 2, 3], shape_signature=None) 24 | # dynamic shape 25 | validate_circle_shape(shape=[1, 1, 3], shape_signature=[1, -1, 3]) 26 | validate_circle_shape(shape=[1, 1, 3], shape_signature=[-1, -1, 3]) 27 | 28 | # Invalid dynamic shape 29 | with self.assertRaises(ValueError): 30 | validate_circle_shape(shape=[1, 2, 3], shape_signature=[1, -1, 2]) 31 | with self.assertRaises(ValueError): 32 | validate_circle_shape(shape=[1, 2, 3], shape_signature=[1, -2, 3]) 33 | with self.assertRaises(ValueError): 34 | validate_circle_shape(shape=[1], shape_signature=[-1, -1]) 35 | with self.assertRaises(ValueError): 36 | validate_circle_shape(shape=[1, 2, 3], shape_signature=[]) 37 | -------------------------------------------------------------------------------- /test/unit_test/serialize_test/test_pack.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | import numpy as np 18 | 19 | from tico.serialize.pack import pack_buffer 20 | 21 | 22 | class PackTest(unittest.TestCase): 23 | def test_pack_uint4(self): 24 | input_ = np.array([1, 2, 3, 4, 5, 6], dtype=np.uint8) 25 | 26 | output_ = pack_buffer(input_, "uint4") 27 | 28 | self.assertEqual((3,), output_.shape) 29 | self.assertEqual(1 + (2 << 4), output_[0]) 30 | self.assertEqual(3 + (4 << 4), output_[1]) 31 | self.assertEqual(5 + (6 << 4), output_[2]) 32 | 33 | def test_pack_uint4_odd(self): 34 | input_ = np.array([1, 2, 3, 4, 5], dtype=np.uint8) 35 | 36 | output_ = pack_buffer(input_, "uint4") 37 | 38 | self.assertEqual((3,), output_.shape) 39 | self.assertEqual(1 + (2 << 4), output_[0]) 40 | self.assertEqual(3 + (4 << 4), output_[1]) 41 | self.assertEqual(5, output_[2]) 42 | 43 | def test_pack_dtype_mismatch_neg(self): 44 | input_ = np.array([1, 2, 3, 4, 5, 6], dtype=np.int16) 45 | 46 | # uint4 data has to be saved in uint8 47 | with self.assertRaises(RuntimeError): 48 | pack_buffer(input_, "uint4") 49 | -------------------------------------------------------------------------------- /test/unit_test/utils_test/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /test/unit_test/utils_test/test_run_bash_cmd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from tico.utils.utils import run_bash_cmd 18 | 19 | 20 | class RunBashCmdTest(unittest.TestCase): 21 | def test_simple_bash_cmd(self): 22 | completed_process = run_bash_cmd(["echo", "Hello World"]) 23 | self.assertEqual(completed_process.stdout, "Hello World\n") 24 | 25 | def test_invalid_cmd_neg(self): 26 | with self.assertRaises(RuntimeError): 27 | run_bash_cmd(["ls", "for_invalid_test"]) 28 | -------------------------------------------------------------------------------- /test/unit_test/utils_test/test_serialize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from tico.serialize.circle_graph import CircleModel, CircleSubgraph 18 | 19 | from tico.utils.serialize import validate_tensor_shapes 20 | 21 | 22 | class CircleSerializeTest(unittest.TestCase): 23 | def test_validate_circle_shape(self): 24 | mod = CircleModel() 25 | g = CircleSubgraph(mod) 26 | g.add_tensor_from_scratch( 27 | prefix="name", shape=[1, 2, 3], shape_signature=None, dtype=0 28 | ) 29 | g.add_tensor_from_scratch( 30 | prefix="name", shape=[1, 2, 3], shape_signature=None, dtype=0 31 | ) 32 | validate_tensor_shapes(g) 33 | 34 | def test_validate_tensor_shape_neg(self): 35 | mod = CircleModel() 36 | g = CircleSubgraph(mod) 37 | g.add_tensor_from_scratch( 38 | prefix="tensor0", 39 | shape=[1, 2, 3], 40 | shape_signature=[-1, 0, 0], # Invalid shape pair 41 | dtype=0, 42 | ) 43 | g.add_tensor_from_scratch( 44 | prefix="tensor1", shape=[1, 2, 3], shape_signature=None, dtype=0 45 | ) 46 | with self.assertRaises(ValueError): 47 | validate_tensor_shapes(g) 48 | -------------------------------------------------------------------------------- /test/utils/runtime.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Literal 16 | 17 | Runtime = Literal["circle-interpreter", "onert"] 18 | -------------------------------------------------------------------------------- /tico/config/__init__.py: -------------------------------------------------------------------------------- 1 | from tico.config.base import CompileConfigBase 2 | from tico.config.factory import get_default_config 3 | 4 | from tico.config.v1 import CompileConfigV1 5 | -------------------------------------------------------------------------------- /tico/config/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | 17 | 18 | @dataclass 19 | class CompileConfigBase: 20 | def get(self, name: str): 21 | return getattr(self, name) if hasattr(self, name) else None 22 | 23 | def set(self, name: str, enabled: bool): 24 | setattr(self, name, enabled) 25 | 26 | def to_dict(self): 27 | return {key: value for key, value in self.__dict__.items()} 28 | 29 | @classmethod 30 | def from_dict(cls, config_dict: dict): 31 | config = cls() 32 | for key in config_dict: 33 | if key in config.to_dict(): 34 | assert isinstance(config.get(key), bool) 35 | config.set(key, config_dict[key]) 36 | 37 | return config 38 | -------------------------------------------------------------------------------- /tico/config/factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Type 16 | 17 | from tico.config.base import CompileConfigBase 18 | from tico.config.v1 import CompileConfigV1 19 | 20 | 21 | class CompileConfigFactory: 22 | _config_classes = { 23 | "1.0": CompileConfigV1, 24 | # '2.0': CompileConfigV2, 25 | } 26 | 27 | @classmethod 28 | def get_config(cls, version: str) -> Type[CompileConfigBase]: 29 | if version not in cls._config_classes: 30 | raise ValueError(f"Unsupported version: {version}") 31 | 32 | return cls._config_classes[version] 33 | 34 | @classmethod 35 | def create(cls, version: str): 36 | config_class = cls.get_config(version) 37 | return config_class() 38 | 39 | 40 | def get_default_config(version: str = "1.0"): 41 | return CompileConfigFactory.create(version) 42 | -------------------------------------------------------------------------------- /tico/config/v1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | 17 | from tico.config.base import CompileConfigBase 18 | 19 | 20 | @dataclass 21 | class CompileConfigV1(CompileConfigBase): 22 | legalize_causal_mask_value: bool = False 23 | remove_constant_input: bool = False 24 | convert_lhs_const_mm_to_fc: bool = False 25 | convert_rhs_const_mm_to_fc: bool = True 26 | convert_single_batch_lhs_const_bmm_to_fc: bool = False 27 | convert_expand_to_slice_cat: bool = False 28 | 29 | def get(self, name: str): 30 | return super().get(name) 31 | 32 | def set(self, name: str, enabled: bool): 33 | super().set(name, enabled) 34 | 35 | def to_dict(self): 36 | return super().to_dict() 37 | 38 | @classmethod 39 | def from_dict(cls, config_dict: dict): 40 | return super().from_dict(config_dict) 41 | -------------------------------------------------------------------------------- /tico/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/interpreter/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/passes/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/passes/fill_meta_val.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch.export import ExportedProgram 16 | 17 | from tico.utils import logging 18 | from tico.utils.passes import PassBase, PassResult 19 | from tico.utils.trace_decorators import trace_graph_diff_on_pass 20 | from tico.utils.utils import set_new_meta_val 21 | 22 | 23 | @trace_graph_diff_on_pass 24 | class FillMetaVal(PassBase): 25 | """ 26 | Let's set new meta['val'] for nodes which don't have meta['val'] 27 | """ 28 | 29 | def __init__(self): 30 | super().__init__() 31 | 32 | def call(self, exported_program: ExportedProgram) -> PassResult: 33 | logger = logging.getLogger(__name__) 34 | 35 | graph_module = exported_program.graph_module 36 | graph = graph_module.graph 37 | modified = False 38 | # To make sure graph is topologically sorted 39 | graph.lint() 40 | for node in graph.nodes: 41 | if not node.op == "call_function": 42 | continue 43 | 44 | if hasattr(node, "meta") and "val" in node.meta: 45 | continue 46 | 47 | set_new_meta_val(node) 48 | 49 | modified = True 50 | 51 | logger.debug(f"{node.name} has new meta values.") 52 | 53 | graph.eliminate_dead_code() 54 | graph.lint() 55 | graph_module.recompile() 56 | 57 | return PassResult(modified) 58 | -------------------------------------------------------------------------------- /tico/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | from tico.quantization.public_interface import convert, prepare 2 | 3 | __all__ = [ 4 | "convert", 5 | "prepare", 6 | ] 7 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/fpi_gptq/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/gptq/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/pt2e/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/pt2e/annotation/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/pt2e/annotation/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Optional 17 | 18 | from torch.ao.quantization.quantizer import QuantizationSpec 19 | 20 | 21 | @dataclass(eq=True, frozen=True) 22 | class QuantizationConfig: 23 | input_activation: Optional[QuantizationSpec] 24 | output_activation: Optional[QuantizationSpec] 25 | weight: Optional[QuantizationSpec] 26 | bias: Optional[QuantizationSpec] 27 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/pt2e/annotation/op/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | from os.path import basename, dirname, isfile, join 17 | 18 | modules = glob.glob(join(dirname(__file__), "*.py")) 19 | __all__ = [ 20 | basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") 21 | ] 22 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/pt2e/annotation/spec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Callable, Dict, List, Optional, TYPE_CHECKING 16 | 17 | if TYPE_CHECKING: 18 | import torch.fx 19 | import torch 20 | 21 | from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig 22 | 23 | AnnotatorType = Callable[ 24 | [ 25 | torch.fx.GraphModule, 26 | torch.fx.Node, 27 | Optional[QuantizationConfig], 28 | Optional[Callable[[torch.fx.Node], bool]], 29 | ], 30 | None, 31 | ] 32 | OP_TO_ANNOTATOR: Dict[torch._ops.OpOverload, AnnotatorType] = {} 33 | OP_TO_SHARE_QUANT_SPEC: List[Callable] = [ 34 | torch.ops.aten.view_copy.default, 35 | torch.ops.aten.view.default, 36 | ] 37 | 38 | 39 | def register_annotator(target: List[torch._ops.OpOverload]): 40 | def decorator(annotator: AnnotatorType): 41 | for t in target: 42 | OP_TO_ANNOTATOR[t] = annotator 43 | return annotator 44 | 45 | return decorator 46 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/pt2e/transformation/README.md: -------------------------------------------------------------------------------- 1 | ## transformation 2 | 3 | The _transformation_ module provides a set of user-defined transformation passes 4 | designed to prepare a graph for quantization. 5 | 6 | Before annotating the graph, this module allows users to perfrom custom transformation. 7 | For example, it includes passes that convert scalars in the graph into tensors, 8 | enabling them to be quantized. 9 | 10 | ### Design Considerations 11 | 12 | The modules(passes) in the _transformation_ module are implemented to work with the 13 | `transformation_for_annotation` interface by the `Quantizer`. 14 | 15 | Unlike other passes in `tico.utils.passes`, the passes in this module do **not** 16 | inherit from the `PassBase` class. This is because the `transformation_for_annotation` 17 | API in the `Quantizer` uses a `torch.fx.GraphModule` as input, rather than the 18 | `ExportedProgram` used by `PassBase`. 19 | 20 | ```python 21 | # https://github.com/pytorch/pytorch/blob/06b4b96b/torch/ao/quantization/quantizer/quantizer.py#L137-L150 22 | class Quantizer(ABC): 23 | def transform_for_annotation( 24 | self, model: torch.fx.GraphModule 25 | ) -> torch.fx.GraphModule: 26 | """Allows for user defined transforms to run before annotating the graph. 27 | This allows quantizer to allow quantizing part of the model that are otherwise not quantizable. 28 | For example quantizer can 29 | a) decompose a compound operator like scaled dot product attention, 30 | into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa 31 | or b) transform scalars to tensor to allow quantizing scalares. 32 | 33 | Note: this is an optional method 34 | """ 35 | return model 36 | ``` -------------------------------------------------------------------------------- /tico/quantization/algorithm/pt2e/transformation/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/smoothquant/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/algorithm/smoothquant/smooth_quant.txt: -------------------------------------------------------------------------------- 1 | # requirments.txt 2 | transformers 3 | datasets 4 | tqdm 5 | -------------------------------------------------------------------------------- /tico/quantization/config/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/config/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC, abstractmethod 16 | 17 | 18 | class BaseConfig(ABC): 19 | """ 20 | Base configuration class for quantization. 21 | """ 22 | 23 | @property 24 | @abstractmethod 25 | def name(self) -> str: 26 | pass 27 | -------------------------------------------------------------------------------- /tico/quantization/config/fpi_gptq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tico.quantization.config.gptq import GPTQConfig 16 | 17 | 18 | class FPIGPTQConfig(GPTQConfig): 19 | """ 20 | Configuration for FPIGPTQ (Fixed Point Iteration). 21 | """ 22 | 23 | def __init__(self, verbose: bool = False, show_progress: bool = True): 24 | self.verbose = verbose 25 | self.show_progress = show_progress 26 | 27 | @property 28 | def name(self) -> str: 29 | return "fpi_gptq" 30 | -------------------------------------------------------------------------------- /tico/quantization/config/gptq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tico.quantization.config.base import BaseConfig 16 | 17 | 18 | class GPTQConfig(BaseConfig): 19 | """ 20 | Configuration for GPTQ. 21 | """ 22 | 23 | def __init__(self, verbose: bool = False, show_progress: bool = True): 24 | self.verbose = verbose 25 | self.show_progress = show_progress 26 | 27 | @property 28 | def name(self) -> str: 29 | return "gptq" 30 | -------------------------------------------------------------------------------- /tico/quantization/config/pt2e.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from tico.quantization.config.base import BaseConfig 16 | 17 | 18 | class PT2EConfig(BaseConfig): 19 | """ 20 | Configuration for pytorch 2.0 export quantization. 21 | """ 22 | 23 | @property 24 | def name(self) -> str: 25 | return "pt2e" 26 | -------------------------------------------------------------------------------- /tico/quantization/config/smoothquant.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, Literal, Optional 16 | 17 | from tico.quantization.config.base import BaseConfig 18 | 19 | 20 | class SmoothQuantConfig(BaseConfig): 21 | """ 22 | Configuration for smooth quant. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | alpha: float = 0.5, 28 | custom_alpha_map: Optional[Dict[str, float]] = None, 29 | acts_from: Literal["input", "output"] = "input", 30 | ): 31 | self.alpha = alpha 32 | self.custom_alpha_map = custom_alpha_map 33 | # Where to collect activation statistics from: 34 | # - "input": use forward-pre-hook (Tensor before the Linear op) 35 | # - "output": use forward-hook (Tensor after the Linear op) 36 | # Default is "input". 37 | self.acts_from = acts_from 38 | 39 | @property 40 | def name(self) -> str: 41 | return "smoothquant" 42 | -------------------------------------------------------------------------------- /tico/quantization/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/evaluation/backend.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import Enum 16 | 17 | 18 | class BACKEND(Enum): 19 | CIRCLE = 1 20 | TRIV24 = 2 21 | -------------------------------------------------------------------------------- /tico/quantization/evaluation/executor/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/evaluation/executor/backend_executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | from typing import Any 17 | 18 | from tico.utils.model import CircleModel 19 | 20 | 21 | class BackendExecutor(abc.ABC): 22 | """ 23 | Abstract base class for executing a circle model on a specific backend. 24 | """ 25 | 26 | @abc.abstractmethod 27 | def compile(self, circle_model: CircleModel) -> None: 28 | """ 29 | Compile the circle model for this backend, if needed. 30 | 31 | Parameters 32 | ----------- 33 | circle_model 34 | The circle model to be compiled. 35 | """ 36 | pass 37 | 38 | @abc.abstractmethod 39 | def run_inference(self, input_data: Any) -> Any: 40 | """ 41 | Run inference using the compiled (or directly usable) model 42 | on the given input data. 43 | 44 | Parameters 45 | ----------- 46 | input_data 47 | The input data to be fed to the model. 48 | 49 | Returns 50 | -------- 51 | Any 52 | The model's inference output. 53 | """ 54 | pass 55 | -------------------------------------------------------------------------------- /tico/quantization/passes/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/mode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import auto, Enum 16 | 17 | 18 | class Mode(Enum): 19 | """ 20 | Mode — global FSM for PTQWrapper & Handlers. 21 | 22 | • NO_QUANT : pure pass-through (no stats, no fake-quant) 23 | • CALIB : collect observer statistics only 24 | • QUANT : use cached (scale, zero-point) → fake-quant enabled 25 | """ 26 | 27 | NO_QUANT = auto() 28 | CALIB = auto() 29 | QUANT = auto() 30 | 31 | def __str__(self) -> str: 32 | return self.name.lower() 33 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/observers/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/observers/minmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from tico.quantization.wrapq.observers.affine_base import AffineObserverBase 18 | from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax 19 | 20 | 21 | class MinMaxObserver(AffineObserverBase): 22 | """Plain min/max range tracker.""" 23 | 24 | @torch.no_grad() 25 | def _update_stats(self, x: torch.Tensor) -> None: 26 | """ 27 | Update running min/max with the incoming batch. 28 | 29 | Per-tensor: use global min/max. 30 | Per-channel: reduce all axes except the channel axis. 31 | """ 32 | if self.channel_axis is None: 33 | curr_min, curr_max = x.min(), x.max() 34 | else: 35 | curr_min, curr_max = channelwise_minmax(x, self.channel_axis) 36 | 37 | # Broadcasting handles scalar-vs-vector cases 38 | self.min_val = torch.minimum(self.min_val, curr_min) 39 | self.max_val = torch.maximum(self.max_val, curr_max) 40 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/qscheme.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from enum import auto, Enum 16 | 17 | 18 | class QScheme(Enum): 19 | # ───── Per-tensor ──────────── 20 | PER_TENSOR_ASYMM = auto() 21 | PER_TENSOR_SYMM = auto() 22 | # ───── Per-channel ─────────── 23 | PER_CHANNEL_ASYMM = auto() 24 | PER_CHANNEL_SYMM = auto() 25 | 26 | # helper 27 | def is_per_channel(self) -> bool: 28 | return self in { 29 | QScheme.PER_CHANNEL_ASYMM, 30 | QScheme.PER_CHANNEL_SYMM, 31 | } 32 | 33 | def is_symmetric(self) -> bool: 34 | return self in { 35 | QScheme.PER_TENSOR_SYMM, 36 | QScheme.PER_CHANNEL_SYMM, 37 | } 38 | 39 | def __str__(self) -> str: 40 | return self.name.lower() 41 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/utils/reduce_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | def channelwise_minmax(x: torch.Tensor, channel_axis: int): 19 | """ 20 | Compute per-channel (min, max) by reducing all axes except `channel_axis`. 21 | """ 22 | channel_axis = channel_axis % x.ndim # handle negative indices safely 23 | dims = tuple(d for d in range(x.ndim) if d != channel_axis) 24 | 25 | return x.amin(dim=dims), x.amax(dim=dims) 26 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/wrappers/fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | from tico.quantization.wrapq.wrappers.fairseq.quant_mha import ( 2 | QuantFairseqMultiheadAttention, 3 | ) 4 | 5 | __all__ = ["QuantFairseqMultiheadAttention"] 6 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/wrappers/llama/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/quantization/wrapq/wrappers/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/serialize/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/serialize/operators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | from os.path import basename, dirname, isfile, join 17 | 18 | from tico.utils.register_custom_op import RegisterOps 19 | 20 | 21 | # Register custom ops to torch namespace 22 | RegisterOps() 23 | 24 | # Load all modules in the current directory 25 | modules = glob.glob(join(dirname(__file__), "*.py")) 26 | __all__ = [ 27 | basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") 28 | ] 29 | -------------------------------------------------------------------------------- /tico/serialize/operators/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/serialize/operators/adapters/llama_rmsnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from contextlib import contextmanager 16 | 17 | import torch 18 | 19 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 20 | 21 | 22 | def llama_rmsnorm_forward_adapter(self: LlamaRMSNorm, hidden_states: torch.Tensor): 23 | return torch.ops.circle_custom.rms_norm( 24 | hidden_states, self.weight, self.variance_epsilon 25 | ) 26 | 27 | 28 | @contextmanager 29 | def patched_llama_rmsnorm(): 30 | orig = LlamaRMSNorm.forward 31 | LlamaRMSNorm.forward = llama_rmsnorm_forward_adapter 32 | try: 33 | yield 34 | finally: 35 | LlamaRMSNorm.forward = orig 36 | -------------------------------------------------------------------------------- /tico/serialize/operators/hashable_opcode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from circle_schema import circle 16 | 17 | 18 | class OpCode(circle.OperatorCode.OperatorCodeT): 19 | """ 20 | Wrapper class for operator code in circle schema 21 | This implements __eq__ and __hash__ for use with dict() 22 | """ 23 | 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def __eq__(self, other): 28 | if self.version != other.version: 29 | return False 30 | 31 | if self.builtinCode == circle.BuiltinOperator.BuiltinOperator.CUSTOM: 32 | return self.customCode == other.customCode 33 | 34 | return self.builtinCode == other.builtinCode 35 | 36 | def __hash__(self): 37 | val = ( 38 | self.deprecatedBuiltinCode, 39 | self.customCode, 40 | self.version, 41 | self.builtinCode, 42 | ) 43 | return hash(val) 44 | -------------------------------------------------------------------------------- /tico/serialize/operators/op_full.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, List, TYPE_CHECKING 16 | 17 | if TYPE_CHECKING: 18 | import torch._ops 19 | import torch.fx 20 | import torch 21 | from circle_schema import circle 22 | 23 | from tico.serialize.circle_graph import CircleSubgraph 24 | from tico.serialize.operators.hashable_opcode import OpCode 25 | from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor 26 | from tico.utils.validate_args_kwargs import FullArgs 27 | 28 | 29 | @register_node_visitor 30 | class FullVisitor(NodeVisitor): 31 | target: List[torch._ops.OpOverload] = [torch.ops.aten.full.default] 32 | 33 | def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): 34 | super().__init__(op_codes, graph) 35 | 36 | def define_node( 37 | self, 38 | node: torch.fx.Node, 39 | ) -> circle.Operator.OperatorT: 40 | args = FullArgs(*node.args, **node.kwargs) # type: ignore[arg-type] 41 | size = args.size 42 | fill_value = args.fill_value 43 | 44 | output_data = torch.full(size, fill_value) 45 | 46 | self.graph.update_tensor_buffer(output_data, node.name) 47 | 48 | return None # type: ignore[return-value] 49 | -------------------------------------------------------------------------------- /tico/serialize/pack.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | 18 | def pack_buffer(flat_data: np.ndarray, dtype: str) -> np.ndarray: 19 | assert len(flat_data.shape) == 1 20 | 21 | if dtype == "uint4": 22 | if flat_data.dtype != np.uint8: 23 | raise RuntimeError("uint4 data should be saved in uint8.") 24 | 25 | numel = flat_data.shape[0] 26 | packed = np.zeros((numel + 1) // 2, dtype=np.uint8) 27 | for i in range(numel): 28 | assert flat_data[i] >= 0 and flat_data[i] <= 15 29 | if i % 2 == 0: 30 | packed[i // 2] = flat_data[i] 31 | else: 32 | packed[i // 2] |= flat_data[i] << 4 33 | return packed 34 | else: 35 | raise NotImplementedError(f"NYI dtype: {dtype}") 36 | -------------------------------------------------------------------------------- /tico/serialize/quant_param.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import List, Optional 17 | 18 | import torch 19 | 20 | """ 21 | This is a key for torch.fx.Node's meta dict to save QuantParam 22 | 23 | QuantParam can be retrieved as node.meta[QPARAM_KEY] 24 | """ 25 | QPARAM_KEY = "_quantization_parameters_" 26 | 27 | 28 | @dataclass 29 | class QuantParam: 30 | scale: Optional[List[float]] = None 31 | zero_point: Optional[List[int]] = None 32 | quantized_dimension: Optional[int] = None 33 | min: Optional[List[float]] = None 34 | max: Optional[List[float]] = None 35 | # NOTE We define dtype as a string to easily extend new dtypes (ex: uint4) 36 | dtype: str = "" 37 | 38 | 39 | def to_qparam_dtype(dtype: torch.dtype) -> str: 40 | str_type = str(dtype) 41 | assert str_type.startswith("torch.") 42 | return str_type[6:] 43 | -------------------------------------------------------------------------------- /tico/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/utils/define.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict, List 16 | 17 | from circle_schema import circle 18 | 19 | from tico.serialize.circle_graph import CircleSubgraph 20 | from tico.serialize.operators.hashable_opcode import OpCode 21 | from tico.serialize.operators.utils import create_builtin_operator, get_op_index 22 | 23 | 24 | def define_pad_node( 25 | graph: CircleSubgraph, op_codes: Dict[OpCode, int], inputs: List, outputs: List 26 | ) -> circle.Operator.OperatorT: 27 | def set_pad_option(operator: circle.Operator.OperatorT): 28 | operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.PadOptions 29 | option = circle.PadOptions.PadOptionsT() 30 | operator.builtinOptions = option 31 | 32 | pad_op_index = get_op_index(circle.BuiltinOperator.BuiltinOperator.PAD, op_codes) 33 | operator = create_builtin_operator(graph, pad_op_index, inputs, outputs) 34 | set_pad_option(operator) 35 | return operator 36 | -------------------------------------------------------------------------------- /tico/utils/dtype.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from circle_schema import circle 5 | 6 | NUMPY_TO_TORCH_DTYPE_DICT = { 7 | np.dtype("float32"): torch.float32, 8 | np.dtype("float64"): torch.float64, 9 | np.dtype("float16"): torch.float16, 10 | np.dtype("complex64"): torch.complex64, 11 | np.dtype("complex128"): torch.complex128, 12 | np.dtype("int64"): torch.int64, 13 | np.dtype("int32"): torch.int32, 14 | np.dtype("int16"): torch.int16, 15 | np.dtype("int8"): torch.int8, 16 | np.dtype("uint8"): torch.uint8, 17 | np.dtype("bool"): torch.bool, 18 | } 19 | 20 | CIRCLE_TO_TORCH_DTYPE_DICT = { 21 | circle.TensorType.TensorType.FLOAT32: torch.float32, 22 | circle.TensorType.TensorType.UINT8: torch.uint8, 23 | circle.TensorType.TensorType.INT8: torch.int8, 24 | circle.TensorType.TensorType.INT16: torch.int16, 25 | circle.TensorType.TensorType.INT32: torch.int32, 26 | circle.TensorType.TensorType.INT64: torch.int64, 27 | circle.TensorType.TensorType.BOOL: torch.bool, 28 | } 29 | 30 | 31 | def numpy_dtype_to_torch_dtype(np_dtype: np.dtype) -> torch.dtype: 32 | return NUMPY_TO_TORCH_DTYPE_DICT[np_dtype] 33 | 34 | 35 | def circle_dtype_to_torch_dtype(circle_dtype: int) -> torch.dtype: 36 | assert isinstance(circle_dtype, int) 37 | if circle_dtype not in CIRCLE_TO_TORCH_DTYPE_DICT: 38 | raise RuntimeError(f"Unsupported dtype {circle_dtype}") 39 | 40 | torch_dtype = CIRCLE_TO_TORCH_DTYPE_DICT[circle_dtype] 41 | assert torch_dtype is not None 42 | return torch_dtype 43 | -------------------------------------------------------------------------------- /tico/utils/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class CircleExirError(Exception): 17 | """Base class for custom exceptions in project""" 18 | 19 | pass 20 | 21 | 22 | class NotYetSupportedError(CircleExirError): 23 | """ 24 | Not yet supported feature or functionality 25 | """ 26 | 27 | pass 28 | 29 | 30 | class InvalidArgumentError(CircleExirError): 31 | """ 32 | Invalid argument, which is never allowed 33 | """ 34 | 35 | pass 36 | -------------------------------------------------------------------------------- /tico/utils/installed_packages.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ############################## 16 | #### Transformers Package #### 17 | ############################## 18 | 19 | 20 | def is_transformers_installed(): 21 | try: 22 | import transformers 23 | 24 | return True 25 | except ImportError: 26 | return False 27 | 28 | 29 | def is_dynamic_cache_available(): 30 | try: 31 | from transformers.cache_utils import DynamicCache 32 | 33 | return True 34 | except ImportError: 35 | return False 36 | -------------------------------------------------------------------------------- /tico/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | 18 | 19 | def _loggerLevel(): 20 | TICO_LOG = os.environ.get("TICO_LOG") 21 | if TICO_LOG == "1": 22 | log_level = logging.FATAL 23 | elif TICO_LOG == "2": 24 | log_level = logging.WARNING 25 | elif TICO_LOG == "3": 26 | log_level = logging.INFO 27 | elif TICO_LOG == "4": 28 | log_level = logging.DEBUG 29 | else: 30 | log_level = logging.WARNING 31 | return log_level 32 | 33 | 34 | LOG_LEVEL = _loggerLevel() 35 | 36 | 37 | def getLogger(name: str): 38 | """ 39 | Get logger with setting log level according to the `TICO_LOG` environment variable. 40 | """ 41 | logging.basicConfig() 42 | logger = logging.getLogger(name) 43 | logger.setLevel(LOG_LEVEL) 44 | 45 | return logger 46 | -------------------------------------------------------------------------------- /tico/utils/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from pathlib import Path 18 | from typing import Any 19 | 20 | from tico.interpreter import infer 21 | 22 | 23 | class CircleModel: 24 | def __init__(self, circle_binary: bytes): 25 | self.circle_binary = circle_binary 26 | 27 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 28 | return infer.infer(self.circle_binary, *args, **kwargs) 29 | 30 | @staticmethod 31 | def load(circle_path: str) -> CircleModel: 32 | with open(circle_path, "rb") as f: 33 | buf = bytes(f.read()) 34 | return CircleModel(buf) 35 | 36 | def save(self, circle_path: str | Path) -> None: 37 | with open(circle_path, "wb") as f: 38 | f.write(self.circle_binary) 39 | -------------------------------------------------------------------------------- /tico/utils/mx/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /tico/utils/torch_compat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Runtime **capability-detection helpers** for the `torch.export` stack. 17 | 18 | Instead of sprinkling version checks like `torch.__version__ >= "2.9"` throughout 19 | the codebase, import these helpers once and branch on the feature you need. 20 | 21 | Each probe executes only **once per process** thanks to `functools.lru_cache`, 22 | so the overhead is negligible. 23 | """ 24 | 25 | import functools 26 | 27 | import torch 28 | 29 | 30 | @functools.lru_cache(maxsize=None) 31 | def export_produces_slice() -> bool: 32 | """ 33 | Compile a minimal model with `torch.export.export` and inspect its FX graph 34 | to see whether an `aten.slice.Tensor` node appears. 35 | 36 | Returns 37 | ------- 38 | bool 39 | * ``True`` — downstream passes should expect redundant **slice** nodes. 40 | * ``False`` — downstream passes should expect only a **select** node. 41 | """ 42 | 43 | class _Probe(torch.nn.Module): 44 | def forward(self, x): # simple slice: keep all dims except 3rd 45 | return x[:, :, 1] 46 | 47 | def get_example_inputs(self): 48 | return (torch.randn(1, 4, 4),) 49 | 50 | m = _Probe() 51 | ep = torch.export.export(m, m.get_example_inputs()) 52 | return any(n.target == torch.ops.aten.slice.Tensor for n in ep.graph.nodes) 53 | -------------------------------------------------------------------------------- /version.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.1.0" 2 | --------------------------------------------------------------------------------