├── .github └── workflows │ ├── github_pages.yml │ ├── publish_to_pypi.yml │ └── python-app.yml ├── .gitignore ├── CITATION.cff ├── LICENSE.txt ├── README.md ├── docs ├── api │ ├── abstract_nodes_api.md │ ├── addition_chains_api.md │ ├── circuits_api.md │ ├── code_generation_api.md │ ├── nodes_api.md │ └── pareto_fronts_api.md ├── config.md ├── example_circuits.md ├── getting_started.md ├── images │ └── oraqle_logo_cropped.svg ├── index.md └── tutorial_running_exps.md ├── main.cpp ├── mkdocs.yml ├── oraqle ├── __init__.py ├── add_chains │ ├── __init__.py │ ├── addition_chains.py │ ├── addition_chains_front.py │ ├── addition_chains_heuristic.py │ ├── addition_chains_mod.py │ ├── memoization.py │ └── solving.py ├── circuits │ ├── __init__.py │ ├── aes.py │ ├── cardio.py │ ├── median.py │ ├── mimc.py │ ├── sorting.py │ └── veto_voting.py ├── compiler │ ├── __init__.py │ ├── arithmetic │ │ ├── __init__.py │ │ ├── exponentiation.py │ │ └── subtraction.py │ ├── boolean │ │ ├── __init__.py │ │ ├── bool_and.py │ │ ├── bool_neg.py │ │ └── bool_or.py │ ├── circuit.py │ ├── comparison │ │ ├── __init__.py │ │ ├── comparison.py │ │ ├── equality.py │ │ └── in_upper_half.py │ ├── control_flow │ │ ├── __init__.py │ │ └── conditional.py │ ├── func2poly.py │ ├── graphviz.py │ ├── instructions.py │ ├── nodes │ │ ├── __init__.py │ │ ├── abstract.py │ │ ├── arbitrary_arithmetic.py │ │ ├── binary_arithmetic.py │ │ ├── fixed.py │ │ ├── flexible.py │ │ ├── leafs.py │ │ ├── non_commutative.py │ │ ├── unary_arithmetic.py │ │ └── univariate.py │ ├── poly2circuit.py │ └── polynomials │ │ ├── __init__.py │ │ └── univariate.py ├── config.py ├── demo │ ├── depth_aware_equality.ipynb │ ├── playground.ipynb │ ├── small_comparison_bgv.ipynb │ └── veto_voting.ipynb ├── examples │ ├── depth_aware_comparison.py │ ├── depth_aware_equality.py │ ├── long_and.py │ ├── small_comparison.py │ ├── small_polynomial.py │ ├── visualize_circuits.py │ └── wahc2024_presentation │ │ ├── 1_high-level.py │ │ ├── 2_arith_step1.py │ │ ├── 3_arith_step2.py │ │ ├── 5_code_gen.py │ │ └── rebalancing.py ├── experiments │ ├── depth_aware_arithmetization │ │ └── execution │ │ │ ├── bench_cardio_circuits.py │ │ │ ├── bench_equality.py │ │ │ ├── cardio_circuits.py │ │ │ ├── comparisons.py │ │ │ ├── equality_first_prime_mods_exec.py │ │ │ ├── poly_evaluation_pareto_front.py │ │ │ ├── run_all.sh │ │ │ └── veto_voting_per_mod.py │ └── oraqle_spotlight │ │ ├── examples │ │ ├── and_16.py │ │ ├── common_expressions.py │ │ ├── equality_31.py │ │ ├── equality_and_comparison.py │ │ └── t2_comparison.py │ │ └── experiments │ │ ├── comparisons │ │ └── comparisons_bench.py │ │ ├── large_equality │ │ ├── .gitignore │ │ ├── CMakeLists.txt │ │ └── large_equality.py │ │ └── veto_voting_minimal_cost.py └── helib_template │ ├── .gitignore │ ├── CMakeLists.txt │ ├── __init__.py │ └── main.cpp ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt ├── ruff.toml ├── setup.cfg └── tests ├── test_circuit_sizes_costs.py ├── test_poly2circuit.py └── test_sugar_expressions.py /.github/workflows/github_pages.yml: -------------------------------------------------------------------------------- 1 | name: Deploy docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Configure Git Credentials 14 | run: | 15 | git config user.name github-actions[bot] 16 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: 3.* 20 | - run: pip install -r requirements_dev.txt 21 | - run: mkdocs gh-deploy --force 22 | -------------------------------------------------------------------------------- /.github/workflows/publish_to_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distribution 📦 to PyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build: 7 | name: Build distribution 📦 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.x" 16 | - name: Install pypa/build 17 | run: >- 18 | python3 -m 19 | pip install 20 | build 21 | --user 22 | - name: Build a binary wheel and a source tarball 23 | run: python3 -m build 24 | - name: Store the distribution packages 25 | uses: actions/upload-artifact@v3 26 | with: 27 | name: python-package-distributions 28 | path: dist/ 29 | 30 | publish-to-pypi: 31 | name: >- 32 | Publish Python 🐍 distribution 📦 to PyPI 33 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 34 | needs: 35 | - build 36 | runs-on: ubuntu-latest 37 | environment: 38 | name: pypi 39 | url: https://pypi.org/p/oraqle 40 | permissions: 41 | id-token: write 42 | 43 | steps: 44 | - name: Download all the dists 45 | uses: actions/download-artifact@v3 46 | with: 47 | name: python-package-distributions 48 | path: dist/ 49 | - name: Publish distribution 📦 to PyPI 50 | uses: pypa/gh-action-pypi-publish@release/v1 51 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ${{ matrix.os }} 19 | strategy: 20 | matrix: 21 | os: [ubuntu-latest, macos-latest, windows-latest] 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 3.10 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: "3.10" 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install ruff pytest 33 | pip install -r requirements.txt 34 | pip install -e . 35 | - name: Lint with ruff 36 | run: | 37 | ruff check 38 | - name: Test with pytest 39 | run: | 40 | pytest 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /venv 2 | *.dot 3 | .idea/** 4 | __pycache__/** 5 | /build 6 | .DS_Store 7 | *.egg-info* 8 | instructions.txt 9 | .ipynb_checkpoints/ 10 | *.pdf 11 | *.pkl 12 | .sphinx_build/ 13 | /dist 14 | *.svg 15 | *.pyc 16 | oraqle/addchain_cache.db 17 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you want to cite the compiler as a whole, please cite this work. See the citations page in the documentation for specific works e.g. about depth-aware arithmetization." 3 | authors: 4 | - family-names: "Vos" 5 | given-names: "Jelle" 6 | orcid: "https://orcid.org/0000-0002-3979-9740" 7 | - family-names: "Conti" 8 | given-names: "Mauro" 9 | orcid: "https://orcid.org/0000-0002-3612-1934" 10 | - family-names: "Erkin" 11 | given-names: "Zekeriya" 12 | orcid: "https://orcid.org/0000-0001-8932-4703" 13 | title: "Oraqle: A Depth-Aware Secure Computation Compiler" 14 | version: 0.1.0 15 | doi: 10.1145/3689945.3694808 16 | date-released: 2024-09-11 17 | url: "https://github.com/jellevos/oraqle" 18 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jelle Vos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Oraqle logo
3 | A secure computation compiler 4 |

5 | 6 | The oraqle compiler lets you generate arithmetic circuits from high-level Python code. It also lets you generate code using HElib. 7 | 8 | This repository uses a fork of fhegen as a dependency and adapts some of the code from [fhegen](https://github.com/Crypto-TII/fhegen), which was written by Johannes Mono, Chiara Marcolla, Georg Land, Tim Güneysu, and Najwa Aaraj. You can read their theoretical work at: https://eprint.iacr.org/2022/706. 9 | 10 | See [our documentation](https://jelle-vos.nl/oraqle) for more details. 11 | 12 | ## Setting up 13 | The best way to get things up and running is using a virtual environment: 14 | - Set up a virtualenv using `python3 -m venv venv` in the directory. 15 | - Enter the virtual environment using `source venv/bin/activate`. 16 | - Install the requirements using `pip install requirements.txt`. 17 | - *To overcome import problems*, run `pip install -e .`, which will create links to your files (so you do not need to re-install after every change). 18 | 19 | We are currently setting up documentation to be rendered using GitHub Actions. 20 | -------------------------------------------------------------------------------- /docs/api/abstract_nodes_api.md: -------------------------------------------------------------------------------- 1 | # Abstract nodes API 2 | !!! warning 3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. 4 | 5 | If you want to extend the oraqle compiler, or implement your own high-level nodes, it is easiest to extend one of the existing abstract node classes. 6 | -------------------------------------------------------------------------------- /docs/api/addition_chains_api.md: -------------------------------------------------------------------------------- 1 | # Addition chains API 2 | !!! warning 3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. 4 | 5 | The `add_chains` module contains tools for generating addition chains. 6 | 7 | ::: oraqle.add_chains 8 | options: 9 | heading_level: 2 10 | show_submodules: true 11 | show_if_no_docstring: false 12 | -------------------------------------------------------------------------------- /docs/api/circuits_api.md: -------------------------------------------------------------------------------- 1 | # Circuits API 2 | !!! warning 3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. 4 | 5 | 6 | ## High-level circuits 7 | ::: oraqle.compiler.circuit.Circuit 8 | options: 9 | heading_level: 3 10 | 11 | 12 | ## Arithmetic circuits 13 | ::: oraqle.compiler.circuit.ArithmeticCircuit 14 | options: 15 | heading_level: 3 16 | -------------------------------------------------------------------------------- /docs/api/code_generation_api.md: -------------------------------------------------------------------------------- 1 | # Code generation API 2 | !!! warning 3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. 4 | 5 | The easiest way is using: 6 | ```python3 7 | arithmetic_circuit.generate_code() 8 | ``` 9 | 10 | ## Arithmetic instructions 11 | If you want to extend the oraqle compiler, or implement your own code generation, you can use the following instructions to do so. 12 | 13 | ??? info "Abstract instruction" 14 | ::: oraqle.compiler.instructions.ArithmeticInstruction 15 | options: 16 | heading_level: 3 17 | 18 | ??? info "InputInstruction" 19 | ::: oraqle.compiler.instructions.InputInstruction 20 | options: 21 | heading_level: 3 22 | 23 | ??? info "AdditionInstruction" 24 | ::: oraqle.compiler.instructions.AdditionInstruction 25 | options: 26 | heading_level: 3 27 | 28 | ??? info "MultiplicationInstruction" 29 | ::: oraqle.compiler.instructions.MultiplicationInstruction 30 | options: 31 | heading_level: 3 32 | 33 | ??? info "ConstantAdditionInstruction" 34 | ::: oraqle.compiler.instructions.ConstantAdditionInstruction 35 | options: 36 | heading_level: 3 37 | 38 | ??? info "ConstantMultiplicationInstruction" 39 | ::: oraqle.compiler.instructions.ConstantMultiplicationInstruction 40 | options: 41 | heading_level: 3 42 | 43 | ??? info "OutputInstruction" 44 | ::: oraqle.compiler.instructions.OutputInstruction 45 | options: 46 | heading_level: 3 47 | 48 | 49 | ## Generating arithmetic programs 50 | ::: oraqle.compiler.instructions.ArithmeticProgram 51 | options: 52 | heading_level: 3 53 | 54 | 55 | ## Generating code for HElib 56 | ... 57 | -------------------------------------------------------------------------------- /docs/api/nodes_api.md: -------------------------------------------------------------------------------- 1 | # Nodes API 2 | !!! warning 3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. 4 | 5 | ## Boolean operations 6 | 7 | ??? info "AND operation" 8 | ::: oraqle.compiler.boolean.bool_and.And 9 | options: 10 | heading_level: 3 11 | 12 | ??? info "OR operation" 13 | ::: oraqle.compiler.boolean.bool_or.Or 14 | options: 15 | heading_level: 3 16 | 17 | ??? info "NEG operation" 18 | ::: oraqle.compiler.boolean.bool_neg.Neg 19 | options: 20 | heading_level: 3 21 | 22 | 23 | ## Arithmetic operations 24 | These operations are fundamental arithmetic operations, so they will stay the same when they are arithmetized. 25 | 26 | 27 | ## High-level arithmetic operations 28 | 29 | ??? info "Subtraction" 30 | ::: oraqle.compiler.arithmetic.subtraction.Subtraction 31 | options: 32 | heading_level: 3 33 | 34 | ??? info "Exponentiation" 35 | ::: oraqle.compiler.arithmetic.exponentiation.Power 36 | options: 37 | heading_level: 3 38 | 39 | 40 | ## Polynomial evaluation 41 | 42 | ??? info "Univariate polynomial evaluation" 43 | ::: oraqle.compiler.polynomials.univariate.UnivariatePoly 44 | options: 45 | heading_level: 3 46 | 47 | 48 | ## Control flow 49 | 50 | ??? info "If-else statement" 51 | ::: oraqle.compiler.control_flow.conditional.IfElse 52 | options: 53 | heading_level: 3 54 | -------------------------------------------------------------------------------- /docs/api/pareto_fronts_api.md: -------------------------------------------------------------------------------- 1 | # Pareto fronts API 2 | !!! warning 3 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. 4 | 5 | If you are using depth-aware arithmetization, you will find that the compiler does not output one arithmetic circuit. 6 | Instead, it outputs a Pareto front, which represents the best circuits that it could generate trading off two metrics: 7 | The *multiplicative depth* and the *multiplicative size/cost*. 8 | This page briefly explains the API for interfacing with these Pareto fronts. 9 | 10 | ## The abstract base class 11 | 12 | ??? info "Abstract ParetoFront" 13 | ::: oraqle.compiler.nodes.abstract.ParetoFront 14 | options: 15 | heading_level: 3 16 | 17 | ## Depth-size and depth-cost fronts 18 | 19 | -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | # Configuration parameters 2 | 3 | ::: oraqle.config 4 | options: 5 | heading_level: 2 6 | show_submodules: true 7 | show_if_no_docstring: false 8 | -------------------------------------------------------------------------------- /docs/example_circuits.md: -------------------------------------------------------------------------------- 1 | !!! warning 2 | Some of these example circuits are untested and may be incorrect. 3 | 4 | ::: oraqle.circuits 5 | options: 6 | heading_level: 3 7 | show_submodules: true 8 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | In 5 minutes, this page will guide you through how to install oraqle, how to specify high-level programs, and how to arithmetize your first circuit! 3 | 4 | ## Installation 5 | Simply install the most recent version of the Oraqle compiler using: 6 | ``` 7 | pip install oraqle 8 | ``` 9 | 10 | We use continuous integration to test every build of the Oraqle compiler on Windows, MacOS, and Unix systems. 11 | If you do run into problems, feel free to [open an issue on GitHub]()! 12 | 13 | ## Specifying high-level programs 14 | Let's start with importing `galois`, which represents our plaintext algebra. 15 | We will also immediately import the relevant oraqle classes for our little example: 16 | ```python3 17 | from galois import GF 18 | 19 | from oraqle.compiler.circuit import Circuit 20 | from oraqle.compiler.nodes.leafs import Input 21 | ``` 22 | 23 | For this example, we will use 31 as our plaintext modulus. This algebra is denoted by `GF(31)`. 24 | Let's create a few inputs that represent elements in this algebra: 25 | ```python3 26 | gf = GF(31) 27 | 28 | x = Input("x", gf) 29 | y = Input("y", gf) 30 | z = Input("z", gf) 31 | ``` 32 | 33 | We can now perform some operations on these elements, and they do not have to be arithmetic operations! 34 | For example, we can perform equality checks or comparisons: 35 | ``` 36 | comparison = x < y 37 | equality = y == z 38 | both = comparison & equality 39 | ``` 40 | 41 | While we have specified some operations, we have not yet established this as a circuit. We will do so now: 42 | ```python3 43 | circuit = Circuit([both]) 44 | ``` 45 | 46 | And that's it! We are done specifying our first high-level circuit. 47 | As you can see this is all very similar to writing a regular Python program. 48 | If you want to visualize this high-level circuit before we continue with arithmetizing it, you can run the following (if you have graphviz installed): 49 | ```python3 50 | circuit.to_pdf("high_level_circuit.pdf") 51 | ``` 52 | 53 | !!! tip 54 | If you do not have graphviz installed, you can instead call: 55 | ```python3 56 | circuit.to_dot("high_level_circuit.dot") 57 | ``` 58 | After that, you can copy the file contents to [an online graphviz viewer](https://dreampuf.github.io/GraphvizOnline)! 59 | 60 | ## Arithmetizing your first circuit 61 | At this point, arithmetization is a breeze, because the oraqle compiler takes care of these steps. 62 | We can create an arithmetic circuit and visualize it using the following snippet: 63 | ```python3 64 | arithmetic_circuit = circuit.arithmetize() 65 | arithmetic_circuit.to_pdf("arithmetic_circuit.pdf") 66 | ``` 67 | 68 | You will notice that it's quite a large circuit. But how large is it exactly? 69 | This is a question that we can ask to the oraqle compiler: 70 | ```python3 71 | print("Depth:", arithmetic_circuit.multiplicative_depth()) 72 | print("Size:", arithmetic_circuit.multiplicative_size()) 73 | print("Cost:", arithmetic_circuit.multiplicative_cost(0.7)) 74 | ``` 75 | 76 | In the last line, we asked the compiler to output the multiplicative cost, considering that squaring operations are cheaper than regular multiplications. 77 | We weighed this cost with a factor 0.7. 78 | 79 | Now that we have an arithmetic circuit, we can use homomorphic encryption to evaluate it! 80 | If you are curious about executing these circuits for real, consider reading [the code generation tutorial](tutorial_running_exps.md). 81 | 82 | !!! warning 83 | There are many homomorphic encryption libraries that do not support plaintext moduli that are not NTT-friendly. The plaintext modulus we chose (31) is not NTT-friendly. 84 | In fact, only very few primes are NTT-friendly, and they are somewhat large. This is why, right now, the oraqle compiler only implements code generation for HElib. 85 | HElib is (as far as we are aware) the only library that supports plaintext moduli that are not NTT-friendly. 86 | -------------------------------------------------------------------------------- /docs/images/oraqle_logo_cropped.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 39 | 41 | 46 | 49 | 52 | 56 | 60 | 64 | 68 | 72 | 76 | 81 | 86 | 89 | 93 | 97 | 98 | 101 | 105 | 109 | 110 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to oraqle 2 |
3 | Oraqle logo
4 | A secure computation compiler 5 |
6 | 7 | Simply install the most recent version of the Oraqle compiler using: 8 | ``` 9 | pip install oraqle==0.1.0 10 | ``` 11 | 12 | Consider checking out our [getting started page](getting_started.md) to help you get up to speed with arithmetizing circuits! 13 | 14 | ## API reference 15 | !!! warning 16 | In this version of Oraqle, the API is still prone to changes. Paths and names can change between any version. 17 | 18 | For an API reference, you can check out the pages for [circuits](api/circuits_api.md) and for [nodes](api/nodes_api.md). 19 | -------------------------------------------------------------------------------- /docs/tutorial_running_exps.md: -------------------------------------------------------------------------------- 1 | # Tutorial: Running experiments 2 | !!! failure 3 | This section is currently missing. Please see the [code generation API](api/code_generation_api.md) for some documentation for now. 4 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | typedef helib::Ptxt ptxt_t; 10 | typedef helib::Ctxt ctxt_t; 11 | 12 | std::map input_map; 13 | 14 | void parse_arguments(int argc, char* argv[]) { 15 | for (int i = 1; i < argc; ++i) { 16 | std::string argument(argv[i]); 17 | size_t pos = argument.find('='); 18 | if (pos != std::string::npos) { 19 | std::string key = argument.substr(0, pos); 20 | int value = std::stoi(argument.substr(pos + 1)); 21 | input_map[key] = value; 22 | } 23 | } 24 | } 25 | 26 | int extract_input(const std::string& name) { 27 | if (input_map.find(name) != input_map.end()) { 28 | return input_map[name]; 29 | } else { 30 | std::cerr << "Error: " << name << " not found" << std::endl; 31 | return -1; 32 | } 33 | } 34 | 35 | int main(int argc, char* argv[]) { 36 | // Parse the inputs 37 | parse_arguments(argc, argv); 38 | 39 | // Set up the HE parameters 40 | unsigned long p = 257; 41 | unsigned long m = 65536; 42 | unsigned long r = 1; 43 | unsigned long bits = 449; 44 | unsigned long c = 3; 45 | helib::Context context = helib::ContextBuilder() 46 | .m(m) 47 | .p(p) 48 | .r(r) 49 | .bits(bits) 50 | .c(c) 51 | .build(); 52 | 53 | 54 | // Generate keys 55 | helib::SecKey secret_key(context); 56 | secret_key.GenSecKey(); 57 | helib::addSome1DMatrices(secret_key); 58 | const helib::PubKey& public_key = secret_key; 59 | 60 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Oraqle 2 | 3 | nav: 4 | - index.md 5 | - getting_started.md 6 | - tutorial_running_exps.md 7 | - API reference: 8 | - api/circuits_api.md 9 | - api/nodes_api.md 10 | - api/code_generation_api.md 11 | - api/pareto_fronts_api.md 12 | - api/abstract_nodes_api.md 13 | - api/addition_chains_api.md 14 | - example_circuits.md 15 | - config.md 16 | 17 | plugins: 18 | - search 19 | - mkdocstrings: 20 | handlers: 21 | python: 22 | options: 23 | show_root_heading: true 24 | allow_inspection: false 25 | show_submodules: false 26 | show_root_full_path: false 27 | show_symbol_type_heading: true 28 | # show_symbol_type_toc: true This currently causes a bug 29 | docstring_style: google 30 | follow_wrapped_lines: true 31 | crosslink_types: true # Makes types clickable 32 | crosslink_types_style: 'sphinx' # Default or sphinx style 33 | annotations_path: brief 34 | inherited_members: true 35 | members_order: source 36 | show_if_no_docstring: true 37 | separate_signature: false 38 | show_source: false 39 | docstring_section_style: list 40 | 41 | theme: 42 | name: material 43 | highlightjs: true 44 | 45 | markdown_extensions: 46 | - admonition 47 | - pymdownx.superfences 48 | - pymdownx.inlinehilite 49 | - pymdownx.critic 50 | - pymdownx.details 51 | - pymdownx.tasklist 52 | - pymdownx.tabbed 53 | - pymdownx.magiclink 54 | - pymdownx.tilde 55 | - toc: 56 | permalink: true 57 | toc_depth: 3 58 | -------------------------------------------------------------------------------- /oraqle/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains the oraqle compiler, tools, and example circuits.""" 2 | -------------------------------------------------------------------------------- /oraqle/add_chains/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools for generating addition chains using different constraints and objectives.""" 2 | -------------------------------------------------------------------------------- /oraqle/add_chains/addition_chains_front.py: -------------------------------------------------------------------------------- 1 | """Tools for generating addition chains that trade off depth and cost.""" 2 | import math 3 | from typing import List, Optional, Tuple 4 | 5 | from oraqle.add_chains.addition_chains import add_chain 6 | from oraqle.add_chains.addition_chains_mod import add_chain_modp, hw, size_lower_bound 7 | 8 | 9 | def chain_depth( 10 | chain: List[Tuple[int, int]], 11 | precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None, 12 | modulus: Optional[int] = None, 13 | ) -> int: 14 | """Return the depth of the addition chain.""" 15 | depths = {1: 0} 16 | if precomputed_values is not None: 17 | depths.update(precomputed_values) 18 | 19 | if modulus is None: 20 | for x, y in chain: 21 | depths[x + y] = max(depths[x], depths[y]) + 1 22 | else: 23 | for x, y in chain: 24 | depths[(x + y) % modulus] = max(depths[x % modulus], depths[y % modulus]) + 1 25 | 26 | return max(depths.values()) 27 | 28 | 29 | def gen_pareto_front( # noqa: PLR0912, PLR0913, PLR0917 30 | target: int, 31 | modulus: Optional[int], 32 | squaring_cost: float, 33 | solver="glucose42", 34 | encoding=1, 35 | thurber=True, 36 | precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None, 37 | ) -> List[Tuple[int, List[Tuple[int, int]]]]: 38 | """Returns a Pareto front of addition chains, trading of cost and depth.""" 39 | if target == 1: 40 | return [(0, [])] 41 | 42 | if modulus is not None: 43 | assert target <= modulus 44 | 45 | # Find the lowest depth chain using square & multiply (SaM) 46 | sam_depth = math.ceil(math.log2(target)) 47 | sam_cost = math.ceil(math.log2(target)) * squaring_cost + hw(target) - 1 48 | sam_target = target 49 | 50 | # If there is a modulus, we should also consider it to find an upper bound on the cost of a minimum-depth chain 51 | if modulus is not None: 52 | current_target = target + modulus - 1 53 | while math.log2(current_target) <= sam_depth: 54 | current_cost = ( 55 | math.ceil(math.log2(current_target)) * squaring_cost + hw(current_target) - 1 56 | ) 57 | if current_cost < sam_cost: 58 | sam_cost = current_cost 59 | sam_target = target 60 | current_target += modulus - 1 61 | 62 | # Find the cheapest chain (i.e. no depth constraints) 63 | min_size = size_lower_bound(target) if precomputed_values is None else 1 64 | if modulus is None: 65 | cheapest_chain = add_chain( 66 | target, 67 | None, 68 | sam_cost, 69 | squaring_cost, 70 | solver, 71 | encoding, 72 | thurber, 73 | min_size, 74 | precomputed_values, 75 | ) 76 | else: 77 | cheapest_chain = add_chain_modp( 78 | target, 79 | modulus, 80 | None, 81 | sam_cost, 82 | squaring_cost, 83 | solver, 84 | encoding, 85 | thurber, 86 | min_size, 87 | precomputed_values, 88 | ) 89 | 90 | # If no cheapest chain is found that satisfies these bounds, then square and multiply had the same cost 91 | if cheapest_chain is None: 92 | sam_chain = [] 93 | for i in range(math.ceil(math.log2(sam_target))): 94 | sam_chain.append((2**i, 2**i)) 95 | previous = 1 96 | for i in range(math.ceil(math.log2(sam_target))): 97 | if (sam_target >> i) & 1: 98 | sam_chain.append((previous, 2**i)) 99 | previous += 2**i 100 | return [(sam_depth, sam_chain)] 101 | 102 | add_size = len(cheapest_chain) # TODO: Check that this is indeed a valid bound 103 | add_cost = sum(squaring_cost if x == y else 1.0 for x, y in cheapest_chain) 104 | add_depth = chain_depth(cheapest_chain, precomputed_values, modulus=modulus) 105 | 106 | # Go through increasing depth and decrease the previous size, until we reach the cost of square and multiply 107 | pareto_front = [] 108 | current_depth = sam_depth 109 | current_cost = sam_cost 110 | while current_cost > add_cost and current_depth < add_depth: 111 | if modulus is None: 112 | chain = add_chain( 113 | target, 114 | current_depth, 115 | current_cost, 116 | squaring_cost, 117 | solver, 118 | encoding, 119 | thurber, 120 | add_size, 121 | precomputed_values, 122 | ) 123 | else: 124 | chain = add_chain_modp( 125 | target, 126 | modulus, 127 | current_depth, 128 | current_cost, 129 | squaring_cost, 130 | solver, 131 | encoding, 132 | thurber, 133 | add_size, 134 | precomputed_values, 135 | ) 136 | 137 | if chain is not None: 138 | # Add to the Pareto front 139 | pareto_front.append((current_depth, chain)) 140 | current_cost = sum(squaring_cost if x == y else 1.0 for x, y in chain) 141 | 142 | current_depth += 1 143 | 144 | # Add the final chain and return 145 | if add_cost < current_cost or len(pareto_front) == 0: 146 | pareto_front.append((add_depth, cheapest_chain)) 147 | 148 | return pareto_front 149 | 150 | 151 | def test_gen_exponentiation_front_small(): # noqa: D103 152 | front = gen_pareto_front(2, None, 0.75) 153 | assert front == [(1, [(1, 1)])] 154 | -------------------------------------------------------------------------------- /oraqle/add_chains/addition_chains_heuristic.py: -------------------------------------------------------------------------------- 1 | """This module contains functions for finding addition chains, while sometimes resorting to heuristics to prevent long computations.""" 2 | 3 | from functools import lru_cache 4 | import math 5 | from typing import List, Optional, Tuple 6 | 7 | from oraqle.add_chains.addition_chains import add_chain 8 | from oraqle.add_chains.addition_chains_mod import add_chain_modp, hw 9 | from oraqle.add_chains.solving import extract_indices 10 | 11 | 12 | def _mul(current_chain: List[Tuple[int, int]], other_chain: List[Tuple[int, int]]): 13 | length = len(current_chain) 14 | for a, b in other_chain: 15 | current_chain.append((a + length, b + length)) 16 | 17 | 18 | def _chain(n, k) -> List[Tuple[int, int]]: 19 | q = n // k 20 | r = n % k 21 | if r in {0, 1}: 22 | chain_k = _minchain(k) 23 | _mul(chain_k, _minchain(q)) 24 | if r == 1: 25 | chain_k.append((0, len(chain_k))) 26 | return chain_k 27 | else: 28 | chain_k = _chain(k, r) 29 | index_r = len(chain_k) 30 | _mul(chain_k, _minchain(q)) 31 | chain_k.append((index_r, len(chain_k))) 32 | return chain_k 33 | 34 | 35 | def _minchain(n: int) -> List[Tuple[int, int]]: 36 | log_n = n.bit_length() - 1 37 | if n == 1 << log_n: 38 | return [(i, i) for i in range(log_n)] 39 | elif n == 3: 40 | return [(0, 0), (0, 1)] 41 | else: 42 | k = n // (1 << (log_n // 2)) 43 | return _chain(n, k) 44 | 45 | 46 | @lru_cache 47 | def add_chain_guaranteed( # noqa: PLR0913, PLR0917 48 | target: int, 49 | modulus: Optional[int], 50 | squaring_cost: float, 51 | solver: str = "glucose421", 52 | encoding: int = 1, 53 | thurber: bool = True, 54 | precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None, 55 | ) -> List[Tuple[int, int]]: 56 | """Always generates an addition chain for a given target, which is suboptimal if the inputs are too large. 57 | 58 | In some cases, the result is not necessarily optimal. These are the cases where we resort to a heuristic. 59 | This currently happens if: 60 | - The target exceeds 1000. 61 | - The modulus (if provided) exceeds 200. 62 | - MAXSAT_TIMEOUT is not None and a MaxSAT instance timed out 63 | 64 | !!! note 65 | This function is useful for preventing long computation, but the result is not guaranteed to be (close to) optimal. 66 | Unlike `add_chain`, this function will always return an addition chain. 67 | 68 | Parameters: 69 | target: The target integer. 70 | modulus: Modulus to take into account. In an exponentiation chain, this is the modulus in the exponent, i.e. x^target mod p corresponds to `modulus = p - 1`. 71 | squaring_cost: The cost of doubling (squaring), compared to other additions (multiplications), which cost 1.0. 72 | solver: Name of the SAT solver, e.g. "glucose421" for glucose 4.2.1. See: https://pysathq.github.io/docs/html/api/solvers.html. 73 | encoding: The encoding to use for cardinality constraints. See: https://pysathq.github.io/docs/html/api/card.html#pysat.card.EncType. 74 | thurber: Whether to use the Thurber bounds, which provide lower bounds for the elements in the chain. The bounds are ignored when `precomputed_values = True`. 75 | precomputed_values: If there are any precomputed values that can be used for free, they can be specified as a tuple of pairs (value, chain_depth). 76 | 77 | Raises: # noqa: DOC502 78 | TimeoutError: If the global MAXSAT_TIMEOUT is not None, and it is reached before a maxsat instance could be solved. 79 | 80 | Returns: 81 | An addition chain. 82 | """ 83 | # We want to do better than square and multiply, so we find an upper bound 84 | sam_cost = math.ceil(math.log2(target)) * squaring_cost + hw(target) - 1 85 | 86 | # Apply CSE to the square & mutliply chain 87 | if precomputed_values is not None: 88 | for exp, depth in precomputed_values: 89 | if exp > 0 and (exp & (exp - 1)) == 0 and depth == math.log2(exp): 90 | sam_cost -= squaring_cost 91 | 92 | try: 93 | addition_chain = None 94 | if modulus is not None and modulus <= 200: 95 | addition_chain = add_chain_modp( 96 | target, 97 | modulus, 98 | None, 99 | sam_cost, 100 | squaring_cost, 101 | solver, 102 | encoding, 103 | thurber, 104 | min_size=math.ceil(math.log2(target)) if precomputed_values is None else 1, 105 | precomputed_values=precomputed_values, 106 | ) 107 | elif target <= 1000: 108 | addition_chain = add_chain( 109 | target, 110 | None, 111 | sam_cost, 112 | squaring_cost, 113 | solver, 114 | encoding, 115 | thurber, 116 | min_size=math.ceil(math.log2(target)) if precomputed_values is None else 1, 117 | precomputed_values=precomputed_values, 118 | ) 119 | 120 | if addition_chain is not None: 121 | addition_chain = extract_indices( 122 | addition_chain, precomputed_values=None if precomputed_values is None else list(k for k, _ in precomputed_values), modulus=modulus 123 | ) 124 | except TimeoutError: 125 | # The MaxSAT solver timed out, so we resort to a heuristic 126 | pass 127 | 128 | if addition_chain is None: 129 | # If no other addition chain algorithm has been called or if we could not do better than square and multiply 130 | 131 | # Uses the minchain algorithm from ["Addition chains using continued fractions."][BBBD1989] 132 | # The implementation was adapted from the `addchain` Rust crate (https://github.com/str4d/addchain). 133 | # This algorithm is not optimal: Below 1000 it requires one too many multiplication in 29 cases. 134 | addition_chain = _minchain(target) 135 | 136 | if precomputed_values is not None: 137 | # We must shift the indices in the addition chain 138 | shift = len(precomputed_values) 139 | addition_chain = [(0 if x == 0 else x + shift, 0 if y == 0 else y + shift) for (x, y) in addition_chain] 140 | 141 | assert addition_chain is not None 142 | 143 | return addition_chain 144 | -------------------------------------------------------------------------------- /oraqle/add_chains/addition_chains_mod.py: -------------------------------------------------------------------------------- 1 | """Tools for computing addition chains, taking into account the modular nature of the algebra.""" 2 | import math 3 | from typing import List, Optional, Tuple 4 | 5 | from oraqle.add_chains.addition_chains import add_chain 6 | 7 | 8 | def hw(n: int) -> int: 9 | """Returns the Hamming weight of n.""" 10 | c = 0 11 | while n: 12 | c += 1 13 | n &= n - 1 14 | 15 | return c 16 | 17 | 18 | def size_lower_bound(target: int) -> int: 19 | """Returns a lower bound on the size of the addition chain for this target.""" 20 | return math.ceil( 21 | max( 22 | math.log2(target) + math.log2(hw(target)) - 2.13, 23 | math.log2(target), 24 | math.log2(target) + math.log(hw(target), 3) - 1, 25 | ) 26 | ) 27 | 28 | 29 | def cost_lower_bound_monotonic(target: int, squaring_cost: float) -> float: 30 | """Returns a lower bound on the cost of the addition chain for this target. The bound is guaranteed to grow monotonically with the target.""" 31 | return math.ceil(math.log2(target)) * squaring_cost 32 | 33 | 34 | def chain_cost(chain: List[Tuple[int, int]], squaring_cost: float) -> float: 35 | """Returns the cost of the addition chain, considering doubling (squaring) to be cheaper than other additions (multiplications).""" 36 | return sum(squaring_cost if x == y else 1.0 for x, y in chain) 37 | 38 | 39 | def add_chain_modp( # noqa: PLR0913, PLR0917 40 | target: int, 41 | modulus: int, 42 | max_depth: Optional[int], 43 | strict_cost_max: float, 44 | squaring_cost: float, 45 | solver, 46 | encoding, 47 | thurber, 48 | min_size: int, 49 | precomputed_values: Optional[Tuple[Tuple[int, int], ...]] = None, 50 | ) -> Optional[List[Tuple[int, int]]]: 51 | """Computes an addition chain for target modulo p with the given constraints and optimization parameters. 52 | 53 | The precomputed_powers are an optional set of powers that have previously been computed along with their depth. 54 | This means that those powers can be reused for free. 55 | 56 | Returns: 57 | If it exists, a minimal addition chain meeting the given constraints and optimization parameters. 58 | """ 59 | if precomputed_values is not None: 60 | # The shortest chain in (t + (k-1)p, t + kp] will have length at least k 61 | # The cheapest chain in (t + (k-1)p, t + kp] will have cost at least k / sqr_cost 62 | best_chain = None 63 | 64 | k = 0 65 | while (k / squaring_cost) < strict_cost_max: 66 | # Add multiples of the precomputed_values 67 | new_precomputed_values = [] 68 | for precomputed_value, depth in precomputed_values: 69 | for i in range(k + 1): 70 | new_precomputed_values.append((precomputed_value + i * modulus, depth)) 71 | 72 | chain = add_chain( 73 | target + k * modulus, 74 | max_depth, 75 | strict_cost_max, 76 | squaring_cost, 77 | solver, 78 | encoding, 79 | thurber, 80 | min_size=max(min_size, k), 81 | precomputed_values=tuple(new_precomputed_values), 82 | ) 83 | 84 | if chain is not None: 85 | cost = chain_cost(chain, squaring_cost) 86 | strict_cost_max = min(strict_cost_max, cost) 87 | best_chain = chain 88 | 89 | k += 1 90 | 91 | return best_chain 92 | 93 | best_chain = None 94 | best_cost = None 95 | 96 | current_target = target 97 | 98 | i = 0 99 | 100 | while cost_lower_bound_monotonic(current_target, squaring_cost) < strict_cost_max and ( 101 | max_depth is None or math.ceil(math.log2(current_target)) <= max_depth 102 | ): 103 | tightest_min_size = max(size_lower_bound(current_target), min_size) 104 | if (tightest_min_size * squaring_cost) >= ( 105 | strict_cost_max if best_cost is None else min(strict_cost_max, best_cost) 106 | ): 107 | current_target += modulus 108 | continue 109 | 110 | chain = add_chain( 111 | current_target, 112 | max_depth, 113 | strict_cost_max, 114 | squaring_cost, 115 | solver, 116 | encoding, 117 | thurber, 118 | tightest_min_size, 119 | precomputed_values, 120 | ) 121 | 122 | if chain is not None: 123 | cost = chain_cost(chain, squaring_cost) 124 | if best_cost is None or cost < best_cost: 125 | best_cost = cost 126 | best_chain = chain 127 | strict_cost_max = min(best_cost, strict_cost_max) 128 | 129 | current_target += modulus 130 | 131 | i += 1 132 | return best_chain 133 | 134 | 135 | def test_add_chain_modp_over_modulus(): # noqa: D103 136 | chain = add_chain_modp( 137 | 62, 138 | 66, 139 | None, 140 | 8.0, 141 | 0.75, 142 | solver="glucose42", 143 | encoding=1, 144 | thurber=True, 145 | min_size=1, 146 | precomputed_values=None, 147 | ) 148 | assert chain == [(1, 1), (2, 2), (4, 4), (8, 8), (16, 16), (32, 32), (64, 64)] 149 | 150 | 151 | def test_add_chain_modp_precomputations(): # noqa: D103 152 | chain = add_chain_modp( 153 | 64, # 64+66 = 65+65 154 | 66, 155 | None, 156 | 2.0, 157 | 0.75, 158 | solver="glucose42", 159 | encoding=1, 160 | thurber=True, 161 | min_size=1, 162 | precomputed_values=((65, 5),), 163 | ) 164 | assert chain == [(65, 65)] 165 | 166 | 167 | if __name__ == "__main__": 168 | print(add_chain_modp( 169 | 254, 170 | 255, 171 | None, 172 | 8.0, 173 | 0.5, 174 | solver="glucose42", 175 | encoding=1, 176 | thurber=True, 177 | min_size=11, 178 | precomputed_values=None, 179 | )) 180 | -------------------------------------------------------------------------------- /oraqle/add_chains/memoization.py: -------------------------------------------------------------------------------- 1 | """This module contains tools for memoizing addition chains, as these are expensive to compute.""" 2 | from hashlib import sha3_256 3 | from importlib.resources import files 4 | import inspect 5 | import shelve 6 | from typing import Set 7 | 8 | from sympy import sieve 9 | 10 | import oraqle 11 | 12 | 13 | ADDCHAIN_CACHE_FILENAME = "addchain_cache" 14 | 15 | 16 | # Adapted from: https://stackoverflow.com/questions/16463582/memoize-to-disk-python-persistent-memoization 17 | def cache_to_disk(ignore_args: Set[str]): 18 | """This decorator caches the calls to this function in a file on disk, ignoring the arguments listed in `ignore_args`. 19 | 20 | Returns: 21 | A cached output 22 | """ 23 | # Always opens the database in the root of where the package is located 24 | oraqle_path = files(oraqle) 25 | database_path = oraqle_path.joinpath(ADDCHAIN_CACHE_FILENAME) 26 | d = shelve.open(str(database_path)) # noqa: SIM115 27 | 28 | def decorator(func): 29 | signature = inspect.signature(func) 30 | signature_args = list(signature.parameters.keys()) 31 | assert all(arg in signature_args for arg in ignore_args) 32 | 33 | def wrapped_func(*args, **kwargs): 34 | relevant_args = [a for a, sa in zip(args, signature_args) if sa not in ignore_args] 35 | for kwarg in signature_args[len(args):]: 36 | if kwarg not in ignore_args: 37 | relevant_args.append(kwargs[kwarg]) 38 | 39 | h = sha3_256() 40 | h.update(str(relevant_args).encode('ascii')) 41 | hashed_args = h.hexdigest() 42 | 43 | if hashed_args not in d: 44 | d[hashed_args] = func(*args, **kwargs) 45 | return d[hashed_args] 46 | 47 | return wrapped_func 48 | 49 | return decorator 50 | 51 | 52 | if __name__ == "__main__": 53 | from oraqle.add_chains.addition_chains_front import gen_pareto_front 54 | 55 | # Precompute addition chains for x^(p-1) mod p for the first 30 primes p 56 | primes = list(sieve.primerange(300))[:30] 57 | for sqr_cost in [0.5, 0.75, 1.0]: 58 | print(f"Computing for {sqr_cost}") 59 | 60 | for p in primes: 61 | gen_pareto_front( 62 | p - 1, 63 | modulus=p - 1, 64 | squaring_cost=sqr_cost, 65 | solver="glucose42", 66 | encoding=1, 67 | thurber=True, 68 | ) 69 | -------------------------------------------------------------------------------- /oraqle/add_chains/solving.py: -------------------------------------------------------------------------------- 1 | """Tools for solving SAT formulations.""" 2 | import math 3 | import signal 4 | from typing import List, Optional, Sequence, Tuple 5 | 6 | from pysat.examples.rc2 import RC2 7 | from pysat.formula import WCNF 8 | 9 | 10 | def solve(wcnf: WCNF, solver: str, strict_cost_max: Optional[float]) -> Optional[List[int]]: 11 | """This code is adapted from pysat's internal code to stop when we have reached a maximum cost. 12 | 13 | Returns: 14 | A list containing the assignment (where 3 indicates that 3=True and -3 indicates that 3=False), or None if the wcnf is unsatisfiable. 15 | """ 16 | rc2 = RC2(wcnf, solver) 17 | 18 | if strict_cost_max is None: 19 | strict_cost_max = float("inf") 20 | 21 | while not rc2.oracle.solve(assumptions=rc2.sels + rc2.sums): # type: ignore 22 | rc2.get_core() 23 | 24 | if not rc2.core: 25 | # core is empty, i.e. hard part is unsatisfiable 26 | return None 27 | 28 | rc2.process_core() 29 | 30 | if rc2.cost >= strict_cost_max: 31 | return None 32 | 33 | rc2.model = rc2.oracle.get_model() # type: ignore 34 | 35 | # Return None if the model could not be solved 36 | if rc2.model is None: 37 | return None 38 | 39 | # Extract the model 40 | if rc2.model is None and rc2.pool.top == 0: 41 | # we seem to have been given an empty formula 42 | # so let's transform the None model returned to [] 43 | rc2.model = [] 44 | 45 | rc2.model = filter(lambda inp: abs(inp) in rc2.vmap.i2e, rc2.model) # type: ignore 46 | rc2.model = map(lambda inp: int(math.copysign(rc2.vmap.i2e[abs(inp)], inp)), rc2.model) 47 | rc2.model = sorted(rc2.model, key=abs) 48 | 49 | return rc2.model 50 | 51 | 52 | def extract_indices( 53 | sequence: List[Tuple[int, int]], 54 | precomputed_values: Optional[Sequence[int]] = None, 55 | modulus: Optional[int] = None, 56 | ) -> List[Tuple[int, int]]: 57 | """Returns the indices for each step of the addition chain. 58 | 59 | If n precomputed values are provided, then these are considered to be the first n indices after x (i.e. x has index 0, followed by 1, ..., n representing the precomputed values). 60 | """ 61 | indices = {1: 0} 62 | offset = 1 63 | if precomputed_values is not None: 64 | for v in precomputed_values: 65 | indices[v] = offset 66 | offset += 1 67 | ans_sequence = [] 68 | 69 | if modulus is None: 70 | for index, pair in enumerate(sequence): 71 | i, j = pair 72 | ans_sequence.append((indices[i], indices[j])) 73 | indices[i + j] = index + offset 74 | else: 75 | for index, pair in enumerate(sequence): 76 | i, j = pair 77 | ans_sequence.append((indices[i % modulus], indices[j % modulus])) 78 | indices[(i + j) % modulus] = index + offset 79 | 80 | return ans_sequence 81 | 82 | 83 | def solve_with_time_limit(wcnf: WCNF, solver: str, strict_cost_max: Optional[float], timeout_secs: float) -> Optional[List[int]]: 84 | """This code is adapted from pysat's internal code to stop when we have reached a maximum cost. 85 | 86 | Raises: 87 | TimeoutError: When a timeout occurs (after `timeout_secs` seconds) 88 | 89 | Returns: 90 | A list containing the assignment (where 3 indicates that 3=True and -3 indicates that 3=False), or None if the wcnf is unsatisfiable. 91 | """ 92 | def timeout_handler(s, f): 93 | raise TimeoutError 94 | 95 | # Set the timeout 96 | signal.signal(signal.SIGALRM, timeout_handler) 97 | signal.setitimer(signal.ITIMER_REAL, timeout_secs) 98 | 99 | try: 100 | # TODO: Reduce code duplication: we only changed solve to solve_limited 101 | rc2 = RC2(wcnf, solver) 102 | 103 | if strict_cost_max is None: 104 | strict_cost_max = float("inf") 105 | 106 | while not rc2.oracle.solve_limited(assumptions=rc2.sels + rc2.sums, expect_interrupt=True): # type: ignore 107 | rc2.get_core() 108 | 109 | if not rc2.core: 110 | # core is empty, i.e. hard part is unsatisfiable 111 | signal.setitimer(signal.ITIMER_REAL, 0) 112 | return None 113 | 114 | rc2.process_core() 115 | 116 | if rc2.cost >= strict_cost_max: 117 | signal.setitimer(signal.ITIMER_REAL, 0) 118 | return None 119 | 120 | signal.setitimer(signal.ITIMER_REAL, 0) 121 | rc2.model = rc2.oracle.get_model() # type: ignore 122 | 123 | # Return None if the model could not be solved 124 | if rc2.model is None: 125 | return None 126 | 127 | # Extract the model 128 | if rc2.model is None and rc2.pool.top == 0: 129 | # we seem to have been given an empty formula 130 | # so let's transform the None model returned to [] 131 | rc2.model = [] 132 | 133 | rc2.model = filter(lambda inp: abs(inp) in rc2.vmap.i2e, rc2.model) # type: ignore 134 | rc2.model = map(lambda inp: int(math.copysign(rc2.vmap.i2e[abs(inp)], inp)), rc2.model) 135 | rc2.model = sorted(rc2.model, key=abs) 136 | 137 | return rc2.model 138 | except TimeoutError as err: 139 | raise TimeoutError from err 140 | -------------------------------------------------------------------------------- /oraqle/circuits/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains example circuits and tools for generating them.""" 2 | -------------------------------------------------------------------------------- /oraqle/circuits/aes.py: -------------------------------------------------------------------------------- 1 | """This module implements a high-level AES encryption circuit for a constant key.""" 2 | from typing import List 3 | 4 | from aeskeyschedule import key_schedule 5 | from galois import GF 6 | 7 | from oraqle.compiler.arithmetic.exponentiation import Power 8 | from oraqle.compiler.circuit import Circuit 9 | from oraqle.compiler.nodes import Constant 10 | from oraqle.compiler.nodes.abstract import Node 11 | from oraqle.compiler.nodes.leafs import Input 12 | 13 | gf = GF(2**8) 14 | 15 | 16 | def encrypt(plaintext: List[Node], key: bytes) -> List[Node]: 17 | """Returns an AES encryption circuit for a constant `key`.""" 18 | mix = [Constant(gf(2)), Constant(gf(3)), Constant(gf(1)), Constant(gf(1))] 19 | 20 | round_keys = [[Constant(gf(byte)) for byte in round_key] for round_key in key_schedule(key)] 21 | 22 | def additions(nodes: List[Node]) -> Node: 23 | node_iter = iter(nodes) 24 | out = next(node_iter) + next(node_iter) 25 | for node in node_iter: 26 | out += node 27 | return out 28 | 29 | def sbox(node: Node, method="minchain") -> Node: 30 | if method == "hardcoded": 31 | x2 = node.mul(node, flatten=False) 32 | x3 = node.mul(x2, flatten=False) 33 | x6 = x3.mul(x3, flatten=False) 34 | x12 = x6.mul(x6, flatten=False) 35 | x15 = x12.mul(x3, flatten=False) 36 | x30 = x15.mul(x15, flatten=False) 37 | x60 = x30.mul(x30, flatten=False) 38 | x63 = x60.mul(x3, flatten=False) 39 | x126 = x63.mul(x63, flatten=False) 40 | x127 = node.mul(x126, flatten=False) 41 | x254 = x127.mul(x127, flatten=False) 42 | return x254 43 | elif method == "minchain": 44 | return Power(node, 254, gf) 45 | else: 46 | raise Exception(f"Invalid method: {method}.") 47 | 48 | # AddRoundKey 49 | b = [round_key + plaintext_byte for round_key, plaintext_byte in zip(round_keys[0], plaintext)] 50 | 51 | for round in range(9): 52 | # SubBytes (modular inverse) 53 | b = [sbox(b[j], method="hardcoded") for j in range(16)] 54 | 55 | # ShiftRows 56 | b[1], b[5], b[9], b[13] = b[5], b[9], b[13], b[1] 57 | b[2], b[6], b[10], b[14] = b[10], b[14], b[2], b[6] 58 | b[3], b[7], b[11], b[15] = b[15], b[3], b[7], b[11] 59 | 60 | # MixColumns 61 | b = [additions([mix[(j + i) % 4] * b[j // 4 + i] for i in range(4)]) for j in range(16)] 62 | 63 | # AddRoundKey 64 | b = [round_key + b[j] for j, round_key in zip(range(16), round_keys[round + 1])] 65 | b: List[Node] 66 | 67 | return b 68 | 69 | 70 | if __name__ == "__main__": 71 | # TODO: Consider if we want to support degree > 1 72 | circuit = Circuit( 73 | encrypt([Input(f"{i}", gf) for i in range(16)], b"abcdabcdabcdabcd") 74 | ).arithmetize() 75 | print(circuit) 76 | print(circuit.multiplicative_depth()) 77 | print(circuit.multiplicative_size()) 78 | circuit.eliminate_subexpressions() 79 | print(circuit.multiplicative_depth()) 80 | print(circuit.multiplicative_size()) 81 | 82 | # TODO: Test if it corresponds to a plaintext implementation of AES 83 | 84 | 85 | def test_aes_128(): # noqa: D103 86 | # Only checks if no errors occur 87 | Circuit(encrypt([Input(f"{i}", gf) for i in range(16)], b"abcdabcdabcdabcd")).arithmetize() 88 | -------------------------------------------------------------------------------- /oraqle/circuits/cardio.py: -------------------------------------------------------------------------------- 1 | """This module implements the cardio circuit that is often used in benchmarking compilers, see: https://arxiv.org/abs/2101.07078.""" 2 | from typing import Type 3 | from galois import GF, FieldArray 4 | 5 | from oraqle.compiler.boolean.bool_neg import Neg 6 | from oraqle.compiler.boolean.bool_or import any_ 7 | from oraqle.compiler.circuit import Circuit 8 | from oraqle.compiler.nodes import Input 9 | from oraqle.compiler.nodes.abstract import Node 10 | from oraqle.compiler.nodes.arbitrary_arithmetic import sum_ 11 | 12 | 13 | def construct_cardio_risk_circuit(gf: Type[FieldArray]) -> Node: 14 | """Returns the cardio circuit from https://arxiv.org/abs/2101.07078.""" 15 | man = Input("man", gf) 16 | smoking = Input("smoking", gf) 17 | age = Input("age", gf) 18 | diabetic = Input("diabetic", gf) 19 | hbp = Input("hbp", gf) 20 | cholesterol = Input("cholesterol", gf) 21 | weight = Input("weight", gf) 22 | height = Input("height", gf) 23 | activity = Input("activity", gf) 24 | alcohol = Input("alcohol", gf) 25 | 26 | return sum_( 27 | man & (age > 50), 28 | Neg(man, gf) & (age > 60), 29 | smoking, 30 | diabetic, 31 | hbp, 32 | cholesterol < 40, 33 | weight > (height - 90), # This might underflow if the modulus is too small 34 | activity < 30, 35 | man & (alcohol > 3), 36 | Neg(man, gf) & (alcohol > 2), 37 | ) 38 | 39 | 40 | def construct_cardio_elevated_risk_circuit(gf: Type[FieldArray]) -> Node: 41 | """Returns a variant of the cardio circuit that returns a Boolean indicating whether any risk factor returned true.""" 42 | man = Input("man", gf) 43 | smoking = Input("smoking", gf) 44 | age = Input("age", gf) 45 | diabetic = Input("diabetic", gf) 46 | hbp = Input("hbp", gf) 47 | cholesterol = Input("cholesterol", gf) 48 | weight = Input("weight", gf) 49 | height = Input("height", gf) 50 | activity = Input("activity", gf) 51 | alcohol = Input("alcohol", gf) 52 | 53 | return any_( 54 | man & (age > 50), 55 | Neg(man, gf) & (age > 60), 56 | smoking, 57 | diabetic, 58 | hbp, 59 | cholesterol < 40, 60 | weight > (height - 90), # This might underflow if the modulus is too small 61 | activity < 30, 62 | man & (alcohol > 3), 63 | Neg(man, gf) & (alcohol > 2), 64 | ) 65 | 66 | 67 | def test_cardio_p101(): # noqa: D103 68 | gf = GF(101) 69 | circuit = Circuit([construct_cardio_risk_circuit(gf)]) 70 | 71 | for _, _, arithmetization in circuit.arithmetize_depth_aware(): 72 | assert arithmetization.evaluate({ 73 | "man": gf(1), 74 | "woman": gf(0), 75 | "age": gf(50), 76 | "smoking": gf(0), 77 | "diabetic": gf(0), 78 | "hbp": gf(0), 79 | "cholesterol": gf(45), 80 | "weight": gf(10), 81 | "height": gf(100), 82 | "activity": gf(90), 83 | "alcohol": gf(3), 84 | })[0] == 0 85 | 86 | assert arithmetization.evaluate({ 87 | "man": gf(0), 88 | "woman": gf(1), 89 | "age": gf(50), 90 | "smoking": gf(0), 91 | "diabetic": gf(0), 92 | "hbp": gf(0), 93 | "cholesterol": gf(45), 94 | "weight": gf(10), 95 | "height": gf(100), 96 | "activity": gf(90), 97 | "alcohol": gf(3), 98 | })[0] == 1 99 | 100 | assert arithmetization.evaluate({ 101 | "man": gf(1), 102 | "woman": gf(0), 103 | "age": gf(50), 104 | "smoking": gf(0), 105 | "diabetic": gf(0), 106 | "hbp": gf(0), 107 | "cholesterol": gf(39), 108 | "weight": gf(10), 109 | "height": gf(100), 110 | "activity": gf(90), 111 | "alcohol": gf(3), 112 | })[0] == 1 113 | 114 | assert arithmetization.evaluate({ 115 | "man": gf(1), 116 | "woman": gf(0), 117 | "age": gf(50), 118 | "smoking": gf(1), 119 | "diabetic": gf(0), 120 | "hbp": gf(0), 121 | "cholesterol": gf(45), 122 | "weight": gf(10), 123 | "height": gf(100), 124 | "activity": gf(90), 125 | "alcohol": gf(3), 126 | })[0] == 1 127 | -------------------------------------------------------------------------------- /oraqle/circuits/median.py: -------------------------------------------------------------------------------- 1 | """This module implements circuits for computing the median.""" 2 | from typing import Sequence, Type 3 | 4 | from galois import GF, FieldArray 5 | 6 | from oraqle.circuits.sorting import cswp 7 | from oraqle.compiler.circuit import Circuit 8 | from oraqle.compiler.nodes import Input 9 | 10 | gf = GF(1037347783) 11 | 12 | 13 | def gen_median_circuit(inputs: Sequence[int], gf: Type[FieldArray]): 14 | """Returns a naive circuit for finding the median value of `inputs`.""" 15 | input_nodes = [Input(f"Input {v}", gf) for v in inputs] 16 | 17 | outputs = [n for n in input_nodes] 18 | 19 | for i in range(len(outputs) - 1, -1, -1): 20 | for j in range(i): 21 | outputs[j], outputs[j + 1] = cswp(outputs[j], outputs[j + 1]) # type: ignore 22 | 23 | if len(outputs) % 2 == 1: 24 | return Circuit([outputs[len(outputs) // 2]]) 25 | return Circuit([outputs[len(outputs) // 2 + 1]]) 26 | 27 | 28 | if __name__ == "__main__": 29 | circuit = gen_median_circuit(range(10), gf) 30 | circuit.to_graph("median.dot") 31 | -------------------------------------------------------------------------------- /oraqle/circuits/mimc.py: -------------------------------------------------------------------------------- 1 | """MIMC is an MPC-friendly cipher: https://eprint.iacr.org/2016/492.""" 2 | from math import ceil, log2 3 | from random import randint 4 | 5 | from galois import GF 6 | 7 | from oraqle.compiler.circuit import Circuit 8 | from oraqle.compiler.nodes import Constant, Input, Node 9 | 10 | gf = GF(680564733841876926926749214863536422929) 11 | 12 | 13 | # TODO: Check parameters with the paper 14 | def encrypt(plaintext: Node, key: int, power_n: int = 129) -> Node: 15 | """Returns an MIMC encryption circuit using a constant key.""" 16 | rounds = ceil(power_n / log2(3)) 17 | 18 | constants = [ 19 | ( 20 | Constant(gf(0)) 21 | if (round == 0) or (round == (rounds - 1)) 22 | else Constant(gf(randint(0, 2**power_n))) 23 | ) 24 | for round in range(rounds) 25 | ] 26 | key_constant = Constant(gf(key)) 27 | 28 | for round in range(rounds): 29 | added = plaintext + key_constant + constants[round] 30 | plaintext = added * added * added 31 | 32 | return plaintext + key_constant 33 | 34 | 35 | if __name__ == "__main__": 36 | node = encrypt(Input("m", gf), 12345) 37 | 38 | circuit = Circuit([node]).arithmetize() 39 | print(circuit.multiplicative_depth()) 40 | print(circuit.multiplicative_size()) 41 | 42 | circuit.to_graph("mimc-129.dot") 43 | 44 | 45 | def test_mimc_129(): # noqa: D103 46 | node = encrypt(Input("m", gf), 12345) 47 | 48 | circuit = Circuit([node]).arithmetize() 49 | 50 | assert circuit.multiplicative_depth() == 164 51 | assert circuit.multiplicative_size() == 164 52 | -------------------------------------------------------------------------------- /oraqle/circuits/sorting.py: -------------------------------------------------------------------------------- 1 | """This module contains sorting circuits and comparators.""" 2 | from typing import Sequence, Tuple, Type 3 | 4 | from galois import GF, FieldArray 5 | 6 | from oraqle.compiler.circuit import ArithmeticCircuit, Circuit 7 | from oraqle.compiler.nodes import Input 8 | from oraqle.compiler.nodes.abstract import Node 9 | 10 | gf = GF(13) 11 | 12 | 13 | def cswp(lhs: Node, rhs: Node) -> Tuple[Node, Node]: 14 | """Conditionally swap inputs `lhs` and `rhs` such that `lhs <= rhs`. 15 | 16 | Returns: 17 | A tuple representing (lower, higher) 18 | """ 19 | teq = lhs < rhs 20 | 21 | first = teq * (lhs - rhs) + rhs 22 | second = lhs + rhs - first 23 | 24 | return ( 25 | first, 26 | second, 27 | ) 28 | 29 | 30 | def gen_naive_sort_circuit(inputs: Sequence[int], gf: Type[FieldArray]) -> ArithmeticCircuit: 31 | """Returns a naive sorting circuit for the given sequence of `inputs`.""" 32 | input_nodes = [Input(f"Input {v}", gf) for v in inputs] 33 | 34 | outputs = [n for n in input_nodes] 35 | 36 | for i in range(len(outputs) - 1, -1, -1): 37 | for j in range(i): 38 | outputs[j], outputs[j + 1] = cswp(outputs[j], outputs[j + 1]) # type: ignore 39 | 40 | return Circuit(outputs).arithmetize() # type: ignore 41 | 42 | 43 | if __name__ == "__main__": 44 | circuit = gen_naive_sort_circuit(range(2), gf) 45 | circuit.to_graph("sorting.dot") 46 | -------------------------------------------------------------------------------- /oraqle/circuits/veto_voting.py: -------------------------------------------------------------------------------- 1 | """The veto voting circuit is the inverse of a consensus vote between a number of participants. 2 | 3 | The circuit is essentially a large OR operation, returning 1 if any participant vetoes (by submitting a 1). 4 | This represents a vote that anyone can veto. 5 | """ 6 | from typing import Type 7 | 8 | from galois import GF, FieldArray 9 | 10 | from oraqle.compiler.boolean.bool_or import any_ 11 | from oraqle.compiler.circuit import Circuit 12 | from oraqle.compiler.nodes import Input 13 | 14 | gf = GF(103) 15 | 16 | 17 | def gen_veto_voting_circuit(participants: int, gf: Type[FieldArray]): 18 | """Returns a veto voting circuit between the number of `participants`.""" 19 | input_nodes = {Input(f"Input {i}", gf) for i in range(participants)} 20 | return Circuit([any_(*input_nodes)]) 21 | 22 | 23 | if __name__ == "__main__": 24 | circuit = gen_veto_voting_circuit(10, gf).arithmetize() 25 | 26 | circuit.eliminate_subexpressions() 27 | circuit.to_graph("veto-voting.dot") 28 | -------------------------------------------------------------------------------- /oraqle/compiler/__init__.py: -------------------------------------------------------------------------------- 1 | """The compiler package contains the main machinery for describing high-level circuits, arithmetizing them, and generating code.""" 2 | -------------------------------------------------------------------------------- /oraqle/compiler/arithmetic/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains classes for arithmetic operations that are not simply additions or multiplications.""" 2 | -------------------------------------------------------------------------------- /oraqle/compiler/arithmetic/exponentiation.py: -------------------------------------------------------------------------------- 1 | """This module contains classes and functions for efficient exponentiation circuits.""" 2 | import math 3 | from typing import Type 4 | 5 | from galois import GF, FieldArray 6 | 7 | from oraqle.add_chains.addition_chains_front import gen_pareto_front 8 | from oraqle.add_chains.addition_chains_heuristic import add_chain_guaranteed 9 | from oraqle.add_chains.solving import extract_indices 10 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node 11 | from oraqle.compiler.nodes.binary_arithmetic import Multiplication 12 | from oraqle.compiler.nodes.leafs import Input 13 | from oraqle.compiler.nodes.univariate import UnivariateNode 14 | 15 | 16 | # TODO: Think about the role of Power when there are also Products 17 | class Power(UnivariateNode): 18 | """Represents an exponentiation: x ** constant.""" 19 | 20 | @property 21 | def _node_shape(self) -> str: 22 | return "box" 23 | 24 | @property 25 | def _hash_name(self) -> str: 26 | return f"pow_{self._exponent}" 27 | 28 | @property 29 | def _node_label(self) -> str: 30 | return f"Pow: {self._exponent}" 31 | 32 | def __init__(self, node: Node, exponent: int, gf: Type[FieldArray]): 33 | """Initialize a `Power` node that exponentiates `node` with `exponent`.""" 34 | self._exponent = exponent 35 | super().__init__(node, gf) 36 | 37 | def _operation_inner(self, input: FieldArray, gf: Type[FieldArray]) -> FieldArray: 38 | return input**self._exponent # type: ignore 39 | 40 | def _arithmetize_inner(self, strategy: str) -> "Node": 41 | if strategy == "naive": 42 | # Square & multiply 43 | nodes = [self._node.arithmetize(strategy)] 44 | 45 | for i in range(math.ceil(math.log2(self._exponent))): 46 | nodes.append(nodes[i].mul(nodes[i], flatten=False)) 47 | previous = None 48 | for i in range(math.ceil(math.log2(self._exponent))): 49 | if (self._exponent >> i) & 1: 50 | if previous is None: 51 | previous = nodes[i] 52 | else: 53 | nodes.append(nodes[i].mul(previous, flatten=False)) 54 | previous = nodes[-1] 55 | 56 | assert previous is not None 57 | return previous 58 | 59 | assert strategy == "best-effort" 60 | 61 | addition_chain = add_chain_guaranteed(self._exponent, self._gf.characteristic - 1, squaring_cost=1.0) 62 | 63 | nodes = [self._node.arithmetize(strategy).to_arithmetic()] 64 | 65 | for i, j in addition_chain: 66 | nodes.append(Multiplication(nodes[i], nodes[j], self._gf)) 67 | 68 | return nodes[-1] 69 | 70 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 71 | # TODO: While generating the front, we can take into account the maximum cost etc. implied by the depth-aware arithmetization of the operand 72 | if self._gf.characteristic <= 257: 73 | front = gen_pareto_front(self._exponent, self._gf.characteristic, cost_of_squaring) 74 | else: 75 | front = gen_pareto_front(self._exponent, None, cost_of_squaring) 76 | 77 | final_front = CostParetoFront(cost_of_squaring) 78 | 79 | for depth1, _, node in self._node.arithmetize_depth_aware(cost_of_squaring): 80 | for depth2, chain in front: 81 | c = extract_indices( 82 | chain, 83 | modulus=self._gf.characteristic - 1 if self._gf.characteristic <= 257 else None, 84 | ) 85 | 86 | nodes = [node] 87 | 88 | for i, j in c: 89 | nodes.append(Multiplication(nodes[i], nodes[j], self._gf)) 90 | 91 | final_front.add(nodes[-1], depth=depth1 + depth2) 92 | 93 | return final_front 94 | 95 | 96 | def test_depth_aware_arithmetization(): # noqa: D103 97 | gf = GF(31) 98 | 99 | x = Input("x", gf) 100 | node = Power(x, 30, gf) 101 | front = node.arithmetize_depth_aware(cost_of_squaring=1.0) 102 | node.clear_cache(set()) 103 | 104 | for _, _, n in front: 105 | assert n.evaluate({"x": gf(0)}) == 0 106 | n.clear_cache(set()) 107 | 108 | for xx in range(1, 31): 109 | assert n.evaluate({"x": gf(xx)}) == 1 110 | -------------------------------------------------------------------------------- /oraqle/compiler/arithmetic/subtraction.py: -------------------------------------------------------------------------------- 1 | """This module contains classes for representing subtraction: x - y.""" 2 | from galois import GF, FieldArray 3 | 4 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node 5 | from oraqle.compiler.nodes.leafs import Constant, Input 6 | from oraqle.compiler.nodes.non_commutative import NonCommutativeBinaryNode 7 | 8 | 9 | class Subtraction(NonCommutativeBinaryNode): 10 | """Represents a subtraction, which can be arithmetized using addition and constant-multiplication.""" 11 | 12 | @property 13 | def _overriden_graphviz_attributes(self) -> dict: 14 | return {"shape": "square", "style": "rounded,filled", "fillcolor": "cornsilk"} 15 | 16 | @property 17 | def _hash_name(self) -> str: 18 | return "sub" 19 | 20 | @property 21 | def _node_label(self) -> str: 22 | return "-" 23 | 24 | def _operation_inner(self, x, y) -> FieldArray: 25 | return x - y 26 | 27 | def _arithmetize_inner(self, strategy: str) -> Node: 28 | # TODO: Reorganize the files: let the arithmetic folder only contain pure arithmetic (including add and mul) and move exponentiation elsewhere. 29 | # TODO: For schemes that support subtraction we do not need to do this. We should only do this transformation during the compiler stage. 30 | return (self._left.arithmetize(strategy) + (Constant(-self._gf(1)) * self._right.arithmetize(strategy))).arithmetize(strategy) # type: ignore # TODO: Should we always perform a final arithmetization in every node for constant folding? E.g. in Node? 31 | 32 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 33 | result = self._left + (Constant(-self._gf(1)) * self._right) 34 | front = result.arithmetize_depth_aware(cost_of_squaring) 35 | return front 36 | 37 | 38 | def test_evaluate_mod5(): # noqa: D103 39 | gf = GF(5) 40 | 41 | a = Input("a", gf) 42 | b = Input("b", gf) 43 | node = Subtraction(a, b, gf) 44 | 45 | assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(1) 46 | node.clear_cache(set()) 47 | assert node.evaluate({"a": gf(4), "b": gf(1)}) == gf(3) 48 | node.clear_cache(set()) 49 | assert node.evaluate({"a": gf(1), "b": gf(3)}) == gf(3) 50 | node.clear_cache(set()) 51 | assert node.evaluate({"a": gf(0), "b": gf(4)}) == gf(1) 52 | 53 | 54 | def test_evaluate_arithmetized_mod5(): # noqa: D103 55 | gf = GF(5) 56 | 57 | a = Input("a", gf) 58 | b = Input("b", gf) 59 | node = Subtraction(a, b, gf).arithmetize("best-effort") 60 | node.clear_cache(set()) 61 | 62 | assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(1) 63 | node.clear_cache(set()) 64 | assert node.evaluate({"a": gf(4), "b": gf(1)}) == gf(3) 65 | node.clear_cache(set()) 66 | assert node.evaluate({"a": gf(1), "b": gf(3)}) == gf(3) 67 | node.clear_cache(set()) 68 | assert node.evaluate({"a": gf(0), "b": gf(4)}) == gf(1) 69 | -------------------------------------------------------------------------------- /oraqle/compiler/boolean/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains nodes for expressing common Boolean operations.""" 2 | -------------------------------------------------------------------------------- /oraqle/compiler/boolean/bool_neg.py: -------------------------------------------------------------------------------- 1 | """Classes for describing Boolean negation.""" 2 | from galois import FieldArray 3 | 4 | from oraqle.compiler.arithmetic.subtraction import Subtraction 5 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node 6 | from oraqle.compiler.nodes.leafs import Constant 7 | from oraqle.compiler.nodes.univariate import UnivariateNode 8 | 9 | 10 | class Neg(UnivariateNode): 11 | """A node that negates a Boolean input.""" 12 | 13 | @property 14 | def _node_shape(self) -> str: 15 | return "box" 16 | 17 | @property 18 | def _hash_name(self) -> str: 19 | return "neg" 20 | 21 | @property 22 | def _node_label(self) -> str: 23 | return "NEG" 24 | 25 | def _operation_inner(self, input: FieldArray) -> FieldArray: 26 | assert input in {0, 1} 27 | return self._gf(not bool(input)) 28 | 29 | def _arithmetize_inner(self, strategy: str) -> Node: 30 | return Subtraction( 31 | Constant(self._gf(1)), self._node.arithmetize(strategy), self._gf 32 | ).arithmetize(strategy) 33 | 34 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 35 | return Subtraction(Constant(self._gf(1)), self._node, self._gf).arithmetize_depth_aware( 36 | cost_of_squaring 37 | ) 38 | -------------------------------------------------------------------------------- /oraqle/compiler/boolean/bool_or.py: -------------------------------------------------------------------------------- 1 | """This module contains tools for evaluating OR operations between many inputs.""" 2 | import itertools 3 | from typing import Set 4 | 5 | from galois import GF, FieldArray 6 | 7 | from oraqle.compiler.boolean.bool_and import And, _find_depth_cost_front 8 | from oraqle.compiler.boolean.bool_neg import Neg 9 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper 10 | from oraqle.compiler.nodes.flexible import CommutativeUniqueReducibleNode 11 | from oraqle.compiler.nodes.leafs import Constant, Input 12 | 13 | # TODO: Reduce code duplication between OR and AND 14 | 15 | 16 | class Or(CommutativeUniqueReducibleNode): 17 | """Performs an OR operation over several operands. The user must ensure that the operands are Booleans.""" 18 | 19 | @property 20 | def _hash_name(self) -> str: 21 | return "or" 22 | 23 | @property 24 | def _node_label(self) -> str: 25 | return "OR" 26 | 27 | def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray: 28 | return self._gf(bool(a) | bool(b)) 29 | 30 | def _arithmetize_inner(self, strategy: str) -> Node: 31 | # FIXME: Handle what happens when arithmetize outputs a constant! 32 | # TODO: Also consider the arithmetization using randomness 33 | return Neg( 34 | And( 35 | { 36 | UnoverloadedWrapper(Neg(operand.node.arithmetize(strategy), self._gf)) 37 | for operand in self._operands 38 | }, 39 | self._gf, 40 | ), 41 | self._gf, 42 | ).arithmetize(strategy) 43 | 44 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 45 | # TODO: This is mostly copied from AND 46 | new_operands: Set[CostParetoFront] = set() 47 | for operand in self._operands: 48 | new_operand = operand.node.arithmetize_depth_aware(cost_of_squaring) 49 | new_operands.add(new_operand) 50 | 51 | if len(new_operands) == 0: 52 | return CostParetoFront.from_leaf(Constant(self._gf(1)), cost_of_squaring) 53 | elif len(new_operands) == 1: 54 | return next(iter(new_operands)) 55 | 56 | # TODO: We can check if any of the element in new_operands are constants and return early 57 | 58 | front = CostParetoFront(cost_of_squaring) 59 | 60 | # TODO: This is brute force composition 61 | for operands in itertools.product(*(iter(new_operand) for new_operand in new_operands)): 62 | checked_operands = [] 63 | for depth, cost, node in operands: 64 | if isinstance(node, Constant): 65 | assert node._value in {0, 1} 66 | if node._value == 0: 67 | return CostParetoFront.from_leaf(Constant(self._gf(0)), cost_of_squaring) 68 | else: 69 | checked_operands.append((depth, cost, node)) 70 | 71 | if len(checked_operands) == 0: 72 | return CostParetoFront.from_leaf(Constant(self._gf(1)), cost_of_squaring) 73 | 74 | if len(checked_operands) == 1: 75 | depth, cost, node = checked_operands[0] 76 | front.add(node, depth, cost) 77 | continue 78 | 79 | this_front = _find_depth_cost_front( 80 | checked_operands, 81 | self._gf, 82 | float("inf"), 83 | squaring_cost=cost_of_squaring, 84 | is_and=False, 85 | ) 86 | front.add_front(this_front) 87 | 88 | return front 89 | 90 | def or_flatten(self, other: Node) -> Node: 91 | """Performs an OR operation with `other`, flattening the `Or` node if either of the two is also an `Or` and absorbing `Constant`s. 92 | 93 | Returns: 94 | An `Or` node containing the flattened OR operation, or a `Constant` node. 95 | """ 96 | if isinstance(other, Constant): 97 | if bool(other._value): 98 | return Constant(self._gf(1)) 99 | else: 100 | return self 101 | 102 | if isinstance(other, Or): 103 | return Or(self._operands | other._operands, self._gf) 104 | 105 | new_operands = self._operands.copy() 106 | new_operands.add(UnoverloadedWrapper(other)) 107 | return Or(new_operands, self._gf) 108 | 109 | 110 | def any_(*operands: Node) -> Or: 111 | """Returns an `Or` node that evaluates to true if any of the given `operands` evaluates to true.""" 112 | assert len(operands) > 0 113 | return Or(set(UnoverloadedWrapper(operand) for operand in operands), operands[0]._gf) 114 | 115 | 116 | def test_evaluate_mod3(): # noqa: D103 117 | gf = GF(3) 118 | 119 | a = Input("a", gf) 120 | b = Input("b", gf) 121 | node = a | b 122 | 123 | assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) 124 | node.clear_cache(set()) 125 | assert node.evaluate({"a": gf(0), "b": gf(1)}) == gf(1) 126 | node.clear_cache(set()) 127 | assert node.evaluate({"a": gf(1), "b": gf(0)}) == gf(1) 128 | node.clear_cache(set()) 129 | assert node.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) 130 | 131 | 132 | def test_evaluate_arithmetized_depth_aware_mod2(): # noqa: D103 133 | gf = GF(2) 134 | 135 | a = Input("a", gf) 136 | b = Input("b", gf) 137 | node = a | b 138 | front = node.arithmetize_depth_aware(cost_of_squaring=1.0) 139 | 140 | for _, _, n in front: 141 | n.clear_cache(set()) 142 | assert n.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) 143 | n.clear_cache(set()) 144 | assert n.evaluate({"a": gf(0), "b": gf(1)}) == gf(1) 145 | n.clear_cache(set()) 146 | assert n.evaluate({"a": gf(1), "b": gf(0)}) == gf(1) 147 | n.clear_cache(set()) 148 | assert n.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) 149 | 150 | 151 | def test_evaluate_arithmetized_mod3(): # noqa: D103 152 | gf = GF(3) 153 | 154 | a = Input("a", gf) 155 | b = Input("b", gf) 156 | node = (a | b).arithmetize("best-effort") 157 | 158 | node.clear_cache(set()) 159 | assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0) 160 | node.clear_cache(set()) 161 | assert node.evaluate({"a": gf(0), "b": gf(1)}) == gf(1) 162 | node.clear_cache(set()) 163 | assert node.evaluate({"a": gf(1), "b": gf(0)}) == gf(1) 164 | node.clear_cache(set()) 165 | assert node.evaluate({"a": gf(1), "b": gf(1)}) == gf(1) 166 | 167 | 168 | def test_evaluate_arithmetized_depth_aware_50_mod31(): # noqa: D103 169 | gf = GF(31) 170 | 171 | xs = {Input(f"x{i}", gf) for i in range(50)} 172 | node = Or({UnoverloadedWrapper(x) for x in xs}, gf) 173 | front = node.arithmetize_depth_aware(cost_of_squaring=1.0) 174 | 175 | for _, _, n in front: 176 | n.clear_cache(set()) 177 | assert n.evaluate({f"x{i}": gf(0) for i in range(50)}) == gf(0) 178 | n.clear_cache(set()) 179 | assert n.evaluate({f"x{i}": gf(i % 2) for i in range(50)}) == gf(1) 180 | n.clear_cache(set()) 181 | assert n.evaluate({f"x{i}": gf(1) for i in range(50)}) == gf(1) 182 | -------------------------------------------------------------------------------- /oraqle/compiler/comparison/__init__.py: -------------------------------------------------------------------------------- 1 | """Package containing tools for expressing equality and comparison operations.""" 2 | -------------------------------------------------------------------------------- /oraqle/compiler/comparison/equality.py: -------------------------------------------------------------------------------- 1 | """This module contains classes for representing equality checks.""" 2 | from galois import GF, FieldArray 3 | 4 | from oraqle.compiler.arithmetic.exponentiation import Power 5 | from oraqle.compiler.arithmetic.subtraction import Subtraction 6 | from oraqle.compiler.boolean.bool_neg import Neg 7 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node 8 | from oraqle.compiler.nodes.binary_arithmetic import CommutativeBinaryNode 9 | from oraqle.compiler.nodes.leafs import Input 10 | from oraqle.compiler.nodes.univariate import UnivariateNode 11 | 12 | 13 | class IsNonZero(UnivariateNode): 14 | """This node represents a zero check: x == 0.""" 15 | 16 | @property 17 | def _node_shape(self) -> str: 18 | return "box" 19 | 20 | @property 21 | def _hash_name(self) -> str: 22 | return "is_nonzero" 23 | 24 | @property 25 | def _node_label(self) -> str: 26 | return "!= 0" 27 | 28 | def _operation_inner(self, input: FieldArray) -> FieldArray: 29 | return input != 0 30 | 31 | def _arithmetize_inner(self, strategy: str) -> Node: 32 | return Power(self._node, self._gf.order - 1, self._gf).arithmetize(strategy) 33 | 34 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 35 | return Power(self._node, self._gf.order - 1, self._gf).arithmetize_depth_aware( 36 | cost_of_squaring 37 | ) 38 | 39 | 40 | class Equals(CommutativeBinaryNode): 41 | """This node represents an equality operation: x == y.""" 42 | 43 | @property 44 | def _hash_name(self) -> str: 45 | return "equals" 46 | 47 | @property 48 | def _node_label(self) -> str: 49 | return "==" 50 | 51 | def _operation_inner(self, x, y) -> FieldArray: 52 | return self._gf(int(x == y)) 53 | 54 | def _arithmetize_inner(self, strategy: str) -> Node: 55 | return Neg( 56 | IsNonZero(Subtraction(self._left, self._right, self._gf), self._gf), 57 | self._gf, 58 | ).arithmetize(strategy) 59 | 60 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 61 | return Neg( 62 | IsNonZero(Subtraction(self._left, self._right, self._gf), self._gf), 63 | self._gf, 64 | ).arithmetize_depth_aware(cost_of_squaring) 65 | 66 | 67 | def test_evaluate_mod5(): # noqa: D103 68 | gf = GF(5) 69 | 70 | a = Input("a", gf) 71 | b = Input("b", gf) 72 | node = Equals(a, b, gf) 73 | 74 | assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(0) 75 | node.clear_cache(set()) 76 | assert node.evaluate({"a": gf(4), "b": gf(4)}) == gf(1) 77 | node.clear_cache(set()) 78 | assert node.evaluate({"a": gf(1), "b": gf(2)}) == gf(0) 79 | node.clear_cache(set()) 80 | assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(1) 81 | 82 | 83 | def test_evaluate_arithmetized_mod5(): # noqa: D103 84 | gf = GF(5) 85 | 86 | a = Input("a", gf) 87 | b = Input("b", gf) 88 | node = Equals(a, b, gf).arithmetize("best-effort") 89 | node.clear_cache(set()) 90 | 91 | assert node.evaluate({"a": gf(3), "b": gf(2)}) == gf(0) 92 | node.clear_cache(set()) 93 | assert node.evaluate({"a": gf(4), "b": gf(4)}) == gf(1) 94 | node.clear_cache(set()) 95 | assert node.evaluate({"a": gf(1), "b": gf(2)}) == gf(0) 96 | node.clear_cache(set()) 97 | assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(1) 98 | 99 | 100 | def test_equality_equivalence_commutative(): # noqa: D103 101 | gf = GF(5) 102 | 103 | a = Input("a", gf) 104 | b = Input("b", gf) 105 | 106 | assert (a == b).is_equivalent(b == a) 107 | -------------------------------------------------------------------------------- /oraqle/compiler/control_flow/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains control flow functions.""" 2 | -------------------------------------------------------------------------------- /oraqle/compiler/control_flow/conditional.py: -------------------------------------------------------------------------------- 1 | """This module contains tools for evaluating conditional statements.""" 2 | from typing import List, Type 3 | 4 | from galois import GF, FieldArray 5 | 6 | from oraqle.compiler.circuit import Circuit 7 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node 8 | from oraqle.compiler.nodes.fixed import FixedNode 9 | from oraqle.compiler.nodes.leafs import Constant, Input 10 | 11 | 12 | class IfElse(FixedNode): 13 | """A node representing an if-else clause.""" 14 | 15 | @property 16 | def _node_label(self): 17 | return "If" 18 | 19 | @property 20 | def _hash_name(self): 21 | return "if_else" 22 | 23 | def __init__(self, condition: Node, positive: Node, negative: Node, gf: Type[FieldArray]): 24 | """Initialize an if-else node: If condition evaluates to true, then it outputs positive, otherwise it outputs negative.""" 25 | self._condition = condition 26 | self._positive = positive 27 | self._negative = negative 28 | super().__init__(gf) 29 | 30 | def __hash__(self) -> int: 31 | return hash((self._hash_name, self._condition, self._positive, self._negative)) 32 | 33 | def is_equivalent(self, other: Node) -> bool: # noqa: D102 34 | if not isinstance(other, self.__class__): 35 | return False 36 | 37 | return ( 38 | self._condition.is_equivalent(other._condition) 39 | and self._positive.is_equivalent(other._positive) 40 | and self._negative.is_equivalent(other._negative) 41 | ) 42 | 43 | def operands(self) -> List[Node]: # noqa: D102 44 | return [self._condition, self._positive, self._negative] 45 | 46 | def set_operands(self, operands: List[Node]): # noqa: D102 47 | self._condition = operands[0] 48 | self._positive = operands[1] 49 | self._negative = operands[2] 50 | 51 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 52 | assert operands[0] == 0 or operands[0] == 1 53 | return operands[1] if operands[0] == 1 else operands[2] 54 | 55 | def _arithmetize_inner(self, strategy: str) -> Node: 56 | return (self._condition * (self._positive - self._negative) + self._negative).arithmetize( 57 | strategy 58 | ) 59 | 60 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 61 | return ( 62 | self._condition * (self._positive - self._negative) + self._negative 63 | ).arithmetize_depth_aware(cost_of_squaring) 64 | 65 | 66 | def if_else(condition: Node, positive: Node, negative: Node) -> IfElse: 67 | """Sugar expression for creating an if-else clause. 68 | 69 | Returns: 70 | An `IfElse` node that equals `positive` if `condition` is true, and `negative` otherwise. 71 | """ 72 | assert condition._gf == positive._gf 73 | assert condition._gf == negative._gf 74 | return IfElse(condition, positive, negative, condition._gf) 75 | 76 | 77 | def test_if_else(): # noqa: D103 78 | gf = GF(11) 79 | 80 | a = Input("a", gf) 81 | b = Input("b", gf) 82 | 83 | output = if_else(a == b, Constant(gf(3)), Constant(gf(5))) 84 | 85 | circuit = Circuit([output]) 86 | 87 | for val_a in range(11): 88 | for val_b in range(11): 89 | expected = gf(3) if val_a == val_b else gf(5) 90 | 91 | values = {"a": gf(val_a), "b": gf(val_b)} 92 | assert circuit.evaluate(values) == expected 93 | 94 | 95 | def test_if_else_arithmetized(): # noqa: D103 96 | gf = GF(11) 97 | 98 | a = Input("a", gf) 99 | b = Input("b", gf) 100 | 101 | output = if_else(a == b, Constant(gf(3)), Constant(gf(5))) 102 | 103 | arithmetic_circuit = Circuit([output]).arithmetize() 104 | 105 | for val_a in range(11): 106 | for val_b in range(11): 107 | expected = gf(3) if val_a == val_b else gf(5) 108 | 109 | values = {"a": gf(val_a), "b": gf(val_b)} 110 | assert arithmetic_circuit.evaluate(values) == expected 111 | -------------------------------------------------------------------------------- /oraqle/compiler/func2poly.py: -------------------------------------------------------------------------------- 1 | """Tools for interpolating polynomials from arbitrary functions.""" 2 | import itertools 3 | from typing import Callable, List 4 | 5 | from sympy import Poly, symbols 6 | 7 | 8 | def principal_character(x, prime_modulus): 9 | """Computes the principal character. This expression always returns 1 when x = 0 and 0 otherwise. Only works for prime moduli. 10 | 11 | Returns: 12 | The principal character x**(p-1). 13 | """ 14 | return x ** (prime_modulus - 1) 15 | 16 | 17 | def interpolate_polynomial( 18 | function: Callable[..., int], prime_modulus: int, input_names: List[str] 19 | ) -> Poly: 20 | """Interpolates a polynomial for the given function. This is currently only implemented for prime moduli. This function interpolates the polynomial on all possible inputs. 21 | 22 | Returns: 23 | A sympy `Poly` object representing the unique polynomial that evaluates to the same outputs for all inputs as `function`. 24 | """ 25 | variables = symbols(input_names) 26 | poly = 0 27 | 28 | for inputs in itertools.product(range(prime_modulus), repeat=len(input_names)): 29 | output = function(*inputs) 30 | assert 0 <= output < prime_modulus 31 | 32 | product = output 33 | for input, variable in zip(inputs, variables): 34 | product *= Poly( 35 | 1 - principal_character(variable - input, prime_modulus), 36 | variable, 37 | modulus=prime_modulus, 38 | ) 39 | product = Poly(product, variables, modulus=prime_modulus) 40 | 41 | poly += product 42 | poly = Poly(poly, variables, modulus=prime_modulus) 43 | 44 | return Poly(poly, variables, modulus=prime_modulus) 45 | -------------------------------------------------------------------------------- /oraqle/compiler/graphviz.py: -------------------------------------------------------------------------------- 1 | """This module contains classes and functions for visualizing circuits using graphviz.""" 2 | from typing import Dict, List, Tuple 3 | 4 | expensive_style = {"shape": "diamond"} 5 | 6 | 7 | class DotFile: 8 | """A `DotFile` is a graph description format that can be rendered to e.g. PDF using graphviz.""" 9 | 10 | def __init__(self): 11 | """Initialize an empty DotFile.""" 12 | self._nodes: List[Dict[str, str]] = [] 13 | self._links: List[Tuple[int, int, Dict[str, str]]] = [] 14 | 15 | def add_node(self, **kwargs) -> int: 16 | """Adds a node to the file. The keyword arguments are directly put into the DOT file. 17 | 18 | For example, one can specify a label, a color, a style, etc... 19 | 20 | Returns: 21 | The identifier of this node in this `DotFile`. 22 | """ 23 | node_id = len(self._nodes) 24 | self._nodes.append(kwargs) 25 | 26 | return node_id 27 | 28 | def add_link(self, from_id: int, to_id: int, **kwargs): 29 | """Adds an unformatted link between the nodes with `from_id` and `to_id`. The keyword arguments are directly put into the DOT file.""" 30 | self._links.append((from_id, to_id, kwargs)) 31 | 32 | def to_file(self, filename: str): 33 | """Writes the DOT file to the given filename as a directed graph called 'G'.""" 34 | with open(filename, mode="w", encoding="utf-8") as file: 35 | file.write("digraph G {\n") 36 | file.write('forcelabels="true";\n') 37 | file.write("graph [nodesep=0.25,ranksep=0.6];") # nodesep, ranksep 38 | 39 | # Write all the nodes 40 | for node_id, attributes in enumerate(self._nodes): 41 | transformed_attributes = ",".join( 42 | [f'{key}="{value}"' for key, value in attributes.items()] 43 | ) 44 | file.write(f"n{node_id} [{transformed_attributes}];\n") 45 | 46 | # Write all the links 47 | for from_id, to_id, attributes in self._links: 48 | if len(attributes) == 0: 49 | file.write(f"n{from_id}->n{to_id};\n") 50 | else: 51 | text = f"n{from_id}->n{to_id} [" 52 | text += ",".join((f"{key}={value}" for key, value in attributes.items())) 53 | text += "];\n" 54 | file.write(text) 55 | 56 | file.write("}\n") 57 | -------------------------------------------------------------------------------- /oraqle/compiler/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | """The nodes package contains a collection of fundamental abstract and concrete nodes.""" 2 | from oraqle.compiler.nodes.abstract import Node 3 | from oraqle.compiler.nodes.binary_arithmetic import Addition, Multiplication 4 | from oraqle.compiler.nodes.leafs import Constant, Input 5 | 6 | __all__ = ['Addition', 'Constant', 'Input', 'Multiplication', 'Node'] 7 | -------------------------------------------------------------------------------- /oraqle/compiler/nodes/binary_arithmetic.py: -------------------------------------------------------------------------------- 1 | """Module containing binary arithmetic nodes: additions and multiplications between non-constant nodes.""" 2 | from abc import abstractmethod 3 | from typing import List, Optional, Set, Tuple, Type 4 | 5 | from galois import FieldArray 6 | 7 | from oraqle.compiler.instructions import ( 8 | AdditionInstruction, 9 | ArithmeticInstruction, 10 | MultiplicationInstruction, 11 | ) 12 | from oraqle.compiler.nodes.abstract import ( 13 | ArithmeticNode, 14 | CostParetoFront, 15 | Node, 16 | iterate_increasing_depth, 17 | select_stack_index, 18 | ) 19 | from oraqle.compiler.nodes.fixed import BinaryNode 20 | from oraqle.compiler.nodes.leafs import Constant 21 | 22 | 23 | class CommutativeBinaryNode(BinaryNode): 24 | """This node has two operands and implements a commutative operation between arithmetic nodes.""" 25 | 26 | def __init__( 27 | self, 28 | left: Node, 29 | right: Node, 30 | gf: Type[FieldArray], 31 | ): 32 | """Initialize the binary node with operands `left` and `right`.""" 33 | self._left = left 34 | self._right = right 35 | super().__init__(gf) 36 | 37 | @abstractmethod 38 | def _operation_inner(self, x: FieldArray, y: FieldArray) -> FieldArray: 39 | """Applies the binary operation on x and y.""" 40 | 41 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 42 | return self._operation_inner(operands[0], operands[1]) 43 | 44 | def operands(self) -> List[Node]: # noqa: D102 45 | return [self._left, self._right] 46 | 47 | def set_operands(self, operands: List[ArithmeticNode]): # noqa: D102 48 | self._left = operands[0] 49 | self._right = operands[1] 50 | 51 | def __hash__(self) -> int: 52 | if self._hash is None: 53 | left_hash = hash(self._left) 54 | right_hash = hash(self._right) 55 | 56 | # Make the hash commutative 57 | if left_hash < right_hash: 58 | self._hash = hash((self._hash_name, (left_hash, right_hash))) 59 | else: 60 | self._hash = hash((self._hash_name, (right_hash, left_hash))) 61 | 62 | return self._hash 63 | 64 | def is_equivalent(self, other: Node) -> bool: # noqa: D102 65 | if not isinstance(other, self.__class__): 66 | return False 67 | 68 | if hash(self) != hash(other): 69 | return False 70 | 71 | # Equivalence by commutative equality 72 | return ( 73 | self._left.is_equivalent(other._left) and self._right.is_equivalent(other._right) 74 | ) or (self._left.is_equivalent(other._right) and self._right.is_equivalent(other._left)) 75 | 76 | 77 | class CommutativeArithmeticBinaryNode(CommutativeBinaryNode): 78 | """This node has two operands and implements a commutative operation between arithmetic nodes.""" 79 | 80 | def __init__( 81 | self, 82 | left: ArithmeticNode, 83 | right: ArithmeticNode, 84 | gf: Type[FieldArray], 85 | ): 86 | """Initialize this binary node with the given `left` and `right` operands. 87 | 88 | Raises: 89 | Exception: Neither `left` nor `right` is allowed to be a `Constant`. 90 | """ 91 | super().__init__(left, right, gf) 92 | 93 | self._multiplications: Optional[Set[int]] = None 94 | self._squarings: Optional[Set[int]] = None 95 | self._depth_cache: Optional[int] = None 96 | 97 | if isinstance(left, Constant) or isinstance(right, Constant): 98 | self._is_multiplication = False 99 | raise Exception("This should be a constant.") 100 | 101 | def multiplicative_depth(self) -> int: # noqa: D102 102 | if self._depth_cache is None: 103 | self._depth_cache = self._is_multiplication + max( 104 | self._left.multiplicative_depth(), self._right.multiplicative_depth() 105 | ) 106 | 107 | return self._depth_cache 108 | 109 | def multiplications(self) -> Set[int]: # noqa: D102 110 | if self._multiplications is None: 111 | self._multiplications = set().union( 112 | *(operand.multiplications() for operand in self.operands()) # type: ignore 113 | ) 114 | if self._is_multiplication: 115 | self._multiplications.add(id(self)) 116 | 117 | return self._multiplications 118 | 119 | # TODO: Squaring should probably be a UniveriateNode 120 | def squarings(self) -> Set[int]: # noqa: D102 121 | if self._squarings is None: 122 | self._squarings = set().union(*(operand.squarings() for operand in self.operands())) # type: ignore 123 | if self._is_multiplication and id(self._left) == id(self._right): 124 | self._squarings.add(id(self)) 125 | 126 | return self._squarings 127 | 128 | def create_instructions( # noqa: D102 129 | self, 130 | instructions: List[ArithmeticInstruction], 131 | stack_counter: int, 132 | stack_occupied: List[bool], 133 | ) -> Tuple[int, int]: 134 | self._left: ArithmeticNode 135 | self._right: ArithmeticNode 136 | 137 | if self._instruction_cache is None: 138 | left_index, stack_counter = self._left.create_instructions( 139 | instructions, stack_counter, stack_occupied 140 | ) 141 | right_index, stack_counter = self._right.create_instructions( 142 | instructions, stack_counter, stack_occupied 143 | ) 144 | 145 | # FIXME: Is it possible for e.g. self._left._instruction_cache to be None? 146 | 147 | self._left._parent_count -= 1 148 | if self._left._parent_count == 0: 149 | stack_occupied[self._left._instruction_cache] = False # type: ignore 150 | 151 | self._right._parent_count -= 1 152 | if self._right._parent_count == 0: 153 | stack_occupied[self._right._instruction_cache] = False # type: ignore 154 | 155 | self._instruction_cache = select_stack_index(stack_occupied) 156 | 157 | if self._is_multiplication: 158 | instructions.append( 159 | MultiplicationInstruction(self._instruction_cache, left_index, right_index) 160 | ) 161 | else: 162 | instructions.append( 163 | AdditionInstruction(self._instruction_cache, left_index, right_index) 164 | ) 165 | 166 | return self._instruction_cache, stack_counter 167 | 168 | 169 | # FIXME: This order should probably change 170 | class Addition(CommutativeArithmeticBinaryNode, ArithmeticNode): 171 | """Performs modular addition of two previous nodes in an arithmetic circuit.""" 172 | 173 | @property 174 | def _overriden_graphviz_attributes(self) -> dict: 175 | return {"shape": "square", "style": "rounded,filled", "fillcolor": "grey80"} 176 | 177 | @property 178 | def _hash_name(self) -> str: 179 | return "add" 180 | 181 | @property 182 | def _node_label(self) -> str: 183 | return "+" 184 | 185 | def __init__( 186 | self, 187 | left: ArithmeticNode, 188 | right: ArithmeticNode, 189 | gf: Type[FieldArray], 190 | ): 191 | """Initialize a modular addition between `left` and `right`.""" 192 | self._is_multiplication = False 193 | super().__init__(left, right, gf) 194 | 195 | def _operation_inner(self, x, y): 196 | return x + y 197 | 198 | def arithmetize(self, strategy: str) -> Node: # noqa: D102 199 | self._left = self._left.arithmetize(strategy) 200 | self._right = self._right.arithmetize(strategy) 201 | return self 202 | 203 | def _arithmetize_inner(self, strategy: str) -> Node: 204 | raise NotImplementedError() 205 | 206 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 207 | front = CostParetoFront(cost_of_squaring) 208 | 209 | for res1, res2 in iterate_increasing_depth( 210 | self._left.arithmetize_depth_aware(cost_of_squaring), 211 | self._right.arithmetize_depth_aware(cost_of_squaring), 212 | ): 213 | d1, _, e1 = res1 214 | d2, _, e2 = res2 215 | 216 | # TODO: Do we use + here for flattening? 217 | front.add(Addition(e1, e2, self._gf), depth=max(d1, d2)) 218 | 219 | assert not front.is_empty() 220 | return front 221 | 222 | 223 | class Multiplication(CommutativeArithmeticBinaryNode, ArithmeticNode): 224 | """Performs modular multiplication of two previous nodes in an arithmetic circuit.""" 225 | 226 | @property 227 | def _overriden_graphviz_attributes(self) -> dict: 228 | return {"shape": "square", "style": "rounded,filled", "fillcolor": "lightpink"} 229 | 230 | @property 231 | def _hash_name(self) -> str: 232 | return "mul" 233 | 234 | @property 235 | def _node_label(self) -> str: 236 | return "×" # noqa: RUF001 237 | 238 | def __init__( 239 | self, 240 | left: ArithmeticNode, 241 | right: ArithmeticNode, 242 | gf: Type[FieldArray], 243 | ): 244 | """Initialize a modular multiplication between `left` and `right`.""" 245 | assert isinstance(left, ArithmeticNode) 246 | assert isinstance(right, ArithmeticNode) 247 | 248 | self._is_multiplication = True 249 | super().__init__(left, right, gf) 250 | 251 | def _operation_inner(self, x, y): 252 | return x * y 253 | 254 | # TODO: This is very hacky! Arithmetic nodes should simply not have to be arithmetized... 255 | def arithmetize(self, strategy: str) -> Node: # noqa: D102 256 | self._left = self._left.arithmetize(strategy) 257 | self._right = self._right.arithmetize(strategy) 258 | return self 259 | 260 | def _arithmetize_inner(self, strategy: str) -> Node: 261 | raise NotImplementedError() 262 | 263 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 264 | return CostParetoFront.from_node(self, cost_of_squaring) 265 | -------------------------------------------------------------------------------- /oraqle/compiler/nodes/fixed.py: -------------------------------------------------------------------------------- 1 | """Module containing fixed nodes: nodes with a fixed number of inputs.""" 2 | from abc import abstractmethod 3 | from typing import Callable, Dict, List 4 | 5 | from galois import FieldArray 6 | 7 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node 8 | 9 | 10 | class FixedNode(Node): 11 | """A node with a fixed number of operands.""" 12 | 13 | @abstractmethod 14 | def operands(self) -> List["Node"]: 15 | """Returns the operands (children) of this node. The list can be empty.""" 16 | 17 | @abstractmethod 18 | def set_operands(self, operands: List["Node"]): 19 | """Overwrites the operands of this node.""" 20 | # TODO: Consider replacing this method with a graph traversal method that applies a function on all operands and replaces them. 21 | 22 | 23 | def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102 24 | for operand in self.operands(): 25 | function(operand) 26 | 27 | 28 | def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102 29 | self.set_operands([function(operand) for operand in self.operands()]) 30 | # TODO: These caches should only be cleared if this is an ArithmeticNode 31 | self._multiplications = None 32 | self._squarings = None 33 | self._depth_cache = None 34 | 35 | 36 | def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 37 | # TODO: Remove modulus in this method and store it in each node instead. Alternatively, add `modulus` to methods such as `flatten` as well. 38 | if self._evaluate_cache is None: 39 | self._evaluate_cache = self.operation( 40 | [operand.evaluate(actual_inputs) for operand in self.operands()] 41 | ) 42 | 43 | return self._evaluate_cache 44 | 45 | @abstractmethod 46 | def operation(self, operands: List[FieldArray]) -> FieldArray: 47 | """Evaluates this node on the specified operands.""" 48 | 49 | def arithmetize(self, strategy: str) -> "Node": # noqa: D102 50 | if self._arithmetize_cache is None: 51 | if self._arithmetize_depth_cache is not None: 52 | return self._arithmetize_depth_cache.get_lowest_value() # type: ignore 53 | 54 | # If we know all operands we can simply evaluate this node 55 | operands = self.operands() 56 | if len(operands) > 0 and all( 57 | hasattr(operand, "_value") for operand in operands 58 | ): # This is a hacky way of checking whether the operands are all constant 59 | from oraqle.compiler.nodes.leafs import Constant 60 | 61 | self._arithmetize_cache = Constant(self.operation([operand._value for operand in self.operands()])) # type: ignore 62 | else: 63 | self._arithmetize_cache = self._arithmetize_inner(strategy) 64 | 65 | return self._arithmetize_cache 66 | 67 | @abstractmethod 68 | def _arithmetize_inner(self, strategy: str) -> "Node": 69 | pass 70 | 71 | # TODO: Reduce code duplication 72 | 73 | def arithmetize_depth_aware(self, cost_of_squaring: float) -> CostParetoFront: # noqa: D102 74 | if self._arithmetize_depth_cache is None: 75 | if self._arithmetize_cache is not None: 76 | raise Exception("This should not happen") 77 | 78 | # If we know all operands we can simply evaluate this node 79 | operands = self.operands() 80 | if len(operands) > 0 and all( 81 | hasattr(operand, "_value") for operand in operands 82 | ): # This is a hacky way of checking whether the operands are all constant 83 | from oraqle.compiler.nodes.leafs import Constant 84 | 85 | self._arithmetize_depth_cache = CostParetoFront.from_leaf(Constant(self.operation([operand._value for operand in self.operands()])), cost_of_squaring) # type: ignore 86 | else: 87 | self._arithmetize_depth_cache = self._arithmetize_depth_aware_inner( 88 | cost_of_squaring 89 | ) 90 | 91 | assert self._arithmetize_depth_cache is not None 92 | return self._arithmetize_depth_cache 93 | 94 | @abstractmethod 95 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 96 | pass 97 | 98 | 99 | class BinaryNode(FixedNode): 100 | """A node with two operands.""" 101 | -------------------------------------------------------------------------------- /oraqle/compiler/nodes/flexible.py: -------------------------------------------------------------------------------- 1 | """Module containing nodes with a flexible number of operands.""" 2 | from abc import abstractmethod 3 | from collections import Counter 4 | from functools import reduce 5 | from typing import Callable 6 | from typing import Counter as CounterType 7 | from typing import Dict, Optional, Set, Type 8 | 9 | from galois import FieldArray 10 | 11 | from oraqle.compiler.graphviz import DotFile 12 | from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper 13 | from oraqle.compiler.nodes.leafs import Constant 14 | 15 | 16 | class FlexibleNode(Node): 17 | """A node with an arbitrary number of operands. The operation must be reducible using a binary associative operation.""" 18 | 19 | # TODO: Ensure that when all inputs are constants, the node is replaced with its evaluation 20 | 21 | def arithmetize(self, strategy: str) -> Node: # noqa: D102 22 | if self._arithmetize_cache is None: 23 | self._arithmetize_cache = self._arithmetize_inner(strategy) 24 | 25 | return self._arithmetize_cache 26 | 27 | @abstractmethod 28 | def _arithmetize_inner(self, strategy: str) -> "Node": 29 | pass 30 | 31 | def arithmetize_depth_aware(self, cost_of_squaring: float) -> CostParetoFront: # noqa: D102 32 | if self._arithmetize_depth_cache is None: 33 | self._arithmetize_depth_cache = self._arithmetize_depth_aware_inner(cost_of_squaring) 34 | 35 | return self._arithmetize_depth_cache 36 | 37 | @abstractmethod 38 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 39 | pass 40 | 41 | 42 | class CommutativeUniqueReducibleNode(FlexibleNode): 43 | """A node with an operation that is reducible without taking order into account: i.e. it has a binary operation that is associative and commutative. 44 | 45 | The operands are unique, i.e. the same operand will never appear twice. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | operands: Set[UnoverloadedWrapper], 51 | gf: Type[FieldArray], 52 | ): 53 | """Initialize a node with the given set as the operands. None of the operands can be a constant.""" 54 | self._operands = operands 55 | assert not any(isinstance(operand.node, Constant) for operand in self._operands) 56 | assert len(operands) > 1 57 | super().__init__(gf) 58 | 59 | def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102 60 | for operand in self._operands: 61 | function(operand.node) 62 | 63 | def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102 64 | self._operands = {UnoverloadedWrapper(function(operand.node)) for operand in self._operands} 65 | 66 | def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 67 | if self._evaluate_cache is None: 68 | self._evaluate_cache = reduce( 69 | self._inner_operation, 70 | (operand.node.evaluate(actual_inputs) for operand in self._operands), 71 | ) 72 | 73 | return self._evaluate_cache 74 | 75 | @abstractmethod 76 | def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray: 77 | """Perform the reducible operation performed by this node (order should not matter).""" 78 | 79 | def __hash__(self) -> int: 80 | if self._hash is None: 81 | # The hash is commutative 82 | hashes = sorted([hash(operand) for operand in self._operands]) 83 | self._hash = hash((self._hash_name, tuple(hashes))) 84 | 85 | return self._hash 86 | 87 | def is_equivalent(self, other: Node) -> bool: # noqa: D102 88 | if not isinstance(other, self.__class__): 89 | return False 90 | 91 | if hash(self) != hash(other): 92 | return False 93 | 94 | return self._operands == other._operands 95 | 96 | 97 | class CommutativeMultiplicityReducibleNode(FlexibleNode): 98 | """A node with an operation that is reducible without taking order into account: i.e. it has a binary operation that is associative and commutative.""" 99 | 100 | def __init__( 101 | self, 102 | operands: CounterType[UnoverloadedWrapper], 103 | gf: Type[FieldArray], 104 | constant: Optional[FieldArray] = None, 105 | ): 106 | """Initialize a reducible node with the given `Counter` representing the operands, none of which is allowed to be a constant.""" 107 | super().__init__(gf) 108 | self._constant = self._identity if constant is None else constant 109 | self._operands = operands 110 | assert not any(isinstance(operand, Constant) for operand in self._operands) 111 | assert (sum(operands.values()) + (self._constant != self._identity)) > 1 112 | assert isinstance(next(iter(self._operands)), UnoverloadedWrapper) 113 | 114 | @property 115 | @abstractmethod 116 | def _identity(self) -> FieldArray: 117 | pass 118 | 119 | def apply_function_to_operands(self, function: Callable[[Node], None]): # noqa: D102 120 | for operand in self._operands: 121 | function(operand.node) 122 | 123 | def replace_operands_using_function(self, function: Callable[[Node], Node]): # noqa: D102 124 | # FIXME: What if there is only one operand remaining? 125 | self._operands = Counter( 126 | { 127 | UnoverloadedWrapper(function(operand.node)): count 128 | for operand, count in self._operands.items() 129 | } 130 | ) 131 | assert not any(isinstance(operand.node, Constant) for operand in self._operands) 132 | assert (sum(self._operands.values()) + (self._constant != self._identity)) > 1 133 | 134 | def __hash__(self) -> int: 135 | if self._hash is None: 136 | # The hash is commutative 137 | hashes = sorted( 138 | [(hash(operand.node), count) for operand, count in self._operands.items()] 139 | ) 140 | self._hash = hash((self._hash_name, tuple(hashes), int(self._constant))) 141 | 142 | return self._hash 143 | 144 | def is_equivalent(self, other: Node) -> bool: # noqa: D102 145 | if not isinstance(other, self.__class__): 146 | return False 147 | 148 | if hash(self) != hash(other): 149 | return False 150 | 151 | return self._operands == other._operands and self._constant == other._constant 152 | 153 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 154 | if self._to_graph_cache is None: 155 | super().to_graph(graph_builder) 156 | self._to_graph_cache: int 157 | 158 | if self._constant != self._identity: 159 | # TODO: Add known_by 160 | graph_builder.add_link( 161 | graph_builder.add_node(label=str(self._constant)), self._to_graph_cache 162 | ) 163 | 164 | return self._to_graph_cache 165 | -------------------------------------------------------------------------------- /oraqle/compiler/nodes/leafs.py: -------------------------------------------------------------------------------- 1 | """Module containing leaf nodes: i.e. nodes without an input.""" 2 | from typing import Any, Dict, List, Set, Tuple, Type 3 | 4 | from galois import FieldArray 5 | 6 | from oraqle.compiler.graphviz import DotFile 7 | from oraqle.compiler.instructions import ArithmeticInstruction, InputInstruction 8 | from oraqle.compiler.nodes.abstract import ArithmeticNode, CostParetoFront, Node, select_stack_index 9 | from oraqle.compiler.nodes.fixed import FixedNode 10 | 11 | 12 | class ArithmeticLeafNode(FixedNode, ArithmeticNode): 13 | """An ArithmeticLeafNode is an ArithmeticNode with no inputs.""" 14 | 15 | def operands(self) -> List[Node]: # noqa: D102 16 | return [] 17 | 18 | def set_operands(self, operands: List["Node"]): # noqa: D102 19 | pass 20 | 21 | def _arithmetize_inner(self, strategy: str) -> Node: 22 | return self 23 | 24 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 25 | return CostParetoFront.from_leaf(self, cost_of_squaring) 26 | 27 | def multiplicative_depth(self) -> int: # noqa: D102 28 | return 0 29 | 30 | def multiplicative_size(self) -> int: # noqa: D102 31 | return 0 32 | 33 | def multiplications(self) -> Set[int]: # noqa: D102 34 | return set() 35 | 36 | def squarings(self) -> Set[int]: # noqa: D102 37 | return set() 38 | 39 | 40 | # TODO: Merge ArithmeticInput and Input using multiple inheritance 41 | class Input(ArithmeticLeafNode): 42 | """Represents a named input to the arithmetic circuit.""" 43 | 44 | @property 45 | def _overriden_graphviz_attributes(self) -> dict: 46 | return {"shape": "circle", "style": "filled", "fillcolor": "lightsteelblue1"} 47 | 48 | @property 49 | def _hash_name(self) -> str: 50 | return "input" 51 | 52 | @property 53 | def _node_label(self) -> str: 54 | return self._name 55 | 56 | def __init__(self, name: str, gf: Type[FieldArray]) -> None: 57 | """Initialize an input with the given `name`.""" 58 | super().__init__(gf) 59 | self._name = name 60 | 61 | 62 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 63 | raise Exception() 64 | 65 | 66 | def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray: # noqa: D102 67 | return actual_inputs[self._name] 68 | 69 | 70 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 71 | if self._to_graph_cache is None: 72 | label = self._name 73 | 74 | self._to_graph_cache = graph_builder.add_node( 75 | label=label, **self._overriden_graphviz_attributes 76 | ) 77 | 78 | return self._to_graph_cache 79 | 80 | def __hash__(self) -> int: 81 | return hash(self._name) 82 | 83 | 84 | def is_equivalent(self, other: Node) -> bool: # noqa: D102 85 | if not isinstance(other, self.__class__): 86 | return False 87 | 88 | return self._name == other._name 89 | 90 | 91 | def create_instructions( # noqa: D102 92 | self, 93 | instructions: List[ArithmeticInstruction], 94 | stack_counter: int, 95 | stack_occupied: List[bool], 96 | ) -> Tuple[int, int]: 97 | if self._instruction_cache is None: 98 | self._instruction_cache = select_stack_index(stack_occupied) 99 | instructions.append(InputInstruction(self._instruction_cache, self._name)) 100 | 101 | return self._instruction_cache, stack_counter 102 | 103 | 104 | class Constant(ArithmeticLeafNode): 105 | """Represents a Node with a constant value.""" 106 | 107 | @property 108 | def _overriden_graphviz_attributes(self) -> dict: 109 | return {"style": "filled", "fillcolor": "grey80", "shape": "circle"} 110 | 111 | @property 112 | def _hash_name(self) -> str: 113 | return "constant" 114 | 115 | @property 116 | def _node_label(self) -> str: 117 | return str(self._value) 118 | 119 | def __init__(self, value: FieldArray): 120 | """Initialize a Node with the given `value`.""" 121 | super().__init__(value.__class__) 122 | self._value = value 123 | 124 | 125 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 126 | return self._value 127 | 128 | 129 | def to_graph(self, graph_builder: DotFile) -> Any: # noqa: D102 130 | if self._to_graph_cache is None: 131 | label = str(self._value) 132 | 133 | self._to_graph_cache = graph_builder.add_node( 134 | label=label, **self._overriden_graphviz_attributes 135 | ) 136 | 137 | return self._to_graph_cache 138 | 139 | def __hash__(self) -> int: 140 | return hash(int(self._value)) 141 | 142 | 143 | def is_equivalent(self, other: Node) -> bool: # noqa: D102 144 | if not isinstance(other, self.__class__): 145 | return False 146 | 147 | return self._value == other._value 148 | 149 | 150 | def add(self, other: "Node", flatten=True) -> "Node": # noqa: D102 151 | if isinstance(other, Constant): 152 | return Constant(self._value + other._value) 153 | 154 | return other.add(self, flatten) 155 | 156 | 157 | def mul(self, other: "Node", flatten=True) -> "Node": # noqa: D102 158 | if isinstance(other, Constant): 159 | return Constant(self._value * other._value) 160 | 161 | return other.mul(self, flatten) 162 | 163 | 164 | def bool_or(self, other: "Node", flatten=True) -> Node: # noqa: D102 165 | if isinstance(other, Constant): 166 | return Constant(self._gf(bool(self._value) | bool(other._value))) 167 | 168 | return other.bool_or(self, flatten) 169 | 170 | def bool_and(self, other: "Node", flatten=True) -> Node: # noqa: D102 171 | if isinstance(other, Constant): 172 | return Constant(self._gf(bool(self._value) & bool(other._value))) 173 | 174 | return other.bool_and(self, flatten) 175 | 176 | def create_instructions( # noqa: D102 177 | self, 178 | instructions: List[ArithmeticInstruction], 179 | stack_counter: int, 180 | stack_occupied: List[bool], 181 | ) -> Tuple[int]: 182 | raise NotImplementedError("The circuit is a constant.") 183 | 184 | 185 | class DummyNode(FixedNode): 186 | """A DummyNode is a fixed node with no inputs and no behavior.""" 187 | 188 | def operands(self) -> List[Node]: # noqa: D102 189 | return [] 190 | 191 | def set_operands(self, operands: List["Node"]): # noqa: D102 192 | pass 193 | -------------------------------------------------------------------------------- /oraqle/compiler/nodes/non_commutative.py: -------------------------------------------------------------------------------- 1 | """A collection of abstract nodes representing operations that are non-commutative.""" 2 | from abc import abstractmethod 3 | from typing import List, Type 4 | 5 | from galois import FieldArray 6 | 7 | from oraqle.compiler.graphviz import DotFile 8 | from oraqle.compiler.nodes.abstract import Node 9 | from oraqle.compiler.nodes.fixed import BinaryNode 10 | 11 | 12 | class NonCommutativeBinaryNode(BinaryNode): 13 | """Represents a non-cummutative binary operation such as `x < y` or `x - y`.""" 14 | 15 | def __init__(self, left, right, gf: Type[FieldArray]): 16 | """Initialize a Node that performs an operation between two operands that is not commutative.""" 17 | self._left = left 18 | self._right = right 19 | super().__init__(gf) 20 | 21 | @abstractmethod 22 | def _operation_inner(self, x, y) -> FieldArray: 23 | """Applies the binary operation on x and y.""" 24 | 25 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 26 | return self._operation_inner(operands[0], operands[1]) 27 | 28 | def operands(self) -> List[Node]: # noqa: D102 29 | return [self._left, self._right] 30 | 31 | def set_operands(self, operands: List["Node"]): # noqa: D102 32 | self._left = operands[0] 33 | self._right = operands[1] 34 | 35 | def __hash__(self) -> int: 36 | if self._hash is None: 37 | left_hash = hash(self._left) 38 | right_hash = hash(self._right) 39 | 40 | self._hash = hash((self._hash_name, (left_hash, right_hash))) 41 | 42 | return self._hash 43 | 44 | def is_equivalent(self, other: Node) -> bool: # noqa: D102 45 | if not isinstance(other, self.__class__): 46 | return False 47 | 48 | if hash(self) != hash(other): 49 | return False 50 | 51 | return self._left.is_equivalent(other._left) and self._right.is_equivalent(other._right) 52 | 53 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 54 | if self._to_graph_cache is None: 55 | attributes = {"shape": "box"} 56 | attributes.update(self._overriden_graphviz_attributes) 57 | 58 | self._to_graph_cache = graph_builder.add_node( 59 | label=self._node_label, 60 | **attributes, 61 | ) 62 | 63 | left = self._left.to_graph(graph_builder) 64 | right = self._right.to_graph(graph_builder) 65 | 66 | graph_builder.add_link(left, self._to_graph_cache, headport="nw") 67 | graph_builder.add_link(right, self._to_graph_cache, headport="ne") 68 | 69 | return self._to_graph_cache 70 | -------------------------------------------------------------------------------- /oraqle/compiler/nodes/unary_arithmetic.py: -------------------------------------------------------------------------------- 1 | """This module contains `ArithmeticNode`s with a single input: Constant additions and constant multiplications.""" 2 | from typing import List, Optional, Set, Tuple 3 | 4 | from galois import FieldArray 5 | 6 | from oraqle.compiler.graphviz import DotFile 7 | from oraqle.compiler.instructions import ( 8 | ArithmeticInstruction, 9 | ConstantAdditionInstruction, 10 | ConstantMultiplicationInstruction, 11 | ) 12 | from oraqle.compiler.nodes.abstract import ArithmeticNode, CostParetoFront, Node, select_stack_index 13 | from oraqle.compiler.nodes.univariate import UnivariateNode 14 | 15 | # TODO: There is (going to be) a lot of code duplication between these two classes 16 | 17 | 18 | class ConstantAddition(UnivariateNode, ArithmeticNode): 19 | """This node represents a multiplication of another node with a constant.""" 20 | 21 | @property 22 | def _overriden_graphviz_attributes(self) -> dict: 23 | return {"style": "rounded,filled", "fillcolor": "grey80"} 24 | 25 | @property 26 | def _node_shape(self) -> str: 27 | return "square" 28 | 29 | @property 30 | def _hash_name(self) -> str: 31 | return f"constant_add_{self._constant}" 32 | 33 | @property 34 | def _node_label(self) -> str: 35 | return "+" 36 | 37 | def __init__(self, node: ArithmeticNode, constant: FieldArray): 38 | """Represents the operation `constant + node`.""" 39 | super().__init__(node, constant.__class__) 40 | self._constant = constant 41 | assert constant != 0 42 | 43 | self._depth_cache: Optional[int] = None 44 | 45 | 46 | def _operation_inner(self, input: FieldArray) -> FieldArray: 47 | return input + self._constant 48 | 49 | 50 | def multiplicative_depth(self) -> int: # noqa: D102 51 | if self._depth_cache is None: 52 | self._depth_cache = self._node.multiplicative_depth() 53 | 54 | return self._depth_cache 55 | 56 | 57 | def multiplications(self) -> Set[int]: # noqa: D102 58 | return self._node.multiplications() 59 | 60 | 61 | def squarings(self) -> Set[int]: # noqa: D102 62 | return self._node.squarings() 63 | 64 | 65 | def create_instructions( # noqa: D102 66 | self, 67 | instructions: List[ArithmeticInstruction], 68 | stack_counter: int, 69 | stack_occupied: List[bool], 70 | ) -> Tuple[int, int]: 71 | self._node: ArithmeticNode 72 | 73 | if self._instruction_cache is None: 74 | operand_index, stack_counter = self._node.create_instructions( 75 | instructions, stack_counter, stack_occupied 76 | ) 77 | 78 | self._node._parent_count -= 1 79 | if self._node._parent_count == 0: 80 | stack_occupied[self._node._instruction_cache] = False # type: ignore 81 | 82 | self._instruction_cache = select_stack_index(stack_occupied) 83 | 84 | instructions.append( 85 | ConstantAdditionInstruction(self._instruction_cache, operand_index, self._constant) 86 | ) 87 | 88 | return self._instruction_cache, stack_counter 89 | 90 | 91 | def _arithmetize_inner(self, strategy: str) -> Node: 92 | return self 93 | 94 | 95 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 96 | front = CostParetoFront(cost_of_squaring) 97 | for _, _, node in self._node.arithmetize_depth_aware(cost_of_squaring): 98 | front.add(ConstantAddition(node, self._constant)) 99 | return front 100 | 101 | 102 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 103 | if self._to_graph_cache is None: 104 | super().to_graph(graph_builder) 105 | self._to_graph_cache: int 106 | 107 | # TODO: Add known_by 108 | graph_builder.add_link( 109 | graph_builder.add_node( 110 | label=str(self._constant), shape="circle", style="filled", fillcolor="grey92" 111 | ), 112 | self._to_graph_cache, 113 | ) 114 | 115 | return self._to_graph_cache 116 | 117 | 118 | class ConstantMultiplication(UnivariateNode, ArithmeticNode): 119 | """This node represents a multiplication of another node with a constant.""" 120 | 121 | @property 122 | def _overriden_graphviz_attributes(self) -> dict: 123 | return {"style": "rounded,filled", "fillcolor": "grey80"} 124 | 125 | @property 126 | def _node_shape(self) -> str: 127 | return "square" 128 | 129 | @property 130 | def _hash_name(self) -> str: 131 | return f"constant_mul_{self._constant}" 132 | 133 | @property 134 | def _node_label(self) -> str: 135 | return "×" # noqa: RUF001 136 | 137 | def __init__(self, node: Node, constant: FieldArray): 138 | """Represents the operation `constant * node`.""" 139 | super().__init__(node, constant.__class__) 140 | self._constant = constant 141 | assert constant != 0 142 | assert constant != 1 143 | 144 | self._depth_cache: Optional[int] = None 145 | 146 | def _operation_inner(self, input: FieldArray) -> FieldArray: 147 | return input * self._constant # type: ignore 148 | 149 | 150 | def multiplicative_depth(self) -> int: # noqa: D102 151 | if self._depth_cache is None: 152 | self._depth_cache = self._node.multiplicative_depth() # type: ignore 153 | 154 | return self._depth_cache # type: ignore 155 | 156 | 157 | def multiplications(self) -> Set[int]: # noqa: D102 158 | return self._node.multiplications() # type: ignore 159 | 160 | 161 | def squarings(self) -> Set[int]: # noqa: D102 162 | return self._node.squarings() # type: ignore 163 | 164 | 165 | def create_instructions( # noqa: D102 166 | self, 167 | instructions: List[ArithmeticInstruction], 168 | stack_counter: int, 169 | stack_occupied: List[bool], 170 | ) -> Tuple[int, int]: 171 | self._node: ArithmeticNode 172 | 173 | if self._instruction_cache is None: 174 | operand_index, stack_counter = self._node.create_instructions( 175 | instructions, stack_counter, stack_occupied 176 | ) 177 | 178 | self._node._parent_count -= 1 179 | if self._node._parent_count == 0: 180 | stack_occupied[self._node._instruction_cache] = False # type: ignore 181 | 182 | self._instruction_cache = select_stack_index(stack_occupied) 183 | 184 | instructions.append( 185 | ConstantMultiplicationInstruction( 186 | self._instruction_cache, operand_index, self._constant 187 | ) 188 | ) 189 | 190 | return self._instruction_cache, stack_counter 191 | 192 | 193 | def _arithmetize_inner(self, strategy: str) -> Node: 194 | return self 195 | 196 | 197 | def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: 198 | front = CostParetoFront(cost_of_squaring) 199 | for _, _, node in self._node.arithmetize_depth_aware(cost_of_squaring): 200 | front.add(ConstantMultiplication(node, self._constant)) 201 | return front 202 | 203 | 204 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 205 | if self._to_graph_cache is None: 206 | super().to_graph(graph_builder) 207 | self._to_graph_cache: int 208 | 209 | # TODO: Add known_by 210 | graph_builder.add_link( 211 | graph_builder.add_node( 212 | label=str(self._constant), shape="circle", style="filled", fillcolor="grey92" 213 | ), 214 | self._to_graph_cache, 215 | ) 216 | 217 | return self._to_graph_cache 218 | -------------------------------------------------------------------------------- /oraqle/compiler/nodes/univariate.py: -------------------------------------------------------------------------------- 1 | """Abstract nodes for univariate operations.""" 2 | 3 | from abc import abstractmethod 4 | from typing import List, Type 5 | 6 | from galois import FieldArray 7 | 8 | from oraqle.compiler.graphviz import DotFile 9 | from oraqle.compiler.nodes.abstract import Node 10 | from oraqle.compiler.nodes.fixed import FixedNode 11 | from oraqle.compiler.nodes.leafs import Constant 12 | 13 | 14 | class UnivariateNode(FixedNode): 15 | """An abstract node with a single input.""" 16 | 17 | @property 18 | @abstractmethod 19 | def _node_shape(self) -> str: 20 | """Graphviz node shape.""" 21 | 22 | def __init__(self, node: Node, gf: Type[FieldArray]): 23 | """Initialize a univariate node.""" 24 | self._node = node 25 | assert not isinstance(node, Constant) 26 | super().__init__(gf) 27 | 28 | 29 | def operands(self) -> List["Node"]: # noqa: D102 30 | return [self._node] 31 | 32 | 33 | def set_operands(self, operands: List["Node"]): # noqa: D102 34 | self._node = operands[0] 35 | 36 | @abstractmethod 37 | def _operation_inner(self, input: FieldArray) -> FieldArray: 38 | """Evaluate the operation on the input. This method does not have to cache.""" 39 | 40 | 41 | def operation(self, operands: List[FieldArray]) -> FieldArray: # noqa: D102 42 | return self._operation_inner(operands[0]) 43 | 44 | 45 | def to_graph(self, graph_builder: DotFile) -> int: # noqa: D102 46 | if self._to_graph_cache is None: 47 | attributes = {} 48 | 49 | attributes.update(self._overriden_graphviz_attributes) 50 | 51 | self._to_graph_cache = graph_builder.add_node( 52 | label=self._node_label, shape=self._node_shape, **attributes 53 | ) 54 | 55 | graph_builder.add_link(self._node.to_graph(graph_builder), self._to_graph_cache) 56 | 57 | return self._to_graph_cache 58 | 59 | def __hash__(self) -> int: 60 | if self._hash is None: 61 | self._hash = hash((self._hash_name, self._node)) 62 | 63 | return self._hash 64 | 65 | def is_equivalent(self, other: Node) -> bool: 66 | """Check whether `self` is semantically equivalent to `other`. 67 | 68 | This function may have false negatives but it should never return false positives. 69 | 70 | Returns: 71 | ------- 72 | `True` if `self` is semantically equivalent to `other`, `False` if they are not or that they cannot be shown to be equivalent. 73 | 74 | """ 75 | if not isinstance(other, self.__class__): 76 | return False 77 | 78 | if hash(self) != hash(other): 79 | return False 80 | 81 | return self._node.is_equivalent(other._node) 82 | -------------------------------------------------------------------------------- /oraqle/compiler/poly2circuit.py: -------------------------------------------------------------------------------- 1 | """Module for automatic circuit generation for any functions with any number of inputs. 2 | 3 | Warning: These circuits can be very large! 4 | """ 5 | 6 | from collections import Counter 7 | from typing import Dict, List, Tuple, Type 8 | 9 | from galois import GF, FieldArray 10 | from sympy import Add, Integer, Mul, Poly, Pow, Symbol 11 | from sympy.core.numbers import NegativeOne 12 | 13 | from oraqle.compiler.circuit import Circuit 14 | from oraqle.compiler.func2poly import interpolate_polynomial 15 | from oraqle.compiler.nodes import Constant, Input, Node 16 | from oraqle.compiler.nodes.abstract import UnoverloadedWrapper 17 | from oraqle.compiler.nodes.arbitrary_arithmetic import Product 18 | 19 | 20 | def construct_subcircuit(expression, gf, modulus: int, inputs: Dict[str, Input]) -> Node: # noqa: PLR0912 21 | """Build a circuit with a single output given an expression of simple arithmetic operations in Sympy. 22 | 23 | Raises: 24 | ------ 25 | Exception: Exponents must be integers, or an exception will be raised. 26 | 27 | Returns: 28 | ------- 29 | A subcircuit (Node) computing the given sympy expression. 30 | 31 | """ 32 | if expression.func == Add: 33 | arg_iter = iter(expression.args) 34 | 35 | # The first argument can be a scalar. 36 | first = next(arg_iter) 37 | if first.func in {Integer, NegativeOne}: 38 | if first.func == Integer: 39 | scalar = Constant(gf(int(first) % modulus)) 40 | else: 41 | scalar = Constant(-gf(1)) 42 | result = scalar + construct_subcircuit(next(arg_iter), gf, modulus, inputs) 43 | else: 44 | # TODO: Replace this entire part with a sum 45 | result = construct_subcircuit(first, gf, modulus, inputs) + construct_subcircuit( 46 | next(arg_iter), gf, modulus, inputs 47 | ) 48 | 49 | for arg in arg_iter: 50 | result = construct_subcircuit(arg, gf, modulus, inputs) + result 51 | 52 | return result 53 | elif expression.func == Mul: 54 | arg_iter = iter(expression.args) 55 | 56 | # The first argument can be a scalar. 57 | first = next(arg_iter) 58 | if first.func in {Integer, NegativeOne}: 59 | if first.func == Integer: 60 | scalar = Constant(gf(int(first) % modulus)) 61 | else: 62 | scalar = Constant(-gf(1)) 63 | result = scalar * construct_subcircuit(next(arg_iter), gf, modulus, inputs) 64 | else: 65 | # TODO: Replace this entire part with a product 66 | result = construct_subcircuit(first, gf, modulus, inputs) * construct_subcircuit( 67 | next(arg_iter), gf, modulus, inputs 68 | ) 69 | 70 | for arg in arg_iter: 71 | result = construct_subcircuit(arg, gf, modulus, inputs) * result 72 | 73 | return result 74 | elif expression.func == Pow: 75 | if expression.args[1].func != Integer: 76 | raise Exception("There was an exponent with a non-integer exponent") 77 | # Change powers to series of multiplications 78 | subcircuit = construct_subcircuit(expression.args[0], gf, modulus, inputs) 79 | # TODO: This is not the most efficient way; we can use re-balancing. 80 | return Product( 81 | Counter({UnoverloadedWrapper(subcircuit): int(expression.args[1])}), gf 82 | ) # FIXME: This could be flattened 83 | elif expression.func == Symbol: 84 | assert len(expression.args) == 0 85 | var = str(expression) 86 | if var in inputs: 87 | return inputs[var] 88 | new_input = Input(var, gf) 89 | inputs[var] = new_input 90 | return new_input 91 | else: 92 | raise Exception( 93 | f"The expression contained an invalid operation (not one implemented in arithmetic circuits): {expression.func}." 94 | ) 95 | 96 | 97 | def construct_circuit(polynomials: List[Poly], modulus: int) -> Tuple[Circuit, Type[FieldArray]]: 98 | """Construct an arithmetic circuit from a list of polynomials and the fixed modulus. 99 | 100 | Returns: 101 | ------- 102 | A circuit outputting the evaluation of each polynomial. 103 | 104 | """ 105 | inputs = {} 106 | gf = GF(modulus) 107 | return ( 108 | Circuit( 109 | [construct_subcircuit(poly.expr, gf, modulus, inputs) for poly in polynomials], 110 | ), 111 | gf, 112 | ) 113 | 114 | 115 | if __name__ == "__main__": 116 | # Use function max(x, y) 117 | function = max 118 | modulus = 7 119 | 120 | # Create a polynomial and then a circuit that evalutes this expression 121 | poly = interpolate_polynomial(function, modulus, ["x", "y"]) 122 | circuit, gf = construct_circuit([poly], modulus) 123 | 124 | # Output a DOT file for this high-level circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/) 125 | circuit.to_graph("max_7_hl.dot") 126 | 127 | # Arithmetize the high-level circuit, afterwards it will only contain arithmetic operations 128 | circuit = circuit.arithmetize() 129 | circuit.to_graph("max_7_hl.dot") 130 | 131 | # Print the initial metrics of the circuit 132 | print("depth", circuit.multiplicative_depth()) 133 | print("size", circuit.multiplicative_size()) 134 | 135 | # Apply common subexpression elimination (CSE) to remove duplicate operations from the circuit 136 | circuit.eliminate_subexpressions() 137 | 138 | # Output a DOT file for this arithmetic circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/) 139 | circuit.to_graph("max_7.dot") 140 | 141 | # Print the resulting metrics of the circuit 142 | print("depth", circuit.multiplicative_depth()) 143 | print("size", circuit.multiplicative_size()) 144 | 145 | # Test that given x=4 and y=2 indeed max(x, y) = 4 146 | assert circuit.evaluate({"x": gf(4), "y": gf(2)}) == [4] 147 | 148 | # Output a DOT file for this arithmetic circuit (you can visualize it using https://dreampuf.github.io/GraphvizOnline/) 149 | circuit.to_graph("max_7.dot") 150 | -------------------------------------------------------------------------------- /oraqle/compiler/polynomials/__init__.py: -------------------------------------------------------------------------------- 1 | """The polynomials package contains nodes for performing polynomial evaluation. 2 | 3 | In a finite field, the set of polyfunctions is the same as the set of all functions. 4 | So, you can perform any function by interpolating a polynomial. 5 | """ 6 | -------------------------------------------------------------------------------- /oraqle/config.py: -------------------------------------------------------------------------------- 1 | """This module contains global configuration options. 2 | 3 | !!! warning 4 | This is almost certainly going to be removed in the future. 5 | We do not want oraqle to have a global configuration, but this is currently an intentional evil to prevent large refactors in the initial versions. 6 | """ 7 | from typing import Annotated, Optional 8 | 9 | 10 | Seconds = Annotated[float, "seconds"] 11 | MAXSAT_TIMEOUT: Optional[Seconds] = None 12 | """Time-out for individual calls to the MaxSAT solver. 13 | 14 | !!! danger 15 | This causes non-deterministic behavior! 16 | 17 | !!! bug 18 | There is currently a chance to get `AttributeError`s, which is a bug caused by pysat trying to delete an oracle that does not exist. 19 | There is no current workaround for this.""" 20 | 21 | 22 | PS_METHOD_FACTOR_K: float = 2.0 23 | """Approximation factor for the PS-method, higher is better. 24 | 25 | The Paterson-Stockmeyer method takes a value k, that is theoretically optimal when k = sqrt(2 * degree). 26 | However, sometimes it is better to try other values of k (e.g. due to rounding and to trade off depth and cost). 27 | This factor, let's call it f, is used to limit the candidate values of k that we try: [1, f * sqrt(2 * degree)).""" 28 | -------------------------------------------------------------------------------- /oraqle/demo/playground.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b5e62be9-cada-42d2-a2bb-7f3ee38aec51", 6 | "metadata": {}, 7 | "source": [ 8 | "# Playground" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "df1eaaad-6ad2-4601-8d12-3b02d9254bfa", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from galois import GF\n", 19 | "\n", 20 | "from circuit_compiler.compiler.boolean.bool_and import And\n", 21 | "from circuit_compiler.compiler.circuit import Circuit\n", 22 | "from circuit_compiler.compiler.nodes.leafs import Input\n", 23 | "\n", 24 | "gf = GF(5)\n", 25 | "\n", 26 | "xs = [Input(f\"x{i}\", gf) for i in range(11)]\n", 27 | "\n", 28 | "output = And(set(xs), gf)\n", 29 | "\n", 30 | "circuit = Circuit(outputs=[output], gf=gf)\n", 31 | "circuit.display_graph()" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "ce27d985-b20c-4f7e-a303-929598c61c17", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "naive_arithmetic_circuit = circuit.arithmetize(\"naive\")\n", 42 | "naive_arithmetic_circuit.display_graph()" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "d5c0531f-f50b-4852-b81c-833a813eb235", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "circuit._clear_cache()\n", 53 | "better_arithmetic_circuit = circuit.arithmetize(\"best-effort\")\n", 54 | "better_arithmetic_circuit.display_graph()" 55 | ] 56 | } 57 | ], 58 | "metadata": { 59 | "kernelspec": { 60 | "display_name": "Python 3 (ipykernel)", 61 | "language": "python", 62 | "name": "python3" 63 | }, 64 | "language_info": { 65 | "codemirror_mode": { 66 | "name": "ipython", 67 | "version": 3 68 | }, 69 | "file_extension": ".py", 70 | "mimetype": "text/x-python", 71 | "name": "python", 72 | "nbconvert_exporter": "python", 73 | "pygments_lexer": "ipython3", 74 | "version": "3.9.6" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 5 79 | } 80 | -------------------------------------------------------------------------------- /oraqle/demo/small_comparison_bgv.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0f2abd68-5065-49c2-aefa-65ca3c8be8f8", 6 | "metadata": {}, 7 | "source": [ 8 | "# Compiling homomorphic encryption circuits made easy" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "1f425b04-35ab-4a1c-8ed4-20ecdc7d2901", 14 | "metadata": {}, 15 | "source": [ 16 | "#### The only boilerplate consists of defining the plaintext space and the inputs of the program." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "18d03a72-d22a-4f54-9a68-ab31507d1e34", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "from galois import GF\n", 27 | "\n", 28 | "from circuit_compiler.compiler.nodes.leafs import Input\n", 29 | "\n", 30 | "gf = GF(11)\n", 31 | "\n", 32 | "a = Input(\"a\", gf)\n", 33 | "b = Input(\"b\", gf)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "7a7890f4-c770-4699-acba-ec2e6796a5bb", 39 | "metadata": {}, 40 | "source": [ 41 | "#### Programmers can use the primitives that they are used to." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "0dd02769-50cc-4eb1-a9e0-896b944d9b28", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "output = a < b" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "8a26b9ca-2441-48e1-8aad-4b626755485e", 57 | "metadata": {}, 58 | "source": [ 59 | "#### A circuit can have an arbitrary number of outputs; here we only have one." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "d00fa605-4510-4393-bdb0-4dd54a21f5f8", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "from circuit_compiler.compiler.circuit import Circuit\n", 70 | "\n", 71 | "circuit = Circuit(outputs=[output], gf=gf)\n", 72 | "circuit.display_graph()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "fc7c6e33-a7ad-4e2f-a742-40653160a0ca", 78 | "metadata": {}, 79 | "source": [ 80 | "#### Turning high-level circuits into arithmetic circuits is a fully automatic process that improves on the state of the art in multiple ways." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "a441c9f5-de63-4253-bbb6-4b63511acc67", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "arithmetic_circuit = circuit.arithmetize()\n", 91 | "arithmetic_circuit.display_graph()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "33a64549-4081-4fb8-9631-1f007b368dfa", 97 | "metadata": {}, 98 | "source": [ 99 | "#### The compiler implements a form of semantic subexpression elimination that significantly optimizes large circuits." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "1cd8bfaf-8113-444a-812c-b2a4fe124cec", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "arithmetic_circuit.eliminate_subexpressions()\n", 110 | "arithmetic_circuit.display_graph()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "a89d7c56-ef33-4ac6-b06a-0f88d45aff91", 116 | "metadata": {}, 117 | "source": [ 118 | "#### This much smaller circuit is still correct!" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "5d50141a-b84a-4ac4-93e7-a7a1cf484688", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "import tabulate\n", 129 | "\n", 130 | "for val_a in range(11):\n", 131 | " for val_b in range(11):\n", 132 | " assert arithmetic_circuit.evaluate({\"a\": gf(val_a), \"b\": gf(val_b)}) == gf(val_a < val_b)\n", 133 | "\n", 134 | "data = [[arithmetic_circuit.evaluate({\"a\": gf(val_a), \"b\": gf(val_b)})[0] for val_a in range(11)] for val_b in range(11)]\n", 135 | "\n", 136 | "table = tabulate.tabulate(data, tablefmt='html')\n", 137 | "table" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "Python 3 (ipykernel)", 144 | "language": "python", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.9.6" 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 5 162 | } 163 | -------------------------------------------------------------------------------- /oraqle/examples/depth_aware_comparison.py: -------------------------------------------------------------------------------- 1 | """Depth-aware arithmetization of a comparison modulo 101.""" 2 | 3 | from galois import GF 4 | 5 | from oraqle.compiler.circuit import Circuit 6 | from oraqle.compiler.nodes.leafs import Input 7 | 8 | gf = GF(101) 9 | cost_of_squaring = 1.0 10 | 11 | a = Input("a", gf) 12 | b = Input("b", gf) 13 | 14 | output = a < b 15 | 16 | circuit = Circuit(outputs=[output]) 17 | circuit.to_graph("high_level_circuit.dot") 18 | 19 | arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring) 20 | 21 | for depth, cost, arithmetic_circuit in arithmetic_circuits: 22 | assert arithmetic_circuit.multiplicative_depth() == depth 23 | assert arithmetic_circuit.multiplicative_cost(cost_of_squaring) == cost 24 | 25 | print("pre CSE", depth, cost) 26 | 27 | arithmetic_circuit.eliminate_subexpressions() 28 | 29 | print( 30 | "post CSE", 31 | arithmetic_circuit.multiplicative_depth(), 32 | arithmetic_circuit.multiplicative_cost(cost_of_squaring), 33 | ) 34 | -------------------------------------------------------------------------------- /oraqle/examples/depth_aware_equality.py: -------------------------------------------------------------------------------- 1 | """Depth-aware arithmetization for an equality operation modulo 31.""" 2 | 3 | from galois import GF 4 | 5 | from oraqle.compiler.circuit import Circuit 6 | from oraqle.compiler.comparison.equality import Equals 7 | from oraqle.compiler.nodes.leafs import Input 8 | 9 | gf = GF(31) 10 | 11 | a = Input("a", gf) 12 | b = Input("b", gf) 13 | 14 | output = Equals(a, b, gf) 15 | 16 | circuit = Circuit(outputs=[output]) 17 | 18 | arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring=1.0) 19 | 20 | if __name__ == "__main__": 21 | circuit.to_pdf("high_level_circuit.pdf") 22 | for depth, size, arithmetic_circuit in arithmetic_circuits: 23 | arithmetic_circuit.to_pdf(f"arithmetic_circuit_d{depth}_s{size}.pdf") 24 | -------------------------------------------------------------------------------- /oraqle/examples/long_and.py: -------------------------------------------------------------------------------- 1 | """Arithmetization of an AND operation between 15 inputs.""" 2 | 3 | from galois import GF 4 | 5 | from oraqle.compiler.boolean.bool_and import And 6 | from oraqle.compiler.circuit import Circuit 7 | from oraqle.compiler.nodes.abstract import UnoverloadedWrapper 8 | from oraqle.compiler.nodes.leafs import Input 9 | 10 | gf = GF(5) 11 | 12 | xs = [Input(f"x{i}", gf) for i in range(15)] 13 | 14 | output = And(set(UnoverloadedWrapper(x) for x in xs), gf) 15 | 16 | circuit = Circuit(outputs=[output]) 17 | circuit.to_graph("high_level_circuit.dot") 18 | 19 | arithmetic_circuit = circuit.arithmetize() 20 | arithmetic_circuit.to_graph("arithmetic_circuit.dot") 21 | -------------------------------------------------------------------------------- /oraqle/examples/small_comparison.py: -------------------------------------------------------------------------------- 1 | """Arithmetizes a comparison modulo 11 with a constant.""" 2 | 3 | from galois import GF 4 | 5 | from oraqle.compiler.circuit import Circuit 6 | from oraqle.compiler.nodes.leafs import Constant, Input 7 | 8 | gf = GF(11) 9 | 10 | a = Input("a", gf) 11 | b = Constant(gf(3)) # Input("b") 12 | 13 | output = a < b 14 | 15 | circuit = Circuit(outputs=[output]) 16 | circuit.to_graph("high_level_circuit.dot") 17 | 18 | arithmetic_circuit = circuit.arithmetize() 19 | arithmetic_circuit.to_graph("arithmetic_circuit.dot") 20 | -------------------------------------------------------------------------------- /oraqle/examples/small_polynomial.py: -------------------------------------------------------------------------------- 1 | """Creates graphs for the arithmetization of a small polynomial evaluation.""" 2 | 3 | from galois import GF 4 | 5 | from oraqle.compiler.circuit import Circuit 6 | from oraqle.compiler.nodes.leafs import Input 7 | from oraqle.compiler.polynomials.univariate import UnivariatePoly 8 | 9 | gf = GF(11) 10 | 11 | x = Input("x", gf) 12 | 13 | output = UnivariatePoly(x, [gf(1), gf(2), gf(3), gf(4), gf(5), gf(6), gf(1)], gf) 14 | 15 | circuit = Circuit(outputs=[output]) 16 | circuit.to_graph("high_level_circuit.dot") 17 | 18 | arithmetic_circuit = circuit.arithmetize() 19 | arithmetic_circuit.to_graph("arithmetic_circuit.dot") 20 | -------------------------------------------------------------------------------- /oraqle/examples/visualize_circuits.py: -------------------------------------------------------------------------------- 1 | """Visualization of three circuits computing an OR operation on 7 inputs.""" 2 | 3 | from galois import GF 4 | 5 | from oraqle.compiler.arithmetic.exponentiation import Power 6 | from oraqle.compiler.boolean.bool_neg import Neg 7 | from oraqle.compiler.circuit import ArithmeticCircuit, Circuit 8 | from oraqle.compiler.nodes.binary_arithmetic import Multiplication 9 | from oraqle.compiler.nodes.leafs import Input 10 | 11 | gf = GF(5) 12 | 13 | x1 = Input("x1", gf) 14 | x2 = Input("x2", gf) 15 | x3 = Input("x3", gf) 16 | x4 = Input("x4", gf) 17 | x5 = Input("x5", gf) 18 | x6 = Input("x6", gf) 19 | x7 = Input("x7", gf) 20 | 21 | sum1 = x1 + x2 + x3 + x4 22 | exp1 = Power(sum1, 4, gf) 23 | 24 | sum2 = x5 + x6 + x7 + exp1 25 | exp2 = Power(sum2, 4, gf) 26 | 27 | circuit = Circuit([exp2]) 28 | arithmetic_circuit = circuit.arithmetize() 29 | arithmetic_circuit.to_graph("arithmetic_circuit1.dot") 30 | 31 | 32 | inv1 = Neg(x1, gf) 33 | inv2 = Neg(x2, gf) 34 | inv3 = Neg(x3, gf) 35 | inv4 = Neg(x4, gf) 36 | inv5 = Neg(x5, gf) 37 | inv6 = Neg(x6, gf) 38 | 39 | mul1 = inv1 * inv2 40 | invmul1 = Neg(mul1, gf) 41 | 42 | mul2 = inv3 * inv4 43 | invmul2 = Neg(mul2, gf) 44 | 45 | mul3 = inv5 * inv6 46 | invmul3 = Neg(mul3, gf) 47 | 48 | add1 = mul1 + mul2 49 | add2 = mul3 + add1 50 | 51 | add3 = add2 + x7 52 | 53 | exp = Power(add3, 4, gf) 54 | 55 | circuit = Circuit([exp]) 56 | arithmetic_circuit = circuit.arithmetize() 57 | arithmetic_circuit.to_graph("arithmetic_circuit2.dot") 58 | 59 | 60 | inv1 = Neg(x1, gf).arithmetize("best-effort").to_arithmetic() 61 | inv2 = Neg(x2, gf).arithmetize("best-effort").to_arithmetic() 62 | inv3 = Neg(x3, gf).arithmetize("best-effort").to_arithmetic() 63 | inv4 = Neg(x4, gf).arithmetize("best-effort").to_arithmetic() 64 | inv5 = Neg(x5, gf).arithmetize("best-effort").to_arithmetic() 65 | inv6 = Neg(x6, gf).arithmetize("best-effort").to_arithmetic() 66 | inv7 = Neg(x7, gf).arithmetize("best-effort").to_arithmetic() 67 | 68 | mul1 = Multiplication(inv1, inv2, gf) 69 | mul2 = Multiplication(inv3, inv4, gf) 70 | mul3 = Multiplication(inv5, inv6, gf) 71 | 72 | mul4 = Multiplication(mul1, mul2, gf) 73 | mul5 = Multiplication(mul3, inv7, gf) 74 | 75 | mul6 = Multiplication(mul4, mul5, gf) 76 | 77 | inv = Neg(mul6, gf).arithmetize("best-effort").to_arithmetic() 78 | 79 | arithmetic_circuit = ArithmeticCircuit([inv]) 80 | arithmetic_circuit.to_graph("arithmetic_circuit3.dot") 81 | -------------------------------------------------------------------------------- /oraqle/examples/wahc2024_presentation/1_high-level.py: -------------------------------------------------------------------------------- 1 | """Renders a high-level comparison circuit.""" 2 | from galois import GF 3 | 4 | from oraqle.compiler.circuit import Circuit 5 | from oraqle.compiler.nodes.leafs import Input 6 | 7 | 8 | if __name__ == "__main__": 9 | gf = GF(101) 10 | 11 | alex = Input("a", gf) 12 | blake = Input("b", gf) 13 | 14 | output = alex < blake 15 | circuit = Circuit(outputs=[output]) 16 | 17 | circuit.to_svg("high_level.svg") 18 | -------------------------------------------------------------------------------- /oraqle/examples/wahc2024_presentation/2_arith_step1.py: -------------------------------------------------------------------------------- 1 | """Show the first step of arithmetization of a comparison circuit.""" 2 | from galois import GF 3 | 4 | from oraqle.compiler.boolean.bool_neg import Neg 5 | from oraqle.compiler.circuit import Circuit 6 | from oraqle.compiler.comparison.comparison import SemiStrictComparison 7 | from oraqle.compiler.nodes.leafs import Constant, Input 8 | 9 | if __name__ == "__main__": 10 | gf = GF(101) 11 | 12 | alex = Input("a", gf) 13 | blake = Input("b", gf) 14 | 15 | output = alex < blake 16 | 17 | 18 | p = output._gf.characteristic 19 | 20 | if output._less_than: 21 | left = output._left 22 | right = output._right 23 | else: 24 | left = output._right 25 | right = output._left 26 | 27 | left = left.arithmetize("best-effort") 28 | right = right.arithmetize("best-effort") 29 | 30 | left_is_small = SemiStrictComparison( 31 | left, Constant(output._gf(p // 2)), less_than=True, gf=output._gf 32 | ) 33 | right_is_small = SemiStrictComparison( 34 | right, Constant(output._gf(p // 2)), less_than=True, gf=output._gf 35 | ) 36 | 37 | # Test whether left and right are in the same range 38 | same_range = (left_is_small & right_is_small) + ( 39 | Neg(left_is_small, output._gf) & Neg(right_is_small, output._gf) 40 | ) 41 | 42 | # Performs left < right on the reduced inputs, note that if both are in the upper half the difference is still small enough for a semi-comparison 43 | comparison = SemiStrictComparison(left, right, less_than=True, gf=output._gf) 44 | result = same_range * comparison 45 | 46 | # Performs left < right when one if small and the other is large 47 | right_is_larger = left_is_small & Neg(right_is_small, output._gf) 48 | result += right_is_larger 49 | 50 | 51 | circuit = Circuit(outputs=[result]) 52 | 53 | circuit.to_svg("arith_step1.svg") 54 | -------------------------------------------------------------------------------- /oraqle/examples/wahc2024_presentation/3_arith_step2.py: -------------------------------------------------------------------------------- 1 | """Show the last step of arithmetization of a comparison circuit.""" 2 | from galois import GF 3 | 4 | from oraqle.compiler.circuit import Circuit 5 | from oraqle.compiler.nodes.leafs import Input 6 | 7 | if __name__ == "__main__": 8 | gf = GF(101) 9 | 10 | alex = Input("a", gf) 11 | blake = Input("b", gf) 12 | 13 | output = alex < blake 14 | 15 | front = output.arithmetize_depth_aware(cost_of_squaring=1.0) 16 | print(front) 17 | 18 | _, tup = front._nodes_by_depth.popitem() 19 | _, node = tup 20 | circuit = Circuit(outputs=[node]) 21 | 22 | circuit.to_svg("arith_step2.svg") 23 | -------------------------------------------------------------------------------- /oraqle/examples/wahc2024_presentation/5_code_gen.py: -------------------------------------------------------------------------------- 1 | """Generates code for the comparison circuit.""" 2 | from galois import GF 3 | 4 | from oraqle.compiler.circuit import Circuit 5 | from oraqle.compiler.nodes.leafs import Input 6 | 7 | if __name__ == "__main__": 8 | gf = GF(101) 9 | 10 | alex = Input("a", gf) 11 | blake = Input("b", gf) 12 | 13 | output = alex < blake 14 | circuit = Circuit(outputs=[output]) 15 | 16 | front = circuit.arithmetize_depth_aware() 17 | 18 | for _, _, arithmetic_circuit in front: 19 | program = arithmetic_circuit.generate_code("example.cpp") 20 | -------------------------------------------------------------------------------- /oraqle/examples/wahc2024_presentation/rebalancing.py: -------------------------------------------------------------------------------- 1 | """Renders two circuits, one with a balanced product tree and one with an imbalanced tree.""" 2 | from galois import GF 3 | 4 | from oraqle.compiler.circuit import Circuit 5 | from oraqle.compiler.nodes.leafs import Input 6 | 7 | if __name__ == "__main__": 8 | gf = GF(101) 9 | 10 | a = Input("a", gf) 11 | b = Input("b", gf) 12 | c = Input("c", gf) 13 | d = Input("d", gf) 14 | 15 | output = a * b * c * d 16 | circuit_good = Circuit(outputs=[output]) 17 | circuit_good = circuit_good.arithmetize_depth_aware() # FIXME: This should also work with arithmetize 18 | circuit_good[0][2].to_svg("rebalancing_good.svg") 19 | 20 | ab = a.mul(b, flatten=False) 21 | abc = ab.mul(c, flatten=False) 22 | abcd = abc.mul(d, flatten=False) 23 | circuit_bad = Circuit(outputs=[abcd]) 24 | circuit_bad = circuit_bad.arithmetize() 25 | circuit_bad.to_svg("rebalancing_bad.svg") 26 | -------------------------------------------------------------------------------- /oraqle/experiments/depth_aware_arithmetization/execution/bench_cardio_circuits.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from typing import Dict 4 | 5 | from galois import GF 6 | 7 | from oraqle.circuits.cardio import ( 8 | construct_cardio_elevated_risk_circuit, 9 | construct_cardio_risk_circuit, 10 | ) 11 | from oraqle.compiler.circuit import Circuit 12 | 13 | 14 | def gen_params() -> Dict[str, int]: 15 | params = {} 16 | 17 | params["man"] = random.randint(0, 1) 18 | params["smoking"] = random.randint(0, 1) 19 | params["diabetic"] = random.randint(0, 1) 20 | params["hbp"] = random.randint(0, 1) 21 | 22 | params["age"] = random.randint(0, 100) 23 | params["cholesterol"] = random.randint(0, 60) 24 | params["weight"] = random.randint(40, 150) 25 | params["height"] = random.randint(80, 210) 26 | params["activity"] = random.randint(0, 250) 27 | params["alcohol"] = random.randint(0, 5) 28 | 29 | return params 30 | 31 | 32 | if __name__ == "__main__": 33 | gf = GF(257) 34 | iterations = 10 35 | 36 | for cost_of_squaring in [0.75]: 37 | print(f"--- Cardio risk assessment ({cost_of_squaring}) ---") 38 | circuit = Circuit([construct_cardio_risk_circuit(gf)]) 39 | 40 | start = time.monotonic() 41 | front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring) 42 | print("Compile time:", time.monotonic() - start, "s") 43 | 44 | for depth, cost, arithmetic_circuit in front: 45 | print(depth, cost) 46 | run_time = arithmetic_circuit.run_using_helib(iterations, True, False, **gen_params()) 47 | print("Run time:", run_time) 48 | 49 | print(f"--- Cardio elevated risk assessment ({cost_of_squaring}) ---") 50 | circuit = Circuit([construct_cardio_elevated_risk_circuit(gf)]) 51 | 52 | start = time.monotonic() 53 | front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring) 54 | print("Compile time:", time.monotonic() - start, "s") 55 | 56 | for depth, cost, arithmetic_circuit in front: 57 | print(depth, cost) 58 | run_time = arithmetic_circuit.run_using_helib(iterations, True, False, **gen_params()) 59 | print("Run time:", run_time) 60 | -------------------------------------------------------------------------------- /oraqle/experiments/depth_aware_arithmetization/execution/bench_equality.py: -------------------------------------------------------------------------------- 1 | from galois import GF 2 | 3 | from oraqle.compiler.circuit import Circuit 4 | from oraqle.compiler.nodes.leafs import Input 5 | 6 | 7 | if __name__ == "__main__": 8 | iterations = 10 9 | 10 | for p in [29, 43, 61, 101, 131]: 11 | gf = GF(p) 12 | 13 | x = Input("x", gf) 14 | y = Input("y", gf) 15 | 16 | circuit = Circuit([x == y]) 17 | 18 | for d, c, arith in circuit.arithmetize_depth_aware(0.75): 19 | print(d, c, arith.run_using_helib(10, True, False, x=13, y=19)) 20 | 21 | arith = circuit.arithmetize('naive') 22 | print('square and multiply', arith.multiplicative_depth(), arith.multiplicative_size(), arith.multiplicative_cost(0.75), arith.run_using_helib(10, True, False, x=13, y=19)) 23 | -------------------------------------------------------------------------------- /oraqle/experiments/depth_aware_arithmetization/execution/cardio_circuits.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from galois import GF 4 | 5 | from oraqle.circuits.cardio import ( 6 | construct_cardio_elevated_risk_circuit, 7 | construct_cardio_risk_circuit, 8 | ) 9 | from oraqle.compiler.circuit import Circuit 10 | 11 | if __name__ == "__main__": 12 | gf = GF(257) 13 | 14 | for cost_of_squaring in [0.5, 0.75, 1.0]: 15 | print(f"--- Cardio risk assessment ({cost_of_squaring}) ---") 16 | circuit = Circuit([construct_cardio_risk_circuit(gf)]) 17 | 18 | start = time.monotonic() 19 | front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring) 20 | print("Run time:", time.monotonic() - start, "s") 21 | 22 | for depth, cost, arithmetic_circuit in front: 23 | print(depth, cost) 24 | arithmetic_circuit.to_graph(f"cardio_arith_d{depth}_c{cost}.dot") 25 | 26 | print(f"--- Cardio elevated risk assessment ({cost_of_squaring}) ---") 27 | circuit = Circuit([construct_cardio_elevated_risk_circuit(gf)]) 28 | 29 | start = time.monotonic() 30 | front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring) 31 | print("Run time:", time.monotonic() - start, "s") 32 | 33 | for depth, cost, arithmetic_circuit in front: 34 | print(depth, cost) 35 | arithmetic_circuit.to_graph(f"cardio_elevated_arith_d{depth}_c{cost}.dot") 36 | -------------------------------------------------------------------------------- /oraqle/experiments/depth_aware_arithmetization/execution/comparisons.py: -------------------------------------------------------------------------------- 1 | from galois import GF 2 | 3 | from oraqle.compiler.circuit import Circuit 4 | from oraqle.compiler.comparison.comparison import ( 5 | IliashenkoZuccaSemiLessThan, 6 | SemiStrictComparison, 7 | T2SemiLessThan, 8 | ) 9 | from oraqle.compiler.nodes.leafs import Input 10 | 11 | if __name__ == "__main__": 12 | iterations = 10 13 | 14 | for p in [29, 43, 61, 101, 131]: 15 | gf = GF(p) 16 | 17 | x = Input("x", gf) 18 | y = Input("y", gf) 19 | 20 | print(f"-------- p = {p}: ---------") 21 | our_circuit = Circuit([SemiStrictComparison(x, y, less_than=True, gf=gf)]) 22 | our_front = our_circuit.arithmetize_depth_aware() 23 | print("Our circuits:", our_front) 24 | 25 | our_front[0][2].to_graph(f"comp_{p}_ours.dot") 26 | for d, s, circ in our_front: 27 | print(d, s, circ.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)) 28 | 29 | t2_circuit = Circuit([T2SemiLessThan(x, y, gf)]) 30 | t2_arithmetization = t2_circuit.arithmetize() 31 | print( 32 | "T2 circuit:", 33 | t2_arithmetization.multiplicative_depth(), 34 | t2_arithmetization.multiplicative_size(), 35 | t2_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22) 36 | ) 37 | t2_arithmetization.eliminate_subexpressions() 38 | print( 39 | "T2 circuit CSE:", 40 | t2_arithmetization.multiplicative_depth(), 41 | t2_arithmetization.multiplicative_size(), 42 | t2_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22) 43 | ) 44 | 45 | iz21_circuit = Circuit([IliashenkoZuccaSemiLessThan(x, y, gf)]) 46 | iz21_arithmetization = iz21_circuit.arithmetize() 47 | print( 48 | "IZ21 circuits:", 49 | iz21_arithmetization.multiplicative_depth(), 50 | iz21_arithmetization.multiplicative_size(), 51 | iz21_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22) 52 | ) 53 | iz21_arithmetization.eliminate_subexpressions() 54 | print( 55 | "IZ21 circuit CSE:", 56 | iz21_arithmetization.multiplicative_depth(), 57 | iz21_arithmetization.multiplicative_size(), 58 | iz21_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22) 59 | ) 60 | -------------------------------------------------------------------------------- /oraqle/experiments/depth_aware_arithmetization/execution/equality_first_prime_mods_exec.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing 3 | import pickle 4 | import time 5 | from functools import partial 6 | from typing import List, Tuple 7 | 8 | from matplotlib import pyplot as plt 9 | from sympy import sieve 10 | 11 | from oraqle.add_chains.addition_chains_front import chain_depth, gen_pareto_front 12 | from oraqle.add_chains.addition_chains_mod import chain_cost, hw 13 | 14 | 15 | def experiment( 16 | t: int, squaring_cost: float 17 | ) -> Tuple[List[Tuple[int, float, List[Tuple[int, int]]]], float]: 18 | start = time.monotonic() 19 | chains = gen_pareto_front( 20 | t - 1, 21 | modulus=t - 1, 22 | squaring_cost=squaring_cost, 23 | solver="glucose42", 24 | encoding=1, 25 | thurber=True, 26 | ) 27 | duration = time.monotonic() - start 28 | 29 | return [ 30 | (chain_depth(chain, modulus=t - 1), chain_cost(chain, squaring_cost), chain) 31 | for _, chain in chains 32 | ], duration 33 | 34 | 35 | def experiment2( 36 | t: int, squaring_cost: float 37 | ) -> Tuple[List[Tuple[int, float, List[Tuple[int, int]]]], float]: 38 | start = time.monotonic() 39 | chains = gen_pareto_front( 40 | t - 1, 41 | modulus=None, 42 | squaring_cost=squaring_cost, 43 | solver="glucose42", 44 | encoding=1, 45 | thurber=True, 46 | ) 47 | duration = time.monotonic() - start 48 | 49 | return [ 50 | (chain_depth(chain), chain_cost(chain, squaring_cost), chain) for _, chain in chains 51 | ], duration 52 | 53 | 54 | def plot_specific_outputs(specific_outputs, specific_outputs_nomod, primes, squaring_cost: float): 55 | plt.figure(figsize=(9, 2.8)) 56 | plt.grid(axis="y", zorder=-1000, alpha=0.5) 57 | 58 | for x, p in enumerate(primes): 59 | label = "Square & multiply" if p == 2 else None 60 | t = p - 1 61 | plt.scatter( 62 | x, 63 | math.ceil(math.log2(t)) * squaring_cost + hw(t) - 1, 64 | color="black", 65 | label=label, 66 | zorder=100, 67 | marker="_", 68 | ) 69 | 70 | for x, outputs in enumerate(specific_outputs): 71 | chains, _ = outputs 72 | for depth, cost, _ in chains: 73 | plt.scatter( 74 | x, 75 | cost, 76 | color="black", 77 | zorder=100, 78 | s=50, 79 | label="Optimal circuit" if x == 0 else None, 80 | ) 81 | if len(chains) > 1: 82 | plt.text( 83 | x, 84 | cost - 0.05, 85 | str(depth), 86 | fontsize=6, 87 | ha="center", 88 | va="center", 89 | color="white", 90 | zorder=200, 91 | fontweight="bold", 92 | ) 93 | 94 | plt.xticks(range(len(primes)), primes, rotation=50) 95 | plt.yticks(range(2 * math.ceil(math.log2(primes[-1])))) 96 | 97 | plt.xlabel("Modulus") 98 | plt.ylabel("Multiplicative cost") 99 | 100 | ax1 = plt.gca() 101 | ax2 = ax1.twinx() 102 | for x, outputs in enumerate(specific_outputs): 103 | _, duration = outputs 104 | ax2.bar(x, duration, color="tab:cyan", zorder=0, alpha=0.3, label="Considering modulus" if x == 0 else None) # type: ignore 105 | for x, outputs in enumerate(specific_outputs_nomod): 106 | _, duration = outputs 107 | ax2.bar(x, duration, color="tab:cyan", zorder=0, alpha=1.0, label="Ignoring modulus" if x == 0 else None) # type: ignore 108 | ax2.set_ylabel("Generation time [s]", color="tab:cyan", alpha=1.0) 109 | 110 | ax1.step( 111 | range(len(primes)), 112 | [squaring_cost * math.ceil(math.log2(p - 1)) for p in primes], 113 | zorder=10, 114 | color="black", 115 | where="mid", 116 | label="Lower bound", 117 | linestyle=":", 118 | ) 119 | 120 | # Combine legends from both axes 121 | lines, labels = ax1.get_legend_handles_labels() 122 | lines2, labels2 = ax2.get_legend_handles_labels() # type: ignore 123 | ax1.legend(lines + lines2, labels + labels2, loc="upper left", fontsize="small") 124 | 125 | plt.savefig(f"equality_first_prime_mods_{squaring_cost}.pdf", bbox_inches="tight") 126 | plt.show() 127 | 128 | 129 | if __name__ == "__main__": 130 | run_experiments = False 131 | 132 | if run_experiments: 133 | multiprocessing.set_start_method("fork") 134 | threads = 4 135 | pool = multiprocessing.Pool(threads) 136 | 137 | primes = list(sieve.primerange(300))[:30] # [:50] 138 | 139 | for sqr_cost in [0.5, 0.75, 1.0]: 140 | print(f"Computing for {sqr_cost}") 141 | experiment_sqr_cost = partial(experiment, squaring_cost=sqr_cost) 142 | outs = list(pool.map(experiment_sqr_cost, primes)) 143 | 144 | with open(f"equality_experiment_{sqr_cost}_mod.pkl", mode="wb") as file: 145 | pickle.dump((primes, outs), file) 146 | 147 | for sqr_cost in [0.5, 0.75, 1.0]: 148 | print(f"Computing for {sqr_cost}") 149 | experiment_sqr_cost = partial(experiment2, squaring_cost=sqr_cost) 150 | outs = list(pool.map(experiment_sqr_cost, primes)) 151 | 152 | with open(f"equality_experiment_{sqr_cost}_nomod.pkl", mode="wb") as file: 153 | pickle.dump((primes, outs), file) 154 | 155 | # Visualize 156 | with open("equality_experiment_0.5_mod.pkl", "rb") as file: 157 | primes_05_mod, outputs_05_mod = pickle.load(file) 158 | with open("equality_experiment_0.75_mod.pkl", "rb") as file: 159 | primes_075_mod, outputs_075_mod = pickle.load(file) 160 | with open("equality_experiment_1.0_mod.pkl", "rb") as file: 161 | primes_10_mod, outputs_10_mod = pickle.load(file) 162 | 163 | with open("equality_experiment_0.5_nomod.pkl", "rb") as file: 164 | primes_05_nomod, outputs_05_nomod = pickle.load(file) 165 | with open("equality_experiment_0.75_nomod.pkl", "rb") as file: 166 | primes_075_nomod, outputs_075_nomod = pickle.load(file) 167 | with open("equality_experiment_1.0_nomod.pkl", "rb") as file: 168 | primes_10_nomod, outputs_10_nomod = pickle.load(file) 169 | 170 | # All the primes should match 171 | primes = primes_10_mod 172 | assert primes == primes_05_mod 173 | assert primes == primes_075_mod 174 | assert primes == primes_05_nomod 175 | assert primes == primes_075_nomod 176 | assert primes == primes_10_nomod 177 | 178 | # All the chains should match (not in theory, but for this visualization they should) 179 | assert all( 180 | all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_05_mod, outputs_05_nomod) 181 | ) 182 | assert all( 183 | all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_075_mod, outputs_075_nomod) 184 | ) 185 | assert all( 186 | all(x == y for x, y in zip(a[0], b[0])) for a, b in zip(outputs_10_mod, outputs_10_nomod) 187 | ) 188 | 189 | plot_specific_outputs(outputs_05_mod, outputs_05_nomod, primes, squaring_cost=0.5) 190 | plot_specific_outputs(outputs_075_mod, outputs_075_nomod, primes, squaring_cost=0.75) 191 | plot_specific_outputs(outputs_10_mod, outputs_10_nomod, primes, squaring_cost=1.0) 192 | -------------------------------------------------------------------------------- /oraqle/experiments/depth_aware_arithmetization/execution/poly_evaluation_pareto_front.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | 4 | from galois import GF 5 | from matplotlib import pyplot as plt 6 | from matplotlib.ticker import MultipleLocator 7 | 8 | from oraqle.compiler.circuit import Circuit 9 | from oraqle.compiler.nodes.abstract import SizeParetoFront 10 | from oraqle.compiler.nodes.leafs import Input 11 | from oraqle.compiler.polynomials.univariate import ( 12 | UnivariatePoly, 13 | _eval_poly, 14 | _eval_poly_alternative, 15 | _eval_poly_divide_conquer, 16 | ) 17 | 18 | if __name__ == "__main__": 19 | sys.setrecursionlimit(15000) 20 | 21 | shape_size = 150 22 | 23 | plt.figure(figsize=(3.5, 4.4)) 24 | 25 | marker1 = (3, 2, 0) 26 | marker2 = (3, 2, 40) 27 | marker3 = (3, 2, 80) 28 | o_marker = "o" 29 | linewidth = 2.5 30 | 31 | squaring_cost = 1.0 32 | 33 | p = 127 # 31 34 | gf = GF(p) 35 | for d in [p - 1]: 36 | x = Input("x", gf) 37 | 38 | poly = UnivariatePoly.from_function(x, gf, lambda x: x % 7) 39 | coefficients = poly._coefficients 40 | 41 | # Generate points 42 | print("Paterson & Stockmeyer") 43 | depths = [] 44 | sizes = [] 45 | 46 | front = SizeParetoFront() 47 | 48 | for k in range(1, len(coefficients)): 49 | res, pows = _eval_poly(x, coefficients, k, gf, squaring_cost) 50 | circ = Circuit([res]).arithmetize() 51 | depths.append(circ.multiplicative_depth()) 52 | sizes.append(circ.multiplicative_size()) 53 | front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore 54 | print(k, circ.multiplicative_depth(), circ.multiplicative_size()) 55 | 56 | data = {(d, s) for d, s in zip(depths, sizes)} 57 | plt.scatter( 58 | [d for d, _ in data], 59 | [s for _, s in data], 60 | marker=marker2, # type: ignore 61 | zorder=10, 62 | alpha=0.4, 63 | s=shape_size, 64 | linewidth=linewidth, 65 | ) 66 | 67 | print("Baby-step giant-step") 68 | depths2 = [] 69 | sizes2 = [] 70 | for k in range(1, len(coefficients)): 71 | res, pows = _eval_poly_alternative(x, coefficients, k, gf) 72 | circ = Circuit([res]).arithmetize() 73 | depths2.append(circ.multiplicative_depth()) 74 | sizes2.append(circ.multiplicative_size()) 75 | front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore 76 | 77 | data2 = {(d, s) for d, s in zip(depths2, sizes2)} 78 | plt.scatter( 79 | [d for d, _ in data2], 80 | [s for _, s in data2], 81 | marker=marker1, # type: ignore 82 | zorder=11, 83 | alpha=0.45, 84 | s=shape_size, 85 | linewidth=linewidth, 86 | ) 87 | 88 | print("Divide and conquer") 89 | depths3 = [] 90 | sizes3 = [] 91 | for k in range(1, len(coefficients)): 92 | res, pows = _eval_poly_divide_conquer(x, coefficients, k, gf, squaring_cost) 93 | circ = Circuit([res]).arithmetize() 94 | depths3.append(circ.multiplicative_depth()) 95 | sizes3.append(circ.multiplicative_size()) 96 | front.add(res, circ.multiplicative_depth(), circ.multiplicative_size()) # type: ignore 97 | 98 | data3 = {(d, s) for d, s in zip(depths3, sizes3)} 99 | plt.scatter( 100 | [d for d, _ in data3], 101 | [s for _, s in data3], 102 | marker=marker3, # type: ignore 103 | zorder=11, 104 | alpha=0.45, 105 | s=shape_size, 106 | linewidth=linewidth, 107 | ) 108 | 109 | # Plot the front 110 | front_initial = [(d, s) for d, s in data2 if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore 111 | front_advanced = [(d, s) for d, s in data if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore 112 | front_divconq = [(d, s) for d, s in data3 if d in front._nodes_by_depth and front._nodes_by_depth[d][0] == s] # type: ignore 113 | 114 | plt.scatter( 115 | [d for d, _ in front_initial], 116 | [s for _, s in front_initial], 117 | marker=marker1, # type: ignore 118 | zorder=10, 119 | color="tab:orange", 120 | s=shape_size, 121 | label="Baby-step giant-step", 122 | linewidth=linewidth, 123 | ) 124 | plt.scatter( 125 | [d for d, _ in front_advanced], 126 | [s for _, s in front_advanced], 127 | marker=marker2, # type: ignore 128 | zorder=10, 129 | color="tab:blue", 130 | s=shape_size, 131 | label="Paterson & Stockmeyer", 132 | linewidth=linewidth, 133 | ) 134 | plt.scatter( 135 | [d for d, _ in front_divconq], 136 | [s for _, s in front_divconq], 137 | marker=marker3, # type: ignore 138 | zorder=10, 139 | color="tab:green", 140 | s=shape_size, 141 | label="Divide & Conquer", 142 | linewidth=linewidth, 143 | ) 144 | 145 | k = round(math.sqrt(d / 2)) 146 | res, pows = _eval_poly(x, coefficients, k, gf, squaring_cost) 147 | circ = Circuit([res]).arithmetize() 148 | plt.scatter( 149 | circ.multiplicative_depth(), 150 | circ.multiplicative_size(), 151 | marker=o_marker, 152 | s=shape_size + 50, 153 | facecolors="none", 154 | edgecolors="black", 155 | ) 156 | plt.text( 157 | circ.multiplicative_depth(), 158 | circ.multiplicative_size() + 0.4, 159 | f"k = {k}", 160 | ha="center", 161 | fontsize=8, 162 | ) 163 | 164 | plt.xlim((5, 15)) 165 | plt.ylim((15, 30)) 166 | 167 | plt.gca().set_aspect("equal") 168 | 169 | plt.gca().xaxis.set_minor_locator(MultipleLocator(1)) 170 | plt.gca().yaxis.set_minor_locator(MultipleLocator(1)) 171 | 172 | plt.grid(True, which="both", zorder=1, alpha=0.5) 173 | 174 | plt.xlabel("Multiplicative depth") 175 | plt.ylabel("Multiplicative size") 176 | 177 | plt.legend(fontsize="small") 178 | 179 | plt.savefig("poly_eval_front_2.pdf", bbox_inches="tight") 180 | plt.show() 181 | -------------------------------------------------------------------------------- /oraqle/experiments/depth_aware_arithmetization/execution/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the directory where the script is located 4 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 5 | 6 | # Change to the script's directory 7 | cd "$SCRIPT_DIR" 8 | 9 | # Loop through all Python files in the script's directory 10 | for file in *.py 11 | do 12 | # Check if there are any Python files 13 | if [ -e "$file" ]; then 14 | echo "Running $file" 15 | python3 "$file" 16 | else 17 | echo "No Python files found in the script's directory." 18 | break 19 | fi 20 | done 21 | -------------------------------------------------------------------------------- /oraqle/experiments/depth_aware_arithmetization/execution/veto_voting_per_mod.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from galois import GF 4 | from matplotlib import pyplot as plt 5 | from sympy import sieve 6 | 7 | from oraqle.compiler.boolean.bool_and import _minimum_cost 8 | from oraqle.compiler.boolean.bool_or import Or 9 | from oraqle.compiler.circuit import Circuit 10 | from oraqle.compiler.nodes.abstract import CostParetoFront, UnoverloadedWrapper 11 | from oraqle.compiler.nodes.leafs import Input 12 | from oraqle.experiments.oraqle_spotlight.experiments.veto_voting_minimal_cost import ( 13 | exponentiation_results, 14 | ) 15 | 16 | 17 | def generate_all_fronts(): 18 | results = {} 19 | 20 | for p in [7, 11, 13, 17]: 21 | fronts = [] 22 | 23 | print(f"------ p = {p} ------") 24 | for k in range(2, 51): 25 | gf = GF(p) 26 | xs = [Input(f"x{i}", gf) for i in range(k)] 27 | 28 | circuit = Circuit([Or(set(UnoverloadedWrapper(x) for x in xs), gf)]) 29 | front = circuit.arithmetize_depth_aware(cost_of_squaring=1.0) 30 | 31 | print(f"{k}.", end=" ") 32 | for f in front: 33 | print(f[0], f[1], end=" ") 34 | 35 | print() 36 | fronts.append(front) 37 | 38 | results[p] = fronts 39 | 40 | return results 41 | 42 | 43 | def plot_fronts(fronts: List[CostParetoFront], color, label, **kwargs): 44 | plt.scatter([], [], color=color, label=label, **kwargs) 45 | for k, front in zip(range(2, 51), fronts): 46 | for depth, cost, _ in front: 47 | kwargs["marker"] = (depth, 2, 0) 48 | kwargs["s"] = 16 49 | kwargs["linewidth"] = 0.5 50 | plt.scatter(k, cost, color=color, **kwargs) 51 | 52 | 53 | if __name__ == "__main__": 54 | fronts_by_p = generate_all_fronts() 55 | max_k = 50 56 | 57 | plt.figure(figsize=(4, 4)) 58 | 59 | plt.plot( 60 | range(2, max_k + 1), 61 | [k - 1 for k in range(2, max_k + 1)], 62 | color="gray", 63 | linestyle="solid", 64 | label="Naive", 65 | linewidth=0.7, 66 | ) 67 | 68 | plot_fronts(fronts_by_p[7], "tab:purple", "Modulus p = 7", zorder=100) 69 | plot_fronts(fronts_by_p[13], "tab:green", "Modulus p = 13", zorder=100) 70 | 71 | best_costs = [100000000.0] * (max_k + 1) 72 | best_ps = [None] * (max_k + 1) 73 | # This is for sqr = 0.75 mul 74 | primes = list(sieve.primerange(300))[1:50] 75 | for p in primes: 76 | for k in range(2, max_k + 1): 77 | cost = _minimum_cost(k, exponentiation_results[p][0][0][1], p) 78 | if cost < best_costs[k - 2]: 79 | best_costs[k - 2] = cost 80 | best_ps[k - 2] = p 81 | 82 | plt.step( 83 | range(2, max_k + 1), 84 | best_costs[:-2], 85 | zorder=10, 86 | color="gray", 87 | where="mid", 88 | label="Lowest for any p", 89 | linestyle="solid", 90 | linewidth=0.7, 91 | ) 92 | 93 | plt.legend() 94 | 95 | plt.xlabel("Number of operands") 96 | plt.ylabel("Multiplicative size") 97 | 98 | plt.savefig("veto_voting.pdf", bbox_inches="tight") 99 | plt.show() 100 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/examples/and_16.py: -------------------------------------------------------------------------------- 1 | from galois import GF 2 | 3 | from oraqle.compiler.boolean.bool_and import all_ 4 | from oraqle.compiler.circuit import Circuit 5 | from oraqle.compiler.nodes.leafs import Input 6 | 7 | if __name__ == "__main__": 8 | gf = GF(17) 9 | 10 | xs = (Input(f"x{i + 1}", gf) for i in range(16)) 11 | 12 | conjunction = all_(*xs) 13 | 14 | circuit = Circuit([conjunction]) 15 | arithmetic_circuit = circuit.arithmetize() 16 | 17 | arithmetic_circuit.to_pdf("conjunction.pdf") 18 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/examples/common_expressions.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from galois import GF 4 | 5 | from oraqle.compiler.circuit import Circuit 6 | from oraqle.compiler.nodes.abstract import Node 7 | from oraqle.compiler.nodes.arbitrary_arithmetic import sum_ 8 | from oraqle.compiler.nodes.leafs import Input 9 | 10 | 11 | def generate_nodes() -> Tuple[Node, Node]: 12 | gf = GF(31) 13 | 14 | x = Input("x", gf) 15 | y = Input("y", gf) 16 | z1 = Input("z1", gf) 17 | z2 = Input("z2", gf) 18 | z3 = Input("z3", gf) 19 | z4 = Input("z4", gf) 20 | 21 | comparison = x < y 22 | sum = sum_(z1, z2, z3, z4) 23 | cse1 = comparison & sum 24 | 25 | comparison = y > x 26 | sum = sum_(z3, z2, z4) + z1 27 | cse2 = sum & comparison 28 | 29 | return cse1, cse2 30 | 31 | 32 | def test_cse_equivalence(): 33 | cse1, cse2 = generate_nodes() 34 | assert cse1.is_equivalent(cse2) 35 | 36 | 37 | if __name__ == "__main__": 38 | cse1, cse2 = generate_nodes() 39 | 40 | cse1 = Circuit([cse1]) 41 | cse2 = Circuit([cse2]) 42 | 43 | cse1.to_pdf("cse1.pdf") 44 | cse2.to_pdf("cse2.pdf") 45 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/examples/equality_31.py: -------------------------------------------------------------------------------- 1 | from galois import GF 2 | 3 | from oraqle.compiler.circuit import Circuit 4 | from oraqle.compiler.nodes.leafs import Input 5 | 6 | if __name__ == "__main__": 7 | gf = GF(31) 8 | 9 | x = Input("x", gf) 10 | y = Input("y", gf) 11 | 12 | equality = x == y 13 | 14 | circuit = Circuit([equality]) 15 | arithmetic_circuits = circuit.arithmetize_depth_aware(cost_of_squaring=1.0) 16 | 17 | for d, _, arithmetic_circuit in arithmetic_circuits: 18 | arithmetic_circuit.to_pdf(f"equality_{d}.pdf") 19 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/examples/equality_and_comparison.py: -------------------------------------------------------------------------------- 1 | from galois import GF 2 | 3 | from oraqle.compiler.circuit import Circuit 4 | from oraqle.compiler.nodes.leafs import Input 5 | 6 | if __name__ == "__main__": 7 | gf = GF(31) 8 | 9 | x = Input("x", gf) 10 | y = Input("y", gf) 11 | z = Input("z", gf) 12 | 13 | comparison = x < y 14 | equality = y == z 15 | both = comparison & equality 16 | 17 | circuit = Circuit([both]) 18 | 19 | circuit.to_pdf("example.pdf") 20 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/examples/t2_comparison.py: -------------------------------------------------------------------------------- 1 | from galois import GF 2 | 3 | from oraqle.compiler.circuit import Circuit 4 | from oraqle.compiler.nodes.leafs import Input 5 | 6 | p = 7 7 | gf = GF(p) 8 | 9 | x = Input("x", gf) 10 | y = Input("y", gf) 11 | 12 | comparison = 0 13 | 14 | for a in range((p + 1) // 2, p): 15 | comparison += 1 - (x - y - a) ** (p - 1) 16 | 17 | circuit = Circuit([comparison]) # type: ignore 18 | 19 | if __name__ == "__main__": 20 | circuit.to_graph("t2.dot") 21 | circuit.to_pdf("t2.pdf") 22 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/experiments/comparisons/comparisons_bench.py: -------------------------------------------------------------------------------- 1 | import random 2 | import subprocess 3 | 4 | from galois import GF 5 | from matplotlib import pyplot as plt 6 | from sympy import sieve 7 | 8 | from oraqle.compiler.circuit import ArithmeticCircuit, Circuit 9 | from oraqle.compiler.comparison.comparison import SemiStrictComparison, T2SemiLessThan 10 | from oraqle.compiler.nodes.leafs import Input 11 | 12 | 13 | def run_benchmark(arithmetic_circuit: ArithmeticCircuit) -> float: 14 | # Prepare the benchmark 15 | arithmetic_circuit.generate_code("main.cpp", iterations=10, measure_time=True) 16 | subprocess.run("make", capture_output=True, check=True) 17 | 18 | # Run the benchmark 19 | command = ["./main"] 20 | p = arithmetic_circuit._gf.characteristic 21 | command.append(f"x={random.randint(0, p - 1)}") 22 | command.append(f"y={random.randint(0, p - 1)}") 23 | print("Running:", " ".join(command)) 24 | result = subprocess.run(command, capture_output=True, text=True, check=False) 25 | 26 | if result.returncode != 0: 27 | print("stderr:") 28 | print(result.stderr) 29 | print() 30 | print("stdout:") 31 | print(result.stdout) 32 | 33 | # Check if the noise was not too large 34 | print(result.stdout) 35 | lines = result.stdout.splitlines() 36 | for line in lines[:-1]: 37 | assert line.endswith("1") 38 | 39 | run_time = float(lines[-1]) / 10 40 | print(p, run_time) 41 | 42 | return run_time 43 | 44 | 45 | if __name__ == "__main__": 46 | slides = True 47 | run_benchmarks = False 48 | gen_plots = True 49 | 50 | if run_benchmarks: 51 | primes = list(sieve.primerange(300))[2:20] 52 | 53 | our_times = [] 54 | t2_times = [] 55 | 56 | for p in primes: 57 | gf = GF(p) 58 | 59 | x = Input("x", gf) 60 | y = Input("y", gf) 61 | 62 | print(f"-------- p = {p}: ---------") 63 | our_circuit = Circuit([SemiStrictComparison(x, y, less_than=True, gf=gf)]) 64 | our_front = our_circuit.arithmetize_depth_aware() 65 | print("Our circuits:", our_front) 66 | 67 | ts = [] 68 | for _, _, arithmetic_circuit in our_front: 69 | ts.append(run_benchmark(arithmetic_circuit)) 70 | our_times.append(tuple(ts)) 71 | 72 | t2_circuit = Circuit([T2SemiLessThan(x, y, gf)]) 73 | t2_arithmetization = t2_circuit.arithmetize() 74 | print( 75 | "T2 circuit:", 76 | t2_arithmetization.multiplicative_depth(), 77 | t2_arithmetization.multiplicative_size(), 78 | ) 79 | 80 | t2_times.append(run_benchmark(t2_arithmetization)) 81 | 82 | print(primes) 83 | print(our_times) 84 | print(t2_times) 85 | 86 | if gen_plots: 87 | primes = [5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71] 88 | our_times = [(0.0156603,), (0.0523416,), (0.0954489,), (0.0936497,), (0.111959,), (0.128402,), (0.288951,), (0.42076, 0.368583), (0.416362,), (0.40343,), (0.385652,), (0.437486,), (0.481356,), (0.522607, 0.504944), (0.526451,), (0.5904119999999999, 0.5146740000000001), (0.592896,), (0.621265, 0.598357)] 89 | t2_times = [0.0156379, 0.0938689, 0.23473899999999998, 0.319668, 0.366707, 0.6632450000000001, 1.8380299999999998, 1.14859, 2.9022200000000002, 3.2060299999999997, 3.5419899999999997, 4.53918, 5.02624, 5.4439, 8.64118, 6.6267499999999995, 6.99609, 9.21295] 90 | 91 | if slides: 92 | plt.figure(figsize=(7, 4)) 93 | else: 94 | plt.figure(figsize=(4, 2)) 95 | plt.grid(axis="y", zorder=-1000, alpha=0.5) 96 | 97 | plt.scatter( 98 | range(len(primes)), t2_times, marker="_", label="T2's Circuit", color="tab:orange", s=100 if slides else None 99 | ) 100 | 101 | for x, ts in enumerate(our_times): 102 | for t in ts: 103 | plt.scatter( 104 | x, 105 | t, 106 | marker="_", 107 | label="Oraqle's circuits" if x == 0 else None, 108 | color="tab:cyan", 109 | s=100 if slides else None 110 | ) 111 | 112 | plt.xticks(range(len(primes)), primes, fontsize=8) # type: ignore 113 | 114 | plt.xlabel("Modulus") 115 | plt.ylabel("Run time (s)") 116 | 117 | plt.legend() 118 | 119 | plt.savefig(f"t2_comparison{'_slides' if slides else ''}.pdf", bbox_inches="tight") 120 | plt.show() 121 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/experiments/large_equality/.gitignore: -------------------------------------------------------------------------------- 1 | /CMakeFiles 2 | CMakeCache.txt 3 | cmake_install.cmake 4 | helib.log 5 | Makefile 6 | main 7 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/experiments/large_equality/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10.2 FATAL_ERROR) 2 | 3 | set(CMAKE_CXX_STANDARD 17) 4 | set(CMAKE_CXX_EXTENSIONS OFF) 5 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 6 | 7 | find_package(helib) 8 | add_executable(main main.cpp) 9 | target_link_libraries(main helib) 10 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/experiments/large_equality/large_equality.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import subprocess 4 | import time 5 | from typing import List, Tuple 6 | 7 | from galois import GF 8 | from sympy import sieve 9 | 10 | from oraqle.compiler.boolean.bool_and import all_ 11 | from oraqle.compiler.circuit import ArithmeticCircuit, Circuit 12 | from oraqle.compiler.nodes.leafs import Input 13 | 14 | 15 | def generate_circuits(bits: int) -> List[Tuple[int, ArithmeticCircuit, int, float]]: 16 | circuits = [] 17 | 18 | primes = list(sieve.primerange(300))[:10] # [:55] # p <= 257 19 | start = time.monotonic() 20 | times = [] 21 | for p in primes: 22 | # (6, 63.0): p=2 23 | # (7, 58.0): p=5 24 | # (8, 51.0): p=17 25 | 26 | limbs = math.ceil(bits / math.log2(p)) 27 | 28 | gf = GF(p) 29 | 30 | xs = [Input(f"x{i}", gf) for i in range(limbs)] 31 | ys = [Input(f"y{i}", gf) for i in range(limbs)] 32 | circuit = Circuit([all_(*(xs[i] == ys[i] for i in range(limbs)))]) 33 | 34 | inbetween = time.monotonic() 35 | front = circuit.arithmetize_depth_aware(0.75) 36 | 37 | print(f"{p}.", end=" ") 38 | 39 | for f in front: 40 | circuits.append((p, f[2], f[0], f[1])) 41 | print(f[0], f[1], end=" ") 42 | 43 | inbetween_time = time.monotonic() - inbetween 44 | print(inbetween_time) 45 | times.append((p, inbetween_time)) 46 | 47 | print(times) 48 | print("Total time", time.monotonic() - start) 49 | 50 | return circuits 51 | 52 | 53 | if __name__ == "__main__": 54 | bits = 64 55 | benchmark_circuits = False 56 | generate_table = True 57 | 58 | # Run a benchmark for all circuits in the front 59 | if benchmark_circuits: 60 | # Generate all circuits per p 61 | circuits = generate_circuits(bits) 62 | 63 | results = [] 64 | for p, arithmetic_circuit, d, c in circuits: 65 | # Prepare the benchmark 66 | params = arithmetic_circuit.generate_code("main.cpp", iterations=10, measure_time=True) 67 | subprocess.run("make", check=True) 68 | 69 | # Run the benchmark 70 | command = ["./main"] 71 | limbs = math.ceil(bits / math.log2(p)) 72 | for i in range(limbs): 73 | command.append(f"x{i}={random.randint(0, p - 1)}") 74 | command.append(f"y{i}={random.randint(0, p - 1)}") 75 | print("Running:", " ".join(command)) 76 | result = subprocess.run(command, capture_output=True, text=True, check=False) 77 | 78 | if result.returncode != 0: 79 | print("stderr:") 80 | print(result.stderr) 81 | print() 82 | print("stdout:") 83 | print(result.stdout) 84 | 85 | # Check if the noise was not too large 86 | print(result.stdout) 87 | lines = result.stdout.splitlines() 88 | for line in lines[:-1]: 89 | assert line.endswith("1") 90 | 91 | run_time = float(lines[-1]) / 10 92 | print(p, run_time, d, c, params) 93 | results.append((p, d, c, params, run_time)) 94 | 95 | print(results) 96 | 97 | if generate_table: 98 | gen_times = [(2, 0.007554411888122559), (3, 0.06264467351138592), (5, 8.457202550023794), (7, 0.05447225831449032), (11, 0.0478445328772068), (13, 0.052152080461382866), (17, 0.04349260404706001), (19, 0.04553743451833725), (23, 0.05198719538748264), (29, 0.046183058992028236)] 99 | results = [(2, 6, 63.0, (16383, 1, 142, 3), 3.27577), (3, 7, 60.75, (32768, 1, 170, 3), 1.51993), (5, 7, 58.0, (32768, 1, 178, 3), 1.7679099999999999), (5, 8, 55.5, (32768, 1, 197, 3), 1.93994), (7, 8, 74.0, (32768, 1, 206, 3), 2.90913), (7, 9, 70.0, (32768, 1, 226, 3), 2.6624600000000003), (7, 10, 69.5, (32768, 1, 246, 3), 3.00814), (11, 9, 69.25, (32768, 1, 228, 3), 2.50603), (11, 12, 68.25, (32768, 1, 300, 3), 3.25469), (13, 9, 68.75, (32768, 1, 237, 3), 2.67845), (13, 10, 67.75, (32768, 1, 237, 3), 2.7718), (13, 11, 66.0, (32768, 1, 237, 3), 2.56386), (13, 12, 65.0, (32768, 1, 301, 3), 3.10959), (17, 8, 51.0, (32768, 1, 217, 3), 1.8792300000000002), (19, 9, 79.0, (32768, 1, 238, 3), 2.85011), (19, 10, 68.0, (32768, 1, 259, 3), 2.8636500000000003), (23, 9, 89.0, (32768, 1, 248, 3), 4.135730000000001), (23, 10, 80.0, (32768, 1, 270, 3), 3.75128), (29, 9, 83.0, (32768, 1, 249, 3), 3.7119), (29, 10, 75.0, (32768, 1, 271, 3), 3.46666)] 100 | 101 | gen_times = {p: t for p, t in gen_times} 102 | 103 | for p, d, c, params, run_time in results: 104 | print(f"{p} & {d} & {c} & {params[0]} & {params[1]} & {params[2]} & {params[3]} & {round(gen_times[p], 2)} & {round(run_time, 2)} \\\\") 105 | -------------------------------------------------------------------------------- /oraqle/experiments/oraqle_spotlight/experiments/veto_voting_minimal_cost.py: -------------------------------------------------------------------------------- 1 | """Finds the minimum cost for veto voting circuits for different prime moduli.""" 2 | 3 | from sympy import sieve 4 | 5 | from oraqle.compiler.boolean.bool_and import _minimum_cost 6 | 7 | exponentiation_results = { 8 | 2: ([(0, 0.0)], 8.633400000002123e-05), 9 | 3: ([(1, 0.75)], 4.6670000000137435e-06), 10 | 5: ([(2, 1.5)], 7.695799999996034e-05), 11 | 7: ([(3, 2.5)], 0.0053472920000000035), 12 | 11: ([(4, 3.25)], 0.007671625000000015), 13 | 13: ([(4, 3.25)], 0.002812749999999975), 14 | 17: ([(4, 3.0)], 7.891700000001167e-05), 15 | 19: ([(5, 4.0)], 0.012155541999999964), 16 | 23: ([(5, 5.0)], 0.03937258299999996), 17 | 29: ([(5, 5.0)], 0.018942542000000007), 18 | 31: ([(5, 6.0), (6, 5.0)], 0.064326), 19 | 37: ([(6, 4.75)], 0.019883207999999986), 20 | 41: ([(6, 4.75)], 0.02284237499999997), 21 | 43: ([(6, 5.75)], 0.03223737499999996), 22 | 47: ([(6, 6.75), (7, 6.0)], 0.607119292), 23 | 53: ([(6, 5.75)], 0.03940958299999997), 24 | 59: ([(6, 6.75)], 1.243811584), 25 | 61: ([(6, 6.75), (7, 5.75)], 0.446000167), 26 | 67: ([(7, 5.5)], 0.051902208000000005), 27 | 71: ([(7, 6.5)], 0.18221370799999997), 28 | 73: ([(7, 5.5)], 0.044685417000000005), 29 | 79: ([(7, 6.75)], 0.362901958), 30 | 83: ([(7, 6.5)], 0.121000375), 31 | 89: ([(7, 6.5)], 0.182695375), 32 | 97: ([(7, 5.5)], 0.06858350000000002), 33 | 101: ([(7, 6.5)], 0.38408749999999997), 34 | 103: ([(7, 7.5), (8, 6.5)], 3.3626029170000002), 35 | 107: ([(7, 7.5)], 8.891771667), 36 | 109: ([(7, 7.5), (8, 6.5)], 4.596561917), 37 | 113: ([(7, 6.5)], 0.1859389579999995), 38 | 127: ([(7, 9.5), (8, 7.5)], 1619.89318625), 39 | 131: ([(8, 6.25)], 0.05858354099996177), 40 | 137: ([(8, 6.25)], 0.10623299999999991), 41 | 139: ([(8, 7.25)], 1.2351711669999998), 42 | 149: ([(8, 7.25)], 0.48292875), 43 | 151: ([(8, 7.5)], 4.641820375), 44 | 157: ([(8, 7.5)], 2.49218775), 45 | 163: ([(8, 7.25)], 0.5001321249999999), 46 | 167: ([(8, 8.25), (9, 7.5)], 48.444338791), 47 | 173: ([(8, 8.25), (9, 7.5)], 37.677076833), 48 | 179: ([(8, 8.25)], 132.232723375), 49 | 181: ([(8, 8.25), (9, 7.25)], 53.822612083999985), 50 | 191: ([(8, 9.25), (9, 8.25)], 907.7980847910001), 51 | 193: ([(8, 6.25)], 0.12370429100008096), 52 | 197: ([(8, 7.25)], 0.6496936670000002), 53 | 199: ([(8, 8.25), (9, 7.25)], 50.102889333), 54 | 211: ([(8, 8.25)], 83.20584475), 55 | 223: ([(8, 10.0), (9, 8.25)], 6772.927301542), 56 | 227: ([(8, 8.25)], 50.801469917), 57 | 229: ([(8, 8.25)], 39.942074416000004), 58 | } 59 | 60 | 61 | def run_experiments(): 62 | """Run the experiments and prints the results.""" 63 | max_k = 50 64 | best_costs = [100000000.0] * (max_k + 1) 65 | best_ps = [None] * (max_k + 1) 66 | 67 | # This is for sqr = 0.75 mul 68 | primes = list(sieve.primerange(300))[1:50] 69 | for p in primes: 70 | print(f"------ p = {p} ------") 71 | for k in range(2, max_k + 1): 72 | cost = _minimum_cost(k, exponentiation_results[p][0][0][1], p) 73 | if cost < best_costs[k - 2]: 74 | best_costs[k - 2] = cost 75 | best_ps[k - 2] = p 76 | 77 | for k, cost, p in zip(range(2, max_k + 1), best_costs, best_ps): 78 | print(k, cost, p) 79 | 80 | 81 | if __name__ == "__main__": 82 | run_experiments() 83 | -------------------------------------------------------------------------------- /oraqle/helib_template/.gitignore: -------------------------------------------------------------------------------- 1 | /CMakeFiles 2 | CMakeCache.txt 3 | cmake_install.cmake 4 | helib.log 5 | Makefile 6 | main 7 | -------------------------------------------------------------------------------- /oraqle/helib_template/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10.2 FATAL_ERROR) 2 | 3 | set(CMAKE_CXX_STANDARD 17) 4 | set(CMAKE_CXX_EXTENSIONS OFF) 5 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 6 | 7 | find_package(helib) 8 | add_executable(main main.cpp) 9 | target_link_libraries(main helib) 10 | -------------------------------------------------------------------------------- /oraqle/helib_template/__init__.py: -------------------------------------------------------------------------------- 1 | """Template containing all the things to build an HElib program.""" 2 | -------------------------------------------------------------------------------- /oraqle/helib_template/main.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | typedef helib::Ptxt ptxt_t; 9 | typedef helib::Ctxt ctxt_t; 10 | 11 | std::map input_map; 12 | 13 | void parse_arguments(int argc, char* argv[]) { 14 | for (int i = 1; i < argc; ++i) { 15 | std::string argument(argv[i]); 16 | size_t pos = argument.find('='); 17 | if (pos != std::string::npos) { 18 | std::string key = argument.substr(0, pos); 19 | int value = std::stoi(argument.substr(pos + 1)); 20 | input_map[key] = value; 21 | } 22 | } 23 | } 24 | 25 | int extract_input(const std::string& name) { 26 | if (input_map.find(name) != input_map.end()) { 27 | return input_map[name]; 28 | } else { 29 | std::cerr << "Error: " << name << " not found" << std::endl; 30 | return -1; 31 | } 32 | } 33 | 34 | int main(int argc, char* argv[]) { 35 | // Parse the inputs 36 | parse_arguments(argc, argv); 37 | 38 | // Set up the HE parameters 39 | unsigned long p = 5; 40 | unsigned long m = 8192; 41 | unsigned long r = 1; 42 | unsigned long bits = 72; 43 | unsigned long c = 3; 44 | helib::Context context = helib::ContextBuilder() 45 | .m(m) 46 | .p(p) 47 | .r(r) 48 | .bits(bits) 49 | .c(c) 50 | .build(); 51 | 52 | 53 | // Generate keys 54 | helib::SecKey secret_key(context); 55 | secret_key.GenSecKey(); 56 | helib::addSome1DMatrices(secret_key); 57 | const helib::PubKey& public_key = secret_key; 58 | 59 | // Encrypt the inputs 60 | std::vector vec_x(1, extract_input("x")); 61 | ptxt_t ptxt_x(context, vec_x); 62 | ctxt_t ciph_x(public_key); 63 | public_key.Encrypt(ciph_x, ptxt_x); 64 | std::vector vec_y(1, extract_input("y")); 65 | ptxt_t ptxt_y(context, vec_y); 66 | ctxt_t ciph_y(public_key); 67 | public_key.Encrypt(ciph_y, ptxt_y); 68 | 69 | // Perform the actual circuit 70 | ctxt_t stack_0 = ciph_x; 71 | ctxt_t stack_1 = ciph_y; 72 | stack_1 *= 4l; 73 | stack_0 += stack_1; 74 | stack_0 *= stack_0; 75 | stack_0 *= stack_0; 76 | stack_0 *= 4l; 77 | stack_0 += 1l; 78 | ptxt_t decrypted(context); 79 | secret_key.Decrypt(decrypted, stack_0); 80 | std::cout << decrypted << std::endl; 81 | 82 | return 0; 83 | } 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "oraqle" 3 | description = "Secure computation compiler for homomorphic encryption and arithmetic circuits in general" 4 | version = "0.1.6" 5 | requires-python = ">= 3.8" 6 | authors = [ 7 | {name = "Jelle Vos", email = "J.V.Vos@tudelft.nl"}, 8 | ] 9 | maintainers = [ 10 | {name = "Jelle Vos", email = "J.V.Vos@tudelft.nl"} 11 | ] 12 | readme = "README.md" 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sympy 2 | six 3 | galois>=0.3.8 4 | aeskeyschedule 5 | python-sat 6 | git+https://github.com/jellevos/fhegen.git 7 | matplotlib 8 | ipython 9 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | gensafeprime 3 | graphviz 4 | tabulate 5 | ruff 6 | mkdocs 7 | mkdocstrings[python] 8 | mkautodoc 9 | mkdocs-material 10 | pymdown-extensions 11 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Exclude a variety of commonly ignored directories. 2 | exclude = [ 3 | ".bzr", 4 | ".direnv", 5 | ".eggs", 6 | ".ipynb_checkpoints", 7 | ".mypy_cache", 8 | ".pyenv", 9 | ".pytest_cache", 10 | ".pytype", 11 | ".ruff_cache", 12 | ".venv", 13 | ".vscode", 14 | "__pypackages__", 15 | "_build", 16 | "build", 17 | "dist", 18 | "node_modules", 19 | "site-packages", 20 | "venv", 21 | ] 22 | 23 | line-length = 100 24 | indent-width = 4 25 | target-version = "py38" 26 | 27 | [lint] 28 | preview = true 29 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 30 | # McCabe complexity (`C901`) by default. 31 | select = ["W", "E4", "E7", "E9", "F", "ERA001", "B", "D", "DOC", "PLW", "B", "SIM", "UP", "PLR", "RUF", "PIE"] 32 | ignore = ["E203", "E501", "E731", "D105", "W293", "PLR2004", "PLR6301"] 33 | 34 | # Allow fix for all enabled rules (when `--fix`) is provided. 35 | fixable = ["ALL"] 36 | unfixable = [] 37 | 38 | # Allow unused variables when underscore-prefixed. 39 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 40 | 41 | [lint.per-file-ignores] 42 | "oraqle/experiments/*" = ["D", "DOC"] 43 | 44 | [lint.pydocstyle] 45 | # Use Google-style docstrings. 46 | convention = "google" 47 | 48 | [format] 49 | preview = true 50 | quote-style = "double" 51 | indent-style = "space" 52 | skip-magic-trailing-comma = false 53 | line-ending = "auto" 54 | docstring-code-format = true 55 | docstring-code-line-length = "dynamic" 56 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | python_files = *.py 3 | norecursedirs = venv build 4 | -------------------------------------------------------------------------------- /tests/test_circuit_sizes_costs.py: -------------------------------------------------------------------------------- 1 | """Test file for testing circuits sizes.""" 2 | 3 | from collections import Counter 4 | 5 | from galois import GF 6 | 7 | from oraqle.compiler.nodes.abstract import ArithmeticNode, UnoverloadedWrapper 8 | from oraqle.compiler.nodes.arbitrary_arithmetic import Sum 9 | from oraqle.compiler.nodes.leafs import Constant, Input 10 | 11 | 12 | def test_size_exponentiation_chain(): 13 | """Test.""" 14 | gf = GF(101) 15 | 16 | x = Input("x", gf) 17 | 18 | x = x.mul(x, flatten=False) 19 | x = x.mul(x, flatten=False) 20 | x = x.mul(x, flatten=False) 21 | 22 | x = x.to_arithmetic() 23 | assert isinstance(x, ArithmeticNode) 24 | assert ( 25 | x.multiplicative_size() == 3 26 | ), f"((x^2)^2)^2 should be 3 multiplications, but counted {x.multiplicative_size()}" 27 | assert x.multiplicative_cost(0.5) == 1.5 28 | 29 | 30 | def test_size_sum_of_products(): 31 | """Test.""" 32 | gf = GF(101) 33 | 34 | a = Input("a", gf) 35 | b = Input("b", gf) 36 | c = Input("c", gf) 37 | d = Input("d", gf) 38 | 39 | ab = a * b 40 | cd = c * d 41 | 42 | out = ab + cd 43 | out = out.to_arithmetic() 44 | 45 | assert isinstance(out, ArithmeticNode) 46 | assert ( 47 | out.multiplicative_size() == 2 48 | ), f"a * b + c * d should be 2 multiplications, but counted {out.multiplicative_size()}" 49 | assert out.multiplicative_cost(0.7) == 2 50 | 51 | 52 | def test_size_linear_function(): 53 | """Test.""" 54 | gf = GF(101) 55 | 56 | a = Input("a", gf) 57 | b = Input("b", gf) 58 | c = Input("c", gf) 59 | 60 | out = Sum( 61 | Counter({UnoverloadedWrapper(a): 1, UnoverloadedWrapper(b): 3, UnoverloadedWrapper(c): 1}), 62 | gf, 63 | gf(2), 64 | ) 65 | 66 | out = out.to_arithmetic() 67 | assert out.multiplicative_size() == 0 68 | assert out.multiplicative_cost(0.5) == 0 69 | 70 | 71 | def test_size_duplicate_nodes(): 72 | """Test.""" 73 | gf = GF(101) 74 | 75 | x = Input("x", gf) 76 | 77 | add1 = x.add(Constant(gf(1))) 78 | add2 = x.add(Constant(gf(1))) 79 | 80 | mul1 = x.mul(x, flatten=False) 81 | mul2 = x.mul(x, flatten=False) 82 | 83 | add3 = mul2.add(add2, flatten=False) 84 | 85 | mul3 = mul1.mul(add3, flatten=False) 86 | 87 | out = add1.add(mul3, flatten=False) 88 | 89 | out = out.to_arithmetic() 90 | 91 | assert isinstance(out, ArithmeticNode) 92 | assert out.multiplicative_size() == 3 93 | assert out.multiplicative_cost(0.7) == 2.4 94 | -------------------------------------------------------------------------------- /tests/test_poly2circuit.py: -------------------------------------------------------------------------------- 1 | """Test file for generating circuits using polynomial interpolation.""" 2 | 3 | import itertools 4 | 5 | from oraqle.compiler.func2poly import interpolate_polynomial 6 | from oraqle.compiler.poly2circuit import construct_circuit 7 | 8 | 9 | def _construct_and_test_circuit_from_bivariate_lambda(function, modulus: int, cse=False): 10 | poly = interpolate_polynomial(function, modulus, ["x", "y"]) 11 | circuit, gf = construct_circuit([poly], modulus) 12 | circuit = circuit.arithmetize() 13 | 14 | if cse: 15 | circuit.eliminate_subexpressions() 16 | 17 | for x, y in itertools.product(range(modulus), repeat=2): 18 | print(function, x, y) 19 | assert circuit.evaluate({"x": gf(x), "y": gf(y)}) == [function(x, y)] 20 | 21 | 22 | def test_inequality_mod7(): 23 | """Tests x != y (mod 7).""" 24 | _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: int(x != y), modulus=7) 25 | 26 | 27 | def test_inequality_mod13(): 28 | """Tests x != y (mod 13).""" 29 | _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: int(x != y), modulus=13) 30 | 31 | 32 | def test_max_mod7(): 33 | """Tests max(x, y) (mod 7).""" 34 | _construct_and_test_circuit_from_bivariate_lambda(max, modulus=7) 35 | 36 | 37 | def test_max_mod13(): 38 | """Tests max(x, y) (mod 13).""" 39 | _construct_and_test_circuit_from_bivariate_lambda(max, modulus=13) 40 | 41 | 42 | def test_xor_mod11(): 43 | """Tests x ^ y (mod 11).""" 44 | _construct_and_test_circuit_from_bivariate_lambda(lambda x, y: (x ^ y) % 11, modulus=11) 45 | 46 | 47 | def test_inequality_mod11_cse(): 48 | """Tests x ^ y (mod 11) with CSE.""" 49 | _construct_and_test_circuit_from_bivariate_lambda( 50 | lambda x, y: int(x != y), modulus=11, cse=True 51 | ) 52 | 53 | 54 | def test_max_mod7_cse(): 55 | """Tests max(x, y) (mod 7) with CSE.""" 56 | _construct_and_test_circuit_from_bivariate_lambda(max, modulus=7, cse=True) 57 | 58 | 59 | def test_xor_mod13_cse(): 60 | """Tests x ^ y (mod 13) with CSE.""" 61 | _construct_and_test_circuit_from_bivariate_lambda( 62 | lambda x, y: (x ^ y) % 13, modulus=13, cse=True 63 | ) 64 | -------------------------------------------------------------------------------- /tests/test_sugar_expressions.py: -------------------------------------------------------------------------------- 1 | """Test file for sugar expressions.""" 2 | 3 | from galois import GF 4 | 5 | from oraqle.compiler.circuit import Circuit 6 | from oraqle.compiler.nodes.arbitrary_arithmetic import sum_ 7 | from oraqle.compiler.nodes.leafs import Input 8 | 9 | 10 | def test_sum(): 11 | """Tests the sum_ function.""" 12 | gf = GF(127) 13 | 14 | a = Input("a", gf) 15 | b = Input("b", gf) 16 | 17 | arithmetic_circuit = Circuit([sum_(a, 4, b, 3)]).arithmetize() 18 | 19 | for val_a in range(127): 20 | for val_b in range(127): 21 | expected = gf(val_a) + gf(val_b) + gf(7) 22 | assert arithmetic_circuit.evaluate({"a": gf(val_a), "b": gf(val_b)}) == expected 23 | --------------------------------------------------------------------------------