├── python ├── __init__.py ├── frontend │ ├── __init__.py │ ├── commands │ │ ├── constants.py │ │ ├── __init__.py │ │ ├── bench │ │ │ ├── list.py │ │ │ └── __init__.py │ │ ├── prove.py │ │ ├── model_check.py │ │ ├── compile.py │ │ ├── witness.py │ │ ├── verify.py │ │ └── args.py │ └── cli.py ├── scripts │ └── __init__.py ├── tests │ ├── __init__.py │ ├── utils_testing │ │ └── __init__.py │ ├── circuit_e2e_tests │ │ └── __init__.py │ ├── circuit_parent_classes │ │ ├── __init__.py │ │ └── test_ort_custom_layers.py │ └── onnx_quantizer_tests │ │ ├── layers_tests │ │ ├── __init__.py │ │ ├── test_validation.py │ │ ├── test_error_cases.py │ │ ├── base_test.py │ │ ├── test_check_model.py │ │ └── test_scalability.py │ │ ├── __init__.py │ │ ├── layers │ │ ├── __init__.py │ │ ├── constant_config.py │ │ ├── flatten_config.py │ │ ├── relu_config.py │ │ ├── reshape_config.py │ │ ├── min_config.py │ │ ├── max_config.py │ │ ├── add_config.py │ │ ├── mul_config.py │ │ └── sub_config.py │ │ ├── testing_helper_functions.py │ │ ├── test_exceptions.py │ │ └── test_registered_quantizers.py ├── core │ ├── binaries │ │ └── __init__.py │ ├── circuits │ │ ├── __init__.py │ │ └── zk_model_base.py │ ├── utils │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── scratch_tests.py │ │ └── errors.py │ ├── circuit_models │ │ └── __init__.py │ ├── model_processing │ │ ├── __init__.py │ │ ├── converters │ │ │ └── __init__.py │ │ ├── onnx_quantizer │ │ │ ├── __init__.py │ │ │ └── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── max.py │ │ │ │ ├── mul.py │ │ │ │ ├── add.py │ │ │ │ ├── sub.py │ │ │ │ ├── min.py │ │ │ │ ├── relu.py │ │ │ │ ├── clip.py │ │ │ │ ├── constant.py │ │ │ │ └── gemm.py │ │ └── onnx_custom_ops │ │ │ ├── __init__.py │ │ │ ├── relu.py │ │ │ ├── batchnorm.py │ │ │ ├── custom_helpers.py │ │ │ ├── mul.py │ │ │ ├── maxpool.py │ │ │ ├── gemm.py │ │ │ └── conv.py │ ├── model_templates │ │ ├── __init__.py │ │ └── circuit_template.py │ └── __init__.py └── models │ └── models_onnx │ └── lenet.onnx ├── rust └── jstprove_circuits │ ├── src │ ├── io │ │ └── mod.rs │ ├── runner │ │ ├── mod.rs │ │ └── errors.rs │ ├── circuit_functions │ │ ├── mod.rs │ │ ├── layers │ │ │ ├── mod.rs │ │ │ ├── errors.rs │ │ │ ├── constant.rs │ │ │ ├── layer_ops.rs │ │ │ ├── flatten.rs │ │ │ ├── reshape.rs │ │ │ ├── sub.rs │ │ │ └── add.rs │ │ ├── gadgets │ │ │ └── mod.rs │ │ ├── utils │ │ │ ├── onnx_types.rs │ │ │ ├── mod.rs │ │ │ ├── constants.rs │ │ │ └── errors.rs │ │ ├── hints │ │ │ ├── mod.rs │ │ │ └── bits.rs │ │ └── errors.rs │ └── lib.rs │ ├── Cargo.toml │ └── bin │ └── simple_circuit.rs ├── rust-toolchain.toml ├── .gitmodules ├── rustfmt.toml ├── pytest.ini ├── docs ├── faq.md ├── troubleshooting.md ├── artifacts.md ├── CONTRIBUTING.md ├── models.md ├── developer-notes.md └── cli.md ├── .gitignore ├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── unit-integration-tests.yml │ ├── security.yml │ └── e2e-tests.yml ├── pyproject.toml ├── LICENSE ├── .pre-commit-config.yaml ├── Cargo.toml └── conftest.py /python/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/frontend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/binaries/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/circuits/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/circuit_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/utils_testing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/model_processing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/model_templates/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/circuit_e2e_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/circuit_parent_classes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/model_processing/converters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/io/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod io_reader; 2 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/__init__.py: -------------------------------------------------------------------------------- 1 | TEST_RNG_SEED = 2 2 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly-2025-03-27" 3 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/runner/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod main_runner; 2 | 3 | pub mod errors; 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "Expander"] 2 | path = Expander 3 | url = https://github.com/PolyhedraZK/Expander.git 4 | -------------------------------------------------------------------------------- /python/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Package metadata constants 2 | PACKAGE_NAME = "JSTprove" 3 | RUST_BINARY_NAME = "onnx_generic_circuit" 4 | -------------------------------------------------------------------------------- /python/models/models_onnx/lenet.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inference-labs-inc/JSTprove/HEAD/python/models/models_onnx/lenet.onnx -------------------------------------------------------------------------------- /python/core/utils/constants.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | MODEL_SOURCE_ONNX: str = "onnx" 4 | MODEL_SOURCE_CLASS: str = "class" 5 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 100 2 | tab_spaces = 4 3 | edition = "2024" 4 | merge_derives = true 5 | use_field_init_shorthand = true 6 | use_try_shorthand = true 7 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/mod.rs: -------------------------------------------------------------------------------- 1 | mod errors; 2 | pub mod gadgets; 3 | pub mod hints; 4 | pub mod layers; 5 | pub mod utils; 6 | 7 | pub use errors::CircuitError; 8 | -------------------------------------------------------------------------------- /python/frontend/commands/constants.py: -------------------------------------------------------------------------------- 1 | """Constants for CLI commands.""" 2 | 3 | # Circuit defaults 4 | DEFAULT_CIRCUIT_MODULE = "python.core.circuit_models.generic_onnx" 5 | DEFAULT_CIRCUIT_CLASS = "GenericModelONNX" 6 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | # Set the base directory for test discovery 3 | testpaths = python 4 | markers = 5 | unit: mark a test as a unit test 6 | integration: mark a test as an integration test (non-E2E) 7 | e2e: mark a test as a full end-to-end test 8 | 9 | # Optionally, set the pythonpath to ensure pytest finds the core module 10 | pythonpath = . 11 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/layers/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod add; 2 | pub mod batchnorm; 3 | pub mod clip; 4 | pub mod constant; 5 | pub mod conv; 6 | mod errors; 7 | pub mod flatten; 8 | pub mod gemm; 9 | mod layer_kinds; 10 | pub mod layer_ops; 11 | pub mod max; 12 | pub mod maxpool; 13 | pub mod min; 14 | pub mod mul; 15 | pub mod relu; 16 | pub mod reshape; 17 | pub mod sub; 18 | 19 | pub use errors::LayerError; 20 | pub use layer_kinds::LayerKind; 21 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLayerConfigProvider, LayerTestConfig 2 | from .factory import TestLayerFactory 3 | 4 | # Auto-discover and make available all config providers 5 | # This triggers the discovery process when the package is imported 6 | _all_configs = TestLayerFactory.get_layer_configs() 7 | 8 | # Export the factory and base classes 9 | __all__ = [ 10 | "BaseLayerConfigProvider", 11 | "LayerTestConfig", 12 | "TestLayerFactory", 13 | ] 14 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/gadgets/mod.rs: -------------------------------------------------------------------------------- 1 | //! Public gadget API for circuit construction. 2 | //! 3 | //! Re-exports all constraint-enforcing gadgets used by layers. 4 | 5 | pub mod linear_algebra; 6 | pub mod max_min_clip; 7 | pub mod range_check; 8 | 9 | pub use max_min_clip::{ShiftRangeContext, constrained_clip, constrained_max, constrained_min}; 10 | pub use range_check::{ 11 | DEFAULT_LOGUP_CHUNK_BITS, LogupRangeCheckContext, constrained_reconstruct_from_bits, 12 | logup_range_check_pow2_unsigned, range_check_pow2_unsigned, 13 | }; 14 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_custom_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pkgutil 3 | from pathlib import Path 4 | 5 | # Get the package name of the current module 6 | package_name = __name__ 7 | 8 | # Dynamically import all .py files in this package directory (except __init__.py) 9 | package_dir = Path(__file__).parent.as_posix() 10 | 11 | 12 | __all__: list[str] = [] 13 | 14 | for _, module_name, is_pkg in pkgutil.iter_modules([package_dir]): 15 | if not is_pkg and (module_name != "custom_helpers"): 16 | importlib.import_module(f"{package_name}.{module_name}") 17 | __all__.append(module_name) # noqa: PYI056 18 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | ## Do I need to specify a circuit class? 4 | **No.** JSTprove defaults to **GenericModelONNX**. 5 | 6 | ## Can I run only witness/prove/verify without compile? 7 | **Yes**, as long as you already have the **circuit** and **quantized ONNX** produced by a prior compile. 8 | 9 | ## Where does the CLI put reshaped inputs? 10 | It writes a local `*_reshaped.json` in your **current working directory** during witness/verify. 11 | 12 | ## What exactly is proven? 13 | That the **quantized model**, when evaluated on your input, produces the stated output — and that the **circuit constraints hold** (via GKR/sumcheck in Expander). 14 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/testing_helper_functions.py: -------------------------------------------------------------------------------- 1 | from onnx import ModelProto 2 | 3 | 4 | # Helper to extract input shapes 5 | def get_input_shapes(model: ModelProto) -> dict: 6 | input_shapes = {} 7 | for inp in model.graph.input: 8 | shape = [] 9 | for dim in inp.type.tensor_type.shape.dim: 10 | if dim.HasField("dim_value"): 11 | shape.append(int(dim.dim_value)) 12 | elif dim.dim_param: 13 | shape.append(1) # Default for dynamic dims 14 | else: 15 | shape.append(1) 16 | input_shapes[inp.name] = shape 17 | return input_shapes 18 | -------------------------------------------------------------------------------- /python/frontend/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from python.frontend.commands.base import BaseCommand 2 | from python.frontend.commands.bench import BenchCommand 3 | from python.frontend.commands.compile import CompileCommand 4 | from python.frontend.commands.model_check import ModelCheckCommand 5 | from python.frontend.commands.prove import ProveCommand 6 | from python.frontend.commands.verify import VerifyCommand 7 | from python.frontend.commands.witness import WitnessCommand 8 | 9 | __all__ = [ 10 | "BaseCommand", 11 | "BenchCommand", 12 | "CompileCommand", 13 | "ModelCheckCommand", 14 | "ProveCommand", 15 | "VerifyCommand", 16 | "WitnessCommand", 17 | ] 18 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/utils/onnx_types.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use serde_json::Value; 3 | use std::collections::HashMap; 4 | 5 | #[derive(Deserialize, Clone, Debug)] 6 | pub struct ONNXIO { 7 | pub name: String, 8 | pub elem_type: i16, 9 | pub shape: Vec, 10 | } 11 | 12 | #[derive(Deserialize, Clone, Debug)] 13 | pub struct ONNXLayer { 14 | pub id: usize, 15 | pub name: String, 16 | pub op_type: String, 17 | pub inputs: Vec, 18 | pub outputs: Vec, 19 | pub shape: HashMap>, 20 | pub tensor: Option, 21 | pub params: Option, 22 | pub opset_version_number: i16, 23 | } 24 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/utils/mod.rs: -------------------------------------------------------------------------------- 1 | //! Utility modules used throughout circuit construction and ONNX translation. 2 | 3 | pub mod build_layers; 4 | pub mod constants; 5 | mod errors; 6 | pub mod graph_pattern_matching; 7 | pub mod json_array; // JSON-to-array conversions and trait lifting 8 | pub mod onnx_model; // ONNX layer param extraction and model shape helpers 9 | pub mod onnx_types; // 10 | pub mod quantization; // Quantization utilities and rescaling logic 11 | pub mod shaping; // Shape manipulation and input partitioning 12 | pub mod tensor_ops; // Conversions between nested Vecs and ArrayD 13 | pub mod typecasting; 14 | 15 | pub use errors::{ArrayConversionError, BuildError, PatternError, RescaleError, UtilsError}; 16 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/utils/constants.rs: -------------------------------------------------------------------------------- 1 | /// Matrix transpose flag 2 | pub const TRANS_A: &str = "transA"; 3 | pub const TRANS_B: &str = "transB"; 4 | 5 | /// GEMM alpha or beta flag 6 | pub const ALPHA: &str = "alpha"; 7 | pub const BETA: &str = "beta"; 8 | 9 | /// General Matrix Multiply (GEMM) operation name 10 | pub const GEMM: &str = "Gemm"; 11 | 12 | /// Value for Constant layer 13 | pub const VALUE: &str = "value"; 14 | 15 | /// AXIS for reshaping 16 | pub const AXIS: &str = "axis"; 17 | 18 | pub const KERNEL_SHAPE: &str = "kernel_shape"; 19 | pub const STRIDES: &str = "strides"; 20 | pub const DILATION: &str = "dilations"; 21 | pub const PADS: &str = "pads"; 22 | pub const GROUP: &str = "group"; 23 | 24 | pub const INPUT: &str = "input"; 25 | pub const WEIGHTS: &str = "weights"; 26 | pub const BIAS: &str = "bias"; 27 | pub const INPUT_SHAPE: &str = "input shape"; 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Output directories and temporary files 2 | inputs/ 3 | output/ 4 | proofs/ 5 | temp/ 6 | 7 | # Virtual environments 8 | .venv/ 9 | env/ 10 | venv/ 11 | 12 | # System and metadata files 13 | .DS_Store 14 | .DS_* 15 | ._* 16 | 17 | # IDE/editor folders 18 | .vscode/ 19 | .idea/ 20 | 21 | # Compiled and cache files 22 | __pycache__/ 23 | *.pyc 24 | *.pyo 25 | *.py[cod] 26 | *.ipynb_checkpoints/ 27 | *.cfg 28 | lib/ 29 | bin/ 30 | share/ 31 | report 32 | **.egg-info 33 | 34 | # Rust Binaries 35 | target/ 36 | !rust/jstprove_circuits/bin 37 | python/core/binaries/ 38 | 39 | # Dataset and binary files 40 | *.json 41 | *.jsonl 42 | *.txt 43 | *.bin 44 | *.csv 45 | *.compiled 46 | *.pth 47 | 48 | # Log files 49 | *.log 50 | 51 | # Archives 52 | *.zip 53 | 54 | # Environment variables 55 | .env 56 | 57 | # Retain Specific Model Files 58 | *.onnx 59 | !python/models/models_onnx/*.onnx 60 | 61 | # Proving System 62 | Expander/ 63 | -------------------------------------------------------------------------------- /python/core/circuits/zk_model_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from python.core.circuits.base import Circuit 4 | from python.core.utils.general_layer_functions import GeneralLayerFunctions 5 | 6 | 7 | class ZKModelBase(GeneralLayerFunctions, Circuit): 8 | """ 9 | Abstract base class for Zero-Knowledge (ZK) ML models. 10 | 11 | This class provides a standard interface for ZK circuit ML models. 12 | Instantiates Circuit and GeneralLayerFunctions. 13 | 14 | Subclasses must implement the constructor to define the model's 15 | architecture, layers, and circuit details. 16 | """ 17 | 18 | def __init__(self: ZKModelBase) -> None: 19 | """Initialize the ZK model. Must be overridden by subclasses 20 | 21 | Raises: 22 | NotImplementedError: If called on the base class directly. 23 | """ 24 | msg = "Must implement __init__" 25 | raise NotImplementedError(msg) 26 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## Related Issue 5 | 6 | - 7 | 8 | ## Type of Change 9 | 10 | - [ ] Bug fix (non-breaking) 11 | - [ ] New feature (non-breaking) 12 | - [ ] Breaking change (fix/feature causing existing functionality to break) 13 | - [ ] Refactor (non-functional changes) 14 | - [ ] Documentation update 15 | 16 | ## Checklist 17 | 18 | - [ ] Code follows project patterns 19 | - [ ] Tests added/updated (if applicable) 20 | - [ ] Documentation updated (if applicable) 21 | - [ ] Self-review of code 22 | - [ ] All tests pass locally 23 | - [ ] Linter passes locally 24 | 25 | ## Deployment Notes 26 | 27 | - 28 | 29 | ## Additional Comments 30 | 31 | - 32 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/hints/mod.rs: -------------------------------------------------------------------------------- 1 | //! Hint infrastructure: LogUp hints + unconstrained arithmetic helpers. 2 | 3 | pub mod bits; 4 | pub use bits::unconstrained_to_bits; 5 | 6 | pub mod max_min_clip; 7 | pub use max_min_clip::{unconstrained_clip, unconstrained_max, unconstrained_min}; 8 | 9 | /// LogUp hint registration 10 | use circuit_std_rs::logup::{query_count_by_key_hint, query_count_hint, rangeproof_hint}; 11 | use expander_compiler::field::Field as CompilerField; 12 | use expander_compiler::hints::registry::HintRegistry; 13 | 14 | /// Build a HintRegistry with all LogUp-related hints registered. 15 | /// These names MUST match the identifiers used by new_hint(...). 16 | pub fn build_logup_hint_registry() -> HintRegistry { 17 | let mut registry = HintRegistry::::new(); 18 | 19 | registry.register("myhint.querycounthint", query_count_hint::); 20 | registry.register("myhint.querycountbykeyhint", query_count_by_key_hint::); 21 | registry.register("myhint.rangeproofhint", rangeproof_hint::); 22 | 23 | registry 24 | } 25 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "JSTprove" 7 | version = "0.1.0" 8 | description = "Zero-knowledge proofs of ML inference on ONNX models" 9 | readme = "README.md" 10 | authors = [{ name = "Inference Labs Inc" }] 11 | requires-python = ">=3.10" 12 | dependencies = [ 13 | "numpy==2.2.6", 14 | "onnx==1.17.0", 15 | "onnxruntime==1.21.0", 16 | "onnxruntime_extensions==0.14.0", 17 | "psutil==7.0.0", 18 | "Requests==2.32.3", 19 | "scikit_learn==1.6.1", 20 | "toml==0.10.2", 21 | "tomli==2.0.1; python_version < '3.11'", 22 | "torch==2.6.0", 23 | "transformers==4.52.4", 24 | ] 25 | 26 | [dependency-groups] 27 | test = [ 28 | "pytest==8.3.5", 29 | "pytest-html==4.1.1", 30 | "pre-commit==4.3.0" 31 | ] 32 | 33 | [project.scripts] 34 | jst = "python.frontend.cli:main" 35 | 36 | [tool.setuptools.packages.find] 37 | where = ["."] 38 | include = ["python*"] 39 | 40 | [tool.setuptools.package-data] 41 | "python.core.binaries" = ["*"] 42 | "python.core.lib" = ["*.so*", "*.dylib"] 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025 Inference Labs Inc. 2 | 3 | Source Access Grant 4 | You may access, view, study, and modify the source code of this software. 5 | 6 | Redistribution Conditions 7 | You may redistribute this software in source or modified form provided that: 8 | a) You retain this license document and all copyright notices 9 | b) Any modified files carry prominent notices stating you changed them 10 | c) You do not misrepresent the origin of the software 11 | 12 | Usage Restriction 13 | NO USE RIGHTS ARE GRANTED BY THIS LICENSE. Any operational use including but not limited to: 14 | - Execution of the software 15 | - Integration with other systems 16 | - Deployment in any environment 17 | - Commercial or production utilization requires express written permission from the IP Owner. 18 | 19 | Intellectual Property Reservation 20 | All rights not expressly granted herein are reserved by the IP Owner. For usage permissions, contact: legal@inferencelabs.com 21 | 22 | Disclaimer 23 | THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND. THE IP OWNER SHALL NOT BE LIABLE FOR ANY DAMAGES ARISING FROM ACCESS OR DISTRIBUTION. 24 | 25 | License Propagation 26 | Any distribution of this software or derivatives must be under this same license agreement. 27 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/errors.rs: -------------------------------------------------------------------------------- 1 | use crate::circuit_functions::{ 2 | layers::LayerError, 3 | utils::{ArrayConversionError, BuildError, PatternError, RescaleError, UtilsError}, 4 | }; 5 | use crate::io::io_reader::onnx_context::OnnxContextError; 6 | use thiserror::Error; 7 | 8 | #[derive(Debug, Error)] 9 | pub enum CircuitError { 10 | #[error("Layer failed: {0}")] 11 | Layer(#[from] LayerError), 12 | 13 | #[error(transparent)] 14 | UtilsError(#[from] UtilsError), 15 | 16 | #[error("Failed to parse weights JSON: {0}")] 17 | InvalidWeightsFormat(#[from] serde_json::Error), 18 | 19 | #[error("Architecture definition is empty")] 20 | EmptyArchitecture, 21 | 22 | #[error("Graph error: {0}")] 23 | GraphPatternError(#[from] PatternError), 24 | 25 | #[error("Array conversion error: {0}")] 26 | ArrayConversionError(#[from] ArrayConversionError), 27 | 28 | #[error("Rescaling error: {0}")] 29 | RescaleError(#[from] RescaleError), 30 | 31 | #[error("Error building layers: {0}")] 32 | BuildError(#[from] BuildError), 33 | 34 | #[error("ONNX context error: {0}")] 35 | OnnxContext(#[from] OnnxContextError), 36 | 37 | #[error("Other circuit error: {0}")] 38 | Other(String), 39 | } 40 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Circuit construction utilities for `JSTprove`, a zero-knowledge proof system 2 | //! supporting fixed-point quantization, neural-network inference, and modular 3 | //! arithmetic over finite fields. 4 | //! 5 | //! # Crate Structure 6 | //! 7 | //! - [`circuit_functions`]: Low-level arithmetic gadgets (LogUp, range-checks, 8 | //! max/min/clip, etc.) and high-level building blocks for layers such as 9 | //! matmul, convolution, `ReLU`, and quantized rescaling. 10 | //! 11 | //! - [`runner`]: CLI-oriented orchestration for compiling, proving, and 12 | //! verifying circuits, including witness generation and memory tracking. 13 | //! 14 | //! - [`io`]: Input/output helpers for serializing circuit inputs and outputs, 15 | //! including ONNX exports and JSON-encoded tensors. 16 | //! 17 | //! Typical usage involves composing layer gadgets from [`circuit_functions`], 18 | //! then invoking the tools in [`runner`] to generate and verify proofs. 19 | //! 20 | //! # Feature Flags 21 | //! 22 | //! This crate requires the nightly feature `min_specialization`. 23 | #![allow( 24 | clippy::doc_markdown, 25 | clippy::doc_lazy_continuation, 26 | clippy::doc_overindented_list_items 27 | )] 28 | #![feature(min_specialization)] 29 | 30 | pub mod circuit_functions; 31 | pub mod io; 32 | pub mod runner; 33 | -------------------------------------------------------------------------------- /docs/troubleshooting.md: -------------------------------------------------------------------------------- 1 | # Troubleshooting 2 | 3 | Common issues and quick fixes. 4 | 5 | --- 6 | 7 | ## Runner not found 8 | 9 | - Run commands from the **repo root** so `./target/release/*` is visible. 10 | - Re-run **compile** (it will build the runner automatically if needed). 11 | 12 | 13 | --- 14 | 15 | ## Shape or "out of bounds" errors during witness 16 | 17 | - Ensure your `--input-path` matches the model's input **shape**. 18 | - Re-run **compile** after changing the model (to refresh circuit + quantization). 19 | - Make sure **witness** and **verify** both use the **same `quantized.onnx`** produced by the last compile. 20 | 21 | --- 22 | 23 | ## Verification complains about shapes 24 | 25 | - If your model has multiple inputs, ensure the input JSON includes **all input keys** with correct shapes. 26 | 27 | --- 28 | 29 | ## Slow runs 30 | 31 | - Large CNNs are heavy. For smoke tests: 32 | - Use a **smaller model** or **reduced input size**. 33 | - Or use the `simple_circuit` Rust binary to validate the toolchain quickly. 34 | 35 | --- 36 | 37 | ## General tips 38 | 39 | - Keep artifacts from the **same compile** together: `circuit.txt` + `quantized.onnx`. 40 | - If anything looks mismatched, re-run **compile → witness → prove → verify** end-to-end. 41 | - Set `JSTPROVE_NO_BANNER=1` or use `--no-banner` for quiet logs in CI. 42 | -------------------------------------------------------------------------------- /.github/workflows/unit-integration-tests.yml: -------------------------------------------------------------------------------- 1 | name: Unit, Integration Python Tests 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: "3.12" 21 | 22 | - name: Install UV 23 | uses: astral-sh/setup-uv@v5 24 | 25 | - name: Install the project 26 | run: uv sync --group test 27 | 28 | - name: Install OMPI 29 | run: | 30 | sudo apt-get update && sudo apt-get install -y \ 31 | libopenmpi-dev \ 32 | pkg-config \ 33 | libclang-dev \ 34 | clang \ 35 | openmpi-bin \ 36 | 37 | - name: Run Unit Tests 38 | run: | 39 | mkdir -p report 40 | uv run pytest --unit --html=report/unit_tests.html 41 | 42 | - name: Run Integration Tests 43 | run: | 44 | mkdir -p report 45 | uv run pytest --integration --html=report/integration_tests.html 46 | 47 | - name: Upload Pytest HTML Report 48 | uses: actions/upload-artifact@v4 49 | with: 50 | name: pytest-html-report 51 | path: report 52 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/constant_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from onnx import numpy_helper 3 | 4 | from python.tests.onnx_quantizer_tests.layers.base import e2e_test, valid_test 5 | from python.tests.onnx_quantizer_tests.layers.factory import ( 6 | BaseLayerConfigProvider, 7 | LayerTestConfig, 8 | ) 9 | 10 | 11 | class ConstantConfigProvider(BaseLayerConfigProvider): 12 | """Test configuration provider for Constant layers""" 13 | 14 | @property 15 | def layer_name(self) -> str: 16 | return "Constant" 17 | 18 | def get_config(self) -> LayerTestConfig: 19 | return LayerTestConfig( 20 | op_type="Constant", 21 | valid_inputs=[], 22 | valid_attributes={ 23 | "value": numpy_helper.from_array(np.array([1.0]), name="const_value"), 24 | }, 25 | required_initializers={}, 26 | ) 27 | 28 | def get_test_specs(self) -> list: 29 | return [ 30 | valid_test("basic") 31 | .description("Basic Constant node returning scalar 1.0") 32 | .tags("basic", "constant") 33 | .build(), 34 | e2e_test("e2e_basic") 35 | .description("End-to-end test for Constant node") 36 | .override_output_shapes(constant_output=[1]) 37 | .tags("e2e", "constant") 38 | .build(), 39 | ] 40 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/layers/errors.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | use crate::circuit_functions::layers::LayerKind; 4 | 5 | #[derive(Debug, Error)] 6 | pub enum LayerError { 7 | #[error("{layer} is missing input: {name}")] 8 | MissingInput { layer: LayerKind, name: String }, 9 | 10 | #[error("Shape mismatch in {layer} for {var_name}: expected {expected:?}, got {got:?}")] 11 | ShapeMismatch { 12 | layer: LayerKind, 13 | expected: Vec, 14 | got: Vec, 15 | var_name: String, 16 | }, 17 | 18 | #[error("{layer} is missing parameter: {param}")] 19 | MissingParameter { layer: LayerKind, param: String }, 20 | 21 | #[error("{layer} layer '{layer_name}' has an invalid value for {param_name}: {value}")] 22 | InvalidParameterValue { 23 | layer: LayerKind, 24 | layer_name: String, 25 | param_name: String, 26 | value: String, 27 | }, 28 | 29 | #[error("Unsupported config in {layer}: {msg}")] 30 | UnsupportedConfig { layer: LayerKind, msg: String }, 31 | 32 | #[error("Invalid shape in {layer}: {msg}")] 33 | InvalidShape { layer: LayerKind, msg: String }, 34 | 35 | #[error("Unknown operator type: {op_type}")] 36 | UnknownOp { op_type: String }, 37 | 38 | #[error("Other error in {layer}: {msg}")] 39 | Other { layer: LayerKind, msg: String }, 40 | } 41 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_custom_ops/relu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from onnxruntime_extensions import PyCustomOpDef, onnx_op 3 | 4 | 5 | @onnx_op( 6 | op_type="Int64Relu", 7 | domain="ai.onnx.contrib", 8 | inputs=[PyCustomOpDef.dt_int64], 9 | outputs=[PyCustomOpDef.dt_int64], 10 | ) 11 | def int64_relu(x: np.ndarray) -> np.ndarray: 12 | """ 13 | Performs a ReLU operation on int64 input tensors. 14 | 15 | This function is registered as a custom ONNX operator via onnxruntime_extensions 16 | and is used in the JSTprove quantized inference pipeline. 17 | It applies ReLU as is (there are no attributes to ReLU). 18 | 19 | Parameters 20 | ---------- 21 | X : Input tensor with dtype int64. 22 | 23 | Returns 24 | ------- 25 | numpy.ndarray 26 | ReLU tensor with dtype int64. 27 | 28 | Notes 29 | ----- 30 | - This op is part of the `ai.onnx.contrib` custom domain. 31 | - ONNX Runtime Extensions is required to register this op. 32 | 33 | References 34 | ---------- 35 | For more information on the ReLU operation, please refer to the 36 | ONNX standard ReLU operator documentation: 37 | https://onnx.ai/onnx/operators/onnx__Relu.html 38 | """ 39 | try: 40 | return np.maximum(x, 0).astype(np.int64) 41 | except Exception as e: 42 | msg = f"Int64ReLU failed: {e}" 43 | raise RuntimeError(msg) from e 44 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "jstprove_circuits" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | ark-std.workspace = true 8 | rand.workspace = true 9 | chrono.workspace = true 10 | clap.workspace = true 11 | ethnum.workspace = true 12 | halo2curves.workspace = true 13 | tiny-keccak.workspace = true 14 | expander_circuit.workspace = true 15 | expander_transcript.workspace = true 16 | gkr.workspace = true 17 | expander_binary.workspace = true 18 | arith.workspace = true 19 | gf2.workspace = true 20 | mersenne31.workspace = true 21 | crosslayer_prototype.workspace = true 22 | serde_json = "1.0" 23 | peakmem-alloc = "0.3.0" 24 | expander_compiler.workspace = true 25 | lazy_static = "1.4" 26 | circuit-std-rs.workspace = true 27 | babybear.workspace = true 28 | gkr_engine.workspace = true 29 | gkr_hashers.workspace = true 30 | goldilocks.workspace = true 31 | serdes.workspace = true 32 | ndarray.workspace = true 33 | mpi.workspace = true 34 | thiserror.workspace = true 35 | strum.workspace = true 36 | strum_macros.workspace = true 37 | once_cell.workspace = true 38 | 39 | 40 | [dev-dependencies] 41 | rayon = "1.9" 42 | sha2 = "0.10.8" 43 | 44 | [dependencies.serde] 45 | version = "1.0" 46 | features = [ "derive",] 47 | 48 | [[bin]] 49 | name = "simple_circuit" 50 | path = "bin/simple_circuit.rs" 51 | 52 | [[bin]] 53 | name = "onnx_generic_circuit_0-1-0" 54 | path = "bin/generic_demo.rs" 55 | -------------------------------------------------------------------------------- /python/frontend/commands/bench/list.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import argparse 7 | 8 | from python.frontend.commands.args import ArgSpec 9 | from python.frontend.commands.base import BaseCommand 10 | 11 | LIST_MODELS = ArgSpec( 12 | name="list_models", 13 | flag="--list-models", 14 | help_text="List all available circuit models.", 15 | extra_kwargs={"action": "store_true", "default": False}, 16 | ) 17 | 18 | 19 | class ListCommand(BaseCommand): 20 | """List all available circuit models for benchmarking.""" 21 | 22 | name: ClassVar[str] = "list" 23 | aliases: ClassVar[list[str]] = [] 24 | help: ClassVar[str] = "List all available circuit models for benchmarking." 25 | 26 | @classmethod 27 | def configure_parser( 28 | cls: type[ListCommand], 29 | parser: argparse.ArgumentParser, 30 | ) -> None: 31 | LIST_MODELS.add_to_parser(parser) 32 | 33 | @classmethod 34 | def run(cls: type[ListCommand], args: argparse.Namespace) -> None: # noqa: ARG003 35 | from python.core.utils.model_registry import ( # noqa: PLC0415 36 | list_available_models, 37 | ) 38 | 39 | available_models = list_available_models() 40 | print("\nAvailable Circuit Models:") # noqa: T201 41 | for model in available_models: 42 | print(f"- {model}") # noqa: T201 43 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/runner/errors.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use thiserror::Error; 4 | 5 | use crate::io::io_reader::onnx_context::OnnxContextError; 6 | 7 | #[derive(Error, Debug)] 8 | pub enum CliError { 9 | #[error("Missing required argument: {0}")] 10 | MissingArgument(&'static str), 11 | 12 | #[error("Unknown command: {0}")] 13 | UnknownCommand(String), 14 | 15 | #[error("IO error: {0}")] 16 | Io(#[from] std::io::Error), 17 | 18 | #[error(transparent)] 19 | RunError(#[from] RunError), 20 | 21 | #[error("Other error: {0}")] 22 | Other(String), 23 | } 24 | 25 | #[derive(Debug, Error)] 26 | pub enum RunError { 27 | #[error("I/O error while accessing {path}: {source}")] 28 | Io { 29 | #[source] 30 | source: io::Error, 31 | path: String, 32 | }, 33 | 34 | #[error("JSON deserialization error: {0}")] 35 | Json(String), 36 | 37 | #[error("Circuit compilation failed: {0}")] 38 | Compile(String), 39 | 40 | #[error("Serialization error: {0}")] 41 | Serialize(String), 42 | 43 | #[error("Deserialization error: {0}")] 44 | Deserialize(String), 45 | 46 | #[error("Witness generation failed: {0}")] 47 | Witness(String), 48 | 49 | #[error("Proving witness failed: {0}")] 50 | Prove(String), 51 | 52 | #[error("Verifying proof failed: {0}")] 53 | Verify(String), 54 | 55 | #[error("Error configuring circuit: {0}")] 56 | ConfigureCircuit(#[from] OnnxContextError), 57 | } 58 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/max.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import onnx 7 | 8 | from python.core.model_processing.onnx_quantizer.layers.base import ( 9 | BaseOpQuantizer, 10 | QuantizerBase, 11 | ScaleConfig, 12 | ) 13 | 14 | 15 | class QuantizeMax(QuantizerBase): 16 | OP_TYPE = "Max" 17 | DOMAIN = "" 18 | USE_WB = True 19 | USE_SCALING = False 20 | SCALE_PLAN: ClassVar = {1: 1} 21 | 22 | 23 | class MaxQuantizer(BaseOpQuantizer, QuantizeMax): 24 | def __init__( 25 | self, 26 | new_initializers: list[onnx.TensorProto] | None = None, 27 | ) -> None: 28 | super().__init__() 29 | if new_initializers is not None: 30 | # Share the caller-provided buffer instead of the default list. 31 | self.new_initializers = new_initializers 32 | 33 | def quantize( 34 | self, 35 | node: onnx.NodeProto, 36 | graph: onnx.GraphProto, 37 | scale_config: ScaleConfig, 38 | initializer_map: dict[str, onnx.TensorProto], 39 | ) -> list[onnx.NodeProto]: 40 | # Delegate to the shared QuantizerBase logic 41 | return QuantizeMax.quantize(self, node, graph, scale_config, initializer_map) 42 | 43 | def check_supported( 44 | self, 45 | node: onnx.NodeProto, 46 | initializer_map: dict[str, onnx.TensorProto] | None = None, 47 | ) -> None: 48 | # If later we want to enforce/relax broadcasting, add it here. 49 | pass 50 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers_tests/test_validation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import onnx 6 | import pytest 7 | 8 | if TYPE_CHECKING: 9 | from python.tests.onnx_quantizer_tests.layers.base import ( 10 | LayerTestConfig, 11 | LayerTestSpec, 12 | ) 13 | from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory 14 | from python.tests.onnx_quantizer_tests.layers_tests.base_test import ( 15 | BaseQuantizerTest, 16 | ) 17 | 18 | 19 | class TestValidation(BaseQuantizerTest): 20 | """Ensure that layer factory models produce valid ONNX graphs.""" 21 | 22 | __test__ = True 23 | 24 | @pytest.mark.unit 25 | @pytest.mark.parametrize( 26 | "test_case_data", 27 | TestLayerFactory.get_all_test_cases(), 28 | ids=BaseQuantizerTest._generate_test_id, 29 | ) 30 | def test_factory_models_pass_onnx_validation( 31 | self: TestValidation, 32 | test_case_data: tuple[str, LayerTestConfig, LayerTestSpec], 33 | ) -> None: 34 | layer_name, config, test_spec = test_case_data 35 | test_case_id = f"{layer_name}_{test_spec.name}" 36 | 37 | if test_spec.skip_reason: 38 | pytest.skip(f"{test_case_id}: {test_spec.skip_reason}") 39 | 40 | model = config.create_test_model(test_spec) 41 | try: 42 | onnx.checker.check_model(model) 43 | except onnx.checker.ValidationError as e: 44 | self._validation_failed_cases.add(test_case_id) 45 | pytest.fail(f"Invalid ONNX model: {e}") 46 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/mul.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import onnx 7 | 8 | from python.core.model_processing.onnx_quantizer.layers.base import ( 9 | BaseOpQuantizer, 10 | QuantizerBase, 11 | ScaleConfig, 12 | ) 13 | 14 | 15 | class QuantizeMul(QuantizerBase): 16 | OP_TYPE = "Int64Mul" 17 | USE_WB = True 18 | USE_SCALING = True 19 | SCALE_PLAN: ClassVar = {0: 1, 1: 1} 20 | 21 | 22 | class MulQuantizer(BaseOpQuantizer, QuantizeMul): 23 | """ 24 | Quantizer for ONNX Mul layers. 25 | 26 | - Uses custom Mul layer to incorporate rescaling, and 27 | makes relevant additional changes to the graph. 28 | """ 29 | 30 | def __init__( 31 | self: MulQuantizer, 32 | new_initializers: list[onnx.TensorProto] | None = None, 33 | ) -> None: 34 | super().__init__() 35 | # Only replace if caller provided something 36 | if new_initializers is not None: 37 | self.new_initializers = new_initializers 38 | 39 | def quantize( 40 | self: MulQuantizer, 41 | node: onnx.NodeProto, 42 | graph: onnx.GraphProto, 43 | scale_config: ScaleConfig, 44 | initializer_map: dict[str, onnx.TensorProto], 45 | ) -> list[onnx.NodeProto]: 46 | return QuantizeMul.quantize(self, node, graph, scale_config, initializer_map) 47 | 48 | def check_supported( 49 | self: MulQuantizer, 50 | node: onnx.NodeProto, 51 | initializer_map: dict[str, onnx.TensorProto] | None = None, 52 | ) -> None: 53 | pass 54 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/add.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import onnx 7 | 8 | from python.core.model_processing.onnx_quantizer.layers.base import ( 9 | BaseOpQuantizer, 10 | QuantizerBase, 11 | ScaleConfig, 12 | ) 13 | 14 | 15 | class QuantizeAdd(QuantizerBase): 16 | OP_TYPE = "Add" 17 | DOMAIN = "" 18 | USE_WB = True 19 | USE_SCALING = False 20 | SCALE_PLAN: ClassVar = {0: 1, 1: 1} 21 | 22 | 23 | class AddQuantizer(BaseOpQuantizer, QuantizeAdd): 24 | """ 25 | Quantizer for ONNX Add layers. 26 | 27 | - Uses standard ONNX Add layer in standard domain, and 28 | makes relevant additional changes to the graph. 29 | """ 30 | 31 | def __init__( 32 | self: AddQuantizer, 33 | new_initializers: list[onnx.TensorProto] | None = None, 34 | ) -> None: 35 | super().__init__() 36 | # Only replace if caller provided something 37 | if new_initializers is not None: 38 | self.new_initializers = new_initializers 39 | 40 | def quantize( 41 | self: AddQuantizer, 42 | node: onnx.NodeProto, 43 | graph: onnx.GraphProto, 44 | scale_config: ScaleConfig, 45 | initializer_map: dict[str, onnx.TensorProto], 46 | ) -> list[onnx.NodeProto]: 47 | return QuantizeAdd.quantize(self, node, graph, scale_config, initializer_map) 48 | 49 | def check_supported( 50 | self: AddQuantizer, 51 | node: onnx.NodeProto, 52 | initializer_map: dict[str, onnx.TensorProto] | None = None, 53 | ) -> None: 54 | pass 55 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/sub.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import onnx 7 | 8 | from python.core.model_processing.onnx_quantizer.layers.base import ( 9 | BaseOpQuantizer, 10 | QuantizerBase, 11 | ScaleConfig, 12 | ) 13 | 14 | 15 | class QuantizeSub(QuantizerBase): 16 | OP_TYPE = "Sub" 17 | DOMAIN = "" 18 | USE_WB = True 19 | USE_SCALING = False 20 | SCALE_PLAN: ClassVar = {0: 1, 1: 1} 21 | 22 | 23 | class SubQuantizer(BaseOpQuantizer, QuantizeSub): 24 | """ 25 | Quantizer for ONNX Sub layers. 26 | 27 | - Uses standard ONNX Sub layer in standard domain, and 28 | makes relevant additional changes to the graph. 29 | """ 30 | 31 | def __init__( 32 | self: SubQuantizer, 33 | new_initializers: list[onnx.TensorProto] | None = None, 34 | ) -> None: 35 | super().__init__() 36 | # Only replace if caller provided something 37 | if new_initializers is not None: 38 | self.new_initializers = new_initializers 39 | 40 | def quantize( 41 | self: SubQuantizer, 42 | node: onnx.NodeProto, 43 | graph: onnx.GraphProto, 44 | scale_config: ScaleConfig, 45 | initializer_map: dict[str, onnx.TensorProto], 46 | ) -> list[onnx.NodeProto]: 47 | return QuantizeSub.quantize(self, node, graph, scale_config, initializer_map) 48 | 49 | def check_supported( 50 | self: SubQuantizer, 51 | node: onnx.NodeProto, 52 | initializer_map: dict[str, onnx.TensorProto] | None = None, 53 | ) -> None: 54 | pass 55 | -------------------------------------------------------------------------------- /docs/artifacts.md: -------------------------------------------------------------------------------- 1 | # Artifacts 2 | 3 | This page describes the files JSTprove reads/writes during the pipeline. 4 | 5 | --- 6 | 7 | ## Files you'll typically see 8 | 9 | - **Circuit** — `circuit.txt` 10 | Compiled Expander circuit description. 11 | _Produced by:_ `compile` 12 | 13 | - **Quantized model** — `quantized.onnx` 14 | ONNX model with integerized ops (used by witness/verify to hydrate shapes). 15 | _Produced by:_ `compile` 16 | 17 | - **Inputs** — your input JSON (you provide it) 18 | During witness/verify the CLI also creates a local `*_reshaped.json` (next to your CWD) after scaling/reshaping. 19 | _Consumed by:_ `witness`, `verify` 20 | 21 | - **Outputs** — `output.json` 22 | Model outputs (integer domain) computed from the quantized model. 23 | _Produced by:_ `witness` (and used by `verify`) 24 | 25 | - **Witness** — `witness.bin` 26 | Private inputs / auxiliary data for proving. 27 | _Produced by:_ `witness` (consumed by `prove`, `verify`) 28 | 29 | - **Proof** — `proof.bin` 30 | Zero-knowledge proof blob. 31 | _Produced by:_ `prove` (checked by `verify`) 32 | 33 | --- 34 | 35 | ## Typical layout 36 | 37 | You control all paths; the CLI **does not** infer directories. 38 | 39 | ``` 40 | 41 | artifacts/ 42 | lenet/ 43 | circuit.txt 44 | quantized.onnx 45 | output.json 46 | witness.bin 47 | proof.bin 48 | models/ 49 | inputs/ 50 | lenet_input.json 51 | 52 | ``` 53 | 54 | > Note: `*_reshaped.json` is generated in your current working directory during witness/verify. It’s a convenience file reflecting the scaled/reshaped inputs actually fed into the circuit. 55 | 56 | --- 57 | 58 | ## Tips 59 | 60 | - Keep artifacts from the **same compile** together (circuit + quantized ONNX) to avoid shape/version mismatches. 61 | - If you change the ONNX model, **re-run compile** before witness/prove/verify. 62 | - Store inputs/outputs under versioned folders if you need reproducibility. 63 | -------------------------------------------------------------------------------- /python/frontend/commands/prove.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import argparse 7 | 8 | from python.core.circuits.errors import CircuitRunError 9 | from python.core.utils.helper_functions import CircuitExecutionConfig, RunType 10 | from python.frontend.commands.args import CIRCUIT_PATH, PROOF_PATH, WITNESS_PATH 11 | from python.frontend.commands.base import BaseCommand 12 | 13 | 14 | class ProveCommand(BaseCommand): 15 | """Generate proof from witness.""" 16 | 17 | name: ClassVar[str] = "prove" 18 | aliases: ClassVar[list[str]] = ["prov"] 19 | help: ClassVar[str] = "Generate a proof from a circuit and witness." 20 | 21 | @classmethod 22 | def configure_parser( 23 | cls: type[ProveCommand], 24 | parser: argparse.ArgumentParser, 25 | ) -> None: 26 | CIRCUIT_PATH.add_to_parser(parser) 27 | WITNESS_PATH.add_to_parser(parser, "Path to an existing witness.") 28 | PROOF_PATH.add_to_parser(parser) 29 | 30 | @classmethod 31 | @BaseCommand.validate_required(CIRCUIT_PATH, WITNESS_PATH, PROOF_PATH) 32 | @BaseCommand.validate_paths(CIRCUIT_PATH, WITNESS_PATH) 33 | @BaseCommand.validate_parent_paths(PROOF_PATH) 34 | def run(cls: type[ProveCommand], args: argparse.Namespace) -> None: 35 | circuit = cls._build_circuit("cli") 36 | 37 | try: 38 | circuit.base_testing( 39 | CircuitExecutionConfig( 40 | run_type=RunType.PROVE_WITNESS, 41 | circuit_path=args.circuit_path, 42 | witness_file=args.witness_path, 43 | proof_file=args.proof_path, 44 | ecc=False, 45 | ), 46 | ) 47 | except CircuitRunError as e: 48 | raise RuntimeError(e) from e 49 | 50 | print(f"[prove] wrote proof → {args.proof_path}") # noqa: T201 51 | -------------------------------------------------------------------------------- /python/frontend/commands/bench/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import argparse 7 | 8 | from python.frontend.commands.base import BaseCommand 9 | from python.frontend.commands.bench.list import ListCommand 10 | from python.frontend.commands.bench.model import ModelCommand 11 | from python.frontend.commands.bench.sweep import SweepCommand 12 | 13 | 14 | class BenchCommand(BaseCommand): 15 | """Benchmark JSTprove models with various configurations.""" 16 | 17 | name: ClassVar[str] = "bench" 18 | aliases: ClassVar[list[str]] = [] 19 | help: ClassVar[str] = "Benchmark JSTprove models with various configurations." 20 | 21 | SUBCOMMANDS: ClassVar[list[type[BaseCommand]]] = [ 22 | ListCommand, 23 | ModelCommand, 24 | SweepCommand, 25 | ] 26 | 27 | @classmethod 28 | def configure_parser( 29 | cls: type[BenchCommand], 30 | parser: argparse.ArgumentParser, 31 | ) -> None: 32 | subparsers = parser.add_subparsers( 33 | dest="bench_subcommand", 34 | required=True, 35 | help="Benchmark subcommands", 36 | ) 37 | 38 | for subcommand_cls in cls.SUBCOMMANDS: 39 | subparser = subparsers.add_parser( 40 | subcommand_cls.name, 41 | help=subcommand_cls.help, 42 | aliases=subcommand_cls.aliases, 43 | ) 44 | subcommand_cls.configure_parser(subparser) 45 | 46 | @classmethod 47 | def run(cls: type[BenchCommand], args: argparse.Namespace) -> None: 48 | for subcommand_cls in cls.SUBCOMMANDS: 49 | if args.bench_subcommand in [subcommand_cls.name, *subcommand_cls.aliases]: 50 | subcommand_cls.run(args) 51 | return 52 | 53 | msg = f"Unknown bench subcommand: {args.bench_subcommand}" 54 | raise ValueError(msg) 55 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/min.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import onnx 7 | 8 | from python.core.model_processing.onnx_quantizer.layers.base import ( 9 | BaseOpQuantizer, 10 | QuantizerBase, 11 | ScaleConfig, 12 | ) 13 | 14 | 15 | class QuantizeMin(QuantizerBase): 16 | OP_TYPE = "Min" 17 | DOMAIN = "" # standard ONNX domain 18 | USE_WB = True # let framework wire inputs/outputs normally 19 | USE_SCALING = False # passthrough: no internal scaling 20 | SCALE_PLAN: ClassVar = {1: 1} # elementwise arity plan 21 | 22 | 23 | class MinQuantizer(BaseOpQuantizer, QuantizeMin): 24 | """ 25 | Passthrough quantizer for elementwise Min. 26 | We rely on the converter to quantize graph inputs; no extra scaling here. 27 | """ 28 | 29 | def __init__( 30 | self: MinQuantizer, 31 | new_initializers: list[onnx.TensorProto] | None = None, 32 | ) -> None: 33 | super().__init__() 34 | if new_initializers is not None: 35 | self.new_initializers = new_initializers 36 | 37 | def quantize( 38 | self: MinQuantizer, 39 | node: onnx.NodeProto, 40 | graph: onnx.GraphProto, 41 | scale_config: ScaleConfig, 42 | initializer_map: dict[str, onnx.TensorProto], 43 | ) -> list[onnx.NodeProto]: 44 | # Delegate to QuantizerBase's generic passthrough implementation. 45 | return QuantizeMin.quantize(self, node, graph, scale_config, initializer_map) 46 | 47 | def check_supported( 48 | self: MinQuantizer, 49 | node: onnx.NodeProto, 50 | initializer_map: dict[str, onnx.TensorProto] | None = None, 51 | ) -> None: 52 | # Min has no attributes; elementwise, variadic ≥ 1 input per ONNX spec. 53 | # We mirror Add/Max broadcasting behavior; no extra checks here. 54 | _ = node, initializer_map 55 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/relu.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | if TYPE_CHECKING: 6 | from onnx import GraphProto, NodeProto, TensorProto 7 | 8 | from python.core.model_processing.onnx_quantizer.layers.base import ( 9 | BaseOpQuantizer, 10 | QuantizerBase, 11 | ScaleConfig, 12 | ) 13 | 14 | 15 | class QuantizeRelu(QuantizerBase): 16 | OP_TYPE = "Int64Relu" 17 | USE_WB = False 18 | USE_SCALING = False 19 | 20 | 21 | class ReluQuantizer(BaseOpQuantizer, QuantizeRelu): 22 | """ 23 | Quantizer for ONNX ReLU layers. 24 | 25 | - Replaces standard ReLU with Int64ReLU from the `ai.onnx.contrib` domain 26 | and makes relevant additional changes to the graph. 27 | - Validates that all required ReLU parameters are present. 28 | """ 29 | 30 | def __init__( 31 | self: ReluQuantizer, 32 | new_initializer: list[TensorProto] | None = None, 33 | ) -> None: 34 | super().__init__() 35 | _ = new_initializer 36 | 37 | def quantize( 38 | self: ReluQuantizer, 39 | node: NodeProto, 40 | graph: GraphProto, 41 | scale_config: ScaleConfig, 42 | initializer_map: dict[str, TensorProto], 43 | ) -> list[NodeProto]: 44 | return QuantizeRelu.quantize(self, node, graph, scale_config, initializer_map) 45 | 46 | def check_supported( 47 | self: ReluQuantizer, 48 | node: NodeProto, 49 | initializer_map: dict[str, TensorProto] | None = None, 50 | ) -> None: 51 | """ 52 | Perform high-level validation to ensure that this node 53 | can be quantized safely. 54 | 55 | Args: 56 | node (onnx.NodeProto): ONNX node to be checked 57 | initializer_map (dict[str, onnx.TensorProto]): 58 | Initializer map (name of weight or bias and tensor) 59 | """ 60 | _ = node 61 | _ = initializer_map 62 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_custom_ops/batchnorm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | from onnxruntime_extensions import PyCustomOpDef, onnx_op 5 | 6 | from .custom_helpers import rescaling 7 | 8 | 9 | @onnx_op( 10 | op_type="Int64BatchNorm", 11 | domain="ai.onnx.contrib", 12 | inputs=[ 13 | PyCustomOpDef.dt_int64, # X (int64) 14 | PyCustomOpDef.dt_int64, # mul (int64 scaled multiplier) 15 | PyCustomOpDef.dt_int64, # add (int64 scaled adder) 16 | PyCustomOpDef.dt_int64, # scaling_factor 17 | ], 18 | outputs=[PyCustomOpDef.dt_int64], 19 | attrs={"rescale": PyCustomOpDef.dt_int64}, 20 | ) 21 | def int64_batchnorm( 22 | x: np.ndarray, 23 | mul: np.ndarray, 24 | add: np.ndarray, 25 | scaling_factor: np.ndarray | None = None, 26 | rescale: int | None = None, 27 | ) -> np.ndarray: 28 | """ 29 | Int64 BatchNorm (folded into affine transform). 30 | 31 | Computes: 32 | Y = X * mul + add 33 | where mul/add are already scaled to int64. 34 | 35 | Parameters 36 | ---------- 37 | x : Input int64 tensor 38 | mul : Per-channel int64 scale multipliers 39 | add : Per-channel int64 bias terms 40 | scaling_factor: factor to rescale 41 | rescale : Optional flag to apply post-scaling 42 | 43 | Returns 44 | ------- 45 | numpy.ndarray (int64) 46 | """ 47 | try: 48 | # Broadcasting shapes must match batchnorm layout: NCHW 49 | # Typically mul/add have shape [C] 50 | dims_x = len(x.shape) 51 | dim_ones = (1,) * (dims_x - 2) 52 | mul = mul.reshape(-1, *dim_ones) 53 | add = add.reshape(-1, *dim_ones) 54 | 55 | y = x * mul + add 56 | 57 | if rescale is not None: 58 | y = rescaling(scaling_factor, rescale, y) 59 | 60 | return y.astype(np.int64) 61 | 62 | except Exception as e: 63 | msg = f"Int64BatchNorm failed: {e}" 64 | raise RuntimeError(msg) from e 65 | -------------------------------------------------------------------------------- /python/frontend/commands/model_check.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import argparse 7 | 8 | from python.frontend.commands.args import MODEL_PATH 9 | from python.frontend.commands.base import BaseCommand 10 | 11 | 12 | class ModelCheckCommand(BaseCommand): 13 | """Check if a model is supported for quantization.""" 14 | 15 | name: ClassVar[str] = "model_check" 16 | aliases: ClassVar[list[str]] = ["check"] 17 | help: ClassVar[str] = "Check if the model is supported for quantization." 18 | 19 | @classmethod 20 | def configure_parser( 21 | cls: type[ModelCheckCommand], 22 | parser: argparse.ArgumentParser, 23 | ) -> None: 24 | MODEL_PATH.add_to_parser(parser) 25 | 26 | @classmethod 27 | @BaseCommand.validate_required(MODEL_PATH) 28 | @BaseCommand.validate_paths(MODEL_PATH) 29 | def run(cls: type[ModelCheckCommand], args: argparse.Namespace) -> None: 30 | import onnx # noqa: PLC0415 31 | 32 | from python.core.model_processing.onnx_quantizer.exceptions import ( # noqa: PLC0415 33 | InvalidParamError, 34 | UnsupportedOpError, 35 | ) 36 | from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ( # noqa: PLC0415 37 | ONNXOpQuantizer, 38 | ) 39 | 40 | model = onnx.load(args.model_path) 41 | quantizer = ONNXOpQuantizer() 42 | try: 43 | quantizer.check_model(model) 44 | print(f"Model {args.model_path} is supported.") # noqa: T201 45 | except UnsupportedOpError as e: 46 | msg = ( 47 | f"Model {args.model_path} is NOT supported: " 48 | f"Unsupported operations {e.unsupported_ops}" 49 | ) 50 | raise RuntimeError(msg) from e 51 | except InvalidParamError as e: 52 | msg = f"Model {args.model_path} is NOT supported: {e.message}" 53 | raise RuntimeError(msg) from e 54 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_custom_ops/custom_helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TypeVar 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def rescaling(scaling_factor: int, rescale: int, y: int) -> int: 9 | """Applies integer rescaling to a value based on the given scaling factor. 10 | 11 | Args: 12 | scaling_factor (int): The divisor to apply when rescaling. 13 | Must be provided if `rescale` is True. 14 | rescale (int): Whether to apply rescaling. (0 -> no rescaling, 1 -> rescaling). 15 | Y (int): The value to be rescaled. 16 | 17 | Raises: 18 | NotImplementedError: If `rescale` is 1 but `scaling_factor` is not provided. 19 | NotImplementedError: If `rescale` is not 0 or 1. 20 | 21 | Returns: 22 | int: The rescaled value if `rescale` is True, otherwise the original value. 23 | """ 24 | if rescale == 1: 25 | if scaling_factor is None: 26 | msg = "scaling_factor must be specified when rescale=1" 27 | raise ValueError(msg) 28 | return y // scaling_factor 29 | if rescale == 0: 30 | return y 31 | msg = f"Rescale must be 0 or 1, got {rescale}" 32 | raise ValueError(msg) 33 | 34 | 35 | def parse_attr(attr: str, default: T) -> T: 36 | """Parses an attribute list of strings into a list of integers. 37 | 38 | Args: 39 | attr (str): Attribute to parse. If a string, it must be 40 | comma-separated integers (e.g., "1, 2, 3"). 41 | If None, returns `default`. 42 | default (T): Default value to return if `attr` is None. 43 | 44 | Raises: 45 | ValueError: If `attr` is a string but cannot be parsed into integers. 46 | 47 | Returns: 48 | T: Parsed list of integers if attr is provided, otherwise the default value. 49 | """ 50 | if attr is None: 51 | return default 52 | try: 53 | return [int(x.strip()) for x in attr.split(",")] 54 | except ValueError as e: 55 | msg = f"Invalid attribute format: {attr}" 56 | raise ValueError(msg) from e 57 | -------------------------------------------------------------------------------- /python/frontend/commands/compile.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from typing import TYPE_CHECKING, ClassVar 5 | 6 | if TYPE_CHECKING: 7 | import argparse 8 | 9 | from python.core.circuits.errors import CircuitRunError 10 | from python.core.utils.helper_functions import CircuitExecutionConfig, RunType 11 | from python.frontend.commands.args import CIRCUIT_PATH, MODEL_PATH 12 | from python.frontend.commands.base import BaseCommand 13 | 14 | 15 | class CompileCommand(BaseCommand): 16 | """Compile an ONNX model to a circuit.""" 17 | 18 | name: ClassVar[str] = "compile" 19 | aliases: ClassVar[list[str]] = ["comp"] 20 | help: ClassVar[str] = ( 21 | "Compile a circuit (writes circuit + quantized model + weights)." 22 | ) 23 | 24 | @classmethod 25 | def configure_parser( 26 | cls: type[CompileCommand], 27 | parser: argparse.ArgumentParser, 28 | ) -> None: 29 | MODEL_PATH.add_to_parser(parser) 30 | CIRCUIT_PATH.add_to_parser( 31 | parser, 32 | "Output path for the compiled circuit (e.g., circuit.txt).", 33 | ) 34 | 35 | @classmethod 36 | @BaseCommand.validate_required(MODEL_PATH, CIRCUIT_PATH) 37 | @BaseCommand.validate_paths(MODEL_PATH) 38 | @BaseCommand.validate_parent_paths(CIRCUIT_PATH) 39 | def run(cls: type[CompileCommand], args: argparse.Namespace) -> None: 40 | model_name_hint = Path(args.model_path).stem 41 | circuit = cls._build_circuit(model_name_hint) 42 | 43 | circuit.model_file_name = args.model_path 44 | circuit.onnx_path = args.model_path 45 | circuit.model_path = args.model_path 46 | 47 | try: 48 | circuit.base_testing( 49 | CircuitExecutionConfig( 50 | run_type=RunType.COMPILE_CIRCUIT, 51 | circuit_path=args.circuit_path, 52 | dev_mode=False, 53 | ), 54 | ) 55 | except CircuitRunError as e: 56 | raise RuntimeError(e) from e 57 | 58 | print(f"[compile] done → circuit={args.circuit_path}") # noqa: T201 59 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import pytest 6 | 7 | if TYPE_CHECKING: 8 | from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ( 9 | ONNXOpQuantizer, 10 | ) 11 | from python.tests.onnx_quantizer_tests.layers.base import ( 12 | LayerTestConfig, 13 | LayerTestSpec, 14 | SpecType, 15 | ) 16 | from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory 17 | from python.tests.onnx_quantizer_tests.layers_tests.base_test import ( 18 | BaseQuantizerTest, 19 | ) 20 | 21 | 22 | class TestErrorCases(BaseQuantizerTest): 23 | """Tests for ONNX model checking.""" 24 | 25 | __test__ = True 26 | 27 | @pytest.mark.unit 28 | @pytest.mark.parametrize( 29 | "test_case_data", 30 | TestLayerFactory.get_test_cases_by_type(SpecType.ERROR), # type: ignore[arg-type] 31 | ids=BaseQuantizerTest._generate_test_id, 32 | ) 33 | def test_check_model_individual_error_cases( 34 | self: TestErrorCases, 35 | quantizer: ONNXOpQuantizer, 36 | test_case_data: tuple[str, LayerTestConfig, LayerTestSpec], 37 | ) -> None: 38 | """Test each individual error test case""" 39 | layer_name, config, test_spec = test_case_data 40 | 41 | # Skips if layer is not a valid onnx layer 42 | self._check_validation_dependency(test_case_data) 43 | 44 | if test_spec.skip_reason: 45 | pytest.skip(f"{layer_name}_{test_spec.name}: {test_spec.skip_reason}") 46 | 47 | # Create model from layer specs 48 | model = config.create_test_model(test_spec) 49 | 50 | # Ensures that expected test is in fact raised 51 | with pytest.raises(test_spec.expected_error) as exc: 52 | quantizer.check_model(model) 53 | 54 | # Ensures the error message is as expected 55 | if isinstance(test_spec.error_match, list): 56 | for e in test_spec.error_match: 57 | assert e in str(exc.value) 58 | else: 59 | assert test_spec.error_match in str(exc.value) 60 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_custom_ops/mul.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from onnxruntime_extensions import PyCustomOpDef, onnx_op 3 | 4 | from .custom_helpers import rescaling 5 | 6 | 7 | @onnx_op( 8 | op_type="Int64Mul", 9 | domain="ai.onnx.contrib", 10 | inputs=[ 11 | PyCustomOpDef.dt_int64, 12 | PyCustomOpDef.dt_int64, 13 | PyCustomOpDef.dt_int64, # Scalar 14 | ], 15 | outputs=[PyCustomOpDef.dt_int64], 16 | attrs={ 17 | "rescale": PyCustomOpDef.dt_int64, 18 | }, 19 | ) 20 | def int64_mul( 21 | a: np.ndarray, 22 | b: np.ndarray, 23 | scaling_factor: np.ndarray | None = None, 24 | rescale: int | None = None, 25 | ) -> np.ndarray: 26 | """ 27 | Performs a Mul (hadamard product) operation on int64 input tensors. 28 | 29 | This function is registered as a custom ONNX operator via onnxruntime_extensions 30 | and is used in the JSTprove quantized inference pipeline. 31 | It applies Mul with the rescaling the outputs back to the original scale. 32 | 33 | Parameters 34 | ---------- 35 | a : np.ndarray 36 | First input tensor with dtype int64. 37 | b : np.ndarray 38 | Second input tensor with dtype int64. 39 | scaling_factor : Scaling factor for rescaling the output. 40 | Optional scalar tensor for rescaling when rescale=1. 41 | rescale : int, optional 42 | Whether to apply rescaling (0=no, 1=yes). 43 | 44 | Returns 45 | ------- 46 | numpy.ndarray 47 | Mul tensor with dtype int64. 48 | 49 | Notes 50 | ----- 51 | - This op is part of the `ai.onnx.contrib` custom domain. 52 | - ONNX Runtime Extensions is required to register this op. 53 | 54 | References 55 | ---------- 56 | For more information on the Mul operation, please refer to the 57 | ONNX standard Mul operator documentation: 58 | https://onnx.ai/onnx/operators/onnx__Mul.html 59 | """ 60 | try: 61 | result = a * b 62 | result = rescaling(scaling_factor, rescale, result) 63 | return result.astype(np.int64) 64 | except Exception as e: 65 | msg = f"Int64Mul failed: {e}" 66 | raise RuntimeError(msg) from e 67 | -------------------------------------------------------------------------------- /python/frontend/commands/witness.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import argparse 7 | 8 | from python.core.circuits.errors import CircuitRunError 9 | from python.core.utils.helper_functions import CircuitExecutionConfig, RunType 10 | from python.frontend.commands.args import ( 11 | CIRCUIT_PATH, 12 | INPUT_PATH, 13 | OUTPUT_PATH, 14 | WITNESS_PATH, 15 | ) 16 | from python.frontend.commands.base import BaseCommand 17 | 18 | 19 | class WitnessCommand(BaseCommand): 20 | """Generate witness from circuit and inputs.""" 21 | 22 | name: ClassVar[str] = "witness" 23 | aliases: ClassVar[list[str]] = ["wit"] 24 | help: ClassVar[str] = "Generate witness using a compiled circuit." 25 | 26 | @classmethod 27 | def configure_parser( 28 | cls: type[WitnessCommand], 29 | parser: argparse.ArgumentParser, 30 | ) -> None: 31 | CIRCUIT_PATH.add_to_parser(parser) 32 | INPUT_PATH.add_to_parser(parser) 33 | OUTPUT_PATH.add_to_parser(parser) 34 | WITNESS_PATH.add_to_parser(parser) 35 | 36 | @classmethod 37 | @BaseCommand.validate_required( 38 | CIRCUIT_PATH, 39 | INPUT_PATH, 40 | OUTPUT_PATH, 41 | WITNESS_PATH, 42 | ) 43 | @BaseCommand.validate_paths(CIRCUIT_PATH, INPUT_PATH) 44 | @BaseCommand.validate_parent_paths(OUTPUT_PATH, WITNESS_PATH) 45 | def run(cls: type[WitnessCommand], args: argparse.Namespace) -> None: 46 | circuit = cls._build_circuit("cli") 47 | 48 | try: 49 | circuit.base_testing( 50 | CircuitExecutionConfig( 51 | run_type=RunType.GEN_WITNESS, 52 | circuit_path=args.circuit_path, 53 | input_file=args.input_path, 54 | output_file=args.output_path, 55 | witness_file=args.witness_path, 56 | ), 57 | ) 58 | except CircuitRunError as e: 59 | raise RuntimeError(e) from e 60 | 61 | print( # noqa: T201 62 | f"[witness] wrote witness → {args.witness_path} " 63 | f"and outputs → {args.output_path}", 64 | ) 65 | -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to JSTprove 2 | 3 | Thank you for your interest in contributing! Please follow these steps to ensure your contributions are smooth and 4 | consistent with project standards. 5 | 6 | --- 7 | 8 | ## **1. Set up your development environment** 9 | 10 | 1. **Clone the repository**: 11 | 12 | ```bash 13 | git clone https://github.com/inference-labs-inc/JSTprove.git 14 | cd JSTprove 15 | ``` 16 | 17 | 2. **Install dependencies with UV**: 18 | 19 | ```bash 20 | uv sync --dev 21 | uv pip install -e . 22 | ``` 23 | 24 | --- 25 | 26 | ## **2. Install Git hooks**: 27 | 28 | ```bash 29 | uv run pre-commit install --hook-type pre-commit --hook-type pre-push 30 | ``` 31 | 32 | This ensures that every commit automatically runs the pre-commit hooks locally, including: 33 | 34 | - Rust formatting (`cargo fmt`) 35 | - Trailing newline enforcement for `.rs` and `.py` files 36 | 37 | --- 38 | 39 | ## **3. Running pre-commit manually** 40 | 41 | You can check all files in the repository at any time: 42 | 43 | ```bash 44 | uv run pre-commit run --all-files 45 | ``` 46 | 47 | This is useful before pushing changes to catch any formatting issues early. 48 | 49 | --- 50 | 51 | ## **4. Committing changes** 52 | 53 | 1. Stage your changes: 54 | 55 | ```bash 56 | git add 57 | ``` 58 | 59 | 2. Commit: 60 | 61 | ```bash 62 | git commit -m "Your commit message" 63 | ``` 64 | 65 | The pre-commit hooks will automatically run. If any errors are detected, the commit will be blocked. Fix the issues, 66 | stage the files again, and commit. 67 | 68 | --- 69 | 70 | ## **5. Pull requests** 71 | 72 | - Always make sure your branch is up-to-date with the main branch. 73 | - Ensure all pre-commit hooks pass locally before opening a PR. 74 | - Run the full test suite locally and update/add tests as needed to cover your changes. 75 | - Update documentation (README, code comments, API docs, etc.) if your changes affect usage or behavior. 76 | - The CI pipeline will also run formatting checks and tests. Any failures must be resolved before merging. 77 | 78 | --- 79 | 80 | ## **6. Formatting & newline policy** 81 | 82 | - **Rust files**: All `.rs` files should be formatted using `cargo fmt`. 83 | - **Python files**: All `.py` files must have a trailing newline at EOF. 84 | - Pre-commit hooks enforce this automatically locally and in CI. 85 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/layers/constant.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use expander_compiler::frontend::{Config, RootAPI, Variable}; 4 | use ndarray::{ArrayD, IxDyn}; 5 | use serde_json::Value; 6 | 7 | use crate::circuit_functions::{ 8 | CircuitError, 9 | layers::{LayerError, LayerKind, layer_ops::LayerOp}, 10 | utils::{constants::VALUE, onnx_model::get_param}, 11 | }; 12 | 13 | // -------- Struct -------- 14 | #[allow(dead_code)] 15 | #[derive(Debug)] 16 | pub struct ConstantLayer { 17 | name: String, 18 | value: Value, 19 | outputs: Vec, 20 | } 21 | 22 | // -------- Implementations -------- 23 | 24 | // TODO remove constants from python side. Incorporate into the layer that uses it instead 25 | impl> LayerOp for ConstantLayer { 26 | // Passthrough 27 | fn apply( 28 | &self, 29 | api: &mut Builder, 30 | _input: HashMap>, 31 | ) -> Result<(Vec, ArrayD), CircuitError> { 32 | let arr = ArrayD::from_shape_vec(IxDyn(&[1]), vec![api.constant(0)]).map_err(|e| { 33 | LayerError::InvalidShape { 34 | layer: LayerKind::Constant, 35 | msg: e.to_string(), 36 | } 37 | })?; 38 | 39 | Ok((self.outputs.clone(), arr)) 40 | } 41 | 42 | fn build( 43 | layer: &crate::circuit_functions::utils::onnx_types::ONNXLayer, 44 | _circuit_params: &crate::circuit_functions::utils::onnx_model::CircuitParams, 45 | _optimization_pattern: crate::circuit_functions::utils::graph_pattern_matching::PatternRegistry, 46 | _is_rescale: bool, 47 | _index: usize, 48 | _layer_context: &crate::circuit_functions::utils::build_layers::BuildLayerContext, 49 | ) -> Result>, CircuitError> { 50 | let params = layer 51 | .params 52 | .clone() 53 | .ok_or_else(|| LayerError::MissingParameter { 54 | layer: LayerKind::Constant, 55 | param: "params".into(), 56 | })?; 57 | let constant = Self { 58 | name: layer.name.clone(), 59 | value: get_param(&layer.name, VALUE, ¶ms)?, 60 | outputs: layer.outputs.clone(), 61 | }; 62 | 63 | Ok(Box::new(constant)) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /python/frontend/commands/verify.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import argparse 7 | 8 | from python.core.circuits.errors import CircuitRunError 9 | from python.core.utils.helper_functions import CircuitExecutionConfig, RunType 10 | from python.frontend.commands.args import ( 11 | CIRCUIT_PATH, 12 | INPUT_PATH, 13 | OUTPUT_PATH, 14 | PROOF_PATH, 15 | WITNESS_PATH, 16 | ) 17 | from python.frontend.commands.base import BaseCommand 18 | 19 | 20 | class VerifyCommand(BaseCommand): 21 | """Verify a proof.""" 22 | 23 | name: ClassVar[str] = "verify" 24 | aliases: ClassVar[list[str]] = ["ver"] 25 | help: ClassVar[str] = "Verify a proof." 26 | 27 | @classmethod 28 | def configure_parser( 29 | cls: type[VerifyCommand], 30 | parser: argparse.ArgumentParser, 31 | ) -> None: 32 | CIRCUIT_PATH.add_to_parser(parser) 33 | INPUT_PATH.add_to_parser(parser) 34 | OUTPUT_PATH.add_to_parser(parser, "Path to expected outputs JSON.") 35 | WITNESS_PATH.add_to_parser(parser) 36 | PROOF_PATH.add_to_parser(parser) 37 | 38 | @classmethod 39 | @BaseCommand.validate_required( 40 | CIRCUIT_PATH, 41 | INPUT_PATH, 42 | OUTPUT_PATH, 43 | WITNESS_PATH, 44 | PROOF_PATH, 45 | ) 46 | @BaseCommand.validate_paths( 47 | CIRCUIT_PATH, 48 | INPUT_PATH, 49 | OUTPUT_PATH, 50 | WITNESS_PATH, 51 | PROOF_PATH, 52 | ) 53 | def run(cls: type[VerifyCommand], args: argparse.Namespace) -> None: 54 | circuit = cls._build_circuit("cli") 55 | 56 | try: 57 | circuit.base_testing( 58 | CircuitExecutionConfig( 59 | run_type=RunType.GEN_VERIFY, 60 | circuit_path=args.circuit_path, 61 | input_file=args.input_path, 62 | output_file=args.output_path, 63 | witness_file=args.witness_path, 64 | proof_file=args.proof_path, 65 | ecc=False, 66 | ), 67 | ) 68 | except CircuitRunError as e: 69 | raise RuntimeError(e) from e 70 | 71 | print( # noqa: T201 72 | f"[verify] verification complete for proof → {args.proof_path}", 73 | ) 74 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/flatten_config.py: -------------------------------------------------------------------------------- 1 | from python.tests.onnx_quantizer_tests.layers.base import ( 2 | e2e_test, 3 | valid_test, 4 | ) 5 | from python.tests.onnx_quantizer_tests.layers.factory import ( 6 | BaseLayerConfigProvider, 7 | LayerTestConfig, 8 | ) 9 | 10 | 11 | class FlattenConfigProvider(BaseLayerConfigProvider): 12 | """Test configuration provider for Flatten layers""" 13 | 14 | @property 15 | def layer_name(self) -> str: 16 | return "Flatten" 17 | 18 | def get_config(self) -> LayerTestConfig: 19 | return LayerTestConfig( 20 | op_type="Flatten", 21 | valid_inputs=["input"], 22 | valid_attributes={"axis": 1}, 23 | required_initializers={}, 24 | ) 25 | 26 | def get_test_specs(self) -> list: 27 | return [ 28 | # --- VALID TESTS --- 29 | valid_test("basic") 30 | .description("Basic Flatten from (1,3,4,4) to (1,48)") 31 | .tags("basic", "flatten") 32 | .build(), 33 | valid_test("flatten_axis0") 34 | .description("Flatten with axis=0 (entire tensor flattened)") 35 | .override_attrs(axis=0) 36 | .tags("flatten", "axis0") 37 | .build(), 38 | valid_test("flatten_axis2") 39 | .description("Flatten starting at axis=2") 40 | .override_attrs(axis=2) 41 | .tags("flatten", "axis2") 42 | .build(), 43 | valid_test("flatten_axis3") 44 | .description("Flatten starting at axis=3 (minimal flatten)") 45 | .override_attrs(axis=3) 46 | .tags("flatten", "axis3") 47 | .build(), 48 | e2e_test("e2e_basic") 49 | .description("End-to-end test for Flatten layer") 50 | .override_input_shapes(input=[1, 3, 4, 4]) 51 | .override_output_shapes(flatten_output=[1, 48]) 52 | .tags("e2e", "flatten") 53 | .build(), 54 | # --- EDGE CASE / SKIPPED TEST --- 55 | valid_test("large_input") 56 | .description("Large input flatten (performance test)") 57 | .override_input_shapes(input=[1, 3, 256, 256]) 58 | .tags("flatten", "large", "performance") 59 | .skip("Performance test, skipped by default") 60 | .build(), 61 | ] 62 | -------------------------------------------------------------------------------- /python/core/model_templates/circuit_template.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from secrets import randbelow 4 | 5 | from python.core.circuits.base import Circuit 6 | 7 | 8 | class SimpleCircuit(Circuit): 9 | """ 10 | Note: This template is irrelevant if using the ONNX circuit builder. 11 | The template only helps developers if they choose to incorporate other circuit 12 | builders into the framework. 13 | 14 | To begin, we need to specify some basic attributes surrounding the circuit we will 15 | be using. 16 | 17 | - `required_keys`: the variables in the input dictionary (and input file). 18 | - `name`: name of the Rust bin to be run by the circuit. 19 | - `scale_base`: base of the scaling applied to each value. 20 | - `scale_exponent`: exponent applied to the base to get the scaling factor. 21 | Scaling factor will be multiplied by each input. 22 | 23 | Other default inputs can be defined below. 24 | """ 25 | 26 | def __init__(self, file_name: str | None = None) -> None: 27 | # Initialize the base class 28 | super().__init__() 29 | self.file_name = file_name 30 | 31 | # Circuit-specific parameters 32 | self.required_keys = ["input_a", "input_b", "nonce"] 33 | self.name = "simple_circuit" # Use exact name that matches the binary 34 | 35 | self.scale_exponent = 1 36 | self.scale_base = 1 37 | 38 | self.input_a = 100 39 | self.input_b = 200 40 | self.nonce = randbelow(10_000) 41 | 42 | def get_inputs(self) -> dict[str, int]: 43 | """ 44 | Specify the inputs to the circuit, based on what was specified 45 | in `__init__`. 46 | """ 47 | return { 48 | "input_a": self.input_a, 49 | "input_b": self.input_b, 50 | "nonce": self.nonce, 51 | } 52 | 53 | def get_outputs(self, inputs: dict[str, int] | None = None) -> int: 54 | """ 55 | Compute the output of the circuit. 56 | 57 | This is overwritten from the base class to ensure computation happens 58 | only once. 59 | """ 60 | if inputs is None: 61 | inputs = { 62 | "input_a": self.input_a, 63 | "input_b": self.input_b, 64 | "nonce": self.nonce, 65 | } 66 | 67 | return inputs["input_a"] + inputs["input_b"] 68 | -------------------------------------------------------------------------------- /docs/models.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | This page explains what kinds of models JSTprove supports and how they're handled internally. 4 | 5 | --- 6 | 7 | ## Supported operators (current) 8 | 9 | - **Linear:** Fully Connected / **GEMM**, **MatMul**, **Add**, **Sub**, **Mul** 10 | - **Convolution:** **Conv2D** 11 | - **Activation:** **ReLU** 12 | - **Pooling:** **MaxPool2D** 13 | - **Shaping / graph ops:** **Flatten**, **Reshape**, **Constant** 14 | - **Normalization:** **BatchNorm** 15 | 16 | --- 17 | 18 | ## ONNX expectations 19 | 20 | - Export models with ops limited to **Conv2D**, **GEMM/MatMul**, **MaxPool2D**, **ReLU**, **Add**, **Sub**, **Mul**, **BatchNorm**. 21 | 22 | --- 23 | 24 | ## Quantization 25 | 26 | - Quantization is **automatic** in the pipeline during **compile**. 27 | - Internally, inputs and weights are scaled to integers, and tensors are reshaped to the expected shapes before witness generation. 28 | - The CLI's **witness** and **verify** stages take care of **rescale + reshape** via circuit helpers. 29 | 30 | --- 31 | 32 | ## Input / Output JSON 33 | 34 | - **Input JSON** should contain your model inputs as numeric arrays. 35 | - If values are floats, they'll be **scaled and rounded** automatically during witness/verify. 36 | - If your key is named exactly `"input"` (single-input models), it will be reshaped to the model's input shape. 37 | - Multi-input models are now supported. 38 | - Make sure to match the name of the inputs to the model, to the inputs that the model expects to receive. 39 | 40 | **Single-input example (flattened vector):** 41 | ```json 42 | { 43 | "input": [0, 1, 2, 3, 4, 5] 44 | } 45 | ```` 46 | 47 | **Single-input example (already shaped, e.g., 1×1×28×28):** 48 | 49 | ```json 50 | { 51 | "input": [[[ 52 | [0, 1, 2, "... 28 values ..."], 53 | "... 28 rows ..." 54 | ]]] 55 | } 56 | ``` 57 | 58 | * **Output JSON** produced by the pipeline is written under the key `"output"`, e.g.: 59 | 60 | ```json 61 | { 62 | "output": [0, 0, 1, 0, 0, 0, 0] 63 | } 64 | ``` 65 | 66 | --- 67 | 68 | ## Best practices 69 | 70 | * Use **one** ONNX model per compile. If you change the model, **re-run compile** to refresh the circuit and quantization. 71 | * Keep a consistent set of artifacts: `circuit.txt`, `quantized.onnx`, `input.json`, `output.json`, `witness.bin`, `proof.bin`. 72 | * For large CNNs, start with a small batch size and small inputs to validate the pipeline before scaling up. 73 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/relu_config.py: -------------------------------------------------------------------------------- 1 | from python.tests.onnx_quantizer_tests.layers.base import ( 2 | e2e_test, 3 | valid_test, 4 | ) 5 | from python.tests.onnx_quantizer_tests.layers.factory import ( 6 | BaseLayerConfigProvider, 7 | LayerTestConfig, 8 | ) 9 | 10 | 11 | class ReluConfigProvider(BaseLayerConfigProvider): 12 | """Test configuration provider for Relu layers""" 13 | 14 | @property 15 | def layer_name(self) -> str: 16 | return "Relu" 17 | 18 | def get_config(self) -> LayerTestConfig: 19 | return LayerTestConfig( 20 | op_type="Relu", 21 | valid_inputs=["input"], 22 | valid_attributes={}, 23 | required_initializers={}, 24 | ) 25 | 26 | def get_test_specs(self) -> list: 27 | return [ 28 | # --- VALID TESTS --- 29 | valid_test("basic") 30 | .description("Basic ReLU activation") 31 | .tags("basic", "activation") 32 | .build(), 33 | valid_test("negative_inputs") 34 | .description("ReLU should zero out negative input values") 35 | .override_input_shapes(input=[1, 3, 4, 4]) 36 | .tags("activation", "negative_values") 37 | .build(), 38 | valid_test("high_dimension_input") 39 | .description("ReLU applied to a 5D input tensor (NCHWT layout)") 40 | .override_input_shapes(input=[1, 3, 4, 4, 2]) 41 | .tags("activation", "high_dim", "5d") 42 | .build(), 43 | valid_test("scalar_input") 44 | .description("ReLU with scalar input (edge case)") 45 | .override_input_shapes(input=[1]) 46 | .tags("activation", "scalar") 47 | .build(), 48 | e2e_test("e2e_basic") 49 | .description("End-to-end test for ReLU activation") 50 | .override_input_shapes(input=[1, 3, 4, 4]) 51 | .override_output_shapes(relu_output=[1, 3, 4, 4]) 52 | .tags("e2e", "activation") 53 | .build(), 54 | # --- EDGE CASE / SKIPPED TEST --- 55 | valid_test("large_input") 56 | .description("Large input tensor for ReLU (performance/stress test)") 57 | .override_input_shapes(input=[1, 3, 512, 512]) 58 | .tags("large", "performance", "activation") 59 | .skip("Performance test, skipped by default") 60 | .build(), 61 | ] 62 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/reshape_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from python.tests.onnx_quantizer_tests.layers.base import ( 4 | e2e_test, 5 | valid_test, 6 | ) 7 | from python.tests.onnx_quantizer_tests.layers.factory import ( 8 | BaseLayerConfigProvider, 9 | LayerTestConfig, 10 | ) 11 | 12 | 13 | class ReshapeConfigProvider(BaseLayerConfigProvider): 14 | """Test configuration provider for Reshape layers""" 15 | 16 | @property 17 | def layer_name(self) -> str: 18 | return "Reshape" 19 | 20 | def get_config(self) -> LayerTestConfig: 21 | return LayerTestConfig( 22 | op_type="Reshape", 23 | valid_inputs=["input", "shape"], 24 | valid_attributes={}, 25 | required_initializers={"shape": np.array([1, -1])}, 26 | ) 27 | 28 | def get_test_specs(self) -> list: 29 | return [ 30 | # --- VALID TESTS --- 31 | valid_test("basic") 32 | .description("Basic Reshape from (1,2,3,4) to (1,24)") 33 | .tags("basic", "reshape") 34 | .build(), 35 | valid_test("reshape_expand_dims") 36 | .description("Reshape expanding dimensions (1,24) → (1,3,8)") 37 | .override_input_shapes(input=[1, 24]) 38 | .tags("reshape", "expand") 39 | .build(), 40 | valid_test("reshape_flatten") 41 | .description("Reshape to flatten spatial dimensions (1,3,4,4) → (1,48)") 42 | .override_input_shapes(input=[1, 24]) 43 | .override_initializer("shape", np.array([1, 3, -1])) 44 | .tags("reshape", "flatten") 45 | .build(), 46 | e2e_test("e2e_basic") 47 | .description("End-to-end test for Reshape layer") 48 | .override_input_shapes(input=[1, 2, 3, 4]) 49 | .override_output_shapes(reshape_output=[1, 24]) 50 | .override_initializer("shape", np.array([1, -1])) 51 | .tags("e2e", "reshape") 52 | .build(), 53 | # --- EDGE CASE / SKIPPED TEST --- 54 | valid_test("large_input") 55 | .description("Large reshape performance test") 56 | .override_input_shapes(input=[1, 3, 256, 256]) 57 | .override_initializer("shape", np.array([1, -1])) 58 | .tags("large", "performance", "reshape") 59 | # .skip("Performance test, skipped by default") 60 | .build(), 61 | ] 62 | -------------------------------------------------------------------------------- /.github/workflows/security.yml: -------------------------------------------------------------------------------- 1 | name: Security Audit 2 | 3 | on: 4 | schedule: 5 | - cron: "0 0 * * *" # Run daily at midnight 6 | push: 7 | branches: ["main", "rustfmt"] 8 | paths: 9 | - "**/Cargo.toml" 10 | - "**/Cargo.lock" 11 | - "pyproject.toml" 12 | - "uv.lock" 13 | pull_request: 14 | branches: ["main", "rustfmt"] 15 | paths: 16 | - "**/Cargo.toml" 17 | - "**/Cargo.lock" 18 | - "pyproject.toml" 19 | - "uv.lock" 20 | 21 | jobs: 22 | rust-audit: 23 | name: Rust Security Audit 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v4 27 | 28 | - name: Install Rust toolchain 29 | uses: actions-rs/toolchain@v1 30 | with: 31 | toolchain: nightly-2025-03-27 32 | override: true 33 | 34 | - uses: actions-rs/audit-check@v1 35 | with: 36 | token: ${{ secrets.GITHUB_TOKEN }} 37 | 38 | python-audit: 39 | name: Python Security Audit 40 | runs-on: ubuntu-latest 41 | steps: 42 | - uses: actions/checkout@v4 43 | 44 | - name: Set up Python 45 | uses: actions/setup-python@v5 46 | with: 47 | python-version: "3.12" 48 | 49 | - name: Install UV 50 | uses: astral-sh/setup-uv@v5 51 | 52 | - name: Install security tools 53 | run: | 54 | uv pip install --system safety bandit 55 | 56 | - name: Run safety check 57 | run: safety check 58 | 59 | - name: Run bandit 60 | id: bandit 61 | run: | 62 | set +e # Don't exit on error 63 | 64 | bandit -r ./python --severity-level high -f json -o bandit-report.json 65 | BANDIT_EXIT_CODE=$? 66 | echo "exit_code=$BANDIT_EXIT_CODE" >> $GITHUB_OUTPUT 67 | 68 | # Also output to console for visibility 69 | echo "::group::Bandit Security Report" 70 | bandit -r ./python --severity-level high -f txt 71 | echo "::endgroup::" 72 | 73 | exit 0 # Continue to upload artifact 74 | 75 | - name: Upload security report 76 | uses: actions/upload-artifact@v4 77 | with: 78 | name: security-reports 79 | path: bandit-report.json 80 | 81 | - name: Check bandit results 82 | if: steps.bandit.outputs.exit_code != '0' 83 | run: | 84 | echo "Bandit found security issues. Please review the report." 85 | exit ${{ steps.bandit.outputs.exit_code }} 86 | -------------------------------------------------------------------------------- /docs/developer-notes.md: -------------------------------------------------------------------------------- 1 | # Developer Notes 2 | 3 | Internal notes for contributors working on JSTprove (Python + Rust). 4 | 5 | > For environment setup, pre-commit, formatting policy, and PR workflow, see **[CONTRIBUTING.md](CONTRIBUTING.md)**. 6 | 7 | --- 8 | 9 | ## Repo layout 10 | 11 | ``` 12 | 13 | . 14 | ├─ python/ # CLI and pipeline 15 | │ └─ frontend/cli.py # JSTprove CLI entrypoint 16 | ├─ python/testing/ # unit/integration tests 17 | │ └─ core/tests/ # CLI tests, etc. 18 | └─ rust/ 19 | └─ jstprove_circuits/ # Rust crate: circuits + runner 20 | 21 | ```` 22 | 23 | --- 24 | 25 | ## Rust binaries 26 | 27 | - Main runner: `onnx_generic_circuit` 28 | - Simple demo: `simple_circuit` 29 | 30 | You can build them manually if needed: 31 | 32 | ```bash 33 | # from repo root 34 | cargo build --release 35 | # or explicitly (if not using a workspace root): 36 | cargo build --release --manifest-path rust/jstprove_circuits/Cargo.toml 37 | ```` 38 | 39 | These must be built manually if you are making changes to the rust side of the codebase, without the entire codebase package updating. 40 | 41 | Artifacts typically appear under `./target/release/`. 42 | 43 | > The CLI **compile** step will (re)build the runner automatically when needed. 44 | 45 | --- 46 | 47 | ## Python tests 48 | 49 | Before running tests, make sure to install test dependencies: 50 | 51 | ```bash 52 | uv sync --group test 53 | ``` 54 | 55 | * **Unit** CLI tests **mock** `base_testing` (fast; no heavy Rust). 56 | * Integration/E2E for heavy models live elsewhere; use `simple_circuit` or small ONNX models for smoke tests. 57 | 58 | Examples: 59 | 60 | ```bash 61 | # run unit + integration markers from repo root 62 | uv run pytest --unit --integration 63 | 64 | # run e2e tests. 65 | Place model to be run in python/models/models_onnx/.onnx 66 | uv run pytest --e2e -- 67 | ``` 68 | 69 | --- 70 | 71 | ## Useful environment variables 72 | 73 | Suppress the ASCII banner in non-interactive runs: 74 | 75 | ```bash 76 | export JSTPROVE_NO_BANNER=1 77 | ``` 78 | 79 | --- 80 | 81 | ## Notes & conventions 82 | 83 | * The CLI uses **GenericModelONNX** by default (no circuit class/name flags). 84 | * Paths are **explicit** (no inference). 85 | * Keep artifacts from the **same compile** together: `circuit.txt` + `quantized.onnx`. 86 | If the ONNX changes, **re-run compile**. 87 | * Run commands from the **repo root** so `./target/release/*` is resolvable. 88 | 89 | ``` 90 | -------------------------------------------------------------------------------- /python/core/utils/scratch_tests.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import onnx 4 | from onnx import TensorProto, helper, load, shape_inference 5 | from onnx.utils import extract_model 6 | 7 | 8 | def prune_model( 9 | model_path: str, 10 | output_names: list[str], 11 | save_path: str, 12 | ) -> None: 13 | """Extract a sub-model with the same inputs and new outputs.""" 14 | model = load(model_path) 15 | 16 | # Provide model input names and the new desired output names. 17 | input_names = [i.name for i in model.graph.input] 18 | 19 | extract_model( 20 | input_path=model_path, 21 | output_path=save_path, 22 | input_names=input_names, 23 | output_names=output_names, 24 | ) 25 | 26 | print(f"Pruned model saved to {save_path}") # noqa: T201 27 | 28 | 29 | def cut_model( 30 | model_path: str, 31 | output_names: list[str], 32 | save_path: str, 33 | ) -> None: 34 | """Replace the graph outputs with the tensors named in `output_names`.""" 35 | model = onnx.load(model_path) 36 | model = shape_inference.infer_shapes(model) 37 | 38 | graph = model.graph 39 | 40 | # Remove all current outputs one by one (cannot use .clear() or assignment). 41 | while graph.output: 42 | graph.output.pop() 43 | 44 | # Add new outputs. 45 | for name in output_names: 46 | # Look in value_info, input, or output. 47 | candidates = list(graph.value_info) + list(graph.input) + list(graph.output) 48 | value_info = next((vi for vi in candidates if vi.name == name), None) 49 | if value_info is None: 50 | msg = f"Tensor {name} not found in model graph." 51 | raise ValueError(msg) 52 | 53 | elem_type = value_info.type.tensor_type.elem_type 54 | shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim] 55 | new_output = helper.make_tensor_value_info(name, elem_type, shape) 56 | graph.output.append(new_output) 57 | 58 | for output in graph.output: 59 | print(output) # noqa: T201 60 | if output.name == "/conv1/Conv_output_0": 61 | output.type.tensor_type.elem_type = TensorProto.INT64 62 | 63 | onnx.save(model, save_path) 64 | print(f"Saved cut model with outputs {output_names} to {save_path}") # noqa: T201 65 | 66 | 67 | if __name__ == "__main__": 68 | prune_model( 69 | model_path="models_onnx/doom.onnx", 70 | output_names=["/Relu_3_output_0"], # replace with your intermediate tensor 71 | save_path="models_onnx/test_doom_cut.onnx", 72 | ) 73 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - id: check-yaml 8 | - id: check-toml 9 | - id: check-added-large-files 10 | - id: detect-private-key 11 | 12 | # Python formatting with black 13 | - repo: https://github.com/psf/black 14 | rev: 25.9.0 15 | hooks: 16 | - id: black 17 | language_version: python3.12 18 | args: [--check] 19 | 20 | # Python linting with ruff (comprehensive) 21 | - repo: https://github.com/astral-sh/ruff-pre-commit 22 | rev: v0.13.3 23 | hooks: 24 | - id: ruff 25 | args: [ 26 | --fix, 27 | --exit-non-zero-on-fix, 28 | --select=ALL, # Enable all rules for local checks 29 | --ignore=D, # Ignore documentation warnings 30 | '--per-file-ignores=python/tests/**.py:S101', # allows assert statements in tests 31 | '--per-file-ignores=python/tests/**.py:T201', # Allows print lines in tests 32 | '--per-file-ignores=python/tests/**.py:F811', # Allows redefinition of unused name in tests (For fixtures which skip tests) 33 | '--per-file-ignores=python/tests/**.py:PLR0913', # Allows more arguments in function definition in tests (Fixtures etc.) 34 | '--per-file-ignores=python/tests/**.py:SLF001', # Allows private members to be accessed in tests (_file_info, _compile_preprocessing) 35 | '--per-file-ignores=python/tests/**.py:ARG001', # Allows unused function arguments in tests (used for mocked variables) 36 | '--per-file-ignores=python/core/model_processing/onnx_custom_ops/**.py:PLR0913', # Allows more arguments in onnx_custom_ops, due to necessity of integration BLE001 37 | '--per-file-ignores=python/**.py:BLE001', # Allows blind exception handling 38 | 39 | ] 40 | 41 | # Rust formatting and linting 42 | - repo: https://github.com/doublify/pre-commit-rust 43 | rev: v1.0 44 | hooks: 45 | - id: fmt 46 | args: ['--manifest-path=./Cargo.toml', '--all', '--', '--check'] 47 | - id: clippy 48 | args: [ 49 | '--manifest-path=./Cargo.toml', 50 | '--all', 51 | '--', 52 | -W, 'clippy::all', 53 | -W, 'clippy::pedantic', 54 | -D, warnings, 55 | ] 56 | 57 | # Dependency checks 58 | - repo: https://github.com/python-poetry/poetry 59 | rev: 2.2.1 60 | hooks: 61 | - id: poetry-check # Validates pyproject.toml 62 | - id: poetry-lock # Updates poetry.lock if needed 63 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_custom_ops/maxpool.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as f 6 | from onnxruntime_extensions import PyCustomOpDef, onnx_op 7 | 8 | from .custom_helpers import parse_attr 9 | 10 | 11 | @onnx_op( 12 | op_type="Int64MaxPool", 13 | domain="ai.onnx.contrib", 14 | inputs=[PyCustomOpDef.dt_int64], # input tensor 15 | outputs=[PyCustomOpDef.dt_int64], 16 | attrs={ 17 | "strides": PyCustomOpDef.dt_string, 18 | "pads": PyCustomOpDef.dt_string, 19 | "kernel_shape": PyCustomOpDef.dt_string, 20 | "dilations": PyCustomOpDef.dt_string, 21 | }, 22 | ) 23 | def int64_maxpool( 24 | x: np.ndarray, 25 | strides: str | None = None, 26 | pads: str | None = None, 27 | kernel_shape: str | None = None, 28 | dilations: str | None = None, 29 | ) -> np.ndarray: 30 | """ 31 | Performs a MaxPool operation on int64 input tensors. 32 | 33 | This function is registered as a custom ONNX operator via onnxruntime_extensions 34 | and is used in the JSTprove quantized inference pipeline. It parses ONNX-style 35 | maxpool attributes and applies maxpool. 36 | 37 | Parameters 38 | ---------- 39 | X : Input tensor with dtype int64. 40 | kernel_shape : Kernel shape (default: `[2, 2]`). 41 | pads : Padding values (default: `[0, 0, 0, 0]`). 42 | strides : Stride values (default: `[1, 1]`). 43 | dilations : dilation values (default: `[1, 1]`). 44 | 45 | 46 | Returns 47 | ------- 48 | numpy.ndarray 49 | Maxpool tensor with dtype int64. 50 | 51 | Notes 52 | ----- 53 | - This op is part of the `ai.onnx.contrib` custom domain. 54 | - ONNX Runtime Extensions is required to register this op. 55 | 56 | References 57 | ---------- 58 | For more information on the maxpool operation, please refer to the 59 | ONNX standard MaxPool operator documentation: 60 | https://onnx.ai/onnx/operators/onnx__MaxPool.html 61 | """ 62 | try: 63 | strides = parse_attr(strides, [1, 1]) 64 | pads = parse_attr(pads, [0, 0]) 65 | kernel_size = parse_attr(kernel_shape, [2, 2]) 66 | dilations = parse_attr(dilations, [1, 1]) 67 | 68 | x = torch.from_numpy(x) 69 | result = f.max_pool2d( 70 | x, 71 | kernel_size=kernel_size, 72 | stride=strides, 73 | padding=pads[:2], 74 | dilation=dilations, 75 | ) 76 | return result.numpy().astype(np.int64) 77 | except Exception as e: 78 | msg = f"Int64MaxPool failed: {e}" 79 | raise RuntimeError(msg) from e 80 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/layers/layer_ops.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use expander_compiler::frontend::{Config, RootAPI, Variable}; 4 | use ndarray::ArrayD; 5 | 6 | use crate::circuit_functions::CircuitError; 7 | use crate::circuit_functions::utils::graph_pattern_matching::PatternRegistry; 8 | use crate::circuit_functions::utils::onnx_model::CircuitParams; 9 | use crate::circuit_functions::utils::onnx_types::ONNXLayer; 10 | 11 | pub trait LayerOp> { 12 | /// Instantiated by each layer op. 13 | /// Applies the operation relevant operation for that layer 14 | /// 15 | /// # Arguments 16 | /// - `api`: Mutable reference to the circuit builder. 17 | /// - `input`: Mapping from input names to inputs of 18 | /// the layer. 19 | /// 20 | /// # Returns 21 | /// A tuple `(output_names, output_tensor)` containing: 22 | /// - The ordered list of output names for this layer. 23 | /// - The computed output tensor as an `ArrayD`. 24 | /// 25 | /// # Errors 26 | /// - [`CircuitError`] if tensor operations or constraints fail. 27 | /// Or typically if a layer is missing. 28 | /// Additionally, any error propogated from underlying computation. 29 | /// 30 | fn apply( 31 | &self, 32 | api: &mut Builder, 33 | input: HashMap>, 34 | ) -> Result<(Vec, ArrayD), CircuitError>; 35 | /// Instantiated by each layer op. 36 | /// Builds a circuit layer from an ONNX definition. 37 | /// 38 | /// # Arguments 39 | /// - `layer`: The ONNX layer specification (op type, attributes, inputs, outputs, etc.). 40 | /// - `circuit_params`: Global parameters controlling scaling and rescaling. 41 | /// - `optimization_pattern`: Find any optimization patterns involved. 42 | /// - `is_rescale`: Flag indicating whether rescaling logic should be applied. 43 | /// - `index`: The index of this layer in the network. 44 | /// - `layer_context`: Additional shared state for building layers. 45 | /// 46 | /// # Returns 47 | /// A boxed `LayerOp` implementing the logic for the given ONNX layer. 48 | /// 49 | /// # Errors 50 | /// - [`CircuitError`] if the layer cannot be instantiated in the circuit. 51 | fn build( 52 | layer: &ONNXLayer, 53 | circuit_params: &CircuitParams, 54 | optimization_pattern: PatternRegistry, 55 | is_rescale: bool, 56 | index: usize, 57 | layer_context: &crate::circuit_functions::utils::build_layers::BuildLayerContext, 58 | ) -> Result>, CircuitError> 59 | where 60 | Self: Sized; 61 | } 62 | -------------------------------------------------------------------------------- /python/frontend/commands/args.py: -------------------------------------------------------------------------------- 1 | """Argument specifications for CLI commands.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | import argparse 9 | 10 | from dataclasses import dataclass, field 11 | from typing import Any 12 | 13 | 14 | @dataclass(frozen=True) 15 | class ArgSpec: 16 | """Specification for a command-line argument.""" 17 | 18 | name: str 19 | flag: str 20 | help_text: str 21 | short: str = "" 22 | arg_type: type | None = None 23 | extra_kwargs: dict[str, Any] = field(default_factory=dict) 24 | 25 | @property 26 | def positional(self) -> str: 27 | """Return the positional argument name.""" 28 | return f"pos_{self.name}" 29 | 30 | def add_to_parser( 31 | self, 32 | parser: argparse.ArgumentParser, 33 | help_override: str | None = None, 34 | ) -> None: 35 | """Add both positional and flag arguments to the parser.""" 36 | help_text = help_override or self.help_text 37 | kwargs = {"help": help_text, **self.extra_kwargs} 38 | if self.arg_type is not None: 39 | kwargs["type"] = self.arg_type 40 | 41 | if self.short: 42 | parser.add_argument( 43 | self.positional, 44 | nargs="?", 45 | metavar=self.name, 46 | **kwargs, 47 | ) 48 | parser.add_argument( 49 | self.short, 50 | self.flag, 51 | **kwargs, 52 | ) 53 | else: 54 | parser.add_argument( 55 | self.flag, 56 | **kwargs, 57 | ) 58 | 59 | 60 | MODEL_PATH = ArgSpec( 61 | name="model_path", 62 | flag="--model-path", 63 | short="-m", 64 | help_text="Path to the original ONNX model.", 65 | ) 66 | 67 | CIRCUIT_PATH = ArgSpec( 68 | name="circuit_path", 69 | flag="--circuit-path", 70 | short="-c", 71 | help_text="Path to the compiled circuit.", 72 | ) 73 | 74 | INPUT_PATH = ArgSpec( 75 | name="input_path", 76 | flag="--input-path", 77 | short="-i", 78 | help_text="Path to input JSON.", 79 | ) 80 | 81 | OUTPUT_PATH = ArgSpec( 82 | name="output_path", 83 | flag="--output-path", 84 | short="-o", 85 | help_text="Path to write model outputs JSON.", 86 | ) 87 | 88 | WITNESS_PATH = ArgSpec( 89 | name="witness_path", 90 | flag="--witness-path", 91 | short="-w", 92 | help_text="Path to write witness.", 93 | ) 94 | 95 | PROOF_PATH = ArgSpec( 96 | name="proof_path", 97 | flag="--proof-path", 98 | short="-p", 99 | help_text="Path to write proof.", 100 | ) 101 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/layers/flatten.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use expander_compiler::frontend::{Config, RootAPI, Variable}; 4 | use ndarray::ArrayD; 5 | 6 | use crate::circuit_functions::{ 7 | CircuitError, 8 | layers::{LayerError, LayerKind, layer_ops::LayerOp}, 9 | utils::{ 10 | constants::{AXIS, INPUT}, 11 | onnx_model::{extract_params_and_expected_shape, get_input_name, get_param_or_default}, 12 | shaping::onnx_flatten, 13 | }, 14 | }; 15 | 16 | // -------- Struct -------- 17 | #[allow(dead_code)] 18 | #[derive(Debug)] 19 | pub struct FlattenLayer { 20 | name: String, 21 | axis: usize, 22 | input_shape: Vec, 23 | inputs: Vec, 24 | outputs: Vec, 25 | } 26 | 27 | // -------- Implementations -------- 28 | 29 | impl> LayerOp for FlattenLayer { 30 | fn apply( 31 | &self, 32 | _api: &mut Builder, 33 | input: HashMap>, 34 | ) -> Result<(Vec, ArrayD), CircuitError> { 35 | let reshape_axis = self.axis; 36 | let input_name = get_input_name(&self.inputs, 0, LayerKind::Flatten, INPUT)?; 37 | let layer_input = input 38 | .get(&input_name.clone()) 39 | .ok_or_else(|| LayerError::MissingInput { 40 | layer: LayerKind::Flatten, 41 | name: input_name.clone(), 42 | })? 43 | .clone(); 44 | 45 | let out = onnx_flatten(layer_input.clone(), reshape_axis)?; 46 | 47 | Ok((self.outputs.clone(), out.clone())) 48 | } 49 | fn build( 50 | layer: &crate::circuit_functions::utils::onnx_types::ONNXLayer, 51 | _circuit_params: &crate::circuit_functions::utils::onnx_model::CircuitParams, 52 | _optimization_pattern: crate::circuit_functions::utils::graph_pattern_matching::PatternRegistry, 53 | _is_rescale: bool, 54 | _index: usize, 55 | layer_context: &crate::circuit_functions::utils::build_layers::BuildLayerContext, 56 | ) -> Result>, CircuitError> { 57 | let (params, expected_shape) = extract_params_and_expected_shape(layer_context, layer) 58 | .map_err(|e| LayerError::Other { 59 | layer: LayerKind::Flatten, 60 | msg: format!("extract_params_and_expected_shape failed: {e}"), 61 | })?; 62 | let flatten = Self { 63 | name: layer.name.clone(), 64 | axis: get_param_or_default(&layer.name, AXIS, ¶ms, Some(&1))?, 65 | input_shape: expected_shape.clone(), 66 | inputs: layer.inputs.clone(), 67 | outputs: layer.outputs.clone(), 68 | }; 69 | Ok(Box::new(flatten)) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_custom_ops/gemm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | from onnxruntime_extensions import PyCustomOpDef, onnx_op 5 | 6 | from .custom_helpers import rescaling 7 | 8 | 9 | @onnx_op( 10 | op_type="Int64Gemm", 11 | domain="ai.onnx.contrib", 12 | inputs=[ 13 | PyCustomOpDef.dt_int64, # X 14 | PyCustomOpDef.dt_int64, # W 15 | PyCustomOpDef.dt_int64, # B 16 | PyCustomOpDef.dt_int64, # Scalar 17 | ], 18 | outputs=[PyCustomOpDef.dt_int64], 19 | attrs={ 20 | "alpha": PyCustomOpDef.dt_float, 21 | "beta": PyCustomOpDef.dt_float, 22 | "transA": PyCustomOpDef.dt_int64, 23 | "transB": PyCustomOpDef.dt_int64, 24 | "rescale": PyCustomOpDef.dt_int64, 25 | }, 26 | ) 27 | def int64_gemm7( 28 | a: np.ndarray, 29 | b: np.ndarray, 30 | c: np.ndarray | None = None, 31 | scaling_factor: np.ndarray | None = None, 32 | alpha: float | None = None, 33 | beta: float | None = None, 34 | transA: int | None = None, # noqa: N803 35 | transB: int | None = None, # noqa: N803 36 | rescale: int | None = None, 37 | ) -> np.ndarray: 38 | """ 39 | Performs a Gemm (alternatively: Linear layer) on int64 input tensors. 40 | 41 | This function is registered as a custom ONNX operator via onnxruntime_extensions 42 | and is used in the JSTprove quantized inference pipeline. It parses ONNX-style 43 | gemm attributes, applies gemm 44 | and optionally rescales the result. 45 | 46 | Parameters 47 | ---------- 48 | a : Input tensor with dtype int64. 49 | b : Gemm weight tensor with dtype int64. 50 | c : Optional bias tensor with dtype int64. 51 | scaling_factor : Scaling factor for rescaling the output. 52 | alpha : alpha value for Gemm operation. 53 | beta : beta value for Gemm operation. 54 | transA : Transpose the a matrix before the Gemm operation 55 | transB : Transpose the b matrix before the Gemm operation 56 | rescale : Optional flag to apply output rescaling or not. 57 | 58 | Returns 59 | ------- 60 | numpy.ndarray 61 | Gemm tensor with dtype int64. 62 | 63 | Notes 64 | ----- 65 | - This op is part of the `ai.onnx.contrib` custom domain. 66 | - ONNX Runtime Extensions is required to register this op. 67 | 68 | References 69 | ---------- 70 | For more information on the gemm operation, please refer to the 71 | ONNX standard Gemm operator documentation: 72 | https://onnx.ai/onnx/operators/onnx__Gemm.html 73 | """ 74 | try: 75 | alpha = int(alpha) 76 | beta = int(beta) 77 | 78 | a = a.T if transA else a 79 | b = b.T if transB else b 80 | 81 | result = alpha * (a @ b) 82 | 83 | if c is not None: 84 | result += beta * c 85 | 86 | result = rescaling(scaling_factor, rescale, result) 87 | return result.astype(np.int64) 88 | 89 | except Exception as e: 90 | msg = f"Int64Gemm failed: {e}" 91 | raise RuntimeError(msg) from e 92 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers_tests/base_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | import pytest 6 | from onnx import TensorProto, helper 7 | 8 | if TYPE_CHECKING: 9 | from onnx import ModelProto 10 | 11 | from python.tests.onnx_quantizer_tests.layers.base import ( 12 | LayerTestConfig, 13 | LayerTestSpec, 14 | ) 15 | from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ( 16 | ONNXOpQuantizer, 17 | ) 18 | from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory 19 | 20 | 21 | class BaseQuantizerTest: 22 | """Base test utilities for ONNX quantizer tests.""" 23 | 24 | __test__ = False # Prevent pytest from collecting this class directly 25 | 26 | _validation_failed_cases: ClassVar[set[str]] = set() 27 | 28 | @pytest.fixture 29 | def quantizer(self) -> ONNXOpQuantizer: 30 | return ONNXOpQuantizer() 31 | 32 | @pytest.fixture 33 | def layer_configs(self) -> dict[str, LayerTestConfig]: 34 | return TestLayerFactory.get_layer_configs() 35 | 36 | @staticmethod 37 | def _generate_test_id( 38 | test_case_tuple: tuple[str, LayerTestConfig, LayerTestSpec], 39 | ) -> str: 40 | try: 41 | layer_name, _, test_spec = test_case_tuple 42 | except Exception: 43 | return str(test_case_tuple) 44 | else: 45 | return f"{layer_name}_{test_spec.name}" 46 | 47 | @classmethod 48 | def _check_validation_dependency( 49 | cls: BaseQuantizerTest, 50 | test_case_data: tuple[str, LayerTestConfig, LayerTestSpec], 51 | ) -> None: 52 | layer_name, _, test_spec = test_case_data 53 | test_case_id = f"{layer_name}_{test_spec.name}" 54 | if test_case_id in cls._validation_failed_cases: 55 | pytest.skip(f"Skipping because ONNX validation failed for {test_case_id}") 56 | 57 | @staticmethod 58 | def create_model_with_layers( 59 | layer_types: list[str], 60 | layer_configs: dict[str, LayerTestConfig], 61 | ) -> ModelProto: 62 | """Create a model composed of several layers.""" 63 | nodes, all_initializers = [], {} 64 | 65 | for i, layer_type in enumerate(layer_types): 66 | config = layer_configs[layer_type] 67 | node = config.create_node(name_suffix=f"_{i}") 68 | if i > 0: 69 | prev_output = f"{layer_types[i-1].lower()}_output_{i-1}" 70 | if node.input: 71 | node.input[0] = prev_output 72 | nodes.append(node) 73 | all_initializers.update(config.create_initializers()) 74 | 75 | graph = helper.make_graph( 76 | nodes, 77 | "test_graph", 78 | [ 79 | helper.make_tensor_value_info( 80 | "input", 81 | TensorProto.FLOAT, 82 | [1, 16, 224, 224], 83 | ), 84 | ], 85 | [ 86 | helper.make_tensor_value_info( 87 | f"{layer_types[-1].lower()}_output_{len(layer_types)-1}", 88 | TensorProto.FLOAT, 89 | [1, 10], 90 | ), 91 | ], 92 | initializer=list(all_initializers.values()), 93 | ) 94 | return helper.make_model(graph) 95 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/clip.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import onnx 7 | 8 | from python.core.model_processing.onnx_quantizer.layers.base import ( 9 | BaseOpQuantizer, 10 | QuantizerBase, 11 | ScaleConfig, 12 | ) 13 | 14 | 15 | class QuantizeClip(QuantizerBase): 16 | """ 17 | Quantization traits for ONNX Clip. 18 | 19 | Semantics: 20 | - X is already scaled/cast to INT64 at the graph boundary by the converter. 21 | - Clip is elementwise + broadcasting. 22 | - The bound inputs (min, max) should live in the *same* fixed-point scale 23 | as X so that Clip(alpha*x; alpha*a, alpha*b) matches the original Clip(x; a, b). 24 | 25 | Implementation: 26 | - Treat inputs 1 and 2 (min, max) like "WB-style" slots: we let the 27 | QuantizerBase machinery rescale / cast those inputs using the same 28 | global scale factor. 29 | - No extra internal scaling input is added (USE_SCALING = False). 30 | """ 31 | 32 | OP_TYPE = "Clip" 33 | DOMAIN = "" # standard ONNX domain 34 | 35 | # We DO want WB-style handling so that min/max initializers get quantized: 36 | USE_WB = True 37 | 38 | # Clip does not introduce its own scale input; it just runs in the 39 | # existing fixed-point scale. 40 | USE_SCALING = False 41 | 42 | # Scale-plan for WB-style slots: 43 | # - Input index 1: min 44 | # - Input index 2: max 45 | # Each should be scaled once by the global alpha (same as activations). 46 | SCALE_PLAN: ClassVar = {1: 1, 2: 1} 47 | 48 | 49 | class ClipQuantizer(BaseOpQuantizer, QuantizeClip): 50 | """ 51 | Quantizer for ONNX Clip. 52 | 53 | - Keeps the node op_type as "Clip". 54 | - Ensures that any bound inputs (min, max), whether they are dynamic 55 | inputs or initializers, are converted to the same INT64 fixed-point 56 | representation as A. 57 | """ 58 | 59 | def __init__( 60 | self, 61 | new_initializers: dict[str, onnx.TensorProto] | None = None, 62 | ) -> None: 63 | # Match Max/Min/Add: we simply share the new_initializers dict 64 | # with the converter so any constants we add are collected. 65 | self.new_initializers = new_initializers 66 | 67 | def quantize( 68 | self, 69 | node: onnx.NodeProto, 70 | graph: onnx.GraphProto, 71 | scale_config: ScaleConfig, 72 | initializer_map: dict[str, onnx.TensorProto], 73 | ) -> list[onnx.NodeProto]: 74 | # Delegate to the shared QuantizerBase logic, which will: 75 | # - keep X as-is (already scaled/cast by the converter), 76 | # - rescale / cast min/max according to SCALE_PLAN, 77 | # - update initializers as needed. 78 | return QuantizeClip.quantize(self, node, graph, scale_config, initializer_map) 79 | 80 | def check_supported( 81 | self, 82 | node: onnx.NodeProto, 83 | initializer_map: dict[str, onnx.TensorProto] | None = None, 84 | ) -> None: 85 | """ 86 | Minimal support check for Clip: 87 | 88 | - Clip is variadic elementwise with optional min/max as inputs or attrs. 89 | - We accept both forms; if attrs are present, ORT enforces semantics. 90 | - Broadcasting is ONNX-standard; we don't restrict further here. 91 | """ 92 | _ = node, initializer_map 93 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/bin/simple_circuit.rs: -------------------------------------------------------------------------------- 1 | use expander_compiler::frontend::{ 2 | BN254Config, CircuitField, Config, Define, RootAPI, Variable, declare_circuit, 3 | }; 4 | use jstprove_circuits::io::io_reader::{FileReader, IOReader}; 5 | use jstprove_circuits::runner::errors::RunError; 6 | use jstprove_circuits::runner::main_runner::{ConfigurableCircuit, get_args, handle_args}; 7 | use serde::Deserialize; 8 | 9 | declare_circuit!(Circuit { 10 | input_a: PublicVariable, 11 | input_b: PublicVariable, 12 | nonce: PublicVariable, 13 | output: PublicVariable, 14 | dummy: [Variable; 2] 15 | }); 16 | 17 | //Still to factor this out 18 | 19 | impl Define for Circuit { 20 | fn define>(&self, api: &mut Builder) { 21 | let out = api.add(self.input_a, self.input_b); 22 | // let out1 = api.add(self.nonce, out); 23 | api.assert_is_non_zero(self.nonce); 24 | 25 | api.assert_is_equal(out, self.output); 26 | for i in 0..self.dummy.len() { 27 | api.assert_is_zero(self.dummy[i]); 28 | } 29 | } 30 | } 31 | 32 | #[derive(Deserialize, Clone)] 33 | struct InputData { 34 | value_a: u32, 35 | value_b: u32, 36 | nonce: u32, 37 | } 38 | 39 | //This is the data structure for the output data to be read in from the json file 40 | #[derive(Deserialize, Clone)] 41 | struct OutputData { 42 | output: u32, 43 | } 44 | 45 | impl ConfigurableCircuit for Circuit { 46 | fn configure(&mut self) -> Result<(), jstprove_circuits::runner::errors::RunError> { 47 | Ok(()) 48 | } 49 | } 50 | 51 | impl IOReader>, C> for FileReader { 52 | fn read_inputs( 53 | &mut self, 54 | file_path: &str, 55 | mut assignment: Circuit>, 56 | ) -> Result>, RunError> { 57 | let data: InputData = 58 | , C>>::read_data_from_json::(file_path)?; 59 | 60 | // Assign inputs to assignment 61 | assignment.input_a = CircuitField::::from(data.value_a); 62 | assignment.input_b = CircuitField::::from(data.value_b); 63 | assignment.nonce = CircuitField::::from(data.nonce); 64 | assignment.dummy = [CircuitField::::from(0); 2]; 65 | 66 | // Return the assignment 67 | Ok(assignment) 68 | } 69 | fn read_outputs( 70 | &mut self, 71 | file_path: &str, 72 | mut assignment: Circuit>, 73 | ) -> Result>, RunError> { 74 | let data: OutputData = 75 | , C>>::read_data_from_json::(file_path)?; 76 | 77 | // Assign inputs to assignment 78 | assignment.output = CircuitField::::from(data.output); 79 | 80 | Ok(assignment) 81 | } 82 | fn get_path(&self) -> &str { 83 | &self.path 84 | } 85 | } 86 | 87 | fn main() { 88 | let mut file_reader = FileReader { 89 | path: "simple_circuit".to_owned(), 90 | }; 91 | 92 | let matches = get_args(); 93 | 94 | if let Err(err) = 95 | handle_args::, Circuit<_>, _>(&matches, &mut file_reader) 96 | { 97 | eprintln!("Error: {err}"); 98 | std::process::exit(1); 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /python/tests/circuit_parent_classes/test_ort_custom_layers.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | import torch 7 | from onnx import TensorProto, helper, shape_inference 8 | 9 | from python.core.model_processing.converters.onnx_converter import ONNXConverter 10 | 11 | 12 | @pytest.fixture 13 | def tiny_conv_model_path(tmp_path: Path) -> Path: 14 | # Create input and output tensor info 15 | input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4]) 16 | output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2]) 17 | 18 | # Kernel weights (3x3 ones) 19 | w_init = helper.make_tensor( 20 | name="W", 21 | data_type=TensorProto.FLOAT, 22 | dims=[1, 1, 3, 3], 23 | vals=np.ones((1 * 1 * 3 * 3), dtype=np.float32).tolist(), 24 | ) 25 | z_init = helper.make_tensor( 26 | name="Z", 27 | data_type=TensorProto.FLOAT, 28 | dims=[1], 29 | vals=np.ones((1), dtype=np.float32).tolist(), 30 | ) 31 | 32 | # Conv node with no padding, stride 1 33 | conv_node = helper.make_node( 34 | "Conv", 35 | inputs=["X", "W", "Z"], 36 | outputs=["Y"], 37 | kernel_shape=[3, 3], 38 | pads=[0, 0, 0, 0], 39 | strides=[1, 1], 40 | dilations=[1, 1], 41 | ) 42 | 43 | # Build graph and model 44 | graph = helper.make_graph( 45 | nodes=[conv_node], 46 | name="TinyConvGraph", 47 | inputs=[input_tensor], 48 | outputs=[output_tensor], 49 | initializer=[w_init, z_init], 50 | ) 51 | 52 | model = helper.make_model(graph, producer_name="tiny-conv-example") 53 | 54 | # Save to a temporary file 55 | model_path = tmp_path / "tiny_conv.onnx" 56 | onnx.save(model, str(model_path)) 57 | 58 | return model_path 59 | 60 | 61 | @pytest.mark.integration 62 | def test_tiny_conv(tiny_conv_model_path: Path, tmp_path: Path) -> None: 63 | path = tiny_conv_model_path 64 | 65 | converter = ONNXConverter() 66 | 67 | # Load and validate original model 68 | model = onnx.load(path) 69 | onnx.checker.check_model(model) 70 | 71 | # Apply shape inference and validate 72 | inferred_model = shape_inference.infer_shapes(model) 73 | onnx.checker.check_model(inferred_model) 74 | 75 | # Quantize and add custom domain 76 | new_model = converter.quantize_model(model, 2, 21) 77 | custom_domain = onnx.helper.make_operatorsetid(domain="ai.onnx.contrib", version=1) 78 | new_model.opset_import.append(custom_domain) 79 | onnx.checker.check_model(new_model) 80 | 81 | # Save quantized model 82 | out_path = tmp_path / "model_quant.onnx" 83 | with out_path.open("wb") as f: 84 | f.write(new_model.SerializeToString()) 85 | 86 | # Reload quantized model to ensure it is valid 87 | model_quant = onnx.load(str(out_path)) 88 | onnx.checker.check_model(model_quant) 89 | 90 | # Prepare inputs and compare outputs 91 | inputs = np.arange(16, dtype=np.float32).reshape(1, 1, 4, 4) 92 | outputs_true = converter.run_model_onnx_runtime(path, inputs) 93 | outputs_quant = converter.run_model_onnx_runtime(out_path, inputs) 94 | 95 | true = torch.tensor(np.array(outputs_true), dtype=torch.float32) 96 | quant = torch.tensor(np.array(outputs_quant), dtype=torch.float32) / (2**21) 97 | 98 | assert torch.allclose(true, quant, rtol=1e-3, atol=1e-5), "Outputs do not match" 99 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "rust/jstprove_circuits", 4 | ] 5 | 6 | 7 | resolver = "2" 8 | # Important to avoid the multiple workspaces error 9 | exclude = ["ExpanderCompilerCollection"] 10 | 11 | 12 | [workspace.dependencies] 13 | serde_json = "1.0" 14 | peakmem-alloc = "0.3.0" 15 | csv = "1.1" 16 | ark-std = "0.4.0" 17 | rand = "0.8.5" 18 | chrono = "0.4" 19 | clap = { version = "4.1", features = ["derive"] } 20 | ethnum = "1.5.0" 21 | tiny-keccak = { version = "2.0", features = ["keccak"] } 22 | halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = ["bits",] } 23 | arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } 24 | expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "circuit" } 25 | gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } 26 | gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } 27 | expander_binary = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "bin" } 28 | mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } 29 | expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "transcript" } 30 | crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "main"} 31 | expander_compiler = {git = "https://github.com/PolyhedraZK/ExpanderCompilerCollection"} 32 | circuit-std-rs = {git = "https://github.com/PolyhedraZK/ExpanderCompilerCollection"} 33 | babybear = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } 34 | gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "gkr_engine" } 35 | gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } 36 | goldilocks = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } 37 | serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "main"} 38 | mpi = "0.8.0" 39 | thiserror = "2.0.15" 40 | strum = "0.27.2" 41 | strum_macros = "0.27.2" 42 | once_cell = "1.21.3" 43 | 44 | # ndarray = "0.16.1" 45 | 46 | ndarray = { version = "0.16.1", features = ["serde"] } 47 | serde = { version = "1.0", features = ["derive"] } 48 | 49 | [patch."https://github.com/rsmpi/rsmpi"] 50 | mpi = { git = "https://github.com/0pendansor/rsmpi", branch = "patch-1" } 51 | mpi-sys = { git = "https://github.com/0pendansor/rsmpi", branch = "patch-1", package = "mpi-sys" } 52 | 53 | [patch.crates-io] 54 | mpi = { git = "https://github.com/0pendansor/rsmpi", branch = "patch-1" } 55 | mpi-sys = { git = "https://github.com/0pendansor/rsmpi", branch = "patch-1", package = "mpi-sys" } 56 | 57 | 58 | # serde_json = "1.0" 59 | 60 | 61 | # [workspace.metadata.cross.target.x86_64-unknown-linux-gnu] 62 | # # Add custom pre-build steps or configurations if needed 63 | # pre-build = [ 64 | # "apt-get update && apt-get install -y libssl-dev && apt-get install libopenmpi-dev && apt-get install openmpi-bin && apt-get install libmpich-dev && apt-get install pkg-config", 65 | # "sudo apt install openmpi-bin openmpi-common libopenmpi-dev", 66 | # "mpirun --version" 67 | # # openmpi-bin \ 68 | # # libopenmpi-dev \ libmpich-dev 69 | # # apt-get install -y \ 70 | # # libmpich-dev \ 71 | # # libopenmpi-dev \ 72 | # # pkg-config \ 73 | # # libclang-dev \ 74 | # # clang \ 75 | # ] 76 | # [workspace.metadata.cross.target.x86_64-unknown-linux-gnu] 77 | # pre-build = [ 78 | # "apt-get update && apt-get install -y openmpi-bin libopenmpi-dev" 79 | # ] 80 | 81 | # [workspace.metadata.cross.target.armv7-unknown-linux-gnueabi] 82 | # image = "my_rust_mpi_image" 83 | -------------------------------------------------------------------------------- /python/frontend/cli.py: -------------------------------------------------------------------------------- 1 | # python/frontend/cli.py 2 | """JSTprove CLI.""" 3 | 4 | from __future__ import annotations 5 | 6 | import argparse 7 | import os 8 | import sys 9 | from typing import TYPE_CHECKING 10 | 11 | if TYPE_CHECKING: 12 | from python.frontend.commands import BaseCommand 13 | 14 | from python.frontend.commands import ( 15 | BenchCommand, 16 | CompileCommand, 17 | ModelCheckCommand, 18 | ProveCommand, 19 | VerifyCommand, 20 | WitnessCommand, 21 | ) 22 | from python.frontend.commands.base import HiddenPositionalHelpFormatter 23 | 24 | BANNER_TITLE = r""" 25 | 888888 .d8888b. 88888888888 26 | "88b d88P Y88b 888 27 | 888 Y88b. 888 28 | 888 "Y888b. 888 88888b. 888d888 .d88b. 888 888 .d88b. 29 | 888 "Y88b. 888 888 "88b 888P" d88""88b 888 888 d8P Y8b 30 | 888 "888 888 888 888 888 888 888 Y88 88P 88888888 31 | 88P Y88b d88P 888 888 d88P 888 Y88..88P Y8bd8P Y8b. 32 | 888 "Y8888P" 888 88888P" 888 "Y88P" Y88P "Y8888 33 | .d88P 888 34 | .d88P" 888 35 | 888P" 888 36 | """ 37 | 38 | COMMANDS: list[type[BaseCommand]] = [ 39 | ModelCheckCommand, 40 | CompileCommand, 41 | WitnessCommand, 42 | ProveCommand, 43 | VerifyCommand, 44 | BenchCommand, 45 | ] 46 | 47 | 48 | def print_header() -> None: 49 | """Print the CLI banner (no side-effects at import time).""" 50 | print( # noqa: T201 51 | BANNER_TITLE 52 | + "\n" 53 | + "JSTprove — Verifiable ML by Inference Labs\n" 54 | + "Based on Polyhedra Network's Expander (GKR-based proving system)\n", 55 | ) 56 | 57 | 58 | def main(argv: list[str] | None = None) -> int: 59 | """ 60 | Entry point for the JSTprove CLI. 61 | 62 | Returns: 63 | 0 on success, 1 on error. 64 | """ 65 | argv = sys.argv[1:] if argv is None else argv 66 | 67 | parser = argparse.ArgumentParser( 68 | prog="jst", 69 | description="ZKML CLI (compile, witness, prove, verify).", 70 | allow_abbrev=False, 71 | ) 72 | parser.add_argument( 73 | "--no-banner", 74 | action="store_true", 75 | help="Suppress the startup banner.", 76 | ) 77 | 78 | subparsers = parser.add_subparsers(dest="cmd", required=True) 79 | 80 | command_map = {} 81 | for command_cls in COMMANDS: 82 | cmd_parser = subparsers.add_parser( 83 | command_cls.name, 84 | aliases=command_cls.aliases, 85 | help=command_cls.help, 86 | allow_abbrev=False, 87 | formatter_class=HiddenPositionalHelpFormatter, 88 | ) 89 | command_cls.configure_parser(cmd_parser) 90 | command_map[command_cls.name] = command_cls 91 | for alias in command_cls.aliases: 92 | command_map[alias] = command_cls 93 | 94 | args = parser.parse_args(argv) 95 | 96 | if not args.no_banner and not os.environ.get("JSTPROVE_NO_BANNER"): 97 | print_header() 98 | 99 | try: 100 | command_cls = command_map[args.cmd] 101 | command_cls.run(args) 102 | except (ValueError, FileNotFoundError, PermissionError, RuntimeError) as e: 103 | print(f"Error: {e}", file=sys.stderr) # noqa: T201 104 | return 1 105 | except SystemExit: 106 | raise 107 | except Exception as e: 108 | print(f"Error: {e}", file=sys.stderr) # noqa: T201 109 | return 1 110 | 111 | return 0 112 | 113 | 114 | if __name__ == "__main__": 115 | raise SystemExit(main()) 116 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/layers/reshape.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use expander_compiler::frontend::{Config, RootAPI, Variable}; 4 | use ndarray::{ArrayD, IxDyn}; 5 | 6 | use crate::circuit_functions::{ 7 | CircuitError, 8 | layers::{LayerError, LayerKind, layer_ops::LayerOp}, 9 | utils::{ 10 | constants::{INPUT, INPUT_SHAPE}, 11 | onnx_model::{extract_params_and_expected_shape, get_input_name, get_param_or_default}, 12 | shaping::infer_reshape_shape, 13 | }, 14 | }; 15 | 16 | // -------- Struct -------- 17 | #[allow(dead_code)] 18 | #[derive(Debug)] 19 | pub struct ReshapeLayer { 20 | name: String, 21 | shape: Vec, 22 | input_shape: Vec, 23 | inputs: Vec, 24 | outputs: Vec, 25 | } 26 | // -------- Implementations -------- 27 | 28 | impl> LayerOp for ReshapeLayer { 29 | fn apply( 30 | &self, 31 | _api: &mut Builder, 32 | input: HashMap>, 33 | ) -> Result<(Vec, ArrayD), CircuitError> { 34 | let reshape_shape = self.shape.clone(); 35 | let input_name = get_input_name(&self.inputs, 0, LayerKind::Conv, INPUT)?; 36 | let layer_input = input 37 | .get(&input_name.clone()) 38 | .ok_or_else(|| LayerError::MissingInput { 39 | layer: LayerKind::Conv, 40 | name: input_name.clone(), 41 | })? 42 | .clone(); 43 | 44 | let inferred_shape = infer_reshape_shape(layer_input.len(), &reshape_shape)?; 45 | 46 | let out = layer_input 47 | .into_shape_with_order(IxDyn(&inferred_shape)) 48 | .map_err(|_| LayerError::InvalidShape { 49 | layer: LayerKind::Reshape, 50 | msg: format!("Cannot reshape into {inferred_shape:?}"), 51 | })?; 52 | 53 | Ok((self.outputs.clone(), out.clone())) 54 | } 55 | 56 | fn build( 57 | layer: &crate::circuit_functions::utils::onnx_types::ONNXLayer, 58 | _circuit_params: &crate::circuit_functions::utils::onnx_model::CircuitParams, 59 | _optimization_pattern: crate::circuit_functions::utils::graph_pattern_matching::PatternRegistry, 60 | _is_rescale: bool, 61 | _index: usize, 62 | layer_context: &crate::circuit_functions::utils::build_layers::BuildLayerContext, 63 | ) -> Result>, CircuitError> { 64 | let shape_name = get_input_name(&layer.inputs, 1, LayerKind::Reshape, INPUT_SHAPE)?; 65 | let (params, expected_shape) = extract_params_and_expected_shape(layer_context, layer) 66 | .map_err(|e| LayerError::Other { 67 | layer: LayerKind::Reshape, 68 | msg: format!("extract_params_and_expected_shape failed: {e}"), 69 | })?; 70 | let output_shape = layer_context.shapes_map.get(&layer.outputs.clone()[0]); 71 | let output_shape_isize: Option> = output_shape.map(|v| { 72 | v.iter() 73 | .filter_map(|&x| x.try_into().ok()) // convert usize -> isize, ignore if it fails 74 | .collect() 75 | }); 76 | 77 | let shape: Vec = get_param_or_default( 78 | &layer.name, 79 | shape_name, 80 | ¶ms, 81 | output_shape_isize.as_ref(), 82 | )?; 83 | 84 | let reshape = Self { 85 | name: layer.name.clone(), 86 | input_shape: expected_shape.clone(), 87 | inputs: layer.inputs.clone(), 88 | outputs: layer.outputs.clone(), 89 | shape, 90 | }; 91 | Ok(Box::new(reshape)) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/layers/sub.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | /// External crate imports 4 | use ndarray::ArrayD; 5 | 6 | /// `ExpanderCompilerCollection` imports 7 | use expander_compiler::frontend::{Config, RootAPI, Variable}; 8 | 9 | use crate::circuit_functions::gadgets::linear_algebra::matrix_subtraction; 10 | use crate::circuit_functions::utils::onnx_model::get_optional_w_or_b; 11 | use crate::circuit_functions::utils::tensor_ops::{ 12 | broadcast_two_arrays, load_array_constants_or_get_inputs, 13 | }; 14 | use crate::circuit_functions::{ 15 | CircuitError, 16 | layers::{LayerError, LayerKind, layer_ops::LayerOp}, 17 | utils::{ 18 | constants::INPUT, 19 | graph_pattern_matching::PatternRegistry, 20 | onnx_model::{extract_params_and_expected_shape, get_input_name}, 21 | }, 22 | }; 23 | 24 | // -------- Struct -------- 25 | #[allow(dead_code)] 26 | #[derive(Debug)] 27 | pub struct SubLayer { 28 | name: String, 29 | // weights: ArrayD, //This should be an optional field 30 | optimization_pattern: PatternRegistry, 31 | input_shape: Vec, 32 | inputs: Vec, 33 | outputs: Vec, 34 | initializer_a: Option>, 35 | initializer_b: Option>, 36 | } 37 | 38 | // -------- Implementation -------- 39 | 40 | impl> LayerOp for SubLayer { 41 | fn apply( 42 | &self, 43 | api: &mut Builder, 44 | input: HashMap>, 45 | ) -> Result<(Vec, ArrayD), CircuitError> { 46 | let a_name = get_input_name(&self.inputs, 0, LayerKind::Sub, INPUT)?; 47 | let b_name = get_input_name(&self.inputs, 1, LayerKind::Sub, INPUT)?; 48 | 49 | let a_input = load_array_constants_or_get_inputs( 50 | api, 51 | &input, 52 | a_name, 53 | &self.initializer_a, 54 | LayerKind::Sub, 55 | )?; 56 | 57 | let b_input = load_array_constants_or_get_inputs( 58 | api, 59 | &input, 60 | b_name, 61 | &self.initializer_b, 62 | LayerKind::Sub, 63 | )?; 64 | 65 | let (a_bc, b_bc) = broadcast_two_arrays(&a_input, &b_input)?; 66 | 67 | // Matrix subtraction 68 | let result = matrix_subtraction(api, &a_bc, b_bc, LayerKind::Sub)?; 69 | Ok((self.outputs.clone(), result)) 70 | } 71 | fn build( 72 | layer: &crate::circuit_functions::utils::onnx_types::ONNXLayer, 73 | _circuit_params: &crate::circuit_functions::utils::onnx_model::CircuitParams, 74 | optimization_pattern: crate::circuit_functions::utils::graph_pattern_matching::PatternRegistry, 75 | _is_rescale: bool, 76 | _index: usize, 77 | layer_context: &crate::circuit_functions::utils::build_layers::BuildLayerContext, 78 | ) -> Result>, CircuitError> { 79 | let (_params, expected_shape) = extract_params_and_expected_shape(layer_context, layer) 80 | .map_err(|e| LayerError::Other { 81 | layer: LayerKind::Sub, 82 | msg: format!("extract_params_and_expected_shape failed: {e}"), 83 | })?; 84 | 85 | let initializer_a = get_optional_w_or_b(layer_context, &layer.inputs[0])?; 86 | let initializer_b = get_optional_w_or_b(layer_context, &layer.inputs[1])?; 87 | 88 | let sub = Self { 89 | name: layer.name.clone(), 90 | optimization_pattern, 91 | input_shape: expected_shape.clone(), 92 | inputs: layer.inputs.clone(), 93 | outputs: layer.outputs.clone(), 94 | initializer_a, 95 | initializer_b, 96 | }; 97 | Ok(Box::new(sub)) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/test_exceptions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from python.core.model_processing.onnx_quantizer.exceptions import ( 4 | REPORTING_URL, 5 | InvalidParamError, 6 | QuantizationError, 7 | UnsupportedOpError, 8 | ) 9 | 10 | 11 | @pytest.mark.unit 12 | def test_quantization_error_message() -> None: 13 | custom_msg = "Something went wrong." 14 | with pytest.raises(QuantizationError) as exc_info: 15 | raise QuantizationError(custom_msg) 16 | assert "This model is not supported by JSTprove." in str(exc_info.value) 17 | assert custom_msg in str(exc_info.value) 18 | 19 | assert REPORTING_URL in str(exc_info.value) 20 | 21 | assert "Submit model support requests via the JSTprove channel:" in str( 22 | exc_info.value, 23 | ) 24 | 25 | 26 | @pytest.mark.unit 27 | def test_invalid_param_error_basic() -> None: 28 | with pytest.raises(InvalidParamError) as exc_info: 29 | raise InvalidParamError( 30 | node_name="Conv_1", 31 | op_type="Conv", 32 | message="Missing 'strides' attribute.", 33 | ) 34 | err_msg = str(exc_info.value) 35 | assert "Invalid parameters in node 'Conv_1'" in err_msg 36 | assert "(op_type='Conv')" in err_msg 37 | assert "Missing 'strides' attribute." in err_msg 38 | assert "[Attribute:" not in err_msg 39 | assert "[Expected:" not in err_msg 40 | 41 | msg = "" 42 | 43 | with pytest.raises(QuantizationError) as exc_info_quantization: 44 | raise QuantizationError(msg) 45 | # Assert contains generic error message from quantization error 46 | assert str(exc_info_quantization.value) in err_msg 47 | 48 | 49 | @pytest.mark.unit 50 | def test_invalid_param_error_with_attr_and_expected() -> None: 51 | with pytest.raises(InvalidParamError) as exc_info: 52 | raise InvalidParamError( 53 | node_name="MaxPool_3", 54 | op_type="MaxPool", 55 | message="Kernel shape is invalid.", 56 | attr_key="kernel_shape", 57 | expected="a list of 2 positive integers", 58 | ) 59 | err_msg = str(exc_info.value) 60 | assert "Invalid parameters in node 'MaxPool_3'" in err_msg 61 | assert "[Attribute: kernel_shape]" in err_msg 62 | assert "[Expected: a list of 2 positive integers]" in err_msg 63 | msg = "" 64 | 65 | with pytest.raises(QuantizationError) as exc_info_quantization: 66 | raise QuantizationError(msg) 67 | # Assert contains generic error message from quantization error 68 | assert str(exc_info_quantization.value) in err_msg 69 | 70 | 71 | @pytest.mark.unit 72 | def test_unsupported_op_error_with_node() -> None: 73 | with pytest.raises(UnsupportedOpError) as exc_info: 74 | raise UnsupportedOpError(op_type="Resize", node_name="Resize_42") 75 | err_msg = str(exc_info.value) 76 | assert "Unsupported op type: 'Resize'" in err_msg 77 | assert "in node 'Resize_42'" in err_msg 78 | assert "documentation for supported layers" in err_msg 79 | msg = "" 80 | 81 | with pytest.raises(QuantizationError) as exc_info_quantization: 82 | raise QuantizationError(msg) 83 | # Assert contains generic error message from quantization error 84 | assert str(exc_info_quantization.value) in err_msg 85 | 86 | 87 | @pytest.mark.unit 88 | def test_unsupported_op_error_without_node() -> None: 89 | with pytest.raises(UnsupportedOpError) as exc_info: 90 | raise UnsupportedOpError(op_type="Upsample") 91 | err_msg = str(exc_info.value) 92 | assert "Unsupported op type: 'Upsample'" in err_msg 93 | assert "in node" not in err_msg 94 | msg = "" 95 | 96 | with pytest.raises(QuantizationError) as exc_info_quantization: 97 | raise QuantizationError(msg) 98 | # Assert contains generic error message from quantization error 99 | assert str(exc_info_quantization.value) in err_msg 100 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/layers/add.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | /// External crate imports 4 | use ndarray::ArrayD; 5 | 6 | /// `ExpanderCompilerCollection` imports 7 | use expander_compiler::frontend::{Config, RootAPI, Variable}; 8 | 9 | use crate::circuit_functions::gadgets::linear_algebra::matrix_addition; 10 | use crate::circuit_functions::utils::onnx_model::get_optional_w_or_b; 11 | use crate::circuit_functions::utils::tensor_ops::{ 12 | broadcast_two_arrays, load_array_constants_or_get_inputs, 13 | }; 14 | use crate::circuit_functions::{ 15 | CircuitError, 16 | layers::{LayerError, LayerKind, layer_ops::LayerOp}, 17 | utils::{ 18 | constants::INPUT, 19 | graph_pattern_matching::PatternRegistry, 20 | onnx_model::{extract_params_and_expected_shape, get_input_name}, 21 | }, 22 | }; 23 | 24 | // -------- Struct -------- 25 | #[allow(dead_code)] 26 | #[derive(Debug)] 27 | pub struct AddLayer { 28 | name: String, 29 | // weights: ArrayD, //This should be an optional field 30 | optimization_pattern: PatternRegistry, 31 | input_shape: Vec, 32 | inputs: Vec, 33 | outputs: Vec, 34 | initializer_a: Option>, 35 | initializer_b: Option>, 36 | } 37 | 38 | // -------- Implementation -------- 39 | 40 | impl> LayerOp for AddLayer { 41 | fn apply( 42 | &self, 43 | api: &mut Builder, 44 | input: HashMap>, 45 | ) -> Result<(Vec, ArrayD), CircuitError> { 46 | let a_name = get_input_name(&self.inputs, 0, LayerKind::Add, INPUT)?; 47 | let b_name = get_input_name(&self.inputs, 1, LayerKind::Add, INPUT)?; 48 | 49 | let a_input = load_array_constants_or_get_inputs( 50 | api, 51 | &input, 52 | a_name, 53 | &self.initializer_a, 54 | LayerKind::Add, 55 | )?; 56 | 57 | let b_input = load_array_constants_or_get_inputs( 58 | api, 59 | &input, 60 | b_name, 61 | &self.initializer_b, 62 | LayerKind::Add, 63 | )?; 64 | 65 | let (a_bc, b_bc) = broadcast_two_arrays(&a_input, &b_input)?; 66 | 67 | // Matrix multiplication and bias addition 68 | let result = matrix_addition(api, &a_bc, b_bc, LayerKind::Add)?; 69 | Ok((self.outputs.clone(), result)) 70 | } 71 | fn build( 72 | layer: &crate::circuit_functions::utils::onnx_types::ONNXLayer, 73 | _circuit_params: &crate::circuit_functions::utils::onnx_model::CircuitParams, 74 | optimization_pattern: crate::circuit_functions::utils::graph_pattern_matching::PatternRegistry, 75 | _is_rescale: bool, 76 | _index: usize, 77 | layer_context: &crate::circuit_functions::utils::build_layers::BuildLayerContext, 78 | ) -> Result>, CircuitError> { 79 | let (_params, expected_shape) = extract_params_and_expected_shape(layer_context, layer) 80 | .map_err(|e| LayerError::Other { 81 | layer: LayerKind::Add, 82 | msg: format!("extract_params_and_expected_shape failed: {e}"), 83 | })?; 84 | 85 | let initializer_a = get_optional_w_or_b(layer_context, &layer.inputs[0])?; 86 | let initializer_b = get_optional_w_or_b(layer_context, &layer.inputs[1])?; 87 | 88 | let add = Self { 89 | name: layer.name.clone(), 90 | // weights: get_w_or_b(&layer_context.w_and_b_map, &layer.inputs[1])?, 91 | optimization_pattern, 92 | input_shape: expected_shape.clone(), 93 | inputs: layer.inputs.clone(), 94 | outputs: layer.outputs.clone(), 95 | initializer_a, 96 | initializer_b, 97 | }; 98 | Ok(Box::new(add)) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_custom_ops/conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as f 6 | from onnxruntime_extensions import PyCustomOpDef, onnx_op 7 | 8 | from .custom_helpers import parse_attr, rescaling 9 | 10 | 11 | @onnx_op( 12 | op_type="Int64Conv", 13 | domain="ai.onnx.contrib", 14 | inputs=[ 15 | PyCustomOpDef.dt_int64, # X 16 | PyCustomOpDef.dt_int64, # W 17 | PyCustomOpDef.dt_int64, # B 18 | PyCustomOpDef.dt_int64, # scaling factor 19 | ], 20 | outputs=[PyCustomOpDef.dt_int64], 21 | attrs={ 22 | "auto_pad": PyCustomOpDef.dt_string, 23 | "strides": PyCustomOpDef.dt_string, 24 | "pads": PyCustomOpDef.dt_string, 25 | "dilations": PyCustomOpDef.dt_string, 26 | "group": PyCustomOpDef.dt_int64, 27 | "kernel_shape": PyCustomOpDef.dt_string, 28 | "rescale": PyCustomOpDef.dt_int64, 29 | }, 30 | ) 31 | def int64_conv( 32 | x: np.ndarray, 33 | w: np.ndarray, 34 | b: np.ndarray | None = None, 35 | scaling_factor: np.ndarray | None = None, 36 | auto_pad: str | None = None, 37 | dilations: str | None = None, 38 | group: int | None = None, 39 | kernel_shape: str | None = None, 40 | pads: str | None = None, 41 | strides: str | None = None, 42 | rescale: int | None = None, 43 | ) -> np.ndarray: 44 | """ 45 | Performs a convolution on int64 input tensors. 46 | 47 | This function is registered as a custom ONNX operator via onnxruntime_extensions 48 | and is used in the JSTprove quantized inference pipeline. It parses ONNX-style 49 | convolution attributes, applies convolution 50 | and optionally rescales the result. 51 | 52 | Parameters 53 | ---------- 54 | X : Input tensor with dtype int64. 55 | W : Convolution weight tensor with dtype int64. 56 | B : Optional bias tensor with dtype int64. 57 | scaling_factor : Scaling factor for rescaling the output. 58 | auto_pad : Optional ONNX auto padding type (`SAME_UPPER`, `SAME_LOWER`, `VALID`). 59 | dilations : Dilation values for the convolution (default: `[1, 1]`). 60 | group : Group value for the convolution (default: 1). 61 | kernel_shape : Kernel shape (default: `[3, 3]`). 62 | pads : Padding values (default: `[0, 0, 0, 0]`). 63 | strides : Stride values (default: `[1, 1]`). 64 | rescale : Optional flag to apply output rescaling or not. 65 | 66 | Returns 67 | ------- 68 | numpy.ndarray 69 | Convolved tensor with dtype int64. 70 | 71 | Notes 72 | ----- 73 | - This op is part of the `ai.onnx.contrib` custom domain. 74 | - ONNX Runtime Extensions is required to register this op. 75 | 76 | References 77 | ---------- 78 | For more information on the convolution operation, please refer to the 79 | ONNX standard Conv operator documentation: 80 | https://onnx.ai/onnx/operators/onnx__Conv.html 81 | """ 82 | _ = auto_pad 83 | try: 84 | strides = parse_attr(strides, [1, 1]) 85 | dilations = parse_attr(dilations, [1, 1]) 86 | pads = parse_attr(pads, [0, 0, 0, 0]) 87 | kernel_shape = parse_attr(kernel_shape, [3, 3]) 88 | 89 | x = torch.from_numpy(x) 90 | w = torch.from_numpy(w) 91 | b = torch.from_numpy(b) 92 | 93 | result = ( 94 | f.conv2d( 95 | x, 96 | w, 97 | bias=b, 98 | stride=strides, 99 | padding=pads[:2], 100 | dilation=dilations, 101 | groups=group, 102 | ) 103 | .numpy() 104 | .astype(np.int64) 105 | ) 106 | result = rescaling(scaling_factor, rescale, result) 107 | return result.astype(np.int64) 108 | 109 | except Exception as e: 110 | msg = f"Int64Conv failed: {e}" 111 | raise RuntimeError(msg) from e 112 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/utils/errors.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Debug, Error)] 4 | pub enum UtilsError { 5 | #[error("Missing param '{param}' for layer '{layer}'")] 6 | MissingParam { layer: String, param: String }, 7 | 8 | #[error("{layer_name} is missing input: {name}")] 9 | MissingInput { layer_name: String, name: String }, 10 | 11 | #[error("Failed to parse param '{param}' for layer '{layer}': {source}")] 12 | ParseError { 13 | layer: String, 14 | param: String, 15 | #[source] 16 | source: serde_json::Error, 17 | }, 18 | 19 | #[error("Cannot convert variable of type '{initial_var_type}' to '{converted_var_type}'")] 20 | ValueConversionError { 21 | initial_var_type: String, 22 | converted_var_type: String, 23 | }, 24 | 25 | #[error("Inputs length mismatch: got {got}, required {required}")] 26 | InputDataLengthMismatch { got: usize, required: usize }, 27 | 28 | #[error("Missing tensor '{tensor}' in weights map")] 29 | MissingTensor { tensor: String }, 30 | 31 | #[error("Bitstring too long: bit index {value} exceeds limit {max}")] 32 | ValueTooLarge { value: usize, max: u128 }, 33 | 34 | #[error("Expected number, but got {value}")] 35 | InvalidNumber { value: serde_json::Value }, 36 | 37 | #[error("Graph error: {0}")] 38 | GraphPatternError(#[from] PatternError), 39 | 40 | #[error("Array conversion error: {0}")] 41 | ArrayConversionError(#[from] ArrayConversionError), 42 | 43 | #[error("Rescaling error: {0}")] 44 | RescaleError(#[from] RescaleError), 45 | 46 | #[error("Build error: {0}")] 47 | BuildError(#[from] BuildError), 48 | } 49 | 50 | #[derive(Debug, thiserror::Error)] 51 | pub enum ArrayConversionError { 52 | #[error("Invalid array structure: expected {expected}, found {found}")] 53 | InvalidArrayStructure { expected: String, found: String }, 54 | 55 | #[error("Invalid number for target type")] 56 | InvalidNumber, 57 | 58 | #[error("Shape error: {0}")] 59 | ShapeError(#[from] ndarray::ShapeError), 60 | 61 | #[error("Invalid axis {axis} for rank {rank}")] 62 | InvalidAxis { axis: usize, rank: usize }, 63 | } 64 | 65 | #[derive(Debug, Error)] 66 | pub enum PatternError { 67 | #[error("Optimization matches have inconsistent patterns: expected {expected}, got {got}")] 68 | InconsistentPattern { expected: String, got: String }, 69 | 70 | #[error("Outputs do not match: expected all in {expected:?}, but got {actual:?}")] 71 | OutputMismatch { 72 | expected: Vec, 73 | actual: Vec, 74 | }, 75 | 76 | #[error("Optimization pattern has no layers")] 77 | EmptyMatch, 78 | 79 | #[error( 80 | "Developer error: Empty optimization pattern {pattern} has been attempted to be created. This is not allowed" 81 | )] 82 | EmptyPattern { pattern: String }, 83 | } 84 | 85 | #[derive(Debug, Error)] 86 | pub enum RescaleError { 87 | #[error("Exponent too large for {type_name} shift: scaling_exponent={exp}")] 88 | ScalingExponentTooLargeError { exp: usize, type_name: &'static str }, 89 | 90 | #[error("Exponent too large for {type_name} shift: shift_exponent={exp}")] 91 | ShiftExponentTooLargeError { exp: usize, type_name: &'static str }, 92 | 93 | #[error("Bit decomposition failed for {var_name} into {n_bits} bits")] 94 | BitDecompositionError { var_name: String, n_bits: usize }, 95 | 96 | #[error("Bit reconstruction failed for {var_name} into {n_bits} bits")] 97 | BitReconstructionError { var_name: String, n_bits: usize }, 98 | } 99 | 100 | #[derive(thiserror::Error, Debug)] 101 | pub enum BuildError { 102 | #[error("Pattern matcher failed: {0}")] 103 | PatternMatcher(#[from] PatternError), 104 | 105 | #[error("Unsupported layer type: {0}")] 106 | UnsupportedLayer(String), 107 | 108 | #[error("Layer build failed: {0}")] 109 | LayerBuild(String), 110 | } 111 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/min_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | 5 | from python.tests.onnx_quantizer_tests import TEST_RNG_SEED 6 | from python.tests.onnx_quantizer_tests.layers.base import ( 7 | e2e_test, 8 | edge_case_test, 9 | valid_test, 10 | ) 11 | from python.tests.onnx_quantizer_tests.layers.factory import ( 12 | BaseLayerConfigProvider, 13 | LayerTestConfig, 14 | ) 15 | 16 | 17 | class MinConfigProvider(BaseLayerConfigProvider): 18 | """Test configuration provider for elementwise Min""" 19 | 20 | @property 21 | def layer_name(self) -> str: 22 | return "Min" 23 | 24 | def get_config(self) -> LayerTestConfig: 25 | return LayerTestConfig( 26 | op_type="Min", 27 | valid_inputs=["A", "B"], 28 | valid_attributes={}, # Min has no layer-specific attributes 29 | required_initializers={}, # default: both A and B are dynamic inputs 30 | input_shapes={ 31 | "A": [1, 3, 4, 4], 32 | "B": [1, 3, 4, 4], 33 | }, 34 | output_shapes={ 35 | "min_output": [1, 3, 4, 4], 36 | }, 37 | ) 38 | 39 | def get_test_specs(self) -> list: 40 | rng = np.random.default_rng(TEST_RNG_SEED) 41 | return [ 42 | # --- VALID TESTS --- 43 | valid_test("basic") 44 | .description("Basic elementwise Min of two same-shaped tensors") 45 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 46 | .tags("basic", "elementwise", "min") 47 | .build(), 48 | valid_test("broadcast_min") 49 | .description("Min with Numpy-style broadcasting along spatial dimensions") 50 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 51 | .tags("broadcast", "elementwise", "min", "onnx14") 52 | .build(), 53 | valid_test("initializer_min") 54 | .description("Min where B is an initializer instead of an input") 55 | .override_input_shapes(A=[1, 3, 4, 4]) 56 | .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4))) 57 | .tags("initializer", "elementwise", "min", "onnxruntime") 58 | .build(), 59 | # --- E2E TESTS --- 60 | e2e_test("e2e_min") 61 | .description("End-to-end Min test with random inputs") 62 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 63 | .override_output_shapes(min_output=[1, 3, 4, 4]) 64 | .tags("e2e", "min", "2d") 65 | .build(), 66 | e2e_test("e2e_broadcast_min") 67 | .description( 68 | "End-to-end Min with Numpy-style broadcasting along spatial dimensions", 69 | ) 70 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 71 | .override_output_shapes(min_output=[1, 3, 4, 4]) 72 | .tags("e2e", "broadcast", "elementwise", "min", "onnx14") 73 | .build(), 74 | e2e_test("e2e_initializer_min") 75 | .description("End-to-end Min where B is an initializer") 76 | .override_input_shapes(A=[1, 3, 4, 4]) 77 | .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4))) 78 | .override_output_shapes(min_output=[1, 3, 4, 4]) 79 | .tags("e2e", "initializer", "elementwise", "min", "onnxruntime") 80 | .build(), 81 | # --- EDGE / STRESS --- 82 | edge_case_test("empty_tensor") 83 | .description("Min with empty tensor input (zero elements)") 84 | .override_input_shapes(A=[0], B=[0]) 85 | .override_output_shapes(min_output=[0]) 86 | .tags("edge", "empty", "min") 87 | .build(), 88 | valid_test("large_tensor") 89 | .description("Large tensor min performance/stress test") 90 | .override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256]) 91 | .tags("large", "performance", "min") 92 | .skip("Performance test, skipped by default") 93 | .build(), 94 | ] 95 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import pytest 6 | from onnx import TensorProto, helper 7 | 8 | from python.core.model_processing.onnx_quantizer.exceptions import ( 9 | InvalidParamError, 10 | UnsupportedOpError, 11 | ) 12 | 13 | if TYPE_CHECKING: 14 | from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ( 15 | ONNXOpQuantizer, 16 | ) 17 | from python.tests.onnx_quantizer_tests.layers.base import ( 18 | LayerTestConfig, 19 | LayerTestSpec, 20 | SpecType, 21 | ) 22 | from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory 23 | from python.tests.onnx_quantizer_tests.layers_tests.base_test import ( 24 | BaseQuantizerTest, 25 | ) 26 | 27 | 28 | class TestCheckModel(BaseQuantizerTest): 29 | """Tests for ONNX model checking.""" 30 | 31 | __test__ = True 32 | 33 | @pytest.mark.unit 34 | @pytest.mark.parametrize( 35 | "test_case_data", 36 | TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type] 37 | ids=BaseQuantizerTest._generate_test_id, 38 | ) 39 | def test_check_model_individual_valid_cases( 40 | self: TestCheckModel, 41 | quantizer: ONNXOpQuantizer, 42 | test_case_data: tuple[str, LayerTestConfig, LayerTestSpec], 43 | ) -> None: 44 | """Test each individual valid test case""" 45 | layer_name, config, test_spec = test_case_data 46 | 47 | # Skips if layer is not a valid onnx layer 48 | self._check_validation_dependency(test_case_data) 49 | 50 | if test_spec.skip_reason: 51 | pytest.skip(f"{layer_name}_{test_spec.name}: {test_spec.skip_reason}") 52 | 53 | # Create model from layer specs 54 | model = config.create_test_model(test_spec) 55 | 56 | try: 57 | quantizer.check_model(model) 58 | except (InvalidParamError, UnsupportedOpError) as e: 59 | pytest.fail(f"Model check failed for {layer_name}.{test_spec.name}: {e}") 60 | except Exception as e: 61 | pytest.fail(f"Model check failed for {layer_name}.{test_spec.name}: {e}") 62 | 63 | @pytest.mark.unit 64 | def test_check_model_unsupported_layer_fails( 65 | self: TestCheckModel, 66 | quantizer: ONNXOpQuantizer, 67 | ) -> None: 68 | """Test that models with unsupported layers fail validation""" 69 | # Create model with unsupported operation 70 | unsupported_node = helper.make_node( 71 | "UnsupportedOp", 72 | inputs=["input"], 73 | outputs=["output"], 74 | name="unsupported", 75 | ) 76 | 77 | graph = helper.make_graph( 78 | [unsupported_node], 79 | "test_graph", 80 | [ 81 | helper.make_tensor_value_info( 82 | "input", 83 | TensorProto.FLOAT, 84 | [1, 16, 224, 224], 85 | ), 86 | ], 87 | [helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 10])], 88 | ) 89 | 90 | model = helper.make_model(graph) 91 | 92 | with pytest.raises(UnsupportedOpError): 93 | quantizer.check_model(model) 94 | 95 | @pytest.mark.unit 96 | @pytest.mark.parametrize( 97 | "layer_combination", 98 | [ 99 | ["Conv", "Relu"], 100 | ["Conv", "Relu", "MaxPool"], 101 | ["Gemm", "Relu"], 102 | ["Conv", "Reshape", "Gemm"], 103 | ["Conv", "Flatten", "Gemm"], 104 | ], 105 | ) 106 | def test_check_model_multi_layer_passes( 107 | self: TestCheckModel, 108 | quantizer: ONNXOpQuantizer, 109 | layer_configs: dict[str, LayerTestConfig], 110 | layer_combination: list[str], 111 | ) -> None: 112 | """Test that models with multiple supported layers pass validation""" 113 | model = self.create_model_with_layers(layer_combination, layer_configs) 114 | # Should not raise any exception 115 | quantizer.check_model(model) 116 | -------------------------------------------------------------------------------- /python/core/utils/errors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | class CircuitExecutionError(Exception): 5 | """Base exception for all circuit execution-related errors.""" 6 | 7 | def __init__(self: CircuitExecutionError, message: str) -> None: 8 | super().__init__(message) 9 | self.message = message 10 | 11 | 12 | class MissingFileError(CircuitExecutionError): 13 | """Raised when cant find file""" 14 | 15 | def __init__(self: MissingFileError, message: str, path: str | None = None) -> None: 16 | full_message = message if path is None else f"{message} [Path: {path}]" 17 | super().__init__(full_message) 18 | self.path = path 19 | 20 | 21 | class FileCacheError(CircuitExecutionError): 22 | """Raised when reading or writing cached output fails.""" 23 | 24 | def __init__(self: FileCacheError, message: str, path: str | None = None) -> None: 25 | full_message = message if path is None else f"{message} [Path: {path}]" 26 | super().__init__(full_message) 27 | self.path = path 28 | 29 | 30 | class ProofBackendError(CircuitExecutionError): 31 | """Raised when a Cargo command fails.""" 32 | 33 | def __init__( 34 | self: ProofBackendError, 35 | message: str, 36 | command: list[str] | None = None, 37 | returncode: int | None = None, 38 | stdout: str | None = None, 39 | stderr: str | None = None, 40 | ) -> None: 41 | parts = [message] 42 | if command is not None: 43 | command2 = [str(c) for c in command] 44 | parts.append(f"Command: {' '.join(command2)}") 45 | command = command2 46 | if returncode is not None: 47 | parts.append(f"Exit code: {returncode}") 48 | if stdout: 49 | parts.append(f"STDOUT:\n{stdout}") 50 | if stderr: 51 | parts.append(f"STDERR:\n{stderr}") 52 | full_message = "\n".join(parts) 53 | super().__init__(full_message) 54 | self.command = command 55 | self.returncode = returncode 56 | self.stdout = stdout 57 | self.stderr = stderr 58 | 59 | 60 | class ProofSystemNotImplementedError(CircuitExecutionError): 61 | """Raised when a proof system is not implemented.""" 62 | 63 | def __init__(self: ProofSystemNotImplementedError, proof_system: object) -> None: 64 | message = f"Proof system '{proof_system}' is not implemented." 65 | super().__init__(message) 66 | self.proof_system = proof_system 67 | 68 | 69 | class CircuitUtilsError(Exception): 70 | """Base exception for layer utility errors.""" 71 | 72 | 73 | class InputFileError(CircuitUtilsError): 74 | """Raised when reading an input file fails.""" 75 | 76 | def __init__( 77 | self: InputFileError, 78 | file_path: str, 79 | message: str, 80 | *, 81 | cause: Exception | None = None, 82 | ) -> None: 83 | full_msg = f"Failed to read input file '{file_path}': {message}" 84 | super().__init__(full_msg) 85 | self.file_path = file_path 86 | self.__cause__ = cause 87 | 88 | 89 | class MissingCircuitAttributeError(CircuitUtilsError): 90 | """Raised when a required attribute is missing or not set.""" 91 | 92 | def __init__( 93 | self: MissingCircuitAttributeError, 94 | attribute_name: str, 95 | context: str | None = None, 96 | ) -> None: 97 | msg = f"Required attribute '{attribute_name}' is missing" 98 | if context: 99 | msg += f" ({context})" 100 | super().__init__(msg) 101 | self.attribute_name = attribute_name 102 | 103 | 104 | class ShapeMismatchError(CircuitUtilsError): 105 | """Raised when reshaping tensors fails due to incompatible shapes.""" 106 | 107 | def __init__( 108 | self: ShapeMismatchError, 109 | expected_shape: list[int], 110 | actual_shape: list[int], 111 | ) -> None: 112 | super().__init__( 113 | f"Cannot reshape tensor of shape {actual_shape}" 114 | f" to expected shape {expected_shape}", 115 | ) 116 | self.expected_shape = expected_shape 117 | self.actual_shape = actual_shape 118 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/constant.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ClassVar 4 | 5 | import numpy as np 6 | import onnx 7 | from onnx import numpy_helper 8 | 9 | from python.core.model_processing.onnx_quantizer.exceptions import ( 10 | HandlerImplementationError, 11 | ) 12 | from python.core.model_processing.onnx_quantizer.layers.base import ( 13 | BaseOpQuantizer, 14 | ScaleConfig, 15 | ) 16 | 17 | 18 | class ConstantQuantizer(BaseOpQuantizer): 19 | """ 20 | Quantizer for ONNX Constant node. 21 | 22 | This quantizer only modifies constants that are: 23 | - Numeric tensors 24 | - Used directly in computation 25 | 26 | Constants used for shape, indexing, or other non-numeric roles are left unchanged. 27 | """ 28 | 29 | DATA_OPS: ClassVar = { 30 | "Add", 31 | "Mul", 32 | "Conv", 33 | "MatMul", 34 | "Sub", 35 | "Div", 36 | "Gemm", 37 | } # ops that consume numeric constants 38 | 39 | def __init__( 40 | self: ConstantQuantizer, 41 | new_initializer: list[onnx.TensorProto] | None = None, 42 | ) -> None: 43 | super().__init__() 44 | _ = new_initializer 45 | 46 | def quantize( 47 | self: ConstantQuantizer, 48 | node: onnx.NodeProto, 49 | graph: onnx.GraphProto, 50 | scale_config: ScaleConfig, 51 | initializer_map: dict[str, onnx.TensorProto], 52 | ) -> list[onnx.NodeProto]: 53 | """Apply quantization scaling to a constant if it is used in 54 | numeric computation. 55 | 56 | Args: 57 | node (onnx.NodeProto): The Constant node to quantize. 58 | rescale (bool): Whether rescaling is enabled 59 | (Doesnt have an affect on this op type in some cases) 60 | graph (onnx.GraphProto): The ONNX graph. 61 | scale_exponent (int): Scale exponent. 62 | scale_base (int): The base of scaling 63 | initializer_map (dict[str, onnx.TensorProto]): 64 | Map of initializer names to tensor data. 65 | 66 | Returns: 67 | list[onnx.NodeProto]: The modified node (possibly unchanged). 68 | 69 | Raises: 70 | HandlerImplementationError: If tensor is unreadable 71 | """ 72 | _ = initializer_map 73 | self.validate_node_has_output(node) 74 | 75 | output_name = node.output[0] 76 | 77 | is_data_constant = any( 78 | output_name in n.input and n.op_type in self.DATA_OPS for n in graph.node 79 | ) 80 | 81 | if not is_data_constant: 82 | # Skip quantization for non-numeric constants 83 | return [node] 84 | 85 | # Safe to quantize: numeric constant used in computation 86 | for attr in node.attribute: 87 | if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: 88 | try: 89 | arr = numpy_helper.to_array(attr.t).astype(np.float64) 90 | except (ValueError, Exception) as e: 91 | raise HandlerImplementationError( 92 | op_type="Constant", 93 | message="Failed to read tensor from Constant node" 94 | f" '{node.name}': {e}", 95 | ) from e 96 | 97 | arr *= self.get_scaling( 98 | scale_config.base, 99 | scale_config.exponent, 100 | ) 101 | attr.t.CopyFrom(numpy_helper.from_array(arr, name="")) 102 | 103 | node.name += "_quant" 104 | return [node] 105 | 106 | def check_supported( 107 | self: ConstantQuantizer, 108 | node: onnx.NodeProto, 109 | initializer_map: dict[str, onnx.TensorProto] | None = None, 110 | ) -> None: 111 | """All Constant nodes are supported... For now. 112 | 113 | Args: 114 | node (onnx.NodeProto): Node to be checked 115 | initializer_map (dict[str, onnx.TensorProto], optional): 116 | Map of initializer names to tensor data. Defaults to None. 117 | """ 118 | _ = node, initializer_map 119 | -------------------------------------------------------------------------------- /docs/cli.md: -------------------------------------------------------------------------------- 1 | # CLI Reference 2 | 3 | The JSTprove CLI runs four steps: **compile → witness → prove → verify**. It's intentionally barebones: no circuit class flags, no path inference. You must pass correct paths. 4 | 5 | ## Installation 6 | 7 | **For development (editable install in venv):** 8 | ```bash 9 | uv sync && uv pip install -e . 10 | ``` 11 | 12 | **For regular use (global install):** 13 | ```bash 14 | uv tool install . 15 | ``` 16 | 17 | --- 18 | 19 | ## Synopsis 20 | 21 | ```bash 22 | jst [--no-banner] [options] 23 | ``` 24 | 25 | * `--no-banner` — suppress the ASCII header. 26 | * Abbreviations are **disabled**; use the full subcommand or an alias. 27 | 28 | --- 29 | 30 | ## Help 31 | 32 | ```bash 33 | jst --help 34 | jst --help 35 | # e.g. 36 | jst witness --help 37 | ``` 38 | 39 | --- 40 | 41 | ## Example paths used below 42 | 43 | * ONNX model: `python/models/models_onnx/lenet.onnx` 44 | * Example input JSON: `python/models/inputs/lenet_input.json` 45 | * Artifacts: `artifacts/lenet/*` 46 | 47 | --- 48 | 49 | ## Commands 50 | 51 | ### model_check 52 | 53 | Check if the ONNX model supplied fits the criteria for JSTprove's current supported layers and parameters 54 | 55 | **Options** 56 | 57 | * `-m, --model-path ` (required) — original ONNX model 58 | 59 | **Example** 60 | 61 | ```bash 62 | jstprove model_check \ 63 | -m python/models/models_onnx/lenet.onnx 64 | ``` 65 | 66 | ### compile (alias: `comp`) 67 | 68 | Generate a circuit file and a **quantized ONNX** model. 69 | 70 | **Options** 71 | 72 | * `-m, --model-path ` (required) — original ONNX model 73 | * `-c, --circuit-path ` (required) — output circuit path 74 | 75 | **Example** 76 | 77 | ```bash 78 | jst compile \ 79 | -m python/models/models_onnx/lenet.onnx \ 80 | -c artifacts/lenet/circuit.txt 81 | ``` 82 | 83 | --- 84 | 85 | ### witness (alias: `wit`) 86 | 87 | Reshapes/scales inputs, runs the quantized model to produce outputs, and writes the witness. 88 | 89 | **Options** 90 | 91 | * `-c, --circuit-path ` (required) — compiled circuit 92 | * `-i, --input-path ` (required) — input JSON 93 | * `-o, --output-path ` (required) — output JSON (written) 94 | * `-w, --witness-path ` (required) — witness file (written) 95 | 96 | **Example** 97 | 98 | ```bash 99 | jst witness \ 100 | -c artifacts/lenet/circuit.txt \ 101 | -i python/models/inputs/lenet_input.json \ 102 | -o artifacts/lenet/output.json \ 103 | -w artifacts/lenet/witness.bin 104 | ``` 105 | 106 | --- 107 | 108 | ### prove (alias: `prov`) 109 | 110 | Create a proof from the circuit + witness. 111 | 112 | **Options** 113 | 114 | * `-c, --circuit-path ` (required) — compiled circuit 115 | * `-w, --witness-path ` (required) — witness file 116 | * `-p, --proof-path ` (required) — proof file (written) 117 | 118 | **Example** 119 | 120 | ```bash 121 | jst prove \ 122 | -c artifacts/lenet/circuit.txt \ 123 | -w artifacts/lenet/witness.bin \ 124 | -p artifacts/lenet/proof.bin 125 | ``` 126 | 127 | --- 128 | 129 | ### verify (alias: `ver`) 130 | 131 | Verify the proof. 132 | 133 | **Options** 134 | 135 | * `-c, --circuit-path ` (required) — compiled circuit 136 | * `-i, --input-path ` (required) — input JSON 137 | * `-o, --output-path ` (required) — expected outputs JSON 138 | * `-w, --witness-path ` (required) — witness file 139 | * `-p, --proof-path ` (required) — proof file 140 | 141 | **Example** 142 | 143 | ```bash 144 | jst verify \ 145 | -c artifacts/lenet/circuit.txt \ 146 | -i python/models/inputs/lenet_input.json \ 147 | -o artifacts/lenet/output.json \ 148 | -w artifacts/lenet/witness.bin \ 149 | -p artifacts/lenet/proof.bin 150 | ``` 151 | 152 | --- 153 | 154 | ## Short flags 155 | 156 | * `-m, -c, -i, -o, -w, -p` 157 | 158 | ## Command aliases 159 | 160 | * `compile` → `comp` 161 | * `witness` → `wit` 162 | * `prove` → `prov` 163 | * `verify` → `ver` 164 | 165 | --- 166 | 167 | ## Notes & gotchas 168 | 169 | * The default circuit is **GenericModelONNX**; you don’t pass a circuit class or name. 170 | * All paths are **mandatory**; no automatic discovery or inference. 171 | * If the runner isn’t found, make sure you’re launching from the **repo root**. 172 | * The **compile** step will auto-build the runner if needed. 173 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/max_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | 5 | from python.tests.onnx_quantizer_tests import TEST_RNG_SEED 6 | from python.tests.onnx_quantizer_tests.layers.base import ( 7 | e2e_test, 8 | edge_case_test, 9 | valid_test, 10 | ) 11 | from python.tests.onnx_quantizer_tests.layers.factory import ( 12 | BaseLayerConfigProvider, 13 | LayerTestConfig, 14 | ) 15 | 16 | 17 | class MaxConfigProvider(BaseLayerConfigProvider): 18 | """Test configuration provider for elementwise Max""" 19 | 20 | @property 21 | def layer_name(self) -> str: 22 | return "Max" 23 | 24 | def get_config(self) -> LayerTestConfig: 25 | return LayerTestConfig( 26 | op_type="Max", 27 | valid_inputs=["A", "B"], 28 | valid_attributes={}, # Max has no layer-specific attributes 29 | required_initializers={}, # default: both A and B are dynamic inputs 30 | input_shapes={ 31 | "A": [1, 3, 4, 4], 32 | "B": [1, 3, 4, 4], 33 | }, 34 | output_shapes={ 35 | "max_output": [1, 3, 4, 4], 36 | }, 37 | ) 38 | 39 | def get_test_specs(self) -> list: 40 | rng = np.random.default_rng(TEST_RNG_SEED) 41 | return [ 42 | # --- VALID TESTS --- 43 | valid_test("basic") 44 | .description("Basic elementwise Max of two same-shaped tensors") 45 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 46 | .tags("basic", "elementwise", "max") 47 | .build(), 48 | valid_test("broadcast_max") 49 | .description("Max with Numpy-style broadcasting along spatial dimensions") 50 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 51 | .tags("broadcast", "elementwise", "max", "onnx14") 52 | .build(), 53 | valid_test("initializer_max") 54 | .description("Max where B is an initializer instead of an input") 55 | .override_input_shapes(A=[1, 3, 4, 4]) 56 | .override_initializer( 57 | "B", 58 | rng.normal(0, 1, (1, 3, 4, 4)).astype(np.float32), 59 | ) 60 | .tags("initializer", "elementwise", "max", "onnxruntime") 61 | .build(), 62 | # --- E2E TESTS --- 63 | e2e_test("e2e_max") 64 | .description("End-to-end Max test with random inputs") 65 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 66 | .override_output_shapes(max_output=[1, 3, 4, 4]) 67 | .tags("e2e", "max", "2d") 68 | .build(), 69 | e2e_test("e2e_broadcast_max") 70 | .description( 71 | "End-to-end Max with Numpy-style broadcasting along spatial dimensions", 72 | ) 73 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 74 | .override_output_shapes(max_output=[1, 3, 4, 4]) 75 | .tags("e2e", "broadcast", "elementwise", "max", "onnx14") 76 | .build(), 77 | e2e_test("e2e_initializer_max") 78 | .description("End-to-end Max where B is an initializer") 79 | .override_input_shapes(A=[1, 3, 4, 4]) 80 | .override_initializer( 81 | "B", 82 | rng.normal(0, 1, (1, 3, 4, 4)).astype(np.float32), 83 | ) 84 | .override_output_shapes(max_output=[1, 3, 4, 4]) 85 | .tags("e2e", "initializer", "elementwise", "max", "onnxruntime") 86 | .build(), 87 | # --- EDGE / STRESS --- 88 | edge_case_test("empty_tensor") 89 | .description("Max with empty tensor input (zero elements)") 90 | .override_input_shapes(A=[0], B=[0]) 91 | .override_output_shapes(max_output=[0]) 92 | .tags("edge", "empty", "max") 93 | .build(), 94 | valid_test("large_tensor") 95 | .description("Large tensor max performance/stress test") 96 | .override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256]) 97 | .tags("large", "performance", "max") 98 | .skip("Performance test, skipped by default") 99 | .build(), 100 | ] 101 | -------------------------------------------------------------------------------- /python/core/model_processing/onnx_quantizer/layers/gemm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, ClassVar 4 | 5 | if TYPE_CHECKING: 6 | import onnx 7 | 8 | from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError 9 | from python.core.model_processing.onnx_quantizer.layers.base import ( 10 | BaseOpQuantizer, 11 | QuantizerBase, 12 | ScaleConfig, 13 | ) 14 | 15 | 16 | class QuantizeGemm(QuantizerBase): 17 | OP_TYPE = "Int64Gemm" 18 | USE_WB = True 19 | USE_SCALING = True 20 | DEFAULT_ATTRS: ClassVar = {"transA": 0, "transB": 0} 21 | SCALE_PLAN: ClassVar = {1: 1, 2: 2} 22 | 23 | 24 | class GemmQuantizer(BaseOpQuantizer, QuantizeGemm): 25 | """ 26 | Quantizer for ONNX Gemm layers. 27 | 28 | - Replaces standard Gemm with Int64Gemm from the `ai.onnx.contrib` 29 | domain and makes relevant additional changes to the graph. 30 | - Validates that all required Gemm parameters are present. 31 | """ 32 | 33 | def __init__( 34 | self: GemmQuantizer, 35 | new_initializers: list[onnx.TensorProto] | None = None, 36 | ) -> None: 37 | super().__init__() 38 | # Only replace if caller provided something 39 | if new_initializers is not None: 40 | self.new_initializers = new_initializers 41 | 42 | def quantize( 43 | self: GemmQuantizer, 44 | node: onnx.NodeProto, 45 | graph: onnx.GraphProto, 46 | scale_config: ScaleConfig, 47 | initializer_map: dict[str, onnx.TensorProto], 48 | ) -> list[onnx.NodeProto]: 49 | return QuantizeGemm.quantize(self, node, graph, scale_config, initializer_map) 50 | 51 | def check_supported( 52 | self: GemmQuantizer, 53 | node: onnx.NodeProto, 54 | initializer_map: dict[str, onnx.TensorProto] | None = None, 55 | ) -> None: 56 | """ 57 | Perform high-level validation to ensure that this node 58 | can be quantized safely. 59 | 60 | Args: 61 | node (onnx.NodeProto): ONNX node to be checked 62 | initializer_map (dict[str, onnx.TensorProto]): 63 | Initializer map (name of weight or bias and tensor) 64 | 65 | Raises: 66 | InvalidParamError: If any requirement is not met. 67 | """ 68 | _ = initializer_map 69 | num_valid_inputs = 2 70 | # Ensure inputs exist 71 | if len(node.input) < num_valid_inputs: 72 | raise InvalidParamError( 73 | node.name, 74 | node.op_type, 75 | f"Expected at least 2 inputs (input, weights), got {len(node.input)}", 76 | ) 77 | num_valid_inputs = 3 78 | 79 | if len(node.input) < num_valid_inputs: 80 | raise InvalidParamError( 81 | node.name, 82 | node.op_type, 83 | "Expected at least 3 inputs (input, weights, bias)" 84 | f", got {len(node.input)}", 85 | ) 86 | 87 | # Validate attributes with defaults 88 | attrs = {attr.name: attr for attr in node.attribute} 89 | alpha = getattr(attrs.get("alpha"), "f", 1.0) 90 | beta = getattr(attrs.get("beta"), "f", 1.0) 91 | trans_a = getattr(attrs.get("transA"), "i", 0) 92 | trans_b = getattr(attrs.get("transB"), "i", 1) 93 | 94 | if alpha != 1.0: 95 | raise InvalidParamError( 96 | node.name, 97 | node.op_type, 98 | f"alpha value of {alpha} not supported", 99 | "alpha", 100 | "1.0", 101 | ) 102 | if beta != 1.0: 103 | raise InvalidParamError( 104 | node.name, 105 | node.op_type, 106 | f"beta value of {beta} not supported", 107 | "beta", 108 | "1.0", 109 | ) 110 | if trans_a not in [0, 1]: 111 | raise InvalidParamError( 112 | node.name, 113 | node.op_type, 114 | f"transA value of {trans_a} not supported", 115 | "transA", 116 | "(0,1)", 117 | ) 118 | if trans_b not in [0, 1]: 119 | raise InvalidParamError( 120 | node.name, 121 | node.op_type, 122 | f"transB value of {trans_b} not supported", 123 | "transB", 124 | "(0,1)", 125 | ) 126 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/add_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from python.tests.onnx_quantizer_tests import TEST_RNG_SEED 4 | from python.tests.onnx_quantizer_tests.layers.base import ( 5 | BaseLayerConfigProvider, 6 | LayerTestConfig, 7 | LayerTestSpec, 8 | e2e_test, 9 | edge_case_test, 10 | valid_test, 11 | ) 12 | 13 | 14 | class AddConfigProvider(BaseLayerConfigProvider): 15 | """Test configuration provider for Add layer""" 16 | 17 | @property 18 | def layer_name(self) -> str: 19 | return "Add" 20 | 21 | def get_config(self) -> LayerTestConfig: 22 | return LayerTestConfig( 23 | op_type="Add", 24 | valid_inputs=["A", "B"], 25 | valid_attributes={}, # Add has no layer-specific attributes 26 | required_initializers={}, 27 | input_shapes={ 28 | "A": [1, 3, 4, 4], 29 | "B": [1, 3, 4, 4], 30 | }, 31 | output_shapes={ 32 | "add_output": [1, 3, 4, 4], 33 | }, 34 | ) 35 | 36 | def get_test_specs(self) -> list[LayerTestSpec]: 37 | rng = np.random.default_rng(TEST_RNG_SEED) 38 | return [ 39 | # --- VALID TESTS --- 40 | valid_test("basic") 41 | .description("Basic elementwise Add of two same-shaped tensors") 42 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 43 | .tags("basic", "elementwise", "add") 44 | .build(), 45 | valid_test("broadcast_add") 46 | .description("Add with Numpy-style broadcasting along spatial dimensions") 47 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 48 | .tags("broadcast", "elementwise", "add", "onnx14") 49 | .build(), 50 | valid_test("initializer_add") 51 | .description( 52 | "Add where second input (B) is a tensor initializer instead of input", 53 | ) 54 | .override_input_shapes(A=[1, 3, 4, 4]) 55 | .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4))) 56 | .tags("initializer", "elementwise", "add", "onnxruntime") 57 | .build(), 58 | valid_test("scalar_add") 59 | .description("Add scalar (initializer) to tensor") 60 | .override_input_shapes(A=[1, 3, 4, 4]) 61 | .override_initializer("B", np.array([2.0], dtype=np.float32)) 62 | .tags("scalar", "elementwise", "add") 63 | .build(), 64 | # --- E2E TESTS --- 65 | e2e_test("e2e_add") 66 | .description("End-to-end Add test with random inputs") 67 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 68 | .override_output_shapes(add_output=[1, 3, 4, 4]) 69 | .tags("e2e", "add", "2d") 70 | .build(), 71 | e2e_test("e2e_initializer_add") 72 | .description( 73 | "Add where second input (B) is a tensor initializer instead of input", 74 | ) 75 | .override_input_shapes(A=[1, 3, 4, 4]) 76 | .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4))) 77 | .tags("initializer", "elementwise", "add", "onnxruntime") 78 | .build(), 79 | e2e_test("e2e_broadcast_add") 80 | .description("Add with Numpy-style broadcasting along spatial dimensions") 81 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 82 | .tags("broadcast", "elementwise", "add", "onnx14") 83 | .build(), 84 | e2e_test("e2e_scalar_add") 85 | .description("Add scalar (initializer) to tensor") 86 | .override_input_shapes(A=[1, 3, 4, 4]) 87 | .override_initializer("B", np.array([2.0], dtype=np.float32)) 88 | .tags("scalar", "elementwise", "add") 89 | .build(), 90 | # # --- EDGE CASES --- 91 | edge_case_test("empty_tensor") 92 | .description("Add with empty tensor input (zero elements)") 93 | .override_input_shapes(A=[0], B=[0]) 94 | .tags("edge", "empty", "add") 95 | .build(), 96 | edge_case_test("large_tensor") 97 | .description("Large tensor add performance/stress test") 98 | .override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256]) 99 | .tags("large", "performance", "add") 100 | .skip("Performance test, skipped by default") 101 | .build(), 102 | ] 103 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/mul_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from python.tests.onnx_quantizer_tests import TEST_RNG_SEED 4 | from python.tests.onnx_quantizer_tests.layers.base import ( 5 | BaseLayerConfigProvider, 6 | LayerTestConfig, 7 | LayerTestSpec, 8 | e2e_test, 9 | edge_case_test, 10 | valid_test, 11 | ) 12 | 13 | 14 | class MulConfigProvider(BaseLayerConfigProvider): 15 | """Test configuration provider for Mul layer""" 16 | 17 | @property 18 | def layer_name(self) -> str: 19 | return "Mul" 20 | 21 | def get_config(self) -> LayerTestConfig: 22 | return LayerTestConfig( 23 | op_type="Mul", 24 | valid_inputs=["A", "B"], 25 | valid_attributes={}, # Mul has no layer-specific attributes 26 | required_initializers={}, 27 | input_shapes={ 28 | "A": [1, 3, 4, 4], 29 | "B": [1, 3, 4, 4], 30 | }, 31 | output_shapes={ 32 | "mul_output": [1, 3, 4, 4], 33 | }, 34 | ) 35 | 36 | def get_test_specs(self) -> list[LayerTestSpec]: 37 | rng = np.random.default_rng(TEST_RNG_SEED) 38 | return [ 39 | # --- VALID TESTS --- 40 | valid_test("basic") 41 | .description("Basic elementwise Mul of two same-shaped tensors") 42 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 43 | .tags("basic", "elementwise", "Mul") 44 | .build(), 45 | valid_test("broadcast_mul") 46 | .description("mul with Numpy-style broadcasting along spatial dimensions") 47 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 48 | .tags("broadcast", "elementwise", "mul", "onnx14") 49 | .build(), 50 | valid_test("initializer_mul") 51 | .description( 52 | "mul where second input (B) is a tensor initializer instead of input", 53 | ) 54 | .override_input_shapes(A=[1, 3, 4, 4]) 55 | .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4))) 56 | .tags("initializer", "elementwise", "mul", "onnxruntime") 57 | .build(), 58 | valid_test("scalar_mul") 59 | .description("mul scalar (initializer) to tensor") 60 | .override_input_shapes(A=[1, 3, 4, 4]) 61 | .override_initializer("B", np.array([2.0], dtype=np.float32)) 62 | .tags("scalar", "elementwise", "mul") 63 | .build(), 64 | # # --- E2E TESTS --- 65 | e2e_test("e2e_mul") 66 | .description("End-to-end mul test with random inputs") 67 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 68 | .override_output_shapes(mul_output=[1, 3, 4, 4]) 69 | .tags("e2e", "mul", "2d") 70 | .build(), 71 | e2e_test("e2e_initializer_mul") 72 | .description( 73 | "mul where second input (B) is a tensor initializer instead of input", 74 | ) 75 | .override_input_shapes(A=[1, 3, 4, 4]) 76 | .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4))) 77 | .tags("initializer", "elementwise", "mul", "onnxruntime") 78 | .build(), 79 | e2e_test("e2e_broadcast_mul") 80 | .description("mul with Numpy-style broadcasting along spatial dimensions") 81 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 82 | .tags("broadcast", "elementwise", "mul", "onnx14") 83 | .build(), 84 | e2e_test("e2e_scalar_mul") 85 | .description("mul scalar (initializer) to tensor") 86 | .override_input_shapes(A=[1, 3, 4, 4]) 87 | .override_initializer("B", np.array([2.0], dtype=np.float32)) 88 | .tags("scalar", "elementwise", "mul") 89 | .build(), 90 | # # --- EDGE CASES --- 91 | edge_case_test("empty_tensor") 92 | .description("mul with empty tensor input (zero elements)") 93 | .override_input_shapes(A=[0], B=[0]) 94 | .tags("edge", "empty", "mul") 95 | .build(), 96 | edge_case_test("large_tensor") 97 | .description("Large tensor mul performance/stress test") 98 | .override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256]) 99 | .tags("large", "performance", "mul") 100 | .skip("Performance test, skipped by default") 101 | .build(), 102 | ] 103 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers/sub_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from python.tests.onnx_quantizer_tests import TEST_RNG_SEED 4 | from python.tests.onnx_quantizer_tests.layers.base import ( 5 | BaseLayerConfigProvider, 6 | LayerTestConfig, 7 | LayerTestSpec, 8 | e2e_test, 9 | edge_case_test, 10 | valid_test, 11 | ) 12 | 13 | 14 | class SubConfigProvider(BaseLayerConfigProvider): 15 | """Test configuration provider for Sub layer""" 16 | 17 | @property 18 | def layer_name(self) -> str: 19 | return "Sub" 20 | 21 | def get_config(self) -> LayerTestConfig: 22 | return LayerTestConfig( 23 | op_type="Sub", 24 | valid_inputs=["A", "B"], 25 | valid_attributes={}, # Sub has no layer-specific attributes 26 | required_initializers={}, 27 | input_shapes={ 28 | "A": [1, 3, 4, 4], 29 | "B": [1, 3, 4, 4], 30 | }, 31 | output_shapes={ 32 | "sub_output": [1, 3, 4, 4], 33 | }, 34 | ) 35 | 36 | def get_test_specs(self) -> list[LayerTestSpec]: 37 | rng = np.random.default_rng(TEST_RNG_SEED) 38 | return [ 39 | # --- VALID TESTS --- 40 | valid_test("basic") 41 | .description("Basic elementwise Sub of two same-shaped tensors") 42 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 43 | .tags("basic", "elementwise", "Sub") 44 | .build(), 45 | valid_test("broadcast_Sub") 46 | .description("Sub with Numpy-style broadcasting along spatial dimensions") 47 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 48 | .tags("broadcast", "elementwise", "Sub", "onnx14") 49 | .build(), 50 | valid_test("initializer_Sub") 51 | .description( 52 | "Sub where second input (B) is a tensor initializer instead of input", 53 | ) 54 | .override_input_shapes(A=[1, 3, 4, 4]) 55 | .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4))) 56 | .tags("initializer", "elementwise", "Sub", "onnxruntime") 57 | .build(), 58 | valid_test("scalar_Sub") 59 | .description("Sub scalar (initializer) to tensor") 60 | .override_input_shapes(A=[1, 3, 4, 4]) 61 | .override_initializer("B", np.array([2.0], dtype=np.float32)) 62 | .tags("scalar", "elementwise", "Sub") 63 | .build(), 64 | # --- E2E TESTS --- 65 | e2e_test("e2e_Sub") 66 | .description("End-to-end Sub test with random inputs") 67 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4]) 68 | .override_output_shapes(sub_output=[1, 3, 4, 4]) 69 | .tags("e2e", "Sub", "2d") 70 | .build(), 71 | e2e_test("e2e_initializer_Sub") 72 | .description( 73 | "Sub where second input (B) is a tensor initializer instead of input", 74 | ) 75 | .override_input_shapes(A=[1, 3, 4, 4]) 76 | .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4))) 77 | .tags("initializer", "elementwise", "Sub", "onnxruntime") 78 | .build(), 79 | e2e_test("e2e_broadcast_Sub") 80 | .description("Sub with Numpy-style broadcasting along spatial dimensions") 81 | .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1]) 82 | .tags("broadcast", "elementwise", "Sub", "onnx14") 83 | .build(), 84 | e2e_test("e2e_scalar_Sub") 85 | .description("Sub scalar (initializer) to tensor") 86 | .override_input_shapes(A=[1, 3, 4, 4]) 87 | .override_initializer("B", np.array([2.0], dtype=np.float32)) 88 | .tags("scalar", "elementwise", "Sub") 89 | .build(), 90 | # # --- EDGE CASES --- 91 | edge_case_test("empty_tensor") 92 | .description("Sub with empty tensor input (zero elements)") 93 | .override_input_shapes(A=[0], B=[0]) 94 | .tags("edge", "empty", "Sub") 95 | .build(), 96 | edge_case_test("large_tensor") 97 | .description("Large tensor Sub performance/stress test") 98 | .override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256]) 99 | .tags("large", "performance", "Sub") 100 | .skip("Performance test, skipped by default") 101 | .build(), 102 | ] 103 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ( 7 | ONNXOpQuantizer, 8 | ) 9 | from python.tests.onnx_quantizer_tests.layers.base import LayerTestConfig, SpecType 10 | from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory 11 | 12 | 13 | class TestScalability: 14 | """Tests (meta) to verify the framework scales with new layers""" 15 | 16 | @pytest.mark.unit 17 | def test_adding_new_layer_config(self: TestScalability) -> None: 18 | """Test that adding new layer configs is straightforward""" 19 | two = 2 20 | # Simulate adding a new layer type 21 | new_layer_config = LayerTestConfig( 22 | op_type="NewCustomOp", 23 | valid_inputs=["input", "custom_param"], 24 | valid_attributes={"custom_attr": 42}, 25 | required_initializers={"custom_param": np.array([1, 2, 3])}, 26 | ) 27 | 28 | # Verify config can create nodes and initializers 29 | node = new_layer_config.create_node() 30 | assert node.op_type == "NewCustomOp" 31 | assert len(node.input) == two 32 | 33 | initializers = new_layer_config.create_initializers() 34 | assert "custom_param" in initializers 35 | 36 | @pytest.mark.unit 37 | def test_layer_config_extensibility(self: TestScalability) -> None: 38 | """Test that layer configs consists of all registered handlers""" 39 | configs = TestLayerFactory.get_layer_configs() 40 | 41 | # Verify all expected layers are present 42 | unsupported = ONNXOpQuantizer().handlers.keys() - set(configs.keys()) 43 | assert unsupported == set(), ( 44 | f"The following layers are not being configured for testing: {unsupported}." 45 | " Please add configuration in tests/onnx_quantizer_tests/layers/" 46 | ) 47 | 48 | # Verify each config has required components 49 | for layer_type, config in configs.items(): 50 | err_msg = ( 51 | f"Quantization test config is not supported yet for {layer_type}" 52 | " and must be implemented" 53 | ) 54 | assert config.op_type == layer_type, err_msg 55 | assert isinstance( 56 | config.valid_inputs, 57 | list, 58 | ), err_msg 59 | assert isinstance( 60 | config.valid_attributes, 61 | dict, 62 | ), err_msg 63 | assert isinstance( 64 | config.required_initializers, 65 | dict, 66 | ), err_msg 67 | 68 | @pytest.mark.unit 69 | def test_every_layer_has_basic_and_e2e(self: TestScalability) -> None: 70 | """Each registered layer must have at least one basic/valid test 71 | and one e2e test.""" 72 | missing_basic = [] 73 | missing_e2e = [] 74 | 75 | # iterate over registered layers 76 | for layer_name in TestLayerFactory.get_available_layers(): 77 | cases = TestLayerFactory.get_test_cases_by_layer(layer_name) 78 | specs = [spec for _, _config, spec in cases] 79 | 80 | # Consider a test "basic" if: 81 | # - it has tag 'basic' or 'valid', OR 82 | # - its spec_type is SpecType.VALID (if you use SpecType) 83 | has_basic = any( 84 | ( 85 | "basic" in getattr(s, "tags", set()) 86 | or "valid" in getattr(s, "tags", set()) 87 | or getattr(s, "spec_type", None) == SpecType.VALID 88 | ) 89 | for s in specs 90 | ) 91 | 92 | # Consider a test "e2e" if: 93 | # - it has tag 'e2e', OR 94 | # - its spec_type is SpecType.E2E (if you use that enum) 95 | has_e2e = any( 96 | ( 97 | "e2e" in getattr(s, "tags", set()) 98 | or getattr(s, "spec_type", None) == SpecType.E2E 99 | ) 100 | for s in specs 101 | ) 102 | 103 | if not has_basic: 104 | missing_basic.append(layer_name) 105 | if not has_e2e: 106 | missing_e2e.append(layer_name) 107 | 108 | assert not missing_basic, f"Layers missing a basic/valid test: {missing_basic}" 109 | assert not missing_e2e, f"Layers missing an e2e test: {missing_e2e}" 110 | -------------------------------------------------------------------------------- /.github/workflows/e2e-tests.yml: -------------------------------------------------------------------------------- 1 | name: End-to-End Tests 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | 8 | jobs: 9 | formatting-and-linting: 10 | name: Check Formatting and Linting 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | with: 15 | fetch-depth: 0 # Required for file change detection 16 | 17 | # Setup Python 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.12" 22 | 23 | # Setup Rust 24 | - name: Install Rust toolchain 25 | uses: actions-rs/toolchain@v1 26 | with: 27 | toolchain: nightly-2025-03-27 28 | components: rustfmt, clippy 29 | override: true 30 | 31 | - name: Install OMPI 32 | run: | 33 | sudo apt-get update && sudo apt-get install -y \ 34 | libopenmpi-dev \ 35 | pkg-config \ 36 | libclang-dev \ 37 | clang \ 38 | openmpi-bin 39 | 40 | # Install UV 41 | - name: Install UV 42 | uses: astral-sh/setup-uv@v5 43 | 44 | # Run formatters and linters 45 | - name: Install the project 46 | run: uv sync --group test 47 | 48 | 49 | - name: Get changed files 50 | id: changed-files 51 | uses: tj-actions/changed-files@v46 52 | with: 53 | files: | 54 | **/*.py 55 | **/*.rs 56 | 57 | - name: Run pre-commit on changed files 58 | if: steps.changed-files.outputs.any_changed == 'true' 59 | uses: pre-commit/action@v3.0.1 60 | with: 61 | extra_args: --files ${{ steps.changed-files.outputs.all_changed_files}} 62 | 63 | # Additional Rust checks (easier to just run on all) 64 | - name: Run cargo fmt check 65 | run: cargo fmt --all -- --check 66 | 67 | - name: Run cargo clippy 68 | run: cargo clippy --workspace --all-targets --all-features -- -D warnings 69 | 70 | e2e-test: 71 | name: End-to-End Testing 72 | runs-on: ubuntu-latest 73 | strategy: 74 | matrix: 75 | python-version: ["3.10", "3.11", "3.12"] # Test all supported versions 76 | steps: 77 | - uses: actions/checkout@v4 78 | with: 79 | fetch-depth: 0 80 | 81 | - name: Set up Python ${{ matrix.python-version }} 82 | uses: actions/setup-python@v5 83 | with: 84 | python-version: ${{ matrix.python-version }} 85 | 86 | - name: Install Rust toolchain 87 | uses: actions-rs/toolchain@v1 88 | with: 89 | toolchain: nightly-2025-03-27 90 | override: true 91 | 92 | - name: Cache Rust dependencies 93 | uses: actions/cache@v4 94 | with: 95 | path: | 96 | ~/.cargo/registry 97 | ~/.cargo/git 98 | target 99 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 100 | 101 | - name: Install UV 102 | uses: astral-sh/setup-uv@v5 103 | 104 | - name: Install the project 105 | run: uv sync --group test 106 | 107 | - name: Install OMPI 108 | run: | 109 | sudo apt-get update && sudo apt-get install -y \ 110 | libopenmpi-dev \ 111 | pkg-config \ 112 | libclang-dev \ 113 | clang \ 114 | openmpi-bin 115 | 116 | - name: Clone ECC 117 | run: git clone https://github.com/PolyhedraZK/Expander.git 118 | 119 | - name: Run E2E Tests 120 | run: | 121 | mkdir -p report 122 | uv run pytest --e2e --model lenet --source onnx --html=report/e2e_tests.html 123 | 124 | - name: Upload test report 125 | uses: actions/upload-artifact@v4 126 | with: 127 | name: e2e-test-report-python-${{ matrix.python-version }} 128 | path: report 129 | 130 | python-setup: 131 | name: Python Setup Check 132 | runs-on: ubuntu-latest 133 | steps: 134 | - uses: actions/checkout@v4 135 | - name: Verify Python setup 136 | run: | 137 | if [ ! -f pyproject.toml ]; then 138 | echo "pyproject.toml is missing" 139 | exit 1 140 | fi 141 | if [ ! -f uv.lock ]; then 142 | echo "uv.lock is missing" 143 | exit 1 144 | fi 145 | 146 | rust-setup: 147 | name: Rust Setup Check 148 | runs-on: ubuntu-latest 149 | steps: 150 | - uses: actions/checkout@v4 151 | - name: Verify Rust setup 152 | run: | 153 | if [ ! -f Cargo.toml ]; then 154 | echo "Cargo.toml is missing" 155 | exit 1 156 | fi 157 | -------------------------------------------------------------------------------- /rust/jstprove_circuits/src/circuit_functions/hints/bits.rs: -------------------------------------------------------------------------------- 1 | //! Unconstrained bit-manipulation helpers. 2 | //! 3 | //! These helpers are *not* soundness-critical — they are witnesses only. 4 | //! Used only for legacy or experimental bit-decomposition flows. 5 | 6 | /// External crate imports 7 | use ethnum::U256; 8 | 9 | /// `ExpanderCompilerCollection` imports 10 | use expander_compiler::frontend::{CircuitField, Config, RootAPI, Variable}; 11 | 12 | // Trait giving CircuitField::MODULUS 13 | use expander_compiler::field::FieldArith; 14 | 15 | /// Internal crate imports 16 | use crate::circuit_functions::{CircuitError, utils::UtilsError}; 17 | 18 | // ----------------------------------------------------------------------------- 19 | // FUNCTION: unconstrained_to_bits 20 | // ----------------------------------------------------------------------------- 21 | 22 | /// Extracts the `n_bits` least significant bits of a field element, using 23 | /// *unconstrained* bit operations, and returns them in **little-endian order**. 24 | /// 25 | /// This is a lightweight helper used only when no constraints are required on 26 | /// the bit pattern. It is *not* a sound range-check: higher bits of the value are 27 | /// simply discarded. 28 | /// 29 | /// # Overview 30 | /// 31 | /// Given an input field element `x`, the function repeatedly: 32 | /// 33 | /// 1. Computes `bit_0 = x AND 1` using `unconstrained_bit_and`. 34 | /// 2. Appends `bit_0` to the output list. 35 | /// 3. Updates `x = x >> 1` using `unconstrained_shift_r`. 36 | /// 37 | /// After `n_bits` iterations, the result is: 38 | /// 39 | /// bits[0] = least significant bit of input 40 | /// bits[1] = next bit 41 | /// ... 42 | /// bits[n_bits - 1] = most significant of the extracted bits 43 | /// 44 | /// This helper mirrors a CPU right-shift loop, but none of the bits are enforced 45 | /// to be boolean and no reconstruction constraint is added. If soundness is 46 | /// required, callers must pair this with `constrained_reconstruct_from_bits` 47 | /// (or with LogUp-based range checks). 48 | /// 49 | /// # Arguments 50 | /// 51 | /// - `api`: the circuit builder, providing unconstrained bitwise operations. 52 | /// - `input`: value from which the `n_bits` LSBs will be extracted. 53 | /// - `n_bits`: number of least significant bits to extract. 54 | /// 55 | /// # Returns 56 | /// 57 | /// A `Vec` of length `n_bits`, storing the least significant bits of 58 | /// `input` in little-endian order. 59 | /// 60 | /// # Errors 61 | /// 62 | /// - Returns `CircuitError::Other` if `n_bits == 0`. 63 | /// - Returns `UtilsError::ValueTooLarge` if `n_bits` does not fit in `u32`. 64 | /// - Returns `CircuitError::Other` if `2^n_bits >= MODULUS/2`. 65 | /// (This guards against extracting more bits than make sense for the field.) 66 | /// 67 | /// # Example 68 | /// 69 | /// For `input = 43` (binary `101011`) and `n_bits = 4`, the function returns: 70 | /// 71 | /// [1, 1, 0, 1] 72 | /// 73 | /// corresponding to the 4 least significant bits `1011` in little-endian form. 74 | pub fn unconstrained_to_bits>( 75 | api: &mut Builder, 76 | input: Variable, 77 | n_bits: usize, 78 | ) -> Result, CircuitError> { 79 | if n_bits == 0 { 80 | return Err(CircuitError::Other("Cannot convert to 0 bits".into())); 81 | } 82 | 83 | // Prevent `U256::pow` overflow/wrap for absurdly large exponents. 84 | // U256 can represent up to 2^256 − 1, but 2^256 itself overflows. 85 | // Therefore, the maximum safe exponent is 255. 86 | let max_bits_for_u256_exp: usize = 256; 87 | if n_bits >= max_bits_for_u256_exp { 88 | return Err(CircuitError::Other(format!( 89 | "unconstrained_to_bits: n_bits too large (max {})", 90 | max_bits_for_u256_exp - 1 91 | ))); 92 | } 93 | 94 | let base: U256 = U256::from(2u32); 95 | if base.pow( 96 | u32::try_from(n_bits).map_err(|_| UtilsError::ValueTooLarge { 97 | value: n_bits, 98 | max: u128::from(u32::MAX), 99 | })?, 100 | ) >= (CircuitField::::MODULUS / 2) 101 | { 102 | return Err(CircuitError::Other( 103 | "unconstrained_to_bits: n_bits too large (require 2^n_bits < MODULUS/2)".into(), 104 | )); 105 | } 106 | 107 | let mut least_significant_bits = Vec::with_capacity(n_bits); 108 | let mut current = input; 109 | 110 | for _ in 0..n_bits { 111 | // Extract bit 0 of `current` 112 | let bit = api.unconstrained_bit_and(current, 1u32); 113 | least_significant_bits.push(bit); 114 | // Shift right by one 115 | current = api.unconstrained_shift_r(current, 1u32); 116 | } 117 | 118 | Ok(least_significant_bits) 119 | } 120 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | if TYPE_CHECKING: 6 | from _pytest.config import Config, Parser 7 | from _pytest.nodes import Item 8 | from _pytest.python import Metafunc 9 | import pytest 10 | 11 | from python.core.utils.model_registry import get_models_to_test, list_available_models 12 | 13 | 14 | def pytest_addoption(parser: Parser) -> None: 15 | parser.addoption( 16 | "--model", 17 | action="append", 18 | default=None, 19 | help="Model(s) to test. Use multiple times to test more than one.", 20 | ) 21 | parser.addoption( 22 | "--list-models", 23 | action="store_true", 24 | default=False, 25 | help="List all available circuit models.", 26 | ) 27 | parser.addoption( 28 | "--source", 29 | action="store", 30 | choices=["class", "pytorch", "onnx"], 31 | default=None, 32 | help="Restrict models to a specific source: class, pytorch, or onnx.", 33 | ) 34 | parser.addoption( 35 | "--unit", 36 | action="store_true", 37 | default=False, 38 | help="Run only unit tests.", 39 | ) 40 | parser.addoption( 41 | "--integration", 42 | action="store_true", 43 | default=False, 44 | help="Run only integration tests.", 45 | ) 46 | parser.addoption( 47 | "--e2e", 48 | action="store_true", 49 | default=False, 50 | help="Run only end-to-end tests.", 51 | ) 52 | 53 | 54 | def pytest_collection_modifyitems(config: Config, items: list[Item]) -> None: 55 | run_unit = config.getoption("--unit") 56 | run_integration = config.getoption("--integration") 57 | run_e2e = config.getoption("--e2e") 58 | 59 | # If no filters set, run all 60 | if not any([run_unit, run_integration, run_e2e]): 61 | return 62 | 63 | selected = [] 64 | deselected = [] 65 | 66 | for item in items: 67 | has_unit = "unit" in item.keywords 68 | has_integration = "integration" in item.keywords and "e2e" not in item.keywords 69 | has_e2e = "e2e" in item.keywords 70 | 71 | if ( 72 | (run_unit and has_unit) 73 | or (run_integration and has_integration) 74 | or (run_e2e and has_e2e) 75 | ): 76 | selected.append(item) 77 | else: 78 | deselected.append(item) 79 | 80 | items[:] = selected 81 | config.hook.pytest_deselected(items=deselected) 82 | 83 | 84 | def pytest_generate_tests(metafunc: Metafunc) -> None: 85 | if "model_fixture" in metafunc.fixturenames: 86 | selected_models = metafunc.config.getoption("model") 87 | selected_source = metafunc.config.getoption("source") 88 | 89 | models = get_models_to_test(selected_models, selected_source) 90 | ids = [ 91 | f"{model.name}:{model.source}" for model in models 92 | ] # Extract readable names 93 | 94 | metafunc.parametrize( 95 | "model_fixture", 96 | models, 97 | indirect=True, 98 | scope="module", 99 | ids=ids, 100 | ) 101 | 102 | 103 | def pytest_configure(config: Config) -> None: 104 | # If the --list-models option is used, list models and exit 105 | if config.getoption("list_models"): 106 | available_models = list_available_models() 107 | print("\nAvailable Circuit Models:") # noqa: T201 108 | for model in available_models: 109 | print(f"- {model}") # noqa: T201 110 | pytest.exit( 111 | "Exiting after listing available models.", 112 | ) # This prevents tests from running 113 | 114 | 115 | @pytest.fixture(scope="session", autouse=True) 116 | def ensure_dev_mode_compile_for_e2e( 117 | request: pytest.FixtureRequest, 118 | ) -> None: 119 | """ 120 | Ensure that rust code is recompiled before e2e tests are performed. 121 | """ 122 | # Only run this for e2e tests 123 | if not request.config.getoption("--e2e"): 124 | return 125 | 126 | # Skip if there are no e2e tests being run 127 | if not any("e2e" in item.keywords for item in request.session.items): 128 | return 129 | 130 | import subprocess # noqa: PLC0415 131 | 132 | result = subprocess.run( 133 | ["cargo", "build", "--release"], # noqa: S607 134 | check=True, 135 | capture_output=True, 136 | text=True, 137 | ) 138 | 139 | print("stdout:", result.stdout) # noqa: T201 140 | print("stderr:", result.stderr) # noqa: T201 141 | 142 | # On initial tests this approach works. If this breaks, we can run 143 | # compilation of a basic circuit with dev_mode = True 144 | -------------------------------------------------------------------------------- /python/tests/onnx_quantizer_tests/test_registered_quantizers.py: -------------------------------------------------------------------------------- 1 | # This file performs very basic integration tests on each registered quantizer 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx import helper 7 | 8 | from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig 9 | from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ( 10 | ONNXOpQuantizer, 11 | ) 12 | from python.tests.onnx_quantizer_tests import TEST_RNG_SEED 13 | 14 | 15 | @pytest.fixture 16 | def dummy_graph() -> onnx.GraphProto: 17 | return onnx.GraphProto() 18 | 19 | 20 | def mock_initializer_map(input_names: list[str]) -> dict[str, onnx.TensorProto]: 21 | rng = np.random.default_rng(TEST_RNG_SEED) 22 | return { 23 | name: onnx.helper.make_tensor( 24 | name=name, 25 | data_type=onnx.TensorProto.FLOAT, 26 | dims=[2, 2], # minimal shape 27 | vals=rng.random(4, dtype=np.float32).tolist(), 28 | ) 29 | for name in input_names 30 | } 31 | 32 | 33 | def get_required_input_names(op_type: str) -> list[str]: 34 | try: 35 | schema = onnx.defs.get_schema(op_type) 36 | return [ 37 | inp.name or f"input{i}" 38 | for i, inp in enumerate(schema.inputs) 39 | if inp.option != 1 40 | ] # 1 = optional 41 | except Exception: 42 | return ["input0"] # fallback 43 | 44 | 45 | def validate_quantized_node(node_result: onnx.NodeProto, op_type: str) -> None: 46 | """Validate a single quantized node.""" 47 | assert isinstance(node_result, onnx.NodeProto), f"Invalid node type for {op_type}" 48 | assert node_result.op_type, f"Missing op_type for {op_type}" 49 | assert node_result.output, f"Missing outputs for {op_type}" 50 | 51 | try: 52 | # Create a minimal graph with dummy IOs to satisfy ONNX requirements 53 | temp_graph = onnx.GraphProto() 54 | temp_graph.name = "temp_graph" 55 | 56 | for inp in node_result.input: 57 | if not any(vi.name == inp for vi in temp_graph.input): 58 | temp_graph.input.append( 59 | onnx.helper.make_tensor_value_info( 60 | inp, 61 | onnx.TensorProto.FLOAT, 62 | [1], 63 | ), 64 | ) 65 | 66 | for out in node_result.output: 67 | if not any(vi.name == out for vi in temp_graph.output): 68 | temp_graph.output.append( 69 | onnx.helper.make_tensor_value_info( 70 | out, 71 | onnx.TensorProto.FLOAT, 72 | [1], 73 | ), 74 | ) 75 | 76 | temp_graph.node.append(node_result) 77 | 78 | # Explicit opset imports for default and contrib domains 79 | temp_model = onnx.helper.make_model( 80 | temp_graph, 81 | opset_imports=[ 82 | onnx.helper.make_opsetid("", 22), 83 | onnx.helper.make_opsetid("ai.onnx.contrib", 1), 84 | ], 85 | ) 86 | 87 | onnx.checker.check_model(temp_model) 88 | except onnx.checker.ValidationError as e: 89 | pytest.fail(f"ONNX node validation failed for {op_type}: {e}") 90 | 91 | 92 | @pytest.mark.integration 93 | @pytest.mark.parametrize("op_type", list(ONNXOpQuantizer().handlers.keys())) 94 | def test_registered_quantizer_quantize( 95 | op_type: str, 96 | dummy_graph: onnx.GraphProto, 97 | ) -> None: 98 | quantizer = ONNXOpQuantizer() 99 | handler = quantizer.handlers[op_type] 100 | 101 | inputs = get_required_input_names(op_type) 102 | dummy_initializer_map = mock_initializer_map(inputs) 103 | 104 | dummy_node = helper.make_node( 105 | op_type=op_type, 106 | inputs=inputs, 107 | outputs=["dummy_output"], 108 | ) 109 | 110 | result = handler.quantize( 111 | node=dummy_node, 112 | graph=dummy_graph, 113 | scale_config=ScaleConfig(exponent=10, base=2, rescale=True), 114 | initializer_map=dummy_initializer_map, 115 | ) 116 | assert result is not None 117 | 118 | # Enhanced assertions: validate result type and structure 119 | if isinstance(result, list): 120 | assert len(result) > 0, f"Quantize returned empty list for {op_type}" 121 | for node_result in result: 122 | validate_quantized_node(node_result, op_type) 123 | else: 124 | if inputs: 125 | # Only assert if this op actually requires inputs 126 | assert ( 127 | result.input 128 | ), f"Missing inputs for {op_type}; required_inputs={inputs}" 129 | 130 | validate_quantized_node(result, op_type) 131 | --------------------------------------------------------------------------------