├── .clang-format
├── .gitignore
├── .gitmodules
├── CHANGES.md
├── CMakeLists.txt
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── azure-pipelines.yml
├── eva
├── CMakeLists.txt
├── ckks
│ ├── CMakeLists.txt
│ ├── always_rescaler.h
│ ├── ckks_compiler.h
│ ├── ckks_config.cpp
│ ├── ckks_config.h
│ ├── ckks_parameters.h
│ ├── ckks_signature.h
│ ├── eager_relinearizer.h
│ ├── eager_waterline_rescaler.h
│ ├── encode_inserter.h
│ ├── encryption_parameter_selector.h
│ ├── lazy_relinearizer.h
│ ├── lazy_waterline_rescaler.h
│ ├── levels_checker.h
│ ├── minimum_rescaler.h
│ ├── mod_switcher.h
│ ├── parameter_checker.h
│ ├── rescaler.h
│ ├── scales_checker.h
│ └── seal_lowering.h
├── common
│ ├── CMakeLists.txt
│ ├── constant_folder.h
│ ├── multicore_program_traversal.h
│ ├── program_traversal.h
│ ├── reduction_balancer.h
│ ├── reference_executor.cpp
│ ├── reference_executor.h
│ ├── rotation_keys_selector.h
│ ├── type_deducer.h
│ └── valuation.h
├── eva.cpp
├── eva.h
├── ir
│ ├── CMakeLists.txt
│ ├── attribute_list.cpp
│ ├── attribute_list.h
│ ├── attributes.cpp
│ ├── attributes.h
│ ├── constant_value.h
│ ├── ops.h
│ ├── program.cpp
│ ├── program.h
│ ├── term.cpp
│ ├── term.h
│ ├── term_map.h
│ └── types.h
├── seal
│ ├── CMakeLists.txt
│ ├── seal.cpp
│ ├── seal.h
│ └── seal_executor.h
├── serialization
│ ├── CMakeLists.txt
│ ├── ckks.proto
│ ├── ckks_serialization.cpp
│ ├── eva.proto
│ ├── eva_format_version.h
│ ├── eva_serialization.cpp
│ ├── known_type.cpp
│ ├── known_type.h
│ ├── known_type.proto
│ ├── save_load.cpp
│ ├── save_load.h
│ ├── seal.proto
│ └── seal_serialization.cpp
├── util
│ ├── CMakeLists.txt
│ ├── galois.cpp
│ ├── galois.h
│ ├── logging.cpp
│ ├── logging.h
│ └── overloaded.h
├── version.cpp
└── version.h
├── examples
├── .gitignore
├── baboon.png
├── image_processing.py
├── requirements.txt
└── serialization.py
├── python
├── .gitignore
├── CMakeLists.txt
├── eva
│ ├── CMakeLists.txt
│ ├── __init__.py
│ ├── ckks
│ │ └── __init__.py
│ ├── metric.py
│ ├── seal
│ │ └── __init__.py
│ ├── std
│ │ └── numeric.py
│ └── wrapper.cpp
└── setup.py.in
├── scripts
└── clang-format-all.sh
└── tests
├── all.py
├── bug_fixes.py
├── common.py
├── features.py
├── large_programs.py
└── std.py
/.clang-format:
--------------------------------------------------------------------------------
1 | AllowShortIfStatementsOnASingleLine: true
2 | BasedOnStyle: llvm
3 | FixNamespaceComments: true
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /.*/
2 | /build/
3 | /dist/
4 | eva.egg-info/
5 | __pycache__/
6 |
7 | # In-source build files
8 | Makefile
9 | CMakeCache.txt
10 | cmake_install.cmake
11 | CPackConfig.cmake
12 | CPackSourceConfig.cmake
13 | compile_commands.json
14 | CMakeFiles/
15 | *.pb.cc
16 | *.pb.h
17 | *.a
18 | *.so
19 | *.whl
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/pybind11"]
2 | path = third_party/pybind11
3 | url = https://github.com/pybind/pybind11
4 | ignore = untracked
5 | [submodule "third_party/Galois"]
6 | path = third_party/Galois
7 | url = https://github.com/IntelligentSoftwareSystems/Galois.git
8 | ignore = untracked
9 |
--------------------------------------------------------------------------------
/CHANGES.md:
--------------------------------------------------------------------------------
1 | # List of Changes
2 |
3 | ## Version 1.0.1
4 |
5 | ### Bug Fixes
6 |
7 | - Fixed issue where modulus switching was inserted for non-encrypted values [(PR 9)](https://github.com/microsoft/EVA/pull/9).
8 |
9 | ### Other
10 |
11 | - Validated compatibility with SEAL v3.6.4 and updated README.md.
12 |
13 | ## Version 1.0.0
14 |
15 | First open source release.
16 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT license.
3 |
4 | cmake_minimum_required(VERSION 3.13)
5 | cmake_policy(SET CMP0079 NEW)
6 | cmake_policy(SET CMP0076 NEW)
7 |
8 | set(CMAKE_CXX_STANDARD 17)
9 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
10 | set(CMAKE_CXX_EXTENSIONS OFF)
11 | set(CMAKE_POSITION_INDEPENDENT_CODE ON)
12 |
13 | project(eva
14 | VERSION 1.0.1
15 | LANGUAGES CXX
16 | )
17 |
18 | option(USE_GALOIS "Use the Galois library for multicore homomorphic evaluation" OFF)
19 | if(USE_GALOIS)
20 | message("Galois based multicore support enabled")
21 | add_definitions(-DEVA_USE_GALOIS)
22 | endif()
23 |
24 | find_package(SEAL 3.6 REQUIRED)
25 | find_package(Protobuf 3.6 REQUIRED)
26 | find_package(Python COMPONENTS Interpreter Development)
27 |
28 | if(NOT Python_VERSION_MAJOR EQUAL 3)
29 | message(FATAL_ERROR "EVA requires Python 3. Please ensure you have it
30 | installed in a location searched by CMake.")
31 | endif()
32 |
33 | add_subdirectory(third_party/pybind11)
34 | if(USE_GALOIS)
35 | add_subdirectory(third_party/Galois EXCLUDE_FROM_ALL)
36 | endif()
37 |
38 | add_subdirectory(eva)
39 | add_subdirectory(python)
40 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contributing
2 |
3 | The EVA project welcomes contributions and suggestions. Most contributions require you to
4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to,
5 | and actually do, grant us the rights to use your contribution. For details, visit
6 | https://cla.microsoft.com.
7 |
8 | Please submit all pull requests on the **contrib** branch. We will handle the final merge onto the main branch.
9 |
10 | When you submit a pull request, a CLA-bot will automatically determine whether you need
11 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
12 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) Microsoft Corporation.
2 |
3 | MIT License
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EVA - Compiler for Microsoft SEAL
2 |
3 | EVA is a compiler for homomorphic encryption, that automates away the parts that require cryptographic expertise.
4 | This gives you a simple way to write programs that operate on encrypted data without having access to the secret key.
5 |
6 | Think of EVA as the "C compiler" of the homomorphic world. Homomorphic computations written in EVA IR (Encrypted Vector Arithmetic Intermediate Representation) get compiled to the "assembly" of the homomorphic encryption library API. Just like C compilers free you from tricky tasks like register allocation, EVA frees you from *encryption parameter selection, rescaling insertion, relinearization*...
7 |
8 | EVA targets [Microsoft SEAL](https://github.com/microsoft/SEAL) — the industry leading library for fully-homomorphic encryption — and currently supports the CKKS scheme for deep computations on encrypted approximate fixed-point arithmetic.
9 |
10 | ## Getting Started
11 |
12 | EVA is a native library written in C++17 with bindings for Python. Both Linux and Windows are supported. The instructions below show how to get started with EVA on Ubuntu. For building on Windows [EVA's Azure Pipelines script](azure-pipelines.yml) is a useful reference.
13 |
14 | ### Installing Dependencies
15 |
16 | To install dependencies on Ubuntu 20.04:
17 | ```
18 | sudo apt install cmake libboost-all-dev libprotobuf-dev protobuf-compiler
19 | ```
20 |
21 | Clang is recommended for compilation, as SEAL is faster when compiled with it. To install clang and set it as default:
22 | ```
23 | sudo apt install clang
24 | sudo update-alternatives --install /usr/bin/cc cc /usr/bin/clang 100
25 | sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/clang++ 100
26 | ```
27 |
28 | Next install Microsoft SEAL version 3.6:
29 | ```
30 | git clone -b v3.6.4 https://github.com/microsoft/SEAL.git
31 | cd SEAL
32 | cmake -DSEAL_THROW_ON_TRANSPARENT_CIPHERTEXT=OFF .
33 | make -j
34 | sudo make install
35 | ```
36 | *Note that SEAL has to be installed with transparent ciphertext checking turned off, as it is not possible in general to statically ensure a program will not produce a transparent ciphertext. This does not affect the security of ciphertexts encrypted with SEAL.*
37 |
38 | ### Building and Installing EVA
39 |
40 | #### Building EVA
41 |
42 | EVA builds with CMake version ≥ 3.13:
43 | ```
44 | git submodule update --init
45 | cmake .
46 | make -j
47 | ```
48 | The build process creates a `setup.py` file in `python/`. To install the package for development with PIP:
49 | ```
50 | python3 -m pip install -e python/
51 | ```
52 | To create a Python Wheel package for distribution in `dist/`:
53 | ```
54 | python3 python/setup.py bdist_wheel --dist-dir='.'
55 | ```
56 |
57 | To check that the installed Python package is working correctly, run all tests with:
58 | ```
59 | python3 tests/all.py
60 | ```
61 |
62 | EVA does not yet support installing the native library for use in other CMake projects (contributions very welcome).
63 |
64 | #### Multicore Support
65 |
66 | EVA features highly scalable multicore support using the [Galois library](https://github.com/IntelligentSoftwareSystems/Galois). It is included as a submodule, but is turned off by default for faster builds and easier debugging. To build EVA with Galois configure with `USE_GALOIS=ON`:
67 | ```
68 | cmake -DUSE_GALOIS=ON .
69 | ```
70 |
71 | ### Running the Examples
72 |
73 | The examples use EVA's Python APIs. To install dependencies with PIP:
74 | ```
75 | python3 -m pip install -r examples/requirements.txt
76 | ```
77 |
78 | To run for example the image processing example in EVA/examples:
79 | ```
80 | cd examples/
81 | python3 image_processing.py
82 | ```
83 | This will compile and run homomorphic evaluations of a Sobel edge detection filter and a Harris corner detection filter on `examples/baboon.png`, producing results of homomorphic evaluation in `*_encrypted.png` and reference results from normal execution in `*_reference.png`.
84 | The script also reports the mean squared error between these for each filter.
85 |
86 | ## Programming with PyEVA
87 |
88 | PyEVA is a thin Python-embedded DSL for producing EVA programs.
89 | We will walk you through compiling a PyEVA program with EVA and running it on top of SEAL.
90 |
91 | ### Writing and Compiling Programs
92 |
93 | A program to evaluate a fixed polynomial 3x2+5x-2 on 1024 encrypted values can be written:
94 | ```
95 | from eva import *
96 | poly = EvaProgram('Polynomial', vec_size=1024)
97 | with poly:
98 | x = Input('x')
99 | Output('y', 3*x**2 + 5*x - 2)
100 | ```
101 | Next we will compile this program for the [CKKS encryption scheme](https://eprint.iacr.org/2016/421.pdf).
102 | Two additional pieces of information EVA currently requires to compile for CKKS are the *fixed-point scale for inputs* and the *maximum ranges of coefficients in outputs*, both represented in number of bits:
103 | ```
104 | poly.set_output_ranges(30)
105 | poly.set_input_scales(30)
106 | ```
107 | Now the program can be compiled:
108 | ```
109 | from eva.ckks import *
110 | compiler = CKKSCompiler()
111 | compiled_poly, params, signature = compiler.compile(poly)
112 | ```
113 | The `compile` method transforms the program in-place and returns:
114 |
115 | 1. the compiled program;
116 | 2. encryption parameters for Microsoft SEAL with which the program can be executed;
117 | 3. a signature object, that specifies how inputs and outputs need to be encoded and decoded.
118 |
119 | The compiled program can be inspected by printing it in the DOT format for the [Graphviz](https://graphviz.org/) visualization software:
120 | ```
121 | print(compiled_poly.to_DOT())
122 | ```
123 | The output can be viewed as a graph in, for example, a number of Graphviz editors available online.
124 |
125 | ### Generating Keys and Encrypting Inputs
126 |
127 | Encryption keys can now be generated using the encryption parameters:
128 | ```
129 | from eva.seal import *
130 | public_ctx, secret_ctx = generate_keys(params)
131 | ```
132 | Next a dictionary of inputs is created and encrypted using the public context and the program signature:
133 | ```
134 | inputs = { 'x': [i for i in range(compiled_poly.vec_size)] }
135 | encInputs = public_ctx.encrypt(inputs, signature)
136 | ```
137 |
138 | ### Homomorphic Execution
139 |
140 | Everything is now in place for executing the program with Microsoft SEAL:
141 | ```
142 | encOutputs = public_ctx.execute(compiled_poly, encInputs)
143 | ```
144 |
145 | ### Decrypting Results
146 |
147 | Finally, the outputs can be decrypted using the secret context:
148 | ```
149 | outputs = secret_ctx.decrypt(encOutputs, signature)
150 | ```
151 | For debugging it is often useful to compare homomorphic results to unencrypted computation.
152 | The `evaluate` method can be used to execute an EVA program on unencrypted data.
153 | The two sets of results can then be compared with for example Mean Squared Error:
154 |
155 | ```
156 | from eva.metric import valuation_mse
157 | reference = evaluate(compiled_poly, inputs)
158 | print('MSE', valuation_mse(outputs, reference))
159 | ```
160 |
161 | ## Contributing
162 |
163 | The EVA project welcomes contributions and suggestions. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details.
164 |
165 | ## Credits
166 |
167 | This project is a collaboration between the Microsoft Research's Research in Software Engineering (RiSE) group and Cryptography and Privacy Research group.
168 |
169 | A huge credit goes to [Dr. Roshan Dathathri](https://roshandathathri.github.io/), who as an intern built the first version of EVA, along with all the transformations required for targeting the CKKS scheme efficiently and the parallelizing runtime required to make execution scale.
170 |
171 | Many thanks to [Sangeeta Chowdhary](https://www.ilab.cs.rutgers.edu/~sc1696/), who as an intern put a huge amount of work into making EVA ready for release.
172 |
173 | ## Publications
174 |
175 | Roshan Dathathri, Blagovesta Kostova, Olli Saarikivi, Wei Dai, Kim Laine, Madanlal Musuvathi. *EVA: An Encrypted Vector Arithmetic Language and Compiler for Efficient Homomorphic Computation*. PLDI 2020. [arXiv](https://arxiv.org/abs/1912.11951) [DOI](https://doi.org/10.1145/3385412.3386023)
176 |
177 | Roshan Dathathri, Olli Saarikivi, Hao Chen, Kim Laine, Kristin Lauter, Saeed Maleki, Madanlal Musuvathi, Todd Mytkowicz. *CHET: An Optimizing Compiler for Fully-Homomorphic Neural-Network Inferencing*. PLDI 2019. [DOI](https://doi.org/10.1145/3314221.3314628)
178 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # Support
2 |
3 | ## How to file issues and get help
4 |
5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
7 | feature request as a new Issue.
8 |
9 | For help and questions about using this project, you can contact the EVA compiler team at
10 | [evacompiler@microsoft.com](mailto:evacompiler@microsoft.com).
11 | We are very interested in helping early adopters start to use EVA.
12 |
13 | ## Microsoft Support Policy
14 |
15 | Support for EVA is limited to the resources listed above.
--------------------------------------------------------------------------------
/azure-pipelines.yml:
--------------------------------------------------------------------------------
1 | # EVA pipeline
2 |
3 | trigger:
4 | - master
5 |
6 | pool:
7 | vmImage: 'windows-latest'
8 |
9 | steps:
10 | - task: UsePythonVersion@0
11 | displayName: 'Ensure Python 3.x'
12 | inputs:
13 | versionSpec: '3.x'
14 | addToPath: true
15 | architecture: 'x64'
16 |
17 | - task: securedevelopmentteam.vss-secure-development-tools.build-task-credscan.CredScan@2
18 | displayName: 'Run CredScan'
19 | inputs:
20 | toolMajorVersion: 'V2'
21 | outputFormat: sarif
22 | debugMode: false
23 |
24 | - task: CmdLine@2
25 | displayName: 'Get SEAL source code'
26 | inputs:
27 | script: |
28 | rem Use github repo
29 | git clone https://github.com/microsoft/SEAL.git
30 | cd SEAL
31 | rem Use 3.6.0 specifically
32 | git checkout 3.6.0
33 | workingDirectory: '$(Build.SourcesDirectory)/third_party'
34 |
35 | - task: CMake@1
36 | displayName: 'Configure SEAL'
37 | inputs:
38 | cmakeArgs: '-DSEAL_THROW_ON_TRANSPARENT_CIPHERTEXT=OFF -DALLOW_COMMAND_LINE_BUILD=ON -DSEAL_USE_MSGSL=OFF -DSEAL_USE_ZLIB=OFF -DSEAL_USE_ZSTD=OFF .'
39 | workingDirectory: $(Build.SourcesDirectory)/third_party/SEAL
40 |
41 | - task: MSBuild@1
42 | displayName: 'Build SEAL'
43 | inputs:
44 | solution: '$(Build.SourcesDirectory)/third_party/SEAL/SEAL.sln'
45 | msbuildArchitecture: 'x64'
46 | platform: 'x64'
47 | configuration: 'Debug'
48 |
49 | - task: CmdLine@2
50 | displayName: 'Get vcpkg'
51 | inputs:
52 | script: 'git clone https://github.com/microsoft/vcpkg.git'
53 | workingDirectory: '$(Build.SourcesDirectory)/third_party'
54 |
55 | - task: CmdLine@2
56 | displayName: 'Bootstrap vcpkg'
57 | inputs:
58 | script: '$(Build.SourcesDirectory)/third_party/vcpkg/bootstrap-vcpkg.bat'
59 | workingDirectory: '$(Build.SourcesDirectory)/third_party/vcpkg'
60 |
61 | - task: PowerShell@2
62 | displayName: 'Get protobuf compiler'
63 | inputs:
64 | targetType: 'inline'
65 | script: |
66 | mkdir protobuf
67 | cd protobuf
68 | Invoke-WebRequest -Uri "https://github.com/protocolbuffers/protobuf/releases/download/v3.15.8/protoc-3.15.8-win64.zip" -OutFile protobufc.zip
69 | Expand-Archive -LiteralPath protobufc.zip -DestinationPath protobufc
70 | workingDirectory: '$(Build.SourcesDirectory)/third_party'
71 |
72 | - task: CmdLine@2
73 | displayName: 'Install protobuf library'
74 | inputs:
75 | script: '$(Build.SourcesDirectory)/third_party/vcpkg/vcpkg.exe install protobuf[zlib]:x64-windows'
76 | workingDirectory: '$(Build.SourcesDirectory)/third_party/vcpkg'
77 |
78 | - task: CmdLine@2
79 | displayName: 'Create build directory'
80 | inputs:
81 | script: 'mkdir build'
82 | workingDirectory: '$(Build.SourcesDirectory)'
83 |
84 | - task: CMake@1
85 | displayName: 'Configure EVA'
86 | inputs:
87 | cmakeArgs: .. -DSEAL_DIR=$(Build.SourcesDirectory)/third_party/SEAL/cmake -DProtobuf_INCLUDE_DIR=$(Build.SourcesDirectory)/third_party/vcpkg/packages/protobuf_x64-windows/include -DProtobuf_LIBRARY=$(Build.SourcesDirectory)/third_party/vcpkg/packages/protobuf_x64-windows/lib/libprotobuf.lib -DProtobuf_PROTOC_EXECUTABLE=$(Build.SourcesDirectory)/third_party/protobuf/protobufc/bin/protoc.exe
88 | workingDirectory: '$(Build.SourcesDirectory)/build'
89 |
90 | - task: MSBuild@1
91 | displayName: 'Build EVA'
92 | inputs:
93 | solution: '$(Build.SourcesDirectory)/build/eva.sln'
94 | msbuildArchitecture: 'x64'
95 | platform: 'x64'
96 | configuration: 'Debug'
97 |
98 | - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0
99 | displayName: 'Component Detection'
100 |
101 | - task: securedevelopmentteam.vss-secure-development-tools.build-task-publishsecurityanalysislogs.PublishSecurityAnalysisLogs@2
102 | displayName: 'Publish Security Analysis Logs'
103 |
104 | - task: PublishBuildArtifacts@1
105 | displayName: 'Publish build artifacts'
106 | inputs:
107 | PathtoPublish: '$(Build.ArtifactStagingDirectory)'
108 | artifactName: windows-drop
109 |
--------------------------------------------------------------------------------
/eva/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT license.
3 |
4 | add_library(eva STATIC
5 | eva.cpp
6 | version.cpp
7 | )
8 |
9 | # TODO: everything except SEAL::seal should be make PRIVATE
10 | target_link_libraries(eva PUBLIC SEAL::seal protobuf::libprotobuf)
11 | if(USE_GALOIS)
12 | target_link_libraries(eva PUBLIC Galois::shmem numa)
13 | endif()
14 | target_include_directories(eva
15 | PUBLIC
16 | $
17 | $
18 | $
19 | )
20 | target_compile_definitions(eva PRIVATE EVA_VERSION_STR="${PROJECT_VERSION}")
21 |
22 | add_subdirectory(util)
23 | add_subdirectory(serialization)
24 | add_subdirectory(ir)
25 | add_subdirectory(common)
26 | add_subdirectory(ckks)
27 | add_subdirectory(seal)
28 |
--------------------------------------------------------------------------------
/eva/ckks/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT license.
3 |
4 | target_sources(eva PRIVATE
5 | ckks_config.cpp
6 | )
7 |
--------------------------------------------------------------------------------
/eva/ckks/always_rescaler.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ckks/rescaler.h"
7 |
8 | namespace eva {
9 |
10 | class AlwaysRescaler : public Rescaler {
11 | std::uint32_t minScale;
12 |
13 | public:
14 | AlwaysRescaler(Program &g, TermMap &type,
15 | TermMapOptional &scale)
16 | : Rescaler(g, type, scale) {
17 | // ASSUME: minScale is max among all the cipher inputs' scale
18 | minScale = 0;
19 | for (auto &source : program.getSources()) {
20 | if (scale[source] > minScale) minScale = scale[source];
21 | }
22 | assert(minScale != 0);
23 | }
24 |
25 | void
26 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
27 | if (term->numOperands() == 0) return; // inputs
28 | if (type[term] == Type::Raw) {
29 | handleRawScale(term);
30 | return;
31 | }
32 |
33 | if (isRescaleOp(term->op)) return; // already processed
34 |
35 | if (!isMultiplicationOp(term->op)) {
36 | // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst,
37 | // Op::RotateRightConst copy scale of the first operand
38 | scale[term] = scale[term->operandAt(0)];
39 | if (isAdditionOp(term->op)) {
40 | // Op::Add, Op::Sub
41 | // assert that all operands have the same scale
42 | for (auto &operand : term->getOperands()) {
43 | if (type[operand] != Type::Raw) {
44 | assert(scale[term] == scale[operand] || type[operand] == Type::Raw);
45 | }
46 | }
47 | }
48 | return;
49 | }
50 |
51 | // Op::Mul only
52 | // ASSUME: only two operands
53 | std::uint32_t multScale = 0;
54 | for (auto &operand : term->getOperands()) {
55 | multScale += scale[operand];
56 | }
57 | assert(multScale != 0);
58 | scale[term] = multScale;
59 |
60 | // always rescale
61 | insertRescale(term, multScale - minScale);
62 | }
63 | };
64 |
65 | } // namespace eva
66 |
--------------------------------------------------------------------------------
/eva/ckks/ckks_config.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #include "eva/ckks/ckks_config.h"
5 | #include "eva/util/logging.h"
6 | #include
7 |
8 | namespace eva {
9 |
10 | CKKSConfig::CKKSConfig(
11 | const std::unordered_map &configMap) {
12 | for (const auto &entry : configMap) {
13 | const auto &option = entry.first;
14 | const auto &valueStr = entry.second;
15 | if (option == "balance_reductions") {
16 | std::istringstream is(valueStr);
17 | is >> std::boolalpha >> balanceReductions;
18 | if (is.bad()) {
19 | warn("Could not parse boolean in balance_reductions=%s. Falling back "
20 | "to default.",
21 | valueStr.c_str());
22 | }
23 | } else if (option == "rescaler") {
24 | if (valueStr == "lazy_waterline") {
25 | rescaler = CKKSRescaler::LazyWaterline;
26 | } else if (valueStr == "eager_waterline") {
27 | rescaler = CKKSRescaler::EagerWaterline;
28 | } else if (valueStr == "always") {
29 | rescaler = CKKSRescaler::Always;
30 | } else if (valueStr == "minimum") {
31 | rescaler = CKKSRescaler::Minimum;
32 | } else {
33 | // Please update this warning message when adding new options to the
34 | // cases above
35 | warn("Unknown value rescaler=%s. Available rescalers are "
36 | "lazy_waterline, eager_waterline, always, minimum. Falling back "
37 | "to default.",
38 | valueStr.c_str());
39 | }
40 | } else if (option == "lazy_relinearize") {
41 | std::istringstream is(valueStr);
42 | is >> std::boolalpha >> lazyRelinearize;
43 | if (is.bad()) {
44 | warn("Could not parse boolean in lazy_relinearize=%s. Falling back to "
45 | "default.",
46 | valueStr.c_str());
47 | }
48 | } else if (option == "security_level") {
49 | std::istringstream is(valueStr);
50 | is >> securityLevel;
51 | if (is.bad()) {
52 | throw std::runtime_error(
53 | "Could not parse unsigned int in security_level=" + valueStr);
54 | }
55 | } else if (option == "quantum_safe") {
56 | std::istringstream is(valueStr);
57 | is >> std::boolalpha >> quantumSafe;
58 | if (is.bad()) {
59 | throw std::runtime_error("Could not parse boolean in quantum_safe=" +
60 | valueStr);
61 | }
62 | } else if (option == "warn_vec_size") {
63 | std::istringstream is(valueStr);
64 | is >> std::boolalpha >> warnVecSize;
65 | if (is.bad()) {
66 | warn("Could not parse boolean in warn_vec_size=%s. Falling "
67 | "back to default.",
68 | valueStr.c_str());
69 | }
70 | } else {
71 | warn("Unknown option %s. Available options are:\n%s", option.c_str(),
72 | OPTIONS_HELP_MESSAGE);
73 | }
74 | }
75 | }
76 |
77 | std::string CKKSConfig::toString(int indent) const {
78 | auto indentStr = std::string(indent, ' ');
79 | std::stringstream s;
80 | s << std::boolalpha;
81 | s << indentStr << "balance_reductions = " << balanceReductions;
82 | s << '\n';
83 | s << indentStr << "rescaler = ";
84 | switch (rescaler) {
85 | case CKKSRescaler::LazyWaterline:
86 | s << "lazy_waterline";
87 | break;
88 | case CKKSRescaler::EagerWaterline:
89 | s << "eager_waterline";
90 | break;
91 | case CKKSRescaler::Always:
92 | s << "always";
93 | break;
94 | case CKKSRescaler::Minimum:
95 | s << "minimum";
96 | break;
97 | }
98 | s << '\n';
99 | s << indentStr << "lazy_relinearize = " << lazyRelinearize;
100 | s << '\n';
101 | s << indentStr << "security_level = " << securityLevel;
102 | s << '\n';
103 | s << indentStr << "quantum_safe = " << quantumSafe;
104 | s << '\n';
105 | s << indentStr << "warn_vec_size = " << warnVecSize;
106 | return s.str();
107 | }
108 |
109 | } // namespace eva
110 |
--------------------------------------------------------------------------------
/eva/ckks/ckks_config.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include
7 | #include
8 |
9 | namespace eva {
10 |
11 | // clang-format off
12 | const char *const OPTIONS_HELP_MESSAGE =
13 | "balance_reductions - Balance trees of mul, add or sub operations. bool (default=true)\n"
14 | "rescaler - Rescaling policy. One of: lazy_waterline (default), eager_waterline, always, minimum\n"
15 | "lazy_relinearize - Relinearize as late as possible. bool (default=true)\n"
16 | "security_level - How many bits of security parameters should be selected for. int (default=128)\n"
17 | "quantum_safe - Select quantum safe parameters. bool (default=false)\n"
18 | "warn_vec_size - Warn about possibly inefficient vector size selection. bool (default=true)";
19 | // clang-format on
20 |
21 | enum class CKKSRescaler { LazyWaterline, EagerWaterline, Always, Minimum };
22 |
23 | // Controls the behavior of CKKSCompiler
24 | class CKKSConfig {
25 | public:
26 | CKKSConfig() {}
27 | CKKSConfig(const std::unordered_map &configMap);
28 |
29 | std::string toString(int indent = 0) const;
30 |
31 | bool balanceReductions = true;
32 | CKKSRescaler rescaler = CKKSRescaler::LazyWaterline;
33 | bool lazyRelinearize = true;
34 | uint32_t securityLevel = 128;
35 | bool quantumSafe = false;
36 |
37 | // Warnings
38 | bool warnVecSize = true;
39 | };
40 |
41 | } // namespace eva
42 |
--------------------------------------------------------------------------------
/eva/ckks/ckks_parameters.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/serialization/ckks.pb.h"
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | namespace eva {
13 |
14 | struct CKKSParameters {
15 | std::vector primeBits; // in log-scale
16 | std::set rotations;
17 | std::uint32_t polyModulusDegree;
18 | };
19 |
20 | std::unique_ptr serialize(const CKKSParameters &);
21 | std::unique_ptr deserialize(const msg::CKKSParameters &);
22 |
23 | } // namespace eva
24 |
--------------------------------------------------------------------------------
/eva/ckks/ckks_signature.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/types.h"
7 | #include "eva/serialization/ckks.pb.h"
8 | #include
9 | #include
10 | #include
11 |
12 | namespace eva {
13 |
14 | // TODO: make these structs immutable
15 |
16 | struct CKKSEncodingInfo {
17 | Type inputType;
18 | int scale;
19 | int level;
20 |
21 | CKKSEncodingInfo(Type inputType, int scale, int level)
22 | : inputType(inputType), scale(scale), level(level) {}
23 | };
24 |
25 | struct CKKSSignature {
26 | int vecSize;
27 | std::unordered_map inputs;
28 |
29 | CKKSSignature(int vecSize,
30 | std::unordered_map inputs)
31 | : vecSize(vecSize), inputs(inputs) {}
32 | };
33 |
34 | std::unique_ptr serialize(const CKKSSignature &);
35 | std::unique_ptr deserialize(const msg::CKKSSignature &);
36 |
37 | } // namespace eva
38 |
--------------------------------------------------------------------------------
/eva/ckks/eager_relinearizer.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 |
9 | namespace eva {
10 |
11 | class EagerRelinearizer {
12 | Program &program;
13 | TermMap &type;
14 | TermMapOptional &scale;
15 |
16 | bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); }
17 |
18 | bool isUnencryptedType(const Type &type) { return type != Type::Cipher; }
19 |
20 | bool areAllOperandsEncrypted(Term::Ptr &term) {
21 | for (auto &op : term->getOperands()) {
22 | if (isUnencryptedType(type[op])) {
23 | return false;
24 | }
25 | }
26 | return true;
27 | }
28 |
29 | public:
30 | EagerRelinearizer(Program &g, TermMap &type,
31 | TermMapOptional &scale)
32 | : program(g), type(type), scale(scale) {}
33 |
34 | void
35 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
36 | auto &operands = term->getOperands();
37 | if (operands.size() == 0) return; // inputs
38 |
39 | auto op = term->op;
40 |
41 | if (!isMultiplicationOp(op)) return;
42 |
43 | // Op::Multiply only
44 | // ASSUME: only two operands
45 | bool encryptedOps = areAllOperandsEncrypted(term);
46 | if (!encryptedOps) return;
47 |
48 | auto relinNode = program.makeTerm(Op::Relinearize, {term});
49 | type[relinNode] = type[term];
50 | scale[relinNode] = scale[term];
51 |
52 | term->replaceOtherUsesWith(relinNode);
53 | }
54 | };
55 |
56 | } // namespace eva
57 |
--------------------------------------------------------------------------------
/eva/ckks/eager_waterline_rescaler.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ckks/rescaler.h"
7 | #include "eva/util/logging.h"
8 |
9 | namespace eva {
10 |
11 | class EagerWaterlineRescaler : public Rescaler {
12 | std::uint32_t minScale;
13 | const std::uint32_t fixedRescale = 60;
14 |
15 | public:
16 | EagerWaterlineRescaler(Program &g, TermMap &type,
17 | TermMapOptional &scale)
18 | : Rescaler(g, type, scale) {
19 | // ASSUME: minScale is max among all the inputs' scale
20 | minScale = 0;
21 | for (auto &source : program.getSources()) {
22 | if (scale[source] > minScale) minScale = scale[source];
23 | }
24 | assert(minScale != 0);
25 | }
26 |
27 | void
28 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
29 | if (term->numOperands() == 0) return; // inputs
30 | if (type[term] == Type::Raw) {
31 | handleRawScale(term);
32 | return;
33 | }
34 |
35 | if (isRescaleOp(term->op)) return; // already processed
36 |
37 | if (!isMultiplicationOp(term->op)) {
38 | // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst,
39 | // Op::RotateRightConst copy scale of the first operand
40 | scale[term] = scale[term->operandAt(0)];
41 | if (isAdditionOp(term->op)) {
42 | // Op::Add, Op::Sub
43 | auto maxScale = scale[term];
44 | for (auto &operand : term->getOperands()) {
45 | // Here we allow raw operands to possibly raise the scale
46 | if (scale[operand] > maxScale) maxScale = scale[operand];
47 | }
48 | for (auto &operand : term->getOperands()) {
49 | if (scale[operand] < maxScale && type[operand] != Type::Raw) {
50 | log(Verbosity::Trace,
51 | "Scaling up t%i from scale %i to match other addition operands "
52 | "at scale %i",
53 | operand->index, scale[operand], maxScale);
54 |
55 | auto scaleConstant = program.makeUniformConstant(1);
56 | scale[scaleConstant] = maxScale - scale[operand];
57 | scaleConstant->set(scale[scaleConstant]);
58 |
59 | auto mulNode = program.makeTerm(Op::Mul, {operand, scaleConstant});
60 | scale[mulNode] = maxScale;
61 |
62 | // TODO: Not obviously correct as it's modifying inside
63 | // iteration. Refine API to make this less surprising.
64 | term->replaceOperand(operand, mulNode);
65 | }
66 | }
67 | // assert that all operands have the same scale
68 | for (auto &operand : term->getOperands()) {
69 | assert(maxScale == scale[operand] || type[operand] == Type::Raw);
70 | }
71 | scale[term] = maxScale;
72 | }
73 | return;
74 | }
75 |
76 | // Op::Mul only
77 | // ASSUME: only two operands
78 | std::uint32_t multScale = 0;
79 | for (auto &operand : term->getOperands()) {
80 | multScale += scale[operand];
81 | }
82 | assert(multScale != 0);
83 | scale[term] = multScale;
84 |
85 | // rescale only if above the waterline
86 | auto temp = term;
87 | while (multScale >= (fixedRescale + minScale)) {
88 | temp = insertRescale(temp, fixedRescale);
89 | multScale -= fixedRescale;
90 | assert(multScale == scale[temp]);
91 | }
92 | }
93 | };
94 |
95 | } // namespace eva
96 |
--------------------------------------------------------------------------------
/eva/ckks/encode_inserter.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 |
9 | namespace eva {
10 |
11 | class EncodeInserter {
12 | Program &program;
13 | TermMap &type;
14 | TermMapOptional &scale;
15 |
16 | bool isRawType(const Type &type) { return type == Type::Raw; }
17 | bool isCipherType(const Type &type) { return type == Type::Cipher; }
18 | bool isAdditionOp(const Op &op_code) {
19 | return ((op_code == Op::Add) || (op_code == Op::Sub));
20 | }
21 |
22 | auto insertEncodeNode(Op op, const Term::Ptr &other, const Term::Ptr &term) {
23 | auto newNode = program.makeTerm(Op::Encode, {term});
24 | type[newNode] = Type::Plain;
25 | if (isAdditionOp(op)) {
26 | scale[newNode] = scale[other];
27 | } else {
28 | scale[newNode] = scale[term];
29 | }
30 | newNode->set(scale[newNode]);
31 | return newNode;
32 | }
33 |
34 | public:
35 | EncodeInserter(Program &g, TermMap &type,
36 | TermMapOptional &scale)
37 | : program(g), type(type), scale(scale) {}
38 |
39 | void
40 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
41 | auto &operands = term->getOperands();
42 | if (operands.size() == 0) return; // inputs
43 |
44 | assert(operands.size() <= 2);
45 | if (operands.size() == 2) {
46 | auto &leftOperand = operands[0];
47 | auto &rightOperand = operands[1];
48 | auto op1 = leftOperand->op;
49 | if (isCipherType(type[leftOperand]) && isRawType(type[rightOperand])) {
50 | auto newTerm = insertEncodeNode(term->op, leftOperand, rightOperand);
51 | term->replaceOperand(rightOperand, newTerm);
52 | }
53 |
54 | if (isCipherType(type[rightOperand]) && isRawType(type[leftOperand])) {
55 | auto newTerm = insertEncodeNode(term->op, rightOperand, leftOperand);
56 | term->replaceOperand(leftOperand, newTerm);
57 | }
58 | }
59 | }
60 | };
61 |
62 | } // namespace eva
63 |
--------------------------------------------------------------------------------
/eva/ckks/encryption_parameter_selector.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | namespace eva {
14 |
15 | class EncryptionParametersSelector {
16 | public:
17 | EncryptionParametersSelector(Program &g,
18 | TermMapOptional &scales,
19 | TermMap &types)
20 | : program_(g), scales_(scales), terms_(g), types(types) {}
21 |
22 | void operator()(const Term::Ptr &term) {
23 | // This function computes, for each term, the set of coeff_modulus primes
24 | // needed to reach that term, taking only into account rescalings. Primes
25 | // needed to hold the output values are not included. For example, input
26 | // terms require no extra primes, so for input terms this function will
27 | // assign an empty set of primes. The example below shows parameters
28 | // assigned in a simple example computation, where we rescale by 40 bits:
29 | //
30 | // In_1:{} In_2:{} In_3:{}
31 | // \ \ /
32 | // \ \ /
33 | // \ * MULTIPLY:{}
34 | // \ |
35 | // \ |
36 | // \ * RESCALE:{40}
37 | // \ |
38 | // \ |
39 | // -----* ADD:{40}
40 | // |
41 | // |
42 | // Out_1:{40}
43 |
44 | // This function must only be used with forward pass traversal, as it
45 | // expects operand terms to have been processed already.
46 | if (types[term] == Type::Raw || term->op == Op::Encode) {
47 | return;
48 | }
49 | auto &operands = term->getOperands();
50 |
51 | // Nothing to do for inputs
52 | if (operands.size() > 0) {
53 | // Get the parameters for this term
54 | auto &parms = terms_[term];
55 |
56 | for (auto &operand : operands) {
57 | // Get the parameters for each operand (forward pass)
58 | auto &operandParms = terms_[operand];
59 |
60 | // Set the parameters for this term to be the maximum over operands
61 | if (operandParms.size() > parms.size()) {
62 | parms = operandParms;
63 | }
64 | }
65 |
66 | // Adjust the parameters if this term is a rescale operation
67 | // NOTE: This is ignoring modulus switches, but still works because there
68 | // is always a longest path with no modulus switches.
69 | // TODO: Validate this claim and generalize to include modulus switches.
70 | if (isRescaleOp(term->op)) {
71 | auto newSize = parms.size() + 1;
72 |
73 | // By how much are we rescaling?
74 | auto divisor = term->get();
75 | assert(divisor != 0);
76 |
77 | // Add the required scaling factor to the parameters
78 | parms.push_back(divisor);
79 | assert(parms.size() == newSize);
80 | }
81 | }
82 | }
83 |
84 | inline void free(const Term::Ptr &term) { terms_[term].clear(); }
85 |
86 | auto getEncryptionParameters() {
87 | // This function returns the encryption parameters (really just a list of
88 | // prime bit counts for the coeff_modulus) needed for this computation. It
89 | // can be called after forward pass traversal has computed the rescaling
90 | // primes for all terms.
91 | //
92 | // The logic is simple: we loop over each output term as those have the
93 | // largest (largest number of primes) parameter sets after forward
94 | // traversal, and find the largest parameter set among those. This set will
95 | // work globally for the computation. Since the parameters are not taking
96 | // into account the need for storing the result for the output terms, we
97 | // need to add one or more additional primes to the parameters, depending on
98 | // the scales and the ranges of the terms. For example, if the output term
99 | // has a parameter set {40} after forward traversal, with a scale and range
100 | // of 40 and 16 bits, respectively, the result requires an additional 56-bit
101 | // prime in the parameter set. This prime is always added in the set before
102 | // the rescaling primes, so in this case the function would return {56,40}.
103 | // If the scale and range are very large, this function will add more than
104 | // one extra prime.
105 |
106 | std::vector parms;
107 |
108 | // The size in bits needed for the output value; this includes the scale and
109 | // the range
110 | std::uint32_t maxOutputSize = 0;
111 |
112 | // The bit count of the largest prime appearing in the parameters
113 | std::uint32_t maxParm = 0;
114 |
115 | // The largest (largest number of primes) set of parameters required among
116 | // all output terms
117 | std::uint32_t maxLen = 0;
118 |
119 | // Loop over each output term
120 | for (auto &entry : program_.getOutputs()) {
121 | auto &output = entry.second;
122 |
123 | // The size for this output term equals the range attribute (bits) plus
124 | // the scale (bits)
125 | auto size = output->get();
126 | size += scales_[output];
127 |
128 | // Update maxOutputSize
129 | if (size > maxOutputSize) maxOutputSize = size;
130 |
131 | // Get the parameters for the current output term
132 | auto &oParms = terms_[output];
133 |
134 | // Update maxLen (number of primes)
135 | if (maxLen < oParms.size()) maxLen = oParms.size();
136 |
137 | // Update maxParm (largest prime)
138 | for (auto &parm : oParms) {
139 | if (parm > maxParm) maxParm = parm;
140 | }
141 | }
142 |
143 | // Ensure that the output size is non-zero
144 | assert(maxOutputSize != 0);
145 |
146 | if (maxOutputSize > 60) {
147 | // If the required output size is larger than 60 bits, we need to increase
148 | // the parameters with more than one additional primes.
149 |
150 | // In this case maxPrime is always 60 bits
151 | maxParm = 60;
152 |
153 | // Add 60-bit primes for as long as needed
154 | while (maxOutputSize >= 60) {
155 | parms.push_back(60);
156 | maxOutputSize -= 60;
157 | }
158 |
159 | // Add one more prime if needed
160 | if (maxOutputSize > 0) {
161 | // TODO: The minimum should probably depend on poly_modulus_degree
162 | parms.push_back(std::max(20u, maxOutputSize));
163 | }
164 | } else {
165 | // The output size is less than 60 bits so the output parameters require
166 | // only one additional prime.
167 |
168 | // Update maxParm
169 | if (maxOutputSize > maxParm) maxParm = maxOutputSize;
170 |
171 | // Add the required prime to the parameters for this term
172 | parms.push_back(maxParm);
173 | }
174 |
175 | // Finally, loop over all output terms and add the largest parameter set to
176 | // parms after what was pushed above.
177 | for (auto &entry : program_.getOutputs()) {
178 | auto &output = entry.second;
179 |
180 | // Get the parameters for the current output term
181 | auto &oParms = terms_[output];
182 |
183 | // If this output node has the longest parameter set, use it
184 | if (maxLen == oParms.size()) {
185 | parms.insert(parms.end(), oParms.rbegin(), oParms.rend());
186 |
187 | // Exit the for loop; we have our parameter set
188 | break;
189 | }
190 | }
191 |
192 | // Add maxParm to result parameters; this is the "key prime".
193 | // TODO: This might be too aggressive. We can try smaller primes here as
194 | // well, which in some cases is advantageous as it may result in smaller
195 | // poly_modulus_degree, even though the noise growth may be a bit larger.
196 | parms.push_back(maxParm);
197 |
198 | return parms;
199 | }
200 |
201 | private:
202 | Program &program_;
203 | TermMapOptional &scales_;
204 | TermMap> terms_;
205 | TermMap &types;
206 |
207 | inline bool isRescaleOp(const Op &op_code) { return op_code == Op::Rescale; }
208 | };
209 |
210 | } // namespace eva
211 |
--------------------------------------------------------------------------------
/eva/ckks/lazy_relinearizer.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 |
9 | namespace eva {
10 |
11 | class LazyRelinearizer {
12 | Program &program;
13 | TermMap &type;
14 | TermMapOptional &scale;
15 | TermMap pending; // maintains whether relinearization is pending
16 | std::uint32_t count;
17 | std::uint32_t countTotal;
18 |
19 | bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); }
20 |
21 | bool isRotationOp(const Op &op_code) {
22 | return ((op_code == Op::RotateLeftConst) ||
23 | (op_code == Op::RotateRightConst));
24 | }
25 |
26 | bool isUnencryptedType(const Type &type) { return type != Type::Cipher; }
27 |
28 | bool areAllOperandsEncrypted(Term::Ptr &term) {
29 | for (auto &op : term->getOperands()) {
30 | assert(type[op] != Type::Undef);
31 | if (isUnencryptedType(type[op])) {
32 | return false;
33 | }
34 | }
35 | return true;
36 | }
37 |
38 | bool isEncryptedMultOp(Term::Ptr &term) {
39 | return (isMultiplicationOp(term->op) && areAllOperandsEncrypted(term));
40 | }
41 |
42 | public:
43 | LazyRelinearizer(Program &g, TermMap &type,
44 | TermMapOptional &scale)
45 | : program(g), type(type), scale(scale), pending(g) {
46 | count = 0;
47 | countTotal = 0;
48 | }
49 |
50 | ~LazyRelinearizer() {
51 | // TODO: move these to a logging system
52 | // std::cout << "Number of delayed relin: " << count << "\n";
53 | // std::cout << "Number of relin: " << countTotal << "\n";
54 | }
55 |
56 | void
57 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
58 | auto &operands = term->getOperands();
59 | if (operands.size() == 0) return; // inputs
60 |
61 | bool delayed = false;
62 |
63 | if (isEncryptedMultOp(term)) {
64 | assert(pending[term] == false);
65 | pending[term] = true;
66 | delayed = true;
67 | } else if (pending[term] == false) {
68 | return;
69 | }
70 |
71 | bool mustInsert = false;
72 | assert(term->numUses() > 0);
73 | auto firstUse = term->getUses()[0];
74 | for (auto &use : term->getUses()) {
75 | if (isEncryptedMultOp(use) || isRotationOp(use->op) ||
76 | use->op == Op::Output || (firstUse != use)) { // different uses
77 | mustInsert = true;
78 | break;
79 | }
80 | }
81 |
82 | if (mustInsert) {
83 | auto relinNode = program.makeTerm(Op::Relinearize, {term});
84 | ++countTotal;
85 |
86 | type[relinNode] = type[term];
87 | scale[relinNode] = scale[term];
88 | term->replaceOtherUsesWith(relinNode);
89 | } else {
90 | if (delayed) ++count;
91 | for (auto &use : term->getUses()) {
92 | pending[use] = true;
93 | }
94 | }
95 | }
96 | };
97 |
98 | } // namespace eva
99 |
--------------------------------------------------------------------------------
/eva/ckks/levels_checker.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 | #include
9 | #include
10 | #include
11 |
12 | namespace eva {
13 |
14 | class LevelsChecker {
15 | public:
16 | LevelsChecker(Program &g, TermMap &types)
17 | : program_(g), types_(types), levels_(g) {}
18 |
19 | void operator()(const Term::Ptr &term) {
20 | // This function verifies that the levels are compatibile. It assumes the
21 | // operand terms are processed already, so it must only be used with forward
22 | // pass traversal.
23 |
24 | if (term->numOperands() == 0) {
25 | // If this is a source node, get the encoding level
26 | levels_[term] = term->get();
27 | } else {
28 | // For other terms, the operands must all have matching level. First find
29 | // the level of any of the ciphertext operands.
30 | std::size_t operandLevel;
31 | for (auto &operand : term->getOperands()) {
32 | if (types_[operand] == Type::Cipher) {
33 | operandLevel = levels_[operand];
34 | break;
35 | }
36 | }
37 |
38 | // Next verify that all operands have the same level.
39 | for (auto &operand : term->getOperands()) {
40 | if (types_[operand] == Type::Cipher) {
41 | auto operandLevel2 = levels_[operand];
42 | assert(operandLevel == operandLevel2);
43 | }
44 | }
45 |
46 | // Incremenet the level for a rescale or modulus switch
47 | std::size_t level = operandLevel;
48 | if (isRescaleOp(term->op) || isModSwitchOp(term->op)) {
49 | ++level;
50 | }
51 | levels_[term] = level;
52 | }
53 | }
54 |
55 | void free(const Term::Ptr &term) {
56 | // No-op
57 | }
58 |
59 | private:
60 | Program &program_;
61 | TermMap &types_;
62 |
63 | // Maintains the reverse level (leaves have 0, roots have max)
64 | TermMap levels_;
65 |
66 | bool isModSwitchOp(const Op &op_code) { return (op_code == Op::ModSwitch); }
67 |
68 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); }
69 | };
70 |
71 | } // namespace eva
72 |
--------------------------------------------------------------------------------
/eva/ckks/minimum_rescaler.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ckks/rescaler.h"
7 | #include "eva/util/logging.h"
8 |
9 | namespace eva {
10 |
11 | class MinimumRescaler : public Rescaler {
12 | std::uint32_t minScale;
13 | const std::uint32_t maxRescale = 60;
14 |
15 | public:
16 | MinimumRescaler(Program &g, TermMap &type,
17 | TermMapOptional &scale)
18 | : Rescaler(g, type, scale) {
19 | // ASSUME: minScale is max among all the inputs' scale
20 | minScale = 0;
21 | for (auto &source : program.getSources()) {
22 | if (scale[source] > minScale) minScale = scale[source];
23 | }
24 | assert(minScale != 0);
25 | }
26 |
27 | void
28 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
29 | auto &operands = term->getOperands();
30 | if (operands.size() == 0) return; // inputs
31 | if (type[term] == Type::Raw) {
32 | handleRawScale(term);
33 | return;
34 | }
35 |
36 | auto op = term->op;
37 |
38 | if (isRescaleOp(op)) return; // already processed
39 |
40 | if (!isMultiplicationOp(op)) {
41 | // Op::Add, Op::Sub, NEGATE, COPY, Op::RotateLeftConst,
42 | // Op::RotateRightConst copy scale of the first operand
43 | for (auto &operand : operands) {
44 | assert(operand->op != Op::Constant);
45 | assert(scale[operand] != 0);
46 | scale[term] = scale[operand];
47 | break;
48 | }
49 | if (isAdditionOp(op)) {
50 | // Op::Add, Op::Sub
51 | auto maxScale = scale[term];
52 | for (auto &operand : operands) {
53 | // Here we allow raw operands to possibly raise the scale
54 | if (scale[operand] > maxScale) maxScale = scale[operand];
55 | }
56 | for (auto &operand : operands) {
57 | if (scale[operand] < maxScale && type[operand] != Type::Raw) {
58 | log(Verbosity::Trace,
59 | "Scaling up t%i from scale %i to match other addition operands "
60 | "at scale %i",
61 | operand->index, scale[operand], maxScale);
62 |
63 | auto scaleConstant = program.makeUniformConstant(1);
64 | scale[scaleConstant] = maxScale - scale[operand];
65 | scaleConstant->set(scale[scaleConstant]);
66 |
67 | auto mulNode = program.makeTerm(Op::Mul, {operand, scaleConstant});
68 | scale[mulNode] = maxScale;
69 |
70 | // TODO: Not obviously correct as it's modifying inside
71 | // iteration.
72 | // Refine API to make this less surprising.
73 | term->replaceOperand(operand, mulNode);
74 | }
75 | }
76 | // assert that all operands have the same scale
77 | for (auto &operand : operands) {
78 | assert(maxScale == scale[operand] || type[operand] == Type::Raw);
79 | }
80 | scale[term] = maxScale;
81 | }
82 | return;
83 | }
84 |
85 | // Op::Multiply only
86 | // ASSUME: only two operands
87 | std::vector operandsCopy;
88 | for (auto &operand : operands) {
89 | operandsCopy.push_back(operand);
90 | }
91 | assert(operandsCopy.size() == 2);
92 | std::uint32_t multScale = scale[operandsCopy[0]] + scale[operandsCopy[1]];
93 | assert(multScale != 0);
94 | scale[term] = multScale;
95 |
96 | auto minOfScales = scale[operandsCopy[0]];
97 | if (minOfScales > scale[operandsCopy[1]])
98 | minOfScales = scale[operandsCopy[1]];
99 | auto rescaleBy = minOfScales - minScale;
100 | if (rescaleBy > maxRescale) rescaleBy = maxRescale;
101 | if ((2 * rescaleBy) >= maxRescale) {
102 | // rescale after multiplication is inevitable
103 | // to reduce the growth of scale, rescale both operands before
104 | // multiplication
105 | assert(rescaleBy <= maxRescale);
106 | insertRescaleBetween(operandsCopy[0], term, rescaleBy);
107 | if (operandsCopy[0] != operandsCopy[1]) {
108 | insertRescaleBetween(operandsCopy[1], term, rescaleBy);
109 | }
110 |
111 | scale[term] = multScale - (2 * rescaleBy);
112 | } else {
113 | // rescale only if above the waterline
114 | auto temp = term;
115 | while (multScale >= (maxRescale + minScale)) {
116 | temp = insertRescale(temp, maxRescale);
117 | multScale -= maxRescale;
118 | assert(multScale == scale[temp]);
119 | }
120 | }
121 | }
122 | };
123 |
124 | } // namespace eva
125 |
--------------------------------------------------------------------------------
/eva/ckks/mod_switcher.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 |
9 | namespace eva {
10 |
11 | class ModSwitcher {
12 | Program &program;
13 | TermMap &type;
14 | TermMapOptional &scale;
15 | TermMap
16 | level; // maintains the reverse level (leaves have 0, roots have max)
17 | std::vector encodeNodes;
18 |
19 | Term::Ptr insertModSwitchNode(Term::Ptr &term, std::uint32_t termLevel) {
20 | auto newNode = program.makeTerm(Op::ModSwitch, {term});
21 | scale[newNode] = scale[term];
22 | level[newNode] = termLevel;
23 | return newNode;
24 | }
25 |
26 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); }
27 |
28 | bool isCipherType(const Term::Ptr &term) const {
29 | return type[term] == Type::Cipher;
30 | }
31 |
32 | public:
33 | ModSwitcher(Program &g, TermMap &type,
34 | TermMapOptional &scale)
35 | : program(g), type(type), scale(scale), level(g) {}
36 |
37 | ~ModSwitcher() {
38 | auto sources = program.getSources();
39 | std::uint32_t maxLevel = 0;
40 | for (auto &source : sources) {
41 | if (level[source] > maxLevel) maxLevel = level[source];
42 | }
43 | for (auto &source : sources) {
44 | auto curLevel = maxLevel - level[source];
45 | source->set(curLevel);
46 | }
47 |
48 | for (auto &encode : encodeNodes) {
49 | encode->set(maxLevel - level[encode]);
50 | }
51 | }
52 |
53 | void operator()(
54 | Term::Ptr &term) { // must only be used with backward pass traversal
55 | if (term->numUses() == 0) return;
56 |
57 | //we do not want to add modswitch for nodes of type raw
58 | if (type[term] == Type::Raw) return;
59 |
60 | if (term->op == Op::Encode) {
61 | encodeNodes.push_back(term);
62 | }
63 | std::map> useLevels; // ordered map
64 | for (auto &use : term->getUses()) {
65 | useLevels[level[use]].push_back(use);
66 | }
67 |
68 | std::uint32_t termLevel = 0;
69 | if (useLevels.size() > 1) {
70 | auto useLevel = useLevels.rbegin(); // max to min
71 | termLevel = useLevel->first;
72 | ++useLevel;
73 |
74 | auto temp = term;
75 | auto tempLevel = termLevel;
76 | while (useLevel != useLevels.rend()) {
77 | auto expectedLevel = useLevel->first;
78 | while (tempLevel > expectedLevel) {
79 | temp = insertModSwitchNode(temp, tempLevel);
80 | --tempLevel;
81 | }
82 | for (auto &use : useLevel->second) {
83 | use->replaceOperand(term, temp);
84 | }
85 | ++useLevel;
86 | }
87 | } else {
88 | assert(useLevels.size() == 1);
89 | termLevel = useLevels.begin()->first;
90 | }
91 | if (isRescaleOp(term->op)) {
92 | ++termLevel;
93 | }
94 | level[term] = termLevel;
95 | }
96 | };
97 |
98 | } // namespace eva
99 |
--------------------------------------------------------------------------------
/eva/ckks/parameter_checker.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | namespace eva {
14 |
15 | class InconsistentParameters : public std::runtime_error {
16 | public:
17 | InconsistentParameters(const std::string &msg) : std::runtime_error(msg) {}
18 | };
19 |
20 | class ParameterChecker {
21 | TermMap &types;
22 |
23 | public:
24 | ParameterChecker(Program &g, TermMap &types)
25 | : program_(g), parms_(g), types(types) {}
26 |
27 | void operator()(const Term::Ptr &term) {
28 | // Must only be used with forward pass traversal
29 | auto &operands = term->getOperands();
30 | if (types[term] == Type::Raw || term->op == Op::Encode) {
31 | return;
32 | }
33 | if (operands.size() > 0) {
34 | // Get the parameters for this term
35 | auto &parms = parms_[term];
36 | // Loop over operands
37 | for (auto &operand : operands) {
38 | // Get the parameters for the operand
39 | auto &operandParms = parms_[operand];
40 |
41 | // Nothing to do if the operand parameters are empty; the operand sets
42 | // no requirements on this node
43 | if (operandParms.size() > 0) {
44 | if (parms.size() > 0) {
45 | // If the parameters for this term are already set (from a different
46 | // operand), they must match the current operand's parameters
47 | if (operandParms.size() != parms.size()) {
48 | throw InconsistentParameters(
49 | "Two operands require different number of primes");
50 | }
51 |
52 | // Loop over the primes in the parameters for this term
53 | for (std::size_t i = 0; i < parms.size(); ++i) {
54 | if (parms[i] == 0) {
55 | // If any of the primes is zero (indicating a previous modulus
56 | // switch operand term, fill in its true value from the current
57 | // operand
58 | parms[i] = operandParms[i];
59 | } else if (operandParms[i] != 0) {
60 | // If the operand prime is non-zero, require equality
61 | if (parms[i] != operandParms[i]) {
62 | throw InconsistentParameters(
63 | "Primes required by two operands do not match");
64 | }
65 | }
66 | }
67 | } else {
68 | // This is the first operand to impose conditions on this term;
69 | // copy the parameters from the operand
70 | parms = operandParms;
71 | }
72 | }
73 | }
74 |
75 | if (isModSwitchOp(term->op)) {
76 | // Is this a modulus switch? If so, add an extra (placeholder) zero
77 | parms.push_back(0);
78 | } else if (isRescaleOp(term->op)) {
79 | // Is this a rescale? Then add a prime of the requested size
80 | auto divisor = term->get();
81 | assert(divisor != 0);
82 | parms.push_back(divisor);
83 | }
84 | } else {
85 | // Get the parameters for this term
86 | auto &parms = parms_[term];
87 | std::uint32_t level = term->get();
88 | while (level > 0) {
89 | parms.push_back(0);
90 | level--;
91 | }
92 | }
93 | }
94 |
95 | void free(const Term::Ptr &term) { parms_[term].clear(); }
96 |
97 | private:
98 | Program &program_;
99 | TermMap> parms_;
100 |
101 | bool isModSwitchOp(const Op &op_code) { return (op_code == Op::ModSwitch); }
102 |
103 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); }
104 | };
105 |
106 | } // namespace eva
107 |
--------------------------------------------------------------------------------
/eva/ckks/rescaler.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 |
9 | namespace eva {
10 |
11 | class Rescaler {
12 | protected:
13 | Program &program;
14 | TermMap &type;
15 | TermMapOptional &scale;
16 |
17 | Rescaler(Program &g, TermMap &type,
18 | TermMapOptional &scale)
19 | : program(g), type(type), scale(scale) {}
20 |
21 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); }
22 |
23 | bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); }
24 |
25 | bool isAdditionOp(const Op &op_code) {
26 | return ((op_code == Op::Add) || (op_code == Op::Sub));
27 | }
28 |
29 | auto insertRescale(Term::Ptr term, std::uint32_t rescaleBy) {
30 | // auto scale = term->getScale();
31 | auto rescaleNode = program.makeRescale(term, rescaleBy);
32 | type[rescaleNode] = type[term];
33 | scale[rescaleNode] = scale[term] - rescaleBy;
34 |
35 | term->replaceOtherUsesWith(rescaleNode);
36 |
37 | return rescaleNode;
38 | }
39 |
40 | void insertRescaleBetween(Term::Ptr term1, Term::Ptr term2,
41 | std::uint32_t rescaleBy) {
42 | auto rescaleNode = program.makeRescale(term1, rescaleBy);
43 | type[rescaleNode] = type[term1];
44 | scale[rescaleNode] = scale[term1] - rescaleBy;
45 |
46 | term2->replaceOperand(term1, rescaleNode);
47 | }
48 |
49 | void handleRawScale(Term::Ptr term) {
50 | if (term->numOperands() > 0) {
51 | int maxScale = 0;
52 | for (auto &operand : term->getOperands()) {
53 | if (scale.at(operand) > maxScale) maxScale = scale.at(operand);
54 | }
55 | scale[term] = maxScale;
56 | }
57 | }
58 | };
59 |
60 | } // namespace eva
61 |
--------------------------------------------------------------------------------
/eva/ckks/scales_checker.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 | #include
9 | #include
10 | #include
11 |
12 | namespace eva {
13 |
14 | class ScalesChecker {
15 | public:
16 | ScalesChecker(Program &g, TermMapOptional &scales,
17 | TermMap &types)
18 | : program_(g), scales_(g), types_(types) {}
19 |
20 | void operator()(const Term::Ptr &term) {
21 | // Must only be used with forward pass traversal
22 | if (types_[term] == Type::Raw) {
23 | return;
24 | }
25 | auto &operands = term->getOperands();
26 |
27 | // Nothing to do for source terms
28 | if (term->op == Op::Input || term->op == Op::Encode) {
29 | scales_[term] = term->get();
30 | if (scales_.at(term) == 0) {
31 | if (term->op == Op::Input) {
32 | throw std::runtime_error("Program has an input with 0 scale");
33 | } else {
34 | throw std::logic_error("Compiled program results in a 0 scale term");
35 | }
36 | }
37 | } else if (term->op == Op::Mul) {
38 | assert(term->numOperands() == 2);
39 | std::uint32_t scale = 0;
40 | for (auto &operand : operands) {
41 | scale += scales_.at(operand);
42 | }
43 | if (scale == 0) {
44 | throw std::logic_error("Compiled program results in a 0 scale term");
45 | }
46 | scales_[term] = scale;
47 | } else if (term->op == Op::Rescale) {
48 | assert(term->numOperands() == 1);
49 | auto divisor = term->get();
50 | auto operandScale = scales_.at(term->operandAt(0));
51 | std::uint32_t scale = operandScale - divisor;
52 | if (scale == 0) {
53 | throw std::logic_error("Compiled program results in a 0 scale term");
54 | }
55 | scales_[term] = scale;
56 |
57 | } else if (isAdditionOp(term->op)) {
58 | std::uint32_t scale = 0;
59 | for (auto &operand : operands) {
60 | if (scale == 0) {
61 | scale = scales_.at(operand);
62 | } else {
63 | if (scale != scales_.at(operand)) {
64 | throw std::logic_error("Addition or subtraction in program has "
65 | "operands of non-equal scale");
66 | }
67 | }
68 | }
69 | if (scale == 0) {
70 | throw std::logic_error("Compiled program results in a 0 scale term");
71 | }
72 | scales_[term] = scale;
73 | } else {
74 | auto scale = scales_.at(term->operandAt(0));
75 | if (scale == 0) {
76 | throw std::logic_error("Compiled program results in a 0 scale term");
77 | }
78 | scales_[term] = scale;
79 | }
80 | }
81 |
82 | void free(const Term::Ptr &term) {
83 | // No-op
84 | }
85 |
86 | private:
87 | Program &program_;
88 | TermMapOptional scales_;
89 | TermMap &types_;
90 |
91 | bool isAdditionOp(const Op &op_code) {
92 | return ((op_code == Op::Add) || (op_code == Op::Sub));
93 | }
94 | };
95 |
96 | } // namespace eva
97 |
--------------------------------------------------------------------------------
/eva/ckks/seal_lowering.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 |
9 | namespace eva {
10 |
11 | class SEALLowering {
12 | Program &program;
13 | TermMap &type;
14 |
15 | public:
16 | SEALLowering(Program &g, TermMap &type) : program(g), type(type) {}
17 |
18 | void
19 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
20 |
21 | // SEAL does not support plaintext subtraction with a plaintext on the left
22 | // hand side, so lower to a negation and addition.
23 | if (term->op == Op::Sub && type[term->operandAt(0)] != Type::Cipher &&
24 | type[term->operandAt(1)] == Type::Cipher) {
25 | auto negation = program.makeTerm(Op::Negate, {term->operandAt(1)});
26 | auto addition = program.makeTerm(Op::Add, {term->operandAt(0), negation});
27 | term->replaceAllUsesWith(addition);
28 | }
29 | }
30 | };
31 |
32 | } // namespace eva
33 |
--------------------------------------------------------------------------------
/eva/common/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT license.
3 |
4 | target_sources(eva PRIVATE
5 | reference_executor.cpp
6 | )
7 |
--------------------------------------------------------------------------------
/eva/common/constant_folder.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 |
9 | namespace eva {
10 |
11 | class ConstantFolder {
12 | Program &program;
13 | TermMapOptional &scale;
14 | std::vector scratch1, scratch2;
15 |
16 | bool isRescaleOp(const Op &op_code) { return (op_code == Op::Rescale); }
17 |
18 | bool isMultiplicationOp(const Op &op_code) { return (op_code == Op::Mul); }
19 |
20 | bool isAdditionOp(const Op &op_code) {
21 | return ((op_code == Op::Add) || (op_code == Op::Sub));
22 | }
23 |
24 | void replaceNodeWithConstant(Term::Ptr term,
25 | const std::vector &output,
26 | double termScale) {
27 | // TODO: optimize output representations
28 | auto constant = program.makeDenseConstant(output);
29 | scale[constant] = termScale;
30 | constant->set(scale[constant]);
31 |
32 | term->replaceAllUsesWith(constant);
33 | assert(term->numUses() == 0);
34 | }
35 |
36 | void add(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) {
37 | auto &input1 = args1->get()->expand(
38 | scratch1, program.getVecSize());
39 | auto &input2 = args2->get()->expand(
40 | scratch2, program.getVecSize());
41 |
42 | std::vector outputValue(input1.size());
43 | for (std::uint64_t i = 0; i < outputValue.size(); ++i) {
44 | outputValue[i] = input1[i] + input2[i];
45 | }
46 |
47 | replaceNodeWithConstant(output, outputValue,
48 | std::max(scale[args1], scale[args2]));
49 | }
50 |
51 | void sub(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) {
52 | auto &input1 = args1->get()->expand(
53 | scratch1, program.getVecSize());
54 | auto &input2 = args2->get()->expand(
55 | scratch2, program.getVecSize());
56 |
57 | std::vector outputValue(input1.size());
58 | for (std::uint64_t i = 0; i < outputValue.size(); ++i) {
59 | outputValue[i] = input1[i] - input2[i];
60 | }
61 |
62 | replaceNodeWithConstant(output, outputValue,
63 | std::max(scale[args1], scale[args2]));
64 | }
65 |
66 | void mul(Term::Ptr output, const Term::Ptr &args1, const Term::Ptr &args2) {
67 | auto &input1 = args1->get()->expand(
68 | scratch1, program.getVecSize());
69 | auto &input2 = args2->get()->expand(
70 | scratch2, program.getVecSize());
71 |
72 | std::vector outputValue(input1.size());
73 | for (std::uint64_t i = 0; i < outputValue.size(); ++i) {
74 | outputValue[i] = input1[i] * input2[i];
75 | }
76 |
77 | replaceNodeWithConstant(output, outputValue,
78 | std::max(scale[args1], scale[args2]));
79 | }
80 |
81 | void leftRotate(Term::Ptr output, const Term::Ptr &args1,
82 | std::int32_t shift) {
83 | auto &input1 = args1->get()->expand(
84 | scratch1, program.getVecSize());
85 |
86 | while (shift > 0 && shift >= input1.size())
87 | shift -= input1.size();
88 | while (shift < 0)
89 | shift += input1.size();
90 |
91 | std::vector outputValue(input1.size());
92 | for (std::uint64_t i = 0; i < (outputValue.size() - shift); ++i) {
93 | outputValue[i] = input1[i + shift];
94 | }
95 | for (std::uint64_t i = 0; i < shift; ++i) {
96 | outputValue[outputValue.size() - shift + i] = input1[i];
97 | }
98 |
99 | replaceNodeWithConstant(output, outputValue, scale[args1]);
100 | }
101 |
102 | void rightRotate(Term::Ptr output, const Term::Ptr &args1,
103 | std::int32_t shift) {
104 | auto &input1 = args1->get()->expand(
105 | scratch1, program.getVecSize());
106 |
107 | while (shift > 0 && shift >= input1.size())
108 | shift -= input1.size();
109 | while (shift < 0)
110 | shift += input1.size();
111 |
112 | std::vector outputValue(input1.size());
113 | for (std::uint64_t i = 0; i < (outputValue.size() - shift); ++i) {
114 | outputValue[i + shift] = input1[i];
115 | }
116 | for (std::uint64_t i = 0; i < shift; ++i) {
117 | outputValue[i] = input1[outputValue.size() - shift + i];
118 | }
119 |
120 | replaceNodeWithConstant(output, outputValue, scale[args1]);
121 | }
122 |
123 | void negate(Term::Ptr output, const Term::Ptr &args1) {
124 | auto &input1 = args1->get()->expand(
125 | scratch1, program.getVecSize());
126 |
127 | std::vector outputValue(input1.size());
128 | for (std::uint64_t i = 0; i < outputValue.size(); ++i) {
129 | outputValue[i] = -input1[i];
130 | }
131 |
132 | replaceNodeWithConstant(output, outputValue, scale[args1]);
133 | }
134 |
135 | public:
136 | ConstantFolder(Program &g, TermMapOptional &scale)
137 | : program(g), scale(scale) {}
138 |
139 | void
140 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
141 | auto &args = term->getOperands();
142 | if (args.size() == 0) return; // inputs
143 |
144 | for (auto &arg : args) {
145 | if (arg->op != Op::Constant) return;
146 | }
147 |
148 | auto op_code = term->op;
149 | switch (op_code) {
150 | case Op::Add:
151 | assert(args.size() == 2);
152 | add(term, args[0], args[1]);
153 | break;
154 | case Op::Sub:
155 | assert(args.size() == 2);
156 | sub(term, args[0], args[1]);
157 | break;
158 | case Op::Mul:
159 | assert(args.size() == 2);
160 | mul(term, args[0], args[1]);
161 | break;
162 | case Op::RotateLeftConst:
163 | assert(args.size() == 1);
164 | leftRotate(term, args[0], term->get());
165 | break;
166 | case Op::RotateRightConst:
167 | assert(args.size() == 1);
168 | rightRotate(term, args[0], term->get());
169 | break;
170 | case Op::Negate:
171 | assert(args.size() == 1);
172 | negate(term, args[0]);
173 | break;
174 | case Op::Output:
175 | [[fallthrough]];
176 | case Op::Encode:
177 | break;
178 | case Op::Relinearize:
179 | [[fallthrough]];
180 | case Op::ModSwitch:
181 | [[fallthrough]];
182 | case Op::Rescale:
183 | throw std::logic_error("Encountered HE specific operation " +
184 | getOpName(op_code) +
185 | " in unencrypted computation");
186 | default:
187 | throw std::logic_error("Unhandled op " + getOpName(op_code));
188 | }
189 | }
190 | };
191 |
192 | } // namespace eva
193 |
--------------------------------------------------------------------------------
/eva/common/multicore_program_traversal.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 | #include "eva/util/galois.h"
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | namespace eva {
19 |
20 | class MulticoreProgramTraversal {
21 | public:
22 | MulticoreProgramTraversal(Program &g) : program_(g) {}
23 |
24 | template void forwardPass(Evaluator &eval) {
25 | TermMap predecessors(program_);
26 | TermMap successors(program_);
27 |
28 | // Add the source terms
29 | galois::InsertBag readyNodes;
30 | for (auto source : program_.getSources()) {
31 | readyNodes.push_back(source);
32 | }
33 |
34 | // Enumerate predecessors and successors
35 | galois::for_each(
36 | galois::iterate(readyNodes),
37 | [&](const Term::Ptr &term, auto &ctx) {
38 | // For each term, iterate over its uses
39 | for (auto &use : term->getUses()) {
40 | // Increment the number of successors
41 | ++successors[term];
42 |
43 | // Increment the number of predecessors
44 | if ((++predecessors[use]) == 1) {
45 | // Only first predecessor will push so each use is added once
46 | ctx.push_back(use);
47 | }
48 | }
49 | },
50 | galois::wl>(),
51 | galois::no_stats(),
52 | galois::loopname("ForwardCountPredecessorsSuccessors"));
53 |
54 | // Traverse the program
55 | galois::for_each(
56 | galois::iterate(readyNodes),
57 | [&](const Term::Ptr &term, auto &ctx) {
58 | // Process the current term
59 | eval(term);
60 |
61 | // Free operands if their successors are done
62 | for (auto &operand : term->getOperands()) {
63 | if ((--successors[operand]) == 0) {
64 | // Only last successor will free
65 | eval.free(operand);
66 | }
67 | }
68 |
69 | // Execute (ready) uses if their predecessors are done
70 | for (auto &use : term->getUses()) {
71 | if ((--predecessors[use]) == 0) {
72 | // Only last predecessor will push
73 | ctx.push_back(use);
74 | }
75 | }
76 | },
77 | galois::wl>(),
78 | galois::no_stats(), galois::loopname("ForwardTraversal"));
79 |
80 | // TODO: Reinstate these checks
81 | // for (auto& predecessor : predecessors) assert(predecessor == 0);
82 | // for (auto& successor : successors) assert(successor == 0);
83 | }
84 |
85 | template void backwardPass(Evaluator &eval) {
86 | TermMap predecessors(program_);
87 | TermMap successors(program_);
88 |
89 | // Add the sink terms
90 | galois::InsertBag readyNodes;
91 | for (auto &sink : program_.getSinks()) {
92 | readyNodes.push_back(sink);
93 | }
94 |
95 | // Enumerate predecessors and successors
96 | galois::for_each(
97 | galois::iterate(readyNodes),
98 | [&](const Term::Ptr &term, auto &ctx) {
99 | // For each term, iterate over its operands
100 | for (auto &operand : term->getOperands()) {
101 | // Increment the number of predecessors
102 | ++predecessors[term];
103 |
104 | // Increment the number of successors for the operand
105 | if ((++successors[operand]) == 1) {
106 | // Only first successor will push so each operand is added once
107 | ctx.push_back(operand);
108 | }
109 | }
110 | },
111 | galois::wl>(),
112 | galois::no_stats(),
113 | galois::loopname("BackwardCountPredecessorsSuccessors"));
114 |
115 | // Traverse the program
116 | galois::for_each(
117 | galois::iterate(readyNodes),
118 | [&](const Term::Ptr &term, auto &ctx) {
119 | // Process the current term
120 | eval(term);
121 |
122 | // Free uses if their predecessors are done
123 | for (auto &use : term->getUses()) {
124 | if ((--predecessors[use]) == 0) {
125 | // Only last predecessor will free
126 | eval.free(use);
127 | }
128 | }
129 |
130 | // Execute (ready) operands if their successors are done
131 | for (auto &operand : term->getOperands()) {
132 | if ((--successors[operand]) == 0) {
133 | // Only last successor will push
134 | ctx.push_back(operand);
135 | }
136 | }
137 | },
138 | galois::wl>(),
139 | galois::no_stats(), galois::loopname("BackwardTraversal"));
140 |
141 | // TODO: Reinstate these checks
142 | // for (auto& predecessor : predecessors) assert(predecessor == 0);
143 | // for (auto& successor : successors) assert(successor == 0);
144 | }
145 |
146 | private:
147 | Program &program_;
148 | GaloisGuard galoisGuard_;
149 | };
150 |
151 | } // namespace eva
152 |
--------------------------------------------------------------------------------
/eva/common/program_traversal.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 | #include "eva/util/logging.h"
9 | #include
10 | #include
11 |
12 | namespace eva {
13 |
14 | /*
15 | Implements efficient forward and backward traversals of Program in the
16 | presence of modifications during traversal.
17 | The rewriter is called for each term in the Program exactly once.
18 | Rewriters must not modify the Program in such a way that terms that are
19 | not uses/operands (for forward/backward traversal, respectively) of the
20 | current term are enabled. With such modifications the whole program is
21 | not guaranteed to be traversed.
22 | */
23 | class ProgramTraversal {
24 | Program &program;
25 |
26 | TermMap ready;
27 | TermMap processed;
28 |
29 | template bool arePredecessorsDone(const Term::Ptr &term) {
30 | for (auto &operand : isForward ? term->getOperands() : term->getUses()) {
31 | if (!processed[operand]) return false;
32 | }
33 | return true;
34 | }
35 |
36 | template
37 | void traverse(Rewriter &&rewrite) {
38 | processed.clear();
39 | ready.clear();
40 |
41 | std::vector readyNodes =
42 | isForward ? program.getSources() : program.getSinks();
43 | for (auto &term : readyNodes) {
44 | ready[term] = true;
45 | }
46 | // Used for remembering uses/operands before rewrite is called. Using a
47 | // vector here is fine because duplicates in the list are handled
48 | // gracefully.
49 | std::vector checkList;
50 |
51 | while (readyNodes.size() != 0) {
52 | // Pop term to transform
53 | auto term = readyNodes.back();
54 | readyNodes.pop_back();
55 |
56 | // If this term is removed, we will lose uses/operands of this term.
57 | // Remember them here for checking readyness after the rewrite.
58 | checkList.clear();
59 | for (auto &succ : isForward ? term->getUses() : term->getOperands()) {
60 | checkList.push_back(succ);
61 | }
62 |
63 | log(Verbosity::Trace, "Processing term with index=%lu", term->index);
64 | rewrite(term);
65 | processed[term] = true;
66 |
67 | // If transform adds new sources/sinks add them to ready terms.
68 | for (auto &leaf : isForward ? program.getSources() : program.getSinks()) {
69 | if (!ready[leaf]) {
70 | readyNodes.push_back(leaf);
71 | ready[leaf] = true;
72 | }
73 | }
74 |
75 | // Also check current uses/operands in case any new ones were added.
76 | for (auto &succ : isForward ? term->getUses() : term->getOperands()) {
77 | checkList.push_back(succ);
78 | }
79 |
80 | // Push and mark uses/operands that are ready to be processed.
81 | for (auto &succ : checkList) {
82 | if (!ready[succ] && arePredecessorsDone(succ)) {
83 | readyNodes.push_back(succ);
84 | ready[succ] = true;
85 | }
86 | }
87 | }
88 | }
89 |
90 | public:
91 | ProgramTraversal(Program &g) : program(g), processed(g), ready(g) {}
92 |
93 | template void forwardPass(Rewriter &&rewrite) {
94 | traverse(std::forward(rewrite));
95 | }
96 |
97 | template void backwardPass(Rewriter &&rewrite) {
98 | traverse(std::forward(rewrite));
99 | }
100 | };
101 |
102 | } // namespace eva
103 |
--------------------------------------------------------------------------------
/eva/common/reduction_balancer.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 | #include
9 |
10 | namespace eva {
11 |
12 | /*
13 | This pass combines nodes to reduce the depth of the tree.
14 | Suppose you have expression tree
15 |
16 | *
17 | / \
18 | * t5(c) => *
19 | / \ / \ \
20 | t1(c) t2(c) t1 t2 t5
21 |
22 |
23 | Before combining first it checks if some node have only one use and both
24 | of these nodes have same op then these two nodes are combined into one node
25 | with children of both the nodes.
26 | This pass helps to get the flat form of an expression so that later on it can
27 | be expanded to get a expression in a balanced form.
28 | For example (a * (b * (c * d))) => (a * b * c * d) => (a * b) * (c * d)
29 | */
30 | class ReductionCombiner {
31 | Program &program;
32 |
33 | bool isReductionOp(const Op &op_code) {
34 | return ((op_code == Op::Add) || (op_code == Op::Mul));
35 | }
36 |
37 | public:
38 | ReductionCombiner(Program &g) : program(g) {}
39 |
40 | void
41 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
42 | if (!term->isInternal() || !isReductionOp(term->op)) return;
43 |
44 | auto uses = term->getUses();
45 | if (uses.size() == 1) {
46 | auto &use = uses[0];
47 | if (use->op == term->op) {
48 | // combine term and its use
49 | while (use->eraseOperand(term)) {
50 | for (auto &operand : term->getOperands()) {
51 | // add term's operands to use's operands
52 | use->addOperand(operand);
53 | }
54 | }
55 | }
56 | }
57 | }
58 | };
59 |
60 | class ReductionLogExpander {
61 | Program &program;
62 | TermMap &type;
63 | TermMapOptional scale;
64 | std::vector operands, nextOperands;
65 | std::map> sortedOperands;
66 |
67 | bool isReductionOp(const Op &op_code) {
68 | return ((op_code == Op::Add) || (op_code == Op::Mul));
69 | }
70 |
71 | public:
72 | ReductionLogExpander(Program &g, TermMap &type)
73 | : program(g), type(type), scale(g) {}
74 |
75 | void operator()(Term::Ptr &term) {
76 | if (term->op == Op::Rescale || term->op == Op::ModSwitch) {
77 | throw std::logic_error("Rescale or ModSwitch encountered, but "
78 | "ReductionLogExpander uses scale as"
79 | " a proxy for level and assumes rescaling has not "
80 | "been performed yet.");
81 | }
82 |
83 | // Calculate the scales that we would get without any rescaling. Terms at a
84 | // similar scale will likely end up having the same level in typical
85 | // rescaling policies, which helps the sorting group terms of the same level
86 | // together.
87 | if (term->numOperands() == 0) {
88 | scale[term] = term->get();
89 | } else if (term->op == Op::Mul) {
90 | scale[term] = std::accumulate(
91 | term->getOperands().begin(), term->getOperands().end(), 0,
92 | [&](auto &sum, auto &operand) { return sum + scale.at(operand); });
93 | } else {
94 | scale[term] = std::accumulate(term->getOperands().begin(),
95 | term->getOperands().end(), 0,
96 | [&](auto &max, auto &operand) {
97 | return std::max(max, scale.at(operand));
98 | });
99 | }
100 |
101 | if (isReductionOp(term->op) && term->numOperands() > 2) {
102 | // We sort operands into constants, plaintext and raw, then ciphertexts by
103 | // scale. This helps avoid unnecessary accumulation of scale.
104 | for (auto &operand : term->getOperands()) {
105 | auto order = 0;
106 | if (type[operand] == Type::Plain || type[operand] == Type::Raw) {
107 | order = 1;
108 | } else if (type[operand] == Type::Cipher) {
109 | order = 2 + scale.at(operand);
110 | }
111 | sortedOperands[order].push_back(operand);
112 | }
113 | for (auto &op : sortedOperands) {
114 | operands.insert(operands.end(), op.second.begin(), op.second.end());
115 | }
116 |
117 | // Expand the sorted operands into a balanced reduction tree by pairing
118 | // adjacent operands until only one remains.
119 | assert(operands.size() >= 2);
120 | while (operands.size() > 2) {
121 | std::size_t i = 0;
122 | while ((i + 1) < operands.size()) {
123 | auto &leftOperand = operands[i];
124 | auto &rightOperand = operands[i + 1];
125 | auto newTerm =
126 | program.makeTerm(term->op, {leftOperand, rightOperand});
127 | nextOperands.push_back(newTerm);
128 | i += 2;
129 | }
130 | if (i < operands.size()) {
131 | assert((i + 1) == operands.size());
132 | nextOperands.push_back(operands[i]);
133 | }
134 | operands = nextOperands;
135 | nextOperands.clear();
136 | }
137 |
138 | assert(operands.size() == 2);
139 | term->setOperands(operands);
140 |
141 | operands.clear();
142 | nextOperands.clear();
143 | sortedOperands.clear();
144 | }
145 | }
146 | };
147 |
148 | } // namespace eva
149 |
--------------------------------------------------------------------------------
/eva/common/reference_executor.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #include "eva/common/reference_executor.h"
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | using namespace std;
11 |
12 | namespace eva {
13 |
14 | void ReferenceExecutor::leftRotate(vector &output,
15 | const Term::Ptr &args, int32_t shift) {
16 | auto &input = terms_.at(args);
17 |
18 | // Reserve enough space for output
19 | output.clear();
20 | output.reserve(input.size());
21 |
22 | while (shift > 0 && shift >= input.size())
23 | shift -= input.size();
24 | while (shift < 0)
25 | shift += input.size();
26 |
27 | // Shift left and copy to output
28 | copy_n(input.cbegin() + shift, input.size() - shift, back_inserter(output));
29 | copy_n(input.cbegin(), shift, back_inserter(output));
30 | }
31 |
32 | void ReferenceExecutor::rightRotate(vector &output,
33 | const Term::Ptr &args, int32_t shift) {
34 | auto &input = terms_.at(args);
35 |
36 | // Reserve enough space for output
37 | output.clear();
38 | output.reserve(input.size());
39 |
40 | while (shift > 0 && shift >= input.size())
41 | shift -= input.size();
42 | while (shift < 0)
43 | shift += input.size();
44 |
45 | // Shift right and copy to output
46 | copy_n(input.cend() - shift, shift, back_inserter(output));
47 | copy_n(input.cbegin(), input.size() - shift, back_inserter(output));
48 | }
49 |
50 | void ReferenceExecutor::negate(vector &output, const Term::Ptr &args) {
51 | auto &input = terms_.at(args);
52 |
53 | // Reserve enough space for output
54 | output.clear();
55 | output.reserve(input.size());
56 | transform(input.cbegin(), input.cend(), back_inserter(output),
57 | std::negate());
58 | }
59 |
60 | void ReferenceExecutor::operator()(const Term::Ptr &term) {
61 | // Must only be used with forward pass traversal
62 | auto &output = terms_[term];
63 |
64 | auto op = term->op;
65 | auto args = term->getOperands();
66 |
67 | switch (op) {
68 | case Op::Input:
69 | // Nothing to do for inputs
70 | break;
71 | case Op::Constant:
72 | // A constant (vector) is expanded to the number of slots (vecSize_ here)
73 | term->get()->expandTo(output, vecSize_);
74 | break;
75 | case Op::Add:
76 | assert(args.size() == 2);
77 | binOp>(output, args[0], args[1]);
78 | break;
79 | case Op::Sub:
80 | assert(args.size() == 2);
81 | binOp>(output, args[0], args[1]);
82 | break;
83 | case Op::Mul:
84 | assert(args.size() == 2);
85 | binOp>(output, args[0], args[1]);
86 | break;
87 | case Op::RotateLeftConst:
88 | assert(args.size() == 1);
89 | leftRotate(output, args[0], term->get());
90 | break;
91 | case Op::RotateRightConst:
92 | assert(args.size() == 1);
93 | rightRotate(output, args[0], term->get());
94 | break;
95 | case Op::Negate:
96 | assert(args.size() == 1);
97 | negate(output, args[0]);
98 | break;
99 | case Op::Encode:
100 | [[fallthrough]];
101 | case Op::Output:
102 | [[fallthrough]];
103 | case Op::Relinearize:
104 | [[fallthrough]];
105 | case Op::ModSwitch:
106 | [[fallthrough]];
107 | case Op::Rescale:
108 | // Copy argument value for outputs
109 | assert(args.size() == 1);
110 | output = terms_[args[0]];
111 | break;
112 | default:
113 | assert(false);
114 | }
115 | }
116 |
117 | } // namespace eva
118 |
--------------------------------------------------------------------------------
/eva/common/reference_executor.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/common/valuation.h"
7 | #include "eva/ir/program.h"
8 | #include "eva/ir/term_map.h"
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | namespace eva {
15 |
16 | // Executes unencrypted computation
17 | class ReferenceExecutor {
18 | public:
19 | ReferenceExecutor(Program &g)
20 | : program_(g), vecSize_(g.getVecSize()), terms_(g) {}
21 |
22 | ReferenceExecutor(const ReferenceExecutor ©) = delete;
23 |
24 | ReferenceExecutor &operator=(const ReferenceExecutor &assign) = delete;
25 |
26 | template
27 | void setInputs(const std::unordered_map &inputs) {
28 | for (auto &in : inputs) {
29 | auto term = program_.getInput(in.first);
30 | terms_[term] = in.second; // TODO: can we avoid this copy?
31 | if (terms_[term].size() != vecSize_) {
32 | throw std::runtime_error(
33 | "The length of all inputs must be the same as program's vector "
34 | "size. Input " +
35 | in.first + " has length " + std::to_string(terms_[term].size()) +
36 | ", but vector size is " + std::to_string(vecSize_));
37 | }
38 | }
39 | }
40 |
41 | void operator()(const Term::Ptr &term);
42 |
43 | void free(const Term::Ptr &term) {
44 | if (term->op == Op::Output) return;
45 | terms_[term].clear();
46 | }
47 |
48 | void getOutputs(Valuation &outputs) {
49 | for (auto &out : program_.getOutputs()) {
50 | outputs[out.first] = terms_[out.second];
51 | }
52 | }
53 |
54 | private:
55 | Program &program_;
56 | std::uint64_t vecSize_;
57 | TermMapOptional> terms_;
58 |
59 | template
60 | void binOp(std::vector &out, const Term::Ptr &args1,
61 | const Term::Ptr &args2) {
62 | auto &in1 = terms_.at(args1);
63 | auto &in2 = terms_.at(args2);
64 | assert(in1.size() == in2.size());
65 |
66 | out.clear();
67 | out.reserve(in1.size());
68 | transform(in1.cbegin(), in1.cend(), in2.cbegin(), back_inserter(out), Op());
69 | }
70 |
71 | void leftRotate(std::vector &output, const Term::Ptr &args,
72 | std::int32_t shift);
73 |
74 | void rightRotate(std::vector &output, const Term::Ptr &args,
75 | std::int32_t shift);
76 |
77 | void negate(std::vector &output, const Term::Ptr &args);
78 | };
79 |
80 | } // namespace eva
81 |
--------------------------------------------------------------------------------
/eva/common/rotation_keys_selector.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | namespace eva {
14 |
15 | class RotationKeysSelector {
16 | public:
17 | RotationKeysSelector(Program &g, const TermMap &type)
18 | : program_(g), type(type) {}
19 |
20 | void operator()(const Term::Ptr &term) {
21 | auto op = term->op;
22 |
23 | // Nothing to do if this is not a rotation
24 | if (!isLeftRotationOp(op) && !isRightRotationOp(op)) return;
25 |
26 | // No rotation keys needed for raw computation
27 | if (type[term] == Type::Raw) return;
28 |
29 | // Add the rotation count
30 | auto rotation = term->get();
31 | keys_.insert(isRightRotationOp(op) ? -rotation : rotation);
32 | }
33 |
34 | void free(const Term::Ptr &term) {
35 | // No-op
36 | }
37 |
38 | auto getRotationKeys() {
39 | // Return the set of rotations needed
40 | return keys_;
41 | }
42 |
43 | private:
44 | Program &program_;
45 | const TermMap &type;
46 | std::set keys_;
47 |
48 | bool isLeftRotationOp(const Op &op_code) {
49 | return (op_code == Op::RotateLeftConst);
50 | }
51 |
52 | bool isRightRotationOp(const Op &op_code) {
53 | return (op_code == Op::RotateRightConst);
54 | }
55 | };
56 |
57 | } // namespace eva
58 |
--------------------------------------------------------------------------------
/eva/common/type_deducer.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/program.h"
7 | #include "eva/ir/term_map.h"
8 |
9 | namespace eva {
10 |
11 | class TypeDeducer {
12 | Program &program;
13 | TermMap &types;
14 |
15 | public:
16 | TypeDeducer(Program &g, TermMap &types) : program(g), types(types) {}
17 |
18 | void
19 | operator()(Term::Ptr &term) { // must only be used with forward pass traversal
20 | auto &operands = term->getOperands();
21 | if (operands.size() > 0) { // not an input/root
22 | Type inferred = Type::Raw; // Plain if not Cipher
23 | for (auto &operand : operands) {
24 | if (types[operand] == Type::Cipher)
25 | inferred = Type::Cipher; // Cipher if any operand is Cipher
26 | }
27 | if (term->op == Op::Encode) {
28 | types[term] = Type::Plain;
29 | } else {
30 | types[term] = inferred;
31 | }
32 | } else if (term->op == Op::Constant) {
33 | types[term] = Type::Raw;
34 | } else {
35 | types[term] = term->get();
36 | }
37 | }
38 | };
39 |
40 | } // namespace eva
41 |
--------------------------------------------------------------------------------
/eva/common/valuation.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation.
2 | // Licensed under the MIT License.
3 |
4 | #pragma once
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | namespace eva {
11 |
12 | using Valuation = std::unordered_map>;
13 |
14 | }
15 |
--------------------------------------------------------------------------------
/eva/eva.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #include "eva/eva.h"
5 | #include "eva/common/program_traversal.h"
6 | #include "eva/common/reference_executor.h"
7 | #include "eva/common/valuation.h"
8 |
9 | namespace eva {
10 |
11 | Valuation evaluate(Program &program, const Valuation &inputs) {
12 | Valuation outputs;
13 | ProgramTraversal programTraverse(program);
14 | ReferenceExecutor ge(program);
15 |
16 | ge.setInputs(inputs);
17 | programTraverse.forwardPass(ge);
18 | ge.getOutputs(outputs);
19 |
20 | return outputs;
21 | }
22 |
23 | } // namespace eva
24 |
--------------------------------------------------------------------------------
/eva/eva.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ckks/ckks_compiler.h"
7 | #include "eva/ir/program.h"
8 | #include "eva/seal/seal.h"
9 | #include "eva/serialization/save_load.h"
10 | #include "eva/version.h"
11 |
12 | namespace eva {
13 |
14 | Valuation evaluate(Program &program, const Valuation &inputs);
15 |
16 | }
17 |
--------------------------------------------------------------------------------
/eva/ir/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT license.
3 |
4 | target_sources(eva PRIVATE
5 | term.cpp
6 | program.cpp
7 | attribute_list.cpp
8 | attributes.cpp
9 | )
10 |
--------------------------------------------------------------------------------
/eva/ir/attribute_list.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #include "eva/ir/attribute_list.h"
5 | #include "eva/ir/attributes.h"
6 | #include "eva/util/overloaded.h"
7 | #include
8 |
9 | using namespace std;
10 |
11 | namespace eva {
12 |
13 | bool AttributeList::has(AttributeKey k) const {
14 | const AttributeList *curr = this;
15 | while (true) {
16 | if (curr->key < k) {
17 | if (curr->tail) {
18 | curr = curr->tail.get();
19 | } else {
20 | return false;
21 | }
22 | } else {
23 | return curr->key == k;
24 | }
25 | }
26 | }
27 |
28 | const AttributeValue &AttributeList::get(AttributeKey k) const {
29 | const AttributeList *curr = this;
30 | while (true) {
31 | if (curr->key == k) {
32 | return curr->value;
33 | } else if (curr->key < k && curr->tail) {
34 | curr = curr->tail.get();
35 | } else {
36 | throw out_of_range("Attribute not in list: " + getAttributeName(k));
37 | }
38 | }
39 | }
40 |
41 | void AttributeList::set(AttributeKey k, AttributeValue v) {
42 | if (this->key == 0) {
43 | this->key = k;
44 | this->value = move(v);
45 | } else {
46 | AttributeList *curr = this;
47 | AttributeList *prev = nullptr;
48 | while (true) {
49 | if (curr->key < k) {
50 | if (curr->tail) {
51 | prev = curr;
52 | curr = curr->tail.get();
53 | } else { // Insert at end
54 | // AttributeList constructor is private
55 | curr->tail = unique_ptr{new AttributeList(k, move(v))};
56 | return;
57 | }
58 | } else if (curr->key > k) {
59 | if (prev) { // Insert between
60 | // AttributeList constructor is private
61 | auto newList =
62 | unique_ptr{new AttributeList(k, move(v))};
63 | newList->tail = move(prev->tail);
64 | prev->tail = move(newList);
65 | } else { // Insert at beginning
66 | // AttributeList constructor is private
67 | curr->tail =
68 | unique_ptr{new AttributeList(move(*curr))};
69 | curr->key = k;
70 | curr->value = move(v);
71 | }
72 | return;
73 | } else {
74 | assert(curr->key == k);
75 | curr->value = move(v);
76 | return;
77 | }
78 | }
79 | }
80 | }
81 |
82 | void AttributeList::assignAttributesFrom(const AttributeList &other) {
83 | if (this->key != 0) {
84 | this->key = 0;
85 | this->value = std::monostate();
86 | this->tail = nullptr;
87 | }
88 | if (other.key == 0) {
89 | return;
90 | }
91 | AttributeList *lhs = this;
92 | const AttributeList *rhs = &other;
93 | while (true) {
94 | lhs->key = rhs->key;
95 | lhs->value = rhs->value;
96 | if (rhs->tail) {
97 | rhs = rhs->tail.get();
98 | lhs->tail = std::make_unique();
99 | lhs = lhs->tail.get();
100 | } else {
101 | return;
102 | }
103 | }
104 | }
105 |
106 | } // namespace eva
107 |
--------------------------------------------------------------------------------
/eva/ir/attribute_list.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/constant_value.h"
7 | #include "eva/ir/types.h"
8 | #include "eva/serialization/eva.pb.h"
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | namespace eva {
18 |
19 | using AttributeValue = std::variant>;
21 |
22 | template struct IsInVariant;
23 | template
24 | struct IsInVariant>
25 | : std::bool_constant<(... || std::is_same{})> {};
26 |
27 | using AttributeKey = std::uint8_t;
28 |
29 | template struct Attribute {
30 | static_assert(IsInVariant::value,
31 | "Attribute type not in AttributeValue");
32 | static_assert(Key > 0, "Keys must be strictly positive");
33 | static_assert(Key <= std::numeric_limits::max(),
34 | "Key larger than current AttributeKey type");
35 |
36 | using Value = T;
37 | static constexpr AttributeKey key = Key;
38 |
39 | static bool isValid(AttributeKey k, const AttributeValue &v) {
40 | return k == Key && std::holds_alternative(v);
41 | }
42 | };
43 |
44 | class AttributeList {
45 | public:
46 | AttributeList() : key(0), tail(nullptr) {}
47 |
48 | // This function is defined in eva/serialization/eva_serialization.cpp
49 | void loadAttribute(const msg::Attribute &msg);
50 |
51 | // This function is defined in eva/serialization/eva_serialization.cpp
52 | void serializeAttributes(std::function addMsg) const;
53 |
54 | template bool has() const { return has(TAttr::key); }
55 |
56 | template const typename TAttr::Value &get() const {
57 | return std::get(get(TAttr::key));
58 | }
59 |
60 | template void set(typename TAttr::Value value) {
61 | set(TAttr::key, std::move(value));
62 | }
63 |
64 | void assignAttributesFrom(const AttributeList &other);
65 |
66 | private:
67 | AttributeKey key;
68 | AttributeValue value;
69 | std::unique_ptr tail;
70 |
71 | AttributeList(AttributeKey k, AttributeValue v)
72 | : key(k), value(std::move(v)) {}
73 |
74 | bool has(AttributeKey k) const;
75 | const AttributeValue &get(AttributeKey k) const;
76 | void set(AttributeKey k, AttributeValue v);
77 | };
78 |
79 | } // namespace eva
80 |
--------------------------------------------------------------------------------
/eva/ir/attributes.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #include "eva/ir/attributes.h"
5 | #include
6 |
7 | using namespace std;
8 |
9 | namespace eva {
10 |
11 | #define X(name, type) name::isValid(k, v) ||
12 | bool isValidAttribute(AttributeKey k, const AttributeValue &v) {
13 | return EVA_ATTRIBUTES false;
14 | }
15 | #undef X
16 |
17 | #define X(name, type) \
18 | case detail::name##Index: \
19 | return #name;
20 | string getAttributeName(AttributeKey k) {
21 | switch (k) {
22 | EVA_ATTRIBUTES
23 | default:
24 | throw runtime_error("Unknown attribute key");
25 | }
26 | }
27 | #undef X
28 |
29 | } // namespace eva
30 |
--------------------------------------------------------------------------------
/eva/ir/attributes.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/ir/attribute_list.h"
7 | #include
8 | #include
9 |
10 | namespace eva {
11 |
12 | #define EVA_ATTRIBUTES \
13 | X(RescaleDivisorAttribute, std::uint32_t) \
14 | X(RotationAttribute, std::int32_t) \
15 | X(ConstantValueAttribute, std::shared_ptr) \
16 | X(TypeAttribute, Type) \
17 | X(RangeAttribute, std::uint32_t) \
18 | X(EncodeAtScaleAttribute, std::uint32_t) \
19 | X(EncodeAtLevelAttribute, std::uint32_t)
20 |
21 | namespace detail {
22 | enum AttributeIndex {
23 | RESERVE_EMPTY_ATTRIBUTE_KEY = 0,
24 | #define X(name, type) name##Index,
25 | EVA_ATTRIBUTES
26 | #undef X
27 | };
28 | } // namespace detail
29 |
30 | #define X(name, type) using name = Attribute;
31 | EVA_ATTRIBUTES
32 | #undef X
33 |
34 | bool isValidAttribute(AttributeKey k, const AttributeValue &v);
35 |
36 | std::string getAttributeName(AttributeKey k);
37 |
38 | } // namespace eva
39 |
--------------------------------------------------------------------------------
/eva/ir/constant_value.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation. All rights reserved.
2 | // Licensed under the MIT license.
3 |
4 | #pragma once
5 |
6 | #include "eva/serialization/eva.pb.h"
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | namespace eva {
15 |
16 | class ConstantValue {
17 | public:
18 | ConstantValue(std::size_t size) : size(size) {}
19 | virtual ~ConstantValue() {}
20 | virtual const std::vector &expand(std::vector &scratch,
21 | std::size_t slots) const = 0;
22 | virtual void expandTo(std::vector &result,
23 | std::size_t slots) const = 0;
24 | virtual bool isZero() const = 0;
25 | virtual void serialize(msg::ConstantValue &msg) const = 0;
26 |
27 | protected:
28 | std::size_t size;
29 |
30 | void validateSlots(std::size_t slots) const {
31 | if (slots < size) {
32 | throw std::runtime_error("Slots must be at least size of constant");
33 | }
34 | if (slots % size != 0) {
35 | throw std::runtime_error("Size must exactly divide slots");
36 | }
37 | }
38 | };
39 |
40 | class DenseConstantValue : public ConstantValue {
41 | public:
42 | DenseConstantValue(std::size_t size, std::vector values)
43 | : ConstantValue(size), values(values) {
44 | if (size % values.size() != 0) {
45 | throw std::runtime_error(
46 | "DenseConstantValue size must exactly divide size");
47 | }
48 | }
49 |
50 | const std::vector &expand(std::vector &scratch,
51 | std::size_t slots) const override {
52 | validateSlots(slots);
53 | if (values.size() == slots) {
54 | return values;
55 | } else {
56 | scratch.clear();
57 | for (int r = slots / values.size(); r > 0; --r) {
58 | scratch.insert(scratch.end(), values.begin(), values.end());
59 | }
60 | return scratch;
61 | }
62 | }
63 |
64 | void expandTo(std::vector &result, std::size_t slots) const override {
65 | validateSlots(slots);
66 | result.clear();
67 | result.reserve(slots);
68 | for (int r = slots / values.size(); r > 0; --r) {
69 | result.insert(result.end(), values.begin(), values.end());
70 | }
71 | }
72 |
73 | bool isZero() const override {
74 | for (double value : values) {
75 | if (value != 0) return false;
76 | }
77 | return true;
78 | }
79 |
80 | void serialize(msg::ConstantValue &msg) const override {
81 | msg.set_size(size);
82 | auto valuesMsg = msg.mutable_values();
83 | valuesMsg->Reserve(values.size());
84 | for (const auto &value : values) {
85 | valuesMsg->Add(value);
86 | }
87 | }
88 |
89 | private:
90 | std::vector values;
91 | };
92 |
93 | class SparseConstantValue : public ConstantValue {
94 | public:
95 | SparseConstantValue(std::size_t size,
96 | std::vector> values)
97 | : ConstantValue(size), values(values) {}
98 |
99 | const std::vector &expand(std::vector &scratch,
100 | std::size_t slots) const override {
101 | validateSlots(slots);
102 | scratch.clear();
103 | scratch.resize(slots);
104 | for (auto &entry : values) {
105 | for (int i = 0; i < slots; i += values.size()) {
106 | scratch.at(entry.first + i) = entry.second;
107 | }
108 | }
109 | return scratch;
110 | }
111 |
112 | void expandTo(std::vector &result, std::size_t slots) const override {
113 | validateSlots(slots);
114 | result.clear();
115 | result.resize(slots);
116 | for (auto &entry : values) {
117 | for (int i = 0; i < slots; i += values.size()) {
118 | result.at(entry.first + i) = entry.second;
119 | }
120 | }
121 | }
122 |
123 | bool isZero() const override {
124 | // TODO: this assumes no repeated indices
125 | for (auto entry : values) {
126 | if (entry.second != 0) return false;
127 | }
128 | return true;
129 | }
130 |
131 | void serialize(msg::ConstantValue &msg) const override {
132 | msg.set_size(size);
133 | for (const auto &pair : values) {
134 | msg.add_sparse_indices(pair.first);
135 | msg.add_values(pair.second);
136 | }
137 | }
138 |
139 | private:
140 | std::vector> values;
141 | };
142 |
143 | std::unique_ptr