├── .clang-format ├── .gitignore ├── .gitmodules ├── CHANGES.md ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── azure-pipelines.yml ├── eva ├── CMakeLists.txt ├── ckks │ ├── CMakeLists.txt │ ├── always_rescaler.h │ ├── ckks_compiler.h │ ├── ckks_config.cpp │ ├── ckks_config.h │ ├── ckks_parameters.h │ ├── ckks_signature.h │ ├── eager_relinearizer.h │ ├── eager_waterline_rescaler.h │ ├── encode_inserter.h │ ├── encryption_parameter_selector.h │ ├── lazy_relinearizer.h │ ├── lazy_waterline_rescaler.h │ ├── levels_checker.h │ ├── minimum_rescaler.h │ ├── mod_switcher.h │ ├── parameter_checker.h │ ├── rescaler.h │ ├── scales_checker.h │ └── seal_lowering.h ├── common │ ├── CMakeLists.txt │ ├── constant_folder.h │ ├── multicore_program_traversal.h │ ├── program_traversal.h │ ├── reduction_balancer.h │ ├── reference_executor.cpp │ ├── reference_executor.h │ ├── rotation_keys_selector.h │ ├── type_deducer.h │ └── valuation.h ├── eva.cpp ├── eva.h ├── ir │ ├── CMakeLists.txt │ ├── attribute_list.cpp │ ├── attribute_list.h │ ├── attributes.cpp │ ├── attributes.h │ ├── constant_value.h │ ├── ops.h │ ├── program.cpp │ ├── program.h │ ├── term.cpp │ ├── term.h │ ├── term_map.h │ └── types.h ├── seal │ ├── CMakeLists.txt │ ├── seal.cpp │ ├── seal.h │ └── seal_executor.h ├── serialization │ ├── CMakeLists.txt │ ├── ckks.proto │ ├── ckks_serialization.cpp │ ├── eva.proto │ ├── eva_format_version.h │ ├── eva_serialization.cpp │ ├── known_type.cpp │ ├── known_type.h │ ├── known_type.proto │ ├── save_load.cpp │ ├── save_load.h │ ├── seal.proto │ └── seal_serialization.cpp ├── util │ ├── CMakeLists.txt │ ├── galois.cpp │ ├── galois.h │ ├── logging.cpp │ ├── logging.h │ └── overloaded.h ├── version.cpp └── version.h ├── examples ├── .gitignore ├── baboon.png ├── image_processing.py ├── requirements.txt └── serialization.py ├── python ├── .gitignore ├── CMakeLists.txt ├── eva │ ├── CMakeLists.txt │ ├── __init__.py │ ├── ckks │ │ └── __init__.py │ ├── metric.py │ ├── seal │ │ └── __init__.py │ ├── std │ │ └── numeric.py │ └── wrapper.cpp └── setup.py.in ├── scripts └── clang-format-all.sh └── tests ├── all.py ├── bug_fixes.py ├── common.py ├── features.py ├── large_programs.py └── std.py /.clang-format: -------------------------------------------------------------------------------- 1 | AllowShortIfStatementsOnASingleLine: true 2 | BasedOnStyle: llvm 3 | FixNamespaceComments: true 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.*/ 2 | /build/ 3 | /dist/ 4 | eva.egg-info/ 5 | __pycache__/ 6 | 7 | # In-source build files 8 | Makefile 9 | CMakeCache.txt 10 | cmake_install.cmake 11 | CPackConfig.cmake 12 | CPackSourceConfig.cmake 13 | compile_commands.json 14 | CMakeFiles/ 15 | *.pb.cc 16 | *.pb.h 17 | *.a 18 | *.so 19 | *.whl -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/pybind11"] 2 | path = third_party/pybind11 3 | url = https://github.com/pybind/pybind11 4 | ignore = untracked 5 | [submodule "third_party/Galois"] 6 | path = third_party/Galois 7 | url = https://github.com/IntelligentSoftwareSystems/Galois.git 8 | ignore = untracked 9 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | # List of Changes 2 | 3 | ## Version 1.0.1 4 | 5 | ### Bug Fixes 6 | 7 | - Fixed issue where modulus switching was inserted for non-encrypted values [(PR 9)](https://github.com/microsoft/EVA/pull/9). 8 | 9 | ### Other 10 | 11 | - Validated compatibility with SEAL v3.6.4 and updated README.md. 12 | 13 | ## Version 1.0.0 14 | 15 | First open source release. 16 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | cmake_minimum_required(VERSION 3.13) 5 | cmake_policy(SET CMP0079 NEW) 6 | cmake_policy(SET CMP0076 NEW) 7 | 8 | set(CMAKE_CXX_STANDARD 17) 9 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 10 | set(CMAKE_CXX_EXTENSIONS OFF) 11 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 12 | 13 | project(eva 14 | VERSION 1.0.1 15 | LANGUAGES CXX 16 | ) 17 | 18 | option(USE_GALOIS "Use the Galois library for multicore homomorphic evaluation" OFF) 19 | if(USE_GALOIS) 20 | message("Galois based multicore support enabled") 21 | add_definitions(-DEVA_USE_GALOIS) 22 | endif() 23 | 24 | find_package(SEAL 3.6 REQUIRED) 25 | find_package(Protobuf 3.6 REQUIRED) 26 | find_package(Python COMPONENTS Interpreter Development) 27 | 28 | if(NOT Python_VERSION_MAJOR EQUAL 3) 29 | message(FATAL_ERROR "EVA requires Python 3. Please ensure you have it 30 | installed in a location searched by CMake.") 31 | endif() 32 | 33 | add_subdirectory(third_party/pybind11) 34 | if(USE_GALOIS) 35 | add_subdirectory(third_party/Galois EXCLUDE_FROM_ALL) 36 | endif() 37 | 38 | add_subdirectory(eva) 39 | add_subdirectory(python) 40 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | 3 | The EVA project welcomes contributions and suggestions. Most contributions require you to 4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 5 | and actually do, grant us the rights to use your contribution. For details, visit 6 | https://cla.microsoft.com. 7 | 8 | Please submit all pull requests on the **contrib** branch. We will handle the final merge onto the main branch. 9 | 10 | When you submit a pull request, a CLA-bot will automatically determine whether you need 11 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 12 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Microsoft Corporation. 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EVA - Compiler for Microsoft SEAL 2 | 3 | EVA is a compiler for homomorphic encryption, that automates away the parts that require cryptographic expertise. 4 | This gives you a simple way to write programs that operate on encrypted data without having access to the secret key. 5 | 6 | Think of EVA as the "C compiler" of the homomorphic world. Homomorphic computations written in EVA IR (Encrypted Vector Arithmetic Intermediate Representation) get compiled to the "assembly" of the homomorphic encryption library API. Just like C compilers free you from tricky tasks like register allocation, EVA frees you from *encryption parameter selection, rescaling insertion, relinearization*... 7 | 8 | EVA targets [Microsoft SEAL](https://github.com/microsoft/SEAL) — the industry leading library for fully-homomorphic encryption — and currently supports the CKKS scheme for deep computations on encrypted approximate fixed-point arithmetic. 9 | 10 | ## Getting Started 11 | 12 | EVA is a native library written in C++17 with bindings for Python. Both Linux and Windows are supported. The instructions below show how to get started with EVA on Ubuntu. For building on Windows [EVA's Azure Pipelines script](azure-pipelines.yml) is a useful reference. 13 | 14 | ### Installing Dependencies 15 | 16 | To install dependencies on Ubuntu 20.04: 17 | ``` 18 | sudo apt install cmake libboost-all-dev libprotobuf-dev protobuf-compiler 19 | ``` 20 | 21 | Clang is recommended for compilation, as SEAL is faster when compiled with it. To install clang and set it as default: 22 | ``` 23 | sudo apt install clang 24 | sudo update-alternatives --install /usr/bin/cc cc /usr/bin/clang 100 25 | sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/clang++ 100 26 | ``` 27 | 28 | Next install Microsoft SEAL version 3.6: 29 | ``` 30 | git clone -b v3.6.4 https://github.com/microsoft/SEAL.git 31 | cd SEAL 32 | cmake -DSEAL_THROW_ON_TRANSPARENT_CIPHERTEXT=OFF . 33 | make -j 34 | sudo make install 35 | ``` 36 | *Note that SEAL has to be installed with transparent ciphertext checking turned off, as it is not possible in general to statically ensure a program will not produce a transparent ciphertext. This does not affect the security of ciphertexts encrypted with SEAL.* 37 | 38 | ### Building and Installing EVA 39 | 40 | #### Building EVA 41 | 42 | EVA builds with CMake version ≥ 3.13: 43 | ``` 44 | git submodule update --init 45 | cmake . 46 | make -j 47 | ``` 48 | The build process creates a `setup.py` file in `python/`. To install the package for development with PIP: 49 | ``` 50 | python3 -m pip install -e python/ 51 | ``` 52 | To create a Python Wheel package for distribution in `dist/`: 53 | ``` 54 | python3 python/setup.py bdist_wheel --dist-dir='.' 55 | ``` 56 | 57 | To check that the installed Python package is working correctly, run all tests with: 58 | ``` 59 | python3 tests/all.py 60 | ``` 61 | 62 | EVA does not yet support installing the native library for use in other CMake projects (contributions very welcome). 63 | 64 | #### Multicore Support 65 | 66 | EVA features highly scalable multicore support using the [Galois library](https://github.com/IntelligentSoftwareSystems/Galois). It is included as a submodule, but is turned off by default for faster builds and easier debugging. To build EVA with Galois configure with `USE_GALOIS=ON`: 67 | ``` 68 | cmake -DUSE_GALOIS=ON . 69 | ``` 70 | 71 | ### Running the Examples 72 | 73 | The examples use EVA's Python APIs. To install dependencies with PIP: 74 | ``` 75 | python3 -m pip install -r examples/requirements.txt 76 | ``` 77 | 78 | To run for example the image processing example in EVA/examples: 79 | ``` 80 | cd examples/ 81 | python3 image_processing.py 82 | ``` 83 | This will compile and run homomorphic evaluations of a Sobel edge detection filter and a Harris corner detection filter on `examples/baboon.png`, producing results of homomorphic evaluation in `*_encrypted.png` and reference results from normal execution in `*_reference.png`. 84 | The script also reports the mean squared error between these for each filter. 85 | 86 | ## Programming with PyEVA 87 | 88 | PyEVA is a thin Python-embedded DSL for producing EVA programs. 89 | We will walk you through compiling a PyEVA program with EVA and running it on top of SEAL. 90 | 91 | ### Writing and Compiling Programs 92 | 93 | A program to evaluate a fixed polynomial 3x2+5x-2 on 1024 encrypted values can be written: 94 | ``` 95 | from eva import * 96 | poly = EvaProgram('Polynomial', vec_size=1024) 97 | with poly: 98 | x = Input('x') 99 | Output('y', 3*x**2 + 5*x - 2) 100 | ``` 101 | Next we will compile this program for the [CKKS encryption scheme](https://eprint.iacr.org/2016/421.pdf). 102 | Two additional pieces of information EVA currently requires to compile for CKKS are the *fixed-point scale for inputs* and the *maximum ranges of coefficients in outputs*, both represented in number of bits: 103 | ``` 104 | poly.set_output_ranges(30) 105 | poly.set_input_scales(30) 106 | ``` 107 | Now the program can be compiled: 108 | ``` 109 | from eva.ckks import * 110 | compiler = CKKSCompiler() 111 | compiled_poly, params, signature = compiler.compile(poly) 112 | ``` 113 | The `compile` method transforms the program in-place and returns: 114 | 115 | 1. the compiled program; 116 | 2. encryption parameters for Microsoft SEAL with which the program can be executed; 117 | 3. a signature object, that specifies how inputs and outputs need to be encoded and decoded. 118 | 119 | The compiled program can be inspected by printing it in the DOT format for the [Graphviz](https://graphviz.org/) visualization software: 120 | ``` 121 | print(compiled_poly.to_DOT()) 122 | ``` 123 | The output can be viewed as a graph in, for example, a number of Graphviz editors available online. 124 | 125 | ### Generating Keys and Encrypting Inputs 126 | 127 | Encryption keys can now be generated using the encryption parameters: 128 | ``` 129 | from eva.seal import * 130 | public_ctx, secret_ctx = generate_keys(params) 131 | ``` 132 | Next a dictionary of inputs is created and encrypted using the public context and the program signature: 133 | ``` 134 | inputs = { 'x': [i for i in range(compiled_poly.vec_size)] } 135 | encInputs = public_ctx.encrypt(inputs, signature) 136 | ``` 137 | 138 | ### Homomorphic Execution 139 | 140 | Everything is now in place for executing the program with Microsoft SEAL: 141 | ``` 142 | encOutputs = public_ctx.execute(compiled_poly, encInputs) 143 | ``` 144 | 145 | ### Decrypting Results 146 | 147 | Finally, the outputs can be decrypted using the secret context: 148 | ``` 149 | outputs = secret_ctx.decrypt(encOutputs, signature) 150 | ``` 151 | For debugging it is often useful to compare homomorphic results to unencrypted computation. 152 | The `evaluate` method can be used to execute an EVA program on unencrypted data. 153 | The two sets of results can then be compared with for example Mean Squared Error: 154 | 155 | ``` 156 | from eva.metric import valuation_mse 157 | reference = evaluate(compiled_poly, inputs) 158 | print('MSE', valuation_mse(outputs, reference)) 159 | ``` 160 | 161 | ## Contributing 162 | 163 | The EVA project welcomes contributions and suggestions. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details. 164 | 165 | ## Credits 166 | 167 | This project is a collaboration between the Microsoft Research's Research in Software Engineering (RiSE) group and Cryptography and Privacy Research group. 168 | 169 | A huge credit goes to [Dr. Roshan Dathathri](https://roshandathathri.github.io/), who as an intern built the first version of EVA, along with all the transformations required for targeting the CKKS scheme efficiently and the parallelizing runtime required to make execution scale. 170 | 171 | Many thanks to [Sangeeta Chowdhary](https://www.ilab.cs.rutgers.edu/~sc1696/), who as an intern put a huge amount of work into making EVA ready for release. 172 | 173 | ## Publications 174 | 175 | Roshan Dathathri, Blagovesta Kostova, Olli Saarikivi, Wei Dai, Kim Laine, Madanlal Musuvathi. *EVA: An Encrypted Vector Arithmetic Language and Compiler for Efficient Homomorphic Computation*. PLDI 2020. [arXiv](https://arxiv.org/abs/1912.11951) [DOI](https://doi.org/10.1145/3385412.3386023) 176 | 177 | Roshan Dathathri, Olli Saarikivi, Hao Chen, Kim Laine, Kristin Lauter, Saeed Maleki, Madanlal Musuvathi, Todd Mytkowicz. *CHET: An Optimizing Compiler for Fully-Homomorphic Neural-Network Inferencing*. PLDI 2019. [DOI](https://doi.org/10.1145/3314221.3314628) 178 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, you can contact the EVA compiler team at 10 | [evacompiler@microsoft.com](mailto:evacompiler@microsoft.com). 11 | We are very interested in helping early adopters start to use EVA. 12 | 13 | ## Microsoft Support Policy 14 | 15 | Support for EVA is limited to the resources listed above. -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # EVA pipeline 2 | 3 | trigger: 4 | - master 5 | 6 | pool: 7 | vmImage: 'windows-latest' 8 | 9 | steps: 10 | - task: UsePythonVersion@0 11 | displayName: 'Ensure Python 3.x' 12 | inputs: 13 | versionSpec: '3.x' 14 | addToPath: true 15 | architecture: 'x64' 16 | 17 | - task: securedevelopmentteam.vss-secure-development-tools.build-task-credscan.CredScan@2 18 | displayName: 'Run CredScan' 19 | inputs: 20 | toolMajorVersion: 'V2' 21 | outputFormat: sarif 22 | debugMode: false 23 | 24 | - task: CmdLine@2 25 | displayName: 'Get SEAL source code' 26 | inputs: 27 | script: | 28 | rem Use github repo 29 | git clone https://github.com/microsoft/SEAL.git 30 | cd SEAL 31 | rem Use 3.6.0 specifically 32 | git checkout 3.6.0 33 | workingDirectory: '$(Build.SourcesDirectory)/third_party' 34 | 35 | - task: CMake@1 36 | displayName: 'Configure SEAL' 37 | inputs: 38 | cmakeArgs: '-DSEAL_THROW_ON_TRANSPARENT_CIPHERTEXT=OFF -DALLOW_COMMAND_LINE_BUILD=ON -DSEAL_USE_MSGSL=OFF -DSEAL_USE_ZLIB=OFF -DSEAL_USE_ZSTD=OFF .' 39 | workingDirectory: $(Build.SourcesDirectory)/third_party/SEAL 40 | 41 | - task: MSBuild@1 42 | displayName: 'Build SEAL' 43 | inputs: 44 | solution: '$(Build.SourcesDirectory)/third_party/SEAL/SEAL.sln' 45 | msbuildArchitecture: 'x64' 46 | platform: 'x64' 47 | configuration: 'Debug' 48 | 49 | - task: CmdLine@2 50 | displayName: 'Get vcpkg' 51 | inputs: 52 | script: 'git clone https://github.com/microsoft/vcpkg.git' 53 | workingDirectory: '$(Build.SourcesDirectory)/third_party' 54 | 55 | - task: CmdLine@2 56 | displayName: 'Bootstrap vcpkg' 57 | inputs: 58 | script: '$(Build.SourcesDirectory)/third_party/vcpkg/bootstrap-vcpkg.bat' 59 | workingDirectory: '$(Build.SourcesDirectory)/third_party/vcpkg' 60 | 61 | - task: PowerShell@2 62 | displayName: 'Get protobuf compiler' 63 | inputs: 64 | targetType: 'inline' 65 | script: | 66 | mkdir protobuf 67 | cd protobuf 68 | Invoke-WebRequest -Uri "https://github.com/protocolbuffers/protobuf/releases/download/v3.15.8/protoc-3.15.8-win64.zip" -OutFile protobufc.zip 69 | Expand-Archive -LiteralPath protobufc.zip -DestinationPath protobufc 70 | workingDirectory: '$(Build.SourcesDirectory)/third_party' 71 | 72 | - task: CmdLine@2 73 | displayName: 'Install protobuf library' 74 | inputs: 75 | script: '$(Build.SourcesDirectory)/third_party/vcpkg/vcpkg.exe install protobuf[zlib]:x64-windows' 76 | workingDirectory: '$(Build.SourcesDirectory)/third_party/vcpkg' 77 | 78 | - task: CmdLine@2 79 | displayName: 'Create build directory' 80 | inputs: 81 | script: 'mkdir build' 82 | workingDirectory: '$(Build.SourcesDirectory)' 83 | 84 | - task: CMake@1 85 | displayName: 'Configure EVA' 86 | inputs: 87 | cmakeArgs: .. -DSEAL_DIR=$(Build.SourcesDirectory)/third_party/SEAL/cmake -DProtobuf_INCLUDE_DIR=$(Build.SourcesDirectory)/third_party/vcpkg/packages/protobuf_x64-windows/include -DProtobuf_LIBRARY=$(Build.SourcesDirectory)/third_party/vcpkg/packages/protobuf_x64-windows/lib/libprotobuf.lib -DProtobuf_PROTOC_EXECUTABLE=$(Build.SourcesDirectory)/third_party/protobuf/protobufc/bin/protoc.exe 88 | workingDirectory: '$(Build.SourcesDirectory)/build' 89 | 90 | - task: MSBuild@1 91 | displayName: 'Build EVA' 92 | inputs: 93 | solution: '$(Build.SourcesDirectory)/build/eva.sln' 94 | msbuildArchitecture: 'x64' 95 | platform: 'x64' 96 | configuration: 'Debug' 97 | 98 | - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 99 | displayName: 'Component Detection' 100 | 101 | - task: securedevelopmentteam.vss-secure-development-tools.build-task-publishsecurityanalysislogs.PublishSecurityAnalysisLogs@2 102 | displayName: 'Publish Security Analysis Logs' 103 | 104 | - task: PublishBuildArtifacts@1 105 | displayName: 'Publish build artifacts' 106 | inputs: 107 | PathtoPublish: '$(Build.ArtifactStagingDirectory)' 108 | artifactName: windows-drop 109 | -------------------------------------------------------------------------------- /eva/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | add_library(eva STATIC 5 | eva.cpp 6 | version.cpp 7 | ) 8 | 9 | # TODO: everything except SEAL::seal should be make PRIVATE 10 | target_link_libraries(eva PUBLIC SEAL::seal protobuf::libprotobuf) 11 | if(USE_GALOIS) 12 | target_link_libraries(eva PUBLIC Galois::shmem numa) 13 | endif() 14 | target_include_directories(eva 15 | PUBLIC 16 | $ 17 | $ 18 | $ 19 | ) 20 | target_compile_definitions(eva PRIVATE EVA_VERSION_STR="${PROJECT_VERSION}") 21 | 22 | add_subdirectory(util) 23 | add_subdirectory(serialization) 24 | add_subdirectory(ir) 25 | add_subdirectory(common) 26 | add_subdirectory(ckks) 27 | add_subdirectory(seal) 28 | -------------------------------------------------------------------------------- /eva/ckks/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | target_sources(eva PRIVATE 5 | ckks_config.cpp 6 | ) 7 | -------------------------------------------------------------------------------- /eva/ckks/always_rescaler.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ckks/rescaler.h" 7 | 8 | namespace eva { 9 | 10 | class AlwaysRescaler : public Rescaler { 11 | std::uint32_t minScale; 12 | 13 | public: 14 | AlwaysRescaler(Program &g, TermMap &type, 15 | TermMapOptional &scale) 16 | : Rescaler(g, type, scale) { 17 | // ASSUME: minScale is max among all the cipher inputs' scale 18 | minScale = 0; 19 | for (auto &source : program.getSources()) { 20 | if (scale[source] > minScale) minScale = scale[source]; 21 | } 22 | assert(minScale != 0); 23 | } 24 | 25 | void 26 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 27 | if (term->numOperands() == 0) return; // inputs 28 | if (type[term] == Type::Raw) { 29 | handleRawScale(term); 30 | return; 31 | } 32 | 33 | if (isRescaleOp(term->op)) return; // already processed 34 | 35 | if (!isMultiplicationOp(term->op)) { 36 | // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst, 37 | // Op::RotateRightConst copy scale of the first operand 38 | scale[term] = scale[term->operandAt(0)]; 39 | if (isAdditionOp(term->op)) { 40 | // Op::Add, Op::Sub 41 | // assert that all operands have the same scale 42 | for (auto &operand : term->getOperands()) { 43 | if (type[operand] != Type::Raw) { 44 | assert(scale[term] == scale[operand] || type[operand] == Type::Raw); 45 | } 46 | } 47 | } 48 | return; 49 | } 50 | 51 | // Op::Mul only 52 | // ASSUME: only two operands 53 | std::uint32_t multScale = 0; 54 | for (auto &operand : term->getOperands()) { 55 | multScale += scale[operand]; 56 | } 57 | assert(multScale != 0); 58 | scale[term] = multScale; 59 | 60 | // always rescale 61 | insertRescale(term, multScale - minScale); 62 | } 63 | }; 64 | 65 | } // namespace eva 66 | -------------------------------------------------------------------------------- /eva/ckks/ckks_config.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/ckks/ckks_config.h" 5 | #include "eva/util/logging.h" 6 | #include 7 | 8 | namespace eva { 9 | 10 | CKKSConfig::CKKSConfig( 11 | const std::unordered_map &configMap) { 12 | for (const auto &entry : configMap) { 13 | const auto &option = entry.first; 14 | const auto &valueStr = entry.second; 15 | if (option == "balance_reductions") { 16 | std::istringstream is(valueStr); 17 | is >> std::boolalpha >> balanceReductions; 18 | if (is.bad()) { 19 | warn("Could not parse boolean in balance_reductions=%s. Falling back " 20 | "to default.", 21 | valueStr.c_str()); 22 | } 23 | } else if (option == "rescaler") { 24 | if (valueStr == "lazy_waterline") { 25 | rescaler = CKKSRescaler::LazyWaterline; 26 | } else if (valueStr == "eager_waterline") { 27 | rescaler = CKKSRescaler::EagerWaterline; 28 | } else if (valueStr == "always") { 29 | rescaler = CKKSRescaler::Always; 30 | } else if (valueStr == "minimum") { 31 | rescaler = CKKSRescaler::Minimum; 32 | } else { 33 | // Please update this warning message when adding new options to the 34 | // cases above 35 | warn("Unknown value rescaler=%s. Available rescalers are " 36 | "lazy_waterline, eager_waterline, always, minimum. Falling back " 37 | "to default.", 38 | valueStr.c_str()); 39 | } 40 | } else if (option == "lazy_relinearize") { 41 | std::istringstream is(valueStr); 42 | is >> std::boolalpha >> lazyRelinearize; 43 | if (is.bad()) { 44 | warn("Could not parse boolean in lazy_relinearize=%s. Falling back to " 45 | "default.", 46 | valueStr.c_str()); 47 | } 48 | } else if (option == "security_level") { 49 | std::istringstream is(valueStr); 50 | is >> securityLevel; 51 | if (is.bad()) { 52 | throw std::runtime_error( 53 | "Could not parse unsigned int in security_level=" + valueStr); 54 | } 55 | } else if (option == "quantum_safe") { 56 | std::istringstream is(valueStr); 57 | is >> std::boolalpha >> quantumSafe; 58 | if (is.bad()) { 59 | throw std::runtime_error("Could not parse boolean in quantum_safe=" + 60 | valueStr); 61 | } 62 | } else if (option == "warn_vec_size") { 63 | std::istringstream is(valueStr); 64 | is >> std::boolalpha >> warnVecSize; 65 | if (is.bad()) { 66 | warn("Could not parse boolean in warn_vec_size=%s. Falling " 67 | "back to default.", 68 | valueStr.c_str()); 69 | } 70 | } else { 71 | warn("Unknown option %s. Available options are:\n%s", option.c_str(), 72 | OPTIONS_HELP_MESSAGE); 73 | } 74 | } 75 | } 76 | 77 | std::string CKKSConfig::toString(int indent) const { 78 | auto indentStr = std::string(indent, ' '); 79 | std::stringstream s; 80 | s << std::boolalpha; 81 | s << indentStr << "balance_reductions = " << balanceReductions; 82 | s << '\n'; 83 | s << indentStr << "rescaler = "; 84 | switch (rescaler) { 85 | case CKKSRescaler::LazyWaterline: 86 | s << "lazy_waterline"; 87 | break; 88 | case CKKSRescaler::EagerWaterline: 89 | s << "eager_waterline"; 90 | break; 91 | case CKKSRescaler::Always: 92 | s << "always"; 93 | break; 94 | case CKKSRescaler::Minimum: 95 | s << "minimum"; 96 | break; 97 | } 98 | s << '\n'; 99 | s << indentStr << "lazy_relinearize = " << lazyRelinearize; 100 | s << '\n'; 101 | s << indentStr << "security_level = " << securityLevel; 102 | s << '\n'; 103 | s << indentStr << "quantum_safe = " << quantumSafe; 104 | s << '\n'; 105 | s << indentStr << "warn_vec_size = " << warnVecSize; 106 | return s.str(); 107 | } 108 | 109 | } // namespace eva 110 | -------------------------------------------------------------------------------- /eva/ckks/ckks_config.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | 9 | namespace eva { 10 | 11 | // clang-format off 12 | const char *const OPTIONS_HELP_MESSAGE = 13 | "balance_reductions - Balance trees of mul, add or sub operations. bool (default=true)\n" 14 | "rescaler - Rescaling policy. One of: lazy_waterline (default), eager_waterline, always, minimum\n" 15 | "lazy_relinearize - Relinearize as late as possible. bool (default=true)\n" 16 | "security_level - How many bits of security parameters should be selected for. int (default=128)\n" 17 | "quantum_safe - Select quantum safe parameters. bool (default=false)\n" 18 | "warn_vec_size - Warn about possibly inefficient vector size selection. bool (default=true)"; 19 | // clang-format on 20 | 21 | enum class CKKSRescaler { LazyWaterline, EagerWaterline, Always, Minimum }; 22 | 23 | // Controls the behavior of CKKSCompiler 24 | class CKKSConfig { 25 | public: 26 | CKKSConfig() {} 27 | CKKSConfig(const std::unordered_map &configMap); 28 | 29 | std::string toString(int indent = 0) const; 30 | 31 | bool balanceReductions = true; 32 | CKKSRescaler rescaler = CKKSRescaler::LazyWaterline; 33 | bool lazyRelinearize = true; 34 | uint32_t securityLevel = 128; 35 | bool quantumSafe = false; 36 | 37 | // Warnings 38 | bool warnVecSize = true; 39 | }; 40 | 41 | } // namespace eva 42 | -------------------------------------------------------------------------------- /eva/ckks/ckks_parameters.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/serialization/ckks.pb.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace eva { 13 | 14 | struct CKKSParameters { 15 | std::vector primeBits; // in log-scale 16 | std::set rotations; 17 | std::uint32_t polyModulusDegree; 18 | }; 19 | 20 | std::unique_ptr serialize(const CKKSParameters &); 21 | std::unique_ptr deserialize(const msg::CKKSParameters &); 22 | 23 | } // namespace eva 24 | -------------------------------------------------------------------------------- /eva/ckks/ckks_signature.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/types.h" 7 | #include "eva/serialization/ckks.pb.h" 8 | #include 9 | #include 10 | #include 11 | 12 | namespace eva { 13 | 14 | // TODO: make these structs immutable 15 | 16 | struct CKKSEncodingInfo { 17 | Type inputType; 18 | int scale; 19 | int level; 20 | 21 | CKKSEncodingInfo(Type inputType, int scale, int level) 22 | : inputType(inputType), scale(scale), level(level) {} 23 | }; 24 | 25 | struct CKKSSignature { 26 | int vecSize; 27 | std::unordered_map inputs; 28 | 29 | CKKSSignature(int vecSize, 30 | std::unordered_map inputs) 31 | : vecSize(vecSize), inputs(inputs) {} 32 | }; 33 | 34 | std::unique_ptr serialize(const CKKSSignature &); 35 | std::unique_ptr deserialize(const msg::CKKSSignature &); 36 | 37 | } // namespace eva 38 | -------------------------------------------------------------------------------- /eva/ckks/eager_relinearizer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | 9 | namespace eva { 10 | 11 | class EagerRelinearizer { 12 | Program &program; 13 | TermMap &type; 14 | TermMapOptional &scale; 15 | 16 | bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); } 17 | 18 | bool isUnencryptedType(const Type &type) { return type != Type::Cipher; } 19 | 20 | bool areAllOperandsEncrypted(Term::Ptr &term) { 21 | for (auto &op : term->getOperands()) { 22 | if (isUnencryptedType(type[op])) { 23 | return false; 24 | } 25 | } 26 | return true; 27 | } 28 | 29 | public: 30 | EagerRelinearizer(Program &g, TermMap &type, 31 | TermMapOptional &scale) 32 | : program(g), type(type), scale(scale) {} 33 | 34 | void 35 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 36 | auto &operands = term->getOperands(); 37 | if (operands.size() == 0) return; // inputs 38 | 39 | auto op = term->op; 40 | 41 | if (!isMultiplicationOp(op)) return; 42 | 43 | // Op::Multiply only 44 | // ASSUME: only two operands 45 | bool encryptedOps = areAllOperandsEncrypted(term); 46 | if (!encryptedOps) return; 47 | 48 | auto relinNode = program.makeTerm(Op::Relinearize, {term}); 49 | type[relinNode] = type[term]; 50 | scale[relinNode] = scale[term]; 51 | 52 | term->replaceOtherUsesWith(relinNode); 53 | } 54 | }; 55 | 56 | } // namespace eva 57 | -------------------------------------------------------------------------------- /eva/ckks/eager_waterline_rescaler.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ckks/rescaler.h" 7 | #include "eva/util/logging.h" 8 | 9 | namespace eva { 10 | 11 | class EagerWaterlineRescaler : public Rescaler { 12 | std::uint32_t minScale; 13 | const std::uint32_t fixedRescale = 60; 14 | 15 | public: 16 | EagerWaterlineRescaler(Program &g, TermMap &type, 17 | TermMapOptional &scale) 18 | : Rescaler(g, type, scale) { 19 | // ASSUME: minScale is max among all the inputs' scale 20 | minScale = 0; 21 | for (auto &source : program.getSources()) { 22 | if (scale[source] > minScale) minScale = scale[source]; 23 | } 24 | assert(minScale != 0); 25 | } 26 | 27 | void 28 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 29 | if (term->numOperands() == 0) return; // inputs 30 | if (type[term] == Type::Raw) { 31 | handleRawScale(term); 32 | return; 33 | } 34 | 35 | if (isRescaleOp(term->op)) return; // already processed 36 | 37 | if (!isMultiplicationOp(term->op)) { 38 | // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst, 39 | // Op::RotateRightConst copy scale of the first operand 40 | scale[term] = scale[term->operandAt(0)]; 41 | if (isAdditionOp(term->op)) { 42 | // Op::Add, Op::Sub 43 | auto maxScale = scale[term]; 44 | for (auto &operand : term->getOperands()) { 45 | // Here we allow raw operands to possibly raise the scale 46 | if (scale[operand] > maxScale) maxScale = scale[operand]; 47 | } 48 | for (auto &operand : term->getOperands()) { 49 | if (scale[operand] < maxScale && type[operand] != Type::Raw) { 50 | log(Verbosity::Trace, 51 | "Scaling up t%i from scale %i to match other addition operands " 52 | "at scale %i", 53 | operand->index, scale[operand], maxScale); 54 | 55 | auto scaleConstant = program.makeUniformConstant(1); 56 | scale[scaleConstant] = maxScale - scale[operand]; 57 | scaleConstant->set(scale[scaleConstant]); 58 | 59 | auto mulNode = program.makeTerm(Op::Mul, {operand, scaleConstant}); 60 | scale[mulNode] = maxScale; 61 | 62 | // TODO: Not obviously correct as it's modifying inside 63 | // iteration. Refine API to make this less surprising. 64 | term->replaceOperand(operand, mulNode); 65 | } 66 | } 67 | // assert that all operands have the same scale 68 | for (auto &operand : term->getOperands()) { 69 | assert(maxScale == scale[operand] || type[operand] == Type::Raw); 70 | } 71 | scale[term] = maxScale; 72 | } 73 | return; 74 | } 75 | 76 | // Op::Mul only 77 | // ASSUME: only two operands 78 | std::uint32_t multScale = 0; 79 | for (auto &operand : term->getOperands()) { 80 | multScale += scale[operand]; 81 | } 82 | assert(multScale != 0); 83 | scale[term] = multScale; 84 | 85 | // rescale only if above the waterline 86 | auto temp = term; 87 | while (multScale >= (fixedRescale + minScale)) { 88 | temp = insertRescale(temp, fixedRescale); 89 | multScale -= fixedRescale; 90 | assert(multScale == scale[temp]); 91 | } 92 | } 93 | }; 94 | 95 | } // namespace eva 96 | -------------------------------------------------------------------------------- /eva/ckks/encode_inserter.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | 9 | namespace eva { 10 | 11 | class EncodeInserter { 12 | Program &program; 13 | TermMap &type; 14 | TermMapOptional &scale; 15 | 16 | bool isRawType(const Type &type) { return type == Type::Raw; } 17 | bool isCipherType(const Type &type) { return type == Type::Cipher; } 18 | bool isAdditionOp(const Op &op_code) { 19 | return ((op_code == Op::Add) || (op_code == Op::Sub)); 20 | } 21 | 22 | auto insertEncodeNode(Op op, const Term::Ptr &other, const Term::Ptr &term) { 23 | auto newNode = program.makeTerm(Op::Encode, {term}); 24 | type[newNode] = Type::Plain; 25 | if (isAdditionOp(op)) { 26 | scale[newNode] = scale[other]; 27 | } else { 28 | scale[newNode] = scale[term]; 29 | } 30 | newNode->set(scale[newNode]); 31 | return newNode; 32 | } 33 | 34 | public: 35 | EncodeInserter(Program &g, TermMap &type, 36 | TermMapOptional &scale) 37 | : program(g), type(type), scale(scale) {} 38 | 39 | void 40 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 41 | auto &operands = term->getOperands(); 42 | if (operands.size() == 0) return; // inputs 43 | 44 | assert(operands.size() <= 2); 45 | if (operands.size() == 2) { 46 | auto &leftOperand = operands[0]; 47 | auto &rightOperand = operands[1]; 48 | auto op1 = leftOperand->op; 49 | if (isCipherType(type[leftOperand]) && isRawType(type[rightOperand])) { 50 | auto newTerm = insertEncodeNode(term->op, leftOperand, rightOperand); 51 | term->replaceOperand(rightOperand, newTerm); 52 | } 53 | 54 | if (isCipherType(type[rightOperand]) && isRawType(type[leftOperand])) { 55 | auto newTerm = insertEncodeNode(term->op, rightOperand, leftOperand); 56 | term->replaceOperand(leftOperand, newTerm); 57 | } 58 | } 59 | } 60 | }; 61 | 62 | } // namespace eva 63 | -------------------------------------------------------------------------------- /eva/ckks/encryption_parameter_selector.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace eva { 14 | 15 | class EncryptionParametersSelector { 16 | public: 17 | EncryptionParametersSelector(Program &g, 18 | TermMapOptional &scales, 19 | TermMap &types) 20 | : program_(g), scales_(scales), terms_(g), types(types) {} 21 | 22 | void operator()(const Term::Ptr &term) { 23 | // This function computes, for each term, the set of coeff_modulus primes 24 | // needed to reach that term, taking only into account rescalings. Primes 25 | // needed to hold the output values are not included. For example, input 26 | // terms require no extra primes, so for input terms this function will 27 | // assign an empty set of primes. The example below shows parameters 28 | // assigned in a simple example computation, where we rescale by 40 bits: 29 | // 30 | // In_1:{} In_2:{} In_3:{} 31 | // \ \ / 32 | // \ \ / 33 | // \ * MULTIPLY:{} 34 | // \ | 35 | // \ | 36 | // \ * RESCALE:{40} 37 | // \ | 38 | // \ | 39 | // -----* ADD:{40} 40 | // | 41 | // | 42 | // Out_1:{40} 43 | 44 | // This function must only be used with forward pass traversal, as it 45 | // expects operand terms to have been processed already. 46 | if (types[term] == Type::Raw || term->op == Op::Encode) { 47 | return; 48 | } 49 | auto &operands = term->getOperands(); 50 | 51 | // Nothing to do for inputs 52 | if (operands.size() > 0) { 53 | // Get the parameters for this term 54 | auto &parms = terms_[term]; 55 | 56 | for (auto &operand : operands) { 57 | // Get the parameters for each operand (forward pass) 58 | auto &operandParms = terms_[operand]; 59 | 60 | // Set the parameters for this term to be the maximum over operands 61 | if (operandParms.size() > parms.size()) { 62 | parms = operandParms; 63 | } 64 | } 65 | 66 | // Adjust the parameters if this term is a rescale operation 67 | // NOTE: This is ignoring modulus switches, but still works because there 68 | // is always a longest path with no modulus switches. 69 | // TODO: Validate this claim and generalize to include modulus switches. 70 | if (isRescaleOp(term->op)) { 71 | auto newSize = parms.size() + 1; 72 | 73 | // By how much are we rescaling? 74 | auto divisor = term->get(); 75 | assert(divisor != 0); 76 | 77 | // Add the required scaling factor to the parameters 78 | parms.push_back(divisor); 79 | assert(parms.size() == newSize); 80 | } 81 | } 82 | } 83 | 84 | inline void free(const Term::Ptr &term) { terms_[term].clear(); } 85 | 86 | auto getEncryptionParameters() { 87 | // This function returns the encryption parameters (really just a list of 88 | // prime bit counts for the coeff_modulus) needed for this computation. It 89 | // can be called after forward pass traversal has computed the rescaling 90 | // primes for all terms. 91 | // 92 | // The logic is simple: we loop over each output term as those have the 93 | // largest (largest number of primes) parameter sets after forward 94 | // traversal, and find the largest parameter set among those. This set will 95 | // work globally for the computation. Since the parameters are not taking 96 | // into account the need for storing the result for the output terms, we 97 | // need to add one or more additional primes to the parameters, depending on 98 | // the scales and the ranges of the terms. For example, if the output term 99 | // has a parameter set {40} after forward traversal, with a scale and range 100 | // of 40 and 16 bits, respectively, the result requires an additional 56-bit 101 | // prime in the parameter set. This prime is always added in the set before 102 | // the rescaling primes, so in this case the function would return {56,40}. 103 | // If the scale and range are very large, this function will add more than 104 | // one extra prime. 105 | 106 | std::vector parms; 107 | 108 | // The size in bits needed for the output value; this includes the scale and 109 | // the range 110 | std::uint32_t maxOutputSize = 0; 111 | 112 | // The bit count of the largest prime appearing in the parameters 113 | std::uint32_t maxParm = 0; 114 | 115 | // The largest (largest number of primes) set of parameters required among 116 | // all output terms 117 | std::uint32_t maxLen = 0; 118 | 119 | // Loop over each output term 120 | for (auto &entry : program_.getOutputs()) { 121 | auto &output = entry.second; 122 | 123 | // The size for this output term equals the range attribute (bits) plus 124 | // the scale (bits) 125 | auto size = output->get(); 126 | size += scales_[output]; 127 | 128 | // Update maxOutputSize 129 | if (size > maxOutputSize) maxOutputSize = size; 130 | 131 | // Get the parameters for the current output term 132 | auto &oParms = terms_[output]; 133 | 134 | // Update maxLen (number of primes) 135 | if (maxLen < oParms.size()) maxLen = oParms.size(); 136 | 137 | // Update maxParm (largest prime) 138 | for (auto &parm : oParms) { 139 | if (parm > maxParm) maxParm = parm; 140 | } 141 | } 142 | 143 | // Ensure that the output size is non-zero 144 | assert(maxOutputSize != 0); 145 | 146 | if (maxOutputSize > 60) { 147 | // If the required output size is larger than 60 bits, we need to increase 148 | // the parameters with more than one additional primes. 149 | 150 | // In this case maxPrime is always 60 bits 151 | maxParm = 60; 152 | 153 | // Add 60-bit primes for as long as needed 154 | while (maxOutputSize >= 60) { 155 | parms.push_back(60); 156 | maxOutputSize -= 60; 157 | } 158 | 159 | // Add one more prime if needed 160 | if (maxOutputSize > 0) { 161 | // TODO: The minimum should probably depend on poly_modulus_degree 162 | parms.push_back(std::max(20u, maxOutputSize)); 163 | } 164 | } else { 165 | // The output size is less than 60 bits so the output parameters require 166 | // only one additional prime. 167 | 168 | // Update maxParm 169 | if (maxOutputSize > maxParm) maxParm = maxOutputSize; 170 | 171 | // Add the required prime to the parameters for this term 172 | parms.push_back(maxParm); 173 | } 174 | 175 | // Finally, loop over all output terms and add the largest parameter set to 176 | // parms after what was pushed above. 177 | for (auto &entry : program_.getOutputs()) { 178 | auto &output = entry.second; 179 | 180 | // Get the parameters for the current output term 181 | auto &oParms = terms_[output]; 182 | 183 | // If this output node has the longest parameter set, use it 184 | if (maxLen == oParms.size()) { 185 | parms.insert(parms.end(), oParms.rbegin(), oParms.rend()); 186 | 187 | // Exit the for loop; we have our parameter set 188 | break; 189 | } 190 | } 191 | 192 | // Add maxParm to result parameters; this is the "key prime". 193 | // TODO: This might be too aggressive. We can try smaller primes here as 194 | // well, which in some cases is advantageous as it may result in smaller 195 | // poly_modulus_degree, even though the noise growth may be a bit larger. 196 | parms.push_back(maxParm); 197 | 198 | return parms; 199 | } 200 | 201 | private: 202 | Program &program_; 203 | TermMapOptional &scales_; 204 | TermMap> terms_; 205 | TermMap &types; 206 | 207 | inline bool isRescaleOp(const Op &op_code) { return op_code == Op::Rescale; } 208 | }; 209 | 210 | } // namespace eva 211 | -------------------------------------------------------------------------------- /eva/ckks/lazy_relinearizer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | 9 | namespace eva { 10 | 11 | class LazyRelinearizer { 12 | Program &program; 13 | TermMap &type; 14 | TermMapOptional &scale; 15 | TermMap pending; // maintains whether relinearization is pending 16 | std::uint32_t count; 17 | std::uint32_t countTotal; 18 | 19 | bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); } 20 | 21 | bool isRotationOp(const Op &op_code) { 22 | return ((op_code == Op::RotateLeftConst) || 23 | (op_code == Op::RotateRightConst)); 24 | } 25 | 26 | bool isUnencryptedType(const Type &type) { return type != Type::Cipher; } 27 | 28 | bool areAllOperandsEncrypted(Term::Ptr &term) { 29 | for (auto &op : term->getOperands()) { 30 | assert(type[op] != Type::Undef); 31 | if (isUnencryptedType(type[op])) { 32 | return false; 33 | } 34 | } 35 | return true; 36 | } 37 | 38 | bool isEncryptedMultOp(Term::Ptr &term) { 39 | return (isMultiplicationOp(term->op) && areAllOperandsEncrypted(term)); 40 | } 41 | 42 | public: 43 | LazyRelinearizer(Program &g, TermMap &type, 44 | TermMapOptional &scale) 45 | : program(g), type(type), scale(scale), pending(g) { 46 | count = 0; 47 | countTotal = 0; 48 | } 49 | 50 | ~LazyRelinearizer() { 51 | // TODO: move these to a logging system 52 | // std::cout << "Number of delayed relin: " << count << "\n"; 53 | // std::cout << "Number of relin: " << countTotal << "\n"; 54 | } 55 | 56 | void 57 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 58 | auto &operands = term->getOperands(); 59 | if (operands.size() == 0) return; // inputs 60 | 61 | bool delayed = false; 62 | 63 | if (isEncryptedMultOp(term)) { 64 | assert(pending[term] == false); 65 | pending[term] = true; 66 | delayed = true; 67 | } else if (pending[term] == false) { 68 | return; 69 | } 70 | 71 | bool mustInsert = false; 72 | assert(term->numUses() > 0); 73 | auto firstUse = term->getUses()[0]; 74 | for (auto &use : term->getUses()) { 75 | if (isEncryptedMultOp(use) || isRotationOp(use->op) || 76 | use->op == Op::Output || (firstUse != use)) { // different uses 77 | mustInsert = true; 78 | break; 79 | } 80 | } 81 | 82 | if (mustInsert) { 83 | auto relinNode = program.makeTerm(Op::Relinearize, {term}); 84 | ++countTotal; 85 | 86 | type[relinNode] = type[term]; 87 | scale[relinNode] = scale[term]; 88 | term->replaceOtherUsesWith(relinNode); 89 | } else { 90 | if (delayed) ++count; 91 | for (auto &use : term->getUses()) { 92 | pending[use] = true; 93 | } 94 | } 95 | } 96 | }; 97 | 98 | } // namespace eva 99 | -------------------------------------------------------------------------------- /eva/ckks/levels_checker.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | #include 9 | #include 10 | #include 11 | 12 | namespace eva { 13 | 14 | class LevelsChecker { 15 | public: 16 | LevelsChecker(Program &g, TermMap &types) 17 | : program_(g), types_(types), levels_(g) {} 18 | 19 | void operator()(const Term::Ptr &term) { 20 | // This function verifies that the levels are compatibile. It assumes the 21 | // operand terms are processed already, so it must only be used with forward 22 | // pass traversal. 23 | 24 | if (term->numOperands() == 0) { 25 | // If this is a source node, get the encoding level 26 | levels_[term] = term->get(); 27 | } else { 28 | // For other terms, the operands must all have matching level. First find 29 | // the level of any of the ciphertext operands. 30 | std::size_t operandLevel; 31 | for (auto &operand : term->getOperands()) { 32 | if (types_[operand] == Type::Cipher) { 33 | operandLevel = levels_[operand]; 34 | break; 35 | } 36 | } 37 | 38 | // Next verify that all operands have the same level. 39 | for (auto &operand : term->getOperands()) { 40 | if (types_[operand] == Type::Cipher) { 41 | auto operandLevel2 = levels_[operand]; 42 | assert(operandLevel == operandLevel2); 43 | } 44 | } 45 | 46 | // Incremenet the level for a rescale or modulus switch 47 | std::size_t level = operandLevel; 48 | if (isRescaleOp(term->op) || isModSwitchOp(term->op)) { 49 | ++level; 50 | } 51 | levels_[term] = level; 52 | } 53 | } 54 | 55 | void free(const Term::Ptr &term) { 56 | // No-op 57 | } 58 | 59 | private: 60 | Program &program_; 61 | TermMap &types_; 62 | 63 | // Maintains the reverse level (leaves have 0, roots have max) 64 | TermMap levels_; 65 | 66 | bool isModSwitchOp(const Op &op_code) { return (op_code == Op::ModSwitch); } 67 | 68 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } 69 | }; 70 | 71 | } // namespace eva 72 | -------------------------------------------------------------------------------- /eva/ckks/minimum_rescaler.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ckks/rescaler.h" 7 | #include "eva/util/logging.h" 8 | 9 | namespace eva { 10 | 11 | class MinimumRescaler : public Rescaler { 12 | std::uint32_t minScale; 13 | const std::uint32_t maxRescale = 60; 14 | 15 | public: 16 | MinimumRescaler(Program &g, TermMap &type, 17 | TermMapOptional &scale) 18 | : Rescaler(g, type, scale) { 19 | // ASSUME: minScale is max among all the inputs' scale 20 | minScale = 0; 21 | for (auto &source : program.getSources()) { 22 | if (scale[source] > minScale) minScale = scale[source]; 23 | } 24 | assert(minScale != 0); 25 | } 26 | 27 | void 28 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 29 | auto &operands = term->getOperands(); 30 | if (operands.size() == 0) return; // inputs 31 | if (type[term] == Type::Raw) { 32 | handleRawScale(term); 33 | return; 34 | } 35 | 36 | auto op = term->op; 37 | 38 | if (isRescaleOp(op)) return; // already processed 39 | 40 | if (!isMultiplicationOp(op)) { 41 | // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst, 42 | // Op::RotateRightConst copy scale of the first operand 43 | for (auto &operand : operands) { 44 | assert(operand->op != Op::Constant); 45 | assert(scale[operand] != 0); 46 | scale[term] = scale[operand]; 47 | break; 48 | } 49 | if (isAdditionOp(op)) { 50 | // Op::Add, Op::Sub 51 | auto maxScale = scale[term]; 52 | for (auto &operand : operands) { 53 | // Here we allow raw operands to possibly raise the scale 54 | if (scale[operand] > maxScale) maxScale = scale[operand]; 55 | } 56 | for (auto &operand : operands) { 57 | if (scale[operand] < maxScale && type[operand] != Type::Raw) { 58 | log(Verbosity::Trace, 59 | "Scaling up t%i from scale %i to match other addition operands " 60 | "at scale %i", 61 | operand->index, scale[operand], maxScale); 62 | 63 | auto scaleConstant = program.makeUniformConstant(1); 64 | scale[scaleConstant] = maxScale - scale[operand]; 65 | scaleConstant->set(scale[scaleConstant]); 66 | 67 | auto mulNode = program.makeTerm(Op::Mul, {operand, scaleConstant}); 68 | scale[mulNode] = maxScale; 69 | 70 | // TODO: Not obviously correct as it's modifying inside 71 | // iteration. 72 | // Refine API to make this less surprising. 73 | term->replaceOperand(operand, mulNode); 74 | } 75 | } 76 | // assert that all operands have the same scale 77 | for (auto &operand : operands) { 78 | assert(maxScale == scale[operand] || type[operand] == Type::Raw); 79 | } 80 | scale[term] = maxScale; 81 | } 82 | return; 83 | } 84 | 85 | // Op::Multiply only 86 | // ASSUME: only two operands 87 | std::vector operandsCopy; 88 | for (auto &operand : operands) { 89 | operandsCopy.push_back(operand); 90 | } 91 | assert(operandsCopy.size() == 2); 92 | std::uint32_t multScale = scale[operandsCopy[0]] + scale[operandsCopy[1]]; 93 | assert(multScale != 0); 94 | scale[term] = multScale; 95 | 96 | auto minOfScales = scale[operandsCopy[0]]; 97 | if (minOfScales > scale[operandsCopy[1]]) 98 | minOfScales = scale[operandsCopy[1]]; 99 | auto rescaleBy = minOfScales - minScale; 100 | if (rescaleBy > maxRescale) rescaleBy = maxRescale; 101 | if ((2 * rescaleBy) >= maxRescale) { 102 | // rescale after multiplication is inevitable 103 | // to reduce the growth of scale, rescale both operands before 104 | // multiplication 105 | assert(rescaleBy <= maxRescale); 106 | insertRescaleBetween(operandsCopy[0], term, rescaleBy); 107 | if (operandsCopy[0] != operandsCopy[1]) { 108 | insertRescaleBetween(operandsCopy[1], term, rescaleBy); 109 | } 110 | 111 | scale[term] = multScale - (2 * rescaleBy); 112 | } else { 113 | // rescale only if above the waterline 114 | auto temp = term; 115 | while (multScale >= (maxRescale + minScale)) { 116 | temp = insertRescale(temp, maxRescale); 117 | multScale -= maxRescale; 118 | assert(multScale == scale[temp]); 119 | } 120 | } 121 | } 122 | }; 123 | 124 | } // namespace eva 125 | -------------------------------------------------------------------------------- /eva/ckks/mod_switcher.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | 9 | namespace eva { 10 | 11 | class ModSwitcher { 12 | Program &program; 13 | TermMap &type; 14 | TermMapOptional &scale; 15 | TermMap 16 | level; // maintains the reverse level (leaves have 0, roots have max) 17 | std::vector encodeNodes; 18 | 19 | Term::Ptr insertModSwitchNode(Term::Ptr &term, std::uint32_t termLevel) { 20 | auto newNode = program.makeTerm(Op::ModSwitch, {term}); 21 | scale[newNode] = scale[term]; 22 | level[newNode] = termLevel; 23 | return newNode; 24 | } 25 | 26 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } 27 | 28 | bool isCipherType(const Term::Ptr &term) const { 29 | return type[term] == Type::Cipher; 30 | } 31 | 32 | public: 33 | ModSwitcher(Program &g, TermMap &type, 34 | TermMapOptional &scale) 35 | : program(g), type(type), scale(scale), level(g) {} 36 | 37 | ~ModSwitcher() { 38 | auto sources = program.getSources(); 39 | std::uint32_t maxLevel = 0; 40 | for (auto &source : sources) { 41 | if (level[source] > maxLevel) maxLevel = level[source]; 42 | } 43 | for (auto &source : sources) { 44 | auto curLevel = maxLevel - level[source]; 45 | source->set(curLevel); 46 | } 47 | 48 | for (auto &encode : encodeNodes) { 49 | encode->set(maxLevel - level[encode]); 50 | } 51 | } 52 | 53 | void operator()( 54 | Term::Ptr &term) { // must only be used with backward pass traversal 55 | if (term->numUses() == 0) return; 56 | 57 | //we do not want to add modswitch for nodes of type raw 58 | if (type[term] == Type::Raw) return; 59 | 60 | if (term->op == Op::Encode) { 61 | encodeNodes.push_back(term); 62 | } 63 | std::map> useLevels; // ordered map 64 | for (auto &use : term->getUses()) { 65 | useLevels[level[use]].push_back(use); 66 | } 67 | 68 | std::uint32_t termLevel = 0; 69 | if (useLevels.size() > 1) { 70 | auto useLevel = useLevels.rbegin(); // max to min 71 | termLevel = useLevel->first; 72 | ++useLevel; 73 | 74 | auto temp = term; 75 | auto tempLevel = termLevel; 76 | while (useLevel != useLevels.rend()) { 77 | auto expectedLevel = useLevel->first; 78 | while (tempLevel > expectedLevel) { 79 | temp = insertModSwitchNode(temp, tempLevel); 80 | --tempLevel; 81 | } 82 | for (auto &use : useLevel->second) { 83 | use->replaceOperand(term, temp); 84 | } 85 | ++useLevel; 86 | } 87 | } else { 88 | assert(useLevels.size() == 1); 89 | termLevel = useLevels.begin()->first; 90 | } 91 | if (isRescaleOp(term->op)) { 92 | ++termLevel; 93 | } 94 | level[term] = termLevel; 95 | } 96 | }; 97 | 98 | } // namespace eva 99 | -------------------------------------------------------------------------------- /eva/ckks/parameter_checker.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace eva { 14 | 15 | class InconsistentParameters : public std::runtime_error { 16 | public: 17 | InconsistentParameters(const std::string &msg) : std::runtime_error(msg) {} 18 | }; 19 | 20 | class ParameterChecker { 21 | TermMap &types; 22 | 23 | public: 24 | ParameterChecker(Program &g, TermMap &types) 25 | : program_(g), parms_(g), types(types) {} 26 | 27 | void operator()(const Term::Ptr &term) { 28 | // Must only be used with forward pass traversal 29 | auto &operands = term->getOperands(); 30 | if (types[term] == Type::Raw || term->op == Op::Encode) { 31 | return; 32 | } 33 | if (operands.size() > 0) { 34 | // Get the parameters for this term 35 | auto &parms = parms_[term]; 36 | // Loop over operands 37 | for (auto &operand : operands) { 38 | // Get the parameters for the operand 39 | auto &operandParms = parms_[operand]; 40 | 41 | // Nothing to do if the operand parameters are empty; the operand sets 42 | // no requirements on this node 43 | if (operandParms.size() > 0) { 44 | if (parms.size() > 0) { 45 | // If the parameters for this term are already set (from a different 46 | // operand), they must match the current operand's parameters 47 | if (operandParms.size() != parms.size()) { 48 | throw InconsistentParameters( 49 | "Two operands require different number of primes"); 50 | } 51 | 52 | // Loop over the primes in the parameters for this term 53 | for (std::size_t i = 0; i < parms.size(); ++i) { 54 | if (parms[i] == 0) { 55 | // If any of the primes is zero (indicating a previous modulus 56 | // switch operand term, fill in its true value from the current 57 | // operand 58 | parms[i] = operandParms[i]; 59 | } else if (operandParms[i] != 0) { 60 | // If the operand prime is non-zero, require equality 61 | if (parms[i] != operandParms[i]) { 62 | throw InconsistentParameters( 63 | "Primes required by two operands do not match"); 64 | } 65 | } 66 | } 67 | } else { 68 | // This is the first operand to impose conditions on this term; 69 | // copy the parameters from the operand 70 | parms = operandParms; 71 | } 72 | } 73 | } 74 | 75 | if (isModSwitchOp(term->op)) { 76 | // Is this a modulus switch? If so, add an extra (placeholder) zero 77 | parms.push_back(0); 78 | } else if (isRescaleOp(term->op)) { 79 | // Is this a rescale? Then add a prime of the requested size 80 | auto divisor = term->get(); 81 | assert(divisor != 0); 82 | parms.push_back(divisor); 83 | } 84 | } else { 85 | // Get the parameters for this term 86 | auto &parms = parms_[term]; 87 | std::uint32_t level = term->get(); 88 | while (level > 0) { 89 | parms.push_back(0); 90 | level--; 91 | } 92 | } 93 | } 94 | 95 | void free(const Term::Ptr &term) { parms_[term].clear(); } 96 | 97 | private: 98 | Program &program_; 99 | TermMap> parms_; 100 | 101 | bool isModSwitchOp(const Op &op_code) { return (op_code == Op::ModSwitch); } 102 | 103 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } 104 | }; 105 | 106 | } // namespace eva 107 | -------------------------------------------------------------------------------- /eva/ckks/rescaler.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | 9 | namespace eva { 10 | 11 | class Rescaler { 12 | protected: 13 | Program &program; 14 | TermMap &type; 15 | TermMapOptional &scale; 16 | 17 | Rescaler(Program &g, TermMap &type, 18 | TermMapOptional &scale) 19 | : program(g), type(type), scale(scale) {} 20 | 21 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } 22 | 23 | bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); } 24 | 25 | bool isAdditionOp(const Op &op_code) { 26 | return ((op_code == Op::Add) || (op_code == Op::Sub)); 27 | } 28 | 29 | auto insertRescale(Term::Ptr term, std::uint32_t rescaleBy) { 30 | // auto scale = term->getScale(); 31 | auto rescaleNode = program.makeRescale(term, rescaleBy); 32 | type[rescaleNode] = type[term]; 33 | scale[rescaleNode] = scale[term] - rescaleBy; 34 | 35 | term->replaceOtherUsesWith(rescaleNode); 36 | 37 | return rescaleNode; 38 | } 39 | 40 | void insertRescaleBetween(Term::Ptr term1, Term::Ptr term2, 41 | std::uint32_t rescaleBy) { 42 | auto rescaleNode = program.makeRescale(term1, rescaleBy); 43 | type[rescaleNode] = type[term1]; 44 | scale[rescaleNode] = scale[term1] - rescaleBy; 45 | 46 | term2->replaceOperand(term1, rescaleNode); 47 | } 48 | 49 | void handleRawScale(Term::Ptr term) { 50 | if (term->numOperands() > 0) { 51 | int maxScale = 0; 52 | for (auto &operand : term->getOperands()) { 53 | if (scale.at(operand) > maxScale) maxScale = scale.at(operand); 54 | } 55 | scale[term] = maxScale; 56 | } 57 | } 58 | }; 59 | 60 | } // namespace eva 61 | -------------------------------------------------------------------------------- /eva/ckks/scales_checker.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | #include 9 | #include 10 | #include 11 | 12 | namespace eva { 13 | 14 | class ScalesChecker { 15 | public: 16 | ScalesChecker(Program &g, TermMapOptional &scales, 17 | TermMap &types) 18 | : program_(g), scales_(g), types_(types) {} 19 | 20 | void operator()(const Term::Ptr &term) { 21 | // Must only be used with forward pass traversal 22 | if (types_[term] == Type::Raw) { 23 | return; 24 | } 25 | auto &operands = term->getOperands(); 26 | 27 | // Nothing to do for source terms 28 | if (term->op == Op::Input || term->op == Op::Encode) { 29 | scales_[term] = term->get(); 30 | if (scales_.at(term) == 0) { 31 | if (term->op == Op::Input) { 32 | throw std::runtime_error("Program has an input with 0 scale"); 33 | } else { 34 | throw std::logic_error("Compiled program results in a 0 scale term"); 35 | } 36 | } 37 | } else if (term->op == Op::Mul) { 38 | assert(term->numOperands() == 2); 39 | std::uint32_t scale = 0; 40 | for (auto &operand : operands) { 41 | scale += scales_.at(operand); 42 | } 43 | if (scale == 0) { 44 | throw std::logic_error("Compiled program results in a 0 scale term"); 45 | } 46 | scales_[term] = scale; 47 | } else if (term->op == Op::Rescale) { 48 | assert(term->numOperands() == 1); 49 | auto divisor = term->get(); 50 | auto operandScale = scales_.at(term->operandAt(0)); 51 | std::uint32_t scale = operandScale - divisor; 52 | if (scale == 0) { 53 | throw std::logic_error("Compiled program results in a 0 scale term"); 54 | } 55 | scales_[term] = scale; 56 | 57 | } else if (isAdditionOp(term->op)) { 58 | std::uint32_t scale = 0; 59 | for (auto &operand : operands) { 60 | if (scale == 0) { 61 | scale = scales_.at(operand); 62 | } else { 63 | if (scale != scales_.at(operand)) { 64 | throw std::logic_error("Addition or subtraction in program has " 65 | "operands of non-equal scale"); 66 | } 67 | } 68 | } 69 | if (scale == 0) { 70 | throw std::logic_error("Compiled program results in a 0 scale term"); 71 | } 72 | scales_[term] = scale; 73 | } else { 74 | auto scale = scales_.at(term->operandAt(0)); 75 | if (scale == 0) { 76 | throw std::logic_error("Compiled program results in a 0 scale term"); 77 | } 78 | scales_[term] = scale; 79 | } 80 | } 81 | 82 | void free(const Term::Ptr &term) { 83 | // No-op 84 | } 85 | 86 | private: 87 | Program &program_; 88 | TermMapOptional scales_; 89 | TermMap &types_; 90 | 91 | bool isAdditionOp(const Op &op_code) { 92 | return ((op_code == Op::Add) || (op_code == Op::Sub)); 93 | } 94 | }; 95 | 96 | } // namespace eva 97 | -------------------------------------------------------------------------------- /eva/ckks/seal_lowering.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | 9 | namespace eva { 10 | 11 | class SEALLowering { 12 | Program &program; 13 | TermMap &type; 14 | 15 | public: 16 | SEALLowering(Program &g, TermMap &type) : program(g), type(type) {} 17 | 18 | void 19 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 20 | 21 | // SEAL does not support plaintext subtraction with a plaintext on the left 22 | // hand side, so lower to a negation and addition. 23 | if (term->op == Op::Sub && type[term->operandAt(0)] != Type::Cipher && 24 | type[term->operandAt(1)] == Type::Cipher) { 25 | auto negation = program.makeTerm(Op::Negate, {term->operandAt(1)}); 26 | auto addition = program.makeTerm(Op::Add, {term->operandAt(0), negation}); 27 | term->replaceAllUsesWith(addition); 28 | } 29 | } 30 | }; 31 | 32 | } // namespace eva 33 | -------------------------------------------------------------------------------- /eva/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | target_sources(eva PRIVATE 5 | reference_executor.cpp 6 | ) 7 | -------------------------------------------------------------------------------- /eva/common/constant_folder.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | 9 | namespace eva { 10 | 11 | class ConstantFolder { 12 | Program &program; 13 | TermMapOptional &scale; 14 | std::vector scratch1, scratch2; 15 | 16 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); } 17 | 18 | bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); } 19 | 20 | bool isAdditionOp(const Op &op_code) { 21 | return ((op_code == Op::Add) || (op_code == Op::Sub)); 22 | } 23 | 24 | void replaceNodeWithConstant(Term::Ptr term, 25 | const std::vector &output, 26 | double termScale) { 27 | // TODO: optimize output representations 28 | auto constant = program.makeDenseConstant(output); 29 | scale[constant] = termScale; 30 | constant->set(scale[constant]); 31 | 32 | term->replaceAllUsesWith(constant); 33 | assert(term->numUses() == 0); 34 | } 35 | 36 | void add(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) { 37 | auto &input1 = args1->get()->expand( 38 | scratch1, program.getVecSize()); 39 | auto &input2 = args2->get()->expand( 40 | scratch2, program.getVecSize()); 41 | 42 | std::vector outputValue(input1.size()); 43 | for (std::uint64_t i = 0; i < outputValue.size(); ++i) { 44 | outputValue[i] = input1[i] + input2[i]; 45 | } 46 | 47 | replaceNodeWithConstant(output, outputValue, 48 | std::max(scale[args1], scale[args2])); 49 | } 50 | 51 | void sub(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) { 52 | auto &input1 = args1->get()->expand( 53 | scratch1, program.getVecSize()); 54 | auto &input2 = args2->get()->expand( 55 | scratch2, program.getVecSize()); 56 | 57 | std::vector outputValue(input1.size()); 58 | for (std::uint64_t i = 0; i < outputValue.size(); ++i) { 59 | outputValue[i] = input1[i] - input2[i]; 60 | } 61 | 62 | replaceNodeWithConstant(output, outputValue, 63 | std::max(scale[args1], scale[args2])); 64 | } 65 | 66 | void mul(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) { 67 | auto &input1 = args1->get()->expand( 68 | scratch1, program.getVecSize()); 69 | auto &input2 = args2->get()->expand( 70 | scratch2, program.getVecSize()); 71 | 72 | std::vector outputValue(input1.size()); 73 | for (std::uint64_t i = 0; i < outputValue.size(); ++i) { 74 | outputValue[i] = input1[i] * input2[i]; 75 | } 76 | 77 | replaceNodeWithConstant(output, outputValue, 78 | std::max(scale[args1], scale[args2])); 79 | } 80 | 81 | void leftRotate(Term::Ptr output, const Term::Ptr &args1, 82 | std::int32_t shift) { 83 | auto &input1 = args1->get()->expand( 84 | scratch1, program.getVecSize()); 85 | 86 | while (shift > 0 && shift >= input1.size()) 87 | shift -= input1.size(); 88 | while (shift < 0) 89 | shift += input1.size(); 90 | 91 | std::vector outputValue(input1.size()); 92 | for (std::uint64_t i = 0; i < (outputValue.size() - shift); ++i) { 93 | outputValue[i] = input1[i + shift]; 94 | } 95 | for (std::uint64_t i = 0; i < shift; ++i) { 96 | outputValue[outputValue.size() - shift + i] = input1[i]; 97 | } 98 | 99 | replaceNodeWithConstant(output, outputValue, scale[args1]); 100 | } 101 | 102 | void rightRotate(Term::Ptr output, const Term::Ptr &args1, 103 | std::int32_t shift) { 104 | auto &input1 = args1->get()->expand( 105 | scratch1, program.getVecSize()); 106 | 107 | while (shift > 0 && shift >= input1.size()) 108 | shift -= input1.size(); 109 | while (shift < 0) 110 | shift += input1.size(); 111 | 112 | std::vector outputValue(input1.size()); 113 | for (std::uint64_t i = 0; i < (outputValue.size() - shift); ++i) { 114 | outputValue[i + shift] = input1[i]; 115 | } 116 | for (std::uint64_t i = 0; i < shift; ++i) { 117 | outputValue[i] = input1[outputValue.size() - shift + i]; 118 | } 119 | 120 | replaceNodeWithConstant(output, outputValue, scale[args1]); 121 | } 122 | 123 | void negate(Term::Ptr output, const Term::Ptr &args1) { 124 | auto &input1 = args1->get()->expand( 125 | scratch1, program.getVecSize()); 126 | 127 | std::vector outputValue(input1.size()); 128 | for (std::uint64_t i = 0; i < outputValue.size(); ++i) { 129 | outputValue[i] = -input1[i]; 130 | } 131 | 132 | replaceNodeWithConstant(output, outputValue, scale[args1]); 133 | } 134 | 135 | public: 136 | ConstantFolder(Program &g, TermMapOptional &scale) 137 | : program(g), scale(scale) {} 138 | 139 | void 140 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 141 | auto &args = term->getOperands(); 142 | if (args.size() == 0) return; // inputs 143 | 144 | for (auto &arg : args) { 145 | if (arg->op != Op::Constant) return; 146 | } 147 | 148 | auto op_code = term->op; 149 | switch (op_code) { 150 | case Op::Add: 151 | assert(args.size() == 2); 152 | add(term, args[0], args[1]); 153 | break; 154 | case Op::Sub: 155 | assert(args.size() == 2); 156 | sub(term, args[0], args[1]); 157 | break; 158 | case Op::Mul: 159 | assert(args.size() == 2); 160 | mul(term, args[0], args[1]); 161 | break; 162 | case Op::RotateLeftConst: 163 | assert(args.size() == 1); 164 | leftRotate(term, args[0], term->get()); 165 | break; 166 | case Op::RotateRightConst: 167 | assert(args.size() == 1); 168 | rightRotate(term, args[0], term->get()); 169 | break; 170 | case Op::Negate: 171 | assert(args.size() == 1); 172 | negate(term, args[0]); 173 | break; 174 | case Op::Output: 175 | [[fallthrough]]; 176 | case Op::Encode: 177 | break; 178 | case Op::Relinearize: 179 | [[fallthrough]]; 180 | case Op::ModSwitch: 181 | [[fallthrough]]; 182 | case Op::Rescale: 183 | throw std::logic_error("Encountered HE specific operation " + 184 | getOpName(op_code) + 185 | " in unencrypted computation"); 186 | default: 187 | throw std::logic_error("Unhandled op " + getOpName(op_code)); 188 | } 189 | } 190 | }; 191 | 192 | } // namespace eva 193 | -------------------------------------------------------------------------------- /eva/common/multicore_program_traversal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | #include "eva/util/galois.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace eva { 19 | 20 | class MulticoreProgramTraversal { 21 | public: 22 | MulticoreProgramTraversal(Program &g) : program_(g) {} 23 | 24 | template void forwardPass(Evaluator &eval) { 25 | TermMap predecessors(program_); 26 | TermMap successors(program_); 27 | 28 | // Add the source terms 29 | galois::InsertBag readyNodes; 30 | for (auto source : program_.getSources()) { 31 | readyNodes.push_back(source); 32 | } 33 | 34 | // Enumerate predecessors and successors 35 | galois::for_each( 36 | galois::iterate(readyNodes), 37 | [&](const Term::Ptr &term, auto &ctx) { 38 | // For each term, iterate over its uses 39 | for (auto &use : term->getUses()) { 40 | // Increment the number of successors 41 | ++successors[term]; 42 | 43 | // Increment the number of predecessors 44 | if ((++predecessors[use]) == 1) { 45 | // Only first predecessor will push so each use is added once 46 | ctx.push_back(use); 47 | } 48 | } 49 | }, 50 | galois::wl>(), 51 | galois::no_stats(), 52 | galois::loopname("ForwardCountPredecessorsSuccessors")); 53 | 54 | // Traverse the program 55 | galois::for_each( 56 | galois::iterate(readyNodes), 57 | [&](const Term::Ptr &term, auto &ctx) { 58 | // Process the current term 59 | eval(term); 60 | 61 | // Free operands if their successors are done 62 | for (auto &operand : term->getOperands()) { 63 | if ((--successors[operand]) == 0) { 64 | // Only last successor will free 65 | eval.free(operand); 66 | } 67 | } 68 | 69 | // Execute (ready) uses if their predecessors are done 70 | for (auto &use : term->getUses()) { 71 | if ((--predecessors[use]) == 0) { 72 | // Only last predecessor will push 73 | ctx.push_back(use); 74 | } 75 | } 76 | }, 77 | galois::wl>(), 78 | galois::no_stats(), galois::loopname("ForwardTraversal")); 79 | 80 | // TODO: Reinstate these checks 81 | // for (auto& predecessor : predecessors) assert(predecessor == 0); 82 | // for (auto& successor : successors) assert(successor == 0); 83 | } 84 | 85 | template void backwardPass(Evaluator &eval) { 86 | TermMap predecessors(program_); 87 | TermMap successors(program_); 88 | 89 | // Add the sink terms 90 | galois::InsertBag readyNodes; 91 | for (auto &sink : program_.getSinks()) { 92 | readyNodes.push_back(sink); 93 | } 94 | 95 | // Enumerate predecessors and successors 96 | galois::for_each( 97 | galois::iterate(readyNodes), 98 | [&](const Term::Ptr &term, auto &ctx) { 99 | // For each term, iterate over its operands 100 | for (auto &operand : term->getOperands()) { 101 | // Increment the number of predecessors 102 | ++predecessors[term]; 103 | 104 | // Increment the number of successors for the operand 105 | if ((++successors[operand]) == 1) { 106 | // Only first successor will push so each operand is added once 107 | ctx.push_back(operand); 108 | } 109 | } 110 | }, 111 | galois::wl>(), 112 | galois::no_stats(), 113 | galois::loopname("BackwardCountPredecessorsSuccessors")); 114 | 115 | // Traverse the program 116 | galois::for_each( 117 | galois::iterate(readyNodes), 118 | [&](const Term::Ptr &term, auto &ctx) { 119 | // Process the current term 120 | eval(term); 121 | 122 | // Free uses if their predecessors are done 123 | for (auto &use : term->getUses()) { 124 | if ((--predecessors[use]) == 0) { 125 | // Only last predecessor will free 126 | eval.free(use); 127 | } 128 | } 129 | 130 | // Execute (ready) operands if their successors are done 131 | for (auto &operand : term->getOperands()) { 132 | if ((--successors[operand]) == 0) { 133 | // Only last successor will push 134 | ctx.push_back(operand); 135 | } 136 | } 137 | }, 138 | galois::wl>(), 139 | galois::no_stats(), galois::loopname("BackwardTraversal")); 140 | 141 | // TODO: Reinstate these checks 142 | // for (auto& predecessor : predecessors) assert(predecessor == 0); 143 | // for (auto& successor : successors) assert(successor == 0); 144 | } 145 | 146 | private: 147 | Program &program_; 148 | GaloisGuard galoisGuard_; 149 | }; 150 | 151 | } // namespace eva 152 | -------------------------------------------------------------------------------- /eva/common/program_traversal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | #include "eva/util/logging.h" 9 | #include 10 | #include 11 | 12 | namespace eva { 13 | 14 | /* 15 | Implements efficient forward and backward traversals of Program in the 16 | presence of modifications during traversal. 17 | The rewriter is called for each term in the Program exactly once. 18 | Rewriters must not modify the Program in such a way that terms that are 19 | not uses/operands (for forward/backward traversal, respectively) of the 20 | current term are enabled. With such modifications the whole program is 21 | not guaranteed to be traversed. 22 | */ 23 | class ProgramTraversal { 24 | Program &program; 25 | 26 | TermMap ready; 27 | TermMap processed; 28 | 29 | template bool arePredecessorsDone(const Term::Ptr &term) { 30 | for (auto &operand : isForward ? term->getOperands() : term->getUses()) { 31 | if (!processed[operand]) return false; 32 | } 33 | return true; 34 | } 35 | 36 | template 37 | void traverse(Rewriter &&rewrite) { 38 | processed.clear(); 39 | ready.clear(); 40 | 41 | std::vector readyNodes = 42 | isForward ? program.getSources() : program.getSinks(); 43 | for (auto &term : readyNodes) { 44 | ready[term] = true; 45 | } 46 | // Used for remembering uses/operands before rewrite is called. Using a 47 | // vector here is fine because duplicates in the list are handled 48 | // gracefully. 49 | std::vector checkList; 50 | 51 | while (readyNodes.size() != 0) { 52 | // Pop term to transform 53 | auto term = readyNodes.back(); 54 | readyNodes.pop_back(); 55 | 56 | // If this term is removed, we will lose uses/operands of this term. 57 | // Remember them here for checking readyness after the rewrite. 58 | checkList.clear(); 59 | for (auto &succ : isForward ? term->getUses() : term->getOperands()) { 60 | checkList.push_back(succ); 61 | } 62 | 63 | log(Verbosity::Trace, "Processing term with index=%lu", term->index); 64 | rewrite(term); 65 | processed[term] = true; 66 | 67 | // If transform adds new sources/sinks add them to ready terms. 68 | for (auto &leaf : isForward ? program.getSources() : program.getSinks()) { 69 | if (!ready[leaf]) { 70 | readyNodes.push_back(leaf); 71 | ready[leaf] = true; 72 | } 73 | } 74 | 75 | // Also check current uses/operands in case any new ones were added. 76 | for (auto &succ : isForward ? term->getUses() : term->getOperands()) { 77 | checkList.push_back(succ); 78 | } 79 | 80 | // Push and mark uses/operands that are ready to be processed. 81 | for (auto &succ : checkList) { 82 | if (!ready[succ] && arePredecessorsDone(succ)) { 83 | readyNodes.push_back(succ); 84 | ready[succ] = true; 85 | } 86 | } 87 | } 88 | } 89 | 90 | public: 91 | ProgramTraversal(Program &g) : program(g), processed(g), ready(g) {} 92 | 93 | template void forwardPass(Rewriter &&rewrite) { 94 | traverse(std::forward(rewrite)); 95 | } 96 | 97 | template void backwardPass(Rewriter &&rewrite) { 98 | traverse(std::forward(rewrite)); 99 | } 100 | }; 101 | 102 | } // namespace eva 103 | -------------------------------------------------------------------------------- /eva/common/reduction_balancer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | #include 9 | 10 | namespace eva { 11 | 12 | /* 13 | This pass combines nodes to reduce the depth of the tree. 14 | Suppose you have expression tree 15 | 16 | * 17 | / \ 18 | * t5(c) => * 19 | / \ / \ \ 20 | t1(c) t2(c) t1 t2 t5 21 | 22 | 23 | Before combining first it checks if some node have only one use and both 24 | of these nodes have same op then these two nodes are combined into one node 25 | with children of both the nodes. 26 | This pass helps to get the flat form of an expression so that later on it can 27 | be expanded to get a expression in a balanced form. 28 | For example (a * (b * (c * d))) => (a * b * c * d) => (a * b) * (c * d) 29 | */ 30 | class ReductionCombiner { 31 | Program &program; 32 | 33 | bool isReductionOp(const Op &op_code) { 34 | return ((op_code == Op::Add) || (op_code == Op::Mul)); 35 | } 36 | 37 | public: 38 | ReductionCombiner(Program &g) : program(g) {} 39 | 40 | void 41 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 42 | if (!term->isInternal() || !isReductionOp(term->op)) return; 43 | 44 | auto uses = term->getUses(); 45 | if (uses.size() == 1) { 46 | auto &use = uses[0]; 47 | if (use->op == term->op) { 48 | // combine term and its use 49 | while (use->eraseOperand(term)) { 50 | for (auto &operand : term->getOperands()) { 51 | // add term's operands to use's operands 52 | use->addOperand(operand); 53 | } 54 | } 55 | } 56 | } 57 | } 58 | }; 59 | 60 | class ReductionLogExpander { 61 | Program &program; 62 | TermMap &type; 63 | TermMapOptional scale; 64 | std::vector operands, nextOperands; 65 | std::map> sortedOperands; 66 | 67 | bool isReductionOp(const Op &op_code) { 68 | return ((op_code == Op::Add) || (op_code == Op::Mul)); 69 | } 70 | 71 | public: 72 | ReductionLogExpander(Program &g, TermMap &type) 73 | : program(g), type(type), scale(g) {} 74 | 75 | void operator()(Term::Ptr &term) { 76 | if (term->op == Op::Rescale || term->op == Op::ModSwitch) { 77 | throw std::logic_error("Rescale or ModSwitch encountered, but " 78 | "ReductionLogExpander uses scale as" 79 | " a proxy for level and assumes rescaling has not " 80 | "been performed yet."); 81 | } 82 | 83 | // Calculate the scales that we would get without any rescaling. Terms at a 84 | // similar scale will likely end up having the same level in typical 85 | // rescaling policies, which helps the sorting group terms of the same level 86 | // together. 87 | if (term->numOperands() == 0) { 88 | scale[term] = term->get(); 89 | } else if (term->op == Op::Mul) { 90 | scale[term] = std::accumulate( 91 | term->getOperands().begin(), term->getOperands().end(), 0, 92 | [&](auto &sum, auto &operand) { return sum + scale.at(operand); }); 93 | } else { 94 | scale[term] = std::accumulate(term->getOperands().begin(), 95 | term->getOperands().end(), 0, 96 | [&](auto &max, auto &operand) { 97 | return std::max(max, scale.at(operand)); 98 | }); 99 | } 100 | 101 | if (isReductionOp(term->op) && term->numOperands() > 2) { 102 | // We sort operands into constants, plaintext and raw, then ciphertexts by 103 | // scale. This helps avoid unnecessary accumulation of scale. 104 | for (auto &operand : term->getOperands()) { 105 | auto order = 0; 106 | if (type[operand] == Type::Plain || type[operand] == Type::Raw) { 107 | order = 1; 108 | } else if (type[operand] == Type::Cipher) { 109 | order = 2 + scale.at(operand); 110 | } 111 | sortedOperands[order].push_back(operand); 112 | } 113 | for (auto &op : sortedOperands) { 114 | operands.insert(operands.end(), op.second.begin(), op.second.end()); 115 | } 116 | 117 | // Expand the sorted operands into a balanced reduction tree by pairing 118 | // adjacent operands until only one remains. 119 | assert(operands.size() >= 2); 120 | while (operands.size() > 2) { 121 | std::size_t i = 0; 122 | while ((i + 1) < operands.size()) { 123 | auto &leftOperand = operands[i]; 124 | auto &rightOperand = operands[i + 1]; 125 | auto newTerm = 126 | program.makeTerm(term->op, {leftOperand, rightOperand}); 127 | nextOperands.push_back(newTerm); 128 | i += 2; 129 | } 130 | if (i < operands.size()) { 131 | assert((i + 1) == operands.size()); 132 | nextOperands.push_back(operands[i]); 133 | } 134 | operands = nextOperands; 135 | nextOperands.clear(); 136 | } 137 | 138 | assert(operands.size() == 2); 139 | term->setOperands(operands); 140 | 141 | operands.clear(); 142 | nextOperands.clear(); 143 | sortedOperands.clear(); 144 | } 145 | } 146 | }; 147 | 148 | } // namespace eva 149 | -------------------------------------------------------------------------------- /eva/common/reference_executor.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/common/reference_executor.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | using namespace std; 11 | 12 | namespace eva { 13 | 14 | void ReferenceExecutor::leftRotate(vector &output, 15 | const Term::Ptr &args, int32_t shift) { 16 | auto &input = terms_.at(args); 17 | 18 | // Reserve enough space for output 19 | output.clear(); 20 | output.reserve(input.size()); 21 | 22 | while (shift > 0 && shift >= input.size()) 23 | shift -= input.size(); 24 | while (shift < 0) 25 | shift += input.size(); 26 | 27 | // Shift left and copy to output 28 | copy_n(input.cbegin() + shift, input.size() - shift, back_inserter(output)); 29 | copy_n(input.cbegin(), shift, back_inserter(output)); 30 | } 31 | 32 | void ReferenceExecutor::rightRotate(vector &output, 33 | const Term::Ptr &args, int32_t shift) { 34 | auto &input = terms_.at(args); 35 | 36 | // Reserve enough space for output 37 | output.clear(); 38 | output.reserve(input.size()); 39 | 40 | while (shift > 0 && shift >= input.size()) 41 | shift -= input.size(); 42 | while (shift < 0) 43 | shift += input.size(); 44 | 45 | // Shift right and copy to output 46 | copy_n(input.cend() - shift, shift, back_inserter(output)); 47 | copy_n(input.cbegin(), input.size() - shift, back_inserter(output)); 48 | } 49 | 50 | void ReferenceExecutor::negate(vector &output, const Term::Ptr &args) { 51 | auto &input = terms_.at(args); 52 | 53 | // Reserve enough space for output 54 | output.clear(); 55 | output.reserve(input.size()); 56 | transform(input.cbegin(), input.cend(), back_inserter(output), 57 | std::negate()); 58 | } 59 | 60 | void ReferenceExecutor::operator()(const Term::Ptr &term) { 61 | // Must only be used with forward pass traversal 62 | auto &output = terms_[term]; 63 | 64 | auto op = term->op; 65 | auto args = term->getOperands(); 66 | 67 | switch (op) { 68 | case Op::Input: 69 | // Nothing to do for inputs 70 | break; 71 | case Op::Constant: 72 | // A constant (vector) is expanded to the number of slots (vecSize_ here) 73 | term->get()->expandTo(output, vecSize_); 74 | break; 75 | case Op::Add: 76 | assert(args.size() == 2); 77 | binOp>(output, args[0], args[1]); 78 | break; 79 | case Op::Sub: 80 | assert(args.size() == 2); 81 | binOp>(output, args[0], args[1]); 82 | break; 83 | case Op::Mul: 84 | assert(args.size() == 2); 85 | binOp>(output, args[0], args[1]); 86 | break; 87 | case Op::RotateLeftConst: 88 | assert(args.size() == 1); 89 | leftRotate(output, args[0], term->get()); 90 | break; 91 | case Op::RotateRightConst: 92 | assert(args.size() == 1); 93 | rightRotate(output, args[0], term->get()); 94 | break; 95 | case Op::Negate: 96 | assert(args.size() == 1); 97 | negate(output, args[0]); 98 | break; 99 | case Op::Encode: 100 | [[fallthrough]]; 101 | case Op::Output: 102 | [[fallthrough]]; 103 | case Op::Relinearize: 104 | [[fallthrough]]; 105 | case Op::ModSwitch: 106 | [[fallthrough]]; 107 | case Op::Rescale: 108 | // Copy argument value for outputs 109 | assert(args.size() == 1); 110 | output = terms_[args[0]]; 111 | break; 112 | default: 113 | assert(false); 114 | } 115 | } 116 | 117 | } // namespace eva 118 | -------------------------------------------------------------------------------- /eva/common/reference_executor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/common/valuation.h" 7 | #include "eva/ir/program.h" 8 | #include "eva/ir/term_map.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace eva { 15 | 16 | // Executes unencrypted computation 17 | class ReferenceExecutor { 18 | public: 19 | ReferenceExecutor(Program &g) 20 | : program_(g), vecSize_(g.getVecSize()), terms_(g) {} 21 | 22 | ReferenceExecutor(const ReferenceExecutor ©) = delete; 23 | 24 | ReferenceExecutor &operator=(const ReferenceExecutor &assign) = delete; 25 | 26 | template 27 | void setInputs(const std::unordered_map &inputs) { 28 | for (auto &in : inputs) { 29 | auto term = program_.getInput(in.first); 30 | terms_[term] = in.second; // TODO: can we avoid this copy? 31 | if (terms_[term].size() != vecSize_) { 32 | throw std::runtime_error( 33 | "The length of all inputs must be the same as program's vector " 34 | "size. Input " + 35 | in.first + " has length " + std::to_string(terms_[term].size()) + 36 | ", but vector size is " + std::to_string(vecSize_)); 37 | } 38 | } 39 | } 40 | 41 | void operator()(const Term::Ptr &term); 42 | 43 | void free(const Term::Ptr &term) { 44 | if (term->op == Op::Output) return; 45 | terms_[term].clear(); 46 | } 47 | 48 | void getOutputs(Valuation &outputs) { 49 | for (auto &out : program_.getOutputs()) { 50 | outputs[out.first] = terms_[out.second]; 51 | } 52 | } 53 | 54 | private: 55 | Program &program_; 56 | std::uint64_t vecSize_; 57 | TermMapOptional> terms_; 58 | 59 | template 60 | void binOp(std::vector &out, const Term::Ptr &args1, 61 | const Term::Ptr &args2) { 62 | auto &in1 = terms_.at(args1); 63 | auto &in2 = terms_.at(args2); 64 | assert(in1.size() == in2.size()); 65 | 66 | out.clear(); 67 | out.reserve(in1.size()); 68 | transform(in1.cbegin(), in1.cend(), in2.cbegin(), back_inserter(out), Op()); 69 | } 70 | 71 | void leftRotate(std::vector &output, const Term::Ptr &args, 72 | std::int32_t shift); 73 | 74 | void rightRotate(std::vector &output, const Term::Ptr &args, 75 | std::int32_t shift); 76 | 77 | void negate(std::vector &output, const Term::Ptr &args); 78 | }; 79 | 80 | } // namespace eva 81 | -------------------------------------------------------------------------------- /eva/common/rotation_keys_selector.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace eva { 14 | 15 | class RotationKeysSelector { 16 | public: 17 | RotationKeysSelector(Program &g, const TermMap &type) 18 | : program_(g), type(type) {} 19 | 20 | void operator()(const Term::Ptr &term) { 21 | auto op = term->op; 22 | 23 | // Nothing to do if this is not a rotation 24 | if (!isLeftRotationOp(op) && !isRightRotationOp(op)) return; 25 | 26 | // No rotation keys needed for raw computation 27 | if (type[term] == Type::Raw) return; 28 | 29 | // Add the rotation count 30 | auto rotation = term->get(); 31 | keys_.insert(isRightRotationOp(op) ? -rotation : rotation); 32 | } 33 | 34 | void free(const Term::Ptr &term) { 35 | // No-op 36 | } 37 | 38 | auto getRotationKeys() { 39 | // Return the set of rotations needed 40 | return keys_; 41 | } 42 | 43 | private: 44 | Program &program_; 45 | const TermMap &type; 46 | std::set keys_; 47 | 48 | bool isLeftRotationOp(const Op &op_code) { 49 | return (op_code == Op::RotateLeftConst); 50 | } 51 | 52 | bool isRightRotationOp(const Op &op_code) { 53 | return (op_code == Op::RotateRightConst); 54 | } 55 | }; 56 | 57 | } // namespace eva 58 | -------------------------------------------------------------------------------- /eva/common/type_deducer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term_map.h" 8 | 9 | namespace eva { 10 | 11 | class TypeDeducer { 12 | Program &program; 13 | TermMap &types; 14 | 15 | public: 16 | TypeDeducer(Program &g, TermMap &types) : program(g), types(types) {} 17 | 18 | void 19 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal 20 | auto &operands = term->getOperands(); 21 | if (operands.size() > 0) { // not an input/root 22 | Type inferred = Type::Raw; // Plain if not Cipher 23 | for (auto &operand : operands) { 24 | if (types[operand] == Type::Cipher) 25 | inferred = Type::Cipher; // Cipher if any operand is Cipher 26 | } 27 | if (term->op == Op::Encode) { 28 | types[term] = Type::Plain; 29 | } else { 30 | types[term] = inferred; 31 | } 32 | } else if (term->op == Op::Constant) { 33 | types[term] = Type::Raw; 34 | } else { 35 | types[term] = term->get(); 36 | } 37 | } 38 | }; 39 | 40 | } // namespace eva 41 | -------------------------------------------------------------------------------- /eva/common/valuation.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace eva { 11 | 12 | using Valuation = std::unordered_map>; 13 | 14 | } 15 | -------------------------------------------------------------------------------- /eva/eva.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/eva.h" 5 | #include "eva/common/program_traversal.h" 6 | #include "eva/common/reference_executor.h" 7 | #include "eva/common/valuation.h" 8 | 9 | namespace eva { 10 | 11 | Valuation evaluate(Program &program, const Valuation &inputs) { 12 | Valuation outputs; 13 | ProgramTraversal programTraverse(program); 14 | ReferenceExecutor ge(program); 15 | 16 | ge.setInputs(inputs); 17 | programTraverse.forwardPass(ge); 18 | ge.getOutputs(outputs); 19 | 20 | return outputs; 21 | } 22 | 23 | } // namespace eva 24 | -------------------------------------------------------------------------------- /eva/eva.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ckks/ckks_compiler.h" 7 | #include "eva/ir/program.h" 8 | #include "eva/seal/seal.h" 9 | #include "eva/serialization/save_load.h" 10 | #include "eva/version.h" 11 | 12 | namespace eva { 13 | 14 | Valuation evaluate(Program &program, const Valuation &inputs); 15 | 16 | } 17 | -------------------------------------------------------------------------------- /eva/ir/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | target_sources(eva PRIVATE 5 | term.cpp 6 | program.cpp 7 | attribute_list.cpp 8 | attributes.cpp 9 | ) 10 | -------------------------------------------------------------------------------- /eva/ir/attribute_list.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/ir/attribute_list.h" 5 | #include "eva/ir/attributes.h" 6 | #include "eva/util/overloaded.h" 7 | #include 8 | 9 | using namespace std; 10 | 11 | namespace eva { 12 | 13 | bool AttributeList::has(AttributeKey k) const { 14 | const AttributeList *curr = this; 15 | while (true) { 16 | if (curr->key < k) { 17 | if (curr->tail) { 18 | curr = curr->tail.get(); 19 | } else { 20 | return false; 21 | } 22 | } else { 23 | return curr->key == k; 24 | } 25 | } 26 | } 27 | 28 | const AttributeValue &AttributeList::get(AttributeKey k) const { 29 | const AttributeList *curr = this; 30 | while (true) { 31 | if (curr->key == k) { 32 | return curr->value; 33 | } else if (curr->key < k && curr->tail) { 34 | curr = curr->tail.get(); 35 | } else { 36 | throw out_of_range("Attribute not in list: " + getAttributeName(k)); 37 | } 38 | } 39 | } 40 | 41 | void AttributeList::set(AttributeKey k, AttributeValue v) { 42 | if (this->key == 0) { 43 | this->key = k; 44 | this->value = move(v); 45 | } else { 46 | AttributeList *curr = this; 47 | AttributeList *prev = nullptr; 48 | while (true) { 49 | if (curr->key < k) { 50 | if (curr->tail) { 51 | prev = curr; 52 | curr = curr->tail.get(); 53 | } else { // Insert at end 54 | // AttributeList constructor is private 55 | curr->tail = unique_ptr{new AttributeList(k, move(v))}; 56 | return; 57 | } 58 | } else if (curr->key > k) { 59 | if (prev) { // Insert between 60 | // AttributeList constructor is private 61 | auto newList = 62 | unique_ptr{new AttributeList(k, move(v))}; 63 | newList->tail = move(prev->tail); 64 | prev->tail = move(newList); 65 | } else { // Insert at beginning 66 | // AttributeList constructor is private 67 | curr->tail = 68 | unique_ptr{new AttributeList(move(*curr))}; 69 | curr->key = k; 70 | curr->value = move(v); 71 | } 72 | return; 73 | } else { 74 | assert(curr->key == k); 75 | curr->value = move(v); 76 | return; 77 | } 78 | } 79 | } 80 | } 81 | 82 | void AttributeList::assignAttributesFrom(const AttributeList &other) { 83 | if (this->key != 0) { 84 | this->key = 0; 85 | this->value = std::monostate(); 86 | this->tail = nullptr; 87 | } 88 | if (other.key == 0) { 89 | return; 90 | } 91 | AttributeList *lhs = this; 92 | const AttributeList *rhs = &other; 93 | while (true) { 94 | lhs->key = rhs->key; 95 | lhs->value = rhs->value; 96 | if (rhs->tail) { 97 | rhs = rhs->tail.get(); 98 | lhs->tail = std::make_unique(); 99 | lhs = lhs->tail.get(); 100 | } else { 101 | return; 102 | } 103 | } 104 | } 105 | 106 | } // namespace eva 107 | -------------------------------------------------------------------------------- /eva/ir/attribute_list.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/constant_value.h" 7 | #include "eva/ir/types.h" 8 | #include "eva/serialization/eva.pb.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace eva { 18 | 19 | using AttributeValue = std::variant>; 21 | 22 | template struct IsInVariant; 23 | template 24 | struct IsInVariant> 25 | : std::bool_constant<(... || std::is_same{})> {}; 26 | 27 | using AttributeKey = std::uint8_t; 28 | 29 | template struct Attribute { 30 | static_assert(IsInVariant::value, 31 | "Attribute type not in AttributeValue"); 32 | static_assert(Key > 0, "Keys must be strictly positive"); 33 | static_assert(Key <= std::numeric_limits::max(), 34 | "Key larger than current AttributeKey type"); 35 | 36 | using Value = T; 37 | static constexpr AttributeKey key = Key; 38 | 39 | static bool isValid(AttributeKey k, const AttributeValue &v) { 40 | return k == Key && std::holds_alternative(v); 41 | } 42 | }; 43 | 44 | class AttributeList { 45 | public: 46 | AttributeList() : key(0), tail(nullptr) {} 47 | 48 | // This function is defined in eva/serialization/eva_serialization.cpp 49 | void loadAttribute(const msg::Attribute &msg); 50 | 51 | // This function is defined in eva/serialization/eva_serialization.cpp 52 | void serializeAttributes(std::function addMsg) const; 53 | 54 | template bool has() const { return has(TAttr::key); } 55 | 56 | template const typename TAttr::Value &get() const { 57 | return std::get(get(TAttr::key)); 58 | } 59 | 60 | template void set(typename TAttr::Value value) { 61 | set(TAttr::key, std::move(value)); 62 | } 63 | 64 | void assignAttributesFrom(const AttributeList &other); 65 | 66 | private: 67 | AttributeKey key; 68 | AttributeValue value; 69 | std::unique_ptr tail; 70 | 71 | AttributeList(AttributeKey k, AttributeValue v) 72 | : key(k), value(std::move(v)) {} 73 | 74 | bool has(AttributeKey k) const; 75 | const AttributeValue &get(AttributeKey k) const; 76 | void set(AttributeKey k, AttributeValue v); 77 | }; 78 | 79 | } // namespace eva 80 | -------------------------------------------------------------------------------- /eva/ir/attributes.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/ir/attributes.h" 5 | #include 6 | 7 | using namespace std; 8 | 9 | namespace eva { 10 | 11 | #define X(name, type) name::isValid(k, v) || 12 | bool isValidAttribute(AttributeKey k, const AttributeValue &v) { 13 | return EVA_ATTRIBUTES false; 14 | } 15 | #undef X 16 | 17 | #define X(name, type) \ 18 | case detail::name##Index: \ 19 | return #name; 20 | string getAttributeName(AttributeKey k) { 21 | switch (k) { 22 | EVA_ATTRIBUTES 23 | default: 24 | throw runtime_error("Unknown attribute key"); 25 | } 26 | } 27 | #undef X 28 | 29 | } // namespace eva 30 | -------------------------------------------------------------------------------- /eva/ir/attributes.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/attribute_list.h" 7 | #include 8 | #include 9 | 10 | namespace eva { 11 | 12 | #define EVA_ATTRIBUTES \ 13 | X(RescaleDivisorAttribute, std::uint32_t) \ 14 | X(RotationAttribute, std::int32_t) \ 15 | X(ConstantValueAttribute, std::shared_ptr) \ 16 | X(TypeAttribute, Type) \ 17 | X(RangeAttribute, std::uint32_t) \ 18 | X(EncodeAtScaleAttribute, std::uint32_t) \ 19 | X(EncodeAtLevelAttribute, std::uint32_t) 20 | 21 | namespace detail { 22 | enum AttributeIndex { 23 | RESERVE_EMPTY_ATTRIBUTE_KEY = 0, 24 | #define X(name, type) name##Index, 25 | EVA_ATTRIBUTES 26 | #undef X 27 | }; 28 | } // namespace detail 29 | 30 | #define X(name, type) using name = Attribute; 31 | EVA_ATTRIBUTES 32 | #undef X 33 | 34 | bool isValidAttribute(AttributeKey k, const AttributeValue &v); 35 | 36 | std::string getAttributeName(AttributeKey k); 37 | 38 | } // namespace eva 39 | -------------------------------------------------------------------------------- /eva/ir/constant_value.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/serialization/eva.pb.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace eva { 15 | 16 | class ConstantValue { 17 | public: 18 | ConstantValue(std::size_t size) : size(size) {} 19 | virtual ~ConstantValue() {} 20 | virtual const std::vector &expand(std::vector &scratch, 21 | std::size_t slots) const = 0; 22 | virtual void expandTo(std::vector &result, 23 | std::size_t slots) const = 0; 24 | virtual bool isZero() const = 0; 25 | virtual void serialize(msg::ConstantValue &msg) const = 0; 26 | 27 | protected: 28 | std::size_t size; 29 | 30 | void validateSlots(std::size_t slots) const { 31 | if (slots < size) { 32 | throw std::runtime_error("Slots must be at least size of constant"); 33 | } 34 | if (slots % size != 0) { 35 | throw std::runtime_error("Size must exactly divide slots"); 36 | } 37 | } 38 | }; 39 | 40 | class DenseConstantValue : public ConstantValue { 41 | public: 42 | DenseConstantValue(std::size_t size, std::vector values) 43 | : ConstantValue(size), values(values) { 44 | if (size % values.size() != 0) { 45 | throw std::runtime_error( 46 | "DenseConstantValue size must exactly divide size"); 47 | } 48 | } 49 | 50 | const std::vector &expand(std::vector &scratch, 51 | std::size_t slots) const override { 52 | validateSlots(slots); 53 | if (values.size() == slots) { 54 | return values; 55 | } else { 56 | scratch.clear(); 57 | for (int r = slots / values.size(); r > 0; --r) { 58 | scratch.insert(scratch.end(), values.begin(), values.end()); 59 | } 60 | return scratch; 61 | } 62 | } 63 | 64 | void expandTo(std::vector &result, std::size_t slots) const override { 65 | validateSlots(slots); 66 | result.clear(); 67 | result.reserve(slots); 68 | for (int r = slots / values.size(); r > 0; --r) { 69 | result.insert(result.end(), values.begin(), values.end()); 70 | } 71 | } 72 | 73 | bool isZero() const override { 74 | for (double value : values) { 75 | if (value != 0) return false; 76 | } 77 | return true; 78 | } 79 | 80 | void serialize(msg::ConstantValue &msg) const override { 81 | msg.set_size(size); 82 | auto valuesMsg = msg.mutable_values(); 83 | valuesMsg->Reserve(values.size()); 84 | for (const auto &value : values) { 85 | valuesMsg->Add(value); 86 | } 87 | } 88 | 89 | private: 90 | std::vector values; 91 | }; 92 | 93 | class SparseConstantValue : public ConstantValue { 94 | public: 95 | SparseConstantValue(std::size_t size, 96 | std::vector> values) 97 | : ConstantValue(size), values(values) {} 98 | 99 | const std::vector &expand(std::vector &scratch, 100 | std::size_t slots) const override { 101 | validateSlots(slots); 102 | scratch.clear(); 103 | scratch.resize(slots); 104 | for (auto &entry : values) { 105 | for (int i = 0; i < slots; i += values.size()) { 106 | scratch.at(entry.first + i) = entry.second; 107 | } 108 | } 109 | return scratch; 110 | } 111 | 112 | void expandTo(std::vector &result, std::size_t slots) const override { 113 | validateSlots(slots); 114 | result.clear(); 115 | result.resize(slots); 116 | for (auto &entry : values) { 117 | for (int i = 0; i < slots; i += values.size()) { 118 | result.at(entry.first + i) = entry.second; 119 | } 120 | } 121 | } 122 | 123 | bool isZero() const override { 124 | // TODO: this assumes no repeated indices 125 | for (auto entry : values) { 126 | if (entry.second != 0) return false; 127 | } 128 | return true; 129 | } 130 | 131 | void serialize(msg::ConstantValue &msg) const override { 132 | msg.set_size(size); 133 | for (const auto &pair : values) { 134 | msg.add_sparse_indices(pair.first); 135 | msg.add_values(pair.second); 136 | } 137 | } 138 | 139 | private: 140 | std::vector> values; 141 | }; 142 | 143 | std::unique_ptr serialize(const ConstantValue &obj); 144 | std::shared_ptr deserialize(const msg::ConstantValue &msg); 145 | 146 | } // namespace eva 147 | -------------------------------------------------------------------------------- /eva/ir/ops.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | 9 | namespace eva { 10 | 11 | #define EVA_OPS \ 12 | X(Undef, 0) \ 13 | X(Input, 1) \ 14 | X(Output, 2) \ 15 | X(Constant, 3) \ 16 | X(Negate, 10) \ 17 | X(Add, 11) \ 18 | X(Sub, 12) \ 19 | X(Mul, 13) \ 20 | X(RotateLeftConst, 14) \ 21 | X(RotateRightConst, 15) \ 22 | X(Relinearize, 20) \ 23 | X(ModSwitch, 21) \ 24 | X(Rescale, 22) \ 25 | X(Encode, 23) 26 | 27 | enum class Op { 28 | #define X(op, code) op = code, 29 | EVA_OPS 30 | #undef X 31 | }; 32 | 33 | inline bool isValidOp(Op op) { 34 | switch (op) { 35 | #define X(op, code) case Op::op: 36 | EVA_OPS 37 | #undef X 38 | return true; 39 | default: 40 | return false; 41 | } 42 | } 43 | 44 | inline std::string getOpName(Op op) { 45 | switch (op) { 46 | #define X(op, code) \ 47 | case Op::op: \ 48 | return #op; 49 | EVA_OPS 50 | #undef X 51 | default: 52 | throw std::runtime_error("Invalid op"); 53 | } 54 | } 55 | 56 | } // namespace eva 57 | -------------------------------------------------------------------------------- /eva/ir/program.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/ir/program.h" 5 | #include "eva/common/program_traversal.h" 6 | #include "eva/ir/term_map.h" 7 | #include "eva/util/logging.h" 8 | #include 9 | 10 | using namespace std; 11 | 12 | namespace eva { 13 | 14 | // TODO: maybe replace with smart iterator to avoid allocation 15 | vector toTermPtrs(const unordered_set &terms) { 16 | vector termPtrs; 17 | termPtrs.reserve(terms.size()); 18 | for (auto &term : terms) { 19 | termPtrs.emplace_back(term->shared_from_this()); 20 | } 21 | return termPtrs; 22 | } 23 | 24 | vector Program::getSources() const { 25 | return toTermPtrs(this->sources); 26 | } 27 | 28 | vector Program::getSinks() const { return toTermPtrs(this->sinks); } 29 | 30 | std::unique_ptr Program::deepCopy() { 31 | auto newProg = std::make_unique(getName(), getVecSize()); 32 | TermMap oldToNew(*this); 33 | ProgramTraversal traversal(*this); 34 | traversal.forwardPass([&](Term::Ptr &term) { 35 | auto newTerm = newProg->makeTerm(term->op); 36 | oldToNew[term] = newTerm; 37 | newTerm->assignAttributesFrom(*term); 38 | for (auto &operand : term->getOperands()) { 39 | newTerm->addOperand(oldToNew[operand]); 40 | } 41 | }); 42 | for (auto &entry : inputs) { 43 | newProg->inputs[entry.first] = oldToNew[entry.second]; 44 | } 45 | for (auto &entry : outputs) { 46 | newProg->outputs[entry.first] = oldToNew[entry.second]; 47 | } 48 | return newProg; 49 | } 50 | 51 | uint64_t Program::allocateIndex() { 52 | // TODO: reuse released indices to save space in TermMap instances 53 | uint64_t index = nextTermIndex++; 54 | for (TermMapBase *termMap : termMaps) { 55 | termMap->resize(nextTermIndex); 56 | } 57 | return index; 58 | } 59 | 60 | void Program::initTermMap(TermMapBase &termMap) { 61 | termMap.resize(nextTermIndex); 62 | } 63 | 64 | void Program::registerTermMap(TermMapBase *termMap) { 65 | termMaps.emplace_back(termMap); 66 | } 67 | 68 | void Program::unregisterTermMap(TermMapBase *termMap) { 69 | auto iter = find(termMaps.begin(), termMaps.end(), termMap); 70 | if (iter == termMaps.end()) { 71 | throw runtime_error("TermMap to unregister not found"); 72 | } else { 73 | termMaps.erase(iter); 74 | } 75 | } 76 | 77 | template 78 | void dumpAttribute(stringstream &s, Term *term, std::string label) { 79 | if (term->has()) { 80 | s << ", " << label << "=" << term->get(); 81 | } 82 | } 83 | 84 | // Print an attribute in DOT format as a box outside the term 85 | template 86 | void toDOTAttributeAsNode(stringstream &s, Term *term, std::string label) { 87 | if (term->has()) { 88 | s << "t" << term->index << "_" << getAttributeName(Attr::key) 89 | << " [shape=box label=\"" << label << "=" << term->get() 90 | << "\"];\n"; 91 | s << "t" << term->index << "_" << getAttributeName(Attr::key) << " -> t" 92 | << term->index << ";\n"; 93 | } 94 | } 95 | 96 | string Program::dump(TermMapOptional &scales, 97 | TermMap &types, 98 | TermMap &level) const { 99 | // TODO: switch to use a non-parallel generic traversal 100 | stringstream s; 101 | s << getName() << "(){\n"; 102 | 103 | // Add all terms in topologically sorted order 104 | uint64_t nextIndex = 0; 105 | unordered_map indices; 106 | stack> work; 107 | for (const auto &sink : getSinks()) { 108 | work.emplace(true, sink.get()); 109 | } 110 | while (!work.empty()) { 111 | bool visit = work.top().first; 112 | auto term = work.top().second; 113 | work.pop(); 114 | if (indices.count(term)) { 115 | continue; 116 | } 117 | if (visit) { 118 | work.emplace(false, term); 119 | for (const auto &operand : term->getOperands()) { 120 | work.emplace(true, operand.get()); 121 | } 122 | } else { 123 | auto index = nextIndex; 124 | nextIndex += 1; 125 | indices[term] = index; 126 | s << "t" << term->index << " = " << getOpName(term->op); 127 | if (term->has()) { 128 | s << "(" << term->get() << ")"; 129 | } 130 | if (term->has()) { 131 | s << "(" << term->get() << ")"; 132 | } 133 | if (term->has()) { 134 | s << ":" << getTypeName(term->get()); 135 | } 136 | for (int i = 0; i < term->numOperands(); ++i) { 137 | s << " t" << term->operandAt(i)->index; 138 | } 139 | dumpAttribute(s, term, "range"); 140 | dumpAttribute(s, term, "level"); 141 | if (types[*term] == Type::Cipher) 142 | s << ", " 143 | << "s" 144 | << "=" << scales[*term] << ", t=cipher "; 145 | else if (types[*term] == Type::Raw) 146 | s << ", " 147 | << "s" 148 | << "=" << scales[*term] << ", t=raw "; 149 | else 150 | s << ", " 151 | << "s" 152 | << "=" << scales[*term] << ", t=plain "; 153 | s << "\n"; 154 | // ConstantValue TODO: printing constant values for simple cases 155 | } 156 | } 157 | 158 | s << "}\n"; 159 | return s.str(); 160 | } 161 | 162 | string Program::toDOT() const { 163 | // TODO: switch to use a non-parallel generic traversal 164 | stringstream s; 165 | 166 | s << "digraph \"" << getName() << "\" {\n"; 167 | 168 | // Add all terms in topologically sorted order 169 | uint64_t nextIndex = 0; 170 | unordered_map indices; 171 | stack> work; 172 | for (const auto &sink : getSinks()) { 173 | work.emplace(true, sink.get()); 174 | } 175 | while (!work.empty()) { 176 | bool visit = work.top().first; 177 | auto term = work.top().second; 178 | work.pop(); 179 | if (indices.count(term)) { 180 | continue; 181 | } 182 | if (visit) { 183 | work.emplace(false, term); 184 | for (const auto &operand : term->getOperands()) { 185 | work.emplace(true, operand.get()); 186 | } 187 | } else { 188 | auto index = nextIndex; 189 | nextIndex += 1; 190 | indices[term] = index; 191 | 192 | // Operands are guaranteed to have been added 193 | s << "t" << term->index << " [label=\"" << getOpName(term->op); 194 | if (term->has()) { 195 | s << "(" << term->get() << ")"; 196 | } 197 | if (term->has()) { 198 | s << "(" << term->get() << ")"; 199 | } 200 | if (term->has()) { 201 | s << " : " << getTypeName(term->get()); 202 | } 203 | s << "\""; // End label 204 | s << "];\n"; 205 | for (int i = 0; i < term->numOperands(); ++i) { 206 | s << "t" << term->operandAt(i)->index << " -> t" << term->index 207 | << " [label=\"" << i << "\"];\n"; 208 | } 209 | toDOTAttributeAsNode(s, term, "range"); 210 | toDOTAttributeAsNode(s, term, "scale"); 211 | toDOTAttributeAsNode(s, term, "level"); 212 | // ConstantValue TODO: printing constant values for simple cases 213 | } 214 | } 215 | 216 | s << "}\n"; 217 | 218 | return s.str(); 219 | } 220 | 221 | } // namespace eva 222 | -------------------------------------------------------------------------------- /eva/ir/program.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/constant_value.h" 7 | #include "eva/ir/term.h" 8 | #include "eva/serialization/eva.pb.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace eva { 19 | 20 | template class TermMapOptional; 21 | template class TermMap; 22 | class TermMapBase; 23 | 24 | class Program { 25 | public: 26 | Program(std::string name, std::uint64_t vecSize) 27 | : name(name), vecSize(vecSize), nextTermIndex(0) { 28 | if (vecSize == 0) { 29 | throw std::runtime_error("Vector size must be non-zero"); 30 | } 31 | if ((vecSize & (vecSize - 1)) != 0) { 32 | throw std::runtime_error("Vector size must be a power-of-two"); 33 | } 34 | } 35 | 36 | Program(const Program ©) = delete; 37 | 38 | Program &operator=(const Program &assign) = delete; 39 | 40 | Term::Ptr makeTerm(Op op, const std::vector &operands = {}) { 41 | auto term = std::make_shared(op, *this); 42 | if (operands.size() > 0) { 43 | term->setOperands(operands); 44 | } 45 | return term; 46 | } 47 | 48 | Term::Ptr makeConstant(std::unique_ptr value) { 49 | auto term = makeTerm(Op::Constant); 50 | term->set(std::move(value)); 51 | return term; 52 | } 53 | 54 | Term::Ptr makeDenseConstant(std::vector values) { 55 | return makeConstant(std::make_unique(vecSize, values)); 56 | } 57 | 58 | Term::Ptr makeUniformConstant(double value) { 59 | return makeDenseConstant({value}); 60 | } 61 | 62 | Term::Ptr makeInput(const std::string &name, Type type = Type::Cipher) { 63 | auto term = makeTerm(Op::Input); 64 | term->set(type); 65 | inputs.emplace(name, term); 66 | return term; 67 | } 68 | 69 | Term::Ptr makeOutput(std::string name, const Term::Ptr &term) { 70 | auto output = makeTerm(Op::Output, {term}); 71 | outputs.emplace(name, output); 72 | return output; 73 | } 74 | 75 | Term::Ptr makeLeftRotation(const Term::Ptr &term, std::int32_t slots) { 76 | auto rotation = makeTerm(Op::RotateLeftConst, {term}); 77 | rotation->set(slots); 78 | return rotation; 79 | } 80 | 81 | Term::Ptr makeRightRotation(const Term::Ptr &term, std::int32_t slots) { 82 | auto rotation = makeTerm(Op::RotateRightConst, {term}); 83 | rotation->set(slots); 84 | return rotation; 85 | } 86 | 87 | Term::Ptr makeRescale(const Term::Ptr &term, std::uint32_t rescaleBy) { 88 | auto rescale = makeTerm(Op::Rescale, {term}); 89 | rescale->set(rescaleBy); 90 | return rescale; 91 | } 92 | 93 | Term::Ptr getInput(std::string name) const { 94 | if (inputs.find(name) == inputs.end()) { 95 | std::stringstream s; 96 | s << "No input named " << name; 97 | throw std::out_of_range(s.str()); 98 | } 99 | return inputs.at(name); 100 | } 101 | 102 | const auto &getInputs() const { return inputs; } 103 | 104 | const auto &getOutputs() const { return outputs; } 105 | 106 | std::string getName() const { return name; } 107 | void setName(std::string newName) { name = newName; } 108 | 109 | std::uint32_t getVecSize() const { return vecSize; } 110 | 111 | std::vector getSources() const; 112 | 113 | std::vector getSinks() const; 114 | 115 | // Make a deep copy of this program 116 | std::unique_ptr deepCopy(); 117 | 118 | std::string toDOT() const; 119 | std::string dump(TermMapOptional &scales, 120 | TermMap &types, 121 | TermMap &level) const; 122 | 123 | private: 124 | std::uint64_t allocateIndex(); 125 | void initTermMap(TermMapBase &termMap); 126 | void registerTermMap(TermMapBase *annotation); 127 | void unregisterTermMap(TermMapBase *annotation); 128 | 129 | std::string name; 130 | std::uint32_t vecSize; 131 | 132 | // These are managed automatically by Term 133 | std::unordered_set sources; 134 | std::unordered_set sinks; 135 | 136 | std::uint64_t nextTermIndex; 137 | std::vector termMaps; 138 | 139 | // These members must currently be last, because their destruction triggers 140 | // associated Terms to be destructed, which still use the sources and sinks 141 | // structures above. 142 | // TODO: move away from shared ownership for Terms and have Program own them 143 | // uniquely. It is an error to hold onto a Term longer than a Program, but 144 | // the shared_ptr is misleading on this regard. 145 | std::unordered_map outputs; 146 | std::unordered_map inputs; 147 | 148 | friend class Term; 149 | friend class TermMapBase; 150 | friend std::unique_ptr serialize(const Program &); 151 | friend std::unique_ptr deserialize(const msg::Program &); 152 | }; 153 | 154 | std::unique_ptr deserialize(const msg::Program &); 155 | 156 | } // namespace eva 157 | -------------------------------------------------------------------------------- /eva/ir/term.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/ir/term.h" 5 | #include "eva/ir/program.h" 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | 11 | namespace eva { 12 | 13 | Term::Term(Op op, Program &program) 14 | : op(op), program(program), index(program.allocateIndex()) { 15 | program.sources.insert(this); 16 | program.sinks.insert(this); 17 | } 18 | 19 | Term::~Term() { 20 | for (Ptr &operand : operands) { 21 | operand->eraseUse(this); 22 | } 23 | if (operands.empty()) { 24 | program.sources.erase(this); 25 | } 26 | assert(uses.empty()); 27 | program.sinks.erase(this); 28 | } 29 | 30 | void Term::addOperand(const Term::Ptr &term) { 31 | if (operands.empty()) { 32 | program.sources.erase(this); 33 | } 34 | operands.emplace_back(term); 35 | term->addUse(this); 36 | } 37 | 38 | bool Term::eraseOperand(const Ptr &term) { 39 | auto iter = find(operands.begin(), operands.end(), term); 40 | if (iter != operands.end()) { 41 | term->eraseUse(this); 42 | operands.erase(iter); 43 | if (operands.empty()) { 44 | program.sources.insert(this); 45 | } 46 | return true; 47 | } 48 | return false; 49 | } 50 | 51 | bool Term::replaceOperand(Ptr oldTerm, Ptr newTerm) { 52 | bool replaced = false; 53 | for (Ptr &operand : operands) { 54 | if (operand == oldTerm) { 55 | operand = newTerm; 56 | oldTerm->eraseUse(this); 57 | newTerm->addUse(this); 58 | replaced = true; 59 | } 60 | } 61 | return replaced; 62 | } 63 | 64 | void Term::replaceUsesWithIf(Ptr term, function predicate) { 65 | auto thisPtr = shared_from_this(); // TODO: avoid this and similar 66 | // unnecessary reference counting 67 | for (auto &use : getUses()) { 68 | if (predicate(use)) { 69 | use->replaceOperand(thisPtr, term); 70 | } 71 | } 72 | } 73 | 74 | void Term::replaceAllUsesWith(Ptr term) { 75 | replaceUsesWithIf(term, [](const Ptr &) { return true; }); 76 | } 77 | 78 | void Term::replaceOtherUsesWith(Ptr term) { 79 | replaceUsesWithIf(term, [&](const Ptr &use) { return use != term; }); 80 | } 81 | 82 | void Term::setOperands(vector o) { 83 | if (operands.empty()) { 84 | program.sources.erase(this); 85 | } 86 | 87 | for (auto &operand : operands) { 88 | operand->eraseUse(this); 89 | } 90 | operands = move(o); 91 | for (auto &operand : operands) { 92 | operand->addUse(this); 93 | } 94 | 95 | if (operands.empty()) { 96 | program.sources.insert(this); 97 | } 98 | } 99 | 100 | size_t Term::numOperands() const { return operands.size(); } 101 | 102 | Term::Ptr Term::operandAt(size_t i) { return operands.at(i); } 103 | 104 | const vector &Term::getOperands() const { return operands; } 105 | 106 | size_t Term::numUses() { return uses.size(); } 107 | 108 | vector Term::getUses() { 109 | vector u; 110 | for (Term *use : uses) { 111 | u.emplace_back(use->shared_from_this()); 112 | } 113 | return u; 114 | } 115 | 116 | bool Term::isInternal() const { 117 | return ((operands.size() != 0) && (uses.size() != 0)); 118 | } 119 | 120 | void Term::addUse(Term *term) { 121 | if (uses.empty()) { 122 | program.sinks.erase(this); 123 | } 124 | uses.emplace_back(term); 125 | } 126 | 127 | bool Term::eraseUse(Term *term) { 128 | auto iter = find(uses.begin(), uses.end(), term); 129 | assert(iter != uses.end()); 130 | uses.erase(iter); 131 | if (uses.empty()) { 132 | program.sinks.insert(this); 133 | return true; 134 | } 135 | return false; 136 | } 137 | 138 | ostream &operator<<(ostream &s, const Term &term) { 139 | s << term.index << ':' << getOpName(term.op) << '('; 140 | bool first = true; 141 | for (const auto &operand : term.getOperands()) { 142 | if (first) { 143 | first = false; 144 | } else { 145 | s << ','; 146 | } 147 | s << operand->index; 148 | } 149 | s << ')'; 150 | return s; 151 | } 152 | 153 | } // namespace eva 154 | -------------------------------------------------------------------------------- /eva/ir/term.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/attributes.h" 7 | #include "eva/ir/ops.h" 8 | #include "eva/ir/types.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace eva { 18 | 19 | class Program; 20 | 21 | class Term : public AttributeList, public std::enable_shared_from_this { 22 | public: 23 | using Ptr = std::shared_ptr; 24 | 25 | Term(Op opcode, Program &program); 26 | ~Term(); 27 | 28 | void addOperand(const Ptr &term); 29 | bool eraseOperand(const Ptr &term); 30 | bool replaceOperand(Ptr oldTerm, Ptr newTerm); 31 | void setOperands(std::vector o); 32 | std::size_t numOperands() const; 33 | Ptr operandAt(size_t i); 34 | const std::vector &getOperands() const; 35 | 36 | void replaceUsesWithIf(Ptr term, std::function); 37 | void replaceAllUsesWith(Ptr term); 38 | void replaceOtherUsesWith(Ptr term); 39 | 40 | std::size_t numUses(); 41 | std::vector getUses(); 42 | 43 | bool isInternal() const; 44 | 45 | const Op op; 46 | Program &program; 47 | 48 | // Unique index for this Term in the owning Program. Managed by Program 49 | // and used to index into TermMap instances. 50 | std::uint64_t index; 51 | 52 | friend std::ostream &operator<<(std::ostream &s, const Term &term); 53 | 54 | private: 55 | std::vector operands; // use->def chain (unmanaged pointers) 56 | std::vector uses; // def->use chain (managed pointers) 57 | 58 | void addUse(Term *term); 59 | bool eraseUse(Term *term); 60 | }; 61 | 62 | } // namespace eva 63 | -------------------------------------------------------------------------------- /eva/ir/term_map.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ir/program.h" 7 | #include "eva/ir/term.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace eva { 16 | 17 | class TermMapBase { 18 | public: 19 | TermMapBase(Program &p) : program(&p) { program->registerTermMap(this); } 20 | ~TermMapBase() { program->unregisterTermMap(this); } 21 | TermMapBase(const TermMapBase &other) : program(other.program) { 22 | program->registerTermMap(this); 23 | } 24 | TermMapBase &operator=(const TermMapBase &other) = default; 25 | 26 | friend class Program; 27 | 28 | protected: 29 | void init() { program->initTermMap(*this); } 30 | 31 | std::uint64_t getIndex(const Term &term) const { return term.index; } 32 | 33 | private: 34 | virtual void resize(std::size_t size) = 0; 35 | 36 | Program *program; 37 | }; 38 | 39 | template class TermMap : TermMapBase { 40 | public: 41 | TermMap(Program &p) : TermMapBase(p) { init(); } 42 | 43 | TValue &operator[](const Term &term) { return values.at(getIndex(term)); } 44 | 45 | const TValue &operator[](const Term &term) const { 46 | return values.at(getIndex(term)); 47 | } 48 | 49 | TValue &operator[](const Term::Ptr &term) { return this->operator[](*term); } 50 | 51 | const TValue &operator[](const Term::Ptr &term) const { 52 | return this->operator[](*term); 53 | } 54 | 55 | void clear() { values.assign(values.size(), {}); } 56 | 57 | private: 58 | void resize(std::size_t size) override { values.resize(size); } 59 | 60 | std::deque values; 61 | }; 62 | 63 | template <> class TermMap : TermMapBase { 64 | public: 65 | TermMap(Program &p) : TermMapBase(p) { init(); } 66 | 67 | std::vector::reference operator[](const Term &term) { 68 | return values.at(getIndex(term)); 69 | } 70 | 71 | bool operator[](const Term &term) const { return values.at(getIndex(term)); } 72 | 73 | std::vector::reference operator[](const Term::Ptr &term) { 74 | return this->operator[](*term); 75 | } 76 | 77 | bool operator[](const Term::Ptr &term) const { 78 | return this->operator[](*term); 79 | } 80 | 81 | void clear() { values.assign(values.size(), false); } 82 | 83 | private: 84 | void resize(std::size_t size) override { values.resize(size); } 85 | 86 | std::vector values; 87 | }; 88 | 89 | template class TermMapOptional : TermMapBase { 90 | public: 91 | TermMapOptional(Program &p) : TermMapBase(p) { init(); } 92 | 93 | TOptionalValue &operator[](const Term &term) { 94 | auto &value = values.at(getIndex(term)); 95 | if (!value.has_value()) { 96 | value.emplace(); 97 | } 98 | return *value; 99 | } 100 | 101 | TOptionalValue &operator[](const Term::Ptr &term) { 102 | return this->operator[](*term); 103 | } 104 | 105 | TOptionalValue &at(const Term &term) { 106 | return values.at(getIndex(term)).value(); 107 | } 108 | 109 | TOptionalValue &at(const Term::Ptr &term) { return this->at(*term); } 110 | 111 | bool has(const Term &term) const { 112 | return values.at(getIndex(term)).has_value(); 113 | } 114 | 115 | bool has(const Term::Ptr &term) const { return has(*term); } 116 | 117 | void clear() { values.assign(values.size(), std::nullopt); } 118 | 119 | private: 120 | void resize(std::size_t size) override { values.resize(size); } 121 | 122 | std::deque> values; 123 | }; 124 | 125 | } // namespace eva 126 | -------------------------------------------------------------------------------- /eva/ir/types.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | 9 | namespace eva { 10 | 11 | #define EVA_TYPES \ 12 | X(Undef, 0) \ 13 | X(Cipher, 1) \ 14 | X(Raw, 2) \ 15 | X(Plain, 3) 16 | 17 | enum class Type : std::int32_t { 18 | #define X(type, code) type = code, 19 | EVA_TYPES 20 | #undef X 21 | }; 22 | 23 | inline std::string getTypeName(Type type) { 24 | switch (type) { 25 | #define X(type, code) \ 26 | case Type::type: \ 27 | return #type; 28 | EVA_TYPES 29 | #undef X 30 | default: 31 | throw std::runtime_error("Invalid type"); 32 | } 33 | } 34 | 35 | } // namespace eva 36 | -------------------------------------------------------------------------------- /eva/seal/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | target_sources(eva PRIVATE 5 | seal.cpp 6 | ) 7 | -------------------------------------------------------------------------------- /eva/seal/seal.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/seal/seal.h" 5 | #include "eva/common/program_traversal.h" 6 | #include "eva/common/valuation.h" 7 | #include "eva/seal/seal_executor.h" 8 | #include "eva/util/logging.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #ifdef EVA_USE_GALOIS 16 | #include "eva/common/multicore_program_traversal.h" 17 | #include "eva/util/galois.h" 18 | #endif 19 | 20 | using namespace std; 21 | 22 | namespace eva { 23 | 24 | SEALValuation SEALPublic::encrypt(const Valuation &inputs, 25 | const CKKSSignature &signature) { 26 | size_t slotCount = encoder.slot_count(); 27 | if (slotCount < signature.vecSize) { 28 | throw runtime_error("Vector size cannot be larger than slot count"); 29 | } 30 | if (slotCount % signature.vecSize != 0) { 31 | throw runtime_error("Vector size must exactly divide the slot count"); 32 | } 33 | 34 | SEALValuation sealInputs(context); 35 | for (auto &in : inputs) { 36 | 37 | // With multicore sealInputs is initialized first, so that multiple threads 38 | // can be used to encode and encrypt values into it at the same time without 39 | // making structural changes. 40 | #ifdef EVA_USE_GALOIS 41 | sealInputs[in.first] = {}; 42 | } 43 | 44 | // Start a second parallel loop to encrypt inputs. 45 | GaloisGuard galois; 46 | galois::do_all( 47 | galois::iterate(inputs), 48 | [&](auto &in) { 49 | #endif 50 | auto name = in.first; 51 | auto &v = in.second; 52 | auto vSize = v.size(); 53 | // TODO remove this check 54 | if (vSize != signature.vecSize) { 55 | throw runtime_error("Input size does not match program vector size"); 56 | } 57 | auto info = signature.inputs.at(name); 58 | 59 | auto ctxData = context.first_context_data(); 60 | for (size_t i = 0; i < info.level; ++i) { 61 | ctxData = ctxData->next_context_data(); 62 | } 63 | 64 | if (info.inputType == Type::Cipher || info.inputType == Type::Plain) { 65 | seal::Plaintext plain; 66 | 67 | if (vSize == 1) { 68 | encoder.encode(v[0], ctxData->parms_id(), pow(2.0, info.scale), 69 | plain); 70 | } else { 71 | vector vec(slotCount); 72 | assert(vSize <= slotCount); 73 | assert((slotCount % vSize) == 0); 74 | auto replicas = (slotCount / vSize); 75 | for (uint32_t r = 0; r < replicas; ++r) { 76 | for (uint64_t i = 0; i < vSize; ++i) { 77 | vec[(r * vSize) + i] = v[i]; 78 | } 79 | } 80 | encoder.encode(vec, ctxData->parms_id(), pow(2.0, info.scale), 81 | plain); 82 | } 83 | if (info.inputType == Type::Cipher) { 84 | seal::Ciphertext cipher; 85 | encryptor.encrypt(plain, cipher); 86 | sealInputs[name] = move(cipher); 87 | } else if (info.inputType == Type::Plain) { 88 | sealInputs[name] = move(plain); 89 | } 90 | } else { 91 | sealInputs[name] = std::shared_ptr( 92 | new DenseConstantValue(signature.vecSize, v)); 93 | } 94 | } 95 | #ifdef EVA_USE_GALOIS 96 | // Finish the parallel loop if using multicore support 97 | , 98 | galois::no_stats(), galois::loopname("EncryptInputs")); 99 | #endif 100 | 101 | return sealInputs; 102 | } 103 | 104 | SEALValuation SEALPublic::execute(Program &program, 105 | const SEALValuation &inputs) { 106 | #ifdef EVA_USE_GALOIS 107 | // Do multicore evaluation if multicore support is available 108 | GaloisGuard galois; 109 | MulticoreProgramTraversal programTraverse(program); 110 | #else 111 | // Otherwise fall back to singlecore evaluation 112 | ProgramTraversal programTraverse(program); 113 | #endif 114 | auto sealExecutor = SEALExecutor(program, context, encoder, encryptor, 115 | evaluator, galoisKeys, relinKeys); 116 | sealExecutor.setInputs(inputs); 117 | programTraverse.forwardPass(sealExecutor); 118 | 119 | SEALValuation encOutputs(context); 120 | sealExecutor.getOutputs(encOutputs); 121 | return encOutputs; 122 | } 123 | 124 | Valuation SEALSecret::decrypt(const SEALValuation &encOutputs, 125 | const CKKSSignature &signature) { 126 | Valuation outputs; 127 | std::vector tempVec; 128 | for (auto &out : encOutputs) { 129 | auto name = out.first; 130 | visit(Overloaded{[&](const seal::Ciphertext &cipher) { 131 | seal::Plaintext plain; 132 | decryptor.decrypt(cipher, plain); 133 | encoder.decode(plain, outputs[name]); 134 | }, 135 | [&](const seal::Plaintext &plain) { 136 | encoder.decode(plain, outputs[name]); 137 | }, 138 | [&](const std::shared_ptr &raw) { 139 | auto &scratch = tempVec; 140 | outputs[name] = raw->expand(scratch, signature.vecSize); 141 | }}, 142 | out.second); 143 | outputs.at(name).resize(signature.vecSize); 144 | } 145 | return outputs; 146 | } 147 | 148 | seal::SEALContext getSEALContext(const seal::EncryptionParameters ¶ms) { 149 | static unordered_map cache; 150 | 151 | // clean cache except for the required entry 152 | for (auto iter = cache.begin(); iter != cache.end();) { 153 | // accessing the context data increases the reference count by one 154 | // Another reference is incremented by cache entry 155 | if (iter->second.key_context_data().use_count() == 2 && 156 | iter->first != params) { 157 | iter = cache.erase(iter); 158 | } else { 159 | ++iter; 160 | } 161 | } 162 | 163 | // find SEALContext 164 | if (cache.count(params) != 0) { 165 | seal::SEALContext result = cache.at(params); 166 | return result; 167 | } else { 168 | auto result = cache.emplace(make_pair( 169 | params, seal::SEALContext(params, true, seal::sec_level_type::none))); 170 | return result.first->second; 171 | } 172 | } 173 | 174 | tuple, unique_ptr> 175 | generateKeys(const CKKSParameters &abstractParams) { 176 | vector logQs(abstractParams.primeBits.begin(), 177 | abstractParams.primeBits.end()); 178 | 179 | auto params = seal::EncryptionParameters(seal::scheme_type::ckks); 180 | params.set_poly_modulus_degree(abstractParams.polyModulusDegree); 181 | params.set_coeff_modulus( 182 | seal::CoeffModulus::Create(abstractParams.polyModulusDegree, logQs)); 183 | 184 | auto context = getSEALContext(params); 185 | 186 | seal::KeyGenerator keygen(context); 187 | vector rotationsVec(abstractParams.rotations.begin(), 188 | abstractParams.rotations.end()); 189 | 190 | seal::PublicKey public_key; 191 | seal::GaloisKeys galois_keys; 192 | seal::RelinKeys relin_keys; 193 | 194 | keygen.create_public_key(public_key); 195 | keygen.create_galois_keys(rotationsVec, galois_keys); 196 | keygen.create_relin_keys(relin_keys); 197 | 198 | auto secretCtx = make_unique(context, keygen.secret_key()); 199 | auto publicCtx = 200 | make_unique(context, public_key, galois_keys, relin_keys); 201 | 202 | return make_tuple(move(publicCtx), move(secretCtx)); 203 | } 204 | 205 | } // namespace eva 206 | -------------------------------------------------------------------------------- /eva/seal/seal.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ckks/ckks_parameters.h" 7 | #include "eva/ckks/ckks_signature.h" 8 | #include "eva/common/valuation.h" 9 | #include "eva/ir/program.h" 10 | #include "eva/serialization/seal.pb.h" 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace eva { 20 | 21 | using SchemeValue = std::variant>; 23 | 24 | class SEALValuation { 25 | public: 26 | SEALValuation(const seal::EncryptionParameters ¶ms) : params(params) {} 27 | SEALValuation(const seal::SEALContext &context) 28 | : params(context.key_context_data()->parms()) {} 29 | 30 | auto &operator[](const std::string &name) { return values[name]; } 31 | auto begin() { return values.begin(); } 32 | auto begin() const { return values.begin(); } 33 | auto end() { return values.end(); } 34 | auto end() const { return values.end(); } 35 | 36 | private: 37 | seal::EncryptionParameters params; 38 | std::unordered_map values; 39 | 40 | friend std::unique_ptr serialize(const SEALValuation &); 41 | }; 42 | 43 | std::unique_ptr deserialize(const msg::SEALValuation &); 44 | 45 | class SEALPublic { 46 | public: 47 | SEALPublic(seal::SEALContext ctx, seal::PublicKey pk, seal::GaloisKeys gk, 48 | seal::RelinKeys rk) 49 | : context(ctx), publicKey(pk), galoisKeys(gk), relinKeys(rk), 50 | encoder(ctx), encryptor(ctx, publicKey), evaluator(ctx) {} 51 | 52 | SEALValuation encrypt(const Valuation &inputs, 53 | const CKKSSignature &signature); 54 | 55 | SEALValuation execute(Program &program, const SEALValuation &inputs); 56 | 57 | private: 58 | seal::SEALContext context; 59 | 60 | seal::PublicKey publicKey; 61 | seal::GaloisKeys galoisKeys; 62 | seal::RelinKeys relinKeys; 63 | 64 | seal::CKKSEncoder encoder; 65 | seal::Encryptor encryptor; 66 | seal::Evaluator evaluator; 67 | 68 | friend std::unique_ptr serialize(const SEALPublic &); 69 | }; 70 | 71 | std::unique_ptr deserialize(const msg::SEALPublic &); 72 | 73 | class SEALSecret { 74 | public: 75 | SEALSecret(seal::SEALContext ctx, seal::SecretKey sk) 76 | : context(ctx), secretKey(sk), encoder(ctx), decryptor(ctx, secretKey) {} 77 | 78 | Valuation decrypt(const SEALValuation &encOutputs, 79 | const CKKSSignature &signature); 80 | 81 | private: 82 | seal::SEALContext context; 83 | 84 | seal::SecretKey secretKey; 85 | 86 | seal::CKKSEncoder encoder; 87 | seal::Decryptor decryptor; 88 | 89 | friend std::unique_ptr serialize(const SEALSecret &); 90 | }; 91 | 92 | std::unique_ptr deserialize(const msg::SEALSecret &); 93 | 94 | seal::SEALContext getSEALContext(const seal::EncryptionParameters ¶ms); 95 | 96 | std::tuple, std::unique_ptr> 97 | generateKeys(const CKKSParameters &abstractParams); 98 | 99 | } // namespace eva 100 | -------------------------------------------------------------------------------- /eva/serialization/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS known_type.proto eva.proto ckks.proto seal.proto) 5 | add_library(protobuf OBJECT ${PROTO_SRCS} ${PROTO_HDRS}) 6 | target_include_directories(protobuf PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) 7 | target_link_libraries(protobuf PUBLIC protobuf::libprotobuf) 8 | 9 | target_sources(eva PRIVATE 10 | $ 11 | known_type.cpp 12 | save_load.cpp 13 | eva_serialization.cpp 14 | ckks_serialization.cpp 15 | seal_serialization.cpp 16 | ) 17 | -------------------------------------------------------------------------------- /eva/serialization/ckks.proto: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | syntax = "proto3"; 5 | 6 | package eva.msg; 7 | 8 | message CKKSParameters { 9 | repeated uint32 prime_bits = 1; 10 | repeated int32 rotations = 2; 11 | uint32 poly_modulus_degree = 3; 12 | } 13 | 14 | message CKKSEncodingInfo { 15 | int32 input_type = 1; 16 | int32 scale = 2; 17 | int32 level = 3; 18 | } 19 | 20 | message CKKSSignature { 21 | int32 vec_size = 1; 22 | map inputs = 2; 23 | } 24 | -------------------------------------------------------------------------------- /eva/serialization/ckks_serialization.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/ckks/ckks_parameters.h" 5 | #include "eva/ckks/ckks_signature.h" 6 | #include "eva/serialization/ckks.pb.h" 7 | #include 8 | #include 9 | 10 | using namespace std; 11 | 12 | namespace eva { 13 | 14 | unique_ptr serialize(const CKKSParameters &obj) { 15 | // Create a new protobuf message 16 | auto msg = make_unique(); 17 | 18 | // Save the prime bit counts 19 | auto primeBitsMsg = msg->mutable_prime_bits(); 20 | primeBitsMsg->Reserve(obj.primeBits.size()); 21 | for (const auto &bits : obj.primeBits) { 22 | primeBitsMsg->Add(bits); 23 | } 24 | 25 | // Save the rotations that are needed 26 | auto rotationsMsg = msg->mutable_rotations(); 27 | rotationsMsg->Reserve(obj.rotations.size()); 28 | for (const auto &rotation : obj.rotations) { 29 | rotationsMsg->Add(rotation); 30 | } 31 | 32 | // Save the polynomial modulus degree 33 | msg->set_poly_modulus_degree(obj.polyModulusDegree); 34 | 35 | return msg; 36 | } 37 | 38 | unique_ptr deserialize(const msg::CKKSParameters &msg) { 39 | // Create a new CKKSParameters object 40 | auto obj = make_unique(); 41 | 42 | // Load the values from the protobuf message 43 | obj->primeBits = {msg.prime_bits().begin(), msg.prime_bits().end()}; 44 | obj->rotations = {msg.rotations().begin(), msg.rotations().end()}; 45 | obj->polyModulusDegree = msg.poly_modulus_degree(); 46 | 47 | return obj; 48 | } 49 | 50 | unique_ptr serialize(const CKKSSignature &obj) { 51 | // Create a new protobuf message 52 | auto msg = make_unique(); 53 | 54 | // Save the vector size 55 | msg->set_vec_size(obj.vecSize); 56 | 57 | // Save the input map 58 | auto &inputsMap = *msg->mutable_inputs(); 59 | for (auto &[key, info] : obj.inputs) { 60 | auto &infoMsg = inputsMap[key]; 61 | infoMsg.set_input_type(static_cast(info.inputType)); 62 | infoMsg.set_scale(info.scale); 63 | infoMsg.set_level(info.level); 64 | } 65 | 66 | return msg; 67 | } 68 | 69 | unique_ptr deserialize(const msg::CKKSSignature &msg) { 70 | // Create a new map of CKKSEncodingInfo objects and load the data 71 | unordered_map inputs; 72 | for (auto &[key, infoMsg] : msg.inputs()) { 73 | inputs.emplace(key, 74 | CKKSEncodingInfo(static_cast(infoMsg.input_type()), 75 | infoMsg.scale(), infoMsg.level())); 76 | } 77 | 78 | // Return a new CKKSSignature object 79 | return make_unique(msg.vec_size(), move(inputs)); 80 | } 81 | 82 | } // namespace eva 83 | -------------------------------------------------------------------------------- /eva/serialization/eva.proto: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | syntax = "proto3"; 5 | 6 | package eva.msg; 7 | 8 | message Term { 9 | uint32 op = 1; 10 | // Absolute indices to list of terms 11 | repeated uint64 operands = 2; 12 | repeated Attribute attributes = 3; 13 | } 14 | 15 | message ConstantValue { 16 | uint32 size = 1; 17 | // If sparse_indices is set then values are interpreted as a sparse set of values 18 | // Otherwise values is interpreted as dense with broadcasting semantics and size must divide vec_size 19 | // If values is empty then the whole constant is zero 20 | repeated double values = 2; 21 | repeated uint32 sparse_indices = 3; 22 | } 23 | 24 | message Attribute { 25 | uint32 key = 1; 26 | oneof value { 27 | uint32 uint32 = 2; 28 | sint32 int32 = 3; 29 | uint32 type = 4; 30 | ConstantValue constant_value = 5; 31 | } 32 | } 33 | 34 | message TermName { 35 | uint64 term = 1; 36 | string name = 2; 37 | } 38 | 39 | message Program { 40 | uint32 ir_version = 1; 41 | string name = 2; 42 | uint32 vec_size = 3; 43 | repeated Term terms = 4; 44 | repeated TermName inputs = 5; 45 | repeated TermName outputs = 6; 46 | } 47 | -------------------------------------------------------------------------------- /eva/serialization/eva_format_version.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | namespace eva { 9 | 10 | // Bump the version for any changes that break serialization 11 | const std::int32_t EVA_FORMAT_VERSION = 2; 12 | 13 | } // namespace eva 14 | -------------------------------------------------------------------------------- /eva/serialization/known_type.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/serialization/known_type.h" 5 | 6 | using namespace std; 7 | 8 | namespace eva { 9 | 10 | namespace { 11 | 12 | inline void dispatchKnownTypeDeserialize(KnownType &obj, 13 | const msg::KnownType &msg) { 14 | // Try loading msg until the correct type is found 15 | EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::Program); 16 | EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::CKKSParameters); 17 | EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::CKKSSignature); 18 | EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::SEALValuation); 19 | EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::SEALPublic); 20 | EVA_KNOWN_TYPE_TRY_DESERIALIZE(msg::SEALSecret); 21 | 22 | // This is not a known type 23 | throw runtime_error("Unknown inner message type " + 24 | msg.contents().type_url()); 25 | } 26 | 27 | } // namespace 28 | 29 | KnownType deserialize(const msg::KnownType &msg) { 30 | KnownType obj; 31 | dispatchKnownTypeDeserialize(obj, msg); 32 | return obj; 33 | } 34 | 35 | } // namespace eva 36 | -------------------------------------------------------------------------------- /eva/serialization/known_type.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/ckks/ckks_parameters.h" 7 | #include "eva/ir/program.h" 8 | #include "eva/seal/seal.h" 9 | #include "eva/serialization/known_type.pb.h" 10 | #include "eva/util/overloaded.h" 11 | #include 12 | #include 13 | #include 14 | 15 | #define EVA_KNOWN_TYPE_TRY_DESERIALIZE(MsgType) \ 16 | do { \ 17 | if (msg.contents().Is()) { \ 18 | MsgType inner; \ 19 | if (!msg.contents().UnpackTo(&inner)) { \ 20 | throw std::runtime_error("Unpacking inner message failed"); \ 21 | } \ 22 | obj = deserialize(inner); \ 23 | return; \ 24 | } \ 25 | } while (false) 26 | 27 | namespace eva { 28 | 29 | // Represents any serializable EVA object 30 | using KnownType = 31 | std::variant, std::unique_ptr, 32 | std::unique_ptr, std::unique_ptr, 33 | std::unique_ptr, std::unique_ptr>; 34 | 35 | KnownType deserialize(const msg::KnownType &msg); 36 | 37 | } // namespace eva 38 | -------------------------------------------------------------------------------- /eva/serialization/known_type.proto: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | syntax = "proto3"; 5 | 6 | package eva.msg; 7 | 8 | import "google/protobuf/any.proto"; 9 | 10 | message KnownType { 11 | google.protobuf.Any contents = 1; 12 | string creator = 2; 13 | } 14 | -------------------------------------------------------------------------------- /eva/serialization/save_load.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/serialization/save_load.h" 5 | 6 | using namespace std; 7 | 8 | namespace eva { 9 | 10 | KnownType load(istream &in) { 11 | msg::KnownType msg; 12 | if (msg.ParseFromIstream(&in)) { 13 | return deserialize(msg); 14 | } else { 15 | throw runtime_error("Could not parse message"); 16 | } 17 | } 18 | 19 | KnownType loadFromFile(const string &path) { 20 | ifstream in(path); 21 | if (in.fail()) { 22 | throw runtime_error("Could not open file"); 23 | } 24 | return load(in); 25 | } 26 | 27 | KnownType loadFromString(const string &str) { 28 | msg::KnownType msg; 29 | if (msg.ParseFromString(str)) { 30 | return deserialize(msg); 31 | } else { 32 | throw runtime_error("Could not parse message"); 33 | } 34 | } 35 | 36 | } // namespace eva 37 | -------------------------------------------------------------------------------- /eva/serialization/save_load.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "eva/serialization/known_type.h" 7 | #include "eva/version.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace eva { 14 | 15 | KnownType load(std::istream &in); 16 | KnownType loadFromFile(const std::string &path); 17 | KnownType loadFromString(const std::string &str); 18 | 19 | template T load(std::istream &in) { return std::get(load(in)); } 20 | 21 | template T loadFromFile(const std::string &path) { 22 | return std::get(loadFromFile(path)); 23 | } 24 | 25 | template T loadFromString(const std::string &str) { 26 | return std::get(loadFromString(str)); 27 | } 28 | 29 | namespace detail { 30 | template void serializeKnownType(const T &obj, msg::KnownType &msg) { 31 | auto inner = serialize(obj); 32 | msg.set_creator("EVA " + version()); 33 | msg.mutable_contents()->PackFrom(*inner); 34 | } 35 | } // namespace detail 36 | 37 | template void save(const T &obj, std::ostream &out) { 38 | msg::KnownType msg; 39 | detail::serializeKnownType(obj, msg); 40 | if (!msg.SerializeToOstream(&out)) { 41 | throw std::runtime_error("Could not serialize message"); 42 | } 43 | } 44 | 45 | template void saveToFile(const T &obj, const std::string &path) { 46 | std::ofstream out(path); 47 | if (out.fail()) { 48 | throw std::runtime_error("Could not open file"); 49 | } 50 | save(obj, out); 51 | } 52 | 53 | template std::string saveToString(const T &obj) { 54 | msg::KnownType msg; 55 | detail::serializeKnownType(obj, msg); 56 | std::string str; 57 | if (msg.SerializeToString(&str)) { 58 | return str; 59 | } else { 60 | throw std::runtime_error("Could not serialize message"); 61 | } 62 | } 63 | 64 | } // namespace eva 65 | -------------------------------------------------------------------------------- /eva/serialization/seal.proto: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | syntax = "proto3"; 5 | 6 | package eva.msg; 7 | 8 | import "eva.proto"; 9 | 10 | message SEALObject { 11 | enum SEALType { 12 | UNKNOWN = 0; 13 | CIPHERTEXT = 1; 14 | PLAINTEXT = 2; 15 | SECRET_KEY = 3; 16 | PUBLIC_KEY = 4; 17 | GALOIS_KEYS = 5; 18 | RELIN_KEYS = 6; 19 | ENCRYPTION_PARAMETERS = 7; 20 | } 21 | SEALType seal_type = 1; 22 | bytes data = 2; 23 | } 24 | 25 | message SEALPublic { 26 | SEALObject encryption_parameters = 1; 27 | SEALObject public_key = 2; 28 | SEALObject galois_keys = 3; 29 | SEALObject relin_keys = 4; 30 | } 31 | 32 | message SEALSecret { 33 | SEALObject encryption_parameters = 1; 34 | SEALObject secret_key = 2; 35 | } 36 | 37 | message SEALValuation { 38 | SEALObject encryption_parameters = 1; 39 | map values = 2; 40 | map raw_values = 3; 41 | } 42 | -------------------------------------------------------------------------------- /eva/serialization/seal_serialization.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/seal/seal.h" 5 | #include "eva/util/overloaded.h" 6 | #include 7 | #include 8 | #include 9 | 10 | using namespace std; 11 | 12 | namespace eva { 13 | 14 | using SEALObject = msg::SEALObject; 15 | 16 | template auto getSEALTypeTag(); 17 | 18 | template <> auto getSEALTypeTag() { 19 | return SEALObject::CIPHERTEXT; 20 | } 21 | 22 | template <> auto getSEALTypeTag() { 23 | return SEALObject::PLAINTEXT; 24 | } 25 | 26 | template <> auto getSEALTypeTag() { 27 | return SEALObject::SECRET_KEY; 28 | } 29 | 30 | template <> auto getSEALTypeTag() { 31 | return SEALObject::PUBLIC_KEY; 32 | } 33 | 34 | template <> auto getSEALTypeTag() { 35 | return SEALObject::GALOIS_KEYS; 36 | } 37 | 38 | template <> auto getSEALTypeTag() { 39 | return SEALObject::RELIN_KEYS; 40 | } 41 | 42 | template <> auto getSEALTypeTag() { 43 | return SEALObject::ENCRYPTION_PARAMETERS; 44 | } 45 | 46 | template void serializeSEALType(const T &obj, SEALObject *msg) { 47 | // Get an upper bound for the size from SEAL; use default compression mode 48 | auto maxSize = obj.save_size(seal::Serialization::compr_mode_default); 49 | 50 | // Set up a buffer (std::string) 51 | // We allocate the string into a std::unique_ptr and eventually pass ownership 52 | // to the Protobuf message below 53 | auto data = make_unique(); 54 | data->resize(maxSize); 55 | 56 | // Note, since C++11 std::string is guaranteed to be contiguous 57 | auto actualSize = 58 | obj.save(reinterpret_cast(&data->operator[](0)), 59 | maxSize, seal::Serialization::compr_mode_default); 60 | data->resize(actualSize); 61 | 62 | // Change ownership of the data string to msg 63 | msg->set_allocated_data(data.release()); 64 | 65 | // Set the type tag to indicate the SEAL object type 66 | msg->set_seal_type(getSEALTypeTag()); 67 | } 68 | 69 | template void deserializeSEALType(T &obj, const SEALObject &msg) { 70 | // Unknown type; throw 71 | if (msg.seal_type() == SEALObject::UNKNOWN) { 72 | throw runtime_error("SEAL message type set to UNKNOWN"); 73 | } 74 | 75 | // Type of obj is incompatible with the type indicated in msg 76 | if (msg.seal_type() != getSEALTypeTag()) { 77 | throw runtime_error("SEAL message type mismatch"); 78 | } 79 | 80 | // Load the SEAL object 81 | obj.load(reinterpret_cast(msg.data().c_str()), 82 | msg.data().size()); 83 | } 84 | 85 | template 86 | void deserializeSEALTypeWithContext(const seal::SEALContext &context, T &obj, 87 | const SEALObject &msg) { 88 | // Most SEAL objects require the SEALContext for safe loading 89 | // Unknown type; throw 90 | if (msg.seal_type() == SEALObject::UNKNOWN) { 91 | throw runtime_error("SEAL message type set to UNKNOWN"); 92 | } 93 | 94 | // Type of obj is incompatible with the type indicated in msg 95 | if (msg.seal_type() != getSEALTypeTag()) { 96 | throw runtime_error("SEAL message type mismatch"); 97 | } 98 | 99 | // Load the SEAL object and check its validity against given context 100 | obj.load(context, 101 | reinterpret_cast(msg.data().c_str()), 102 | msg.data().size()); 103 | } 104 | 105 | unique_ptr deserialize(const msg::SEALValuation &msg) { 106 | // Deserialize a SEAL valuation: either plaintexts or ciphertexts 107 | // First need to load the encryption parameters and obtain the context 108 | seal::EncryptionParameters encParams; 109 | deserializeSEALType(encParams, msg.encryption_parameters()); 110 | auto context = getSEALContext(encParams); 111 | 112 | // Create the destination valuation and load the correct type 113 | auto obj = make_unique(encParams); 114 | for (const auto &entry : msg.values()) { 115 | auto &value = obj->operator[](entry.first); 116 | 117 | // Load the correct kind of object based on value 118 | switch (entry.second.seal_type()) { 119 | case SEALObject::CIPHERTEXT: { 120 | value = seal::Ciphertext(); 121 | deserializeSEALTypeWithContext(context, get(value), 122 | entry.second); 123 | break; 124 | } 125 | case SEALObject::PLAINTEXT: { 126 | value = seal::Plaintext(); 127 | deserializeSEALTypeWithContext(context, get(value), 128 | entry.second); 129 | break; 130 | } 131 | default: 132 | throw runtime_error("Not a ciphertext or plaintext"); 133 | } 134 | } 135 | 136 | // Deserialize the raw part of the valuation 137 | for (const auto &entry : msg.raw_values()) { 138 | obj->operator[](entry.first) = deserialize(entry.second); 139 | } 140 | 141 | return obj; 142 | } 143 | 144 | unique_ptr serialize(const SEALValuation &obj) { 145 | // Create the Protobuf message and save the encryption parameters 146 | auto msg = make_unique(); 147 | serializeSEALType(obj.params, msg->mutable_encryption_parameters()); 148 | // Serialize a SEAL valuation: either plaintexts or ciphertexts 149 | auto &valuesMsg = *msg->mutable_values(); 150 | auto &rawValuesMsg = *msg->mutable_raw_values(); 151 | for (const auto &entry : obj) { 152 | // Visit entry.second with an overloaded lambda function; we need to specify 153 | // handling for both possible data types (plaintexts and ciphertexts) 154 | visit(Overloaded{[&](const seal::Ciphertext &cipher) { 155 | serializeSEALType(cipher, &valuesMsg[entry.first]); 156 | }, 157 | [&](const seal::Plaintext &plain) { 158 | serializeSEALType(plain, &valuesMsg[entry.first]); 159 | }, 160 | [&](const std::shared_ptr raw) { 161 | raw->serialize(rawValuesMsg[entry.first]); 162 | }}, 163 | entry.second); 164 | } 165 | 166 | return msg; 167 | } 168 | 169 | unique_ptr serialize(const SEALPublic &obj) { 170 | // Serialize a SEALPublic object 171 | auto msg = make_unique(); 172 | 173 | // Save the encryption parameters 174 | serializeSEALType(obj.context.key_context_data()->parms(), 175 | msg->mutable_encryption_parameters()); 176 | 177 | // Save the different public keys 178 | serializeSEALType(obj.publicKey, msg->mutable_public_key()); 179 | serializeSEALType(obj.galoisKeys, msg->mutable_galois_keys()); 180 | serializeSEALType(obj.relinKeys, msg->mutable_relin_keys()); 181 | 182 | return msg; 183 | } 184 | 185 | unique_ptr deserialize(const msg::SEALPublic &msg) { 186 | // Deserialize a SEALPublic object 187 | // Load the encryption parameters and acquire a SEALContext; this is needed 188 | // for safe loading of the other objects 189 | seal::EncryptionParameters encParams; 190 | deserializeSEALType(encParams, msg.encryption_parameters()); 191 | auto context = getSEALContext(encParams); 192 | 193 | // Load the different public keys 194 | seal::PublicKey pk; 195 | deserializeSEALTypeWithContext(context, pk, msg.public_key()); 196 | seal::GaloisKeys gk; 197 | deserializeSEALTypeWithContext(context, gk, msg.galois_keys()); 198 | seal::RelinKeys rk; 199 | deserializeSEALTypeWithContext(context, rk, msg.relin_keys()); 200 | 201 | return make_unique(context, pk, gk, rk); 202 | } 203 | 204 | unique_ptr serialize(const SEALSecret &obj) { 205 | // Serialize a SEALSecret object 206 | auto msg = make_unique(); 207 | 208 | // Save the encryption parameters 209 | serializeSEALType(obj.context.key_context_data()->parms(), 210 | msg->mutable_encryption_parameters()); 211 | 212 | // Save the secret key 213 | serializeSEALType(obj.secretKey, msg->mutable_secret_key()); 214 | return msg; 215 | } 216 | 217 | unique_ptr deserialize(const msg::SEALSecret &msg) { 218 | // Deserialize a SEALSecret object 219 | // Load the encryption parameters and acquire a SEALContext; this is needed 220 | // for safe loading of the other objects 221 | seal::EncryptionParameters encParams; 222 | deserializeSEALType(encParams, msg.encryption_parameters()); 223 | auto context = getSEALContext(encParams); 224 | 225 | // Load the secret key 226 | seal::SecretKey sk; 227 | deserializeSEALTypeWithContext(context, sk, msg.secret_key()); 228 | 229 | return make_unique(context, sk); 230 | } 231 | 232 | } // namespace eva 233 | -------------------------------------------------------------------------------- /eva/util/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | if(USE_GALOIS) 5 | target_sources(eva PRIVATE 6 | galois.cpp 7 | ) 8 | endif() 9 | 10 | target_sources(eva PRIVATE 11 | logging.cpp 12 | ) 13 | -------------------------------------------------------------------------------- /eva/util/galois.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/util/galois.h" 5 | 6 | namespace eva { 7 | 8 | GaloisGuard::GaloisGuard() { 9 | // Galois doesn't exit quietly, so lets just leak it instead. 10 | // It was also crashing on exit when this decision was made. 11 | static galois::SharedMemSys *galois = new galois::SharedMemSys(); 12 | } 13 | 14 | } // namespace eva 15 | -------------------------------------------------------------------------------- /eva/util/galois.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | 9 | namespace eva { 10 | 11 | struct GaloisGuard { 12 | GaloisGuard(); 13 | }; 14 | 15 | } // namespace eva 16 | -------------------------------------------------------------------------------- /eva/util/logging.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "eva/util/logging.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace eva { 11 | 12 | int getUserVerbosity() { 13 | static int userVerbosity = 0; 14 | static bool parsed = false; 15 | if (!parsed) { 16 | if (const char *envP = std::getenv("EVA_VERBOSITY")) { 17 | auto envStr = std::string(envP); 18 | try { 19 | userVerbosity = std::stoi(envStr); 20 | } catch (std::invalid_argument e) { 21 | std::transform(envStr.begin(), envStr.end(), envStr.begin(), ::tolower); 22 | if (envStr == "silent") { 23 | userVerbosity = 0; 24 | } else if (envStr == "info") { 25 | userVerbosity = (int)Verbosity::Info; 26 | } else if (envStr == "debug") { 27 | userVerbosity = (int)Verbosity::Debug; 28 | } else if (envStr == "trace") { 29 | userVerbosity = (int)Verbosity::Trace; 30 | } else { 31 | std::cerr << "Invalid verbosity EVA_VERBOSITY=" << envStr 32 | << " Defaulting to silent.\n"; 33 | userVerbosity = 0; 34 | } 35 | } 36 | } 37 | parsed = true; 38 | } 39 | return userVerbosity; 40 | } 41 | 42 | void log(Verbosity verbosity, const char *fmt, ...) { 43 | if (getUserVerbosity() >= (int)verbosity) { 44 | printf("EVA: "); 45 | va_list args; 46 | va_start(args, fmt); 47 | vprintf(fmt, args); 48 | va_end(args); 49 | printf("\n"); 50 | fflush(stdout); 51 | } 52 | } 53 | 54 | bool verbosityAtLeast(Verbosity verbosity) { 55 | return getUserVerbosity() >= (int)verbosity; 56 | } 57 | 58 | void warn(const char *fmt, ...) { 59 | fprintf(stderr, "WARNING: "); 60 | va_list args; 61 | va_start(args, fmt); 62 | vfprintf(stderr, fmt, args); 63 | va_end(args); 64 | fprintf(stderr, "\n"); 65 | fflush(stderr); 66 | } 67 | 68 | } // namespace eva 69 | -------------------------------------------------------------------------------- /eva/util/logging.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace eva { 11 | 12 | enum class Verbosity { 13 | Info = 1, 14 | Debug = 2, 15 | Trace = 3, 16 | }; 17 | 18 | void log(Verbosity verbosity, const char *fmt, ...); 19 | bool verbosityAtLeast(Verbosity verbosity); 20 | void warn(const char *fmt, ...); 21 | 22 | } // namespace eva 23 | -------------------------------------------------------------------------------- /eva/util/overloaded.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | namespace eva { 9 | 10 | // The "Overloaded" trick to create convenient overloaded function objects for 11 | // use in std::visit. 12 | template struct Overloaded : Ts... { 13 | // Bring the various operator() overloads to this namespace 14 | using Ts::operator()...; 15 | }; 16 | 17 | // Add a user-defined deduction guide for the class template 18 | template Overloaded(Ts...) -> Overloaded; 19 | 20 | } // namespace eva 21 | -------------------------------------------------------------------------------- /eva/version.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #include "version.h" 5 | 6 | using namespace std; 7 | 8 | namespace eva { 9 | 10 | string version() { return EVA_VERSION_STR; } 11 | 12 | } // namespace eva 13 | -------------------------------------------------------------------------------- /eva/version.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. All rights reserved. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | 8 | namespace eva { 9 | 10 | std::string version(); 11 | 12 | } 13 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | *.eva 2 | *.evaparams 3 | *.evasignature 4 | *.sealpublic 5 | *.sealsecret 6 | *.sealvals 7 | *_encrypted.png 8 | *_reference.png -------------------------------------------------------------------------------- /examples/baboon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/EVA/7ea9d34b0380c25c03fe529fe9a755a6477144ce/examples/baboon.png -------------------------------------------------------------------------------- /examples/image_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from eva import EvaProgram, Input, Output, evaluate 5 | from eva.ckks import CKKSCompiler 6 | from eva.seal import generate_keys 7 | from eva.metric import valuation_mse 8 | from PIL import Image 9 | import numpy as np 10 | 11 | def convolution(image, width, filter): 12 | for i in range(len(filter)): 13 | for j in range(len(filter[0])): 14 | rotated = image << i * width + j 15 | partial = rotated * filter[i][j] 16 | if i == 0 and j == 0: 17 | convolved = partial 18 | else: 19 | convolved += partial 20 | return convolved 21 | 22 | def convolutionXY(image, width, filter): 23 | for i in range(len(filter)): 24 | for j in range(len(filter[0])): 25 | rotated = image << (i * width + j) 26 | horizontal = rotated * filter[i][j] 27 | vertical = rotated * filter[j][i] 28 | if i == 0 and j == 0: 29 | Ix = horizontal 30 | Iy = vertical 31 | else: 32 | Ix += horizontal 33 | Iy += vertical 34 | return Ix, Iy 35 | 36 | h = 64 37 | w = 64 38 | 39 | sobel = EvaProgram('sobel', vec_size=h*w) 40 | with sobel: 41 | image = Input('image') 42 | 43 | sobel_filter = [ 44 | [-1, 0, 1], 45 | [-2, 0, 2], 46 | [-1, 0, 1]] 47 | 48 | a1 = 2.2137874823876622 49 | a2 = -1.0984324107372518 50 | a3 = 0.17254603006834726 51 | 52 | conv_hor, conv_ver = convolutionXY(image, w, sobel_filter) 53 | 54 | conv_hor2 = conv_hor**2 55 | conv_ver2 = conv_ver**2 56 | dsq = conv_hor2 + conv_ver2 57 | dsq2 = dsq * dsq 58 | dsq3 = dsq2 * dsq 59 | 60 | Output('image', dsq * a1 + dsq2 * a2 + dsq3 * a3) 61 | 62 | sobel.set_input_scales(25) 63 | sobel.set_output_ranges(10) 64 | 65 | harris = EvaProgram('harris', vec_size=h*w) 66 | with harris: 67 | image = Input('image') 68 | 69 | sobel_filter = [ 70 | [-1, 0, 1], 71 | [-2, 0, 2], 72 | [-1, 0, 1]] 73 | pool = [ 74 | [1, 1, 1], 75 | [1, 1, 1], 76 | [1, 1, 1]] 77 | 78 | c = 0.04 79 | 80 | Ix, Iy = convolutionXY(image, w, sobel_filter) 81 | 82 | Ixx = Ix**2 83 | Iyy = Iy**2 84 | Ixy = Ix * Iy 85 | 86 | #FIX: masking may be needed here (to handle boundaries) 87 | 88 | Sxx = convolution(Ixx, w, pool) 89 | Syy = convolution(Iyy, w, pool) 90 | Sxy = convolution(Ixy, w, pool) 91 | 92 | SxxSyy = Sxx * Syy 93 | SxySxy = Sxy * Sxy 94 | det = SxxSyy - SxySxy 95 | trace = Sxx + Syy 96 | 97 | Output('image', det - trace**2 * c) 98 | 99 | harris.set_input_scales(30) 100 | harris.set_output_ranges(20) 101 | 102 | def read_input_image(): 103 | image = Image.open('baboon.png').convert('L') 104 | image_data = [x / 255.0 for x in list(image.getdata())] 105 | return {'image': image_data} 106 | 107 | def write_output_image(outputs, tag): 108 | enc_result_image = Image.new('L', (w, h)) 109 | enc_result_image.putdata([x * 255.0 for x in outputs['image'][0:h*w]]) 110 | enc_result_image.save(f'baboon_{tag}.png', "PNG") 111 | 112 | if __name__ == "__main__": 113 | inputs = read_input_image() 114 | 115 | for prog in [sobel, harris]: 116 | print(f'Compiling {prog.name}') 117 | 118 | compiler = CKKSCompiler() 119 | compiled, params, signature = compiler.compile(prog) 120 | public_ctx, secret_ctx = generate_keys(params) 121 | enc_inputs = public_ctx.encrypt(inputs, signature) 122 | enc_outputs = public_ctx.execute(compiled, enc_inputs) 123 | outputs = secret_ctx.decrypt(enc_outputs, signature) 124 | 125 | write_output_image(outputs, f'{compiled.name}_encrypted') 126 | 127 | reference = evaluate(compiled, inputs) 128 | write_output_image(reference, f'{compiled.name}_reference') 129 | 130 | print('MSE', valuation_mse(outputs, reference)) 131 | print() 132 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow -------------------------------------------------------------------------------- /examples/serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from eva import EvaProgram, Input, Output, evaluate, save, load 5 | from eva.ckks import CKKSCompiler 6 | from eva.seal import generate_keys 7 | from eva.metric import valuation_mse 8 | import numpy as np 9 | 10 | ################################################# 11 | print('Compile time') 12 | 13 | poly = EvaProgram('Polynomial', vec_size=8) 14 | with poly: 15 | x = Input('x') 16 | Output('y', 3*x**2 + 5*x - 2) 17 | 18 | poly.set_output_ranges(20) 19 | poly.set_input_scales(20) 20 | 21 | compiler = CKKSCompiler() 22 | poly, params, signature = compiler.compile(poly) 23 | 24 | save(poly, 'poly.eva') 25 | save(params, 'poly.evaparams') 26 | save(signature, 'poly.evasignature') 27 | 28 | ################################################# 29 | print('Key generation time') 30 | 31 | params = load('poly.evaparams') 32 | 33 | public_ctx, secret_ctx = generate_keys(params) 34 | 35 | save(public_ctx, 'poly.sealpublic') 36 | save(secret_ctx, 'poly.sealsecret') 37 | 38 | ################################################# 39 | print('Runtime on client') 40 | 41 | signature = load('poly.evasignature') 42 | public_ctx = load('poly.sealpublic') 43 | 44 | inputs = { 45 | 'x': [i for i in range(signature.vec_size)] 46 | } 47 | encInputs = public_ctx.encrypt(inputs, signature) 48 | 49 | save(encInputs, 'poly_inputs.sealvals') 50 | 51 | ################################################# 52 | print('Runtime on server') 53 | 54 | poly = load('poly.eva') 55 | public_ctx = load('poly.sealpublic') 56 | encInputs = load('poly_inputs.sealvals') 57 | 58 | encOutputs = public_ctx.execute(poly, encInputs) 59 | 60 | save(encOutputs, 'poly_outputs.sealvals') 61 | 62 | ################################################# 63 | print('Back on client') 64 | 65 | secret_ctx = load('poly.sealsecret') 66 | encOutputs = load('poly_outputs.sealvals') 67 | 68 | outputs = secret_ctx.decrypt(encOutputs, signature) 69 | 70 | reference = evaluate(poly, inputs) 71 | print('Expected', reference) 72 | print('Got', outputs) 73 | print('MSE', valuation_mse(outputs, reference)) 74 | -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | # In-source build files 2 | setup.py 3 | setup.py.after_configure -------------------------------------------------------------------------------- /python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | configure_file(setup.py.in setup.py.after_configure) 5 | file(GENERATE OUTPUT setup.py INPUT ${CMAKE_CURRENT_BINARY_DIR}/setup.py.after_configure) 6 | 7 | add_subdirectory(eva) 8 | 9 | add_custom_target(python ALL DEPENDS eva_py) 10 | add_custom_command(TARGET python 11 | POST_BUILD 12 | COMMAND ${CMAKE_COMMAND} -E copy_directory 13 | ${CMAKE_CURRENT_SOURCE_DIR}/eva 14 | ${CMAKE_CURRENT_BINARY_DIR}/eva 15 | ) 16 | -------------------------------------------------------------------------------- /python/eva/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | pybind11_add_module(eva_py 5 | wrapper.cpp 6 | ) 7 | 8 | target_compile_features(eva_py PUBLIC cxx_std_17) 9 | target_link_libraries(eva_py PRIVATE eva) 10 | if (MSVC) 11 | target_link_libraries(eva_py PUBLIC bcrypt) 12 | endif() 13 | set_target_properties(eva_py PROPERTIES OUTPUT_NAME _eva) 14 | -------------------------------------------------------------------------------- /python/eva/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from ._eva import * 5 | import numbers 6 | import psutil 7 | 8 | # Find the number of CPU cores available to this process. This has to happen before Galois is initialized because it 9 | # messes with the CPU affinity of the process. 10 | _default_num_threads = len(psutil.Process().cpu_affinity()) 11 | # Initialize Galois here (trying to do it in the static initialization step of the native library hangs). 12 | _global_guard = _eva._GaloisGuard() 13 | # Set the default number of threads to use to match the cores. 14 | set_num_threads(_default_num_threads) 15 | 16 | _current_program = None 17 | def _curr(): 18 | """ Returns the EvaProgram that is currently in context """ 19 | global _current_program 20 | if _current_program == None: 21 | raise RuntimeError("No Program in context") 22 | return _current_program 23 | 24 | def _py_to_term(x, program): 25 | """ Maps supported types into EVA terms """ 26 | if isinstance(x, Expr): 27 | return x.term 28 | elif isinstance(x, list): 29 | return program._make_dense_constant(x) 30 | elif isinstance(x, numbers.Number): 31 | return program._make_uniform_constant(x) 32 | elif isinstance(x, Term): 33 | return x 34 | else: 35 | raise TypeError("No conversion to Term available for " + str(x)) 36 | 37 | def py_to_eva(x, program = None): 38 | """ Maps supported types into EVA terms. May be used in library functions 39 | to provide uniform support for Expr instances and python types that 40 | are convertible into constants in EVA programs. 41 | 42 | Parameters 43 | ---------- 44 | x : eva.Expr, EVA native Term, list or a number 45 | The value to be converted to an Expr 46 | program : EvaProgram, optional 47 | The program a new term is created in (if necessary). If None then 48 | the program currently in context is used (again if necessary). 49 | """ 50 | if isinstance(x, Expr): 51 | return x 52 | else: 53 | if program == None: 54 | program = _curr() 55 | return Expr(_py_to_term(x, program), program) 56 | 57 | class Expr(): 58 | """ Wrapper for EVA's native Term class. Provides operator overloads that 59 | create terms in the associated EvaProgram. 60 | 61 | Attributes 62 | ---------- 63 | term 64 | The EVA native term 65 | program : eva.EVAProgram 66 | The program the wrapped term is in 67 | """ 68 | 69 | def __init__(self, term, program): 70 | self.term = term 71 | self.program = program 72 | 73 | def __add__(self,other): 74 | """ Create a new addition term """ 75 | return Expr(self.program._make_term(Op.Add, [self.term, _py_to_term(other, self.program)]), self.program) 76 | 77 | def __radd__(self,other): 78 | """ Create a new addition term """ 79 | return Expr(self.program._make_term(Op.Add, [_py_to_term(other, self.program), self.term]), self.program) 80 | 81 | def __sub__(self,other): 82 | """ Create a new subtraction term """ 83 | return Expr(self.program._make_term(Op.Sub, [self.term, _py_to_term(other, self.program)]), self.program) 84 | 85 | def __rsub__(self,other): 86 | """ Create a new subtraction term """ 87 | return Expr(self.program._make_term(Op.Sub, [_py_to_term(other, self.program), self.term]), self.program) 88 | 89 | def __mul__(self,other): 90 | """ Create a new multiplication term """ 91 | return Expr(self.program._make_term(Op.Mul, [self.term, _py_to_term(other, self.program)]), self.program) 92 | 93 | def __rmul__(self,other): 94 | """ Create a new multiplication term """ 95 | return Expr(self.program._make_term(Op.Mul, [_py_to_term(other, self.program), self.term]), self.program) 96 | 97 | def __pow__(self,exponent): 98 | """ Create exponentiation as nested multiplication terms """ 99 | if exponent < 1: 100 | raise ValueError("exponent must be greater than zero, got " + exponent) 101 | result = self.term 102 | for i in range(exponent-1): 103 | result = self.program._make_term(Op.Mul, [result, self.term]) 104 | return Expr(result, self.program) 105 | 106 | def __lshift__(self,rotation): 107 | """ Create a left rotation term """ 108 | return Expr(self.program._make_left_rotation(self.term, rotation), self.program) 109 | 110 | def __rshift__(self,rotation): 111 | """ Create a right rotation term """ 112 | return Expr(self.program._make_right_rotation(self.term, rotation), self.program) 113 | 114 | def __neg__(self): 115 | """ Create a negation term """ 116 | return Expr(self.program._make_term(Op.Negate, [self.term]), self.program) 117 | 118 | class EvaProgram(Program): 119 | """ A wrapper for EVA's native Program class. Acts as a context manager to 120 | set the program the Input and Output free functions operate on. """ 121 | 122 | def __init__(self, name, vec_size): 123 | """ Create a new EvaProgram with a name and a vector size 124 | 125 | Parameters 126 | ---------- 127 | name : str 128 | The name of the program 129 | vec_size : int 130 | The number of elements in all values in the program 131 | Must be a power-of-two 132 | """ 133 | super().__init__(name, vec_size) 134 | 135 | def __enter__(self): 136 | global _current_program 137 | if _current_program != None: 138 | raise RuntimeError("There is already an EVA Program in context") 139 | _current_program = self 140 | 141 | def __exit__(self, exc_type, exc_value, exc_traceback): 142 | global _current_program 143 | if _current_program != self: 144 | raise RuntimeError("This program is not currently in context") 145 | _current_program = None 146 | 147 | def Input(name, is_encrypted=True): 148 | """ Create a new named input term in the current EvaProgram 149 | 150 | Parameters 151 | ---------- 152 | name : str 153 | The name of the input 154 | is_encrypted : bool, optional 155 | Whether this input should be encrypted or not (default: True) 156 | """ 157 | program = _curr() 158 | return Expr(program._make_input(name, Type.Cipher if is_encrypted else Type.Raw), program) 159 | 160 | def Output(name, expr): 161 | """ Create a new named output term in the current EvaProgram 162 | 163 | Parameters 164 | ---------- 165 | name : str 166 | The name of the output 167 | is_encrypted : bool, optional 168 | Whether this input should be encrypted or not (default: True) 169 | """ 170 | program = _curr() 171 | program._make_output(name, _py_to_term(expr, program)) 172 | -------------------------------------------------------------------------------- /python/eva/ckks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from .._eva._ckks import * 5 | -------------------------------------------------------------------------------- /python/eva/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as _np 5 | 6 | def valuation_mse(a,b): 7 | """ Calculate the total mean squared error between two valuations 8 | 9 | Parameters 10 | ---------- 11 | a,b : dict from names to list of numbers 12 | Must have the same structure 13 | """ 14 | if set(a.keys()) != set(b.keys()): 15 | raise ValueError("Valuations must have the same keys") 16 | mse = 0 17 | for k in a.keys(): 18 | mse += _np.mean((_np.array(a[k]) - _np.array(b[k]))**2) 19 | return mse / len(a) 20 | -------------------------------------------------------------------------------- /python/eva/seal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from .._eva._seal import * 5 | -------------------------------------------------------------------------------- /python/eva/std/numeric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | from eva import py_to_eva 4 | 5 | def horizontal_sum(x): 6 | """ Sum together all elements of a vector. The result is replicated in all 7 | elements of the returned vector. 8 | 9 | Parameters 10 | ---------- 11 | x : an EVA compatible type (see eva.py_to_eva) 12 | The vector to sum together 13 | """ 14 | 15 | x = py_to_eva(x) 16 | i = 1 17 | while i < x.program.vec_size: 18 | y = x << i 19 | x = x + y 20 | i <<= 1 21 | return x 22 | -------------------------------------------------------------------------------- /python/setup.py.in: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from setuptools import setup, find_packages 5 | from setuptools.dist import Distribution 6 | 7 | class BinaryDistribution(Distribution): 8 | """Distribution which always forces a binary package with platform name""" 9 | def has_ext_modules(foo): 10 | return True 11 | 12 | setup( 13 | name='eva', 14 | version='${PROJECT_VERSION}', 15 | author='Microsoft Research EVA compiler team', 16 | author_email='evacompiler@microsoft.com', 17 | description='Compiler for the Microsoft SEAL homomorphic encryption library', 18 | packages=find_packages('${CMAKE_CURRENT_BINARY_DIR}'), 19 | package_data={ 20 | 'eva': ['$'], 21 | }, 22 | distclass=BinaryDistribution, 23 | install_requires=[ 24 | 'psutil', 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /scripts/clang-format-all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Microsoft Corporation. All rights reserved. 4 | # Licensed under the MIT license. 5 | 6 | BASE_DIR=$(dirname "$0") 7 | PROJECT_ROOT_DIR=$BASE_DIR/../ 8 | shopt -s globstar 9 | clang-format -i $PROJECT_ROOT_DIR/eva/**/*.h 10 | clang-format -i $PROJECT_ROOT_DIR/eva/**/*.cpp 11 | -------------------------------------------------------------------------------- /tests/all.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | from bug_fixes import * 5 | from features import * 6 | from large_programs import * 7 | from std import * 8 | 9 | if __name__ == '__main__': 10 | unittest.main() -------------------------------------------------------------------------------- /tests/bug_fixes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import unittest 5 | from common import * 6 | from eva import EvaProgram, Input, Output 7 | 8 | class BugFixes(EvaTestCase): 9 | 10 | def test_high_inner_term_scale(self): 11 | """ Test lazy waterline rescaler with a program causing a high inner term scale 12 | 13 | This test was added for a bug that was an interaction between 14 | rescaling not being inserted (causing high scales to be accumulated) 15 | and parameter selection not handling high scales in inner terms.""" 16 | 17 | prog = EvaProgram('HighInnerTermScale', vec_size=4) 18 | with prog: 19 | x1 = Input('x1') 20 | x2 = Input('x2') 21 | Output('y', x1*x1*x2) 22 | 23 | prog.set_output_ranges(20) 24 | prog.set_input_scales(60) 25 | 26 | self.assert_compiles_and_matches_reference(prog, config={'rescaler':'lazy_waterline'}) 27 | 28 | @unittest.skip('not fixed in SEAL yet') 29 | def test_large_and_small(self): 30 | """ Check that a ciphertext with very large and small values decodes accurately 31 | 32 | This test was added to track a common bug in CKKS implementations, 33 | where double precision floating points used in decoding fail to 34 | provide good accuracy for small values in ciphertexts when other 35 | very large values are present.""" 36 | 37 | prog = EvaProgram('LargeAndSmall', vec_size=4) 38 | with prog: 39 | x = Input('x') 40 | Output('y', pow(x,8)) 41 | 42 | prog.set_output_ranges(60) 43 | prog.set_input_scales(60) 44 | 45 | inputs = { 46 | 'x': [0,1,10,100] 47 | } 48 | 49 | self.assert_compiles_and_matches_reference(prog, inputs, config={'warn_vec_size':'false'}) 50 | 51 | def test_output_rescaled(self): 52 | """ Check that the lazy waterline policy rescales outputs 53 | 54 | This test was added for a bug where outputs could be returned with 55 | more primes in their modulus than necessary, which causes them to 56 | take more space when serialized.""" 57 | 58 | prog = EvaProgram('OutputRescaled', vec_size=4) 59 | with prog: 60 | x = Input('x') 61 | Output('y', x*x) 62 | 63 | prog.set_output_ranges(20) 64 | prog.set_input_scales(60) 65 | 66 | compiler = CKKSCompiler(config={'rescaler':'lazy_waterline', 'warn_vec_size':'false'}) 67 | prog, params, signature = compiler.compile(prog) 68 | self.assertEqual(params.prime_bits, [60, 20, 60, 60]) 69 | 70 | if __name__ == '__main__': 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import unittest 5 | from random import uniform 6 | from eva import evaluate 7 | from eva.ckks import CKKSCompiler 8 | from eva.seal import generate_keys 9 | from eva.metric import valuation_mse 10 | 11 | class EvaTestCase(unittest.TestCase): 12 | def assert_compiles_and_matches_reference(self, prog, inputs = None, config={}): 13 | if inputs == None: 14 | inputs = { name: [uniform(-2,2) for _ in range(prog.vec_size)] 15 | for name in prog.inputs } 16 | config['warn_vec_size'] = 'false' 17 | 18 | reference = evaluate(prog, inputs) 19 | 20 | compiler = CKKSCompiler(config = config) 21 | compiled_prog, params, signature = compiler.compile(prog) 22 | 23 | reference_compiled = evaluate(compiled_prog, inputs) 24 | ref_mse = valuation_mse(reference, reference_compiled) 25 | self.assertTrue(ref_mse < 0.0000000001, 26 | f"Mean squared error was {ref_mse}") 27 | 28 | public_ctx, secret_ctx = generate_keys(params) 29 | encInputs = public_ctx.encrypt(inputs, signature) 30 | encOutputs = public_ctx.execute(compiled_prog, encInputs) 31 | outputs = secret_ctx.decrypt(encOutputs, signature) 32 | 33 | he_mse = valuation_mse(outputs, reference) 34 | self.assertTrue(he_mse < 0.01, f"Mean squared error was {he_mse}") 35 | 36 | return (compiled_prog, params, signature) -------------------------------------------------------------------------------- /tests/features.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import unittest 5 | import tempfile 6 | import os 7 | from common import * 8 | from eva import EvaProgram, Input, Output, save, load 9 | 10 | class Features(EvaTestCase): 11 | def test_bin_ops(self): 12 | """ Test all binary ops """ 13 | 14 | for binOp in [lambda a, b: a + b, lambda a, b: a - b, lambda a, b: a * b]: 15 | for enc1 in [False, True]: 16 | for enc2 in [False, True]: 17 | prog = EvaProgram('BinOp', vec_size = 64) 18 | with prog: 19 | a = Input('a', enc1) 20 | b = Input('b', enc2) 21 | Output('y', binOp(a,b)) 22 | 23 | prog.set_output_ranges(20) 24 | prog.set_input_scales(30) 25 | 26 | self.assert_compiles_and_matches_reference(prog, 27 | config={'warn_vec_size':'false'}) 28 | 29 | def test_unary_ops(self): 30 | """ Test all unary ops """ 31 | 32 | for unOp in [lambda x: x, lambda x: -x, lambda x: x**3, lambda x: 42]: 33 | for enc in [False, True]: 34 | prog = EvaProgram('UnOp', vec_size = 64) 35 | with prog: 36 | x = Input('x', enc) 37 | Output('y', unOp(x)) 38 | 39 | prog.set_output_ranges(20) 40 | prog.set_input_scales(30) 41 | 42 | self.assert_compiles_and_matches_reference(prog, 43 | config={'warn_vec_size':'false'}) 44 | 45 | def test_rotations(self): 46 | """ Test all rotations """ 47 | 48 | for rotOp in [lambda x, r: x << r, lambda x, r: x >> r]: 49 | for enc in [False, True]: 50 | for rot in range(-2,2): 51 | prog = EvaProgram('RotOp', vec_size = 8) 52 | with prog: 53 | x = Input('x') 54 | Output('y', rotOp(x,rot)) 55 | 56 | prog.set_output_ranges(20) 57 | prog.set_input_scales(30) 58 | 59 | self.assert_compiles_and_matches_reference(prog, 60 | config={'warn_vec_size':'false'}) 61 | 62 | def test_unencrypted_computation(self): 63 | """ Test computation on unencrypted values """ 64 | 65 | for enc1 in [False, True]: 66 | for enc2 in [False, True]: 67 | prog = EvaProgram('UnencryptedInputs', vec_size=128) 68 | with prog: 69 | x1 = Input('x1', enc1) 70 | x2 = Input('x2', enc2) 71 | Output('y', pow(x2,3) + x1*x2) 72 | 73 | prog.set_output_ranges(20) 74 | prog.set_input_scales(30) 75 | 76 | self.assert_compiles_and_matches_reference(prog, 77 | config={'warn_vec_size':'false'}) 78 | 79 | def test_security_levels(self): 80 | """ Check that all supported security levels work """ 81 | 82 | security_levels = ['128','192','256'] 83 | quantum_safety = ['false','true'] 84 | 85 | for s in security_levels: 86 | for q in quantum_safety: 87 | prog = EvaProgram('SecurityLevel', vec_size=512) 88 | with prog: 89 | x = Input('x') 90 | Output('y', 5*x*x + 3*x + x<<12 + 10) 91 | 92 | prog.set_output_ranges(20) 93 | prog.set_input_scales(30) 94 | 95 | self.assert_compiles_and_matches_reference(prog, 96 | config={'security_level':s, 'quantum_safe':q, 'warn_vec_size':'false'}) 97 | 98 | @unittest.expectedFailure 99 | def test_unsupported_security_level(self): 100 | """ Check that unsupported security levels error out """ 101 | 102 | prog = EvaProgram('SecurityLevel', vec_size=512) 103 | with prog: 104 | x = Input('x') 105 | Output('y', 5*x*x + 3*x + x<<12 + 10) 106 | 107 | prog.set_output_ranges(20) 108 | prog.set_input_scales(30) 109 | 110 | self.assert_compiles_and_matches_reference(prog, 111 | config={'security_level':'1024','warn_vec_size':'false'}) 112 | 113 | def test_reduction_balancer(self): 114 | """ Check that reductions are balanced under balance_reductions=true """ 115 | 116 | prog = EvaProgram('ReductionTree', vec_size=16384) 117 | with prog: 118 | x1 = Input('x1') 119 | x2 = Input('x2') 120 | x3 = Input('x3') 121 | x4 = Input('x4') 122 | Output('y', (x1*(x2*(x3*x4))) + (x1+(x2+(x3+x4)))) 123 | 124 | prog.set_output_ranges(20) 125 | prog.set_input_scales(60) 126 | 127 | progc, params, signature = self.assert_compiles_and_matches_reference(prog, 128 | config={'rescaler':'always', 'balance_reductions':'false', 'warn_vec_size':'false'}) 129 | self.assertEqual(params.prime_bits, [60, 20, 60, 60, 60, 60]) 130 | 131 | progc, params, signature = self.assert_compiles_and_matches_reference(prog, 132 | config={'rescaler':'always', 'balance_reductions':'true', 'warn_vec_size':'false'}) 133 | self.assertEqual(params.prime_bits, [60, 20, 60, 60, 60]) 134 | 135 | def test_seal_no_throw_on_transparent(self): 136 | """ Check that SEAL is compiled with -DSEAL_THROW_ON_TRANSPARENT_CIPHERTEXT=OFF 137 | 138 | An HE compiler cannot in general work with transparent ciphertext detection 139 | turned on because it is not possible to statically detect all situations that 140 | result in them. For example, x1-x2 is transparent only if the user gives the 141 | same ciphertext as both inputs.""" 142 | 143 | prog = EvaProgram('Transparent', vec_size=4096) 144 | with prog: 145 | x = Input('x') 146 | Output('y', x-x+x*0) 147 | 148 | prog.set_output_ranges(20) 149 | prog.set_input_scales(30) 150 | 151 | self.assert_compiles_and_matches_reference(prog, 152 | config={'warn_vec_size':'false'}) 153 | 154 | def test_serialization(self): 155 | """ Test (de)serialization and check that results stay the same """ 156 | 157 | poly = EvaProgram('Polynomial', vec_size=4096) 158 | with poly: 159 | x = Input('x') 160 | Output('y', 3*x**2 + 5*x - 2) 161 | 162 | poly.set_output_ranges(20) 163 | poly.set_input_scales(30) 164 | 165 | inputs = { 166 | 'x': [i for i in range(poly.vec_size)] 167 | } 168 | reference = evaluate(poly, inputs) 169 | 170 | compiler = CKKSCompiler(config={'warn_vec_size':'false'}) 171 | poly, params, signature = compiler.compile(poly) 172 | 173 | with tempfile.TemporaryDirectory() as tmp_dir: 174 | tmp_path = lambda x: os.path.join(tmp_dir, x) 175 | 176 | save(poly, tmp_path('poly.eva')) 177 | save(params, tmp_path('poly.evaparams')) 178 | save(signature, tmp_path('poly.evasignature')) 179 | 180 | # Key generation time 181 | 182 | params = load(tmp_path('poly.evaparams')) 183 | 184 | public_ctx, secret_ctx = generate_keys(params) 185 | 186 | save(public_ctx, tmp_path('poly.sealpublic')) 187 | save(secret_ctx, tmp_path('poly.sealsecret')) 188 | 189 | # Runtime on client 190 | 191 | signature = load(tmp_path('poly.evasignature')) 192 | public_ctx = load(tmp_path('poly.sealpublic')) 193 | 194 | encInputs = public_ctx.encrypt(inputs, signature) 195 | 196 | save(encInputs, tmp_path('poly_inputs.sealvals')) 197 | 198 | # Runtime on server 199 | 200 | poly = load(tmp_path('poly.eva')) 201 | public_ctx = load(tmp_path('poly.sealpublic')) 202 | encInputs = load(tmp_path('poly_inputs.sealvals')) 203 | 204 | encOutputs = public_ctx.execute(poly, encInputs) 205 | 206 | save(encOutputs, tmp_path('poly_outputs.sealvals')) 207 | 208 | # Runtime back on client 209 | 210 | secret_ctx = load(tmp_path('poly.sealsecret')) 211 | encOutputs = load(tmp_path('poly_outputs.sealvals')) 212 | 213 | outputs = secret_ctx.decrypt(encOutputs, signature) 214 | 215 | reference_compiled = evaluate(poly, inputs) 216 | self.assertTrue(valuation_mse(reference, reference_compiled) < 0.0000000001) 217 | self.assertTrue(valuation_mse(outputs, reference) < 0.01) 218 | 219 | if __name__ == '__main__': 220 | unittest.main() -------------------------------------------------------------------------------- /tests/large_programs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import unittest 5 | import math 6 | from common import * 7 | from eva import EvaProgram, Input, Output 8 | 9 | class LargePrograms(EvaTestCase): 10 | def test_sobel_configs(self): 11 | """ Check accuracy of Sobel filter on random image with various compiler configurations """ 12 | 13 | def convolutionXY(image, width, filter): 14 | for i in range(len(filter)): 15 | for j in range(len(filter[0])): 16 | rotated = image << (i * width + j) 17 | horizontal = rotated * filter[i][j] 18 | vertical = rotated * filter[j][i] 19 | if i == 0 and j == 0: 20 | Ix = horizontal 21 | Iy = vertical 22 | else: 23 | Ix += horizontal 24 | Iy += vertical 25 | return Ix, Iy 26 | 27 | h = 90 28 | w = 90 29 | 30 | sobel = EvaProgram('sobel', vec_size=2**(math.ceil(math.log(h*w, 2)))) 31 | with sobel: 32 | image = Input('image') 33 | 34 | sobel_filter = [ 35 | [-1, 0, 1], 36 | [-2, 0, 2], 37 | [-1, 0, 1]] 38 | 39 | a1 = 2.2137874823876622 40 | a2 = -1.0984324107372518 41 | a3 = 0.17254603006834726 42 | 43 | conv_hor, conv_ver = convolutionXY(image, w, sobel_filter) 44 | x = conv_hor**2 + conv_ver**2 45 | Output('image', x * a1 + x**2 * a2 + x**3 * a3) 46 | 47 | sobel.set_input_scales(45) 48 | sobel.set_output_ranges(20) 49 | 50 | for rescaler in ['lazy_waterline','eager_waterline','always']: 51 | for balance_reductions in ['true','false']: 52 | self.assert_compiles_and_matches_reference(sobel, 53 | config={'rescaler':rescaler,'balance_reductions':balance_reductions}) 54 | 55 | def test_regression(self): 56 | """ Test batched compilation and execution of multiple linear regression programs """ 57 | 58 | linreg = EvaProgram('linear_regression', vec_size=2048) 59 | with linreg: 60 | p = 63 61 | 62 | x = [Input(f'x{i}') for i in range(p)] 63 | e = Input('e') 64 | b0 = 6.56 65 | b = [i * 0.732 for i in range(p)] 66 | 67 | y = e + b0 68 | for i in range(p): 69 | t = x[i] * b[i] 70 | y += t 71 | 72 | Output('y', y) 73 | 74 | linreg.set_input_scales(40) 75 | linreg.set_output_ranges(30) 76 | 77 | linreg_inputs = {'e': [(linreg.vec_size - i) * 0.001 for i in range(linreg.vec_size)]} 78 | for i in range(p): 79 | linreg_inputs[f'x{i}'] = [i * j * 0.01 for j in range(linreg.vec_size)] 80 | 81 | polyreg = EvaProgram('polynomial_regression', vec_size=4096) 82 | with polyreg: 83 | p = 4 84 | 85 | x = Input('x') 86 | e = Input('e') 87 | b0 = 6.56 88 | b = [i * 0.732 for i in range(p)] 89 | 90 | y = e + b0 91 | for i in range(p): 92 | x_i = x 93 | for j in range(i): 94 | x_i = x_i * x 95 | t = x_i * b[i] 96 | y += t 97 | 98 | Output('y', y) 99 | 100 | polyreg.set_input_scales(40) 101 | polyreg.set_output_ranges(30) 102 | 103 | polyreg_inputs = { 104 | 'x': [i * 0.01 for i in range(polyreg.vec_size)], 105 | 'e': [(polyreg.vec_size - i) * 0.001 for i in range(polyreg.vec_size)], 106 | } 107 | 108 | multireg = EvaProgram('multivariate_regression', vec_size=2048) 109 | with multireg: 110 | p = 63 111 | k = 4 112 | 113 | x = [Input(f'x{i}') for i in range(p)] 114 | e = [Input(f'e{j}') for j in range(k)] 115 | b0 = [j * 0.56 for j in range(k)] 116 | b = [[k * i * 0.732 for i in range(p)] for j in range(k)] 117 | 118 | y = [0 for j in range(k)] 119 | for j in range(k): 120 | y[j] = e[j] + b0[j] 121 | for i in range(p): 122 | t = x[i] * b[j][i] 123 | y[j] += t 124 | 125 | for j in range(k): 126 | Output(f'y{j}', y[j]) 127 | 128 | multireg.set_input_scales(40) 129 | multireg.set_output_ranges(30) 130 | 131 | multireg_inputs = {} 132 | for i in range(p): 133 | multireg_inputs[f'x{i}'] = [i * j * 0.01 for j in range(multireg.vec_size)] 134 | for j in range(k): 135 | multireg_inputs[f'e{j}'] = [(multireg.vec_size - i) * j * 0.001 for i in range(multireg.vec_size)] 136 | 137 | compiler = CKKSCompiler(config={'warn_vec_size':'false'}) 138 | 139 | for prog, inputs in [(linreg, linreg_inputs), (polyreg, polyreg_inputs), (multireg, multireg_inputs)]: 140 | compiled_prog, params, signature = compiler.compile(prog) 141 | public_ctx, secret_ctx = generate_keys(params) 142 | enc_inputs = public_ctx.encrypt(inputs, signature) 143 | enc_outputs = public_ctx.execute(compiled_prog, enc_inputs) 144 | outputs = secret_ctx.decrypt(enc_outputs, signature) 145 | reference = evaluate(compiled_prog, inputs) 146 | self.assertTrue(valuation_mse(outputs, reference) < 0.01) 147 | 148 | if __name__ == '__main__': 149 | unittest.main() -------------------------------------------------------------------------------- /tests/std.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT license. 3 | 4 | import unittest 5 | from common import * 6 | from eva import EvaProgram, Input, Output 7 | from eva.std.numeric import horizontal_sum 8 | 9 | class Std(EvaTestCase): 10 | def test_horizontal_sum(self): 11 | """ Test eva.std.numeric.horizontal_sum """ 12 | 13 | for enc in [True, False]: 14 | prog = EvaProgram('HorizontalSum', vec_size = 2048) 15 | with prog: 16 | x = Input('x', is_encrypted=enc) 17 | y = horizontal_sum(x) 18 | Output('y', y) 19 | 20 | prog.set_output_ranges(25) 21 | prog.set_input_scales(33) 22 | 23 | self.assert_compiles_and_matches_reference(prog, 24 | config={'warn_vec_size':'false'}) 25 | 26 | prog = EvaProgram('HorizontalSumConstant', vec_size = 2048) 27 | with prog: 28 | y = horizontal_sum([1 for _ in range(prog.vec_size)]) 29 | Output('y', y) 30 | 31 | prog.set_output_ranges(25) 32 | prog.set_input_scales(33) 33 | 34 | self.assert_compiles_and_matches_reference(prog, 35 | config={'warn_vec_size':'false'}) 36 | 37 | if __name__ == '__main__': 38 | unittest.main() --------------------------------------------------------------------------------