├── .env.example ├── .github └── workflows │ ├── build.yml │ └── coverage.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── docs └── operations.md ├── input └── circuit.circom ├── src ├── a_gate_type.rs ├── circom.rs ├── circom │ ├── parser.rs │ └── type_analysis.rs ├── cli.rs ├── compiler.rs ├── lib.rs ├── main.rs ├── process.rs ├── program.rs ├── runtime.rs └── topological_sort.rs └── tests ├── circuits ├── integration │ ├── addZero.circom │ ├── arrayAssignment.circom │ ├── constantSum.circom │ ├── directOutput.circom │ ├── indexOutOfBounds.circom │ ├── infixOps.circom │ ├── mainTemplateArgument.circom │ ├── matElemMul.circom │ ├── prefixOps.circom │ ├── sum.circom │ ├── underConstrained.circom │ └── xEqX.circom └── machine-learning │ ├── ArgMax.circom │ ├── AveragePooling2D.circom │ ├── BatchNormalization2D.circom │ ├── Conv1D.circom │ ├── Conv2D.circom │ ├── Dense.circom │ ├── DepthwiseConv2D.circom │ ├── Flatten2D.circom │ ├── GlobalAveragePooling2D.circom │ ├── GlobalMaxPooling2D.circom │ ├── GlobalSumPooling2D.circom │ ├── MaxPooling2D.circom │ ├── NaiveSearch.circom │ ├── PointwiseConv2D.circom │ ├── ReLU.circom │ ├── SeparableConv2D.circom │ ├── SumPooling2D.circom │ ├── Zanh.circom │ ├── ZeLU.circom │ ├── Zigmoid.circom │ ├── circomlib-matrix │ ├── matElemMul.circom │ ├── matElemSum.circom │ └── matMul.circom │ ├── circomlib │ ├── aliascheck.circom │ ├── babyjub.circom │ ├── binsum.circom │ ├── bitify.circom │ ├── comparators.circom │ ├── compconstant.circom │ ├── escalarmulany.circom │ ├── escalarmulfix.circom │ ├── mimc.circom │ ├── montgomery.circom │ ├── mux3.circom │ ├── sign.circom │ └── switcher.circom │ ├── crypto │ ├── ecdh.circom │ ├── encrypt.circom │ └── publickey_derivation.circom │ ├── fc.circom │ ├── util.circom │ └── utils-comp.circom └── integration.rs /.env.example: -------------------------------------------------------------------------------- 1 | LOG_LEVEL="debug" -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | env: 12 | CARGO_TERM_COLOR: always 13 | 14 | jobs: 15 | build: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - uses: actions/cache@v3 22 | with: 23 | path: | 24 | ~/.cargo 25 | ~/.rustup/toolchains 26 | target 27 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 28 | restore-keys: | 29 | ${{ runner.os }}-cargo- 30 | 31 | - name: Install Rust 32 | run: rustup toolchain install stable 33 | 34 | - name: Build 35 | run: cargo build --verbose 36 | 37 | - name: Clippy 38 | run: cargo clippy --verbose -- -D warnings 39 | 40 | - name: Tests 41 | run: cargo test --verbose 42 | 43 | - name: Fmt 44 | run: cargo fmt -- --check 45 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: Coverage 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | env: 12 | CARGO_TERM_COLOR: always 13 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 14 | 15 | jobs: 16 | coverage: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v3 21 | 22 | - uses: actions/cache@v3 23 | with: 24 | path: | 25 | ~/.cargo 26 | ~/.rustup/toolchains 27 | target 28 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 29 | restore-keys: | 30 | ${{ runner.os }}-cargo- 31 | 32 | - name: Install Rust 33 | run: rustup toolchain install stable 34 | 35 | - name: Install cargo-llvm-cov 36 | uses: taiki-e/install-action@cargo-llvm-cov 37 | 38 | - name: Generate code coverage 39 | run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info 40 | 41 | - name: Upload coverage to Codecov 42 | uses: codecov/codecov-action@v4 43 | with: 44 | files: lcov.info 45 | fail_ci_if_error: true 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | .DS_Store 16 | 17 | # Local 18 | .env 19 | .vscode/ 20 | 21 | # Output 22 | output/ 23 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "circom-2-arithc" 3 | version = "0.1.0" 4 | edition = "2021" 5 | resolver = "1" # Fixes lalrpop issue, see: https://github.com/lalrpop/lalrpop/issues/616 6 | 7 | [dependencies] 8 | clap = { version = "4.5.4", features = ["derive"] } 9 | dotenv = "0.15.0" 10 | env_logger = "0.11.1" 11 | log = "0.4.20" 12 | rand = "0.8.5" 13 | regex = "1.10.3" 14 | serde_json = "1.0" 15 | serde = { version = "1.0.196", features = ["derive"] } 16 | thiserror = "1.0.59" 17 | strum_macros = "0.26.4" 18 | strum = "0.26.2" 19 | sim-circuit = { git = "https://github.com/brech1/sim-circuit" } 20 | bristol-circuit = { git = "https://github.com/voltrevo/bristol-circuit", rev = "2a8b001" } 21 | boolify = { git = "https://github.com/voltrevo/boolify", rev = "6376405" } 22 | 23 | # DSL 24 | circom-circom_algebra = { git = "https://github.com/iden3/circom", package = "circom_algebra", rev = "e8e125e" } 25 | circom-code_producers = { git = "https://github.com/iden3/circom", package = "code_producers", rev = "e8e125e" } 26 | circom-compiler = { git = "https://github.com/iden3/circom", package = "compiler", rev = "e8e125e" } 27 | circom-constant_tracking = { git = "https://github.com/iden3/circom", package = "constant_tracking", rev = "e8e125e" } 28 | circom-constraint_generation = { git = "https://github.com/iden3/circom", package = "constraint_generation", rev = "e8e125e" } 29 | circom-constraint_list = { git = "https://github.com/iden3/circom", package = "constraint_list", rev = "e8e125e" } 30 | circom-constraint_writers = { git = "https://github.com/iden3/circom", package = "constraint_writers", rev = "e8e125e" } 31 | circom-dag = { git = "https://github.com/iden3/circom", package = "dag", rev = "e8e125e" } 32 | circom-parser = { git = "https://github.com/iden3/circom", package = "parser", rev = "e8e125e" } 33 | circom-program_structure = { git = "https://github.com/iden3/circom", package = "program_structure", rev = "e8e125e" } 34 | circom-type_analysis = { git = "https://github.com/iden3/circom", package = "type_analysis", rev = "e8e125e" } 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Nam Ngo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Circom To Arithmetic Circuit 2 | 3 | [![MIT licensed][mit-badge]][mit-url] 4 | [![Build Status][actions-badge]][actions-url] 5 | [![codecov](https://codecov.io/github/namnc/circom-2-arithc/graph/badge.svg)](https://app.codecov.io/github/namnc/circom-2-arithc/) 6 | 7 | [mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg 8 | [mit-url]: https://github.com/namnc/circom-2-arithc/blob/master/LICENSE 9 | [actions-badge]: https://github.com/namnc/circom-2-arithc/actions/workflows/build.yml/badge.svg 10 | [actions-url]: https://github.com/namnc/circom-2-arithc/actions?query=branch%3Amain 11 | 12 | This library enables the creation of arithmetic circuits from circom programs. 13 | 14 | ## Supported Circom Features 15 | 16 | | Category | Type | Supported | 17 | | --------------- | ------------------------ | :-------: | 18 | | **Statements** | `InitializationBlock` | ✅ | 19 | | | `Block` | ✅ | 20 | | | `Substitution` | ✅ | 21 | | | `Declaration` | ✅ | 22 | | | `IfThenElse` | ✅ | 23 | | | `While` | ✅ | 24 | | | `Return` | ✅ | 25 | | | `MultSubstitution` | ❌ | 26 | | | `UnderscoreSubstitution` | ❌ | 27 | | | `ConstraintEquality` | ❌ | 28 | | | `LogCall` | ❌ | 29 | | | `Assert` | ✅ | 30 | | **Expressions** | `Call` | ✅ | 31 | | | `InfixOp` | ✅ | 32 | | | `Number` | ✅ | 33 | | | `Variable` | ✅ | 34 | | | `PrefixOp` | ✅ | 35 | | | `InlineSwitchOp` | ❌ | 36 | | | `ParallelOp` | ❌ | 37 | | | `AnonymousComp` | ✅ | 38 | | | `ArrayInLine` | ❌ | 39 | | | `Tuple` | ✅ | 40 | | | `UniformArray` | ❌ | 41 | 42 | ## Circomlib 43 | 44 | WIP 45 | 46 | ## Requirements 47 | 48 | - Rust: To install, follow the instructions found [here](https://www.rust-lang.org/tools/install). 49 | 50 | ## Getting Started 51 | 52 | - Write your circom program in the `input` directory under the `circuit.circom` name. 53 | 54 | - Build the program 55 | 56 | ```bash 57 | cargo build --release 58 | ``` 59 | 60 | - Run the compilation 61 | 62 | ```bash 63 | cargo run --release 64 | ``` 65 | 66 | The compiled circuit and circuit report can be found in the `./output` directory. 67 | 68 | ### Boolean Circuits 69 | 70 | Although this library is named after arithmetic circuits, the CLI integrates [boolify](https://github.com/voltrevo/boolify) allowing further compilation down to boolean circuits. 71 | 72 | To achieve this, add `--boolify-width DESIRED_INT_WIDTH` to your command: 73 | 74 | ```bash 75 | cargo run --release -- --boolify-width 16 76 | ``` 77 | 78 | ## ZK/MPC/FHE backends: 79 | 80 | - [circom-mp-spdz](https://github.com/namnc/circom-mp-spdz) 81 | 82 | ## Contributing 83 | 84 | Contributions are welcome! 85 | 86 | ## License 87 | 88 | This project is licensed under the MIT License - see the LICENSE file for details. 89 | -------------------------------------------------------------------------------- /docs/operations.md: -------------------------------------------------------------------------------- 1 | # Operations 2 | 3 | This document outlines the high-level operations and how we process statements, expressions and gates creation. 4 | 5 | ## Main Component 6 | 7 | - It's a template call. 8 | - Includes initialized variables, likely in an `InitializationBlock`. 9 | - The body is traversed using `traverse_sequence_of_statements`. 10 | 11 | ## Statements 12 | 13 | Each statement in the component's body is handled by `traverse_statement`. These include: 14 | 15 | - **DataItem Declaration**: Declaring variables or signals, either as single scalar or array. Here we can just add a `DataItem` based on the type and dimension 16 | - **If-Then-Else**: Evaluates conditions (variables or function calls) with `execute_expression`, then executes the chosen path using `traverse_sequence_of_statements`. 17 | - **Loops (While/For)**: Similar to `if-then-else`, but repeats based on a condition (breaks if it's `false`). Uses `traverse_sequence_of_statements` for the loop body. 18 | - **ConstraintEquality**: Probably only for ZK, not MPC. 19 | - **Return Statement**: In a function body, this statement assigns the result directly to the variable on the left-hand side of the call. For instance, in `a = func()`, the return value of `func()` is assigned to `a`. This also applies when `func()` is part of a larger expression, like in `a = a + func()`, where the return value is used as part of the expression calculation. 20 | - **Assert**: Probably only for ZK, not MPC. 21 | - **Substitution**: Like `a <== b + c`, the right-hand side is an expression processed by `traverse_expression`. This is the primary instance where a substitution statement is executed. If the left-hand side is a variable, we use `execute_expression` for execution instead of `traverse_expression`. 22 | - **Block**: Traversed with `traverse_sequence_of_statements`. 23 | - **LogCall**: For debugging, not used. 24 | - **UnderscoreSubstitution**: Probably an anonymous substitution, to be handled later (if `circomlib-ml` uses it). 25 | 26 | ## Expressions 27 | 28 | Parts of statements, like the right side of a substitution or a flow control (if/loop) condition. 29 | 30 | - **Number**: A constant. We return its value and can add it as a named variable in the context for ease in mixed signal-variable expressions, like naming "1" for the value 1. 31 | - **Infix-Op**: If the right-hand side is a variable, we use `execute_infix_op`. If not, we use `traverse_infix_op`. It gets complex when variables are mixed with signals, like if rhs is a signal. In such cases, we use `traverse_infix_op` and might need to create an intermediate signal. For example, in `sum[1] = sum[0] + input_A[0]*input_B[0]`, we generate an intermediate signal for `input_A[0]*input_B[0]`. If one of the operands is a signal, we create an auto-named signal. If both are variables, we simply execute `execute_infix_op` and return the value. This value can be treated as a constant variable in expressions involving signals, like "1" with value 1. The `execute_infix_op` is a straightforward operation, like in `a = b + c` where `b` is `1` and `c` is `2`, resulting in `a` being `3`. 32 | - **Prefix-Op**: Handled like `infix-op`. 33 | - **Inline-Switch**: Acts like a quick If-Then-Else. 34 | - **ParallelOp**: Not addressed. 35 | - **Variable**: Returns an signal id for a signal, or the value for a variable. 36 | - **Call**: We start by identifying if the call is a template or a function. 37 | - For both templates and functions, we map the defined arguments to their initialized values at the call time. 38 | - Next, we process the body of the template or function using `traverse_sequence_of_statements`. 39 | - In the case of a _function_, we're assuming it processes only variables, as previously covered. (We're assuming that a function body doesn't include a template call) 40 | - For a _template_ call, it's similar to processing the main call. However, there's an additional step of delayed mapping for input and output signals. After traversing the template, we map the template's signals to the caller's signals. For example, in `component c = Template()` where `Template` has input signal `I` and output signal `O`, these are mapped in the caller's code as `c.I = I` and `c.O = O`. 41 | - This mapping may be going on in the `traverse_sequence_of_statements`, during `execute_delayed_declarations` if it's a complete template. 42 | - `if is_complete_template { execute_delayed_declarations(program_archive, runtime, actual_node, flags); }`. 43 | - **AnonymousComponent**: Not addressed. 44 | - **ArrayInLine**: Not addressed. 45 | - **Tuple**: Not addressed. 46 | - **UniformArray**: Not addressed. 47 | 48 | ## Creating Gates and Operations 49 | 50 | - Gates are only created when processing `traverse_infix_op`. 51 | - Based on the operation, a specific gate like a fan-in-2 gate may be added to the circuit. 52 | - For example, when we encounter `traverse_infix_op` and it results in `id_1 = id_2 + id_3`, we create an add gate with `id_2` and `id_3` as inputs and `id_1` as the output. 53 | 54 | ### Special Gates 55 | 56 | For the Comparison, Negative/Positive, and Zero Check gates, Circom implements these using an advisory approach. Our strategy should involve identifying these specific gates during the processing of template calls within `traverse expression`. We do this by matching the template name and then substituting them with dedicated gates for comparison, sign checking, and zero equality. 57 | -------------------------------------------------------------------------------- /input/circuit.circom: -------------------------------------------------------------------------------- 1 | // from 0xZKML/zk-mnist 2 | 3 | pragma circom 2.0.0; 4 | 5 | template Switcher() { 6 | signal input sel; 7 | signal input L; 8 | signal input R; 9 | signal output outL; 10 | signal output outR; 11 | 12 | signal aux; 13 | 14 | aux <== (R-L)*sel; // We create aux in order to have only one multiplication 15 | outL <== aux + L; 16 | outR <== -aux + R; 17 | } 18 | 19 | template ArgMax (n) { 20 | signal input in[n]; 21 | signal output out; 22 | 23 | // assert (out < n); 24 | signal gts[n]; // store comparators 25 | component switchers[n+1]; // switcher for comparing maxs 26 | component aswitchers[n+1]; // switcher for arg max 27 | 28 | signal maxs[n+1]; 29 | signal amaxs[n+1]; 30 | 31 | maxs[0] <== in[0]; 32 | amaxs[0] <== 0; 33 | for(var i = 0; i < n; i++) { 34 | gts[i] <== in[i] > maxs[i]; // changed to 252 (maximum) for better compatibility 35 | switchers[i+1] = Switcher(); 36 | aswitchers[i+1] = Switcher(); 37 | 38 | switchers[i+1].sel <== gts[i]; 39 | switchers[i+1].L <== maxs[i]; 40 | switchers[i+1].R <== in[i]; 41 | 42 | aswitchers[i+1].sel <== gts[i]; 43 | aswitchers[i+1].L <== amaxs[i]; 44 | aswitchers[i+1].R <== i; 45 | amaxs[i+1] <== aswitchers[i+1].outL; 46 | maxs[i+1] <== switchers[i+1].outL; 47 | } 48 | 49 | out <== amaxs[n]; 50 | } 51 | 52 | component main = ArgMax(2); 53 | 54 | /* INPUT = { 55 | "in": ["2","3","1","5","4"], 56 | "out": "3" 57 | } */ -------------------------------------------------------------------------------- /src/a_gate_type.rs: -------------------------------------------------------------------------------- 1 | use circom_program_structure::ast::ExpressionInfixOpcode; 2 | use serde::{Deserialize, Serialize}; 3 | use strum_macros::{Display as StrumDisplay, EnumString}; 4 | 5 | /// The supported Arithmetic gate types. 6 | #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, EnumString, StrumDisplay)] 7 | pub enum AGateType { 8 | AAdd, 9 | ADiv, 10 | AEq, 11 | AGEq, 12 | AGt, 13 | ALEq, 14 | ALt, 15 | AMul, 16 | ANeq, 17 | ASub, 18 | AXor, 19 | APow, 20 | AIntDiv, 21 | AMod, 22 | AShiftL, 23 | AShiftR, 24 | ABoolOr, 25 | ABoolAnd, 26 | ABitOr, 27 | ABitAnd, 28 | } 29 | 30 | impl From<&ExpressionInfixOpcode> for AGateType { 31 | fn from(opcode: &ExpressionInfixOpcode) -> Self { 32 | match opcode { 33 | ExpressionInfixOpcode::Mul => AGateType::AMul, 34 | ExpressionInfixOpcode::Div => AGateType::ADiv, 35 | ExpressionInfixOpcode::Add => AGateType::AAdd, 36 | ExpressionInfixOpcode::Sub => AGateType::ASub, 37 | ExpressionInfixOpcode::Pow => AGateType::APow, 38 | ExpressionInfixOpcode::IntDiv => AGateType::AIntDiv, 39 | ExpressionInfixOpcode::Mod => AGateType::AMod, 40 | ExpressionInfixOpcode::ShiftL => AGateType::AShiftL, 41 | ExpressionInfixOpcode::ShiftR => AGateType::AShiftR, 42 | ExpressionInfixOpcode::LesserEq => AGateType::ALEq, 43 | ExpressionInfixOpcode::GreaterEq => AGateType::AGEq, 44 | ExpressionInfixOpcode::Lesser => AGateType::ALt, 45 | ExpressionInfixOpcode::Greater => AGateType::AGt, 46 | ExpressionInfixOpcode::Eq => AGateType::AEq, 47 | ExpressionInfixOpcode::NotEq => AGateType::ANeq, 48 | ExpressionInfixOpcode::BoolOr => AGateType::ABoolOr, 49 | ExpressionInfixOpcode::BoolAnd => AGateType::ABoolAnd, 50 | ExpressionInfixOpcode::BitOr => AGateType::ABitOr, 51 | ExpressionInfixOpcode::BitAnd => AGateType::ABitAnd, 52 | ExpressionInfixOpcode::BitXor => AGateType::AXor, 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/circom.rs: -------------------------------------------------------------------------------- 1 | //! # Circom Module 2 | //! 3 | //! This module is a slighly modified version of the original code from the repository `circom` sub module, related to 4 | //! the circom program structure, project configuration and compilation. 5 | 6 | pub mod parser; 7 | pub mod type_analysis; 8 | 9 | pub const VERSION: &str = "2.1.0"; 10 | -------------------------------------------------------------------------------- /src/circom/parser.rs: -------------------------------------------------------------------------------- 1 | use crate::{circom::VERSION, cli::Args, program::ProgramError}; 2 | use circom_parser::run_parser; 3 | use circom_program_structure::{error_definition::Report, program_archive::ProgramArchive}; 4 | 5 | pub fn parse_project(args: &Args) -> Result { 6 | let initial_file = args.input.to_str().unwrap().to_string(); 7 | match run_parser(initial_file, VERSION, vec![]) { 8 | Result::Err((file_library, report_collection)) => { 9 | Report::print_reports(&report_collection, &file_library); 10 | Result::Err(ProgramError::ParsingError) 11 | } 12 | Result::Ok((program_archive, warnings)) => { 13 | Report::print_reports(&warnings, &program_archive.file_library); 14 | Result::Ok(program_archive) 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/circom/type_analysis.rs: -------------------------------------------------------------------------------- 1 | use crate::program::ProgramError; 2 | use circom_program_structure::{error_definition::Report, program_archive::ProgramArchive}; 3 | use circom_type_analysis::check_types::check_types; 4 | 5 | pub fn analyse_project(program_archive: &mut ProgramArchive) -> Result<(), ProgramError> { 6 | match check_types(program_archive) { 7 | Err(errs) => { 8 | Report::print_reports(&errs, program_archive.get_file_library()); 9 | Err(ProgramError::AnalysisError) 10 | } 11 | Ok(warns) => { 12 | Report::print_reports(&warns, program_archive.get_file_library()); 13 | Ok(()) 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/cli.rs: -------------------------------------------------------------------------------- 1 | use std::path::{Path, PathBuf}; 2 | 3 | use clap::{Parser, ValueEnum}; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum, Serialize, Deserialize, Default)] 7 | #[serde(rename_all = "lowercase")] 8 | pub enum ValueType { 9 | #[serde(rename = "sint")] 10 | #[default] 11 | Sint, 12 | #[serde(rename = "sfloat")] 13 | Sfloat, 14 | } 15 | 16 | #[derive(Parser)] 17 | #[clap(name = "Arithmetic Circuits Compiler")] 18 | #[command(disable_help_subcommand = true)] 19 | pub struct Args { 20 | /// Input file to process 21 | #[arg( 22 | short, 23 | long, 24 | help = "Path to the input file", 25 | default_value = "./input/circuit.circom" 26 | )] 27 | pub input: PathBuf, 28 | 29 | /// Output file to write the result 30 | #[arg( 31 | short, 32 | long, 33 | help = "Path to the directory where the output will be written", 34 | default_value = "./output/" 35 | )] 36 | pub output: PathBuf, 37 | 38 | #[arg( 39 | short, 40 | long, 41 | value_enum, 42 | help = "Type that'll be used for values in MPC backend", 43 | default_value_t = ValueType::Sint, 44 | )] 45 | pub value_type: ValueType, 46 | 47 | #[arg( 48 | long, 49 | help = "Optional: Convert to a boolean circuit by using integers with this number of bits", 50 | default_value = None, 51 | )] 52 | pub boolify_width: Option, 53 | } 54 | 55 | impl Args { 56 | pub fn new( 57 | input: PathBuf, 58 | output: PathBuf, 59 | value_type: ValueType, 60 | boolify_width: Option, 61 | ) -> Self { 62 | Self { 63 | input, 64 | output, 65 | value_type, 66 | boolify_width, 67 | } 68 | } 69 | } 70 | 71 | /// Function that returns output file path 72 | pub fn build_output(output_path: &Path, filename: &str, ext: &str) -> PathBuf { 73 | let mut file = output_path.to_path_buf(); 74 | file.push(format!("{}.{}", filename, ext)); 75 | file 76 | } 77 | 78 | #[cfg(test)] 79 | mod tests { 80 | use super::*; 81 | 82 | #[test] 83 | fn test_build_output() { 84 | let output_path = Path::new("./output"); 85 | let filename = "result"; 86 | let ext = "txt"; 87 | 88 | let expected = PathBuf::from("./output/result.txt"); 89 | let result = build_output(output_path, filename, ext); 90 | 91 | assert_eq!(result, expected); 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/compiler.rs: -------------------------------------------------------------------------------- 1 | //! # Circuit Module 2 | //! 3 | //! This module defines the data structures used to represent the arithmetic circuit. 4 | 5 | use crate::{ 6 | a_gate_type::AGateType, cli::ValueType, program::ProgramError, 7 | topological_sort::topological_sort, 8 | }; 9 | use bristol_circuit::{BristolCircuit, CircuitInfo, ConstantInfo, Gate}; 10 | use log::debug; 11 | use serde::{Deserialize, Serialize}; 12 | use std::collections::{HashMap, HashSet}; 13 | use thiserror::Error; 14 | 15 | /// Represents a signal in the circuit, with a name and an optional value. 16 | #[derive(Debug, Serialize, Deserialize)] 17 | pub struct Signal { 18 | name: String, 19 | value: Option, 20 | } 21 | 22 | impl Signal { 23 | /// Creates a new signal. 24 | pub fn new(name: String, value: Option) -> Self { 25 | Self { name, value } 26 | } 27 | } 28 | 29 | /// Represents a node in the circuit, a collection of signals. 30 | /// The `is_const` and `is_out` fields saves us some iterations over the signals and gates. 31 | #[derive(Default, Debug, Serialize, Deserialize)] 32 | pub struct Node { 33 | is_const: bool, 34 | is_out: bool, 35 | signals: Vec, 36 | } 37 | 38 | impl Node { 39 | /// Creates a new empty node. 40 | pub fn new() -> Self { 41 | Self { 42 | signals: Vec::new(), 43 | is_const: false, 44 | is_out: false, 45 | } 46 | } 47 | 48 | /// Creates a new node with an initial signal. 49 | pub fn new_with_signal(signal_id: u32, is_const: bool, is_out: bool) -> Self { 50 | Self { 51 | signals: vec![signal_id], 52 | is_const, 53 | is_out, 54 | } 55 | } 56 | 57 | /// Adds a set of signals to the node. 58 | pub fn add_signals(&mut self, signals: &Vec) { 59 | self.signals.extend(signals); 60 | } 61 | 62 | /// Checks if the node contains a signal. 63 | pub fn contains_signal(&self, signal_id: &u32) -> bool { 64 | self.signals.contains(signal_id) 65 | } 66 | 67 | /// Gets the signals of the node. 68 | pub fn get_signals(&self) -> &Vec { 69 | &self.signals 70 | } 71 | 72 | /// Sets the node as an output node. 73 | pub fn set_output(&mut self, is_out: bool) { 74 | self.is_out = is_out; 75 | } 76 | 77 | /// Sets the node as a constant node. 78 | pub fn set_const(&mut self, is_const: bool) { 79 | self.is_const = is_const; 80 | } 81 | } 82 | 83 | /// Represents a circuit gate, with a left-hand input, right-hand input, and output node identifiers. 84 | #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] 85 | pub struct ArithmeticGate { 86 | pub op: AGateType, 87 | pub lh_in: u32, 88 | pub rh_in: u32, 89 | pub out: u32, 90 | } 91 | 92 | impl ArithmeticGate { 93 | /// Creates a new gate. 94 | pub fn new(op: AGateType, lh_in: u32, rh_in: u32, out: u32) -> Self { 95 | Self { 96 | op, 97 | lh_in, 98 | rh_in, 99 | out, 100 | } 101 | } 102 | } 103 | 104 | /// Compilation data structure representing an arithmetic circuit with extra information, including 105 | /// a set of variables and gates. 106 | #[derive(Default, Debug, Serialize, Deserialize)] 107 | pub struct Compiler { 108 | node_count: u32, 109 | inputs: HashMap, 110 | outputs: HashMap, 111 | signals: HashMap, 112 | nodes: HashMap, 113 | gates: Vec, 114 | value_type: ValueType, 115 | } 116 | 117 | impl Compiler { 118 | pub fn new() -> Compiler { 119 | Compiler { 120 | node_count: 0, 121 | inputs: HashMap::new(), 122 | outputs: HashMap::new(), 123 | signals: HashMap::new(), 124 | nodes: HashMap::new(), 125 | gates: Vec::new(), 126 | value_type: Default::default(), 127 | } 128 | } 129 | 130 | pub fn add_inputs(&mut self, inputs: HashMap) { 131 | self.inputs.extend(inputs); 132 | } 133 | 134 | pub fn add_outputs(&mut self, outputs: HashMap) { 135 | self.outputs.extend(outputs); 136 | } 137 | 138 | /// Adds a new signal to the circuit. 139 | pub fn add_signal( 140 | &mut self, 141 | id: u32, 142 | name: String, 143 | value: Option, 144 | ) -> Result<(), CircuitError> { 145 | // Check that the signal isn't already declared 146 | if self.signals.contains_key(&id) { 147 | return Err(CircuitError::SignalAlreadyDeclared); 148 | } 149 | 150 | // Create a new signal 151 | let signal = Signal::new(name, value); 152 | self.signals.insert(id, signal); 153 | 154 | // Create a new node 155 | let node = Node::new_with_signal(id, value.is_some(), false); 156 | debug!("{:?}", node); 157 | let node_id = self.get_node_id(); 158 | self.nodes.insert(node_id, node); 159 | 160 | Ok(()) 161 | } 162 | 163 | pub fn get_signals(&self, filter: String) -> HashMap { 164 | let mut ret = HashMap::new(); 165 | for (signal_id, signal) in self.signals.iter() { 166 | if signal.name.starts_with(filter.as_str()) { 167 | ret.insert(*signal_id, signal.name.to_string()); 168 | } 169 | } 170 | ret 171 | } 172 | 173 | /// Adds a new gate to the circuit. 174 | pub fn add_gate( 175 | &mut self, 176 | gate_type: AGateType, 177 | lhs_signal_id: u32, 178 | rhs_signal_id: u32, 179 | output_signal_id: u32, 180 | ) -> Result<(), CircuitError> { 181 | // Get the signal node ids 182 | let node_ids = { 183 | let mut nodes: [u32; 3] = [0; 3]; 184 | 185 | for (&id, node) in self.nodes.iter() { 186 | if node.contains_signal(&lhs_signal_id) { 187 | nodes[0] = id; 188 | } 189 | if node.contains_signal(&rhs_signal_id) { 190 | nodes[1] = id; 191 | } 192 | if node.contains_signal(&output_signal_id) { 193 | nodes[2] = id; 194 | } 195 | } 196 | 197 | nodes 198 | }; 199 | 200 | // Set the output node as an output node 201 | self.nodes.get_mut(&node_ids[2]).unwrap().set_output(true); 202 | 203 | // Create gate 204 | let gate = ArithmeticGate::new(gate_type, node_ids[0], node_ids[1], node_ids[2]); 205 | debug!("{:?}", gate); 206 | self.gates.push(gate); 207 | 208 | Ok(()) 209 | } 210 | 211 | /// Creates a connection between two signals in the circuit. 212 | /// This is finding the nodes that contain these signals and merging them. 213 | pub fn add_connection(&mut self, a: u32, b: u32) -> Result<(), CircuitError> { 214 | // Get the signal node ids 215 | let n = Node::new(); 216 | let nodes = { 217 | let mut nodes: [(u32, &Node); 2] = [(0, &n); 2]; 218 | 219 | for (&id, node) in self.nodes.iter() { 220 | if node.contains_signal(&a) { 221 | nodes[0] = (id, node); 222 | } 223 | if node.contains_signal(&b) { 224 | nodes[1] = (id, node); 225 | } 226 | } 227 | 228 | nodes 229 | }; 230 | 231 | let (node_a_id, node_a) = nodes[0]; 232 | let (node_b_id, node_b) = nodes[1]; 233 | 234 | // If both signals are in the same node, no action is needed 235 | if node_a_id == node_b_id { 236 | return Ok(()); 237 | } 238 | // Check for output and constant nodes 239 | if node_a.is_out && node_b.is_out { 240 | return Err(CircuitError::CannotMergeOutputNodes); 241 | } 242 | 243 | if node_a.is_const && node_b.is_const { 244 | return Err(CircuitError::CannotMergeConstantNodes); 245 | } 246 | 247 | // Merge the nodes into a new node 248 | let mut merged_node = Node::new(); 249 | 250 | // Set the new node as an output and constant 251 | merged_node.set_output(node_a.is_out || node_b.is_out); 252 | merged_node.set_const(node_a.is_const || node_b.is_const); 253 | 254 | merged_node.add_signals(node_a.get_signals()); 255 | merged_node.add_signals(node_b.get_signals()); 256 | 257 | let merged_node_id = self.get_node_id(); 258 | 259 | // Update connections in gates to point to the new merged node 260 | self.gates.iter_mut().for_each(|gate| { 261 | if gate.lh_in == node_a_id || gate.lh_in == node_b_id { 262 | gate.lh_in = merged_node_id; 263 | } 264 | if gate.rh_in == node_a_id || gate.rh_in == node_b_id { 265 | gate.rh_in = merged_node_id; 266 | } 267 | if gate.out == node_a_id || gate.out == node_b_id { 268 | gate.out = merged_node_id; 269 | } 270 | }); 271 | 272 | // Remove the old nodes and insert the new merged node 273 | self.nodes.remove(&node_a_id); 274 | self.nodes.remove(&node_b_id); 275 | self.nodes.insert(merged_node_id, merged_node); 276 | 277 | Ok(()) 278 | } 279 | 280 | pub fn update_type(&mut self, value_type: ValueType) -> Result<(), CircuitError> { 281 | self.value_type = value_type; 282 | 283 | Ok(()) 284 | } 285 | 286 | /// Generates a circuit report with input and output signals information. 287 | pub fn generate_circuit_report(&self) -> Result { 288 | // Split input and output nodes 289 | let mut input_nodes = Vec::new(); 290 | let mut output_nodes = Vec::new(); 291 | self.nodes.iter().for_each(|(&id, node)| { 292 | if node.is_out { 293 | output_nodes.push(id); 294 | } else { 295 | input_nodes.push(id); 296 | } 297 | }); 298 | 299 | // Remove output nodes that are inputs to gates 300 | output_nodes.retain(|&id| { 301 | self.gates 302 | .iter() 303 | .all(|gate| gate.lh_in != id && gate.rh_in != id) 304 | }); 305 | 306 | // Sort 307 | input_nodes.sort_unstable(); 308 | output_nodes.sort_unstable(); 309 | 310 | // Generate reports 311 | let inputs = self.generate_signal_reports(&input_nodes); 312 | let outputs = self.generate_signal_reports(&output_nodes); 313 | 314 | Ok(CircuitReport { 315 | inputs, 316 | outputs, 317 | value_type: self.value_type, 318 | }) 319 | } 320 | 321 | pub fn build_circuit(&self) -> Result { 322 | // First build up these maps so we can easily see which node id to use 323 | let mut input_to_node_id = HashMap::::new(); 324 | let mut constant_to_node_id_and_value = HashMap::::new(); 325 | let mut output_to_node_id = HashMap::::new(); 326 | 327 | for (node_id, node) in self.nodes.iter() { 328 | // Each node has a list of signal ids which all correspond to that node 329 | // The compiler associates IO with signals, so here we bridge the gap so we get 330 | // IO <=> node instead of IO <=> signal <=> node 331 | for signal_id in node.get_signals() { 332 | if let Some(input_name) = self.inputs.get(signal_id) { 333 | let prev = input_to_node_id.insert(input_name.clone(), *node_id); 334 | 335 | if prev.is_some() { 336 | return Err(CircuitError::Inconsistency { 337 | message: format!("Duplicate input {}", input_name), 338 | }); 339 | } 340 | } 341 | 342 | if let Some(output_name) = self.outputs.get(signal_id) { 343 | let prev = output_to_node_id.insert(output_name.clone(), *node_id); 344 | 345 | if prev.is_some() { 346 | return Err(CircuitError::Inconsistency { 347 | message: format!("Duplicate output {}", output_name), 348 | }); 349 | } 350 | } 351 | 352 | let signal = &self.signals[signal_id]; 353 | 354 | if let Some(value) = signal.value { 355 | constant_to_node_id_and_value.insert( 356 | format!("{}_{}", signal.name.clone(), signal_id), 357 | (*node_id, value.to_string()), 358 | ); 359 | } 360 | } 361 | } 362 | 363 | { 364 | // We want inputs at the start and outputs at the end 365 | // We won't be able to do that if a node is used for both input and output 366 | // That shouldn't happen, so we check here that it doesn't happen 367 | 368 | let node_id_to_input_name = input_to_node_id 369 | .iter() 370 | .map(|(name, node_id)| (node_id, name)) 371 | .collect::>(); 372 | 373 | for (output_name, output_node_id) in &output_to_node_id { 374 | if let Some(input_name) = node_id_to_input_name.get(output_node_id) { 375 | return Err(CircuitError::Inconsistency { 376 | message: format!( 377 | "Node {} used for both input {} and output {}", 378 | output_node_id, input_name, output_name 379 | ), 380 | }); 381 | } 382 | } 383 | } 384 | 385 | // Now node ids are like wire ids, but the compiler generates them in a way that leaves a 386 | // lot of gaps. So we assign new wire ids so they'll be sequential instead. We also do this 387 | // ensure inputs are at the start and outputs are at the end. 388 | let mut node_id_to_wire_id = HashMap::::new(); 389 | let mut next_wire_id = 0; 390 | 391 | // First inputs 392 | for node_id in input_to_node_id.values() { 393 | node_id_to_wire_id.insert(*node_id, next_wire_id); 394 | next_wire_id += 1; 395 | } 396 | 397 | // For the intermediate nodes, we need the gates in topological order so that the wires are 398 | // assigned in the order they are needed. The topological order is also needed to comply 399 | // with bristol format and allow for easy evaluation. 400 | 401 | let mut node_id_to_required_gate = HashMap::::new(); 402 | 403 | for (gate_id, gate) in self.gates.iter().enumerate() { 404 | // the gate.out node depends on this gate 405 | node_id_to_required_gate.insert(gate.out, gate_id); 406 | } 407 | 408 | let sorted_gate_ids = topological_sort(self.gates.len(), &|gate_id: usize| { 409 | let gate = &self.gates[gate_id]; 410 | let mut deps = Vec::::new(); 411 | 412 | if let Some(required_gate_id) = node_id_to_required_gate.get(&gate.lh_in) { 413 | deps.push(*required_gate_id); 414 | } 415 | 416 | if let Some(required_gate_id) = node_id_to_required_gate.get(&gate.rh_in) { 417 | deps.push(*required_gate_id); 418 | } 419 | 420 | deps 421 | })?; 422 | 423 | let output_node_ids = output_to_node_id.values().collect::>(); 424 | 425 | // Now that the gates are in order, we can assign wire ids to each node in the order they 426 | // are seen 427 | for gate_id in &sorted_gate_ids { 428 | let gate = &self.gates[*gate_id]; 429 | 430 | for node_id in &[gate.lh_in, gate.rh_in, gate.out] { 431 | if output_node_ids.contains(node_id) { 432 | // Output wires are excluded so that they can all be at the end 433 | continue; 434 | } 435 | 436 | if node_id_to_wire_id.contains_key(node_id) { 437 | continue; 438 | } 439 | 440 | node_id_to_wire_id.insert(*node_id, next_wire_id); 441 | next_wire_id += 1; 442 | } 443 | } 444 | 445 | // Assign wire ids to output nodes 446 | for node_id in output_to_node_id.values() { 447 | node_id_to_wire_id.insert(*node_id, next_wire_id); 448 | next_wire_id += 1; 449 | } 450 | 451 | // Now we can create the new gates using topological order and the new wire ids 452 | let mut new_gates = Vec::::new(); 453 | for gate_id in sorted_gate_ids { 454 | let gate = &self.gates[gate_id]; 455 | 456 | new_gates.push(Gate { 457 | inputs: vec![ 458 | node_id_to_wire_id[&gate.lh_in] as usize, 459 | node_id_to_wire_id[&gate.rh_in] as usize, 460 | ], 461 | outputs: vec![node_id_to_wire_id[&gate.out] as usize], 462 | op: gate.op.to_string(), 463 | }); 464 | } 465 | 466 | let mut constants = HashMap::::new(); 467 | 468 | for (name, (node_id, value)) in constant_to_node_id_and_value { 469 | constants.insert( 470 | name, 471 | ConstantInfo { 472 | value, 473 | wire_index: node_id_to_wire_id[&node_id] as usize, 474 | }, 475 | ); 476 | } 477 | 478 | Ok(BristolCircuit { 479 | wire_count: next_wire_id as usize, 480 | info: CircuitInfo { 481 | input_name_to_wire_index: input_to_node_id 482 | .iter() 483 | .map(|(name, node_id)| (name.clone(), node_id_to_wire_id[node_id] as usize)) 484 | .collect(), 485 | constants, 486 | output_name_to_wire_index: output_to_node_id 487 | .iter() 488 | .map(|(name, node_id)| (name.clone(), node_id_to_wire_id[node_id] as usize)) 489 | .collect(), 490 | }, 491 | gates: new_gates, 492 | io_widths: None, 493 | }) 494 | } 495 | 496 | /// Returns a node id and increments the count. 497 | fn get_node_id(&mut self) -> u32 { 498 | self.node_count += 1; 499 | self.node_count 500 | } 501 | 502 | /// Generates signal reports for a set of node IDs. 503 | fn generate_signal_reports(&self, nodes: &[u32]) -> Vec { 504 | nodes 505 | .iter() 506 | .map(|&id| { 507 | let signals = self 508 | .nodes 509 | .get(&id) 510 | .expect("Node ID not found in node map") 511 | .get_signals(); 512 | 513 | let (names, value) = signals.iter().fold((Vec::new(), None), |mut acc, &sig_id| { 514 | let signal = self 515 | .signals 516 | .get(&sig_id) 517 | .expect("Signal ID not found in signal map"); 518 | 519 | if !signal.name.contains("random_") { 520 | acc.0.push(signal.name.clone()); 521 | } 522 | if signal.value.is_some() { 523 | acc.1 = signal.value; 524 | } 525 | acc 526 | }); 527 | 528 | SignalReport { id, names, value } 529 | }) 530 | .collect() 531 | } 532 | } 533 | 534 | /// The full circuit report, containing input and output signals information. 535 | #[derive(Debug, Serialize, Deserialize)] 536 | pub struct CircuitReport { 537 | inputs: Vec, 538 | outputs: Vec, 539 | value_type: ValueType, 540 | } 541 | 542 | /// A single node report, with a list of signal names and an optional value. 543 | #[derive(Debug, Serialize, Deserialize)] 544 | pub struct SignalReport { 545 | id: u32, 546 | names: Vec, 547 | value: Option, 548 | } 549 | 550 | #[derive(Debug, Error)] 551 | pub enum CircuitError { 552 | #[error("Cannot merge constant nodes")] 553 | CannotMergeConstantNodes, 554 | #[error("Cannot merge output nodes")] 555 | CannotMergeOutputNodes, 556 | #[error("Constant value already set for variable")] 557 | ConstantValueAlreadySet, 558 | #[error("Signal is not connected to any node")] 559 | DisconnectedSignal, 560 | #[error(transparent)] 561 | IOError(#[from] std::io::Error), 562 | #[error(transparent)] 563 | ParseIntError(#[from] std::num::ParseIntError), 564 | #[error("Signal already declared")] 565 | SignalAlreadyDeclared, 566 | #[error("unsupported gate type: {0}")] 567 | UnsupportedGateType(String), 568 | #[error("Unprocessed node")] 569 | UnprocessedNode, 570 | #[error("Cyclic dependency: {message}")] 571 | CyclicDependency { message: String }, 572 | #[error("Inconsistency: {message}")] 573 | Inconsistency { message: String }, 574 | #[error("Parsing error: {message}")] 575 | ParsingError { message: String }, 576 | } 577 | 578 | impl From for ProgramError { 579 | fn from(e: CircuitError) -> Self { 580 | ProgramError::CircuitError(e) 581 | } 582 | } 583 | 584 | #[cfg(test)] 585 | mod tests { 586 | use super::*; 587 | 588 | #[test] 589 | fn test_node_with_signal() { 590 | let node = Node::new_with_signal(1, true, false); 591 | assert_eq!(node.signals.len(), 1); 592 | assert_eq!(node.signals[0], 1); 593 | assert!(node.is_const); 594 | assert!(!node.is_out); 595 | } 596 | 597 | #[test] 598 | fn test_node_add_signal() { 599 | let mut node = Node::new(); 600 | node.add_signals(&vec![1, 2, 3]); 601 | assert_eq!(node.signals.len(), 3); 602 | assert!(node.contains_signal(&1)); 603 | assert!(node.contains_signal(&2)); 604 | assert!(node.contains_signal(&3)); 605 | } 606 | 607 | #[test] 608 | fn test_node_contains_signal() { 609 | let node = Node::new_with_signal(1, true, false); 610 | assert!(node.contains_signal(&1)); 611 | assert!(!node.contains_signal(&2)); 612 | } 613 | 614 | #[test] 615 | fn test_node_set_output() { 616 | let mut node = Node::new(); 617 | node.set_output(true); 618 | assert!(node.is_out); 619 | } 620 | 621 | #[test] 622 | fn test_node_set_const() { 623 | let mut node = Node::new(); 624 | node.set_const(true); 625 | assert!(node.is_const); 626 | } 627 | 628 | #[test] 629 | fn test_compiler_add_inputs() { 630 | let mut compiler = Compiler::new(); 631 | let mut inputs = HashMap::new(); 632 | inputs.insert(1, String::from("input1")); 633 | inputs.insert(2, String::from("input2")); 634 | compiler.add_inputs(inputs); 635 | 636 | assert_eq!(compiler.inputs.len(), 2); 637 | assert_eq!(compiler.inputs[&1], "input1"); 638 | assert_eq!(compiler.inputs[&2], "input2"); 639 | } 640 | 641 | #[test] 642 | fn test_compiler_add_outputs() { 643 | let mut compiler = Compiler::new(); 644 | let mut outputs = HashMap::new(); 645 | outputs.insert(3, String::from("output1")); 646 | outputs.insert(4, String::from("output2")); 647 | compiler.add_outputs(outputs); 648 | 649 | assert_eq!(compiler.outputs.len(), 2); 650 | assert_eq!(compiler.outputs[&3], "output1"); 651 | assert_eq!(compiler.outputs[&4], "output2"); 652 | } 653 | 654 | #[test] 655 | fn test_compiler_add_signal() { 656 | let mut compiler = Compiler::new(); 657 | let result = compiler.add_signal(1, String::from("signal1"), None); 658 | 659 | assert!(result.is_ok()); 660 | assert_eq!(compiler.signals.len(), 1); 661 | assert!(compiler.signals.contains_key(&1)); 662 | assert_eq!(compiler.signals[&1].name, "signal1"); 663 | } 664 | 665 | #[test] 666 | fn test_compiler_add_duplicated_signal() { 667 | let mut compiler = Compiler::new(); 668 | compiler 669 | .add_signal(1, String::from("signal1"), None) 670 | .unwrap(); 671 | let result = compiler.add_signal(1, String::from("signal1"), None); 672 | 673 | assert!(matches!(result, Err(CircuitError::SignalAlreadyDeclared))); 674 | } 675 | 676 | #[test] 677 | fn test_compiler_get_signals() { 678 | let mut compiler = Compiler::new(); 679 | compiler 680 | .add_signal(1, String::from("signal1"), None) 681 | .unwrap(); 682 | compiler 683 | .add_signal(2, String::from("filter_signal"), None) 684 | .unwrap(); 685 | let filtered_signals = compiler.get_signals(String::from("filter")); 686 | 687 | assert_eq!(filtered_signals.len(), 1); 688 | assert_eq!(filtered_signals[&2], "filter_signal"); 689 | } 690 | 691 | #[test] 692 | fn test_compiler_add_gate() { 693 | let mut compiler = Compiler::new(); 694 | compiler 695 | .add_signal(1, String::from("signal1"), None) 696 | .unwrap(); 697 | compiler 698 | .add_signal(2, String::from("signal2"), None) 699 | .unwrap(); 700 | compiler 701 | .add_signal(3, String::from("signal3"), None) 702 | .unwrap(); 703 | 704 | let result = compiler.add_gate(AGateType::AAdd, 1, 2, 3); 705 | 706 | assert!(result.is_ok()); 707 | assert_eq!(compiler.gates.len(), 1); 708 | let gate = &compiler.gates[0]; 709 | assert_eq!(gate.op, AGateType::AAdd); 710 | assert_eq!(gate.lh_in, 1); 711 | assert_eq!(gate.rh_in, 2); 712 | assert_eq!(gate.out, 3); 713 | } 714 | 715 | #[test] 716 | fn test_compiler_add_connection() { 717 | let mut compiler = Compiler::new(); 718 | compiler 719 | .add_signal(1, String::from("signal1"), None) 720 | .unwrap(); 721 | compiler 722 | .add_signal(2, String::from("signal2"), None) 723 | .unwrap(); 724 | compiler 725 | .add_signal(3, String::from("signal3"), None) 726 | .unwrap(); 727 | 728 | // Adding connection between signals 1 and 2 729 | let result = compiler.add_connection(1, 2); 730 | 731 | assert!(result.is_ok()); 732 | assert_eq!(compiler.nodes.len(), 2); 733 | 734 | // Assert new node contains both signals 735 | let node = compiler.nodes.get(&4).unwrap(); 736 | assert_eq!(node.signals.len(), 2); 737 | assert!(node.contains_signal(&1)); 738 | assert!(node.contains_signal(&2)); 739 | } 740 | 741 | #[test] 742 | fn test_compiler_add_connection_same_node() { 743 | let mut compiler = Compiler::new(); 744 | compiler 745 | .add_signal(1, String::from("signal1"), None) 746 | .unwrap(); 747 | compiler 748 | .add_signal(2, String::from("signal2"), None) 749 | .unwrap(); 750 | 751 | compiler.add_connection(1, 2).unwrap(); 752 | // Connect the same node 753 | let result = compiler.add_connection(1, 2); 754 | 755 | assert!(result.is_ok()); 756 | // No change in number of nodes 757 | assert_eq!(compiler.nodes.len(), 1); 758 | } 759 | 760 | #[test] 761 | fn test_compiler_add_connection_output_nodes() { 762 | let mut compiler = Compiler::new(); 763 | compiler 764 | .add_signal(1, String::from("signal1"), None) 765 | .unwrap(); 766 | compiler 767 | .add_signal(2, String::from("signal2"), None) 768 | .unwrap(); 769 | 770 | // Set both nodes as output nodes 771 | compiler.nodes.get_mut(&1).unwrap().set_output(true); 772 | compiler.nodes.get_mut(&2).unwrap().set_output(true); 773 | 774 | let result = compiler.add_connection(1, 2); 775 | 776 | assert!(matches!(result, Err(CircuitError::CannotMergeOutputNodes))); 777 | } 778 | 779 | #[test] 780 | fn test_compiler_add_connection_constant_nodes() { 781 | let mut compiler = Compiler::new(); 782 | compiler 783 | .add_signal(1, String::from("signal1"), Some(1)) 784 | .unwrap(); 785 | compiler 786 | .add_signal(2, String::from("signal2"), Some(2)) 787 | .unwrap(); 788 | 789 | let result = compiler.add_connection(1, 2); 790 | assert!(matches!( 791 | result, 792 | Err(CircuitError::CannotMergeConstantNodes) 793 | )); 794 | } 795 | } 796 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Circom To Arithmetic Circuit 2 | //! 3 | //! This library provides the functionality to convert a Circom program into an arithmetic circuit. 4 | 5 | pub mod a_gate_type; 6 | pub mod circom; 7 | pub mod cli; 8 | pub mod compiler; 9 | pub mod process; 10 | pub mod program; 11 | pub mod runtime; 12 | 13 | mod topological_sort; 14 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use boolify::boolify; 2 | use circom_2_arithc::{ 3 | cli::{build_output, Args}, 4 | program::{compile, ProgramError}, 5 | }; 6 | use clap::Parser; 7 | use dotenv::dotenv; 8 | use env_logger::{init_from_env, Env}; 9 | use serde_json::to_string_pretty; 10 | use std::{ 11 | fs::{self, File}, 12 | io::Write, 13 | }; 14 | 15 | fn main() -> Result<(), ProgramError> { 16 | dotenv().ok(); 17 | init_from_env(Env::default().filter_or("LOG_LEVEL", "info")); 18 | 19 | let args = Args::parse(); 20 | 21 | let compiler = compile(&args)?; 22 | let report = compiler.generate_circuit_report()?; 23 | 24 | let output_dir = args.output.clone(); 25 | fs::create_dir_all(output_dir.clone()) 26 | .map_err(|_| ProgramError::OutputDirectoryCreationError)?; 27 | 28 | let mut circuit = compiler.build_circuit()?; 29 | 30 | if let Some(boolify_width) = args.boolify_width { 31 | circuit = boolify(&circuit, boolify_width); 32 | } 33 | 34 | let output_file_path = build_output(&output_dir, "circuit", "txt"); 35 | circuit.write_bristol(&mut File::create(output_file_path)?)?; 36 | 37 | // let output_file_path_json = build_output(&output_dir, "circuit", "json"); 38 | // File::create(output_file_path_json)?.write_all(serde_json::to_string_pretty(&circuit)?.as_bytes())?; 39 | 40 | // let output_debug_path_json = build_output(&output_dir, "debug", "json"); 41 | // File::create(output_debug_path_json)?.write_all(serde_json::to_string_pretty(&compiler)?.as_bytes())?; 42 | 43 | let output_file_path = build_output(&output_dir, "circuit_info", "json"); 44 | File::create(output_file_path)?.write_all(to_string_pretty(&circuit.info)?.as_bytes())?; 45 | 46 | let report_file_path = build_output(&output_dir, "report", "json"); 47 | File::create(report_file_path)?.write_all(to_string_pretty(&report)?.as_bytes())?; 48 | 49 | Ok(()) 50 | } 51 | -------------------------------------------------------------------------------- /src/process.rs: -------------------------------------------------------------------------------- 1 | //! # Process Module 2 | //! 3 | //! Handles execution of statements and expressions for arithmetic circuit generation within a `Runtime` environment. 4 | 5 | use crate::a_gate_type::AGateType; 6 | use crate::compiler::Compiler; 7 | use crate::program::ProgramError; 8 | use crate::runtime::{ 9 | generate_u32, increment_indices, u32_to_access, Context, DataAccess, DataType, NestedValue, 10 | Runtime, RuntimeError, Signal, SubAccess, RETURN_VAR, 11 | }; 12 | use circom_circom_algebra::num_traits::ToPrimitive; 13 | use circom_program_structure::ast::{ 14 | Access, AssignOp, Expression, ExpressionInfixOpcode, ExpressionPrefixOpcode, Statement, 15 | }; 16 | use circom_program_structure::program_archive::ProgramArchive; 17 | use std::cell::RefCell; 18 | use std::collections::HashMap; 19 | use std::rc::Rc; 20 | 21 | /// Processes a sequence of statements. 22 | pub fn process_statements( 23 | ac: &mut Compiler, 24 | runtime: &mut Runtime, 25 | program_archive: &ProgramArchive, 26 | statements: &[Statement], 27 | ) -> Result<(), ProgramError> { 28 | for statement in statements { 29 | process_statement(ac, runtime, program_archive, statement)?; 30 | } 31 | 32 | Ok(()) 33 | } 34 | 35 | /// Processes a single statement. 36 | pub fn process_statement( 37 | ac: &mut Compiler, 38 | runtime: &mut Runtime, 39 | program_archive: &ProgramArchive, 40 | statement: &Statement, 41 | ) -> Result<(), ProgramError> { 42 | match statement { 43 | Statement::InitializationBlock { 44 | initializations, .. 45 | } => process_statements(ac, runtime, program_archive, initializations), 46 | Statement::Block { stmts, .. } => process_statements(ac, runtime, program_archive, stmts), 47 | Statement::Substitution { 48 | var, 49 | access, 50 | rhe, 51 | op, 52 | .. 53 | } => handle_substitution(ac, runtime, program_archive, var, access, rhe, op), 54 | Statement::Declaration { 55 | xtype, 56 | name, 57 | dimensions, 58 | .. 59 | } => { 60 | let data_type = DataType::try_from(xtype)?; 61 | let dim_access: Vec = dimensions 62 | .iter() 63 | .map(|expression| process_expression(ac, runtime, program_archive, expression)) 64 | .collect::, ProgramError>>()?; 65 | 66 | let signal_gen = runtime.get_signal_gen(); 67 | let ctx = runtime.current_context()?; 68 | let dimensions: Vec = dim_access 69 | .iter() 70 | .map(|dim_access| { 71 | ctx.get_variable_value(dim_access)? 72 | .ok_or(ProgramError::EmptyDataItem) 73 | }) 74 | .collect::, ProgramError>>()?; 75 | ctx.declare_item(data_type.clone(), name, &dimensions, signal_gen)?; 76 | 77 | // If the declared item is a signal we should add it to the arithmetic circuit 78 | if data_type == DataType::Signal { 79 | let mut signal_access = DataAccess::new(name, Vec::new()); 80 | 81 | if dimensions.is_empty() { 82 | let signal_id = ctx.get_signal_id(&signal_access)?; 83 | ac.add_signal( 84 | signal_id, 85 | signal_access.access_str(ctx.get_ctx_name()), 86 | None, 87 | )?; 88 | } else { 89 | let mut indices: Vec = vec![0; dimensions.len()]; 90 | 91 | loop { 92 | // Set access and get signal id for the current indices 93 | signal_access.set_access(u32_to_access(&indices)); 94 | let signal_id = ctx.get_signal_id(&signal_access)?; 95 | ac.add_signal( 96 | signal_id, 97 | signal_access.access_str(ctx.get_ctx_name()), 98 | None, 99 | )?; 100 | 101 | // Increment indices 102 | if !increment_indices(&mut indices, &dimensions)? { 103 | break; 104 | } 105 | } 106 | } 107 | } 108 | 109 | Ok(()) 110 | } 111 | Statement::IfThenElse { 112 | cond, 113 | if_case, 114 | else_case, 115 | .. 116 | } => { 117 | let access = process_expression(ac, runtime, program_archive, cond)?; 118 | let result = runtime 119 | .current_context()? 120 | .get_variable_value(&access)? 121 | .ok_or(ProgramError::EmptyDataItem)?; 122 | 123 | if result == 0 { 124 | if let Some(else_statement) = else_case { 125 | runtime.push_context(true, "IF_FALSE".to_string())?; 126 | process_statement(ac, runtime, program_archive, else_statement)?; 127 | runtime.pop_context(true)?; 128 | Ok(()) 129 | } else { 130 | Ok(()) 131 | } 132 | } else { 133 | runtime.push_context(true, "IF_TRUE".to_string())?; 134 | process_statement(ac, runtime, program_archive, if_case)?; 135 | runtime.pop_context(true)?; 136 | Ok(()) 137 | } 138 | } 139 | Statement::While { cond, stmt, .. } => { 140 | runtime.push_context(true, "WHILE_PRE".to_string())?; 141 | loop { 142 | let access = process_expression(ac, runtime, program_archive, cond)?; 143 | let result = runtime 144 | .current_context()? 145 | .get_variable_value(&access)? 146 | .ok_or(ProgramError::EmptyDataItem)?; 147 | 148 | if result == 0 { 149 | break; 150 | } 151 | 152 | runtime.push_context(true, "WHILE_EXE".to_string())?; 153 | process_statement(ac, runtime, program_archive, stmt)?; 154 | runtime.pop_context(true)?; 155 | } 156 | runtime.pop_context(true)?; 157 | 158 | Ok(()) 159 | } 160 | Statement::Return { value, .. } => { 161 | let return_access = process_expression(ac, runtime, program_archive, value)?; 162 | 163 | let signal_gen = runtime.get_signal_gen(); 164 | let ctx = runtime.current_context()?; 165 | let return_value = ctx 166 | .get_variable_value(&return_access)? 167 | .ok_or(ProgramError::EmptyDataItem)?; 168 | 169 | ctx.declare_item(DataType::Variable, RETURN_VAR, &[], signal_gen)?; 170 | ctx.set_variable(&DataAccess::new(RETURN_VAR, vec![]), Some(return_value))?; 171 | 172 | Ok(()) 173 | } 174 | Statement::Assert { arg, .. } => { 175 | let access = process_expression(ac, runtime, program_archive, arg)?; 176 | let result = runtime 177 | .current_context()? 178 | .get_variable_value(&access)? 179 | .ok_or(ProgramError::EmptyDataItem)?; 180 | 181 | if result == 0 { 182 | return Err(ProgramError::RuntimeError(RuntimeError::AssertionFailed)); 183 | } 184 | 185 | Ok(()) 186 | } 187 | _ => Err(ProgramError::StatementNotImplemented), 188 | } 189 | } 190 | 191 | /// Handles a substitution statement 192 | fn handle_substitution( 193 | ac: &mut Compiler, 194 | runtime: &mut Runtime, 195 | program_archive: &ProgramArchive, 196 | var: &str, 197 | access: &[Access], 198 | rhe: &Expression, 199 | op: &AssignOp, 200 | ) -> Result<(), ProgramError> { 201 | let lh_access = build_access(ac, runtime, program_archive, var, access)?; 202 | let rh_access = process_expression(ac, runtime, program_archive, rhe)?; 203 | 204 | let signal_gen = runtime.get_signal_gen(); 205 | let ctx = runtime.current_context()?; 206 | match ctx.get_item_data_type(var)? { 207 | DataType::Variable => { 208 | // Assign the evaluated right-hand side to the left-hand side 209 | let value = ctx.get_variable_value(&rh_access)?; 210 | ctx.set_variable(&lh_access, value)?; 211 | } 212 | DataType::Component => match op { 213 | AssignOp::AssignVar => { 214 | // Component instantiation 215 | let signal_map = ctx.get_component_map(&rh_access)?; 216 | ctx.set_component(&lh_access, signal_map)?; 217 | } 218 | AssignOp::AssignConstraintSignal => { 219 | // Component signal assignment 220 | match ctx.get_component_signal_content(&lh_access)? { 221 | NestedValue::Array(signal) => { 222 | let assigned_signal_array = 223 | match get_signal_content_for_access(ctx, &rh_access)? { 224 | NestedValue::Array(array) => array, 225 | _ => return Err(ProgramError::InvalidDataType), 226 | }; 227 | 228 | connect_signal_arrays(ac, &signal, &assigned_signal_array)?; 229 | } 230 | NestedValue::Value(_) => { 231 | let component_signal = ctx.get_component_signal_id(&lh_access)?; 232 | let assigned_signal = 233 | get_signal_for_access(ac, ctx, signal_gen, &rh_access)?; 234 | 235 | ac.add_connection(assigned_signal, component_signal)?; 236 | } 237 | } 238 | } 239 | _ => return Err(ProgramError::OperationNotSupported), 240 | }, 241 | DataType::Signal => { 242 | match rhe { 243 | Expression::Variable { .. } => match ctx.get_signal_content(&lh_access)? { 244 | NestedValue::Array(signal) => { 245 | // Connect the signals in the arrays 246 | let assigned_signal_array = 247 | match get_signal_content_for_access(ctx, &rh_access)? { 248 | NestedValue::Array(array) => array, 249 | _ => return Err(ProgramError::InvalidDataType), 250 | }; 251 | 252 | connect_signal_arrays(ac, &signal, &assigned_signal_array)?; 253 | } 254 | NestedValue::Value(signal_id) => { 255 | let gate_output_id = 256 | get_signal_for_access(ac, ctx, signal_gen, &rh_access)?; 257 | 258 | ac.add_connection(gate_output_id, signal_id)?; 259 | } 260 | }, 261 | Expression::Call { .. } 262 | | Expression::InfixOp { .. } 263 | | Expression::PrefixOp { .. } 264 | | Expression::Number(_, _) => { 265 | // Get the signal identifiers and connect them 266 | let given_output_id = ctx.get_signal_id(&lh_access)?; 267 | let gate_output_id = get_signal_for_access(ac, ctx, signal_gen, &rh_access)?; 268 | 269 | ac.add_connection(gate_output_id, given_output_id)?; 270 | } 271 | _ => return Err(ProgramError::SignalSubstitutionNotImplemented), 272 | } 273 | } 274 | } 275 | 276 | Ok(()) 277 | } 278 | 279 | /// Processes an expression and returns an access to the result. 280 | pub fn process_expression( 281 | ac: &mut Compiler, 282 | runtime: &mut Runtime, 283 | program_archive: &ProgramArchive, 284 | expression: &Expression, 285 | ) -> Result { 286 | match expression { 287 | Expression::Call { id, args, .. } => handle_call(ac, runtime, program_archive, id, args), 288 | Expression::InfixOp { 289 | lhe, infix_op, rhe, .. 290 | } => handle_infix_op(ac, runtime, program_archive, infix_op, lhe, rhe), 291 | Expression::PrefixOp { prefix_op, rhe, .. } => { 292 | handle_prefix_op(ac, runtime, program_archive, prefix_op, rhe) 293 | } 294 | Expression::Number(_, value) => { 295 | let signal_gen = runtime.get_signal_gen(); 296 | let access = runtime 297 | .current_context()? 298 | .declare_random_item(signal_gen, DataType::Variable)?; 299 | 300 | runtime.current_context()?.set_variable( 301 | &access, 302 | Some(value.to_u32().ok_or(ProgramError::ParsingError)?), 303 | )?; 304 | 305 | Ok(access) 306 | } 307 | Expression::Variable { name, access, .. } => { 308 | build_access(ac, runtime, program_archive, name, access) 309 | } 310 | _ => Err(ProgramError::ExpressionNotImplemented), 311 | } 312 | } 313 | 314 | /// Handles function and template calls. 315 | fn handle_call( 316 | ac: &mut Compiler, 317 | runtime: &mut Runtime, 318 | program_archive: &ProgramArchive, 319 | id: &str, 320 | args: &[Expression], 321 | ) -> Result { 322 | // Determine if the call is to a function or a template and get argument names and body 323 | let is_function = program_archive.contains_function(id); 324 | let (arg_names, body) = if is_function { 325 | let function_data = program_archive.get_function_data(id); 326 | ( 327 | function_data.get_name_of_params().clone(), 328 | function_data.get_body_as_vec().to_vec(), 329 | ) 330 | } else if program_archive.contains_template(id) { 331 | let template_data = program_archive.get_template_data(id); 332 | ( 333 | template_data.get_name_of_params().clone(), 334 | template_data.get_body_as_vec().to_vec(), 335 | ) 336 | } else { 337 | return Err(ProgramError::UndefinedFunctionOrTemplate); 338 | }; 339 | 340 | let arg_values = args 341 | .iter() 342 | .map(|arg_expr| { 343 | process_expression(ac, runtime, program_archive, arg_expr).and_then(|value_access| { 344 | runtime 345 | .current_context()? 346 | .get_variable_value(&value_access)? 347 | .ok_or(ProgramError::EmptyDataItem) 348 | }) 349 | }) 350 | .collect::, ProgramError>>()?; 351 | 352 | // Create a new execution context 353 | runtime.push_context(false, id.to_string())?; 354 | 355 | // Set arguments in the new context 356 | for (arg_name, &arg_value) in arg_names.iter().zip(&arg_values) { 357 | let signal_gen = runtime.get_signal_gen(); 358 | runtime 359 | .current_context()? 360 | .declare_item(DataType::Variable, arg_name, &[], signal_gen)?; 361 | runtime 362 | .current_context()? 363 | .set_variable(&DataAccess::new(arg_name, vec![]), Some(arg_value))?; 364 | } 365 | 366 | // Process the function/template body 367 | process_statements(ac, runtime, program_archive, &body)?; 368 | 369 | // Get return values 370 | let mut function_return: Option = None; 371 | let mut component_return: HashMap = HashMap::new(); 372 | 373 | if is_function { 374 | if let Ok(value) = runtime 375 | .current_context()? 376 | .get_variable_value(&DataAccess::new(RETURN_VAR, vec![])) 377 | { 378 | function_return = value; 379 | } 380 | } else { 381 | // Retrieve input and output signals 382 | let template_data = program_archive.get_template_data(id); 383 | let input_signals = template_data.get_inputs(); 384 | let output_signals = template_data.get_outputs(); 385 | 386 | // Store ids in the component 387 | for (signal, _) in input_signals.iter().chain(output_signals.iter()) { 388 | let ids = runtime.current_context()?.get_signal(signal)?; 389 | component_return.insert(signal.to_string(), ids); 390 | } 391 | } 392 | 393 | // Return to parent context 394 | runtime.pop_context(false)?; 395 | let signal_gen = runtime.get_signal_gen(); 396 | let ctx = runtime.current_context()?; 397 | let return_access = 398 | DataAccess::new(&format!("{}_{}_{}", id, RETURN_VAR, generate_u32()), vec![]); 399 | 400 | if is_function { 401 | ctx.declare_item( 402 | DataType::Variable, 403 | &return_access.get_name(), 404 | &[], 405 | signal_gen, 406 | )?; 407 | ctx.set_variable(&return_access, function_return)?; 408 | } else { 409 | ctx.declare_item( 410 | DataType::Component, 411 | &return_access.get_name(), 412 | &[], 413 | signal_gen, 414 | )?; 415 | ctx.set_component(&return_access, component_return)?; 416 | } 417 | 418 | Ok(return_access) 419 | } 420 | 421 | /// Handles an infix operation. 422 | /// - If both inputs are variables, it directly computes the operation. 423 | /// - If one or both inputs are signals, it constructs the corresponding circuit gate. 424 | /// 425 | /// Returns the access to a variable containing the result of the operation or the signal of the output gate. 426 | fn handle_infix_op( 427 | ac: &mut Compiler, 428 | runtime: &mut Runtime, 429 | program_archive: &ProgramArchive, 430 | op: &ExpressionInfixOpcode, 431 | lhe: &Expression, 432 | rhe: &Expression, 433 | ) -> Result { 434 | let lhe_access = process_expression(ac, runtime, program_archive, lhe)?; 435 | let rhe_access = process_expression(ac, runtime, program_archive, rhe)?; 436 | 437 | let signal_gen: Rc> = runtime.get_signal_gen(); 438 | let ctx = runtime.current_context()?; 439 | 440 | // Determine the data types of the left and right operands 441 | let lhs_data_type = ctx.get_item_data_type(&lhe_access.get_name())?; 442 | let rhs_data_type = ctx.get_item_data_type(&rhe_access.get_name())?; 443 | 444 | // Handle the case where both inputs are variables 445 | if lhs_data_type == DataType::Variable && rhs_data_type == DataType::Variable { 446 | let lhs_value = ctx 447 | .get_variable_value(&lhe_access)? 448 | .ok_or(ProgramError::EmptyDataItem)?; 449 | let rhs_value = ctx 450 | .get_variable_value(&rhe_access)? 451 | .ok_or(ProgramError::EmptyDataItem)?; 452 | 453 | let op_res = execute_op(lhs_value, rhs_value, op)?; 454 | let item_access = ctx.declare_random_item(signal_gen, DataType::Variable)?; 455 | ctx.set_variable(&item_access, Some(op_res))?; 456 | 457 | return Ok(item_access); 458 | } 459 | 460 | // Handle cases where one or both inputs are signals 461 | let lhs_id = get_signal_for_access(ac, ctx, signal_gen.clone(), &lhe_access)?; 462 | let rhs_id = get_signal_for_access(ac, ctx, signal_gen.clone(), &rhe_access)?; 463 | 464 | // Construct the corresponding circuit gate 465 | let gate_type = AGateType::from(op); 466 | let output_signal = ctx.declare_random_item(signal_gen, DataType::Signal)?; 467 | let output_id = ctx.get_signal_id(&output_signal)?; 468 | 469 | // Add output signal and gate to the circuit 470 | ac.add_signal( 471 | output_id, 472 | output_signal.access_str(ctx.get_ctx_name()), 473 | None, 474 | )?; 475 | ac.add_gate(gate_type, lhs_id, rhs_id, output_id)?; 476 | 477 | Ok(output_signal) 478 | } 479 | 480 | /// Handles a prefix operation. 481 | /// - If input is a variable, it directly computes the operation. 482 | /// - If input is a signal, it handles it like an infix op against a constant. 483 | /// 484 | /// Returns the access to a variable containing the result of the operation or the signal of the output gate. 485 | fn handle_prefix_op( 486 | ac: &mut Compiler, 487 | runtime: &mut Runtime, 488 | program_archive: &ProgramArchive, 489 | op: &ExpressionPrefixOpcode, 490 | rhe: &Expression, 491 | ) -> Result { 492 | let rhe_access = process_expression(ac, runtime, program_archive, rhe)?; 493 | 494 | let signal_gen: Rc> = runtime.get_signal_gen(); 495 | let ctx = runtime.current_context()?; 496 | 497 | // Determine the data type of the operand 498 | let rhs_data_type = ctx.get_item_data_type(&rhe_access.get_name())?; 499 | 500 | // Handle the variable case 501 | if rhs_data_type == DataType::Variable { 502 | let rhs_value = ctx 503 | .get_variable_value(&rhe_access)? 504 | .ok_or(ProgramError::EmptyDataItem)?; 505 | 506 | let op_res = execute_prefix_op(op, rhs_value)?; 507 | let item_access = ctx.declare_random_item(signal_gen, DataType::Variable)?; 508 | ctx.set_variable(&item_access, Some(op_res))?; 509 | 510 | return Ok(item_access); 511 | } 512 | 513 | let (lhs_value, infix_op) = to_equivalent_infix(op); 514 | let lhs_id = make_constant(ac, ctx, signal_gen.clone(), lhs_value)?; 515 | 516 | // Handle signal input 517 | let rhs_id = get_signal_for_access(ac, ctx, signal_gen.clone(), &rhe_access)?; 518 | 519 | // Construct the corresponding circuit gate 520 | let gate_type = AGateType::from(&infix_op); 521 | let output_signal = ctx.declare_random_item(signal_gen, DataType::Signal)?; 522 | let output_id = ctx.get_signal_id(&output_signal)?; 523 | 524 | // Add output signal and gate to the circuit 525 | ac.add_signal( 526 | output_id, 527 | output_signal.access_str(ctx.get_ctx_name()), 528 | None, 529 | )?; 530 | ac.add_gate(gate_type, lhs_id, rhs_id, output_id)?; 531 | 532 | Ok(output_signal) 533 | } 534 | 535 | /// Returns a signal id for a given access 536 | /// - If the access is a signal or a component, it returns the corresponding signal id. 537 | /// - If the access is a variable, it adds a constant variable to the circuit and returns the corresponding signal id. 538 | fn get_signal_for_access( 539 | ac: &mut Compiler, 540 | ctx: &mut Context, 541 | signal_gen: Rc>, 542 | access: &DataAccess, 543 | ) -> Result { 544 | match ctx.get_item_data_type(&access.get_name())? { 545 | DataType::Signal => Ok(ctx.get_signal_id(access)?), 546 | DataType::Variable => { 547 | // Get variable value 548 | let value = ctx 549 | .get_variable_value(access)? 550 | .ok_or(ProgramError::EmptyDataItem)?; 551 | 552 | make_constant(ac, ctx, signal_gen, value) 553 | } 554 | DataType::Component => Ok(ctx.get_component_signal_id(access)?), 555 | } 556 | } 557 | 558 | fn make_constant( 559 | ac: &mut Compiler, 560 | ctx: &mut Context, 561 | signal_gen: Rc>, 562 | value: u32, 563 | ) -> Result { 564 | let signal_access = DataAccess::new(&format!("const_signal_{}", value), vec![]); 565 | // Try to get signal id if it exists 566 | if let Ok(id) = ctx.get_signal_id(&signal_access) { 567 | Ok(id) 568 | } else { 569 | // If it doesn't exist, declare it and add it to the circuit 570 | ctx.declare_item(DataType::Signal, &signal_access.get_name(), &[], signal_gen)?; 571 | let signal_id = ctx.get_signal_id(&signal_access)?; 572 | ac.add_signal( 573 | signal_id, 574 | signal_access.access_str(ctx.get_ctx_name()), 575 | Some(value), 576 | )?; 577 | Ok(signal_id) 578 | } 579 | } 580 | 581 | /// Returns the content of a signal for a given access 582 | fn get_signal_content_for_access( 583 | ctx: &Context, 584 | access: &DataAccess, 585 | ) -> Result, ProgramError> { 586 | match ctx.get_item_data_type(&access.get_name())? { 587 | DataType::Signal => Ok(ctx.get_signal_content(access)?), 588 | DataType::Component => Ok(ctx.get_component_signal_content(access)?), 589 | _ => Err(ProgramError::InvalidDataType), 590 | } 591 | } 592 | 593 | /// Connects two composed signals 594 | fn connect_signal_arrays( 595 | ac: &mut Compiler, 596 | a: &[NestedValue], 597 | b: &[NestedValue], 598 | ) -> Result<(), ProgramError> { 599 | // Verify that the arrays have the same length 600 | if a.len() != b.len() { 601 | return Err(ProgramError::InvalidDataType); 602 | } 603 | 604 | for (a, b) in a.iter().zip(b.iter()) { 605 | match (a, b) { 606 | (NestedValue::Value(a), NestedValue::Value(b)) => { 607 | ac.add_connection(*a, *b)?; 608 | } 609 | (NestedValue::Array(a), NestedValue::Array(b)) => { 610 | connect_signal_arrays(ac, a, b)?; 611 | } 612 | _ => return Err(ProgramError::InvalidDataType), 613 | } 614 | } 615 | 616 | Ok(()) 617 | } 618 | 619 | /// Builds a DataAccess from an Access array 620 | fn build_access( 621 | ac: &mut Compiler, 622 | runtime: &mut Runtime, 623 | program_archive: &ProgramArchive, 624 | name: &str, 625 | access: &[Access], 626 | ) -> Result { 627 | let mut access_vec = Vec::new(); 628 | 629 | for a in access.iter() { 630 | match a { 631 | Access::ArrayAccess(expression) => { 632 | let index_access = process_expression(ac, runtime, program_archive, expression)?; 633 | let index = runtime 634 | .current_context()? 635 | .get_variable_value(&index_access)? 636 | .ok_or(ProgramError::EmptyDataItem)?; 637 | access_vec.push(SubAccess::Array(index)); 638 | } 639 | Access::ComponentAccess(signal) => { 640 | access_vec.push(SubAccess::Component(signal.to_string())); 641 | } 642 | } 643 | } 644 | 645 | Ok(DataAccess::new(name, access_vec)) 646 | } 647 | 648 | /// Executes an operation on two u32 values, performing the specified arithmetic or logical computation. 649 | fn execute_op(lhs: u32, rhs: u32, op: &ExpressionInfixOpcode) -> Result { 650 | let res = match op { 651 | ExpressionInfixOpcode::Mul => lhs * rhs, 652 | ExpressionInfixOpcode::Div => { 653 | if rhs == 0 { 654 | return Err(ProgramError::OperationError("Division by zero".to_string())); 655 | } 656 | 657 | lhs / rhs 658 | } 659 | ExpressionInfixOpcode::Add => lhs + rhs, 660 | ExpressionInfixOpcode::Sub => { 661 | if lhs < rhs { 662 | return Err(ProgramError::OperationError( 663 | "Subtraction underflow".to_string(), 664 | )); 665 | } 666 | 667 | lhs - rhs 668 | } 669 | ExpressionInfixOpcode::Pow => lhs.pow(rhs), 670 | ExpressionInfixOpcode::IntDiv => { 671 | if rhs == 0 { 672 | return Err(ProgramError::OperationError( 673 | "Integer division by zero".to_string(), 674 | )); 675 | } 676 | 677 | lhs / rhs 678 | } 679 | ExpressionInfixOpcode::Mod => { 680 | if rhs == 0 { 681 | return Err(ProgramError::OperationError("Modulo by zero".to_string())); 682 | } 683 | 684 | lhs % rhs 685 | } 686 | ExpressionInfixOpcode::ShiftL => lhs << rhs, 687 | ExpressionInfixOpcode::ShiftR => lhs >> rhs, 688 | ExpressionInfixOpcode::LesserEq => { 689 | if lhs <= rhs { 690 | 1 691 | } else { 692 | 0 693 | } 694 | } 695 | ExpressionInfixOpcode::GreaterEq => { 696 | if lhs >= rhs { 697 | 1 698 | } else { 699 | 0 700 | } 701 | } 702 | ExpressionInfixOpcode::Lesser => { 703 | if lhs < rhs { 704 | 1 705 | } else { 706 | 0 707 | } 708 | } 709 | ExpressionInfixOpcode::Greater => { 710 | if lhs > rhs { 711 | 1 712 | } else { 713 | 0 714 | } 715 | } 716 | ExpressionInfixOpcode::Eq => { 717 | if lhs == rhs { 718 | 1 719 | } else { 720 | 0 721 | } 722 | } 723 | ExpressionInfixOpcode::NotEq => { 724 | if lhs != rhs { 725 | 1 726 | } else { 727 | 0 728 | } 729 | } 730 | ExpressionInfixOpcode::BoolOr => { 731 | if lhs != 0 || rhs != 0 { 732 | 1 733 | } else { 734 | 0 735 | } 736 | } 737 | ExpressionInfixOpcode::BoolAnd => { 738 | if lhs != 0 && rhs != 0 { 739 | 1 740 | } else { 741 | 0 742 | } 743 | } 744 | ExpressionInfixOpcode::BitOr => lhs | rhs, 745 | ExpressionInfixOpcode::BitAnd => lhs & rhs, 746 | ExpressionInfixOpcode::BitXor => lhs ^ rhs, 747 | }; 748 | 749 | Ok(res) 750 | } 751 | 752 | /// Executes a prefix operation on a u32 value, performing the specified arithmetic or logical computation. 753 | fn execute_prefix_op(op: &ExpressionPrefixOpcode, rhs: u32) -> Result { 754 | let (lhs_value, infix_op) = to_equivalent_infix(op); 755 | execute_op(lhs_value, rhs, &infix_op) 756 | } 757 | 758 | fn to_equivalent_infix(op: &ExpressionPrefixOpcode) -> (u32, ExpressionInfixOpcode) { 759 | match op { 760 | ExpressionPrefixOpcode::Sub => (0, ExpressionInfixOpcode::Sub), 761 | ExpressionPrefixOpcode::BoolNot => (0, ExpressionInfixOpcode::Eq), 762 | ExpressionPrefixOpcode::Complement => (u32::MAX, ExpressionInfixOpcode::BitXor), 763 | } 764 | } 765 | 766 | #[cfg(test)] 767 | mod tests { 768 | use super::*; 769 | use circom_program_structure::ast::{ExpressionInfixOpcode, ExpressionPrefixOpcode}; 770 | 771 | #[test] 772 | fn test_execute_op() { 773 | assert_eq!(execute_op(3, 4, &ExpressionInfixOpcode::Add).unwrap(), 7); 774 | assert_eq!(execute_op(10, 5, &ExpressionInfixOpcode::Sub).unwrap(), 5); 775 | assert_eq!(execute_op(6, 3, &ExpressionInfixOpcode::Mul).unwrap(), 18); 776 | assert_eq!(execute_op(9, 3, &ExpressionInfixOpcode::Div).unwrap(), 3); 777 | assert_eq!(execute_op(7, 3, &ExpressionInfixOpcode::Mod).unwrap(), 1); 778 | assert_eq!(execute_op(2, 3, &ExpressionInfixOpcode::Pow).unwrap(), 8); 779 | assert_eq!( 780 | execute_op(8, 2, &ExpressionInfixOpcode::ShiftL).unwrap(), 781 | 32 782 | ); 783 | assert_eq!(execute_op(8, 2, &ExpressionInfixOpcode::ShiftR).unwrap(), 2); 784 | assert_eq!(execute_op(5, 5, &ExpressionInfixOpcode::Eq).unwrap(), 1); 785 | assert_eq!(execute_op(5, 4, &ExpressionInfixOpcode::NotEq).unwrap(), 1); 786 | assert_eq!(execute_op(1, 0, &ExpressionInfixOpcode::BoolOr).unwrap(), 1); 787 | assert_eq!( 788 | execute_op(1, 1, &ExpressionInfixOpcode::BoolAnd).unwrap(), 789 | 1 790 | ); 791 | assert_eq!(execute_op(1, 1, &ExpressionInfixOpcode::BitOr).unwrap(), 1); 792 | assert_eq!(execute_op(1, 1, &ExpressionInfixOpcode::BitAnd).unwrap(), 1); 793 | assert_eq!(execute_op(1, 1, &ExpressionInfixOpcode::BitXor).unwrap(), 0); 794 | } 795 | 796 | #[test] 797 | fn test_execute_op_errors() { 798 | assert!(execute_op(10, 0, &ExpressionInfixOpcode::Div).is_err()); 799 | assert!(execute_op(10, 0, &ExpressionInfixOpcode::IntDiv).is_err()); 800 | assert!(execute_op(10, 0, &ExpressionInfixOpcode::Mod).is_err()); 801 | } 802 | 803 | #[test] 804 | fn test_execute_prefix_op() { 805 | assert_eq!( 806 | execute_prefix_op(&ExpressionPrefixOpcode::Sub, 5) 807 | .unwrap_err() 808 | .to_string(), 809 | "Operation error: Subtraction underflow" 810 | ); // 0 - 5 811 | assert_eq!( 812 | execute_prefix_op(&ExpressionPrefixOpcode::BoolNot, 0).unwrap(), 813 | 1 814 | ); // !0 == 1 815 | assert_eq!( 816 | execute_prefix_op(&ExpressionPrefixOpcode::BoolNot, 1).unwrap(), 817 | 0 818 | ); // !1 == 0 819 | assert_eq!( 820 | execute_prefix_op(&ExpressionPrefixOpcode::Complement, 0b1010).unwrap(), 821 | 0b1111_1111_1111_1111_1111_1111_1111_0101 822 | ); // ~0b1010 823 | } 824 | 825 | #[test] 826 | fn test_to_equivalent_infix() { 827 | let (value, opcode) = to_equivalent_infix(&ExpressionPrefixOpcode::Sub); 828 | assert_eq!(value, 0); 829 | assert!(matches!(opcode, ExpressionInfixOpcode::Sub)); 830 | 831 | let (value, opcode) = to_equivalent_infix(&ExpressionPrefixOpcode::BoolNot); 832 | assert_eq!(value, 0); 833 | assert!(matches!(opcode, ExpressionInfixOpcode::Eq)); 834 | 835 | let (value, opcode) = to_equivalent_infix(&ExpressionPrefixOpcode::Complement); 836 | assert_eq!(value, u32::MAX); 837 | assert!(matches!(opcode, ExpressionInfixOpcode::BitXor)); 838 | } 839 | } 840 | -------------------------------------------------------------------------------- /src/program.rs: -------------------------------------------------------------------------------- 1 | //! # Program Module 2 | //! 3 | //! This module processes the circom input program to build the arithmetic circuit. 4 | 5 | use crate::{ 6 | circom::{parser::parse_project, type_analysis::analyse_project}, 7 | cli::Args, 8 | compiler::{CircuitError, Compiler}, 9 | process::{process_expression, process_statements}, 10 | runtime::{DataAccess, DataType, Runtime, RuntimeError}, 11 | }; 12 | use bristol_circuit::BristolCircuitError; 13 | use circom_program_structure::ast::Expression; 14 | use std::io; 15 | use thiserror::Error; 16 | 17 | /// Parses a given Circom program and constructs an arithmetic circuit from it. 18 | pub fn compile(args: &Args) -> Result { 19 | let mut compiler = Compiler::new(); 20 | let mut runtime = Runtime::new(); 21 | let mut program_archive = parse_project(args)?; 22 | 23 | analyse_project(&mut program_archive)?; 24 | 25 | match program_archive.get_main_expression() { 26 | Expression::Call { id, args, .. } => { 27 | let template_data = program_archive.get_template_data(id); 28 | 29 | // Get values 30 | let mut values: Vec> = Vec::new(); 31 | for expression in args { 32 | let access = 33 | process_expression(&mut compiler, &mut runtime, &program_archive, expression)?; 34 | let value = runtime.current_context()?.get_variable_value(&access)?; 35 | values.push(value); 36 | } 37 | 38 | // Get and declare arguments 39 | let names = template_data.get_name_of_params(); 40 | for (name, &value) in names.iter().zip(values.iter()) { 41 | let signal_gen = runtime.get_signal_gen(); 42 | runtime.current_context()?.declare_item( 43 | DataType::Variable, 44 | name, 45 | &[], 46 | signal_gen, 47 | )?; 48 | runtime 49 | .current_context()? 50 | .set_variable(&DataAccess::new(name, Vec::new()), value)?; 51 | } 52 | 53 | // Process the main component 54 | let statements = template_data.get_body_as_vec(); 55 | process_statements(&mut compiler, &mut runtime, &program_archive, statements)?; 56 | 57 | for (ikey, (_ivs, _ivh)) in template_data.get_inputs().iter() { 58 | let filter = format!("0.{}", ikey); 59 | compiler.add_inputs(compiler.get_signals(filter)); 60 | } 61 | 62 | for (okey, (_ovs, _ovh)) in template_data.get_outputs().iter() { 63 | let filter = format!("0.{}", okey); 64 | let signals = compiler.get_signals(filter); 65 | compiler.add_outputs(signals); 66 | } 67 | } 68 | _ => return Err(ProgramError::MainExpressionNotACall), 69 | } 70 | 71 | compiler.update_type(args.value_type)?; 72 | 73 | Ok(compiler) 74 | } 75 | 76 | /// Program errors 77 | #[derive(Error, Debug)] 78 | pub enum ProgramError { 79 | #[error("Analysis error")] 80 | AnalysisError, 81 | #[error("Call error")] 82 | CallError, 83 | #[error("Circuit error: {0}")] 84 | CircuitError(CircuitError), 85 | #[error("Empty data item")] 86 | EmptyDataItem, 87 | #[error("Expression not implemented")] 88 | ExpressionNotImplemented, 89 | #[error("Input initialization error")] 90 | InputInitializationError, 91 | #[error("Invalid data type")] 92 | InvalidDataType, 93 | #[error("IO error: {0}")] 94 | IOError(#[from] io::Error), 95 | #[error("JSON serialization error: {0}")] 96 | JsonSerializationError(#[from] serde_json::Error), 97 | #[error("Main expression not a call")] 98 | MainExpressionNotACall, 99 | #[error("Operation error: {0}")] 100 | OperationError(String), 101 | #[error("Operation not supported")] 102 | OperationNotSupported, 103 | #[error("Output directory creation error")] 104 | OutputDirectoryCreationError, 105 | #[error("Parsing error")] 106 | ParsingError, 107 | #[error("Runtime error: {0}")] 108 | RuntimeError(RuntimeError), 109 | #[error("Statement not implemented")] 110 | StatementNotImplemented, 111 | #[error("Signal substitution not implemented")] 112 | SignalSubstitutionNotImplemented, 113 | #[error("Undefined function or template")] 114 | UndefinedFunctionOrTemplate, 115 | #[error(transparent)] 116 | BristolCircuitError(#[from] BristolCircuitError), 117 | } 118 | -------------------------------------------------------------------------------- /src/topological_sort.rs: -------------------------------------------------------------------------------- 1 | use crate::compiler::CircuitError; 2 | 3 | pub fn topological_sort( 4 | len: usize, 5 | get_deps: &dyn Fn(usize) -> Vec, 6 | ) -> Result, CircuitError> { 7 | let mut sorted = Vec::with_capacity(len); 8 | let mut visiting = vec![false; len]; 9 | let mut visited = vec![false; len]; 10 | 11 | for i in 0..len { 12 | topological_sort_visit(i, &mut visiting, &mut visited, get_deps, &mut sorted)?; 13 | } 14 | 15 | assert!( 16 | sorted.len() == len, 17 | "Topological sort did not return all elements" 18 | ); 19 | 20 | Ok(sorted) 21 | } 22 | 23 | fn topological_sort_visit( 24 | i: usize, 25 | visiting: &mut [bool], 26 | visited: &mut [bool], 27 | get_deps: &dyn Fn(usize) -> Vec, 28 | sorted: &mut Vec, 29 | ) -> Result<(), CircuitError> { 30 | if visited[i] { 31 | return Ok(()); 32 | } 33 | 34 | if visiting[i] { 35 | return Err(CircuitError::CyclicDependency { 36 | message: format!("detected at i={}", i), 37 | }); 38 | } 39 | 40 | visiting[i] = true; 41 | 42 | for j in get_deps(i) { 43 | topological_sort_visit(j, visiting, visited, get_deps, sorted)?; 44 | } 45 | 46 | sorted.push(i); 47 | visited[i] = true; 48 | 49 | Ok(()) 50 | } 51 | -------------------------------------------------------------------------------- /tests/circuits/integration/addZero.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template addZero() { 4 | signal input in; 5 | signal output out; 6 | 7 | out <== in + 0; 8 | } 9 | 10 | component main = addZero(); 11 | -------------------------------------------------------------------------------- /tests/circuits/integration/arrayAssignment.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template componentA () { 4 | signal input in[2][2]; 5 | signal output out; 6 | 7 | out <== in[0][0] + in[0][1] + in[1][0] + in[1][1]; 8 | } 9 | 10 | template componentB() { 11 | signal input a_in[2][2]; 12 | signal output out; 13 | 14 | component a = componentA(); 15 | a.in <== a_in; 16 | 17 | out <== a.out; 18 | } 19 | 20 | component main = componentB(); 21 | -------------------------------------------------------------------------------- /tests/circuits/integration/constantSum.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template constantSum() { 4 | signal output out; 5 | 6 | out <== 3 + 5; 7 | } 8 | 9 | component main = constantSum(); 10 | -------------------------------------------------------------------------------- /tests/circuits/integration/directOutput.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template directOutput() { 4 | signal output out; 5 | out <== 42; 6 | } 7 | 8 | component main = directOutput(); 9 | -------------------------------------------------------------------------------- /tests/circuits/integration/indexOutOfBounds.circom: -------------------------------------------------------------------------------- 1 | // This circuit should fail because of out of bounds error 2 | 3 | pragma circom 2.1.0; 4 | 5 | template indexOutOfBounds() { 6 | signal arr[10]; 7 | 8 | for (var i = 0; i < 100; i++) { 9 | arr[i] <== 1; 10 | } 11 | } 12 | 13 | component main = indexOutOfBounds(); 14 | -------------------------------------------------------------------------------- /tests/circuits/integration/infixOps.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template infixOps() { 4 | signal input x0; 5 | signal input x1; 6 | signal input x2; 7 | signal input x3; 8 | signal input x4; 9 | signal input x5; 10 | 11 | signal output mul_2_3; 12 | // signal output div_4_3; // unsupported for NumberU32 13 | signal output idiv_4_3; 14 | signal output add_3_4; 15 | signal output sub_4_1; 16 | signal output pow_2_4; 17 | signal output mod_5_3; 18 | signal output shl_5_1; 19 | signal output shr_5_1; 20 | signal output leq_2_3; 21 | signal output leq_3_3; 22 | signal output leq_4_3; 23 | signal output geq_2_3; 24 | signal output geq_3_3; 25 | signal output geq_4_3; 26 | signal output lt_2_3; 27 | signal output lt_3_3; 28 | signal output lt_4_3; 29 | signal output gt_2_3; 30 | signal output gt_3_3; 31 | signal output gt_4_3; 32 | signal output eq_2_3; 33 | signal output eq_3_3; 34 | signal output neq_2_3; 35 | signal output neq_3_3; 36 | signal output or_0_1; 37 | signal output and_0_1; 38 | signal output bit_or_1_3; 39 | signal output bit_and_1_3; 40 | signal output bit_xor_1_3; 41 | 42 | mul_2_3 <== x2 * x3; 43 | // div_4_3 <== x4 / x3; 44 | idiv_4_3 <== x4 \ x3; 45 | add_3_4 <== x3 + x4; 46 | sub_4_1 <== x4 - x1; 47 | pow_2_4 <== x2 ** x4; 48 | mod_5_3 <== x5 % x3; 49 | shl_5_1 <== x5 << x1; 50 | shr_5_1 <== x5 >> x1; 51 | leq_2_3 <== x2 <= x3; 52 | leq_3_3 <== x3 <= x3; 53 | leq_4_3 <== x4 <= x3; 54 | geq_2_3 <== x2 >= x3; 55 | geq_3_3 <== x3 >= x3; 56 | geq_4_3 <== x4 >= x3; 57 | lt_2_3 <== x2 < x3; 58 | lt_3_3 <== x3 < x3; 59 | lt_4_3 <== x4 < x3; 60 | gt_2_3 <== x2 > x3; 61 | gt_3_3 <== x3 > x3; 62 | gt_4_3 <== x4 > x3; 63 | eq_2_3 <== x2 == x3; 64 | eq_3_3 <== x3 == x3; 65 | neq_2_3 <== x2 != x3; 66 | neq_3_3 <== x3 != x3; 67 | or_0_1 <== x0 || x1; 68 | and_0_1 <== x0 && x1; 69 | bit_or_1_3 <== x1 | x3; 70 | bit_and_1_3 <== x1 & x3; 71 | bit_xor_1_3 <== x1 ^ x3; 72 | } 73 | 74 | component main = infixOps(); 75 | -------------------------------------------------------------------------------- /tests/circuits/integration/mainTemplateArgument.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template mainComponent (argument) { 4 | signal input in; 5 | signal output out; 6 | 7 | out <== in + argument; 8 | } 9 | 10 | component main = mainComponent(100); 11 | -------------------------------------------------------------------------------- /tests/circuits/integration/matElemMul.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | // Matrix multiplication by element 4 | template matElemMul (m,n) { 5 | signal input a[m][n]; 6 | signal input b[m][n]; 7 | signal output out[m][n]; 8 | 9 | for (var i=0; i < m; i++) { 10 | for (var j=0; j < n; j++) { 11 | out[i][j] <== a[i][j] * b[i][j]; 12 | } 13 | } 14 | } 15 | 16 | component main = matElemMul(2,2); 17 | -------------------------------------------------------------------------------- /tests/circuits/integration/prefixOps.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template prefixOps() { 4 | signal input a; 5 | signal input b; 6 | signal input c; 7 | 8 | signal output negateA; 9 | 10 | signal output notA; 11 | signal output notB; 12 | signal output notC; 13 | 14 | signal output complementA; 15 | signal output complementB; 16 | signal output complementC; 17 | 18 | negateA <== -a; 19 | 20 | notA <== !a; 21 | notB <== !b; 22 | notC <== !c; 23 | 24 | complementA <== ~a; 25 | complementB <== ~b; 26 | complementC <== ~c; 27 | } 28 | 29 | component main = prefixOps(); 30 | -------------------------------------------------------------------------------- /tests/circuits/integration/sum.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | // Two element sum 4 | template sum () { 5 | signal input a; 6 | signal input b; 7 | signal output out; 8 | 9 | out <== a + b; 10 | } 11 | 12 | component main = sum(); 13 | -------------------------------------------------------------------------------- /tests/circuits/integration/underConstrained.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template underConstrained() { 4 | signal output x; 5 | } 6 | 7 | component main = underConstrained(); 8 | -------------------------------------------------------------------------------- /tests/circuits/integration/xEqX.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.0; 2 | 3 | template xEqX() { 4 | signal input x; 5 | signal output out; 6 | 7 | out <== x == x; 8 | } 9 | 10 | component main = xEqX(); 11 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/ArgMax.circom: -------------------------------------------------------------------------------- 1 | // from 0xZKML/zk-mnist 2 | 3 | pragma circom 2.0.0; 4 | 5 | include "./circomlib/switcher.circom"; 6 | 7 | template ArgMax (n) { 8 | signal input in[n]; 9 | signal output out; 10 | 11 | // assert (out < n); 12 | signal gts[n]; // store comparators 13 | component switchers[n+1]; // switcher for comparing maxs 14 | component aswitchers[n+1]; // switcher for arg max 15 | 16 | signal maxs[n+1]; 17 | signal amaxs[n+1]; 18 | 19 | maxs[0] <== in[0]; 20 | amaxs[0] <== 0; 21 | for(var i = 0; i < n; i++) { 22 | gts[i] <== in[i] > maxs[i]; // changed to 252 (maximum) for better compatibility 23 | switchers[i+1] = Switcher(); 24 | aswitchers[i+1] = Switcher(); 25 | 26 | switchers[i+1].sel <== gts[i]; 27 | switchers[i+1].L <== maxs[i]; 28 | switchers[i+1].R <== in[i]; 29 | 30 | aswitchers[i+1].sel <== gts[i]; 31 | aswitchers[i+1].L <== amaxs[i]; 32 | aswitchers[i+1].R <== i; 33 | amaxs[i+1] <== aswitchers[i+1].outL; 34 | maxs[i+1] <== switchers[i+1].outL; 35 | } 36 | 37 | out <== amaxs[n]; 38 | } 39 | 40 | component main = ArgMax(5); 41 | 42 | /* INPUT = { 43 | "in": ["2","3","1","5","4"], 44 | "out": "3" 45 | } */ -------------------------------------------------------------------------------- /tests/circuits/machine-learning/AveragePooling2D.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "./SumPooling2D.circom"; 4 | 5 | // AveragePooling2D layer, poolSize is required to be equal for both dimensions, might lose precision compared to SumPooling2D 6 | template AveragePooling2D (nRows, nCols, nChannels, poolSize, strides) { 7 | signal input in[nRows][nCols][nChannels]; 8 | signal output out[(nRows-poolSize)\strides+1][(nCols-poolSize)\strides+1][nChannels]; 9 | // signal input remainder[(nRows-poolSize)\strides+1][(nCols-poolSize)\strides+1][nChannels]; 10 | 11 | component sumPooling2D = SumPooling2D (nRows, nCols, nChannels, poolSize, strides); 12 | 13 | for (var i=0; i> n; 20 | } 21 | } 22 | } 23 | } 24 | 25 | // component main { public [ out ] } = BatchNormalization2D(1, 1, 1, 1000); 26 | 27 | /* INPUT = { 28 | "in": ["123"], 29 | "a": ["234"], 30 | "b": ["345678"], 31 | "out": ["374"], 32 | "remainder": ["460"] 33 | } */ -------------------------------------------------------------------------------- /tests/circuits/machine-learning/Conv1D.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "./circomlib-matrix/matElemMul.circom"; 4 | include "./circomlib-matrix/matElemSum.circom"; 5 | include "./util.circom"; 6 | 7 | // Conv1D layer with valid padding 8 | // n = 10 to the power of the number of decimal places 9 | template Conv1D (nInputs, nChannels, nFilters, kernelSize, strides, n) { 10 | signal input in[nInputs][nChannels]; 11 | signal input weights[kernelSize][nChannels][nFilters]; 12 | signal input bias[nFilters]; 13 | signal output out[(nInputs-kernelSize)\strides+1][nFilters]; 14 | // signal input remainder[(nInputs-kernelSize)\strides+1][nFilters]; 15 | 16 | component mul[(nInputs-kernelSize)\strides+1][nChannels][nFilters]; 17 | component elemSum[(nInputs-kernelSize)\strides+1][nChannels][nFilters]; 18 | component sum[(nInputs-kernelSize)\strides+1][nFilters]; 19 | 20 | for (var i=0; i<(nInputs-kernelSize)\strides+1; i++) { 21 | for (var j=0; j> n; 41 | } 42 | } 43 | } -------------------------------------------------------------------------------- /tests/circuits/machine-learning/Conv2D.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "./circomlib-matrix/matElemMul.circom"; 4 | include "./circomlib-matrix/matElemSum.circom"; 5 | include "./util.circom"; 6 | 7 | // Conv2D layer with valid padding 8 | // n = 10 to the power of the number of decimal places 9 | template Conv2D (nRows, nCols, nChannels, nFilters, kernelSize, strides, n) { 10 | signal input in[nRows][nCols][nChannels]; 11 | signal input weights[kernelSize][kernelSize][nChannels][nFilters]; 12 | signal input bias[nFilters]; 13 | signal input out[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nFilters]; 14 | // signal input remainder[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nFilters]; 15 | 16 | component mul[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nChannels][nFilters]; 17 | component elemSum[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nChannels][nFilters]; 18 | component sum[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nFilters]; 19 | 20 | for (var i=0; i<(nRows-kernelSize)\strides+1; i++) { 21 | for (var j=0; j<(nCols-kernelSize)\strides+1; j++) { 22 | for (var k=0; k> n; 46 | } 47 | } 48 | } 49 | } -------------------------------------------------------------------------------- /tests/circuits/machine-learning/Dense.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "./circomlib-matrix/matMul.circom"; 4 | // Dense layer 5 | // n = 10 to the power of the number of decimal places 6 | template Dense (nInputs, nOutputs, n) { 7 | signal input in[nInputs]; 8 | signal input weights[nInputs][nOutputs]; 9 | signal input bias[nOutputs]; 10 | signal input out[nOutputs]; 11 | signal input remainder[nOutputs]; 12 | 13 | component dot[nOutputs]; 14 | 15 | for (var i=0; i> n; 25 | } 26 | } -------------------------------------------------------------------------------- /tests/circuits/machine-learning/DepthwiseConv2D.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.1; 2 | // include "./Conv2D.circom"; 3 | 4 | include "./circomlib/sign.circom"; 5 | include "./circomlib/bitify.circom"; 6 | include "./circomlib/comparators.circom"; 7 | include "./circomlib-matrix/matElemMul.circom"; 8 | include "./circomlib-matrix/matElemSum.circom"; 9 | include "./util.circom"; 10 | 11 | // Depthwise Convolution layer with valid padding 12 | // Note that nFilters must be a multiple of nChannels 13 | // n = 10 to the power of the number of decimal places 14 | // component main = DepthwiseConv2D(34, 34, 8, 8, 3, 1); 15 | template DepthwiseConv2D (nRows, nCols, nChannels, nFilters, kernelSize, strides, n) { 16 | var outRows = (nRows-kernelSize)\strides+1; 17 | var outCols = (nCols-kernelSize)\strides+1; 18 | 19 | signal input in[nRows][nCols][nChannels]; 20 | signal input weights[kernelSize][kernelSize][nFilters]; // weights are 3d because depth is 1 21 | signal input bias[nFilters]; 22 | // signal input remainder[outRows][outCols][nFilters]; 23 | 24 | signal output out[outRows][outCols][nFilters]; 25 | 26 | component mul[outRows][outCols][nFilters]; 27 | component elemSum[outRows][outCols][nFilters]; 28 | 29 | var valid_groups = nFilters % nChannels; 30 | var filtersPerChannel = nFilters / nChannels; 31 | 32 | signal groups; 33 | groups <== valid_groups; 34 | component is_zero = IsZero(); 35 | is_zero.in <== groups; 36 | is_zero.out === 1; 37 | 38 | for (var row=0; row> n; 60 | } 61 | } 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/Flatten2D.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | // Flatten layer with that accepts a 2D input 4 | template Flatten2D (nRows, nCols, nChannels) { 5 | signal input in[nRows][nCols][nChannels]; 6 | signal output out[nRows*nCols*nChannels]; 7 | 8 | var idx = 0; 9 | 10 | for (var i=0; i> n; 33 | } 34 | } 35 | } 36 | } 37 | 38 | 39 | // component main = PointwiseConv2D(32, 32, 8, 16); 40 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/ReLU.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "./util.circom"; 4 | 5 | // ReLU layer 6 | template ReLU () { 7 | signal input in; 8 | signal output out; 9 | 10 | component isPositive = IsPositive(); 11 | 12 | isPositive.in <== in; 13 | 14 | out <== in * isPositive.out; 15 | } -------------------------------------------------------------------------------- /tests/circuits/machine-learning/SeparableConv2D.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.1.1; 2 | 3 | include "./PointwiseConv2D.circom"; 4 | include "./DepthwiseConv2D.circom"; 5 | 6 | // Separable convolution layer with valid padding. 7 | // Quantization is done by the caller by multiplying float values by 10**exp. 8 | template SeparableConv2D (nRows, nCols, nChannels, nDepthFilters, nPointFilters, depthKernelSize, strides, n) { 9 | var outRows = (nRows-depthKernelSize)\strides+1; 10 | var outCols = (nCols-depthKernelSize)\strides+1; 11 | 12 | signal input in[nRows][nCols][nChannels]; 13 | signal input depthWeights[depthKernelSize][depthKernelSize][nDepthFilters]; // weights are 3d because depth is 1 14 | signal input depthBias[nDepthFilters]; 15 | // signal input depthRemainder[outRows][outCols][nDepthFilters]; 16 | signal input depthOut[outRows][outCols][nDepthFilters]; 17 | 18 | signal input pointWeights[nChannels][nPointFilters]; // weights are 2d because depthKernelSize is one 19 | signal input pointBias[nPointFilters]; 20 | 21 | // signal input pointRemainder[outRows][outCols][nPointFilters]; 22 | signal oputput pointOut[outRows][outCols][nPointFilters]; 23 | 24 | component depthConv = DepthwiseConv2D(nRows, nCols, nChannels, nDepthFilters, depthKernelSize, strides, n); 25 | component pointConv = PointwiseConv2D(outRows, outCols, nDepthFilters, nPointFilters, n); 26 | 27 | for (var filter=0; filter 0) { 15 | sum[idx] <== sum[idx-1] + a[i][j]; 16 | } 17 | idx++; 18 | } 19 | } 20 | 21 | out <== sum[m*n-1]; 22 | } -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib-matrix/matMul.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | include "matElemMul.circom"; 4 | include "matElemSum.circom"; 5 | 6 | // matrix multiplication 7 | template matMul (m,n,p) { 8 | signal input a[m][n]; 9 | signal input b[n][p]; 10 | signal output out[m][p]; 11 | 12 | component matElemMulComp[m][p]; 13 | component matElemSumComp[m][p]; 14 | 15 | for (var i=0; i < m; i++) { 16 | for (var j=0; j < p; j++) { 17 | matElemMulComp[i][j] = matElemMul(1,n); 18 | matElemSumComp[i][j] = matElemSum(1,n); 19 | for (var k=0; k < n; k++) { 20 | matElemMulComp[i][j].a[0][k] <== a[i][k]; 21 | matElemMulComp[i][j].b[0][k] <== b[k][j]; 22 | } 23 | for (var k=0; k < n; k++) { 24 | matElemSumComp[i][j].a[0][k] <== matElemMulComp[i][j].out[0][k]; 25 | } 26 | out[i][j] <== matElemSumComp[i][j].out; 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/aliascheck.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | include "compconstant.circom"; 22 | 23 | 24 | template AliasCheck() { 25 | 26 | signal input in[254]; 27 | 28 | component compConstant = CompConstant(-1); 29 | 30 | for (var i=0; i<254; i++) in[i] ==> compConstant.in[i]; 31 | 32 | compConstant.out === 0; 33 | } 34 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/babyjub.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | include "bitify.circom"; 22 | include "escalarmulfix.circom"; 23 | 24 | template BabyAdd() { 25 | signal input x1; 26 | signal input y1; 27 | signal input x2; 28 | signal input y2; 29 | signal output xout; 30 | signal output yout; 31 | 32 | signal beta; 33 | signal gamma; 34 | signal delta; 35 | signal tau; 36 | 37 | var a = 168700; 38 | var d = 168696; 39 | 40 | beta <== x1*y2; 41 | gamma <== y1*x2; 42 | delta <== (-a*x1+y1)*(x2 + y2); 43 | tau <== beta * gamma; 44 | 45 | xout <-- (beta + gamma) / (1+ d*tau); 46 | (1+ d*tau) * xout === (beta + gamma); 47 | 48 | yout <-- (delta + a*beta - gamma) / (1-d*tau); 49 | (1-d*tau)*yout === (delta + a*beta - gamma); 50 | } 51 | 52 | template BabyDbl() { 53 | signal input x; 54 | signal input y; 55 | signal output xout; 56 | signal output yout; 57 | 58 | component adder = BabyAdd(); 59 | adder.x1 <== x; 60 | adder.y1 <== y; 61 | adder.x2 <== x; 62 | adder.y2 <== y; 63 | 64 | adder.xout ==> xout; 65 | adder.yout ==> yout; 66 | } 67 | 68 | 69 | template BabyCheck() { 70 | signal input x; 71 | signal input y; 72 | 73 | signal x2; 74 | signal y2; 75 | 76 | var a = 168700; 77 | var d = 168696; 78 | 79 | x2 <== x*x; 80 | y2 <== y*y; 81 | 82 | a*x2 + y2 === 1 + d*x2*y2; 83 | } 84 | 85 | // Extracts the public key from private key 86 | template BabyPbk() { 87 | signal input in; 88 | signal output Ax; 89 | signal output Ay; 90 | 91 | var BASE8[2] = [ 92 | 5299619240641551281634865583518297030282874472190772894086521144482721001553, 93 | 16950150798460657717958625567821834550301663161624707787222815936182638968203 94 | ]; 95 | 96 | component pvkBits = Num2Bits(253); 97 | pvkBits.in <== in; 98 | 99 | component mulFix = EscalarMulFix(253, BASE8); 100 | 101 | var i; 102 | for (i=0; i<253; i++) { 103 | mulFix.e[i] <== pvkBits.out[i]; 104 | } 105 | Ax <== mulFix.out[0]; 106 | Ay <== mulFix.out[1]; 107 | } 108 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/binsum.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | 20 | /* 21 | 22 | Binary Sum 23 | ========== 24 | 25 | This component creates a binary sum componet of ops operands and n bits each operand. 26 | 27 | e is Number of carries: Depends on the number of operands in the input. 28 | 29 | Main Constraint: 30 | in[0][0] * 2^0 + in[0][1] * 2^1 + ..... + in[0][n-1] * 2^(n-1) + 31 | + in[1][0] * 2^0 + in[1][1] * 2^1 + ..... + in[1][n-1] * 2^(n-1) + 32 | + .. 33 | + in[ops-1][0] * 2^0 + in[ops-1][1] * 2^1 + ..... + in[ops-1][n-1] * 2^(n-1) + 34 | === 35 | out[0] * 2^0 + out[1] * 2^1 + + out[n+e-1] *2(n+e-1) 36 | 37 | To waranty binary outputs: 38 | 39 | out[0] * (out[0] - 1) === 0 40 | out[1] * (out[0] - 1) === 0 41 | . 42 | . 43 | . 44 | out[n+e-1] * (out[n+e-1] - 1) == 0 45 | 46 | */ 47 | 48 | 49 | /* 50 | This function calculates the number of extra bits in the output to do the full sum. 51 | */ 52 | pragma circom 2.0.0; 53 | 54 | function nbits(a) { 55 | var n = 1; 56 | var r = 0; 57 | while (n-1> k) & 1; 89 | 90 | // Ensure out is binary 91 | out[k] * (out[k] - 1) === 0; 92 | 93 | lout += out[k] * e2; 94 | 95 | e2 = e2+e2; 96 | } 97 | 98 | // Ensure the sum; 99 | 100 | lin === lout; 101 | } 102 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/bitify.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | include "comparators.circom"; 22 | include "aliascheck.circom"; 23 | 24 | 25 | template Num2Bits(n) { 26 | signal input in; 27 | signal output out[n]; 28 | var lc1=0; 29 | 30 | var e2=1; 31 | for (var i = 0; i> i) & 1; 33 | out[i] * (out[i] -1 ) === 0; 34 | lc1 += out[i] * e2; 35 | e2 = e2+e2; 36 | } 37 | 38 | lc1 === in; 39 | } 40 | 41 | template Num2Bits_strict() { 42 | signal input in; 43 | signal output out[254]; 44 | 45 | component aliasCheck = AliasCheck(); 46 | component n2b = Num2Bits(254); 47 | in ==> n2b.in; 48 | 49 | for (var i=0; i<254; i++) { 50 | n2b.out[i] ==> out[i]; 51 | n2b.out[i] ==> aliasCheck.in[i]; 52 | } 53 | } 54 | 55 | template Bits2Num(n) { 56 | signal input in[n]; 57 | signal output out; 58 | var lc1=0; 59 | 60 | var e2 = 1; 61 | for (var i = 0; i out; 67 | } 68 | 69 | template Bits2Num_strict() { 70 | signal input in[254]; 71 | signal output out; 72 | 73 | component aliasCheck = AliasCheck(); 74 | component b2n = Bits2Num(254); 75 | 76 | for (var i=0; i<254; i++) { 77 | in[i] ==> b2n.in[i]; 78 | in[i] ==> aliasCheck.in[i]; 79 | } 80 | 81 | b2n.out ==> out; 82 | } 83 | 84 | template Num2BitsNeg(n) { 85 | signal input in; 86 | signal output out[n]; 87 | var lc1=0; 88 | 89 | component isZero; 90 | 91 | isZero = IsZero(); 92 | 93 | var neg = n == 0 ? 0 : 2**n - in; 94 | 95 | for (var i = 0; i> i) & 1; 97 | out[i] * (out[i] -1 ) === 0; 98 | lc1 += out[i] * 2**i; 99 | } 100 | 101 | in ==> isZero.in; 102 | 103 | 104 | 105 | lc1 + isZero.out * 2**n === 2**n - in; 106 | } 107 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/comparators.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | include "bitify.circom"; 22 | include "binsum.circom"; 23 | 24 | template IsZero() { 25 | signal input in; 26 | signal output out; 27 | 28 | out <== in == 0; 29 | } 30 | 31 | 32 | template IsEqual() { 33 | signal input in[2]; 34 | signal output out; 35 | 36 | out <== in[0] == in[1]; 37 | } 38 | 39 | template ForceEqualIfEnabled() { 40 | signal input enabled; 41 | signal input in[2]; 42 | 43 | component isz = IsZero(); 44 | 45 | in[1] - in[0] ==> isz.in; 46 | 47 | (1 - isz.out)*enabled === 0; 48 | } 49 | 50 | /* 51 | // N is the number of bits the input have. 52 | // The MSF is the sign bit. 53 | template LessThan(n) { 54 | signal input in[2]; 55 | signal output out; 56 | 57 | component num2Bits0; 58 | component num2Bits1; 59 | 60 | component adder; 61 | 62 | adder = BinSum(n, 2); 63 | 64 | num2Bits0 = Num2Bits(n); 65 | num2Bits1 = Num2BitsNeg(n); 66 | 67 | in[0] ==> num2Bits0.in; 68 | in[1] ==> num2Bits1.in; 69 | 70 | var i; 71 | for (i=0;i adder.in[0][i]; 73 | num2Bits1.out[i] ==> adder.in[1][i]; 74 | } 75 | 76 | adder.out[n-1] ==> out; 77 | } 78 | */ 79 | 80 | template LessThan(n) { 81 | assert(n <= 252); 82 | signal input in[2]; 83 | signal output out; 84 | 85 | component n2b = Num2Bits(n+1); 86 | 87 | n2b.in <== in[0]+ (1< out; 105 | } 106 | 107 | // N is the number of bits the input have. 108 | // The MSF is the sign bit. 109 | template GreaterThan(n) { 110 | signal input in[2]; 111 | signal output out; 112 | 113 | component lt = LessThan(n); 114 | 115 | lt.in[0] <== in[1]; 116 | lt.in[1] <== in[0]; 117 | lt.out ==> out; 118 | } 119 | 120 | // N is the number of bits the input have. 121 | // The MSF is the sign bit. 122 | template GreaterEqThan(n) { 123 | signal input in[2]; 124 | signal output out; 125 | 126 | component lt = LessThan(n); 127 | 128 | lt.in[0] <== in[1]; 129 | lt.in[1] <== in[0]+1; 130 | lt.out ==> out; 131 | } 132 | 133 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/compconstant.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | include "bitify.circom"; 22 | 23 | // Returns 1 if in (in binary) > ct 24 | 25 | template CompConstant(ct) { 26 | signal input in[254]; 27 | signal output out; 28 | 29 | signal parts[127]; 30 | signal sout; 31 | 32 | var clsb; 33 | var cmsb; 34 | var slsb; 35 | var smsb; 36 | 37 | var sum=0; 38 | 39 | var b = (1 << 128) -1; 40 | var a = 1; 41 | var e = 1; 42 | var i; 43 | 44 | for (i=0;i<127; i++) { 45 | clsb = (ct >> (i*2)) & 1; 46 | cmsb = (ct >> (i*2+1)) & 1; 47 | slsb = in[i*2]; 48 | smsb = in[i*2+1]; 49 | 50 | if ((cmsb==0)&&(clsb==0)) { 51 | parts[i] <== -b*smsb*slsb + b*smsb + b*slsb; 52 | } else if ((cmsb==0)&&(clsb==1)) { 53 | parts[i] <== a*smsb*slsb - a*slsb + b*smsb - a*smsb + a; 54 | } else if ((cmsb==1)&&(clsb==0)) { 55 | parts[i] <== b*smsb*slsb - a*smsb + a; 56 | } else { 57 | parts[i] <== -a*smsb*slsb + a; 58 | } 59 | 60 | sum = sum + parts[i]; 61 | 62 | b = b -e; 63 | a = a +e; 64 | e = e*2; 65 | } 66 | 67 | sout <== sum; 68 | 69 | component num2bits = Num2Bits(135); 70 | 71 | num2bits.in <== sout; 72 | 73 | out <== num2bits.out[127]; 74 | } 75 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/escalarmulany.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | include "montgomery.circom"; 22 | include "babyjub.circom"; 23 | include "comparators.circom"; 24 | 25 | template Multiplexor2() { 26 | signal input sel; 27 | signal input in[2][2]; 28 | signal output out[2]; 29 | 30 | out[0] <== (in[1][0] - in[0][0])*sel + in[0][0]; 31 | out[1] <== (in[1][1] - in[0][1])*sel + in[0][1]; 32 | } 33 | 34 | template BitElementMulAny() { 35 | signal input sel; 36 | signal input dblIn[2]; 37 | signal input addIn[2]; 38 | signal output dblOut[2]; 39 | signal output addOut[2]; 40 | 41 | component doubler = MontgomeryDouble(); 42 | component adder = MontgomeryAdd(); 43 | component selector = Multiplexor2(); 44 | 45 | 46 | sel ==> selector.sel; 47 | 48 | dblIn[0] ==> doubler.in[0]; 49 | dblIn[1] ==> doubler.in[1]; 50 | doubler.out[0] ==> adder.in1[0]; 51 | doubler.out[1] ==> adder.in1[1]; 52 | addIn[0] ==> adder.in2[0]; 53 | addIn[1] ==> adder.in2[1]; 54 | addIn[0] ==> selector.in[0][0]; 55 | addIn[1] ==> selector.in[0][1]; 56 | adder.out[0] ==> selector.in[1][0]; 57 | adder.out[1] ==> selector.in[1][1]; 58 | 59 | doubler.out[0] ==> dblOut[0]; 60 | doubler.out[1] ==> dblOut[1]; 61 | selector.out[0] ==> addOut[0]; 62 | selector.out[1] ==> addOut[1]; 63 | } 64 | 65 | // p is montgomery point 66 | // n must be <= 248 67 | // returns out in twisted edwards 68 | // Double is in montgomery to be linked; 69 | 70 | template SegmentMulAny(n) { 71 | signal input e[n]; 72 | signal input p[2]; 73 | signal output out[2]; 74 | signal output dbl[2]; 75 | 76 | component bits[n-1]; 77 | 78 | component e2m = Edwards2Montgomery(); 79 | 80 | p[0] ==> e2m.in[0]; 81 | p[1] ==> e2m.in[1]; 82 | 83 | var i; 84 | 85 | bits[0] = BitElementMulAny(); 86 | e2m.out[0] ==> bits[0].dblIn[0]; 87 | e2m.out[1] ==> bits[0].dblIn[1]; 88 | e2m.out[0] ==> bits[0].addIn[0]; 89 | e2m.out[1] ==> bits[0].addIn[1]; 90 | e[1] ==> bits[0].sel; 91 | 92 | for (i=1; i bits[i].dblIn[0]; 96 | bits[i-1].dblOut[1] ==> bits[i].dblIn[1]; 97 | bits[i-1].addOut[0] ==> bits[i].addIn[0]; 98 | bits[i-1].addOut[1] ==> bits[i].addIn[1]; 99 | e[i+1] ==> bits[i].sel; 100 | } 101 | 102 | bits[n-2].dblOut[0] ==> dbl[0]; 103 | bits[n-2].dblOut[1] ==> dbl[1]; 104 | 105 | component m2e = Montgomery2Edwards(); 106 | 107 | bits[n-2].addOut[0] ==> m2e.in[0]; 108 | bits[n-2].addOut[1] ==> m2e.in[1]; 109 | 110 | component eadder = BabyAdd(); 111 | 112 | m2e.out[0] ==> eadder.x1; 113 | m2e.out[1] ==> eadder.y1; 114 | -p[0] ==> eadder.x2; 115 | p[1] ==> eadder.y2; 116 | 117 | component lastSel = Multiplexor2(); 118 | 119 | e[0] ==> lastSel.sel; 120 | eadder.xout ==> lastSel.in[0][0]; 121 | eadder.yout ==> lastSel.in[0][1]; 122 | m2e.out[0] ==> lastSel.in[1][0]; 123 | m2e.out[1] ==> lastSel.in[1][1]; 124 | 125 | lastSel.out[0] ==> out[0]; 126 | lastSel.out[1] ==> out[1]; 127 | } 128 | 129 | // This function assumes that p is in the subgroup and it is different to 0 130 | 131 | template EscalarMulAny(n) { 132 | signal input e[n]; // Input in binary format 133 | signal input p[2]; // Point (Twisted format) 134 | signal output out[2]; // Point (Twisted format) 135 | 136 | var nsegments = (n-1)\148 +1; 137 | var nlastsegment = n - (nsegments-1)*148; 138 | 139 | component segments[nsegments]; 140 | component doublers[nsegments-1]; 141 | component m2e[nsegments-1]; 142 | component adders[nsegments-1]; 143 | component zeropoint = IsZero(); 144 | zeropoint.in <== p[0]; 145 | 146 | var s; 147 | var i; 148 | var nseg; 149 | 150 | for (s=0; s segments[s].e[i]; 158 | } 159 | 160 | if (s==0) { 161 | // force G8 point if input point is zero 162 | segments[s].p[0] <== p[0] + (5299619240641551281634865583518297030282874472190772894086521144482721001553 - p[0])*zeropoint.out; 163 | segments[s].p[1] <== p[1] + (16950150798460657717958625567821834550301663161624707787222815936182638968203 - p[1])*zeropoint.out; 164 | } else { 165 | doublers[s-1] = MontgomeryDouble(); 166 | m2e[s-1] = Montgomery2Edwards(); 167 | adders[s-1] = BabyAdd(); 168 | 169 | segments[s-1].dbl[0] ==> doublers[s-1].in[0]; 170 | segments[s-1].dbl[1] ==> doublers[s-1].in[1]; 171 | 172 | doublers[s-1].out[0] ==> m2e[s-1].in[0]; 173 | doublers[s-1].out[1] ==> m2e[s-1].in[1]; 174 | 175 | m2e[s-1].out[0] ==> segments[s].p[0]; 176 | m2e[s-1].out[1] ==> segments[s].p[1]; 177 | 178 | if (s==1) { 179 | segments[s-1].out[0] ==> adders[s-1].x1; 180 | segments[s-1].out[1] ==> adders[s-1].y1; 181 | } else { 182 | adders[s-2].xout ==> adders[s-1].x1; 183 | adders[s-2].yout ==> adders[s-1].y1; 184 | } 185 | segments[s].out[0] ==> adders[s-1].x2; 186 | segments[s].out[1] ==> adders[s-1].y2; 187 | } 188 | } 189 | 190 | if (nsegments == 1) { 191 | segments[0].out[0]*(1-zeropoint.out) ==> out[0]; 192 | segments[0].out[1]+(1-segments[0].out[1])*zeropoint.out ==> out[1]; 193 | } else { 194 | adders[nsegments-2].xout*(1-zeropoint.out) ==> out[0]; 195 | adders[nsegments-2].yout+(1-adders[nsegments-2].yout)*zeropoint.out ==> out[1]; 196 | } 197 | } 198 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/escalarmulfix.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | include "mux3.circom"; 22 | include "montgomery.circom"; 23 | include "babyjub.circom"; 24 | 25 | /* 26 | Window of 3 elements, it calculates 27 | out = base + base*in[0] + 2*base*in[1] + 4*base*in[2] 28 | out4 = 4*base 29 | 30 | The result should be compensated. 31 | */ 32 | 33 | /* 34 | 35 | The scalar is s = a0 + a1*2^3 + a2*2^6 + ...... + a81*2^243 36 | First We calculate Q = B + 2^3*B + 2^6*B + ......... + 2^246*B 37 | 38 | Then we calculate S1 = 2*2^246*B + (1 + a0)*B + (2^3 + a1)*B + .....+ (2^243 + a81)*B 39 | 40 | And Finaly we compute the result: RES = SQ - Q 41 | 42 | As you can see the input of the adders cannot be equal nor zero, except for the last 43 | substraction that it's done in montgomery. 44 | 45 | A good way to see it is that the accumulator input of the adder >= 2^247*B and the other input 46 | is the output of the windows that it's going to be <= 2^246*B 47 | */ 48 | template WindowMulFix() { 49 | signal input in[3]; 50 | signal input base[2]; 51 | signal output out[2]; 52 | signal output out8[2]; // Returns 8*Base (To be linked) 53 | 54 | component mux = MultiMux3(2); 55 | 56 | mux.s[0] <== in[0]; 57 | mux.s[1] <== in[1]; 58 | mux.s[2] <== in[2]; 59 | 60 | component dbl2 = MontgomeryDouble(); 61 | component adr3 = MontgomeryAdd(); 62 | component adr4 = MontgomeryAdd(); 63 | component adr5 = MontgomeryAdd(); 64 | component adr6 = MontgomeryAdd(); 65 | component adr7 = MontgomeryAdd(); 66 | component adr8 = MontgomeryAdd(); 67 | 68 | // in[0] -> 1*BASE 69 | 70 | mux.c[0][0] <== base[0]; 71 | mux.c[1][0] <== base[1]; 72 | 73 | // in[1] -> 2*BASE 74 | dbl2.in[0] <== base[0]; 75 | dbl2.in[1] <== base[1]; 76 | mux.c[0][1] <== dbl2.out[0]; 77 | mux.c[1][1] <== dbl2.out[1]; 78 | 79 | // in[2] -> 3*BASE 80 | adr3.in1[0] <== base[0]; 81 | adr3.in1[1] <== base[1]; 82 | adr3.in2[0] <== dbl2.out[0]; 83 | adr3.in2[1] <== dbl2.out[1]; 84 | mux.c[0][2] <== adr3.out[0]; 85 | mux.c[1][2] <== adr3.out[1]; 86 | 87 | // in[3] -> 4*BASE 88 | adr4.in1[0] <== base[0]; 89 | adr4.in1[1] <== base[1]; 90 | adr4.in2[0] <== adr3.out[0]; 91 | adr4.in2[1] <== adr3.out[1]; 92 | mux.c[0][3] <== adr4.out[0]; 93 | mux.c[1][3] <== adr4.out[1]; 94 | 95 | // in[4] -> 5*BASE 96 | adr5.in1[0] <== base[0]; 97 | adr5.in1[1] <== base[1]; 98 | adr5.in2[0] <== adr4.out[0]; 99 | adr5.in2[1] <== adr4.out[1]; 100 | mux.c[0][4] <== adr5.out[0]; 101 | mux.c[1][4] <== adr5.out[1]; 102 | 103 | // in[5] -> 6*BASE 104 | adr6.in1[0] <== base[0]; 105 | adr6.in1[1] <== base[1]; 106 | adr6.in2[0] <== adr5.out[0]; 107 | adr6.in2[1] <== adr5.out[1]; 108 | mux.c[0][5] <== adr6.out[0]; 109 | mux.c[1][5] <== adr6.out[1]; 110 | 111 | // in[6] -> 7*BASE 112 | adr7.in1[0] <== base[0]; 113 | adr7.in1[1] <== base[1]; 114 | adr7.in2[0] <== adr6.out[0]; 115 | adr7.in2[1] <== adr6.out[1]; 116 | mux.c[0][6] <== adr7.out[0]; 117 | mux.c[1][6] <== adr7.out[1]; 118 | 119 | // in[7] -> 8*BASE 120 | adr8.in1[0] <== base[0]; 121 | adr8.in1[1] <== base[1]; 122 | adr8.in2[0] <== adr7.out[0]; 123 | adr8.in2[1] <== adr7.out[1]; 124 | mux.c[0][7] <== adr8.out[0]; 125 | mux.c[1][7] <== adr8.out[1]; 126 | 127 | out8[0] <== adr8.out[0]; 128 | out8[1] <== adr8.out[1]; 129 | 130 | out[0] <== mux.out[0]; 131 | out[1] <== mux.out[1]; 132 | } 133 | 134 | 135 | /* 136 | This component does a multiplication of a escalar times a fix base 137 | Signals: 138 | e: The scalar in bits 139 | base: the base point in edwards format 140 | out: The result 141 | dbl: Point in Edwards to be linked to the next segment. 142 | */ 143 | 144 | template SegmentMulFix(nWindows) { 145 | signal input e[nWindows*3]; 146 | signal input base[2]; 147 | signal output out[2]; 148 | signal output dbl[2]; 149 | 150 | var i; 151 | var j; 152 | 153 | // Convert the base to montgomery 154 | 155 | component e2m = Edwards2Montgomery(); 156 | e2m.in[0] <== base[0]; 157 | e2m.in[1] <== base[1]; 158 | 159 | component windows[nWindows]; 160 | component adders[nWindows]; 161 | component cadders[nWindows]; 162 | 163 | // In the last step we add an extra doubler so that numbers do not match. 164 | component dblLast = MontgomeryDouble(); 165 | 166 | for (i=0; i out[0]; 222 | cAdd.yout ==> out[1]; 223 | 224 | windows[nWindows-1].out8[0] ==> dbl[0]; 225 | windows[nWindows-1].out8[1] ==> dbl[1]; 226 | } 227 | 228 | 229 | /* 230 | This component multiplies a escalar times a fixed point BASE (twisted edwards format) 231 | Signals 232 | e: The escalar in binary format 233 | out: The output point in twisted edwards 234 | */ 235 | template EscalarMulFix(n, BASE) { 236 | signal input e[n]; // Input in binary format 237 | signal output out[2]; // Point (Twisted format) 238 | 239 | var nsegments = (n-1)\246 +1; // 249 probably would work. But I'm not sure and for security I keep 246 240 | var nlastsegment = n - (nsegments-1)*249; 241 | 242 | component segments[nsegments]; 243 | 244 | component m2e[nsegments-1]; 245 | component adders[nsegments-1]; 246 | 247 | var s; 248 | var i; 249 | var nseg; 250 | var nWindows; 251 | 252 | for (s=0; s m2e[s-1].in[0]; 275 | segments[s-1].dbl[1] ==> m2e[s-1].in[1]; 276 | 277 | m2e[s-1].out[0] ==> segments[s].base[0]; 278 | m2e[s-1].out[1] ==> segments[s].base[1]; 279 | 280 | if (s==1) { 281 | segments[s-1].out[0] ==> adders[s-1].x1; 282 | segments[s-1].out[1] ==> adders[s-1].y1; 283 | } else { 284 | adders[s-2].xout ==> adders[s-1].x1; 285 | adders[s-2].yout ==> adders[s-1].y1; 286 | } 287 | segments[s].out[0] ==> adders[s-1].x2; 288 | segments[s].out[1] ==> adders[s-1].y2; 289 | } 290 | } 291 | 292 | if (nsegments == 1) { 293 | segments[0].out[0] ==> out[0]; 294 | segments[0].out[1] ==> out[1]; 295 | } else { 296 | adders[nsegments-2].xout ==> out[0]; 297 | adders[nsegments-2].yout ==> out[1]; 298 | } 299 | } 300 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/mimc.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | template MiMC7(nrounds) { 22 | signal input x_in; 23 | signal input k; 24 | signal output out; 25 | 26 | var c[91] = [ 27 | 0, 28 | 20888961410941983456478427210666206549300505294776164667214940546594746570981, 29 | 15265126113435022738560151911929040668591755459209400716467504685752745317193, 30 | 8334177627492981984476504167502758309043212251641796197711684499645635709656, 31 | 1374324219480165500871639364801692115397519265181803854177629327624133579404, 32 | 11442588683664344394633565859260176446561886575962616332903193988751292992472, 33 | 2558901189096558760448896669327086721003508630712968559048179091037845349145, 34 | 11189978595292752354820141775598510151189959177917284797737745690127318076389, 35 | 3262966573163560839685415914157855077211340576201936620532175028036746741754, 36 | 17029914891543225301403832095880481731551830725367286980611178737703889171730, 37 | 4614037031668406927330683909387957156531244689520944789503628527855167665518, 38 | 19647356996769918391113967168615123299113119185942498194367262335168397100658, 39 | 5040699236106090655289931820723926657076483236860546282406111821875672148900, 40 | 2632385916954580941368956176626336146806721642583847728103570779270161510514, 41 | 17691411851977575435597871505860208507285462834710151833948561098560743654671, 42 | 11482807709115676646560379017491661435505951727793345550942389701970904563183, 43 | 8360838254132998143349158726141014535383109403565779450210746881879715734773, 44 | 12663821244032248511491386323242575231591777785787269938928497649288048289525, 45 | 3067001377342968891237590775929219083706800062321980129409398033259904188058, 46 | 8536471869378957766675292398190944925664113548202769136103887479787957959589, 47 | 19825444354178182240559170937204690272111734703605805530888940813160705385792, 48 | 16703465144013840124940690347975638755097486902749048533167980887413919317592, 49 | 13061236261277650370863439564453267964462486225679643020432589226741411380501, 50 | 10864774797625152707517901967943775867717907803542223029967000416969007792571, 51 | 10035653564014594269791753415727486340557376923045841607746250017541686319774, 52 | 3446968588058668564420958894889124905706353937375068998436129414772610003289, 53 | 4653317306466493184743870159523234588955994456998076243468148492375236846006, 54 | 8486711143589723036499933521576871883500223198263343024003617825616410932026, 55 | 250710584458582618659378487568129931785810765264752039738223488321597070280, 56 | 2104159799604932521291371026105311735948154964200596636974609406977292675173, 57 | 16313562605837709339799839901240652934758303521543693857533755376563489378839, 58 | 6032365105133504724925793806318578936233045029919447519826248813478479197288, 59 | 14025118133847866722315446277964222215118620050302054655768867040006542798474, 60 | 7400123822125662712777833064081316757896757785777291653271747396958201309118, 61 | 1744432620323851751204287974553233986555641872755053103823939564833813704825, 62 | 8316378125659383262515151597439205374263247719876250938893842106722210729522, 63 | 6739722627047123650704294650168547689199576889424317598327664349670094847386, 64 | 21211457866117465531949733809706514799713333930924902519246949506964470524162, 65 | 13718112532745211817410303291774369209520657938741992779396229864894885156527, 66 | 5264534817993325015357427094323255342713527811596856940387954546330728068658, 67 | 18884137497114307927425084003812022333609937761793387700010402412840002189451, 68 | 5148596049900083984813839872929010525572543381981952060869301611018636120248, 69 | 19799686398774806587970184652860783461860993790013219899147141137827718662674, 70 | 19240878651604412704364448729659032944342952609050243268894572835672205984837, 71 | 10546185249390392695582524554167530669949955276893453512788278945742408153192, 72 | 5507959600969845538113649209272736011390582494851145043668969080335346810411, 73 | 18177751737739153338153217698774510185696788019377850245260475034576050820091, 74 | 19603444733183990109492724100282114612026332366576932662794133334264283907557, 75 | 10548274686824425401349248282213580046351514091431715597441736281987273193140, 76 | 1823201861560942974198127384034483127920205835821334101215923769688644479957, 77 | 11867589662193422187545516240823411225342068709600734253659804646934346124945, 78 | 18718569356736340558616379408444812528964066420519677106145092918482774343613, 79 | 10530777752259630125564678480897857853807637120039176813174150229243735996839, 80 | 20486583726592018813337145844457018474256372770211860618687961310422228379031, 81 | 12690713110714036569415168795200156516217175005650145422920562694422306200486, 82 | 17386427286863519095301372413760745749282643730629659997153085139065756667205, 83 | 2216432659854733047132347621569505613620980842043977268828076165669557467682, 84 | 6309765381643925252238633914530877025934201680691496500372265330505506717193, 85 | 20806323192073945401862788605803131761175139076694468214027227878952047793390, 86 | 4037040458505567977365391535756875199663510397600316887746139396052445718861, 87 | 19948974083684238245321361840704327952464170097132407924861169241740046562673, 88 | 845322671528508199439318170916419179535949348988022948153107378280175750024, 89 | 16222384601744433420585982239113457177459602187868460608565289920306145389382, 90 | 10232118865851112229330353999139005145127746617219324244541194256766741433339, 91 | 6699067738555349409504843460654299019000594109597429103342076743347235369120, 92 | 6220784880752427143725783746407285094967584864656399181815603544365010379208, 93 | 6129250029437675212264306655559561251995722990149771051304736001195288083309, 94 | 10773245783118750721454994239248013870822765715268323522295722350908043393604, 95 | 4490242021765793917495398271905043433053432245571325177153467194570741607167, 96 | 19596995117319480189066041930051006586888908165330319666010398892494684778526, 97 | 837850695495734270707668553360118467905109360511302468085569220634750561083, 98 | 11803922811376367215191737026157445294481406304781326649717082177394185903907, 99 | 10201298324909697255105265958780781450978049256931478989759448189112393506592, 100 | 13564695482314888817576351063608519127702411536552857463682060761575100923924, 101 | 9262808208636973454201420823766139682381973240743541030659775288508921362724, 102 | 173271062536305557219323722062711383294158572562695717740068656098441040230, 103 | 18120430890549410286417591505529104700901943324772175772035648111937818237369, 104 | 20484495168135072493552514219686101965206843697794133766912991150184337935627, 105 | 19155651295705203459475805213866664350848604323501251939850063308319753686505, 106 | 11971299749478202793661982361798418342615500543489781306376058267926437157297, 107 | 18285310723116790056148596536349375622245669010373674803854111592441823052978, 108 | 7069216248902547653615508023941692395371990416048967468982099270925308100727, 109 | 6465151453746412132599596984628739550147379072443683076388208843341824127379, 110 | 16143532858389170960690347742477978826830511669766530042104134302796355145785, 111 | 19362583304414853660976404410208489566967618125972377176980367224623492419647, 112 | 1702213613534733786921602839210290505213503664731919006932367875629005980493, 113 | 10781825404476535814285389902565833897646945212027592373510689209734812292327, 114 | 4212716923652881254737947578600828255798948993302968210248673545442808456151, 115 | 7594017890037021425366623750593200398174488805473151513558919864633711506220, 116 | 18979889247746272055963929241596362599320706910852082477600815822482192194401, 117 | 13602139229813231349386885113156901793661719180900395818909719758150455500533 118 | ]; 119 | 120 | var t; 121 | signal t2[nrounds]; 122 | signal t4[nrounds]; 123 | signal t6[nrounds]; 124 | signal t7[nrounds-1]; 125 | 126 | for (var i=0; i. 18 | */ 19 | 20 | /* 21 | Source: https://en.wikipedia.org/wiki/Montgomery_curve 22 | 23 | 1 + y 1 + y 24 | [u, v] = [ ------- , ---------- ] 25 | 1 - y (1 - y)x 26 | 27 | */ 28 | pragma circom 2.0.0; 29 | 30 | template Edwards2Montgomery() { 31 | signal input in[2]; 32 | signal output out[2]; 33 | 34 | out[0] <-- (1 + in[1]) / (1 - in[1]); 35 | out[1] <-- out[0] / in[0]; 36 | 37 | 38 | out[0] * (1-in[1]) === (1 + in[1]); 39 | out[1] * in[0] === out[0]; 40 | } 41 | 42 | /* 43 | 44 | u u - 1 45 | [x, y] = [ ---, ------- ] 46 | v u + 1 47 | 48 | */ 49 | template Montgomery2Edwards() { 50 | signal input in[2]; 51 | signal output out[2]; 52 | 53 | out[0] <-- in[0] / in[1]; 54 | out[1] <-- (in[0] - 1) / (in[0] + 1); 55 | 56 | out[0] * in[1] === in[0]; 57 | out[1] * (in[0] + 1) === in[0] - 1; 58 | } 59 | 60 | 61 | /* 62 | x2 - x1 63 | lamda = --------- 64 | y2 - y1 65 | 66 | x3 + A + x1 + x2 67 | x3 = B * lamda^2 - A - x1 -x2 => lamda^2 = ------------------ 68 | B 69 | 70 | y3 = (2*x1 + x2 + A)*lamda - B*lamda^3 - y1 => 71 | 72 | 73 | => y3 = lamda * ( 2*x1 + x2 + A - x3 - A - x1 - x2) - y1 => 74 | 75 | => y3 = lamda * ( x1 - x3 ) - y1 76 | 77 | ---------- 78 | 79 | y2 - y1 80 | lamda = --------- 81 | x2 - x1 82 | 83 | x3 = B * lamda^2 - A - x1 -x2 84 | 85 | y3 = lamda * ( x1 - x3 ) - y1 86 | 87 | */ 88 | 89 | template MontgomeryAdd() { 90 | signal input in1[2]; 91 | signal input in2[2]; 92 | signal output out[2]; 93 | 94 | var a = 168700; 95 | var d = 168696; 96 | 97 | var A = (2 * (a + d)) / (a - d); 98 | var B = 4 / (a - d); 99 | 100 | signal lamda; 101 | 102 | lamda <-- (in2[1] - in1[1]) / (in2[0] - in1[0]); 103 | lamda * (in2[0] - in1[0]) === (in2[1] - in1[1]); 104 | 105 | out[0] <== B*lamda*lamda - A - in1[0] -in2[0]; 106 | out[1] <== lamda * (in1[0] - out[0]) - in1[1]; 107 | } 108 | 109 | /* 110 | 111 | x1_2 = x1*x1 112 | 113 | 3*x1_2 + 2*A*x1 + 1 114 | lamda = --------------------- 115 | 2*B*y1 116 | 117 | x3 = B * lamda^2 - A - x1 -x1 118 | 119 | y3 = lamda * ( x1 - x3 ) - y1 120 | 121 | */ 122 | template MontgomeryDouble() { 123 | signal input in[2]; 124 | signal output out[2]; 125 | 126 | var a = 168700; 127 | var d = 168696; 128 | 129 | var A = (2 * (a + d)) / (a - d); 130 | var B = 4 / (a - d); 131 | 132 | signal lamda; 133 | signal x1_2; 134 | 135 | x1_2 <== in[0] * in[0]; 136 | 137 | lamda <-- (3*x1_2 + 2*A*in[0] + 1 ) / (2*B*in[1]); 138 | lamda * (2*B*in[1]) === (3*x1_2 + 2*A*in[0] + 1 ); 139 | 140 | out[0] <== B*lamda*lamda - A - 2*in[0]; 141 | out[1] <== lamda * (in[0] - out[0]) - in[1]; 142 | } 143 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/mux3.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | template MultiMux3(n) { 22 | signal input c[n][8]; // Constants 23 | signal input s[3]; // Selector 24 | signal output out[n]; 25 | 26 | signal a210[n]; 27 | signal a21[n]; 28 | signal a20[n]; 29 | signal a2[n]; 30 | 31 | signal a10[n]; 32 | signal a1[n]; 33 | signal a0[n]; 34 | signal a[n]; 35 | 36 | // 4 constrains for the intermediary variables 37 | signal s10; 38 | s10 <== s[1] * s[0]; 39 | 40 | for (var i=0; i mux.s[i]; 72 | } 73 | 74 | mux.out[0] ==> out; 75 | } 76 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/sign.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | pragma circom 2.0.0; 20 | 21 | include "compconstant.circom"; 22 | 23 | template Sign() { 24 | signal input in; 25 | signal output sign; 26 | 27 | sign <== in > 0; 28 | } 29 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/circomlib/switcher.circom: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2018 0KIMS association. 3 | 4 | This file is part of circom (Zero Knowledge Circuit Compiler). 5 | 6 | circom is a free software: you can redistribute it and/or modify it 7 | under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | circom is distributed in the hope that it will be useful, but WITHOUT 12 | ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 13 | or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public 14 | License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with circom. If not, see . 18 | */ 19 | 20 | /* 21 | Assume sel is binary. 22 | 23 | If sel == 0 then outL = L and outR=R 24 | If sel == 1 then outL = R and outR=L 25 | 26 | */ 27 | 28 | pragma circom 2.0.0; 29 | 30 | template Switcher() { 31 | signal input sel; 32 | signal input L; 33 | signal input R; 34 | signal output outL; 35 | signal output outR; 36 | 37 | signal aux; 38 | 39 | aux <== (R-L)*sel; // We create aux in order to have only one multiplication 40 | outL <== aux + L; 41 | outR <== -aux + R; 42 | } 43 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/crypto/ecdh.circom: -------------------------------------------------------------------------------- 1 | // from privacy-scaling-explorations/maci 2 | 3 | pragma circom 2.0.0; 4 | 5 | include "../circomlib/bitify.circom"; 6 | include "../circomlib/escalarmulany.circom"; 7 | 8 | template Ecdh() { 9 | // Note: private key 10 | // Needs to be hashed, and then pruned before 11 | // supplying it to the circuit 12 | signal input private_key; 13 | signal input public_key[2]; 14 | 15 | signal output shared_key; 16 | 17 | component privBits = Num2Bits(253); 18 | privBits.in <== private_key; 19 | 20 | component mulFix = EscalarMulAny(253); 21 | mulFix.p[0] <== public_key[0]; 22 | mulFix.p[1] <== public_key[1]; 23 | 24 | for (var i = 0; i < 253; i++) { 25 | mulFix.e[i] <== privBits.out[i]; 26 | } 27 | 28 | shared_key <== mulFix.out[0]; 29 | } -------------------------------------------------------------------------------- /tests/circuits/machine-learning/crypto/encrypt.circom: -------------------------------------------------------------------------------- 1 | //from zk-ml/linear-regression-demo 2 | 3 | pragma circom 2.0.0; 4 | 5 | include "../circomlib/mimc.circom"; 6 | 7 | template EncryptBits(N) { 8 | signal input plaintext[N]; 9 | signal input shared_key; 10 | signal output out[N+1]; 11 | 12 | component mimc = MultiMiMC7(N, 91); 13 | for (var i=0; i in[i]; 79 | 80 | switchers[i+1] = Switcher(); 81 | 82 | switchers[i+1].sel <== gts[i]; 83 | switchers[i+1].L <== maxs[i]; 84 | switchers[i+1].R <== in[i]; 85 | 86 | maxs[i+1] <== switchers[i+1].outL; 87 | } 88 | 89 | out <== maxs[n]; 90 | } 91 | -------------------------------------------------------------------------------- /tests/circuits/machine-learning/utils-comp.circom: -------------------------------------------------------------------------------- 1 | pragma circom 2.0.0; 2 | 3 | 4 | template ShiftLeft(n) { 5 | signal input in; 6 | signal output out; 7 | 8 | out <== in << n; 9 | } 10 | 11 | template ShiftRight(n) { 12 | signal input in; 13 | signal output out; 14 | 15 | out <== in >> n; 16 | } -------------------------------------------------------------------------------- /tests/integration.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::upper_case_acronyms)] 2 | 3 | use bristol_circuit::BristolCircuit; 4 | use circom_2_arithc::{a_gate_type::AGateType, cli::ValueType}; 5 | use sim_circuit::{ 6 | circuit::{CircuitBuilder, CircuitMemory, GenericCircuit, GenericCircuitExecutor}, 7 | model::{Component, Executable, Memory}, 8 | }; 9 | use std::collections::HashMap; 10 | 11 | #[derive(Debug, PartialEq, Eq, Clone)] 12 | enum ArithmeticOperation { 13 | ADD, 14 | DIV, 15 | EQ, 16 | GEQ, 17 | GT, 18 | LEQ, 19 | LT, 20 | MUL, 21 | NEQ, 22 | SUB, 23 | XOR, 24 | POW, 25 | INTDIV, 26 | MOD, 27 | SHIFTL, 28 | SHIFTR, 29 | BOOLOR, 30 | BOOLAND, 31 | BITOR, 32 | BITAND, 33 | } 34 | 35 | impl From for ArithmeticOperation { 36 | fn from(gate_type: AGateType) -> Self { 37 | match gate_type { 38 | AGateType::AAdd => ArithmeticOperation::ADD, 39 | AGateType::ADiv => ArithmeticOperation::DIV, 40 | AGateType::AEq => ArithmeticOperation::EQ, 41 | AGateType::AGEq => ArithmeticOperation::GEQ, 42 | AGateType::AGt => ArithmeticOperation::GT, 43 | AGateType::ALEq => ArithmeticOperation::LEQ, 44 | AGateType::ALt => ArithmeticOperation::LT, 45 | AGateType::AMul => ArithmeticOperation::MUL, 46 | AGateType::ANeq => ArithmeticOperation::NEQ, 47 | AGateType::ASub => ArithmeticOperation::SUB, 48 | AGateType::AXor => ArithmeticOperation::XOR, 49 | AGateType::APow => ArithmeticOperation::POW, 50 | AGateType::AIntDiv => ArithmeticOperation::INTDIV, 51 | AGateType::AMod => ArithmeticOperation::MOD, 52 | AGateType::AShiftL => ArithmeticOperation::SHIFTL, 53 | AGateType::AShiftR => ArithmeticOperation::SHIFTR, 54 | AGateType::ABoolOr => ArithmeticOperation::BOOLOR, 55 | AGateType::ABoolAnd => ArithmeticOperation::BOOLAND, 56 | AGateType::ABitOr => ArithmeticOperation::BITOR, 57 | AGateType::ABitAnd => ArithmeticOperation::BITAND, 58 | } 59 | } 60 | } 61 | 62 | #[derive(Debug, PartialEq, Eq, Clone)] 63 | struct ArithmeticGate { 64 | operation: ArithmeticOperation, 65 | inputs: Vec, 66 | outputs: Vec, 67 | } 68 | 69 | impl Component for ArithmeticGate { 70 | fn inputs(&self) -> &[usize] { 71 | &self.inputs 72 | } 73 | 74 | fn outputs(&self) -> &[usize] { 75 | &self.outputs 76 | } 77 | 78 | fn set_inputs(&mut self, inputs: Vec) { 79 | self.inputs = inputs; 80 | } 81 | 82 | fn set_outputs(&mut self, outputs: Vec) { 83 | self.outputs = outputs; 84 | } 85 | } 86 | 87 | impl Executable> for ArithmeticGate { 88 | type Error = (); 89 | 90 | fn execute(&self, memory: &mut CircuitMemory) -> Result<(), Self::Error> { 91 | let a = memory.read(self.inputs[0]).unwrap(); 92 | let b = memory.read(self.inputs[1]).unwrap(); 93 | 94 | let result = match self.operation { 95 | ArithmeticOperation::ADD => a + b, 96 | ArithmeticOperation::DIV => a / b, 97 | ArithmeticOperation::EQ => (a == b) as u32, 98 | ArithmeticOperation::GEQ => (a >= b) as u32, 99 | ArithmeticOperation::GT => (a > b) as u32, 100 | ArithmeticOperation::LEQ => (a <= b) as u32, 101 | ArithmeticOperation::LT => (a < b) as u32, 102 | ArithmeticOperation::MUL => a * b, 103 | ArithmeticOperation::NEQ => (a != b) as u32, 104 | ArithmeticOperation::SUB => a - b, 105 | ArithmeticOperation::XOR => a ^ b, 106 | ArithmeticOperation::POW => a.pow(b), 107 | ArithmeticOperation::INTDIV => a / b, 108 | ArithmeticOperation::MOD => a % b, 109 | ArithmeticOperation::SHIFTL => a << b, 110 | ArithmeticOperation::SHIFTR => a >> b, 111 | ArithmeticOperation::BOOLOR => (a != 0 || b != 0) as u32, 112 | ArithmeticOperation::BOOLAND => (a != 0 && b != 0) as u32, 113 | ArithmeticOperation::BITOR => a | b, 114 | ArithmeticOperation::BITAND => a & b, 115 | }; 116 | 117 | memory.write(self.outputs[0], result).unwrap(); 118 | Ok(()) 119 | } 120 | } 121 | 122 | #[derive(Debug, PartialEq, Eq, Clone)] 123 | pub struct ArithmeticCircuit { 124 | gates: Vec, 125 | constants: HashMap, 126 | label_to_index: HashMap, 127 | input_indices: Vec, 128 | outputs: Vec, 129 | } 130 | 131 | impl ArithmeticCircuit { 132 | /// Create a new `ArithmeticCircuit` from a bristol circuit 133 | pub fn new_from_bristol(circuit: BristolCircuit) -> Result { 134 | let mut label_to_index: HashMap = HashMap::new(); 135 | let mut outputs: Vec = Vec::new(); 136 | let mut input_indices: Vec = Vec::new(); 137 | let mut gates: Vec = Vec::new(); 138 | 139 | // Get circuit inputs 140 | let inputs = circuit.info.input_name_to_wire_index; 141 | for (label, index) in inputs { 142 | label_to_index.insert(label, index); 143 | input_indices.push(index); 144 | } 145 | 146 | // Get circuit constants 147 | let mut constants: HashMap = HashMap::new(); 148 | for (_, constant_info) in circuit.info.constants { 149 | input_indices.push(constant_info.wire_index); 150 | constants.insert( 151 | constant_info.wire_index, 152 | constant_info.value.parse().unwrap(), 153 | ); 154 | } 155 | 156 | // Get circuit outputs 157 | let output_map = circuit.info.output_name_to_wire_index; 158 | let mut output_indices = vec![]; 159 | for (label, index) in output_map { 160 | label_to_index.insert(label.clone(), index); 161 | outputs.push(label); 162 | output_indices.push(index); 163 | } 164 | 165 | // Transform and add gates 166 | for gate in circuit.gates { 167 | let operation = ArithmeticOperation::from( 168 | gate.op 169 | .parse::() 170 | .map_err(|_| "unrecognized gate")?, 171 | ); 172 | 173 | let arithmetic_gate = ArithmeticGate { 174 | operation, 175 | inputs: gate.inputs, 176 | outputs: gate.outputs, 177 | }; 178 | gates.push(arithmetic_gate); 179 | } 180 | 181 | Ok(Self { 182 | gates, 183 | constants, 184 | label_to_index, 185 | input_indices, 186 | outputs, 187 | }) 188 | } 189 | 190 | /// Run the circuit 191 | pub fn run(&self, inputs: HashMap) -> Result, &'static str> { 192 | // Build circuit 193 | let circuit = self.build_circuit(); 194 | // Instantiate a circuit executor 195 | let mut executor: GenericCircuitExecutor = 196 | GenericCircuitExecutor::new(circuit); 197 | 198 | // The executor receives a map of WireIndex -> Value 199 | let input_map: HashMap = inputs 200 | .iter() 201 | .map(|(label, value)| { 202 | let index = self 203 | .label_to_index 204 | .get(label) 205 | .ok_or("Input label not found") 206 | .unwrap(); 207 | (*index, *value) 208 | }) 209 | .collect(); 210 | 211 | // Load constants into the input map 212 | let input_map = self 213 | .constants 214 | .iter() 215 | .fold(input_map, |mut acc, (index, value)| { 216 | acc.insert(*index, *value); 217 | acc 218 | }); 219 | 220 | let output = executor.run(&input_map).unwrap(); 221 | 222 | // The executor returns a map of WireIndex -> Value 223 | let output_map: HashMap = self 224 | .outputs 225 | .iter() 226 | .map(|label| { 227 | let index = self 228 | .label_to_index 229 | .get(label) 230 | .ok_or("Output label not found") 231 | .unwrap(); 232 | (label.clone(), *output.get(index).unwrap()) 233 | }) 234 | .collect(); 235 | 236 | Ok(output_map) 237 | } 238 | 239 | fn build_circuit(&self) -> GenericCircuit { 240 | let mut builder = CircuitBuilder::::new(); 241 | builder.add_inputs(&self.input_indices); 242 | 243 | for gate in &self.gates { 244 | builder.add_component(gate.clone()).unwrap(); 245 | } 246 | 247 | builder.build().unwrap() 248 | } 249 | } 250 | 251 | #[cfg(test)] 252 | mod integration_tests { 253 | use super::*; 254 | use bristol_circuit::ConstantInfo; 255 | use circom_2_arithc::{cli::Args, program::compile}; 256 | 257 | fn simulation_test( 258 | circuit_path: &str, 259 | inputs: &[(&str, u32)], 260 | expected_outputs: &[(&str, u32)], 261 | ) { 262 | let compiler_input = Args::new(circuit_path.into(), "./".into(), ValueType::Sint, None); 263 | let circuit = compile(&compiler_input).unwrap().build_circuit().unwrap(); 264 | let arithmetic_circuit = ArithmeticCircuit::new_from_bristol(circuit).unwrap(); 265 | 266 | let mut input_map: HashMap = HashMap::new(); 267 | for (label, value) in inputs { 268 | input_map.insert(label.to_string(), *value); 269 | } 270 | 271 | let outputs: HashMap = arithmetic_circuit.run(input_map).unwrap(); 272 | 273 | for (label, expected_value) in expected_outputs { 274 | let value = outputs.get(*label).unwrap(); 275 | assert_eq!(value, expected_value); 276 | } 277 | } 278 | 279 | #[test] 280 | fn test_add_zero() { 281 | simulation_test( 282 | "tests/circuits/integration/addZero.circom", 283 | &[("0.in", 42)], 284 | &[("0.out", 42)], 285 | ); 286 | } 287 | 288 | #[test] 289 | fn test_infix_ops() { 290 | simulation_test( 291 | "tests/circuits/integration/infixOps.circom", 292 | &[ 293 | ("0.x0", 0), 294 | ("0.x1", 1), 295 | ("0.x2", 2), 296 | ("0.x3", 3), 297 | ("0.x4", 4), 298 | ("0.x5", 5), 299 | ], 300 | &[ 301 | ("0.mul_2_3", 6), 302 | // ("0.div_4_3", 1), // unsupported for NumberU32 303 | ("0.idiv_4_3", 1), 304 | ("0.add_3_4", 7), 305 | ("0.sub_4_1", 3), 306 | ("0.pow_2_4", 16), 307 | ("0.mod_5_3", 2), 308 | ("0.shl_5_1", 10), 309 | ("0.shr_5_1", 2), 310 | ("0.leq_2_3", 1), 311 | ("0.leq_3_3", 1), 312 | ("0.leq_4_3", 0), 313 | ("0.geq_2_3", 0), 314 | ("0.geq_3_3", 1), 315 | ("0.geq_4_3", 1), 316 | ("0.lt_2_3", 1), 317 | ("0.lt_3_3", 0), 318 | ("0.lt_4_3", 0), 319 | ("0.gt_2_3", 0), 320 | ("0.gt_3_3", 0), 321 | ("0.gt_4_3", 1), 322 | ("0.eq_2_3", 0), 323 | ("0.eq_3_3", 1), 324 | ("0.neq_2_3", 1), 325 | ("0.neq_3_3", 0), 326 | ("0.or_0_1", 1), 327 | ("0.and_0_1", 0), 328 | ("0.bit_or_1_3", 3), 329 | ("0.bit_and_1_3", 1), 330 | ("0.bit_xor_1_3", 2), 331 | ], 332 | ); 333 | } 334 | 335 | #[test] 336 | fn test_matrix_element_multiplication() { 337 | simulation_test( 338 | "tests/circuits/integration/matElemMul.circom", 339 | &[ 340 | ("0.a[0][0]", 2), 341 | ("0.a[0][1]", 2), 342 | ("0.a[1][0]", 2), 343 | ("0.a[1][1]", 2), 344 | ("0.b[0][0]", 2), 345 | ("0.b[0][1]", 2), 346 | ("0.b[1][0]", 2), 347 | ("0.b[1][1]", 2), 348 | ], 349 | &[ 350 | ("0.out[0][0]", 4), 351 | ("0.out[0][1]", 4), 352 | ("0.out[1][0]", 4), 353 | ("0.out[1][1]", 4), 354 | ], 355 | ); 356 | } 357 | 358 | #[test] 359 | fn test_sum() { 360 | simulation_test( 361 | "tests/circuits/integration/sum.circom", 362 | &[("0.a", 3), ("0.b", 5)], 363 | &[("0.out", 8)], 364 | ); 365 | } 366 | 367 | #[test] 368 | fn test_x_eq_x() { 369 | simulation_test( 370 | "tests/circuits/integration/xEqX.circom", 371 | &[("0.x", 37)], 372 | &[("0.out", 1)], 373 | ); 374 | } 375 | 376 | #[test] 377 | fn test_out_of_bounds() { 378 | let compiler_input = Args::new( 379 | "tests/circuits/integration/indexOutOfBounds.circom".into(), 380 | "./".into(), 381 | ValueType::Sint, 382 | None, 383 | ); 384 | let circuit = compile(&compiler_input); 385 | 386 | assert!(circuit.is_err()); 387 | assert_eq!( 388 | circuit.unwrap_err().to_string(), 389 | "Runtime error: Index out of bounds" 390 | ); 391 | } 392 | 393 | #[test] 394 | fn test_constant_sum() { 395 | let compiler_input = Args::new( 396 | "tests/circuits/integration/constantSum.circom".into(), 397 | "./".into(), 398 | ValueType::Sint, 399 | None, 400 | ); 401 | let circuit_res = compile(&compiler_input); 402 | 403 | assert!(circuit_res.is_ok()); 404 | 405 | let circuit = circuit_res.unwrap().build_circuit().unwrap(); 406 | 407 | assert_eq!(circuit.info.constants.len(), 1); 408 | assert_eq!( 409 | circuit.info.constants.get("0.const_signal_8_1"), 410 | Some(&ConstantInfo { 411 | value: "8".to_string(), // 5 + 3 412 | wire_index: 0 413 | }) 414 | ); 415 | } 416 | 417 | #[test] 418 | fn test_direct_output() { 419 | let compiler_input = Args::new( 420 | "tests/circuits/integration/directOutput.circom".into(), 421 | "./".into(), 422 | ValueType::Sint, 423 | None, 424 | ); 425 | let circuit_res = compile(&compiler_input); 426 | 427 | assert!(circuit_res.is_ok()); 428 | 429 | let circuit = circuit_res.unwrap().build_circuit().unwrap(); 430 | 431 | let expected_output = HashMap::from([("0.out".to_string(), 0)]); 432 | assert_eq!(circuit.info.output_name_to_wire_index, expected_output); 433 | assert_eq!(circuit.info.constants.len(), 1); 434 | assert_eq!( 435 | circuit.info.constants.get("0.const_signal_42_1"), 436 | Some(&ConstantInfo { 437 | value: "42".to_string(), 438 | wire_index: 0 439 | }) 440 | ); 441 | } 442 | 443 | #[ignore] 444 | #[test] 445 | fn test_under_constrained() { 446 | // FIXME: There should be an error instead (zero comes from default initialization, not from 447 | // running the circuit) 448 | simulation_test( 449 | "tests/circuits/integration/underConstrained.circom", 450 | &[], 451 | &[("0.x", 0)], 452 | ); 453 | } 454 | 455 | #[ignore] 456 | #[test] 457 | fn test_prefix_ops() { 458 | // FIXME: The compiler sees several of the outputs as inputs, leading to the error below 459 | // CircuitError(Inconsistency { 460 | // message: "Node 10 used for both input 0.complementC and output 0.complementC" 461 | // }) 462 | simulation_test( 463 | "tests/circuits/integration/prefixOps.circom", 464 | &[("0.a", 0), ("0.b", 1), ("0.c", 2)], 465 | &[ 466 | ("0.negateA", 0), // -0 467 | ("0.notA", 1), // !0 468 | ("0.notB", 0), // !1 469 | ("0.notC", 0), // !2 470 | ("0.complementA", 0b_11111111_11111111_11111111_11111111), // ~0 471 | ("0.complementB", 0b_11111111_11111111_11111111_11111110), // ~1 472 | ("0.complementC", 0b_11111111_11111111_11111111_11111101), // ~2 473 | ], 474 | ); 475 | } 476 | } 477 | --------------------------------------------------------------------------------