├── .github
└── workflows
│ ├── github_pages.yml
│ ├── publish_to_pypi.yml
│ └── python-app.yml
├── .gitignore
├── CITATION.cff
├── LICENSE.txt
├── README.md
├── docs
├── api
│ ├── abstract_nodes_api.md
│ ├── addition_chains_api.md
│ ├── circuits_api.md
│ ├── code_generation_api.md
│ ├── nodes_api.md
│ └── pareto_fronts_api.md
├── config.md
├── example_circuits.md
├── getting_started.md
├── images
│ └── oraqle_logo_cropped.svg
├── index.md
└── tutorial_running_exps.md
├── main.cpp
├── mkdocs.yml
├── oraqle
├── __init__.py
├── add_chains
│ ├── __init__.py
│ ├── addition_chains.py
│ ├── addition_chains_front.py
│ ├── addition_chains_heuristic.py
│ ├── addition_chains_mod.py
│ ├── memoization.py
│ └── solving.py
├── circuits
│ ├── __init__.py
│ ├── aes.py
│ ├── cardio.py
│ ├── median.py
│ ├── mimc.py
│ ├── sorting.py
│ └── veto_voting.py
├── compiler
│ ├── __init__.py
│ ├── arithmetic
│ │ ├── __init__.py
│ │ ├── exponentiation.py
│ │ └── subtraction.py
│ ├── boolean
│ │ ├── __init__.py
│ │ ├── bool_and.py
│ │ ├── bool_neg.py
│ │ └── bool_or.py
│ ├── circuit.py
│ ├── comparison
│ │ ├── __init__.py
│ │ ├── comparison.py
│ │ ├── equality.py
│ │ └── in_upper_half.py
│ ├── control_flow
│ │ ├── __init__.py
│ │ └── conditional.py
│ ├── func2poly.py
│ ├── graphviz.py
│ ├── instructions.py
│ ├── nodes
│ │ ├── __init__.py
│ │ ├── abstract.py
│ │ ├── arbitrary_arithmetic.py
│ │ ├── binary_arithmetic.py
│ │ ├── fixed.py
│ │ ├── flexible.py
│ │ ├── leafs.py
│ │ ├── non_commutative.py
│ │ ├── unary_arithmetic.py
│ │ └── univariate.py
│ ├── poly2circuit.py
│ └── polynomials
│ │ ├── __init__.py
│ │ └── univariate.py
├── config.py
├── demo
│ ├── depth_aware_equality.ipynb
│ ├── playground.ipynb
│ ├── small_comparison_bgv.ipynb
│ └── veto_voting.ipynb
├── examples
│ ├── depth_aware_comparison.py
│ ├── depth_aware_equality.py
│ ├── long_and.py
│ ├── small_comparison.py
│ ├── small_polynomial.py
│ ├── visualize_circuits.py
│ └── wahc2024_presentation
│ │ ├── 1_high-level.py
│ │ ├── 2_arith_step1.py
│ │ ├── 3_arith_step2.py
│ │ ├── 5_code_gen.py
│ │ └── rebalancing.py
├── experiments
│ ├── depth_aware_arithmetization
│ │ └── execution
│ │ │ ├── bench_cardio_circuits.py
│ │ │ ├── bench_equality.py
│ │ │ ├── cardio_circuits.py
│ │ │ ├── comparisons.py
│ │ │ ├── equality_first_prime_mods_exec.py
│ │ │ ├── poly_evaluation_pareto_front.py
│ │ │ ├── run_all.sh
│ │ │ └── veto_voting_per_mod.py
│ └── oraqle_spotlight
│ │ ├── examples
│ │ ├── and_16.py
│ │ ├── common_expressions.py
│ │ ├── equality_31.py
│ │ ├── equality_and_comparison.py
│ │ └── t2_comparison.py
│ │ └── experiments
│ │ ├── comparisons
│ │ └── comparisons_bench.py
│ │ ├── large_equality
│ │ ├── .gitignore
│ │ ├── CMakeLists.txt
│ │ └── large_equality.py
│ │ └── veto_voting_minimal_cost.py
└── helib_template
│ ├── .gitignore
│ ├── CMakeLists.txt
│ ├── __init__.py
│ └── main.cpp
├── pyproject.toml
├── requirements.txt
├── requirements_dev.txt
├── ruff.toml
├── setup.cfg
└── tests
├── test_circuit_sizes_costs.py
├── test_poly2circuit.py
└── test_sugar_expressions.py
/.github/workflows/github_pages.yml:
--------------------------------------------------------------------------------
1 | name: Deploy docs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | permissions:
7 | contents: write
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - name: Configure Git Credentials
14 | run: |
15 | git config user.name github-actions[bot]
16 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com
17 | - uses: actions/setup-python@v5
18 | with:
19 | python-version: 3.*
20 | - run: pip install -r requirements_dev.txt
21 | - run: mkdocs gh-deploy --force
22 |
--------------------------------------------------------------------------------
/.github/workflows/publish_to_pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python 🐍 distribution 📦 to PyPI
2 |
3 | on: push
4 |
5 | jobs:
6 | build:
7 | name: Build distribution 📦
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Set up Python
13 | uses: actions/setup-python@v4
14 | with:
15 | python-version: "3.x"
16 | - name: Install pypa/build
17 | run: >-
18 | python3 -m
19 | pip install
20 | build
21 | --user
22 | - name: Build a binary wheel and a source tarball
23 | run: python3 -m build
24 | - name: Store the distribution packages
25 | uses: actions/upload-artifact@v3
26 | with:
27 | name: python-package-distributions
28 | path: dist/
29 |
30 | publish-to-pypi:
31 | name: >-
32 | Publish Python 🐍 distribution 📦 to PyPI
33 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
34 | needs:
35 | - build
36 | runs-on: ubuntu-latest
37 | environment:
38 | name: pypi
39 | url: https://pypi.org/p/oraqle
40 | permissions:
41 | id-token: write
42 |
43 | steps:
44 | - name: Download all the dists
45 | uses: actions/download-artifact@v3
46 | with:
47 | name: python-package-distributions
48 | path: dist/
49 | - name: Publish distribution 📦 to PyPI
50 | uses: pypa/gh-action-pypi-publish@release/v1
51 |
--------------------------------------------------------------------------------
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | build:
17 |
18 | runs-on: ${{ matrix.os }}
19 | strategy:
20 | matrix:
21 | os: [ubuntu-latest, macos-latest, windows-latest]
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python 3.10
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: "3.10"
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install ruff pytest
33 | pip install -r requirements.txt
34 | pip install -e .
35 | - name: Lint with ruff
36 | run: |
37 | ruff check
38 | - name: Test with pytest
39 | run: |
40 | pytest
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /venv
2 | *.dot
3 | .idea/**
4 | __pycache__/**
5 | /build
6 | .DS_Store
7 | *.egg-info*
8 | instructions.txt
9 | .ipynb_checkpoints/
10 | *.pdf
11 | *.pkl
12 | .sphinx_build/
13 | /dist
14 | *.svg
15 | *.pyc
16 | oraqle/addchain_cache.db
17 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you want to cite the compiler as a whole, please cite this work. See the citations page in the documentation for specific works e.g. about depth-aware arithmetization."
3 | authors:
4 | - family-names: "Vos"
5 | given-names: "Jelle"
6 | orcid: "https://orcid.org/0000-0002-3979-9740"
7 | - family-names: "Conti"
8 | given-names: "Mauro"
9 | orcid: "https://orcid.org/0000-0002-3612-1934"
10 | - family-names: "Erkin"
11 | given-names: "Zekeriya"
12 | orcid: "https://orcid.org/0000-0001-8932-4703"
13 | title: "Oraqle: A Depth-Aware Secure Computation Compiler"
14 | version: 0.1.0
15 | doi: 10.1145/3689945.3694808
16 | date-released: 2024-09-11
17 | url: "https://github.com/jellevos/oraqle"
18 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jelle Vos
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | A secure computation compiler
4 |
5 |
6 | The oraqle compiler lets you generate arithmetic circuits from high-level Python code. It also lets you generate code using HElib.
7 |
8 | This repository uses a fork of fhegen as a dependency and adapts some of the code from [fhegen](https://github.com/Crypto-TII/fhegen), which was written by Johannes Mono, Chiara Marcolla, Georg Land, Tim Güneysu, and Najwa Aaraj. You can read their theoretical work at: https://eprint.iacr.org/2022/706.
9 |
10 | See [our documentation](https://jelle-vos.nl/oraqle) for more details.
11 |
12 | ## Setting up
13 | The best way to get things up and running is using a virtual environment:
14 | - Set up a virtualenv using `python3 -m venv venv` in the directory.
15 | - Enter the virtual environment using `source venv/bin/activate`.
16 | - Install the requirements using `pip install requirements.txt`.
17 | - *To overcome import problems*, run `pip install -e .`, which will create links to your files (so you do not need to re-install after every change).
18 |
19 | We are currently setting up documentation to be rendered using GitHub Actions.
20 |
--------------------------------------------------------------------------------
/docs/api/abstract_nodes_api.md:
--------------------------------------------------------------------------------
1 | # Abstract nodes API
2 | !!! warning
3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version.
4 |
5 | If you want to extend the oraqle compiler, or implement your own high-level nodes, it is easiest to extend one of the existing abstract node classes.
6 |
--------------------------------------------------------------------------------
/docs/api/addition_chains_api.md:
--------------------------------------------------------------------------------
1 | # Addition chains API
2 | !!! warning
3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version.
4 |
5 | The `add_chains` module contains tools for generating addition chains.
6 |
7 | ::: oraqle.add_chains
8 | options:
9 | heading_level: 2
10 | show_submodules: true
11 | show_if_no_docstring: false
12 |
--------------------------------------------------------------------------------
/docs/api/circuits_api.md:
--------------------------------------------------------------------------------
1 | # Circuits API
2 | !!! warning
3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version.
4 |
5 |
6 | ## High-level circuits
7 | ::: oraqle.compiler.circuit.Circuit
8 | options:
9 | heading_level: 3
10 |
11 |
12 | ## Arithmetic circuits
13 | ::: oraqle.compiler.circuit.ArithmeticCircuit
14 | options:
15 | heading_level: 3
16 |
--------------------------------------------------------------------------------
/docs/api/code_generation_api.md:
--------------------------------------------------------------------------------
1 | # Code generation API
2 | !!! warning
3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version.
4 |
5 | The easiest way is using:
6 | ```python3
7 | arithmetic_circuit.generate_code()
8 | ```
9 |
10 | ## Arithmetic instructions
11 | If you want to extend the oraqle compiler, or implement your own code generation, you can use the following instructions to do so.
12 |
13 | ??? info "Abstract instruction"
14 | ::: oraqle.compiler.instructions.ArithmeticInstruction
15 | options:
16 | heading_level: 3
17 |
18 | ??? info "InputInstruction"
19 | ::: oraqle.compiler.instructions.InputInstruction
20 | options:
21 | heading_level: 3
22 |
23 | ??? info "AdditionInstruction"
24 | ::: oraqle.compiler.instructions.AdditionInstruction
25 | options:
26 | heading_level: 3
27 |
28 | ??? info "MultiplicationInstruction"
29 | ::: oraqle.compiler.instructions.MultiplicationInstruction
30 | options:
31 | heading_level: 3
32 |
33 | ??? info "ConstantAdditionInstruction"
34 | ::: oraqle.compiler.instructions.ConstantAdditionInstruction
35 | options:
36 | heading_level: 3
37 |
38 | ??? info "ConstantMultiplicationInstruction"
39 | ::: oraqle.compiler.instructions.ConstantMultiplicationInstruction
40 | options:
41 | heading_level: 3
42 |
43 | ??? info "OutputInstruction"
44 | ::: oraqle.compiler.instructions.OutputInstruction
45 | options:
46 | heading_level: 3
47 |
48 |
49 | ## Generating arithmetic programs
50 | ::: oraqle.compiler.instructions.ArithmeticProgram
51 | options:
52 | heading_level: 3
53 |
54 |
55 | ## Generating code for HElib
56 | ...
57 |
--------------------------------------------------------------------------------
/docs/api/nodes_api.md:
--------------------------------------------------------------------------------
1 | # Nodes API
2 | !!! warning
3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version.
4 |
5 | ## Boolean operations
6 |
7 | ??? info "AND operation"
8 | ::: oraqle.compiler.boolean.bool_and.And
9 | options:
10 | heading_level: 3
11 |
12 | ??? info "OR operation"
13 | ::: oraqle.compiler.boolean.bool_or.Or
14 | options:
15 | heading_level: 3
16 |
17 | ??? info "NEG operation"
18 | ::: oraqle.compiler.boolean.bool_neg.Neg
19 | options:
20 | heading_level: 3
21 |
22 |
23 | ## Arithmetic operations
24 | These operations are fundamental arithmetic operations, so they will stay the same when they are arithmetized.
25 |
26 |
27 | ## High-level arithmetic operations
28 |
29 | ??? info "Subtraction"
30 | ::: oraqle.compiler.arithmetic.subtraction.Subtraction
31 | options:
32 | heading_level: 3
33 |
34 | ??? info "Exponentiation"
35 | ::: oraqle.compiler.arithmetic.exponentiation.Power
36 | options:
37 | heading_level: 3
38 |
39 |
40 | ## Polynomial evaluation
41 |
42 | ??? info "Univariate polynomial evaluation"
43 | ::: oraqle.compiler.polynomials.univariate.UnivariatePoly
44 | options:
45 | heading_level: 3
46 |
47 |
48 | ## Control flow
49 |
50 | ??? info "If-else statement"
51 | ::: oraqle.compiler.control_flow.conditional.IfElse
52 | options:
53 | heading_level: 3
54 |
--------------------------------------------------------------------------------
/docs/api/pareto_fronts_api.md:
--------------------------------------------------------------------------------
1 | # Pareto fronts API
2 | !!! warning
3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version.
4 |
5 | If you are using depth-aware arithmetization, you will find that the compiler does not output one arithmetic circuit.
6 | Instead, it outputs a Pareto front, which represents the best circuits that it could generate trading off two metrics:
7 | The *multiplicative depth* and the *multiplicative size/cost*.
8 | This page briefly explains the API for interfacing with these Pareto fronts.
9 |
10 | ## The abstract base class
11 |
12 | ??? info "Abstract ParetoFront"
13 | ::: oraqle.compiler.nodes.abstract.ParetoFront
14 | options:
15 | heading_level: 3
16 |
17 | ## Depth-size and depth-cost fronts
18 |
19 |
--------------------------------------------------------------------------------
/docs/config.md:
--------------------------------------------------------------------------------
1 | # Configuration parameters
2 |
3 | ::: oraqle.config
4 | options:
5 | heading_level: 2
6 | show_submodules: true
7 | show_if_no_docstring: false
8 |
--------------------------------------------------------------------------------
/docs/example_circuits.md:
--------------------------------------------------------------------------------
1 | !!! warning
2 | Some of these example circuits are untested and may be incorrect.
3 |
4 | ::: oraqle.circuits
5 | options:
6 | heading_level: 3
7 | show_submodules: true
8 |
--------------------------------------------------------------------------------
/docs/getting_started.md:
--------------------------------------------------------------------------------
1 | # Getting started
2 | In 5 minutes, this page will guide you through how to install oraqle, how to specify high-level programs, and how to arithmetize your first circuit!
3 |
4 | ## Installation
5 | Simply install the most recent version of the Oraqle compiler using:
6 | ```
7 | pip install oraqle
8 | ```
9 |
10 | We use continuous integration to test every build of the Oraqle compiler on Windows, MacOS, and Unix systems.
11 | If you do run into problems, feel free to [open an issue on GitHub]()!
12 |
13 | ## Specifying high-level programs
14 | Let's start with importing `galois`, which represents our plaintext algebra.
15 | We will also immediately import the relevant oraqle classes for our little example:
16 | ```python3
17 | from galois import GF
18 |
19 | from oraqle.compiler.circuit import Circuit
20 | from oraqle.compiler.nodes.leafs import Input
21 | ```
22 |
23 | For this example, we will use 31 as our plaintext modulus. This algebra is denoted by `GF(31)`.
24 | Let's create a few inputs that represent elements in this algebra:
25 | ```python3
26 | gf = GF(31)
27 |
28 | x = Input("x", gf)
29 | y = Input("y", gf)
30 | z = Input("z", gf)
31 | ```
32 |
33 | We can now perform some operations on these elements, and they do not have to be arithmetic operations!
34 | For example, we can perform equality checks or comparisons:
35 | ```
36 | comparison = x < y
37 | equality = y == z
38 | both = comparison & equality
39 | ```
40 |
41 | While we have specified some operations, we have not yet established this as a circuit. We will do so now:
42 | ```python3
43 | circuit = Circuit([both])
44 | ```
45 |
46 | And that's it! We are done specifying our first high-level circuit.
47 | As you can see this is all very similar to writing a regular Python program.
48 | If you want to visualize this high-level circuit before we continue with arithmetizing it, you can run the following (if you have graphviz installed):
49 | ```python3
50 | circuit.to_pdf("high_level_circuit.pdf")
51 | ```
52 |
53 | !!! tip
54 | If you do not have graphviz installed, you can instead call:
55 | ```python3
56 | circuit.to_dot("high_level_circuit.dot")
57 | ```
58 | After that, you can copy the file contents to [an online graphviz viewer](https://dreampuf.github.io/GraphvizOnline)!
59 |
60 | ## Arithmetizing your first circuit
61 | At this point, arithmetization is a breeze, because the oraqle compiler takes care of these steps.
62 | We can create an arithmetic circuit and visualize it using the following snippet:
63 | ```python3
64 | arithmetic_circuit = circuit.arithmetize()
65 | arithmetic_circuit.to_pdf("arithmetic_circuit.pdf")
66 | ```
67 |
68 | You will notice that it's quite a large circuit. But how large is it exactly?
69 | This is a question that we can ask to the oraqle compiler:
70 | ```python3
71 | print("Depth:", arithmetic_circuit.multiplicative_depth())
72 | print("Size:", arithmetic_circuit.multiplicative_size())
73 | print("Cost:", arithmetic_circuit.multiplicative_cost(0.7))
74 | ```
75 |
76 | In the last line, we asked the compiler to output the multiplicative cost, considering that squaring operations are cheaper than regular multiplications.
77 | We weighed this cost with a factor 0.7.
78 |
79 | Now that we have an arithmetic circuit, we can use homomorphic encryption to evaluate it!
80 | If you are curious about executing these circuits for real, consider reading [the code generation tutorial](tutorial_running_exps.md).
81 |
82 | !!! warning
83 | There are many homomorphic encryption libraries that do not support plaintext moduli that are not NTT-friendly. The plaintext modulus we chose (31) is not NTT-friendly.
84 | In fact, only very few primes are NTT-friendly, and they are somewhat large. This is why, right now, the oraqle compiler only implements code generation for HElib.
85 | HElib is (as far as we are aware) the only library that supports plaintext moduli that are not NTT-friendly.
86 |
--------------------------------------------------------------------------------
/docs/images/oraqle_logo_cropped.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
19 |
39 |
41 |
46 |
49 |
52 |
56 |
60 |
64 |
68 |
72 |
76 |
81 |
86 |
89 |
93 |
97 |
98 |
101 |
105 |
109 |
110 |
116 |
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Welcome to oraqle
2 |
3 |
4 |
A secure computation compiler
5 |
6 |
7 | Simply install the most recent version of the Oraqle compiler using:
8 | ```
9 | pip install oraqle==0.1.0
10 | ```
11 |
12 | Consider checking out our [getting started page](getting_started.md) to help you get up to speed with arithmetizing circuits!
13 |
14 | ## API reference
15 | !!! warning
16 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version.
17 |
18 | For an API reference, you can check out the pages for [circuits](api/circuits_api.md) and for [nodes](api/nodes_api.md).
19 |
--------------------------------------------------------------------------------
/docs/tutorial_running_exps.md:
--------------------------------------------------------------------------------
1 | # Tutorial: Running experiments
2 | !!! failure
3 | This section is currently missing. Please see the [code generation API](api/code_generation_api.md) for some documentation for now.
4 |
--------------------------------------------------------------------------------
/main.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 |
9 | typedef helib::Ptxt ptxt_t;
10 | typedef helib::Ctxt ctxt_t;
11 |
12 | std::map input_map;
13 |
14 | void parse_arguments(int argc, char* argv[]) {
15 | for (int i = 1; i < argc; ++i) {
16 | std::string argument(argv[i]);
17 | size_t pos = argument.find('=');
18 | if (pos != std::string::npos) {
19 | std::string key = argument.substr(0, pos);
20 | int value = std::stoi(argument.substr(pos + 1));
21 | input_map[key] = value;
22 | }
23 | }
24 | }
25 |
26 | int extract_input(const std::string& name) {
27 | if (input_map.find(name) != input_map.end()) {
28 | return input_map[name];
29 | } else {
30 | std::cerr << "Error: " << name << " not found" << std::endl;
31 | return -1;
32 | }
33 | }
34 |
35 | int main(int argc, char* argv[]) {
36 | // Parse the inputs
37 | parse_arguments(argc, argv);
38 |
39 | // Set up the HE parameters
40 | unsigned long p = 257;
41 | unsigned long m = 65536;
42 | unsigned long r = 1;
43 | unsigned long bits = 449;
44 | unsigned long c = 3;
45 | helib::Context context = helib::ContextBuilder()
46 | .m(m)
47 | .p(p)
48 | .r(r)
49 | .bits(bits)
50 | .c(c)
51 | .build();
52 |
53 |
54 | // Generate keys
55 | helib::SecKey secret_key(context);
56 | secret_key.GenSecKey();
57 | helib::addSome1DMatrices(secret_key);
58 | const helib::PubKey& public_key = secret_key;
59 |
60 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Oraqle
2 |
3 | nav:
4 | - index.md
5 | - getting_started.md
6 | - tutorial_running_exps.md
7 | - API reference:
8 | - api/circuits_api.md
9 | - api/nodes_api.md
10 | - api/code_generation_api.md
11 | - api/pareto_fronts_api.md
12 | - api/abstract_nodes_api.md
13 | - api/addition_chains_api.md
14 | - example_circuits.md
15 | - config.md
16 |
17 | plugins:
18 | - search
19 | - mkdocstrings:
20 | handlers:
21 | python:
22 | options:
23 | show_root_heading: true
24 | allow_inspection: false
25 | show_submodules: false
26 | show_root_full_path: false
27 | show_symbol_type_heading: true
28 | # show_symbol_type_toc: true This currently causes a bug
29 | docstring_style: google
30 | follow_wrapped_lines: true
31 | crosslink_types: true # Makes types clickable
32 | crosslink_types_style: 'sphinx' # Default or sphinx style
33 | annotations_path: brief
34 | inherited_members: true
35 | members_order: source
36 | show_if_no_docstring: true
37 | separate_signature: false
38 | show_source: false
39 | docstring_section_style: list
40 |
41 | theme:
42 | name: material
43 | highlightjs: true
44 |
45 | markdown_extensions:
46 | - admonition
47 | - pymdownx.superfences
48 | - pymdownx.inlinehilite
49 | - pymdownx.critic
50 | - pymdownx.details
51 | - pymdownx.tasklist
52 | - pymdownx.tabbed
53 | - pymdownx.magiclink
54 | - pymdownx.tilde
55 | - toc:
56 | permalink: true
57 | toc_depth: 3
58 |
--------------------------------------------------------------------------------
/oraqle/__init__.py:
--------------------------------------------------------------------------------
1 | """This module contains the oraqle compiler, tools, and example circuits."""
2 |
--------------------------------------------------------------------------------
/oraqle/add_chains/__init__.py:
--------------------------------------------------------------------------------
1 | """Tools for generating addition chains using different constraints and objectives."""
2 |
--------------------------------------------------------------------------------
/oraqle/add_chains/addition_chains_front.py:
--------------------------------------------------------------------------------
1 | """Tools for generating addition chains that trade off depth and cost."""
2 | import math
3 | from typing import List, Optional, Tuple
4 |
5 | from oraqle.add_chains.addition_chains import add_chain
6 | from oraqle.add_chains.addition_chains_mod import add_chain_modp, hw, size_lower_bound
7 |
8 |
9 | def chain_depth(
10 | chain: List[Tuple[int, int]],
11 | precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None,
12 | modulus: Optional[int] = None,
13 | ) -> int:
14 | """Return the depth of the addition chain."""
15 | depths = {1: 0}
16 | if precomputed_values is not None:
17 | depths.update(precomputed_values)
18 |
19 | if modulus is None:
20 | for x, y in chain:
21 | depths[x + y] = max(depths[x], depths[y]) + 1
22 | else:
23 | for x, y in chain:
24 | depths[(x + y) % modulus] = max(depths[x % modulus], depths[y % modulus]) + 1
25 |
26 | return max(depths.values())
27 |
28 |
29 | def gen_pareto_front( # noqa: PLR0912, PLR0913, PLR0917
30 | target: int,
31 | modulus: Optional[int],
32 | squaring_cost: float,
33 | solver="glucose42",
34 | encoding=1,
35 | thurber=True,
36 | precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None,
37 | ) -> List[Tuple[int, List[Tuple[int, int]]]]:
38 | """Returns a Pareto front of addition chains, trading of cost and depth."""
39 | if target == 1:
40 | return [(0, [])]
41 |
42 | if modulus is not None:
43 | assert target <= modulus
44 |
45 | # Find the lowest depth chain using square & multiply (SaM)
46 | sam_depth = math.ceil(math.log2(target))
47 | sam_cost = math.ceil(math.log2(target)) * squaring_cost + hw(target) - 1
48 | sam_target = target
49 |
50 | # If there is a modulus, we should also consider it to find an upper bound on the cost of a minimum-depth chain
51 | if modulus is not None:
52 | current_target = target + modulus - 1
53 | while math.log2(current_target) <= sam_depth:
54 | current_cost = (
55 | math.ceil(math.log2(current_target)) * squaring_cost + hw(current_target) - 1
56 | )
57 | if current_cost < sam_cost:
58 | sam_cost = current_cost
59 | sam_target = target
60 | current_target += modulus - 1
61 |
62 | # Find the cheapest chain (i.e. no depth constraints)
63 | min_size = size_lower_bound(target) if precomputed_values is None else 1
64 | if modulus is None:
65 | cheapest_chain = add_chain(
66 | target,
67 | None,
68 | sam_cost,
69 | squaring_cost,
70 | solver,
71 | encoding,
72 | thurber,
73 | min_size,
74 | precomputed_values,
75 | )
76 | else:
77 | cheapest_chain = add_chain_modp(
78 | target,
79 | modulus,
80 | None,
81 | sam_cost,
82 | squaring_cost,
83 | solver,
84 | encoding,
85 | thurber,
86 | min_size,
87 | precomputed_values,
88 | )
89 |
90 | # If no cheapest chain is found that satisfies these bounds, then square and multiply had the same cost
91 | if cheapest_chain is None:
92 | sam_chain = []
93 | for i in range(math.ceil(math.log2(sam_target))):
94 | sam_chain.append((2**i, 2**i))
95 | previous = 1
96 | for i in range(math.ceil(math.log2(sam_target))):
97 | if (sam_target >> i) & 1:
98 | sam_chain.append((previous, 2**i))
99 | previous += 2**i
100 | return [(sam_depth, sam_chain)]
101 |
102 | add_size = len(cheapest_chain) # TODO: Check that this is indeed a valid bound
103 | add_cost = sum(squaring_cost if x == y else 1.0 for x, y in cheapest_chain)
104 | add_depth = chain_depth(cheapest_chain, precomputed_values, modulus=modulus)
105 |
106 | # Go through increasing depth and decrease the previous size, until we reach the cost of square and multiply
107 | pareto_front = []
108 | current_depth = sam_depth
109 | current_cost = sam_cost
110 | while current_cost > add_cost and current_depth < add_depth:
111 | if modulus is None:
112 | chain = add_chain(
113 | target,
114 | current_depth,
115 | current_cost,
116 | squaring_cost,
117 | solver,
118 | encoding,
119 | thurber,
120 | add_size,
121 | precomputed_values,
122 | )
123 | else:
124 | chain = add_chain_modp(
125 | target,
126 | modulus,
127 | current_depth,
128 | current_cost,
129 | squaring_cost,
130 | solver,
131 | encoding,
132 | thurber,
133 | add_size,
134 | precomputed_values,
135 | )
136 |
137 | if chain is not None:
138 | # Add to the Pareto front
139 | pareto_front.append((current_depth, chain))
140 | current_cost = sum(squaring_cost if x == y else 1.0 for x, y in chain)
141 |
142 | current_depth += 1
143 |
144 | # Add the final chain and return
145 | if add_cost < current_cost or len(pareto_front) == 0:
146 | pareto_front.append((add_depth, cheapest_chain))
147 |
148 | return pareto_front
149 |
150 |
151 | def test_gen_exponentiation_front_small(): # noqa: D103
152 | front = gen_pareto_front(2, None, 0.75)
153 | assert front == [(1, [(1, 1)])]
154 |
--------------------------------------------------------------------------------
/oraqle/add_chains/addition_chains_heuristic.py:
--------------------------------------------------------------------------------
1 | """This module contains functions for finding addition chains, while sometimes resorting to heuristics to prevent long computations."""
2 |
3 | from functools import lru_cache
4 | import math
5 | from typing import List, Optional, Tuple
6 |
7 | from oraqle.add_chains.addition_chains import add_chain
8 | from oraqle.add_chains.addition_chains_mod import add_chain_modp, hw
9 | from oraqle.add_chains.solving import extract_indices
10 |
11 |
12 | def _mul(current_chain: List[Tuple[int, int]], other_chain: List[Tuple[int, int]]):
13 | length = len(current_chain)
14 | for a, b in other_chain:
15 | current_chain.append((a + length, b + length))
16 |
17 |
18 | def _chain(n, k) -> List[Tuple[int, int]]:
19 | q = n // k
20 | r = n % k
21 | if r in {0, 1}:
22 | chain_k = _minchain(k)
23 | _mul(chain_k, _minchain(q))
24 | if r == 1:
25 | chain_k.append((0, len(chain_k)))
26 | return chain_k
27 | else:
28 | chain_k = _chain(k, r)
29 | index_r = len(chain_k)
30 | _mul(chain_k, _minchain(q))
31 | chain_k.append((index_r, len(chain_k)))
32 | return chain_k
33 |
34 |
35 | def _minchain(n: int) -> List[Tuple[int, int]]:
36 | log_n = n.bit_length() - 1
37 | if n == 1 << log_n:
38 | return [(i, i) for i in range(log_n)]
39 | elif n == 3:
40 | return [(0, 0), (0, 1)]
41 | else:
42 | k = n // (1 << (log_n // 2))
43 | return _chain(n, k)
44 |
45 |
46 | @lru_cache
47 | def add_chain_guaranteed( # noqa: PLR0913, PLR0917
48 | target: int,
49 | modulus: Optional[int],
50 | squaring_cost: float,
51 | solver: str = "glucose421",
52 | encoding: int = 1,
53 | thurber: bool = True,
54 | precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None,
55 | ) -> List[Tuple[int, int]]:
56 | """Always generates an addition chain for a given target, which is suboptimal if the inputs are too large.
57 |
58 | In some cases, the result is not necessarily optimal. These are the cases where we resort to a heuristic.
59 | This currently happens if:
60 | - The target exceeds 1000.
61 | - The modulus (if provided) exceeds 200.
62 | - MAXSAT_TIMEOUT is not None and a MaxSAT instance timed out
63 |
64 | !!! note
65 | This function is useful for preventing long computation, but the result is not guaranteed to be (close to) optimal.
66 | Unlike `add_chain`, this function will always return an addition chain.
67 |
68 | Parameters:
69 | target: The target integer.
70 | modulus: Modulus to take into account. In an exponentiation chain, this is the modulus in the exponent, i.e. x^target mod p corresponds to `modulus = p - 1`.
71 | squaring_cost: The cost of doubling (squaring), compared to other additions (multiplications), which cost 1.0.
72 | solver: Name of the SAT solver, e.g. "glucose421" for glucose 4.2.1. See: https://pysathq.github.io/docs/html/api/solvers.html.
73 | encoding: The encoding to use for cardinality constraints. See: https://pysathq.github.io/docs/html/api/card.html#pysat.card.EncType.
74 | thurber: Whether to use the Thurber bounds, which provide lower bounds for the elements in the chain. The bounds are ignored when `precomputed_values = True`.
75 | precomputed_values: If there are any precomputed values that can be used for free, they can be specified as a tuple of pairs (value, chain_depth).
76 |
77 | Raises: # noqa: DOC502
78 | TimeoutError: If the global MAXSAT_TIMEOUT is not None, and it is reached before a maxsat instance could be solved.
79 |
80 | Returns:
81 | An addition chain.
82 | """
83 | # We want to do better than square and multiply, so we find an upper bound
84 | sam_cost = math.ceil(math.log2(target)) * squaring_cost + hw(target) - 1
85 |
86 | # Apply CSE to the square & mutliply chain
87 | if precomputed_values is not None:
88 | for exp, depth in precomputed_values:
89 | if exp > 0 and (exp & (exp - 1)) == 0 and depth == math.log2(exp):
90 | sam_cost -= squaring_cost
91 |
92 | try:
93 | addition_chain = None
94 | if modulus is not None and modulus <= 200:
95 | addition_chain = add_chain_modp(
96 | target,
97 | modulus,
98 | None,
99 | sam_cost,
100 | squaring_cost,
101 | solver,
102 | encoding,
103 | thurber,
104 | min_size=math.ceil(math.log2(target)) if precomputed_values is None else 1,
105 | precomputed_values=precomputed_values,
106 | )
107 | elif target <= 1000:
108 | addition_chain = add_chain(
109 | target,
110 | None,
111 | sam_cost,
112 | squaring_cost,
113 | solver,
114 | encoding,
115 | thurber,
116 | min_size=math.ceil(math.log2(target)) if precomputed_values is None else 1,
117 | precomputed_values=precomputed_values,
118 | )
119 |
120 | if addition_chain is not None:
121 | addition_chain = extract_indices(
122 | addition_chain, precomputed_values=None if precomputed_values is None else list(k for k, _ in precomputed_values), modulus=modulus
123 | )
124 | except TimeoutError:
125 | # The MaxSAT solver timed out, so we resort to a heuristic
126 | pass
127 |
128 | if addition_chain is None:
129 | # If no other addition chain algorithm has been called or if we could not do better than square and multiply
130 |
131 | # Uses the minchain algorithm from ["Addition chains using continued fractions."][BBBD1989]
132 | # The implementation was adapted from the `addchain` Rust crate (https://github.com/str4d/addchain).
133 | # This algorithm is not optimal: Below 1000 it requires one too many multiplication in 29 cases.
134 | addition_chain = _minchain(target)
135 |
136 | if precomputed_values is not None:
137 | # We must shift the indices in the addition chain
138 | shift = len(precomputed_values)
139 | addition_chain = [(0 if x == 0 else x + shift, 0 if y == 0 else y + shift) for (x, y) in addition_chain]
140 |
141 | assert addition_chain is not None
142 |
143 | return addition_chain
144 |
--------------------------------------------------------------------------------
/oraqle/add_chains/addition_chains_mod.py:
--------------------------------------------------------------------------------
1 | """Tools for computing addition chains, taking into account the modular nature of the algebra."""
2 | import math
3 | from typing import List, Optional, Tuple
4 |
5 | from oraqle.add_chains.addition_chains import add_chain
6 |
7 |
8 | def hw(n: int) -> int:
9 | """Returns the Hamming weight of n."""
10 | c = 0
11 | while n:
12 | c += 1
13 | n &= n - 1
14 |
15 | return c
16 |
17 |
18 | def size_lower_bound(target: int) -> int:
19 | """Returns a lower bound on the size of the addition chain for this target."""
20 | return math.ceil(
21 | max(
22 | math.log2(target) + math.log2(hw(target)) - 2.13,
23 | math.log2(target),
24 | math.log2(target) + math.log(hw(target), 3) - 1,
25 | )
26 | )
27 |
28 |
29 | def cost_lower_bound_monotonic(target: int, squaring_cost: float) -> float:
30 | """Returns a lower bound on the cost of the addition chain for this target. The bound is guaranteed to grow monotonically with the target."""
31 | return math.ceil(math.log2(target)) * squaring_cost
32 |
33 |
34 | def chain_cost(chain: List[Tuple[int, int]], squaring_cost: float) -> float:
35 | """Returns the cost of the addition chain, considering doubling (squaring) to be cheaper than other additions (multiplications)."""
36 | return sum(squaring_cost if x == y else 1.0 for x, y in chain)
37 |
38 |
39 | def add_chain_modp( # noqa: PLR0913, PLR0917
40 | target: int,
41 | modulus: int,
42 | max_depth: Optional[int],
43 | strict_cost_max: float,
44 | squaring_cost: float,
45 | solver,
46 | encoding,
47 | thurber,
48 | min_size: int,
49 | precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None,
50 | ) -> Optional[List[Tuple[int, int]]]:
51 | """Computes an addition chain for target modulo p with the given constraints and optimization parameters.
52 |
53 | The precomputed_powers are an optional set of powers that have previously been computed along with their depth.
54 | This means that those powers can be reused for free.
55 |
56 | Returns:
57 | If it exists, a minimal addition chain meeting the given constraints and optimization parameters.
58 | """
59 | if precomputed_values is not None:
60 | # The shortest chain in (t + (k-1)p, t + kp] will have length at least k
61 | # The cheapest chain in (t + (k-1)p, t + kp] will have cost at least k / sqr_cost
62 | best_chain = None
63 |
64 | k = 0
65 | while (k / squaring_cost) < strict_cost_max:
66 | # Add multiples of the precomputed_values
67 | new_precomputed_values = []
68 | for precomputed_value, depth in precomputed_values:
69 | for i in range(k + 1):
70 | new_precomputed_values.append((precomputed_value + i * modulus, depth))
71 |
72 | chain = add_chain(
73 | target + k * modulus,
74 | max_depth,
75 | strict_cost_max,
76 | squaring_cost,
77 | solver,
78 | encoding,
79 | thurber,
80 | min_size=max(min_size, k),
81 | precomputed_values=tuple(new_precomputed_values),
82 | )
83 |
84 | if chain is not None:
85 | cost = chain_cost(chain, squaring_cost)
86 | strict_cost_max = min(strict_cost_max, cost)
87 | best_chain = chain
88 |
89 | k += 1
90 |
91 | return best_chain
92 |
93 | best_chain = None
94 | best_cost = None
95 |
96 | current_target = target
97 |
98 | i = 0
99 |
100 | while cost_lower_bound_monotonic(current_target, squaring_cost) < strict_cost_max and (
101 | max_depth is None or math.ceil(math.log2(current_target)) <= max_depth
102 | ):
103 | tightest_min_size = max(size_lower_bound(current_target), min_size)
104 | if (tightest_min_size * squaring_cost) >= (
105 | strict_cost_max if best_cost is None else min(strict_cost_max, best_cost)
106 | ):
107 | current_target += modulus
108 | continue
109 |
110 | chain = add_chain(
111 | current_target,
112 | max_depth,
113 | strict_cost_max,
114 | squaring_cost,
115 | solver,
116 | encoding,
117 | thurber,
118 | tightest_min_size,
119 | precomputed_values,
120 | )
121 |
122 | if chain is not None:
123 | cost = chain_cost(chain, squaring_cost)
124 | if best_cost is None or cost < best_cost:
125 | best_cost = cost
126 | best_chain = chain
127 | strict_cost_max = min(best_cost, strict_cost_max)
128 |
129 | current_target += modulus
130 |
131 | i += 1
132 | return best_chain
133 |
134 |
135 | def test_add_chain_modp_over_modulus(): # noqa: D103
136 | chain = add_chain_modp(
137 | 62,
138 | 66,
139 | None,
140 | 8.0,
141 | 0.75,
142 | solver="glucose42",
143 | encoding=1,
144 | thurber=True,
145 | min_size=1,
146 | precomputed_values=None,
147 | )
148 | assert chain == [(1, 1), (2, 2), (4, 4), (8, 8), (16, 16), (32, 32), (64, 64)]
149 |
150 |
151 | def test_add_chain_modp_precomputations(): # noqa: D103
152 | chain = add_chain_modp(
153 | 64, # 64+66 = 65+65
154 | 66,
155 | None,
156 | 2.0,
157 | 0.75,
158 | solver="glucose42",
159 | encoding=1,
160 | thurber=True,
161 | min_size=1,
162 | precomputed_values=((65, 5),),
163 | )
164 | assert chain == [(65, 65)]
165 |
166 |
167 | if __name__ == "__main__":
168 | print(add_chain_modp(
169 | 254,
170 | 255,
171 | None,
172 | 8.0,
173 | 0.5,
174 | solver="glucose42",
175 | encoding=1,
176 | thurber=True,
177 | min_size=11,
178 | precomputed_values=None,
179 | ))
180 |
--------------------------------------------------------------------------------
/oraqle/add_chains/memoization.py:
--------------------------------------------------------------------------------
1 | """This module contains tools for memoizing addition chains, as these are expensive to compute."""
2 | from hashlib import sha3_256
3 | from importlib.resources import files
4 | import inspect
5 | import shelve
6 | from typing import Set
7 |
8 | from sympy import sieve
9 |
10 | import oraqle
11 |
12 |
13 | ADDCHAIN_CACHE_FILENAME = "addchain_cache"
14 |
15 |
16 | # Adapted from: https://stackoverflow.com/questions/16463582/memoize-to-disk-python-persistent-memoization
17 | def cache_to_disk(ignore_args: Set[str]):
18 | """This decorator caches the calls to this function in a file on disk, ignoring the arguments listed in `ignore_args`.
19 |
20 | Returns:
21 | A cached output
22 | """
23 | # Always opens the database in the root of where the package is located
24 | oraqle_path = files(oraqle)
25 | database_path = oraqle_path.joinpath(ADDCHAIN_CACHE_FILENAME)
26 | d = shelve.open(str(database_path)) # noqa: SIM115
27 |
28 | def decorator(func):
29 | signature = inspect.signature(func)
30 | signature_args = list(signature.parameters.keys())
31 | assert all(arg in signature_args for arg in ignore_args)
32 |
33 | def wrapped_func(*args, **kwargs):
34 | relevant_args = [a for a, sa in zip(args, signature_args) if sa not in ignore_args]
35 | for kwarg in signature_args[len(args):]:
36 | if kwarg not in ignore_args:
37 | relevant_args.append(kwargs[kwarg])
38 |
39 | h = sha3_256()
40 | h.update(str(relevant_args).encode('ascii'))
41 | hashed_args = h.hexdigest()
42 |
43 | if hashed_args not in d:
44 | d[hashed_args] = func(*args, **kwargs)
45 | return d[hashed_args]
46 |
47 | return wrapped_func
48 |
49 | return decorator
50 |
51 |
52 | if __name__ == "__main__":
53 | from oraqle.add_chains.addition_chains_front import gen_pareto_front
54 |
55 | # Precompute addition chains for x^(p-1) mod p for the first 30 primes p
56 | primes = list(sieve.primerange(300))[:30]
57 | for sqr_cost in [0.5, 0.75, 1.0]:
58 | print(f"Computing for {sqr_cost}")
59 |
60 | for p in primes:
61 | gen_pareto_front(
62 | p - 1,
63 | modulus=p - 1,
64 | squaring_cost=sqr_cost,
65 | solver="glucose42",
66 | encoding=1,
67 | thurber=True,
68 | )
69 |
--------------------------------------------------------------------------------
/oraqle/add_chains/solving.py:
--------------------------------------------------------------------------------
1 | """Tools for solving SAT formulations."""
2 | import math
3 | import signal
4 | from typing import List, Optional, Sequence, Tuple
5 |
6 | from pysat.examples.rc2 import RC2
7 | from pysat.formula import WCNF
8 |
9 |
10 | def solve(wcnf: WCNF, solver: str, strict_cost_max: Optional[float]) -> Optional[List[int]]:
11 | """This code is adapted from pysat's internal code to stop when we have reached a maximum cost.
12 |
13 | Returns:
14 | A list containing the assignment (where 3 indicates that 3=True and -3 indicates that 3=False), or None if the wcnf is unsatisfiable.
15 | """
16 | rc2 = RC2(wcnf, solver)
17 |
18 | if strict_cost_max is None:
19 | strict_cost_max = float("inf")
20 |
21 | while not rc2.oracle.solve(assumptions=rc2.sels + rc2.sums): # type: ignore
22 | rc2.get_core()
23 |
24 | if not rc2.core:
25 | # core is empty, i.e. hard part is unsatisfiable
26 | return None
27 |
28 | rc2.process_core()
29 |
30 | if rc2.cost >= strict_cost_max:
31 | return None
32 |
33 | rc2.model = rc2.oracle.get_model() # type: ignore
34 |
35 | # Return None if the model could not be solved
36 | if rc2.model is None:
37 | return None
38 |
39 | # Extract the model
40 | if rc2.model is None and rc2.pool.top == 0:
41 | # we seem to have been given an empty formula
42 | # so let's transform the None model returned to []
43 | rc2.model = []
44 |
45 | rc2.model = filter(lambda inp: abs(inp) in rc2.vmap.i2e, rc2.model) # type: ignore
46 | rc2.model = map(lambda inp: int(math.copysign(rc2.vmap.i2e[abs(inp)], inp)), rc2.model)
47 | rc2.model = sorted(rc2.model, key=abs)
48 |
49 | return rc2.model
50 |
51 |
52 | def extract_indices(
53 | sequence: List[Tuple[int, int]],
54 | precomputed_values: Optional[Sequence[int]] = None,
55 | modulus: Optional[int] = None,
56 | ) -> List[Tuple[int, int]]:
57 | """Returns the indices for each step of the addition chain.
58 |
59 | If n precomputed values are provided, then these are considered to be the first n indices after x (i.e. x has index 0, followed by 1, ..., n representing the precomputed values).
60 | """
61 | indices = {1: 0}
62 | offset = 1
63 | if precomputed_values is not None:
64 | for v in precomputed_values:
65 | indices[v] = offset
66 | offset += 1
67 | ans_sequence = []
68 |
69 | if modulus is None:
70 | for index, pair in enumerate(sequence):
71 | i, j = pair
72 | ans_sequence.append((indices[i], indices[j]))
73 | indices[i + j] = index + offset
74 | else:
75 | for index, pair in enumerate(sequence):
76 | i, j = pair
77 | ans_sequence.append((indices[i % modulus], indices[j % modulus]))
78 | indices[(i + j) % modulus] = index + offset
79 |
80 | return ans_sequence
81 |
82 |
83 | def solve_with_time_limit(wcnf: WCNF, solver: str, strict_cost_max: Optional[float], timeout_secs: float) -> Optional[List[int]]:
84 | """This code is adapted from pysat's internal code to stop when we have reached a maximum cost.
85 |
86 | Raises:
87 | TimeoutError: When a timeout occurs (after `timeout_secs` seconds)
88 |
89 | Returns:
90 | A list containing the assignment (where 3 indicates that 3=True and -3 indicates that 3=False), or None if the wcnf is unsatisfiable.
91 | """
92 | def timeout_handler(s, f):
93 | raise TimeoutError
94 |
95 | # Set the timeout
96 | signal.signal(signal.SIGALRM, timeout_handler)
97 | signal.setitimer(signal.ITIMER_REAL, timeout_secs)
98 |
99 | try:
100 | # TODO: Reduce code duplication: we only changed solve to solve_limited
101 | rc2 = RC2(wcnf, solver)
102 |
103 | if strict_cost_max is None:
104 | strict_cost_max = float("inf")
105 |
106 | while not rc2.oracle.solve_limited(assumptions=rc2.sels + rc2.sums, expect_interrupt=True): # type: ignore
107 | rc2.get_core()
108 |
109 | if not rc2.core:
110 | # core is empty, i.e. hard part is unsatisfiable
111 | signal.setitimer(signal.ITIMER_REAL, 0)
112 | return None
113 |
114 | rc2.process_core()
115 |
116 | if rc2.cost >= strict_cost_max:
117 | signal.setitimer(signal.ITIMER_REAL, 0)
118 | return None
119 |
120 | signal.setitimer(signal.ITIMER_REAL, 0)
121 | rc2.model = rc2.oracle.get_model() # type: ignore
122 |
123 | # Return None if the model could not be solved
124 | if rc2.model is None:
125 | return None
126 |
127 | # Extract the model
128 | if rc2.model is None and rc2.pool.top == 0:
129 | # we seem to have been given an empty formula
130 | # so let's transform the None model returned to []
131 | rc2.model = []
132 |
133 | rc2.model = filter(lambda inp: abs(inp) in rc2.vmap.i2e, rc2.model) # type: ignore
134 | rc2.model = map(lambda inp: int(math.copysign(rc2.vmap.i2e[abs(inp)], inp)), rc2.model)
135 | rc2.model = sorted(rc2.model, key=abs)
136 |
137 | return rc2.model
138 | except TimeoutError as err:
139 | raise TimeoutError from err
140 |
--------------------------------------------------------------------------------
/oraqle/circuits/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains example circuits and tools for generating them."""
2 |
--------------------------------------------------------------------------------
/oraqle/circuits/aes.py:
--------------------------------------------------------------------------------
1 | """This module implements a high-level AES encryption circuit for a constant key."""
2 | from typing import List
3 |
4 | from aeskeyschedule import key_schedule
5 | from galois import GF
6 |
7 | from oraqle.compiler.arithmetic.exponentiation import Power
8 | from oraqle.compiler.circuit import Circuit
9 | from oraqle.compiler.nodes import Constant
10 | from oraqle.compiler.nodes.abstract import Node
11 | from oraqle.compiler.nodes.leafs import Input
12 |
13 | gf = GF(2**8)
14 |
15 |
16 | def encrypt(plaintext: List[Node], key: bytes) -> List[Node]:
17 | """Returns an AES encryption circuit for a constant `key`."""
18 | mix = [Constant(gf(2)), Constant(gf(3)), Constant(gf(1)), Constant(gf(1))]
19 |
20 | round_keys = [[Constant(gf(byte)) for byte in round_key] for round_key in key_schedule(key)]
21 |
22 | def additions(nodes: List[Node]) -> Node:
23 | node_iter = iter(nodes)
24 | out = next(node_iter) + next(node_iter)
25 | for node in node_iter:
26 | out += node
27 | return out
28 |
29 | def sbox(node: Node, method="minchain") -> Node:
30 | if method == "hardcoded":
31 | x2 = node.mul(node, flatten=False)
32 | x3 = node.mul(x2, flatten=False)
33 | x6 = x3.mul(x3, flatten=False)
34 | x12 = x6.mul(x6, flatten=False)
35 | x15 = x12.mul(x3, flatten=False)
36 | x30 = x15.mul(x15, flatten=False)
37 | x60 = x30.mul(x30, flatten=False)
38 | x63 = x60.mul(x3, flatten=False)
39 | x126 = x63.mul(x63, flatten=False)
40 | x127 = node.mul(x126, flatten=False)
41 | x254 = x127.mul(x127, flatten=False)
42 | return x254
43 | elif method == "minchain":
44 | return Power(node, 254, gf)
45 | else:
46 | raise Exception(f"Invalid method: {method}.")
47 |
48 | # AddRoundKey
49 | b = [round_key + plaintext_byte for round_key, plaintext_byte in zip(round_keys[0], plaintext)]
50 |
51 | for round in range(9):
52 | # SubBytes (modular inverse)
53 | b = [sbox(b[j], method="hardcoded") for j in range(16)]
54 |
55 | # ShiftRows
56 | b[1], b[5], b[9], b[13] = b[5], b[9], b[13], b[1]
57 | b[2], b[6], b[10], b[14] = b[10], b[14], b[2], b[6]
58 | b[3], b[7], b[11], b[15] = b[15], b[3], b[7], b[11]
59 |
60 | # MixColumns
61 | b = [additions([mix[(j + i) % 4] * b[j // 4 + i] for i in range(4)]) for j in range(16)]
62 |
63 | # AddRoundKey
64 | b = [round_key + b[j] for j, round_key in zip(range(16), round_keys[round + 1])]
65 | b: List[Node]
66 |
67 | return b
68 |
69 |
70 | if __name__ == "__main__":
71 | # TODO: Consider if we want to support degree > 1
72 | circuit = Circuit(
73 | encrypt([Input(f"{i}", gf) for i in range(16)], b"abcdabcdabcdabcd")
74 | ).arithmetize()
75 | print(circuit)
76 | print(circuit.multiplicative_depth())
77 | print(circuit.multiplicative_size())
78 | circuit.eliminate_subexpressions()
79 | print(circuit.multiplicative_depth())
80 | print(circuit.multiplicative_size())
81 |
82 | # TODO: Test if it corresponds to a plaintext implementation of AES
83 |
84 |
85 | def test_aes_128(): # noqa: D103
86 | # Only checks if no errors occur
87 | Circuit(encrypt([Input(f"{i}", gf) for i in range(16)], b"abcdabcdabcdabcd")).arithmetize()
88 |
--------------------------------------------------------------------------------
/oraqle/circuits/cardio.py:
--------------------------------------------------------------------------------
1 | """This module implements the cardio circuit that is often used in benchmarking compilers, see: https://arxiv.org/abs/2101.07078."""
2 | from typing import Type
3 | from galois import GF, FieldArray
4 |
5 | from oraqle.compiler.boolean.bool_neg import Neg
6 | from oraqle.compiler.boolean.bool_or import any_
7 | from oraqle.compiler.circuit import Circuit
8 | from oraqle.compiler.nodes import Input
9 | from oraqle.compiler.nodes.abstract import Node
10 | from oraqle.compiler.nodes.arbitrary_arithmetic import sum_
11 |
12 |
13 | def construct_cardio_risk_circuit(gf: Type[FieldArray]) -> Node:
14 | """Returns the cardio circuit from https://arxiv.org/abs/2101.07078."""
15 | man = Input("man", gf)
16 | smoking = Input("smoking", gf)
17 | age = Input("age", gf)
18 | diabetic = Input("diabetic", gf)
19 | hbp = Input("hbp", gf)
20 | cholesterol = Input("cholesterol", gf)
21 | weight = Input("weight", gf)
22 | height = Input("height", gf)
23 | activity = Input("activity", gf)
24 | alcohol = Input("alcohol", gf)
25 |
26 | return sum_(
27 | man & (age > 50),
28 | Neg(man, gf) & (age > 60),
29 | smoking,
30 | diabetic,
31 | hbp,
32 | cholesterol < 40,
33 | weight > (height - 90), # This might underflow if the modulus is too small
34 | activity < 30,
35 | man & (alcohol > 3),
36 | Neg(man, gf) & (alcohol > 2),
37 | )
38 |
39 |
40 | def construct_cardio_elevated_risk_circuit(gf: Type[FieldArray]) -> Node:
41 | """Returns a variant of the cardio circuit that returns a Boolean indicating whether any risk factor returned true."""
42 | man = Input("man", gf)
43 | smoking = Input("smoking", gf)
44 | age = Input("age", gf)
45 | diabetic = Input("diabetic", gf)
46 | hbp = Input("hbp", gf)
47 | cholesterol = Input("cholesterol", gf)
48 | weight = Input("weight", gf)
49 | height = Input("height", gf)
50 | activity = Input("activity", gf)
51 | alcohol = Input("alcohol", gf)
52 |
53 | return any_(
54 | man & (age > 50),
55 | Neg(man, gf) & (age > 60),
56 | smoking,
57 | diabetic,
58 | hbp,
59 | cholesterol < 40,
60 | weight > (height - 90), # This might underflow if the modulus is too small
61 | activity < 30,
62 | man & (alcohol > 3),
63 | Neg(man, gf) & (alcohol > 2),
64 | )
65 |
66 |
67 | def test_cardio_p101(): # noqa: D103
68 | gf = GF(101)
69 | circuit = Circuit([construct_cardio_risk_circuit(gf)])
70 |
71 | for _, _, arithmetization in circuit.arithmetize_depth_aware():
72 | assert arithmetization.evaluate({
73 | "man": gf(1),
74 | "woman": gf(0),
75 | "age": gf(50),
76 | "smoking": gf(0),
77 | "diabetic": gf(0),
78 | "hbp": gf(0),
79 | "cholesterol": gf(45),
80 | "weight": gf(10),
81 | "height": gf(100),
82 | "activity": gf(90),
83 | "alcohol": gf(3),
84 | })[0] == 0
85 |
86 | assert arithmetization.evaluate({
87 | "man": gf(0),
88 | "woman": gf(1),
89 | "age": gf(50),
90 | "smoking": gf(0),
91 | "diabetic": gf(0),
92 | "hbp": gf(0),
93 | "cholesterol": gf(45),
94 | "weight": gf(10),
95 | "height": gf(100),
96 | "activity": gf(90),
97 | "alcohol": gf(3),
98 | })[0] == 1
99 |
100 | assert arithmetization.evaluate({
101 | "man": gf(1),
102 | "woman": gf(0),
103 | "age": gf(50),
104 | "smoking": gf(0),
105 | "diabetic": gf(0),
106 | "hbp": gf(0),
107 | "cholesterol": gf(39),
108 | "weight": gf(10),
109 | "height": gf(100),
110 | "activity": gf(90),
111 | "alcohol": gf(3),
112 | })[0] == 1
113 |
114 | assert arithmetization.evaluate({
115 | "man": gf(1),
116 | "woman": gf(0),
117 | "age": gf(50),
118 | "smoking": gf(1),
119 | "diabetic": gf(0),
120 | "hbp": gf(0),
121 | "cholesterol": gf(45),
122 | "weight": gf(10),
123 | "height": gf(100),
124 | "activity": gf(90),
125 | "alcohol": gf(3),
126 | })[0] == 1
127 |
--------------------------------------------------------------------------------
/oraqle/circuits/median.py:
--------------------------------------------------------------------------------
1 | """This module implements circuits for computing the median."""
2 | from typing import Sequence, Type
3 |
4 | from galois import GF, FieldArray
5 |
6 | from oraqle.circuits.sorting import cswp
7 | from oraqle.compiler.circuit import Circuit
8 | from oraqle.compiler.nodes import Input
9 |
10 | gf = GF(1037347783)
11 |
12 |
13 | def gen_median_circuit(inputs: Sequence[int], gf: Type[FieldArray]):
14 | """Returns a naive circuit for finding the median value of `inputs`."""
15 | input_nodes = [Input(f"Input {v}", gf) for v in inputs]
16 |
17 | outputs = [n for n in input_nodes]
18 |
19 | for i in range(len(outputs) - 1, -1, -1):
20 | for j in range(i):
21 | outputs[j], outputs[j + 1] = cswp(outputs[j], outputs[j + 1]) # type: ignore
22 |
23 | if len(outputs) % 2 == 1:
24 | return Circuit([outputs[len(outputs) // 2]])
25 | return Circuit([outputs[len(outputs) // 2 + 1]])
26 |
27 |
28 | if __name__ == "__main__":
29 | circuit = gen_median_circuit(range(10), gf)
30 | circuit.to_graph("median.dot")
31 |
--------------------------------------------------------------------------------
/oraqle/circuits/mimc.py:
--------------------------------------------------------------------------------
1 | """MIMC is an MPC-friendly cipher: https://eprint.iacr.org/2016/492."""
2 | from math import ceil, log2
3 | from random import randint
4 |
5 | from galois import GF
6 |
7 | from oraqle.compiler.circuit import Circuit
8 | from oraqle.compiler.nodes import Constant, Input, Node
9 |
10 | gf = GF(680564733841876926926749214863536422929)
11 |
12 |
13 | # TODO: Check parameters with the paper
14 | def encrypt(plaintext: Node, key: int, power_n: int = 129) -> Node:
15 | """Returns an MIMC encryption circuit using a constant key."""
16 | rounds = ceil(power_n / log2(3))
17 |
18 | constants = [
19 | (
20 | Constant(gf(0))
21 | if (round == 0) or (round == (rounds - 1))
22 | else Constant(gf(randint(0, 2**power_n)))
23 | )
24 | for round in range(rounds)
25 | ]
26 | key_constant = Constant(gf(key))
27 |
28 | for round in range(rounds):
29 | added = plaintext + key_constant + constants[round]
30 | plaintext = added * added * added
31 |
32 | return plaintext + key_constant
33 |
34 |
35 | if __name__ == "__main__":
36 | node = encrypt(Input("m", gf), 12345)
37 |
38 | circuit = Circuit([node]).arithmetize()
39 | print(circuit.multiplicative_depth())
40 | print(circuit.multiplicative_size())
41 |
42 | circuit.to_graph("mimc-129.dot")
43 |
44 |
45 | def test_mimc_129(): # noqa: D103
46 | node = encrypt(Input("m", gf), 12345)
47 |
48 | circuit = Circuit([node]).arithmetize()
49 |
50 | assert circuit.multiplicative_depth() == 164
51 | assert circuit.multiplicative_size() == 164
52 |
--------------------------------------------------------------------------------
/oraqle/circuits/sorting.py:
--------------------------------------------------------------------------------
1 | """This module contains sorting circuits and comparators."""
2 | from typing import Sequence, Tuple, Type
3 |
4 | from galois import GF, FieldArray
5 |
6 | from oraqle.compiler.circuit import ArithmeticCircuit, Circuit
7 | from oraqle.compiler.nodes import Input
8 | from oraqle.compiler.nodes.abstract import Node
9 |
10 | gf = GF(13)
11 |
12 |
13 | def cswp(lhs: Node, rhs: Node) -> Tuple[Node, Node]:
14 | """Conditionally swap inputs `lhs` and `rhs` such that `lhs <= rhs`.
15 |
16 | Returns:
17 | A tuple representing (lower, higher)
18 | """
19 | teq = lhs < rhs
20 |
21 | first = teq * (lhs - rhs) + rhs
22 | second = lhs + rhs - first
23 |
24 | return (
25 | first,
26 | second,
27 | )
28 |
29 |
30 | def gen_naive_sort_circuit(inputs: Sequence[int], gf: Type[FieldArray]) -> ArithmeticCircuit:
31 | """Returns a naive sorting circuit for the given sequence of `inputs`."""
32 | input_nodes = [Input(f"Input {v}", gf) for v in inputs]
33 |
34 | outputs = [n for n in input_nodes]
35 |
36 | for i in range(len(outputs) - 1, -1, -1):
37 | for j in range(i):
38 | outputs[j], outputs[j + 1] = cswp(outputs[j], outputs[j + 1]) # type: ignore
39 |
40 | return Circuit(outputs).arithmetize() # type: ignore
41 |
42 |
43 | if __name__ == "__main__":
44 | circuit = gen_naive_sort_circuit(range(2), gf)
45 | circuit.to_graph("sorting.dot")
46 |
--------------------------------------------------------------------------------
/oraqle/circuits/veto_voting.py:
--------------------------------------------------------------------------------
1 | """The veto voting circuit is the inverse of a consensus vote between a number of participants.
2 |
3 | The circuit is essentially a large OR operation, returning 1 if any participant vetoes (by submitting a 1).
4 | This represents a vote that anyone can veto.
5 | """
6 | from typing import Type
7 |
8 | from galois import GF, FieldArray
9 |
10 | from oraqle.compiler.boolean.bool_or import any_
11 | from oraqle.compiler.circuit import Circuit
12 | from oraqle.compiler.nodes import Input
13 |
14 | gf = GF(103)
15 |
16 |
17 | def gen_veto_voting_circuit(participants: int, gf: Type[FieldArray]):
18 | """Returns a veto voting circuit between the number of `participants`."""
19 | input_nodes = {Input(f"Input {i}", gf) for i in range(participants)}
20 | return Circuit([any_(*input_nodes)])
21 |
22 |
23 | if __name__ == "__main__":
24 | circuit = gen_veto_voting_circuit(10, gf).arithmetize()
25 |
26 | circuit.eliminate_subexpressions()
27 | circuit.to_graph("veto-voting.dot")
28 |
--------------------------------------------------------------------------------
/oraqle/compiler/__init__.py:
--------------------------------------------------------------------------------
1 | """The compiler package contains the main machinery for describing high-level circuits, arithmetizing them, and generating code."""
2 |
--------------------------------------------------------------------------------
/oraqle/compiler/arithmetic/__init__.py:
--------------------------------------------------------------------------------
1 | """This module contains classes for arithmetic operations that are not simply additions or multiplications."""
2 |
--------------------------------------------------------------------------------
/oraqle/compiler/arithmetic/exponentiation.py:
--------------------------------------------------------------------------------
1 | """This module contains classes and functions for efficient exponentiation circuits."""
2 | import math
3 | from typing import Type
4 |
5 | from galois import GF, FieldArray
6 |
7 | from oraqle.add_chains.addition_chains_front import gen_pareto_front
8 | from oraqle.add_chains.addition_chains_heuristic import add_chain_guaranteed
9 | from oraqle.add_chains.solving import extract_indices
10 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node
11 | from oraqle.compiler.nodes.binary_arithmetic import Multiplication
12 | from oraqle.compiler.nodes.leafs import Input
13 | from oraqle.compiler.nodes.univariate import UnivariateNode
14 |
15 |
16 | # TODO: Think about the role of Power when there are also Products
17 | class Power(UnivariateNode):
18 | """Represents an exponentiation: x ** constant."""
19 |
20 | @property
21 | def _node_shape(self) -> str:
22 | return "box"
23 |
24 | @property
25 | def _hash_name(self) -> str:
26 | return f"pow_{self._exponent}"
27 |
28 | @property
29 | def _node_label(self) -> str:
30 | return f"Pow: {self._exponent}"
31 |
32 | def __init__(self, node: Node, exponent: int, gf: Type[FieldArray]):
33 | """Initialize a `Power` node that exponentiates `node` with `exponent`."""
34 | self._exponent = exponent
35 | super().__init__(node, gf)
36 |
37 | def _operation_inner(self, input: FieldArray, gf: Type[FieldArray]) -> FieldArray:
38 | return input**self._exponent # type: ignore
39 |
40 | def _arithmetize_inner(self, strategy: str) -> "Node":
41 | if strategy == "naive":
42 | # Square & multiply
43 | nodes = [self._node.arithmetize(strategy)]
44 |
45 | for i in range(math.ceil(math.log2(self._exponent))):
46 | nodes.append(nodes[i].mul(nodes[i], flatten=False))
47 | previous = None
48 | for i in range(math.ceil(math.log2(self._exponent))):
49 | if (self._exponent >> i) & 1:
50 | if previous is None:
51 | previous = nodes[i]
52 | else:
53 | nodes.append(nodes[i].mul(previous, flatten=False))
54 | previous = nodes[-1]
55 |
56 | assert previous is not None
57 | return previous
58 |
59 | assert strategy == "best-effort"
60 |
61 | addition_chain = add_chain_guaranteed(self._exponent, self._gf.characteristic - 1, squaring_cost=1.0)
62 |
63 | nodes = [self._node.arithmetize(strategy).to_arithmetic()]
64 |
65 | for i, j in addition_chain:
66 | nodes.append(Multiplication(nodes[i], nodes[j], self._gf))
67 |
68 | return nodes[-1]
69 |
70 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
71 | # TODO: While generating the front, we can take into account the maximum cost etc. implied by the depth-aware arithmetization of the operand
72 | if self._gf.characteristic <= 257:
73 | front = gen_pareto_front(self._exponent, self._gf.characteristic, cost_of_squaring)
74 | else:
75 | front = gen_pareto_front(self._exponent, None, cost_of_squaring)
76 |
77 | final_front = CostParetoFront(cost_of_squaring)
78 |
79 | for depth1, _, node in self._node.arithmetize_depth_aware(cost_of_squaring):
80 | for depth2, chain in front:
81 | c = extract_indices(
82 | chain,
83 | modulus=self._gf.characteristic - 1 if self._gf.characteristic <= 257 else None,
84 | )
85 |
86 | nodes = [node]
87 |
88 | for i, j in c:
89 | nodes.append(Multiplication(nodes[i], nodes[j], self._gf))
90 |
91 | final_front.add(nodes[-1], depth=depth1 + depth2)
92 |
93 | return final_front
94 |
95 |
96 | def test_depth_aware_arithmetization(): # noqa: D103
97 | gf = GF(31)
98 |
99 | x = Input("x", gf)
100 | node = Power(x, 30, gf)
101 | front = node.arithmetize_depth_aware(cost_of_squaring=1.0)
102 | node.clear_cache(set())
103 |
104 | for _, _, n in front:
105 | assert n.evaluate({"x": gf(0)}) == 0
106 | n.clear_cache(set())
107 |
108 | for xx in range(1, 31):
109 | assert n.evaluate({"x": gf(xx)}) == 1
110 |
--------------------------------------------------------------------------------
/oraqle/compiler/arithmetic/subtraction.py:
--------------------------------------------------------------------------------
1 | """This module contains classes for representing subtraction: x - y."""
2 | from galois import GF, FieldArray
3 |
4 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node
5 | from oraqle.compiler.nodes.leafs import Constant, Input
6 | from oraqle.compiler.nodes.non_commutative import NonCommutativeBinaryNode
7 |
8 |
9 | class Subtraction(NonCommutativeBinaryNode):
10 | """Represents a subtraction, which can be arithmetized using addition and constant-multiplication."""
11 |
12 | @property
13 | def _overriden_graphviz_attributes(self) -> dict:
14 | return {"shape": "square", "style": "rounded,filled", "fillcolor": "cornsilk"}
15 |
16 | @property
17 | def _hash_name(self) -> str:
18 | return "sub"
19 |
20 | @property
21 | def _node_label(self) -> str:
22 | return "-"
23 |
24 | def _operation_inner(self, x, y) -> FieldArray:
25 | return x - y
26 |
27 | def _arithmetize_inner(self, strategy: str) -> Node:
28 | # TODO: Reorganize the files: let the arithmetic folder only contain pure arithmetic (including add and mul) and move exponentiation elsewhere.
29 | # TODO: For schemes that support subtraction we do not need to do this. We should only do this transformation during the compiler stage.
30 | return (self._left.arithmetize(strategy) + (Constant(-self._gf(1)) * self._right.arithmetize(strategy))).arithmetize(strategy) # type: ignore # TODO: Should we always perform a final arithmetization in every node for constant folding? E.g. in Node?
31 |
32 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
33 | result = self._left + (Constant(-self._gf(1)) * self._right)
34 | front = result.arithmetize_depth_aware(cost_of_squaring)
35 | return front
36 |
37 |
38 | def test_evaluate_mod5(): # noqa: D103
39 | gf = GF(5)
40 |
41 | a = Input("a", gf)
42 | b = Input("b", gf)
43 | node = Subtraction(a, b, gf)
44 |
45 | assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(1)
46 | node.clear_cache(set())
47 | assert node.evaluate({"a": gf(4), "b": gf(1)}) == gf(3)
48 | node.clear_cache(set())
49 | assert node.evaluate({"a": gf(1), "b": gf(3)}) == gf(3)
50 | node.clear_cache(set())
51 | assert node.evaluate({"a": gf(0), "b": gf(4)}) == gf(1)
52 |
53 |
54 | def test_evaluate_arithmetized_mod5(): # noqa: D103
55 | gf = GF(5)
56 |
57 | a = Input("a", gf)
58 | b = Input("b", gf)
59 | node = Subtraction(a, b, gf).arithmetize("best-effort")
60 | node.clear_cache(set())
61 |
62 | assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(1)
63 | node.clear_cache(set())
64 | assert node.evaluate({"a": gf(4), "b": gf(1)}) == gf(3)
65 | node.clear_cache(set())
66 | assert node.evaluate({"a": gf(1), "b": gf(3)}) == gf(3)
67 | node.clear_cache(set())
68 | assert node.evaluate({"a": gf(0), "b": gf(4)}) == gf(1)
69 |
--------------------------------------------------------------------------------
/oraqle/compiler/boolean/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains nodes for expressing common Boolean operations."""
2 |
--------------------------------------------------------------------------------
/oraqle/compiler/boolean/bool_neg.py:
--------------------------------------------------------------------------------
1 | """Classes for describing Boolean negation."""
2 | from galois import FieldArray
3 |
4 | from oraqle.compiler.arithmetic.subtraction import Subtraction
5 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node
6 | from oraqle.compiler.nodes.leafs import Constant
7 | from oraqle.compiler.nodes.univariate import UnivariateNode
8 |
9 |
10 | class Neg(UnivariateNode):
11 | """A node that negates a Boolean input."""
12 |
13 | @property
14 | def _node_shape(self) -> str:
15 | return "box"
16 |
17 | @property
18 | def _hash_name(self) -> str:
19 | return "neg"
20 |
21 | @property
22 | def _node_label(self) -> str:
23 | return "NEG"
24 |
25 | def _operation_inner(self, input: FieldArray) -> FieldArray:
26 | assert input in {0, 1}
27 | return self._gf(not bool(input))
28 |
29 | def _arithmetize_inner(self, strategy: str) -> Node:
30 | return Subtraction(
31 | Constant(self._gf(1)), self._node.arithmetize(strategy), self._gf
32 | ).arithmetize(strategy)
33 |
34 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
35 | return Subtraction(Constant(self._gf(1)), self._node, self._gf).arithmetize_depth_aware(
36 | cost_of_squaring
37 | )
38 |
--------------------------------------------------------------------------------
/oraqle/compiler/boolean/bool_or.py:
--------------------------------------------------------------------------------
1 | """This module contains tools for evaluating OR operations between many inputs."""
2 | import itertools
3 | from typing import Set
4 |
5 | from galois import GF, FieldArray
6 |
7 | from oraqle.compiler.boolean.bool_and import And, _find_depth_cost_front
8 | from oraqle.compiler.boolean.bool_neg import Neg
9 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper
10 | from oraqle.compiler.nodes.flexible import CommutativeUniqueReducibleNode
11 | from oraqle.compiler.nodes.leafs import Constant, Input
12 |
13 | # TODO: Reduce code duplication between OR and AND
14 |
15 |
16 | class Or(CommutativeUniqueReducibleNode):
17 | """Performs an OR operation over several operands. The user must ensure that the operands are Booleans."""
18 |
19 | @property
20 | def _hash_name(self) -> str:
21 | return "or"
22 |
23 | @property
24 | def _node_label(self) -> str:
25 | return "OR"
26 |
27 | def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
28 | return self._gf(bool(a) | bool(b))
29 |
30 | def _arithmetize_inner(self, strategy: str) -> Node:
31 | # FIXME: Handle what happens when arithmetize outputs a constant!
32 | # TODO: Also consider the arithmetization using randomness
33 | return Neg(
34 | And(
35 | {
36 | UnoverloadedWrapper(Neg(operand.node.arithmetize(strategy), self._gf))
37 | for operand in self._operands
38 | },
39 | self._gf,
40 | ),
41 | self._gf,
42 | ).arithmetize(strategy)
43 |
44 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
45 | # TODO: This is mostly copied from AND
46 | new_operands: Set[CostParetoFront] = set()
47 | for operand in self._operands:
48 | new_operand = operand.node.arithmetize_depth_aware(cost_of_squaring)
49 | new_operands.add(new_operand)
50 |
51 | if len(new_operands) == 0:
52 | return CostParetoFront.from_leaf(Constant(self._gf(1)), cost_of_squaring)
53 | elif len(new_operands) == 1:
54 | return next(iter(new_operands))
55 |
56 | # TODO: We can check if any of the element in new_operands are constants and return early
57 |
58 | front = CostParetoFront(cost_of_squaring)
59 |
60 | # TODO: This is brute force composition
61 | for operands in itertools.product(*(iter(new_operand) for new_operand in new_operands)):
62 | checked_operands = []
63 | for depth, cost, node in operands:
64 | if isinstance(node, Constant):
65 | assert node._value in {0, 1}
66 | if node._value == 0:
67 | return CostParetoFront.from_leaf(Constant(self._gf(0)), cost_of_squaring)
68 | else:
69 | checked_operands.append((depth, cost, node))
70 |
71 | if len(checked_operands) == 0:
72 | return CostParetoFront.from_leaf(Constant(self._gf(1)), cost_of_squaring)
73 |
74 | if len(checked_operands) == 1:
75 | depth, cost, node = checked_operands[0]
76 | front.add(node, depth, cost)
77 | continue
78 |
79 | this_front = _find_depth_cost_front(
80 | checked_operands,
81 | self._gf,
82 | float("inf"),
83 | squaring_cost=cost_of_squaring,
84 | is_and=False,
85 | )
86 | front.add_front(this_front)
87 |
88 | return front
89 |
90 | def or_flatten(self, other: Node) -> Node:
91 | """Performs an OR operation with `other`, flattening the `Or` node if either of the two is also an `Or` and absorbing `Constant`s.
92 |
93 | Returns:
94 | An `Or` node containing the flattened OR operation, or a `Constant` node.
95 | """
96 | if isinstance(other, Constant):
97 | if bool(other._value):
98 | return Constant(self._gf(1))
99 | else:
100 | return self
101 |
102 | if isinstance(other, Or):
103 | return Or(self._operands | other._operands, self._gf)
104 |
105 | new_operands = self._operands.copy()
106 | new_operands.add(UnoverloadedWrapper(other))
107 | return Or(new_operands, self._gf)
108 |
109 |
110 | def any_(*operands: Node) -> Or:
111 | """Returns an `Or` node that evaluates to true if any of the given `operands` evaluates to true."""
112 | assert len(operands) > 0
113 | return Or(set(UnoverloadedWrapper(operand) for operand in operands), operands[0]._gf)
114 |
115 |
116 | def test_evaluate_mod3(): # noqa: D103
117 | gf = GF(3)
118 |
119 | a = Input("a", gf)
120 | b = Input("b", gf)
121 | node = a | b
122 |
123 | assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0)
124 | node.clear_cache(set())
125 | assert node.evaluate({"a": gf(0), "b": gf(1)}) == gf(1)
126 | node.clear_cache(set())
127 | assert node.evaluate({"a": gf(1), "b": gf(0)}) == gf(1)
128 | node.clear_cache(set())
129 | assert node.evaluate({"a": gf(1), "b": gf(1)}) == gf(1)
130 |
131 |
132 | def test_evaluate_arithmetized_depth_aware_mod2(): # noqa: D103
133 | gf = GF(2)
134 |
135 | a = Input("a", gf)
136 | b = Input("b", gf)
137 | node = a | b
138 | front = node.arithmetize_depth_aware(cost_of_squaring=1.0)
139 |
140 | for _, _, n in front:
141 | n.clear_cache(set())
142 | assert n.evaluate({"a": gf(0), "b": gf(0)}) == gf(0)
143 | n.clear_cache(set())
144 | assert n.evaluate({"a": gf(0), "b": gf(1)}) == gf(1)
145 | n.clear_cache(set())
146 | assert n.evaluate({"a": gf(1), "b": gf(0)}) == gf(1)
147 | n.clear_cache(set())
148 | assert n.evaluate({"a": gf(1), "b": gf(1)}) == gf(1)
149 |
150 |
151 | def test_evaluate_arithmetized_mod3(): # noqa: D103
152 | gf = GF(3)
153 |
154 | a = Input("a", gf)
155 | b = Input("b", gf)
156 | node = (a | b).arithmetize("best-effort")
157 |
158 | node.clear_cache(set())
159 | assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0)
160 | node.clear_cache(set())
161 | assert node.evaluate({"a": gf(0), "b": gf(1)}) == gf(1)
162 | node.clear_cache(set())
163 | assert node.evaluate({"a": gf(1), "b": gf(0)}) == gf(1)
164 | node.clear_cache(set())
165 | assert node.evaluate({"a": gf(1), "b": gf(1)}) == gf(1)
166 |
167 |
168 | def test_evaluate_arithmetized_depth_aware_50_mod31(): # noqa: D103
169 | gf = GF(31)
170 |
171 | xs = {Input(f"x{i}", gf) for i in range(50)}
172 | node = Or({UnoverloadedWrapper(x) for x in xs}, gf)
173 | front = node.arithmetize_depth_aware(cost_of_squaring=1.0)
174 |
175 | for _, _, n in front:
176 | n.clear_cache(set())
177 | assert n.evaluate({f"x{i}": gf(0) for i in range(50)}) == gf(0)
178 | n.clear_cache(set())
179 | assert n.evaluate({f"x{i}": gf(i % 2) for i in range(50)}) == gf(1)
180 | n.clear_cache(set())
181 | assert n.evaluate({f"x{i}": gf(1) for i in range(50)}) == gf(1)
182 |
--------------------------------------------------------------------------------
/oraqle/compiler/comparison/__init__.py:
--------------------------------------------------------------------------------
1 | """Package containing tools for expressing equality and comparison operations."""
2 |
--------------------------------------------------------------------------------
/oraqle/compiler/comparison/equality.py:
--------------------------------------------------------------------------------
1 | """This module contains classes for representing equality checks."""
2 | from galois import GF, FieldArray
3 |
4 | from oraqle.compiler.arithmetic.exponentiation import Power
5 | from oraqle.compiler.arithmetic.subtraction import Subtraction
6 | from oraqle.compiler.boolean.bool_neg import Neg
7 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node
8 | from oraqle.compiler.nodes.binary_arithmetic import CommutativeBinaryNode
9 | from oraqle.compiler.nodes.leafs import Input
10 | from oraqle.compiler.nodes.univariate import UnivariateNode
11 |
12 |
13 | class IsNonZero(UnivariateNode):
14 | """This node represents a zero check: x == 0."""
15 |
16 | @property
17 | def _node_shape(self) -> str:
18 | return "box"
19 |
20 | @property
21 | def _hash_name(self) -> str:
22 | return "is_nonzero"
23 |
24 | @property
25 | def _node_label(self) -> str:
26 | return "!= 0"
27 |
28 | def _operation_inner(self, input: FieldArray) -> FieldArray:
29 | return input != 0
30 |
31 | def _arithmetize_inner(self, strategy: str) -> Node:
32 | return Power(self._node, self._gf.order - 1, self._gf).arithmetize(strategy)
33 |
34 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
35 | return Power(self._node, self._gf.order - 1, self._gf).arithmetize_depth_aware(
36 | cost_of_squaring
37 | )
38 |
39 |
40 | class Equals(CommutativeBinaryNode):
41 | """This node represents an equality operation: x == y."""
42 |
43 | @property
44 | def _hash_name(self) -> str:
45 | return "equals"
46 |
47 | @property
48 | def _node_label(self) -> str:
49 | return "=="
50 |
51 | def _operation_inner(self, x, y) -> FieldArray:
52 | return self._gf(int(x == y))
53 |
54 | def _arithmetize_inner(self, strategy: str) -> Node:
55 | return Neg(
56 | IsNonZero(Subtraction(self._left, self._right, self._gf), self._gf),
57 | self._gf,
58 | ).arithmetize(strategy)
59 |
60 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
61 | return Neg(
62 | IsNonZero(Subtraction(self._left, self._right, self._gf), self._gf),
63 | self._gf,
64 | ).arithmetize_depth_aware(cost_of_squaring)
65 |
66 |
67 | def test_evaluate_mod5(): # noqa: D103
68 | gf = GF(5)
69 |
70 | a = Input("a", gf)
71 | b = Input("b", gf)
72 | node = Equals(a, b, gf)
73 |
74 | assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(0)
75 | node.clear_cache(set())
76 | assert node.evaluate({"a": gf(4), "b": gf(4)}) == gf(1)
77 | node.clear_cache(set())
78 | assert node.evaluate({"a": gf(1), "b": gf(2)}) == gf(0)
79 | node.clear_cache(set())
80 | assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(1)
81 |
82 |
83 | def test_evaluate_arithmetized_mod5(): # noqa: D103
84 | gf = GF(5)
85 |
86 | a = Input("a", gf)
87 | b = Input("b", gf)
88 | node = Equals(a, b, gf).arithmetize("best-effort")
89 | node.clear_cache(set())
90 |
91 | assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(0)
92 | node.clear_cache(set())
93 | assert node.evaluate({"a": gf(4), "b": gf(4)}) == gf(1)
94 | node.clear_cache(set())
95 | assert node.evaluate({"a": gf(1), "b": gf(2)}) == gf(0)
96 | node.clear_cache(set())
97 | assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(1)
98 |
99 |
100 | def test_equality_equivalence_commutative(): # noqa: D103
101 | gf = GF(5)
102 |
103 | a = Input("a", gf)
104 | b = Input("b", gf)
105 |
106 | assert (a == b).is_equivalent(b == a)
107 |
--------------------------------------------------------------------------------
/oraqle/compiler/control_flow/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains control flow functions."""
2 |
--------------------------------------------------------------------------------
/oraqle/compiler/control_flow/conditional.py:
--------------------------------------------------------------------------------
1 | """This module contains tools for evaluating conditional statements."""
2 | from typing import List, Type
3 |
4 | from galois import GF, FieldArray
5 |
6 | from oraqle.compiler.circuit import Circuit
7 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node
8 | from oraqle.compiler.nodes.fixed import FixedNode
9 | from oraqle.compiler.nodes.leafs import Constant, Input
10 |
11 |
12 | class IfElse(FixedNode):
13 | """A node representing an if-else clause."""
14 |
15 | @property
16 | def _node_label(self):
17 | return "If"
18 |
19 | @property
20 | def _hash_name(self):
21 | return "if_else"
22 |
23 | def __init__(self, condition: Node, positive: Node, negative: Node, gf: Type[FieldArray]):
24 | """Initialize an if-else node: If condition evaluates to true, then it outputs positive, otherwise it outputs negative."""
25 | self._condition = condition
26 | self._positive = positive
27 | self._negative = negative
28 | super().__init__(gf)
29 |
30 | def __hash__(self) -> int:
31 | return hash((self._hash_name, self._condition, self._positive, self._negative))
32 |
33 | def is_equivalent(self, other: Node) -> bool: # noqa: D102
34 | if not isinstance(other, self.__class__):
35 | return False
36 |
37 | return (
38 | self._condition.is_equivalent(other._condition)
39 | and self._positive.is_equivalent(other._positive)
40 | and self._negative.is_equivalent(other._negative)
41 | )
42 |
43 | def operands(self) -> List[Node]: # noqa: D102
44 | return [self._condition, self._positive, self._negative]
45 |
46 | def set_operands(self, operands: List[Node]): # noqa: D102
47 | self._condition = operands[0]
48 | self._positive = operands[1]
49 | self._negative = operands[2]
50 |
51 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102
52 | assert operands[0] == 0 or operands[0] == 1
53 | return operands[1] if operands[0] == 1 else operands[2]
54 |
55 | def _arithmetize_inner(self, strategy: str) -> Node:
56 | return (self._condition * (self._positive - self._negative) + self._negative).arithmetize(
57 | strategy
58 | )
59 |
60 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
61 | return (
62 | self._condition * (self._positive - self._negative) + self._negative
63 | ).arithmetize_depth_aware(cost_of_squaring)
64 |
65 |
66 | def if_else(condition: Node, positive: Node, negative: Node) -> IfElse:
67 | """Sugar expression for creating an if-else clause.
68 |
69 | Returns:
70 | An `IfElse` node that equals `positive` if `condition` is true, and `negative` otherwise.
71 | """
72 | assert condition._gf == positive._gf
73 | assert condition._gf == negative._gf
74 | return IfElse(condition, positive, negative, condition._gf)
75 |
76 |
77 | def test_if_else(): # noqa: D103
78 | gf = GF(11)
79 |
80 | a = Input("a", gf)
81 | b = Input("b", gf)
82 |
83 | output = if_else(a == b, Constant(gf(3)), Constant(gf(5)))
84 |
85 | circuit = Circuit([output])
86 |
87 | for val_a in range(11):
88 | for val_b in range(11):
89 | expected = gf(3) if val_a == val_b else gf(5)
90 |
91 | values = {"a": gf(val_a), "b": gf(val_b)}
92 | assert circuit.evaluate(values) == expected
93 |
94 |
95 | def test_if_else_arithmetized(): # noqa: D103
96 | gf = GF(11)
97 |
98 | a = Input("a", gf)
99 | b = Input("b", gf)
100 |
101 | output = if_else(a == b, Constant(gf(3)), Constant(gf(5)))
102 |
103 | arithmetic_circuit = Circuit([output]).arithmetize()
104 |
105 | for val_a in range(11):
106 | for val_b in range(11):
107 | expected = gf(3) if val_a == val_b else gf(5)
108 |
109 | values = {"a": gf(val_a), "b": gf(val_b)}
110 | assert arithmetic_circuit.evaluate(values) == expected
111 |
--------------------------------------------------------------------------------
/oraqle/compiler/func2poly.py:
--------------------------------------------------------------------------------
1 | """Tools for interpolating polynomials from arbitrary functions."""
2 | import itertools
3 | from typing import Callable, List
4 |
5 | from sympy import Poly, symbols
6 |
7 |
8 | def principal_character(x, prime_modulus):
9 | """Computes the principal character. This expression always returns 1 when x = 0 and 0 otherwise. Only works for prime moduli.
10 |
11 | Returns:
12 | The principal character x**(p-1).
13 | """
14 | return x ** (prime_modulus - 1)
15 |
16 |
17 | def interpolate_polynomial(
18 | function: Callable[..., int], prime_modulus: int, input_names: List[str]
19 | ) -> Poly:
20 | """Interpolates a polynomial for the given function. This is currently only implemented for prime moduli. This function interpolates the polynomial on all possible inputs.
21 |
22 | Returns:
23 | A sympy `Poly` object representing the unique polynomial that evaluates to the same outputs for all inputs as `function`.
24 | """
25 | variables = symbols(input_names)
26 | poly = 0
27 |
28 | for inputs in itertools.product(range(prime_modulus), repeat=len(input_names)):
29 | output = function(*inputs)
30 | assert 0 <= output < prime_modulus
31 |
32 | product = output
33 | for input, variable in zip(inputs, variables):
34 | product *= Poly(
35 | 1 - principal_character(variable - input, prime_modulus),
36 | variable,
37 | modulus=prime_modulus,
38 | )
39 | product = Poly(product, variables, modulus=prime_modulus)
40 |
41 | poly += product
42 | poly = Poly(poly, variables, modulus=prime_modulus)
43 |
44 | return Poly(poly, variables, modulus=prime_modulus)
45 |
--------------------------------------------------------------------------------
/oraqle/compiler/graphviz.py:
--------------------------------------------------------------------------------
1 | """This module contains classes and functions for visualizing circuits using graphviz."""
2 | from typing import Dict, List, Tuple
3 |
4 | expensive_style = {"shape": "diamond"}
5 |
6 |
7 | class DotFile:
8 | """A `DotFile` is a graph description format that can be rendered to e.g. PDF using graphviz."""
9 |
10 | def __init__(self):
11 | """Initialize an empty DotFile."""
12 | self._nodes: List[Dict[str, str]] = []
13 | self._links: List[Tuple[int, int, Dict[str, str]]] = []
14 |
15 | def add_node(self, **kwargs) -> int:
16 | """Adds a node to the file. The keyword arguments are directly put into the DOT file.
17 |
18 | For example, one can specify a label, a color, a style, etc...
19 |
20 | Returns:
21 | The identifier of this node in this `DotFile`.
22 | """
23 | node_id = len(self._nodes)
24 | self._nodes.append(kwargs)
25 |
26 | return node_id
27 |
28 | def add_link(self, from_id: int, to_id: int, **kwargs):
29 | """Adds an unformatted link between the nodes with `from_id` and `to_id`. The keyword arguments are directly put into the DOT file."""
30 | self._links.append((from_id, to_id, kwargs))
31 |
32 | def to_file(self, filename: str):
33 | """Writes the DOT file to the given filename as a directed graph called 'G'."""
34 | with open(filename, mode="w", encoding="utf-8") as file:
35 | file.write("digraph G {\n")
36 | file.write('forcelabels="true";\n')
37 | file.write("graph [nodesep=0.25,ranksep=0.6];") # nodesep, ranksep
38 |
39 | # Write all the nodes
40 | for node_id, attributes in enumerate(self._nodes):
41 | transformed_attributes = ",".join(
42 | [f'{key}="{value}"' for key, value in attributes.items()]
43 | )
44 | file.write(f"n{node_id} [{transformed_attributes}];\n")
45 |
46 | # Write all the links
47 | for from_id, to_id, attributes in self._links:
48 | if len(attributes) == 0:
49 | file.write(f"n{from_id}->n{to_id};\n")
50 | else:
51 | text = f"n{from_id}->n{to_id} ["
52 | text += ",".join((f"{key}={value}" for key, value in attributes.items()))
53 | text += "];\n"
54 | file.write(text)
55 |
56 | file.write("}\n")
57 |
--------------------------------------------------------------------------------
/oraqle/compiler/nodes/__init__.py:
--------------------------------------------------------------------------------
1 | """The nodes package contains a collection of fundamental abstract and concrete nodes."""
2 | from oraqle.compiler.nodes.abstract import Node
3 | from oraqle.compiler.nodes.binary_arithmetic import Addition, Multiplication
4 | from oraqle.compiler.nodes.leafs import Constant, Input
5 |
6 | __all__ = ['Addition', 'Constant', 'Input', 'Multiplication', 'Node']
7 |
--------------------------------------------------------------------------------
/oraqle/compiler/nodes/binary_arithmetic.py:
--------------------------------------------------------------------------------
1 | """Module containing binary arithmetic nodes: additions and multiplications between non-constant nodes."""
2 | from abc import abstractmethod
3 | from typing import List, Optional, Set, Tuple, Type
4 |
5 | from galois import FieldArray
6 |
7 | from oraqle.compiler.instructions import (
8 | AdditionInstruction,
9 | ArithmeticInstruction,
10 | MultiplicationInstruction,
11 | )
12 | from oraqle.compiler.nodes.abstract import (
13 | ArithmeticNode,
14 | CostParetoFront,
15 | Node,
16 | iterate_increasing_depth,
17 | select_stack_index,
18 | )
19 | from oraqle.compiler.nodes.fixed import BinaryNode
20 | from oraqle.compiler.nodes.leafs import Constant
21 |
22 |
23 | class CommutativeBinaryNode(BinaryNode):
24 | """This node has two operands and implements a commutative operation between arithmetic nodes."""
25 |
26 | def __init__(
27 | self,
28 | left: Node,
29 | right: Node,
30 | gf: Type[FieldArray],
31 | ):
32 | """Initialize the binary node with operands `left` and `right`."""
33 | self._left = left
34 | self._right = right
35 | super().__init__(gf)
36 |
37 | @abstractmethod
38 | def _operation_inner(self, x: FieldArray, y: FieldArray) -> FieldArray:
39 | """Applies the binary operation on x and y."""
40 |
41 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102
42 | return self._operation_inner(operands[0], operands[1])
43 |
44 | def operands(self) -> List[Node]: # noqa: D102
45 | return [self._left, self._right]
46 |
47 | def set_operands(self, operands: List[ArithmeticNode]): # noqa: D102
48 | self._left = operands[0]
49 | self._right = operands[1]
50 |
51 | def __hash__(self) -> int:
52 | if self._hash is None:
53 | left_hash = hash(self._left)
54 | right_hash = hash(self._right)
55 |
56 | # Make the hash commutative
57 | if left_hash < right_hash:
58 | self._hash = hash((self._hash_name, (left_hash, right_hash)))
59 | else:
60 | self._hash = hash((self._hash_name, (right_hash, left_hash)))
61 |
62 | return self._hash
63 |
64 | def is_equivalent(self, other: Node) -> bool: # noqa: D102
65 | if not isinstance(other, self.__class__):
66 | return False
67 |
68 | if hash(self) != hash(other):
69 | return False
70 |
71 | # Equivalence by commutative equality
72 | return (
73 | self._left.is_equivalent(other._left) and self._right.is_equivalent(other._right)
74 | ) or (self._left.is_equivalent(other._right) and self._right.is_equivalent(other._left))
75 |
76 |
77 | class CommutativeArithmeticBinaryNode(CommutativeBinaryNode):
78 | """This node has two operands and implements a commutative operation between arithmetic nodes."""
79 |
80 | def __init__(
81 | self,
82 | left: ArithmeticNode,
83 | right: ArithmeticNode,
84 | gf: Type[FieldArray],
85 | ):
86 | """Initialize this binary node with the given `left` and `right` operands.
87 |
88 | Raises:
89 | Exception: Neither `left` nor `right` is allowed to be a `Constant`.
90 | """
91 | super().__init__(left, right, gf)
92 |
93 | self._multiplications: Optional[Set[int]] = None
94 | self._squarings: Optional[Set[int]] = None
95 | self._depth_cache: Optional[int] = None
96 |
97 | if isinstance(left, Constant) or isinstance(right, Constant):
98 | self._is_multiplication = False
99 | raise Exception("This should be a constant.")
100 |
101 | def multiplicative_depth(self) -> int: # noqa: D102
102 | if self._depth_cache is None:
103 | self._depth_cache = self._is_multiplication + max(
104 | self._left.multiplicative_depth(), self._right.multiplicative_depth()
105 | )
106 |
107 | return self._depth_cache
108 |
109 | def multiplications(self) -> Set[int]: # noqa: D102
110 | if self._multiplications is None:
111 | self._multiplications = set().union(
112 | *(operand.multiplications() for operand in self.operands()) # type: ignore
113 | )
114 | if self._is_multiplication:
115 | self._multiplications.add(id(self))
116 |
117 | return self._multiplications
118 |
119 | # TODO: Squaring should probably be a UniveriateNode
120 | def squarings(self) -> Set[int]: # noqa: D102
121 | if self._squarings is None:
122 | self._squarings = set().union(*(operand.squarings() for operand in self.operands())) # type: ignore
123 | if self._is_multiplication and id(self._left) == id(self._right):
124 | self._squarings.add(id(self))
125 |
126 | return self._squarings
127 |
128 | def create_instructions( # noqa: D102
129 | self,
130 | instructions: List[ArithmeticInstruction],
131 | stack_counter: int,
132 | stack_occupied: List[bool],
133 | ) -> Tuple[int, int]:
134 | self._left: ArithmeticNode
135 | self._right: ArithmeticNode
136 |
137 | if self._instruction_cache is None:
138 | left_index, stack_counter = self._left.create_instructions(
139 | instructions, stack_counter, stack_occupied
140 | )
141 | right_index, stack_counter = self._right.create_instructions(
142 | instructions, stack_counter, stack_occupied
143 | )
144 |
145 | # FIXME: Is it possible for e.g. self._left._instruction_cache to be None?
146 |
147 | self._left._parent_count -= 1
148 | if self._left._parent_count == 0:
149 | stack_occupied[self._left._instruction_cache] = False # type: ignore
150 |
151 | self._right._parent_count -= 1
152 | if self._right._parent_count == 0:
153 | stack_occupied[self._right._instruction_cache] = False # type: ignore
154 |
155 | self._instruction_cache = select_stack_index(stack_occupied)
156 |
157 | if self._is_multiplication:
158 | instructions.append(
159 | MultiplicationInstruction(self._instruction_cache, left_index, right_index)
160 | )
161 | else:
162 | instructions.append(
163 | AdditionInstruction(self._instruction_cache, left_index, right_index)
164 | )
165 |
166 | return self._instruction_cache, stack_counter
167 |
168 |
169 | # FIXME: This order should probably change
170 | class Addition(CommutativeArithmeticBinaryNode, ArithmeticNode):
171 | """Performs modular addition of two previous nodes in an arithmetic circuit."""
172 |
173 | @property
174 | def _overriden_graphviz_attributes(self) -> dict:
175 | return {"shape": "square", "style": "rounded,filled", "fillcolor": "grey80"}
176 |
177 | @property
178 | def _hash_name(self) -> str:
179 | return "add"
180 |
181 | @property
182 | def _node_label(self) -> str:
183 | return "+"
184 |
185 | def __init__(
186 | self,
187 | left: ArithmeticNode,
188 | right: ArithmeticNode,
189 | gf: Type[FieldArray],
190 | ):
191 | """Initialize a modular addition between `left` and `right`."""
192 | self._is_multiplication = False
193 | super().__init__(left, right, gf)
194 |
195 | def _operation_inner(self, x, y):
196 | return x + y
197 |
198 | def arithmetize(self, strategy: str) -> Node: # noqa: D102
199 | self._left = self._left.arithmetize(strategy)
200 | self._right = self._right.arithmetize(strategy)
201 | return self
202 |
203 | def _arithmetize_inner(self, strategy: str) -> Node:
204 | raise NotImplementedError()
205 |
206 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
207 | front = CostParetoFront(cost_of_squaring)
208 |
209 | for res1, res2 in iterate_increasing_depth(
210 | self._left.arithmetize_depth_aware(cost_of_squaring),
211 | self._right.arithmetize_depth_aware(cost_of_squaring),
212 | ):
213 | d1, _, e1 = res1
214 | d2, _, e2 = res2
215 |
216 | # TODO: Do we use + here for flattening?
217 | front.add(Addition(e1, e2, self._gf), depth=max(d1, d2))
218 |
219 | assert not front.is_empty()
220 | return front
221 |
222 |
223 | class Multiplication(CommutativeArithmeticBinaryNode, ArithmeticNode):
224 | """Performs modular multiplication of two previous nodes in an arithmetic circuit."""
225 |
226 | @property
227 | def _overriden_graphviz_attributes(self) -> dict:
228 | return {"shape": "square", "style": "rounded,filled", "fillcolor": "lightpink"}
229 |
230 | @property
231 | def _hash_name(self) -> str:
232 | return "mul"
233 |
234 | @property
235 | def _node_label(self) -> str:
236 | return "×" # noqa: RUF001
237 |
238 | def __init__(
239 | self,
240 | left: ArithmeticNode,
241 | right: ArithmeticNode,
242 | gf: Type[FieldArray],
243 | ):
244 | """Initialize a modular multiplication between `left` and `right`."""
245 | assert isinstance(left, ArithmeticNode)
246 | assert isinstance(right, ArithmeticNode)
247 |
248 | self._is_multiplication = True
249 | super().__init__(left, right, gf)
250 |
251 | def _operation_inner(self, x, y):
252 | return x * y
253 |
254 | # TODO: This is very hacky! Arithmetic nodes should simply not have to be arithmetized...
255 | def arithmetize(self, strategy: str) -> Node: # noqa: D102
256 | self._left = self._left.arithmetize(strategy)
257 | self._right = self._right.arithmetize(strategy)
258 | return self
259 |
260 | def _arithmetize_inner(self, strategy: str) -> Node:
261 | raise NotImplementedError()
262 |
263 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
264 | return CostParetoFront.from_node(self, cost_of_squaring)
265 |
--------------------------------------------------------------------------------
/oraqle/compiler/nodes/fixed.py:
--------------------------------------------------------------------------------
1 | """Module containing fixed nodes: nodes with a fixed number of inputs."""
2 | from abc import abstractmethod
3 | from typing import Callable, Dict, List
4 |
5 | from galois import FieldArray
6 |
7 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node
8 |
9 |
10 | class FixedNode(Node):
11 | """A node with a fixed number of operands."""
12 |
13 | @abstractmethod
14 | def operands(self) -> List["Node"]:
15 | """Returns the operands (children) of this node. The list can be empty."""
16 |
17 | @abstractmethod
18 | def set_operands(self, operands: List["Node"]):
19 | """Overwrites the operands of this node."""
20 | # TODO: Consider replacing this method with a graph traversal method that applies a function on all operands and replaces them.
21 |
22 |
23 | def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102
24 | for operand in self.operands():
25 | function(operand)
26 |
27 |
28 | def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102
29 | self.set_operands([function(operand) for operand in self.operands()])
30 | # TODO: These caches should only be cleared if this is an ArithmeticNode
31 | self._multiplications = None
32 | self._squarings = None
33 | self._depth_cache = None
34 |
35 |
36 | def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102
37 | # TODO: Remove modulus in this method and store it in each node instead. Alternatively, add `modulus` to methods such as `flatten` as well.
38 | if self._evaluate_cache is None:
39 | self._evaluate_cache = self.operation(
40 | [operand.evaluate(actual_inputs) for operand in self.operands()]
41 | )
42 |
43 | return self._evaluate_cache
44 |
45 | @abstractmethod
46 | def operation(self, operands: List[FieldArray]) -> FieldArray:
47 | """Evaluates this node on the specified operands."""
48 |
49 | def arithmetize(self, strategy: str) -> "Node": # noqa: D102
50 | if self._arithmetize_cache is None:
51 | if self._arithmetize_depth_cache is not None:
52 | return self._arithmetize_depth_cache.get_lowest_value() # type: ignore
53 |
54 | # If we know all operands we can simply evaluate this node
55 | operands = self.operands()
56 | if len(operands) > 0 and all(
57 | hasattr(operand, "_value") for operand in operands
58 | ): # This is a hacky way of checking whether the operands are all constant
59 | from oraqle.compiler.nodes.leafs import Constant
60 |
61 | self._arithmetize_cache = Constant(self.operation([operand._value for operand in self.operands()])) # type: ignore
62 | else:
63 | self._arithmetize_cache = self._arithmetize_inner(strategy)
64 |
65 | return self._arithmetize_cache
66 |
67 | @abstractmethod
68 | def _arithmetize_inner(self, strategy: str) -> "Node":
69 | pass
70 |
71 | # TODO: Reduce code duplication
72 |
73 | def arithmetize_depth_aware(self, cost_of_squaring: float) -> CostParetoFront: # noqa: D102
74 | if self._arithmetize_depth_cache is None:
75 | if self._arithmetize_cache is not None:
76 | raise Exception("This should not happen")
77 |
78 | # If we know all operands we can simply evaluate this node
79 | operands = self.operands()
80 | if len(operands) > 0 and all(
81 | hasattr(operand, "_value") for operand in operands
82 | ): # This is a hacky way of checking whether the operands are all constant
83 | from oraqle.compiler.nodes.leafs import Constant
84 |
85 | self._arithmetize_depth_cache = CostParetoFront.from_leaf(Constant(self.operation([operand._value for operand in self.operands()])), cost_of_squaring) # type: ignore
86 | else:
87 | self._arithmetize_depth_cache = self._arithmetize_depth_aware_inner(
88 | cost_of_squaring
89 | )
90 |
91 | assert self._arithmetize_depth_cache is not None
92 | return self._arithmetize_depth_cache
93 |
94 | @abstractmethod
95 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
96 | pass
97 |
98 |
99 | class BinaryNode(FixedNode):
100 | """A node with two operands."""
101 |
--------------------------------------------------------------------------------
/oraqle/compiler/nodes/flexible.py:
--------------------------------------------------------------------------------
1 | """Module containing nodes with a flexible number of operands."""
2 | from abc import abstractmethod
3 | from collections import Counter
4 | from functools import reduce
5 | from typing import Callable
6 | from typing import Counter as CounterType
7 | from typing import Dict, Optional, Set, Type
8 |
9 | from galois import FieldArray
10 |
11 | from oraqle.compiler.graphviz import DotFile
12 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper
13 | from oraqle.compiler.nodes.leafs import Constant
14 |
15 |
16 | class FlexibleNode(Node):
17 | """A node with an arbitrary number of operands. The operation must be reducible using a binary associative operation."""
18 |
19 | # TODO: Ensure that when all inputs are constants, the node is replaced with its evaluation
20 |
21 | def arithmetize(self, strategy: str) -> Node: # noqa: D102
22 | if self._arithmetize_cache is None:
23 | self._arithmetize_cache = self._arithmetize_inner(strategy)
24 |
25 | return self._arithmetize_cache
26 |
27 | @abstractmethod
28 | def _arithmetize_inner(self, strategy: str) -> "Node":
29 | pass
30 |
31 | def arithmetize_depth_aware(self, cost_of_squaring: float) -> CostParetoFront: # noqa: D102
32 | if self._arithmetize_depth_cache is None:
33 | self._arithmetize_depth_cache = self._arithmetize_depth_aware_inner(cost_of_squaring)
34 |
35 | return self._arithmetize_depth_cache
36 |
37 | @abstractmethod
38 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
39 | pass
40 |
41 |
42 | class CommutativeUniqueReducibleNode(FlexibleNode):
43 | """A node with an operation that is reducible without taking order into account: i.e. it has a binary operation that is associative and commutative.
44 |
45 | The operands are unique, i.e. the same operand will never appear twice.
46 | """
47 |
48 | def __init__(
49 | self,
50 | operands: Set[UnoverloadedWrapper],
51 | gf: Type[FieldArray],
52 | ):
53 | """Initialize a node with the given set as the operands. None of the operands can be a constant."""
54 | self._operands = operands
55 | assert not any(isinstance(operand.node, Constant) for operand in self._operands)
56 | assert len(operands) > 1
57 | super().__init__(gf)
58 |
59 | def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102
60 | for operand in self._operands:
61 | function(operand.node)
62 |
63 | def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102
64 | self._operands = {UnoverloadedWrapper(function(operand.node)) for operand in self._operands}
65 |
66 | def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102
67 | if self._evaluate_cache is None:
68 | self._evaluate_cache = reduce(
69 | self._inner_operation,
70 | (operand.node.evaluate(actual_inputs) for operand in self._operands),
71 | )
72 |
73 | return self._evaluate_cache
74 |
75 | @abstractmethod
76 | def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
77 | """Perform the reducible operation performed by this node (order should not matter)."""
78 |
79 | def __hash__(self) -> int:
80 | if self._hash is None:
81 | # The hash is commutative
82 | hashes = sorted([hash(operand) for operand in self._operands])
83 | self._hash = hash((self._hash_name, tuple(hashes)))
84 |
85 | return self._hash
86 |
87 | def is_equivalent(self, other: Node) -> bool: # noqa: D102
88 | if not isinstance(other, self.__class__):
89 | return False
90 |
91 | if hash(self) != hash(other):
92 | return False
93 |
94 | return self._operands == other._operands
95 |
96 |
97 | class CommutativeMultiplicityReducibleNode(FlexibleNode):
98 | """A node with an operation that is reducible without taking order into account: i.e. it has a binary operation that is associative and commutative."""
99 |
100 | def __init__(
101 | self,
102 | operands: CounterType[UnoverloadedWrapper],
103 | gf: Type[FieldArray],
104 | constant: Optional[FieldArray] = None,
105 | ):
106 | """Initialize a reducible node with the given `Counter` representing the operands, none of which is allowed to be a constant."""
107 | super().__init__(gf)
108 | self._constant = self._identity if constant is None else constant
109 | self._operands = operands
110 | assert not any(isinstance(operand, Constant) for operand in self._operands)
111 | assert (sum(operands.values()) + (self._constant != self._identity)) > 1
112 | assert isinstance(next(iter(self._operands)), UnoverloadedWrapper)
113 |
114 | @property
115 | @abstractmethod
116 | def _identity(self) -> FieldArray:
117 | pass
118 |
119 | def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102
120 | for operand in self._operands:
121 | function(operand.node)
122 |
123 | def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102
124 | # FIXME: What if there is only one operand remaining?
125 | self._operands = Counter(
126 | {
127 | UnoverloadedWrapper(function(operand.node)): count
128 | for operand, count in self._operands.items()
129 | }
130 | )
131 | assert not any(isinstance(operand.node, Constant) for operand in self._operands)
132 | assert (sum(self._operands.values()) + (self._constant != self._identity)) > 1
133 |
134 | def __hash__(self) -> int:
135 | if self._hash is None:
136 | # The hash is commutative
137 | hashes = sorted(
138 | [(hash(operand.node), count) for operand, count in self._operands.items()]
139 | )
140 | self._hash = hash((self._hash_name, tuple(hashes), int(self._constant)))
141 |
142 | return self._hash
143 |
144 | def is_equivalent(self, other: Node) -> bool: # noqa: D102
145 | if not isinstance(other, self.__class__):
146 | return False
147 |
148 | if hash(self) != hash(other):
149 | return False
150 |
151 | return self._operands == other._operands and self._constant == other._constant
152 |
153 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102
154 | if self._to_graph_cache is None:
155 | super().to_graph(graph_builder)
156 | self._to_graph_cache: int
157 |
158 | if self._constant != self._identity:
159 | # TODO: Add known_by
160 | graph_builder.add_link(
161 | graph_builder.add_node(label=str(self._constant)), self._to_graph_cache
162 | )
163 |
164 | return self._to_graph_cache
165 |
--------------------------------------------------------------------------------
/oraqle/compiler/nodes/leafs.py:
--------------------------------------------------------------------------------
1 | """Module containing leaf nodes: i.e. nodes without an input."""
2 | from typing import Any, Dict, List, Set, Tuple, Type
3 |
4 | from galois import FieldArray
5 |
6 | from oraqle.compiler.graphviz import DotFile
7 | from oraqle.compiler.instructions import ArithmeticInstruction, InputInstruction
8 | from oraqle.compiler.nodes.abstract import ArithmeticNode, CostParetoFront, Node, select_stack_index
9 | from oraqle.compiler.nodes.fixed import FixedNode
10 |
11 |
12 | class ArithmeticLeafNode(FixedNode, ArithmeticNode):
13 | """An ArithmeticLeafNode is an ArithmeticNode with no inputs."""
14 |
15 | def operands(self) -> List[Node]: # noqa: D102
16 | return []
17 |
18 | def set_operands(self, operands: List["Node"]): # noqa: D102
19 | pass
20 |
21 | def _arithmetize_inner(self, strategy: str) -> Node:
22 | return self
23 |
24 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
25 | return CostParetoFront.from_leaf(self, cost_of_squaring)
26 |
27 | def multiplicative_depth(self) -> int: # noqa: D102
28 | return 0
29 |
30 | def multiplicative_size(self) -> int: # noqa: D102
31 | return 0
32 |
33 | def multiplications(self) -> Set[int]: # noqa: D102
34 | return set()
35 |
36 | def squarings(self) -> Set[int]: # noqa: D102
37 | return set()
38 |
39 |
40 | # TODO: Merge ArithmeticInput and Input using multiple inheritance
41 | class Input(ArithmeticLeafNode):
42 | """Represents a named input to the arithmetic circuit."""
43 |
44 | @property
45 | def _overriden_graphviz_attributes(self) -> dict:
46 | return {"shape": "circle", "style": "filled", "fillcolor": "lightsteelblue1"}
47 |
48 | @property
49 | def _hash_name(self) -> str:
50 | return "input"
51 |
52 | @property
53 | def _node_label(self) -> str:
54 | return self._name
55 |
56 | def __init__(self, name: str, gf: Type[FieldArray]) -> None:
57 | """Initialize an input with the given `name`."""
58 | super().__init__(gf)
59 | self._name = name
60 |
61 |
62 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102
63 | raise Exception()
64 |
65 |
66 | def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102
67 | return actual_inputs[self._name]
68 |
69 |
70 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102
71 | if self._to_graph_cache is None:
72 | label = self._name
73 |
74 | self._to_graph_cache = graph_builder.add_node(
75 | label=label, **self._overriden_graphviz_attributes
76 | )
77 |
78 | return self._to_graph_cache
79 |
80 | def __hash__(self) -> int:
81 | return hash(self._name)
82 |
83 |
84 | def is_equivalent(self, other: Node) -> bool: # noqa: D102
85 | if not isinstance(other, self.__class__):
86 | return False
87 |
88 | return self._name == other._name
89 |
90 |
91 | def create_instructions( # noqa: D102
92 | self,
93 | instructions: List[ArithmeticInstruction],
94 | stack_counter: int,
95 | stack_occupied: List[bool],
96 | ) -> Tuple[int, int]:
97 | if self._instruction_cache is None:
98 | self._instruction_cache = select_stack_index(stack_occupied)
99 | instructions.append(InputInstruction(self._instruction_cache, self._name))
100 |
101 | return self._instruction_cache, stack_counter
102 |
103 |
104 | class Constant(ArithmeticLeafNode):
105 | """Represents a Node with a constant value."""
106 |
107 | @property
108 | def _overriden_graphviz_attributes(self) -> dict:
109 | return {"style": "filled", "fillcolor": "grey80", "shape": "circle"}
110 |
111 | @property
112 | def _hash_name(self) -> str:
113 | return "constant"
114 |
115 | @property
116 | def _node_label(self) -> str:
117 | return str(self._value)
118 |
119 | def __init__(self, value: FieldArray):
120 | """Initialize a Node with the given `value`."""
121 | super().__init__(value.__class__)
122 | self._value = value
123 |
124 |
125 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102
126 | return self._value
127 |
128 |
129 | def to_graph(self, graph_builder: DotFile) -> Any: # noqa: D102
130 | if self._to_graph_cache is None:
131 | label = str(self._value)
132 |
133 | self._to_graph_cache = graph_builder.add_node(
134 | label=label, **self._overriden_graphviz_attributes
135 | )
136 |
137 | return self._to_graph_cache
138 |
139 | def __hash__(self) -> int:
140 | return hash(int(self._value))
141 |
142 |
143 | def is_equivalent(self, other: Node) -> bool: # noqa: D102
144 | if not isinstance(other, self.__class__):
145 | return False
146 |
147 | return self._value == other._value
148 |
149 |
150 | def add(self, other: "Node", flatten=True) -> "Node": # noqa: D102
151 | if isinstance(other, Constant):
152 | return Constant(self._value + other._value)
153 |
154 | return other.add(self, flatten)
155 |
156 |
157 | def mul(self, other: "Node", flatten=True) -> "Node": # noqa: D102
158 | if isinstance(other, Constant):
159 | return Constant(self._value * other._value)
160 |
161 | return other.mul(self, flatten)
162 |
163 |
164 | def bool_or(self, other: "Node", flatten=True) -> Node: # noqa: D102
165 | if isinstance(other, Constant):
166 | return Constant(self._gf(bool(self._value) | bool(other._value)))
167 |
168 | return other.bool_or(self, flatten)
169 |
170 | def bool_and(self, other: "Node", flatten=True) -> Node: # noqa: D102
171 | if isinstance(other, Constant):
172 | return Constant(self._gf(bool(self._value) & bool(other._value)))
173 |
174 | return other.bool_and(self, flatten)
175 |
176 | def create_instructions( # noqa: D102
177 | self,
178 | instructions: List[ArithmeticInstruction],
179 | stack_counter: int,
180 | stack_occupied: List[bool],
181 | ) -> Tuple[int]:
182 | raise NotImplementedError("The circuit is a constant.")
183 |
184 |
185 | class DummyNode(FixedNode):
186 | """A DummyNode is a fixed node with no inputs and no behavior."""
187 |
188 | def operands(self) -> List[Node]: # noqa: D102
189 | return []
190 |
191 | def set_operands(self, operands: List["Node"]): # noqa: D102
192 | pass
193 |
--------------------------------------------------------------------------------
/oraqle/compiler/nodes/non_commutative.py:
--------------------------------------------------------------------------------
1 | """A collection of abstract nodes representing operations that are non-commutative."""
2 | from abc import abstractmethod
3 | from typing import List, Type
4 |
5 | from galois import FieldArray
6 |
7 | from oraqle.compiler.graphviz import DotFile
8 | from oraqle.compiler.nodes.abstract import Node
9 | from oraqle.compiler.nodes.fixed import BinaryNode
10 |
11 |
12 | class NonCommutativeBinaryNode(BinaryNode):
13 | """Represents a non-cummutative binary operation such as `x < y` or `x - y`."""
14 |
15 | def __init__(self, left, right, gf: Type[FieldArray]):
16 | """Initialize a Node that performs an operation between two operands that is not commutative."""
17 | self._left = left
18 | self._right = right
19 | super().__init__(gf)
20 |
21 | @abstractmethod
22 | def _operation_inner(self, x, y) -> FieldArray:
23 | """Applies the binary operation on x and y."""
24 |
25 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102
26 | return self._operation_inner(operands[0], operands[1])
27 |
28 | def operands(self) -> List[Node]: # noqa: D102
29 | return [self._left, self._right]
30 |
31 | def set_operands(self, operands: List["Node"]): # noqa: D102
32 | self._left = operands[0]
33 | self._right = operands[1]
34 |
35 | def __hash__(self) -> int:
36 | if self._hash is None:
37 | left_hash = hash(self._left)
38 | right_hash = hash(self._right)
39 |
40 | self._hash = hash((self._hash_name, (left_hash, right_hash)))
41 |
42 | return self._hash
43 |
44 | def is_equivalent(self, other: Node) -> bool: # noqa: D102
45 | if not isinstance(other, self.__class__):
46 | return False
47 |
48 | if hash(self) != hash(other):
49 | return False
50 |
51 | return self._left.is_equivalent(other._left) and self._right.is_equivalent(other._right)
52 |
53 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102
54 | if self._to_graph_cache is None:
55 | attributes = {"shape": "box"}
56 | attributes.update(self._overriden_graphviz_attributes)
57 |
58 | self._to_graph_cache = graph_builder.add_node(
59 | label=self._node_label,
60 | **attributes,
61 | )
62 |
63 | left = self._left.to_graph(graph_builder)
64 | right = self._right.to_graph(graph_builder)
65 |
66 | graph_builder.add_link(left, self._to_graph_cache, headport="nw")
67 | graph_builder.add_link(right, self._to_graph_cache, headport="ne")
68 |
69 | return self._to_graph_cache
70 |
--------------------------------------------------------------------------------
/oraqle/compiler/nodes/unary_arithmetic.py:
--------------------------------------------------------------------------------
1 | """This module contains `ArithmeticNode`s with a single input: Constant additions and constant multiplications."""
2 | from typing import List, Optional, Set, Tuple
3 |
4 | from galois import FieldArray
5 |
6 | from oraqle.compiler.graphviz import DotFile
7 | from oraqle.compiler.instructions import (
8 | ArithmeticInstruction,
9 | ConstantAdditionInstruction,
10 | ConstantMultiplicationInstruction,
11 | )
12 | from oraqle.compiler.nodes.abstract import ArithmeticNode, CostParetoFront, Node, select_stack_index
13 | from oraqle.compiler.nodes.univariate import UnivariateNode
14 |
15 | # TODO: There is (going to be) a lot of code duplication between these two classes
16 |
17 |
18 | class ConstantAddition(UnivariateNode, ArithmeticNode):
19 | """This node represents a multiplication of another node with a constant."""
20 |
21 | @property
22 | def _overriden_graphviz_attributes(self) -> dict:
23 | return {"style": "rounded,filled", "fillcolor": "grey80"}
24 |
25 | @property
26 | def _node_shape(self) -> str:
27 | return "square"
28 |
29 | @property
30 | def _hash_name(self) -> str:
31 | return f"constant_add_{self._constant}"
32 |
33 | @property
34 | def _node_label(self) -> str:
35 | return "+"
36 |
37 | def __init__(self, node: ArithmeticNode, constant: FieldArray):
38 | """Represents the operation `constant + node`."""
39 | super().__init__(node, constant.__class__)
40 | self._constant = constant
41 | assert constant != 0
42 |
43 | self._depth_cache: Optional[int] = None
44 |
45 |
46 | def _operation_inner(self, input: FieldArray) -> FieldArray:
47 | return input + self._constant
48 |
49 |
50 | def multiplicative_depth(self) -> int: # noqa: D102
51 | if self._depth_cache is None:
52 | self._depth_cache = self._node.multiplicative_depth()
53 |
54 | return self._depth_cache
55 |
56 |
57 | def multiplications(self) -> Set[int]: # noqa: D102
58 | return self._node.multiplications()
59 |
60 |
61 | def squarings(self) -> Set[int]: # noqa: D102
62 | return self._node.squarings()
63 |
64 |
65 | def create_instructions( # noqa: D102
66 | self,
67 | instructions: List[ArithmeticInstruction],
68 | stack_counter: int,
69 | stack_occupied: List[bool],
70 | ) -> Tuple[int, int]:
71 | self._node: ArithmeticNode
72 |
73 | if self._instruction_cache is None:
74 | operand_index, stack_counter = self._node.create_instructions(
75 | instructions, stack_counter, stack_occupied
76 | )
77 |
78 | self._node._parent_count -= 1
79 | if self._node._parent_count == 0:
80 | stack_occupied[self._node._instruction_cache] = False # type: ignore
81 |
82 | self._instruction_cache = select_stack_index(stack_occupied)
83 |
84 | instructions.append(
85 | ConstantAdditionInstruction(self._instruction_cache, operand_index, self._constant)
86 | )
87 |
88 | return self._instruction_cache, stack_counter
89 |
90 |
91 | def _arithmetize_inner(self, strategy: str) -> Node:
92 | return self
93 |
94 |
95 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
96 | front = CostParetoFront(cost_of_squaring)
97 | for _, _, node in self._node.arithmetize_depth_aware(cost_of_squaring):
98 | front.add(ConstantAddition(node, self._constant))
99 | return front
100 |
101 |
102 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102
103 | if self._to_graph_cache is None:
104 | super().to_graph(graph_builder)
105 | self._to_graph_cache: int
106 |
107 | # TODO: Add known_by
108 | graph_builder.add_link(
109 | graph_builder.add_node(
110 | label=str(self._constant), shape="circle", style="filled", fillcolor="grey92"
111 | ),
112 | self._to_graph_cache,
113 | )
114 |
115 | return self._to_graph_cache
116 |
117 |
118 | class ConstantMultiplication(UnivariateNode, ArithmeticNode):
119 | """This node represents a multiplication of another node with a constant."""
120 |
121 | @property
122 | def _overriden_graphviz_attributes(self) -> dict:
123 | return {"style": "rounded,filled", "fillcolor": "grey80"}
124 |
125 | @property
126 | def _node_shape(self) -> str:
127 | return "square"
128 |
129 | @property
130 | def _hash_name(self) -> str:
131 | return f"constant_mul_{self._constant}"
132 |
133 | @property
134 | def _node_label(self) -> str:
135 | return "×" # noqa: RUF001
136 |
137 | def __init__(self, node: Node, constant: FieldArray):
138 | """Represents the operation `constant * node`."""
139 | super().__init__(node, constant.__class__)
140 | self._constant = constant
141 | assert constant != 0
142 | assert constant != 1
143 |
144 | self._depth_cache: Optional[int] = None
145 |
146 | def _operation_inner(self, input: FieldArray) -> FieldArray:
147 | return input * self._constant # type: ignore
148 |
149 |
150 | def multiplicative_depth(self) -> int: # noqa: D102
151 | if self._depth_cache is None:
152 | self._depth_cache = self._node.multiplicative_depth() # type: ignore
153 |
154 | return self._depth_cache # type: ignore
155 |
156 |
157 | def multiplications(self) -> Set[int]: # noqa: D102
158 | return self._node.multiplications() # type: ignore
159 |
160 |
161 | def squarings(self) -> Set[int]: # noqa: D102
162 | return self._node.squarings() # type: ignore
163 |
164 |
165 | def create_instructions( # noqa: D102
166 | self,
167 | instructions: List[ArithmeticInstruction],
168 | stack_counter: int,
169 | stack_occupied: List[bool],
170 | ) -> Tuple[int, int]:
171 | self._node: ArithmeticNode
172 |
173 | if self._instruction_cache is None:
174 | operand_index, stack_counter = self._node.create_instructions(
175 | instructions, stack_counter, stack_occupied
176 | )
177 |
178 | self._node._parent_count -= 1
179 | if self._node._parent_count == 0:
180 | stack_occupied[self._node._instruction_cache] = False # type: ignore
181 |
182 | self._instruction_cache = select_stack_index(stack_occupied)
183 |
184 | instructions.append(
185 | ConstantMultiplicationInstruction(
186 | self._instruction_cache, operand_index, self._constant
187 | )
188 | )
189 |
190 | return self._instruction_cache, stack_counter
191 |
192 |
193 | def _arithmetize_inner(self, strategy: str) -> Node:
194 | return self
195 |
196 |
197 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
198 | front = CostParetoFront(cost_of_squaring)
199 | for _, _, node in self._node.arithmetize_depth_aware(cost_of_squaring):
200 | front.add(ConstantMultiplication(node, self._constant))
201 | return front
202 |
203 |
204 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102
205 | if self._to_graph_cache is None:
206 | super().to_graph(graph_builder)
207 | self._to_graph_cache: int
208 |
209 | # TODO: Add known_by
210 | graph_builder.add_link(
211 | graph_builder.add_node(
212 | label=str(self._constant), shape="circle", style="filled", fillcolor="grey92"
213 | ),
214 | self._to_graph_cache,
215 | )
216 |
217 | return self._to_graph_cache
218 |
--------------------------------------------------------------------------------
/oraqle/compiler/nodes/univariate.py:
--------------------------------------------------------------------------------
1 | """Abstract nodes for univariate operations."""
2 |
3 | from abc import abstractmethod
4 | from typing import List, Type
5 |
6 | from galois import FieldArray
7 |
8 | from oraqle.compiler.graphviz import DotFile
9 | from oraqle.compiler.nodes.abstract import Node
10 | from oraqle.compiler.nodes.fixed import FixedNode
11 | from oraqle.compiler.nodes.leafs import Constant
12 |
13 |
14 | class UnivariateNode(FixedNode):
15 | """An abstract node with a single input."""
16 |
17 | @property
18 | @abstractmethod
19 | def _node_shape(self) -> str:
20 | """Graphviz node shape."""
21 |
22 | def __init__(self, node: Node, gf: Type[FieldArray]):
23 | """Initialize a univariate node."""
24 | self._node = node
25 | assert not isinstance(node, Constant)
26 | super().__init__(gf)
27 |
28 |
29 | def operands(self) -> List["Node"]: # noqa: D102
30 | return [self._node]
31 |
32 |
33 | def set_operands(self, operands: List["Node"]): # noqa: D102
34 | self._node = operands[0]
35 |
36 | @abstractmethod
37 | def _operation_inner(self, input: FieldArray) -> FieldArray:
38 | """Evaluate the operation on the input. This method does not have to cache."""
39 |
40 |
41 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102
42 | return self._operation_inner(operands[0])
43 |
44 |
45 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102
46 | if self._to_graph_cache is None:
47 | attributes = {}
48 |
49 | attributes.update(self._overriden_graphviz_attributes)
50 |
51 | self._to_graph_cache = graph_builder.add_node(
52 | label=self._node_label, shape=self._node_shape, **attributes
53 | )
54 |
55 | graph_builder.add_link(self._node.to_graph(graph_builder), self._to_graph_cache)
56 |
57 | return self._to_graph_cache
58 |
59 | def __hash__(self) -> int:
60 | if self._hash is None:
61 | self._hash = hash((self._hash_name, self._node))
62 |
63 | return self._hash
64 |
65 | def is_equivalent(self, other: Node) -> bool:
66 | """Check whether `self` is semantically equivalent to `other`.
67 |
68 | This function may have false negatives but it should never return false positives.
69 |
70 | Returns:
71 | -------
72 | `True` if `self` is semantically equivalent to `other`, `False` if they are not or that they cannot be shown to be equivalent.
73 |
74 | """
75 | if not isinstance(other, self.__class__):
76 | return False
77 |
78 | if hash(self) != hash(other):
79 | return False
80 |
81 | return self._node.is_equivalent(other._node)
82 |
--------------------------------------------------------------------------------
/oraqle/compiler/poly2circuit.py:
--------------------------------------------------------------------------------
1 | """Module for automatic circuit generation for any functions with any number of inputs.
2 |
3 | Warning: These circuits can be very large!
4 | """
5 |
6 | from collections import Counter
7 | from typing import Dict, List, Tuple, Type
8 |
9 | from galois import GF, FieldArray
10 | from sympy import Add, Integer, Mul, Poly, Pow, Symbol
11 | from sympy.core.numbers import NegativeOne
12 |
13 | from oraqle.compiler.circuit import Circuit
14 | from oraqle.compiler.func2poly import interpolate_polynomial
15 | from oraqle.compiler.nodes import Constant, Input, Node
16 | from oraqle.compiler.nodes.abstract import UnoverloadedWrapper
17 | from oraqle.compiler.nodes.arbitrary_arithmetic import Product
18 |
19 |
20 | def construct_subcircuit(expression, gf, modulus: int, inputs: Dict[str, Input]) -> Node: # noqa: PLR0912
21 | """Build a circuit with a single output given an expression of simple arithmetic operations in Sympy.
22 |
23 | Raises:
24 | ------
25 | Exception: Exponents must be integers, or an exception will be raised.
26 |
27 | Returns:
28 | -------
29 | A subcircuit (Node) computing the given sympy expression.
30 |
31 | """
32 | if expression.func == Add:
33 | arg_iter = iter(expression.args)
34 |
35 | # The first argument can be a scalar.
36 | first = next(arg_iter)
37 | if first.func in {Integer, NegativeOne}:
38 | if first.func == Integer:
39 | scalar = Constant(gf(int(first) % modulus))
40 | else:
41 | scalar = Constant(-gf(1))
42 | result = scalar + construct_subcircuit(next(arg_iter), gf, modulus, inputs)
43 | else:
44 | # TODO: Replace this entire part with a sum
45 | result = construct_subcircuit(first, gf, modulus, inputs) + construct_subcircuit(
46 | next(arg_iter), gf, modulus, inputs
47 | )
48 |
49 | for arg in arg_iter:
50 | result = construct_subcircuit(arg, gf, modulus, inputs) + result
51 |
52 | return result
53 | elif expression.func == Mul:
54 | arg_iter = iter(expression.args)
55 |
56 | # The first argument can be a scalar.
57 | first = next(arg_iter)
58 | if first.func in {Integer, NegativeOne}:
59 | if first.func == Integer:
60 | scalar = Constant(gf(int(first) % modulus))
61 | else:
62 | scalar = Constant(-gf(1))
63 | result = scalar * construct_subcircuit(next(arg_iter), gf, modulus, inputs)
64 | else:
65 | # TODO: Replace this entire part with a product
66 | result = construct_subcircuit(first, gf, modulus, inputs) * construct_subcircuit(
67 | next(arg_iter), gf, modulus, inputs
68 | )
69 |
70 | for arg in arg_iter:
71 | result = construct_subcircuit(arg, gf, modulus, inputs) * result
72 |
73 | return result
74 | elif expression.func == Pow:
75 | if expression.args[1].func != Integer:
76 | raise Exception("There was an exponent with a non-integer exponent")
77 | # Change powers to series of multiplications
78 | subcircuit = construct_subcircuit(expression.args[0], gf, modulus, inputs)
79 | # TODO: This is not the most efficient way; we can use re-balancing.
80 | return Product(
81 | Counter({UnoverloadedWrapper(subcircuit): int(expression.args[1])}), gf
82 | ) # FIXME: This could be flattened
83 | elif expression.func == Symbol:
84 | assert len(expression.args) == 0
85 | var = str(expression)
86 | if var in inputs:
87 | return inputs[var]
88 | new_input = Input(var, gf)
89 | inputs[var] = new_input
90 | return new_input
91 | else:
92 | raise Exception(
93 | f"The expression contained an invalid operation (not one implemented in arithmetic circuits): {expression.func}."
94 | )
95 |
96 |
97 | def construct_circuit(polynomials: List[Poly], modulus: int) -> Tuple[Circuit, Type[FieldArray]]:
98 | """Construct an arithmetic circuit from a list of polynomials and the fixed modulus.
99 |
100 | Returns:
101 | -------
102 | A circuit outputting the evaluation of each polynomial.
103 |
104 | """
105 | inputs = {}
106 | gf = GF(modulus)
107 | return (
108 | Circuit(
109 | [construct_subcircuit(poly.expr, gf, modulus, inputs) for poly in polynomials],
110 | ),
111 | gf,
112 | )
113 |
114 |
115 | if __name__ == "__main__":
116 | # Use function max(x, y)
117 | function = max
118 | modulus = 7
119 |
120 | # Create a polynomial and then a circuit that evalutes this expression
121 | poly = interpolate_polynomial(function, modulus, ["x", "y"])
122 | circuit, gf = construct_circuit([poly], modulus)
123 |
124 | # Output a DOT file for this high-level circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/)
125 | circuit.to_graph("max_7_hl.dot")
126 |
127 | # Arithmetize the high-level circuit, afterwards it will only contain arithmetic operations
128 | circuit = circuit.arithmetize()
129 | circuit.to_graph("max_7_hl.dot")
130 |
131 | # Print the initial metrics of the circuit
132 | print("depth", circuit.multiplicative_depth())
133 | print("size", circuit.multiplicative_size())
134 |
135 | # Apply common subexpression elimination (CSE) to remove duplicate operations from the circuit
136 | circuit.eliminate_subexpressions()
137 |
138 | # Output a DOT file for this arithmetic circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/)
139 | circuit.to_graph("max_7.dot")
140 |
141 | # Print the resulting metrics of the circuit
142 | print("depth", circuit.multiplicative_depth())
143 | print("size", circuit.multiplicative_size())
144 |
145 | # Test that given x=4 and y=2 indeed max(x, y) = 4
146 | assert circuit.evaluate({"x": gf(4), "y": gf(2)}) == [4]
147 |
148 | # Output a DOT file for this arithmetic circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/)
149 | circuit.to_graph("max_7.dot")
150 |
--------------------------------------------------------------------------------
/oraqle/compiler/polynomials/__init__.py:
--------------------------------------------------------------------------------
1 | """The polynomials package contains nodes for performing polynomial evaluation.
2 |
3 | In a finite field, the set of polyfunctions is the same as the set of all functions.
4 | So, you can perform any function by interpolating a polynomial.
5 | """
6 |
--------------------------------------------------------------------------------
/oraqle/config.py:
--------------------------------------------------------------------------------
1 | """This module contains global configuration options.
2 |
3 | !!! warning
4 | This is almost certainly going to be removed in the future.
5 | We do not want oraqle to have a global configuration, but this is currently an intentional evil to prevent large refactors in the initial versions.
6 | """
7 | from typing import Annotated, Optional
8 |
9 |
10 | Seconds = Annotated[float, "seconds"]
11 | MAXSAT_TIMEOUT: Optional[Seconds] = None
12 | """Time-out for individual calls to the MaxSAT solver.
13 |
14 | !!! danger
15 | This causes non-deterministic behavior!
16 |
17 | !!! bug
18 | There is currently a chance to get `AttributeError`s, which is a bug caused by pysat trying to delete an oracle that does not exist.
19 | There is no current workaround for this."""
20 |
21 |
22 | PS_METHOD_FACTOR_K: float = 2.0
23 | """Approximation factor for the PS-method, higher is better.
24 |
25 | The Paterson-Stockmeyer method takes a value k, that is theoretically optimal when k = sqrt(2 * degree).
26 | However, sometimes it is better to try other values of k (e.g. due to rounding and to trade off depth and cost).
27 | This factor, let's call it f, is used to limit the candidate values of k that we try: [1, f * sqrt(2 * degree))."""
28 |
--------------------------------------------------------------------------------
/oraqle/demo/playground.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b5e62be9-cada-42d2-a2bb-7f3ee38aec51",
6 | "metadata": {},
7 | "source": [
8 | "# Playground"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "df1eaaad-6ad2-4601-8d12-3b02d9254bfa",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from galois import GF\n",
19 | "\n",
20 | "from circuit_compiler.compiler.boolean.bool_and import And\n",
21 | "from circuit_compiler.compiler.circuit import Circuit\n",
22 | "from circuit_compiler.compiler.nodes.leafs import Input\n",
23 | "\n",
24 | "gf = GF(5)\n",
25 | "\n",
26 | "xs = [Input(f\"x{i}\", gf) for i in range(11)]\n",
27 | "\n",
28 | "output = And(set(xs), gf)\n",
29 | "\n",
30 | "circuit = Circuit(outputs=[output], gf=gf)\n",
31 | "circuit.display_graph()"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "id": "ce27d985-b20c-4f7e-a303-929598c61c17",
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "naive_arithmetic_circuit = circuit.arithmetize(\"naive\")\n",
42 | "naive_arithmetic_circuit.display_graph()"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "id": "d5c0531f-f50b-4852-b81c-833a813eb235",
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "circuit._clear_cache()\n",
53 | "better_arithmetic_circuit = circuit.arithmetize(\"best-effort\")\n",
54 | "better_arithmetic_circuit.display_graph()"
55 | ]
56 | }
57 | ],
58 | "metadata": {
59 | "kernelspec": {
60 | "display_name": "Python 3 (ipykernel)",
61 | "language": "python",
62 | "name": "python3"
63 | },
64 | "language_info": {
65 | "codemirror_mode": {
66 | "name": "ipython",
67 | "version": 3
68 | },
69 | "file_extension": ".py",
70 | "mimetype": "text/x-python",
71 | "name": "python",
72 | "nbconvert_exporter": "python",
73 | "pygments_lexer": "ipython3",
74 | "version": "3.9.6"
75 | }
76 | },
77 | "nbformat": 4,
78 | "nbformat_minor": 5
79 | }
80 |
--------------------------------------------------------------------------------
/oraqle/demo/small_comparison_bgv.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "0f2abd68-5065-49c2-aefa-65ca3c8be8f8",
6 | "metadata": {},
7 | "source": [
8 | "# Compiling homomorphic encryption circuits made easy"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "1f425b04-35ab-4a1c-8ed4-20ecdc7d2901",
14 | "metadata": {},
15 | "source": [
16 | "#### The only boilerplate consists of defining the plaintext space and the inputs of the program."
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "18d03a72-d22a-4f54-9a68-ab31507d1e34",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "from galois import GF\n",
27 | "\n",
28 | "from circuit_compiler.compiler.nodes.leafs import Input\n",
29 | "\n",
30 | "gf = GF(11)\n",
31 | "\n",
32 | "a = Input(\"a\", gf)\n",
33 | "b = Input(\"b\", gf)"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "id": "7a7890f4-c770-4699-acba-ec2e6796a5bb",
39 | "metadata": {},
40 | "source": [
41 | "#### Programmers can use the primitives that they are used to."
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "id": "0dd02769-50cc-4eb1-a9e0-896b944d9b28",
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "output = a < b"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "id": "8a26b9ca-2441-48e1-8aad-4b626755485e",
57 | "metadata": {},
58 | "source": [
59 | "#### A circuit can have an arbitrary number of outputs; here we only have one."
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "d00fa605-4510-4393-bdb0-4dd54a21f5f8",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "from circuit_compiler.compiler.circuit import Circuit\n",
70 | "\n",
71 | "circuit = Circuit(outputs=[output], gf=gf)\n",
72 | "circuit.display_graph()"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "id": "fc7c6e33-a7ad-4e2f-a742-40653160a0ca",
78 | "metadata": {},
79 | "source": [
80 | "#### Turning high-level circuits into arithmetic circuits is a fully automatic process that improves on the state of the art in multiple ways."
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "id": "a441c9f5-de63-4253-bbb6-4b63511acc67",
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "arithmetic_circuit = circuit.arithmetize()\n",
91 | "arithmetic_circuit.display_graph()"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "id": "33a64549-4081-4fb8-9631-1f007b368dfa",
97 | "metadata": {},
98 | "source": [
99 | "#### The compiler implements a form of semantic subexpression elimination that significantly optimizes large circuits."
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "id": "1cd8bfaf-8113-444a-812c-b2a4fe124cec",
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "arithmetic_circuit.eliminate_subexpressions()\n",
110 | "arithmetic_circuit.display_graph()"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "id": "a89d7c56-ef33-4ac6-b06a-0f88d45aff91",
116 | "metadata": {},
117 | "source": [
118 | "#### This much smaller circuit is still correct!"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "id": "5d50141a-b84a-4ac4-93e7-a7a1cf484688",
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "import tabulate\n",
129 | "\n",
130 | "for val_a in range(11):\n",
131 | " for val_b in range(11):\n",
132 | " assert arithmetic_circuit.evaluate({\"a\": gf(val_a), \"b\": gf(val_b)}) == gf(val_a < val_b)\n",
133 | "\n",
134 | "data = [[arithmetic_circuit.evaluate({\"a\": gf(val_a), \"b\": gf(val_b)})[0] for val_a in range(11)] for val_b in range(11)]\n",
135 | "\n",
136 | "table = tabulate.tabulate(data, tablefmt='html')\n",
137 | "table"
138 | ]
139 | }
140 | ],
141 | "metadata": {
142 | "kernelspec": {
143 | "display_name": "Python 3 (ipykernel)",
144 | "language": "python",
145 | "name": "python3"
146 | },
147 | "language_info": {
148 | "codemirror_mode": {
149 | "name": "ipython",
150 | "version": 3
151 | },
152 | "file_extension": ".py",
153 | "mimetype": "text/x-python",
154 | "name": "python",
155 | "nbconvert_exporter": "python",
156 | "pygments_lexer": "ipython3",
157 | "version": "3.9.6"
158 | }
159 | },
160 | "nbformat": 4,
161 | "nbformat_minor": 5
162 | }
163 |
--------------------------------------------------------------------------------
/oraqle/examples/depth_aware_comparison.py:
--------------------------------------------------------------------------------
1 | """Depth-aware arithmetization of a comparison modulo 101."""
2 |
3 | from galois import GF
4 |
5 | from oraqle.compiler.circuit import Circuit
6 | from oraqle.compiler.nodes.leafs import Input
7 |
8 | gf = GF(101)
9 | cost_of_squaring = 1.0
10 |
11 | a = Input("a", gf)
12 | b = Input("b", gf)
13 |
14 | output = a < b
15 |
16 | circuit = Circuit(outputs=[output])
17 | circuit.to_graph("high_level_circuit.dot")
18 |
19 | arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring)
20 |
21 | for depth, cost, arithmetic_circuit in arithmetic_circuits:
22 | assert arithmetic_circuit.multiplicative_depth() == depth
23 | assert arithmetic_circuit.multiplicative_cost(cost_of_squaring) == cost
24 |
25 | print("pre CSE", depth, cost)
26 |
27 | arithmetic_circuit.eliminate_subexpressions()
28 |
29 | print(
30 | "post CSE",
31 | arithmetic_circuit.multiplicative_depth(),
32 | arithmetic_circuit.multiplicative_cost(cost_of_squaring),
33 | )
34 |
--------------------------------------------------------------------------------
/oraqle/examples/depth_aware_equality.py:
--------------------------------------------------------------------------------
1 | """Depth-aware arithmetization for an equality operation modulo 31."""
2 |
3 | from galois import GF
4 |
5 | from oraqle.compiler.circuit import Circuit
6 | from oraqle.compiler.comparison.equality import Equals
7 | from oraqle.compiler.nodes.leafs import Input
8 |
9 | gf = GF(31)
10 |
11 | a = Input("a", gf)
12 | b = Input("b", gf)
13 |
14 | output = Equals(a, b, gf)
15 |
16 | circuit = Circuit(outputs=[output])
17 |
18 | arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring=1.0)
19 |
20 | if __name__ == "__main__":
21 | circuit.to_pdf("high_level_circuit.pdf")
22 | for depth, size, arithmetic_circuit in arithmetic_circuits:
23 | arithmetic_circuit.to_pdf(f"arithmetic_circuit_d{depth}_s{size}.pdf")
24 |
--------------------------------------------------------------------------------
/oraqle/examples/long_and.py:
--------------------------------------------------------------------------------
1 | """Arithmetization of an AND operation between 15 inputs."""
2 |
3 | from galois import GF
4 |
5 | from oraqle.compiler.boolean.bool_and import And
6 | from oraqle.compiler.circuit import Circuit
7 | from oraqle.compiler.nodes.abstract import UnoverloadedWrapper
8 | from oraqle.compiler.nodes.leafs import Input
9 |
10 | gf = GF(5)
11 |
12 | xs = [Input(f"x{i}", gf) for i in range(15)]
13 |
14 | output = And(set(UnoverloadedWrapper(x) for x in xs), gf)
15 |
16 | circuit = Circuit(outputs=[output])
17 | circuit.to_graph("high_level_circuit.dot")
18 |
19 | arithmetic_circuit = circuit.arithmetize()
20 | arithmetic_circuit.to_graph("arithmetic_circuit.dot")
21 |
--------------------------------------------------------------------------------
/oraqle/examples/small_comparison.py:
--------------------------------------------------------------------------------
1 | """Arithmetizes a comparison modulo 11 with a constant."""
2 |
3 | from galois import GF
4 |
5 | from oraqle.compiler.circuit import Circuit
6 | from oraqle.compiler.nodes.leafs import Constant, Input
7 |
8 | gf = GF(11)
9 |
10 | a = Input("a", gf)
11 | b = Constant(gf(3)) # Input("b")
12 |
13 | output = a < b
14 |
15 | circuit = Circuit(outputs=[output])
16 | circuit.to_graph("high_level_circuit.dot")
17 |
18 | arithmetic_circuit = circuit.arithmetize()
19 | arithmetic_circuit.to_graph("arithmetic_circuit.dot")
20 |
--------------------------------------------------------------------------------
/oraqle/examples/small_polynomial.py:
--------------------------------------------------------------------------------
1 | """Creates graphs for the arithmetization of a small polynomial evaluation."""
2 |
3 | from galois import GF
4 |
5 | from oraqle.compiler.circuit import Circuit
6 | from oraqle.compiler.nodes.leafs import Input
7 | from oraqle.compiler.polynomials.univariate import UnivariatePoly
8 |
9 | gf = GF(11)
10 |
11 | x = Input("x", gf)
12 |
13 | output = UnivariatePoly(x, [gf(1), gf(2), gf(3), gf(4), gf(5), gf(6), gf(1)], gf)
14 |
15 | circuit = Circuit(outputs=[output])
16 | circuit.to_graph("high_level_circuit.dot")
17 |
18 | arithmetic_circuit = circuit.arithmetize()
19 | arithmetic_circuit.to_graph("arithmetic_circuit.dot")
20 |
--------------------------------------------------------------------------------
/oraqle/examples/visualize_circuits.py:
--------------------------------------------------------------------------------
1 | """Visualization of three circuits computing an OR operation on 7 inputs."""
2 |
3 | from galois import GF
4 |
5 | from oraqle.compiler.arithmetic.exponentiation import Power
6 | from oraqle.compiler.boolean.bool_neg import Neg
7 | from oraqle.compiler.circuit import ArithmeticCircuit, Circuit
8 | from oraqle.compiler.nodes.binary_arithmetic import Multiplication
9 | from oraqle.compiler.nodes.leafs import Input
10 |
11 | gf = GF(5)
12 |
13 | x1 = Input("x1", gf)
14 | x2 = Input("x2", gf)
15 | x3 = Input("x3", gf)
16 | x4 = Input("x4", gf)
17 | x5 = Input("x5", gf)
18 | x6 = Input("x6", gf)
19 | x7 = Input("x7", gf)
20 |
21 | sum1 = x1 + x2 + x3 + x4
22 | exp1 = Power(sum1, 4, gf)
23 |
24 | sum2 = x5 + x6 + x7 + exp1
25 | exp2 = Power(sum2, 4, gf)
26 |
27 | circuit = Circuit([exp2])
28 | arithmetic_circuit = circuit.arithmetize()
29 | arithmetic_circuit.to_graph("arithmetic_circuit1.dot")
30 |
31 |
32 | inv1 = Neg(x1, gf)
33 | inv2 = Neg(x2, gf)
34 | inv3 = Neg(x3, gf)
35 | inv4 = Neg(x4, gf)
36 | inv5 = Neg(x5, gf)
37 | inv6 = Neg(x6, gf)
38 |
39 | mul1 = inv1 * inv2
40 | invmul1 = Neg(mul1, gf)
41 |
42 | mul2 = inv3 * inv4
43 | invmul2 = Neg(mul2, gf)
44 |
45 | mul3 = inv5 * inv6
46 | invmul3 = Neg(mul3, gf)
47 |
48 | add1 = mul1 + mul2
49 | add2 = mul3 + add1
50 |
51 | add3 = add2 + x7
52 |
53 | exp = Power(add3, 4, gf)
54 |
55 | circuit = Circuit([exp])
56 | arithmetic_circuit = circuit.arithmetize()
57 | arithmetic_circuit.to_graph("arithmetic_circuit2.dot")
58 |
59 |
60 | inv1 = Neg(x1, gf).arithmetize("best-effort").to_arithmetic()
61 | inv2 = Neg(x2, gf).arithmetize("best-effort").to_arithmetic()
62 | inv3 = Neg(x3, gf).arithmetize("best-effort").to_arithmetic()
63 | inv4 = Neg(x4, gf).arithmetize("best-effort").to_arithmetic()
64 | inv5 = Neg(x5, gf).arithmetize("best-effort").to_arithmetic()
65 | inv6 = Neg(x6, gf).arithmetize("best-effort").to_arithmetic()
66 | inv7 = Neg(x7, gf).arithmetize("best-effort").to_arithmetic()
67 |
68 | mul1 = Multiplication(inv1, inv2, gf)
69 | mul2 = Multiplication(inv3, inv4, gf)
70 | mul3 = Multiplication(inv5, inv6, gf)
71 |
72 | mul4 = Multiplication(mul1, mul2, gf)
73 | mul5 = Multiplication(mul3, inv7, gf)
74 |
75 | mul6 = Multiplication(mul4, mul5, gf)
76 |
77 | inv = Neg(mul6, gf).arithmetize("best-effort").to_arithmetic()
78 |
79 | arithmetic_circuit = ArithmeticCircuit([inv])
80 | arithmetic_circuit.to_graph("arithmetic_circuit3.dot")
81 |
--------------------------------------------------------------------------------
/oraqle/examples/wahc2024_presentation/1_high-level.py:
--------------------------------------------------------------------------------
1 | """Renders a high-level comparison circuit."""
2 | from galois import GF
3 |
4 | from oraqle.compiler.circuit import Circuit
5 | from oraqle.compiler.nodes.leafs import Input
6 |
7 |
8 | if __name__ == "__main__":
9 | gf = GF(101)
10 |
11 | alex = Input("a", gf)
12 | blake = Input("b", gf)
13 |
14 | output = alex < blake
15 | circuit = Circuit(outputs=[output])
16 |
17 | circuit.to_svg("high_level.svg")
18 |
--------------------------------------------------------------------------------
/oraqle/examples/wahc2024_presentation/2_arith_step1.py:
--------------------------------------------------------------------------------
1 | """Show the first step of arithmetization of a comparison circuit."""
2 | from galois import GF
3 |
4 | from oraqle.compiler.boolean.bool_neg import Neg
5 | from oraqle.compiler.circuit import Circuit
6 | from oraqle.compiler.comparison.comparison import SemiStrictComparison
7 | from oraqle.compiler.nodes.leafs import Constant, Input
8 |
9 | if __name__ == "__main__":
10 | gf = GF(101)
11 |
12 | alex = Input("a", gf)
13 | blake = Input("b", gf)
14 |
15 | output = alex < blake
16 |
17 |
18 | p = output._gf.characteristic
19 |
20 | if output._less_than:
21 | left = output._left
22 | right = output._right
23 | else:
24 | left = output._right
25 | right = output._left
26 |
27 | left = left.arithmetize("best-effort")
28 | right = right.arithmetize("best-effort")
29 |
30 | left_is_small = SemiStrictComparison(
31 | left, Constant(output._gf(p // 2)), less_than=True, gf=output._gf
32 | )
33 | right_is_small = SemiStrictComparison(
34 | right, Constant(output._gf(p // 2)), less_than=True, gf=output._gf
35 | )
36 |
37 | # Test whether left and right are in the same range
38 | same_range = (left_is_small & right_is_small) + (
39 | Neg(left_is_small, output._gf) & Neg(right_is_small, output._gf)
40 | )
41 |
42 | # Performs left < right on the reduced inputs, note that if both are in the upper half the difference is still small enough for a semi-comparison
43 | comparison = SemiStrictComparison(left, right, less_than=True, gf=output._gf)
44 | result = same_range * comparison
45 |
46 | # Performs left < right when one if small and the other is large
47 | right_is_larger = left_is_small & Neg(right_is_small, output._gf)
48 | result += right_is_larger
49 |
50 |
51 | circuit = Circuit(outputs=[result])
52 |
53 | circuit.to_svg("arith_step1.svg")
54 |
--------------------------------------------------------------------------------
/oraqle/examples/wahc2024_presentation/3_arith_step2.py:
--------------------------------------------------------------------------------
1 | """Show the last step of arithmetization of a comparison circuit."""
2 | from galois import GF
3 |
4 | from oraqle.compiler.circuit import Circuit
5 | from oraqle.compiler.nodes.leafs import Input
6 |
7 | if __name__ == "__main__":
8 | gf = GF(101)
9 |
10 | alex = Input("a", gf)
11 | blake = Input("b", gf)
12 |
13 | output = alex < blake
14 |
15 | front = output.arithmetize_depth_aware(cost_of_squaring=1.0)
16 | print(front)
17 |
18 | _, tup = front._nodes_by_depth.popitem()
19 | _, node = tup
20 | circuit = Circuit(outputs=[node])
21 |
22 | circuit.to_svg("arith_step2.svg")
23 |
--------------------------------------------------------------------------------
/oraqle/examples/wahc2024_presentation/5_code_gen.py:
--------------------------------------------------------------------------------
1 | """Generates code for the comparison circuit."""
2 | from galois import GF
3 |
4 | from oraqle.compiler.circuit import Circuit
5 | from oraqle.compiler.nodes.leafs import Input
6 |
7 | if __name__ == "__main__":
8 | gf = GF(101)
9 |
10 | alex = Input("a", gf)
11 | blake = Input("b", gf)
12 |
13 | output = alex < blake
14 | circuit = Circuit(outputs=[output])
15 |
16 | front = circuit.arithmetize_depth_aware()
17 |
18 | for _, _, arithmetic_circuit in front:
19 | program = arithmetic_circuit.generate_code("example.cpp")
20 |
--------------------------------------------------------------------------------
/oraqle/examples/wahc2024_presentation/rebalancing.py:
--------------------------------------------------------------------------------
1 | """Renders two circuits, one with a balanced product tree and one with an imbalanced tree."""
2 | from galois import GF
3 |
4 | from oraqle.compiler.circuit import Circuit
5 | from oraqle.compiler.nodes.leafs import Input
6 |
7 | if __name__ == "__main__":
8 | gf = GF(101)
9 |
10 | a = Input("a", gf)
11 | b = Input("b", gf)
12 | c = Input("c", gf)
13 | d = Input("d", gf)
14 |
15 | output = a * b * c * d
16 | circuit_good = Circuit(outputs=[output])
17 | circuit_good = circuit_good.arithmetize_depth_aware() # FIXME: This should also work with arithmetize
18 | circuit_good[0][2].to_svg("rebalancing_good.svg")
19 |
20 | ab = a.mul(b, flatten=False)
21 | abc = ab.mul(c, flatten=False)
22 | abcd = abc.mul(d, flatten=False)
23 | circuit_bad = Circuit(outputs=[abcd])
24 | circuit_bad = circuit_bad.arithmetize()
25 | circuit_bad.to_svg("rebalancing_bad.svg")
26 |
--------------------------------------------------------------------------------
/oraqle/experiments/depth_aware_arithmetization/execution/bench_cardio_circuits.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | from typing import Dict
4 |
5 | from galois import GF
6 |
7 | from oraqle.circuits.cardio import (
8 | construct_cardio_elevated_risk_circuit,
9 | construct_cardio_risk_circuit,
10 | )
11 | from oraqle.compiler.circuit import Circuit
12 |
13 |
14 | def gen_params() -> Dict[str, int]:
15 | params = {}
16 |
17 | params["man"] = random.randint(0, 1)
18 | params["smoking"] = random.randint(0, 1)
19 | params["diabetic"] = random.randint(0, 1)
20 | params["hbp"] = random.randint(0, 1)
21 |
22 | params["age"] = random.randint(0, 100)
23 | params["cholesterol"] = random.randint(0, 60)
24 | params["weight"] = random.randint(40, 150)
25 | params["height"] = random.randint(80, 210)
26 | params["activity"] = random.randint(0, 250)
27 | params["alcohol"] = random.randint(0, 5)
28 |
29 | return params
30 |
31 |
32 | if __name__ == "__main__":
33 | gf = GF(257)
34 | iterations = 10
35 |
36 | for cost_of_squaring in [0.75]:
37 | print(f"--- Cardio risk assessment ({cost_of_squaring}) ---")
38 | circuit = Circuit([construct_cardio_risk_circuit(gf)])
39 |
40 | start = time.monotonic()
41 | front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring)
42 | print("Compile time:", time.monotonic() - start, "s")
43 |
44 | for depth, cost, arithmetic_circuit in front:
45 | print(depth, cost)
46 | run_time = arithmetic_circuit.run_using_helib(iterations, True, False, **gen_params())
47 | print("Run time:", run_time)
48 |
49 | print(f"--- Cardio elevated risk assessment ({cost_of_squaring}) ---")
50 | circuit = Circuit([construct_cardio_elevated_risk_circuit(gf)])
51 |
52 | start = time.monotonic()
53 | front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring)
54 | print("Compile time:", time.monotonic() - start, "s")
55 |
56 | for depth, cost, arithmetic_circuit in front:
57 | print(depth, cost)
58 | run_time = arithmetic_circuit.run_using_helib(iterations, True, False, **gen_params())
59 | print("Run time:", run_time)
60 |
--------------------------------------------------------------------------------
/oraqle/experiments/depth_aware_arithmetization/execution/bench_equality.py:
--------------------------------------------------------------------------------
1 | from galois import GF
2 |
3 | from oraqle.compiler.circuit import Circuit
4 | from oraqle.compiler.nodes.leafs import Input
5 |
6 |
7 | if __name__ == "__main__":
8 | iterations = 10
9 |
10 | for p in [29, 43, 61, 101, 131]:
11 | gf = GF(p)
12 |
13 | x = Input("x", gf)
14 | y = Input("y", gf)
15 |
16 | circuit = Circuit([x == y])
17 |
18 | for d, c, arith in circuit.arithmetize_depth_aware(0.75):
19 | print(d, c, arith.run_using_helib(10, True, False, x=13, y=19))
20 |
21 | arith = circuit.arithmetize('naive')
22 | print('square and multiply', arith.multiplicative_depth(), arith.multiplicative_size(), arith.multiplicative_cost(0.75), arith.run_using_helib(10, True, False, x=13, y=19))
23 |
--------------------------------------------------------------------------------
/oraqle/experiments/depth_aware_arithmetization/execution/cardio_circuits.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from galois import GF
4 |
5 | from oraqle.circuits.cardio import (
6 | construct_cardio_elevated_risk_circuit,
7 | construct_cardio_risk_circuit,
8 | )
9 | from oraqle.compiler.circuit import Circuit
10 |
11 | if __name__ == "__main__":
12 | gf = GF(257)
13 |
14 | for cost_of_squaring in [0.5, 0.75, 1.0]:
15 | print(f"--- Cardio risk assessment ({cost_of_squaring}) ---")
16 | circuit = Circuit([construct_cardio_risk_circuit(gf)])
17 |
18 | start = time.monotonic()
19 | front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring)
20 | print("Run time:", time.monotonic() - start, "s")
21 |
22 | for depth, cost, arithmetic_circuit in front:
23 | print(depth, cost)
24 | arithmetic_circuit.to_graph(f"cardio_arith_d{depth}_c{cost}.dot")
25 |
26 | print(f"--- Cardio elevated risk assessment ({cost_of_squaring}) ---")
27 | circuit = Circuit([construct_cardio_elevated_risk_circuit(gf)])
28 |
29 | start = time.monotonic()
30 | front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring)
31 | print("Run time:", time.monotonic() - start, "s")
32 |
33 | for depth, cost, arithmetic_circuit in front:
34 | print(depth, cost)
35 | arithmetic_circuit.to_graph(f"cardio_elevated_arith_d{depth}_c{cost}.dot")
36 |
--------------------------------------------------------------------------------
/oraqle/experiments/depth_aware_arithmetization/execution/comparisons.py:
--------------------------------------------------------------------------------
1 | from galois import GF
2 |
3 | from oraqle.compiler.circuit import Circuit
4 | from oraqle.compiler.comparison.comparison import (
5 | IliashenkoZuccaSemiLessThan,
6 | SemiStrictComparison,
7 | T2SemiLessThan,
8 | )
9 | from oraqle.compiler.nodes.leafs import Input
10 |
11 | if __name__ == "__main__":
12 | iterations = 10
13 |
14 | for p in [29, 43, 61, 101, 131]:
15 | gf = GF(p)
16 |
17 | x = Input("x", gf)
18 | y = Input("y", gf)
19 |
20 | print(f"-------- p = {p}: ---------")
21 | our_circuit = Circuit([SemiStrictComparison(x, y, less_than=True, gf=gf)])
22 | our_front = our_circuit.arithmetize_depth_aware()
23 | print("Our circuits:", our_front)
24 |
25 | our_front[0][2].to_graph(f"comp_{p}_ours.dot")
26 | for d, s, circ in our_front:
27 | print(d, s, circ.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22))
28 |
29 | t2_circuit = Circuit([T2SemiLessThan(x, y, gf)])
30 | t2_arithmetization = t2_circuit.arithmetize()
31 | print(
32 | "T2 circuit:",
33 | t2_arithmetization.multiplicative_depth(),
34 | t2_arithmetization.multiplicative_size(),
35 | t2_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)
36 | )
37 | t2_arithmetization.eliminate_subexpressions()
38 | print(
39 | "T2 circuit CSE:",
40 | t2_arithmetization.multiplicative_depth(),
41 | t2_arithmetization.multiplicative_size(),
42 | t2_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)
43 | )
44 |
45 | iz21_circuit = Circuit([IliashenkoZuccaSemiLessThan(x, y, gf)])
46 | iz21_arithmetization = iz21_circuit.arithmetize()
47 | print(
48 | "IZ21 circuits:",
49 | iz21_arithmetization.multiplicative_depth(),
50 | iz21_arithmetization.multiplicative_size(),
51 | iz21_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)
52 | )
53 | iz21_arithmetization.eliminate_subexpressions()
54 | print(
55 | "IZ21 circuit CSE:",
56 | iz21_arithmetization.multiplicative_depth(),
57 | iz21_arithmetization.multiplicative_size(),
58 | iz21_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)
59 | )
60 |
--------------------------------------------------------------------------------
/oraqle/experiments/depth_aware_arithmetization/execution/equality_first_prime_mods_exec.py:
--------------------------------------------------------------------------------
1 | import math
2 | import multiprocessing
3 | import pickle
4 | import time
5 | from functools import partial
6 | from typing import List, Tuple
7 |
8 | from matplotlib import pyplot as plt
9 | from sympy import sieve
10 |
11 | from oraqle.add_chains.addition_chains_front import chain_depth, gen_pareto_front
12 | from oraqle.add_chains.addition_chains_mod import chain_cost, hw
13 |
14 |
15 | def experiment(
16 | t: int, squaring_cost: float
17 | ) -> Tuple[List[Tuple[int, float, List[Tuple[int, int]]]], float]:
18 | start = time.monotonic()
19 | chains = gen_pareto_front(
20 | t - 1,
21 | modulus=t - 1,
22 | squaring_cost=squaring_cost,
23 | solver="glucose42",
24 | encoding=1,
25 | thurber=True,
26 | )
27 | duration = time.monotonic() - start
28 |
29 | return [
30 | (chain_depth(chain, modulus=t - 1), chain_cost(chain, squaring_cost), chain)
31 | for _, chain in chains
32 | ], duration
33 |
34 |
35 | def experiment2(
36 | t: int, squaring_cost: float
37 | ) -> Tuple[List[Tuple[int, float, List[Tuple[int, int]]]], float]:
38 | start = time.monotonic()
39 | chains = gen_pareto_front(
40 | t - 1,
41 | modulus=None,
42 | squaring_cost=squaring_cost,
43 | solver="glucose42",
44 | encoding=1,
45 | thurber=True,
46 | )
47 | duration = time.monotonic() - start
48 |
49 | return [
50 | (chain_depth(chain), chain_cost(chain, squaring_cost), chain) for _, chain in chains
51 | ], duration
52 |
53 |
54 | def plot_specific_outputs(specific_outputs, specific_outputs_nomod, primes, squaring_cost: float):
55 | plt.figure(figsize=(9, 2.8))
56 | plt.grid(axis="y", zorder=-1000, alpha=0.5)
57 |
58 | for x, p in enumerate(primes):
59 | label = "Square & multiply" if p == 2 else None
60 | t = p - 1
61 | plt.scatter(
62 | x,
63 | math.ceil(math.log2(t)) * squaring_cost + hw(t) - 1,
64 | color="black",
65 | label=label,
66 | zorder=100,
67 | marker="_",
68 | )
69 |
70 | for x, outputs in enumerate(specific_outputs):
71 | chains, _ = outputs
72 | for depth, cost, _ in chains:
73 | plt.scatter(
74 | x,
75 | cost,
76 | color="black",
77 | zorder=100,
78 | s=50,
79 | label="Optimal circuit" if x == 0 else None,
80 | )
81 | if len(chains) > 1:
82 | plt.text(
83 | x,
84 | cost - 0.05,
85 | str(depth),
86 | fontsize=6,
87 | ha="center",
88 | va="center",
89 | color="white",
90 | zorder=200,
91 | fontweight="bold",
92 | )
93 |
94 | plt.xticks(range(len(primes)), primes, rotation=50)
95 | plt.yticks(range(2 * math.ceil(math.log2(primes[-1]))))
96 |
97 | plt.xlabel("Modulus")
98 | plt.ylabel("Multiplicative cost")
99 |
100 | ax1 = plt.gca()
101 | ax2 = ax1.twinx()
102 | for x, outputs in enumerate(specific_outputs):
103 | _, duration = outputs
104 | ax2.bar(x, duration, color="tab:cyan", zorder=0, alpha=0.3, label="Considering modulus" if x == 0 else None) # type: ignore
105 | for x, outputs in enumerate(specific_outputs_nomod):
106 | _, duration = outputs
107 | ax2.bar(x, duration, color="tab:cyan", zorder=0, alpha=1.0, label="Ignoring modulus" if x == 0 else None) # type: ignore
108 | ax2.set_ylabel("Generation time [s]", color="tab:cyan", alpha=1.0)
109 |
110 | ax1.step(
111 | range(len(primes)),
112 | [squaring_cost * math.ceil(math.log2(p - 1)) for p in primes],
113 | zorder=10,
114 | color="black",
115 | where="mid",
116 | label="Lower bound",
117 | linestyle=":",
118 | )
119 |
120 | # Combine legends from both axes
121 | lines, labels = ax1.get_legend_handles_labels()
122 | lines2, labels2 = ax2.get_legend_handles_labels() # type: ignore
123 | ax1.legend(lines + lines2, labels + labels2, loc="upper left", fontsize="small")
124 |
125 | plt.savefig(f"equality_first_prime_mods_{squaring_cost}.pdf", bbox_inches="tight")
126 | plt.show()
127 |
128 |
129 | if __name__ == "__main__":
130 | run_experiments = False
131 |
132 | if run_experiments:
133 | multiprocessing.set_start_method("fork")
134 | threads = 4
135 | pool = multiprocessing.Pool(threads)
136 |
137 | primes = list(sieve.primerange(300))[:30] # [:50]
138 |
139 | for sqr_cost in [0.5, 0.75, 1.0]:
140 | print(f"Computing for {sqr_cost}")
141 | experiment_sqr_cost = partial(experiment, squaring_cost=sqr_cost)
142 | outs = list(pool.map(experiment_sqr_cost, primes))
143 |
144 | with open(f"equality_experiment_{sqr_cost}_mod.pkl", mode="wb") as file:
145 | pickle.dump((primes, outs), file)
146 |
147 | for sqr_cost in [0.5, 0.75, 1.0]:
148 | print(f"Computing for {sqr_cost}")
149 | experiment_sqr_cost = partial(experiment2, squaring_cost=sqr_cost)
150 | outs = list(pool.map(experiment_sqr_cost, primes))
151 |
152 | with open(f"equality_experiment_{sqr_cost}_nomod.pkl", mode="wb") as file:
153 | pickle.dump((primes, outs), file)
154 |
155 | # Visualize
156 | with open("equality_experiment_0.5_mod.pkl", "rb") as file:
157 | primes_05_mod, outputs_05_mod = pickle.load(file)
158 | with open("equality_experiment_0.75_mod.pkl", "rb") as file:
159 | primes_075_mod, outputs_075_mod = pickle.load(file)
160 | with open("equality_experiment_1.0_mod.pkl", "rb") as file:
161 | primes_10_mod, outputs_10_mod = pickle.load(file)
162 |
163 | with open("equality_experiment_0.5_nomod.pkl", "rb") as file:
164 | primes_05_nomod, outputs_05_nomod = pickle.load(file)
165 | with open("equality_experiment_0.75_nomod.pkl", "rb") as file:
166 | primes_075_nomod, outputs_075_nomod = pickle.load(file)
167 | with open("equality_experiment_1.0_nomod.pkl", "rb") as file:
168 | primes_10_nomod, outputs_10_nomod = pickle.load(file)
169 |
170 | # All the primes should match
171 | primes = primes_10_mod
172 | assert primes == primes_05_mod
173 | assert primes == primes_075_mod
174 | assert primes == primes_05_nomod
175 | assert primes == primes_075_nomod
176 | assert primes == primes_10_nomod
177 |
178 | # All the chains should match (not in theory, but for this visualization they should)
179 | assert all(
180 | all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_05_mod, outputs_05_nomod)
181 | )
182 | assert all(
183 | all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_075_mod, outputs_075_nomod)
184 | )
185 | assert all(
186 | all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_10_mod, outputs_10_nomod)
187 | )
188 |
189 | plot_specific_outputs(outputs_05_mod, outputs_05_nomod, primes, squaring_cost=0.5)
190 | plot_specific_outputs(outputs_075_mod, outputs_075_nomod, primes, squaring_cost=0.75)
191 | plot_specific_outputs(outputs_10_mod, outputs_10_nomod, primes, squaring_cost=1.0)
192 |
--------------------------------------------------------------------------------
/oraqle/experiments/depth_aware_arithmetization/execution/poly_evaluation_pareto_front.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 |
4 | from galois import GF
5 | from matplotlib import pyplot as plt
6 | from matplotlib.ticker import MultipleLocator
7 |
8 | from oraqle.compiler.circuit import Circuit
9 | from oraqle.compiler.nodes.abstract import SizeParetoFront
10 | from oraqle.compiler.nodes.leafs import Input
11 | from oraqle.compiler.polynomials.univariate import (
12 | UnivariatePoly,
13 | _eval_poly,
14 | _eval_poly_alternative,
15 | _eval_poly_divide_conquer,
16 | )
17 |
18 | if __name__ == "__main__":
19 | sys.setrecursionlimit(15000)
20 |
21 | shape_size = 150
22 |
23 | plt.figure(figsize=(3.5, 4.4))
24 |
25 | marker1 = (3, 2, 0)
26 | marker2 = (3, 2, 40)
27 | marker3 = (3, 2, 80)
28 | o_marker = "o"
29 | linewidth = 2.5
30 |
31 | squaring_cost = 1.0
32 |
33 | p = 127 # 31
34 | gf = GF(p)
35 | for d in [p - 1]:
36 | x = Input("x", gf)
37 |
38 | poly = UnivariatePoly.from_function(x, gf, lambda x: x % 7)
39 | coefficients = poly._coefficients
40 |
41 | # Generate points
42 | print("Paterson & Stockmeyer")
43 | depths = []
44 | sizes = []
45 |
46 | front = SizeParetoFront()
47 |
48 | for k in range(1, len(coefficients)):
49 | res, pows = _eval_poly(x, coefficients, k, gf, squaring_cost)
50 | circ = Circuit([res]).arithmetize()
51 | depths.append(circ.multiplicative_depth())
52 | sizes.append(circ.multiplicative_size())
53 | front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore
54 | print(k, circ.multiplicative_depth(), circ.multiplicative_size())
55 |
56 | data = {(d, s) for d, s in zip(depths, sizes)}
57 | plt.scatter(
58 | [d for d, _ in data],
59 | [s for _, s in data],
60 | marker=marker2, # type: ignore
61 | zorder=10,
62 | alpha=0.4,
63 | s=shape_size,
64 | linewidth=linewidth,
65 | )
66 |
67 | print("Baby-step giant-step")
68 | depths2 = []
69 | sizes2 = []
70 | for k in range(1, len(coefficients)):
71 | res, pows = _eval_poly_alternative(x, coefficients, k, gf)
72 | circ = Circuit([res]).arithmetize()
73 | depths2.append(circ.multiplicative_depth())
74 | sizes2.append(circ.multiplicative_size())
75 | front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore
76 |
77 | data2 = {(d, s) for d, s in zip(depths2, sizes2)}
78 | plt.scatter(
79 | [d for d, _ in data2],
80 | [s for _, s in data2],
81 | marker=marker1, # type: ignore
82 | zorder=11,
83 | alpha=0.45,
84 | s=shape_size,
85 | linewidth=linewidth,
86 | )
87 |
88 | print("Divide and conquer")
89 | depths3 = []
90 | sizes3 = []
91 | for k in range(1, len(coefficients)):
92 | res, pows = _eval_poly_divide_conquer(x, coefficients, k, gf, squaring_cost)
93 | circ = Circuit([res]).arithmetize()
94 | depths3.append(circ.multiplicative_depth())
95 | sizes3.append(circ.multiplicative_size())
96 | front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore
97 |
98 | data3 = {(d, s) for d, s in zip(depths3, sizes3)}
99 | plt.scatter(
100 | [d for d, _ in data3],
101 | [s for _, s in data3],
102 | marker=marker3, # type: ignore
103 | zorder=11,
104 | alpha=0.45,
105 | s=shape_size,
106 | linewidth=linewidth,
107 | )
108 |
109 | # Plot the front
110 | front_initial = [(d, s) for d, s in data2 if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore
111 | front_advanced = [(d, s) for d, s in data if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore
112 | front_divconq = [(d, s) for d, s in data3 if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore
113 |
114 | plt.scatter(
115 | [d for d, _ in front_initial],
116 | [s for _, s in front_initial],
117 | marker=marker1, # type: ignore
118 | zorder=10,
119 | color="tab:orange",
120 | s=shape_size,
121 | label="Baby-step giant-step",
122 | linewidth=linewidth,
123 | )
124 | plt.scatter(
125 | [d for d, _ in front_advanced],
126 | [s for _, s in front_advanced],
127 | marker=marker2, # type: ignore
128 | zorder=10,
129 | color="tab:blue",
130 | s=shape_size,
131 | label="Paterson & Stockmeyer",
132 | linewidth=linewidth,
133 | )
134 | plt.scatter(
135 | [d for d, _ in front_divconq],
136 | [s for _, s in front_divconq],
137 | marker=marker3, # type: ignore
138 | zorder=10,
139 | color="tab:green",
140 | s=shape_size,
141 | label="Divide & Conquer",
142 | linewidth=linewidth,
143 | )
144 |
145 | k = round(math.sqrt(d / 2))
146 | res, pows = _eval_poly(x, coefficients, k, gf, squaring_cost)
147 | circ = Circuit([res]).arithmetize()
148 | plt.scatter(
149 | circ.multiplicative_depth(),
150 | circ.multiplicative_size(),
151 | marker=o_marker,
152 | s=shape_size + 50,
153 | facecolors="none",
154 | edgecolors="black",
155 | )
156 | plt.text(
157 | circ.multiplicative_depth(),
158 | circ.multiplicative_size() + 0.4,
159 | f"k = {k}",
160 | ha="center",
161 | fontsize=8,
162 | )
163 |
164 | plt.xlim((5, 15))
165 | plt.ylim((15, 30))
166 |
167 | plt.gca().set_aspect("equal")
168 |
169 | plt.gca().xaxis.set_minor_locator(MultipleLocator(1))
170 | plt.gca().yaxis.set_minor_locator(MultipleLocator(1))
171 |
172 | plt.grid(True, which="both", zorder=1, alpha=0.5)
173 |
174 | plt.xlabel("Multiplicative depth")
175 | plt.ylabel("Multiplicative size")
176 |
177 | plt.legend(fontsize="small")
178 |
179 | plt.savefig("poly_eval_front_2.pdf", bbox_inches="tight")
180 | plt.show()
181 |
--------------------------------------------------------------------------------
/oraqle/experiments/depth_aware_arithmetization/execution/run_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Get the directory where the script is located
4 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5 |
6 | # Change to the script's directory
7 | cd "$SCRIPT_DIR"
8 |
9 | # Loop through all Python files in the script's directory
10 | for file in *.py
11 | do
12 | # Check if there are any Python files
13 | if [ -e "$file" ]; then
14 | echo "Running $file"
15 | python3 "$file"
16 | else
17 | echo "No Python files found in the script's directory."
18 | break
19 | fi
20 | done
21 |
--------------------------------------------------------------------------------
/oraqle/experiments/depth_aware_arithmetization/execution/veto_voting_per_mod.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from galois import GF
4 | from matplotlib import pyplot as plt
5 | from sympy import sieve
6 |
7 | from oraqle.compiler.boolean.bool_and import _minimum_cost
8 | from oraqle.compiler.boolean.bool_or import Or
9 | from oraqle.compiler.circuit import Circuit
10 | from oraqle.compiler.nodes.abstract import CostParetoFront, UnoverloadedWrapper
11 | from oraqle.compiler.nodes.leafs import Input
12 | from oraqle.experiments.oraqle_spotlight.experiments.veto_voting_minimal_cost import (
13 | exponentiation_results,
14 | )
15 |
16 |
17 | def generate_all_fronts():
18 | results = {}
19 |
20 | for p in [7, 11, 13, 17]:
21 | fronts = []
22 |
23 | print(f"------ p = {p} ------")
24 | for k in range(2, 51):
25 | gf = GF(p)
26 | xs = [Input(f"x{i}", gf) for i in range(k)]
27 |
28 | circuit = Circuit([Or(set(UnoverloadedWrapper(x) for x in xs), gf)])
29 | front = circuit.arithmetize_depth_aware(cost_of_squaring=1.0)
30 |
31 | print(f"{k}.", end=" ")
32 | for f in front:
33 | print(f[0], f[1], end=" ")
34 |
35 | print()
36 | fronts.append(front)
37 |
38 | results[p] = fronts
39 |
40 | return results
41 |
42 |
43 | def plot_fronts(fronts: List[CostParetoFront], color, label, **kwargs):
44 | plt.scatter([], [], color=color, label=label, **kwargs)
45 | for k, front in zip(range(2, 51), fronts):
46 | for depth, cost, _ in front:
47 | kwargs["marker"] = (depth, 2, 0)
48 | kwargs["s"] = 16
49 | kwargs["linewidth"] = 0.5
50 | plt.scatter(k, cost, color=color, **kwargs)
51 |
52 |
53 | if __name__ == "__main__":
54 | fronts_by_p = generate_all_fronts()
55 | max_k = 50
56 |
57 | plt.figure(figsize=(4, 4))
58 |
59 | plt.plot(
60 | range(2, max_k + 1),
61 | [k - 1 for k in range(2, max_k + 1)],
62 | color="gray",
63 | linestyle="solid",
64 | label="Naive",
65 | linewidth=0.7,
66 | )
67 |
68 | plot_fronts(fronts_by_p[7], "tab:purple", "Modulus p = 7", zorder=100)
69 | plot_fronts(fronts_by_p[13], "tab:green", "Modulus p = 13", zorder=100)
70 |
71 | best_costs = [100000000.0] * (max_k + 1)
72 | best_ps = [None] * (max_k + 1)
73 | # This is for sqr = 0.75 mul
74 | primes = list(sieve.primerange(300))[1:50]
75 | for p in primes:
76 | for k in range(2, max_k + 1):
77 | cost = _minimum_cost(k, exponentiation_results[p][0][0][1], p)
78 | if cost < best_costs[k - 2]:
79 | best_costs[k - 2] = cost
80 | best_ps[k - 2] = p
81 |
82 | plt.step(
83 | range(2, max_k + 1),
84 | best_costs[:-2],
85 | zorder=10,
86 | color="gray",
87 | where="mid",
88 | label="Lowest for any p",
89 | linestyle="solid",
90 | linewidth=0.7,
91 | )
92 |
93 | plt.legend()
94 |
95 | plt.xlabel("Number of operands")
96 | plt.ylabel("Multiplicative size")
97 |
98 | plt.savefig("veto_voting.pdf", bbox_inches="tight")
99 | plt.show()
100 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/examples/and_16.py:
--------------------------------------------------------------------------------
1 | from galois import GF
2 |
3 | from oraqle.compiler.boolean.bool_and import all_
4 | from oraqle.compiler.circuit import Circuit
5 | from oraqle.compiler.nodes.leafs import Input
6 |
7 | if __name__ == "__main__":
8 | gf = GF(17)
9 |
10 | xs = (Input(f"x{i + 1}", gf) for i in range(16))
11 |
12 | conjunction = all_(*xs)
13 |
14 | circuit = Circuit([conjunction])
15 | arithmetic_circuit = circuit.arithmetize()
16 |
17 | arithmetic_circuit.to_pdf("conjunction.pdf")
18 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/examples/common_expressions.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from galois import GF
4 |
5 | from oraqle.compiler.circuit import Circuit
6 | from oraqle.compiler.nodes.abstract import Node
7 | from oraqle.compiler.nodes.arbitrary_arithmetic import sum_
8 | from oraqle.compiler.nodes.leafs import Input
9 |
10 |
11 | def generate_nodes() -> Tuple[Node, Node]:
12 | gf = GF(31)
13 |
14 | x = Input("x", gf)
15 | y = Input("y", gf)
16 | z1 = Input("z1", gf)
17 | z2 = Input("z2", gf)
18 | z3 = Input("z3", gf)
19 | z4 = Input("z4", gf)
20 |
21 | comparison = x < y
22 | sum = sum_(z1, z2, z3, z4)
23 | cse1 = comparison & sum
24 |
25 | comparison = y > x
26 | sum = sum_(z3, z2, z4) + z1
27 | cse2 = sum & comparison
28 |
29 | return cse1, cse2
30 |
31 |
32 | def test_cse_equivalence():
33 | cse1, cse2 = generate_nodes()
34 | assert cse1.is_equivalent(cse2)
35 |
36 |
37 | if __name__ == "__main__":
38 | cse1, cse2 = generate_nodes()
39 |
40 | cse1 = Circuit([cse1])
41 | cse2 = Circuit([cse2])
42 |
43 | cse1.to_pdf("cse1.pdf")
44 | cse2.to_pdf("cse2.pdf")
45 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/examples/equality_31.py:
--------------------------------------------------------------------------------
1 | from galois import GF
2 |
3 | from oraqle.compiler.circuit import Circuit
4 | from oraqle.compiler.nodes.leafs import Input
5 |
6 | if __name__ == "__main__":
7 | gf = GF(31)
8 |
9 | x = Input("x", gf)
10 | y = Input("y", gf)
11 |
12 | equality = x == y
13 |
14 | circuit = Circuit([equality])
15 | arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring=1.0)
16 |
17 | for d, _, arithmetic_circuit in arithmetic_circuits:
18 | arithmetic_circuit.to_pdf(f"equality_{d}.pdf")
19 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/examples/equality_and_comparison.py:
--------------------------------------------------------------------------------
1 | from galois import GF
2 |
3 | from oraqle.compiler.circuit import Circuit
4 | from oraqle.compiler.nodes.leafs import Input
5 |
6 | if __name__ == "__main__":
7 | gf = GF(31)
8 |
9 | x = Input("x", gf)
10 | y = Input("y", gf)
11 | z = Input("z", gf)
12 |
13 | comparison = x < y
14 | equality = y == z
15 | both = comparison & equality
16 |
17 | circuit = Circuit([both])
18 |
19 | circuit.to_pdf("example.pdf")
20 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/examples/t2_comparison.py:
--------------------------------------------------------------------------------
1 | from galois import GF
2 |
3 | from oraqle.compiler.circuit import Circuit
4 | from oraqle.compiler.nodes.leafs import Input
5 |
6 | p = 7
7 | gf = GF(p)
8 |
9 | x = Input("x", gf)
10 | y = Input("y", gf)
11 |
12 | comparison = 0
13 |
14 | for a in range((p + 1) // 2, p):
15 | comparison += 1 - (x - y - a) ** (p - 1)
16 |
17 | circuit = Circuit([comparison]) # type: ignore
18 |
19 | if __name__ == "__main__":
20 | circuit.to_graph("t2.dot")
21 | circuit.to_pdf("t2.pdf")
22 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/experiments/comparisons/comparisons_bench.py:
--------------------------------------------------------------------------------
1 | import random
2 | import subprocess
3 |
4 | from galois import GF
5 | from matplotlib import pyplot as plt
6 | from sympy import sieve
7 |
8 | from oraqle.compiler.circuit import ArithmeticCircuit, Circuit
9 | from oraqle.compiler.comparison.comparison import SemiStrictComparison, T2SemiLessThan
10 | from oraqle.compiler.nodes.leafs import Input
11 |
12 |
13 | def run_benchmark(arithmetic_circuit: ArithmeticCircuit) -> float:
14 | # Prepare the benchmark
15 | arithmetic_circuit.generate_code("main.cpp", iterations=10, measure_time=True)
16 | subprocess.run("make", capture_output=True, check=True)
17 |
18 | # Run the benchmark
19 | command = ["./main"]
20 | p = arithmetic_circuit._gf.characteristic
21 | command.append(f"x={random.randint(0, p - 1)}")
22 | command.append(f"y={random.randint(0, p - 1)}")
23 | print("Running:", " ".join(command))
24 | result = subprocess.run(command, capture_output=True, text=True, check=False)
25 |
26 | if result.returncode != 0:
27 | print("stderr:")
28 | print(result.stderr)
29 | print()
30 | print("stdout:")
31 | print(result.stdout)
32 |
33 | # Check if the noise was not too large
34 | print(result.stdout)
35 | lines = result.stdout.splitlines()
36 | for line in lines[:-1]:
37 | assert line.endswith("1")
38 |
39 | run_time = float(lines[-1]) / 10
40 | print(p, run_time)
41 |
42 | return run_time
43 |
44 |
45 | if __name__ == "__main__":
46 | slides = True
47 | run_benchmarks = False
48 | gen_plots = True
49 |
50 | if run_benchmarks:
51 | primes = list(sieve.primerange(300))[2:20]
52 |
53 | our_times = []
54 | t2_times = []
55 |
56 | for p in primes:
57 | gf = GF(p)
58 |
59 | x = Input("x", gf)
60 | y = Input("y", gf)
61 |
62 | print(f"-------- p = {p}: ---------")
63 | our_circuit = Circuit([SemiStrictComparison(x, y, less_than=True, gf=gf)])
64 | our_front = our_circuit.arithmetize_depth_aware()
65 | print("Our circuits:", our_front)
66 |
67 | ts = []
68 | for _, _, arithmetic_circuit in our_front:
69 | ts.append(run_benchmark(arithmetic_circuit))
70 | our_times.append(tuple(ts))
71 |
72 | t2_circuit = Circuit([T2SemiLessThan(x, y, gf)])
73 | t2_arithmetization = t2_circuit.arithmetize()
74 | print(
75 | "T2 circuit:",
76 | t2_arithmetization.multiplicative_depth(),
77 | t2_arithmetization.multiplicative_size(),
78 | )
79 |
80 | t2_times.append(run_benchmark(t2_arithmetization))
81 |
82 | print(primes)
83 | print(our_times)
84 | print(t2_times)
85 |
86 | if gen_plots:
87 | primes = [5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]
88 | our_times = [(0.0156603,), (0.0523416,), (0.0954489,), (0.0936497,), (0.111959,), (0.128402,), (0.288951,), (0.42076, 0.368583), (0.416362,), (0.40343,), (0.385652,), (0.437486,), (0.481356,), (0.522607, 0.504944), (0.526451,), (0.5904119999999999, 0.5146740000000001), (0.592896,), (0.621265, 0.598357)]
89 | t2_times = [0.0156379, 0.0938689, 0.23473899999999998, 0.319668, 0.366707, 0.6632450000000001, 1.8380299999999998, 1.14859, 2.9022200000000002, 3.2060299999999997, 3.5419899999999997, 4.53918, 5.02624, 5.4439, 8.64118, 6.6267499999999995, 6.99609, 9.21295]
90 |
91 | if slides:
92 | plt.figure(figsize=(7, 4))
93 | else:
94 | plt.figure(figsize=(4, 2))
95 | plt.grid(axis="y", zorder=-1000, alpha=0.5)
96 |
97 | plt.scatter(
98 | range(len(primes)), t2_times, marker="_", label="T2's Circuit", color="tab:orange", s=100 if slides else None
99 | )
100 |
101 | for x, ts in enumerate(our_times):
102 | for t in ts:
103 | plt.scatter(
104 | x,
105 | t,
106 | marker="_",
107 | label="Oraqle's circuits" if x == 0 else None,
108 | color="tab:cyan",
109 | s=100 if slides else None
110 | )
111 |
112 | plt.xticks(range(len(primes)), primes, fontsize=8) # type: ignore
113 |
114 | plt.xlabel("Modulus")
115 | plt.ylabel("Run time (s)")
116 |
117 | plt.legend()
118 |
119 | plt.savefig(f"t2_comparison{'_slides' if slides else ''}.pdf", bbox_inches="tight")
120 | plt.show()
121 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/experiments/large_equality/.gitignore:
--------------------------------------------------------------------------------
1 | /CMakeFiles
2 | CMakeCache.txt
3 | cmake_install.cmake
4 | helib.log
5 | Makefile
6 | main
7 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/experiments/large_equality/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.10.2 FATAL_ERROR)
2 |
3 | set(CMAKE_CXX_STANDARD 17)
4 | set(CMAKE_CXX_EXTENSIONS OFF)
5 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
6 |
7 | find_package(helib)
8 | add_executable(main main.cpp)
9 | target_link_libraries(main helib)
10 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/experiments/large_equality/large_equality.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import subprocess
4 | import time
5 | from typing import List, Tuple
6 |
7 | from galois import GF
8 | from sympy import sieve
9 |
10 | from oraqle.compiler.boolean.bool_and import all_
11 | from oraqle.compiler.circuit import ArithmeticCircuit, Circuit
12 | from oraqle.compiler.nodes.leafs import Input
13 |
14 |
15 | def generate_circuits(bits: int) -> List[Tuple[int, ArithmeticCircuit, int, float]]:
16 | circuits = []
17 |
18 | primes = list(sieve.primerange(300))[:10] # [:55] # p <= 257
19 | start = time.monotonic()
20 | times = []
21 | for p in primes:
22 | # (6, 63.0): p=2
23 | # (7, 58.0): p=5
24 | # (8, 51.0): p=17
25 |
26 | limbs = math.ceil(bits / math.log2(p))
27 |
28 | gf = GF(p)
29 |
30 | xs = [Input(f"x{i}", gf) for i in range(limbs)]
31 | ys = [Input(f"y{i}", gf) for i in range(limbs)]
32 | circuit = Circuit([all_(*(xs[i] == ys[i] for i in range(limbs)))])
33 |
34 | inbetween = time.monotonic()
35 | front = circuit.arithmetize_depth_aware(0.75)
36 |
37 | print(f"{p}.", end=" ")
38 |
39 | for f in front:
40 | circuits.append((p, f[2], f[0], f[1]))
41 | print(f[0], f[1], end=" ")
42 |
43 | inbetween_time = time.monotonic() - inbetween
44 | print(inbetween_time)
45 | times.append((p, inbetween_time))
46 |
47 | print(times)
48 | print("Total time", time.monotonic() - start)
49 |
50 | return circuits
51 |
52 |
53 | if __name__ == "__main__":
54 | bits = 64
55 | benchmark_circuits = False
56 | generate_table = True
57 |
58 | # Run a benchmark for all circuits in the front
59 | if benchmark_circuits:
60 | # Generate all circuits per p
61 | circuits = generate_circuits(bits)
62 |
63 | results = []
64 | for p, arithmetic_circuit, d, c in circuits:
65 | # Prepare the benchmark
66 | params = arithmetic_circuit.generate_code("main.cpp", iterations=10, measure_time=True)
67 | subprocess.run("make", check=True)
68 |
69 | # Run the benchmark
70 | command = ["./main"]
71 | limbs = math.ceil(bits / math.log2(p))
72 | for i in range(limbs):
73 | command.append(f"x{i}={random.randint(0, p - 1)}")
74 | command.append(f"y{i}={random.randint(0, p - 1)}")
75 | print("Running:", " ".join(command))
76 | result = subprocess.run(command, capture_output=True, text=True, check=False)
77 |
78 | if result.returncode != 0:
79 | print("stderr:")
80 | print(result.stderr)
81 | print()
82 | print("stdout:")
83 | print(result.stdout)
84 |
85 | # Check if the noise was not too large
86 | print(result.stdout)
87 | lines = result.stdout.splitlines()
88 | for line in lines[:-1]:
89 | assert line.endswith("1")
90 |
91 | run_time = float(lines[-1]) / 10
92 | print(p, run_time, d, c, params)
93 | results.append((p, d, c, params, run_time))
94 |
95 | print(results)
96 |
97 | if generate_table:
98 | gen_times = [(2, 0.007554411888122559), (3, 0.06264467351138592), (5, 8.457202550023794), (7, 0.05447225831449032), (11, 0.0478445328772068), (13, 0.052152080461382866), (17, 0.04349260404706001), (19, 0.04553743451833725), (23, 0.05198719538748264), (29, 0.046183058992028236)]
99 | results = [(2, 6, 63.0, (16383, 1, 142, 3), 3.27577), (3, 7, 60.75, (32768, 1, 170, 3), 1.51993), (5, 7, 58.0, (32768, 1, 178, 3), 1.7679099999999999), (5, 8, 55.5, (32768, 1, 197, 3), 1.93994), (7, 8, 74.0, (32768, 1, 206, 3), 2.90913), (7, 9, 70.0, (32768, 1, 226, 3), 2.6624600000000003), (7, 10, 69.5, (32768, 1, 246, 3), 3.00814), (11, 9, 69.25, (32768, 1, 228, 3), 2.50603), (11, 12, 68.25, (32768, 1, 300, 3), 3.25469), (13, 9, 68.75, (32768, 1, 237, 3), 2.67845), (13, 10, 67.75, (32768, 1, 237, 3), 2.7718), (13, 11, 66.0, (32768, 1, 237, 3), 2.56386), (13, 12, 65.0, (32768, 1, 301, 3), 3.10959), (17, 8, 51.0, (32768, 1, 217, 3), 1.8792300000000002), (19, 9, 79.0, (32768, 1, 238, 3), 2.85011), (19, 10, 68.0, (32768, 1, 259, 3), 2.8636500000000003), (23, 9, 89.0, (32768, 1, 248, 3), 4.135730000000001), (23, 10, 80.0, (32768, 1, 270, 3), 3.75128), (29, 9, 83.0, (32768, 1, 249, 3), 3.7119), (29, 10, 75.0, (32768, 1, 271, 3), 3.46666)]
100 |
101 | gen_times = {p: t for p, t in gen_times}
102 |
103 | for p, d, c, params, run_time in results:
104 | print(f"{p} & {d} & {c} & {params[0]} & {params[1]} & {params[2]} & {params[3]} & {round(gen_times[p], 2)} & {round(run_time, 2)} \\\\")
105 |
--------------------------------------------------------------------------------
/oraqle/experiments/oraqle_spotlight/experiments/veto_voting_minimal_cost.py:
--------------------------------------------------------------------------------
1 | """Finds the minimum cost for veto voting circuits for different prime moduli."""
2 |
3 | from sympy import sieve
4 |
5 | from oraqle.compiler.boolean.bool_and import _minimum_cost
6 |
7 | exponentiation_results = {
8 | 2: ([(0, 0.0)], 8.633400000002123e-05),
9 | 3: ([(1, 0.75)], 4.6670000000137435e-06),
10 | 5: ([(2, 1.5)], 7.695799999996034e-05),
11 | 7: ([(3, 2.5)], 0.0053472920000000035),
12 | 11: ([(4, 3.25)], 0.007671625000000015),
13 | 13: ([(4, 3.25)], 0.002812749999999975),
14 | 17: ([(4, 3.0)], 7.891700000001167e-05),
15 | 19: ([(5, 4.0)], 0.012155541999999964),
16 | 23: ([(5, 5.0)], 0.03937258299999996),
17 | 29: ([(5, 5.0)], 0.018942542000000007),
18 | 31: ([(5, 6.0), (6, 5.0)], 0.064326),
19 | 37: ([(6, 4.75)], 0.019883207999999986),
20 | 41: ([(6, 4.75)], 0.02284237499999997),
21 | 43: ([(6, 5.75)], 0.03223737499999996),
22 | 47: ([(6, 6.75), (7, 6.0)], 0.607119292),
23 | 53: ([(6, 5.75)], 0.03940958299999997),
24 | 59: ([(6, 6.75)], 1.243811584),
25 | 61: ([(6, 6.75), (7, 5.75)], 0.446000167),
26 | 67: ([(7, 5.5)], 0.051902208000000005),
27 | 71: ([(7, 6.5)], 0.18221370799999997),
28 | 73: ([(7, 5.5)], 0.044685417000000005),
29 | 79: ([(7, 6.75)], 0.362901958),
30 | 83: ([(7, 6.5)], 0.121000375),
31 | 89: ([(7, 6.5)], 0.182695375),
32 | 97: ([(7, 5.5)], 0.06858350000000002),
33 | 101: ([(7, 6.5)], 0.38408749999999997),
34 | 103: ([(7, 7.5), (8, 6.5)], 3.3626029170000002),
35 | 107: ([(7, 7.5)], 8.891771667),
36 | 109: ([(7, 7.5), (8, 6.5)], 4.596561917),
37 | 113: ([(7, 6.5)], 0.1859389579999995),
38 | 127: ([(7, 9.5), (8, 7.5)], 1619.89318625),
39 | 131: ([(8, 6.25)], 0.05858354099996177),
40 | 137: ([(8, 6.25)], 0.10623299999999991),
41 | 139: ([(8, 7.25)], 1.2351711669999998),
42 | 149: ([(8, 7.25)], 0.48292875),
43 | 151: ([(8, 7.5)], 4.641820375),
44 | 157: ([(8, 7.5)], 2.49218775),
45 | 163: ([(8, 7.25)], 0.5001321249999999),
46 | 167: ([(8, 8.25), (9, 7.5)], 48.444338791),
47 | 173: ([(8, 8.25), (9, 7.5)], 37.677076833),
48 | 179: ([(8, 8.25)], 132.232723375),
49 | 181: ([(8, 8.25), (9, 7.25)], 53.822612083999985),
50 | 191: ([(8, 9.25), (9, 8.25)], 907.7980847910001),
51 | 193: ([(8, 6.25)], 0.12370429100008096),
52 | 197: ([(8, 7.25)], 0.6496936670000002),
53 | 199: ([(8, 8.25), (9, 7.25)], 50.102889333),
54 | 211: ([(8, 8.25)], 83.20584475),
55 | 223: ([(8, 10.0), (9, 8.25)], 6772.927301542),
56 | 227: ([(8, 8.25)], 50.801469917),
57 | 229: ([(8, 8.25)], 39.942074416000004),
58 | }
59 |
60 |
61 | def run_experiments():
62 | """Run the experiments and prints the results."""
63 | max_k = 50
64 | best_costs = [100000000.0] * (max_k + 1)
65 | best_ps = [None] * (max_k + 1)
66 |
67 | # This is for sqr = 0.75 mul
68 | primes = list(sieve.primerange(300))[1:50]
69 | for p in primes:
70 | print(f"------ p = {p} ------")
71 | for k in range(2, max_k + 1):
72 | cost = _minimum_cost(k, exponentiation_results[p][0][0][1], p)
73 | if cost < best_costs[k - 2]:
74 | best_costs[k - 2] = cost
75 | best_ps[k - 2] = p
76 |
77 | for k, cost, p in zip(range(2, max_k + 1), best_costs, best_ps):
78 | print(k, cost, p)
79 |
80 |
81 | if __name__ == "__main__":
82 | run_experiments()
83 |
--------------------------------------------------------------------------------
/oraqle/helib_template/.gitignore:
--------------------------------------------------------------------------------
1 | /CMakeFiles
2 | CMakeCache.txt
3 | cmake_install.cmake
4 | helib.log
5 | Makefile
6 | main
7 |
--------------------------------------------------------------------------------
/oraqle/helib_template/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.10.2 FATAL_ERROR)
2 |
3 | set(CMAKE_CXX_STANDARD 17)
4 | set(CMAKE_CXX_EXTENSIONS OFF)
5 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
6 |
7 | find_package(helib)
8 | add_executable(main main.cpp)
9 | target_link_libraries(main helib)
10 |
--------------------------------------------------------------------------------
/oraqle/helib_template/__init__.py:
--------------------------------------------------------------------------------
1 | """Template containing all the things to build an HElib program."""
2 |
--------------------------------------------------------------------------------
/oraqle/helib_template/main.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include
5 |
6 | #include
7 |
8 | typedef helib::Ptxt ptxt_t;
9 | typedef helib::Ctxt ctxt_t;
10 |
11 | std::map input_map;
12 |
13 | void parse_arguments(int argc, char* argv[]) {
14 | for (int i = 1; i < argc; ++i) {
15 | std::string argument(argv[i]);
16 | size_t pos = argument.find('=');
17 | if (pos != std::string::npos) {
18 | std::string key = argument.substr(0, pos);
19 | int value = std::stoi(argument.substr(pos + 1));
20 | input_map[key] = value;
21 | }
22 | }
23 | }
24 |
25 | int extract_input(const std::string& name) {
26 | if (input_map.find(name) != input_map.end()) {
27 | return input_map[name];
28 | } else {
29 | std::cerr << "Error: " << name << " not found" << std::endl;
30 | return -1;
31 | }
32 | }
33 |
34 | int main(int argc, char* argv[]) {
35 | // Parse the inputs
36 | parse_arguments(argc, argv);
37 |
38 | // Set up the HE parameters
39 | unsigned long p = 5;
40 | unsigned long m = 8192;
41 | unsigned long r = 1;
42 | unsigned long bits = 72;
43 | unsigned long c = 3;
44 | helib::Context context = helib::ContextBuilder()
45 | .m(m)
46 | .p(p)
47 | .r(r)
48 | .bits(bits)
49 | .c(c)
50 | .build();
51 |
52 |
53 | // Generate keys
54 | helib::SecKey secret_key(context);
55 | secret_key.GenSecKey();
56 | helib::addSome1DMatrices(secret_key);
57 | const helib::PubKey& public_key = secret_key;
58 |
59 | // Encrypt the inputs
60 | std::vector vec_x(1, extract_input("x"));
61 | ptxt_t ptxt_x(context, vec_x);
62 | ctxt_t ciph_x(public_key);
63 | public_key.Encrypt(ciph_x, ptxt_x);
64 | std::vector vec_y(1, extract_input("y"));
65 | ptxt_t ptxt_y(context, vec_y);
66 | ctxt_t ciph_y(public_key);
67 | public_key.Encrypt(ciph_y, ptxt_y);
68 |
69 | // Perform the actual circuit
70 | ctxt_t stack_0 = ciph_x;
71 | ctxt_t stack_1 = ciph_y;
72 | stack_1 *= 4l;
73 | stack_0 += stack_1;
74 | stack_0 *= stack_0;
75 | stack_0 *= stack_0;
76 | stack_0 *= 4l;
77 | stack_0 += 1l;
78 | ptxt_t decrypted(context);
79 | secret_key.Decrypt(decrypted, stack_0);
80 | std::cout << decrypted << std::endl;
81 |
82 | return 0;
83 | }
84 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "oraqle"
3 | description = "Secure computation compiler for homomorphic encryption and arithmetic circuits in general"
4 | version = "0.1.6"
5 | requires-python = ">= 3.8"
6 | authors = [
7 | {name = "Jelle Vos", email = "J.V.Vos@tudelft.nl"},
8 | ]
9 | maintainers = [
10 | {name = "Jelle Vos", email = "J.V.Vos@tudelft.nl"}
11 | ]
12 | readme = "README.md"
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | sympy
2 | six
3 | galois>=0.3.8
4 | aeskeyschedule
5 | python-sat
6 | git+https://github.com/jellevos/fhegen.git
7 | matplotlib
8 | ipython
9 |
--------------------------------------------------------------------------------
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | gensafeprime
3 | graphviz
4 | tabulate
5 | ruff
6 | mkdocs
7 | mkdocstrings[python]
8 | mkautodoc
9 | mkdocs-material
10 | pymdown-extensions
11 |
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | # Exclude a variety of commonly ignored directories.
2 | exclude = [
3 | ".bzr",
4 | ".direnv",
5 | ".eggs",
6 | ".ipynb_checkpoints",
7 | ".mypy_cache",
8 | ".pyenv",
9 | ".pytest_cache",
10 | ".pytype",
11 | ".ruff_cache",
12 | ".venv",
13 | ".vscode",
14 | "__pypackages__",
15 | "_build",
16 | "build",
17 | "dist",
18 | "node_modules",
19 | "site-packages",
20 | "venv",
21 | ]
22 |
23 | line-length = 100
24 | indent-width = 4
25 | target-version = "py38"
26 |
27 | [lint]
28 | preview = true
29 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
30 | # McCabe complexity (`C901`) by default.
31 | select = ["W", "E4", "E7", "E9", "F", "ERA001", "B", "D", "DOC", "PLW", "B", "SIM", "UP", "PLR", "RUF", "PIE"]
32 | ignore = ["E203", "E501", "E731", "D105", "W293", "PLR2004", "PLR6301"]
33 |
34 | # Allow fix for all enabled rules (when `--fix`) is provided.
35 | fixable = ["ALL"]
36 | unfixable = []
37 |
38 | # Allow unused variables when underscore-prefixed.
39 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
40 |
41 | [lint.per-file-ignores]
42 | "oraqle/experiments/*" = ["D", "DOC"]
43 |
44 | [lint.pydocstyle]
45 | # Use Google-style docstrings.
46 | convention = "google"
47 |
48 | [format]
49 | preview = true
50 | quote-style = "double"
51 | indent-style = "space"
52 | skip-magic-trailing-comma = false
53 | line-ending = "auto"
54 | docstring-code-format = true
55 | docstring-code-line-length = "dynamic"
56 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [tool:pytest]
2 | python_files = *.py
3 | norecursedirs = venv build
4 |
--------------------------------------------------------------------------------
/tests/test_circuit_sizes_costs.py:
--------------------------------------------------------------------------------
1 | """Test file for testing circuits sizes."""
2 |
3 | from collections import Counter
4 |
5 | from galois import GF
6 |
7 | from oraqle.compiler.nodes.abstract import ArithmeticNode, UnoverloadedWrapper
8 | from oraqle.compiler.nodes.arbitrary_arithmetic import Sum
9 | from oraqle.compiler.nodes.leafs import Constant, Input
10 |
11 |
12 | def test_size_exponentiation_chain():
13 | """Test."""
14 | gf = GF(101)
15 |
16 | x = Input("x", gf)
17 |
18 | x = x.mul(x, flatten=False)
19 | x = x.mul(x, flatten=False)
20 | x = x.mul(x, flatten=False)
21 |
22 | x = x.to_arithmetic()
23 | assert isinstance(x, ArithmeticNode)
24 | assert (
25 | x.multiplicative_size() == 3
26 | ), f"((x^2)^2)^2 should be 3 multiplications, but counted {x.multiplicative_size()}"
27 | assert x.multiplicative_cost(0.5) == 1.5
28 |
29 |
30 | def test_size_sum_of_products():
31 | """Test."""
32 | gf = GF(101)
33 |
34 | a = Input("a", gf)
35 | b = Input("b", gf)
36 | c = Input("c", gf)
37 | d = Input("d", gf)
38 |
39 | ab = a * b
40 | cd = c * d
41 |
42 | out = ab + cd
43 | out = out.to_arithmetic()
44 |
45 | assert isinstance(out, ArithmeticNode)
46 | assert (
47 | out.multiplicative_size() == 2
48 | ), f"a * b + c * d should be 2 multiplications, but counted {out.multiplicative_size()}"
49 | assert out.multiplicative_cost(0.7) == 2
50 |
51 |
52 | def test_size_linear_function():
53 | """Test."""
54 | gf = GF(101)
55 |
56 | a = Input("a", gf)
57 | b = Input("b", gf)
58 | c = Input("c", gf)
59 |
60 | out = Sum(
61 | Counter({UnoverloadedWrapper(a): 1, UnoverloadedWrapper(b): 3, UnoverloadedWrapper(c): 1}),
62 | gf,
63 | gf(2),
64 | )
65 |
66 | out = out.to_arithmetic()
67 | assert out.multiplicative_size() == 0
68 | assert out.multiplicative_cost(0.5) == 0
69 |
70 |
71 | def test_size_duplicate_nodes():
72 | """Test."""
73 | gf = GF(101)
74 |
75 | x = Input("x", gf)
76 |
77 | add1 = x.add(Constant(gf(1)))
78 | add2 = x.add(Constant(gf(1)))
79 |
80 | mul1 = x.mul(x, flatten=False)
81 | mul2 = x.mul(x, flatten=False)
82 |
83 | add3 = mul2.add(add2, flatten=False)
84 |
85 | mul3 = mul1.mul(add3, flatten=False)
86 |
87 | out = add1.add(mul3, flatten=False)
88 |
89 | out = out.to_arithmetic()
90 |
91 | assert isinstance(out, ArithmeticNode)
92 | assert out.multiplicative_size() == 3
93 | assert out.multiplicative_cost(0.7) == 2.4
94 |
--------------------------------------------------------------------------------
/tests/test_poly2circuit.py:
--------------------------------------------------------------------------------
1 | """Test file for generating circuits using polynomial interpolation."""
2 |
3 | import itertools
4 |
5 | from oraqle.compiler.func2poly import interpolate_polynomial
6 | from oraqle.compiler.poly2circuit import construct_circuit
7 |
8 |
9 | def _construct_and_test_circuit_from_bivariate_lambda(function, modulus: int, cse=False):
10 | poly = interpolate_polynomial(function, modulus, ["x", "y"])
11 | circuit, gf = construct_circuit([poly], modulus)
12 | circuit = circuit.arithmetize()
13 |
14 | if cse:
15 | circuit.eliminate_subexpressions()
16 |
17 | for x, y in itertools.product(range(modulus), repeat=2):
18 | print(function, x, y)
19 | assert circuit.evaluate({"x": gf(x), "y": gf(y)}) == [function(x, y)]
20 |
21 |
22 | def test_inequality_mod7():
23 | """Tests x != y (mod 7)."""
24 | _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: int(x != y), modulus=7)
25 |
26 |
27 | def test_inequality_mod13():
28 | """Tests x != y (mod 13)."""
29 | _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: int(x != y), modulus=13)
30 |
31 |
32 | def test_max_mod7():
33 | """Tests max(x, y) (mod 7)."""
34 | _construct_and_test_circuit_from_bivariate_lambda(max, modulus=7)
35 |
36 |
37 | def test_max_mod13():
38 | """Tests max(x, y) (mod 13)."""
39 | _construct_and_test_circuit_from_bivariate_lambda(max, modulus=13)
40 |
41 |
42 | def test_xor_mod11():
43 | """Tests x ^ y (mod 11)."""
44 | _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: (x ^ y) % 11, modulus=11)
45 |
46 |
47 | def test_inequality_mod11_cse():
48 | """Tests x ^ y (mod 11) with CSE."""
49 | _construct_and_test_circuit_from_bivariate_lambda(
50 | lambda x, y: int(x != y), modulus=11, cse=True
51 | )
52 |
53 |
54 | def test_max_mod7_cse():
55 | """Tests max(x, y) (mod 7) with CSE."""
56 | _construct_and_test_circuit_from_bivariate_lambda(max, modulus=7, cse=True)
57 |
58 |
59 | def test_xor_mod13_cse():
60 | """Tests x ^ y (mod 13) with CSE."""
61 | _construct_and_test_circuit_from_bivariate_lambda(
62 | lambda x, y: (x ^ y) % 13, modulus=13, cse=True
63 | )
64 |
--------------------------------------------------------------------------------
/tests/test_sugar_expressions.py:
--------------------------------------------------------------------------------
1 | """Test file for sugar expressions."""
2 |
3 | from galois import GF
4 |
5 | from oraqle.compiler.circuit import Circuit
6 | from oraqle.compiler.nodes.arbitrary_arithmetic import sum_
7 | from oraqle.compiler.nodes.leafs import Input
8 |
9 |
10 | def test_sum():
11 | """Tests the sum_ function."""
12 | gf = GF(127)
13 |
14 | a = Input("a", gf)
15 | b = Input("b", gf)
16 |
17 | arithmetic_circuit = Circuit([sum_(a, 4, b, 3)]).arithmetize()
18 |
19 | for val_a in range(127):
20 | for val_b in range(127):
21 | expected = gf(val_a) + gf(val_b) + gf(7)
22 | assert arithmetic_circuit.evaluate({"a": gf(val_a), "b": gf(val_b)}) == expected
23 |
--------------------------------------------------------------------------------