├── sample.json ├── rustfmt.toml ├── challenge ├── sample.onnx ├── layer_setup └── .gitignore ├── src ├── util │ ├── err.rs │ ├── random.rs │ ├── iter.rs │ ├── copy_constraint.rs │ ├── config.rs │ ├── arithmetic.rs │ ├── serialization.rs │ ├── shape.rs │ ├── poly.rs │ └── onnx.rs ├── main.rs ├── util.rs ├── basic_block │ ├── transpose.rs │ ├── rope.rs │ ├── clip.rs │ ├── id.rs │ ├── less.rs │ ├── reshape.rs │ ├── constant.rs │ ├── split.rs │ ├── ops.rs │ ├── sub.rs │ ├── sort.rs │ ├── eq.rs │ ├── concat.rs │ ├── add.rs │ └── range.rs ├── layer │ ├── neg.rs │ ├── not.rs │ ├── shape.rs │ ├── xor.rs │ ├── clip.rs │ ├── arithmetic.rs │ ├── range.rs │ ├── where.rs │ ├── and.rs │ ├── gather.rs │ ├── constantofshape.rs │ ├── cast.rs │ ├── flatten.rs │ ├── nonlinear.rs │ ├── equal.rs │ ├── less.rs │ ├── transpose.rs │ ├── rope.rs │ ├── split.rs │ ├── expand.rs │ ├── tile.rs │ ├── sqrt.rs │ ├── reducemean.rs │ ├── gathernd.rs │ ├── concat.rs │ ├── scatternd.rs │ ├── softmax.rs │ ├── slice.rs │ ├── reshape.rs │ ├── pool.rs │ ├── mul.rs │ ├── pow.rs │ ├── gemm.rs │ └── squeeze.rs ├── bin │ ├── witness_gen.rs │ └── search_sf.rs ├── lib.rs ├── ptau.rs └── layer.rs ├── .gitignore ├── .github └── workflows │ ├── rust.yml │ └── gpu.yml ├── config.yaml ├── test_gpu_scripts ├── notify_enqueue.py ├── add_gpu_dependencies.py ├── test_gpu.sbatch ├── notify_test_result.py ├── notify_result.py └── README.md ├── Cargo.toml ├── scratch ├── bert │ └── replace_reshape_trans.py └── gptj │ ├── replace_multihead.py │ └── replace_gelu.py └── README.md /sample.json: -------------------------------------------------------------------------------- 1 | {"input_data": [[0.09, 0.13, 0.24, 0.05]]} -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | tab_spaces = 2 2 | max_width = 150 3 | chain_width = 130 4 | -------------------------------------------------------------------------------- /challenge: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuc-kang-lab/zk-torch/HEAD/challenge -------------------------------------------------------------------------------- /sample.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuc-kang-lab/zk-torch/HEAD/sample.onnx -------------------------------------------------------------------------------- /layer_setup/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /src/util/err.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | #[derive(Debug, Clone)] 4 | pub struct CQOutOfRangeError { 5 | pub input: i128, 6 | } 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock 3 | .DS_Store 4 | *.swp 5 | inputsEnc 6 | modelsEnc 7 | outputsEnc 8 | proofs 9 | setups 10 | models 11 | acc_proofs 12 | final_proofs 13 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use zk_torch::util::zktorch_kernel; 2 | 3 | fn main() { 4 | // please export RUST_LOG=debug; the debug logs for timing will be printed 5 | zktorch_kernel(); 6 | } 7 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Nightly 20 | run: rustup override set nightly 21 | - name: Build 22 | run: cargo build 23 | - name: Run 24 | run: cargo run --bin zk_torch -- config.yaml 25 | - name: Test 26 | run: cargo test 27 | 28 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | pub use arithmetic::*; 2 | pub use config::*; 3 | pub use copy_constraint::*; 4 | pub use err::*; 5 | pub use fft::*; 6 | pub use fold::*; 7 | pub use iter::*; 8 | pub use msm::*; 9 | pub use onnx::*; 10 | pub use poly::*; 11 | pub use prover::*; 12 | pub use random::*; 13 | pub use serialization::*; 14 | pub use shape::*; 15 | pub use verifier::*; 16 | 17 | pub mod arithmetic; 18 | pub mod config; 19 | pub mod copy_constraint; 20 | pub mod err; 21 | pub mod fft; 22 | pub mod fold; 23 | pub mod iter; 24 | pub mod msm; 25 | pub mod onnx; 26 | pub mod poly; 27 | pub mod prover; 28 | pub mod random; 29 | pub mod serialization; 30 | pub mod shape; 31 | pub mod verifier; 32 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | task: sample 2 | onnx: 3 | model_path: sample.onnx 4 | input_path: sample.json 5 | ptau: 6 | ptau_path: challenge 7 | pow_len_log: 7 8 | loaded_pow_len_log: 7 9 | sf: 10 | scale_factor_log: 3 11 | cq_range_log: 6 12 | cq_range_lower_log: 5 13 | prover: 14 | model_path: models 15 | setup_path: setups 16 | enc_model_path: modelsEnc 17 | enc_input_path: inputsEnc 18 | enc_output_path: outputsEnc 19 | proof_path: proofs 20 | acc_proof_path: acc_proofs 21 | final_proof_path: final_proofs 22 | enable_layer_setup: true 23 | verifier: 24 | enc_model_path: modelsEnc 25 | enc_input_path: inputsEnc 26 | enc_output_path: outputsEnc 27 | proof_path: proofs 28 | -------------------------------------------------------------------------------- /test_gpu_scripts/notify_enqueue.py: -------------------------------------------------------------------------------- 1 | from slack_sdk import WebClient 2 | import sys 3 | 4 | pr_number = sys.argv[1] 5 | token = sys.argv[2] 6 | commit_hash = str(sys.argv[3])[:7] 7 | 8 | m = ' starts to run the gpu test for commit: ' + commit_hash + '. ' 9 | note = 'If no message is sent after this, the test is either still running or being killed.' 10 | message = 'Event #' + str(pr_number) + m + note 11 | 12 | 13 | # reference: https://www.datacamp.com/tutorial/how-to-send-slack-messages-with-python 14 | # Set up a WebClient with the Slack OAuth token 15 | client = WebClient(token=token) 16 | 17 | # Send a message 18 | client.chat_postMessage( 19 | channel="zk-torch-test-gpu", 20 | text=message, 21 | username="SLURM bot" 22 | ) 23 | -------------------------------------------------------------------------------- /test_gpu_scripts/add_gpu_dependencies.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | gpu_dir = str(sys.argv[1]) 4 | 5 | with open("Cargo.toml", "r") as f: 6 | contents = f.readlines() 7 | 8 | for i, c in enumerate(contents): 9 | if c.strip() == "[dependencies]": 10 | # add dependencies after [dependencies] 11 | c = c + 'icicle-cuda-runtime = { path = "' + gpu_dir +'/gpu/icicle/wrappers/rust/icicle-cuda-runtime" }\n' 12 | c = c + 'icicle-core = { path = "' + gpu_dir +'/gpu/icicle/wrappers/rust/icicle-core", features = ["arkworks"]}\n' 13 | c = c + 'icicle-bn254 = { path = "' + gpu_dir +'/gpu/icicle/wrappers/rust/icicle-curves/icicle-bn254" , features = ["arkworks", "g2"]}\n' 14 | contents[i] = c 15 | break 16 | 17 | with open("Cargo.toml", "w") as f: 18 | contents = "".join(contents) 19 | f.write(contents) 20 | -------------------------------------------------------------------------------- /src/util/random.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Random utilities: 3 | * The functions are used for adding randomness to the RNG and 4 | * setting the random device for GPU computations. 5 | */ 6 | #![allow(unused_imports)] 7 | use rand::{rngs::StdRng, Rng, RngCore, SeedableRng}; 8 | use sha3::{Digest, Keccak256}; 9 | 10 | pub fn add_randomness(rng: &mut StdRng, mut bytes: Vec) { 11 | let mut buf = vec![0u8; 32]; 12 | rng.fill_bytes(&mut buf); 13 | bytes.append(&mut buf); 14 | let mut buf = [0u8; 32]; 15 | let mut hasher = Keccak256::new(); 16 | hasher.update(bytes); 17 | hasher.finalize_into((&mut buf).into()); 18 | *rng = StdRng::from_seed(buf); 19 | } 20 | 21 | #[cfg(feature = "gpu")] 22 | pub fn gpu_set_random_device() { 23 | let mut rng = StdRng::from_entropy(); 24 | icicle_cuda_runtime::device::set_device(rng.gen_range(0..1)).unwrap(); 25 | } 26 | -------------------------------------------------------------------------------- /src/basic_block/transpose.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, SRS}; 2 | use crate::util; 3 | use ark_bn254::Fr; 4 | use ndarray::ArrayD; 5 | 6 | #[derive(Debug)] 7 | pub struct TransposeBasicBlock { 8 | pub perm: Vec, 9 | } 10 | 11 | impl BasicBlock for TransposeBasicBlock { 12 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 13 | assert!(inputs.len() == 1); 14 | assert!(*self.perm.last().unwrap() == self.perm.len() - 1); 15 | Ok(vec![inputs[0].view().permuted_axes(&self.perm[..]).to_owned()]) 16 | } 17 | 18 | fn encodeOutputs(&self, _srs: &SRS, _model: &ArrayD, inputs: &Vec<&ArrayD>, _outputs: &Vec<&ArrayD>) -> Vec> { 19 | let n = self.perm.len(); 20 | vec![inputs[0].view().permuted_axes(&self.perm[..n - 1]).to_owned()] 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /test_gpu_scripts/test_gpu.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --partition=ddkang 3 | #SBATCH --time=01:00:00 4 | #SBATCH --nodes=1 5 | #SBATCH --mem=32G 6 | #SBATCH --nodelist=ccc0419 7 | #SBATCH --ntasks-per-node=1 8 | #SBATCH --job-name="zktorch_gh_action" 9 | #SBATCH --output=zktorch_gh_action.out 10 | #SBATCH --account=ddkang-cs-eng 11 | #SBATCH --gres=gpu:1 12 | 13 | module load cmake/3.26.3 14 | module load cuda/12.4 15 | module load python/3.9.16 16 | 17 | srun python test_gpu_scripts/notify_enqueue.py "$1" "$2" "$3" 18 | srun rustup override set nightly 19 | 20 | export LIBCLANG_PATH=/projects/illinois/eng/cs/ddkang/bjchen4/gpu/llvm-project/build/lib 21 | srun --export=ALL cargo run --bin zk_torch --features gpu -- config.yaml 22 | srun python test_gpu_scripts/notify_result.py "$1" "$2" "$3" 23 | 24 | srun --export=ALL cargo test --features gpu 25 | srun python test_gpu_scripts/notify_test_result.py "$1" "$2" "$3" 26 | -------------------------------------------------------------------------------- /src/basic_block/rope.rs: -------------------------------------------------------------------------------- 1 | use super::BasicBlock; 2 | use crate::util; 3 | use ark_bn254::Fr; 4 | use ndarray::{arr1, ArrayD}; 5 | 6 | #[derive(Debug)] 7 | pub struct RoPEBasicBlock { 8 | pub token_i: usize, 9 | pub output_SF: usize, 10 | } 11 | 12 | impl BasicBlock for RoPEBasicBlock { 13 | fn run(&self, _model: &ArrayD, _inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 14 | let mut r1 = vec![]; 15 | let mut r2 = vec![]; 16 | for i in 0..64 { 17 | let x = (self.token_i as f64) / (10000_f64.powf((i as f64) / 64_f64)); 18 | let mut a = x.cos(); 19 | let mut b = x.sin(); 20 | a *= (1 << self.output_SF) as f64; 21 | b *= (1 << self.output_SF) as f64; 22 | let a = Fr::from(a.round() as i128); 23 | let b = Fr::from(b.round() as i128); 24 | r1.push(a); 25 | r2.push(b); 26 | } 27 | Ok(vec![arr1(&r1).into_dyn(), arr1(&r2).into_dyn()]) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/basic_block/clip.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ndarray::{arr0, azip, ArrayD, IxDyn}; 5 | use rand::rngs::StdRng; 6 | use rayon::prelude::*; 7 | 8 | #[derive(Debug)] 9 | pub struct ClipBasicBlock { 10 | pub min: f32, 11 | pub max: f32, 12 | } 13 | impl BasicBlock for ClipBasicBlock { 14 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 15 | assert!(inputs.len() == 1); 16 | let shape = inputs[0].shape(); 17 | let out = util::array_into_iter(inputs[0]) 18 | .map(|x| { 19 | let mut x = util::fr_to_int(*x) as f32; 20 | x = x.max(self.min).min(self.max); 21 | Fr::from(x.round() as i32) 22 | }) 23 | .collect::>(); 24 | Ok(vec![ArrayD::from_shape_vec(shape, out).unwrap()]) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zk_torch" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | ark-std = { version = "0.4", features = ["parallel"] } 8 | ark-ff = { version = "0.4", features = ["parallel"] } 9 | ark-ec = "0.4" 10 | ark-poly = { version = "0.4", features = ["parallel"] } 11 | ark-bn254 = "0.4" 12 | ark-serialize = { version = "0.4", features = ["derive"] } 13 | serde = { version = "1.0.201", features = ["derive"] } 14 | serde_json = "1.0" 15 | serde_yaml = "0.9" 16 | once_cell = "1.15" 17 | log = { version = "0.4.14", default-features = false } 18 | env_logger = { version = "0.10" } 19 | plonky2 = { version = "0.2.2", features = ["timing"] } 20 | rand = "0.8.5" 21 | rayon = "1" 22 | ndarray = { version = "0.15.6", features = ["serde", "rayon"] } 23 | tract-onnx = "=0.21.6" 24 | sha3 = "0.10.8" 25 | bincode = "1.3.3" 26 | itertools = "0.13.0" 27 | downcast-rs = "2.0.1" 28 | 29 | [features] 30 | debug = [] 31 | gpu = [] 32 | mock_prove = [] 33 | fold = [] 34 | -------------------------------------------------------------------------------- /src/util/iter.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Iteration utilities: 3 | * The functions are used for iterating over arrays and vectors. 4 | * Each function/macro has CPU and GPU implementations. 5 | */ 6 | use ndarray::ArrayD; 7 | #[cfg(feature = "gpu")] 8 | use rayon::prelude::*; 9 | 10 | #[cfg(feature = "gpu")] 11 | pub fn array_into_iter(x: &ArrayD) -> impl ParallelIterator { 12 | x.into_par_iter() 13 | } 14 | 15 | #[cfg(not(feature = "gpu"))] 16 | pub fn array_into_iter(x: &ArrayD) -> impl Iterator { 17 | x.into_iter() 18 | } 19 | 20 | #[cfg(feature = "gpu")] 21 | pub fn vec_iter(x: &Vec) -> impl ParallelIterator { 22 | x.par_iter() 23 | } 24 | 25 | #[cfg(not(feature = "gpu"))] 26 | pub fn vec_iter(x: &Vec) -> impl Iterator { 27 | x.iter() 28 | } 29 | 30 | #[macro_export] 31 | macro_rules! ndarr_azip { 32 | ($($arg:tt)*) => { 33 | #[cfg(feature = "gpu")] 34 | { 35 | par_azip!($($arg)*) 36 | } 37 | #[cfg(not(feature = "gpu"))] 38 | { 39 | azip!($($arg)*) 40 | } 41 | }; 42 | } 43 | -------------------------------------------------------------------------------- /src/basic_block/id.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ndarray::{arr0, azip, ArrayD, IxDyn}; 5 | use rand::rngs::StdRng; 6 | 7 | #[derive(Debug)] 8 | pub struct IdBasicBlock; 9 | impl BasicBlock for IdBasicBlock { 10 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 11 | Ok(inputs.iter().map(|&x| x.clone()).collect()) 12 | } 13 | 14 | fn encodeOutputs(&self, _srs: &SRS, _model: &ArrayD, inputs: &Vec<&ArrayD>, _outputs: &Vec<&ArrayD>) -> Vec> { 15 | inputs.iter().map(|&x| x.clone()).collect() 16 | } 17 | 18 | fn verify( 19 | &self, 20 | _srs: &SRS, 21 | _model: &ArrayD, 22 | inputs: &Vec<&ArrayD>, 23 | outputs: &Vec<&ArrayD>, 24 | _proof: (&Vec, &Vec, &Vec), 25 | _rng: &mut StdRng, 26 | _cache: ProveVerifyCache, 27 | ) -> Vec { 28 | assert!(inputs == outputs); 29 | vec![] 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/layer/neg.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ark_std::Zero; 7 | use ndarray::{arr1, ArrayD}; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | pub struct NegLayer; 12 | 13 | impl Layer for NegLayer { 14 | fn graph( 15 | input_shapes: &Vec<&Vec>, 16 | input_types: &Vec, 17 | _constants: &Vec, DatumType)>>, 18 | _attributes: &Vec<&AttributeProto>, 19 | ) -> (Graph, Vec>, Vec) { 20 | let mut graph = Graph::new(); 21 | let zero = graph.addBB(Box::new(Const2BasicBlock { 22 | c: arr1(&vec![Fr::zero(); util::next_pow(*input_shapes[0].last().unwrap() as u32) as usize]).into_dyn(), 23 | })); 24 | let layer = graph.addBB(Box::new(RepeaterBasicBlock { 25 | basic_block: Box::new(SubBasicBlock {}), 26 | N: 1, 27 | })); 28 | let zero_output = graph.addNode(zero, vec![]); 29 | let layer_output = graph.addNode(layer, vec![(zero_output, 0), (-1, 0)]); 30 | graph.outputs.push((layer_output, 0)); 31 | (graph, vec![util::broadcastDims(input_shapes, 0)], vec![input_types[0]]) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /test_gpu_scripts/notify_test_result.py: -------------------------------------------------------------------------------- 1 | from slack_sdk import WebClient 2 | import sys 3 | 4 | pr_number = sys.argv[1] 5 | token = sys.argv[2] 6 | commit_hash = str(sys.argv[3])[:7] 7 | 8 | with open("zktorch_gh_action.out", "r") as f: 9 | contents = f.readlines() 10 | 11 | commit_str = ' (commit: ' + commit_hash + ')' 12 | message = 'error found when cargo test in event #' + str(pr_number) + commit_str 13 | 14 | # Check if log file exists 15 | try: 16 | with open("zktorch_gh_action.out", "r") as f: 17 | contents = f.readlines() 18 | # Check if the log file contains the string Cargo run was successful. 19 | for line in contents: 20 | if "0 failed;" in line: 21 | message = 'Event #' + str(pr_number) + commit_str + ' successfully passed cargo test on CC gpu' 22 | break 23 | except FileNotFoundError: 24 | message = 'error found when cargo test in event #' + str(pr_number) + commit_str 25 | 26 | # reference: https://www.datacamp.com/tutorial/how-to-send-slack-messages-with-python 27 | # Set up a WebClient with the Slack OAuth token 28 | client = WebClient(token=token) 29 | 30 | # Send a message 31 | client.chat_postMessage( 32 | channel="zk-torch-test-gpu", 33 | text=message, 34 | username="SLURM bot" 35 | ) 36 | -------------------------------------------------------------------------------- /test_gpu_scripts/notify_result.py: -------------------------------------------------------------------------------- 1 | from slack_sdk import WebClient 2 | import sys 3 | 4 | pr_number = sys.argv[1] 5 | token = sys.argv[2] 6 | commit_hash = str(sys.argv[3])[:7] 7 | 8 | with open("zktorch_gh_action.out", "r") as f: 9 | contents = f.readlines() 10 | 11 | commit_str = ' (commit: ' + commit_hash + ')' 12 | message = 'error found when cargo run in event #' + str(pr_number) + commit_str 13 | 14 | # Check if log file exists 15 | try: 16 | with open("zktorch_gh_action.out", "r") as f: 17 | contents = f.readlines() 18 | # Check if the log file contains the string Cargo run was successful. 19 | for line in contents: 20 | if "Cargo run was successful." in line: 21 | message = 'Event #' + str(pr_number) + commit_str + ' successfully passed cargo run on CC gpu' 22 | break 23 | except FileNotFoundError: 24 | message = 'error found when cargo run in event #' + str(pr_number) + commit_str 25 | 26 | # reference: https://www.datacamp.com/tutorial/how-to-send-slack-messages-with-python 27 | # Set up a WebClient with the Slack OAuth token 28 | client = WebClient(token=token) 29 | 30 | # Send a message 31 | client.chat_postMessage( 32 | channel="zk-torch-test-gpu", 33 | text=message, 34 | username="SLURM bot" 35 | ) 36 | -------------------------------------------------------------------------------- /src/basic_block/less.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ndarray::{arr0, azip, ArrayD, IxDyn}; 5 | use rand::rngs::StdRng; 6 | 7 | // perform element-wise less than comparison for two 1-d arrays 8 | #[derive(Debug)] 9 | pub struct LessBasicBlock; 10 | impl BasicBlock for LessBasicBlock { 11 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 12 | assert!(inputs.len() == 2 && inputs[0].ndim() <= 1 && inputs[1].ndim() <= 1); 13 | let mut r = ArrayD::zeros(IxDyn(&[std::cmp::max(inputs[0].len(), inputs[1].len())])); 14 | if inputs[0].len() == 1 && inputs[1].ndim() > 0 { 15 | azip!((r in &mut r, &x in inputs[1]) *r = Fr::from((util::fr_to_int(x) >= util::fr_to_int(*inputs[0].first().unwrap())) as i32)); 16 | } else if inputs[1].len() == 1 { 17 | azip!((r in &mut r, &x in inputs[0]) *r = Fr::from((util::fr_to_int(x) < util::fr_to_int(*inputs[1].first().unwrap())) as i32)); 18 | } else { 19 | azip!((r in &mut r, &x in inputs[0], &y in inputs[1]) *r = Fr::from((util::fr_to_int(x) < util::fr_to_int(y)) as i32)); 20 | } 21 | Ok(vec![r]) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/bin/witness_gen.rs: -------------------------------------------------------------------------------- 1 | use ark_bn254::Fr; 2 | use ndarray::ArrayD; 3 | use plonky2::{timed, util::timing::TimingTree}; 4 | use zk_torch::graph::Graph; 5 | use zk_torch::{onnx, util, CONFIG}; 6 | 7 | fn run( 8 | inputs: &Vec<&ArrayD>, 9 | graph: &Graph, 10 | models: &Vec<&ArrayD>, 11 | timing: &mut TimingTree, 12 | ) -> Result>>, util::CQOutOfRangeError> { 13 | // Run: 14 | timed!(timing, "run witness generation", graph.run(inputs, models)) 15 | } 16 | fn main() { 17 | // Timing 18 | let mut timing = TimingTree::default(); 19 | // please export RUST_LOG=debug; the debug logs for timing will be printed 20 | env_logger::init(); 21 | 22 | let onnx_file_name = &CONFIG.onnx.model_path; 23 | let (graph, models) = onnx::load_file(onnx_file_name); 24 | 25 | let input_path = &CONFIG.onnx.input_path; 26 | let inputs = if std::path::Path::new(input_path).exists() { 27 | util::load_inputs_from_json_for_onnx(onnx_file_name, input_path) 28 | } else { 29 | util::generate_fake_inputs_for_onnx(onnx_file_name) 30 | }; 31 | let inputs = inputs.iter().map(|x| x).collect(); 32 | let models = models.iter().map(|x| &x.0).collect(); 33 | let _outputs = run(&inputs, &graph, &models, &mut timing); 34 | 35 | timing.print(); 36 | println!("Witness generation done"); 37 | } 38 | -------------------------------------------------------------------------------- /src/layer/not.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ark_std::One; 7 | use ndarray::{arr1, ArrayD}; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | pub struct NotLayer; 12 | 13 | impl Layer for NotLayer { 14 | fn graph( 15 | input_shapes: &Vec<&Vec>, 16 | _input_types: &Vec, 17 | _constants: &Vec, DatumType)>>, 18 | _attributes: &Vec<&AttributeProto>, 19 | ) -> (Graph, Vec>, Vec) { 20 | let mut graph = Graph::new(); 21 | let bool_check = graph.addBB(Box::new(BooleanCheckBasicBlock {})); 22 | let one = graph.addBB(Box::new(Const2BasicBlock { 23 | c: arr1(&vec![Fr::one(); *input_shapes[0].last().unwrap()]).into_dyn(), 24 | })); 25 | let layer = graph.addBB(Box::new(RepeaterBasicBlock { 26 | basic_block: Box::new(SubBasicBlock {}), 27 | N: 1, 28 | })); 29 | let _ = graph.addNode(bool_check, vec![(-1, 0)]); 30 | let one_output = graph.addNode(one, vec![]); 31 | let layer_output = graph.addNode(layer, vec![(one_output, 0), (-1, 0)]); 32 | graph.outputs.push((layer_output, 0)); 33 | (graph, vec![util::broadcastDims(input_shapes, 0)], vec![DatumType::Bool]) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | #![allow(non_upper_case_globals)] 3 | #![allow(unused_imports)] 4 | pub mod basic_block; 5 | pub mod graph; 6 | pub mod layer; 7 | pub mod onnx; 8 | pub mod ptau; 9 | #[cfg(test)] 10 | pub mod tests; 11 | pub mod util; 12 | 13 | use once_cell::sync::Lazy; 14 | use std::env; 15 | use std::fs::{self, File}; 16 | use std::io::Read; 17 | use std::path::Path; 18 | 19 | pub static CONFIG_FILE: Lazy = Lazy::new(|| { 20 | let args: Vec = env::args().collect(); 21 | if args.len() != 2 { 22 | panic!("Usage: cargo run -- "); 23 | } 24 | args[1].clone() 25 | }); 26 | 27 | // Define a static CONFIG that holds the loaded configuration 28 | pub static CONFIG: Lazy = Lazy::new(|| { 29 | let mut file = File::open(&*CONFIG_FILE).expect("Could not open config"); 30 | let mut contents = String::new(); 31 | file.read_to_string(&mut contents).expect("Could not read config"); 32 | 33 | serde_yaml::from_str(&contents).expect("Could not parse config") 34 | }); 35 | 36 | pub static LAYER_SETUP_DIR: Lazy = Lazy::new(|| { 37 | let dir = format!( 38 | "layer_setup/{}_{}_{}", 39 | CONFIG.sf.scale_factor_log, CONFIG.sf.cq_range_log, CONFIG.sf.cq_range_lower_log 40 | ); 41 | assert!(Path::new(&dir).exists() || fs::create_dir_all(&dir).is_ok()); 42 | dir 43 | }); 44 | -------------------------------------------------------------------------------- /src/layer/shape.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ark_std::Zero; 7 | use ndarray::{arr1, ArrayD}; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | #[derive(Debug)] 12 | pub struct ShapeBasicBlock; 13 | impl BasicBlock for ShapeBasicBlock { 14 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 15 | let shape: Vec<_> = inputs[0].shape().iter().map(|&x| Fr::from(x as i32)).collect(); 16 | let shape = arr1(&shape).into_dyn(); 17 | let padded_shape = util::pad_to_pow_of_two(&shape, &Fr::zero()); 18 | Ok(vec![padded_shape]) 19 | } 20 | } 21 | 22 | pub struct ShapeLayer; 23 | impl Layer for ShapeLayer { 24 | fn graph( 25 | input_shapes: &Vec<&Vec>, 26 | _input_types: &Vec, 27 | _constants: &Vec, DatumType)>>, 28 | _attributes: &Vec<&AttributeProto>, 29 | ) -> (Graph, Vec>, Vec) { 30 | let mut graph = Graph::new(); 31 | let shape = graph.addBB(Box::new(ShapeBasicBlock {})); 32 | let shape_output = graph.addNode(shape, vec![(-1, 0)]); 33 | graph.outputs.push((shape_output, 0)); 34 | (graph, vec![vec![input_shapes[0].len()]], vec![DatumType::I64]) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/util/copy_constraint.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copy constraint utilities: 3 | * The functions are used for constructing the permutation and 4 | * padding_partitions fields in the CopyConstraintBasicBlock. 5 | */ 6 | use crate::util::pad_to_pow_of_two; 7 | use ndarray::{ArrayD, IxDyn}; 8 | 9 | // Helper function to get the indices of the reshaped tensor 10 | // Note that the input_shape and output_shape are non-padded 11 | pub fn get_reshape_indices(input_shape: Vec, output_shape: Vec) -> ArrayD> { 12 | let indices = ArrayD::from_shape_fn(input_shape.as_slice(), |index| Some(index.clone())); 13 | let output_indices = indices.view().into_shape(&output_shape[..]).unwrap().to_owned(); 14 | 15 | let padded_indices = pad_to_pow_of_two(&output_indices, &None); 16 | padded_indices 17 | } 18 | 19 | pub fn get_reshape_transpose_indices(input_shape: Vec, output_shape: Vec, axes: Vec) -> ArrayD> { 20 | let indices = ArrayD::from_shape_fn(input_shape.as_slice(), |index| Some(index.clone())); 21 | let output_indices = indices.view().into_shape(&output_shape[..]).unwrap().to_owned(); 22 | 23 | let mut permuted_indices = output_indices.clone(); 24 | permuted_indices = permuted_indices.permuted_axes(IxDyn(&axes)); 25 | 26 | let padded_indices = pad_to_pow_of_two(&permuted_indices, &None); 27 | padded_indices 28 | } 29 | -------------------------------------------------------------------------------- /test_gpu_scripts/README.md: -------------------------------------------------------------------------------- 1 | # GPU Test 2 | `test_gpu_scripts` is a directory that contains scripts to test the GPU on the Illinois Campus Cluster. The scripts are used to ensure that the GPU feature is working correctly. 3 | 4 | ## Workflow 5 | The workflow for testing gpu is in the `.github/workflows/gpu.yml`. The workflow consists of the following steps: 6 | 1. uses: actions/checkout@v2: this pulls the repo to the VM hosted by GitHub 7 | 2. name: Copy files over to the cluster: this scp the repo on the VM to CC 8 | 3. name: Sleep for 1 minute: this is to let CC prepare for the next step 9 | 4. name: Execute script to enqueue job: this ssh to CC. And it appends necessary dependencies for gpu testing and sends the sbatch job to CC SLURM node to test it 10 | 11 | ## Notification 12 | After step 4. above, the user will receive a notice in our slack channel `#zk-torch-test-gpu` that the job has been submitted. Once the job is done, the user will receive another notice in the same channel. 13 | 14 | ## Notes 15 | - The GitHub Actions may show error messages (e.g., Error: Timed out while waiting for handshake) sometimes. This is because the VM hosted by GitHub cannot access the Campus Cluster network. In this case, the workflow will fail and the user will have to re-run it when the Campus Cluster network is accessible (i.e., click `Checks` tab and `Re-run all jobs` button under the PR). 16 | -------------------------------------------------------------------------------- /src/util/config.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Debug, Deserialize, Serialize, Clone)] 4 | pub struct Config { 5 | pub task: String, 6 | pub onnx: OnnxConfig, 7 | pub ptau: PtauConfig, 8 | pub sf: ScaleFactorConfig, 9 | pub prover: ProverConfig, 10 | pub verifier: VerifierConfig, 11 | } 12 | 13 | #[derive(Debug, Deserialize, Serialize, Clone)] 14 | pub struct PtauConfig { 15 | pub ptau_path: String, 16 | pub pow_len_log: usize, 17 | pub loaded_pow_len_log: usize, 18 | } 19 | 20 | #[derive(Debug, Deserialize, Serialize, Clone)] 21 | pub struct OnnxConfig { 22 | pub model_path: String, 23 | pub input_path: String, 24 | } 25 | 26 | #[derive(Debug, Deserialize, Serialize, Clone)] 27 | pub struct ScaleFactorConfig { 28 | pub scale_factor_log: usize, 29 | pub cq_range_log: usize, 30 | pub cq_range_lower_log: usize, 31 | } 32 | 33 | #[derive(Debug, Deserialize, Serialize, Clone)] 34 | pub struct ProverConfig { 35 | pub model_path: String, 36 | pub setup_path: String, 37 | pub enc_model_path: String, 38 | pub enc_input_path: String, 39 | pub enc_output_path: String, 40 | pub proof_path: String, 41 | pub acc_proof_path: String, 42 | pub final_proof_path: String, 43 | pub enable_layer_setup: bool, 44 | } 45 | 46 | #[derive(Debug, Deserialize, Serialize, Clone)] 47 | pub struct VerifierConfig { 48 | pub enc_model_path: String, 49 | pub enc_input_path: String, 50 | pub enc_output_path: String, 51 | pub proof_path: String, 52 | } 53 | -------------------------------------------------------------------------------- /src/layer/xor.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::ArrayD; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | pub struct XorLayer; 11 | impl Layer for XorLayer { 12 | fn graph( 13 | input_shapes: &Vec<&Vec>, 14 | _input_types: &Vec, 15 | _constants: &Vec, DatumType)>>, 16 | _attributes: &Vec<&AttributeProto>, 17 | ) -> (Graph, Vec>, Vec) { 18 | let mut graph = Graph::new(); 19 | let bool_check = graph.addBB(Box::new(BooleanCheckBasicBlock {})); 20 | let sub = graph.addBB(Box::new(RepeaterBasicBlock { 21 | basic_block: Box::new(SubBasicBlock {}), 22 | N: 1, 23 | })); 24 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize; 25 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 26 | basic_block: Box::new(MulBasicBlock { len }), 27 | N: 1, 28 | })); 29 | let _ = graph.addNode(bool_check, vec![(-1, 0)]); 30 | let _ = graph.addNode(bool_check, vec![(-2, 0)]); 31 | let sub_output = graph.addNode(sub, vec![(-1, 0), (-2, 0)]); 32 | let xor_output = graph.addNode(mul, vec![(sub_output, 0), (sub_output, 0)]); // XOR(a, b) = PointwiseMul((a - b), (a - b)) 33 | graph.outputs.push((xor_output, 0)); 34 | (graph, vec![util::broadcastDims(input_shapes, 0)], vec![DatumType::Bool]) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/layer/clip.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ndarray::ArrayD; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | pub struct ClipLayer; 12 | impl Layer for ClipLayer { 13 | fn graph( 14 | input_shapes: &Vec<&Vec>, 15 | input_types: &Vec, 16 | constants: &Vec, DatumType)>>, 17 | _attributes: &Vec<&AttributeProto>, 18 | ) -> (Graph, Vec>, Vec) { 19 | let mut graph = Graph::new(); 20 | let min = util::fr_to_int(constants[1].unwrap().0.as_slice().unwrap()[0]) as f32; 21 | let max = util::fr_to_int(constants[2].unwrap().0.as_slice().unwrap()[0]) as f32; 22 | 23 | let clip = graph.addBB(Box::new(ClipBasicBlock { min: min, max: max })); 24 | let clip_output = graph.addNode(clip, vec![(-1, 0)]); 25 | let clip_check = graph.addBB(Box::new(RepeaterBasicBlock { 26 | basic_block: Box::new(CQ2BasicBlock { 27 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 28 | setup: Some((Box::new(ClipBasicBlock { min: min, max: max }), *onnx::CQ_RANGE_LOWER, *onnx::CQ_RANGE)), 29 | }), 30 | N: 1, 31 | })); 32 | let _ = graph.addNode(clip_check, vec![(-1, 0), (clip_output, 0)]); 33 | 34 | graph.outputs.push((clip_output, 0)); 35 | 36 | (graph, vec![input_shapes[0].clone()], vec![input_types[0]]) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/util/arithmetic.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Arithmetic utilities: 3 | * The functions are used for converting between Fr and i32, for calculating powers of Fr, 4 | * and for pointwise operations on u32 or f32. 5 | */ 6 | use ark_bn254::Fr; 7 | use ark_ff::PrimeField; 8 | 9 | pub fn fr_to_int(x: Fr) -> i128 { 10 | let a: i128 = 1; 11 | if x < Fr::from(a << 127) { 12 | x.into_bigint().0[0] as i128 13 | } else { 14 | -((-x).into_bigint().0[0] as i128) 15 | } 16 | } 17 | 18 | pub fn calc_pow(alpha: Fr, n: usize) -> Vec { 19 | let mut pow: Vec = vec![alpha; n]; 20 | if n > 0 { 21 | for i in 0..n - 1 { 22 | pow[i + 1] = pow[i] * alpha; 23 | } 24 | } 25 | pow 26 | } 27 | 28 | pub fn next_pow(n: u32) -> u32 { 29 | if n == 0 { 30 | return 1; 31 | } 32 | let mut v = n; 33 | v -= 1; 34 | v |= v >> 1; 35 | v |= v >> 2; 36 | v |= v >> 4; 37 | v |= v >> 8; 38 | v |= v >> 16; 39 | v += 1; 40 | v 41 | } 42 | 43 | /// Computes erf(x) approximation using A&S formula 7.1.26 44 | pub fn erf(x: f64) -> f64 { 45 | let a1 = 0.254829592; 46 | let a2 = -0.284496736; 47 | let a3 = 1.421413741; 48 | let a4 = -1.453152027; 49 | let a5 = 1.061405429; 50 | let p = 0.3275911; 51 | let sign = if x < 0.0 { -1.0 } else { 1.0 }; 52 | let x = x.abs(); 53 | let t = 1.0 / (1.0 + p * x); 54 | let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp(); 55 | sign * y 56 | } 57 | 58 | pub fn gelu(x: f64) -> f64 { 59 | let sqrt_2 = 2.0_f64.sqrt(); 60 | let y = 0.5 * x * (1.0 + erf(x / sqrt_2)); 61 | y 62 | } 63 | -------------------------------------------------------------------------------- /.github/workflows/gpu.yml: -------------------------------------------------------------------------------- 1 | name: SLURM Enqueue Workflow 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | jobs: 10 | enqueue: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Copy files over to the cluster 15 | uses: garygrossgarten/github-action-scp@release 16 | with: 17 | local: . 18 | remote: "${{ secrets.GPU_ACTION_DIR }}/gh_actions/${{ github.event.pull_request.head.sha }}" 19 | host: ${{ secrets.HOST_NAME }} 20 | username: ${{ secrets.USR_NAME }} 21 | password: ${{ secrets.PASSWORD }} 22 | - name: Sleep for 1 minute 23 | uses: jakejarvis/wait-action@master 24 | with: 25 | time: '1m' 26 | - name: Execute script to enqueue job in cluster 27 | uses: appleboy/ssh-action@v0.1.3 28 | with: 29 | host: ${{ secrets.HOST_NAME }} 30 | username: ${{ secrets.USR_NAME }} 31 | password: ${{ secrets.PASSWORD }} 32 | script: | 33 | cd "${{ secrets.GPU_ACTION_DIR }}/gh_actions/${{ github.event.pull_request.head.sha }}" 34 | python test_gpu_scripts/add_gpu_dependencies.py "${{ secrets.GPU_ACTION_DIR }}" 35 | mv test_gpu_scripts/test_gpu.sbatch ${{ github.event.pull_request.head.sha }}_test_gpu.sbatch 36 | module load anaconda/2023-Mar/3 37 | source activate slack 38 | sbatch ${{ github.event.pull_request.head.sha }}_test_gpu.sbatch "${{ github.event.number }}" "${{ secrets.SLACK_TOKEN }}" "${{ github.event.pull_request.head.sha }}" 39 | -------------------------------------------------------------------------------- /src/layer/arithmetic.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::ArrayD; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | macro_rules! define_arithmetic_layer { 11 | ($struct_name:ident, $basic_block:ident) => { 12 | pub struct $struct_name; 13 | 14 | impl Layer for $struct_name { 15 | fn graph( 16 | input_shapes: &Vec<&Vec>, 17 | input_types: &Vec, 18 | _constants: &Vec, DatumType)>>, 19 | _attributes: &Vec<&AttributeProto>, 20 | ) -> (Graph, Vec>, Vec) { 21 | let mut graph = Graph::new(); 22 | let layer = if input_shapes[0].len() == 0 && input_shapes[1].len() == 0 { 23 | graph.addBB(Box::new($basic_block {})) 24 | } else { 25 | graph.addBB(Box::new(RepeaterBasicBlock { 26 | basic_block: Box::new($basic_block {}), 27 | N: 1, 28 | })) 29 | }; 30 | let output_shape = if input_shapes[0].len() == 0 && input_shapes[1].len() == 0 { 31 | input_shapes[0].clone() 32 | } else { 33 | util::broadcastDims(input_shapes, 0) 34 | }; 35 | let layer_output = graph.addNode(layer, vec![(-1, 0), (-2, 0)]); 36 | graph.outputs.push((layer_output, 0)); 37 | (graph, vec![output_shape], vec![input_types[0]]) 38 | } 39 | } 40 | }; 41 | } 42 | 43 | // Using the macro to define AddLayer and SubLayer 44 | define_arithmetic_layer!(AddLayer, AddBasicBlock); 45 | define_arithmetic_layer!(SubLayer, SubBasicBlock); 46 | -------------------------------------------------------------------------------- /src/layer/range.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::ArrayD; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | pub struct RangeLayer; 11 | impl Layer for RangeLayer { 12 | fn graph( 13 | _input_shapes: &Vec<&Vec>, 14 | _input_types: &Vec, 15 | constants: &Vec, DatumType)>>, 16 | _attributes: &Vec<&AttributeProto>, 17 | ) -> (Graph, Vec>, Vec) { 18 | let mut graph = Graph::new(); 19 | let mut all_are_constant = true; 20 | 21 | // check if constants are all Some 22 | for i in 0..3 { 23 | if constants[i].is_none() { 24 | all_are_constant = false; 25 | break; 26 | } 27 | } 28 | 29 | let mut length = 0; 30 | if all_are_constant { 31 | let start = util::fr_to_int(constants[0].unwrap().0.as_slice().unwrap()[0]); 32 | let limit = util::fr_to_int(constants[1].unwrap().0.as_slice().unwrap()[0]); 33 | let delta = util::fr_to_int(constants[2].unwrap().0.as_slice().unwrap()[0]); 34 | 35 | // all fields are constant 36 | let range = graph.addBB(Box::new(RangeConstBasicBlock { 37 | start: start, 38 | limit: limit, 39 | delta: delta, 40 | })); 41 | let range_output = graph.addNode(range, vec![]); 42 | graph.outputs.push((range_output, 0)); 43 | 44 | let mut start = start; 45 | while start < limit { 46 | start += delta; 47 | length += 1; 48 | } 49 | } else { 50 | panic!("Don't support non-constant range yet"); 51 | } 52 | 53 | (graph, vec![vec![length]], vec![constants[0].unwrap().1]) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/basic_block/reshape.rs: -------------------------------------------------------------------------------- 1 | use crate::util; 2 | 3 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 4 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 5 | use ndarray::ArrayD; 6 | use rand::rngs::StdRng; 7 | 8 | #[derive(Debug)] 9 | pub struct ReshapeBasicBlock { 10 | pub shape: Vec, 11 | } 12 | 13 | impl BasicBlock for ReshapeBasicBlock { 14 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 15 | assert!(inputs.len() == 1); 16 | assert!(inputs[0].shape().last() == self.shape.last()); 17 | let result = match inputs[0].view().into_shape(&self.shape[..]) { 18 | Ok(view) => view.to_owned(), 19 | Err(_) => inputs[0].to_shape(&self.shape[..]).unwrap().into_owned(), 20 | }; 21 | 22 | Ok(vec![result]) 23 | } 24 | 25 | fn encodeOutputs(&self, _srs: &SRS, _model: &ArrayD, inputs: &Vec<&ArrayD>, _outputs: &Vec<&ArrayD>) -> Vec> { 26 | let n = self.shape.len(); 27 | let result = match inputs[0].view().into_shape(&self.shape[..n - 1]) { 28 | Ok(view) => view.to_owned(), 29 | Err(_) => inputs[0].to_shape(&self.shape[..n - 1]).unwrap().into_owned(), 30 | }; 31 | 32 | vec![result] 33 | } 34 | 35 | fn verify( 36 | &self, 37 | _srs: &SRS, 38 | _model: &ArrayD, 39 | inputs: &Vec<&ArrayD>, 40 | outputs: &Vec<&ArrayD>, 41 | _proof: (&Vec, &Vec, &Vec), 42 | _rng: &mut StdRng, 43 | _cache: ProveVerifyCache, 44 | ) -> Vec { 45 | let n = self.shape.len(); 46 | let reshaped = match inputs[0].view().into_shape(&self.shape[..n - 1]) { 47 | Ok(view) => view.to_owned(), 48 | Err(_) => inputs[0].to_shape(&self.shape[..n - 1]).unwrap().into_owned(), 49 | }; 50 | assert!(outputs[0] == &reshaped); 51 | 52 | vec![] 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/basic_block/constant.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ndarray::{arr0, ArrayD, IxDyn}; 5 | use rand::rngs::StdRng; 6 | 7 | #[derive(Debug)] 8 | pub struct ConstBasicBlock; 9 | impl BasicBlock for ConstBasicBlock { 10 | fn run(&self, model: &ArrayD, _inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 11 | Ok(vec![model.clone()]) 12 | } 13 | 14 | fn encodeOutputs(&self, _srs: &SRS, model: &ArrayD, _inputs: &Vec<&ArrayD>, _outputs: &Vec<&ArrayD>) -> Vec> { 15 | vec![model.clone()] 16 | } 17 | 18 | fn verify( 19 | &self, 20 | _srs: &SRS, 21 | model: &ArrayD, 22 | _inputs: &Vec<&ArrayD>, 23 | outputs: &Vec<&ArrayD>, 24 | _proof: (&Vec, &Vec, &Vec), 25 | _rng: &mut StdRng, 26 | _cache: ProveVerifyCache, 27 | ) -> Vec { 28 | assert!(model == outputs[0]); 29 | 30 | vec![] 31 | } 32 | } 33 | 34 | #[derive(Debug)] 35 | pub struct Const2BasicBlock { 36 | pub c: ArrayD, 37 | } 38 | 39 | impl BasicBlock for Const2BasicBlock { 40 | fn run(&self, _model: &ArrayD, _inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 41 | Ok(vec![self.c.clone()]) 42 | } 43 | } 44 | 45 | // ConstOfShapeBasicBlock is a basic block that creates a constant tensor of a given shape and value. 46 | // It requires no proving since the constant value is known as a public input. 47 | #[derive(Debug)] 48 | pub struct ConstOfShapeBasicBlock { 49 | pub c: Fr, 50 | pub shape: Vec, 51 | } 52 | impl BasicBlock for ConstOfShapeBasicBlock { 53 | fn run(&self, _model: &ArrayD, _inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 54 | Ok(vec![ArrayD::from_elem(IxDyn(&self.shape), self.c)]) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/ptau.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use ark_bn254::{Fq, Fq2, G1Affine, G1Projective, G2Affine, G2Projective}; 3 | use ark_ff::PrimeField; 4 | use rayon::prelude::*; 5 | use std::fs::File; 6 | use std::io::{Read, Seek, SeekFrom}; 7 | 8 | pub fn load_file(filename: &str, n: usize, m: usize) -> SRS { 9 | let powers_length = 1 << n; 10 | let powers_g1_length = (powers_length << 1) - 1; 11 | 12 | let mut file = File::open(filename).unwrap(); 13 | let mut bytes = vec![0; 64 * (1 << m) + 1]; 14 | file.seek(SeekFrom::Start(64)).unwrap(); 15 | file.read_exact(&mut bytes).unwrap(); 16 | 17 | let g1: Vec = (0..1 << m) 18 | .into_par_iter() 19 | .map(|i| { 20 | let start = i * 64; 21 | let x = Fq::from_be_bytes_mod_order(&bytes[start..start + 32]); 22 | let y = Fq::from_be_bytes_mod_order(&bytes[start + 32..start + 64]); 23 | G1Affine::new_unchecked(x, y) 24 | }) 25 | .collect(); 26 | let g1_p: Vec = g1.par_iter().map(|x| (*x).into()).collect(); 27 | 28 | let mut bytes = vec![0; 128 * (1 << m) + 1]; 29 | file.seek(SeekFrom::Start(64 + 64 * powers_g1_length)).unwrap(); 30 | file.read_exact(&mut bytes).unwrap(); 31 | 32 | let g2: Vec = (0..1 << m) 33 | .into_par_iter() 34 | .map(|i| { 35 | let start = 128 * i; 36 | let a = Fq::from_be_bytes_mod_order(&bytes[start..start + 32]); 37 | let b = Fq::from_be_bytes_mod_order(&bytes[start + 32..start + 64]); 38 | let c = Fq::from_be_bytes_mod_order(&bytes[start + 64..start + 96]); 39 | let d = Fq::from_be_bytes_mod_order(&bytes[start + 96..start + 128]); 40 | G2Affine::new_unchecked(Fq2 { c0: b, c1: a }, Fq2 { c0: d, c1: c }) 41 | }) 42 | .collect(); 43 | let g2_p: Vec = g2.par_iter().map(|x| (*x).into()).collect(); 44 | 45 | let res = SRS { 46 | Y1A: g1[g2.len() - 1], 47 | Y2A: g2[g2.len() - 1], 48 | Y1P: g1_p[g2.len() - 1], 49 | Y2P: g2_p[g2.len() - 1], 50 | X1A: g1, 51 | X2A: g2, 52 | X1P: g1_p, 53 | X2P: g2_p, 54 | }; 55 | 56 | res 57 | } 58 | -------------------------------------------------------------------------------- /src/basic_block/split.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ark_std::Zero; 5 | use ndarray::{ArrayD, Axis}; 6 | use rand::rngs::StdRng; 7 | 8 | #[derive(Debug)] 9 | pub struct SplitBasicBlock { 10 | pub axis: usize, 11 | pub split: Vec, 12 | } 13 | 14 | impl BasicBlock for SplitBasicBlock { 15 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 16 | assert!(inputs.len() == 1); 17 | assert!(self.axis < inputs[0].ndim() - 1); 18 | assert!(inputs[0].shape()[self.axis] == self.split.iter().sum::()); 19 | let mut r = vec![]; 20 | // use split_at 21 | let mut b = inputs[0].view(); 22 | for &s in self.split.iter() { 23 | let (a, remaining) = b.split_at(Axis(self.axis), s); 24 | b = remaining; 25 | r.push(a.to_owned()); 26 | } 27 | Ok(r) 28 | } 29 | 30 | fn encodeOutputs(&self, _srs: &SRS, _model: &ArrayD, inputs: &Vec<&ArrayD>, _outputs: &Vec<&ArrayD>) -> Vec> { 31 | let mut r = vec![]; 32 | // use split_at 33 | let mut b = inputs[0].view(); 34 | for &s in self.split.iter() { 35 | let (a, remaining) = b.split_at(Axis(self.axis), s); 36 | b = remaining; 37 | r.push(a.to_owned()); 38 | } 39 | r 40 | } 41 | 42 | fn verify( 43 | &self, 44 | _srs: &SRS, 45 | _model: &ArrayD, 46 | inputs: &Vec<&ArrayD>, 47 | outputs: &Vec<&ArrayD>, 48 | _proof: (&Vec, &Vec, &Vec), 49 | _rng: &mut StdRng, 50 | _cache: ProveVerifyCache, 51 | ) -> Vec { 52 | let mut b = inputs[0].view(); 53 | for i in 0..outputs.len() { 54 | let (a, remaining) = b.split_at(Axis(self.axis), self.split[i]); 55 | b = remaining; 56 | outputs[i].iter().zip(a.iter()).for_each(|(input, output)| { 57 | assert!(input == output); 58 | }); 59 | } 60 | vec![] 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/layer/where.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::{arr1, ArrayD}; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | pub struct WhereLayer; 11 | impl Layer for WhereLayer { 12 | fn graph( 13 | input_shapes: &Vec<&Vec>, 14 | input_types: &Vec, 15 | _constants: &Vec, DatumType)>>, 16 | _attributes: &Vec<&AttributeProto>, 17 | ) -> (Graph, Vec>, Vec) { 18 | //condition, X, Y 19 | //condition * X + (1-condition) * Y 20 | let mut graph = Graph::new(); 21 | let mul_scalar = graph.addBB(Box::new(RepeaterBasicBlock { 22 | basic_block: Box::new(MulScalarBasicBlock {}), 23 | N: 1, 24 | })); 25 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize; 26 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 27 | basic_block: Box::new(MulBasicBlock { len }), 28 | N: 1, 29 | })); 30 | let one = graph.addBB(Box::new(Const2BasicBlock { 31 | c: arr1(&vec![Fr::from(1); util::next_pow(*input_shapes[0].last().unwrap() as u32) as usize]).into_dyn(), 32 | })); 33 | let add = graph.addBB(Box::new(RepeaterBasicBlock { 34 | basic_block: Box::new(AddBasicBlock {}), 35 | N: 1, 36 | })); 37 | let sub = graph.addBB(Box::new(RepeaterBasicBlock { 38 | basic_block: Box::new(SubBasicBlock {}), 39 | N: 1, 40 | })); 41 | 42 | let one_output = graph.addNode(one, vec![]); 43 | let mul1_output = graph.addNode(if input_shapes[1].len() == 0 { mul_scalar } else { mul }, vec![(-1, 0), (-2, 0)]); 44 | let sub_output = graph.addNode(sub, vec![(one_output, 0), (-1, 0)]); 45 | let mul2_output = graph.addNode(if input_shapes[2].len() == 0 { mul_scalar } else { mul }, vec![(sub_output, 0), (-3, 0)]); 46 | let add_output = graph.addNode(add, vec![(mul1_output, 0), (mul2_output, 0)]); 47 | graph.outputs.push((add_output, 0)); 48 | (graph, vec![input_shapes[0].clone()], vec![input_types[0]]) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/layer/and.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::ArrayD; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | pub struct AndLayer; 11 | impl Layer for AndLayer { 12 | fn graph( 13 | input_shapes: &Vec<&Vec>, 14 | _input_types: &Vec, 15 | _constants: &Vec, DatumType)>>, 16 | _attributes: &Vec<&AttributeProto>, 17 | ) -> (Graph, Vec>, Vec) { 18 | let mut graph = Graph::new(); 19 | let bool_check = graph.addBB(Box::new(BooleanCheckBasicBlock {})); 20 | let mul_scalar = graph.addBB(Box::new(RepeaterBasicBlock { 21 | basic_block: Box::new(MulScalarBasicBlock {}), 22 | N: 1, 23 | })); 24 | 25 | let _ = graph.addNode(bool_check, vec![(-1, 0)]); 26 | let _ = graph.addNode(bool_check, vec![(-2, 0)]); 27 | // If any of the inputs are scalars, use the scalar version of the mul basic block. 28 | let mul_basicblock = if input_shapes[1].len() == 0 || input_shapes[0].len() == 0 { 29 | mul_scalar 30 | // else use the normal version of the mul basic block. 31 | } else { 32 | let len = if input_shapes[0].len() > input_shapes[1].len() { 33 | util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize 34 | } else { 35 | util::next_pow(input_shapes[1][input_shapes[1].len() - 1] as u32) as usize 36 | }; 37 | graph.addBB(Box::new(RepeaterBasicBlock { 38 | basic_block: Box::new(MulBasicBlock { len }), 39 | N: 1, 40 | })) 41 | }; 42 | // If the first input is a scalar, swap the inputs, because the mul scalar basic block expects the scalar to be the second input. 43 | let and_output = if input_shapes[0].len() == 0 { 44 | graph.addNode(mul_basicblock, vec![(-2, 0), (-1, 0)]) 45 | } else { 46 | graph.addNode(mul_basicblock, vec![(-1, 0), (-2, 0)]) 47 | }; 48 | 49 | graph.outputs.push((and_output, 0)); 50 | (graph, vec![util::broadcastDims(input_shapes, 0)], vec![DatumType::Bool]) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/basic_block/ops.rs: -------------------------------------------------------------------------------- 1 | use super::BasicBlock; 2 | use crate::util; 3 | use ark_bn254::Fr; 4 | use ndarray::ArrayD; 5 | use rayon::iter::ParallelIterator; 6 | 7 | macro_rules! make_basic_block { 8 | ( 9 | $name:ident, 10 | $operation:block 11 | ) => { 12 | #[derive(Debug)] 13 | pub struct $name { 14 | pub input_SF: usize, 15 | pub output_SF: usize, 16 | } 17 | impl BasicBlock for $name { 18 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 19 | assert!(inputs.len() == 1); 20 | let shape = inputs[0].shape(); 21 | let out = util::array_into_iter(inputs[0]) 22 | .map(|x| { 23 | let mut x = util::fr_to_int(*x) as f64; 24 | x /= (1 << self.input_SF) as f64; 25 | x = $operation(x); 26 | x *= (1 << self.output_SF) as f64; 27 | Fr::from(x.round() as i128) 28 | }) 29 | .collect::>(); 30 | 31 | Ok(vec![ArrayD::from_shape_vec(shape, out).unwrap()]) 32 | } 33 | } 34 | }; 35 | } 36 | 37 | make_basic_block!(ExpBasicBlock, { |x: f64| { x.exp() } }); 38 | make_basic_block!(LogBasicBlock, { |x: f64| { x.ln() } }); 39 | make_basic_block!(ReLUBasicBlock, { 40 | |x: f64| { 41 | if x < 0f64 { 42 | 0f64 43 | } else { 44 | x 45 | } 46 | } 47 | }); 48 | make_basic_block!(SqrtBasicBlock, { |x: f64| { x.sqrt() } }); 49 | make_basic_block!(ChangeSFBasicBlock, { |x: f64| { x } }); 50 | make_basic_block!(ErfBasicBlock, { |x: f64| { util::erf(x) } }); 51 | make_basic_block!(SigmoidBasicBlock, { |x: f64| { x.exp() / (1. + x.exp()) } }); 52 | make_basic_block!(TanhBasicBlock, { |x: f64| { x.tanh() } }); 53 | make_basic_block!(CeilBasicBlock, { |x: f64| { x.ceil() } }); 54 | make_basic_block!(NegBasicBlock, { |x: f64| { -x } }); 55 | make_basic_block!(CosBasicBlock, { |x: f64| { x.cos() } }); 56 | make_basic_block!(SinBasicBlock, { |x: f64| { x.sin() } }); 57 | make_basic_block!(TanBasicBlock, { |x: f64| { x.tan() } }); 58 | make_basic_block!(ReciprocalBasicBlock, { |x: f64| { 1. / x } }); 59 | make_basic_block!(GeLUBasicBlock, { |x: f64| { util::gelu(x) } }); 60 | -------------------------------------------------------------------------------- /src/basic_block/sub.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ndarray::{arr0, azip, ArrayD, IxDyn}; 5 | use rand::rngs::StdRng; 6 | 7 | #[derive(Debug)] 8 | pub struct SubBasicBlock; 9 | impl BasicBlock for SubBasicBlock { 10 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 11 | assert!(inputs.len() == 2 && inputs[0].ndim() <= 1 && inputs[1].ndim() <= 1); 12 | let mut r = ArrayD::zeros(IxDyn(&[std::cmp::max(inputs[0].len(), inputs[1].len())])); 13 | if inputs[0].len() == 1 && inputs[1].ndim() == 0 { 14 | // speicial case: [1] - [] 15 | r = inputs[0].map(|x| x - inputs[1].first().unwrap()); 16 | } else if inputs[0].len() == 1 { 17 | azip!((r in &mut r, &y in inputs[1]) *r = *inputs[0].first().unwrap() - y); 18 | } else if inputs[1].len() == 1 { 19 | azip!((r in &mut r, &x in inputs[0]) *r = x - *inputs[1].first().unwrap()); 20 | } else { 21 | azip!((r in &mut r, &x in inputs[0], &y in inputs[1]) *r = x - y); 22 | } 23 | Ok(vec![r]) 24 | } 25 | 26 | fn encodeOutputs(&self, _srs: &SRS, _model: &ArrayD, inputs: &Vec<&ArrayD>, outputs: &Vec<&ArrayD>) -> Vec> { 27 | let a = &inputs[0].first().unwrap(); 28 | let b = &inputs[1].first().unwrap(); 29 | vec![arr0(Data { 30 | raw: outputs[0].clone().into_raw_vec(), 31 | poly: (&a.poly) - (&b.poly), 32 | g1: a.g1 - b.g1, 33 | r: a.r - b.r, 34 | }) 35 | .into_dyn()] 36 | } 37 | 38 | fn verify( 39 | &self, 40 | _srs: &SRS, 41 | _model: &ArrayD, 42 | inputs: &Vec<&ArrayD>, 43 | outputs: &Vec<&ArrayD>, 44 | _proof: (&Vec, &Vec, &Vec), 45 | _rng: &mut StdRng, 46 | _cache: ProveVerifyCache, 47 | ) -> Vec { 48 | let a = inputs[0].first().unwrap(); 49 | let b = inputs[1].first().unwrap(); 50 | let c = outputs[0].first().unwrap(); 51 | // Verify f(x)-g(x)=h(x) 52 | assert!(a.g1 - b.g1 == c.g1); 53 | vec![] 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/layer/gather.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::{ArrayD, Axis}; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | #[derive(Debug)] 11 | pub struct GatherBasicBlock; 12 | impl BasicBlock for GatherBasicBlock { 13 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 14 | let mut v = Vec::new(); 15 | inputs[1].for_each(|x| { 16 | let idx = util::fr_to_int(*x) as usize; 17 | v.extend_from_slice(inputs[0].index_axis(Axis(0), idx).to_slice().unwrap()); 18 | }); 19 | let mut shape = inputs[1].shape().to_vec(); 20 | shape.extend_from_slice(&inputs[0].shape()[1..]); 21 | let v = ArrayD::from_shape_vec(shape, v).unwrap(); 22 | Ok(vec![v]) 23 | } 24 | } 25 | 26 | pub struct GatherLayer; 27 | impl Layer for GatherLayer { 28 | fn graph( 29 | input_shapes: &Vec<&Vec>, 30 | input_types: &Vec, 31 | constants: &Vec, DatumType)>>, 32 | _attributes: &Vec<&AttributeProto>, 33 | ) -> (Graph, Vec>, Vec) { 34 | let mut graph = Graph::new(); 35 | let mut indices_output = -2; 36 | // Handle the case where indices are not an input 37 | if input_shapes[1].len() == 0 || constants.len() > 1 && constants[1].is_some() { 38 | let indices = constants[1].unwrap().0.mapv(|x| { 39 | if x > Fr::from(input_shapes[0][0] as i128) { 40 | Fr::from(input_shapes[0][0] as i128 + util::fr_to_int(x)) 41 | } else { 42 | x 43 | } 44 | }); 45 | let indices = graph.addBB(Box::new(Const2BasicBlock { c: indices })); 46 | indices_output = graph.addNode(indices, vec![]); 47 | } 48 | let gather = graph.addBB(Box::new(GatherBasicBlock {})); 49 | let output = graph.addNode(gather, vec![(-1, 0), (indices_output, 0)]); 50 | graph.outputs.push((output, 0)); 51 | let mut output_shape = input_shapes[1].clone(); 52 | output_shape.extend_from_slice(&input_shapes[0][1..]); 53 | (graph, vec![output_shape], vec![input_types[0]]) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/util/serialization.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Serialization utilities for converting between serde and ark_serialize. 3 | * And other file I/O utilities. 4 | */ 5 | use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; 6 | use std::collections::hash_map::DefaultHasher; 7 | use std::fs::{self, File}; 8 | use std::hash::{Hash, Hasher}; 9 | 10 | // For serialization, ArrayD uses serde while G1Affine uses ark_serialize. 11 | // In order to bridge between the two, the following code snippet is used: 12 | // https://github.com/arkworks-rs/algebra/issues/178#issuecomment-1413219278 13 | pub fn ark_se(a: &A, s: S) -> Result 14 | where 15 | S: serde::Serializer, 16 | { 17 | let mut bytes = vec![]; 18 | a.serialize_compressed(&mut bytes).map_err(serde::ser::Error::custom)?; 19 | s.serialize_bytes(&bytes) 20 | } 21 | 22 | pub fn ark_de<'de, D, A: CanonicalDeserialize>(data: D) -> Result 23 | where 24 | D: serde::de::Deserializer<'de>, 25 | { 26 | let s: Vec = serde::de::Deserialize::deserialize(data)?; 27 | let a = A::deserialize_compressed_unchecked(s.as_slice()); 28 | a.map_err(serde::de::Error::custom) 29 | } 30 | 31 | pub fn measure_file_size(file_path: &str) -> u64 { 32 | let file = File::open(file_path).unwrap(); 33 | let metadata = file.metadata().unwrap(); 34 | let file_size_bytes = metadata.len(); 35 | println!("{} size: {}", file_path, format_file_size(file_size_bytes)); 36 | file_size_bytes 37 | } 38 | 39 | pub fn format_file_size(bytes: u64) -> String { 40 | const KB: f64 = 1024.0; 41 | const MB: f64 = KB * 1024.0; 42 | const GB: f64 = MB * 1024.0; 43 | 44 | if bytes as f64 >= GB { 45 | format!("{:.2} GB", bytes as f64 / GB) 46 | } else if bytes as f64 >= MB { 47 | format!("{:.2} MB", bytes as f64 / MB) 48 | } else if bytes as f64 >= KB { 49 | format!("{:.2} KB", bytes as f64 / KB) 50 | } else { 51 | format!("{} bytes", bytes) 52 | } 53 | } 54 | 55 | pub fn hash_str(s: &str) -> String { 56 | let mut hasher = DefaultHasher::new(); 57 | s.hash(&mut hasher); 58 | let hash_value = hasher.finish(); 59 | hash_value.to_string() 60 | } 61 | 62 | pub fn file_exists(path: &str) -> bool { 63 | fs::metadata(path).is_ok() 64 | } 65 | -------------------------------------------------------------------------------- /src/layer/constantofshape.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ndarray::ArrayD; 8 | use tract_onnx::pb::tensor_proto::DataType; 9 | use tract_onnx::pb::AttributeProto; 10 | use tract_onnx::prelude::DatumType; 11 | 12 | // Generate a tensor with a given value (the value is in the ONNX attribute) and shape (the shape is in the input tensor) 13 | // reference: https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html 14 | pub struct ConstOfShapeLayer; 15 | impl Layer for ConstOfShapeLayer { 16 | fn graph( 17 | _input_shapes: &Vec<&Vec>, 18 | _input_types: &Vec, 19 | constants: &Vec, DatumType)>>, 20 | attributes: &Vec<&AttributeProto>, 21 | ) -> (Graph, Vec>, Vec) { 22 | let mut graph = Graph::new(); 23 | 24 | let attr_val = attributes.iter().filter(|x| x.name == "value").next().unwrap(); 25 | let dtype = DataType::from_i32(attr_val.r#type).unwrap().into(); 26 | let datum_type = match dtype { 27 | DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => DatumType::I64, 28 | DataType::Uint8 | DataType::Uint16 | DataType::Uint32 | DataType::Uint64 => DatumType::I64, 29 | DataType::Double | DataType::Float16 | DataType::Float => DatumType::F32, 30 | _ => panic!("Unsupported data type"), 31 | }; 32 | let value = match datum_type { 33 | DatumType::I64 => Fr::from(attr_val.t.clone().unwrap().raw_data[0]), 34 | DatumType::F32 => Fr::from((attr_val.t.clone().unwrap().raw_data[0] as f32 * onnx::SF_FLOAT.read().unwrap().to_owned()).round() as i32), 35 | _ => panic!("Unsupported data type"), 36 | }; 37 | let endShape: Vec = constants[0].unwrap().0.as_slice().unwrap().iter().map(|x| util::fr_to_int(*x) as usize).filter(|x| *x != 0).collect(); 38 | let endShape_padded: Vec = endShape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 39 | 40 | let constantOfShape = graph.addBB(Box::new(ConstOfShapeBasicBlock { 41 | c: value, 42 | shape: endShape_padded.clone(), 43 | })); 44 | let output = graph.addNode(constantOfShape, vec![]); 45 | graph.outputs.push((output, 0)); 46 | (graph, vec![endShape], vec![datum_type]) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/layer/cast.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use crate::util::datumtype_to_sf; 7 | use ark_bn254::Fr; 8 | use ndarray::ArrayD; 9 | use tract_onnx::pb::AttributeProto; 10 | use tract_onnx::prelude::DatumType; 11 | 12 | pub struct CastLayer; 13 | impl Layer for CastLayer { 14 | fn graph( 15 | input_shapes: &Vec<&Vec>, 16 | input_types: &Vec, 17 | _constants: &Vec, DatumType)>>, 18 | attributes: &Vec<&AttributeProto>, 19 | ) -> (Graph, Vec>, Vec) { 20 | let mut graph = Graph::new(); 21 | let to = match attributes.iter().filter(|x| x.name == "to").next() { 22 | Some(v) => vec![util::datatype_to_datumtype(v.i as i32)], 23 | None => vec![input_types[0]], 24 | }; 25 | let input_SF = datumtype_to_sf(input_types[0]); 26 | let output_SF = datumtype_to_sf(to[0]); 27 | let id = if input_SF == output_SF { 28 | graph.addBB(Box::new(IdBasicBlock {})) 29 | } else { 30 | graph.addBB(Box::new(ChangeSFBasicBlock { input_SF, output_SF })) 31 | }; 32 | let change_sf_check = if input_shapes[0].len() == 0 { 33 | graph.addBB(Box::new(CQ2BasicBlock { 34 | n: 1, 35 | setup: Some(( 36 | Box::new(ChangeSFBasicBlock { 37 | input_SF: input_SF, 38 | output_SF: output_SF, 39 | }), 40 | *onnx::CQ_RANGE_LOWER, 41 | *onnx::CQ_RANGE, 42 | )), 43 | })) 44 | } else { 45 | graph.addBB(Box::new(RepeaterBasicBlock { 46 | basic_block: Box::new(CQ2BasicBlock { 47 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 48 | setup: Some(( 49 | Box::new(ChangeSFBasicBlock { 50 | input_SF: input_SF, 51 | output_SF: output_SF, 52 | }), 53 | *onnx::CQ_RANGE_LOWER, 54 | *onnx::CQ_RANGE, 55 | )), 56 | }), 57 | N: 1, 58 | })) 59 | }; 60 | let id_output = graph.addNode(id, vec![(-1, 0)]); 61 | if input_SF != output_SF { 62 | let _ = graph.addNode(change_sf_check, vec![(-1, 0), (id_output, 0)]); 63 | } 64 | graph.outputs.push((id_output, 0)); 65 | (graph, vec![input_shapes[0].clone()], to) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/layer/flatten.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::{ArrayD, IxDyn}; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | fn get_permutation(input_shape: &[usize], axis: usize) -> (ArrayD>, Vec) { 11 | assert!(axis < input_shape.len()); 12 | let output_shape = if axis == 0 { 13 | vec![1, input_shape.iter().product()] 14 | } else { 15 | vec![input_shape[..axis].iter().product(), input_shape[axis..].iter().product()] 16 | }; 17 | 18 | let permutation = ArrayD::from_shape_fn(input_shape, |index| Some(index)); 19 | let permutation = permutation.view().into_shape(&output_shape.clone()[..]).unwrap().to_owned(); 20 | let padded_permutation = util::pad_to_pow_of_two(&permutation, &None); 21 | 22 | (padded_permutation, output_shape) 23 | } 24 | 25 | // https://onnx.ai/onnx/operators/onnx__Flatten.html 26 | // Flattens the input tensor into a 2D matrix. 27 | // If input tensor has shape (d_0, d_1, ..., d_n) then the output will have shape (d_0 × d_1 × ... × d_{axis-1}, d_{axis} × d_{axis+1} × ... × dn). 28 | pub struct FlattenLayer; 29 | impl Layer for FlattenLayer { 30 | fn graph( 31 | input_shapes: &Vec<&Vec>, 32 | input_types: &Vec, 33 | _constants: &Vec, DatumType)>>, 34 | attributes: &Vec<&AttributeProto>, 35 | ) -> (Graph, Vec>, Vec) { 36 | let mut graph = Graph::new(); 37 | 38 | let axis: isize = attributes.iter().filter(|x| x.name == "axis").next().unwrap().i as isize; 39 | let axis = (if axis < 0 { input_shapes[0].len() as isize + axis } else { axis }) as usize; 40 | 41 | let padded_input_shape: Vec = input_shapes[0].iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 42 | 43 | let (permutation, output_shape) = get_permutation(&input_shapes[0], axis); 44 | 45 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 46 | permutation: permutation, 47 | input_dim: IxDyn(&padded_input_shape), 48 | padding_partition: copy_constraint::PaddingEnum::Zero, 49 | })); 50 | 51 | let output = graph.addNode(cc, vec![(-1, 0)]); 52 | graph.outputs.push((output, 0)); 53 | 54 | (graph, vec![output_shape], vec![input_types[0]]) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/layer/nonlinear.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use ark_bn254::Fr; 6 | use ndarray::ArrayD; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | macro_rules! define_nonlinear_layer { 11 | ($struct_name:ident, $basic_block:ident) => { 12 | pub struct $struct_name; 13 | 14 | impl Layer for $struct_name { 15 | fn graph( 16 | input_shapes: &Vec<&Vec>, 17 | input_types: &Vec, 18 | _constants: &Vec, DatumType)>>, 19 | _attributes: &Vec<&AttributeProto>, 20 | ) -> (Graph, Vec>, Vec) { 21 | let mut graph = Graph::new(); 22 | let sf_log = onnx::SF_LOG.read().unwrap().to_owned(); 23 | let layer = graph.addBB(Box::new($basic_block { 24 | input_SF: sf_log, 25 | output_SF: sf_log, 26 | })); 27 | let layer_check = graph.addBB(Box::new(RepeaterBasicBlock { 28 | basic_block: Box::new(CQ2BasicBlock { 29 | n: if input_shapes[0].len() == 0 { 30 | 1 31 | } else { 32 | input_shapes[0][input_shapes[0].len() - 1].next_power_of_two() 33 | }, 34 | setup: Some(( 35 | Box::new($basic_block { 36 | input_SF: sf_log, 37 | output_SF: sf_log, 38 | }), 39 | *onnx::CQ_RANGE_LOWER, 40 | *onnx::CQ_RANGE, 41 | )), 42 | }), 43 | N: 1, 44 | })); 45 | let layer_output = graph.addNode(layer, vec![(-1, 0)]); 46 | let _ = graph.addNode(layer_check, vec![(-1, 0), (layer_output, 0)]); 47 | graph.outputs.push((layer_output, 0)); 48 | (graph, vec![input_shapes[0].clone()], vec![input_types[0]]) 49 | } 50 | } 51 | }; 52 | } 53 | 54 | // Using the macro to define nonlinear layers 55 | define_nonlinear_layer!(ReLULayer, ReLUBasicBlock); 56 | define_nonlinear_layer!(CeilLayer, CeilBasicBlock); 57 | define_nonlinear_layer!(ErfLayer, ErfBasicBlock); 58 | define_nonlinear_layer!(ExpLayer, ExpBasicBlock); 59 | define_nonlinear_layer!(SigmoidLayer, SigmoidBasicBlock); 60 | define_nonlinear_layer!(TanhLayer, TanhBasicBlock); 61 | define_nonlinear_layer!(CosLayer, CosBasicBlock); 62 | define_nonlinear_layer!(SinLayer, SinBasicBlock); 63 | define_nonlinear_layer!(TanLayer, TanBasicBlock); 64 | define_nonlinear_layer!(ReciprocalLayer, ReciprocalBasicBlock); 65 | define_nonlinear_layer!(GeLULayer, GeLUBasicBlock); 66 | -------------------------------------------------------------------------------- /src/basic_block/sort.rs: -------------------------------------------------------------------------------- 1 | use super::BasicBlock; 2 | use crate::{ 3 | basic_block::{Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}, 4 | onnx, 5 | util::{self, calc_pow}, 6 | }; 7 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 8 | use ark_ff::Field; 9 | use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, GeneralEvaluationDomain, Polynomial}; 10 | use ark_serialize::CanonicalSerialize; 11 | use ark_std::{cmp::max, One, UniformRand, Zero}; 12 | use ndarray::{arr0, arr1, azip, s, ArrayD, Axis}; 13 | use rand::{rngs::StdRng, SeedableRng}; 14 | use rayon::prelude::*; 15 | use std::ops::{Add, Mul, Sub}; 16 | 17 | // SortBasicBlock is a basic block that sorts the input data in ascending or descending order. 18 | // It takes two inputs: the data and the original indices. 19 | // It returns three tensors: the sorted data, the sorted indice. 20 | // Note 1: please always remember to perform one-to-one mapping check and order check after sorting. 21 | // Note 2: we need len to be passed in as a parameter because the data may be padded with zeros, and we need to ignore the padded zeros when sorting. 22 | #[derive(Debug)] 23 | pub struct SortBasicBlock { 24 | pub descending: bool, 25 | pub len: usize, 26 | } 27 | impl BasicBlock for SortBasicBlock { 28 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 29 | assert!(inputs.len() == 2 && inputs[0].ndim() == 1 && inputs[1].ndim() == 1 && inputs[0].len() == inputs[1].len()); 30 | let data = inputs[0].slice(s![..self.len]).to_owned().into_dyn(); 31 | let data_tail = inputs[0].slice(s![self.len..]).iter().cloned().collect::>(); 32 | let indices = inputs[1].slice(s![..self.len]).to_owned().into_dyn(); 33 | let indices_tail = inputs[1].slice(s![self.len..]).iter().cloned().collect::>(); 34 | 35 | // Pair the data and indices 36 | let mut paired: Vec<_> = data.into_iter().zip(indices.into_iter()).collect(); 37 | // Sort by the first element of the tuple (data value) 38 | paired.sort_by_key(|&(data, _)| data); 39 | if self.descending { 40 | // Reverse the sorted data to get descending order 41 | paired.reverse(); 42 | } 43 | 44 | // Separate the sorted data and indices 45 | let (sorted_data, sorted_indices): (Vec<_>, Vec<_>) = util::vec_iter(&paired).map(|(data, index)| (data, index)).unzip(); 46 | // Concatenate the sorted data and indices with the tail 47 | let sorted_data = sorted_data.into_iter().chain(data_tail).collect::>(); 48 | let sorted_indices = sorted_indices.into_iter().chain(indices_tail).collect::>(); 49 | 50 | let (sorted_data, sorted_indices) = (arr1(&sorted_data).into_dyn(), arr1(&sorted_indices).into_dyn()); 51 | 52 | Ok(vec![sorted_data, sorted_indices]) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/layer/equal.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ndarray::{arr1, ArrayD}; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | pub struct EqualLayer; 12 | impl Layer for EqualLayer { 13 | fn graph( 14 | input_shapes: &Vec<&Vec>, 15 | _input_types: &Vec, 16 | _constants: &Vec, DatumType)>>, 17 | _attributes: &Vec<&AttributeProto>, 18 | ) -> (Graph, Vec>, Vec) { 19 | let mut graph = Graph::new(); 20 | let one = graph.addBB(Box::new(Const2BasicBlock { 21 | c: arr1(&vec![Fr::from(1); util::next_pow(*input_shapes[0].last().unwrap() as u32) as usize]).into_dyn(), 22 | })); 23 | let sub = graph.addBB(Box::new(RepeaterBasicBlock { 24 | basic_block: Box::new(SubBasicBlock {}), 25 | N: 1, 26 | })); 27 | let add = graph.addBB(Box::new(RepeaterBasicBlock { 28 | basic_block: Box::new(AddBasicBlock {}), 29 | N: 1, 30 | })); 31 | let equal = graph.addBB(Box::new(RepeaterBasicBlock { 32 | basic_block: Box::new(ElementwiseEqBasicBlock {}), 33 | N: 1, 34 | })); 35 | let eq = graph.addBB(Box::new(RepeaterBasicBlock { 36 | basic_block: Box::new(EqBasicBlock {}), 37 | N: 1, 38 | })); 39 | 40 | let len = util::next_pow(input_shapes[1][input_shapes[1].len() - 1] as u32) as usize; 41 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 42 | basic_block: Box::new(MulBasicBlock { len }), 43 | N: 1, 44 | })); 45 | 46 | let nonzero_check = graph.addBB(Box::new(RepeaterBasicBlock { 47 | basic_block: Box::new(CQBasicBlock { 48 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 49 | setup: util::CQArrayType::NonZero, 50 | }), 51 | N: 1, 52 | })); 53 | 54 | let equal_output = graph.addNode(equal, vec![(-1, 0), (-2, 0)]); // a == b 55 | let one_output = graph.addNode(one, vec![]); 56 | let not_equal_output = graph.addNode(sub, vec![(one_output, 0), (equal_output, 0)]); 57 | 58 | let a_equal_b = graph.addNode(mul, vec![(-1, 0), (equal_output, 0)]); // a * (a == b) 59 | let b_equal_a = graph.addNode(mul, vec![(-2, 0), (equal_output, 0)]); // b * (a == b) 60 | let a_minus_b = graph.addNode(sub, vec![(-1, 0), (-2, 0)]); // a - b 61 | let a_not_equal_b = graph.addNode(mul, vec![(a_minus_b, 0), (not_equal_output, 0)]); // (a - b) * (1 - (a == b)) 62 | let add_output = graph.addNode(add, vec![(a_not_equal_b, 0), (equal_output, 0)]); // should be all nonzeros 63 | 64 | let _eq_check = graph.addNode(eq, vec![(a_equal_b, 0), (b_equal_a, 0)]); 65 | let _nonzero_check = graph.addNode(nonzero_check, vec![(add_output, 0)]); 66 | 67 | graph.outputs.push((equal_output, 0)); 68 | (graph, vec![util::broadcastDims(input_shapes, 0)], vec![DatumType::Bool]) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/basic_block/eq.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ark_poly::univariate::DensePolynomial; 5 | use ndarray::{ArrayD, IxDyn, Zip}; 6 | use rand::rngs::StdRng; 7 | 8 | #[derive(Debug)] 9 | pub struct EqBasicBlock; 10 | impl BasicBlock for EqBasicBlock { 11 | fn prove( 12 | &self, 13 | srs: &SRS, 14 | _setup: (&Vec, &Vec, &Vec>), 15 | _model: &ArrayD, 16 | inputs: &Vec<&ArrayD>, 17 | _outputs: &Vec<&ArrayD>, 18 | _rng: &mut StdRng, 19 | _cache: ProveVerifyCache, 20 | ) -> (Vec, Vec, Vec) { 21 | assert!(inputs.len() == 2 && inputs[0].ndim() <= 1 && inputs[1].ndim() <= 1); 22 | // Blinding 23 | let C = srs.X1P[0] * (inputs[0].first().unwrap().r - inputs[1].first().unwrap().r); 24 | (vec![C], Vec::new(), Vec::new()) 25 | } 26 | 27 | fn verify( 28 | &self, 29 | srs: &SRS, 30 | _model: &ArrayD, 31 | inputs: &Vec<&ArrayD>, 32 | _outputs: &Vec<&ArrayD>, 33 | proof: (&Vec, &Vec, &Vec), 34 | _rng: &mut StdRng, 35 | _cache: ProveVerifyCache, 36 | ) -> Vec { 37 | // Verify f(x)+g(x)=h(x) 38 | vec![vec![ 39 | (inputs[0].first().unwrap().g1, srs.X2A[0]), 40 | (-inputs[1].first().unwrap().g1, srs.X2A[0]), 41 | (-proof.0[0], srs.Y2A), 42 | ]] 43 | } 44 | } 45 | 46 | // ElementwiseEqBasicBlock is a basic block that performs elementwise equality comparison. 47 | #[derive(Debug)] 48 | pub struct ElementwiseEqBasicBlock; 49 | impl BasicBlock for ElementwiseEqBasicBlock { 50 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 51 | assert!(inputs.len() == 2 && inputs[0].ndim() <= 1 && inputs[1].ndim() <= 1); 52 | let mut r = ArrayD::zeros(IxDyn(&[std::cmp::max(inputs[0].len(), inputs[1].len())])); 53 | // broadcast inputs[0] to compare with each element in inputs[1] 54 | if inputs[0].len() == 1 && inputs[1].ndim() > 0 { 55 | Zip::from(r.view_mut()) 56 | .and(inputs[1].view()) 57 | .for_each(|r, &x| *r = (util::fr_to_int(x) == util::fr_to_int(*inputs[0].first().unwrap())) as u8); 58 | // broadcast inputs[1] to compare with each element in inputs[0] 59 | } else if inputs[1].len() == 1 && inputs[0].ndim() > 0 { 60 | Zip::from(r.view_mut()) 61 | .and(inputs[0].view()) 62 | .for_each(|r, &x| *r = (util::fr_to_int(x) == util::fr_to_int(*inputs[1].first().unwrap())) as u8); 63 | // elementwise comparison 64 | } else { 65 | Zip::from(r.view_mut()) 66 | .and(inputs[0].view()) 67 | .and(inputs[1].view()) 68 | .for_each(|r, &x, &y| *r = (util::fr_to_int(x) == util::fr_to_int(y)) as u8); 69 | } 70 | 71 | Ok(vec![r.map(|&x| Fr::from(x)).into_dyn()]) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/util/shape.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Shape utilities: 3 | * The functions are used for shape-related operations, such as 4 | * slicing and padding arrays. 5 | */ 6 | use ark_bn254::Fr; 7 | use ndarray::{ArrayD, Axis, Dimension, IxDyn, Slice, SliceInfo}; 8 | 9 | pub fn slice_nd_array(arr: ArrayD, indices: &[usize]) -> ArrayD { 10 | // Create slices from the indices 11 | let slices: Vec<_> = indices.iter().map(|&i| (0..i).into()).collect(); 12 | 13 | // Convert slices into a SliceInfo instance 14 | let slice_info = unsafe { SliceInfo::<_, IxDyn, IxDyn>::new(slices).unwrap() }; 15 | 16 | // Slice the array 17 | arr.slice_move(slice_info) 18 | } 19 | 20 | pub fn flatten_last_dimension(arr: &ArrayD) -> ArrayD> { 21 | let shape = arr.shape().to_vec(); 22 | let new_shape = IxDyn(&shape[..shape.len() - 1]); 23 | 24 | ArrayD::from_shape_fn(new_shape, |idx| { 25 | let mut full_idx = idx.as_array_view().to_vec(); 26 | full_idx.push(0); 27 | let slice = arr.slice_each_axis(|ax| { 28 | if ax.axis.index() < full_idx.len() - 1 { 29 | ndarray::Slice::from(full_idx[ax.axis.index()]..=full_idx[ax.axis.index()]) 30 | } else { 31 | ndarray::Slice::from(..) 32 | } 33 | }); 34 | slice.to_owned().into_raw_vec() 35 | }) 36 | } 37 | 38 | // Pads each dimension of input by the corresponding amount in padding on both ends. 39 | pub fn pad(input: &ArrayD, padding: &Vec<[usize; 2]>, pad_val: &G) -> ArrayD { 40 | let tmp = input.into_iter().collect(); 41 | let input = ArrayD::from_shape_vec(input.raw_dim(), tmp).unwrap(); 42 | assert_eq!(input.ndim(), padding.len()); 43 | let mut padded_shape = input.raw_dim(); 44 | for (ax, (&ax_len, &[pad_lo, pad_hi])) in input.shape().iter().zip(padding).enumerate() { 45 | padded_shape[ax] = ax_len + pad_lo + pad_hi; 46 | } 47 | 48 | let mut padded = ArrayD::from_elem(padded_shape, pad_val); 49 | let padded_dim = padded.raw_dim(); 50 | { 51 | // Select portion of padded array that needs to be copied from the 52 | // original array. 53 | let mut orig_portion = padded.view_mut(); 54 | for (ax, &[pad_lo, pad_hi]) in padding.iter().enumerate() { 55 | orig_portion.slice_axis_inplace(Axis(ax), Slice::from(pad_lo as isize..padded_dim[ax] as isize - (pad_hi as isize))); 56 | } 57 | // Copy the data from the original array. 58 | orig_portion.assign(&input); 59 | } 60 | 61 | let dim = padded.raw_dim(); 62 | let tmp = padded.into_iter().map(|x| x.clone()).collect(); 63 | let padded = ArrayD::from_shape_vec(dim, tmp).unwrap(); 64 | 65 | padded 66 | } 67 | 68 | pub fn pad_to_pow_of_two(input: &ArrayD, pad_val: &G) -> ArrayD { 69 | let padding: Vec<_> = input.shape().iter().map(|x| [0, x.next_power_of_two() - x]).collect(); 70 | pad(&input, &padding, &pad_val) 71 | } 72 | 73 | pub fn broadcastDims(dims: &Vec<&Vec>, N: usize) -> Vec { 74 | let len = dims.iter().map(|x| x.len()).max().unwrap(); 75 | (0..len - N) 76 | .map(|i| dims.iter().map(|dim| if dim.len() >= len - i { dim[i + dim.len() - len] } else { 1 }).max().unwrap()) 77 | .collect() 78 | } 79 | -------------------------------------------------------------------------------- /src/basic_block/concat.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ark_std::Zero; 5 | use ndarray::{indices, Array1, ArrayD, Axis, Slice, SliceInfoElem}; 6 | use rand::rngs::StdRng; 7 | use std::fmt::Debug; 8 | 9 | fn concatenate_sliced(inputs: &Vec<&ArrayD>, axis: usize, input_shapes: &Vec>) -> ArrayD { 10 | let sliced_views: Vec<_> = inputs 11 | .iter() 12 | .zip(input_shapes.iter()) 13 | .map(|(x, shape)| { 14 | let slice_info: Vec = (0..x.ndim()) 15 | .map(|i| { 16 | if i < shape.len() - 1 { 17 | SliceInfoElem::from(..shape[i]) 18 | } else { 19 | SliceInfoElem::from(..) 20 | } 21 | }) 22 | .collect(); 23 | x.slice(slice_info.as_slice()) 24 | }) 25 | .collect(); 26 | 27 | ndarray::concatenate(Axis(axis), &sliced_views).unwrap() 28 | } 29 | 30 | // support concat over any dim except for the last 31 | #[derive(Debug)] 32 | pub struct ConcatBasicBlock { 33 | pub axis: usize, 34 | pub input_shapes: Vec>, 35 | } 36 | 37 | impl BasicBlock for ConcatBasicBlock { 38 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 39 | assert!(self.axis != inputs[0].shape().len() - 1); 40 | let r = concatenate_sliced(inputs, self.axis, &self.input_shapes); 41 | Ok(vec![util::pad_to_pow_of_two(&r, &Fr::zero())]) 42 | } 43 | 44 | fn encodeOutputs(&self, srs: &SRS, _model: &ArrayD, inputs: &Vec<&ArrayD>, _outputs: &Vec<&ArrayD>) -> Vec> { 45 | if inputs[0].ndim() == 0 { 46 | let r_vec = inputs.iter().map(|input| input.first().unwrap().clone()).collect::>(); 47 | let r = Array1::from_vec(r_vec).into_dyn(); 48 | vec![r] 49 | } else { 50 | let N = inputs[0].first().unwrap().raw.len(); 51 | let r = concatenate_sliced(inputs, self.axis, &self.input_shapes); 52 | let data = Data::new(srs, &vec![Fr::zero(); N]); 53 | vec![util::pad_to_pow_of_two(&r, &data)] 54 | } 55 | } 56 | 57 | fn verify( 58 | &self, 59 | _srs: &SRS, 60 | _model: &ArrayD, 61 | inputs: &Vec<&ArrayD>, 62 | outputs: &Vec<&ArrayD>, 63 | _proof: (&Vec, &Vec, &Vec), 64 | _rng: &mut StdRng, 65 | _cache: ProveVerifyCache, 66 | ) -> Vec { 67 | if inputs[0].ndim() == 0 { 68 | let r = inputs.iter().map(|input| input.first().unwrap().clone()).collect::>(); 69 | let r_enc = outputs[0]; 70 | for i in 0..r.len() { 71 | assert!(r[i] == r_enc[i], "Mismatch at index {:?}", i); 72 | } 73 | } else { 74 | let r = concatenate_sliced(inputs, self.axis, &self.input_shapes); 75 | let r_enc = outputs[0]; 76 | 77 | for indices in ndarray::indices(r.shape()) { 78 | let r_val = &r[&indices]; 79 | let r_enc_val = &r_enc[&indices]; 80 | assert!(r_val == r_enc_val, "Mismatch at indices {:?}", indices); 81 | } 82 | } 83 | vec![] 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/layer.rs: -------------------------------------------------------------------------------- 1 | use crate::graph::Graph; 2 | pub use and::AndLayer; 3 | pub use arithmetic::{AddLayer, SubLayer}; 4 | use ark_bn254::Fr; 5 | pub use cast::CastLayer; 6 | pub use clip::ClipLayer; 7 | pub use concat::ConcatLayer; 8 | pub use constantofshape::ConstOfShapeLayer; 9 | pub use conv::{ConvLayer, ConvTransposeLayer}; 10 | pub use div::{DivLayer, ModLayer}; 11 | pub use einsum::EinsumLayer; 12 | pub use equal::EqualLayer; 13 | pub use expand::ExpandLayer; 14 | pub use flatten::FlattenLayer; 15 | pub use gather::GatherLayer; 16 | pub use gathernd::GatherNDLayer; 17 | pub use gemm::GemmLayer; 18 | pub use less::LessLayer; 19 | pub use lstm::LSTMLayer; 20 | pub use matmul::{MatMulLayer, MultiHeadMatMulLayer}; 21 | pub use max::{MaxLayer, MinLayer}; 22 | pub use mul::MulLayer; 23 | use ndarray::ArrayD; 24 | pub use neg::NegLayer; 25 | pub use new_conv::{ConcatConv3dLayer, Conv2dLayer, Conv3dLayer, Conv3dTransposeLayer, MultiHeadConv2dLayer}; 26 | pub use new_maxpool::MaxPool2dLayer; 27 | pub use nonlinear::*; 28 | pub use norm::{BatchNormLayer, CustomInstanceNormLayer, InstanceNormLayer}; 29 | pub use not::NotLayer; 30 | pub use pow::PowLayer; 31 | pub use r#where::WhereLayer; 32 | pub use range::RangeLayer; 33 | pub use reducemean::ReduceMeanLayer; 34 | pub use reshape::{ReshapeLayer, ReshapeTransLayer}; 35 | pub use resize::{CustomResizeLayer, ResizeLayer}; 36 | pub use rope::{RopeConstLayer, RopeRotateLayer}; 37 | pub use scatternd::ScatterNDLayer; 38 | pub use shape::ShapeLayer; 39 | pub use slice::SliceLayer; 40 | pub use softmax::SoftmaxLayer; 41 | pub use split::SplitLayer; 42 | pub use sqrt::SqrtLayer; 43 | pub use squeeze::{SqueezeLayer, UnsqueezeLayer}; 44 | pub use tile::TileLayer; 45 | pub use topk::{ArgMaxLayer, TopKLayer}; 46 | use tract_onnx::{pb::AttributeProto, prelude::DatumType}; 47 | pub use transpose::TransposeLayer; 48 | pub use xor::XorLayer; 49 | 50 | pub mod and; 51 | pub mod arithmetic; 52 | pub mod cast; 53 | pub mod clip; 54 | pub mod concat; 55 | pub mod constantofshape; 56 | pub mod conv; 57 | pub mod div; 58 | pub mod einsum; 59 | pub mod equal; 60 | pub mod expand; 61 | pub mod flatten; 62 | pub mod gather; 63 | pub mod gathernd; 64 | pub mod gemm; 65 | pub mod less; 66 | pub mod lstm; 67 | pub mod matmul; 68 | pub mod max; 69 | pub mod mul; 70 | pub mod neg; 71 | pub mod new_conv; 72 | pub mod new_maxpool; 73 | pub mod nonlinear; 74 | pub mod norm; 75 | pub mod not; 76 | pub mod pool; 77 | pub mod pow; 78 | pub mod range; 79 | pub mod reducemean; 80 | pub mod reshape; 81 | pub mod resize; 82 | pub mod rope; 83 | pub mod scatternd; 84 | pub mod shape; 85 | pub mod slice; 86 | pub mod softmax; 87 | pub mod split; 88 | pub mod sqrt; 89 | pub mod squeeze; 90 | pub mod tile; 91 | pub mod topk; 92 | pub mod transpose; 93 | pub mod r#where; 94 | pub mod xor; 95 | 96 | // Most output types will only depend on an input type but for e.g., Range layer depends on the type of the constants 97 | pub trait Layer { 98 | fn graph( 99 | input_shapes: &Vec<&Vec>, 100 | input_types: &Vec, 101 | constants: &Vec, DatumType)>>, 102 | attributes: &Vec<&AttributeProto>, 103 | ) -> (Graph, Vec>, Vec); 104 | } 105 | -------------------------------------------------------------------------------- /src/layer/less.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ndarray::{arr1, ArrayD}; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | // Less layer performs `less`, an element-wise logical comparison of two tensors. 12 | pub struct LessLayer; 13 | impl Layer for LessLayer { 14 | fn graph( 15 | input_shapes: &Vec<&Vec>, 16 | _input_types: &Vec, 17 | _constants: &Vec, DatumType)>>, 18 | _attributes: &Vec<&AttributeProto>, 19 | ) -> (Graph, Vec>, Vec) { 20 | // Inputs: A, B 21 | // Outputs: L = (A < B); then 1 - L = (A >= B). We can view them as selection of indices. 22 | // Check 1: (A - B) * L + (-1) * (1 - L) < 0 because A - B will always < 0 at indices of A < B and we set values at other indices as -1 23 | // Check 1 is equivalent to (A - B) * L - (1 - L) < 0 24 | // Check 2: 0 * L + (A - B) * (1 - L) >= 0 because A - B will always >= 0 at indices of A >= B and we set values at other indices as 0 25 | // Check 2 is equivalent to (A - B) * (1 - L) >= 0 26 | let mut graph = Graph::new(); 27 | 28 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize; 29 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 30 | basic_block: Box::new(MulBasicBlock { len }), 31 | N: 1, 32 | })); 33 | let less = graph.addBB(Box::new(RepeaterBasicBlock { 34 | basic_block: Box::new(LessBasicBlock {}), 35 | N: 1, 36 | })); 37 | let one = graph.addBB(Box::new(Const2BasicBlock { 38 | c: arr1(&vec![Fr::from(1); *input_shapes[0].last().unwrap()]).into_dyn(), 39 | })); 40 | let sub = graph.addBB(Box::new(RepeaterBasicBlock { 41 | basic_block: Box::new(SubBasicBlock {}), 42 | N: 1, 43 | })); 44 | let negative_check = graph.addBB(Box::new(RepeaterBasicBlock { 45 | basic_block: Box::new(CQBasicBlock { 46 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 47 | setup: util::CQArrayType::Negative, 48 | }), 49 | N: 1, 50 | })); 51 | let non_negative_check = graph.addBB(Box::new(RepeaterBasicBlock { 52 | basic_block: Box::new(CQBasicBlock { 53 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 54 | setup: util::CQArrayType::NonNegative, 55 | }), 56 | N: 1, 57 | })); 58 | 59 | let one_output = graph.addNode(one, vec![]); 60 | let less_output = graph.addNode(less, vec![(-1, 0), (-2, 0)]); 61 | let one_minus_less_output = graph.addNode(sub, vec![(one_output, 0), (less_output, 0)]); 62 | let sub_output = graph.addNode(sub, vec![(-1, 0), (-2, 0)]); 63 | let mul1_output = graph.addNode(mul, vec![(sub_output, 0), (less_output, 0)]); 64 | let check1_output = graph.addNode(sub, vec![(mul1_output, 0), (one_minus_less_output, 0)]); 65 | let check2_output = graph.addNode(mul, vec![(sub_output, 0), (one_minus_less_output, 0)]); 66 | let _ = graph.addNode(negative_check, vec![(check1_output, 0)]); 67 | let _ = graph.addNode(non_negative_check, vec![(check2_output, 0)]); 68 | graph.outputs.push((less_output, 0)); 69 | (graph, vec![util::broadcastDims(input_shapes, 0)], vec![DatumType::Bool]) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/layer/transpose.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::ArrayD; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | pub struct TransposeLayer; 11 | impl Layer for TransposeLayer { 12 | fn graph( 13 | input_shapes: &Vec<&Vec>, 14 | input_types: &Vec, 15 | _constants: &Vec, DatumType)>>, 16 | attributes: &Vec<&AttributeProto>, 17 | ) -> (Graph, Vec>, Vec) { 18 | let mut graph = Graph::new(); 19 | 20 | let axes: Vec<_> = attributes.iter().filter(|x| x.name == "perm").next().unwrap().ints.iter().map(|x| *x as usize).collect(); 21 | let n = axes.len(); 22 | let endShape = axes.iter().map(|i| input_shapes[0][*i]).collect(); 23 | 24 | if *axes.last().unwrap() == n - 1 { 25 | let transpose = graph.addBB(Box::new(TransposeBasicBlock { perm: axes.clone() })); 26 | let output = graph.addNode(transpose, vec![(-1, 0)]); 27 | graph.outputs.push((output, 0)); 28 | } else { 29 | let pos = axes.iter().position(|&x| x == n - 1).unwrap(); 30 | let mut perm = axes.clone(); 31 | // Keep n-1 in last index and move axes[n-1] to n-2 32 | // With this, we have the values 33 | // If pos != n-2, n-1 34 | // pos n-2 n-1 35 | // axes[n-2] axes[n-1] n-1 36 | // If pos == n-2 37 | // pos/n-2 n-1 38 | // axes[n-1] n-1 39 | // pos == n-1 is covered by the if case 40 | perm[pos] = axes[n - 2]; 41 | perm[n - 2] = axes[n - 1]; 42 | perm[n - 1] = n - 1; 43 | let transpose = graph.addBB(Box::new(TransposeBasicBlock { perm })); 44 | let intermediate = graph.addNode(transpose, vec![(-1, 0)]); 45 | // Swap the last two 46 | // If pos != n-2, n-1 47 | // pos n-2 n-1 48 | // axes[n-2] n-1 axes[n-1] 49 | // If pos == n-2 50 | // pos/n-2 n-1 51 | // n-1 axes[n-1] 52 | let (a, b) = (n - 1, axes[n - 1]); 53 | let (mut c, mut d) = (input_shapes[0][a], input_shapes[0][b]); 54 | c = util::next_pow(c as u32) as usize; 55 | d = util::next_pow(d as u32) as usize; 56 | let permutation = ((0..c).map(|x| x * d).collect(), (0..d).collect()); 57 | let permute = graph.addBB(Box::new(RepeaterBasicBlock { 58 | basic_block: Box::new(PermuteBasicBlock { permutation, n: d, m: c }), 59 | N: 2, 60 | })); 61 | let permute_output = graph.addNode(permute, vec![(intermediate, 0)]); 62 | // If pos swap happened, correct the swap 63 | // If pos != n-2, n-1 64 | // swap swaps pos and n-2 indices 65 | // pos n-2 n-1 66 | // n-1 axes[n-2] axes[n-1] 67 | // If pos == n-2 68 | // pos/n-2 n-1 69 | // n-1 axes[n-1] 70 | let output = if pos == n - 2 { 71 | permute_output 72 | } else { 73 | let mut swap: Vec<_> = (0..n).collect(); 74 | swap[pos] = n - 2; 75 | swap[n - 2] = pos; 76 | let transpose_1 = graph.addBB(Box::new(TransposeBasicBlock { perm: swap })); 77 | graph.addNode(transpose_1, vec![(permute_output, 0)]) 78 | }; 79 | graph.outputs.push((output, 0)); 80 | } 81 | 82 | (graph, vec![endShape], vec![input_types[0]]) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/bin/search_sf.rs: -------------------------------------------------------------------------------- 1 | use core::panic; 2 | use zk_torch::{onnx, util, CONFIG}; 3 | 4 | // Run the witness_gen binary with a new sf and return the std output 5 | fn witness_gen(onnx_file_name: &str) -> String { 6 | let (graph, models) = onnx::load_file(onnx_file_name); 7 | 8 | let input_path = &CONFIG.onnx.input_path; 9 | let inputs = if std::path::Path::new(input_path).exists() { 10 | util::load_inputs_from_json_for_onnx(onnx_file_name, input_path) 11 | } else { 12 | util::generate_fake_inputs_for_onnx(onnx_file_name) 13 | }; 14 | let inputs = inputs.iter().map(|x| x).collect(); 15 | let models = models.iter().map(|x| &x.0).collect(); 16 | let outputs = graph.run(&inputs, &models); 17 | if outputs.is_err() { 18 | "CQ Error".to_string() 19 | } else { 20 | "Success".to_string() 21 | } 22 | } 23 | 24 | fn update_sf(new_sf_log: usize) { 25 | let mut sf_log = onnx::SF_LOG.write().unwrap(); 26 | *sf_log = new_sf_log; 27 | drop(sf_log); 28 | let mut sf = onnx::SF.write().unwrap(); 29 | *sf = 1 << new_sf_log; 30 | drop(sf); 31 | let mut sf_float = onnx::SF_FLOAT.write().unwrap(); 32 | *sf_float = (1 << new_sf_log) as f32; 33 | drop(sf_float); 34 | } 35 | 36 | // Given the CQ range, search for the optimal scale factor for the given model 37 | fn search_optimal_sf(onnx_file_name: &str, cq_range_log: usize) -> usize { 38 | let loaded_pow_len_log = CONFIG.ptau.loaded_pow_len_log; 39 | assert!(cq_range_log < loaded_pow_len_log); 40 | let mut min_sf = 0; 41 | let mut max_sf = cq_range_log - 1; 42 | let mut current_sf = 0; 43 | let mut opt_sf = 0; 44 | let mut prev_sfs: Vec = Vec::new(); 45 | // Binary search for the optimal scale factor 46 | // In each iteration, we try with the new scale factor, which is 47 | // the average of the min and max scale factors. 48 | while min_sf <= max_sf && prev_sfs.iter().find(|&&x| x == current_sf).is_none() { 49 | println!("==> Trying scale factor: 2^{}", current_sf); 50 | 51 | // Update the global scale factor by the new scale factor 52 | update_sf(current_sf); 53 | 54 | let stdout = witness_gen(onnx_file_name); 55 | 56 | // Check if the std output contains success message 57 | if stdout.contains("Success") { 58 | // If the output contains "Success", then the 59 | // optimal scale factor may be larger than the current scale factor 60 | // Set the minimum scale factor to the current scale factor 61 | min_sf = current_sf; 62 | opt_sf = current_sf; 63 | } else { 64 | // If the output does not contain "Success", then the current scale factor too high 65 | // Set the maximum scale factor to the current scale factor 66 | if current_sf == 0 { 67 | // If the current scale factor is 0, then the CQ range is too small for the given circuit 68 | panic!("CQ range is too small for the given circuit"); 69 | } 70 | max_sf = current_sf; 71 | } 72 | prev_sfs.push(current_sf); 73 | current_sf = ((min_sf + max_sf) as f64 / 2.0).round() as usize; 74 | } 75 | opt_sf 76 | } 77 | 78 | fn main() { 79 | let cq_range_log = CONFIG.sf.cq_range_log; 80 | let onnx_file_name = &CONFIG.onnx.model_path; 81 | let optimal_sf = search_optimal_sf(onnx_file_name, cq_range_log); 82 | println!("==> Given the CQ range, the optimal scale factor for this model is 2^{}", optimal_sf); 83 | println!("==> Please set 'scale_factor_log={}' in the config file", optimal_sf); 84 | } 85 | -------------------------------------------------------------------------------- /src/layer/rope.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ark_std::Zero; 8 | use ndarray::Dimension; 9 | use ndarray::{ArrayD, IxDyn}; 10 | use tract_onnx::pb::tensor_proto::DataType; 11 | use tract_onnx::pb::AttributeProto; 12 | use tract_onnx::prelude::DatumType; 13 | 14 | // Rotate the last dimension of the input tensor: 15 | // [x_0, x_1, x_2, x_3, ..., x_{n-2}, x_{n-1}] -> [x_1, x_0, x_3, x_2, ..., x_{n-1}, x_{n-2}] 16 | fn get_rope_rotate_indices(input_shape: &Vec) -> ArrayD> { 17 | let indices = ArrayD::from_shape_fn(input_shape.as_slice(), |index| { 18 | let index_len = index.ndim(); 19 | let index_last = index[index_len - 1]; 20 | let new_index_last = 2 * (index_last / 2) + 1 - (index_last % 2); 21 | let mut index = index.clone(); 22 | index[index_len - 1] = new_index_last; 23 | Some(index) 24 | }); 25 | let indices = util::pad_to_pow_of_two(&indices, &None); 26 | indices 27 | } 28 | 29 | // Generate a tensor with a given value (the value is in the ONNX attribute) and shape (the shape is in the input tensor) 30 | pub struct RopeConstLayer; 31 | impl Layer for RopeConstLayer { 32 | fn graph( 33 | _input_shapes: &Vec<&Vec>, 34 | input_types: &Vec, 35 | constants: &Vec, DatumType)>>, 36 | _attributes: &Vec<&AttributeProto>, 37 | ) -> (Graph, Vec>, Vec) { 38 | let mut graph = Graph::new(); 39 | let inputShape: Vec = 40 | constants[0].unwrap().0.as_slice().unwrap().iter().map(|x| util::fr_to_int(*x) as usize).filter(|x| *x != 0).collect(); 41 | let inputShape_last = inputShape[inputShape.len() - 1]; 42 | // -1, 1, -1, 1, ... 43 | let rope_constant = (0..inputShape_last).map(|i| Fr::from((-1 as i32).pow((i + 1) as u32)).into()).collect(); 44 | let mut endShape: Vec = (0..inputShape.len()).map(|_| 1).collect(); 45 | let endShape_len = endShape.len(); 46 | endShape[endShape_len - 1] = inputShape_last; 47 | let rope_constant_tensor = ArrayD::from_shape_vec(endShape.clone(), rope_constant).unwrap(); 48 | let rope_constant_tensor = util::pad_to_pow_of_two(&rope_constant_tensor, &Fr::zero()); 49 | 50 | let constant = graph.addBB(Box::new(Const2BasicBlock { c: rope_constant_tensor })); 51 | let output = graph.addNode(constant, vec![]); 52 | graph.outputs.push((output, 0)); 53 | (graph, vec![endShape], vec![input_types[0]]) 54 | } 55 | } 56 | 57 | pub struct RopeRotateLayer; 58 | impl Layer for RopeRotateLayer { 59 | fn graph( 60 | input_shapes: &Vec<&Vec>, 61 | input_types: &Vec, 62 | _constants: &Vec, DatumType)>>, 63 | _attributes: &Vec<&AttributeProto>, 64 | ) -> (Graph, Vec>, Vec) { 65 | let mut graph = Graph::new(); 66 | let inputShape = input_shapes[0].clone(); 67 | let startShape_padded: Vec = inputShape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 68 | let permutation = get_rope_rotate_indices(&inputShape); 69 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 70 | permutation: permutation.clone(), 71 | input_dim: IxDyn(&startShape_padded), 72 | padding_partition: copy_constraint::PaddingEnum::Zero, 73 | })); 74 | let output = graph.addNode(cc, vec![(-1, 0)]); 75 | graph.outputs.push((output, 0)); 76 | (graph, vec![inputShape], vec![input_types[0]]) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/layer/split.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::ArrayD; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | pub struct SplitLayer; 11 | impl Layer for SplitLayer { 12 | fn graph( 13 | input_shapes: &Vec<&Vec>, 14 | input_types: &Vec, 15 | _constants: &Vec, DatumType)>>, 16 | attributes: &Vec<&AttributeProto>, 17 | ) -> (Graph, Vec>, Vec) { 18 | let mut graph = Graph::new(); 19 | 20 | // This code is for Opset 11; in the latest version of ONNX, "split" is in inputs instead of the attributes 21 | let axis: isize = attributes.iter().filter(|x| x.name == "axis").next().unwrap().i as isize; 22 | let axis = (if axis < 0 { input_shapes[0].len() as isize + axis } else { axis }) as usize; 23 | let split = attributes.iter().filter(|x| x.name == "split").next().unwrap().ints.iter().map(|x| *x as usize).collect::>(); 24 | 25 | let mut outputShapes = vec![]; 26 | for i in 0..split.len() { 27 | let mut outputShape = input_shapes[0].clone(); 28 | outputShape[axis] = split[i]; 29 | outputShapes.push(outputShape); 30 | } 31 | 32 | if axis == input_shapes[0].len() - 1 { 33 | // permute inputs 34 | let n = input_shapes[0].len(); 35 | let mut a = input_shapes[0][n - 2]; 36 | let mut b = input_shapes[0][n - 1]; 37 | (a, b) = (util::next_pow(a as u32) as usize, util::next_pow(b as u32) as usize); 38 | let permutation = ((0..b).map(|x| x * a).collect(), (0..a).map(|x| x).collect()); 39 | let permute = graph.addBB(Box::new(RepeaterBasicBlock { 40 | basic_block: Box::new(PermuteBasicBlock { 41 | permutation: permutation, 42 | n: a, 43 | m: b, 44 | }), 45 | N: 2, 46 | })); 47 | let split_bb = graph.addBB(Box::new(SplitBasicBlock { 48 | axis: (axis - 1) as usize, 49 | split: split.clone(), 50 | })); 51 | let mut permute_backs = vec![]; 52 | for i in 0..split.len() { 53 | let (mut a, mut b) = (outputShapes[i][n - 2], outputShapes[i][n - 1]); 54 | (a, b) = (util::next_pow(a as u32) as usize, util::next_pow(b as u32) as usize); 55 | let permutation_back = ((0..a).map(|x| x * b).collect(), (0..b).collect()); 56 | let permute_back = graph.addBB(Box::new(RepeaterBasicBlock { 57 | basic_block: Box::new(PermuteBasicBlock { 58 | permutation: permutation_back, 59 | n: b, 60 | m: a, 61 | }), 62 | N: 2, 63 | })); 64 | permute_backs.push(permute_back); 65 | } 66 | 67 | let permute_output = graph.addNode(permute, vec![(-1, 0)]); 68 | let split_output = graph.addNode(split_bb, vec![(permute_output, 0)]); 69 | for i in 0..split.len() { 70 | let output = graph.addNode(permute_backs[i], vec![(split_output, i)]); 71 | graph.outputs.push((output, 0)); 72 | } 73 | } else { 74 | let split_bb = graph.addBB(Box::new(SplitBasicBlock { 75 | axis: axis as usize, 76 | split: split.clone(), 77 | })); 78 | let split_output = graph.addNode(split_bb, vec![(-1, 0)]); 79 | for i in 0..split.len() { 80 | graph.outputs.push((split_output, i)); 81 | } 82 | } 83 | 84 | let num_outputs = outputShapes.len(); 85 | (graph, outputShapes, vec![input_types[0]; num_outputs]) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/layer/expand.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ark_std::One; 7 | use ndarray::ArrayD; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | #[derive(Debug)] 12 | pub struct ExpandBasicBlock; 13 | impl BasicBlock for ExpandBasicBlock { 14 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 15 | let newShape: Vec<_> = inputs[1].as_slice().unwrap().iter().map(|&x| util::fr_to_int(x) as usize).filter(|x| *x != 0).collect(); 16 | let padded_newShape: Vec<_> = newShape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 17 | Ok(vec![inputs[0].broadcast(padded_newShape).unwrap().into_owned()]) 18 | } 19 | } 20 | 21 | pub struct ExpandLayer; 22 | impl Layer for ExpandLayer { 23 | fn graph( 24 | input_shapes: &Vec<&Vec>, 25 | input_types: &Vec, 26 | constants: &Vec, DatumType)>>, 27 | _attributes: &Vec<&AttributeProto>, 28 | ) -> (Graph, Vec>, Vec) { 29 | let shape0 = input_shapes[0].clone(); 30 | let shape1: Vec<_> = constants[1].unwrap().0.as_slice().unwrap().iter().map(|&x| util::fr_to_int(x) as usize).filter(|x| *x != 0).collect(); 31 | let newShape = vec![shape0.clone(), shape1.clone()]; 32 | let newShape: Vec<_> = newShape.iter().map(|x| x).collect(); 33 | let newShape = util::broadcastDims(&newShape, 0); 34 | let newShape_padded: Vec<_> = newShape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 35 | 36 | let mut graph = Graph::new(); 37 | // check if the last dimension of the input shape is equal to the last dimension of the new shape 38 | // and the product of the input shape is less than the product of the new shape 39 | // if so, use ExpandBasicBlock without proving. Otherwise, use ConstOfShapeBasicBlock and MulScalarBasicBlock 40 | let shape0_product = shape0.iter().fold(1, |acc, x| acc * x); 41 | let shape1_product = shape1.iter().fold(1, |acc, x| acc * x); 42 | if *input_shapes[0].last().unwrap() == *newShape.clone().last().unwrap() && shape0_product <= shape1_product { 43 | let expand = graph.addBB(Box::new(ExpandBasicBlock {})); 44 | let expand_output = graph.addNode(expand, vec![(-1, 0), (-2, 0)]); 45 | graph.outputs.push((expand_output, 0)); 46 | } else { 47 | let constantOfShape = graph.addBB(Box::new(ConstOfShapeBasicBlock { 48 | c: Fr::one(), 49 | shape: newShape_padded.clone(), 50 | })); 51 | let mul_scalar = graph.addBB(Box::new(RepeaterBasicBlock { 52 | basic_block: Box::new(MulScalarBasicBlock {}), 53 | N: 1, 54 | })); 55 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize; 56 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 57 | basic_block: Box::new(MulBasicBlock { len }), 58 | N: 1, 59 | })); 60 | let constantOfShape_output = graph.addNode(constantOfShape, vec![]); 61 | let expand_output = if *input_shapes[0].last().unwrap() == 1 { 62 | graph.addNode(mul_scalar, vec![(constantOfShape_output, 0), (-1, 0)]) 63 | } else if *newShape_padded.last().unwrap() == 1 { 64 | graph.addNode(mul_scalar, vec![(-1, 0), (constantOfShape_output, 0)]) 65 | } else { 66 | graph.addNode(mul, vec![(-1, 0), (constantOfShape_output, 0)]) 67 | }; 68 | graph.outputs.push((expand_output, 0)); 69 | } 70 | 71 | (graph, vec![newShape], vec![input_types[0]]) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/util/poly.rs: -------------------------------------------------------------------------------- 1 | use ark_bn254::Fr; 2 | use ark_ff::Zero; 3 | use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, GeneralEvaluationDomain, Polynomial}; 4 | use ark_std::One; 5 | use rayon::prelude::*; 6 | 7 | fn elementwise_product(vecs: &[Vec]) -> Vec 8 | where 9 | T: std::iter::Product + std::marker::Send + std::marker::Sync + Copy + std::ops::Mul + 'static, 10 | { 11 | // Assuming vecs is non-empty and all vectors have the same length 12 | let m = vecs[0].len(); 13 | 14 | (0..m).into_par_iter().map(|i| vecs[0][i] * vecs[1][i]).collect() 15 | } 16 | 17 | pub fn mul_two_polys(polys: &Vec>) -> DensePolynomial { 18 | assert!(polys.len() == 2); 19 | if polys[0].is_zero() || polys[1].is_zero() { 20 | return DensePolynomial::zero(); 21 | } 22 | let N: usize = polys.iter().map(|p| p.coeffs.len()).sum(); 23 | let domain = GeneralEvaluationDomain::::new(N); 24 | if domain.is_some() { 25 | let domain = domain.unwrap(); 26 | let p_evals = polys.par_iter().map(|p| domain.fft(&p.coeffs)).collect::>(); 27 | let p_evals = elementwise_product(&p_evals); 28 | DensePolynomial::from_coefficients_vec(domain.ifft(&p_evals)) 29 | } else { 30 | karatsuba_multiply(&polys[0], &polys[1]) 31 | } 32 | } 33 | 34 | fn mul_by_xn(poly: &DensePolynomial, n: usize) -> DensePolynomial { 35 | let mut new_coeffs = vec![Fr::zero(); n]; 36 | new_coeffs.extend(poly.coeffs().iter().cloned()); 37 | DensePolynomial::from_coefficients_vec(new_coeffs) 38 | } 39 | 40 | fn karatsuba_multiply(a: &DensePolynomial, b: &DensePolynomial) -> DensePolynomial { 41 | let n = std::cmp::max(a.degree(), b.degree()) + 1; 42 | if n <= 1 << 27 { 43 | return a * b; 44 | } 45 | let m = n / 2; 46 | let (a0, a1) = karatsuba_split(a, m); 47 | let (b0, b1) = karatsuba_split(b, m); 48 | let z0 = karatsuba_multiply(&a0, &b0); 49 | let z2 = karatsuba_multiply(&a1, &b1); 50 | let a0_plus_a1 = &a0 + &a1; 51 | let b0_plus_b1 = &b0 + &b1; 52 | let z1 = karatsuba_multiply(&a0_plus_a1, &b0_plus_b1); 53 | let mut result = mul_by_xn(&z2, 2 * m); 54 | result += &z0; 55 | result = &result + &mul_by_xn(&(&(&z1 - &z2) - &z0), m); 56 | result 57 | } 58 | 59 | // Multiply a list of polynomials in parallel 60 | // TODO: explore if there exists a more efficient parallel algorithm 61 | pub fn mul_polys(polys: &Vec>) -> DensePolynomial { 62 | // Base case: if the list has only one polynomial, return it directly 63 | if polys.len() == 1 { 64 | return polys[0].clone(); 65 | } 66 | 67 | // Parallel recursive case: pairwise multiply the polynomials 68 | let next_level: Vec> = polys 69 | .par_chunks(2) // Parallelize processing in chunks of 2 70 | .map(|chunk| { 71 | if chunk.len() == 2 { 72 | // If there are two polynomials in the chunk, multiply them 73 | mul_two_polys(&vec![chunk[0].clone(), chunk[1].clone()]) 74 | } else { 75 | // If there's only one polynomial in the chunk, return it 76 | chunk[0].clone() 77 | } 78 | }) 79 | .collect(); 80 | 81 | // Recursively call mul_polys on the next level until we get the root 82 | mul_polys(&next_level) 83 | } 84 | 85 | fn karatsuba_split(p: &DensePolynomial, m: usize) -> (DensePolynomial, DensePolynomial) { 86 | let coeffs = p.coeffs(); 87 | let low = DensePolynomial::from_coefficients_vec(coeffs[..m.min(coeffs.len())].to_vec()); 88 | let high = DensePolynomial::from_coefficients_vec(coeffs[m.min(coeffs.len())..].to_vec()); 89 | (low, high) 90 | } 91 | 92 | // Split poly into degree n-1 polynomials 93 | pub fn split_polynomial(poly: &DensePolynomial, n: usize) -> Vec> { 94 | poly.coeffs().chunks(n).map(|chunk| DensePolynomial::from_coefficients_vec(chunk.to_vec())).collect() 95 | } 96 | -------------------------------------------------------------------------------- /src/layer/tile.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::{squeeze::UnsqueezeBasicBlock, Layer}; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::{concatenate, ArrayD, Axis, IxDyn}; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | // Helper function to get the indices of the tiled tensor 11 | fn get_tile_indices(input_shape: Vec, repeats: Vec) -> ArrayD> { 12 | let output_shape: Vec<_> = input_shape.iter().zip(repeats.iter()).map(|(x, y)| x * y).collect(); 13 | let padded_output_shape: Vec<_> = output_shape.iter().map(|x| util::next_pow(*x as u32) as usize).collect(); 14 | 15 | // first generate indices for the input tensor 16 | let mut tiled = ArrayD::from_shape_fn(input_shape.as_slice(), |index| Some(index.clone())); 17 | // then repeat the indices r time(s) along each axis, where r is the corresponding element in repeats 18 | for (i, repeat) in repeats.iter().enumerate() { 19 | tiled = concatenate(Axis(i), std::iter::repeat(tiled.view()).take(*repeat).collect::>().as_slice()).unwrap(); 20 | } 21 | assert!(tiled.shape() == output_shape.as_slice()); 22 | // finally pad the tiled tensor to the next power of 2 23 | let padded_tiled = util::pad_to_pow_of_two(&tiled, &None); 24 | assert!(padded_tiled.shape() == padded_output_shape.as_slice()); 25 | 26 | padded_tiled 27 | } 28 | 29 | // TileLayer is a layer that repeats the input tensor along each axis according to the repeats. 30 | // The functionality is equivalent to numpy.tile(arr, repeats) 31 | // reference: https://numpy.org/doc/stable/reference/generated/numpy.tile.html 32 | pub struct TileLayer; 33 | impl Layer for TileLayer { 34 | fn graph( 35 | input_shapes: &Vec<&Vec>, 36 | input_types: &Vec, 37 | constants: &Vec, DatumType)>>, 38 | _attributes: &Vec<&AttributeProto>, 39 | ) -> (Graph, Vec>, Vec) { 40 | let mut graph = Graph::new(); 41 | let repeats: Vec<_> = constants[1].unwrap().0.as_slice().unwrap().iter().map(|x| util::fr_to_int(*x)).collect(); 42 | let mut repeats: Vec<_> = repeats.iter().map(|x| *x as usize).collect(); 43 | 44 | let mut input_shape = input_shapes[0].clone(); 45 | let mut input_index = -1; 46 | 47 | // when the input.ndim() is shorter than repeats.len(), we need to unsqueeze the input 48 | if input_shape.len() < repeats.len() { 49 | // append 1s at the beginning of the input_shape 50 | let diff = repeats.len() - input_shape.len(); 51 | input_shape = std::iter::repeat(1).take(diff).chain(input_shape.iter().cloned()).collect::>(); 52 | let unsq = graph.addBB(Box::new(UnsqueezeBasicBlock {})); 53 | input_index = graph.addNode(unsq, vec![(input_index, 0)]); 54 | for _ in 0..diff - 1 { 55 | input_index = graph.addNode(unsq, vec![(input_index, 0)]); 56 | } 57 | // when the input.ndim() is longer than repeats.len(), we need to pad the repeats 58 | } else if input_shape.len() > repeats.len() { 59 | // append 1s at the beginning of the repeats 60 | repeats = std::iter::repeat(1).take(input_shape.len() - repeats.len()).chain(repeats.iter().cloned()).collect(); 61 | } 62 | // now input_shape.len() should be equal to repeats.len() 63 | assert!(input_shape.len() == repeats.len()); 64 | let permutation = get_tile_indices(input_shape, repeats); 65 | 66 | let padded_output_shape = permutation.shape().to_vec(); 67 | let padded_input_shape: Vec<_> = input_shapes[0].iter().map(|x| util::next_pow(*x as u32) as usize).collect(); 68 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 69 | permutation, 70 | input_dim: IxDyn(&padded_input_shape), 71 | padding_partition: copy_constraint::PaddingEnum::Zero, 72 | })); 73 | let tiled_output = graph.addNode(cc, vec![(input_index, 0)]); 74 | graph.outputs.push((tiled_output, 0)); 75 | 76 | (graph, vec![padded_output_shape], vec![input_types[0]]) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/basic_block/add.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 2 | use crate::util; 3 | use ark_bn254::{Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 4 | use ark_ff::Zero; 5 | use ark_poly::univariate::DensePolynomial; 6 | use ndarray::{arr0, azip, ArrayD, IxDyn}; 7 | use rand::rngs::StdRng; 8 | 9 | // This basic block is used to add two inputs together 10 | // Note: The inputs are expected to have at most 1 dimension 11 | #[derive(Debug)] 12 | pub struct AddBasicBlock; 13 | impl BasicBlock for AddBasicBlock { 14 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 15 | assert!(inputs.len() == 2 && inputs[0].ndim() <= 1 && inputs[1].ndim() <= 1); 16 | let mut r = ArrayD::zeros(IxDyn(&[std::cmp::max(inputs[0].len(), inputs[1].len())])); 17 | if inputs[0].len() == 1 && inputs[1].ndim() > 0 { 18 | azip!((r in &mut r, &x in inputs[1]) *r = x + inputs[0].first().unwrap()); 19 | } else if inputs[1].len() == 1 { 20 | azip!((r in &mut r, &x in inputs[0]) *r = x + inputs[1].first().unwrap()); 21 | } else { 22 | azip!((r in &mut r, &x in inputs[0], &y in inputs[1]) *r = x + y); 23 | } 24 | Ok(vec![r]) 25 | } 26 | 27 | fn encodeOutputs(&self, _srs: &SRS, _model: &ArrayD, inputs: &Vec<&ArrayD>, outputs: &Vec<&ArrayD>) -> Vec> { 28 | let a = &inputs[0].first().unwrap(); 29 | let b = &inputs[1].first().unwrap(); 30 | vec![arr0(Data { 31 | raw: outputs[0].clone().into_raw_vec(), 32 | poly: (&a.poly) + (&b.poly), 33 | g1: a.g1 + b.g1, 34 | r: a.r + b.r, 35 | }) 36 | .into_dyn()] 37 | } 38 | 39 | fn verify( 40 | &self, 41 | _srs: &SRS, 42 | _model: &ArrayD, 43 | inputs: &Vec<&ArrayD>, 44 | outputs: &Vec<&ArrayD>, 45 | _proof: (&Vec, &Vec, &Vec), 46 | _rng: &mut StdRng, 47 | _cache: ProveVerifyCache, 48 | ) -> Vec { 49 | let a = inputs[0].first().unwrap(); 50 | let b = inputs[1].first().unwrap(); 51 | let c = outputs[0].first().unwrap(); 52 | // Verify f(x)+g(x)=h(x) 53 | assert!(a.g1 + b.g1 == c.g1); 54 | vec![] 55 | } 56 | } 57 | 58 | // This basic block is used to add multiple inputs together 59 | // Note: The inputs are expected to have the same shape 60 | #[derive(Debug)] 61 | pub struct BatchAddBasicBlock; 62 | impl BasicBlock for BatchAddBasicBlock { 63 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 64 | // The inputs are expected to have the same shape 65 | assert!(inputs.iter().all(|x| x.shape() == inputs[0].shape())); 66 | let mut r = ArrayD::zeros(IxDyn(&[inputs[0].len()])); 67 | for i in 0..inputs.len() { 68 | azip!((r in &mut r, &x in inputs[i]) *r = *r + x); 69 | } 70 | Ok(vec![r]) 71 | } 72 | 73 | fn encodeOutputs(&self, _srs: &SRS, _model: &ArrayD, inputs: &Vec<&ArrayD>, outputs: &Vec<&ArrayD>) -> Vec> { 74 | vec![arr0(Data { 75 | raw: outputs[0].clone().into_raw_vec(), 76 | poly: inputs.iter().fold(DensePolynomial::zero(), |acc, x| acc + x.first().unwrap().poly.clone()), 77 | g1: inputs.iter().fold(G1Projective::zero(), |acc, x| acc + x.first().unwrap().g1), 78 | r: inputs.iter().fold(Fr::zero(), |acc, x| acc + x.first().unwrap().r), 79 | }) 80 | .into_dyn()] 81 | } 82 | 83 | fn verify( 84 | &self, 85 | _srs: &SRS, 86 | _model: &ArrayD, 87 | inputs: &Vec<&ArrayD>, 88 | outputs: &Vec<&ArrayD>, 89 | _proof: (&Vec, &Vec, &Vec), 90 | _rng: &mut StdRng, 91 | _cache: ProveVerifyCache, 92 | ) -> Vec { 93 | let inputs_g1 = inputs.iter().fold(G1Projective::zero(), |acc, x| acc + x.first().unwrap().g1); 94 | let c_g1 = outputs[0].first().unwrap().g1; 95 | // Verify f1(x)+f2(x)+...+fn(x)=h(x) 96 | assert!(inputs_g1 == c_g1); 97 | vec![] 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/basic_block/range.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | #![allow(non_upper_case_globals)] 3 | use super::{BasicBlock, Data, DataEnc, PairingCheck, ProveVerifyCache, SRS}; 4 | use crate::{ 5 | onnx, 6 | util::{self, calc_pow}, 7 | }; 8 | use ark_bn254::{Bn254, Fr, G1Affine, G1Projective, G2Affine, G2Projective}; 9 | use ark_ec::{pairing::Pairing, AffineRepr}; 10 | use ark_ff::Field; 11 | use ark_poly::{ 12 | evaluations::univariate::Evaluations, univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, GeneralEvaluationDomain, Polynomial, 13 | }; 14 | use ark_serialize::CanonicalSerialize; 15 | use ark_std::{ 16 | ops::{Add, Mul, Sub}, 17 | One, UniformRand, Zero, 18 | }; 19 | use ndarray::{arr1, indices, ArrayD, ArrayView, ArrayView1, ArrayViewD, Axis, Dim, Dimension, IxDyn, IxDynImpl, NdIndex, Shape, Zip}; 20 | use rand::{rngs::StdRng, SeedableRng}; 21 | use rayon::prelude::*; 22 | use std::{ 23 | cmp::{max, min}, 24 | collections::HashMap, 25 | iter::{once, repeat}, 26 | }; 27 | 28 | // RangeConstBasicBlock is a basic block that creates a tensor of a range of values. 29 | // The range is defined by three constants: the start, limit, and delta values. 30 | #[derive(Debug)] 31 | pub struct RangeConstBasicBlock { 32 | pub start: i128, 33 | pub limit: i128, 34 | pub delta: i128, 35 | } 36 | impl BasicBlock for RangeConstBasicBlock { 37 | fn run(&self, _model: &ArrayD, _inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 38 | let element_num = max(0, ((self.limit - self.start) + self.delta - 1) / self.delta); 39 | let mut r = vec![]; 40 | let mut x = self.start; 41 | while x < self.limit { 42 | r.push(Fr::from(x)); 43 | x += self.delta; 44 | } 45 | let element_num_pad = util::next_pow(element_num as u32) as usize; 46 | while r.len() < element_num_pad { 47 | r.push(Fr::zero()); 48 | } 49 | Ok(vec![arr1(&r).into_dyn()]) 50 | } 51 | 52 | #[cfg(not(feature = "mock_prove"))] 53 | fn setup(&self, srs: &SRS, _model: &ArrayD) -> (Vec, Vec, Vec>) { 54 | let element_num = max(0, ((self.limit - self.start) + self.delta - 1) / self.delta); 55 | let element_num_pad = util::next_pow(element_num as u32) as usize; 56 | let domain = GeneralEvaluationDomain::::new(element_num_pad.clone() as usize).unwrap(); 57 | 58 | let mut r = vec![]; 59 | let mut x = self.start; 60 | while x < self.limit { 61 | r.push(Fr::from(x)); 62 | x += self.delta; 63 | } 64 | while r.len() < element_num_pad { 65 | r.push(Fr::zero()); 66 | } 67 | let range_poly = DensePolynomial::from_coefficients_vec(domain.ifft(&r)); 68 | let range_x = util::msm::(&srs.X1A, &range_poly.coeffs); 69 | (vec![range_x], vec![], vec![]) 70 | } 71 | 72 | #[cfg(feature = "mock_prove")] 73 | fn setup(&self, srs: &SRS, _model: &ArrayD) -> (Vec, Vec, Vec>) { 74 | eprintln!("\x1b[93mWARNING\x1b[0m: MockSetup is enabled. This is only for testing purposes."); 75 | (vec![srs.X1P[0].clone()], vec![], vec![]) 76 | } 77 | 78 | fn prove( 79 | &self, 80 | srs: &SRS, 81 | setup: (&Vec, &Vec, &Vec>), 82 | _model: &ArrayD, 83 | _inputs: &Vec<&ArrayD>, 84 | outputs: &Vec<&ArrayD>, 85 | _rng: &mut StdRng, 86 | _cache: ProveVerifyCache, 87 | ) -> (Vec, Vec, Vec) { 88 | let C = srs.Y1P * outputs[0].first().unwrap().r; 89 | (vec![setup.0[0].into(), C.into()], vec![], vec![]) 90 | } 91 | 92 | fn verify( 93 | &self, 94 | _srs: &SRS, 95 | _model: &ArrayD, 96 | _inputs: &Vec<&ArrayD>, 97 | outputs: &Vec<&ArrayD>, 98 | proof: (&Vec, &Vec, &Vec), 99 | _rng: &mut StdRng, 100 | _cache: ProveVerifyCache, 101 | ) -> Vec { 102 | assert!(proof.0[0] + proof.0[1] == outputs[0].first().unwrap().g1); 103 | vec![] 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/layer/sqrt.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ark_std::One; 8 | use ndarray::{arr1, ArrayD}; 9 | use rayon::iter::ParallelIterator; 10 | use tract_onnx::pb::AttributeProto; 11 | use tract_onnx::prelude::DatumType; 12 | 13 | pub struct SqrtLayer; 14 | impl Layer for SqrtLayer { 15 | fn graph( 16 | input_shapes: &Vec<&Vec>, 17 | input_types: &Vec, 18 | _constants: &Vec, DatumType)>>, 19 | _attributes: &Vec<&AttributeProto>, 20 | ) -> (Graph, Vec>, Vec) { 21 | let mut graph = Graph::new(); 22 | let sf_log = onnx::SF_LOG.read().unwrap().to_owned(); 23 | let sqrt = graph.addBB(Box::new(SqrtBasicBlock { 24 | input_SF: sf_log, 25 | output_SF: sf_log, 26 | })); 27 | let sf = onnx::SF.read().unwrap().to_owned(); 28 | let sf_const = graph.addBB(Box::new(Const2BasicBlock { 29 | c: arr1(&vec![Fr::from(sf as i32)]).into_dyn(), 30 | })); 31 | let two_const = graph.addBB(Box::new(Const2BasicBlock { 32 | c: arr1(&vec![Fr::from(2)]).into_dyn(), 33 | })); 34 | let add = graph.addBB(Box::new(RepeaterBasicBlock { 35 | basic_block: Box::new(AddBasicBlock {}), 36 | N: 1, 37 | })); 38 | let sub = graph.addBB(Box::new(RepeaterBasicBlock { 39 | basic_block: Box::new(SubBasicBlock {}), 40 | N: 1, 41 | })); 42 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize; 43 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 44 | basic_block: Box::new(MulBasicBlock { len }), 45 | N: 1, 46 | })); 47 | let mul_scalar = graph.addBB(Box::new(RepeaterBasicBlock { 48 | basic_block: Box::new(MulScalarBasicBlock {}), 49 | N: 1, 50 | })); 51 | let non_negative_check = graph.addBB(Box::new(RepeaterBasicBlock { 52 | basic_block: Box::new(CQBasicBlock { 53 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 54 | setup: util::CQArrayType::NonNegative, 55 | }), 56 | N: 1, 57 | })); 58 | let negative_check = graph.addBB(Box::new(RepeaterBasicBlock { 59 | basic_block: Box::new(CQBasicBlock { 60 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 61 | setup: util::CQArrayType::NonPositive, 62 | }), 63 | N: 1, 64 | })); 65 | 66 | // SqrtBB(x) = sqrt(x/SF)*SF + eps (where -1 < eps < 1) 67 | let sqrt_output = graph.addNode(sqrt, vec![(-1, 0)]); 68 | // The following operations are to check if sqrt_output is correct 69 | // square_sqrt = SqrtBB(x)^2 = x*SF + 2*sqrt(x/SF)*SF*eps + eps^2 70 | let square_sqrt = graph.addNode(mul, vec![(sqrt_output, 0), (sqrt_output, 0)]); 71 | // scale_input_by_sf = x*SF 72 | let sf_const_output = graph.addNode(sf_const, vec![]); 73 | let scale_input_by_sf = graph.addNode(mul_scalar, vec![(-1, 0), (sf_const_output, 0)]); 74 | // difference = SqrtBB(x)^2 - x*SF = 2*sqrt(x/SF)*SF*eps + eps^2 = 2*SqrtBB(x)*eps - eps^2 75 | // Because -1 < eps < 1, -2*SqrtBB(x) < 2*SqrtBB(x)*eps < 2*SqrtBB(x) and -1 < -eps^2 < 0. 76 | // Therefore, -1 - 2*SqrtBB(x) < difference < 2*SqrtBB(x). 77 | // The following two inequalities should hold: 78 | // 1. difference + 2*SqrtBB(x) >= 0 79 | // 2. difference - 2*SqrtBB(x) < 0 80 | let difference = graph.addNode(sub, vec![(square_sqrt, 0), (scale_input_by_sf, 0)]); 81 | // scale_output_by_2 = 2*SqrtBB(x) 82 | let two_const_output = graph.addNode(two_const, vec![]); 83 | let scale_output_by_2 = graph.addNode(mul_scalar, vec![(sqrt_output, 0), (two_const_output, 0)]); 84 | // d_plus_scale_output_by_2 = difference + 2*SqrtBB(x) 85 | let d_plus_scale_output_by_2 = graph.addNode(add, vec![(difference, 0), (scale_output_by_2, 0)]); 86 | // d_minus_scale_output_by_2 = difference - 2*SqrtBB(x) 87 | let d_minus_scale_output_by_2 = graph.addNode(sub, vec![(difference, 0), (scale_output_by_2, 0)]); 88 | let _ = graph.addNode(non_negative_check, vec![(d_plus_scale_output_by_2, 0)]); 89 | let _ = graph.addNode(negative_check, vec![(d_minus_scale_output_by_2, 0)]); 90 | 91 | graph.outputs.push((sqrt_output, 0)); 92 | (graph, vec![input_shapes[0].clone()], vec![input_types[0]]) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /scratch/bert/replace_reshape_trans.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import helper, shape_inference 3 | 4 | def replace_reshape_transpose(model): 5 | graph = model.graph 6 | nodes = graph.node 7 | 8 | # Build mappings from tensor names to nodes 9 | tensor_producers = {} 10 | tensor_consumers = {} 11 | for node in nodes: 12 | for output_name in node.output: 13 | tensor_producers[output_name] = node 14 | for input_name in node.input: 15 | tensor_consumers.setdefault(input_name, []).append(node) 16 | 17 | nodes_to_remove = [] 18 | nodes_to_add = [] 19 | 20 | for idx, transpose_node in enumerate(list(nodes)): 21 | print("node") 22 | if transpose_node.op_type != 'Transpose': 23 | continue 24 | 25 | transpose_input = transpose_node.input[0] 26 | transpose_output = transpose_node.output[0] 27 | 28 | # Check if the input comes from a Reshape node 29 | reshape_node = tensor_producers.get(transpose_input) 30 | if reshape_node is None or reshape_node.op_type != 'Reshape': 31 | continue 32 | 33 | # Ensure that the Reshape node's output is only consumed by this Transpose node 34 | consumers_of_reshape_output = tensor_consumers.get(reshape_node.output[0], []) 35 | if len(consumers_of_reshape_output) != 1: 36 | continue # Can't replace if Reshape output has other consumers 37 | 38 | # Collect inputs for the custom node 39 | reshape_inputs = reshape_node.input 40 | custom_node_output = transpose_output 41 | 42 | # Get attributes from the Transpose node 43 | transpose_perm = None 44 | for attr in transpose_node.attribute: 45 | if attr.name == 'perm': 46 | transpose_perm = attr.ints 47 | 48 | # Create attributes for the custom node 49 | custom_attrs = {} 50 | if transpose_perm is not None: 51 | custom_attrs['perm'] = transpose_perm 52 | 53 | # Create the custom ReshapeTrans node 54 | custom_node = helper.make_node( 55 | 'ReshapeTrans', 56 | inputs=reshape_inputs, 57 | outputs=[custom_node_output], 58 | name='ReshapeTrans_' + str(idx), 59 | **custom_attrs 60 | ) 61 | 62 | # Record nodes to remove 63 | nodes_to_remove.extend([reshape_node, transpose_node]) 64 | 65 | # Insert the new node at the position of the Transpose node 66 | nodes_to_add.append((idx, custom_node)) 67 | 68 | # Update the mappings 69 | # Remove old producer mappings 70 | tensor_producers.pop(reshape_node.output[0], None) 71 | tensor_producers.pop(transpose_node.output[0], None) 72 | # Add new producer mapping 73 | tensor_producers[custom_node_output] = custom_node 74 | 75 | # Update consumer mappings 76 | for input_name in reshape_node.input: 77 | consumers = tensor_consumers.get(input_name, []) 78 | if reshape_node in consumers: 79 | consumers.remove(reshape_node) 80 | for input_name in transpose_node.input: 81 | consumers = tensor_consumers.get(input_name, []) 82 | if transpose_node in consumers: 83 | consumers.remove(transpose_node) 84 | for input_name in custom_node.input: 85 | tensor_consumers.setdefault(input_name, []).append(custom_node) 86 | 87 | # Insert new nodes into the graph 88 | c = 0 89 | for idx, custom_node in nodes_to_add: 90 | graph.node.insert(idx + c, custom_node) 91 | c += 1 92 | 93 | # Remove old nodes from the graph 94 | for node in nodes_to_remove: 95 | if node in graph.node: 96 | graph.node.remove(node) 97 | 98 | # (Optional) Infer shapes to ensure consistency 99 | #model = shape_inference.infer_shapes(model) 100 | 101 | return model 102 | 103 | # Load the original model 104 | model = onnx.load('Bert.onnx') 105 | 106 | # Apply the pattern replacement 107 | model = replace_reshape_transpose(model) 108 | 109 | # Check the model for correctness 110 | #onnx.checker.check_model(model) 111 | 112 | # Save the modified model 113 | onnx.save(model, 'Bert_replaced.onnx') 114 | -------------------------------------------------------------------------------- /src/layer/reducemean.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ndarray::ArrayD; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | // ReduceMeanLayer is a layer that returns the mean of the input tensor along one or two given axis/axes 12 | // More than two axes is not supported for now 13 | pub struct ReduceMeanLayer; 14 | impl Layer for ReduceMeanLayer { 15 | fn graph( 16 | input_shapes: &Vec<&Vec>, 17 | input_types: &Vec, 18 | _constants: &Vec, DatumType)>>, 19 | attributes: &Vec<&AttributeProto>, 20 | ) -> (Graph, Vec>, Vec) { 21 | let mut graph = Graph::new(); 22 | 23 | let axes: Vec<_> = match attributes.iter().filter(|x| x.name == "axes").next() { 24 | Some(x) => x.ints.iter().map(|x| *x).collect(), 25 | None => vec![(input_shapes[0].len() - 1) as i64], 26 | }; 27 | 28 | let axes: Vec<_> = axes 29 | .iter() 30 | .map(|&x| { 31 | if x < 0 { 32 | (input_shapes[0].len() as i64 + x) as usize 33 | } else { 34 | x as usize 35 | } 36 | }) 37 | .collect(); 38 | 39 | // Only support reducing along one or two axis 40 | assert!(axes.len() == 1 || axes.len() == 2); 41 | // reducing along the last axis 42 | assert!(axes.iter().any(|&x| x == input_shapes[0].len() - 1)); 43 | 44 | let n = input_shapes[0].len(); 45 | let mut a = input_shapes[0][n - 1]; 46 | a = util::next_pow(a as u32) as usize; 47 | let permutation = (vec![0], (0..a).collect()); 48 | // PermuteBasicBlock is used for permute the last two dimensions for the case of reducing along two axes 49 | // (we need it because our mean computation is done along the last dimension) 50 | let permute = graph.addBB(Box::new(RepeaterBasicBlock { 51 | basic_block: Box::new(PermuteBasicBlock { 52 | permutation: permutation, 53 | n: a, 54 | m: 1, 55 | }), 56 | N: 2, 57 | })); 58 | 59 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize; 60 | let sum = graph.addBB(Box::new(RepeaterBasicBlock { 61 | basic_block: Box::new(SumBasicBlock { len }), 62 | N: 1, 63 | })); 64 | let div = graph.addBB(Box::new(DivConstBasicBlock { 65 | c: input_shapes[0][input_shapes[0].len() - 1] as f32, 66 | })); 67 | let div_check = graph.addBB(Box::new(RepeaterBasicBlock { 68 | basic_block: Box::new(CQ2BasicBlock { 69 | n: 1, 70 | setup: Some(( 71 | Box::new(DivConstBasicBlock { 72 | c: input_shapes[0][input_shapes[0].len() - 1] as f32, 73 | }), 74 | *onnx::CQ_RANGE_LOWER, 75 | *onnx::CQ_RANGE, 76 | )), 77 | }), 78 | N: 1, 79 | })); 80 | let sum_output = graph.addNode(sum, vec![(-1, 0)]); 81 | let div_output = graph.addNode(div, vec![(sum_output, 0)]); 82 | let _ = graph.addNode(div_check, vec![(sum_output, 0), (div_output, 0)]); 83 | 84 | if axes.len() == 2 { 85 | let permute_output = graph.addNode(permute, vec![(div_output, 0)]); 86 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 2] as u32) as usize; 87 | let sum = graph.addBB(Box::new(RepeaterBasicBlock { 88 | basic_block: Box::new(SumBasicBlock { len }), 89 | N: 1, 90 | })); 91 | let sum_output1 = graph.addNode(sum, vec![(permute_output, 0)]); 92 | let div_output1 = graph.addNode(div, vec![(sum_output1, 0)]); 93 | let _ = graph.addNode(div_check, vec![(sum_output1, 0), (div_output1, 0)]); 94 | graph.outputs.push((div_output1, 0)); 95 | let mut outputShape = input_shapes[0].clone(); 96 | outputShape[input_shapes[0].len() - 1] = 1; 97 | outputShape[input_shapes[0].len() - 2] = 1; 98 | (graph, vec![outputShape], vec![input_types[0]]) 99 | } else if axes.len() == 1 { 100 | graph.outputs.push((div_output, 0)); 101 | let mut outputShape = input_shapes[0].clone(); 102 | outputShape[input_shapes[0].len() - 1] = 1; 103 | (graph, vec![outputShape], vec![input_types[0]]) 104 | } else { 105 | panic!("Only support reducing along one or two axis"); 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/layer/gathernd.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ark_std::iterable::Iterable; 7 | use ndarray::Dimension; 8 | use ndarray::{ArrayD, Axis, IxDyn}; 9 | use tract_onnx::pb::AttributeProto; 10 | use tract_onnx::prelude::DatumType; 11 | 12 | // array: the N-dimensional array 13 | // n_minus_1_index: the index of the N-1 dimension 14 | fn get_sub_array(array: ArrayD, n_minus_1_index: &[usize]) -> ArrayD 15 | where 16 | T: Clone, 17 | { 18 | let mut sub_array = array.clone(); 19 | for &index in n_minus_1_index.iter() { 20 | let s = sub_array.view(); 21 | sub_array = s.index_axis(Axis(0), index).to_owned(); 22 | } 23 | sub_array 24 | } 25 | 26 | fn get_gathernd_masks(input_shape: &[usize], indices: &ArrayD, batch_dims: usize) -> (ArrayD>, Vec) { 27 | assert!(indices.shape()[indices.ndim() - 1] <= input_shape.len() - batch_dims); 28 | // ref: https://docs.openvino.ai/2022.3/openvino_docs_ops_movement_GatherND_8.html 29 | let output_shape: Vec = if indices.shape()[indices.ndim() - 1] == input_shape.len() - batch_dims { 30 | // indice shape but exclude the last dimension 31 | indices.shape().iter().take(indices.ndim() - 1).cloned().collect() 32 | } else { 33 | // indices.shape[:batch_dims] + list(indices.shape)[batch_dims:-1] + list(data.shape)[batch_dims + indices.shape[-1]:]. 34 | let mut output_shape = vec![]; 35 | output_shape.extend_from_slice(&indices.shape()[..indices.ndim() - 1]); 36 | output_shape.extend_from_slice(&input_shape[batch_dims + indices.shape()[indices.ndim() - 1]..]); 37 | output_shape 38 | }; 39 | 40 | // permutation[i_0, ..., i_{K-2},:,...,:] = [indices[i_0, ..., i_{K-2}],:,...,:] 41 | let permutation = ArrayD::from_shape_fn(output_shape.clone(), |idx| { 42 | let mut v = vec![]; 43 | // select the partial index from 0..indices.len() - 1 44 | let mut partial_idx = vec![]; 45 | for i in 0..indices.ndim() - 1 { 46 | partial_idx.push(idx[i]); 47 | } 48 | let sub_array = get_sub_array(indices.clone(), &partial_idx); 49 | v.extend(sub_array.as_slice().unwrap()); 50 | for i in indices.ndim() - 1..idx.ndim() { 51 | v.push(idx[i]); 52 | } 53 | Some(IxDyn(&v)) 54 | }); 55 | 56 | let padded_permutation = util::pad_to_pow_of_two(&permutation, &None); 57 | (padded_permutation, output_shape) 58 | } 59 | 60 | // reference (v13): https://onnx.ai/onnx/operators/onnx__GatherND.html 61 | pub struct GatherNDLayer; 62 | impl Layer for GatherNDLayer { 63 | fn graph( 64 | input_shapes: &Vec<&Vec>, 65 | input_types: &Vec, 66 | constants: &Vec, DatumType)>>, 67 | attributes: &Vec<&AttributeProto>, 68 | ) -> (Graph, Vec>, Vec) { 69 | let mut graph = Graph::new(); 70 | 71 | let indices = if constants[1].is_none() { 72 | // we cannot handle non-constant indices because we need to know the shape of the indices to compile graph in zk-torch 73 | panic!("GatherNDLayer: indices must be a constant"); 74 | } else { 75 | constants[1].unwrap().0.map(|x| util::fr_to_int(*x) as usize) 76 | }; 77 | 78 | // attributes may contain batch_dims, but we only support batch_dims = 0 for now 79 | let batch_dims: usize = if attributes.iter().find(|x| x.name == "batch_dims").is_none() { 80 | 0 81 | } else { 82 | let b = attributes.iter().filter(|x| x.name == "axis").next().unwrap().i as usize; 83 | if b != 0 { 84 | panic!("GatherNDLayer: only support the case where batch_dims = 0"); 85 | } else { 86 | b 87 | } 88 | }; 89 | 90 | let data_shape = input_shapes[0].clone(); 91 | let padded_data_shape: Vec<_> = data_shape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 92 | 93 | let (permutation, output_shape) = get_gathernd_masks(&data_shape, &indices, batch_dims); 94 | 95 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 96 | permutation, 97 | input_dim: IxDyn(&padded_data_shape), 98 | padding_partition: copy_constraint::PaddingEnum::Zero, 99 | })); 100 | 101 | let output = graph.addNode(cc, vec![(-1, 0)]); 102 | graph.outputs.push((output, 0)); 103 | 104 | (graph, vec![output_shape], vec![input_types[0]]) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/layer/concat.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ark_std::Zero; 7 | use ndarray::{ArrayD, IxDyn}; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | // This function returns N outputs where N is the number of inputs. 12 | // Each output is an array with the same shape as the final concatenation array. 13 | // And the value at each index is the index of the corresponding input array. 14 | // For example, [1], [1], [1] -> [0, None, None], [None, 0, None], [None, None, 0] 15 | // such that we can use the indices to copy the input arrays to a padded array and add them together into the final output array. 16 | fn get_concat_indices(input_shapes: &Vec<&Vec>, output_shape: &Vec, axis: usize) -> Vec>> { 17 | let mut indices = vec![]; 18 | let mut axis_offset = 0; 19 | for i in 0..input_shapes.len() { 20 | let output = ArrayD::from_shape_fn(output_shape.as_slice(), |index| { 21 | if index[axis] >= axis_offset && index[axis] < axis_offset + input_shapes[i][axis] { 22 | let mut new_index = index.clone(); 23 | new_index[axis] = index[axis] - axis_offset; 24 | Some(new_index) 25 | } else { 26 | None 27 | } 28 | }); 29 | axis_offset += input_shapes[i][axis]; 30 | let output = util::pad_to_pow_of_two(&output, &None); 31 | indices.push(output); 32 | } 33 | indices 34 | } 35 | 36 | // Concatenate the input arrays along the specified axis. 37 | // If the axis is the last axis, we copy the input arrays to a padded array by Copy Constraint and add them together. 38 | // Otherwise, we directly concatenate the input arrays. 39 | pub struct ConcatLayer; 40 | impl Layer for ConcatLayer { 41 | fn graph( 42 | input_shapes: &Vec<&Vec>, 43 | input_types: &Vec, 44 | _constants: &Vec, DatumType)>>, 45 | attributes: &Vec<&AttributeProto>, 46 | ) -> (Graph, Vec>, Vec) { 47 | let mut graph = Graph::new(); 48 | 49 | // Extract the 'axis' attribute and adjust for negative values 50 | let axis: isize = attributes.iter().filter(|x| x.name == "axis").next().unwrap().i as isize; 51 | let axis = (if axis < 0 { input_shapes[0].len() as isize + axis } else { axis }) as usize; 52 | // Compute the output shape after concatenation 53 | let mut outputShape = input_shapes[0].clone(); 54 | outputShape[axis] = input_shapes.iter().map(|x| x[axis as usize]).sum(); 55 | // If concatenating along the last axis, use copy constraint as the output commitment changes 56 | if axis == input_shapes[0].len() - 1 { 57 | let mut padded_output_shape = outputShape.clone(); 58 | padded_output_shape[axis] = util::next_pow(padded_output_shape[axis] as u32) as usize; 59 | let permutations = get_concat_indices(input_shapes, &padded_output_shape, axis); 60 | let mut cc_basicblocks = vec![]; 61 | for i in 0..input_shapes.len() { 62 | let padded_input_shape: Vec = input_shapes[i].iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 63 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 64 | permutation: permutations[i].clone(), 65 | input_dim: IxDyn(&padded_input_shape), 66 | padding_partition: copy_constraint::PaddingEnum::Zero, 67 | })); 68 | cc_basicblocks.push(cc); 69 | } 70 | let add = graph.addBB(Box::new(RepeaterBasicBlock { 71 | basic_block: Box::new(AddBasicBlock {}), 72 | N: 1, 73 | })); 74 | 75 | let mut cc_outputs = vec![]; 76 | for i in 0..input_shapes.len() { 77 | let cc_output = graph.addNode(cc_basicblocks[i], vec![(-(i as i32 + 1), 0)]); 78 | cc_outputs.push((cc_output, 0)); 79 | } 80 | // add 2 cc_outputs at a time until only 1 output is left 81 | while cc_outputs.len() > 1 { 82 | let add_output = graph.addNode(add, vec![cc_outputs.pop().unwrap(), cc_outputs.pop().unwrap()]); 83 | cc_outputs.push((add_output, 0)); 84 | } 85 | let final_output = cc_outputs.pop().unwrap(); 86 | graph.outputs.push(final_output); 87 | } else { 88 | // If not concatenating along the last axis, directly concatenate 89 | let n_input = input_shapes.len(); 90 | let concat = graph.addBB(Box::new(ConcatBasicBlock { 91 | axis: axis as usize, 92 | input_shapes: input_shapes.iter().map(|x| (*x).clone()).collect(), 93 | })); 94 | let concat_input: Vec<_> = (0..n_input).map(|i| (-(i as i32 + 1), 0)).collect(); 95 | let concat_output = graph.addNode(concat, concat_input); 96 | graph.outputs.push((concat_output, 0)); 97 | } 98 | 99 | (graph, vec![outputShape], vec![input_types[0]]) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/layer/scatternd.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::{ArrayD, Axis, Dim, IxDyn}; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | fn get_masks(input_shape: &[usize], indices: &ArrayD) -> (ArrayD>, ArrayD>) { 11 | let preserve = ArrayD::from_shape_fn(input_shape, |index| Some(index)); 12 | let update = ArrayD::from_shape_fn(input_shape, |_| None); 13 | let mut preserve = util::pad_to_pow_of_two(&preserve, &None); 14 | let mut update = util::pad_to_pow_of_two(&update, &None); 15 | let indices_usize = indices.map(|x| util::fr_to_int(*x) as usize); 16 | let indices_shape = indices.shape(); 17 | let update_indices = &indices_shape[..indices_shape.len() - 1]; 18 | 19 | let mut current_index = vec![]; 20 | let mut all_indices = vec![]; 21 | ndindex(update_indices, &mut current_index, &mut all_indices); 22 | let update_indices = indices_usize.lanes(Axis(indices_usize.ndim() - 1)); 23 | for (idx, update_idx) in all_indices.iter().zip(update_indices) { 24 | if update_idx.len() > input_shape.len() { 25 | // use only the first n elements of update_idx, where n is the length of input_shape 26 | let mut new_update_idx = update_idx.to_vec(); 27 | new_update_idx.truncate(input_shape.len()); 28 | let copy_update_idx = Dim(new_update_idx.to_vec()); 29 | let copy_idx = Dim(idx.clone()); 30 | preserve[copy_update_idx.clone()] = None; 31 | update[copy_update_idx] = Some(copy_idx.clone()); 32 | } else if update_idx.len() < input_shape.len() { 33 | let input_shape_extra_dims = input_shape[input_shape.len() - update_idx.len()..].to_vec(); 34 | let mut current_index = vec![]; 35 | let mut extra_indices = vec![]; 36 | ndindex(&input_shape_extra_dims, &mut current_index, &mut extra_indices); 37 | for extra_idx in extra_indices { 38 | // concat the extra indices to the update_idx 39 | let mut new_update_idx = update_idx.to_vec(); 40 | new_update_idx.extend(extra_idx.clone()); 41 | let copy_update_idx = Dim(new_update_idx.to_vec()); 42 | // concat the extra indices to the idx 43 | let mut new_idx = idx.to_vec(); 44 | new_idx.extend(extra_idx); 45 | let copy_idx = Dim(new_idx.to_vec()); 46 | 47 | preserve[copy_update_idx.clone()] = None; 48 | update[copy_update_idx] = Some(copy_idx.clone()); 49 | } 50 | } else { 51 | preserve[Dim(update_idx.to_vec())] = None; 52 | update[Dim(update_idx.to_vec())] = Some(Dim(idx.clone())); 53 | } 54 | } 55 | 56 | (preserve, update) 57 | } 58 | 59 | fn ndindex(shape: &[usize], current_index: &mut Vec, all_indices: &mut Vec>) { 60 | if current_index.len() == shape.len() { 61 | all_indices.push(current_index.clone()); 62 | return; 63 | } 64 | 65 | let dim = shape[current_index.len()]; 66 | for i in 0..dim { 67 | current_index.push(i); 68 | ndindex(shape, current_index, all_indices); 69 | current_index.pop(); 70 | } 71 | } 72 | 73 | // https://onnx.ai/onnx/operators/onnx__ScatterND.html 74 | pub struct ScatterNDLayer; 75 | impl Layer for ScatterNDLayer { 76 | fn graph( 77 | input_shapes: &Vec<&Vec>, 78 | input_types: &Vec, 79 | constants: &Vec, DatumType)>>, 80 | _attributes: &Vec<&AttributeProto>, 81 | ) -> (Graph, Vec>, Vec) { 82 | let mut graph = Graph::new(); 83 | // TODO: handle cases when attribute is not none 84 | let indices = constants[1].unwrap().0; 85 | 86 | let (permutation_preserve, permutation_update) = get_masks(&input_shapes[0], &indices); 87 | let input_shape_0_padded: Vec<_> = input_shapes[0].iter().map(|x| util::next_pow(*x as u32) as usize).collect(); 88 | let input_shape_2_padded: Vec<_> = input_shapes[2].iter().map(|x| util::next_pow(*x as u32) as usize).collect(); 89 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 90 | permutation: permutation_preserve, 91 | input_dim: IxDyn(&input_shape_0_padded), 92 | padding_partition: copy_constraint::PaddingEnum::Zero, 93 | })); 94 | let cc1 = graph.addBB(Box::new(CopyConstraintBasicBlock { 95 | permutation: permutation_update, 96 | input_dim: IxDyn(&input_shape_2_padded), 97 | padding_partition: copy_constraint::PaddingEnum::Zero, 98 | })); 99 | let add = graph.addBB(Box::new(RepeaterBasicBlock { 100 | basic_block: Box::new(AddBasicBlock {}), 101 | N: 1, 102 | })); 103 | 104 | let data_to_preserve = graph.addNode(cc, vec![(-1, 0)]); 105 | let data_to_update = graph.addNode(cc1, vec![(-3, 0)]); 106 | let add_output = graph.addNode(add, vec![(data_to_preserve, 0), (data_to_update, 0)]); 107 | graph.outputs.push((add_output, 0)); 108 | 109 | (graph, vec![input_shapes[0].clone()], vec![input_types[0]]) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/layer/softmax.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ndarray::ArrayD; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | pub struct SoftmaxLayer; 12 | impl Layer for SoftmaxLayer { 13 | fn graph( 14 | input_shapes: &Vec<&Vec>, 15 | input_types: &Vec, 16 | _constants: &Vec, DatumType)>>, 17 | _attributes: &Vec<&AttributeProto>, 18 | ) -> (Graph, Vec>, Vec) { 19 | let mut graph = Graph::new(); 20 | let sf_log = onnx::SF_LOG.read().unwrap().to_owned(); 21 | let max = graph.addBB(Box::new(MaxBasicBlock {})); 22 | let sub = graph.addBB(Box::new(RepeaterBasicBlock { 23 | basic_block: Box::new(SubBasicBlock {}), 24 | N: 1, 25 | })); 26 | let exp = graph.addBB(Box::new(ExpBasicBlock { 27 | input_SF: sf_log, 28 | output_SF: sf_log, 29 | })); 30 | let exp_check = graph.addBB(Box::new(RepeaterBasicBlock { 31 | basic_block: Box::new(CQ2BasicBlock { 32 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 33 | setup: Some(( 34 | Box::new(ExpBasicBlock { 35 | input_SF: sf_log, 36 | output_SF: sf_log, 37 | }), 38 | (-(*onnx::CQ_RANGE as i32) + 1) as i128, 39 | *onnx::CQ_RANGE, 40 | )), 41 | }), 42 | N: 1, 43 | })); 44 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize; 45 | let sum = graph.addBB(Box::new(RepeaterBasicBlock { 46 | basic_block: Box::new(SumBasicBlock { len }), 47 | N: 1, 48 | })); 49 | let reciprocal = graph.addBB(Box::new(ReciprocalBasicBlock { 50 | input_SF: sf_log, 51 | output_SF: sf_log, 52 | })); 53 | let rec_check = graph.addBB(Box::new(RepeaterBasicBlock { 54 | basic_block: Box::new(CQ2BasicBlock { 55 | n: 1, 56 | setup: Some(( 57 | Box::new(ReciprocalBasicBlock { 58 | input_SF: sf_log, 59 | output_SF: sf_log, 60 | }), 61 | 0, 62 | *onnx::CQ_RANGE, 63 | )), 64 | }), 65 | N: 1, 66 | })); 67 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 68 | basic_block: Box::new(MulScalarBasicBlock {}), 69 | N: 1, 70 | })); 71 | let change_SF = graph.addBB(Box::new(ChangeSFBasicBlock { 72 | input_SF: sf_log * 2, 73 | output_SF: sf_log, 74 | })); 75 | let change_SF_check = graph.addBB(Box::new(RepeaterBasicBlock { 76 | basic_block: Box::new(CQ2BasicBlock { 77 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 78 | setup: Some(( 79 | Box::new(ChangeSFBasicBlock { 80 | input_SF: sf_log * 2, 81 | output_SF: sf_log, 82 | }), 83 | *onnx::CQ_RANGE_LOWER, 84 | *onnx::CQ_RANGE, 85 | )), 86 | }), 87 | N: 1, 88 | })); 89 | 90 | // The proving idea is as follows 91 | // 1. m = max(X): 92 | // We first compute the maximum value of the input array. 93 | // 2. X - m: 94 | // We subtract the maximum value from each element of the input array. 95 | // 3. e^(X - m) * SF: 96 | // We compute the exponential of each element of the input array. 97 | // And we use "exp_check" to ensure that the output is within the CQ range. 98 | // 4. SUM(e^(X - m)) * SF: 99 | // We compute the sum of the exponential of each element of the input array. 100 | // 5. SF / SUM(e^(X - m)): 101 | // We compute the reciprocal of the sum of the exponential of each element of the input array. 102 | // And we use "rec_check" to ensure that the output is within the CQ range. 103 | // 6. [e^(X - m) * SF] * [SF / SUM(e^(X - m))]: 104 | // We multiply the output from step 3 and step 5. 105 | // 7. [e^(X - m) * SF] * [SF / SUM(e^(X - m))] --> [e^(X - m) / SUM(e^(X - m))] * SF 106 | // Change the scale factor of the output to the original scale factor. 107 | let max_output = graph.addNode(max, vec![(-1, 0)]); 108 | let sub_output = graph.addNode(sub, vec![(-1, 0), (max_output, 0)]); 109 | let exp_output = graph.addNode(exp, vec![(sub_output, 0)]); 110 | let _ = graph.addNode(exp_check, vec![(sub_output, 0), (exp_output, 0)]); 111 | let sum_output = graph.addNode(sum, vec![(exp_output, 0)]); 112 | let rec_output = graph.addNode(reciprocal, vec![(sum_output, 0)]); 113 | let _ = graph.addNode(rec_check, vec![(sum_output, 0), (rec_output, 0)]); 114 | let mul_output = graph.addNode(mul, vec![(exp_output, 0), (rec_output, 0)]); 115 | let output = graph.addNode(change_SF, vec![(mul_output, 0)]); 116 | let _ = graph.addNode(change_SF_check, vec![(mul_output, 0), (output, 0)]); 117 | graph.outputs.push((output, 0)); 118 | 119 | (graph, vec![input_shapes[0].clone()], vec![input_types[0]]) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/layer/slice.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util; 5 | use ark_bn254::Fr; 6 | use ndarray::{ArrayD, IxDyn}; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | fn combinations(vecs: &Vec>) -> Vec> { 11 | // Recursive function to generate combinations 12 | fn combine(vecs: &[Vec], current: Vec, result: &mut Vec>) { 13 | if vecs.is_empty() { 14 | result.push(current); 15 | } else { 16 | for item in &vecs[0] { 17 | let mut new_current = current.clone(); 18 | new_current.push(item.clone()); 19 | combine(&vecs[1..], new_current, result); 20 | } 21 | } 22 | } 23 | 24 | let mut result = Vec::new(); 25 | combine(&vecs, Vec::new(), &mut result); 26 | result 27 | } 28 | 29 | fn get_slice( 30 | input_dim: &Vec, 31 | starts: &mut Vec, 32 | ends: &mut Vec, 33 | axes: &mut Vec, 34 | steps: &mut Vec, 35 | ) -> (ArrayD>, Vec, Vec) { 36 | let rank = input_dim.len(); 37 | let mut result_idx = vec![vec![]; rank]; 38 | let mut real_output_shape = vec![0; rank]; 39 | let mut real_ends = ends.clone(); 40 | 41 | let input_shape_pad: Vec<_> = input_dim.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 42 | 43 | if starts.len() < rank { 44 | for i in 0..rank { 45 | if axes.contains(&i) { 46 | continue; 47 | } 48 | starts.insert(i, 0); 49 | ends.insert(i, input_shape_pad[i]); 50 | real_ends.insert(i, input_dim[i]); 51 | axes.insert(i, i); 52 | steps.insert(i, 1); 53 | } 54 | } 55 | 56 | for (i, &axis) in axes.iter().enumerate() { 57 | let step = steps[i]; 58 | let mut start = starts[i]; 59 | let end = ends[i]; 60 | let mut real_end = real_ends[i]; 61 | if end > input_shape_pad[i] { 62 | real_end = input_dim[i]; 63 | } 64 | while start < real_end { 65 | result_idx[axis].push(start); 66 | real_output_shape[axis] += 1; 67 | start += step; 68 | } 69 | } 70 | let combination_result = combinations(&result_idx); 71 | let f = combination_result.iter().map(|v| Some(IxDyn(v))).collect(); 72 | let result = ArrayD::from_shape_vec(real_output_shape.clone(), f).unwrap(); 73 | let result = util::pad_to_pow_of_two(&result, &None); 74 | (result, real_output_shape, input_shape_pad) 75 | } 76 | 77 | // https://onnx.ai/onnx/operators/onnx__Slice.html 78 | pub struct SliceLayer; 79 | impl Layer for SliceLayer { 80 | fn graph( 81 | input_shapes: &Vec<&Vec>, 82 | input_types: &Vec, 83 | constants: &Vec, DatumType)>>, 84 | _attributes: &Vec<&AttributeProto>, 85 | ) -> (Graph, Vec>, Vec) { 86 | let mut graph = Graph::new(); 87 | let starts: Vec<_> = constants[1].unwrap().0.as_slice().unwrap().iter().map(|x| util::fr_to_int(*x) as i32).collect(); 88 | let ends: Vec<_> = constants[2].unwrap().0.as_slice().unwrap().iter().map(|x| util::fr_to_int(*x) as i32).collect(); 89 | // steps and axes might be optional 90 | let axes: Vec<_> = match constants.get(3) { 91 | Some(x) => x.unwrap().0.as_slice().unwrap().iter().map(|x| util::fr_to_int(*x) as i32).collect(), 92 | None => (0..input_shapes[0].len()).map(|x| x as i32).collect(), 93 | }; 94 | let mut steps: Vec<_> = match constants.get(4) { 95 | Some(x) => x.unwrap().0.as_slice().unwrap().iter().map(|x| util::fr_to_int(*x) as usize).collect(), 96 | None => vec![1; starts.len()], 97 | }; 98 | let mut axes: Vec<_> = axes 99 | .iter() 100 | .map(|&x| { 101 | if x < 0 { 102 | (input_shapes[0].len() as i32 + x) as usize 103 | } else { 104 | x as usize 105 | } 106 | }) 107 | .collect(); 108 | let mut starts: Vec<_> = starts 109 | .iter() 110 | .enumerate() 111 | .map(|(i, &x)| { 112 | if x < 0 { 113 | (input_shapes[0][axes[i]] as i32 + x) as usize 114 | } else { 115 | x as usize 116 | } 117 | }) 118 | .collect(); 119 | let mut ends: Vec<_> = ends 120 | .iter() 121 | .enumerate() 122 | .map(|(i, &x)| { 123 | if x < 0 { 124 | (input_shapes[0][axes[i]] as i32 + x + 1) as usize 125 | } else { 126 | x as usize 127 | } 128 | }) 129 | .collect(); 130 | 131 | let (permutation, output_shape, input_shape_pad) = get_slice(&input_shapes[0], &mut starts, &mut ends, &mut axes, &mut steps); 132 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 133 | permutation, 134 | input_dim: IxDyn(&input_shape_pad), 135 | padding_partition: copy_constraint::PaddingEnum::Zero, 136 | })); 137 | let slice_output = graph.addNode(cc, vec![(-1, 0)]); 138 | graph.outputs.push((slice_output, 0)); 139 | 140 | (graph, vec![output_shape], vec![input_types[0]]) 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ZKTorch 2 | 3 | 4 | 5 | ## Overview 6 | 7 | [Zero-knowledge (ZK) proofs of ML model inference](https://medium.com/@danieldkang/bridging-the-gap-how-zk-snarks-bring-transparency-to-private-ml-models-with-zkml-e0e59708c2fc) help provide transparency to users without requiring model owners to share model weights. Past work on these provers can be placed into two categories. The first method compiles the ML model into a low-level circuit, and the second method uses custom cryptographic protocols designed only for a specific class of models. Unfortunately, the first method is highly inefficient, and the second method does not generalize well. 8 | 9 | ZKTorch is an end-to-end proving system for compiling ML model inference computation into ZK circuits from ONNX models by compiling layers into a set of specialized cryptographic operations, which we call basic blocks. It is built on top of a parallel extension to the Mira accumulation scheme, enabling succinct proofs with minimal accumulation overhead. We support all edge models in the [MLPerf Edge Inference Suite v4.1](https://github.com/mlcommons/inference_policies/blob/master/inference_rules.adoc#benchmarks-1), covering convolutional neural networks (CNNs), recurrent neural networks (RNNs), and large language models (LLMs). Overall, ZKTorch supports 61 layers with a total of 20 basic blocks. With the Mira accumulator extension, we condense proofs of the same basic block type. 10 | 11 | ![zk_torch_readme](https://github.com/user-attachments/assets/6715728d-1818-4ee2-9732-35fafc53976c) 12 | 13 | ## Prerequisites 14 | 15 | ### Install Rust 16 | ``` 17 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 18 | ``` 19 | 20 | ### Install nightly Rust 21 | ``` 22 | rustup override set nightly 23 | ``` 24 | 25 | ## Run the example 26 | 27 | The example runs ZKTorch with Mira-style folding enabled on the ONNX file and configurations specified in `config.yaml`. The ONNX file contains a small model with two fully-connected layers and two ReLU layers. 28 | 29 | ``` 30 | cargo run --release --bin zk_torch --features fold -- config.yaml 31 | ``` 32 | 33 | ## How to run custom experiments 34 | 35 | 36 | 37 | ### Update the ptau file 38 | 39 | Most models will require a larger powers of tau (ptau) file than the provided `challenge` file which has `pow_len_log=7`. The size of the ptau file needed (`pow_len_log` in `config.yaml`) depends on the magnitude of the quantized values to support in the inference computation (`cq_range_log` in `config.yaml`) as well as the sizes of the inputs to certain layers. In most cases, the former will be the deciding factor, with the constraint `cq_range_log` < `pow_len_log`. 40 | 41 | To produce a larger file, please refer to the following instructions to generate one with the `snarkjs` tool: 42 | https://github.com/iden3/snarkjs?tab=readme-ov-file#1-start-a-new-powers-of-tau-ceremony. For step 1, you can replace `14` with the desired value for the `pow_len_log` and then directly follow the remaining instructions through step 4 which produces the file with the `snarkjs powersoftau export challenge` command. 43 | 44 | Then, update `config.yaml` based on the produced ptau file. 45 | 46 | ``` 47 | ptau_path: 48 | pow_len_log: 49 | loaded_pow_len_log: 50 | cq_range_log: 51 | cq_range_lower_log: 52 | ``` 53 | For example, here is a valid configuration for `challenge_0003` produced by the example instructions: 54 | 55 | ``` 56 | ptau_path: challenge_0003 57 | pow_len_log: 14 58 | loaded_pow_len_log: 14 59 | cq_range_log: 6 60 | cq_range_lower_log: 5 61 | ``` 62 | 63 | ### Replace the model and input in `config.yaml` 64 | `model_path` should contain the path to the ONNX file to compile. `input_path` can be left blank or contain a JSON file similar to the example below, replacing the value with a tensor value. 65 | 66 | `{"input_data": [[0.09, 0.13, 0.24, 0.05]]}` 67 | 68 | If it is left blank or the provided path does not exist, Zk-Torch will generate a random input tensor based on the input shape specified in the ONNX file, or otherwise throw an error. 69 | 70 | Update the `config.yaml`: 71 | ``` 72 | model_path: 73 | input_path: 74 | ``` 75 | 76 | ### Use customized scale factor 77 | Update `config.yaml` based on the desired quantization scale factor. 78 | ``` 79 | scale_factor_log: 80 | ``` 81 | ### Run experiment 82 | If you change the ptau file after `layers_setup/`, `models`, and `setups` have been produced from a previous run, please delete them before proving. 83 | 84 | Then run 85 | ``` 86 | cargo run --release --bin zk_torch --features fold -- 87 | ``` 88 | To just run proving (e.g., for testing purposes), you can additionally add the `mock_prove` feature (`--features mock_prove,fold`). 89 | 90 | The outputs consist of input, model (including weights and lookup tables), output, and setup encodings, as well as the proof before accumulation, accumulation-specific proofs, and the final proof after accumulation for the prover and verifier. The output paths are specified in `config.yaml`. 91 | -------------------------------------------------------------------------------- /src/layer/reshape.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::{squeeze::UnsqueezeBasicBlock, Layer}; 4 | use crate::util::{self, get_reshape_indices, get_reshape_transpose_indices}; 5 | use ark_bn254::Fr; 6 | use ark_std::Zero; 7 | use ndarray::{ArrayD, IxDyn}; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | pub struct ReshapeLayer; 12 | impl Layer for ReshapeLayer { 13 | fn graph( 14 | input_shapes: &Vec<&Vec>, 15 | input_types: &Vec, 16 | constants: &Vec, DatumType)>>, 17 | _attributes: &Vec<&AttributeProto>, 18 | ) -> (Graph, Vec>, Vec) { 19 | let mut graph = Graph::new(); 20 | 21 | let startShape = input_shapes[0]; 22 | let mut endShape: Vec<_> = constants[1] 23 | .unwrap() 24 | .0 25 | .as_slice() 26 | .unwrap() 27 | .iter() 28 | .enumerate() 29 | .map(|(i, x)| { 30 | if i < input_shapes[1][0] { 31 | // If a shape dimension is 0, then we replace the value with the corresponding input dimension 32 | if *x == Fr::zero() { 33 | input_shapes[0][i] as i32 34 | } else { 35 | util::fr_to_int(*x) as i32 36 | } 37 | } else { 38 | 0 39 | } 40 | }) 41 | .filter(|x| *x != 0) 42 | .collect(); 43 | if let Some(i) = endShape.iter().position(|&x| x == -1) { 44 | let a = input_shapes[0].iter().fold(1, |x, &y| x * y) as i32; 45 | let b = endShape.iter().fold(-1, |x, &y| x * y); 46 | endShape[i] = a / b; 47 | } 48 | let endShape: Vec<_> = endShape.iter().map(|&x| x as usize).filter(|x| *x != 0).collect(); 49 | let endShape_padded: Vec<_> = endShape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 50 | let startShape_padded: Vec<_> = startShape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 51 | // check if the product of startShape_padded is equal to the product of endShape_padded 52 | let equal = startShape_padded.iter().fold(1, |x, &y| x * y) == endShape_padded.iter().fold(1, |x, &y| x * y); 53 | 54 | if equal && (startShape.last() == endShape.last()) { 55 | let reshape = graph.addBB(Box::new(ReshapeBasicBlock { 56 | shape: endShape_padded.clone(), 57 | })); 58 | let output = graph.addNode(reshape, vec![(-1, 0)]); 59 | graph.outputs.push((output, 0)); 60 | } else if startShape.len() == 0 { 61 | // special case: arr0 --> [1,1,...] 62 | let unsq = graph.addBB(Box::new(UnsqueezeBasicBlock {})); 63 | let mut unsq_output = graph.addNode(unsq, vec![(-1, 0)]); 64 | for _ in 0..endShape.len() - 1 { 65 | unsq_output = graph.addNode(unsq, vec![(unsq_output, 0)]); 66 | } 67 | graph.outputs.push((unsq_output, 0)); 68 | } else { 69 | let permutation = get_reshape_indices(startShape.clone(), endShape.clone()); 70 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 71 | permutation: permutation.clone(), 72 | input_dim: IxDyn(&startShape_padded), 73 | padding_partition: copy_constraint::PaddingEnum::Zero, 74 | })); 75 | let output = graph.addNode(cc, vec![(-1, 0)]); 76 | graph.outputs.push((output, 0)); 77 | } 78 | 79 | (graph, vec![endShape], vec![input_types[0]]) 80 | } 81 | } 82 | 83 | pub struct ReshapeTransLayer; 84 | impl Layer for ReshapeTransLayer { 85 | fn graph( 86 | input_shapes: &Vec<&Vec>, 87 | input_types: &Vec, 88 | constants: &Vec, DatumType)>>, 89 | attributes: &Vec<&AttributeProto>, 90 | ) -> (Graph, Vec>, Vec) { 91 | let mut graph = Graph::new(); 92 | 93 | let axes: Vec<_> = attributes.iter().filter(|x| x.name == "perm").next().unwrap().ints.iter().map(|x| *x as usize).collect(); 94 | 95 | let startShape = input_shapes[0]; 96 | let mut endShape: Vec<_> = constants[1] 97 | .unwrap() 98 | .0 99 | .as_slice() 100 | .unwrap() 101 | .iter() 102 | .enumerate() 103 | .map(|(i, x)| { 104 | if i < input_shapes[1][0] { 105 | // If a shape dimension is 0, then we replace the value with the corresponding input dimension 106 | if *x == Fr::zero() { 107 | input_shapes[0][i] as i32 108 | } else { 109 | util::fr_to_int(*x) as i32 110 | } 111 | } else { 112 | 0 113 | } 114 | }) 115 | .filter(|x| *x != 0) 116 | .collect(); 117 | if let Some(i) = endShape.iter().position(|&x| x == -1) { 118 | let a = input_shapes[0].iter().fold(1, |x, &y| x * y) as i32; 119 | let b = endShape.iter().fold(-1, |x, &y| x * y); 120 | endShape[i] = a / b; 121 | } 122 | let endShape: Vec<_> = endShape.iter().map(|&x| x as usize).filter(|x| *x != 0).collect(); 123 | 124 | let startShape_padded: Vec<_> = startShape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 125 | 126 | let permutation = get_reshape_transpose_indices(startShape.clone(), endShape.clone(), axes.clone()); 127 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 128 | permutation: permutation.clone(), 129 | input_dim: IxDyn(&startShape_padded), 130 | padding_partition: copy_constraint::PaddingEnum::Zero, 131 | })); 132 | let output = graph.addNode(cc, vec![(-1, 0)]); 133 | graph.outputs.push((output, 0)); 134 | 135 | let endShape: Vec<_> = axes.iter().map(|i| endShape[*i]).collect(); 136 | (graph, vec![endShape], vec![input_types[0]]) 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/layer/pool.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::conv::{out_hw, splat_pad}; 4 | use crate::layer::Layer; 5 | use crate::onnx; 6 | use crate::util; 7 | use ark_bn254::Fr; 8 | use ndarray::{arr1, indices, ArrayD, Dim, Dimension, IxDyn}; 9 | use tract_onnx::pb::AttributeProto; 10 | use tract_onnx::prelude::DatumType; 11 | 12 | // This constructs the permutation for CopyConstraintBasicBlock to be inputted into MaxProofBasicBlock. The output is a (product of output dims of the pool operation * input channels X product of kernel_dims) permutation where the rows correspond to one max operation, and each row contains the set of arguments to max. 13 | // ci is the number of input channels 14 | fn splat_input(input_shape: &Vec, strides: &Vec, pads: &Vec, ci: usize, kernel_dims: &Vec) -> Vec>> { 15 | let dims = input_shape[2..].to_vec(); 16 | let mut padding = vec![[0, 0], [0, 0]]; 17 | for i in 0..dims.len() { 18 | padding.push([pads[i], pads[dims.len() + i]]); 19 | } 20 | 21 | let inp_shape = Dim(IxDyn(input_shape)); 22 | let inp = ArrayD::from_shape_vec(inp_shape.clone(), indices(inp_shape).into_iter().map(|x| x.into_dyn()).collect()).unwrap(); 23 | 24 | let inp_pad = util::pad(&inp, &padding, &IxDyn::zeros(input_shape.len())); 25 | 26 | let out_dims = out_hw(&dims, &strides, &kernel_dims, &padding[2..].to_vec(), false); 27 | 28 | let mut inp_cells = vec![]; 29 | let mut input_row_idx = 0; 30 | 31 | // (out_dims product * inp_channels x kernel_dims product) 32 | for batch in 0..inp.shape()[0] { 33 | for out_idx in indices(out_dims.clone()) { 34 | for ck in 0..ci { 35 | inp_cells.push(vec![]); 36 | for kernel_idx in indices(IxDyn(&kernel_dims)) { 37 | let mut idx = vec![batch, ck]; 38 | idx.append(&mut (0..dims.len()).map(|i| out_idx[i] * strides[i] + kernel_idx[i]).collect()); 39 | inp_cells[input_row_idx].push(Some(inp_pad[IxDyn(&idx)].clone())); 40 | } 41 | input_row_idx += 1; 42 | } 43 | } 44 | } 45 | inp_cells 46 | } 47 | 48 | pub struct MaxPoolLayer; 49 | impl Layer for MaxPoolLayer { 50 | fn graph( 51 | input_shapes: &Vec<&Vec>, 52 | input_types: &Vec, 53 | _constants: &Vec, DatumType)>>, 54 | attributes: &Vec<&AttributeProto>, 55 | ) -> (Graph, Vec>, Vec) { 56 | let mut graph = Graph::new(); 57 | let dims = input_shapes[0][2..].to_vec(); 58 | 59 | let kernel_shape: Vec<_> = attributes.iter().filter(|x| x.name == "kernel_shape").next().unwrap().ints.iter().map(|x| *x as usize).collect(); 60 | 61 | let strides: Vec<_> = match attributes.iter().filter(|x| x.name == "strides").next() { 62 | Some(v) => v.ints.iter().map(|x| *x as usize).collect(), 63 | None => vec![1; dims.len()], 64 | }; 65 | let pads: Vec<_> = match attributes.iter().filter(|x| x.name == "pads").next() { 66 | Some(v) => v.ints.iter().map(|x| *x as usize).collect(), 67 | None => vec![0; 2 * dims.len()], 68 | }; 69 | 70 | // Splat input 71 | let ch = input_shapes[0][1]; 72 | let permutation = splat_input(&input_shapes[0], &strides, &pads, ch, &kernel_shape); 73 | let permutation_padded = splat_pad(&permutation, &None); 74 | let input_shape_padded: Vec<_> = input_shapes[0].iter().map(|i| i.next_power_of_two()).collect(); 75 | 76 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 77 | permutation: permutation_padded, 78 | input_dim: IxDyn(&input_shape_padded), 79 | padding_partition: copy_constraint::PaddingEnum::Max(Fr::from(*onnx::CQ_RANGE_LOWER)), 80 | })); 81 | 82 | // Prove max over each row 83 | let max = graph.addBB(Box::new(RepeaterBasicBlock { 84 | basic_block: Box::new(MaxProofBasicBlock { 85 | cq_range_lower: *onnx::CQ_RANGE_LOWER, 86 | }), 87 | N: 1, 88 | })); 89 | 90 | // Reshape into output shape 91 | let mut padding = vec![[0, 0], [0, 0]]; 92 | for i in 0..dims.len() { 93 | padding.push([pads[i], pads[dims.len() + i]]); 94 | } 95 | let mut output_shape = input_shapes[0][..2].to_vec(); 96 | output_shape.append(&mut out_hw(&dims, &strides, &kernel_shape, &padding[2..].to_vec(), false)); 97 | let mut reshape_shape: Vec<_> = output_shape.iter().map(|x| x.next_power_of_two()).collect(); 98 | reshape_shape.push(1); 99 | let reshape1 = graph.addBB(Box::new(ReshapeBasicBlock { shape: reshape_shape })); 100 | let a = output_shape[output_shape.len() - 1].next_power_of_two(); 101 | let transpose1 = ((0..1).map(|x| x * a).collect(), (0..a).collect()); 102 | let permute = graph.addBB(Box::new(RepeaterBasicBlock { 103 | basic_block: Box::new(PermuteBasicBlock { 104 | permutation: transpose1, 105 | n: a, 106 | m: 1, 107 | }), 108 | N: 2, 109 | })); 110 | 111 | let reshape2 = graph.addBB(Box::new(ReshapeBasicBlock { 112 | shape: output_shape.iter().map(|x| x.next_power_of_two()).collect(), 113 | })); 114 | 115 | let range_check = graph.addBB(Box::new(RepeaterBasicBlock { 116 | basic_block: Box::new(CQBasicBlock { 117 | n: output_shape[output_shape.len() - 1].next_power_of_two(), 118 | setup: util::CQArrayType::NonNegative, 119 | }), 120 | N: 1, 121 | })); 122 | 123 | let cc_output = graph.addNode(cc, vec![(-1, 0)]); 124 | let max_output = graph.addNode(max, vec![(cc_output, 0)]); 125 | let reshape1_output = graph.addNode(reshape1, vec![(max_output, 0)]); 126 | let permute_output = graph.addNode(permute, vec![(reshape1_output, 0)]); 127 | let reshape2_output = graph.addNode(reshape2, vec![(permute_output, 0)]); 128 | let _ = graph.addNode(range_check, vec![(max_output, 1)]); 129 | graph.outputs.push((reshape2_output, 0)); 130 | (graph, vec![output_shape], vec![input_types[0]]) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/util/onnx.rs: -------------------------------------------------------------------------------- 1 | use crate::onnx; 2 | /* 3 | * ONNX utilities: 4 | * The function(s) are used for ONNX-related operations. 5 | * For example, generate fake inputs for ONNX models. 6 | */ 7 | use crate::util::pad_to_pow_of_two; 8 | use ark_bn254::Fr; 9 | use ark_std::Zero; 10 | use ndarray::ArrayD; 11 | use rand::{rngs::StdRng, Rng, SeedableRng}; 12 | use serde::Deserialize; 13 | use tract_onnx::pb::{tensor_proto::DataType, type_proto::Tensor}; 14 | use tract_onnx::prelude::{DatumType, Framework}; 15 | 16 | // This function is used for getting the shape of an ONNX input tensor 17 | pub fn get_shape_from_onnx_tensor(tensor: &Tensor) -> Vec { 18 | tensor 19 | .shape 20 | .as_ref() 21 | .unwrap() 22 | .dim 23 | .iter() 24 | .map(|x| { 25 | if let tract_onnx::pb::tensor_shape_proto::dimension::Value::DimValue(x) = x.value.as_ref().unwrap() { 26 | *x as usize 27 | } else { 28 | panic!("Unknown dimension") 29 | } 30 | }) 31 | .collect::>() 32 | } 33 | 34 | // This function is used for generating fake inputs for onnx models 35 | // Fake inputs are random field (i.e., Fr) elements whose shapes and types match those described in the input tensors of an ONNX model. 36 | // Generating these when loading an ONNX file saves us from creating different input tensors ourselves when testing new ONNX. 37 | // It is only for testing purposes 38 | pub fn generate_fake_inputs_for_onnx(filename: &str) -> Vec> { 39 | let onnx = tract_onnx::onnx(); 40 | let onnx_graph = onnx.proto_model_for_path(filename).unwrap().graph.unwrap(); 41 | 42 | let mut inputs = vec![]; 43 | 44 | for onnx_input in onnx_graph.input.iter() { 45 | let tract_onnx::pb::type_proto::Value::TensorType(t) = onnx_input.r#type.as_ref().unwrap().value.as_ref().unwrap(); 46 | let shape = get_shape_from_onnx_tensor(t); 47 | 48 | let input = generate_fake_tensor(t.elem_type(), shape); 49 | let input = pad_to_pow_of_two(&input, &Fr::zero()); 50 | inputs.push(input); 51 | } 52 | inputs 53 | } 54 | 55 | pub fn generate_fake_tensor(dtype: DataType, shape: Vec) -> ArrayD { 56 | eprintln!("\x1b[93mWARNING\x1b[0m: Generating fake tensor for ONNX model. This is only for testing purposes."); 57 | let mut rng = StdRng::from_entropy(); 58 | let val_num = shape.iter().fold(1, |acc, x| acc * x); 59 | let input = match dtype { 60 | DataType::Float | DataType::Float16 | DataType::Double => (0..val_num).map(|_| Fr::from(rng.gen_range(-2..2))).collect(), 61 | DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => (0..val_num).map(|_| Fr::from(1)).collect(), 62 | DataType::Uint8 | DataType::Uint16 | DataType::Uint32 | DataType::Uint64 => (0..val_num).map(|_| Fr::from(1)).collect(), 63 | DataType::Bool => (0..val_num).map(|_| Fr::from(rng.gen_range(0..2))).collect(), 64 | _ => panic!("Unsupported constant type: {:?}", dtype), 65 | }; 66 | ArrayD::from_shape_vec(shape, input).unwrap() 67 | } 68 | 69 | // Converts ints for the DataType enum into DatumType 70 | // https://docs.rs/tract-onnx/latest/tract_onnx/pb/tensor_proto/enum.DataType.html 71 | pub fn datatype_to_datumtype(t: i32) -> DatumType { 72 | match t { 73 | 2 | 3 | 4 | 5 | 6 | 7 | 12 | 13 => DatumType::I64, 74 | 1 | 10 | 11 => DatumType::F32, 75 | 8 => DatumType::String, 76 | 9 => DatumType::Bool, 77 | _ => panic!("DatumType {:?} not supported", t), 78 | } 79 | } 80 | 81 | #[derive(Deserialize, Debug)] 82 | struct InputData { 83 | input_data: Vec>, 84 | } 85 | 86 | pub fn load_inputs_from_json_for_onnx(onnx_name: &str, json_name: &str) -> Vec> { 87 | let onnx = tract_onnx::onnx(); 88 | let onnx_graph = onnx.proto_model_for_path(onnx_name).unwrap().graph.unwrap(); 89 | let mut inputs = vec![]; 90 | 91 | let json = std::fs::read_to_string(json_name).expect("Failed to read file"); 92 | let json: InputData = serde_json::from_str(&json).unwrap(); 93 | 94 | for (i, onnx_input) in onnx_graph.input.iter().enumerate() { 95 | let tract_onnx::pb::type_proto::Value::TensorType(t) = onnx_input.r#type.as_ref().unwrap().value.as_ref().unwrap(); 96 | let shape = get_shape_from_onnx_tensor(t); 97 | 98 | let input = match t.elem_type() { 99 | DataType::Float | DataType::Float16 | DataType::Double => { 100 | let input: Vec = json.input_data[i] 101 | .iter() 102 | .map(|x| { 103 | let y = (*x * onnx::SF_FLOAT.read().unwrap().to_owned() as f64).round(); 104 | Fr::from(y as i32) 105 | }) 106 | .collect(); 107 | input 108 | } 109 | 110 | DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { 111 | let input: Vec = json.input_data[i].iter().map(|x| Fr::from(*x as i32)).collect(); 112 | input 113 | } 114 | DataType::Uint8 | DataType::Uint16 | DataType::Uint32 | DataType::Uint64 => { 115 | let input: Vec = json.input_data[i].iter().map(|x| Fr::from(*x as u32)).collect(); 116 | input 117 | } 118 | DataType::Bool => { 119 | let input: Vec = json.input_data[i].iter().map(|x| Fr::from(*x as u8)).collect(); 120 | input 121 | } 122 | _ => panic!("Unsupported constant type: {:?}", t.elem_type()), 123 | }; 124 | let input = ArrayD::from_shape_vec(shape, input).unwrap(); 125 | let input = pad_to_pow_of_two(&input, &Fr::zero()); 126 | inputs.push(input); 127 | } 128 | inputs 129 | } 130 | 131 | // Converts DatumType to the corresponding scale factor 132 | // It should only be used in the IN_SF/OUT_SF of nonlinear basicblocks 133 | pub fn datumtype_to_sf(t: DatumType) -> usize { 134 | match t { 135 | DatumType::I32 => 1, 136 | DatumType::I64 => 1, 137 | DatumType::Bool => 1, 138 | DatumType::F32 => onnx::SF_LOG.read().unwrap().to_owned(), 139 | _ => panic!("DatumType {:?} not supported", t), 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/layer/mul.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ark_std::One; 8 | use ndarray::ArrayD; 9 | use tract_onnx::pb::AttributeProto; 10 | use tract_onnx::prelude::DatumType; 11 | 12 | pub struct MulLayer; 13 | impl Layer for MulLayer { 14 | fn graph( 15 | input_shapes: &Vec<&Vec>, 16 | input_types: &Vec, 17 | _constants: &Vec, DatumType)>>, 18 | _attributes: &Vec<&AttributeProto>, 19 | ) -> (Graph, Vec>, Vec) { 20 | let mut graph = Graph::new(); 21 | let mul_scalar = if input_shapes[0].len() == input_shapes[1].len() && input_shapes[0].len() == 0 { 22 | graph.addBB(Box::new(MulScalarBasicBlock {})) 23 | } else { 24 | graph.addBB(Box::new(RepeaterBasicBlock { 25 | basic_block: Box::new(MulScalarBasicBlock {}), 26 | N: 1, 27 | })) 28 | }; 29 | let sf_log = onnx::SF_LOG.read().unwrap().to_owned(); 30 | let change_SF = graph.addBB(Box::new(ChangeSFBasicBlock { 31 | input_SF: sf_log * 2, 32 | output_SF: sf_log, 33 | })); 34 | let change_SF_check = if input_shapes[0].len() == input_shapes[1].len() && input_shapes[0].len() == 0 { 35 | graph.addBB(Box::new(CQ2BasicBlock { 36 | n: 1, 37 | setup: Some(( 38 | Box::new(ChangeSFBasicBlock { 39 | input_SF: sf_log * 2, 40 | output_SF: sf_log, 41 | }), 42 | *onnx::CQ_RANGE_LOWER, 43 | *onnx::CQ_RANGE, 44 | )), 45 | })) 46 | } else { 47 | graph.addBB(Box::new(RepeaterBasicBlock { 48 | basic_block: Box::new(CQ2BasicBlock { 49 | n: if input_shapes[1].len() == 0 { 50 | input_shapes[0][input_shapes[0].len() - 1].next_power_of_two() 51 | } else { 52 | std::cmp::max( 53 | input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 54 | input_shapes[1][input_shapes[1].len() - 1].next_power_of_two(), 55 | ) 56 | }, 57 | setup: Some(( 58 | Box::new(ChangeSFBasicBlock { 59 | input_SF: sf_log * 2, 60 | output_SF: sf_log, 61 | }), 62 | *onnx::CQ_RANGE_LOWER, 63 | *onnx::CQ_RANGE, 64 | )), 65 | }), 66 | N: 1, 67 | })) 68 | }; 69 | // If any of the inputs are scalars, use the scalar version of the mul basic block. 70 | // If the first input is a scalar, swap the inputs, because the mul scalar basic block expects the scalar to be the second input. If the last dimension differs between the two inputs, broadcast. 71 | let mul_output = if input_shapes[0].len() == 0 { 72 | graph.addNode(mul_scalar, vec![(-2, 0), (-1, 0)]) 73 | } else if input_shapes[1].len() > 0 && input_shapes[0].last().unwrap() != input_shapes[1].last().unwrap() { 74 | let (broadcast_inp, mul_inp, broadcast_idx) = if input_shapes[0].last().unwrap() > input_shapes[1].last().unwrap() { 75 | (-2, -1, 0) 76 | } else { 77 | (-1, -2, 1) 78 | }; 79 | let constantOfShape = graph.addBB(Box::new(ConstOfShapeBasicBlock { 80 | c: Fr::one(), 81 | shape: input_shapes[broadcast_idx].iter().map(|x| x.next_power_of_two()).collect(), 82 | })); 83 | let mul_scalar = graph.addBB(Box::new(RepeaterBasicBlock { 84 | basic_block: Box::new(MulScalarBasicBlock {}), 85 | N: 1, 86 | })); 87 | let constantOfShape_output = graph.addNode(constantOfShape, vec![]); 88 | if *input_shapes[0].last().unwrap() == 1 || *input_shapes[1].last().unwrap() == 1 { 89 | let broadcast_output = graph.addNode(mul_scalar, vec![(constantOfShape_output, 0), (broadcast_inp, 0)]); 90 | let mul_inp_idx = (-mul_inp - 1) as usize; 91 | let len = util::next_pow(input_shapes[mul_inp_idx][input_shapes[mul_inp_idx].len() - 1] as u32) as usize; 92 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 93 | basic_block: Box::new(MulBasicBlock { len }), 94 | N: 1, 95 | })); 96 | graph.addNode(mul, vec![(mul_inp, 0), (broadcast_output, 0)]) 97 | } else { 98 | let inp_shape = input_shapes[broadcast_idx]; 99 | let len = util::next_pow(inp_shape[inp_shape.len() - 1] as u32) as usize; 100 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 101 | basic_block: Box::new(MulBasicBlock { len }), 102 | N: 1, 103 | })); 104 | let broadcast_output = graph.addNode(mul, vec![(constantOfShape_output, 0), (broadcast_inp, 0)]); 105 | graph.addNode(mul, vec![(broadcast_output, 0), (mul_inp, 0)]) 106 | } 107 | } else { 108 | let mul_basicblock = if input_shapes[1].len() == 0 || input_shapes[0].len() == 0 { 109 | mul_scalar 110 | } else { 111 | let len = if input_shapes[0].len() > input_shapes[1].len() { 112 | util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize 113 | } else { 114 | util::next_pow(input_shapes[1][input_shapes[1].len() - 1] as u32) as usize 115 | }; 116 | graph.addBB(Box::new(RepeaterBasicBlock { 117 | basic_block: Box::new(MulBasicBlock { len }), 118 | N: 1, 119 | })) 120 | }; 121 | graph.addNode(mul_basicblock, vec![(-1, 0), (-2, 0)]) 122 | }; 123 | 124 | if input_types[0].is_float() { 125 | let change_SF_output = graph.addNode(change_SF, vec![(mul_output, 0)]); 126 | let _ = graph.addNode(change_SF_check, vec![(mul_output, 0), (change_SF_output, 0)]); 127 | graph.outputs.push((change_SF_output, 0)); 128 | } else if input_types[0].is_integer() { 129 | graph.outputs.push((mul_output, 0)); 130 | } else { 131 | panic!("Mul input type {:?} is not supported", input_types[0]); 132 | } 133 | (graph, vec![util::broadcastDims(input_shapes, 0)], vec![input_types[0]]) 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/layer/pow.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ark_std::One; 8 | use ndarray::ArrayD; 9 | use rayon::iter::ParallelIterator; 10 | use tract_onnx::pb::AttributeProto; 11 | use tract_onnx::prelude::DatumType; 12 | 13 | #[derive(Debug)] 14 | pub struct PrecomputedPowBasicBlock { 15 | pub input_SF: usize, 16 | pub output_SF: usize, 17 | } 18 | impl BasicBlock for PrecomputedPowBasicBlock { 19 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 20 | let base = util::fr_to_int(*inputs[0].first().unwrap()) as f64; 21 | let shape = inputs[1].shape(); 22 | let out = util::array_into_iter(inputs[1]) 23 | .map(|x| { 24 | let mut b = base; 25 | b /= (1 << self.input_SF) as f64; 26 | b = b.powi(util::fr_to_int(*x).try_into().unwrap()); 27 | b *= (1 << self.output_SF) as f64; 28 | Fr::from(b.round() as i64) 29 | }) 30 | .collect::>(); 31 | Ok(vec![ArrayD::from_shape_vec(shape, out).unwrap()]) 32 | } 33 | } 34 | 35 | pub struct PowLayer; 36 | impl Layer for PowLayer { 37 | fn graph( 38 | input_shapes: &Vec<&Vec>, 39 | input_types: &Vec, 40 | constants: &Vec, DatumType)>>, 41 | _attributes: &Vec<&AttributeProto>, 42 | ) -> (Graph, Vec>, Vec) { 43 | let mut graph = Graph::new(); 44 | assert!(constants[1].is_some()); 45 | 46 | let sf_log = onnx::SF_LOG.read().unwrap().to_owned(); 47 | let sf_float = onnx::SF_FLOAT.read().unwrap().to_owned(); 48 | // Note: the following code is a workaround for the case that constants[0] is a scalar and constants[1] is a tensor 49 | // If we want to formally prove this, we need to 50 | // (1) either implement a new basic block that can handle this case 51 | // (2) or perform element-wise pow and copy the result to the output tensor 52 | // both of which are a little bit complicated. 53 | // Fortunately, this case only happens in the precomputable part of RoPE embedding for now. 54 | // So, we can just use a simple basic block that can handle this case without proving. 55 | // TODO: think about how to handle this case in a more general way later 56 | // if both constants[0] and constants[1] are Some 57 | if constants[0].is_some() && constants[1].is_some() { 58 | // if constants[0].len() == 1 and constants[1].len() > 1 59 | if constants[0].unwrap().0.len() == 1 && constants[1].unwrap().0.len() > 1 { 60 | let c_vec = match constants[1].unwrap().1 { 61 | DatumType::I32 | DatumType::I64 => constants[1].unwrap().0.iter().map(|x| *x).collect::>(), 62 | DatumType::F32 => constants[1].unwrap().0.iter().map(|x| Fr::from((util::fr_to_int(*x) as f32 / sf_float) as i32)).collect::>(), 63 | _ => panic!("unsupported type"), 64 | }; 65 | let shape = constants[1].unwrap().0.shape(); 66 | let const2 = graph.addBB(Box::new(Const2BasicBlock { 67 | c: ArrayD::from_shape_vec(shape, c_vec).unwrap(), 68 | })); 69 | let pow = graph.addBB(Box::new(PrecomputedPowBasicBlock { 70 | input_SF: sf_log, 71 | output_SF: sf_log, 72 | })); 73 | let const2_output = graph.addNode(const2, vec![]); 74 | let pow_output = graph.addNode(pow, vec![(-1, 0), (const2_output, 0)]); 75 | graph.outputs.push((pow_output, 0)); 76 | return (graph, vec![input_shapes[1].clone()], vec![input_types[0]]); 77 | } 78 | } 79 | 80 | assert!(constants[1].unwrap().0.len() == 1); 81 | let N = match constants[1].unwrap().1 { 82 | DatumType::I32 | DatumType::I64 => util::fr_to_int(*constants[1].unwrap().0.first().unwrap()), 83 | DatumType::F32 => (util::fr_to_int(*constants[1].unwrap().0.first().unwrap()) as f32 / sf_float) as i128, 84 | _ => panic!("unsupported type"), 85 | }; 86 | 87 | assert!(N >= 0); 88 | if N == 0 { 89 | let endShape_padded: Vec = input_shapes[0].clone().iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 90 | let one = graph.addBB(Box::new(ConstOfShapeBasicBlock { 91 | c: Fr::one(), 92 | shape: endShape_padded.clone(), 93 | })); 94 | let one_output = graph.addNode(one, vec![]); 95 | graph.outputs.push((one_output, 0)); 96 | return (graph, vec![input_shapes[0].clone()], vec![input_types[0]]); 97 | } 98 | 99 | let len = util::next_pow(input_shapes[0][input_shapes[0].len() - 1] as u32) as usize; 100 | let mul = graph.addBB(Box::new(RepeaterBasicBlock { 101 | basic_block: Box::new(MulBasicBlock { len }), 102 | N: 1, 103 | })); 104 | let change_SF = graph.addBB(Box::new(ChangeSFBasicBlock { 105 | input_SF: sf_log * 2, 106 | output_SF: sf_log, 107 | })); 108 | let change_SF_check = graph.addBB(Box::new(RepeaterBasicBlock { 109 | basic_block: Box::new(CQ2BasicBlock { 110 | n: input_shapes[0][input_shapes[0].len() - 1].next_power_of_two(), 111 | setup: Some(( 112 | Box::new(ChangeSFBasicBlock { 113 | input_SF: sf_log * 2, 114 | output_SF: sf_log, 115 | }), 116 | *onnx::CQ_RANGE_LOWER, 117 | *onnx::CQ_RANGE, 118 | )), 119 | }), 120 | N: 1, 121 | })); 122 | let mut mul_output = -1; 123 | let mut change_SF_output = -1; 124 | // TODO: when N > 2, it is better to use a more efficient way to calculate the power such as the way in nonlinear.rs 125 | for _i in 1..N { 126 | mul_output = graph.addNode(mul, vec![(-1, 0), (mul_output, 0)]); 127 | change_SF_output = graph.addNode(change_SF, vec![(mul_output, 0)]); 128 | let _ = graph.addNode(change_SF_check, vec![(mul_output, 0), (change_SF_output, 0)]); 129 | } 130 | 131 | graph.outputs.push((change_SF_output, 0)); 132 | (graph, vec![input_shapes[0].clone()], vec![input_types[0]]) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/layer/gemm.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::onnx; 5 | use crate::util; 6 | use ark_bn254::Fr; 7 | use ndarray::{arr1, ArrayD}; 8 | use tract_onnx::pb::AttributeProto; 9 | use tract_onnx::prelude::DatumType; 10 | 11 | // Gemm computes Y = alpha * A' * B' + beta * C, where 12 | // the first input tensor A has shape (M, K) or (K, M), 13 | // the second input tensor B has shape (K, N) or (N, K), 14 | // (optional) the third input tensor C is broadcastable to shape (M, N), 15 | // and output tensor Y has shape (M, N). 16 | // A will be transposed to A' before doing the computation if attribute transA is non-zero, same for B and transB. 17 | pub struct GemmLayer; 18 | impl Layer for GemmLayer { 19 | fn graph( 20 | input_shapes: &Vec<&Vec>, 21 | input_types: &Vec, 22 | _constants: &Vec, DatumType)>>, 23 | attributes: &Vec<&AttributeProto>, 24 | ) -> (Graph, Vec>, Vec) { 25 | let mut graph = Graph::new(); 26 | 27 | let alpha = if attributes.iter().any(|x| x.name == "alpha") { 28 | attributes.iter().filter(|x| x.name == "alpha").next().unwrap().f 29 | } else { 30 | 1.0 31 | }; 32 | let beta = if attributes.iter().any(|x| x.name == "beta") { 33 | attributes.iter().filter(|x| x.name == "beta").next().unwrap().f 34 | } else { 35 | 1.0 36 | }; 37 | let transA = if attributes.iter().any(|x| x.name == "transA") { 38 | attributes.iter().filter(|x| x.name == "transA").next().unwrap().i as usize 39 | } else { 40 | 0 41 | }; 42 | let transB = if attributes.iter().any(|x| x.name == "transB") { 43 | attributes.iter().filter(|x| x.name == "transB").next().unwrap().i as usize 44 | } else { 45 | 0 46 | }; 47 | 48 | let (M, K_a) = if transA == 0 { 49 | (input_shapes[0][0], input_shapes[0][1]) 50 | } else { 51 | (input_shapes[0][1], input_shapes[0][0]) 52 | }; 53 | let (K_b, N) = if transB == 0 { 54 | (input_shapes[1][0], input_shapes[1][1]) 55 | } else { 56 | (input_shapes[1][1], input_shapes[1][0]) 57 | }; 58 | assert!(K_a == K_b); 59 | 60 | let add = graph.addBB(Box::new(RepeaterBasicBlock { 61 | basic_block: Box::new(AddBasicBlock {}), 62 | N: 1, 63 | })); 64 | let mul_scalar = graph.addBB(Box::new(RepeaterBasicBlock { 65 | basic_block: Box::new(MulScalarBasicBlock {}), 66 | N: 1, 67 | })); 68 | let M_pad = util::next_pow(M as u32) as usize; 69 | let N_pad = util::next_pow(N as u32) as usize; 70 | let K_pad = util::next_pow(K_a as u32) as usize; 71 | let matmul = graph.addBB(Box::new(RepeaterBasicBlock { 72 | basic_block: Box::new(MatMulBasicBlock { m: K_pad, n: N_pad }), 73 | N: 2, 74 | })); 75 | let sf_log = onnx::SF_LOG.read().unwrap().to_owned(); 76 | let change_SF = graph.addBB(Box::new(ChangeSFBasicBlock { 77 | input_SF: sf_log * 2, 78 | output_SF: sf_log, 79 | })); 80 | let change_SF_check = graph.addBB(Box::new(RepeaterBasicBlock { 81 | basic_block: Box::new(CQ2BasicBlock { 82 | n: N.next_power_of_two(), 83 | setup: Some(( 84 | Box::new(ChangeSFBasicBlock { 85 | input_SF: sf_log * 2, 86 | output_SF: sf_log, 87 | }), 88 | *onnx::CQ_RANGE_LOWER, 89 | *onnx::CQ_RANGE, 90 | )), 91 | }), 92 | N: 1, 93 | })); 94 | let sf_float = onnx::SF_FLOAT.read().unwrap().to_owned(); 95 | let alpha = graph.addBB(Box::new(Const2BasicBlock { 96 | c: arr1(&vec![Fr::from((alpha * sf_float) as i64)]).into_dyn(), 97 | })); 98 | let beta = graph.addBB(Box::new(Const2BasicBlock { 99 | c: arr1(&vec![Fr::from((beta * sf_float) as i64)]).into_dyn(), 100 | })); 101 | 102 | let permutation_A = ((0..M_pad).map(|x| x * K_pad).collect(), (0..K_pad).collect()); 103 | let permutation_B = ((0..N_pad).map(|x| x * K_pad).collect(), (0..K_pad).collect()); 104 | let mut A_output = -1; 105 | let mut B_output = -2; 106 | if transA != 0 { 107 | let transpose_A = graph.addBB(Box::new(RepeaterBasicBlock { 108 | basic_block: Box::new(PermuteBasicBlock { 109 | permutation: permutation_A, 110 | n: K_pad, 111 | m: M_pad, 112 | }), 113 | N: 2, 114 | })); 115 | A_output = graph.addNode(transpose_A, vec![(A_output, 0)]); 116 | } 117 | if transB == 0 { 118 | let transpose_B = graph.addBB(Box::new(RepeaterBasicBlock { 119 | basic_block: Box::new(PermuteBasicBlock { 120 | permutation: permutation_B, 121 | n: K_pad, 122 | m: N_pad, 123 | }), 124 | N: 2, 125 | })); 126 | B_output = graph.addNode(transpose_B, vec![(B_output, 0)]); 127 | } 128 | 129 | let alpha_output = graph.addNode(alpha, vec![]); 130 | let matmul_output = graph.addNode(matmul, vec![(A_output, 0), (B_output, 0)]); 131 | let change_SF_output = graph.addNode(change_SF, vec![(matmul_output, 0)]); 132 | let _ = graph.addNode(change_SF_check, vec![(matmul_output, 0), (change_SF_output, 0)]); 133 | let mul_output_AB = graph.addNode(mul_scalar, vec![(change_SF_output, 0), (alpha_output, 0)]); 134 | let mut output = graph.addNode(change_SF, vec![(mul_output_AB, 0)]); 135 | let _ = graph.addNode(change_SF_check, vec![(mul_output_AB, 0), (change_SF_output, 0)]); 136 | if input_shapes.len() > 2 { 137 | // C exists 138 | let beta_output = graph.addNode(beta, vec![]); 139 | let mul_output_C = graph.addNode(mul_scalar, vec![(-3, 0), (beta_output, 0)]); 140 | let change_SF_output_C = graph.addNode(change_SF, vec![(mul_output_C, 0)]); 141 | let _ = graph.addNode(change_SF_check, vec![(mul_output_C, 0), (change_SF_output_C, 0)]); 142 | output = graph.addNode(add, vec![(output, 0), (change_SF_output_C, 0)]); 143 | } 144 | 145 | graph.outputs.push((output, 0)); 146 | 147 | let output_shape = vec![M, N]; 148 | (graph, vec![output_shape], vec![input_types[0]]) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /src/layer/squeeze.rs: -------------------------------------------------------------------------------- 1 | use crate::basic_block::*; 2 | use crate::graph::*; 3 | use crate::layer::Layer; 4 | use crate::util::{self, get_reshape_indices}; 5 | use ark_bn254::Fr; 6 | use ndarray::{ArrayD, Axis, IxDyn}; 7 | use tract_onnx::pb::AttributeProto; 8 | use tract_onnx::prelude::DatumType; 9 | 10 | // Squeeze the input tensor by removing all dimensions of size 1. 11 | // If axes is provided, remove the dimensions specified by axes. 12 | // Otherwise, remove all dimensions of size 1. 13 | // If the last dimension is squeezed, we need to permute the tensor before reshaping because the last dimension affects the commitment. 14 | pub struct SqueezeLayer; 15 | impl Layer for SqueezeLayer { 16 | fn graph( 17 | input_shapes: &Vec<&Vec>, 18 | input_types: &Vec, 19 | constants: &Vec, DatumType)>>, 20 | attributes: &Vec<&AttributeProto>, 21 | ) -> (Graph, Vec>, Vec) { 22 | let mut graph = Graph::new(); 23 | 24 | let axes_result = attributes.iter().filter(|x| x.name == "axes").next(); 25 | let mut axes: Vec<_>; 26 | if let Some(x) = axes_result { 27 | // axes is provided 28 | axes = x.ints.iter().map(|x| *x as i64).collect(); 29 | } else { 30 | // axes is not provided 31 | axes = match constants.get(1) { 32 | Some(x) => x.unwrap().0.iter().map(|x| util::fr_to_int(*x) as i64).collect(), 33 | _ => input_shapes[0].iter().enumerate().filter(|(_, x)| **x == 1).map(|(i, _)| i as i64).collect(), 34 | }; 35 | } 36 | 37 | // map negative axes to positive 38 | axes = axes.iter().map(|&x| if x < 0 { input_shapes[0].len() as i64 + x } else { x }).collect(); 39 | 40 | let startShape = input_shapes[0]; 41 | assert!(axes.iter().all(|&x| startShape[x as usize] == 1)); 42 | let endShape: Vec<_> = startShape.iter().enumerate().filter(|(i, _)| !axes.contains(&(*i as i64))).map(|(_, x)| *x).collect(); 43 | 44 | if startShape.last() == endShape.last() { 45 | let reshape = graph.addBB(Box::new(ReshapeBasicBlock { 46 | shape: endShape.clone().iter().map(|&x| util::next_pow(x as u32) as usize).collect(), 47 | })); 48 | let output = graph.addNode(reshape, vec![(-1, 0)]); 49 | graph.outputs.push((output, 0)); 50 | } else { 51 | let startShape_padded: Vec<_> = startShape.iter().map(|&x| util::next_pow(x as u32) as usize).collect(); 52 | let permutation = get_reshape_indices(startShape.clone(), endShape.clone()); 53 | let cc = graph.addBB(Box::new(CopyConstraintBasicBlock { 54 | permutation: permutation.clone(), 55 | input_dim: IxDyn(&startShape_padded), 56 | padding_partition: copy_constraint::PaddingEnum::Zero, 57 | })); 58 | let output = graph.addNode(cc, vec![(-1, 0)]); 59 | graph.outputs.push((output, 0)); 60 | } 61 | (graph, vec![endShape], vec![input_types[0]]) 62 | } 63 | } 64 | 65 | #[derive(Debug)] 66 | pub struct UnsqueezeBasicBlock; 67 | impl BasicBlock for UnsqueezeBasicBlock { 68 | fn run(&self, _model: &ArrayD, inputs: &Vec<&ArrayD>) -> Result>, util::CQOutOfRangeError> { 69 | // unsqueeze the input tensor 70 | let r = inputs[0].clone(); 71 | let r = r.insert_axis(Axis(0)); 72 | Ok(vec![r]) 73 | } 74 | } 75 | 76 | // Unsqueeze the input tensor by adding a dimension of size 1 at the specified axis. 77 | // If the last dimension is unsqueezed, we need to permute the tensor before reshaping because the last dimension affects the commitment. 78 | // Otherwise, when the last dimension is not unsqueezed or an arr0 is unsqueezed (special case), we can directly reshape it. 79 | pub struct UnsqueezeLayer; 80 | impl Layer for UnsqueezeLayer { 81 | fn graph( 82 | input_shapes: &Vec<&Vec>, 83 | input_types: &Vec, 84 | constants: &Vec, DatumType)>>, 85 | attributes: &Vec<&AttributeProto>, 86 | ) -> (Graph, Vec>, Vec) { 87 | let mut graph = Graph::new(); 88 | 89 | let axis: isize = match attributes.iter().filter(|x| x.name == "axes").next() { 90 | Some(v) => v.ints[0] as isize, 91 | None => util::fr_to_int(constants[1].unwrap().0[0]) as isize, 92 | }; 93 | let axis = if axis < 0 { input_shapes[0].len() as isize + axis + 1 } else { axis }; 94 | 95 | let startShape = input_shapes[0]; 96 | let endShape: Vec<_> = (0..startShape.len() + 1) 97 | .map(|x| { 98 | if x == axis as usize { 99 | 1 100 | } else { 101 | if x > axis as usize { 102 | startShape[x - 1] 103 | } else { 104 | startShape[x] 105 | } 106 | } 107 | }) 108 | .collect(); 109 | 110 | if startShape.last() == endShape.last() { 111 | let reshape = graph.addBB(Box::new(ReshapeBasicBlock { 112 | shape: endShape.clone().iter().map(|&x| util::next_pow(x as u32) as usize).collect(), 113 | })); 114 | let output = graph.addNode(reshape, vec![(-1, 0)]); 115 | graph.outputs.push((output, 0)); 116 | } else if startShape.last() > endShape.last() { 117 | let n = endShape.len(); 118 | let mut a = endShape[n - 2]; 119 | assert!(*startShape.last().unwrap() == a); 120 | let mut intermediateShape = endShape[..n - 2].to_vec(); 121 | intermediateShape.push(1); 122 | intermediateShape.push(*startShape.last().unwrap()); 123 | intermediateShape.iter_mut().for_each(|x| *x = util::next_pow(*x as u32) as usize); 124 | let reshape = graph.addBB(Box::new(ReshapeBasicBlock { shape: intermediateShape })); 125 | a = util::next_pow(a as u32) as usize; 126 | let permutation = ((0..a).map(|x| x).collect(), vec![0]); 127 | let permute = graph.addBB(Box::new(RepeaterBasicBlock { 128 | basic_block: Box::new(PermuteBasicBlock { 129 | permutation: permutation, 130 | n: 1, 131 | m: a, 132 | }), 133 | N: 2, 134 | })); 135 | let intermediate = graph.addNode(reshape, vec![(-1, 0)]); 136 | let output = graph.addNode(permute, vec![(intermediate, 0)]); 137 | graph.outputs.push((output, 0)); 138 | } else { 139 | // special case (startShape.last() < endShape.last()): [] --> [1] 140 | let unsqueeze = graph.addBB(Box::new(UnsqueezeBasicBlock {})); 141 | let unsqueeze_output = graph.addNode(unsqueeze, vec![(-1, 0)]); 142 | graph.outputs.push((unsqueeze_output, 0)); 143 | } 144 | 145 | (graph, vec![endShape], vec![input_types[0]]) 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /scratch/gptj/replace_multihead.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import helper, shape_inference 3 | 4 | def replace_for_multihead(model): 5 | graph = model.graph 6 | nodes = graph.node 7 | 8 | # Build mappings from tensor names to nodes 9 | tensor_producers = {} 10 | tensor_consumers = {} 11 | for node in nodes: 12 | for output_name in node.output: 13 | tensor_producers[output_name] = node 14 | for input_name in node.input: 15 | tensor_consumers.setdefault(input_name, []).append(node) 16 | 17 | final_nodes_to_remove = [] 18 | nodes_to_add = [] 19 | shape_dict = {} 20 | 21 | for idx, node in enumerate(list(nodes)): 22 | shapes = [] 23 | nodes_to_remove = [] 24 | if node.op_type != 'Reshape': 25 | nodes_to_remove = [] 26 | continue 27 | reshape_node_0 = node 28 | nodes_to_remove.append(reshape_node_0) 29 | 30 | reshape_input = reshape_node_0.input 31 | original_matmul_inputs = None 32 | original_matmul_outputs = [] 33 | multihead_output = reshape_node_0.output[0] 34 | multihead_input = reshape_input[1] 35 | 36 | # Check if the input 1 comes from a Concat node 37 | input_name_0 = reshape_input[0] 38 | original_matmul_outputs.append(input_name_0) 39 | input_name_1 = reshape_input[1] 40 | input_node_0 = tensor_producers.get(input_name_0) 41 | input_node_1 = tensor_producers.get(input_name_1) 42 | 43 | shape_input = None 44 | if input_node_1.op_type != 'Concat': 45 | nodes_to_remove = [] 46 | continue 47 | else: 48 | concat_node = input_node_1 49 | concat_input_0 = concat_node.input[0] 50 | concat_input_1 = concat_node.input[1] 51 | input_node = tensor_producers.get(concat_input_0) 52 | if input_node == None or input_node.op_type != 'Unsqueeze': 53 | nodes_to_remove = [] 54 | continue 55 | else: 56 | unsqueeze_input = input_node.input[0] 57 | input_node = tensor_producers.get(unsqueeze_input) 58 | if input_node.op_type != 'Gather': 59 | nodes_to_remove = [] 60 | continue 61 | else: 62 | gather_input = input_node.input[0] 63 | input_node = tensor_producers.get(gather_input) 64 | if input_node.op_type != 'Shape': 65 | nodes_to_remove = [] 66 | continue 67 | else: 68 | shape_input = input_node.input[0] 69 | original_matmul_outputs.append(shape_input) 70 | shapes.append(input_node.name) 71 | 72 | input_node = tensor_producers.get(concat_input_1) 73 | if input_node == None or input_node.op_type != 'Unsqueeze': 74 | nodes_to_remove = [] 75 | continue 76 | else: 77 | unsqueeze_input = input_node.input[0] 78 | input_node = tensor_producers.get(unsqueeze_input) 79 | if input_node.op_type != 'Gather': 80 | nodes_to_remove = [] 81 | continue 82 | else: 83 | gather_input = input_node.input[0] 84 | input_node = tensor_producers.get(gather_input) 85 | if input_node.op_type != 'Shape': 86 | nodes_to_remove = [] 87 | continue 88 | else: 89 | shape_input = input_node.input[0] 90 | original_matmul_outputs.append(shape_input) 91 | shapes.append(input_node.name) 92 | 93 | if original_matmul_outputs[0] != original_matmul_outputs[1] or original_matmul_outputs[0] != original_matmul_outputs[2]: 94 | nodes_to_remove = [] 95 | continue 96 | 97 | 98 | if input_node_0.op_type != 'MatMul': 99 | nodes_to_remove = [] 100 | continue 101 | else: 102 | nodes_to_remove.append(input_node_0) 103 | original_matmul_inputs = input_node_0.input 104 | multihead_input = [original_matmul_inputs[0], original_matmul_inputs[1], multihead_input] 105 | for shape in shapes: 106 | shape_dict[shape] = multihead_input[0] 107 | 108 | custom_node = helper.make_node( 109 | 'MultiHeadMatMul', 110 | inputs=multihead_input, 111 | outputs=[multihead_output], 112 | name='MultiHeadMatMul_' + reshape_node_0.name.split('_')[-1] 113 | ) 114 | 115 | final_nodes_to_remove.extend(nodes_to_remove) 116 | nodes_to_add.append((idx, custom_node)) 117 | 118 | update_node_list = nodes_to_remove 119 | 120 | # Update the mappings 121 | # Remove old producer mappings 122 | for node in update_node_list: 123 | tensor_producers.pop(node.output[0], None) 124 | for input_name in node.input: 125 | consumers = tensor_consumers.get(input_name, []) 126 | if node in consumers: 127 | consumers.remove(node) 128 | 129 | # Add new producer mapping 130 | tensor_producers[multihead_output] = custom_node 131 | for input_name in custom_node.input: 132 | tensor_consumers.setdefault(input_name, []).append(custom_node) 133 | 134 | # Insert new nodes into the graph 135 | c = 0 136 | for idx, custom_node in nodes_to_add: 137 | graph.node.insert(idx + c, custom_node) 138 | c += 1 139 | 140 | # Remove old nodes from the graph 141 | for node in final_nodes_to_remove: 142 | if node in graph.node: 143 | graph.node.remove(node) 144 | 145 | # update the shape_dict 146 | for node in graph.node: 147 | if node.op_type == 'Shape' and node.name in shape_dict: 148 | node.input[0] = shape_dict[node.name] 149 | 150 | # (Optional) Infer shapes to ensure consistency 151 | #model = shape_inference.infer_shapes(model) 152 | return model 153 | 154 | # Load the original model 155 | model = onnx.load('model_gelu_rope.onnx', load_external_data=False) 156 | 157 | # Apply the pattern replacement 158 | model = replace_for_multihead(model) 159 | 160 | # Save the modified model 161 | onnx.save(model, 'model_gelu_rope_multi.onnx') 162 | -------------------------------------------------------------------------------- /scratch/gptj/replace_gelu.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import helper, shape_inference 3 | 4 | def replace_gelu(model): 5 | graph = model.graph 6 | nodes = graph.node 7 | 8 | # Build mappings from tensor names to nodes 9 | tensor_producers = {} 10 | tensor_consumers = {} 11 | for node in nodes: 12 | for output_name in node.output: 13 | tensor_producers[output_name] = node 14 | for input_name in node.input: 15 | tensor_consumers.setdefault(input_name, []).append(node) 16 | 17 | final_nodes_to_remove = [] 18 | nodes_to_add = [] 19 | 20 | for idx, mul_node in enumerate(list(nodes)): 21 | mul_node_0 = None 22 | mul_node_1 = None 23 | add_node_0 = None 24 | tanh_node_0 = None 25 | mul_node_2 = None 26 | add_node_1 = None 27 | mul_node_3 = None 28 | pow_node_0 = None 29 | 30 | nodes_to_remove = [] 31 | if mul_node.op_type != 'Mul': 32 | nodes_to_remove = [] 33 | continue 34 | mul_node_0 = mul_node 35 | nodes_to_remove.append(mul_node) 36 | 37 | mul_input = mul_node.input 38 | gelu_output = mul_node.output[0] 39 | gelu_inputs = [] 40 | 41 | # Check if the input comes from a Reshape node 42 | input_name_0 = mul_input[0] 43 | input_node = tensor_producers.get(input_name_0) 44 | 45 | if input_node.op_type != 'Mul': 46 | nodes_to_remove = [] 47 | continue 48 | else: 49 | mul_node = input_node 50 | mul_node_1 = mul_node 51 | nodes_to_remove.append(mul_node) 52 | gelu_inputs.append(mul_node.input[0]) 53 | 54 | input_name_1 = mul_input[1] 55 | input_node = tensor_producers.get(input_name_1) 56 | 57 | if input_node.op_type != 'Add': 58 | nodes_to_remove = [] 59 | continue 60 | else: 61 | add_node = input_node 62 | add_node_0 = add_node 63 | nodes_to_remove.append(add_node) 64 | add_input = add_node.input[0] 65 | input_node = tensor_producers.get(add_input) 66 | if input_node.op_type != 'Tanh': 67 | nodes_to_remove = [] 68 | continue 69 | else: 70 | tanh_node = input_node 71 | tanh_node_0 = tanh_node 72 | nodes_to_remove.append(tanh_node) 73 | tanh_input = tanh_node.input[0] 74 | input_node = tensor_producers.get(tanh_input) 75 | if input_node.op_type != 'Mul': 76 | nodes_to_remove = [] 77 | continue 78 | else: 79 | mul_node = input_node 80 | mul_node_2 = mul_node 81 | nodes_to_remove.append(mul_node) 82 | mul_input = mul_node.input[0] 83 | input_node = tensor_producers.get(mul_input) 84 | if input_node.op_type != 'Add': 85 | nodes_to_remove = [] 86 | continue 87 | else: 88 | add_node = input_node 89 | add_node_1 = add_node 90 | nodes_to_remove.append(add_node) 91 | add_input = add_node.input[1] 92 | gelu_inputs.append(add_node.input[0]) 93 | input_node = tensor_producers.get(add_input) 94 | if input_node.op_type != 'Mul': 95 | nodes_to_remove = [] 96 | continue 97 | else: 98 | mul_node = input_node 99 | mul_node_3 = mul_node 100 | nodes_to_remove.append(mul_node) 101 | mul_input = mul_node.input[0] 102 | input_node = tensor_producers.get(mul_input) 103 | if input_node.op_type != 'Pow': 104 | nodes_to_remove = [] 105 | continue 106 | else: 107 | pow_node = input_node 108 | pow_node_0 = pow_node 109 | nodes_to_remove.append(pow_node) 110 | gelu_inputs.append(pow_node.input[0]) 111 | 112 | # all of them should be the same 113 | if len(gelu_inputs) != 3 or gelu_inputs[0] != gelu_inputs[1] or gelu_inputs[0] != gelu_inputs[2]: 114 | nodes_to_remove = [] 115 | continue 116 | 117 | custom_gelu_input = gelu_inputs[0] 118 | custom_gelu_output = gelu_output 119 | 120 | custom_node = helper.make_node( 121 | 'Gelu', 122 | inputs=[custom_gelu_input], 123 | outputs=[custom_gelu_output], 124 | name='Gelu_' + mul_node.name 125 | ) 126 | 127 | final_nodes_to_remove.extend(nodes_to_remove) 128 | nodes_to_add.append((idx, custom_node)) 129 | 130 | update_node_list = [mul_node_0, mul_node_1, add_node_0, tanh_node_0, mul_node_2, add_node_1, mul_node_3, pow_node_0] 131 | 132 | # Update the mappings 133 | # Remove old producer mappings 134 | for node in update_node_list: 135 | tensor_producers.pop(node.output[0], None) 136 | for input_name in node.input: 137 | consumers = tensor_consumers.get(input_name, []) 138 | if node in consumers: 139 | consumers.remove(node) 140 | 141 | # Add new producer mapping 142 | tensor_producers[custom_gelu_output] = custom_node 143 | for input_name in custom_node.input: 144 | tensor_consumers.setdefault(input_name, []).append(custom_node) 145 | 146 | # Insert new nodes into the graph 147 | c = 0 148 | for idx, custom_node in nodes_to_add: 149 | graph.node.insert(idx + c, custom_node) 150 | c += 1 151 | 152 | # Remove old nodes from the graph 153 | for node in final_nodes_to_remove: 154 | if node in graph.node: 155 | graph.node.remove(node) 156 | 157 | # (Optional) Infer shapes to ensure consistency 158 | #model = shape_inference.infer_shapes(model) 159 | return model 160 | 161 | # Load the original model 162 | model = onnx.load('GPTj.onnx', load_external_data=False) 163 | 164 | # Apply the pattern replacement 165 | model = replace_gelu(model) 166 | 167 | # Save the modified model 168 | onnx.save(model, 'GPTj_gelu.onnx') 169 | --------------------------------------------------------------------------------