├── .env.example
├── .github
└── workflows
│ ├── build.yml
│ └── coverage.yml
├── .gitignore
├── Cargo.toml
├── LICENSE
├── README.md
├── docs
└── operations.md
├── input
└── circuit.circom
├── src
├── a_gate_type.rs
├── circom.rs
├── circom
│ ├── parser.rs
│ └── type_analysis.rs
├── cli.rs
├── compiler.rs
├── lib.rs
├── main.rs
├── process.rs
├── program.rs
├── runtime.rs
└── topological_sort.rs
└── tests
├── circuits
├── integration
│ ├── addZero.circom
│ ├── arrayAssignment.circom
│ ├── constantSum.circom
│ ├── directOutput.circom
│ ├── indexOutOfBounds.circom
│ ├── infixOps.circom
│ ├── mainTemplateArgument.circom
│ ├── matElemMul.circom
│ ├── prefixOps.circom
│ ├── sum.circom
│ ├── underConstrained.circom
│ └── xEqX.circom
└── machine-learning
│ ├── ArgMax.circom
│ ├── AveragePooling2D.circom
│ ├── BatchNormalization2D.circom
│ ├── Conv1D.circom
│ ├── Conv2D.circom
│ ├── Dense.circom
│ ├── DepthwiseConv2D.circom
│ ├── Flatten2D.circom
│ ├── GlobalAveragePooling2D.circom
│ ├── GlobalMaxPooling2D.circom
│ ├── GlobalSumPooling2D.circom
│ ├── MaxPooling2D.circom
│ ├── NaiveSearch.circom
│ ├── PointwiseConv2D.circom
│ ├── ReLU.circom
│ ├── SeparableConv2D.circom
│ ├── SumPooling2D.circom
│ ├── Zanh.circom
│ ├── ZeLU.circom
│ ├── Zigmoid.circom
│ ├── circomlib-matrix
│ ├── matElemMul.circom
│ ├── matElemSum.circom
│ └── matMul.circom
│ ├── circomlib
│ ├── aliascheck.circom
│ ├── babyjub.circom
│ ├── binsum.circom
│ ├── bitify.circom
│ ├── comparators.circom
│ ├── compconstant.circom
│ ├── escalarmulany.circom
│ ├── escalarmulfix.circom
│ ├── mimc.circom
│ ├── montgomery.circom
│ ├── mux3.circom
│ ├── sign.circom
│ └── switcher.circom
│ ├── crypto
│ ├── ecdh.circom
│ ├── encrypt.circom
│ └── publickey_derivation.circom
│ ├── fc.circom
│ ├── util.circom
│ └── utils-comp.circom
└── integration.rs
/.env.example:
--------------------------------------------------------------------------------
1 | LOG_LEVEL="debug"
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | env:
12 | CARGO_TERM_COLOR: always
13 |
14 | jobs:
15 | build:
16 | runs-on: ubuntu-latest
17 |
18 | steps:
19 | - uses: actions/checkout@v3
20 |
21 | - uses: actions/cache@v3
22 | with:
23 | path: |
24 | ~/.cargo
25 | ~/.rustup/toolchains
26 | target
27 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
28 | restore-keys: |
29 | ${{ runner.os }}-cargo-
30 |
31 | - name: Install Rust
32 | run: rustup toolchain install stable
33 |
34 | - name: Build
35 | run: cargo build --verbose
36 |
37 | - name: Clippy
38 | run: cargo clippy --verbose -- -D warnings
39 |
40 | - name: Tests
41 | run: cargo test --verbose
42 |
43 | - name: Fmt
44 | run: cargo fmt -- --check
45 |
--------------------------------------------------------------------------------
/.github/workflows/coverage.yml:
--------------------------------------------------------------------------------
1 | name: Coverage
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | env:
12 | CARGO_TERM_COLOR: always
13 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
14 |
15 | jobs:
16 | coverage:
17 | runs-on: ubuntu-latest
18 |
19 | steps:
20 | - uses: actions/checkout@v3
21 |
22 | - uses: actions/cache@v3
23 | with:
24 | path: |
25 | ~/.cargo
26 | ~/.rustup/toolchains
27 | target
28 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
29 | restore-keys: |
30 | ${{ runner.os }}-cargo-
31 |
32 | - name: Install Rust
33 | run: rustup toolchain install stable
34 |
35 | - name: Install cargo-llvm-cov
36 | uses: taiki-e/install-action@cargo-llvm-cov
37 |
38 | - name: Generate code coverage
39 | run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
40 |
41 | - name: Upload coverage to Codecov
42 | uses: codecov/codecov-action@v4
43 | with:
44 | files: lcov.info
45 | fail_ci_if_error: true
46 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generated by Cargo
2 | # will have compiled files and executables
3 | debug/
4 | target/
5 |
6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
8 | Cargo.lock
9 |
10 | # These are backup files generated by rustfmt
11 | **/*.rs.bk
12 |
13 | # MSVC Windows builds of rustc generate these, which store debugging information
14 | *.pdb
15 | .DS_Store
16 |
17 | # Local
18 | .env
19 | .vscode/
20 |
21 | # Output
22 | output/
23 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "circom-2-arithc"
3 | version = "0.1.0"
4 | edition = "2021"
5 | resolver = "1" # Fixes lalrpop issue, see: https://github.com/lalrpop/lalrpop/issues/616
6 |
7 | [dependencies]
8 | clap = { version = "4.5.4", features = ["derive"] }
9 | dotenv = "0.15.0"
10 | env_logger = "0.11.1"
11 | log = "0.4.20"
12 | rand = "0.8.5"
13 | regex = "1.10.3"
14 | serde_json = "1.0"
15 | serde = { version = "1.0.196", features = ["derive"] }
16 | thiserror = "1.0.59"
17 | strum_macros = "0.26.4"
18 | strum = "0.26.2"
19 | sim-circuit = { git = "https://github.com/brech1/sim-circuit" }
20 | bristol-circuit = { git = "https://github.com/voltrevo/bristol-circuit", rev = "2a8b001" }
21 | boolify = { git = "https://github.com/voltrevo/boolify", rev = "6376405" }
22 |
23 | # DSL
24 | circom-circom_algebra = { git = "https://github.com/iden3/circom", package = "circom_algebra", rev = "e8e125e" }
25 | circom-code_producers = { git = "https://github.com/iden3/circom", package = "code_producers", rev = "e8e125e" }
26 | circom-compiler = { git = "https://github.com/iden3/circom", package = "compiler", rev = "e8e125e" }
27 | circom-constant_tracking = { git = "https://github.com/iden3/circom", package = "constant_tracking", rev = "e8e125e" }
28 | circom-constraint_generation = { git = "https://github.com/iden3/circom", package = "constraint_generation", rev = "e8e125e" }
29 | circom-constraint_list = { git = "https://github.com/iden3/circom", package = "constraint_list", rev = "e8e125e" }
30 | circom-constraint_writers = { git = "https://github.com/iden3/circom", package = "constraint_writers", rev = "e8e125e" }
31 | circom-dag = { git = "https://github.com/iden3/circom", package = "dag", rev = "e8e125e" }
32 | circom-parser = { git = "https://github.com/iden3/circom", package = "parser", rev = "e8e125e" }
33 | circom-program_structure = { git = "https://github.com/iden3/circom", package = "program_structure", rev = "e8e125e" }
34 | circom-type_analysis = { git = "https://github.com/iden3/circom", package = "type_analysis", rev = "e8e125e" }
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Nam Ngo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Circom To Arithmetic Circuit
2 |
3 | [![MIT licensed][mit-badge]][mit-url]
4 | [![Build Status][actions-badge]][actions-url]
5 | [](https://app.codecov.io/github/namnc/circom-2-arithc/)
6 |
7 | [mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg
8 | [mit-url]: https://github.com/namnc/circom-2-arithc/blob/master/LICENSE
9 | [actions-badge]: https://github.com/namnc/circom-2-arithc/actions/workflows/build.yml/badge.svg
10 | [actions-url]: https://github.com/namnc/circom-2-arithc/actions?query=branch%3Amain
11 |
12 | This library enables the creation of arithmetic circuits from circom programs.
13 |
14 | ## Supported Circom Features
15 |
16 | | Category | Type | Supported |
17 | | --------------- | ------------------------ | :-------: |
18 | | **Statements** | `InitializationBlock` | ✅ |
19 | | | `Block` | ✅ |
20 | | | `Substitution` | ✅ |
21 | | | `Declaration` | ✅ |
22 | | | `IfThenElse` | ✅ |
23 | | | `While` | ✅ |
24 | | | `Return` | ✅ |
25 | | | `MultSubstitution` | ❌ |
26 | | | `UnderscoreSubstitution` | ❌ |
27 | | | `ConstraintEquality` | ❌ |
28 | | | `LogCall` | ❌ |
29 | | | `Assert` | ✅ |
30 | | **Expressions** | `Call` | ✅ |
31 | | | `InfixOp` | ✅ |
32 | | | `Number` | ✅ |
33 | | | `Variable` | ✅ |
34 | | | `PrefixOp` | ✅ |
35 | | | `InlineSwitchOp` | ❌ |
36 | | | `ParallelOp` | ❌ |
37 | | | `AnonymousComp` | ✅ |
38 | | | `ArrayInLine` | ❌ |
39 | | | `Tuple` | ✅ |
40 | | | `UniformArray` | ❌ |
41 |
42 | ## Circomlib
43 |
44 | WIP
45 |
46 | ## Requirements
47 |
48 | - Rust: To install, follow the instructions found [here](https://www.rust-lang.org/tools/install).
49 |
50 | ## Getting Started
51 |
52 | - Write your circom program in the `input` directory under the `circuit.circom` name.
53 |
54 | - Build the program
55 |
56 | ```bash
57 | cargo build --release
58 | ```
59 |
60 | - Run the compilation
61 |
62 | ```bash
63 | cargo run --release
64 | ```
65 |
66 | The compiled circuit and circuit report can be found in the `./output` directory.
67 |
68 | ### Boolean Circuits
69 |
70 | Although this library is named after arithmetic circuits, the CLI integrates [boolify](https://github.com/voltrevo/boolify) allowing further compilation down to boolean circuits.
71 |
72 | To achieve this, add `--boolify-width DESIRED_INT_WIDTH` to your command:
73 |
74 | ```bash
75 | cargo run --release -- --boolify-width 16
76 | ```
77 |
78 | ## ZK/MPC/FHE backends:
79 |
80 | - [circom-mp-spdz](https://github.com/namnc/circom-mp-spdz)
81 |
82 | ## Contributing
83 |
84 | Contributions are welcome!
85 |
86 | ## License
87 |
88 | This project is licensed under the MIT License - see the LICENSE file for details.
89 |
--------------------------------------------------------------------------------
/docs/operations.md:
--------------------------------------------------------------------------------
1 | # Operations
2 |
3 | This document outlines the high-level operations and how we process statements, expressions and gates creation.
4 |
5 | ## Main Component
6 |
7 | - It's a template call.
8 | - Includes initialized variables, likely in an `InitializationBlock`.
9 | - The body is traversed using `traverse_sequence_of_statements`.
10 |
11 | ## Statements
12 |
13 | Each statement in the component's body is handled by `traverse_statement`. These include:
14 |
15 | - **DataItem Declaration**: Declaring variables or signals, either as single scalar or array. Here we can just add a `DataItem` based on the type and dimension
16 | - **If-Then-Else**: Evaluates conditions (variables or function calls) with `execute_expression`, then executes the chosen path using `traverse_sequence_of_statements`.
17 | - **Loops (While/For)**: Similar to `if-then-else`, but repeats based on a condition (breaks if it's `false`). Uses `traverse_sequence_of_statements` for the loop body.
18 | - **ConstraintEquality**: Probably only for ZK, not MPC.
19 | - **Return Statement**: In a function body, this statement assigns the result directly to the variable on the left-hand side of the call. For instance, in `a = func()`, the return value of `func()` is assigned to `a`. This also applies when `func()` is part of a larger expression, like in `a = a + func()`, where the return value is used as part of the expression calculation.
20 | - **Assert**: Probably only for ZK, not MPC.
21 | - **Substitution**: Like `a <== b + c`, the right-hand side is an expression processed by `traverse_expression`. This is the primary instance where a substitution statement is executed. If the left-hand side is a variable, we use `execute_expression` for execution instead of `traverse_expression`.
22 | - **Block**: Traversed with `traverse_sequence_of_statements`.
23 | - **LogCall**: For debugging, not used.
24 | - **UnderscoreSubstitution**: Probably an anonymous substitution, to be handled later (if `circomlib-ml` uses it).
25 |
26 | ## Expressions
27 |
28 | Parts of statements, like the right side of a substitution or a flow control (if/loop) condition.
29 |
30 | - **Number**: A constant. We return its value and can add it as a named variable in the context for ease in mixed signal-variable expressions, like naming "1" for the value 1.
31 | - **Infix-Op**: If the right-hand side is a variable, we use `execute_infix_op`. If not, we use `traverse_infix_op`. It gets complex when variables are mixed with signals, like if rhs is a signal. In such cases, we use `traverse_infix_op` and might need to create an intermediate signal. For example, in `sum[1] = sum[0] + input_A[0]*input_B[0]`, we generate an intermediate signal for `input_A[0]*input_B[0]`. If one of the operands is a signal, we create an auto-named signal. If both are variables, we simply execute `execute_infix_op` and return the value. This value can be treated as a constant variable in expressions involving signals, like "1" with value 1. The `execute_infix_op` is a straightforward operation, like in `a = b + c` where `b` is `1` and `c` is `2`, resulting in `a` being `3`.
32 | - **Prefix-Op**: Handled like `infix-op`.
33 | - **Inline-Switch**: Acts like a quick If-Then-Else.
34 | - **ParallelOp**: Not addressed.
35 | - **Variable**: Returns an signal id for a signal, or the value for a variable.
36 | - **Call**: We start by identifying if the call is a template or a function.
37 | - For both templates and functions, we map the defined arguments to their initialized values at the call time.
38 | - Next, we process the body of the template or function using `traverse_sequence_of_statements`.
39 | - In the case of a _function_, we're assuming it processes only variables, as previously covered. (We're assuming that a function body doesn't include a template call)
40 | - For a _template_ call, it's similar to processing the main call. However, there's an additional step of delayed mapping for input and output signals. After traversing the template, we map the template's signals to the caller's signals. For example, in `component c = Template()` where `Template` has input signal `I` and output signal `O`, these are mapped in the caller's code as `c.I = I` and `c.O = O`.
41 | - This mapping may be going on in the `traverse_sequence_of_statements`, during `execute_delayed_declarations` if it's a complete template.
42 | - `if is_complete_template { execute_delayed_declarations(program_archive, runtime, actual_node, flags); }`.
43 | - **AnonymousComponent**: Not addressed.
44 | - **ArrayInLine**: Not addressed.
45 | - **Tuple**: Not addressed.
46 | - **UniformArray**: Not addressed.
47 |
48 | ## Creating Gates and Operations
49 |
50 | - Gates are only created when processing `traverse_infix_op`.
51 | - Based on the operation, a specific gate like a fan-in-2 gate may be added to the circuit.
52 | - For example, when we encounter `traverse_infix_op` and it results in `id_1 = id_2 + id_3`, we create an add gate with `id_2` and `id_3` as inputs and `id_1` as the output.
53 |
54 | ### Special Gates
55 |
56 | For the Comparison, Negative/Positive, and Zero Check gates, Circom implements these using an advisory approach. Our strategy should involve identifying these specific gates during the processing of template calls within `traverse expression`. We do this by matching the template name and then substituting them with dedicated gates for comparison, sign checking, and zero equality.
57 |
--------------------------------------------------------------------------------
/input/circuit.circom:
--------------------------------------------------------------------------------
1 | // from 0xZKML/zk-mnist
2 |
3 | pragma circom 2.0.0;
4 |
5 | template Switcher() {
6 | signal input sel;
7 | signal input L;
8 | signal input R;
9 | signal output outL;
10 | signal output outR;
11 |
12 | signal aux;
13 |
14 | aux <== (R-L)*sel; // We create aux in order to have only one multiplication
15 | outL <== aux + L;
16 | outR <== -aux + R;
17 | }
18 |
19 | template ArgMax (n) {
20 | signal input in[n];
21 | signal output out;
22 |
23 | // assert (out < n);
24 | signal gts[n]; // store comparators
25 | component switchers[n+1]; // switcher for comparing maxs
26 | component aswitchers[n+1]; // switcher for arg max
27 |
28 | signal maxs[n+1];
29 | signal amaxs[n+1];
30 |
31 | maxs[0] <== in[0];
32 | amaxs[0] <== 0;
33 | for(var i = 0; i < n; i++) {
34 | gts[i] <== in[i] > maxs[i]; // changed to 252 (maximum) for better compatibility
35 | switchers[i+1] = Switcher();
36 | aswitchers[i+1] = Switcher();
37 |
38 | switchers[i+1].sel <== gts[i];
39 | switchers[i+1].L <== maxs[i];
40 | switchers[i+1].R <== in[i];
41 |
42 | aswitchers[i+1].sel <== gts[i];
43 | aswitchers[i+1].L <== amaxs[i];
44 | aswitchers[i+1].R <== i;
45 | amaxs[i+1] <== aswitchers[i+1].outL;
46 | maxs[i+1] <== switchers[i+1].outL;
47 | }
48 |
49 | out <== amaxs[n];
50 | }
51 |
52 | component main = ArgMax(2);
53 |
54 | /* INPUT = {
55 | "in": ["2","3","1","5","4"],
56 | "out": "3"
57 | } */
--------------------------------------------------------------------------------
/src/a_gate_type.rs:
--------------------------------------------------------------------------------
1 | use circom_program_structure::ast::ExpressionInfixOpcode;
2 | use serde::{Deserialize, Serialize};
3 | use strum_macros::{Display as StrumDisplay, EnumString};
4 |
5 | /// The supported Arithmetic gate types.
6 | #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, EnumString, StrumDisplay)]
7 | pub enum AGateType {
8 | AAdd,
9 | ADiv,
10 | AEq,
11 | AGEq,
12 | AGt,
13 | ALEq,
14 | ALt,
15 | AMul,
16 | ANeq,
17 | ASub,
18 | AXor,
19 | APow,
20 | AIntDiv,
21 | AMod,
22 | AShiftL,
23 | AShiftR,
24 | ABoolOr,
25 | ABoolAnd,
26 | ABitOr,
27 | ABitAnd,
28 | }
29 |
30 | impl From<&ExpressionInfixOpcode> for AGateType {
31 | fn from(opcode: &ExpressionInfixOpcode) -> Self {
32 | match opcode {
33 | ExpressionInfixOpcode::Mul => AGateType::AMul,
34 | ExpressionInfixOpcode::Div => AGateType::ADiv,
35 | ExpressionInfixOpcode::Add => AGateType::AAdd,
36 | ExpressionInfixOpcode::Sub => AGateType::ASub,
37 | ExpressionInfixOpcode::Pow => AGateType::APow,
38 | ExpressionInfixOpcode::IntDiv => AGateType::AIntDiv,
39 | ExpressionInfixOpcode::Mod => AGateType::AMod,
40 | ExpressionInfixOpcode::ShiftL => AGateType::AShiftL,
41 | ExpressionInfixOpcode::ShiftR => AGateType::AShiftR,
42 | ExpressionInfixOpcode::LesserEq => AGateType::ALEq,
43 | ExpressionInfixOpcode::GreaterEq => AGateType::AGEq,
44 | ExpressionInfixOpcode::Lesser => AGateType::ALt,
45 | ExpressionInfixOpcode::Greater => AGateType::AGt,
46 | ExpressionInfixOpcode::Eq => AGateType::AEq,
47 | ExpressionInfixOpcode::NotEq => AGateType::ANeq,
48 | ExpressionInfixOpcode::BoolOr => AGateType::ABoolOr,
49 | ExpressionInfixOpcode::BoolAnd => AGateType::ABoolAnd,
50 | ExpressionInfixOpcode::BitOr => AGateType::ABitOr,
51 | ExpressionInfixOpcode::BitAnd => AGateType::ABitAnd,
52 | ExpressionInfixOpcode::BitXor => AGateType::AXor,
53 | }
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/circom.rs:
--------------------------------------------------------------------------------
1 | //! # Circom Module
2 | //!
3 | //! This module is a slighly modified version of the original code from the repository `circom` sub module, related to
4 | //! the circom program structure, project configuration and compilation.
5 |
6 | pub mod parser;
7 | pub mod type_analysis;
8 |
9 | pub const VERSION: &str = "2.1.0";
10 |
--------------------------------------------------------------------------------
/src/circom/parser.rs:
--------------------------------------------------------------------------------
1 | use crate::{circom::VERSION, cli::Args, program::ProgramError};
2 | use circom_parser::run_parser;
3 | use circom_program_structure::{error_definition::Report, program_archive::ProgramArchive};
4 |
5 | pub fn parse_project(args: &Args) -> Result {
6 | let initial_file = args.input.to_str().unwrap().to_string();
7 | match run_parser(initial_file, VERSION, vec![]) {
8 | Result::Err((file_library, report_collection)) => {
9 | Report::print_reports(&report_collection, &file_library);
10 | Result::Err(ProgramError::ParsingError)
11 | }
12 | Result::Ok((program_archive, warnings)) => {
13 | Report::print_reports(&warnings, &program_archive.file_library);
14 | Result::Ok(program_archive)
15 | }
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/circom/type_analysis.rs:
--------------------------------------------------------------------------------
1 | use crate::program::ProgramError;
2 | use circom_program_structure::{error_definition::Report, program_archive::ProgramArchive};
3 | use circom_type_analysis::check_types::check_types;
4 |
5 | pub fn analyse_project(program_archive: &mut ProgramArchive) -> Result<(), ProgramError> {
6 | match check_types(program_archive) {
7 | Err(errs) => {
8 | Report::print_reports(&errs, program_archive.get_file_library());
9 | Err(ProgramError::AnalysisError)
10 | }
11 | Ok(warns) => {
12 | Report::print_reports(&warns, program_archive.get_file_library());
13 | Ok(())
14 | }
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/src/cli.rs:
--------------------------------------------------------------------------------
1 | use std::path::{Path, PathBuf};
2 |
3 | use clap::{Parser, ValueEnum};
4 | use serde::{Deserialize, Serialize};
5 |
6 | #[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum, Serialize, Deserialize, Default)]
7 | #[serde(rename_all = "lowercase")]
8 | pub enum ValueType {
9 | #[serde(rename = "sint")]
10 | #[default]
11 | Sint,
12 | #[serde(rename = "sfloat")]
13 | Sfloat,
14 | }
15 |
16 | #[derive(Parser)]
17 | #[clap(name = "Arithmetic Circuits Compiler")]
18 | #[command(disable_help_subcommand = true)]
19 | pub struct Args {
20 | /// Input file to process
21 | #[arg(
22 | short,
23 | long,
24 | help = "Path to the input file",
25 | default_value = "./input/circuit.circom"
26 | )]
27 | pub input: PathBuf,
28 |
29 | /// Output file to write the result
30 | #[arg(
31 | short,
32 | long,
33 | help = "Path to the directory where the output will be written",
34 | default_value = "./output/"
35 | )]
36 | pub output: PathBuf,
37 |
38 | #[arg(
39 | short,
40 | long,
41 | value_enum,
42 | help = "Type that'll be used for values in MPC backend",
43 | default_value_t = ValueType::Sint,
44 | )]
45 | pub value_type: ValueType,
46 |
47 | #[arg(
48 | long,
49 | help = "Optional: Convert to a boolean circuit by using integers with this number of bits",
50 | default_value = None,
51 | )]
52 | pub boolify_width: Option,
53 | }
54 |
55 | impl Args {
56 | pub fn new(
57 | input: PathBuf,
58 | output: PathBuf,
59 | value_type: ValueType,
60 | boolify_width: Option,
61 | ) -> Self {
62 | Self {
63 | input,
64 | output,
65 | value_type,
66 | boolify_width,
67 | }
68 | }
69 | }
70 |
71 | /// Function that returns output file path
72 | pub fn build_output(output_path: &Path, filename: &str, ext: &str) -> PathBuf {
73 | let mut file = output_path.to_path_buf();
74 | file.push(format!("{}.{}", filename, ext));
75 | file
76 | }
77 |
78 | #[cfg(test)]
79 | mod tests {
80 | use super::*;
81 |
82 | #[test]
83 | fn test_build_output() {
84 | let output_path = Path::new("./output");
85 | let filename = "result";
86 | let ext = "txt";
87 |
88 | let expected = PathBuf::from("./output/result.txt");
89 | let result = build_output(output_path, filename, ext);
90 |
91 | assert_eq!(result, expected);
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/src/compiler.rs:
--------------------------------------------------------------------------------
1 | //! # Circuit Module
2 | //!
3 | //! This module defines the data structures used to represent the arithmetic circuit.
4 |
5 | use crate::{
6 | a_gate_type::AGateType, cli::ValueType, program::ProgramError,
7 | topological_sort::topological_sort,
8 | };
9 | use bristol_circuit::{BristolCircuit, CircuitInfo, ConstantInfo, Gate};
10 | use log::debug;
11 | use serde::{Deserialize, Serialize};
12 | use std::collections::{HashMap, HashSet};
13 | use thiserror::Error;
14 |
15 | /// Represents a signal in the circuit, with a name and an optional value.
16 | #[derive(Debug, Serialize, Deserialize)]
17 | pub struct Signal {
18 | name: String,
19 | value: Option,
20 | }
21 |
22 | impl Signal {
23 | /// Creates a new signal.
24 | pub fn new(name: String, value: Option) -> Self {
25 | Self { name, value }
26 | }
27 | }
28 |
29 | /// Represents a node in the circuit, a collection of signals.
30 | /// The `is_const` and `is_out` fields saves us some iterations over the signals and gates.
31 | #[derive(Default, Debug, Serialize, Deserialize)]
32 | pub struct Node {
33 | is_const: bool,
34 | is_out: bool,
35 | signals: Vec,
36 | }
37 |
38 | impl Node {
39 | /// Creates a new empty node.
40 | pub fn new() -> Self {
41 | Self {
42 | signals: Vec::new(),
43 | is_const: false,
44 | is_out: false,
45 | }
46 | }
47 |
48 | /// Creates a new node with an initial signal.
49 | pub fn new_with_signal(signal_id: u32, is_const: bool, is_out: bool) -> Self {
50 | Self {
51 | signals: vec![signal_id],
52 | is_const,
53 | is_out,
54 | }
55 | }
56 |
57 | /// Adds a set of signals to the node.
58 | pub fn add_signals(&mut self, signals: &Vec) {
59 | self.signals.extend(signals);
60 | }
61 |
62 | /// Checks if the node contains a signal.
63 | pub fn contains_signal(&self, signal_id: &u32) -> bool {
64 | self.signals.contains(signal_id)
65 | }
66 |
67 | /// Gets the signals of the node.
68 | pub fn get_signals(&self) -> &Vec {
69 | &self.signals
70 | }
71 |
72 | /// Sets the node as an output node.
73 | pub fn set_output(&mut self, is_out: bool) {
74 | self.is_out = is_out;
75 | }
76 |
77 | /// Sets the node as a constant node.
78 | pub fn set_const(&mut self, is_const: bool) {
79 | self.is_const = is_const;
80 | }
81 | }
82 |
83 | /// Represents a circuit gate, with a left-hand input, right-hand input, and output node identifiers.
84 | #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
85 | pub struct ArithmeticGate {
86 | pub op: AGateType,
87 | pub lh_in: u32,
88 | pub rh_in: u32,
89 | pub out: u32,
90 | }
91 |
92 | impl ArithmeticGate {
93 | /// Creates a new gate.
94 | pub fn new(op: AGateType, lh_in: u32, rh_in: u32, out: u32) -> Self {
95 | Self {
96 | op,
97 | lh_in,
98 | rh_in,
99 | out,
100 | }
101 | }
102 | }
103 |
104 | /// Compilation data structure representing an arithmetic circuit with extra information, including
105 | /// a set of variables and gates.
106 | #[derive(Default, Debug, Serialize, Deserialize)]
107 | pub struct Compiler {
108 | node_count: u32,
109 | inputs: HashMap,
110 | outputs: HashMap,
111 | signals: HashMap,
112 | nodes: HashMap,
113 | gates: Vec,
114 | value_type: ValueType,
115 | }
116 |
117 | impl Compiler {
118 | pub fn new() -> Compiler {
119 | Compiler {
120 | node_count: 0,
121 | inputs: HashMap::new(),
122 | outputs: HashMap::new(),
123 | signals: HashMap::new(),
124 | nodes: HashMap::new(),
125 | gates: Vec::new(),
126 | value_type: Default::default(),
127 | }
128 | }
129 |
130 | pub fn add_inputs(&mut self, inputs: HashMap) {
131 | self.inputs.extend(inputs);
132 | }
133 |
134 | pub fn add_outputs(&mut self, outputs: HashMap) {
135 | self.outputs.extend(outputs);
136 | }
137 |
138 | /// Adds a new signal to the circuit.
139 | pub fn add_signal(
140 | &mut self,
141 | id: u32,
142 | name: String,
143 | value: Option,
144 | ) -> Result<(), CircuitError> {
145 | // Check that the signal isn't already declared
146 | if self.signals.contains_key(&id) {
147 | return Err(CircuitError::SignalAlreadyDeclared);
148 | }
149 |
150 | // Create a new signal
151 | let signal = Signal::new(name, value);
152 | self.signals.insert(id, signal);
153 |
154 | // Create a new node
155 | let node = Node::new_with_signal(id, value.is_some(), false);
156 | debug!("{:?}", node);
157 | let node_id = self.get_node_id();
158 | self.nodes.insert(node_id, node);
159 |
160 | Ok(())
161 | }
162 |
163 | pub fn get_signals(&self, filter: String) -> HashMap {
164 | let mut ret = HashMap::new();
165 | for (signal_id, signal) in self.signals.iter() {
166 | if signal.name.starts_with(filter.as_str()) {
167 | ret.insert(*signal_id, signal.name.to_string());
168 | }
169 | }
170 | ret
171 | }
172 |
173 | /// Adds a new gate to the circuit.
174 | pub fn add_gate(
175 | &mut self,
176 | gate_type: AGateType,
177 | lhs_signal_id: u32,
178 | rhs_signal_id: u32,
179 | output_signal_id: u32,
180 | ) -> Result<(), CircuitError> {
181 | // Get the signal node ids
182 | let node_ids = {
183 | let mut nodes: [u32; 3] = [0; 3];
184 |
185 | for (&id, node) in self.nodes.iter() {
186 | if node.contains_signal(&lhs_signal_id) {
187 | nodes[0] = id;
188 | }
189 | if node.contains_signal(&rhs_signal_id) {
190 | nodes[1] = id;
191 | }
192 | if node.contains_signal(&output_signal_id) {
193 | nodes[2] = id;
194 | }
195 | }
196 |
197 | nodes
198 | };
199 |
200 | // Set the output node as an output node
201 | self.nodes.get_mut(&node_ids[2]).unwrap().set_output(true);
202 |
203 | // Create gate
204 | let gate = ArithmeticGate::new(gate_type, node_ids[0], node_ids[1], node_ids[2]);
205 | debug!("{:?}", gate);
206 | self.gates.push(gate);
207 |
208 | Ok(())
209 | }
210 |
211 | /// Creates a connection between two signals in the circuit.
212 | /// This is finding the nodes that contain these signals and merging them.
213 | pub fn add_connection(&mut self, a: u32, b: u32) -> Result<(), CircuitError> {
214 | // Get the signal node ids
215 | let n = Node::new();
216 | let nodes = {
217 | let mut nodes: [(u32, &Node); 2] = [(0, &n); 2];
218 |
219 | for (&id, node) in self.nodes.iter() {
220 | if node.contains_signal(&a) {
221 | nodes[0] = (id, node);
222 | }
223 | if node.contains_signal(&b) {
224 | nodes[1] = (id, node);
225 | }
226 | }
227 |
228 | nodes
229 | };
230 |
231 | let (node_a_id, node_a) = nodes[0];
232 | let (node_b_id, node_b) = nodes[1];
233 |
234 | // If both signals are in the same node, no action is needed
235 | if node_a_id == node_b_id {
236 | return Ok(());
237 | }
238 | // Check for output and constant nodes
239 | if node_a.is_out && node_b.is_out {
240 | return Err(CircuitError::CannotMergeOutputNodes);
241 | }
242 |
243 | if node_a.is_const && node_b.is_const {
244 | return Err(CircuitError::CannotMergeConstantNodes);
245 | }
246 |
247 | // Merge the nodes into a new node
248 | let mut merged_node = Node::new();
249 |
250 | // Set the new node as an output and constant
251 | merged_node.set_output(node_a.is_out || node_b.is_out);
252 | merged_node.set_const(node_a.is_const || node_b.is_const);
253 |
254 | merged_node.add_signals(node_a.get_signals());
255 | merged_node.add_signals(node_b.get_signals());
256 |
257 | let merged_node_id = self.get_node_id();
258 |
259 | // Update connections in gates to point to the new merged node
260 | self.gates.iter_mut().for_each(|gate| {
261 | if gate.lh_in == node_a_id || gate.lh_in == node_b_id {
262 | gate.lh_in = merged_node_id;
263 | }
264 | if gate.rh_in == node_a_id || gate.rh_in == node_b_id {
265 | gate.rh_in = merged_node_id;
266 | }
267 | if gate.out == node_a_id || gate.out == node_b_id {
268 | gate.out = merged_node_id;
269 | }
270 | });
271 |
272 | // Remove the old nodes and insert the new merged node
273 | self.nodes.remove(&node_a_id);
274 | self.nodes.remove(&node_b_id);
275 | self.nodes.insert(merged_node_id, merged_node);
276 |
277 | Ok(())
278 | }
279 |
280 | pub fn update_type(&mut self, value_type: ValueType) -> Result<(), CircuitError> {
281 | self.value_type = value_type;
282 |
283 | Ok(())
284 | }
285 |
286 | /// Generates a circuit report with input and output signals information.
287 | pub fn generate_circuit_report(&self) -> Result {
288 | // Split input and output nodes
289 | let mut input_nodes = Vec::new();
290 | let mut output_nodes = Vec::new();
291 | self.nodes.iter().for_each(|(&id, node)| {
292 | if node.is_out {
293 | output_nodes.push(id);
294 | } else {
295 | input_nodes.push(id);
296 | }
297 | });
298 |
299 | // Remove output nodes that are inputs to gates
300 | output_nodes.retain(|&id| {
301 | self.gates
302 | .iter()
303 | .all(|gate| gate.lh_in != id && gate.rh_in != id)
304 | });
305 |
306 | // Sort
307 | input_nodes.sort_unstable();
308 | output_nodes.sort_unstable();
309 |
310 | // Generate reports
311 | let inputs = self.generate_signal_reports(&input_nodes);
312 | let outputs = self.generate_signal_reports(&output_nodes);
313 |
314 | Ok(CircuitReport {
315 | inputs,
316 | outputs,
317 | value_type: self.value_type,
318 | })
319 | }
320 |
321 | pub fn build_circuit(&self) -> Result {
322 | // First build up these maps so we can easily see which node id to use
323 | let mut input_to_node_id = HashMap::::new();
324 | let mut constant_to_node_id_and_value = HashMap::::new();
325 | let mut output_to_node_id = HashMap::::new();
326 |
327 | for (node_id, node) in self.nodes.iter() {
328 | // Each node has a list of signal ids which all correspond to that node
329 | // The compiler associates IO with signals, so here we bridge the gap so we get
330 | // IO <=> node instead of IO <=> signal <=> node
331 | for signal_id in node.get_signals() {
332 | if let Some(input_name) = self.inputs.get(signal_id) {
333 | let prev = input_to_node_id.insert(input_name.clone(), *node_id);
334 |
335 | if prev.is_some() {
336 | return Err(CircuitError::Inconsistency {
337 | message: format!("Duplicate input {}", input_name),
338 | });
339 | }
340 | }
341 |
342 | if let Some(output_name) = self.outputs.get(signal_id) {
343 | let prev = output_to_node_id.insert(output_name.clone(), *node_id);
344 |
345 | if prev.is_some() {
346 | return Err(CircuitError::Inconsistency {
347 | message: format!("Duplicate output {}", output_name),
348 | });
349 | }
350 | }
351 |
352 | let signal = &self.signals[signal_id];
353 |
354 | if let Some(value) = signal.value {
355 | constant_to_node_id_and_value.insert(
356 | format!("{}_{}", signal.name.clone(), signal_id),
357 | (*node_id, value.to_string()),
358 | );
359 | }
360 | }
361 | }
362 |
363 | {
364 | // We want inputs at the start and outputs at the end
365 | // We won't be able to do that if a node is used for both input and output
366 | // That shouldn't happen, so we check here that it doesn't happen
367 |
368 | let node_id_to_input_name = input_to_node_id
369 | .iter()
370 | .map(|(name, node_id)| (node_id, name))
371 | .collect::>();
372 |
373 | for (output_name, output_node_id) in &output_to_node_id {
374 | if let Some(input_name) = node_id_to_input_name.get(output_node_id) {
375 | return Err(CircuitError::Inconsistency {
376 | message: format!(
377 | "Node {} used for both input {} and output {}",
378 | output_node_id, input_name, output_name
379 | ),
380 | });
381 | }
382 | }
383 | }
384 |
385 | // Now node ids are like wire ids, but the compiler generates them in a way that leaves a
386 | // lot of gaps. So we assign new wire ids so they'll be sequential instead. We also do this
387 | // ensure inputs are at the start and outputs are at the end.
388 | let mut node_id_to_wire_id = HashMap::::new();
389 | let mut next_wire_id = 0;
390 |
391 | // First inputs
392 | for node_id in input_to_node_id.values() {
393 | node_id_to_wire_id.insert(*node_id, next_wire_id);
394 | next_wire_id += 1;
395 | }
396 |
397 | // For the intermediate nodes, we need the gates in topological order so that the wires are
398 | // assigned in the order they are needed. The topological order is also needed to comply
399 | // with bristol format and allow for easy evaluation.
400 |
401 | let mut node_id_to_required_gate = HashMap::::new();
402 |
403 | for (gate_id, gate) in self.gates.iter().enumerate() {
404 | // the gate.out node depends on this gate
405 | node_id_to_required_gate.insert(gate.out, gate_id);
406 | }
407 |
408 | let sorted_gate_ids = topological_sort(self.gates.len(), &|gate_id: usize| {
409 | let gate = &self.gates[gate_id];
410 | let mut deps = Vec::::new();
411 |
412 | if let Some(required_gate_id) = node_id_to_required_gate.get(&gate.lh_in) {
413 | deps.push(*required_gate_id);
414 | }
415 |
416 | if let Some(required_gate_id) = node_id_to_required_gate.get(&gate.rh_in) {
417 | deps.push(*required_gate_id);
418 | }
419 |
420 | deps
421 | })?;
422 |
423 | let output_node_ids = output_to_node_id.values().collect::>();
424 |
425 | // Now that the gates are in order, we can assign wire ids to each node in the order they
426 | // are seen
427 | for gate_id in &sorted_gate_ids {
428 | let gate = &self.gates[*gate_id];
429 |
430 | for node_id in &[gate.lh_in, gate.rh_in, gate.out] {
431 | if output_node_ids.contains(node_id) {
432 | // Output wires are excluded so that they can all be at the end
433 | continue;
434 | }
435 |
436 | if node_id_to_wire_id.contains_key(node_id) {
437 | continue;
438 | }
439 |
440 | node_id_to_wire_id.insert(*node_id, next_wire_id);
441 | next_wire_id += 1;
442 | }
443 | }
444 |
445 | // Assign wire ids to output nodes
446 | for node_id in output_to_node_id.values() {
447 | node_id_to_wire_id.insert(*node_id, next_wire_id);
448 | next_wire_id += 1;
449 | }
450 |
451 | // Now we can create the new gates using topological order and the new wire ids
452 | let mut new_gates = Vec::::new();
453 | for gate_id in sorted_gate_ids {
454 | let gate = &self.gates[gate_id];
455 |
456 | new_gates.push(Gate {
457 | inputs: vec![
458 | node_id_to_wire_id[&gate.lh_in] as usize,
459 | node_id_to_wire_id[&gate.rh_in] as usize,
460 | ],
461 | outputs: vec![node_id_to_wire_id[&gate.out] as usize],
462 | op: gate.op.to_string(),
463 | });
464 | }
465 |
466 | let mut constants = HashMap::::new();
467 |
468 | for (name, (node_id, value)) in constant_to_node_id_and_value {
469 | constants.insert(
470 | name,
471 | ConstantInfo {
472 | value,
473 | wire_index: node_id_to_wire_id[&node_id] as usize,
474 | },
475 | );
476 | }
477 |
478 | Ok(BristolCircuit {
479 | wire_count: next_wire_id as usize,
480 | info: CircuitInfo {
481 | input_name_to_wire_index: input_to_node_id
482 | .iter()
483 | .map(|(name, node_id)| (name.clone(), node_id_to_wire_id[node_id] as usize))
484 | .collect(),
485 | constants,
486 | output_name_to_wire_index: output_to_node_id
487 | .iter()
488 | .map(|(name, node_id)| (name.clone(), node_id_to_wire_id[node_id] as usize))
489 | .collect(),
490 | },
491 | gates: new_gates,
492 | io_widths: None,
493 | })
494 | }
495 |
496 | /// Returns a node id and increments the count.
497 | fn get_node_id(&mut self) -> u32 {
498 | self.node_count += 1;
499 | self.node_count
500 | }
501 |
502 | /// Generates signal reports for a set of node IDs.
503 | fn generate_signal_reports(&self, nodes: &[u32]) -> Vec {
504 | nodes
505 | .iter()
506 | .map(|&id| {
507 | let signals = self
508 | .nodes
509 | .get(&id)
510 | .expect("Node ID not found in node map")
511 | .get_signals();
512 |
513 | let (names, value) = signals.iter().fold((Vec::new(), None), |mut acc, &sig_id| {
514 | let signal = self
515 | .signals
516 | .get(&sig_id)
517 | .expect("Signal ID not found in signal map");
518 |
519 | if !signal.name.contains("random_") {
520 | acc.0.push(signal.name.clone());
521 | }
522 | if signal.value.is_some() {
523 | acc.1 = signal.value;
524 | }
525 | acc
526 | });
527 |
528 | SignalReport { id, names, value }
529 | })
530 | .collect()
531 | }
532 | }
533 |
534 | /// The full circuit report, containing input and output signals information.
535 | #[derive(Debug, Serialize, Deserialize)]
536 | pub struct CircuitReport {
537 | inputs: Vec,
538 | outputs: Vec,
539 | value_type: ValueType,
540 | }
541 |
542 | /// A single node report, with a list of signal names and an optional value.
543 | #[derive(Debug, Serialize, Deserialize)]
544 | pub struct SignalReport {
545 | id: u32,
546 | names: Vec,
547 | value: Option,
548 | }
549 |
550 | #[derive(Debug, Error)]
551 | pub enum CircuitError {
552 | #[error("Cannot merge constant nodes")]
553 | CannotMergeConstantNodes,
554 | #[error("Cannot merge output nodes")]
555 | CannotMergeOutputNodes,
556 | #[error("Constant value already set for variable")]
557 | ConstantValueAlreadySet,
558 | #[error("Signal is not connected to any node")]
559 | DisconnectedSignal,
560 | #[error(transparent)]
561 | IOError(#[from] std::io::Error),
562 | #[error(transparent)]
563 | ParseIntError(#[from] std::num::ParseIntError),
564 | #[error("Signal already declared")]
565 | SignalAlreadyDeclared,
566 | #[error("unsupported gate type: {0}")]
567 | UnsupportedGateType(String),
568 | #[error("Unprocessed node")]
569 | UnprocessedNode,
570 | #[error("Cyclic dependency: {message}")]
571 | CyclicDependency { message: String },
572 | #[error("Inconsistency: {message}")]
573 | Inconsistency { message: String },
574 | #[error("Parsing error: {message}")]
575 | ParsingError { message: String },
576 | }
577 |
578 | impl From for ProgramError {
579 | fn from(e: CircuitError) -> Self {
580 | ProgramError::CircuitError(e)
581 | }
582 | }
583 |
584 | #[cfg(test)]
585 | mod tests {
586 | use super::*;
587 |
588 | #[test]
589 | fn test_node_with_signal() {
590 | let node = Node::new_with_signal(1, true, false);
591 | assert_eq!(node.signals.len(), 1);
592 | assert_eq!(node.signals[0], 1);
593 | assert!(node.is_const);
594 | assert!(!node.is_out);
595 | }
596 |
597 | #[test]
598 | fn test_node_add_signal() {
599 | let mut node = Node::new();
600 | node.add_signals(&vec![1, 2, 3]);
601 | assert_eq!(node.signals.len(), 3);
602 | assert!(node.contains_signal(&1));
603 | assert!(node.contains_signal(&2));
604 | assert!(node.contains_signal(&3));
605 | }
606 |
607 | #[test]
608 | fn test_node_contains_signal() {
609 | let node = Node::new_with_signal(1, true, false);
610 | assert!(node.contains_signal(&1));
611 | assert!(!node.contains_signal(&2));
612 | }
613 |
614 | #[test]
615 | fn test_node_set_output() {
616 | let mut node = Node::new();
617 | node.set_output(true);
618 | assert!(node.is_out);
619 | }
620 |
621 | #[test]
622 | fn test_node_set_const() {
623 | let mut node = Node::new();
624 | node.set_const(true);
625 | assert!(node.is_const);
626 | }
627 |
628 | #[test]
629 | fn test_compiler_add_inputs() {
630 | let mut compiler = Compiler::new();
631 | let mut inputs = HashMap::new();
632 | inputs.insert(1, String::from("input1"));
633 | inputs.insert(2, String::from("input2"));
634 | compiler.add_inputs(inputs);
635 |
636 | assert_eq!(compiler.inputs.len(), 2);
637 | assert_eq!(compiler.inputs[&1], "input1");
638 | assert_eq!(compiler.inputs[&2], "input2");
639 | }
640 |
641 | #[test]
642 | fn test_compiler_add_outputs() {
643 | let mut compiler = Compiler::new();
644 | let mut outputs = HashMap::new();
645 | outputs.insert(3, String::from("output1"));
646 | outputs.insert(4, String::from("output2"));
647 | compiler.add_outputs(outputs);
648 |
649 | assert_eq!(compiler.outputs.len(), 2);
650 | assert_eq!(compiler.outputs[&3], "output1");
651 | assert_eq!(compiler.outputs[&4], "output2");
652 | }
653 |
654 | #[test]
655 | fn test_compiler_add_signal() {
656 | let mut compiler = Compiler::new();
657 | let result = compiler.add_signal(1, String::from("signal1"), None);
658 |
659 | assert!(result.is_ok());
660 | assert_eq!(compiler.signals.len(), 1);
661 | assert!(compiler.signals.contains_key(&1));
662 | assert_eq!(compiler.signals[&1].name, "signal1");
663 | }
664 |
665 | #[test]
666 | fn test_compiler_add_duplicated_signal() {
667 | let mut compiler = Compiler::new();
668 | compiler
669 | .add_signal(1, String::from("signal1"), None)
670 | .unwrap();
671 | let result = compiler.add_signal(1, String::from("signal1"), None);
672 |
673 | assert!(matches!(result, Err(CircuitError::SignalAlreadyDeclared)));
674 | }
675 |
676 | #[test]
677 | fn test_compiler_get_signals() {
678 | let mut compiler = Compiler::new();
679 | compiler
680 | .add_signal(1, String::from("signal1"), None)
681 | .unwrap();
682 | compiler
683 | .add_signal(2, String::from("filter_signal"), None)
684 | .unwrap();
685 | let filtered_signals = compiler.get_signals(String::from("filter"));
686 |
687 | assert_eq!(filtered_signals.len(), 1);
688 | assert_eq!(filtered_signals[&2], "filter_signal");
689 | }
690 |
691 | #[test]
692 | fn test_compiler_add_gate() {
693 | let mut compiler = Compiler::new();
694 | compiler
695 | .add_signal(1, String::from("signal1"), None)
696 | .unwrap();
697 | compiler
698 | .add_signal(2, String::from("signal2"), None)
699 | .unwrap();
700 | compiler
701 | .add_signal(3, String::from("signal3"), None)
702 | .unwrap();
703 |
704 | let result = compiler.add_gate(AGateType::AAdd, 1, 2, 3);
705 |
706 | assert!(result.is_ok());
707 | assert_eq!(compiler.gates.len(), 1);
708 | let gate = &compiler.gates[0];
709 | assert_eq!(gate.op, AGateType::AAdd);
710 | assert_eq!(gate.lh_in, 1);
711 | assert_eq!(gate.rh_in, 2);
712 | assert_eq!(gate.out, 3);
713 | }
714 |
715 | #[test]
716 | fn test_compiler_add_connection() {
717 | let mut compiler = Compiler::new();
718 | compiler
719 | .add_signal(1, String::from("signal1"), None)
720 | .unwrap();
721 | compiler
722 | .add_signal(2, String::from("signal2"), None)
723 | .unwrap();
724 | compiler
725 | .add_signal(3, String::from("signal3"), None)
726 | .unwrap();
727 |
728 | // Adding connection between signals 1 and 2
729 | let result = compiler.add_connection(1, 2);
730 |
731 | assert!(result.is_ok());
732 | assert_eq!(compiler.nodes.len(), 2);
733 |
734 | // Assert new node contains both signals
735 | let node = compiler.nodes.get(&4).unwrap();
736 | assert_eq!(node.signals.len(), 2);
737 | assert!(node.contains_signal(&1));
738 | assert!(node.contains_signal(&2));
739 | }
740 |
741 | #[test]
742 | fn test_compiler_add_connection_same_node() {
743 | let mut compiler = Compiler::new();
744 | compiler
745 | .add_signal(1, String::from("signal1"), None)
746 | .unwrap();
747 | compiler
748 | .add_signal(2, String::from("signal2"), None)
749 | .unwrap();
750 |
751 | compiler.add_connection(1, 2).unwrap();
752 | // Connect the same node
753 | let result = compiler.add_connection(1, 2);
754 |
755 | assert!(result.is_ok());
756 | // No change in number of nodes
757 | assert_eq!(compiler.nodes.len(), 1);
758 | }
759 |
760 | #[test]
761 | fn test_compiler_add_connection_output_nodes() {
762 | let mut compiler = Compiler::new();
763 | compiler
764 | .add_signal(1, String::from("signal1"), None)
765 | .unwrap();
766 | compiler
767 | .add_signal(2, String::from("signal2"), None)
768 | .unwrap();
769 |
770 | // Set both nodes as output nodes
771 | compiler.nodes.get_mut(&1).unwrap().set_output(true);
772 | compiler.nodes.get_mut(&2).unwrap().set_output(true);
773 |
774 | let result = compiler.add_connection(1, 2);
775 |
776 | assert!(matches!(result, Err(CircuitError::CannotMergeOutputNodes)));
777 | }
778 |
779 | #[test]
780 | fn test_compiler_add_connection_constant_nodes() {
781 | let mut compiler = Compiler::new();
782 | compiler
783 | .add_signal(1, String::from("signal1"), Some(1))
784 | .unwrap();
785 | compiler
786 | .add_signal(2, String::from("signal2"), Some(2))
787 | .unwrap();
788 |
789 | let result = compiler.add_connection(1, 2);
790 | assert!(matches!(
791 | result,
792 | Err(CircuitError::CannotMergeConstantNodes)
793 | ));
794 | }
795 | }
796 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | //! # Circom To Arithmetic Circuit
2 | //!
3 | //! This library provides the functionality to convert a Circom program into an arithmetic circuit.
4 |
5 | pub mod a_gate_type;
6 | pub mod circom;
7 | pub mod cli;
8 | pub mod compiler;
9 | pub mod process;
10 | pub mod program;
11 | pub mod runtime;
12 |
13 | mod topological_sort;
14 |
--------------------------------------------------------------------------------
/src/main.rs:
--------------------------------------------------------------------------------
1 | use boolify::boolify;
2 | use circom_2_arithc::{
3 | cli::{build_output, Args},
4 | program::{compile, ProgramError},
5 | };
6 | use clap::Parser;
7 | use dotenv::dotenv;
8 | use env_logger::{init_from_env, Env};
9 | use serde_json::to_string_pretty;
10 | use std::{
11 | fs::{self, File},
12 | io::Write,
13 | };
14 |
15 | fn main() -> Result<(), ProgramError> {
16 | dotenv().ok();
17 | init_from_env(Env::default().filter_or("LOG_LEVEL", "info"));
18 |
19 | let args = Args::parse();
20 |
21 | let compiler = compile(&args)?;
22 | let report = compiler.generate_circuit_report()?;
23 |
24 | let output_dir = args.output.clone();
25 | fs::create_dir_all(output_dir.clone())
26 | .map_err(|_| ProgramError::OutputDirectoryCreationError)?;
27 |
28 | let mut circuit = compiler.build_circuit()?;
29 |
30 | if let Some(boolify_width) = args.boolify_width {
31 | circuit = boolify(&circuit, boolify_width);
32 | }
33 |
34 | let output_file_path = build_output(&output_dir, "circuit", "txt");
35 | circuit.write_bristol(&mut File::create(output_file_path)?)?;
36 |
37 | // let output_file_path_json = build_output(&output_dir, "circuit", "json");
38 | // File::create(output_file_path_json)?.write_all(serde_json::to_string_pretty(&circuit)?.as_bytes())?;
39 |
40 | // let output_debug_path_json = build_output(&output_dir, "debug", "json");
41 | // File::create(output_debug_path_json)?.write_all(serde_json::to_string_pretty(&compiler)?.as_bytes())?;
42 |
43 | let output_file_path = build_output(&output_dir, "circuit_info", "json");
44 | File::create(output_file_path)?.write_all(to_string_pretty(&circuit.info)?.as_bytes())?;
45 |
46 | let report_file_path = build_output(&output_dir, "report", "json");
47 | File::create(report_file_path)?.write_all(to_string_pretty(&report)?.as_bytes())?;
48 |
49 | Ok(())
50 | }
51 |
--------------------------------------------------------------------------------
/src/process.rs:
--------------------------------------------------------------------------------
1 | //! # Process Module
2 | //!
3 | //! Handles execution of statements and expressions for arithmetic circuit generation within a `Runtime` environment.
4 |
5 | use crate::a_gate_type::AGateType;
6 | use crate::compiler::Compiler;
7 | use crate::program::ProgramError;
8 | use crate::runtime::{
9 | generate_u32, increment_indices, u32_to_access, Context, DataAccess, DataType, NestedValue,
10 | Runtime, RuntimeError, Signal, SubAccess, RETURN_VAR,
11 | };
12 | use circom_circom_algebra::num_traits::ToPrimitive;
13 | use circom_program_structure::ast::{
14 | Access, AssignOp, Expression, ExpressionInfixOpcode, ExpressionPrefixOpcode, Statement,
15 | };
16 | use circom_program_structure::program_archive::ProgramArchive;
17 | use std::cell::RefCell;
18 | use std::collections::HashMap;
19 | use std::rc::Rc;
20 |
21 | /// Processes a sequence of statements.
22 | pub fn process_statements(
23 | ac: &mut Compiler,
24 | runtime: &mut Runtime,
25 | program_archive: &ProgramArchive,
26 | statements: &[Statement],
27 | ) -> Result<(), ProgramError> {
28 | for statement in statements {
29 | process_statement(ac, runtime, program_archive, statement)?;
30 | }
31 |
32 | Ok(())
33 | }
34 |
35 | /// Processes a single statement.
36 | pub fn process_statement(
37 | ac: &mut Compiler,
38 | runtime: &mut Runtime,
39 | program_archive: &ProgramArchive,
40 | statement: &Statement,
41 | ) -> Result<(), ProgramError> {
42 | match statement {
43 | Statement::InitializationBlock {
44 | initializations, ..
45 | } => process_statements(ac, runtime, program_archive, initializations),
46 | Statement::Block { stmts, .. } => process_statements(ac, runtime, program_archive, stmts),
47 | Statement::Substitution {
48 | var,
49 | access,
50 | rhe,
51 | op,
52 | ..
53 | } => handle_substitution(ac, runtime, program_archive, var, access, rhe, op),
54 | Statement::Declaration {
55 | xtype,
56 | name,
57 | dimensions,
58 | ..
59 | } => {
60 | let data_type = DataType::try_from(xtype)?;
61 | let dim_access: Vec = dimensions
62 | .iter()
63 | .map(|expression| process_expression(ac, runtime, program_archive, expression))
64 | .collect::, ProgramError>>()?;
65 |
66 | let signal_gen = runtime.get_signal_gen();
67 | let ctx = runtime.current_context()?;
68 | let dimensions: Vec = dim_access
69 | .iter()
70 | .map(|dim_access| {
71 | ctx.get_variable_value(dim_access)?
72 | .ok_or(ProgramError::EmptyDataItem)
73 | })
74 | .collect::, ProgramError>>()?;
75 | ctx.declare_item(data_type.clone(), name, &dimensions, signal_gen)?;
76 |
77 | // If the declared item is a signal we should add it to the arithmetic circuit
78 | if data_type == DataType::Signal {
79 | let mut signal_access = DataAccess::new(name, Vec::new());
80 |
81 | if dimensions.is_empty() {
82 | let signal_id = ctx.get_signal_id(&signal_access)?;
83 | ac.add_signal(
84 | signal_id,
85 | signal_access.access_str(ctx.get_ctx_name()),
86 | None,
87 | )?;
88 | } else {
89 | let mut indices: Vec = vec![0; dimensions.len()];
90 |
91 | loop {
92 | // Set access and get signal id for the current indices
93 | signal_access.set_access(u32_to_access(&indices));
94 | let signal_id = ctx.get_signal_id(&signal_access)?;
95 | ac.add_signal(
96 | signal_id,
97 | signal_access.access_str(ctx.get_ctx_name()),
98 | None,
99 | )?;
100 |
101 | // Increment indices
102 | if !increment_indices(&mut indices, &dimensions)? {
103 | break;
104 | }
105 | }
106 | }
107 | }
108 |
109 | Ok(())
110 | }
111 | Statement::IfThenElse {
112 | cond,
113 | if_case,
114 | else_case,
115 | ..
116 | } => {
117 | let access = process_expression(ac, runtime, program_archive, cond)?;
118 | let result = runtime
119 | .current_context()?
120 | .get_variable_value(&access)?
121 | .ok_or(ProgramError::EmptyDataItem)?;
122 |
123 | if result == 0 {
124 | if let Some(else_statement) = else_case {
125 | runtime.push_context(true, "IF_FALSE".to_string())?;
126 | process_statement(ac, runtime, program_archive, else_statement)?;
127 | runtime.pop_context(true)?;
128 | Ok(())
129 | } else {
130 | Ok(())
131 | }
132 | } else {
133 | runtime.push_context(true, "IF_TRUE".to_string())?;
134 | process_statement(ac, runtime, program_archive, if_case)?;
135 | runtime.pop_context(true)?;
136 | Ok(())
137 | }
138 | }
139 | Statement::While { cond, stmt, .. } => {
140 | runtime.push_context(true, "WHILE_PRE".to_string())?;
141 | loop {
142 | let access = process_expression(ac, runtime, program_archive, cond)?;
143 | let result = runtime
144 | .current_context()?
145 | .get_variable_value(&access)?
146 | .ok_or(ProgramError::EmptyDataItem)?;
147 |
148 | if result == 0 {
149 | break;
150 | }
151 |
152 | runtime.push_context(true, "WHILE_EXE".to_string())?;
153 | process_statement(ac, runtime, program_archive, stmt)?;
154 | runtime.pop_context(true)?;
155 | }
156 | runtime.pop_context(true)?;
157 |
158 | Ok(())
159 | }
160 | Statement::Return { value, .. } => {
161 | let return_access = process_expression(ac, runtime, program_archive, value)?;
162 |
163 | let signal_gen = runtime.get_signal_gen();
164 | let ctx = runtime.current_context()?;
165 | let return_value = ctx
166 | .get_variable_value(&return_access)?
167 | .ok_or(ProgramError::EmptyDataItem)?;
168 |
169 | ctx.declare_item(DataType::Variable, RETURN_VAR, &[], signal_gen)?;
170 | ctx.set_variable(&DataAccess::new(RETURN_VAR, vec![]), Some(return_value))?;
171 |
172 | Ok(())
173 | }
174 | Statement::Assert { arg, .. } => {
175 | let access = process_expression(ac, runtime, program_archive, arg)?;
176 | let result = runtime
177 | .current_context()?
178 | .get_variable_value(&access)?
179 | .ok_or(ProgramError::EmptyDataItem)?;
180 |
181 | if result == 0 {
182 | return Err(ProgramError::RuntimeError(RuntimeError::AssertionFailed));
183 | }
184 |
185 | Ok(())
186 | }
187 | _ => Err(ProgramError::StatementNotImplemented),
188 | }
189 | }
190 |
191 | /// Handles a substitution statement
192 | fn handle_substitution(
193 | ac: &mut Compiler,
194 | runtime: &mut Runtime,
195 | program_archive: &ProgramArchive,
196 | var: &str,
197 | access: &[Access],
198 | rhe: &Expression,
199 | op: &AssignOp,
200 | ) -> Result<(), ProgramError> {
201 | let lh_access = build_access(ac, runtime, program_archive, var, access)?;
202 | let rh_access = process_expression(ac, runtime, program_archive, rhe)?;
203 |
204 | let signal_gen = runtime.get_signal_gen();
205 | let ctx = runtime.current_context()?;
206 | match ctx.get_item_data_type(var)? {
207 | DataType::Variable => {
208 | // Assign the evaluated right-hand side to the left-hand side
209 | let value = ctx.get_variable_value(&rh_access)?;
210 | ctx.set_variable(&lh_access, value)?;
211 | }
212 | DataType::Component => match op {
213 | AssignOp::AssignVar => {
214 | // Component instantiation
215 | let signal_map = ctx.get_component_map(&rh_access)?;
216 | ctx.set_component(&lh_access, signal_map)?;
217 | }
218 | AssignOp::AssignConstraintSignal => {
219 | // Component signal assignment
220 | match ctx.get_component_signal_content(&lh_access)? {
221 | NestedValue::Array(signal) => {
222 | let assigned_signal_array =
223 | match get_signal_content_for_access(ctx, &rh_access)? {
224 | NestedValue::Array(array) => array,
225 | _ => return Err(ProgramError::InvalidDataType),
226 | };
227 |
228 | connect_signal_arrays(ac, &signal, &assigned_signal_array)?;
229 | }
230 | NestedValue::Value(_) => {
231 | let component_signal = ctx.get_component_signal_id(&lh_access)?;
232 | let assigned_signal =
233 | get_signal_for_access(ac, ctx, signal_gen, &rh_access)?;
234 |
235 | ac.add_connection(assigned_signal, component_signal)?;
236 | }
237 | }
238 | }
239 | _ => return Err(ProgramError::OperationNotSupported),
240 | },
241 | DataType::Signal => {
242 | match rhe {
243 | Expression::Variable { .. } => match ctx.get_signal_content(&lh_access)? {
244 | NestedValue::Array(signal) => {
245 | // Connect the signals in the arrays
246 | let assigned_signal_array =
247 | match get_signal_content_for_access(ctx, &rh_access)? {
248 | NestedValue::Array(array) => array,
249 | _ => return Err(ProgramError::InvalidDataType),
250 | };
251 |
252 | connect_signal_arrays(ac, &signal, &assigned_signal_array)?;
253 | }
254 | NestedValue::Value(signal_id) => {
255 | let gate_output_id =
256 | get_signal_for_access(ac, ctx, signal_gen, &rh_access)?;
257 |
258 | ac.add_connection(gate_output_id, signal_id)?;
259 | }
260 | },
261 | Expression::Call { .. }
262 | | Expression::InfixOp { .. }
263 | | Expression::PrefixOp { .. }
264 | | Expression::Number(_, _) => {
265 | // Get the signal identifiers and connect them
266 | let given_output_id = ctx.get_signal_id(&lh_access)?;
267 | let gate_output_id = get_signal_for_access(ac, ctx, signal_gen, &rh_access)?;
268 |
269 | ac.add_connection(gate_output_id, given_output_id)?;
270 | }
271 | _ => return Err(ProgramError::SignalSubstitutionNotImplemented),
272 | }
273 | }
274 | }
275 |
276 | Ok(())
277 | }
278 |
279 | /// Processes an expression and returns an access to the result.
280 | pub fn process_expression(
281 | ac: &mut Compiler,
282 | runtime: &mut Runtime,
283 | program_archive: &ProgramArchive,
284 | expression: &Expression,
285 | ) -> Result {
286 | match expression {
287 | Expression::Call { id, args, .. } => handle_call(ac, runtime, program_archive, id, args),
288 | Expression::InfixOp {
289 | lhe, infix_op, rhe, ..
290 | } => handle_infix_op(ac, runtime, program_archive, infix_op, lhe, rhe),
291 | Expression::PrefixOp { prefix_op, rhe, .. } => {
292 | handle_prefix_op(ac, runtime, program_archive, prefix_op, rhe)
293 | }
294 | Expression::Number(_, value) => {
295 | let signal_gen = runtime.get_signal_gen();
296 | let access = runtime
297 | .current_context()?
298 | .declare_random_item(signal_gen, DataType::Variable)?;
299 |
300 | runtime.current_context()?.set_variable(
301 | &access,
302 | Some(value.to_u32().ok_or(ProgramError::ParsingError)?),
303 | )?;
304 |
305 | Ok(access)
306 | }
307 | Expression::Variable { name, access, .. } => {
308 | build_access(ac, runtime, program_archive, name, access)
309 | }
310 | _ => Err(ProgramError::ExpressionNotImplemented),
311 | }
312 | }
313 |
314 | /// Handles function and template calls.
315 | fn handle_call(
316 | ac: &mut Compiler,
317 | runtime: &mut Runtime,
318 | program_archive: &ProgramArchive,
319 | id: &str,
320 | args: &[Expression],
321 | ) -> Result {
322 | // Determine if the call is to a function or a template and get argument names and body
323 | let is_function = program_archive.contains_function(id);
324 | let (arg_names, body) = if is_function {
325 | let function_data = program_archive.get_function_data(id);
326 | (
327 | function_data.get_name_of_params().clone(),
328 | function_data.get_body_as_vec().to_vec(),
329 | )
330 | } else if program_archive.contains_template(id) {
331 | let template_data = program_archive.get_template_data(id);
332 | (
333 | template_data.get_name_of_params().clone(),
334 | template_data.get_body_as_vec().to_vec(),
335 | )
336 | } else {
337 | return Err(ProgramError::UndefinedFunctionOrTemplate);
338 | };
339 |
340 | let arg_values = args
341 | .iter()
342 | .map(|arg_expr| {
343 | process_expression(ac, runtime, program_archive, arg_expr).and_then(|value_access| {
344 | runtime
345 | .current_context()?
346 | .get_variable_value(&value_access)?
347 | .ok_or(ProgramError::EmptyDataItem)
348 | })
349 | })
350 | .collect::, ProgramError>>()?;
351 |
352 | // Create a new execution context
353 | runtime.push_context(false, id.to_string())?;
354 |
355 | // Set arguments in the new context
356 | for (arg_name, &arg_value) in arg_names.iter().zip(&arg_values) {
357 | let signal_gen = runtime.get_signal_gen();
358 | runtime
359 | .current_context()?
360 | .declare_item(DataType::Variable, arg_name, &[], signal_gen)?;
361 | runtime
362 | .current_context()?
363 | .set_variable(&DataAccess::new(arg_name, vec![]), Some(arg_value))?;
364 | }
365 |
366 | // Process the function/template body
367 | process_statements(ac, runtime, program_archive, &body)?;
368 |
369 | // Get return values
370 | let mut function_return: Option = None;
371 | let mut component_return: HashMap = HashMap::new();
372 |
373 | if is_function {
374 | if let Ok(value) = runtime
375 | .current_context()?
376 | .get_variable_value(&DataAccess::new(RETURN_VAR, vec![]))
377 | {
378 | function_return = value;
379 | }
380 | } else {
381 | // Retrieve input and output signals
382 | let template_data = program_archive.get_template_data(id);
383 | let input_signals = template_data.get_inputs();
384 | let output_signals = template_data.get_outputs();
385 |
386 | // Store ids in the component
387 | for (signal, _) in input_signals.iter().chain(output_signals.iter()) {
388 | let ids = runtime.current_context()?.get_signal(signal)?;
389 | component_return.insert(signal.to_string(), ids);
390 | }
391 | }
392 |
393 | // Return to parent context
394 | runtime.pop_context(false)?;
395 | let signal_gen = runtime.get_signal_gen();
396 | let ctx = runtime.current_context()?;
397 | let return_access =
398 | DataAccess::new(&format!("{}_{}_{}", id, RETURN_VAR, generate_u32()), vec![]);
399 |
400 | if is_function {
401 | ctx.declare_item(
402 | DataType::Variable,
403 | &return_access.get_name(),
404 | &[],
405 | signal_gen,
406 | )?;
407 | ctx.set_variable(&return_access, function_return)?;
408 | } else {
409 | ctx.declare_item(
410 | DataType::Component,
411 | &return_access.get_name(),
412 | &[],
413 | signal_gen,
414 | )?;
415 | ctx.set_component(&return_access, component_return)?;
416 | }
417 |
418 | Ok(return_access)
419 | }
420 |
421 | /// Handles an infix operation.
422 | /// - If both inputs are variables, it directly computes the operation.
423 | /// - If one or both inputs are signals, it constructs the corresponding circuit gate.
424 | ///
425 | /// Returns the access to a variable containing the result of the operation or the signal of the output gate.
426 | fn handle_infix_op(
427 | ac: &mut Compiler,
428 | runtime: &mut Runtime,
429 | program_archive: &ProgramArchive,
430 | op: &ExpressionInfixOpcode,
431 | lhe: &Expression,
432 | rhe: &Expression,
433 | ) -> Result {
434 | let lhe_access = process_expression(ac, runtime, program_archive, lhe)?;
435 | let rhe_access = process_expression(ac, runtime, program_archive, rhe)?;
436 |
437 | let signal_gen: Rc> = runtime.get_signal_gen();
438 | let ctx = runtime.current_context()?;
439 |
440 | // Determine the data types of the left and right operands
441 | let lhs_data_type = ctx.get_item_data_type(&lhe_access.get_name())?;
442 | let rhs_data_type = ctx.get_item_data_type(&rhe_access.get_name())?;
443 |
444 | // Handle the case where both inputs are variables
445 | if lhs_data_type == DataType::Variable && rhs_data_type == DataType::Variable {
446 | let lhs_value = ctx
447 | .get_variable_value(&lhe_access)?
448 | .ok_or(ProgramError::EmptyDataItem)?;
449 | let rhs_value = ctx
450 | .get_variable_value(&rhe_access)?
451 | .ok_or(ProgramError::EmptyDataItem)?;
452 |
453 | let op_res = execute_op(lhs_value, rhs_value, op)?;
454 | let item_access = ctx.declare_random_item(signal_gen, DataType::Variable)?;
455 | ctx.set_variable(&item_access, Some(op_res))?;
456 |
457 | return Ok(item_access);
458 | }
459 |
460 | // Handle cases where one or both inputs are signals
461 | let lhs_id = get_signal_for_access(ac, ctx, signal_gen.clone(), &lhe_access)?;
462 | let rhs_id = get_signal_for_access(ac, ctx, signal_gen.clone(), &rhe_access)?;
463 |
464 | // Construct the corresponding circuit gate
465 | let gate_type = AGateType::from(op);
466 | let output_signal = ctx.declare_random_item(signal_gen, DataType::Signal)?;
467 | let output_id = ctx.get_signal_id(&output_signal)?;
468 |
469 | // Add output signal and gate to the circuit
470 | ac.add_signal(
471 | output_id,
472 | output_signal.access_str(ctx.get_ctx_name()),
473 | None,
474 | )?;
475 | ac.add_gate(gate_type, lhs_id, rhs_id, output_id)?;
476 |
477 | Ok(output_signal)
478 | }
479 |
480 | /// Handles a prefix operation.
481 | /// - If input is a variable, it directly computes the operation.
482 | /// - If input is a signal, it handles it like an infix op against a constant.
483 | ///
484 | /// Returns the access to a variable containing the result of the operation or the signal of the output gate.
485 | fn handle_prefix_op(
486 | ac: &mut Compiler,
487 | runtime: &mut Runtime,
488 | program_archive: &ProgramArchive,
489 | op: &ExpressionPrefixOpcode,
490 | rhe: &Expression,
491 | ) -> Result {
492 | let rhe_access = process_expression(ac, runtime, program_archive, rhe)?;
493 |
494 | let signal_gen: Rc> = runtime.get_signal_gen();
495 | let ctx = runtime.current_context()?;
496 |
497 | // Determine the data type of the operand
498 | let rhs_data_type = ctx.get_item_data_type(&rhe_access.get_name())?;
499 |
500 | // Handle the variable case
501 | if rhs_data_type == DataType::Variable {
502 | let rhs_value = ctx
503 | .get_variable_value(&rhe_access)?
504 | .ok_or(ProgramError::EmptyDataItem)?;
505 |
506 | let op_res = execute_prefix_op(op, rhs_value)?;
507 | let item_access = ctx.declare_random_item(signal_gen, DataType::Variable)?;
508 | ctx.set_variable(&item_access, Some(op_res))?;
509 |
510 | return Ok(item_access);
511 | }
512 |
513 | let (lhs_value, infix_op) = to_equivalent_infix(op);
514 | let lhs_id = make_constant(ac, ctx, signal_gen.clone(), lhs_value)?;
515 |
516 | // Handle signal input
517 | let rhs_id = get_signal_for_access(ac, ctx, signal_gen.clone(), &rhe_access)?;
518 |
519 | // Construct the corresponding circuit gate
520 | let gate_type = AGateType::from(&infix_op);
521 | let output_signal = ctx.declare_random_item(signal_gen, DataType::Signal)?;
522 | let output_id = ctx.get_signal_id(&output_signal)?;
523 |
524 | // Add output signal and gate to the circuit
525 | ac.add_signal(
526 | output_id,
527 | output_signal.access_str(ctx.get_ctx_name()),
528 | None,
529 | )?;
530 | ac.add_gate(gate_type, lhs_id, rhs_id, output_id)?;
531 |
532 | Ok(output_signal)
533 | }
534 |
535 | /// Returns a signal id for a given access
536 | /// - If the access is a signal or a component, it returns the corresponding signal id.
537 | /// - If the access is a variable, it adds a constant variable to the circuit and returns the corresponding signal id.
538 | fn get_signal_for_access(
539 | ac: &mut Compiler,
540 | ctx: &mut Context,
541 | signal_gen: Rc>,
542 | access: &DataAccess,
543 | ) -> Result {
544 | match ctx.get_item_data_type(&access.get_name())? {
545 | DataType::Signal => Ok(ctx.get_signal_id(access)?),
546 | DataType::Variable => {
547 | // Get variable value
548 | let value = ctx
549 | .get_variable_value(access)?
550 | .ok_or(ProgramError::EmptyDataItem)?;
551 |
552 | make_constant(ac, ctx, signal_gen, value)
553 | }
554 | DataType::Component => Ok(ctx.get_component_signal_id(access)?),
555 | }
556 | }
557 |
558 | fn make_constant(
559 | ac: &mut Compiler,
560 | ctx: &mut Context,
561 | signal_gen: Rc>,
562 | value: u32,
563 | ) -> Result {
564 | let signal_access = DataAccess::new(&format!("const_signal_{}", value), vec![]);
565 | // Try to get signal id if it exists
566 | if let Ok(id) = ctx.get_signal_id(&signal_access) {
567 | Ok(id)
568 | } else {
569 | // If it doesn't exist, declare it and add it to the circuit
570 | ctx.declare_item(DataType::Signal, &signal_access.get_name(), &[], signal_gen)?;
571 | let signal_id = ctx.get_signal_id(&signal_access)?;
572 | ac.add_signal(
573 | signal_id,
574 | signal_access.access_str(ctx.get_ctx_name()),
575 | Some(value),
576 | )?;
577 | Ok(signal_id)
578 | }
579 | }
580 |
581 | /// Returns the content of a signal for a given access
582 | fn get_signal_content_for_access(
583 | ctx: &Context,
584 | access: &DataAccess,
585 | ) -> Result, ProgramError> {
586 | match ctx.get_item_data_type(&access.get_name())? {
587 | DataType::Signal => Ok(ctx.get_signal_content(access)?),
588 | DataType::Component => Ok(ctx.get_component_signal_content(access)?),
589 | _ => Err(ProgramError::InvalidDataType),
590 | }
591 | }
592 |
593 | /// Connects two composed signals
594 | fn connect_signal_arrays(
595 | ac: &mut Compiler,
596 | a: &[NestedValue],
597 | b: &[NestedValue],
598 | ) -> Result<(), ProgramError> {
599 | // Verify that the arrays have the same length
600 | if a.len() != b.len() {
601 | return Err(ProgramError::InvalidDataType);
602 | }
603 |
604 | for (a, b) in a.iter().zip(b.iter()) {
605 | match (a, b) {
606 | (NestedValue::Value(a), NestedValue::Value(b)) => {
607 | ac.add_connection(*a, *b)?;
608 | }
609 | (NestedValue::Array(a), NestedValue::Array(b)) => {
610 | connect_signal_arrays(ac, a, b)?;
611 | }
612 | _ => return Err(ProgramError::InvalidDataType),
613 | }
614 | }
615 |
616 | Ok(())
617 | }
618 |
619 | /// Builds a DataAccess from an Access array
620 | fn build_access(
621 | ac: &mut Compiler,
622 | runtime: &mut Runtime,
623 | program_archive: &ProgramArchive,
624 | name: &str,
625 | access: &[Access],
626 | ) -> Result {
627 | let mut access_vec = Vec::new();
628 |
629 | for a in access.iter() {
630 | match a {
631 | Access::ArrayAccess(expression) => {
632 | let index_access = process_expression(ac, runtime, program_archive, expression)?;
633 | let index = runtime
634 | .current_context()?
635 | .get_variable_value(&index_access)?
636 | .ok_or(ProgramError::EmptyDataItem)?;
637 | access_vec.push(SubAccess::Array(index));
638 | }
639 | Access::ComponentAccess(signal) => {
640 | access_vec.push(SubAccess::Component(signal.to_string()));
641 | }
642 | }
643 | }
644 |
645 | Ok(DataAccess::new(name, access_vec))
646 | }
647 |
648 | /// Executes an operation on two u32 values, performing the specified arithmetic or logical computation.
649 | fn execute_op(lhs: u32, rhs: u32, op: &ExpressionInfixOpcode) -> Result {
650 | let res = match op {
651 | ExpressionInfixOpcode::Mul => lhs * rhs,
652 | ExpressionInfixOpcode::Div => {
653 | if rhs == 0 {
654 | return Err(ProgramError::OperationError("Division by zero".to_string()));
655 | }
656 |
657 | lhs / rhs
658 | }
659 | ExpressionInfixOpcode::Add => lhs + rhs,
660 | ExpressionInfixOpcode::Sub => {
661 | if lhs < rhs {
662 | return Err(ProgramError::OperationError(
663 | "Subtraction underflow".to_string(),
664 | ));
665 | }
666 |
667 | lhs - rhs
668 | }
669 | ExpressionInfixOpcode::Pow => lhs.pow(rhs),
670 | ExpressionInfixOpcode::IntDiv => {
671 | if rhs == 0 {
672 | return Err(ProgramError::OperationError(
673 | "Integer division by zero".to_string(),
674 | ));
675 | }
676 |
677 | lhs / rhs
678 | }
679 | ExpressionInfixOpcode::Mod => {
680 | if rhs == 0 {
681 | return Err(ProgramError::OperationError("Modulo by zero".to_string()));
682 | }
683 |
684 | lhs % rhs
685 | }
686 | ExpressionInfixOpcode::ShiftL => lhs << rhs,
687 | ExpressionInfixOpcode::ShiftR => lhs >> rhs,
688 | ExpressionInfixOpcode::LesserEq => {
689 | if lhs <= rhs {
690 | 1
691 | } else {
692 | 0
693 | }
694 | }
695 | ExpressionInfixOpcode::GreaterEq => {
696 | if lhs >= rhs {
697 | 1
698 | } else {
699 | 0
700 | }
701 | }
702 | ExpressionInfixOpcode::Lesser => {
703 | if lhs < rhs {
704 | 1
705 | } else {
706 | 0
707 | }
708 | }
709 | ExpressionInfixOpcode::Greater => {
710 | if lhs > rhs {
711 | 1
712 | } else {
713 | 0
714 | }
715 | }
716 | ExpressionInfixOpcode::Eq => {
717 | if lhs == rhs {
718 | 1
719 | } else {
720 | 0
721 | }
722 | }
723 | ExpressionInfixOpcode::NotEq => {
724 | if lhs != rhs {
725 | 1
726 | } else {
727 | 0
728 | }
729 | }
730 | ExpressionInfixOpcode::BoolOr => {
731 | if lhs != 0 || rhs != 0 {
732 | 1
733 | } else {
734 | 0
735 | }
736 | }
737 | ExpressionInfixOpcode::BoolAnd => {
738 | if lhs != 0 && rhs != 0 {
739 | 1
740 | } else {
741 | 0
742 | }
743 | }
744 | ExpressionInfixOpcode::BitOr => lhs | rhs,
745 | ExpressionInfixOpcode::BitAnd => lhs & rhs,
746 | ExpressionInfixOpcode::BitXor => lhs ^ rhs,
747 | };
748 |
749 | Ok(res)
750 | }
751 |
752 | /// Executes a prefix operation on a u32 value, performing the specified arithmetic or logical computation.
753 | fn execute_prefix_op(op: &ExpressionPrefixOpcode, rhs: u32) -> Result {
754 | let (lhs_value, infix_op) = to_equivalent_infix(op);
755 | execute_op(lhs_value, rhs, &infix_op)
756 | }
757 |
758 | fn to_equivalent_infix(op: &ExpressionPrefixOpcode) -> (u32, ExpressionInfixOpcode) {
759 | match op {
760 | ExpressionPrefixOpcode::Sub => (0, ExpressionInfixOpcode::Sub),
761 | ExpressionPrefixOpcode::BoolNot => (0, ExpressionInfixOpcode::Eq),
762 | ExpressionPrefixOpcode::Complement => (u32::MAX, ExpressionInfixOpcode::BitXor),
763 | }
764 | }
765 |
766 | #[cfg(test)]
767 | mod tests {
768 | use super::*;
769 | use circom_program_structure::ast::{ExpressionInfixOpcode, ExpressionPrefixOpcode};
770 |
771 | #[test]
772 | fn test_execute_op() {
773 | assert_eq!(execute_op(3, 4, &ExpressionInfixOpcode::Add).unwrap(), 7);
774 | assert_eq!(execute_op(10, 5, &ExpressionInfixOpcode::Sub).unwrap(), 5);
775 | assert_eq!(execute_op(6, 3, &ExpressionInfixOpcode::Mul).unwrap(), 18);
776 | assert_eq!(execute_op(9, 3, &ExpressionInfixOpcode::Div).unwrap(), 3);
777 | assert_eq!(execute_op(7, 3, &ExpressionInfixOpcode::Mod).unwrap(), 1);
778 | assert_eq!(execute_op(2, 3, &ExpressionInfixOpcode::Pow).unwrap(), 8);
779 | assert_eq!(
780 | execute_op(8, 2, &ExpressionInfixOpcode::ShiftL).unwrap(),
781 | 32
782 | );
783 | assert_eq!(execute_op(8, 2, &ExpressionInfixOpcode::ShiftR).unwrap(), 2);
784 | assert_eq!(execute_op(5, 5, &ExpressionInfixOpcode::Eq).unwrap(), 1);
785 | assert_eq!(execute_op(5, 4, &ExpressionInfixOpcode::NotEq).unwrap(), 1);
786 | assert_eq!(execute_op(1, 0, &ExpressionInfixOpcode::BoolOr).unwrap(), 1);
787 | assert_eq!(
788 | execute_op(1, 1, &ExpressionInfixOpcode::BoolAnd).unwrap(),
789 | 1
790 | );
791 | assert_eq!(execute_op(1, 1, &ExpressionInfixOpcode::BitOr).unwrap(), 1);
792 | assert_eq!(execute_op(1, 1, &ExpressionInfixOpcode::BitAnd).unwrap(), 1);
793 | assert_eq!(execute_op(1, 1, &ExpressionInfixOpcode::BitXor).unwrap(), 0);
794 | }
795 |
796 | #[test]
797 | fn test_execute_op_errors() {
798 | assert!(execute_op(10, 0, &ExpressionInfixOpcode::Div).is_err());
799 | assert!(execute_op(10, 0, &ExpressionInfixOpcode::IntDiv).is_err());
800 | assert!(execute_op(10, 0, &ExpressionInfixOpcode::Mod).is_err());
801 | }
802 |
803 | #[test]
804 | fn test_execute_prefix_op() {
805 | assert_eq!(
806 | execute_prefix_op(&ExpressionPrefixOpcode::Sub, 5)
807 | .unwrap_err()
808 | .to_string(),
809 | "Operation error: Subtraction underflow"
810 | ); // 0 - 5
811 | assert_eq!(
812 | execute_prefix_op(&ExpressionPrefixOpcode::BoolNot, 0).unwrap(),
813 | 1
814 | ); // !0 == 1
815 | assert_eq!(
816 | execute_prefix_op(&ExpressionPrefixOpcode::BoolNot, 1).unwrap(),
817 | 0
818 | ); // !1 == 0
819 | assert_eq!(
820 | execute_prefix_op(&ExpressionPrefixOpcode::Complement, 0b1010).unwrap(),
821 | 0b1111_1111_1111_1111_1111_1111_1111_0101
822 | ); // ~0b1010
823 | }
824 |
825 | #[test]
826 | fn test_to_equivalent_infix() {
827 | let (value, opcode) = to_equivalent_infix(&ExpressionPrefixOpcode::Sub);
828 | assert_eq!(value, 0);
829 | assert!(matches!(opcode, ExpressionInfixOpcode::Sub));
830 |
831 | let (value, opcode) = to_equivalent_infix(&ExpressionPrefixOpcode::BoolNot);
832 | assert_eq!(value, 0);
833 | assert!(matches!(opcode, ExpressionInfixOpcode::Eq));
834 |
835 | let (value, opcode) = to_equivalent_infix(&ExpressionPrefixOpcode::Complement);
836 | assert_eq!(value, u32::MAX);
837 | assert!(matches!(opcode, ExpressionInfixOpcode::BitXor));
838 | }
839 | }
840 |
--------------------------------------------------------------------------------
/src/program.rs:
--------------------------------------------------------------------------------
1 | //! # Program Module
2 | //!
3 | //! This module processes the circom input program to build the arithmetic circuit.
4 |
5 | use crate::{
6 | circom::{parser::parse_project, type_analysis::analyse_project},
7 | cli::Args,
8 | compiler::{CircuitError, Compiler},
9 | process::{process_expression, process_statements},
10 | runtime::{DataAccess, DataType, Runtime, RuntimeError},
11 | };
12 | use bristol_circuit::BristolCircuitError;
13 | use circom_program_structure::ast::Expression;
14 | use std::io;
15 | use thiserror::Error;
16 |
17 | /// Parses a given Circom program and constructs an arithmetic circuit from it.
18 | pub fn compile(args: &Args) -> Result {
19 | let mut compiler = Compiler::new();
20 | let mut runtime = Runtime::new();
21 | let mut program_archive = parse_project(args)?;
22 |
23 | analyse_project(&mut program_archive)?;
24 |
25 | match program_archive.get_main_expression() {
26 | Expression::Call { id, args, .. } => {
27 | let template_data = program_archive.get_template_data(id);
28 |
29 | // Get values
30 | let mut values: Vec