├── .github └── workflows │ └── continuous-integration-workflow.yml ├── .gitignore ├── Cargo.toml ├── README.md ├── benches ├── multiplicative_inverse.rs └── nth_root.rs └── src ├── bimap_util.rs ├── binary_arithmetic.rs ├── bitwise_operations.rs ├── boolean_algebra.rs ├── comparisons.rs ├── constraint.rs ├── curves ├── edwards.rs ├── jubjub.rs ├── mod.rs ├── montgomery.rs └── weierstrass.rs ├── davies_meyer.rs ├── expression.rs ├── field.rs ├── field_arithmetic.rs ├── gadget.rs ├── gadget_builder.rs ├── gadget_traits.rs ├── group.rs ├── lcg.rs ├── lib.rs ├── matrices.rs ├── merkle_damgard.rs ├── merkle_trees.rs ├── mimc.rs ├── miyaguchi_preneel.rs ├── permutations.rs ├── poseidon.rs ├── random_access.rs ├── rescue.rs ├── signature.rs ├── sorting.rs ├── splitting.rs ├── sponge.rs ├── test_util.rs ├── util.rs ├── verify_permutation.rs ├── wire.rs ├── wire_values.rs └── witness_generator.rs /.github/workflows/continuous-integration-workflow.yml: -------------------------------------------------------------------------------- 1 | name: Continuous Integration 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: ~ 7 | 8 | jobs: 9 | test-std: 10 | name: Test with std 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@master 14 | - name: Test 15 | uses: icepuma/rust-action@master 16 | with: 17 | args: cargo test 18 | test-no-std: 19 | name: Test without std 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@master 23 | - name: Test 24 | uses: icepuma/rust-action@master 25 | with: 26 | args: cargo test --no-default-features 27 | build-wasm: 28 | name: Build with a WASM target 29 | runs-on: ubuntu-latest 30 | steps: 31 | - uses: actions/checkout@master 32 | - name: Check for errors with a WASM target 33 | uses: icepuma/rust-action@master 34 | with: 35 | args: rustup target add wasm32-unknown-unknown && cargo check --target wasm32-unknown-unknown 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | *.iml 5 | .idea/ 6 | .vscode -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "r1cs" 3 | description = "A library for building R1CS gadgets" 4 | version = "0.4.7" 5 | authors = ["Daniel Lubarov ", "Brendan Farmer "] 6 | readme = "README.md" 7 | license = "MIT OR Apache-2.0" 8 | repository = "https://github.com/mir-protocol/r1cs" 9 | documentation = "https://docs.rs/r1cs" 10 | keywords = ["R1CS", "cryptography", "SNARK"] 11 | categories = ["cryptography", "no-std"] 12 | maintenance = { status = "actively-developed" } 13 | edition = "2018" 14 | 15 | [features] 16 | default = ["std"] 17 | std = ["num/std", "num-traits/std", "itertools/use_std", "bimap/std"] 18 | 19 | [dev-dependencies] 20 | criterion = "0.3.5" 21 | 22 | [dependencies] 23 | bimap = { version = "0.4.0" } 24 | itertools = { version = "0.8.0" } 25 | num = { version = "0.4.0", features = ["rand"] } 26 | num-traits = { version = "0.2.14" } 27 | 28 | [[bench]] 29 | name = "nth_root" 30 | harness = false 31 | 32 | [[bench]] 33 | name = "multiplicative_inverse" 34 | harness = false 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # r1cs [![Crates.io](https://img.shields.io/crates/v/r1cs)](https://crates.io/crates/r1cs) [![docs.rs](https://docs.rs/r1cs/badge.svg)](https://docs.rs/r1cs) 2 | 3 | This is a rust library for building R1CS gadgets over prime fields, which are useful in SNARKs and other argument systems. 4 | 5 | An R1CS instance is defined by three matrices, `A`, `B` and `C`. These encode the following NP-complete decision problem: does there exist a witness vector `w` such that `Aw ∘ Bw = Cw`? 6 | 7 | A *gadget* for some R1CS instance takes a set of inputs, which are a subset of the witness vector. If the given inputs are valid, it extends the input set into a complete witness vector which satisfies the R1CS instance. 8 | 9 | 10 | ## Features 11 | 12 | The goal of this library is to make SNARK programming easy. To that end, we support a broad set of features, including some fairly high-level abstractions: 13 | 14 | - Basic operations on field elements, such as multiplication, division, and comparisons 15 | - Type-safe boolean operations, such as `GadgetBuilder::and` and `GadgetBuilder::bitwise_and` 16 | - Type-safe binary operations, such as `GadgetBuilder::binary_sum` 17 | - `GadgetBuilder::assert_permutation`, which efficiently verifies a permutation using an AS-Waksman network 18 | - Methods for sorting lists of expressions, such as `GadgetBuilder::sort_ascending` 19 | - Methods for working with Merkle trees, such as `GadgetBuilder::merkle_tree_root` 20 | - Common cryptographic constructions such as Merkle-Damgård, Davies-Meyer, and Sponge functions 21 | - R1CS-friendly primitives like MiMC, Poseidon and Rescue 22 | 23 | 24 | ## Core types 25 | 26 | `Field` is a trait representing prime fields. An `Element` is an element of the prime field `F`. 27 | 28 | A `Wire` is an element of the witness vector. An `Expression` is a linear combination of wires. 29 | 30 | A `BooleanWire` is a `Wire` which has been constrained in such a way that it can only equal 0 or 1. Similarly, a `BooleanExpression` is an `Expression` which has been so constrained. 31 | 32 | A `BinaryWire` is a vector of `BooleanWire`s. Similarly, a `BinaryExpression` is a vector of `BooleanExpression`s. 33 | 34 | 35 | ## Basic example 36 | 37 | Here's a simple gadget which computes the cube of a BN128 field element: 38 | 39 | ```rust 40 | // Create a gadget which takes a single input, x, and computes x*x*x. 41 | let mut builder = GadgetBuilder::::new(); 42 | let x = builder.wire(); 43 | let x_exp = Expression::from(x); 44 | let x_squared = builder.product(&x_exp, &x_exp); 45 | let x_cubed = builder.product(&x_squared, &x_exp); 46 | let gadget = builder.build(); 47 | 48 | // This structure maps wires to their (field element) values. Since 49 | // x is our input, we will assign it a value before executing the 50 | // gadget. Other wires will be computed by the gadget. 51 | let mut values = values!(x => 5u8.into()); 52 | 53 | // Execute the gadget and assert that all constraints were satisfied. 54 | let constraints_satisfied = gadget.execute(&mut values); 55 | assert!(constraints_satisfied); 56 | 57 | // Check the result. 58 | assert_eq!(Element::from(125u8), x_cubed.evaluate(&values)); 59 | ``` 60 | 61 | This can also be done more succinctly with `builder.exp(x_exp, 3)`, which performs exponentiation by squaring. 62 | 63 | 64 | ## Custom fields 65 | 66 | You can define a custom field by implementing the `Field` trait. As an example, here's the definition of `Bn128` which was referenced above: 67 | 68 | ```rust 69 | pub struct Bn128 {} 70 | 71 | impl Field for Bn128 { 72 | fn order() -> BigUint { 73 | BigUint::from_str( 74 | "21888242871839275222246405745257275088548364400416034343698204186575808495617" 75 | ).unwrap() 76 | } 77 | } 78 | ``` 79 | 80 | 81 | ## Cryptographic tools 82 | 83 | Suppose we wanted to hash a vector of `Expression`s. One approach would be to take a block cipher like MiMC, transform it into a one-way compression function using the Davies-Meyer construction, and transform that into a hash function using the Merkle-Damgård construction. We could do that like so: 84 | 85 | ```rust 86 | fn hash( 87 | builder: &mut GadgetBuilder, 88 | blocks: &[Expression] 89 | ) -> Expression { 90 | let cipher = MiMCBlockCipher::default(); 91 | let compress = DaviesMeyer::new(cipher); 92 | let hash = MerkleDamgard::new_defaults(compress); 93 | hash.hash(builder, blocks) 94 | } 95 | ``` 96 | 97 | 98 | ## Permutation networks 99 | 100 | To verify that two lists are permutations of one another, you can use `assert_permutation`. This is implemented using AS-Waksman permutation networks, which permute `n` items using roughly `n log_2(n) - n` switches. Each switch involves two constraints: one "is boolean" check, and one constraint for routing. 101 | 102 | Permutation networks make it easy to implement sorting gadgets, which we provide in the form of `sort_ascending` and `sort_descending`. 103 | 104 | 105 | ## Non-determinism 106 | 107 | Suppose we wish to compute the multiplicative inverse of a field element `x`. While this is possible to do in a deterministic arithmetic circuit, it is prohibitively expensive. What we can do instead is have the user compute `x_inv = 1 / x`, provide the result as a witness element, and add a constraint in the R1CS instance to verify that `x * x_inv = 1`. 108 | 109 | `GadgetBuilder` supports such non-deterministic computations via its `generator` method, which can be used like so: 110 | 111 | ```rust 112 | fn inverse(builder: &mut GadgetBuilder, x: Expression) -> Expression { 113 | // Create a new witness element for x_inv. 114 | let x_inv = builder.wire(); 115 | 116 | // Add the constraint x * x_inv = 1. 117 | builder.assert_product(&x, &Expression::from(x_inv), 118 | &Expression::one()); 119 | 120 | // Non-deterministically generate x_inv = 1 / x. 121 | builder.generator( 122 | x.dependencies(), 123 | move |values: &mut WireValues| { 124 | let x_value = x.evaluate(values); 125 | let x_inv_value = x_value.multiplicative_inverse(); 126 | values.set(x_inv, x_inv_value); 127 | }, 128 | ); 129 | 130 | // Output x_inv. 131 | x_inv.into() 132 | } 133 | ``` 134 | 135 | This is roughly equivalent to the built-in `GadgetBuilder::inverse` method, with slight modifications for readability. 136 | 137 | 138 | ## Backends 139 | 140 | The [r1cs-zkinterface](https://crates.io/crates/r1cs-zkinterface) crate can be used to export these gadgets to the standard zkinterface format. 141 | 142 | There is also a direct backend for [bellman](https://crates.io/crates/bellman) via the [r1cs-bellman](https://crates.io/crates/r1cs-bellman) crate. 143 | 144 | 145 | ## Disclaimer 146 | 147 | This code has not been thoroughly reviewed or tested, and should not be used in any production systems. 148 | -------------------------------------------------------------------------------- /benches/multiplicative_inverse.rs: -------------------------------------------------------------------------------- 1 | use criterion::Criterion; 2 | use criterion::criterion_group; 3 | use criterion::criterion_main; 4 | 5 | use r1cs::{Bls12_381, Element, LCG}; 6 | 7 | fn criterion_benchmark(c: &mut Criterion) { 8 | type F = Bls12_381; 9 | 10 | let mut lcg = LCG::new(); 11 | 12 | c.bench_function("1/x", move |b| b.iter(|| { 13 | let x = lcg.next_element::(); 14 | let x_inv = x.multiplicative_inverse(); 15 | assert_eq!(x * x_inv, Element::one()); 16 | })); 17 | } 18 | 19 | criterion_group!(benches, criterion_benchmark); 20 | criterion_main!(benches); 21 | -------------------------------------------------------------------------------- /benches/nth_root.rs: -------------------------------------------------------------------------------- 1 | use criterion::Criterion; 2 | use criterion::criterion_group; 3 | use criterion::criterion_main; 4 | 5 | use r1cs::{Bls12_381, Element, GadgetBuilder, LCG, MonomialPermutation, Permutation, values}; 6 | 7 | fn criterion_benchmark(c: &mut Criterion) { 8 | type F = Bls12_381; 9 | let n = Element::from(5u8); 10 | let x_n = MonomialPermutation::::new(n.clone()); 11 | 12 | let mut builder = GadgetBuilder::::new(); 13 | let residue = builder.wire(); 14 | let root = x_n.inverse(&mut builder, &residue.into()); 15 | let gadget = builder.build(); 16 | let mut lcg = LCG::new(); 17 | 18 | c.bench_function("x^{1/5}", move |b| b.iter(|| { 19 | let residue_value = lcg.next_element(); 20 | let mut values = values!(residue => residue_value.clone()); 21 | gadget.execute(&mut values); 22 | 23 | assert_eq!(root.evaluate(&values).exponentiation(&n), residue_value); 24 | })); 25 | } 26 | 27 | criterion_group!(benches, criterion_benchmark); 28 | criterion_main!(benches); 29 | -------------------------------------------------------------------------------- /src/bimap_util.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | #[cfg(feature = "std")] 5 | use std::collections::BTreeMap; 6 | #[cfg(not(feature = "std"))] 7 | use alloc::collections::btree_map::BTreeMap; 8 | 9 | use std::hash::Hash; 10 | 11 | use bimap::BiMap; 12 | use itertools::enumerate; 13 | 14 | /// Given two lists which are permutations of one another, creates a BiMap which maps an index in 15 | /// one list to an index in the other list with the same associated value. 16 | /// 17 | /// If the lists contain duplicates, then multiple permutations with this property exist, and an 18 | /// arbitrary one of them will be returned. 19 | pub fn bimap_from_lists(a: Vec, b: Vec) -> BiMap { 20 | assert_eq!(a.len(), b.len(), "Vectors differ in length"); 21 | 22 | let mut b_values_to_indices = BTreeMap::new(); 23 | for (i, value) in enumerate(b) { 24 | b_values_to_indices.entry(value).or_insert_with(Vec::new).push(i); 25 | } 26 | 27 | let mut bimap = BiMap::new(); 28 | for (i, value) in enumerate(a) { 29 | if let Some(j) = b_values_to_indices.get_mut(&value).and_then(Vec::pop) { 30 | bimap.insert(i, j); 31 | } else { 32 | panic!("Value in first list not found in second list"); 33 | } 34 | } 35 | 36 | bimap 37 | } 38 | 39 | #[cfg(test)] 40 | mod tests { 41 | use crate::bimap_util::bimap_from_lists; 42 | #[cfg(not(feature = "std"))] 43 | use alloc::vec::Vec; 44 | 45 | #[test] 46 | fn empty_lists() { 47 | let empty: Vec = Vec::new(); 48 | let bimap = bimap_from_lists(empty.clone(), empty); 49 | assert!(bimap.is_empty()); 50 | } 51 | 52 | #[test] 53 | fn without_duplicates() { 54 | let bimap = bimap_from_lists(vec!['a', 'b', 'c'], vec!['b', 'c', 'a']); 55 | assert_eq!(bimap.get_by_left(&0), Some(&2)); 56 | assert_eq!(bimap.get_by_left(&1), Some(&0)); 57 | assert_eq!(bimap.get_by_left(&2), Some(&1)); 58 | } 59 | 60 | #[test] 61 | fn with_duplicates() { 62 | let first = vec!['a', 'a', 'b']; 63 | let second = vec!['a', 'b', 'a']; 64 | let bimap = bimap_from_lists(first.clone(), second.clone()); 65 | for i in 0..3 { 66 | let j = *bimap.get_by_left(&i).unwrap(); 67 | assert_eq!(first[i], second[j]); 68 | } 69 | } 70 | 71 | #[test] 72 | #[should_panic] 73 | fn lengths_differ() { 74 | bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b']); 75 | } 76 | 77 | #[test] 78 | #[should_panic] 79 | fn not_a_permutation() { 80 | bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b', 'b']); 81 | } 82 | } -------------------------------------------------------------------------------- /src/binary_arithmetic.rs: -------------------------------------------------------------------------------- 1 | use itertools::Itertools; 2 | use num::BigUint; 3 | use num_traits::{One, Zero}; 4 | 5 | use crate::expression::{BinaryExpression, Expression}; 6 | use crate::field::{Element, Field}; 7 | use crate::gadget_builder::GadgetBuilder; 8 | use crate::wire_values::WireValues; 9 | 10 | impl GadgetBuilder { 11 | /// Add two binary expressions in a widening manner. The result will be one bit longer than the 12 | /// longer of the two inputs. 13 | pub fn binary_sum( 14 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 15 | ) -> BinaryExpression { 16 | self.binary_summation(&[x.clone(), y.clone()]) 17 | } 18 | 19 | /// Add two binary expressions, ignoring any overflow. 20 | pub fn binary_sum_ignoring_overflow( 21 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 22 | ) -> BinaryExpression { 23 | self.binary_summation_ignoring_overflow(&[x.clone(), y.clone()]) 24 | } 25 | 26 | /// Add two binary expressions while asserting that overflow does not occur. 27 | pub fn binary_sum_asserting_no_overflow( 28 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 29 | ) -> BinaryExpression { 30 | self.binary_summation_asserting_no_overflow(&[x.clone(), y.clone()]) 31 | } 32 | 33 | /// Add an arbitrary number of binary expressions. The result will be at least one bit longer than the 34 | /// longest input. 35 | pub fn binary_summation(&mut self, terms: &[BinaryExpression]) -> BinaryExpression { 36 | // We will non-deterministically generate the sum bits, join the binary expressions, and 37 | // verify the summation on those field elements. 38 | 39 | let mut max_sum = BigUint::zero(); 40 | for term in terms { 41 | let max_term = (BigUint::one() << term.len()) - BigUint::one(); 42 | max_sum += max_term; 43 | } 44 | let sum_bits = max_sum.bits() as usize; 45 | 46 | // TODO: Generalize this addition function to support larger operands. 47 | // We can split the bits into chunks and perform addition on joined chunks. 48 | assert!(sum_bits < Element::::max_bits(), 49 | "Binary operands are too large to fit an a field element."); 50 | 51 | let sum_wire = self.binary_wire(sum_bits); 52 | let sum = BinaryExpression::from(&sum_wire); 53 | 54 | let sum_of_terms = Expression::sum_of_expressions( 55 | &terms.iter().map(BinaryExpression::join).collect_vec()); 56 | self.assert_equal(&sum_of_terms, &sum.join()); 57 | 58 | self.generator( 59 | sum_of_terms.dependencies(), 60 | move |values: &mut WireValues| { 61 | let sum_element = sum_of_terms.evaluate(values); 62 | let sum_biguint = sum_element.to_biguint(); 63 | values.set_binary_unsigned(&sum_wire, sum_biguint); 64 | }, 65 | ); 66 | 67 | sum 68 | } 69 | 70 | /// Add an arbitrary number of binary expressions, ignoring any overflow. 71 | pub fn binary_summation_ignoring_overflow(&mut self, terms: &[BinaryExpression]) 72 | -> BinaryExpression { 73 | let input_bits = terms.iter().fold(0, |x, y| x.max(y.len())); 74 | let mut sum = self.binary_summation(terms); 75 | sum.truncate(input_bits); 76 | sum 77 | } 78 | 79 | /// Add an arbitrary number of binary expressions, asserting that overflow does not occur. 80 | pub fn binary_summation_asserting_no_overflow(&mut self, terms: &[BinaryExpression]) 81 | -> BinaryExpression { 82 | let input_bits = terms.iter().fold(0, |x, y| x.max(y.len())); 83 | let mut sum = self.binary_summation(terms); 84 | let carry = BinaryExpression { bits: sum.bits[input_bits..].to_vec() }; 85 | self.binary_assert_zero(&carry); 86 | sum.truncate(input_bits); 87 | sum 88 | } 89 | 90 | /// Assert that a binary expression is zero. 91 | pub fn binary_assert_zero(&mut self, x: &BinaryExpression) { 92 | // The expression may be too large to fit in a single field element, so we will join chunks 93 | // and assert that each chunk is zero. The chunk size is chosen such that overflow is 94 | // impossible, even if all bits are 1. 95 | let bits = Element::::max_bits() - 1; 96 | for chunk in x.chunks(bits) { 97 | self.assert_zero(&chunk.join()); 98 | } 99 | } 100 | } 101 | 102 | #[cfg(test)] 103 | mod tests { 104 | use num::BigUint; 105 | use num_traits::Zero; 106 | 107 | use crate::expression::BinaryExpression; 108 | use crate::gadget_builder::GadgetBuilder; 109 | use crate::test_util::F257; 110 | 111 | #[test] 112 | fn binary_sum() { 113 | let mut builder = GadgetBuilder::::new(); 114 | let x = builder.binary_wire(4); 115 | let y = builder.binary_wire(4); 116 | let sum = builder.binary_sum(&BinaryExpression::from(&x), &BinaryExpression::from(&y)); 117 | let gadget = builder.build(); 118 | 119 | // 10 + 3 = 13. 120 | let mut values = binary_unsigned_values!( 121 | &x => &BigUint::from(10u8), &y => &BigUint::from(3u8)); 122 | assert!(gadget.execute(&mut values)); 123 | assert_eq!(BigUint::from(13u8), sum.evaluate(&values)); 124 | 125 | // 10 + 11 = 21. 126 | let mut values = binary_unsigned_values!( 127 | &x => &BigUint::from(10u8), &y => &BigUint::from(11u8)); 128 | assert!(gadget.execute(&mut values)); 129 | assert_eq!(BigUint::from(21u8), sum.evaluate(&values)); 130 | } 131 | 132 | #[test] 133 | fn binary_sum_ignoring_overflow() { 134 | let mut builder = GadgetBuilder::::new(); 135 | let x = builder.binary_wire(4); 136 | let y = builder.binary_wire(4); 137 | let sum = builder.binary_sum_ignoring_overflow( 138 | &BinaryExpression::from(&x), &BinaryExpression::from(&y)); 139 | let gadget = builder.build(); 140 | 141 | // 10 + 3 = 13. 142 | let mut values = binary_unsigned_values!( 143 | &x => &BigUint::from(10u8), &y => &BigUint::from(3u8)); 144 | assert!(gadget.execute(&mut values)); 145 | assert_eq!(BigUint::from(13u8), sum.evaluate(&values)); 146 | 147 | // 10 + 11 = 21 % 16 = 5. 148 | let mut values = binary_unsigned_values!( 149 | &x => &BigUint::from(10u8), &y => &BigUint::from(11u8)); 150 | assert!(gadget.execute(&mut values)); 151 | assert_eq!(BigUint::from(5u8), sum.evaluate(&values)); 152 | } 153 | 154 | #[test] 155 | fn binary_sum_asserting_no_overflow() { 156 | let mut builder = GadgetBuilder::::new(); 157 | let x = builder.binary_wire(4); 158 | let y = builder.binary_wire(4); 159 | let sum = builder.binary_sum_asserting_no_overflow( 160 | &BinaryExpression::from(&x), &BinaryExpression::from(&y)); 161 | let gadget = builder.build(); 162 | 163 | // 10 + 3 = 13. 164 | let mut values = binary_unsigned_values!( 165 | &x => &BigUint::from(10u8), &y => &BigUint::from(3u8)); 166 | assert!(gadget.execute(&mut values)); 167 | assert_eq!(BigUint::from(13u8), sum.evaluate(&values)); 168 | 169 | // 10 + 11 = [error]. 170 | let mut values = binary_unsigned_values!( 171 | &x => &BigUint::from(10u8), &y => &BigUint::from(11u8)); 172 | assert!(!gadget.execute(&mut values)); 173 | } 174 | 175 | // TODO: Test inputs with differing lengths. 176 | 177 | // TODO: Test summations with more than two terms. 178 | 179 | #[test] 180 | fn assert_zero_f257() { 181 | let mut builder = GadgetBuilder::::new(); 182 | let x_bits = 10; 183 | let x_wire = builder.binary_wire(x_bits); 184 | let x_exp = BinaryExpression::from(&x_wire); 185 | builder.binary_assert_zero(&x_exp); 186 | let gadget = builder.build(); 187 | 188 | let mut values_0 = binary_unsigned_values!(&x_wire => &BigUint::zero()); 189 | assert!(gadget.execute(&mut values_0)); 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /src/bitwise_operations.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with bitwise operations such as rotations, bitwise AND, and 2 | //! so forth. 3 | 4 | use crate::expression::{BinaryExpression, BooleanExpression}; 5 | use crate::field::Field; 6 | use crate::gadget_builder::GadgetBuilder; 7 | 8 | impl GadgetBuilder { 9 | /// The bitwise negation of a binary expression `x`, a.k.a. `~x`. 10 | pub fn bitwise_not(&mut self, x: &BinaryExpression) -> BinaryExpression { 11 | let bits = x.bits.iter() 12 | .map(|w| self.not(w)) 13 | .collect(); 14 | BinaryExpression { bits } 15 | } 16 | 17 | /// The bitwise conjunction of two binary expressions `x` and `y`, a.k.a. `x & y`. 18 | pub fn bitwise_and( 19 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 20 | ) -> BinaryExpression { 21 | assert_eq!(x.len(), y.len()); 22 | let l = x.len(); 23 | let bits = (0..l).map(|i| 24 | self.and(&x.bits[i], &y.bits[i]) 25 | ).collect(); 26 | BinaryExpression { bits } 27 | } 28 | 29 | /// The bitwise disjunction of two binary expressions `x` and `y`, a.k.a. `x | y`. 30 | pub fn bitwise_or( 31 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 32 | ) -> BinaryExpression { 33 | assert_eq!(x.len(), y.len()); 34 | let l = x.len(); 35 | let bits = (0..l).map(|i| 36 | self.or(&x.bits[i], &y.bits[i]) 37 | ).collect(); 38 | BinaryExpression { bits } 39 | } 40 | 41 | /// The bitwise exclusive disjunction of two binary expressions `x` and `y`, a.k.a. `x ^ y`. 42 | pub fn bitwise_xor( 43 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 44 | ) -> BinaryExpression { 45 | assert_eq!(x.len(), y.len()); 46 | let l = x.len(); 47 | let bits = (0..l).map(|i| 48 | self.xor(&x.bits[i], &y.bits[i]) 49 | ).collect(); 50 | BinaryExpression { bits } 51 | } 52 | 53 | /// Rotate bits in the direction of increasing significance. This is equivalent to "left rotate" 54 | /// in most programming languages. 55 | pub fn bitwise_rotate_inc_significance( 56 | &mut self, x: &BinaryExpression, n: usize, 57 | ) -> BinaryExpression { 58 | let l = x.len(); 59 | let bits = (0..l).map(|i| { 60 | // This is equivalent to (i - n) mod l. 61 | let from_idx = (l + i - n % l) % l; 62 | x.bits[from_idx].clone() 63 | }).collect(); 64 | BinaryExpression { bits } 65 | } 66 | 67 | /// Rotate bits in the direction of increasing significance. This is equivalent to "right 68 | /// rotate" in most programming languages. 69 | pub fn bitwise_rotate_dec_significance( 70 | &mut self, x: &BinaryExpression, n: usize, 71 | ) -> BinaryExpression { 72 | let l = x.len(); 73 | let bits = (0..l).map(|i| { 74 | let from_idx = (i + n) % l; 75 | x.bits[from_idx].clone() 76 | }).collect(); 77 | BinaryExpression { bits } 78 | } 79 | 80 | /// Shift bits in the direction of increasing significance, discarding bits on the most 81 | /// significant end and inserting zeros on the least significant end. This is equivalent to 82 | /// "left shift" in most programming languages. 83 | pub fn bitwise_shift_inc_significance( 84 | &mut self, x: &BinaryExpression, n: usize, 85 | ) -> BinaryExpression { 86 | let bits = (0..x.len()).map(|i| { 87 | if i < n { 88 | BooleanExpression::_false() 89 | } else { 90 | let from_idx = i - n; 91 | x.bits[from_idx].clone() 92 | } 93 | }).collect(); 94 | BinaryExpression { bits } 95 | } 96 | 97 | /// Shift bits in the direction of decreasing significance, discarding bits on the least 98 | /// significant end and inserting zeros on the most significant end. This is equivalent to 99 | /// "right shift" in most programming languages. 100 | pub fn bitwise_shift_dec_significance( 101 | &mut self, x: &BinaryExpression, n: usize, 102 | ) -> BinaryExpression { 103 | let l = x.len(); 104 | let bits = (0..l).map(|i| { 105 | if i < l - n { 106 | let from_idx = i + n; 107 | x.bits[from_idx].clone() 108 | } else { 109 | BooleanExpression::_false() 110 | } 111 | }).collect(); 112 | BinaryExpression { bits } 113 | } 114 | } 115 | 116 | #[cfg(test)] 117 | mod tests { 118 | use num::BigUint; 119 | 120 | use crate::expression::BinaryExpression; 121 | use crate::gadget_builder::GadgetBuilder; 122 | use crate::test_util::F257; 123 | 124 | #[test] 125 | fn bitwise_not() { 126 | let mut builder = GadgetBuilder::::new(); 127 | let x = builder.binary_wire(8); 128 | let not_x = builder.bitwise_not(&BinaryExpression::from(&x)); 129 | let gadget = builder.build(); 130 | 131 | // ~00010011 = 11101100. 132 | let mut values = binary_unsigned_values!(&x => &BigUint::from(0b00010011u32)); 133 | assert!(gadget.execute(&mut values)); 134 | assert_eq!(BigUint::from(0b11101100u32), not_x.evaluate(&values)); 135 | } 136 | 137 | #[test] 138 | fn bitwise_and() { 139 | let mut builder = GadgetBuilder::::new(); 140 | let x = builder.binary_wire(8); 141 | let y = builder.binary_wire(8); 142 | let x_and_y = builder.bitwise_and(&BinaryExpression::from(&x), &BinaryExpression::from(&y)); 143 | let gadget = builder.build(); 144 | 145 | // 0 & 0 = 0. 146 | let mut values_0_0 = binary_unsigned_values!( 147 | &x => &BigUint::from(0u32), 148 | &y => &BigUint::from(0u32)); 149 | assert!(gadget.execute(&mut values_0_0)); 150 | assert_eq!(BigUint::from(0u32), x_and_y.evaluate(&values_0_0)); 151 | 152 | // 255 & 0 = 0. 153 | let mut values_255_0 = binary_unsigned_values!( 154 | &x => &BigUint::from(0b11111111u32), 155 | &y => &BigUint::from(0u32)); 156 | assert!(gadget.execute(&mut values_255_0)); 157 | assert_eq!(BigUint::from(0u32), x_and_y.evaluate(&values_255_0)); 158 | 159 | // 255 & 255 = 255. 160 | let mut values_255_255 = binary_unsigned_values!( 161 | &x => &BigUint::from(0b11111111u32), 162 | &y => &BigUint::from(0b11111111u32)); 163 | assert!(gadget.execute(&mut values_255_255)); 164 | assert_eq!(BigUint::from(0b11111111u32), x_and_y.evaluate(&values_255_255)); 165 | 166 | // 11111100 & 00111111 = 00111100. 167 | let mut values_11111100_00111111 = binary_unsigned_values!( 168 | &x => &BigUint::from(0b11111100u32), 169 | &y => &BigUint::from(0b00111111u32)); 170 | assert!(gadget.execute(&mut values_11111100_00111111)); 171 | assert_eq!(BigUint::from(0b00111100u32), x_and_y.evaluate(&values_11111100_00111111)); 172 | } 173 | 174 | #[test] 175 | fn bitwise_rotate_dec_significance() { 176 | let mut builder = GadgetBuilder::::new(); 177 | let x = builder.binary_wire(8); 178 | let x_rot = builder.bitwise_rotate_dec_significance(&BinaryExpression::from(&x), 3); 179 | let gadget = builder.build(); 180 | 181 | // 00000000 >> 3 = 00000000. 182 | let mut values_zero = binary_unsigned_values!(&x => &BigUint::from(0u32)); 183 | assert!(gadget.execute(&mut values_zero)); 184 | assert_eq!(BigUint::from(0u32), x_rot.evaluate(&values_zero)); 185 | 186 | // 00010011 >> 3 = 01100010. 187 | let mut values_nonzero = binary_unsigned_values!(&x => &BigUint::from(0b00010011u32)); 188 | assert!(gadget.execute(&mut values_nonzero)); 189 | assert_eq!(BigUint::from(0b01100010u32), x_rot.evaluate(&values_nonzero)); 190 | } 191 | 192 | #[test] 193 | fn bitwise_rotate_dec_significance_multiple_wraps() { 194 | let mut builder = GadgetBuilder::::new(); 195 | let x = builder.binary_wire(8); 196 | let x_rot = builder.bitwise_rotate_dec_significance(&BinaryExpression::from(&x), 19); 197 | let gadget = builder.build(); 198 | 199 | // 00010011 >> 19 = 00010011 >> 3 = 01100010. 200 | let mut values = binary_unsigned_values!(&x => &BigUint::from(0b00010011u32)); 201 | assert!(gadget.execute(&mut values)); 202 | assert_eq!(BigUint::from(0b01100010u32), x_rot.evaluate(&values)); 203 | } 204 | 205 | // TODO: Tests for shift methods 206 | } -------------------------------------------------------------------------------- /src/boolean_algebra.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with boolean algebra methods. 2 | 3 | use crate::expression::{BooleanExpression, Expression}; 4 | use crate::field::Field; 5 | use crate::gadget_builder::GadgetBuilder; 6 | 7 | impl GadgetBuilder { 8 | /// The negation of a boolean value. 9 | pub fn not(&mut self, x: &BooleanExpression) -> BooleanExpression { 10 | BooleanExpression::new_unsafe(Expression::one() - x.expression()) 11 | } 12 | 13 | /// The conjunction of two boolean values. 14 | pub fn and( 15 | &mut self, x: &BooleanExpression, y: &BooleanExpression, 16 | ) -> BooleanExpression { 17 | BooleanExpression::new_unsafe(self.product(x.expression(), y.expression())) 18 | } 19 | 20 | /// The disjunction of two boolean values. 21 | pub fn or( 22 | &mut self, x: &BooleanExpression, y: &BooleanExpression, 23 | ) -> BooleanExpression { 24 | let x_exp = x.expression(); 25 | let y_exp = y.expression(); 26 | BooleanExpression::new_unsafe( 27 | x_exp + y_exp - self.product(x_exp, y_exp)) 28 | } 29 | 30 | /// The exclusive disjunction of two boolean values. 31 | pub fn xor( 32 | &mut self, x: &BooleanExpression, y: &BooleanExpression, 33 | ) -> BooleanExpression { 34 | let x_exp = x.expression(); 35 | let y_exp = y.expression(); 36 | BooleanExpression::new_unsafe(x_exp + y_exp - self.product(x_exp, y_exp) * 2u128) 37 | } 38 | } 39 | 40 | #[cfg(test)] 41 | mod tests { 42 | use crate::expression::BooleanExpression; 43 | use crate::gadget_builder::GadgetBuilder; 44 | use crate::test_util::F257; 45 | 46 | #[test] 47 | fn and() { 48 | let mut builder = GadgetBuilder::::new(); 49 | let (x, y) = (builder.boolean_wire(), builder.boolean_wire()); 50 | let and = builder.and(&BooleanExpression::from(x), &BooleanExpression::from(y)); 51 | let gadget = builder.build(); 52 | 53 | let mut values00 = boolean_values!(x => false, y => false); 54 | assert!(gadget.execute(&mut values00)); 55 | assert_eq!(false, and.evaluate(&values00)); 56 | 57 | let mut values01 = boolean_values!(x => false, y => true); 58 | assert!(gadget.execute(&mut values01)); 59 | assert_eq!(false, and.evaluate(&values01)); 60 | 61 | let mut values10 = boolean_values!(x => true, y => false); 62 | assert!(gadget.execute(&mut values10)); 63 | assert_eq!(false, and.evaluate(&values10)); 64 | 65 | let mut values11 = boolean_values!(x => true, y => true); 66 | assert!(gadget.execute(&mut values11)); 67 | assert_eq!(true, and.evaluate(&values11)); 68 | } 69 | 70 | #[test] 71 | fn or() { 72 | let mut builder = GadgetBuilder::::new(); 73 | let (x, y) = (builder.boolean_wire(), builder.boolean_wire()); 74 | let or = builder.or(&BooleanExpression::from(x), &BooleanExpression::from(y)); 75 | let gadget = builder.build(); 76 | 77 | let mut values00 = boolean_values!(x => false, y => false); 78 | assert!(gadget.execute(&mut values00)); 79 | assert_eq!(false, or.evaluate(&values00)); 80 | 81 | let mut values01 = boolean_values!(x => false, y => true); 82 | assert!(gadget.execute(&mut values01)); 83 | assert_eq!(true, or.evaluate(&values01)); 84 | 85 | let mut values10 = boolean_values!(x => true, y => false); 86 | assert!(gadget.execute(&mut values10)); 87 | assert_eq!(true, or.evaluate(&values10)); 88 | 89 | let mut values11 = boolean_values!(x => true, y => true); 90 | assert!(gadget.execute(&mut values11)); 91 | assert_eq!(true, or.evaluate(&values11)); 92 | } 93 | 94 | #[test] 95 | fn xor() { 96 | let mut builder = GadgetBuilder::::new(); 97 | let (x, y) = (builder.boolean_wire(), builder.boolean_wire()); 98 | let xor = builder.xor(&BooleanExpression::from(x), &BooleanExpression::from(y)); 99 | let gadget = builder.build(); 100 | 101 | let mut values00 = boolean_values!(x => false, y => false); 102 | assert!(gadget.execute(&mut values00)); 103 | assert_eq!(false, xor.evaluate(&values00)); 104 | 105 | let mut values01 = boolean_values!(x => false, y => true); 106 | assert!(gadget.execute(&mut values01)); 107 | assert_eq!(true, xor.evaluate(&values01)); 108 | 109 | let mut values10 = boolean_values!(x => true, y => false); 110 | assert!(gadget.execute(&mut values10)); 111 | assert_eq!(true, xor.evaluate(&values10)); 112 | 113 | let mut values11 = boolean_values!(x => true, y => true); 114 | assert!(gadget.execute(&mut values11)); 115 | assert_eq!(false, xor.evaluate(&values11)); 116 | } 117 | } -------------------------------------------------------------------------------- /src/comparisons.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with methods for comparing native field elements. 2 | 3 | #[cfg(not(feature = "std"))] 4 | use alloc::vec::Vec; 5 | 6 | use itertools::enumerate; 7 | 8 | use crate::expression::{BinaryExpression, BooleanExpression, Expression}; 9 | use crate::field::{Element, Field}; 10 | use crate::gadget_builder::GadgetBuilder; 11 | use crate::wire_values::WireValues; 12 | use crate::util::concat; 13 | 14 | impl GadgetBuilder { 15 | /// Assert that `x < y`. 16 | pub fn assert_lt(&mut self, x: &Expression, y: &Expression) { 17 | let lt = self.lt(x, y); 18 | self.assert_true(<); 19 | } 20 | 21 | /// Assert that `x <= y`. 22 | pub fn assert_le(&mut self, x: &Expression, y: &Expression) { 23 | let le = self.le(x, y); 24 | self.assert_true(&le); 25 | } 26 | 27 | /// Assert that `x > y`. 28 | pub fn assert_gt(&mut self, x: &Expression, y: &Expression) { 29 | let gt = self.gt(x, y); 30 | self.assert_true(>); 31 | } 32 | 33 | /// Assert that `x >= y`. 34 | pub fn assert_ge(&mut self, x: &Expression, y: &Expression) { 35 | let ge = self.ge(x, y); 36 | self.assert_true(&ge); 37 | } 38 | 39 | /// Assert that `x < y`. 40 | pub fn assert_lt_binary(&mut self, x: &BinaryExpression, y: &BinaryExpression) { 41 | let lt = self.lt_binary(x, y); 42 | self.assert_true(<); 43 | } 44 | 45 | /// Assert that `x <= y`. 46 | pub fn assert_le_binary(&mut self, x: &BinaryExpression, y: &BinaryExpression) { 47 | let le = self.le_binary(x, y); 48 | self.assert_true(&le); 49 | } 50 | 51 | /// Assert that `x > y`. 52 | pub fn assert_gt_binary(&mut self, x: &BinaryExpression, y: &BinaryExpression) { 53 | let gt = self.gt_binary(x, y); 54 | self.assert_true(>); 55 | } 56 | 57 | /// Assert that `x >= y`. 58 | pub fn assert_ge_binary(&mut self, x: &BinaryExpression, y: &BinaryExpression) 59 | { 60 | let ge = self.ge_binary(x, y); 61 | self.assert_true(&ge); 62 | } 63 | 64 | /// Returns `x < y`. 65 | pub fn lt(&mut self, x: &Expression, y: &Expression) -> BooleanExpression { 66 | self.cmp(x, y, true, true) 67 | } 68 | 69 | /// Returns `x <= y`. 70 | pub fn le(&mut self, x: &Expression, y: &Expression) -> BooleanExpression { 71 | self.cmp(x, y, true, false) 72 | } 73 | 74 | /// Returns `x > y`. 75 | pub fn gt(&mut self, x: &Expression, y: &Expression) -> BooleanExpression { 76 | self.cmp(x, y, false, true) 77 | } 78 | 79 | /// Returns `x >= y`. 80 | pub fn ge(&mut self, x: &Expression, y: &Expression) -> BooleanExpression { 81 | self.cmp(x, y, false, false) 82 | } 83 | 84 | /// Returns `x < y`. 85 | pub fn lt_binary( 86 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 87 | ) -> BooleanExpression { 88 | self.cmp_binary(x, y, true, true) 89 | } 90 | 91 | /// Returns `x <= y`. 92 | pub fn le_binary( 93 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 94 | ) -> BooleanExpression { 95 | self.cmp_binary(x, y, true, false) 96 | } 97 | 98 | /// Returns `x > y`. 99 | pub fn gt_binary( 100 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 101 | ) -> BooleanExpression { 102 | self.cmp_binary(x, y, false, true) 103 | } 104 | 105 | /// Returns `x >= y`. 106 | pub fn ge_binary( 107 | &mut self, x: &BinaryExpression, y: &BinaryExpression, 108 | ) -> BooleanExpression { 109 | self.cmp_binary(x, y, false, false) 110 | } 111 | 112 | fn cmp( 113 | &mut self, x: &Expression, y: &Expression, less: bool, strict: bool, 114 | ) -> BooleanExpression { 115 | let (x_bin, y_bin) = if less { 116 | // We're asserting x <[=] y. We don't need x's canonical encoding, because the 117 | // non-canonical encoding would give x_bin > |F| and thus x_bin > y_bin, rendering the 118 | // instance unsatisfiable. 119 | // TODO: This only holds for assertions, not evaluations. 120 | (self.split_allowing_ambiguity(x), self.split(y)) 121 | } else { 122 | // Similarly, here we're asserting y <[=] x, so we don't need y's canonical encoding. 123 | (self.split(x), self.split_allowing_ambiguity(y)) 124 | }; 125 | self.cmp_binary(&x_bin, &y_bin, less, strict) 126 | } 127 | 128 | // TODO: Consider identifying the first differing chunk with a single field element rather than 129 | // a bitmask. This will mean doing random access later. 130 | fn cmp_binary( 131 | &mut self, 132 | x_bits: &BinaryExpression, 133 | y_bits: &BinaryExpression, 134 | less: bool, strict: bool, 135 | ) -> BooleanExpression { 136 | assert_eq!(x_bits.len(), y_bits.len()); 137 | let operand_bits = x_bits.len(); 138 | 139 | // We will chunk both bit vectors, then have the prover supply a mask which identifies the 140 | // first pair of chunks to differ. Credit to Ahmed Kosba who described this technique. 141 | let chunk_bits = Self::cmp_chunk_bits(operand_bits); 142 | let x_chunks: Vec> = x_bits.chunks(chunk_bits) 143 | .iter().map(BinaryExpression::join).collect(); 144 | let y_chunks: Vec> = y_bits.chunks(chunk_bits) 145 | .iter().map(BinaryExpression::join).collect(); 146 | let chunks = x_chunks.len(); 147 | 148 | // Create a mask bit for each chunk index. masks[i] must equal 1 iff i is the first index 149 | // where the chunks differ, otherwise 0. If no chunks differ, all masks must equal 0. 150 | let mask = self.wires(chunks); 151 | // Each mask bit wire must equal 0 or 1. 152 | for &m in &mask { 153 | self.assert_boolean(&Expression::from(m)); 154 | } 155 | // The sum of all masks must equal 0 or 1, so that at most one mask can equal 1. 156 | let diff_exists = self.assert_boolean(&Expression::sum_of_wires(&mask)); 157 | 158 | { 159 | let x_chunks = x_chunks.clone(); 160 | let y_chunks = y_chunks.clone(); 161 | let mask = mask.clone(); 162 | self.generator( 163 | concat(&[x_bits.dependencies(), y_bits.dependencies()]), 164 | move |values: &mut WireValues| { 165 | let mut seen_diff: bool = false; 166 | for (i, &mask_bit) in enumerate(&mask).rev() { 167 | let x_chunk_value = x_chunks[i].evaluate(values); 168 | let y_chunk_value = y_chunks[i].evaluate(values); 169 | let diff = x_chunk_value != y_chunk_value; 170 | let mask_bit_value = diff && !seen_diff; 171 | seen_diff |= diff; 172 | values.set(mask_bit, mask_bit_value.into()); 173 | } 174 | }, 175 | ); 176 | } 177 | 178 | // Compute the dot product of the mask vector with (x_chunks - y_chunks). 179 | let diff_chunk = (0..chunks).fold(Expression::zero(), 180 | |sum, i| { let diff = &x_chunks[i] - &y_chunks[i]; sum + self.product(&Expression::from(mask[i]), &diff) }); 181 | 182 | // Verify that any more significant pairs of chunks are equal. 183 | // diff_seen tracks whether a mask bit of 1 has been observed for a less significant bit. 184 | let mut diff_seen: Expression = mask[0].into(); 185 | for i in 1..chunks { 186 | // If diff_seen = 1, we require that x_chunk = y_chunk. 187 | // Equivalently, we require that diff_seen * (x_chunk - y_chunk) = 0. 188 | self.assert_product(&diff_seen, 189 | &(&x_chunks[i] - &y_chunks[i]), 190 | &Expression::zero()); 191 | diff_seen += Expression::from(mask[i]);//TODO: Alter loop format to remove extraneous add op 192 | } 193 | 194 | // If the mask has a 1 bit, then the corresponding pair of chunks must differ. We only need 195 | // this check for non-strict comparisons though, since for strict comparisons, the 196 | // comparison operation applied to the selected chunks will enforce that they differ. 197 | if !strict { 198 | // The mask is 0, so just assert that 1 (arbitrary) is non-zero. 199 | let nonzero = self.selection(&diff_exists, &diff_chunk, &Expression::one()); 200 | self.assert_nonzero(&nonzero); 201 | } 202 | 203 | // Finally, apply a different comparison algorithm to the (small) differing chunks. 204 | self.cmp_subtractive(diff_chunk, less, strict, chunk_bits) 205 | } 206 | 207 | /// Given a diff of `x - y`, compare `x` and `y`. 208 | fn cmp_subtractive(&mut self, diff: Expression, 209 | less: bool, strict: bool, bits: usize) -> BooleanExpression { 210 | // An as example, assume less=false and strict=false. In that case, we compute 211 | // 2^bits + x - y 212 | // And check the most significant bit, i.e., the one with index `bits`. 213 | // x >= y iff that bit is set. The other cases are similar. 214 | // TODO: If `bits` is very large, base might not fit in a field element. Need to generalize 215 | // this to work with arbitrary bit widths, or at least an assertion to fail gracefully. 216 | let base = Expression::from( 217 | (Element::one() << bits) - Element::from(strict)); 218 | let z = base + if less { -diff } else { diff }; 219 | self.split_bounded(&z, bits + 1).bits[bits].clone() 220 | } 221 | 222 | /// The number of constraints used by `cmp_binary`, given a certain chunk size. 223 | fn cmp_constraints(operand_bits: usize, chunk_bits: usize) -> usize { 224 | let chunks = (operand_bits + chunk_bits - 1) / chunk_bits; 225 | 3 * chunks + 2 + chunk_bits 226 | } 227 | 228 | /// The optimal number of bits per chunk for the comparison algorithm used in `cmp_binary`. 229 | fn cmp_chunk_bits(operand_bits: usize) -> usize { 230 | let mut best_chunk_bits = 1; 231 | let mut best_constraints = Self::cmp_constraints(operand_bits, 1); 232 | for chunk_bits in 2..Element::::max_bits() { 233 | let constraints = Self::cmp_constraints(operand_bits, chunk_bits); 234 | if constraints < best_constraints { 235 | best_chunk_bits = chunk_bits; 236 | best_constraints = constraints; 237 | } 238 | } 239 | best_chunk_bits 240 | } 241 | } 242 | 243 | #[cfg(test)] 244 | mod tests { 245 | use crate::Bn128; 246 | use crate::expression::Expression; 247 | use crate::field::Element; 248 | use crate::gadget_builder::GadgetBuilder; 249 | use crate::test_util::assert_eq_false; 250 | use crate::test_util::assert_eq_true; 251 | 252 | #[test] 253 | fn comparisons() { 254 | let mut builder = GadgetBuilder::::new(); 255 | let (x, y) = (builder.wire(), builder.wire()); 256 | let x_exp = Expression::from(x); 257 | let y_exp = Expression::from(y); 258 | let lt = builder.lt(&x_exp, &y_exp); 259 | let le = builder.le(&x_exp, &y_exp); 260 | let gt = builder.gt(&x_exp, &y_exp); 261 | let ge = builder.ge(&x_exp, &y_exp); 262 | let gadget = builder.build(); 263 | 264 | let mut values_42_63 = values!(x => 42u8.into(), y => 63u8.into()); 265 | assert!(gadget.execute(&mut values_42_63)); 266 | assert_eq_true(<, &values_42_63); 267 | assert_eq_true(&le, &values_42_63); 268 | assert_eq_false(>, &values_42_63); 269 | assert_eq_false(&ge, &values_42_63); 270 | 271 | let mut values_42_42 = values!(x => 42u8.into(), y => 42u8.into()); 272 | assert!(gadget.execute(&mut values_42_42)); 273 | assert_eq_false(<, &values_42_42); 274 | assert_eq_true(&le, &values_42_42); 275 | assert_eq_false(>, &values_42_42); 276 | assert_eq_true(&ge, &values_42_42); 277 | 278 | let mut values_42_41 = values!(x => 42u8.into(), y => 41u8.into()); 279 | assert!(gadget.execute(&mut values_42_41)); 280 | assert_eq_false(<, &values_42_41); 281 | assert_eq_false(&le, &values_42_41); 282 | assert_eq_true(>, &values_42_41); 283 | assert_eq_true(&ge, &values_42_41); 284 | 285 | // This is a white box sort of test. Since the implementation is based on chunks of roughly 286 | // 32 bits each, all the numbers in the preceding tests will fit into the least significant 287 | // chunk. So let's try some larger numbers. In particular, let's have x < y but have the 288 | // least significant chunk of y exceed that of x, to make sure the more significant chunk 289 | // takes precedence. 290 | let mut values_large_lt = values!( 291 | x => Element::from(1u128 << 80 | 1u128), 292 | y => Element::from(1u128 << 81)); 293 | assert!(gadget.execute(&mut values_large_lt)); 294 | assert_eq_true(<, &values_large_lt); 295 | assert_eq_true(&le, &values_large_lt); 296 | assert_eq_false(>, &values_large_lt); 297 | assert_eq_false(&ge, &values_large_lt); 298 | } 299 | } -------------------------------------------------------------------------------- /src/constraint.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::fmt::Formatter; 3 | 4 | use crate::expression::Expression; 5 | use crate::field::Field; 6 | use crate::wire_values::WireValues; 7 | 8 | /// An rank-1 constraint of the form a * b = c, where a, b, and c are linear combinations of wires. 9 | #[derive(Clone, Debug)] 10 | pub struct Constraint { 11 | pub a: Expression, 12 | pub b: Expression, 13 | pub c: Expression, 14 | } 15 | 16 | impl Constraint { 17 | pub fn evaluate(&self, wire_values: &WireValues) -> bool { 18 | let a_value = self.a.evaluate(wire_values); 19 | let b_value = self.b.evaluate(wire_values); 20 | let c_value = self.c.evaluate(wire_values); 21 | a_value * b_value == c_value 22 | } 23 | } 24 | 25 | impl fmt::Display for Constraint { 26 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 27 | let a_str = if self.a.num_terms() >= 2 { 28 | format!("({})", self.a) 29 | } else { 30 | format!("{}", self.a) 31 | }; 32 | 33 | let b_str = if self.b.num_terms() >= 2 { 34 | format!("({})", self.b) 35 | } else { 36 | format!("{}", self.b) 37 | }; 38 | 39 | write!(f, "{} * {} = {}", a_str, b_str, self.c) 40 | } 41 | } -------------------------------------------------------------------------------- /src/curves/edwards.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | use std::marker::PhantomData; 4 | 5 | use crate::{Element, Evaluable, Expression, Field, GadgetBuilder, Group, GroupExpression, WireValues}; 6 | 7 | /// Trait used to represent Edwards Curves and Twisted Edwards Curves. Note that the `a` 8 | /// parameter can be set to 1 to represent the less-general non-twisted Edwards Curves. 9 | pub trait EdwardsCurve { 10 | fn a() -> Element; 11 | fn d() -> Element; 12 | } 13 | 14 | pub struct EdwardsGroup> { 15 | phantom_f: PhantomData<*const F>, 16 | phantom_c: PhantomData<*const C>, 17 | } 18 | 19 | 20 | impl> Group for EdwardsGroup { 21 | 22 | type GroupElement = EdwardsPoint; 23 | type GroupExpression = EdwardsExpression; 24 | 25 | fn identity_element() -> Self::GroupElement { 26 | EdwardsPoint::new(Element::zero(), Element::one()) 27 | } 28 | 29 | /// Adds two points on an `EdwardsCurve` using the standard algorithm for Twisted Edwards 30 | /// Curves. 31 | // TODO: Add special case for variable + constant addition. 32 | // TODO: This uses 7 constraints, but we can get this down to 6, as described in the Zcash spec. 33 | fn add_expressions( 34 | builder: &mut GadgetBuilder, 35 | lhs: &Self::GroupExpression, 36 | rhs: &Self::GroupExpression, 37 | ) -> Self::GroupExpression { 38 | let a = C::a(); 39 | let d = C::d(); 40 | let EdwardsExpression { x: x1, y: y1, .. } = lhs; 41 | let EdwardsExpression { x: x2, y: y2, .. } = rhs; 42 | let x1y2 = builder.product(&x1, &y2); 43 | let x2y1 = builder.product(&y1, &x2); 44 | let x1x2 = builder.product(&x1, &x2); 45 | let x1x2y1y2 = builder.product(&x1y2, &x2y1); 46 | let y1y2 = builder.product(&y1, &y2); 47 | let x3 = builder.quotient_unsafe( 48 | &(x1y2 + x2y1), 49 | &(&x1x2y1y2 * &d + Expression::one())); 50 | let y3 = builder.quotient_unsafe( 51 | &(y1y2 - &x1x2 * &a), 52 | &(&x1x2y1y2 * -&d + Expression::one())); 53 | EdwardsExpression::new_unsafe(x3, y3) 54 | } 55 | 56 | // TODO: improve constraint count 57 | /// Naive implementation of the doubling algorithm for twisted Edwards curves. 58 | /// 59 | /// Note that this algorithm requires the point to be of odd order, which, in the case 60 | /// of prime-order subgroups of Edwards curves, is satisfied. 61 | fn double_expression( 62 | builder: &mut GadgetBuilder, 63 | point: &Self::GroupExpression, 64 | ) -> Self::GroupExpression { 65 | let EdwardsExpression { x, y, .. } = point; 66 | let a = C::a(); 67 | 68 | let xy = builder.product(&x, &y); 69 | let xx = builder.product(&x, &x); 70 | let yy = builder.product(&y, &y); 71 | let x_2 = builder.quotient_unsafe(&(&xy * Element::from(2u8)), &(&xx * &a + &yy)); 72 | let y_2 = builder.quotient_unsafe(&(&yy - &xx * &a), 73 | &(-&xx * &a - &yy + Expression::from(2u8))); 74 | 75 | EdwardsExpression::new_unsafe(x_2, y_2) 76 | } 77 | 78 | // TODO: implement Daira's algorithm from https://github.com/zcash/zcash/issues/3924 79 | // TODO: optimize for fixed-base multiplication using windowing, given a constant expression 80 | } 81 | 82 | /// An embedded Edwards curve point defined over the same base field as 83 | /// the constraint system, with affine coordinates as elements. 84 | pub struct EdwardsPoint> { 85 | pub x: Element, 86 | pub y: Element, 87 | phantom: PhantomData<*const C>, 88 | } 89 | 90 | impl> Clone for EdwardsPoint { 91 | fn clone(&self) -> Self { 92 | EdwardsPoint { 93 | x: self.x.clone(), 94 | y: self.y.clone(), 95 | phantom: PhantomData, 96 | } 97 | } 98 | } 99 | 100 | impl> Clone for EdwardsExpression { 101 | fn clone(&self) -> Self { 102 | EdwardsExpression { 103 | x: self.x.clone(), 104 | y: self.y.clone(), 105 | phantom: PhantomData, 106 | } 107 | } 108 | } 109 | 110 | impl> EdwardsPoint { 111 | pub fn new(x: Element, y: Element) -> EdwardsPoint { 112 | assert!(C::a() * &x * &x + &y * &y == Element::one() + C::d() * &x * &x * &y * &y, 113 | "Point must be contained on the curve."); 114 | EdwardsPoint { x, y, phantom: PhantomData } 115 | } 116 | 117 | pub fn compressed_element(&self) -> &Element { 118 | &self.y 119 | } 120 | } 121 | 122 | pub struct EdwardsExpression> { 123 | pub x: Expression, 124 | pub y: Expression, 125 | phantom: PhantomData<*const C>, 126 | } 127 | 128 | impl> EdwardsExpression { 129 | 130 | /// Safely creates an `EdwardsExpression` from two coordinates of type `EdwardsExpression`. 131 | /// Automatically generates constraints that assert that the resulting curve point 132 | /// is contained in the EdwardsCurve. 133 | pub fn new( 134 | builder: &mut GadgetBuilder, 135 | x: Expression, 136 | y: Expression, 137 | ) -> EdwardsExpression { 138 | let x_squared = builder.product(&x, &x); 139 | let y_squared = builder.product(&y, &y); 140 | let x_squared_y_squared = builder.product(&x_squared, &y_squared); 141 | builder.assert_equal(&(&x_squared * C::a() + &y_squared), 142 | &(&x_squared_y_squared * C::d() + Expression::one())); 143 | EdwardsExpression::new_unsafe(x, y) 144 | } 145 | 146 | /// Creates an `EdwardsExpression` from two arbitrary coordinates of type `Expression`. 147 | /// This method is unsafe and should only be used when the coordinates are proven 148 | /// to exist on the curve. 149 | pub fn new_unsafe(x: Expression, y: Expression) -> EdwardsExpression { 150 | EdwardsExpression { x, y, phantom: PhantomData } 151 | } 152 | } 153 | 154 | impl> GroupExpression for EdwardsExpression { 155 | fn compressed(&self) -> &Expression { &self.y } 156 | fn to_components(&self) -> Vec> { vec![self.x.clone(), self.y.clone()] } 157 | 158 | /// Given two group components of type `Expression`, creates an `EdwardsExpression`. Used 159 | /// in the generic implementation of scalar multiplication for groups. 160 | fn from_components_unsafe(mut components: Vec>) -> Self { 161 | let x = components.remove(0); 162 | let y = components.remove(0); 163 | Self::new_unsafe(x, y) 164 | } 165 | } 166 | 167 | impl> From<&EdwardsPoint> for EdwardsExpression { 168 | fn from(point: &EdwardsPoint) -> Self { 169 | EdwardsExpression { 170 | x: Expression::from(&point.x), 171 | y: Expression::from(&point.y), 172 | phantom: PhantomData, 173 | } 174 | } 175 | } 176 | 177 | impl> From<(Element, Element)> for EdwardsExpression { 178 | fn from(coordinates: (Element, Element)) -> Self { 179 | let point = EdwardsPoint::new(coordinates.0, coordinates.1); 180 | EdwardsExpression::from(&point) 181 | } 182 | } 183 | 184 | impl> Evaluable> for EdwardsExpression { 185 | fn evaluate( 186 | &self, 187 | wire_values: &WireValues, 188 | ) -> EdwardsPoint { 189 | EdwardsPoint { 190 | x: self.x.evaluate(wire_values), 191 | y: self.y.evaluate(wire_values), 192 | phantom: PhantomData, 193 | } 194 | } 195 | } 196 | 197 | #[cfg(test)] 198 | mod tests { 199 | use std::str::FromStr; 200 | 201 | use crate::{EdwardsExpression, Expression, GadgetBuilder, Group, WireValues, EdwardsGroup}; 202 | use crate::field::{Bls12_381, Element}; 203 | use crate::{JubJub}; 204 | 205 | #[test] 206 | fn point_on_curve() { 207 | let x = Element::from_str( 208 | "11076627216317271660298050606127911965867021807910416450833192264015104452986" 209 | ).unwrap(); 210 | let y = Element::from_str( 211 | "44412834903739585386157632289020980010620626017712148233229312325549216099227" 212 | ).unwrap(); 213 | 214 | let x_exp = Expression::from(x); 215 | let y_exp = Expression::from(y); 216 | 217 | let mut builder = GadgetBuilder::::new(); 218 | let p = EdwardsExpression::::new( 219 | &mut builder, x_exp, y_exp); 220 | 221 | let gadget = builder.build(); 222 | assert!(gadget.execute(&mut WireValues::new())); 223 | } 224 | 225 | #[test] 226 | fn point_not_on_curve_with_expressions() { 227 | let x = Element::from_str( 228 | "11076627216317271660298050606127911965867021807910416450833192264015104452986" 229 | ).unwrap(); 230 | let y = Element::from_str( 231 | "44412834903739585386157632289020980010620626017712148233229312325549216099226" 232 | ).unwrap(); 233 | 234 | let x_exp = Expression::from(x); 235 | let y_exp = Expression::from(y); 236 | 237 | let mut builder = GadgetBuilder::::new(); 238 | let p 239 | = EdwardsExpression::::new( 240 | &mut builder, 241 | x_exp, 242 | y_exp 243 | ); 244 | 245 | let gadget = builder.build(); 246 | assert!(!gadget.execute(&mut WireValues::new())); 247 | } 248 | 249 | #[test] 250 | #[should_panic] 251 | fn point_not_on_curve() { 252 | let x = Element::from_str( 253 | "11076627216317271660298050606127911965867021807910416450833192264015104452985" 254 | ).unwrap(); 255 | 256 | let y = Element::from_str( 257 | "44412834903739585386157632289020980010620626017712148233229312325549216099227" 258 | ).unwrap(); 259 | 260 | EdwardsExpression::::from((x, y)); 261 | } 262 | 263 | #[test] 264 | fn add_and_negate() { 265 | let x1 = Element::::from_str( 266 | "11076627216317271660298050606127911965867021807910416450833192264015104452986" 267 | ).unwrap(); 268 | let y1 = Element::::from_str( 269 | "44412834903739585386157632289020980010620626017712148233229312325549216099227" 270 | ).unwrap(); 271 | 272 | let p1 273 | = EdwardsExpression::::from((x1, y1)); 274 | 275 | let p2 276 | = EdwardsExpression::::new_unsafe(-p1.x.clone(), p1.y.clone()); 277 | 278 | let mut builder = GadgetBuilder::::new(); 279 | let p3 = EdwardsGroup::::add_expressions(&mut builder, &p1, &p2); 280 | let gadget = builder.build(); 281 | let mut values = WireValues::new(); 282 | gadget.execute(&mut values); 283 | assert_eq!(p3.x.evaluate(&values), Element::zero()); 284 | assert_eq!(p3.y.evaluate(&values), Element::one()); 285 | } 286 | 287 | #[test] 288 | fn mul_scalar() { 289 | let x1 = Element::::from_str( 290 | "11076627216317271660298050606127911965867021807910416450833192264015104452986" 291 | ).unwrap(); 292 | let y1 = Element::::from_str( 293 | "44412834903739585386157632289020980010620626017712148233229312325549216099227" 294 | ).unwrap(); 295 | 296 | let scalar = Expression::::from( 297 | Element::::from_str( 298 | "444128349033229312325549216099227444128349033229312325549216099220000000" 299 | ).unwrap() 300 | ); 301 | 302 | let p1 303 | = EdwardsExpression::::from((x1, y1)); 304 | 305 | let mut builder = GadgetBuilder::::new(); 306 | let p3 = EdwardsGroup::::mul_scalar_expression( 307 | &mut builder, 308 | &p1, 309 | &scalar, 310 | ); 311 | let gadget = builder.build(); 312 | let mut values = WireValues::new(); 313 | gadget.execute(&mut values); 314 | 315 | // TODO: include assertion 316 | } 317 | } -------------------------------------------------------------------------------- /src/curves/jubjub.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use crate::{Bls12_381, CyclicGenerator, EdwardsCurve, EdwardsGroup, EdwardsPoint, Element, CyclicSubgroup}; 4 | 5 | pub struct JubJub; 6 | 7 | pub type JubJubPrimeSubgroup = CyclicSubgroup, JubJub>; 8 | 9 | impl EdwardsCurve for JubJub { 10 | fn a() -> Element { 11 | -Element::one() 12 | } 13 | 14 | fn d() -> Element { 15 | Element::from_str( 16 | "19257038036680949359750312669786877991949435402254120286184196891950884077233" 17 | ).unwrap() 18 | } 19 | } 20 | 21 | impl CyclicGenerator> for JubJub { 22 | fn generator_element() -> EdwardsPoint { 23 | let x = Element::from_str( 24 | "11076627216317271660298050606127911965867021807910416450833192264015104452986" 25 | ).unwrap(); 26 | let y = Element::from_str( 27 | "44412834903739585386157632289020980010620626017712148233229312325549216099227" 28 | ).unwrap(); 29 | 30 | EdwardsPoint::new(x, y) 31 | } 32 | } 33 | 34 | #[cfg(test)] 35 | mod tests { 36 | #[test] 37 | fn subgroup_check() { 38 | //TODO: verify that generator is valid and generates a subgroup of prime order with appropriate cofactor 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/curves/mod.rs: -------------------------------------------------------------------------------- 1 | pub use edwards::*; 2 | pub use jubjub::*; 3 | pub use montgomery::*; 4 | pub use weierstrass::*; 5 | 6 | mod edwards; 7 | mod jubjub; 8 | mod montgomery; 9 | mod weierstrass; 10 | -------------------------------------------------------------------------------- /src/curves/montgomery.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use crate::{Element, Expression, Field}; 4 | 5 | /// A Montgomery curve. 6 | pub trait MontgomeryCurve { 7 | fn a() -> Element; 8 | fn b() -> Element; 9 | } 10 | 11 | /// An embedded Montgomery curve point defined over the same base field as 12 | /// the constraint system, with affine coordinates as elements. 13 | pub struct MontgomeryPoint> { 14 | pub x: Element, 15 | pub y: Element, 16 | phantom: PhantomData<*const C>, 17 | } 18 | 19 | /// An embedded Montgomery curve point defined over the same base field 20 | /// as the field used in the constraint system, with affine coordinates as 21 | /// expressions. 22 | pub struct MontgomeryExpression> { 23 | pub x: Expression, 24 | pub y: Expression, 25 | phantom: PhantomData<*const C>, 26 | } 27 | -------------------------------------------------------------------------------- /src/curves/weierstrass.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use crate::{Element, Expression, Field}; 4 | 5 | /// A short Weierstrass curve. 6 | pub trait WeierstrassCurve { 7 | fn a() -> Element; 8 | fn b() -> Element; 9 | } 10 | 11 | /// An embedded Weierstrass curve point defined over the same base field as 12 | /// the constraint system, with affine coordinates as elements. 13 | pub struct WeierstrassPoint> { 14 | pub x: Element, 15 | pub y: Element, 16 | phantom: PhantomData<*const C>, 17 | } 18 | 19 | /// An embedded Weierstrass curve point defined over the same base field 20 | /// as the field used in the constraint system, with affine coordinates as 21 | /// expressions. 22 | pub struct WeierstrassExpression> { 23 | pub x: Expression, 24 | pub y: Expression, 25 | phantom: PhantomData<*const C>, 26 | } 27 | 28 | /// An embedded Weierstrass curve point defined over the same base field 29 | /// as the field used in the constraint system, with projective coordinates 30 | /// as expressions. 31 | pub struct ProjWeierstrassExpression> { 32 | pub x: Expression, 33 | pub y: Expression, 34 | pub z: Expression, 35 | phantom: PhantomData<*const C>, 36 | } 37 | -------------------------------------------------------------------------------- /src/davies_meyer.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with an implementation of the Davies-Meyer construction. 2 | 3 | use std::marker::PhantomData; 4 | 5 | use crate::expression::Expression; 6 | use crate::field::Field; 7 | use crate::gadget_builder::GadgetBuilder; 8 | use crate::gadget_traits::{BlockCipher, CompressionFunction}; 9 | 10 | /// The additive variant of Davies-Meyer, which creates a one-way compression function from a block 11 | /// cipher. 12 | pub struct DaviesMeyer> { 13 | cipher: BC, 14 | phantom: PhantomData<*const F>, 15 | } 16 | 17 | impl> DaviesMeyer { 18 | /// Create a new Davies-Meyer compression function from the given block cipher. 19 | pub fn new(cipher: BC) -> Self { 20 | DaviesMeyer { cipher, phantom: PhantomData } 21 | } 22 | } 23 | 24 | impl> CompressionFunction for DaviesMeyer { 25 | fn compress(&self, builder: &mut GadgetBuilder, x: &Expression, y: &Expression) 26 | -> Expression { 27 | self.cipher.encrypt(builder, y, x) + x 28 | } 29 | } 30 | 31 | #[cfg(test)] 32 | mod tests { 33 | use crate::davies_meyer::DaviesMeyer; 34 | use crate::expression::Expression; 35 | use crate::field::{Element, Field}; 36 | use crate::gadget_builder::GadgetBuilder; 37 | use crate::gadget_traits::{BlockCipher, CompressionFunction}; 38 | use crate::test_util::F7; 39 | 40 | #[test] 41 | fn davies_meyer() { 42 | // We will use a trivial cipher to keep the test simple. 43 | // The cipher is: (k, i) -> 2k + 4i + 3ki 44 | struct TestCipher; 45 | 46 | impl BlockCipher for TestCipher { 47 | fn encrypt(&self, builder: &mut GadgetBuilder, key: &Expression, 48 | input: &Expression) -> Expression { 49 | let product = builder.product(key, input); 50 | key * 2 + input * 4 + product * 3 51 | } 52 | 53 | fn decrypt(&self, _builder: &mut GadgetBuilder, _key: &Expression, 54 | _output: &Expression) -> Expression { 55 | panic!("Should never be called") 56 | } 57 | } 58 | 59 | let mut builder = GadgetBuilder::::new(); 60 | let x_wire = builder.wire(); 61 | let y_wire = builder.wire(); 62 | let x = Expression::from(x_wire); 63 | let y = Expression::from(y_wire); 64 | let dm = DaviesMeyer::new(TestCipher); 65 | let dm_output = dm.compress(&mut builder, &x, &y); 66 | let gadget = builder.build(); 67 | 68 | let mut values = values!(x_wire => 2u8.into(), y_wire => 3u8.into()); 69 | assert!(gadget.execute(&mut values)); 70 | // The result should be: (2y + 4x + 3yx) + x = 6 + 8 + 18 + 2 = 34 = 6. 71 | assert_eq!(Element::from(6u8), dm_output.evaluate(&values)); 72 | } 73 | } -------------------------------------------------------------------------------- /src/field.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::fmt; 3 | use std::fmt::Formatter; 4 | use std::hash::{Hash, Hasher}; 5 | use std::marker::PhantomData; 6 | use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Shl, Sub, SubAssign}; 7 | use std::str::FromStr; 8 | 9 | use num::bigint::ParseBigIntError; 10 | use num::BigUint; 11 | use num_traits::One; 12 | use num_traits::Zero; 13 | 14 | /// A prime order field. 15 | pub trait Field: 'static { 16 | /// The (prime) order of this field. 17 | fn order() -> BigUint; 18 | } 19 | 20 | /// The BN128 curve. 21 | #[derive(Debug)] 22 | pub struct Bn128 {} 23 | 24 | impl Field for Bn128 { 25 | fn order() -> BigUint { 26 | BigUint::from_str( 27 | "21888242871839275222246405745257275088548364400416034343698204186575808495617" 28 | ).unwrap() 29 | } 30 | } 31 | 32 | /// The BLS12-381 curve. 33 | #[derive(Debug)] 34 | pub struct Bls12_381 {} 35 | 36 | impl Field for Bls12_381 { 37 | fn order() -> BigUint { 38 | BigUint::from_str( 39 | "52435875175126190479447740508185965837690552500527637822603658699938581184513" 40 | ).unwrap() 41 | } 42 | } 43 | 44 | /// An element of a prime field. 45 | #[derive(Debug)] 46 | pub struct Element { 47 | n: BigUint, 48 | /// F needs to be present in a struct field, otherwise the compiler will complain that it is 49 | /// unused. In reality it is used, but only at compile time. For example, some functions take an 50 | /// `Element` and call `F::order()`. 51 | phantom: PhantomData<*const F>, 52 | } 53 | 54 | impl Element { 55 | pub fn zero() -> Self { 56 | Self::from(BigUint::zero()) 57 | } 58 | 59 | pub fn one() -> Self { 60 | Self::from(BigUint::one()) 61 | } 62 | 63 | pub fn largest_element() -> Self { 64 | Self::from(F::order() - BigUint::one()) 65 | } 66 | 67 | pub fn to_biguint(&self) -> &BigUint { 68 | &self.n 69 | } 70 | 71 | pub fn is_zero(&self) -> bool { 72 | self.to_biguint().is_zero() 73 | } 74 | 75 | pub fn is_nonzero(&self) -> bool { 76 | !self.to_biguint().is_zero() 77 | } 78 | 79 | pub fn is_one(&self) -> bool { 80 | self.to_biguint().is_one() 81 | } 82 | 83 | pub fn multiplicative_inverse(&self) -> Self { 84 | assert!(!self.is_zero(), "Zero does not have a multiplicative inverse"); 85 | // From Fermat's little theorem. 86 | // TODO: Use a faster method, like the one described in "Fast Modular Reciprocals". 87 | // Or just wait for https://github.com/rust-num/num-bigint/issues/60 88 | self.exponentiation(&-Self::from(2u8)) 89 | } 90 | 91 | /// Like `multiplicative_inverse`, except that zero is mapped to itself rather than causing a 92 | /// panic. 93 | pub fn multiplicative_inverse_or_zero(&self) -> Self { 94 | if self.is_zero() { 95 | Self::zero() 96 | } else { 97 | self.multiplicative_inverse() 98 | } 99 | } 100 | 101 | pub fn exponentiation(&self, power: &Self) -> Self { 102 | Self::from(self.to_biguint().modpow(power.to_biguint(), &F::order())) 103 | } 104 | 105 | pub fn integer_division(&self, rhs: &Self) -> Self { 106 | Self::from(self.to_biguint() / rhs.to_biguint()) 107 | } 108 | 109 | pub fn integer_modulus(&self, rhs: &Self) -> Self { 110 | Self::from(self.to_biguint() % rhs.to_biguint()) 111 | } 112 | 113 | pub fn gcd(&self, rhs: &Self) -> Self { 114 | // This is just the Euclidean algorithm. 115 | if rhs.is_zero() { 116 | self.clone() 117 | } else { 118 | rhs.gcd(&self.integer_modulus(rhs)) 119 | } 120 | } 121 | 122 | pub fn is_prime(&self) -> bool { 123 | let mut divisor = Self::from(2u8); 124 | while &divisor * &divisor <= self.clone() { 125 | let divides = self.integer_modulus(&divisor).is_zero(); 126 | if divides { 127 | return false; 128 | } 129 | divisor += Element::one(); 130 | } 131 | true 132 | } 133 | 134 | /// The number of bits needed to encode every element of `F`. 135 | pub fn max_bits() -> usize { 136 | Self::largest_element().bits() 137 | } 138 | 139 | /// The number of bits needed to encode this particular field element. 140 | pub fn bits(&self) -> usize { 141 | self.to_biguint().bits() as usize 142 | } 143 | 144 | /// Return the i'th least significant bit. So, for example, x.bit(0) returns the least 145 | /// significant bit of x. Return false for outside of range. 146 | pub fn bit(&self, i: usize) -> bool { 147 | ((self.to_biguint() >> i) & BigUint::one()).is_one() 148 | } 149 | } 150 | 151 | impl From for Element { 152 | fn from(n: BigUint) -> Element { 153 | assert!(n < F::order(), "Out of range"); 154 | Element { n, phantom: PhantomData } 155 | } 156 | } 157 | 158 | impl From for Element { 159 | fn from(n: usize) -> Element { 160 | Element::from(BigUint::from(n)) 161 | } 162 | } 163 | 164 | impl From for Element { 165 | fn from(n: u128) -> Element { 166 | Element::from(BigUint::from(n)) 167 | } 168 | } 169 | 170 | impl From for Element { 171 | fn from(n: u64) -> Element { 172 | Element::from(BigUint::from(n)) 173 | } 174 | } 175 | 176 | impl From for Element { 177 | fn from(n: u32) -> Element { 178 | Element::from(BigUint::from(n)) 179 | } 180 | } 181 | 182 | impl From for Element { 183 | fn from(n: u16) -> Element { 184 | Element::from(BigUint::from(n)) 185 | } 186 | } 187 | 188 | impl From for Element { 189 | fn from(n: u8) -> Element { 190 | Element::from(BigUint::from(n)) 191 | } 192 | } 193 | 194 | impl From for Element { 195 | fn from(b: bool) -> Element { 196 | Element::from(b as u128) 197 | } 198 | } 199 | 200 | impl FromStr for Element { 201 | type Err = ParseBigIntError; 202 | 203 | fn from_str(s: &str) -> Result { 204 | BigUint::from_str(s).map(Element::from) 205 | } 206 | } 207 | 208 | impl PartialEq for Element { 209 | fn eq(&self, other: &Self) -> bool { 210 | self.to_biguint() == other.to_biguint() 211 | } 212 | } 213 | 214 | impl Eq for Element {} 215 | 216 | impl Clone for Element { 217 | fn clone(&self) -> Self { 218 | Element::from(self.to_biguint().clone()) 219 | } 220 | } 221 | 222 | impl Hash for Element { 223 | fn hash(&self, state: &mut H) { 224 | self.n.hash(state) 225 | } 226 | } 227 | 228 | impl Ord for Element { 229 | fn cmp(&self, other: &Self) -> Ordering { 230 | self.n.cmp(&other.n) 231 | } 232 | } 233 | 234 | impl PartialOrd for Element { 235 | fn partial_cmp(&self, other: &Self) -> Option { 236 | Some(self.cmp(other)) 237 | } 238 | } 239 | 240 | impl Neg for Element { 241 | type Output = Element; 242 | 243 | fn neg(self) -> Element { 244 | -&self 245 | } 246 | } 247 | 248 | impl Neg for &Element { 249 | type Output = Element; 250 | 251 | fn neg(self) -> Element { 252 | if self.is_zero() { 253 | Element::zero() 254 | } else { 255 | Element::from(F::order() - self.to_biguint()) 256 | } 257 | } 258 | } 259 | 260 | impl Add> for Element { 261 | type Output = Element; 262 | 263 | fn add(self, rhs: Element) -> Element { 264 | &self + &rhs 265 | } 266 | } 267 | 268 | impl Add<&Element> for Element { 269 | type Output = Element; 270 | 271 | fn add(self, rhs: &Element) -> Element { 272 | &self + rhs 273 | } 274 | } 275 | 276 | impl Add> for &Element { 277 | type Output = Element; 278 | 279 | fn add(self, rhs: Element) -> Element { 280 | self + &rhs 281 | } 282 | } 283 | 284 | impl Add<&Element> for &Element { 285 | type Output = Element; 286 | 287 | fn add(self, rhs: &Element) -> Element { 288 | Element::from((self.to_biguint() + rhs.to_biguint()) % F::order()) 289 | } 290 | } 291 | 292 | impl AddAssign for Element { 293 | fn add_assign(&mut self, rhs: Element) { 294 | *self += &rhs; 295 | } 296 | } 297 | 298 | impl AddAssign<&Element> for Element { 299 | fn add_assign(&mut self, rhs: &Element) { 300 | *self = &*self + rhs; 301 | } 302 | } 303 | 304 | impl Sub> for Element { 305 | type Output = Element; 306 | 307 | fn sub(self, rhs: Element) -> Element { 308 | &self - &rhs 309 | } 310 | } 311 | 312 | impl Sub<&Element> for Element { 313 | type Output = Element; 314 | 315 | fn sub(self, rhs: &Element) -> Element { 316 | &self - rhs 317 | } 318 | } 319 | 320 | impl Sub> for &Element { 321 | type Output = Element; 322 | 323 | fn sub(self, rhs: Element) -> Element { 324 | self - &rhs 325 | } 326 | } 327 | 328 | impl Sub<&Element> for &Element { 329 | type Output = Element; 330 | 331 | fn sub(self, rhs: &Element) -> Element { 332 | self + -rhs 333 | } 334 | } 335 | 336 | impl SubAssign for Element { 337 | fn sub_assign(&mut self, rhs: Element) { 338 | *self -= &rhs; 339 | } 340 | } 341 | 342 | impl SubAssign<&Element> for Element { 343 | fn sub_assign(&mut self, rhs: &Element) { 344 | *self = &*self - rhs; 345 | } 346 | } 347 | 348 | impl Mul> for Element { 349 | type Output = Element; 350 | 351 | fn mul(self, rhs: Element) -> Element { 352 | &self * &rhs 353 | } 354 | } 355 | 356 | impl Mul<&Element> for Element { 357 | type Output = Element; 358 | 359 | fn mul(self, rhs: &Element) -> Element { 360 | &self * rhs 361 | } 362 | } 363 | 364 | impl Mul> for &Element { 365 | type Output = Element; 366 | 367 | fn mul(self, rhs: Element) -> Element { 368 | self * &rhs 369 | } 370 | } 371 | 372 | impl Mul<&Element> for &Element { 373 | type Output = Element; 374 | 375 | fn mul(self, rhs: &Element) -> Element { 376 | Element::from((self.to_biguint() * rhs.to_biguint()) % F::order()) 377 | } 378 | } 379 | 380 | impl Mul for Element { 381 | type Output = Element; 382 | 383 | fn mul(self, rhs: u128) -> Element { 384 | &self * rhs 385 | } 386 | } 387 | 388 | impl Mul for &Element { 389 | type Output = Element; 390 | 391 | fn mul(self, rhs: u128) -> Element { 392 | self * Element::from(rhs) 393 | } 394 | } 395 | 396 | impl MulAssign for Element { 397 | fn mul_assign(&mut self, rhs: Element) { 398 | *self *= &rhs; 399 | } 400 | } 401 | 402 | impl MulAssign<&Element> for Element { 403 | fn mul_assign(&mut self, rhs: &Element) { 404 | *self = self.clone() * rhs; 405 | } 406 | } 407 | 408 | impl MulAssign for Element { 409 | fn mul_assign(&mut self, rhs: u128) { 410 | *self = self.clone() * rhs; 411 | } 412 | } 413 | 414 | impl Div> for Element { 415 | type Output = Element; 416 | 417 | fn div(self, rhs: Element) -> Element { 418 | &self / &rhs 419 | } 420 | } 421 | 422 | impl Div<&Element> for Element { 423 | type Output = Element; 424 | 425 | fn div(self, rhs: &Element) -> Element { 426 | &self / rhs 427 | } 428 | } 429 | 430 | impl Div> for &Element { 431 | type Output = Element; 432 | 433 | fn div(self, rhs: Element) -> Element { 434 | self / &rhs 435 | } 436 | } 437 | 438 | impl Div<&Element> for &Element { 439 | type Output = Element; 440 | 441 | #[allow(clippy::suspicious_arithmetic_impl)] 442 | fn div(self, rhs: &Element) -> Element { 443 | self * rhs.multiplicative_inverse() 444 | } 445 | } 446 | 447 | impl Div for Element { 448 | type Output = Element; 449 | 450 | fn div(self, rhs: u128) -> Element { 451 | &self / rhs 452 | } 453 | } 454 | 455 | impl Div for &Element { 456 | type Output = Element; 457 | 458 | fn div(self, rhs: u128) -> Element { 459 | self / Element::from(rhs) 460 | } 461 | } 462 | 463 | impl DivAssign for Element { 464 | fn div_assign(&mut self, rhs: Element) { 465 | *self /= &rhs; 466 | } 467 | } 468 | 469 | impl DivAssign<&Element> for Element { 470 | fn div_assign(&mut self, rhs: &Element) { 471 | *self = self.clone() / rhs; 472 | } 473 | } 474 | 475 | impl DivAssign for Element { 476 | fn div_assign(&mut self, rhs: u128) { 477 | *self = self.clone() / rhs; 478 | } 479 | } 480 | 481 | impl Shl for Element { 482 | type Output = Element; 483 | 484 | fn shl(self, rhs: usize) -> Element { 485 | &self << rhs 486 | } 487 | } 488 | 489 | impl Shl for &Element { 490 | type Output = Element; 491 | 492 | fn shl(self, rhs: usize) -> Element { 493 | Element::from(self.to_biguint() << rhs) 494 | } 495 | } 496 | 497 | impl fmt::Display for Element { 498 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 499 | write!(f, "{}", self.to_biguint()) 500 | } 501 | } 502 | 503 | #[cfg(test)] 504 | mod tests { 505 | use std::iter; 506 | 507 | use itertools::assert_equal; 508 | 509 | use crate::field::Element; 510 | use crate::test_util::{F257, F7}; 511 | 512 | #[test] 513 | fn addition() { 514 | type F = F257; 515 | 516 | assert_eq!( 517 | Element::::from(2u8), 518 | Element::one() + Element::one()); 519 | 520 | assert_eq!( 521 | Element::::from(33u8), 522 | Element::from(13u8) + Element::from(20u8)); 523 | } 524 | 525 | #[test] 526 | fn addition_overflow() { 527 | type F = F7; 528 | 529 | assert_eq!( 530 | Element::::from(3u8), 531 | Element::from(5u8) + Element::from(5u8)); 532 | } 533 | 534 | #[test] 535 | fn additive_inverse() { 536 | type F = F7; 537 | 538 | assert_eq!( 539 | Element::::from(6u8), 540 | -Element::one()); 541 | 542 | assert_eq!( 543 | Element::::zero(), 544 | Element::from(5u8) + -Element::from(5u8)); 545 | } 546 | 547 | #[test] 548 | fn multiplicative_inverse() { 549 | type F = F7; 550 | 551 | // Verified with a bit of Python code: 552 | // >>> f = 7 553 | // >>> [[y for y in range(f) if x * y % f == 1] for x in range(f)] 554 | // [[], [1], [4], [5], [2], [3], [6]] 555 | assert_eq!(Element::::from(0u8), Element::from(0u8).multiplicative_inverse_or_zero()); 556 | assert_eq!(Element::::from(1u8), Element::from(1u8).multiplicative_inverse_or_zero()); 557 | assert_eq!(Element::::from(4u8), Element::from(2u8).multiplicative_inverse_or_zero()); 558 | assert_eq!(Element::::from(5u8), Element::from(3u8).multiplicative_inverse_or_zero()); 559 | assert_eq!(Element::::from(2u8), Element::from(4u8).multiplicative_inverse_or_zero()); 560 | assert_eq!(Element::::from(3u8), Element::from(5u8).multiplicative_inverse_or_zero()); 561 | assert_eq!(Element::::from(6u8), Element::from(6u8).multiplicative_inverse_or_zero()); 562 | } 563 | 564 | #[test] 565 | fn multiplication_overflow() { 566 | type F = F7; 567 | 568 | assert_eq!( 569 | Element::::from(2u8), 570 | Element::from(3u8) * Element::from(3u8)); 571 | } 572 | 573 | #[test] 574 | fn bits_0() { 575 | let x = Element::::zero(); 576 | let n: usize = 20; 577 | assert_equal( 578 | iter::repeat(false).take(n), 579 | (0..n).map(|i| x.bit(i))); 580 | } 581 | 582 | #[test] 583 | fn bits_19() { 584 | let x = Element::::from(19u8); 585 | assert_eq!(true, x.bit(0)); 586 | assert_eq!(true, x.bit(1)); 587 | assert_eq!(false, x.bit(2)); 588 | assert_eq!(false, x.bit(3)); 589 | assert_eq!(true, x.bit(4)); 590 | assert_eq!(false, x.bit(5)); 591 | assert_eq!(false, x.bit(6)); 592 | assert_eq!(false, x.bit(7)); 593 | assert_eq!(false, x.bit(8)); 594 | assert_eq!(false, x.bit(9)); 595 | } 596 | 597 | #[test] 598 | fn order_of_elements() { 599 | type F = F257; 600 | for i in 0u8..50 { 601 | assert!(Element::::from(i) < Element::::from(i + 1)); 602 | } 603 | } 604 | } 605 | -------------------------------------------------------------------------------- /src/field_arithmetic.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with native field arithmetic methods. 2 | 3 | use crate::expression::{BooleanExpression, Expression}; 4 | use crate::field::{Element, Field}; 5 | use crate::gadget_builder::GadgetBuilder; 6 | use crate::wire_values::WireValues; 7 | use crate::util::concat; 8 | 9 | impl GadgetBuilder { 10 | /// The product of two `Expression`s `x` and `y`, i.e. `x * y`. 11 | pub fn product(&mut self, x: &Expression, y: &Expression) -> Expression { 12 | if let Some(c) = x.as_constant() { 13 | return y * c; 14 | } 15 | if let Some(c) = y.as_constant() { 16 | return x * c; 17 | } 18 | 19 | let product = self.wire(); 20 | let product_exp = Expression::from(product); 21 | self.assert_product(x, y, &product_exp); 22 | 23 | { 24 | let x = x.clone(); 25 | let y = y.clone(); 26 | self.generator( 27 | concat(&[x.dependencies(), y.dependencies()]), 28 | move |values: &mut WireValues| { 29 | let product_value = x.evaluate(values) * y.evaluate(values); 30 | values.set(product, product_value); 31 | }, 32 | ); 33 | } 34 | 35 | product_exp 36 | } 37 | 38 | /// `x^p` for a constant `p`. 39 | pub fn exponentiation(&mut self, x: &Expression, p: &Element) -> Expression { 40 | // This is exponentiation by squaring. For each 1 bit of p, multiply by the associated 41 | // square power. 42 | let mut product_exp = Expression::one(); 43 | let mut last_square = Expression::zero(); 44 | 45 | for i in 0..p.bits() { 46 | let square = if i == 0 { 47 | x.clone() 48 | } else { 49 | self.product(&last_square, &last_square) 50 | }; 51 | 52 | if p.bit(i) { 53 | product_exp = self.product(&product_exp, &square); 54 | } 55 | 56 | last_square = square; 57 | } 58 | product_exp 59 | } 60 | 61 | /// Returns `1 / x`, assuming `x` is non-zero. If `x` is zero, the gadget will not be 62 | /// satisfiable. 63 | pub fn inverse(&mut self, x: &Expression) -> Expression { 64 | let x_inv = self.wire(); 65 | self.assert_product(x, &Expression::from(x_inv), &Expression::one()); 66 | 67 | let x = x.clone(); 68 | self.generator( 69 | x.dependencies(), 70 | move |values: &mut WireValues| { 71 | let x_value = x.evaluate(values); 72 | let inverse_value = x_value.multiplicative_inverse(); 73 | values.set(x_inv, inverse_value); 74 | }, 75 | ); 76 | 77 | x_inv.into() 78 | } 79 | 80 | /// Like `inverse`, except that zero is mapped to itself rather than being prohibited. 81 | pub fn inverse_or_zero(&mut self, x: &Expression) -> Expression { 82 | let x_inv_or_zero = self.wire(); 83 | let nonzero = self.nonzero(x); 84 | self.assert_product(x, &Expression::from(x_inv_or_zero), nonzero.expression()); 85 | 86 | let x = x.clone(); 87 | self.generator( 88 | x.dependencies(), 89 | move |values: &mut WireValues| { 90 | let x_value = x.evaluate(values); 91 | values.set(x_inv_or_zero, x_value.multiplicative_inverse_or_zero()); 92 | }, 93 | ); 94 | 95 | x_inv_or_zero.into() 96 | } 97 | 98 | /// Returns `x / y`, assuming `y` is non-zero. If `y` is zero, the gadget will not be 99 | /// satisfiable. 100 | pub fn quotient(&mut self, x: &Expression, y: &Expression) -> Expression { 101 | let y_inv = self.inverse(y); 102 | self.product(x, &y_inv) 103 | } 104 | 105 | /// Returns `x / y`, assuming `y` is non-zero. This is equivalent to `quotient` except that it 106 | /// allows `0 / 0`. This method will panic if it encounters `0 / 0`, but a malicious prover 107 | /// would be able to supply an arbitrary quotient. 108 | /// 109 | /// This method uses a single constraint, whereas `quotient` uses two, so this method may be 110 | /// preferable in cases where `0 / 0` cannot possibly arise. 111 | pub fn quotient_unsafe(&mut self, x: &Expression, y: &Expression) -> Expression { 112 | let q = self.wire(); 113 | let x = x.clone(); 114 | let y = y.clone(); 115 | self.generator( 116 | [x.dependencies(), y.dependencies()].concat(), 117 | move |values: &mut WireValues| { 118 | let x_value = x.evaluate(values); 119 | let y_value = y.evaluate(values); 120 | assert!(y_value.is_nonzero(), "Division by zero"); 121 | let q_value = x_value / y_value; 122 | values.set(q, q_value) 123 | } 124 | ); 125 | Expression::from(q) 126 | } 127 | 128 | /// Returns `x mod y`, assuming `y` is non-zero. If `y` is zero, the gadget will not be 129 | /// satisfiable. 130 | pub fn modulus(&mut self, x: &Expression, y: &Expression) -> Expression { 131 | // We will non-deterministically compute a quotient q and remainder r such that: 132 | // y * q = x - r 133 | // r < y 134 | 135 | let q = self.wire(); 136 | let r = self.wire(); 137 | self.assert_product(y, &Expression::from(q), &(x - Expression::from(r))); 138 | self.assert_lt(&Expression::from(r), y); 139 | 140 | { 141 | let x = x.clone(); 142 | let y = y.clone(); 143 | self.generator( 144 | concat(&[x.dependencies(), y.dependencies()]), 145 | move |values: &mut WireValues| { 146 | let x_value = x.evaluate(values); 147 | let y_value = y.evaluate(values); 148 | values.set(q, x_value.integer_division(&y_value)); 149 | values.set(r, x_value.integer_modulus(&y_value)); 150 | }, 151 | ); 152 | } 153 | 154 | r.into() 155 | } 156 | 157 | /// Returns whether `x` divides `y`, i.e. `x | y`. 158 | pub fn divides(&mut self, x: &Expression, y: &Expression) -> BooleanExpression { 159 | let m = self.modulus(y, x); 160 | self.zero(&m) 161 | } 162 | } 163 | 164 | #[cfg(test)] 165 | mod tests { 166 | use crate::expression::Expression; 167 | use crate::field::Element; 168 | use crate::gadget_builder::GadgetBuilder; 169 | use crate::test_util::{assert_eq_false, assert_eq_true, F257}; 170 | 171 | #[test] 172 | fn exp() { 173 | let mut builder = GadgetBuilder::::new(); 174 | let x = builder.wire(); 175 | let x_exp_0 = builder.exponentiation(&Expression::from(x), &Element::from(0u8)); 176 | let x_exp_1 = builder.exponentiation(&Expression::from(x), &Element::from(1u8)); 177 | let x_exp_2 = builder.exponentiation(&Expression::from(x), &Element::from(2u8)); 178 | let x_exp_3 = builder.exponentiation(&Expression::from(x), &Element::from(3u8)); 179 | let gadget = builder.build(); 180 | 181 | let mut values = values!(x => 3u8.into()); 182 | assert!(gadget.execute(&mut values)); 183 | assert_eq!(Element::from(1u8), x_exp_0.evaluate(&values)); 184 | assert_eq!(Element::from(3u8), x_exp_1.evaluate(&values)); 185 | assert_eq!(Element::from(9u8), x_exp_2.evaluate(&values)); 186 | assert_eq!(Element::from(27u8), x_exp_3.evaluate(&values)); 187 | } 188 | 189 | #[test] 190 | #[should_panic] 191 | fn invert_zero() { 192 | let mut builder = GadgetBuilder::::new(); 193 | let x = builder.wire(); 194 | builder.inverse(&Expression::from(x)); 195 | let gadget = builder.build(); 196 | 197 | let mut values = values!(x => 0u8.into()); 198 | gadget.execute(&mut values); 199 | } 200 | 201 | #[test] 202 | fn divides() { 203 | let mut builder = GadgetBuilder::::new(); 204 | let x = builder.wire(); 205 | let y = builder.wire(); 206 | let divides = builder.divides(&Expression::from(x), &Expression::from(y)); 207 | let gadget = builder.build(); 208 | 209 | let mut values_1_1 = values!(x => 1u8.into(), y => 1u8.into()); 210 | assert!(gadget.execute(&mut values_1_1)); 211 | assert_eq_true(÷s, &values_1_1); 212 | 213 | let mut values_3_6 = values!(x => 3u8.into(), y => 6u8.into()); 214 | assert!(gadget.execute(&mut values_3_6)); 215 | assert_eq_true(÷s, &values_3_6); 216 | 217 | let mut values_3_7 = values!(x => 3u8.into(), y => 7u8.into()); 218 | assert!(gadget.execute(&mut values_3_7)); 219 | assert_eq_false(÷s, &values_3_7); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /src/gadget.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | use crate::constraint::Constraint; 5 | use crate::field::Field; 6 | use crate::wire_values::WireValues; 7 | use crate::witness_generator::WitnessGenerator; 8 | 9 | /// An R1CS gadget. 10 | pub struct Gadget { 11 | /// The set of rank-1 constraints which define the R1CS instance. 12 | pub constraints: Vec>, 13 | /// The set of generators used to generate a complete witness from inputs. 14 | pub witness_generators: Vec>, 15 | } 16 | 17 | impl Gadget { 18 | /// The number of constraints in this gadget. 19 | pub fn size(&self) -> usize { 20 | self.constraints.len() 21 | } 22 | 23 | /// Execute the gadget, and return whether all constraints were satisfied. 24 | pub fn execute(&self, wire_values: &mut WireValues) -> bool { 25 | let mut pending_generators: Vec<&WitnessGenerator> = self.witness_generators.iter().collect(); 26 | 27 | // TODO: This repeatedly enumerates all generators, whether or not any of their dependencies 28 | // have been generated. A better approach would be to create a map from wires to generators 29 | // which depend on those wires. Then when a wire is assigned a value, we could efficiently 30 | // check for generators which are now ready to run, and place them in a queue. 31 | loop { 32 | let mut made_progress = false; 33 | pending_generators.retain(|generator| { 34 | if wire_values.contains_all(generator.inputs()) { 35 | generator.generate(wire_values); 36 | made_progress = true; 37 | false 38 | } else { 39 | true 40 | } 41 | }); 42 | 43 | if !made_progress { 44 | break; 45 | } 46 | } 47 | 48 | assert_eq!(pending_generators.len(), 0, "Some generators never received inputs"); 49 | 50 | self.constraints.iter().all(|constraint| constraint.evaluate(wire_values)) 51 | } 52 | } 53 | 54 | #[cfg(test)] 55 | mod tests { 56 | use crate::expression::Expression; 57 | use crate::gadget_builder::GadgetBuilder; 58 | use crate::test_util::F257; 59 | use crate::wire_values::WireValues; 60 | 61 | #[test] 62 | fn constraint_not_satisfied() { 63 | let mut builder = GadgetBuilder::::new(); 64 | let (x, y) = (builder.wire(), builder.wire()); 65 | builder.assert_equal(&Expression::from(x), &Expression::from(y)); 66 | let gadget = builder.build(); 67 | 68 | let mut values = values!(x => 42u8.into(), y => 43u8.into()); 69 | let constraints_satisfied = gadget.execute(&mut values); 70 | assert!(!constraints_satisfied); 71 | } 72 | 73 | #[test] 74 | #[should_panic] 75 | fn missing_generator() { 76 | let mut builder = GadgetBuilder::::new(); 77 | let (x, y, z) = (builder.wire(), builder.wire(), builder.wire()); 78 | builder.assert_product(&Expression::from(x), &Expression::from(y), &Expression::from(z)); 79 | let gadget = builder.build(); 80 | 81 | let mut values = values!(x => 2u8.into(), y => 3u8.into()); 82 | gadget.execute(&mut values); 83 | } 84 | 85 | #[test] 86 | #[should_panic] 87 | fn missing_input() { 88 | let mut builder = GadgetBuilder::::new(); 89 | let x = builder.wire(); 90 | builder.inverse(&Expression::from(x)); 91 | let gadget = builder.build(); 92 | 93 | let mut values = WireValues::new(); 94 | gadget.execute(&mut values); 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/gadget_builder.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | use crate::constraint::Constraint; 5 | use crate::expression::{BooleanExpression, Expression}; 6 | use crate::field::{Element, Field}; 7 | use crate::gadget::Gadget; 8 | use crate::wire::{BinaryWire, BooleanWire, Wire}; 9 | use crate::wire_values::WireValues; 10 | use crate::witness_generator::WitnessGenerator; 11 | 12 | pub struct GadgetBuilder { 13 | next_wire_index: u32, 14 | constraints: Vec>, 15 | witness_generators: Vec>, 16 | } 17 | 18 | /// A utility for building `Gadget`s. See the readme for examples. 19 | #[allow(clippy::new_without_default)] 20 | impl GadgetBuilder { 21 | /// Creates a new `GadgetBuilder`, starting with no constraints or generators. 22 | pub fn new() -> Self { 23 | GadgetBuilder { 24 | next_wire_index: 1, 25 | constraints: Vec::new(), 26 | witness_generators: Vec::new(), 27 | } 28 | } 29 | 30 | /// Add a wire to the gadget. It will start with no generator and no associated constraints. 31 | pub fn wire(&mut self) -> Wire { 32 | let index = self.next_wire_index; 33 | self.next_wire_index += 1; 34 | Wire { index } 35 | } 36 | 37 | /// Add a wire to the gadget, whose value is constrained to equal 0 or 1. 38 | pub fn boolean_wire(&mut self) -> BooleanWire { 39 | let w = self.wire(); 40 | self.assert_boolean(&Expression::from(w)); 41 | BooleanWire::new_unsafe(w) 42 | } 43 | 44 | /// Add `n` wires to the gadget. They will start with no generator and no associated 45 | /// constraints. 46 | pub fn wires(&mut self, n: usize) -> Vec { 47 | (0..n).map(|_i| self.wire()).collect() 48 | } 49 | 50 | /// Add a binary wire comprised of `n` bits to the gadget. 51 | pub fn binary_wire(&mut self, n: usize) -> BinaryWire { 52 | BinaryWire { bits: (0..n).map(|_i| self.boolean_wire()).collect() } 53 | } 54 | 55 | /// Add a generator function for setting certain wire values. 56 | pub fn generator(&mut self, dependencies: Vec, generate: T) 57 | where T: Fn(&mut WireValues) + 'static { 58 | self.witness_generators.push(WitnessGenerator::new(dependencies, generate)); 59 | } 60 | 61 | /// x == y 62 | pub fn equal(&mut self, x: &Expression, y: &Expression) -> BooleanExpression { 63 | self.zero(&(x - y)) 64 | } 65 | 66 | /// x == 0 67 | pub fn zero(&mut self, x: &Expression) -> BooleanExpression { 68 | let nonzero = self.nonzero(x); 69 | self.not(&nonzero) 70 | } 71 | 72 | /// x != 0 73 | pub fn nonzero(&mut self, x: &Expression) -> BooleanExpression { 74 | // See the Pinocchio paper for an explanation. 75 | let (y, m) = (self.wire(), self.wire()); 76 | let (y_exp, m_exp) = (Expression::from(y), Expression::from(m)); 77 | self.assert_product(x, &m_exp, &y_exp); 78 | self.assert_product(&(Expression::one() - &y_exp), x, &Expression::zero()); 79 | 80 | let x = x.clone(); 81 | self.generator( 82 | x.dependencies(), 83 | move |values: &mut WireValues| { 84 | let x_value = x.evaluate(values); 85 | let y_value = if x_value.is_nonzero() { 86 | Element::one() 87 | } else { 88 | Element::zero() 89 | }; 90 | let m_value: Element = if x_value.is_nonzero() { 91 | &y_value / x_value 92 | } else { 93 | // The value of m doesn't matter if x = 0. 94 | Element::one() 95 | }; 96 | values.set(m, m_value); 97 | values.set(y, y_value); 98 | }, 99 | ); 100 | 101 | // y can only be 0 or 1 based on the constraints above. 102 | BooleanExpression::new_unsafe(y_exp) 103 | } 104 | 105 | /// if c { x } else { y }. Assumes c is binary. 106 | pub fn selection( 107 | &mut self, c: &BooleanExpression, x: &Expression, y: &Expression, 108 | ) -> Expression { 109 | y + self.product(c.expression(), &(x - y)) 110 | } 111 | 112 | /// Assert that x * y = z; 113 | pub fn assert_product(&mut self, x: &Expression, y: &Expression, z: &Expression) { 114 | self.constraints.push(Constraint { 115 | a: x.clone(), 116 | b: y.clone(), 117 | c: z.clone(), 118 | }); 119 | } 120 | 121 | /// Assert that the given quantity is in [0, 1], and return it as a `BooleanExpression`. 122 | pub fn assert_boolean(&mut self, x: &Expression) -> BooleanExpression { 123 | self.assert_product(x, &(x - Expression::one()), &Expression::zero()); 124 | BooleanExpression::new_unsafe(x.clone()) 125 | } 126 | 127 | /// Assert that x == y. 128 | pub fn assert_equal(&mut self, x: &Expression, y: &Expression) { 129 | self.assert_product(x, &Expression::one(), y); 130 | } 131 | 132 | /// Assert that x != y. 133 | pub fn assert_nonequal(&mut self, x: &Expression, y: &Expression) { 134 | let difference = x - y; 135 | self.assert_nonzero(&difference); 136 | } 137 | 138 | /// Assert that x == 0. 139 | pub fn assert_zero(&mut self, x: &Expression) { 140 | self.assert_equal(x, &Expression::zero()); 141 | } 142 | 143 | /// Assert that x != 0. 144 | pub fn assert_nonzero(&mut self, x: &Expression) { 145 | // A field element is non-zero iff it has a multiplicative inverse. 146 | // We don't care what the inverse is, but calling inverse(x) will require that it exists. 147 | self.inverse(x); 148 | } 149 | 150 | /// Assert that x == 1. 151 | pub fn assert_true(&mut self, x: &BooleanExpression) { 152 | self.assert_equal(x.expression(), &Expression::one()); 153 | } 154 | 155 | /// Assert that x == 0. 156 | pub fn assert_false(&mut self, x: &BooleanExpression) { 157 | self.assert_equal(x.expression(), &Expression::zero()); 158 | } 159 | 160 | /// Builds the gadget. 161 | pub fn build(self) -> Gadget { 162 | Gadget { 163 | constraints: self.constraints, 164 | witness_generators: self.witness_generators, 165 | } 166 | } 167 | } 168 | 169 | #[cfg(test)] 170 | mod tests { 171 | use crate::expression::{BooleanExpression, Expression}; 172 | use crate::field::Element; 173 | use crate::gadget_builder::GadgetBuilder; 174 | use crate::test_util::{assert_eq_false, assert_eq_true, F257}; 175 | 176 | #[test] 177 | fn assert_binary_0_1() { 178 | let mut builder = GadgetBuilder::::new(); 179 | let x = builder.wire(); 180 | builder.assert_boolean(&Expression::from(x)); 181 | let gadget = builder.build(); 182 | 183 | // With x = 0, the constraint should be satisfied. 184 | let mut values0 = values!(x => 0u8.into()); 185 | assert!(gadget.execute(&mut values0)); 186 | 187 | // With x = 1, the constraint should be satisfied. 188 | let mut values1 = values!(x => 1u8.into()); 189 | assert!(gadget.execute(&mut values1)); 190 | } 191 | 192 | #[test] 193 | fn assert_binary_2() { 194 | let mut builder = GadgetBuilder::::new(); 195 | let x = builder.wire(); 196 | builder.assert_boolean(&Expression::from(x)); 197 | let gadget = builder.build(); 198 | 199 | // With x = 2, the constraint should NOT be satisfied. 200 | let mut values2 = values!(x => 2u8.into()); 201 | assert!(!gadget.execute(&mut values2)); 202 | } 203 | 204 | #[test] 205 | fn selection() { 206 | let mut builder = GadgetBuilder::::new(); 207 | let (c, x, y) = (builder.boolean_wire(), builder.wire(), builder.wire()); 208 | let selection = builder.selection( 209 | &BooleanExpression::from(c), &Expression::from(x), &Expression::from(y)); 210 | let gadget = builder.build(); 211 | 212 | let values_3_5 = values!(x => 3u8.into(), y => 5u8.into()); 213 | 214 | let mut values_0_3_5 = values_3_5.clone(); 215 | values_0_3_5.set_boolean(c, false); 216 | assert!(gadget.execute(&mut values_0_3_5)); 217 | assert_eq!(Element::from(5u8), selection.evaluate(&values_0_3_5)); 218 | 219 | let mut values_1_3_5 = values_3_5.clone(); 220 | values_1_3_5.set_boolean(c, true); 221 | assert!(gadget.execute(&mut values_1_3_5)); 222 | assert_eq!(Element::from(3u8), selection.evaluate(&values_1_3_5)); 223 | } 224 | 225 | #[test] 226 | fn equal() { 227 | let mut builder = GadgetBuilder::::new(); 228 | let (x, y) = (builder.wire(), builder.wire()); 229 | let equal = builder.equal(&Expression::from(x), &Expression::from(y)); 230 | let gadget = builder.build(); 231 | 232 | let mut values_7_7 = values!(x => 7u8.into(), y => 7u8.into()); 233 | assert!(gadget.execute(&mut values_7_7)); 234 | assert_eq_true(&equal, &values_7_7); 235 | 236 | let mut values_6_7 = values!(x => 6u8.into(), y => 7u8.into()); 237 | assert!(gadget.execute(&mut values_6_7)); 238 | assert_eq_false(&equal, &values_6_7); 239 | 240 | let mut values_7_13 = values!(x => 7u8.into(), y => 13u8.into()); 241 | assert!(gadget.execute(&mut values_7_13)); 242 | assert_eq_false(&equal, &values_7_13); 243 | } 244 | } 245 | -------------------------------------------------------------------------------- /src/gadget_traits.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | use itertools::Itertools; 5 | 6 | use crate::{Element, Expression, Field, GadgetBuilder, WireValues}; 7 | 8 | /// A symmetric-key block cipher. 9 | pub trait BlockCipher { 10 | /// Encrypt the given input using the given key. 11 | fn encrypt(&self, builder: &mut GadgetBuilder, key: &Expression, input: &Expression) 12 | -> Expression; 13 | 14 | /// Decrypt the given output using the given key. 15 | fn decrypt(&self, builder: &mut GadgetBuilder, key: &Expression, output: &Expression) 16 | -> Expression; 17 | 18 | /// Like `encrypt`, but actually evaluates the encryption function rather than just adding it 19 | /// to a `GadgetBuilder`. 20 | fn encrypt_evaluate(&self, key: &Element, input: &Element) -> Element { 21 | let mut builder = GadgetBuilder::new(); 22 | let encrypted = self.encrypt( 23 | &mut builder, &Expression::from(key), &Expression::from(input)); 24 | let mut values = WireValues::new(); 25 | builder.build().execute(&mut values); 26 | encrypted.evaluate(&values) 27 | } 28 | 29 | /// Like `decrypt`, but actually evaluates the decryption function rather than just adding it 30 | /// to a `GadgetBuilder`. 31 | fn decrypt_evaluate(&self, key: &Element, output: &Element) -> Element { 32 | let mut builder = GadgetBuilder::new(); 33 | let decrypted = self.decrypt( 34 | &mut builder, &Expression::from(key), &Expression::from(output)); 35 | let mut values = WireValues::new(); 36 | builder.build().execute(&mut values); 37 | decrypted.evaluate(&values) 38 | } 39 | } 40 | 41 | /// A function which compresses two field elements into one, and is intended to be one-way. 42 | pub trait CompressionFunction { 43 | /// Compress two field elements into one. 44 | fn compress(&self, builder: &mut GadgetBuilder, x: &Expression, y: &Expression) 45 | -> Expression; 46 | 47 | /// Like `compress`, but actually evaluates the compression function rather than just adding it 48 | /// to a `GadgetBuilder`. 49 | fn compress_evaluate(&self, x: &Element, y: &Element) -> Element { 50 | let mut builder = GadgetBuilder::new(); 51 | let compressed = self.compress(&mut builder, &Expression::from(x), &Expression::from(y)); 52 | let mut values = WireValues::new(); 53 | builder.build().execute(&mut values); 54 | compressed.evaluate(&values) 55 | } 56 | } 57 | 58 | /// A permutation of single field elements. 59 | pub trait Permutation { 60 | /// Permute the given field element. 61 | fn permute(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression; 62 | 63 | /// Like `permute`, but actually evaluates the permutation rather than just adding it to a 64 | /// `GadgetBuilder`. 65 | fn permute_evaluate(&self, x: &Element) -> Element { 66 | let mut builder = GadgetBuilder::new(); 67 | let permuted = self.permute(&mut builder, &Expression::from(x)); 68 | let mut values = WireValues::new(); 69 | builder.build().execute(&mut values); 70 | permuted.evaluate(&values) 71 | } 72 | 73 | /// Apply the inverse of this permutation to the given field element. 74 | fn inverse(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression; 75 | 76 | /// Like `inverse`, but actually evaluates the inverse permutation rather than just adding it to 77 | /// a `GadgetBuilder`. 78 | fn inverse_evaluate(&self, x: &Element) -> Element { 79 | let mut builder = GadgetBuilder::new(); 80 | let inverse = self.inverse(&mut builder, &Expression::from(x)); 81 | let mut values = WireValues::new(); 82 | builder.build().execute(&mut values); 83 | inverse.evaluate(&values) 84 | } 85 | } 86 | 87 | /// A permutation whose inputs and outputs consist of multiple field elements. 88 | pub trait MultiPermutation { 89 | /// The size of the permutation, in field elements. 90 | fn width(&self) -> usize; 91 | 92 | /// Permute the given sequence of field elements. 93 | fn permute(&self, builder: &mut GadgetBuilder, inputs: &[Expression]) 94 | -> Vec>; 95 | 96 | /// Like `permute`, but actually evaluates the permutation rather than just adding it to a 97 | /// `GadgetBuilder`. 98 | fn permute_evaluate(&self, inputs: &[Element]) -> Vec> { 99 | let mut builder = GadgetBuilder::new(); 100 | let input_expressions = inputs.iter().map(Expression::from).collect_vec(); 101 | let permuted = self.permute(&mut builder, &input_expressions); 102 | let mut values = WireValues::new(); 103 | builder.build().execute(&mut values); 104 | permuted.iter().map(|exp| exp.evaluate(&values)).collect() 105 | } 106 | 107 | /// Apply the inverse of this permutation to the given sequence of field elements. 108 | fn inverse(&self, builder: &mut GadgetBuilder, outputs: &[Expression]) 109 | -> Vec>; 110 | 111 | /// Like `inverse`, but actually evaluates the inverse permutation rather than just adding it to 112 | /// a `GadgetBuilder`. 113 | fn inverse_evaluate(&self, outputs: &[Element]) -> Vec> { 114 | let mut builder = GadgetBuilder::new(); 115 | let output_expressions = outputs.iter().map(Expression::from).collect_vec(); 116 | let inversed = self.inverse(&mut builder, &output_expressions); 117 | let mut values = WireValues::new(); 118 | builder.build().execute(&mut values); 119 | inversed.iter().map(|exp| exp.evaluate(&values)).collect() 120 | } 121 | } 122 | 123 | /// A function which hashes a sequence of field elements, outputting a single field element. 124 | pub trait HashFunction { 125 | fn hash(&self, builder: &mut GadgetBuilder, blocks: &[Expression]) -> Expression; 126 | 127 | /// Like `hash`, but actually evaluates the hash function rather than just adding it to a 128 | /// `GadgetBuilder`. 129 | fn hash_evaluate(&self, blocks: &[Element]) -> Element { 130 | let mut builder = GadgetBuilder::new(); 131 | let block_expressions = blocks.iter().map(Expression::from).collect_vec(); 132 | let hash = self.hash(&mut builder, &block_expressions); 133 | let mut values = WireValues::new(); 134 | builder.build().execute(&mut values); 135 | hash.evaluate(&values) 136 | } 137 | } -------------------------------------------------------------------------------- /src/group.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | use std::marker::PhantomData; 4 | 5 | use crate::{BooleanExpression, Element, Evaluable, Expression, Field, GadgetBuilder, WireValues}; 6 | 7 | pub trait Group where Self::GroupExpression: for<'a> From<&'a Self::GroupElement>, 8 | Self::GroupExpression: Evaluable, 9 | Self::GroupExpression: GroupExpression, 10 | Self::GroupExpression: Clone { 11 | type GroupElement; 12 | type GroupExpression; 13 | 14 | fn identity_element() -> Self::GroupElement; 15 | 16 | fn identity_expression() -> Self::GroupExpression { 17 | Self::GroupExpression::from(&Self::identity_element()) 18 | } 19 | 20 | fn add_expressions( 21 | builder: &mut GadgetBuilder, 22 | lhs: &Self::GroupExpression, 23 | rhs: &Self::GroupExpression, 24 | ) -> Self::GroupExpression; 25 | 26 | fn add_elements( 27 | lhs: &Self::GroupElement, 28 | rhs: &Self::GroupElement, 29 | ) -> Self::GroupElement { 30 | let lhs_exp = Self::GroupExpression::from(lhs); 31 | let rhs_exp = Self::GroupExpression::from(rhs); 32 | 33 | let mut builder = GadgetBuilder::new(); 34 | let sum = Self::add_expressions(&mut builder, &lhs_exp, &rhs_exp); 35 | let mut values = WireValues::new(); 36 | builder.build().execute(&mut values); 37 | sum.evaluate(&values) 38 | } 39 | 40 | fn double_expression( 41 | builder: &mut GadgetBuilder, 42 | expression: &Self::GroupExpression, 43 | ) -> Self::GroupExpression { 44 | Self::add_expressions(builder, expression, expression) 45 | } 46 | 47 | fn double_element(element: &Self::GroupElement) -> Self::GroupElement { 48 | Self::add_elements(element, element) 49 | } 50 | 51 | /// Performs scalar multiplication in constraints by first splitting up a scalar into 52 | /// a binary representation, and then performing the naive double-or-add algorithm. This 53 | /// implementation is generic for all groups. 54 | fn mul_scalar_expression( 55 | builder: &mut GadgetBuilder, 56 | expression: &Self::GroupExpression, 57 | scalar: &Expression, 58 | ) -> Self::GroupExpression { 59 | let scalar_binary = builder.split_allowing_ambiguity(&scalar); 60 | 61 | let mut sum = Self::identity_expression(); 62 | let mut current = expression.clone(); 63 | for bit in scalar_binary.bits { 64 | let boolean_product = Self::mul_boolean_expression(builder, ¤t, &bit); 65 | sum = Self::add_expressions(builder, &sum, &boolean_product); 66 | current = Self::double_expression(builder, ¤t); 67 | } 68 | sum 69 | } 70 | 71 | /// Like `mul_scalar_expression`, but actually evaluates the compression function rather than just adding it 72 | /// to a `GadgetBuilder`. 73 | fn mul_scalar_element( 74 | element: &Self::GroupElement, 75 | scalar: &Element, 76 | ) -> Self::GroupElement { 77 | let mut builder = GadgetBuilder::new(); 78 | let new_point = Self::mul_scalar_expression( 79 | &mut builder, 80 | &Self::GroupExpression::from(element), 81 | &Expression::from(scalar), 82 | ); 83 | let mut values = WireValues::new(); 84 | builder.build().execute(&mut values); 85 | new_point.evaluate(&values) 86 | } 87 | 88 | /// Given a boolean element, return the given element if element is on, otherwise 89 | /// return the identity. 90 | fn mul_boolean_expression( 91 | builder: &mut GadgetBuilder, 92 | expression: &Self::GroupExpression, 93 | boolean: &BooleanExpression, 94 | ) -> Self::GroupExpression { 95 | let coordinates = expression.to_components(); 96 | 97 | let mut r = Vec::new(); 98 | let ic = Self::identity_expression().to_components(); 99 | 100 | for (i, x) in coordinates.iter().enumerate() { 101 | r.push(builder.selection(boolean, &x, &ic[i])); 102 | } 103 | 104 | Self::GroupExpression::from_components_unsafe(r) 105 | } 106 | } 107 | 108 | /// A trait that defines a generator `g` for a cyclic group in which every element 109 | /// is defined as `g^a` for some scalar `a`. 110 | pub trait CyclicGroup: Group { 111 | fn generator_element() -> Self::GroupElement; 112 | 113 | fn generator_expression() -> Self::GroupExpression { 114 | Self::GroupExpression::from(&Self::generator_element()) 115 | } 116 | } 117 | 118 | pub trait CyclicGenerator> { 119 | fn generator_element() -> G::GroupElement; 120 | 121 | fn generator_expression() -> G::GroupExpression { 122 | G::GroupExpression::from(&Self::generator_element()) 123 | } 124 | } 125 | 126 | pub struct CyclicSubgroup, C: CyclicGenerator> { 127 | phantom_f: PhantomData<*const F>, 128 | phantom_g: PhantomData<*const G>, 129 | phantom_c: PhantomData<*const C>, 130 | } 131 | 132 | impl, C: CyclicGenerator> Group for CyclicSubgroup { 133 | type GroupElement = G::GroupElement; 134 | type GroupExpression = G::GroupExpression; 135 | 136 | fn identity_element() -> Self::GroupElement { 137 | G::identity_element() 138 | } 139 | 140 | fn add_expressions( 141 | builder: &mut GadgetBuilder, 142 | lhs: &Self::GroupExpression, 143 | rhs: &Self::GroupExpression 144 | ) -> Self::GroupExpression { 145 | G::add_expressions(builder, lhs, rhs) 146 | } 147 | 148 | fn double_expression( 149 | builder: &mut GadgetBuilder, 150 | expression: &Self::GroupExpression 151 | ) -> Self::GroupExpression { 152 | G::double_expression(builder, expression) 153 | } 154 | } 155 | 156 | impl, C: CyclicGenerator> CyclicGroup for CyclicSubgroup { 157 | fn generator_element() -> Self::GroupElement { 158 | C::generator_element() 159 | } 160 | 161 | fn generator_expression() -> Self::GroupExpression { 162 | C::generator_expression() 163 | } 164 | } 165 | 166 | /// Applies a (not necessarily injective) map, defined from a group to the field, 167 | /// to an expression corresponding to an element in the group. 168 | pub trait GroupExpression { 169 | fn compressed(&self) -> &Expression; 170 | fn to_components(&self) -> Vec>; 171 | fn from_components_unsafe(components: Vec>) -> Self; 172 | } -------------------------------------------------------------------------------- /src/lcg.rs: -------------------------------------------------------------------------------- 1 | /// This module provides a linear congruential generator for (not cryptographically secure) random 2 | /// data. 3 | 4 | #[cfg(not(feature = "std"))] 5 | use alloc::vec::Vec; 6 | 7 | use num::BigUint; 8 | use num_traits::One; 9 | 10 | use crate::field::{Element, Field}; 11 | 12 | /// A simple linear congruential generator, with parameters taken from Numerical Recipes. 13 | #[derive(Default)] 14 | pub struct LCG { 15 | state: u32 16 | } 17 | 18 | impl LCG { 19 | pub fn new() -> Self { 20 | LCG { state: 0 } 21 | } 22 | 23 | pub fn next_u32(&mut self) -> u32 { 24 | self.state = self.state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223); 25 | self.state 26 | } 27 | 28 | pub fn next_element(&mut self) -> Element { 29 | Element::from(self.next_biguint(F::order())) 30 | } 31 | 32 | pub fn next_biguint(&mut self, limit_exclusive: BigUint) -> BigUint { 33 | let bits = (&limit_exclusive - BigUint::one()).bits() as usize; 34 | loop { 35 | let n = self.next_biguint_bits(bits); 36 | if n < limit_exclusive { 37 | return n; 38 | } 39 | } 40 | } 41 | 42 | fn next_biguint_bits(&mut self, bits: usize) -> BigUint { 43 | let full_chunks = bits / 32; 44 | let remaining_bits = bits % 32; 45 | let partial_chunk = remaining_bits > 0; 46 | 47 | let mut chunk_data = Vec::new(); 48 | for _i in 0..full_chunks { 49 | chunk_data.push(self.next_u32()); 50 | } 51 | if partial_chunk { 52 | chunk_data.push(self.next_u32() % (1 << remaining_bits)) 53 | } 54 | BigUint::new(chunk_data) 55 | } 56 | } 57 | 58 | #[cfg(test)] 59 | mod tests { 60 | use crate::lcg::LCG; 61 | 62 | #[test] 63 | fn next_u32() { 64 | let mut lcg = LCG::new(); 65 | assert_eq!(lcg.next_u32(), 1013904223); 66 | assert_eq!(lcg.next_u32(), 1196435762); 67 | assert_eq!(lcg.next_u32(), 3519870697); 68 | assert_eq!(lcg.next_u32(), 2868466484); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // TODO: Copy some examples etc. here when the API is more stable. 2 | 3 | //! This is a rust library for building R1CS gadgets over prime fields, which are useful in SNARKs 4 | //! and other argument systems. 5 | //! 6 | //! See the [readme](https://github.com/mir-protocol/r1cs) for more information and examples. 7 | 8 | #![cfg_attr(not(feature = "std"), no_std)] 9 | 10 | #[cfg(feature = "std")] 11 | #[macro_use] 12 | extern crate std; 13 | 14 | #[cfg(not(feature = "std"))] 15 | #[macro_use] 16 | extern crate core as std; 17 | 18 | #[cfg(not(feature = "std"))] 19 | #[macro_use] 20 | extern crate alloc; 21 | 22 | pub use num; 23 | 24 | pub use constraint::*; 25 | pub use curves::*; 26 | pub use davies_meyer::*; 27 | pub use expression::*; 28 | pub use field::*; 29 | pub use gadget::*; 30 | pub use gadget_builder::*; 31 | pub use gadget_traits::*; 32 | pub use group::*; 33 | pub use lcg::*; 34 | pub use matrices::*; 35 | pub use merkle_damgard::*; 36 | pub use merkle_trees::*; 37 | pub use mimc::*; 38 | pub use miyaguchi_preneel::*; 39 | pub use permutations::*; 40 | pub use poseidon::*; 41 | pub use rescue::*; 42 | pub use sponge::*; 43 | pub use wire::*; 44 | pub use wire_values::*; 45 | pub use witness_generator::*; 46 | 47 | #[macro_use] 48 | mod wire_values; 49 | 50 | mod bimap_util; 51 | mod binary_arithmetic; 52 | mod bitwise_operations; 53 | mod boolean_algebra; 54 | mod comparisons; 55 | mod constraint; 56 | mod curves; 57 | mod davies_meyer; 58 | mod expression; 59 | mod field; 60 | mod field_arithmetic; 61 | mod gadget; 62 | mod gadget_builder; 63 | mod gadget_traits; 64 | mod group; 65 | mod lcg; 66 | mod matrices; 67 | mod merkle_damgard; 68 | mod merkle_trees; 69 | mod mimc; 70 | mod miyaguchi_preneel; 71 | mod permutations; 72 | mod poseidon; 73 | mod random_access; 74 | mod rescue; 75 | mod signature; 76 | mod sorting; 77 | mod splitting; 78 | mod sponge; 79 | mod util; 80 | mod verify_permutation; 81 | mod wire; 82 | mod witness_generator; 83 | 84 | #[cfg(test)] 85 | mod test_util; 86 | -------------------------------------------------------------------------------- /src/matrices.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | use std::ops::Mul; 5 | 6 | use crate::{Element, Expression, Field}; 7 | 8 | /// A matrix of prime field elements. 9 | pub struct ElementMatrix { 10 | rows: Vec>>, 11 | } 12 | 13 | impl ElementMatrix { 14 | pub fn new(rows: Vec>>) -> Self { 15 | assert!(!rows.is_empty(), "Expected at least one row"); 16 | let num_cols = rows[0].len(); 17 | assert!(num_cols > 0, "Expected at least one column"); 18 | for row in rows.iter() { 19 | assert_eq!(row.len(), num_cols, "Rows must have uniform length"); 20 | } 21 | ElementMatrix { rows } 22 | } 23 | } 24 | 25 | impl Clone for ElementMatrix { 26 | fn clone(&self) -> Self { 27 | ElementMatrix { rows: self.rows.clone() } 28 | } 29 | } 30 | 31 | impl Mul<&[Element]> for &ElementMatrix { 32 | type Output = Vec>; 33 | 34 | fn mul(self, rhs: &[Element]) -> Self::Output { 35 | self.rows.iter().zip(rhs.iter()) 36 | .map(|(row, val)| row.iter().fold( 37 | Element::zero(), |sum, row_i| sum + val * row_i)) 38 | .collect() 39 | } 40 | } 41 | 42 | impl Mul<&[Expression]> for &ElementMatrix { 43 | type Output = Vec>; 44 | 45 | fn mul(self, rhs: &[Expression]) -> Self::Output { 46 | self.rows.iter().zip(rhs.iter()) 47 | .map(|(row, val)| row.iter().fold( 48 | Expression::zero(), |sum, row_i| sum + val * row_i)) 49 | .collect() 50 | } 51 | } 52 | 53 | impl Mul<&[Element]> for ElementMatrix { 54 | type Output = Vec>; 55 | 56 | fn mul(self, rhs: &[Element]) -> Self::Output { 57 | &self * rhs 58 | } 59 | } 60 | 61 | impl Mul<&[Expression]> for ElementMatrix { 62 | type Output = Vec>; 63 | 64 | fn mul(self, rhs: &[Expression]) -> Self::Output { 65 | &self * rhs 66 | } 67 | } 68 | 69 | /// A Maximum Distance Separable matrix. 70 | pub struct MdsMatrix { 71 | matrix: ElementMatrix, 72 | } 73 | 74 | impl MdsMatrix { 75 | pub fn new(rows: Vec>>) -> Self { 76 | // TODO: Verify the MDS diffusion property. 77 | MdsMatrix { matrix: ElementMatrix::new(rows) } 78 | } 79 | 80 | pub fn inverse(&self) -> Self { 81 | unimplemented!("TODO: Implement inverse") 82 | } 83 | } 84 | 85 | impl Clone for MdsMatrix { 86 | fn clone(&self) -> Self { 87 | MdsMatrix { matrix: self.matrix.clone() } 88 | } 89 | } 90 | 91 | impl Mul<&[Element]> for &MdsMatrix { 92 | type Output = Vec>; 93 | 94 | fn mul(self, rhs: &[Element]) -> Self::Output { 95 | &self.matrix * rhs 96 | } 97 | } 98 | 99 | impl Mul<&[Expression]> for &MdsMatrix { 100 | type Output = Vec>; 101 | 102 | fn mul(self, rhs: &[Expression]) -> Self::Output { 103 | &self.matrix * rhs 104 | } 105 | } 106 | 107 | impl Mul<&[Element]> for MdsMatrix { 108 | type Output = Vec>; 109 | 110 | fn mul(self, rhs: &[Element]) -> Self::Output { 111 | self.matrix * rhs 112 | } 113 | } 114 | 115 | impl Mul<&[Expression]> for MdsMatrix { 116 | type Output = Vec>; 117 | 118 | fn mul(self, rhs: &[Expression]) -> Self::Output { 119 | self.matrix * rhs 120 | } 121 | } 122 | 123 | #[cfg(test)] 124 | mod tests { 125 | #[test] 126 | fn matrix_vector_multiplication() { 127 | // TODO 128 | } 129 | } -------------------------------------------------------------------------------- /src/merkle_damgard.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with an implementation of the Merkle-Damgård construction. 2 | 3 | use crate::{CompressionFunction, HashFunction}; 4 | use crate::expression::Expression; 5 | use crate::field::{Element, Field}; 6 | use crate::gadget_builder::GadgetBuilder; 7 | use crate::lcg::LCG; 8 | 9 | /// A hash function based on the Merkle–Damgård construction. 10 | pub struct MerkleDamgard> { 11 | initial_value: Element, 12 | compress: CF, 13 | } 14 | 15 | impl> MerkleDamgard { 16 | /// Creates a Merkle–Damgård hash function from the given initial value and one-way compression 17 | /// function. 18 | pub fn new(initial_value: Element, compress: CF) -> Self { 19 | MerkleDamgard { initial_value, compress } 20 | } 21 | 22 | /// Creates a Merkle–Damgård hash function from the given one-way compression function. Uses a 23 | /// simple LCG (seeded with 0) as a source of randomness for the initial value. 24 | pub fn new_defaults(compress: CF) -> Self { 25 | let initial_value = LCG::new().next_element(); 26 | Self::new(initial_value, compress) 27 | } 28 | } 29 | 30 | impl> HashFunction for MerkleDamgard { 31 | fn hash(&self, builder: &mut GadgetBuilder, blocks: &[Expression]) -> Expression { 32 | let mut current = Expression::from(&self.initial_value); 33 | for block in blocks { 34 | current = self.compress.compress(builder, ¤t, block); 35 | } 36 | 37 | // Length padding 38 | self.compress.compress(builder, ¤t, &Expression::from(blocks.len())) 39 | } 40 | } 41 | 42 | #[cfg(test)] 43 | mod tests { 44 | use crate::{CompressionFunction, HashFunction, MerkleDamgard}; 45 | use crate::expression::Expression; 46 | use crate::field::{Element, Field}; 47 | use crate::gadget_builder::GadgetBuilder; 48 | use crate::test_util::F7; 49 | 50 | #[test] 51 | fn merkle_damgard() { 52 | // We will use a trivial compression function to keep the test simple. 53 | struct TestCompress; 54 | 55 | impl CompressionFunction for TestCompress { 56 | fn compress( 57 | &self, _builder: &mut GadgetBuilder, x: &Expression, y: &Expression, 58 | ) -> Expression { 59 | x * 2 + y * 3 60 | } 61 | } 62 | 63 | let mut builder = GadgetBuilder::::new(); 64 | let x_wire = builder.wire(); 65 | let y_wire = builder.wire(); 66 | let x = Expression::from(x_wire); 67 | let y = Expression::from(y_wire); 68 | let blocks = &[x, y]; 69 | let md = MerkleDamgard::new(Element::from(2u8), TestCompress); 70 | let hash = md.hash(&mut builder, blocks); 71 | let gadget = builder.build(); 72 | 73 | let mut values = values!(x_wire => 3u8.into(), y_wire => 4u8.into()); 74 | assert!(gadget.execute(&mut values)); 75 | // initial value: 2 76 | // after 3: 2*2 + 3*3 = 6 77 | // after 4: 6*2 + 4*3 = 3 78 | // after 2 (length): 3*2 + 2*3 = 5 79 | assert_eq!(Element::from(5u8), hash.evaluate(&values)); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/merkle_trees.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | use crate::expression::{BinaryExpression, BooleanExpression, Expression}; 5 | use crate::field::Field; 6 | use crate::gadget_builder::GadgetBuilder; 7 | use crate::gadget_traits::CompressionFunction; 8 | 9 | /// The path from a leaf to the root of a binary Merkle tree. 10 | #[derive(Debug)] 11 | pub struct MerklePath { 12 | /// The sequence of "turns" when traversing up the tree. The value of each bit indicates the 13 | /// index of the target node relative to its parent. For example, a zero bit indicates that the 14 | /// target node is the left child, and its sibling is the right child. 15 | prefix: BinaryExpression, 16 | /// The sequence of (hashes of) sibling nodes which are encountered along the path up the tree. 17 | siblings: Vec>, 18 | } 19 | 20 | impl MerklePath { 21 | pub fn new(prefix: BinaryExpression, siblings: Vec>) -> Self { 22 | assert_eq!(prefix.len(), siblings.len()); 23 | MerklePath { prefix, siblings } 24 | } 25 | } 26 | 27 | impl Clone for MerklePath { 28 | fn clone(&self) -> Self { 29 | MerklePath { 30 | prefix: self.prefix.clone(), 31 | siblings: self.siblings.clone(), 32 | } 33 | } 34 | } 35 | 36 | impl GadgetBuilder { 37 | /// Update an intermediate hash value in a Merkle tree, given the sibling at the current layer. 38 | fn merkle_tree_step( 39 | &mut self, 40 | node: &Expression, 41 | sibling: &Expression, 42 | prefix_bit: &BooleanExpression, 43 | compress: &CF, 44 | ) -> Expression where CF: CompressionFunction { 45 | let left = self.selection(prefix_bit, sibling, node); 46 | let right = sibling + node - &left; 47 | compress.compress(self, &left, &right) 48 | } 49 | 50 | /// Compute a Merkle root given a leaf value and its Merkle path. 51 | pub fn merkle_tree_root( 52 | &mut self, 53 | leaf: &Expression, 54 | path: &MerklePath, 55 | compress: &CF, 56 | ) -> Expression where CF: CompressionFunction { 57 | let mut current = leaf.clone(); 58 | for (prefix_bit, sibling) in path.prefix.bits.iter().zip(path.siblings.iter()) { 59 | current = self.merkle_tree_step( 60 | ¤t, sibling, prefix_bit, compress); 61 | } 62 | current 63 | } 64 | 65 | pub fn assert_merkle_tree_membership( 66 | &mut self, 67 | leaf: &Expression, 68 | purported_root: &Expression, 69 | path: &MerklePath, 70 | compress: &CF, 71 | ) where CF: CompressionFunction { 72 | let computed_root = self.merkle_tree_root(leaf, path, compress); 73 | self.assert_equal(purported_root, &computed_root) 74 | } 75 | } 76 | 77 | #[cfg(test)] 78 | mod tests { 79 | use num::BigUint; 80 | 81 | use crate::expression::{BinaryExpression, BooleanExpression, Expression}; 82 | use crate::field::{Element, Field}; 83 | use crate::gadget_builder::GadgetBuilder; 84 | use crate::gadget_traits::CompressionFunction; 85 | use crate::merkle_trees::MerklePath; 86 | use crate::test_util::{F257, F7}; 87 | 88 | #[test] 89 | fn merkle_step() { 90 | let mut builder = GadgetBuilder::::new(); 91 | let node = builder.wire(); 92 | let sibling = builder.wire(); 93 | let is_right = builder.boolean_wire(); 94 | let parent_hash = builder.merkle_tree_step( 95 | &Expression::from(node), &Expression::from(sibling), 96 | &BooleanExpression::from(is_right), &TestCompress); 97 | let gadget = builder.build(); 98 | 99 | let mut values_3_4 = values!(node => 3u8.into(), sibling => 4u8.into()); 100 | values_3_4.set_boolean(is_right, false); 101 | assert!(gadget.execute(&mut values_3_4)); 102 | assert_eq!(Element::from(10u8), parent_hash.evaluate(&values_3_4)); 103 | 104 | let mut values_4_3 = values!(node => 3u8.into(), sibling => 4u8.into()); 105 | values_4_3.set_boolean(is_right, true); 106 | assert!(gadget.execute(&mut values_4_3)); 107 | assert_eq!(Element::from(11u8), parent_hash.evaluate(&values_4_3)); 108 | } 109 | 110 | #[test] 111 | fn merkle_root() { 112 | let mut builder = GadgetBuilder::::new(); 113 | let prefix_wire = builder.binary_wire(3); 114 | let (sibling_1, sibling_2, sibling_3) = (builder.wire(), builder.wire(), builder.wire()); 115 | let path = MerklePath::new( 116 | BinaryExpression::from(&prefix_wire), 117 | vec![sibling_1.into(), sibling_2.into(), sibling_3.into()]); 118 | let root_hash = builder.merkle_tree_root(&Expression::one(), &path, &TestCompress); 119 | let gadget = builder.build(); 120 | 121 | let mut values = values!( 122 | sibling_1 => 3u8.into(), 123 | sibling_2 => 3u8.into(), 124 | sibling_3 => 9u8.into()); 125 | values.set_binary_unsigned(&prefix_wire, &BigUint::from(0b010u8)); 126 | assert!(gadget.execute(&mut values)); 127 | // The leaf is 1; the first parent hash is 2*1 + 3 = 5; the next parent hash is 128 | // 2*3 + 5 = 11; the root is 2*11 + 9 = 31. 129 | assert_eq!(Element::from(31u8), root_hash.evaluate(&values)); 130 | } 131 | 132 | // Tests whether large path Sparse Merkle Trees are possible 133 | #[test] 134 | fn large_merkle_root() { 135 | let mut builder = GadgetBuilder::::new(); 136 | let prefix_wire = builder.binary_wire(8); 137 | let (sibling_1, sibling_2, sibling_3, sibling_4, sibling_5, sibling_6, sibling_7, sibling_8) 138 | = (builder.wire(), builder.wire(), builder.wire(), builder.wire(), builder.wire(), builder.wire(), builder.wire(), builder.wire()); 139 | let path = MerklePath::new( 140 | BinaryExpression::from(&prefix_wire), 141 | vec![sibling_1.into(), sibling_2.into(), sibling_3.into(), sibling_4.into(), sibling_5.into(), sibling_6.into(), sibling_7.into(), sibling_8.into()]); 142 | let root_hash = builder.merkle_tree_root(&Expression::one(), &path, &TestCompress); 143 | let gadget = builder.build(); 144 | 145 | let mut values = values!( 146 | sibling_1 => 1u8.into(), 147 | sibling_2 => 1u8.into(), 148 | sibling_3 => 1u8.into(), 149 | sibling_4 => 1u8.into(), 150 | sibling_5 => 1u8.into(), 151 | sibling_6 => 1u8.into(), 152 | sibling_7 => 1u8.into(), 153 | sibling_8 => 1u8.into() 154 | ); 155 | values.set_binary_unsigned(&prefix_wire, &BigUint::from(0b00000000u8)); 156 | assert!(gadget.execute(&mut values)); 157 | // The leaf is 1; the first parent hash is 2*1 + 1 = 3; 2*3 + 1 == 0; 2*0 + 1 = 1; 2*1 + 1 = 3; 2*3 + 1 == 0; 2*0 + 1 = 1; 2*1 + 1 = 3; 158 | // the root is 2*3 + 1 == 0. 159 | assert_eq!( 160 | Element::from(0u8), 161 | root_hash.evaluate(&values)); 162 | } 163 | 164 | // A dummy compression function which returns 2x + y. 165 | struct TestCompress; 166 | 167 | impl CompressionFunction for TestCompress { 168 | fn compress(&self, _builder: &mut GadgetBuilder, x: &Expression, y: &Expression) 169 | -> Expression { 170 | x * 2 + y 171 | } 172 | } 173 | } -------------------------------------------------------------------------------- /src/mimc.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with an implementation of MiMC. 2 | 3 | #[cfg(not(feature = "std"))] 4 | use alloc::vec::Vec; 5 | 6 | use crate::expression::Expression; 7 | use crate::field::{Element, Field}; 8 | use crate::gadget_builder::GadgetBuilder; 9 | use crate::gadget_traits::{BlockCipher, Permutation}; 10 | use crate::lcg::LCG; 11 | use crate::MonomialPermutation; 12 | 13 | /// The MiMC block cipher. 14 | pub struct MiMCBlockCipher { 15 | round_constants: Vec>, 16 | round_permutation: MonomialPermutation, 17 | } 18 | 19 | impl MiMCBlockCipher { 20 | /// Creates an instance of the MiMC block cipher with the given round constants, which should be 21 | /// generated randomly. 22 | /// 23 | /// The number of rounds will be `round_constants.len() + 1`, since the first round has no 24 | /// random constant. 25 | fn new(round_constants: &[Element]) -> Self { 26 | let round_permutation = MonomialPermutation::new(Element::from(3u8)); 27 | let round_constants = round_constants.to_vec(); 28 | MiMCBlockCipher { round_permutation, round_constants } 29 | } 30 | } 31 | 32 | impl Default for MiMCBlockCipher { 33 | /// Configures MiMC with the number of rounds recommended in the paper. Uses a simple LCG 34 | /// (seeded with 0) as the source of randomness for these constants. 35 | fn default() -> Self { 36 | let mut round_constants = Vec::new(); 37 | let mut lcg = LCG::new(); 38 | for _r in 0..mimc_recommended_rounds::() { 39 | round_constants.push(lcg.next_element()); 40 | } 41 | MiMCBlockCipher::new(&round_constants) 42 | } 43 | } 44 | 45 | impl BlockCipher for MiMCBlockCipher { 46 | fn encrypt(&self, builder: &mut GadgetBuilder, key: &Expression, input: &Expression) 47 | -> Expression { 48 | let mut current = input.clone(); 49 | 50 | // In the first round, there is no round constant, so just add the key. 51 | current += key; 52 | 53 | // Cube the current value. 54 | current = self.round_permutation.permute(builder, ¤t); 55 | 56 | for round_constant in self.round_constants.iter() { 57 | // Add the key and the random round constant. 58 | current += key + Expression::from(round_constant); 59 | 60 | // Cube the current value. 61 | current = self.round_permutation.permute(builder, ¤t); 62 | } 63 | 64 | // Final key addition, as per the spec. 65 | current + key 66 | } 67 | 68 | fn decrypt(&self, builder: &mut GadgetBuilder, key: &Expression, output: &Expression) 69 | -> Expression { 70 | let mut current = output.clone(); 71 | 72 | // Undo final key adddition. 73 | current -= key; 74 | 75 | for round_constant in self.round_constants.iter().rev() { 76 | // Undo the cubing permutation. 77 | current = self.round_permutation.inverse(builder, ¤t); 78 | 79 | // Undo the key and random round constant additions. 80 | current -= key + Expression::from(round_constant); 81 | } 82 | 83 | // Undo the first round cubing and key addition. (There is no constant in the first round.) 84 | current = self.round_permutation.inverse(builder, ¤t); 85 | current - key 86 | } 87 | } 88 | 89 | /// The MiMC permutation, which is equivalent to MiMC encryption with a key of zero. 90 | // TODO: Consider merging the two structs. 91 | // TODO: Implement Default. 92 | pub struct MiMCPermutation { 93 | cipher: MiMCBlockCipher 94 | } 95 | 96 | impl Permutation for MiMCPermutation { 97 | fn permute(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 98 | // As per the paper, we just use a key of zero. 99 | self.cipher.encrypt(builder, &Expression::zero(), x) 100 | } 101 | 102 | fn inverse(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 103 | self.cipher.decrypt(builder, &Expression::zero(), x) 104 | } 105 | } 106 | 107 | /// The recommended number of rounds to use in MiMC, based on the paper. 108 | fn mimc_recommended_rounds() -> usize { 109 | let n = Element::::max_bits(); 110 | (n as f64 / 3f64.log2()).ceil() as usize 111 | } 112 | 113 | #[cfg(test)] 114 | mod tests { 115 | use crate::expression::Expression; 116 | use crate::field::Element; 117 | use crate::gadget_builder::GadgetBuilder; 118 | use crate::gadget_traits::BlockCipher; 119 | use crate::mimc::MiMCBlockCipher; 120 | use crate::test_util::{F11, F7}; 121 | 122 | #[test] 123 | fn mimc_encrypt_and_decrypt() { 124 | let mut builder = GadgetBuilder::::new(); 125 | let key_wire = builder.wire(); 126 | let input_wire = builder.wire(); 127 | let key = Expression::from(key_wire); 128 | let input = Expression::from(input_wire); 129 | let mimc = MiMCBlockCipher::default(); 130 | let encrypted = mimc.encrypt(&mut builder, &key, &input); 131 | let decrypted = mimc.decrypt(&mut builder, &key, &encrypted); 132 | let gadget = builder.build(); 133 | 134 | let mut values = values!(key_wire => 2u8.into(), input_wire => 3u8.into()); 135 | assert!(gadget.execute(&mut values)); 136 | assert_eq!(input.evaluate(&values), decrypted.evaluate(&values)); 137 | } 138 | 139 | #[test] 140 | fn mimc_f11() { 141 | let constants = &[Element::from(5u8), Element::from(7u8)]; 142 | 143 | let mut builder = GadgetBuilder::::new(); 144 | let key_wire = builder.wire(); 145 | let input_wire = builder.wire(); 146 | let key = Expression::from(key_wire); 147 | let input = Expression::from(input_wire); 148 | let mimc = MiMCBlockCipher::new(constants); 149 | let mimc_output = mimc.encrypt(&mut builder, &key, &input); 150 | let gadget = builder.build(); 151 | 152 | let mut values = values!(key_wire => 3u8.into(), input_wire => 6u8.into()); 153 | assert!(gadget.execute(&mut values)); 154 | assert_eq!(Element::from(2u8), mimc_output.evaluate(&values)); 155 | } 156 | 157 | /// MiMC is incompatible with F_7, because cubing is not a permutation in this field. 158 | #[test] 159 | #[should_panic] 160 | fn mimc_f7_incompatible() { 161 | MiMCBlockCipher::::default(); 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /src/miyaguchi_preneel.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use crate::expression::Expression; 4 | use crate::field::Field; 5 | use crate::gadget_builder::GadgetBuilder; 6 | use crate::gadget_traits::{BlockCipher, CompressionFunction}; 7 | 8 | /// The additive variant of Miyaguchi-Preneel, which creates a one-way compression function from a 9 | /// block cipher. 10 | pub struct MiyaguchiPreneel> { 11 | cipher: BC, 12 | phantom: PhantomData<*const F>, 13 | } 14 | 15 | impl> MiyaguchiPreneel { 16 | /// Create a new Miyaguchi-Preneel compression function from the given block cipher. 17 | pub fn new(cipher: BC) -> Self { 18 | MiyaguchiPreneel { cipher, phantom: PhantomData } 19 | } 20 | } 21 | 22 | impl> CompressionFunction for MiyaguchiPreneel { 23 | fn compress(&self, builder: &mut GadgetBuilder, x: &Expression, y: &Expression) 24 | -> Expression { 25 | self.cipher.encrypt(builder, x, y) + x + y 26 | } 27 | } 28 | 29 | #[cfg(test)] 30 | mod tests { 31 | use crate::expression::Expression; 32 | use crate::field::{Element, Field}; 33 | use crate::gadget_builder::GadgetBuilder; 34 | use crate::gadget_traits::{BlockCipher, CompressionFunction}; 35 | use crate::miyaguchi_preneel::MiyaguchiPreneel; 36 | use crate::test_util::F7; 37 | 38 | #[test] 39 | fn miyaguchi_preneel() { 40 | // We will use a trivial cipher to keep the test simple. 41 | // The cipher is: (k, i) -> 2k + 4i + 3ki 42 | struct TestCipher; 43 | 44 | impl BlockCipher for TestCipher { 45 | fn encrypt(&self, builder: &mut GadgetBuilder, key: &Expression, 46 | input: &Expression) -> Expression { 47 | let product = builder.product(key, input); 48 | key * 2 + input * 4 + product * 3 49 | } 50 | 51 | fn decrypt(&self, _builder: &mut GadgetBuilder, _key: &Expression, 52 | _output: &Expression) -> Expression { 53 | panic!("Should never be called") 54 | } 55 | } 56 | 57 | let mut builder = GadgetBuilder::::new(); 58 | let x_wire = builder.wire(); 59 | let y_wire = builder.wire(); 60 | let x = Expression::from(x_wire); 61 | let y = Expression::from(y_wire); 62 | let mp = MiyaguchiPreneel::new(TestCipher); 63 | let mp_output = mp.compress(&mut builder, &x, &y); 64 | let gadget = builder.build(); 65 | 66 | let mut values = values!(x_wire => 2u8.into(), y_wire => 3u8.into()); 67 | assert!(gadget.execute(&mut values)); 68 | // The result should be: (2x + 4y + 3xy) + x + y = 4 + 12 + 18 + 2 + 3 = 39 = 4. 69 | assert_eq!(Element::from(4u8), mp_output.evaluate(&values)); 70 | } 71 | } -------------------------------------------------------------------------------- /src/permutations.rs: -------------------------------------------------------------------------------- 1 | use num::{BigUint, Integer}; 2 | use num_traits::One; 3 | 4 | use crate::{Element, Expression, Field, GadgetBuilder, Permutation, WireValues}; 5 | 6 | /// The permutation `1 / x`, with zero being mapped to itself. 7 | pub struct InversePermutation; 8 | 9 | impl Permutation for InversePermutation { 10 | fn permute(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 11 | builder.inverse_or_zero(x) 12 | } 13 | 14 | fn inverse(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 15 | builder.inverse_or_zero(x) 16 | } 17 | } 18 | 19 | /// The permutation `x^n`. 20 | pub struct MonomialPermutation { 21 | n: Element, 22 | } 23 | 24 | impl MonomialPermutation { 25 | /// Creates a new monomial permutation given the given exponent. 26 | /// 27 | /// This method will panic if `x^n` is not a permutation of `F`. 28 | pub fn new(n: Element) -> Self { 29 | // It is well-known that x^n is a permutation of F_q iff gcd(n, q - 1) = 1. See, for 30 | // example, Theorem 1.14 in "Permutation Polynomials of Finite Fields" [Shallue 12]. 31 | assert!(Element::largest_element().gcd(&n).is_one(), 32 | "x^{} is not a permutation of F", n); 33 | MonomialPermutation { n } 34 | } 35 | } 36 | 37 | impl Permutation for MonomialPermutation { 38 | fn permute(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 39 | builder.exponentiation(x, &self.n) 40 | } 41 | 42 | fn inverse(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 43 | let root_wire = builder.wire(); 44 | let root = Expression::from(root_wire); 45 | let exponentiation = builder.exponentiation(&root, &self.n); 46 | builder.assert_equal(&exponentiation, x); 47 | 48 | // By Fermat's little theorem, x^p = x mod p, so if n divides p, then x^(p / n)^n = x mod p. 49 | // Further, since x^(p - 1) = 1 mod p, x^((p + (p - 1)*k) / n)^n = x mod p for any positive k, 50 | // provided that n divides p + (p - 1)*k. Thus we start with p, and repeatedly add 51 | // p - 1 until we find an exponent divisible by n. 52 | //TODO: find a solution that isn't O(p) 53 | let mut exponent_times_n = F::order(); 54 | let exponent = loop { 55 | exponent_times_n += F::order() - BigUint::one(); 56 | if exponent_times_n.is_multiple_of(self.n.to_biguint()) { 57 | break Element::from(exponent_times_n / self.n.to_biguint()); 58 | } 59 | }; 60 | 61 | let x = x.clone(); 62 | builder.generator( 63 | x.dependencies(), 64 | move |values: &mut WireValues| { 65 | let root_value = x.evaluate(values).exponentiation(&exponent); 66 | values.set(root_wire, root_value); 67 | }); 68 | 69 | root 70 | } 71 | } 72 | 73 | #[cfg(test)] 74 | mod tests { 75 | use crate::{Element, Expression, GadgetBuilder, MonomialPermutation, Permutation}; 76 | use crate::test_util::{F11, F7}; 77 | 78 | #[test] 79 | fn cube_and_cube_root() { 80 | let mut builder = GadgetBuilder::::new(); 81 | let permutation = MonomialPermutation::new(Element::from(3u8)); 82 | let x_wire = builder.wire(); 83 | let x = Expression::from(x_wire); 84 | let x_cubed = permutation.permute(&mut builder, &x); 85 | let cube_root = permutation.inverse(&mut builder, &x_cubed); 86 | let gadget = builder.build(); 87 | 88 | for i in 0u8..11 { 89 | let mut values = values!(x_wire => i.into()); 90 | assert!(gadget.execute(&mut values)); 91 | assert_eq!(Element::from(i), cube_root.evaluate(&values)); 92 | } 93 | } 94 | 95 | #[test] 96 | #[should_panic] 97 | fn not_a_permutation() { 98 | // x^3 is not a permutation in F_7, since gcd(3, 7-1) = 3 != 1. 99 | MonomialPermutation::::new(Element::from(3u8)); 100 | } 101 | } -------------------------------------------------------------------------------- /src/poseidon.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | #[cfg(not(feature = "std"))] 4 | use alloc::boxed::Box; 5 | 6 | use crate::{Element, Expression, Field, GadgetBuilder, InversePermutation, MdsMatrix, MonomialPermutation, MultiPermutation, Permutation}; 7 | 8 | const DEFAULT_SECURITY_BITS: usize = 128; 9 | 10 | /// An S-Box that can be used with Poseidon. 11 | #[derive(Copy, Clone, Debug)] 12 | pub enum PoseidonSbox { 13 | Exponentiation3, 14 | Exponentiation5, 15 | Inverse, 16 | } 17 | 18 | /// The Poseidon permutation. 19 | pub struct Poseidon { 20 | /// The size of the permutation, in field elements. 21 | width: usize, 22 | /// The number full and partial of rounds to use. 23 | num_rounds: NumberOfRounds, 24 | /// The S-box to apply in the sub words layer. 25 | sbox: PoseidonSbox, 26 | /// The MDS matrix to apply in the mix layer. 27 | mds_matrix: MdsMatrix, 28 | } 29 | 30 | /// Builds a `Poseidon` instance. 31 | pub struct PoseidonBuilder { 32 | /// The size of the permutation, in field elements. 33 | width: usize, 34 | /// The number full and partial of rounds to use. 35 | num_rounds: Option, 36 | /// The S-box to apply in the sub words layer. 37 | sbox: Option, 38 | /// The desired (classical) security level, in bits. 39 | security_bits: Option, 40 | /// The MDS matrix to apply in the mix layer. 41 | mds_matrix: Option>, 42 | } 43 | 44 | impl PoseidonBuilder { 45 | pub fn new(width: usize) -> Self { 46 | PoseidonBuilder { 47 | width, 48 | num_rounds: None, 49 | sbox: None, 50 | security_bits: None, 51 | mds_matrix: None, 52 | } 53 | } 54 | 55 | pub fn sbox(&mut self, sbox: PoseidonSbox) -> &mut Self { 56 | self.sbox = Some(sbox); 57 | self 58 | } 59 | 60 | pub fn num_rounds(&mut self, num_rounds: NumberOfRounds) -> &mut Self { 61 | self.num_rounds = Some(num_rounds); 62 | self 63 | } 64 | 65 | pub fn security_bits(&mut self, security_bits: usize) -> &mut Self { 66 | self.security_bits = Some(security_bits); 67 | self 68 | } 69 | 70 | pub fn mds_matrix(&mut self, mds_matrix: MdsMatrix) -> &mut Self { 71 | self.mds_matrix = Some(mds_matrix); 72 | self 73 | } 74 | 75 | pub fn build(&self) -> Poseidon { 76 | let width = self.width; 77 | 78 | // TODO: Generate a default MDS matrix instead of making the caller supply one. 79 | let mds_matrix = self.mds_matrix.clone().expect("MDS matrix required for now"); 80 | 81 | // If an S-box is not specified, determine the optimal choice based on the guidance in the 82 | // paper. 83 | let sbox = self.sbox.unwrap_or_else( 84 | || match Element::::largest_element() { 85 | ref x if x.gcd(&3u8.into()).is_one() => PoseidonSbox::Exponentiation3, 86 | ref x if x.gcd(&5u8.into()).is_one() => PoseidonSbox::Exponentiation5, 87 | _ => PoseidonSbox::Inverse, 88 | }); 89 | 90 | if self.num_rounds.is_some() && self.security_bits.is_some() { 91 | panic!("Cannot specify both the number of rounds and the desired security level"); 92 | } 93 | 94 | // Determine the optimal numbers of full and partial rounds. 95 | let num_rounds = self.num_rounds.unwrap_or_else( 96 | || secure_num_rounds_padded::(sbox, width, 97 | self.security_bits.unwrap_or(DEFAULT_SECURITY_BITS))); 98 | 99 | Poseidon { width, num_rounds, sbox, mds_matrix } 100 | } 101 | } 102 | 103 | /// The number of full and partial rounds to use in an instance of Poseidon. 104 | #[derive(Copy, Clone, Debug)] 105 | pub struct NumberOfRounds { 106 | full: usize, 107 | partial: usize, 108 | } 109 | 110 | impl Poseidon { 111 | fn sbox_permute(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 112 | self.sbox_to_permutation().permute(builder, x) 113 | } 114 | 115 | fn sbox_inverse(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 116 | self.sbox_to_permutation().inverse(builder, x) 117 | } 118 | 119 | fn sbox_to_permutation(&self) -> Box> { 120 | match &self.sbox { 121 | PoseidonSbox::Inverse => Box::new(InversePermutation), 122 | PoseidonSbox::Exponentiation3 => Box::new(MonomialPermutation::new(Element::from(3u8))), 123 | PoseidonSbox::Exponentiation5 => Box::new(MonomialPermutation::new(Element::from(5u8))), 124 | } 125 | } 126 | } 127 | 128 | impl MultiPermutation for Poseidon { 129 | fn width(&self) -> usize { 130 | self.width 131 | } 132 | 133 | fn permute(&self, builder: &mut GadgetBuilder, inputs: &[Expression]) 134 | -> Vec> { 135 | assert_eq!(inputs.len(), self.width); 136 | 137 | let rounds = self.num_rounds.full + self.num_rounds.partial; 138 | assert!(self.num_rounds.full % 2 == 0, "asymmetric permutation configuration"); 139 | let full_rounds_per_side = self.num_rounds.full / 2; 140 | 141 | let mut current = inputs.to_vec(); 142 | for round in 0..rounds { 143 | // Sub words layer. 144 | let full = round < full_rounds_per_side || round >= rounds - full_rounds_per_side; 145 | if full { 146 | current = current.iter() 147 | .map(|exp| self.sbox_permute(builder, exp)) 148 | .collect(); 149 | } else { 150 | current[0] = self.sbox_permute(builder, ¤t[0]); 151 | } 152 | 153 | // Mix layer. 154 | current = &self.mds_matrix * current.as_slice(); 155 | } 156 | 157 | current 158 | } 159 | 160 | fn inverse(&self, builder: &mut GadgetBuilder, outputs: &[Expression]) 161 | -> Vec> { 162 | assert_eq!(outputs.len(), self.width); 163 | 164 | let rounds = self.num_rounds.full + self.num_rounds.partial; 165 | assert!(self.num_rounds.full % 2 == 0, "asymmetric permutation configuration"); 166 | let full_rounds_per_side = self.num_rounds.full / 2; 167 | 168 | let inverse_mds_matrix = self.mds_matrix.inverse(); 169 | 170 | let mut current = outputs.to_vec();//.to_owned(); 171 | for round in 0..rounds { 172 | // Mix layer. 173 | current = &inverse_mds_matrix * current.as_slice(); 174 | 175 | // Sub words layer. 176 | let full = round < full_rounds_per_side || round >= rounds - full_rounds_per_side; 177 | if full { 178 | current = current.iter() 179 | .map(|exp| self.sbox_inverse(builder, exp)) 180 | .collect(); 181 | } else { 182 | current[0] = self.sbox_inverse(builder, ¤t[0]); 183 | } 184 | } 185 | 186 | current 187 | } 188 | } 189 | 190 | /// Selects a number of full and partial rounds so as to provide plausible security, including a 191 | /// reasonable security margin as suggested by the Poseidon authors. 192 | fn secure_num_rounds_padded( 193 | sbox: PoseidonSbox, width: usize, security_bits: usize, 194 | ) -> NumberOfRounds { 195 | let unpadded = secure_num_rounds_unpadded::(sbox, width, security_bits); 196 | NumberOfRounds { 197 | full: unpadded.full + 2, 198 | partial: (unpadded.partial as f64 * 1.075).round() as usize, 199 | } 200 | } 201 | 202 | fn secure_num_rounds_unpadded( 203 | sbox: PoseidonSbox, width: usize, security_bits: usize, 204 | ) -> NumberOfRounds { 205 | let mut full = 6; 206 | let mut best_rounds = NumberOfRounds { 207 | full, 208 | partial: secure_partial_rounds_unpadded::(sbox, width, full, security_bits), 209 | }; 210 | let mut best_sboxes = num_sboxes(width, best_rounds); 211 | 212 | loop { 213 | // We increment by 2 to maintain symmetry. 214 | full += 2; 215 | 216 | let rounds = NumberOfRounds { 217 | full, 218 | partial: secure_partial_rounds_unpadded::(sbox, width, full, security_bits), 219 | }; 220 | let sboxes = num_sboxes(width, rounds); 221 | 222 | if sboxes > best_sboxes { 223 | // The cost is starting to increase. Terminate with the best configuration we found. 224 | break best_rounds; 225 | } 226 | 227 | best_rounds = rounds; 228 | best_sboxes = sboxes; 229 | } 230 | } 231 | 232 | fn secure_partial_rounds_unpadded( 233 | sbox: PoseidonSbox, width: usize, full_rounds: usize, security_bits: usize, 234 | ) -> usize { 235 | // We could do an exponential search here, but brute force seems fast enough. 236 | let mut partial = 0; 237 | loop { 238 | let num_rounds = NumberOfRounds { full: full_rounds, partial }; 239 | if !is_attackable::(sbox, width, num_rounds, security_bits) { 240 | break partial; 241 | } 242 | partial += 1; 243 | } 244 | } 245 | 246 | fn is_attackable( 247 | sbox: PoseidonSbox, width: usize, num_rounds: NumberOfRounds, security_bits: usize, 248 | ) -> bool { 249 | match sbox { 250 | PoseidonSbox::Exponentiation3 => is_attackable_exponentiation_3::( 251 | width, num_rounds, security_bits), 252 | PoseidonSbox::Exponentiation5 => is_attackable_exponentiation_5::( 253 | width, num_rounds, security_bits), 254 | PoseidonSbox::Inverse => is_attackable_inverse::( 255 | width, num_rounds, security_bits), 256 | } 257 | } 258 | 259 | fn is_attackable_exponentiation_3( 260 | width: usize, num_rounds: NumberOfRounds, security_bits: usize, 261 | ) -> bool { 262 | let inequality_1 = (num_rounds.full + num_rounds.partial) as f64 263 | <= 2f64.log(3f64) * min_n_m::(security_bits) + (width as f64).log2(); 264 | let inequality_2a = (num_rounds.full + num_rounds.partial) as f64 265 | <= 0.32 * min_n_m::(security_bits); 266 | let inequality_2b = ((width - 1) * num_rounds.full + num_rounds.partial) as f64 267 | <= 0.18 * min_n_m::(security_bits) - 1.0; 268 | inequality_1 || inequality_2a || inequality_2b 269 | } 270 | 271 | fn is_attackable_exponentiation_5( 272 | width: usize, num_rounds: NumberOfRounds, security_bits: usize, 273 | ) -> bool { 274 | let inequality_1 = (num_rounds.full + num_rounds.partial) as f64 275 | <= 2f64.log(5f64) * min_n_m::(security_bits) + (width as f64).log2(); 276 | let inequality_2a = (num_rounds.full + num_rounds.partial) as f64 277 | <= 0.21 * min_n_m::(security_bits); 278 | let inequality_2b = ((width - 1) * num_rounds.full + num_rounds.partial) as f64 279 | <= 0.14 * min_n_m::(security_bits) - 1.0; 280 | inequality_1 || inequality_2a || inequality_2b 281 | } 282 | 283 | fn is_attackable_inverse( 284 | width: usize, num_rounds: NumberOfRounds, security_bits: usize, 285 | ) -> bool { 286 | let inequality_1 = num_rounds.full as f64 * (width as f64).log2() + num_rounds.partial as f64 287 | <= (width as f64).log2() + 0.5 + min_n_m::(security_bits); 288 | // In the paper, inequality (2a) is identical to (1) for the case of 1/x, so we omit it. 289 | let inequality_2 = ((width - 1) * num_rounds.full + num_rounds.partial) as f64 290 | <= 0.25 * min_n_m::(security_bits) - 1.0; 291 | inequality_1 || inequality_2 292 | } 293 | 294 | /// The minimum of the field size (in bits) and the security level, which the paper calls 295 | /// `min{n, M}`. 296 | fn min_n_m(security_bits: usize) -> f64 { 297 | security_bits.min(Element::::max_bits()) as f64 298 | } 299 | 300 | fn num_sboxes(width: usize, num_rounds: NumberOfRounds) -> usize { 301 | num_rounds.full * width + num_rounds.partial 302 | } 303 | 304 | #[cfg(test)] 305 | mod tests { 306 | use itertools::Itertools; 307 | 308 | use crate::{Expression, GadgetBuilder, MdsMatrix, MultiPermutation, PoseidonBuilder}; 309 | use crate::poseidon::NumberOfRounds; 310 | use crate::PoseidonSbox::Exponentiation3; 311 | use crate::test_util::F11; 312 | 313 | #[test] 314 | fn poseidon_x3_f11() { 315 | let mds_matrix = MdsMatrix::::new(vec![ 316 | vec![2u8.into(), 3u8.into(), 1u8.into(), 1u8.into()], 317 | vec![1u8.into(), 2u8.into(), 3u8.into(), 1u8.into()], 318 | vec![1u8.into(), 1u8.into(), 2u8.into(), 3u8.into()], 319 | vec![3u8.into(), 1u8.into(), 1u8.into(), 2u8.into()], 320 | ]); 321 | 322 | let poseidon = PoseidonBuilder::new(4) 323 | .sbox(Exponentiation3) 324 | .num_rounds(NumberOfRounds { full: 4, partial: 6 }) 325 | .mds_matrix(mds_matrix) 326 | .build(); 327 | 328 | let mut builder = GadgetBuilder::new(); 329 | let input_wires = builder.wires(4); 330 | let input_exps = input_wires.iter().map(Expression::from).collect_vec(); 331 | let _outputs = poseidon.permute(&mut builder, &input_exps); 332 | let gadget = builder.build(); 333 | 334 | let mut values = values!( 335 | input_wires[0] => 0u8.into(), input_wires[1] => 1u8.into(), 336 | input_wires[2] => 2u8.into(), input_wires[3] => 3u8.into()); 337 | assert!(gadget.execute(&mut values)); 338 | } 339 | } -------------------------------------------------------------------------------- /src/random_access.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with methods for randomly accessing lists. 2 | 3 | #[cfg(not(feature = "std"))] 4 | use alloc::vec::Vec; 5 | 6 | use crate::BooleanExpression; 7 | use crate::expression::Expression; 8 | use crate::field::Field; 9 | use crate::gadget_builder::GadgetBuilder; 10 | 11 | impl GadgetBuilder { 12 | /// Access the `i`th element of `items`, where `i` may be a dynamic expression. Assumes 13 | /// `i < items.len()`. 14 | /// 15 | /// Note that this gadget returns 0 if the index is out of range. If you want to prohibit 16 | /// out-of-range indices, you can do so with a separate call to `assert_lt`. 17 | pub fn random_access( 18 | &mut self, 19 | items: &[Expression], 20 | index: &Expression, 21 | ) -> Expression { 22 | // Determine the minimum number of bits needed to encode the index. 23 | let mut bits = 0; 24 | while 1 << bits < items.len() { 25 | bits += 1; 26 | } 27 | 28 | let index_binary = self.split_bounded(index, bits); 29 | self.random_access_binary(items, index_binary.bits) 30 | } 31 | 32 | /// Like `random_access`, but with a binary index. 33 | fn random_access_binary( 34 | &mut self, 35 | items: &[Expression], 36 | mut index_bits: Vec>, 37 | ) -> Expression { 38 | // Imagine a perfect binary tree whose leaves consist of the given items, followed by zeros 39 | // for padding. We can think of each bit of the index as filtering a single layer of the 40 | // tree. For example, the first (least significant) index bit selects between pairs of 41 | // leaves. After filtering each layer in this manner, we are left with a single value 42 | // corresponding to the root of the tree. 43 | 44 | // This leads to a natural recursive solution. Each call of this method will filter the 45 | // deepest layer of the tree, then recurse, until we are left with a singleton list. 46 | 47 | if items.len() == 1 { 48 | assert!(index_bits.is_empty()); 49 | return items[0].clone(); 50 | } 51 | 52 | let lsb = index_bits.remove(0); 53 | let num_parents = (items.len() + 1) / 2; 54 | let mut parent_layer = Vec::with_capacity(num_parents); 55 | for parent_index in 0..num_parents { 56 | let left_child_index = parent_index * 2; 57 | let right_child_index = parent_index * 2 + 1; 58 | let left_child = &items[left_child_index]; 59 | if right_child_index == items.len() { 60 | parent_layer.push(left_child.clone()); 61 | } else { 62 | let right_child = &items[right_child_index]; 63 | parent_layer.push(self.selection(&lsb, right_child, left_child)); 64 | } 65 | } 66 | 67 | self.random_access_binary(&parent_layer, index_bits) 68 | } 69 | } 70 | 71 | #[cfg(test)] 72 | mod tests { 73 | use itertools::Itertools; 74 | 75 | use crate::expression::Expression; 76 | use crate::field::Element; 77 | use crate::gadget_builder::GadgetBuilder; 78 | use crate::test_util::F257; 79 | use crate::wire_values::WireValues; 80 | 81 | #[test] 82 | fn random_access() { 83 | let n = 10; 84 | let mut builder = GadgetBuilder::::new(); 85 | let item_wires = builder.wires(n); 86 | let item_exps = item_wires.iter().map(Expression::from).collect_vec(); 87 | let index_wire = builder.wire(); 88 | let index_exp = Expression::from(index_wire); 89 | let result = builder.random_access(&item_exps, &index_exp); 90 | let gadget = builder.build(); 91 | 92 | let mut wire_values = WireValues::new(); 93 | for i in 0..n { 94 | wire_values.set(item_wires[i], Element::from(i)); 95 | } 96 | 97 | for i in 0..n { 98 | let mut wire_values_i = wire_values.clone(); 99 | wire_values_i.set(index_wire, Element::from(i)); 100 | assert!(gadget.execute(&mut wire_values_i)); 101 | assert_eq!(Element::from(i), result.evaluate(&wire_values_i)); 102 | } 103 | } 104 | } -------------------------------------------------------------------------------- /src/rescue.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | use crate::{Element, Expression, Field, GadgetBuilder, MdsMatrix, MonomialPermutation, MultiPermutation, Permutation}; 5 | 6 | const DEFAULT_SECURITY_BITS: usize = 128; 7 | const SECURITY_MARGIN: usize = 2; 8 | const MINIMUM_ROUNDS: usize = 10; 9 | 10 | /// The Rescue permutation. 11 | pub struct Rescue { 12 | /// The size of the permutation, in field elements. 13 | width: usize, 14 | /// The degree of the permutation monomial. 15 | alpha: Element, 16 | /// The number of rounds to use. 17 | num_rounds: usize, 18 | /// The MDS matrix to apply after each permutation layer. 19 | mds_matrix: MdsMatrix, 20 | } 21 | 22 | impl Rescue { 23 | fn pi_1(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 24 | MonomialPermutation::new(self.alpha.clone()).permute(builder, x) 25 | } 26 | 27 | fn pi_2(&self, builder: &mut GadgetBuilder, x: &Expression) -> Expression { 28 | MonomialPermutation::new(self.alpha.clone()).inverse(builder, x) 29 | } 30 | } 31 | 32 | impl MultiPermutation for Rescue { 33 | fn width(&self) -> usize { 34 | self.width 35 | } 36 | 37 | fn permute(&self, builder: &mut GadgetBuilder, inputs: &[Expression]) 38 | -> Vec> { 39 | let mut current = inputs.to_vec(); 40 | for _round in 0..self.num_rounds { 41 | current = current.iter().map(|exp| self.pi_1(builder, exp)).collect(); 42 | current = &self.mds_matrix * current.as_slice(); 43 | current = current.iter().map(|exp| self.pi_2(builder, exp)).collect(); 44 | current = &self.mds_matrix * current.as_slice(); 45 | } 46 | current 47 | } 48 | 49 | fn inverse(&self, _builder: &mut GadgetBuilder, _outputs: &[Expression]) 50 | -> Vec> { 51 | unimplemented!("TODO: implement inverse Rescue") 52 | } 53 | } 54 | 55 | /// Builds a `Rescue` instance. 56 | pub struct RescueBuilder { 57 | /// The size of the permutation, in field elements. 58 | width: usize, 59 | /// The degree of the permutation monomial. 60 | alpha: Option>, 61 | /// The number of rounds to use. 62 | num_rounds: Option, 63 | /// The desired (classical) security level, in bits. 64 | security_bits: Option, 65 | /// The MDS matrix to apply after each permutation layer. 66 | mds_matrix: Option>, 67 | } 68 | 69 | impl RescueBuilder { 70 | pub fn new(width: usize) -> Self { 71 | assert!(width > 0, "Permutation width must be non-zero"); 72 | RescueBuilder { 73 | width, 74 | alpha: None, 75 | num_rounds: None, 76 | security_bits: None, 77 | mds_matrix: None, 78 | } 79 | } 80 | 81 | pub fn alpha(&mut self, alpha: Element) -> &mut Self { 82 | self.alpha = Some(alpha); 83 | self 84 | } 85 | 86 | pub fn num_rounds(&mut self, num_rounds: usize) -> &mut Self { 87 | self.num_rounds = Some(num_rounds); 88 | self 89 | } 90 | 91 | pub fn security_bits(&mut self, security_bits: usize) -> &mut Self { 92 | self.security_bits = Some(security_bits); 93 | self 94 | } 95 | 96 | pub fn mds_matrix(&mut self, mds_matrix: MdsMatrix) -> &mut Self { 97 | self.mds_matrix = Some(mds_matrix); 98 | self 99 | } 100 | 101 | pub fn build(&self) -> Rescue { 102 | let width = self.width; 103 | let alpha = self.alpha.clone().unwrap_or_else(Self::smallest_alpha); 104 | 105 | // TODO: Generate a default MDS matrix instead of making the caller supply one. 106 | let mds_matrix = self.mds_matrix.clone().expect("MDS matrix required for now"); 107 | 108 | if self.num_rounds.is_some() && self.security_bits.is_some() { 109 | panic!("Cannot specify both the number of rounds and the desired security level"); 110 | } 111 | let num_rounds = self.num_rounds.unwrap_or_else( 112 | || Self::secure_num_rounds( 113 | self.security_bits.unwrap_or(DEFAULT_SECURITY_BITS), 114 | width)); 115 | 116 | Rescue { width, alpha, num_rounds, mds_matrix } 117 | } 118 | 119 | /// Find the smallest prime `a` such that `x^a` is a permutation in `F`, or equivalently, 120 | /// `gcd(|F| - 1, a) = 1`. 121 | fn smallest_alpha() -> Element { 122 | let largest_element = Element::::largest_element(); 123 | let mut alpha = Element::::from(3u8); 124 | while !largest_element.gcd(&alpha).is_one() { 125 | // Incremenet alpha to the next prime. 126 | alpha += Element::one(); 127 | while !alpha.is_prime() { 128 | alpha += Element::one(); 129 | } 130 | } 131 | alpha 132 | } 133 | 134 | fn secure_num_rounds(security_bits: usize, width: usize) -> usize { 135 | // As per the paper, a Gröbner basis attack is lower bounded by 2^{4 * width * rounds}. 136 | // Thus, attackable_rounds = security_bits / (4 * width) 137 | let attackable_rounds = integer_division_ceil(security_bits, 4 * width); 138 | (attackable_rounds * SECURITY_MARGIN).max(MINIMUM_ROUNDS) 139 | } 140 | } 141 | 142 | fn integer_division_ceil(n: usize, m: usize) -> usize { 143 | (n + m - 1) / m 144 | } 145 | 146 | #[cfg(test)] 147 | mod tests { 148 | use crate::MdsMatrix; 149 | use crate::rescue::RescueBuilder; 150 | use crate::test_util::F11; 151 | 152 | #[test] 153 | fn rescue_permutation_f11() { 154 | let mds_matrix = MdsMatrix::::new(vec![ 155 | vec![2u8.into(), 3u8.into(), 1u8.into(), 1u8.into()], 156 | vec![1u8.into(), 2u8.into(), 3u8.into(), 1u8.into()], 157 | vec![1u8.into(), 1u8.into(), 2u8.into(), 3u8.into()], 158 | vec![3u8.into(), 1u8.into(), 1u8.into(), 2u8.into()], 159 | ]); 160 | 161 | let _rescue = RescueBuilder::new(2).security_bits(128).mds_matrix(mds_matrix).build(); 162 | 163 | // TODO: Verify execution. 164 | } 165 | } -------------------------------------------------------------------------------- /src/signature.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use crate::{CompressionFunction, Expression, Field, GadgetBuilder, GroupExpression, CyclicGroup}; 4 | 5 | pub trait SignatureScheme, CF: CompressionFunction> { 6 | fn verify( 7 | builder: &mut GadgetBuilder, 8 | signature: &SignatureExpression, 9 | message: &Expression, 10 | public_key: &C::GroupExpression, 11 | compress: &CF, 12 | ); 13 | } 14 | 15 | pub struct Schnorr, CF: CompressionFunction> { 16 | phantom_f: PhantomData<*const F>, 17 | phantom_c: PhantomData<*const C>, 18 | phantom_cf: PhantomData<*const CF>, 19 | } 20 | 21 | /// Struct to represent a Schnorr Signature. 22 | /// 23 | /// Assumes that the message has already been hashed to a field element 24 | /// Signature is a tuple consisting of scalars (s, e), where r_v = sg + ey 25 | /// Public key is a group element, y = xg for private key x 26 | pub struct SignatureExpression { 27 | pub s: Expression, 28 | pub e: Expression, 29 | } 30 | 31 | impl, CF: CompressionFunction> SignatureScheme for Schnorr { 32 | /// Generates constraints to verify that a Schnorr signature for a message is valid, 33 | /// given a public key and a secure compression function. 34 | /// 35 | /// Requires a preimage-resistant hash function for full security. 36 | /// 37 | /// A naive implementation that has not been optimized or audited. 38 | // TODO: optimize scalar multiplication for a fixed generator 39 | fn verify( 40 | builder: &mut GadgetBuilder, 41 | signature: &SignatureExpression, 42 | message: &Expression, 43 | public_key: &C::GroupExpression, 44 | compress: &CF, 45 | ) { 46 | let generator = C::generator_expression(); 47 | let gs = C::mul_scalar_expression( 48 | builder, 49 | &generator, 50 | &signature.s); 51 | let ye = C::mul_scalar_expression( 52 | builder, 53 | public_key, 54 | &signature.e); 55 | let gs_ye = C::add_expressions(builder, &gs, &ye); 56 | 57 | // TODO: verify that compressing the Edwards Curve point to the Y-coordinate is valid 58 | let hash_check = compress.compress(builder, &gs_ye.compressed(), &message); 59 | builder.assert_equal(&hash_check, &signature.e); 60 | } 61 | } 62 | 63 | #[cfg(test)] 64 | mod tests { 65 | use std::str::FromStr; 66 | 67 | use crate::{CyclicGenerator, EdwardsExpression, Expression, GadgetBuilder, Group, WireValues, JubJub, JubJubPrimeSubgroup}; 68 | use crate::CompressionFunction; 69 | use crate::field::{Bls12_381, Element, Field}; 70 | use crate::signature::{Schnorr, SignatureExpression, SignatureScheme}; 71 | 72 | #[test] 73 | fn verify() { 74 | // Generate signature 75 | let generator = JubJub::generator_element(); 76 | 77 | let private_key = Element::from_str("4372820819045374670962167435360035096875258").unwrap(); 78 | 79 | let mut builder = GadgetBuilder::::new(); 80 | 81 | let public_key 82 | = JubJubPrimeSubgroup::mul_scalar_element(&generator, &private_key); 83 | 84 | let nonce = Element::from_str("5434290453746709621674353600312312").unwrap(); 85 | 86 | let r 87 | = JubJubPrimeSubgroup::mul_scalar_element(&generator, &nonce); 88 | 89 | let compress = TestCompress {}; 90 | 91 | let message = Element::from_str("12345").unwrap(); 92 | 93 | let e = compress.compress_evaluate(&r.compressed_element(), &message); 94 | 95 | let s = &nonce - &private_key * &e; 96 | 97 | let signature = SignatureExpression { s: Expression::from(s), e: Expression::from(e) }; 98 | 99 | let mut builder = GadgetBuilder::::new(); 100 | 101 | Schnorr::::verify( 102 | &mut builder, 103 | &signature, 104 | &Expression::from(message), 105 | &EdwardsExpression::from(&public_key), 106 | &compress, 107 | ); 108 | 109 | let gadget = builder.build(); 110 | let mut values = WireValues::new(); 111 | gadget.execute(&mut values); 112 | 113 | //TODO: include test vectors 114 | } 115 | 116 | // A dummy compression function which returns 2x + y. 117 | struct TestCompress; 118 | 119 | impl CompressionFunction for TestCompress { 120 | fn compress(&self, _builder: &mut GadgetBuilder, x: &Expression, y: &Expression) 121 | -> Expression { 122 | x * 2 + y 123 | } 124 | } 125 | } -------------------------------------------------------------------------------- /src/sorting.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with a method for sorting lists of field elements. 2 | 3 | #[cfg(not(feature = "std"))] 4 | use alloc::vec::Vec; 5 | 6 | use itertools::enumerate; 7 | 8 | use crate::expression::Expression; 9 | use crate::field::{Element, Field}; 10 | use crate::gadget_builder::GadgetBuilder; 11 | use crate::wire::Wire; 12 | use crate::wire_values::WireValues; 13 | 14 | impl GadgetBuilder { 15 | /// Sorts field elements in ascending order. 16 | pub fn sort_ascending(&mut self, inputs: &[Expression]) -> Vec> { 17 | let n = inputs.len(); 18 | 19 | let output_wires: Vec = self.wires(n); 20 | let outputs: Vec> = output_wires.iter().map(Expression::from).collect(); 21 | 22 | // First we assert that the input and output lists are permutations of one another, i.e., 23 | // that they contain the same values. 24 | self.assert_permutation(inputs, &outputs); 25 | 26 | // Then, we assert the order of each adjacent pair of output values. Note that assert_le 27 | // would internally split each input into its binary form. To avoid splitting intermediate 28 | // items twice, we will explicitly split here, and call assert_le_binary instead. 29 | // Also note that only the purportedly largest item (i.e. the last one) needs to be split 30 | // canonically. If one of the other elements were to be split into their non-canonical 31 | // binary encoding, that binary expression would be greater than the last element, rendering 32 | // the instance unsatisfiable. 33 | let mut outputs_binary = Vec::new(); 34 | for out in outputs.iter().take(n - 1) { 35 | outputs_binary.push(self.split_allowing_ambiguity(out)); 36 | } 37 | outputs_binary.push(self.split(&outputs[n - 1])); 38 | 39 | for i in 0..(n - 1) { 40 | let a = &outputs_binary[i]; 41 | let b = &outputs_binary[i + 1]; 42 | self.assert_le_binary(a, b); 43 | } 44 | 45 | let inputs = inputs.to_vec(); 46 | self.generator( 47 | inputs.iter().flat_map(Expression::dependencies).collect(), 48 | move |values: &mut WireValues| { 49 | // Evaluate all the inputs, sort that list of field elements, and output that. 50 | let mut items: Vec> = 51 | inputs.iter().map(|exp| exp.evaluate(values)).collect(); 52 | items.sort(); 53 | for (i, item) in enumerate(items) { 54 | values.set(output_wires[i], item); 55 | } 56 | }); 57 | 58 | outputs 59 | } 60 | 61 | /// Sorts field elements in descending order. 62 | pub fn sort_descending(&mut self, inputs: &[Expression]) -> Vec> { 63 | let mut items = self.sort_ascending(inputs); 64 | items.reverse(); 65 | items 66 | } 67 | } 68 | 69 | #[cfg(test)] 70 | mod tests { 71 | use crate::expression::Expression; 72 | use crate::field::Element; 73 | use crate::gadget_builder::GadgetBuilder; 74 | use crate::test_util::F257; 75 | 76 | #[test] 77 | fn sort_4_ascending() { 78 | let mut builder = GadgetBuilder::::new(); 79 | let (a, b, c, d) = (builder.wire(), builder.wire(), builder.wire(), builder.wire()); 80 | let outputs = builder.sort_ascending(&vec![ 81 | Expression::from(a), Expression::from(b), Expression::from(c), Expression::from(d)]); 82 | let gadget = builder.build(); 83 | 84 | let mut values = values!( 85 | a => 4u8.into(), b => 7u8.into(), c => 0u8.into(), d => 1u8.into()); 86 | assert!(gadget.execute(&mut values)); 87 | assert_eq!(Element::from(0u8), outputs[0].evaluate(&values)); 88 | assert_eq!(Element::from(1u8), outputs[1].evaluate(&values)); 89 | assert_eq!(Element::from(4u8), outputs[2].evaluate(&values)); 90 | assert_eq!(Element::from(7u8), outputs[3].evaluate(&values)); 91 | } 92 | 93 | #[test] 94 | fn sort_4_descending() { 95 | let mut builder = GadgetBuilder::::new(); 96 | let (a, b, c, d) = (builder.wire(), builder.wire(), builder.wire(), builder.wire()); 97 | let outputs = builder.sort_descending(&vec![ 98 | Expression::from(a), Expression::from(b), Expression::from(c), Expression::from(d)]); 99 | let gadget = builder.build(); 100 | 101 | let mut values = values!( 102 | a => 4u8.into(), b => 7u8.into(), c => 0u8.into(), d => 1u8.into()); 103 | assert!(gadget.execute(&mut values)); 104 | assert_eq!(Element::from(7u8), outputs[0].evaluate(&values)); 105 | assert_eq!(Element::from(4u8), outputs[1].evaluate(&values)); 106 | assert_eq!(Element::from(1u8), outputs[2].evaluate(&values)); 107 | assert_eq!(Element::from(0u8), outputs[3].evaluate(&values)); 108 | } 109 | } -------------------------------------------------------------------------------- /src/splitting.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with methods for splitting field elements into bits. 2 | 3 | use crate::expression::{BinaryExpression, Expression}; 4 | use crate::field::{Element, Field}; 5 | use crate::gadget_builder::GadgetBuilder; 6 | use crate::wire_values::WireValues; 7 | 8 | impl GadgetBuilder { 9 | /// Split an arbitrary field element `x` into its canonical binary representation. 10 | pub fn split(&mut self, x: &Expression) -> BinaryExpression { 11 | let result = self.split_without_range_check(x, Element::::max_bits()); 12 | self.assert_lt_binary(&result, &BinaryExpression::from(F::order())); 13 | result 14 | } 15 | 16 | /// Split an arbitrary field element `x` into a binary representation. Unlike `split`, this 17 | /// method permits two distinct binary decompositions: the canonical one, and another 18 | /// representation where the weighted sum of bits overflows the field size. This minimizes 19 | /// constraints, but the ambiguity can be a security problem depending on the context. If in 20 | /// doubt, use `split` instead. 21 | pub fn split_allowing_ambiguity(&mut self, x: &Expression) -> BinaryExpression { 22 | self.split_without_range_check(x, Element::::max_bits()) 23 | } 24 | 25 | /// Split `x` into `bits` bit wires. This method assumes `x < 2^bits < |F|`. Note that only one 26 | /// binary representation is possible here, since `bits` bits is not enough to overflow the 27 | /// field size. 28 | pub fn split_bounded(&mut self, x: &Expression, bits: usize) -> BinaryExpression { 29 | assert!(bits < Element::::max_bits()); 30 | self.split_without_range_check(x, bits) 31 | } 32 | 33 | fn split_without_range_check(&mut self, x: &Expression, bits: usize) -> BinaryExpression { 34 | let binary_wire = self.binary_wire(bits); 35 | let binary_exp = BinaryExpression::from(&binary_wire); 36 | let weighted_sum = binary_exp.join_allowing_overflow(); 37 | self.assert_equal(x, &weighted_sum); 38 | 39 | let x = x.clone(); 40 | self.generator( 41 | x.dependencies(), 42 | move |values: &mut WireValues| { 43 | let value = x.evaluate(values); 44 | assert!(value.bits() <= bits); 45 | for i in 0..bits { 46 | values.set_boolean(binary_wire.bits[i], value.bit(i)); 47 | } 48 | }, 49 | ); 50 | 51 | binary_exp 52 | } 53 | } 54 | 55 | #[cfg(test)] 56 | mod tests { 57 | use crate::Bn128; 58 | use crate::expression::Expression; 59 | use crate::gadget_builder::GadgetBuilder; 60 | 61 | #[test] 62 | fn split_19_32() { 63 | let mut builder = GadgetBuilder::::new(); 64 | let wire = builder.wire(); 65 | let bit_wires = builder.split_bounded(&Expression::from(wire), 32); 66 | let gadget = builder.build(); 67 | 68 | let mut wire_values = values!(wire => 19u8.into()); 69 | assert!(gadget.execute(&mut wire_values)); 70 | 71 | assert_eq!(true, bit_wires.bits[0].evaluate(&wire_values)); 72 | assert_eq!(true, bit_wires.bits[1].evaluate(&wire_values)); 73 | assert_eq!(false, bit_wires.bits[2].evaluate(&wire_values)); 74 | assert_eq!(false, bit_wires.bits[3].evaluate(&wire_values)); 75 | assert_eq!(true, bit_wires.bits[4].evaluate(&wire_values)); 76 | assert_eq!(false, bit_wires.bits[5].evaluate(&wire_values)); 77 | assert_eq!(false, bit_wires.bits[6].evaluate(&wire_values)); 78 | assert_eq!(false, bit_wires.bits[7].evaluate(&wire_values)); 79 | assert_eq!(false, bit_wires.bits[8].evaluate(&wire_values)); 80 | assert_eq!(false, bit_wires.bits[9].evaluate(&wire_values)); 81 | assert_eq!(false, bit_wires.bits[10].evaluate(&wire_values)); 82 | assert_eq!(false, bit_wires.bits[11].evaluate(&wire_values)); 83 | assert_eq!(false, bit_wires.bits[12].evaluate(&wire_values)); 84 | assert_eq!(false, bit_wires.bits[13].evaluate(&wire_values)); 85 | assert_eq!(false, bit_wires.bits[14].evaluate(&wire_values)); 86 | assert_eq!(false, bit_wires.bits[15].evaluate(&wire_values)); 87 | assert_eq!(false, bit_wires.bits[16].evaluate(&wire_values)); 88 | assert_eq!(false, bit_wires.bits[17].evaluate(&wire_values)); 89 | assert_eq!(false, bit_wires.bits[18].evaluate(&wire_values)); 90 | assert_eq!(false, bit_wires.bits[19].evaluate(&wire_values)); 91 | assert_eq!(false, bit_wires.bits[20].evaluate(&wire_values)); 92 | assert_eq!(false, bit_wires.bits[21].evaluate(&wire_values)); 93 | assert_eq!(false, bit_wires.bits[22].evaluate(&wire_values)); 94 | assert_eq!(false, bit_wires.bits[23].evaluate(&wire_values)); 95 | assert_eq!(false, bit_wires.bits[24].evaluate(&wire_values)); 96 | assert_eq!(false, bit_wires.bits[25].evaluate(&wire_values)); 97 | assert_eq!(false, bit_wires.bits[26].evaluate(&wire_values)); 98 | assert_eq!(false, bit_wires.bits[27].evaluate(&wire_values)); 99 | assert_eq!(false, bit_wires.bits[28].evaluate(&wire_values)); 100 | assert_eq!(false, bit_wires.bits[29].evaluate(&wire_values)); 101 | assert_eq!(false, bit_wires.bits[30].evaluate(&wire_values)); 102 | assert_eq!(false, bit_wires.bits[31].evaluate(&wire_values)); 103 | } 104 | } -------------------------------------------------------------------------------- /src/sponge.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with an implementation of the Merkle-Damgard construction. 2 | 3 | #[cfg(not(feature = "std"))] 4 | use alloc::vec::Vec; 5 | 6 | use core::iter; 7 | use std::marker::PhantomData; 8 | 9 | use itertools::{enumerate, Itertools}; 10 | 11 | use crate::{GadgetBuilder, MultiPermutation}; 12 | use crate::Expression; 13 | use crate::Field; 14 | use crate::util::concat; 15 | 16 | /// A sponge function. 17 | /// 18 | /// In a SNARK setting, efficiency demands that the two sections of sponge state memory (R and C) be 19 | /// stored in separate field elements, so that inputs can be efficiently added to R without 20 | /// affecting C. 21 | pub struct Sponge> { 22 | permutation: MP, 23 | bitrate: usize, 24 | capacity: usize, 25 | phantom: PhantomData<*const F>, 26 | } 27 | 28 | impl> Sponge { 29 | /// Create a new sponge function. 30 | /// 31 | /// # Parameters 32 | /// - `permutation` - the permutation with which to transform state memory 33 | /// - `bitrate` - the size of the input section, in field elements 34 | /// - `capacity` - the size of the capacity section, in field elements 35 | pub fn new(permutation: MP, bitrate: usize, capacity: usize) -> Self { 36 | assert_eq!(bitrate + capacity, permutation.width(), 37 | "Sponge state memory size must match permutation size"); 38 | Sponge { permutation, bitrate, capacity, phantom: PhantomData } 39 | } 40 | 41 | pub fn evaluate( 42 | &self, builder: &mut GadgetBuilder, inputs: &[Expression], output_len: usize, 43 | ) -> Vec> { 44 | let mut input_section = iter::repeat(Expression::zero()) 45 | .take(self.bitrate).collect_vec(); 46 | let mut capacity_section = iter::repeat(Expression::zero()) 47 | .take(self.capacity).collect_vec(); 48 | 49 | let chunks = inputs.chunks(self.bitrate); 50 | for chunk in chunks { 51 | // Add this chunk to the input section. 52 | for (i, element) in enumerate(chunk) { 53 | input_section[i] += element; 54 | } 55 | 56 | // Apply the permutation. 57 | let old_state = concat(&[input_section, capacity_section]); 58 | let new_state = self.permutation.permute(builder, &old_state); 59 | assert_eq!(old_state.len(), new_state.len()); 60 | let (new_input, new_capacity) = new_state.split_at(self.bitrate); 61 | input_section = new_input.to_vec(); 62 | capacity_section = new_capacity.to_vec(); 63 | } 64 | 65 | let mut outputs = input_section.clone(); 66 | while outputs.len() < output_len { 67 | // Apply the permutation. 68 | let old_state = concat(&[input_section, capacity_section]); 69 | let new_state = self.permutation.permute(builder, &old_state); 70 | assert_eq!(old_state.len(), new_state.len()); 71 | let (new_input, new_capacity) = new_state.split_at(self.bitrate); 72 | input_section = new_input.to_vec(); 73 | capacity_section = new_capacity.to_vec(); 74 | 75 | outputs.extend(input_section.clone()) 76 | } 77 | 78 | // If output_len is not a multiple of the bitrate, then the code above would have added more 79 | // output elements than we actually want to return. 80 | outputs.truncate(output_len); 81 | 82 | outputs 83 | } 84 | } 85 | 86 | #[cfg(test)] 87 | mod tests { 88 | #[cfg(not(feature = "std"))] 89 | use alloc::vec::Vec; 90 | use crate::{Element, Expression, Field, GadgetBuilder, MultiPermutation, Sponge}; 91 | use crate::test_util::F7; 92 | 93 | #[test] 94 | fn sponge_1_1_1_f7() { 95 | // We will use a trivial compression function to keep the test simple. 96 | // It transforms (x, y) into (2y, 3x). 97 | struct TestPermutation; 98 | 99 | impl MultiPermutation for TestPermutation { 100 | fn width(&self) -> usize { 101 | 2 102 | } 103 | 104 | fn permute( 105 | &self, _builder: &mut GadgetBuilder, inputs: &[Expression], 106 | ) -> Vec> { 107 | assert_eq!(inputs.len(), 2); 108 | let x = &inputs[0]; 109 | let y = &inputs[1]; 110 | vec![y * Element::from(2u8), x * Element::from(3u8)] 111 | } 112 | 113 | fn inverse( 114 | &self, _builder: &mut GadgetBuilder, outputs: &[Expression], 115 | ) -> Vec> { 116 | assert_eq!(outputs.len(), 2); 117 | let x = &outputs[0]; 118 | let y = &outputs[1]; 119 | vec![y / Element::from(3u8), x / Element::from(2u8)] 120 | } 121 | } 122 | 123 | let mut builder = GadgetBuilder::::new(); 124 | let x_wire = builder.wire(); 125 | let y_wire = builder.wire(); 126 | let x = Expression::from(x_wire); 127 | let y = Expression::from(y_wire); 128 | let blocks = &[x, y]; 129 | let sponge = Sponge::new(TestPermutation, 1, 1); 130 | let hash = sponge.evaluate(&mut builder, blocks, 1); 131 | assert_eq!(hash.len(), 1); 132 | let hash = &hash[0]; 133 | let gadget = builder.build(); 134 | 135 | let mut values = values!(x_wire => 3u8.into(), y_wire => 4u8.into()); 136 | assert!(gadget.execute(&mut values)); 137 | // It transforms (x, y) into (2y, 3x). 138 | // Initial state: (0, 0) 139 | // After adding 3: (3, 0) 140 | // After permuting: (0, 2) 141 | // After adding 4: (4, 2) 142 | // After permuting: (4, 5) 143 | // Output: 4 144 | assert_eq!(Element::from(4u8), hash.evaluate(&values)); 145 | } 146 | } -------------------------------------------------------------------------------- /src/test_util.rs: -------------------------------------------------------------------------------- 1 | //! This module contains test helper functions. 2 | 3 | use std::borrow::Borrow; 4 | 5 | use num::BigUint; 6 | 7 | use crate::expression::BooleanExpression; 8 | use crate::field::Field; 9 | use crate::wire_values::WireValues; 10 | 11 | pub fn assert_eq_true(x: T, values: &WireValues) 12 | where F: Field, T: Borrow> { 13 | assert_eq!(true, x.borrow().evaluate(values)); 14 | } 15 | 16 | pub fn assert_eq_false(x: T, values: &WireValues) 17 | where F: Field, T: Borrow> { 18 | assert_eq!(false, x.borrow().evaluate(values)); 19 | } 20 | 21 | #[derive(Debug)] 22 | pub struct F7 {} 23 | 24 | impl Field for F7 { 25 | fn order() -> BigUint { 26 | BigUint::from(7u8) 27 | } 28 | } 29 | 30 | #[derive(Debug)] 31 | pub struct F11 {} 32 | 33 | impl Field for F11 { 34 | fn order() -> BigUint { 35 | BigUint::from(11u8) 36 | } 37 | } 38 | 39 | #[derive(Debug)] 40 | pub struct F257 {} 41 | 42 | impl Field for F257 { 43 | fn order() -> BigUint { 44 | BigUint::from(257u16) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | #[cfg(not(feature = "std"))] 5 | use alloc::string::String; 6 | 7 | use core::borrow::Borrow; 8 | 9 | /// Like SliceConcatExt::concat, but works in stable with no_std. 10 | /// See https://github.com/rust-lang/rust/issues/27747 11 | pub fn concat>(vecs: &[V]) -> Vec { 12 | let size = vecs.iter().map(|slice| slice.borrow().len()).sum(); 13 | let mut result = Vec::with_capacity(size); 14 | for v in vecs { 15 | result.extend_from_slice(v.borrow()) 16 | } 17 | result 18 | } 19 | 20 | /// Like SliceConcatExt::join for strings, but works in stable with no_std. 21 | /// See https://github.com/rust-lang/rust/issues/27747 22 | pub fn join>(sep: &str, strings: &[S]) -> String { 23 | let mut builder = String::new(); 24 | for s in strings { 25 | if !builder.is_empty() { 26 | builder += sep; 27 | } 28 | builder += s.borrow(); 29 | } 30 | builder 31 | } -------------------------------------------------------------------------------- /src/verify_permutation.rs: -------------------------------------------------------------------------------- 1 | //! This module extends GadgetBuilder with a method for verifying permutations. 2 | 3 | #[cfg(not(feature = "std"))] 4 | use alloc::vec::Vec; 5 | 6 | #[cfg(feature = "std")] 7 | use std::collections::BTreeMap; 8 | #[cfg(not(feature = "std"))] 9 | use alloc::collections::btree_map::BTreeMap; 10 | 11 | use crate::bimap_util::bimap_from_lists; 12 | use crate::expression::{BooleanExpression, Expression}; 13 | use crate::field::{Element, Field}; 14 | use crate::gadget_builder::GadgetBuilder; 15 | use crate::wire::{BooleanWire, Wire}; 16 | use crate::wire_values::WireValues; 17 | use crate::util::concat; 18 | 19 | impl GadgetBuilder { 20 | /// Assert that two lists of expressions evaluate to permutations of one another. 21 | /// 22 | /// This is currently implemented using an AS-Waksman permutation network, although that could 23 | /// change in the future. See "On Arbitrary Waksman Networks and their Vulnerability". 24 | pub fn assert_permutation(&mut self, a: &[Expression], b: &[Expression]) { 25 | assert_eq!(a.len(), b.len(), "Permutation must have same number of inputs and outputs"); 26 | 27 | match a.len() { 28 | // Two empty lists are permutations of one another, trivially. 29 | 0 => (), 30 | // Two singleton lists are permutations of one another as long as their items are equal. 31 | 1 => self.assert_equal(&a[0], &b[0]), 32 | // For the 2x2 case, we're implementing a switch gadget. The switch will be controlled 33 | // with a binary wire. It will swap swap the order of its inputs iff that wire is true. 34 | 2 => self.assert_permutation_2x2(&a[0], &a[1], &b[0], &b[1]), 35 | // For larger lists, we recursively use two smaller permutation networks. 36 | _ => self.assert_permutation_recursive(a, b) 37 | } 38 | } 39 | 40 | /// Assert that [a, b] is a permutation of [c, d]. 41 | fn assert_permutation_2x2(&mut self, 42 | a: &Expression, b: &Expression, 43 | c: &Expression, d: &Expression) { 44 | let (switch, c_target, d_target) = self.create_switch(a, b); 45 | self.assert_equal(c, &c_target); 46 | self.assert_equal(d, &d_target); 47 | let a = a.clone(); 48 | let b = b.clone(); 49 | let c = c.clone(); 50 | let d = d.clone(); 51 | self.generator( 52 | concat(&[a.dependencies(), b.dependencies(), c.dependencies(), d.dependencies()]), 53 | move |values: &mut WireValues| { 54 | let a_value = a.evaluate(values); 55 | let b_value = b.evaluate(values); 56 | let c_value = c.evaluate(values); 57 | let d_value = d.evaluate(values); 58 | if a_value == c_value && b_value == d_value { 59 | values.set_boolean(switch, false); 60 | } else if a_value == d_value && b_value == c_value { 61 | values.set_boolean(switch, true); 62 | } else { 63 | panic!("No permutation from [{}, {}] to [{}, {}]", 64 | a_value, b_value, c_value, d_value); 65 | } 66 | }); 67 | } 68 | 69 | /// Creates a 2x2 switch given the two input expressions. Returns three things: the (boolean) 70 | /// switch wire and the two output expressions. The order of the outputs will match that of the 71 | /// inputs if the switch wire is set to false, otherwise the order will be swapped. 72 | fn create_switch(&mut self, a: &Expression, b: &Expression) 73 | -> (BooleanWire, Expression, Expression) { 74 | let switch = self.boolean_wire(); 75 | let c = self.selection(&BooleanExpression::from(switch), b, a); 76 | let d = a + b - &c; 77 | (switch, c, d) 78 | } 79 | 80 | fn assert_permutation_recursive(&mut self, a: &[Expression], b: &[Expression]) { 81 | let n = a.len(); 82 | let even = n % 2 == 0; 83 | 84 | let mut child_1_a = Vec::new(); 85 | let mut child_1_b = Vec::new(); 86 | let mut child_2_a = Vec::new(); 87 | let mut child_2_b = Vec::new(); 88 | 89 | // See Figure 8 in the AS-Waksman paper. 90 | let a_num_switches = n / 2; 91 | let b_num_switches = if even { a_num_switches - 1 } else { a_num_switches }; 92 | 93 | let mut a_switches = Vec::new(); 94 | let mut b_switches = Vec::new(); 95 | for i in 0..a_num_switches { 96 | let (switch, out_1, out_2) = self.create_switch(&a[i * 2], &a[i * 2 + 1]); 97 | a_switches.push(switch); 98 | child_1_a.push(out_1); 99 | child_2_a.push(out_2); 100 | } 101 | for i in 0..b_num_switches { 102 | let (switch, out_1, out_2) = self.create_switch(&b[i * 2], &b[i * 2 + 1]); 103 | b_switches.push(switch); 104 | child_1_b.push(out_1); 105 | child_2_b.push(out_2); 106 | } 107 | 108 | // See Figure 8 in the AS-Waksman paper. 109 | if even { 110 | child_1_b.push(b[n - 2].clone()); 111 | child_2_b.push(b[n - 1].clone()); 112 | } else { 113 | child_2_a.push(a[n - 1].clone()); 114 | child_2_b.push(b[n - 1].clone()); 115 | } 116 | 117 | self.assert_permutation(&child_1_a, &child_1_b); 118 | self.assert_permutation(&child_2_a, &child_2_b); 119 | 120 | let a_deps: Vec = a.iter().flat_map(Expression::dependencies).collect(); 121 | let b_deps: Vec = b.iter().flat_map(Expression::dependencies).collect(); 122 | 123 | let a = a.to_vec(); 124 | let b = b.to_vec(); 125 | self.generator( 126 | concat(&[a_deps, b_deps]), 127 | move |values: &mut WireValues| { 128 | let a_values: Vec> = a.iter().map(|exp| exp.evaluate(values)).collect(); 129 | let b_values: Vec> = b.iter().map(|exp| exp.evaluate(values)).collect(); 130 | route(a_values, b_values, &a_switches, &b_switches, values); 131 | }); 132 | } 133 | } 134 | 135 | /// Generates switch settings for a single layer of the recursive network. 136 | fn route(a_values: Vec>, b_values: Vec>, 137 | a_switches: &[BooleanWire], b_switches: &[BooleanWire], 138 | values: &mut WireValues) { 139 | assert_eq!(a_values.len(), b_values.len()); 140 | let n = a_values.len(); 141 | let even = n % 2 == 0; 142 | let ab_map = bimap_from_lists(a_values, b_values); 143 | let switches = [a_switches, b_switches]; 144 | 145 | let ab_map_by_side = |side: usize, index: usize| -> usize { 146 | *match side { 147 | 0 => ab_map.get_by_left(&index), 148 | 1 => ab_map.get_by_right(&index), 149 | _ => panic!("Expected side to be 0 or 1") 150 | }.unwrap() 151 | }; 152 | 153 | // We maintain two maps for wires which have been routed to a particular subnetwork on one side 154 | // of the network (left or right) but not the other. The keys are wire indices, and the values 155 | // are subnetwork indices. 156 | let mut partial_routes = [BTreeMap::new(), BTreeMap::new()]; 157 | 158 | // After we route a wire on one side, we find the corresponding wire on the other side and check 159 | // if it still needs to be routed. If so, we add it to partial_routes. 160 | let enqueue_other_side = |partial_routes: &mut [BTreeMap], 161 | values: &mut WireValues, 162 | side: usize, this_i: usize, subnet: bool| { 163 | let other_side = 1 - side; 164 | let other_i = ab_map_by_side(side, this_i); 165 | let other_switch_i = other_i / 2; 166 | 167 | if other_switch_i >= switches[other_side].len() { 168 | // The other wire doesn't go through a switch, so there's no routing to be done. 169 | return; 170 | } 171 | 172 | if values.contains_boolean(switches[other_side][other_switch_i]) { 173 | // The other switch has already been routed. 174 | return; 175 | } 176 | 177 | let other_i_sibling = 4 * other_switch_i + 1 - other_i; 178 | if let Some(&sibling_subnet) = partial_routes[other_side].get(&other_i_sibling) { 179 | // The other switch's sibling is already pending routing. 180 | assert_ne!(subnet, sibling_subnet); 181 | } else { 182 | let opt_old_subnet = partial_routes[other_side].insert(other_i, subnet); 183 | if let Some(old_subnet) = opt_old_subnet { 184 | assert_eq!(subnet, old_subnet, "Routing conflict (should never happen)"); 185 | } 186 | } 187 | }; 188 | 189 | // See Figure 8 in the AS-Waksman paper. 190 | if even { 191 | enqueue_other_side(&mut partial_routes, values, 1, n - 2, false); 192 | enqueue_other_side(&mut partial_routes, values, 1, n - 1, true); 193 | } else { 194 | enqueue_other_side(&mut partial_routes, values, 0, n - 1, true); 195 | enqueue_other_side(&mut partial_routes, values, 1, n - 1, true); 196 | } 197 | 198 | let route_switch = |partial_routes: &mut [BTreeMap], values: &mut WireValues, 199 | side: usize, switch_index: usize, swap: bool| { 200 | // First, we actually set the switch configuration. 201 | values.set_boolean(switches[side][switch_index], swap); 202 | 203 | // Then, we enqueue the two corresponding wires on the other side of the network, to ensure 204 | // that they get routed in the next step. 205 | let this_i_1 = switch_index * 2; 206 | let this_i_2 = this_i_1 + 1; 207 | enqueue_other_side(partial_routes, values, side, this_i_1, swap); 208 | enqueue_other_side(partial_routes, values, side, this_i_2, !swap); 209 | }; 210 | 211 | // If {a,b}_only_routes is empty, then we can route any switch next. For efficiency, we will 212 | // simply do top-down scans (one on the left side, one on the right side) for switches which 213 | // have not yet been routed. These variables represent the positions of those two scans. 214 | let mut scan_index = [0, 0]; 215 | 216 | // Until both scans complete, we alternate back and worth between the left and right switch 217 | // layers. We process any partially routed wires for that side, or if there aren't any, we route 218 | // the next switch in our scan. 219 | while scan_index[0] < switches[0].len() || scan_index[1] < switches[1].len() { 220 | for side in 0..=1 { 221 | if !partial_routes[side].is_empty() { 222 | for (this_i, subnet) in partial_routes[side].clone().into_iter() { 223 | let this_first_switch_input = this_i % 2 == 0; 224 | let swap = this_first_switch_input == subnet; 225 | let this_switch_i = this_i / 2; 226 | route_switch(&mut partial_routes, values, side, this_switch_i, swap); 227 | } 228 | partial_routes[side].clear(); 229 | } else { 230 | // We can route any switch next. Continue our scan for pending switches. 231 | while scan_index[side] < switches[side].len() 232 | && values.contains_boolean(switches[side][scan_index[side]]) { 233 | scan_index[side] += 1; 234 | } 235 | if scan_index[side] < switches[side].len() { 236 | // Either switch configuration would work; we arbitrarily choose to not swap. 237 | route_switch(&mut partial_routes, values, side, scan_index[side], false); 238 | scan_index[side] += 1; 239 | } 240 | } 241 | } 242 | } 243 | } 244 | 245 | #[cfg(test)] 246 | mod tests { 247 | use itertools::Itertools; 248 | 249 | use crate::expression::Expression; 250 | use crate::gadget_builder::GadgetBuilder; 251 | use crate::test_util::F257; 252 | use crate::wire_values::WireValues; 253 | 254 | #[test] 255 | fn route_2x2() { 256 | let mut builder = GadgetBuilder::::new(); 257 | builder.assert_permutation( 258 | &[1u8.into(), 2u8.into()], 259 | &[2u8.into(), 1u8.into()]); 260 | let gadget = builder.build(); 261 | assert!(gadget.execute(&mut WireValues::new())); 262 | } 263 | 264 | #[test] 265 | fn route_3x3() { 266 | let mut builder = GadgetBuilder::::new(); 267 | builder.assert_permutation( 268 | &[1u8.into(), 2u8.into(), 3u8.into()], 269 | &[2u8.into(), 1u8.into(), 3u8.into()]); 270 | let gadget = builder.build(); 271 | assert!(gadget.execute(&mut WireValues::new())); 272 | } 273 | 274 | #[test] 275 | fn route_4x4() { 276 | let mut builder = GadgetBuilder::::new(); 277 | builder.assert_permutation( 278 | &[1u8.into(), 2u8.into(), 3u8.into(), 4u8.into()], 279 | &[2u8.into(), 1u8.into(), 4u8.into(), 3u8.into()]); 280 | let gadget = builder.build(); 281 | assert!(gadget.execute(&mut WireValues::new())); 282 | } 283 | 284 | #[test] 285 | fn route_5x5() { 286 | let mut builder = GadgetBuilder::::new(); 287 | let a = builder.wires(5); 288 | let b = builder.wires(5); 289 | let a_exp = a.iter().map(Expression::from).collect_vec(); 290 | let b_exp = b.iter().map(Expression::from).collect_vec(); 291 | builder.assert_permutation(&a_exp, &b_exp); 292 | let gadget = builder.build(); 293 | 294 | let mut values_normal = values!( 295 | a[0] => 0u8.into(), a[1] => 1u8.into(), a[2] => 2u8.into(), a[3] => 3u8.into(), a[4] => 4u8.into(), 296 | b[0] => 1u8.into(), b[1] => 4u8.into(), b[2] => 0u8.into(), b[3] => 3u8.into(), b[4] => 2u8.into()); 297 | assert!(gadget.execute(&mut values_normal)); 298 | 299 | let mut values_with_duplicates = values!( 300 | a[0] => 0u8.into(), a[1] => 1u8.into(), a[2] => 2u8.into(), a[3] => 0u8.into(), a[4] => 1u8.into(), 301 | b[0] => 1u8.into(), b[1] => 1u8.into(), b[2] => 0u8.into(), b[3] => 0u8.into(), b[4] => 2u8.into()); 302 | assert!(gadget.execute(&mut values_with_duplicates)); 303 | } 304 | 305 | #[test] 306 | #[should_panic] 307 | fn not_a_permutation() { 308 | let mut builder = GadgetBuilder::::new(); 309 | builder.assert_permutation( 310 | &[1u8.into(), 2u8.into(), 2u8.into()], 311 | &[1u8.into(), 2u8.into(), 1u8.into()]); 312 | let gadget = builder.build(); 313 | // The generator should fail, since there's no possible routing. 314 | gadget.execute(&mut WireValues::new()); 315 | } 316 | 317 | #[test] 318 | #[should_panic] 319 | fn lengths_differ() { 320 | let mut builder = GadgetBuilder::::new(); 321 | builder.assert_permutation( 322 | &[1u8.into(), 2u8.into(), 3u8.into()], 323 | &[1u8.into(), 2u8.into()]); 324 | } 325 | } -------------------------------------------------------------------------------- /src/wire.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | 4 | use std::cmp::Ordering; 5 | use std::fmt; 6 | use std::fmt::Formatter; 7 | 8 | /// A wire represents a witness element. 9 | #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] 10 | pub struct Wire { 11 | pub index: u32, 12 | } 13 | 14 | impl Wire { 15 | /// A special wire whose value is always set to 1. This is used to create `Expression`s with 16 | /// constant terms. 17 | pub const ONE: Wire = Wire { index: 0 }; 18 | } 19 | 20 | impl Ord for Wire { 21 | fn cmp(&self, other: &Self) -> Ordering { 22 | // For presentation, we want the 1 wire to be last. Otherwise use ascending index order. 23 | if *self == Wire::ONE && *other == Wire::ONE { 24 | Ordering::Equal 25 | } else if *self == Wire::ONE { 26 | Ordering::Greater 27 | } else if *other == Wire::ONE { 28 | Ordering::Less 29 | } else { 30 | self.index.cmp(&other.index) 31 | } 32 | } 33 | } 34 | 35 | impl PartialOrd for Wire { 36 | fn partial_cmp(&self, other: &Self) -> Option { 37 | Some(self.cmp(other)) 38 | } 39 | } 40 | 41 | impl fmt::Display for Wire { 42 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 43 | if self.index == 0 { 44 | write!(f, "1") 45 | } else { 46 | write!(f, "wire_{}", self.index) 47 | } 48 | } 49 | } 50 | 51 | /// A `Wire` whose value is constrained to be binary. 52 | #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] 53 | pub struct BooleanWire { 54 | pub wire: Wire, 55 | } 56 | 57 | /// A `Wire` which has been constrained in such a way that it can only take on a value of 0 or 1. 58 | impl BooleanWire { 59 | /// Construct a BooleanWire from an arbitrary wire. This is only safe if you separately 60 | /// constrain the wire to equal 0 or 1. 61 | /// 62 | /// Users should not normally call this method directly; use a method like 63 | /// `GadgetBuilder::boolean_wire()` instead. 64 | pub fn new_unsafe(wire: Wire) -> Self { 65 | BooleanWire { wire } 66 | } 67 | 68 | pub fn wire(self) -> Wire { 69 | self.wire 70 | } 71 | } 72 | 73 | /// A "binary wire" which is comprised of several bits, each one being a boolean wire. 74 | #[derive(Clone, Debug)] 75 | pub struct BinaryWire { 76 | /// The list of bits, ordered from least significant to most significant. 77 | pub bits: Vec, 78 | } 79 | 80 | #[allow(clippy::len_without_is_empty)] 81 | impl BinaryWire { 82 | /// The number of bits. 83 | pub fn len(&self) -> usize { 84 | self.bits.len() 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/wire_values.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "std")] 2 | use std::collections::BTreeMap; 3 | #[cfg(not(feature = "std"))] 4 | use alloc::collections::btree_map::BTreeMap; 5 | 6 | use num::BigUint; 7 | use num_traits::One; 8 | 9 | use crate::expression::{BooleanExpression}; 10 | use crate::field::{Element, Field}; 11 | use crate::wire::{BinaryWire, BooleanWire, Wire}; 12 | 13 | /// An assignment of wire values, where each value is an element of the field `F`. 14 | #[derive(Default, Debug)] 15 | pub struct WireValues { 16 | values: BTreeMap>, 17 | } 18 | 19 | impl WireValues { 20 | pub fn new() -> Self { 21 | let mut values = BTreeMap::new(); 22 | values.insert(Wire::ONE, Element::one()); 23 | WireValues { values } 24 | } 25 | 26 | pub fn as_map(&self) -> &BTreeMap> { 27 | &self.values 28 | } 29 | 30 | pub fn get(&self, wire: Wire) -> &Element { 31 | assert!(self.values.contains_key(&wire), "No value for {}", wire); 32 | &self.values[&wire] 33 | } 34 | 35 | pub fn get_boolean(&self, wire: BooleanWire) -> bool { 36 | BooleanExpression::from(wire).evaluate(self) 37 | } 38 | 39 | pub fn set(&mut self, wire: Wire, value: Element) { 40 | let old_value = self.values.insert(wire, value); 41 | assert!(old_value.is_none()); 42 | } 43 | 44 | pub fn set_boolean(&mut self, wire: BooleanWire, value: bool) { 45 | self.set(wire.wire(), Element::from(value)); 46 | } 47 | 48 | pub fn set_binary_unsigned(&mut self, wire: &BinaryWire, value: &BigUint) { 49 | let l = wire.len(); 50 | assert!(value.bits() <= l as u64, "Value does not fit"); 51 | 52 | for i in 0..l { 53 | let value = ((value >> i) & BigUint::one()).is_one(); 54 | self.set_boolean(wire.bits[i], value); 55 | } 56 | } 57 | 58 | pub fn contains(&self, wire: Wire) -> bool { 59 | self.values.contains_key(&wire) 60 | } 61 | 62 | pub fn contains_boolean(&self, wire: BooleanWire) -> bool { 63 | self.contains(wire.wire) 64 | } 65 | 66 | pub fn contains_all(&self, wires: &[Wire]) -> bool { 67 | wires.iter().all(|&wire| self.contains(wire)) 68 | } 69 | } 70 | 71 | impl Clone for WireValues { 72 | fn clone(&self) -> Self { 73 | WireValues { values: self.values.clone() } 74 | } 75 | } 76 | 77 | pub trait Evaluable { 78 | fn evaluate(&self, wire_values: &WireValues) -> R; 79 | } 80 | 81 | /// Creates an instance of `WireValues` from the given wires and field element values. 82 | #[macro_export] 83 | macro_rules! values { 84 | ( $( $wire:expr => $value:expr ),* ) => { 85 | { 86 | let mut values = $crate::WireValues::new(); 87 | $( 88 | values.set($wire, $value); 89 | )* 90 | values 91 | } 92 | } 93 | } 94 | 95 | /// Creates an instance of `WireValues` from the given boolean wires and boolean values. 96 | #[macro_export] 97 | macro_rules! boolean_values { 98 | ( $( $wire:expr => $value:expr ),* ) => { 99 | { 100 | let mut values = $crate::WireValues::new(); 101 | $( 102 | values.set_boolean($wire, $value); 103 | )* 104 | values 105 | } 106 | } 107 | } 108 | 109 | /// Creates an instance of `WireValues` from the given binary wires and `BigUint` values. 110 | #[macro_export] 111 | macro_rules! binary_unsigned_values { 112 | ( $( $wire:expr => $value:expr ),* ) => { 113 | { 114 | let mut values = $crate::WireValues::new(); 115 | $( 116 | values.set_binary_unsigned($wire, $value); 117 | )* 118 | values 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/witness_generator.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "std"))] 2 | use alloc::vec::Vec; 3 | #[cfg(not(feature = "std"))] 4 | use alloc::boxed::Box; 5 | 6 | use crate::field::Field; 7 | use crate::wire::Wire; 8 | use crate::wire_values::WireValues; 9 | 10 | /// Generates some elements of the witness. 11 | pub struct WitnessGenerator { 12 | inputs: Vec, 13 | generator: Box)>, 14 | } 15 | 16 | impl WitnessGenerator { 17 | /// Creates a new `WitnessGenerator`. 18 | /// 19 | /// # Arguments 20 | /// * `inputs` - the wires whose values must be set before this generator can run 21 | /// * `generate` - a function which generates some elements of the witness 22 | pub fn new(inputs: Vec, generate: T) -> Self 23 | where T: Fn(&mut WireValues) + 'static { 24 | WitnessGenerator { 25 | inputs, 26 | generator: Box::new(generate), 27 | } 28 | } 29 | 30 | /// The wires whose values must be set before this generator can run. 31 | pub fn inputs(&self) -> &[Wire] { 32 | &self.inputs 33 | } 34 | 35 | /// Run the generator. 36 | pub fn generate(&self, values: &mut WireValues) { 37 | (*self.generator)(values) 38 | } 39 | } --------------------------------------------------------------------------------