├── .bazelrc ├── .clang-format ├── .github └── workflows │ ├── scripts │ ├── lint_bazel.sh │ ├── lint_cpp.sh │ └── run_tests_core.sh │ └── tests.yml ├── .gitignore ├── CONTRIBUTORS.md ├── LICENSE ├── README.md ├── WORKSPACE ├── pir ├── BUILD ├── cpp │ ├── BUILD │ ├── README.md │ ├── benchmark.cpp │ ├── client.cpp │ ├── client.h │ ├── client_test.cpp │ ├── context.cpp │ ├── context.h │ ├── correctness_test.cpp │ ├── ct_reencoder.cpp │ ├── ct_reencoder.h │ ├── ct_reencoder_test.cpp │ ├── database.cpp │ ├── database.h │ ├── database_test.cpp │ ├── parameters.cpp │ ├── parameters.h │ ├── parameters_test.cpp │ ├── serialization.cpp │ ├── serialization.h │ ├── serialization_test.cpp │ ├── server.cpp │ ├── server.h │ ├── server_test.cpp │ ├── status_asserts.h │ ├── string_encoder.cpp │ ├── string_encoder.h │ ├── string_encoder_test.cpp │ ├── test_base.cpp │ ├── test_base.h │ ├── utils.cpp │ ├── utils.h │ └── utils_test.cpp ├── deps.bzl ├── preload.bzl └── proto │ ├── BUILD │ └── payload.proto └── third_party ├── BUILD └── seal.BUILD /.bazelrc: -------------------------------------------------------------------------------- 1 | build --cxxopt='-std=c++17' 2 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | AccessModifierOffset: -1 4 | AlignAfterOpenBracket: Align 5 | AlignConsecutiveMacros: false 6 | AlignConsecutiveAssignments: false 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlines: Left 9 | AlignOperands: true 10 | AlignTrailingComments: true 11 | AllowAllArgumentsOnNextLine: true 12 | AllowAllConstructorInitializersOnNextLine: true 13 | AllowAllParametersOfDeclarationOnNextLine: true 14 | AllowShortBlocksOnASingleLine: false 15 | AllowShortCaseLabelsOnASingleLine: false 16 | AllowShortFunctionsOnASingleLine: All 17 | AllowShortLambdasOnASingleLine: All 18 | AllowShortIfStatementsOnASingleLine: WithoutElse 19 | AllowShortLoopsOnASingleLine: true 20 | AlwaysBreakAfterDefinitionReturnType: None 21 | AlwaysBreakAfterReturnType: None 22 | AlwaysBreakBeforeMultilineStrings: true 23 | AlwaysBreakTemplateDeclarations: Yes 24 | BinPackArguments: true 25 | BinPackParameters: true 26 | BraceWrapping: 27 | AfterCaseLabel: false 28 | AfterClass: false 29 | AfterControlStatement: false 30 | AfterEnum: false 31 | AfterFunction: false 32 | AfterNamespace: false 33 | AfterObjCDeclaration: false 34 | AfterStruct: false 35 | AfterUnion: false 36 | AfterExternBlock: false 37 | BeforeCatch: false 38 | BeforeElse: false 39 | IndentBraces: false 40 | SplitEmptyFunction: true 41 | SplitEmptyRecord: true 42 | SplitEmptyNamespace: true 43 | BreakBeforeBinaryOperators: None 44 | BreakBeforeBraces: Attach 45 | BreakBeforeInheritanceComma: false 46 | BreakInheritanceList: BeforeColon 47 | BreakBeforeTernaryOperators: true 48 | BreakConstructorInitializersBeforeComma: false 49 | BreakConstructorInitializers: BeforeColon 50 | BreakAfterJavaFieldAnnotations: false 51 | BreakStringLiterals: true 52 | ColumnLimit: 80 53 | CommentPragmas: '^ IWYU pragma:' 54 | CompactNamespaces: false 55 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 56 | ConstructorInitializerIndentWidth: 4 57 | ContinuationIndentWidth: 4 58 | Cpp11BracedListStyle: true 59 | DerivePointerAlignment: true 60 | DisableFormat: false 61 | ExperimentalAutoDetectBinPacking: false 62 | FixNamespaceComments: true 63 | ForEachMacros: 64 | - foreach 65 | - Q_FOREACH 66 | - BOOST_FOREACH 67 | IncludeBlocks: Regroup 68 | IncludeCategories: 69 | - Regex: '^' 70 | Priority: 2 71 | - Regex: '^<.*\.h>' 72 | Priority: 1 73 | - Regex: '^<.*' 74 | Priority: 2 75 | - Regex: '.*' 76 | Priority: 3 77 | IncludeIsMainRegex: '([-_](test|unittest))?$' 78 | IndentCaseLabels: true 79 | IndentPPDirectives: None 80 | IndentWidth: 4 81 | IndentWrappedFunctionNames: false 82 | JavaScriptQuotes: Leave 83 | JavaScriptWrapImports: true 84 | KeepEmptyLinesAtTheStartOfBlocks: false 85 | MacroBlockBegin: '' 86 | MacroBlockEnd: '' 87 | MaxEmptyLinesToKeep: 1 88 | NamespaceIndentation: None 89 | ObjCBinPackProtocolList: Never 90 | ObjCBlockIndentWidth: 2 91 | ObjCSpaceAfterProperty: false 92 | ObjCSpaceBeforeProtocolList: true 93 | PenaltyBreakAssignment: 2 94 | PenaltyBreakBeforeFirstCallParameter: 1 95 | PenaltyBreakComment: 300 96 | PenaltyBreakFirstLessLess: 120 97 | PenaltyBreakString: 1000 98 | PenaltyBreakTemplateDeclaration: 10 99 | PenaltyExcessCharacter: 1000000 100 | PenaltyReturnTypeOnItsOwnLine: 200 101 | PointerAlignment: Left 102 | RawStringFormats: 103 | - Language: Cpp 104 | Delimiters: 105 | - cc 106 | - CC 107 | - cpp 108 | - Cpp 109 | - CPP 110 | - 'c++' 111 | - 'C++' 112 | CanonicalDelimiter: '' 113 | BasedOnStyle: google 114 | - Language: TextProto 115 | Delimiters: 116 | - pb 117 | - PB 118 | - proto 119 | - PROTO 120 | EnclosingFunctions: 121 | - EqualsProto 122 | - EquivToProto 123 | - PARSE_PARTIAL_TEXT_PROTO 124 | - PARSE_TEST_PROTO 125 | - PARSE_TEXT_PROTO 126 | - ParseTextOrDie 127 | - ParseTextProtoOrDie 128 | CanonicalDelimiter: '' 129 | BasedOnStyle: google 130 | ReflowComments: true 131 | SortIncludes: true 132 | SortUsingDeclarations: true 133 | SpaceAfterCStyleCast: false 134 | SpaceAfterLogicalNot: false 135 | SpaceAfterTemplateKeyword: true 136 | SpaceBeforeAssignmentOperators: true 137 | SpaceBeforeCpp11BracedList: false 138 | SpaceBeforeCtorInitializerColon: true 139 | SpaceBeforeInheritanceColon: true 140 | SpaceBeforeParens: ControlStatements 141 | SpaceBeforeRangeBasedForLoopColon: true 142 | SpaceInEmptyParentheses: false 143 | SpacesBeforeTrailingComments: 2 144 | SpacesInAngles: false 145 | SpacesInContainerLiterals: true 146 | SpacesInCStyleCastParentheses: false 147 | SpacesInParentheses: false 148 | SpacesInSquareBrackets: false 149 | Standard: Auto 150 | StatementMacros: 151 | - Q_UNUSED 152 | - QT_REQUIRE_VERSION 153 | TabWidth: 8 154 | UseTab: Never 155 | ... 156 | 157 | -------------------------------------------------------------------------------- /.github/workflows/scripts/lint_bazel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | # Lint files (all BUILD) inplace. 5 | find ./pir/ \( -iname BUILD \) | xargs buildifier 6 | if [ $? -ne 0 ] 7 | then 8 | exit 1 9 | fi 10 | 11 | # Print changes. 12 | git diff 13 | # Already well formated if 'git diff' doesn't output anything. 14 | ! ( git diff | grep -q ^ ) || exit 1 15 | -------------------------------------------------------------------------------- /.github/workflows/scripts/lint_cpp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | # Lint files (all .cpp and .h files) inplace. 5 | find ./pir/ \( -iname *.h -o -iname *.cpp \) | xargs clang-format -i -style='google' 6 | if [ $? -ne 0 ] 7 | then 8 | exit 1 9 | fi 10 | 11 | # Print changes. 12 | git diff 13 | # Already well formated if 'git diff' doesn't output anything. 14 | ! ( git diff | grep -q ^ ) || exit 1 15 | -------------------------------------------------------------------------------- /.github/workflows/scripts/run_tests_core.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | # C++ 5 | bazel test --test_output=all //pir/cpp/... 6 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | 8 | jobs: 9 | Core: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [macos-latest, ubuntu-20.04] 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Run tests 17 | timeout-minutes: 30 18 | run: .github/workflows/scripts/run_tests_core.sh 19 | - name: CPP Lint 20 | run: .github/workflows/scripts/lint_cpp.sh 21 | if: ${{ matrix.os == 'ubuntu-20.04' }} 22 | 23 | Bazel: 24 | runs-on: ubuntu-20.04 25 | steps: 26 | - name: Install golang 27 | uses: actions/setup-go@v2 28 | with: 29 | go-version: 1.13.3 30 | id: go 31 | - name: Install buildifier 32 | run: go get github.com/bazelbuild/buildtools/buildifier 33 | - uses: actions/checkout@v2 34 | - name: Bazel Lint 35 | run: .github/workflows/scripts/lint_bazel.sh 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Bazel 2 | bazel-* 3 | 4 | # CLion 5 | .clwb 6 | 7 | # VScode 8 | .vscode 9 | 10 | # npm 11 | node_modules/ 12 | 13 | # JS code coverage 14 | coverage/ 15 | 16 | # TS output 17 | tsc-out/ 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # Prerequisites 25 | *.d 26 | 27 | # Compiled Object files 28 | *.slo 29 | *.lo 30 | *.o 31 | *.obj 32 | 33 | # Precompiled Headers 34 | *.gch 35 | *.pch 36 | 37 | # Compiled Dynamic libraries 38 | *.so 39 | *.dylib 40 | *.dll 41 | 42 | # Fortran module files 43 | *.mod 44 | *.smod 45 | 46 | # Compiled Static libraries 47 | *.lai 48 | *.la 49 | *.a 50 | *.lib 51 | 52 | # Executables 53 | *.exe 54 | *.out 55 | *.app 56 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | [Bogdan Cebere](https://github.com/bcebere) 2 | 3 | [Kareem Shehata](https://github.com/kshehata) 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![om-logo](https://github.com/OpenMined/design-assets/blob/master/logos/OM/horizontal-primary-trans.png) 2 | 3 | [![Tests](https://github.com/OpenMined/PIR/workflows/Tests/badge.svg?branch=master&event=push)](https://github.com/OpenMined/PIR/actions?query=workflow%3ATests+branch%3Amaster+event%3Apush) 4 | ![License](https://img.shields.io/github/license/OpenMined/PIR) 5 | ![OpenCollective](https://img.shields.io/opencollective/all/openmined) 6 | 7 | 8 | # Private Information Retrieval 9 | 10 | ## Requirements 11 | 12 | There are requirements for the entire project which each language shares. There also could be requirements for each target language: 13 | 14 | ### Global Requirements 15 | 16 | These are the common requirements across all target languages of this project. 17 | 18 | - A compiler such as clang, gcc, or msvc 19 | - [Bazel](https://bazel.build) 20 | 21 | ## Compiling and Running 22 | 23 | The repository uses a folder structure to isolate the supported targets from one another: 24 | 25 | ``` 26 | pir// 27 | ``` 28 | 29 | ### C++ 30 | 31 | See the [C++ README.md](pir/cpp/README.md) 32 | 33 | ## Usage 34 | 35 | To use this library in another Bazel project, add the following in your WORKSPACE file: 36 | 37 | ``` 38 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") 39 | 40 | git_repository( 41 | name = "org_openmined_pir", 42 | remote = "https://github.com/OpenMined/PIR", 43 | branch = "master", 44 | init_submodules = True, 45 | ) 46 | 47 | load("@org_openmined_pir//pir:preload.bzl", "pir_preload") 48 | 49 | pir_preload() 50 | 51 | load("@org_openmined_pir//pir:deps.bzl", "pir_deps") 52 | 53 | pir_deps() 54 | 55 | ``` 56 | 57 | ## Contributing 58 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 59 | 60 | Please make sure to update tests as appropriate. 61 | 62 | ## Contributors 63 | 64 | See [CONTRIBUTORS.md](CONTRIBUTORS.md). 65 | 66 | ## License 67 | [Apache License 2.0](https://choosealicense.com/licenses/apache-2.0/) 68 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "org_openmined_pir") 2 | 3 | load("//pir:preload.bzl", "pir_preload") 4 | 5 | pir_preload() 6 | 7 | load("//pir:deps.bzl", "pir_deps") 8 | 9 | pir_deps() 10 | -------------------------------------------------------------------------------- /pir/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMined/PIR/e039662808a244a9042b4bee17213a7eba5388eb/pir/BUILD -------------------------------------------------------------------------------- /pir/cpp/BUILD: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | package(default_visibility = ["//visibility:public"]) 18 | 19 | PIR_DEFAULT_INCLUDES = ["."] 20 | 21 | PIR_DEFAULT_COPTS = ["-std=c++17"] 22 | 23 | cc_library( 24 | name = "pir", 25 | srcs = [ 26 | "client.cpp", 27 | "context.cpp", 28 | "context.h", 29 | "ct_reencoder.cpp", 30 | "ct_reencoder.h", 31 | "database.cpp", 32 | "database.h", 33 | "parameters.cpp", 34 | "parameters.h", 35 | "serialization.cpp", 36 | "serialization.h", 37 | "server.cpp", 38 | "status_asserts.h", 39 | "string_encoder.cpp", 40 | "string_encoder.h", 41 | "utils.cpp", 42 | "utils.h", 43 | ], 44 | hdrs = [ 45 | "client.h", 46 | "server.h", 47 | ], 48 | copts = PIR_DEFAULT_COPTS, 49 | includes = PIR_DEFAULT_INCLUDES, 50 | deps = [ 51 | "//pir/proto:payload_cc_proto", 52 | "@com_google_absl//absl/memory", 53 | "@com_google_absl//absl/status:statusor", 54 | "@com_google_absl//absl/strings", 55 | "@com_google_absl//absl/types:optional", 56 | "@com_google_absl//absl/types:span", 57 | "@com_microsoft_seal//:seal", 58 | ], 59 | ) 60 | 61 | cc_test( 62 | name = "pir_test", 63 | srcs = [ 64 | "client_test.cpp", 65 | "correctness_test.cpp", 66 | "ct_reencoder_test.cpp", 67 | "database_test.cpp", 68 | "parameters_test.cpp", 69 | "serialization_test.cpp", 70 | "server_test.cpp", 71 | "status_asserts.h", 72 | "string_encoder_test.cpp", 73 | "test_base.cpp", 74 | "test_base.h", 75 | "utils_test.cpp", 76 | ], 77 | copts = PIR_DEFAULT_COPTS, 78 | includes = PIR_DEFAULT_INCLUDES, 79 | linkstatic = True, 80 | deps = [ 81 | ":pir", 82 | "@com_google_absl//absl/status:statusor", 83 | "@com_google_googletest//:gtest", 84 | "@com_google_googletest//:gtest_main", 85 | ], 86 | ) 87 | 88 | cc_binary( 89 | name = "benchmark", 90 | srcs = [ 91 | "benchmark.cpp", 92 | "status_asserts.h", 93 | "test_base.cpp", 94 | "test_base.h", 95 | ], 96 | copts = PIR_DEFAULT_COPTS, 97 | includes = PIR_DEFAULT_INCLUDES, 98 | linkstatic = True, 99 | deps = [ 100 | ":pir", 101 | "@com_google_absl//absl/status:statusor", 102 | "@com_google_benchmark//:benchmark_main", 103 | "@com_google_googletest//:gtest", 104 | ], 105 | ) 106 | -------------------------------------------------------------------------------- /pir/cpp/README.md: -------------------------------------------------------------------------------- 1 | # PIR - C++ 2 | 3 | ## Build and test 4 | 5 | 6 | Build all libraries with or without optimizations, or build a specific module 7 | 8 | ``` 9 | # Build everything using the fastbuild optimization configuration 10 | bazel build //pir/cpp/... 11 | 12 | # With a specific optimization flag '-c opt' 13 | bazel build -c opt //pir/cpp/... 14 | 15 | ``` 16 | 17 | Build and run tests 18 | 19 | ``` 20 | bazel test //pir/cpp/... 21 | ``` 22 | 23 | Build and run benchmarks 24 | 25 | ``` 26 | bazel run -c opt //pir/cpp:benchmark 27 | ``` 28 | -------------------------------------------------------------------------------- /pir/cpp/benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include "benchmark/benchmark.h" 2 | 3 | #include 4 | 5 | #include "gmock/gmock.h" 6 | #include "gtest/gtest.h" 7 | #include "pir/cpp/client.h" 8 | #include "pir/cpp/server.h" 9 | #include "pir/cpp/status_asserts.h" 10 | #include "pir/cpp/test_base.h" 11 | #include "seal/seal.h" 12 | 13 | namespace pir { 14 | 15 | using namespace ::testing; 16 | 17 | constexpr bool USE_CIPHERTEXT_MULTIPLICATION = false; 18 | constexpr uint32_t ITEM_SIZE = 288; 19 | constexpr uint32_t DIMENSIONS = 2; 20 | constexpr uint32_t POLY_MOD_DEGREE = 4096; 21 | constexpr uint32_t PLAIN_MOD_BITS = 24; 22 | constexpr uint32_t BITS_PER_COEFF = 0; 23 | constexpr uint32_t QUERIES_PER_REQUEST = 1; 24 | 25 | using std::cout; 26 | using std::endl; 27 | 28 | class PIRFixture : public benchmark::Fixture, public PIRTestingBase { 29 | public: 30 | void SetUpDb(const ::benchmark::State& state) { 31 | SetUpParams(state.range(0), ITEM_SIZE, DIMENSIONS, POLY_MOD_DEGREE, 32 | PLAIN_MOD_BITS, BITS_PER_COEFF, USE_CIPHERTEXT_MULTIPLICATION); 33 | GenerateDB(); 34 | SetUpSealTools(); 35 | 36 | client_ = *(PIRClient::Create(pir_params_)); 37 | server_ = *(PIRServer::Create(pir_db_, pir_params_)); 38 | ASSERT_THAT(client_, NotNull()); 39 | ASSERT_THAT(server_, NotNull()); 40 | } 41 | 42 | vector GenerateRandomIndices() { 43 | static auto prng = 44 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 45 | vector result(QUERIES_PER_REQUEST, 0); 46 | for (auto& i : result) { 47 | i = prng->generate() % (db_size_); 48 | } 49 | return result; 50 | } 51 | 52 | unique_ptr client_; 53 | unique_ptr server_; 54 | }; 55 | 56 | BENCHMARK_DEFINE_F(PIRFixture, SetupDb)(benchmark::State& st) { 57 | for (auto _ : st) { 58 | SetUpDb(st); 59 | } 60 | } 61 | 62 | BENCHMARK_DEFINE_F(PIRFixture, ClientCreateRequest)(benchmark::State& st) { 63 | SetUpDb(st); 64 | for (auto _ : st) { 65 | auto indices = GenerateRandomIndices(); 66 | ASSIGN_OR_FAIL(auto request, client_->CreateRequest(indices)); 67 | ::benchmark::DoNotOptimize(request); 68 | } 69 | } 70 | 71 | BENCHMARK_DEFINE_F(PIRFixture, ServerProcessRequest)(benchmark::State& st) { 72 | SetUpDb(st); 73 | auto indices = GenerateRandomIndices(); 74 | ASSIGN_OR_FAIL(auto request, client_->CreateRequest(indices)); 75 | for (auto _ : st) { 76 | ASSIGN_OR_FAIL(auto response, server_->ProcessRequest(request)); 77 | ::benchmark::DoNotOptimize(response); 78 | } 79 | } 80 | 81 | BENCHMARK_DEFINE_F(PIRFixture, ClientProcessResponse)(benchmark::State& st) { 82 | SetUpDb(st); 83 | auto indices = GenerateRandomIndices(); 84 | ASSIGN_OR_FAIL(auto request, client_->CreateRequest(indices)); 85 | ASSIGN_OR_FAIL(auto response, server_->ProcessRequest(request)); 86 | 87 | for (auto _ : st) { 88 | ASSIGN_OR_FAIL(auto results, client_->ProcessResponse(indices, response)); 89 | ASSERT_EQ(results.size(), indices.size()); 90 | for (size_t i = 0; i < results.size(); ++i) { 91 | ASSERT_EQ(results[i], string_db_[indices[i]]) << "i = " << i; 92 | } 93 | } 94 | } 95 | 96 | BENCHMARK_REGISTER_F(PIRFixture, SetupDb) 97 | ->RangeMultiplier(2) 98 | ->Range(1 << 8, 1 << 16); 99 | BENCHMARK_REGISTER_F(PIRFixture, ClientCreateRequest) 100 | ->RangeMultiplier(2) 101 | ->Range(1 << 8, 1 << 16); 102 | BENCHMARK_REGISTER_F(PIRFixture, ServerProcessRequest) 103 | ->RangeMultiplier(2) 104 | ->Range(1 << 8, 1 << 16); 105 | BENCHMARK_REGISTER_F(PIRFixture, ClientProcessResponse) 106 | ->RangeMultiplier(2) 107 | ->Range(1 << 8, 1 << 16); 108 | 109 | } // namespace pir 110 | -------------------------------------------------------------------------------- /pir/cpp/client.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | #include "pir/cpp/client.h" 17 | 18 | #include "pir/cpp/ct_reencoder.h" 19 | #include "pir/cpp/database.h" 20 | #include "pir/cpp/status_asserts.h" 21 | #include "pir/cpp/string_encoder.h" 22 | #include "pir/cpp/utils.h" 23 | #include "seal/seal.h" 24 | 25 | namespace pir { 26 | 27 | using absl::InternalError; 28 | using absl::InvalidArgumentError; 29 | using absl::StatusOr; 30 | using ::seal::Ciphertext; 31 | using ::seal::GaloisKeys; 32 | using ::seal::Plaintext; 33 | using ::seal::RelinKeys; 34 | 35 | PIRClient::PIRClient(std::unique_ptr context) 36 | : context_(std::move(context)) {} 37 | 38 | Status PIRClient::initialize() { 39 | ASSIGN_OR_RETURN(db_, PIRDatabase::Create(context_->Params())); 40 | try { 41 | auto sealctx = context_->SEALContext(); 42 | keygen_ = std::make_unique(sealctx); 43 | encryptor_ = 44 | std::make_shared(sealctx, keygen_->public_key()); 45 | decryptor_ = 46 | std::make_shared(sealctx, keygen_->secret_key()); 47 | auto gal_keys = keygen_->galois_keys(generate_galois_elts( 48 | context_->EncryptionParams().poly_modulus_degree())); 49 | auto relin_keys = keygen_->relin_keys(); 50 | request_proto_ = std::make_unique(); 51 | RETURN_IF_ERROR( 52 | SEALSerialize<>(gal_keys, request_proto_->mutable_galois_keys())); 53 | RETURN_IF_ERROR( 54 | SEALSerialize<>(relin_keys, request_proto_->mutable_relin_keys())); 55 | } catch (const std::exception& ex) { 56 | return InternalError(ex.what()); 57 | } 58 | return absl::OkStatus(); 59 | } 60 | 61 | StatusOr> PIRClient::Create( 62 | shared_ptr params) { 63 | ASSIGN_OR_RETURN(auto context, PIRContext::Create(params)); 64 | auto client = absl::WrapUnique(new PIRClient(std::move(context))); 65 | RETURN_IF_ERROR(client->initialize()); 66 | return client; 67 | } 68 | 69 | StatusOr InvertMod(uint64_t m, const seal::Modulus& mod) { 70 | if (mod.uint64_count() > 1) { 71 | return InternalError("Modulus too big to invert"); 72 | } 73 | uint64_t inverse; 74 | if (!seal::util::try_invert_uint_mod(m, mod.value(), inverse)) { 75 | return InternalError("Could not invert value"); 76 | } 77 | return inverse; 78 | } 79 | 80 | StatusOr PIRClient::CreateRequest( 81 | const std::vector& indexes) const { 82 | vector> queries(indexes.size()); 83 | for (size_t i = 0; i < indexes.size(); ++i) { 84 | RETURN_IF_ERROR(createQueryFor(indexes[i], queries[i])); 85 | } 86 | 87 | Request request_proto(*request_proto_); 88 | RETURN_IF_ERROR(SaveRequest(queries, &request_proto)); 89 | return request_proto; 90 | } 91 | 92 | Status PIRClient::createQueryFor(size_t desired_index, 93 | vector& query) const { 94 | if (desired_index >= context_->Params()->num_items()) { 95 | return InvalidArgumentError("invalid index " + 96 | std::to_string(desired_index)); 97 | } 98 | auto plain_mod = context_->EncryptionParams().plain_modulus(); 99 | const auto poly_modulus_degree = 100 | context_->EncryptionParams().poly_modulus_degree(); 101 | 102 | auto dims = std::vector(context_->Params()->dimensions().begin(), 103 | context_->Params()->dimensions().end()); 104 | auto indices = db_->calculate_indices(desired_index); 105 | 106 | const size_t dim_sum = context_->DimensionsSum(); 107 | 108 | size_t offset = 0; 109 | query.resize(dim_sum / poly_modulus_degree + 1); 110 | Plaintext pt(poly_modulus_degree); 111 | for (size_t c = 0; c < query.size(); ++c) { 112 | pt.set_zero(); 113 | 114 | while (!indices.empty()) { 115 | if (indices[0] + offset >= poly_modulus_degree) { 116 | // no more slots in this poly 117 | indices[0] -= (poly_modulus_degree - offset); 118 | dims[0] -= (poly_modulus_degree - offset); 119 | offset = 0; 120 | break; 121 | } 122 | uint64_t m = (c < query.size() - 1) 123 | ? poly_modulus_degree 124 | : next_power_two(dim_sum % poly_modulus_degree); 125 | ASSIGN_OR_RETURN(pt[indices[0] + offset], InvertMod(m, plain_mod)); 126 | offset += dims[0]; 127 | indices.erase(indices.begin()); 128 | dims.erase(dims.begin()); 129 | 130 | if (offset >= poly_modulus_degree) { 131 | offset -= poly_modulus_degree; 132 | break; 133 | } 134 | } 135 | 136 | try { 137 | encryptor_->encrypt(pt, query[c]); 138 | } catch (const std::exception& e) { 139 | return InternalError(e.what()); 140 | } 141 | } 142 | 143 | return absl::OkStatus(); 144 | } 145 | 146 | StatusOr> PIRClient::ProcessResponseInteger( 147 | const Response& response_proto) const { 148 | vector result; 149 | result.reserve(response_proto.reply_size()); 150 | for (const auto& r : response_proto.reply()) { 151 | ASSIGN_OR_RETURN(auto result_pt, ProcessReply(r)); 152 | try { 153 | result.push_back(context_->Encoder()->decode_int64(result_pt)); 154 | } catch (const std::exception& e) { 155 | return InternalError(e.what()); 156 | } 157 | } 158 | return result; 159 | } 160 | 161 | StatusOr> PIRClient::ProcessResponse( 162 | const std::vector& indexes, 163 | const Response& response_proto) const { 164 | if (indexes.size() != response_proto.reply_size()) { 165 | return InvalidArgumentError( 166 | "Number of indexes must match number of replies"); 167 | } 168 | 169 | StringEncoder encoder(context_->SEALContext()); 170 | if (context_->Params()->bits_per_coeff() > 0) { 171 | encoder.set_bits_per_coeff(context_->Params()->bits_per_coeff()); 172 | } 173 | vector result; 174 | result.reserve(response_proto.reply_size()); 175 | 176 | for (size_t i = 0; i < indexes.size(); ++i) { 177 | ASSIGN_OR_RETURN(auto result_pt, ProcessReply(response_proto.reply(i))); 178 | 179 | ASSIGN_OR_RETURN( 180 | auto v, encoder.decode(result_pt, context_->Params()->bytes_per_item(), 181 | db_->calculate_item_offset(indexes[i]))); 182 | result.push_back(v); 183 | } 184 | return result; 185 | } 186 | 187 | StatusOr PIRClient::ProcessReply( 188 | const Ciphertexts& reply_proto) const { 189 | if (context_->Params()->use_ciphertext_multiplication()) { 190 | return ProcessReplyCiphertextMult(reply_proto); 191 | } else { 192 | return ProcessReplyCiphertextDecomp(reply_proto); 193 | } 194 | } 195 | 196 | StatusOr<Plaintext> PIRClient::ProcessReplyCiphertextMult( 197 | const Ciphertexts& reply_proto) const { 198 | ASSIGN_OR_RETURN(auto reply_cts, 199 | LoadCiphertexts(context_->SEALContext(), reply_proto)); 200 | if (reply_cts.size() != 1) { 201 | return InvalidArgumentError( 202 | "Number of ciphertexts in reply must be 1 when using CT " 203 | "multiplication"); 204 | } 205 | 206 | const auto poly_modulus_degree = 207 | context_->EncryptionParams().poly_modulus_degree(); 208 | seal::Plaintext pt(poly_modulus_degree, 0); 209 | 210 | try { 211 | decryptor_->decrypt(reply_cts[0], pt); 212 | } catch (const std::exception& e) { 213 | return InternalError(e.what()); 214 | } 215 | 216 | return pt; 217 | } 218 | 219 | StatusOr<Plaintext> PIRClient::ProcessReplyCiphertextDecomp( 220 | const Ciphertexts& reply_proto) const { 221 | ASSIGN_OR_RETURN(auto ct_reencoder, 222 | CiphertextReencoder::Create(context_->SEALContext())); 223 | // TODO: this should use the original CT size 224 | const size_t exp_ratio = ct_reencoder->ExpansionRatio() * 2; 225 | const size_t num_dims = context_->Params()->dimensions_size(); 226 | const size_t num_ct_per_reply = ipow(exp_ratio, num_dims - 1); 227 | 228 | ASSIGN_OR_RETURN(auto reply_cts, 229 | LoadCiphertexts(context_->SEALContext(), reply_proto)); 230 | if (reply_cts.size() != num_ct_per_reply) { 231 | return InvalidArgumentError( 232 | "Number of ciphertexts in reply does not match expected"); 233 | } 234 | vector<Plaintext> reply_pts; 235 | 236 | for (size_t d = 0; d < num_dims; ++d) { 237 | reply_pts.resize(reply_cts.size()); 238 | try { 239 | for (size_t i = 0; i < reply_cts.size(); ++i) { 240 | decryptor_->decrypt(reply_cts[i], reply_pts[i]); 241 | } 242 | } catch (const std::exception& e) { 243 | return InternalError(e.what()); 244 | } 245 | 246 | if (reply_pts.size() <= 1) break; 247 | 248 | reply_cts.resize(reply_cts.size() / exp_ratio); 249 | for (size_t i = 0; i < reply_cts.size(); ++i) { 250 | reply_cts[i] = ct_reencoder->Decode(reply_pts.begin() + i * exp_ratio, 2); 251 | } 252 | } 253 | 254 | return reply_pts[0]; 255 | } 256 | } // namespace pir 257 | -------------------------------------------------------------------------------- /pir/cpp/client.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_CLIENT_H_ 18 | #define PIR_CLIENT_H_ 19 | 20 | #include <string> 21 | 22 | #include "absl/status/statusor.h" 23 | #include "pir/cpp/context.h" 24 | #include "pir/cpp/database.h" 25 | #include "pir/cpp/serialization.h" 26 | 27 | namespace pir { 28 | 29 | using absl::StatusOr; 30 | 31 | class PIRClient { 32 | friend class PIRClientTest; 33 | 34 | public: 35 | /** 36 | * Creates and returns a new client instance, from existing parameters 37 | * @param[in] params PIR parameters 38 | * @returns InvalidArgument if the parameters cannot be loaded 39 | **/ 40 | static StatusOr<std::unique_ptr<PIRClient>> Create( 41 | shared_ptr<PIRParameters> params); 42 | /** 43 | * Creates a new request to query the database for the given index. Note that 44 | * if more than one dimension is specified in context, then the request 45 | * generated will include multiple selection vectors concatenated into one set 46 | * of ciphertexts. It is expected that the server will first expand the 47 | * request ciphertexts, and then split them into vectors by the dimensions 48 | * given in context. 49 | * @param[in] desiredIndex Expected database value from an index 50 | * @returns InvalidArgument if the index is invalid or if the encryption fails 51 | **/ 52 | StatusOr<Request> CreateRequest( 53 | const std::vector<std::size_t>& /*indexes*/) const; 54 | 55 | /** 56 | * Extracts database value from server response message. Needs the indices 57 | * from the original request since multiple values may be packed into each 58 | * reply ciphertext. 59 | * @param[in] indexes Original indices when request was created. 60 | * @param[in] response Server response. 61 | * @returns List of resulting strings of DB values, or an error. 62 | */ 63 | StatusOr<std::vector<std::string>> ProcessResponse( 64 | const std::vector<std::size_t>& indexes, const Response& response) const; 65 | 66 | /** 67 | * Extracts server response as an integer encoded in the plaintext. 68 | * Should only be used for testing. 69 | * @param[in] response Server output 70 | * @returns InvalidArgument if the decryption fails 71 | **/ 72 | StatusOr<std::vector<int64_t>> ProcessResponseInteger( 73 | const Response& response) const; 74 | 75 | PIRClient() = delete; 76 | 77 | private: 78 | PIRClient(std::unique_ptr<PIRContext>); 79 | Status initialize(); 80 | Status createQueryFor(size_t desired_index, vector<Ciphertext>& query) const; 81 | 82 | std::unique_ptr<PIRContext> context_; 83 | std::shared_ptr<PIRDatabase> db_; 84 | 85 | std::unique_ptr<seal::KeyGenerator> keygen_; 86 | std::shared_ptr<seal::Encryptor> encryptor_; 87 | std::shared_ptr<seal::Decryptor> decryptor_; 88 | std::unique_ptr<Request> request_proto_; 89 | 90 | StatusOr<seal::Plaintext> ProcessReply(const Ciphertexts& reply_proto) const; 91 | StatusOr<seal::Plaintext> ProcessReplyCiphertextMult( 92 | const Ciphertexts& reply_proto) const; 93 | StatusOr<seal::Plaintext> ProcessReplyCiphertextDecomp( 94 | const Ciphertexts& reply_proto) const; 95 | }; 96 | 97 | } // namespace pir 98 | 99 | #endif // PIR_CLIENT_H_ 100 | -------------------------------------------------------------------------------- /pir/cpp/client_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "pir/cpp/client.h" 18 | 19 | #include "gmock/gmock.h" 20 | #include "gtest/gtest.h" 21 | #include "pir/cpp/ct_reencoder.h" 22 | #include "pir/cpp/server.h" 23 | #include "pir/cpp/status_asserts.h" 24 | #include "pir/cpp/string_encoder.h" 25 | #include "pir/cpp/utils.h" 26 | 27 | namespace pir { 28 | 29 | using namespace seal; 30 | using std::get; 31 | using std::make_tuple; 32 | using std::make_unique; 33 | using std::shared_ptr; 34 | using std::tuple; 35 | using std::unique_ptr; 36 | using std::vector; 37 | using namespace ::testing; 38 | 39 | constexpr uint32_t POLY_MODULUS_DEGREE = 4096; 40 | 41 | class PIRClientTest : public ::testing::Test { 42 | protected: 43 | void SetUp() { SetUpDB(100); } 44 | 45 | void SetUpDB(size_t dbsize, size_t dimensions = 1, size_t elem_size = 0, 46 | bool use_ciphertext_multiplication = false) { 47 | db_size_ = dbsize; 48 | encryption_params_ = GenerateEncryptionParams(POLY_MODULUS_DEGREE, 16); 49 | pir_params_ = 50 | *(CreatePIRParameters(dbsize, elem_size, dimensions, encryption_params_, 51 | use_ciphertext_multiplication)); 52 | client_ = *(PIRClient::Create(pir_params_)); 53 | 54 | ASSERT_TRUE(client_ != nullptr); 55 | } 56 | 57 | const auto& Context() { return client_->context_; } 58 | std::shared_ptr<seal::Decryptor> Decryptor() { return client_->decryptor_; } 59 | std::shared_ptr<seal::Encryptor> Encryptor() { return client_->encryptor_; } 60 | 61 | size_t db_size_; 62 | shared_ptr<PIRParameters> pir_params_; 63 | EncryptionParameters encryption_params_; 64 | std::unique_ptr<PIRClient> client_; 65 | }; 66 | 67 | TEST_F(PIRClientTest, TestCreateRequest) { 68 | const size_t desired_index = 5; 69 | const vector<size_t> indices = {desired_index}; 70 | 71 | ASSIGN_OR_FAIL(auto req_proto, client_->CreateRequest(indices)); 72 | ASSERT_EQ(req_proto.query_size(), 1); 73 | ASSIGN_OR_FAIL(auto req, 74 | LoadCiphertexts(Context()->SEALContext(), req_proto.query(0))); 75 | 76 | Plaintext pt; 77 | ASSERT_EQ(req.size(), 1); 78 | EXPECT_THAT(req_proto.galois_keys(), Not(IsEmpty())); 79 | Decryptor()->decrypt(req[0], pt); 80 | 81 | const auto plain_mod = encryption_params_.plain_modulus().value(); 82 | EXPECT_EQ((pt[desired_index] * next_power_two(db_size_)) % plain_mod, 1); 83 | for (size_t i = 0; i < pt.coeff_count(); ++i) { 84 | if (i != desired_index) { 85 | EXPECT_EQ(pt[i], 0); 86 | } 87 | } 88 | } 89 | 90 | TEST_F(PIRClientTest, TestCreateRequestD2) { 91 | SetUpDB(84, 2); 92 | const size_t desired_index = 42; 93 | const vector<size_t> indices = {desired_index}; 94 | 95 | const size_t num_rows = 10; 96 | const size_t num_cols = 9; 97 | const size_t total_s_items = num_rows + num_cols; 98 | ASSERT_THAT(Context()->Params()->dimensions(), 99 | ElementsAre(num_rows, num_cols)); 100 | 101 | ASSIGN_OR_FAIL(auto request_proto, client_->CreateRequest(indices)); 102 | ASSERT_EQ(request_proto.query_size(), 1); 103 | ASSIGN_OR_FAIL(auto request, LoadCiphertexts(Context()->SEALContext(), 104 | request_proto.query(0))); 105 | Plaintext pt; 106 | ASSERT_EQ(request.size(), 1); 107 | EXPECT_THAT(request_proto.galois_keys(), Not(IsEmpty())); 108 | EXPECT_THAT(request_proto.relin_keys(), Not(IsEmpty())); 109 | 110 | Decryptor()->decrypt(request[0], pt); 111 | 112 | const size_t expected_row = 4; 113 | const size_t expected_col = 6; 114 | const auto plain_mod = encryption_params_.plain_modulus().value(); 115 | // NB: both row and column selection vectors are packed into the same CT 116 | EXPECT_EQ((pt[expected_row] * next_power_two(total_s_items)) % plain_mod, 1); 117 | EXPECT_EQ( 118 | (pt[num_rows + expected_col] * next_power_two(total_s_items)) % plain_mod, 119 | 1); 120 | for (size_t i = 0; i < pt.coeff_count(); ++i) { 121 | if (i != expected_row && i != (num_rows + expected_col)) { 122 | EXPECT_EQ(pt[i], 0) << "i = " << i; 123 | } 124 | } 125 | } 126 | 127 | TEST_F(PIRClientTest, TestCreateRequestD3) { 128 | SetUpDB(82, 3); 129 | const size_t desired_index = 42; 130 | const vector<size_t> indices = {desired_index}; 131 | 132 | const size_t num_rows = 5; 133 | const size_t num_cols = 5; 134 | const size_t num_depth = 4; 135 | const size_t total_s_items = num_rows + num_cols + num_depth; 136 | ASSERT_THAT(Context()->Params()->dimensions(), 137 | ElementsAre(num_rows, num_cols, num_depth)); 138 | 139 | ASSIGN_OR_FAIL(auto request_proto, client_->CreateRequest(indices)); 140 | ASSERT_EQ(request_proto.query_size(), 1); 141 | ASSIGN_OR_FAIL(auto request, LoadCiphertexts(Context()->SEALContext(), 142 | request_proto.query(0))); 143 | Plaintext pt; 144 | ASSERT_EQ(request.size(), 1); 145 | EXPECT_THAT(request_proto.galois_keys(), Not(IsEmpty())); 146 | EXPECT_THAT(request_proto.relin_keys(), Not(IsEmpty())); 147 | Decryptor()->decrypt(request[0], pt); 148 | 149 | const size_t expected_row = 2; 150 | const size_t expected_col = 0; 151 | const size_t expected_depth = 2; 152 | const auto plain_mod = encryption_params_.plain_modulus().value(); 153 | EXPECT_EQ((pt[expected_row] * next_power_two(total_s_items)) % plain_mod, 1); 154 | EXPECT_EQ( 155 | (pt[num_rows + expected_col] * next_power_two(total_s_items)) % plain_mod, 156 | 1); 157 | EXPECT_EQ((pt[num_rows + num_cols + expected_depth] * 158 | next_power_two(total_s_items)) % 159 | plain_mod, 160 | 1); 161 | for (size_t i = 0; i < pt.coeff_count(); ++i) { 162 | if (i != expected_row && i != (num_rows + expected_col) && 163 | i != (num_rows + num_cols + expected_depth)) { 164 | EXPECT_EQ(pt[i], 0) << "i = " << i; 165 | } 166 | } 167 | } 168 | 169 | TEST_F(PIRClientTest, TestCreateRequestMultiDimMultiCT1) { 170 | SetUpDB(20000000, 2); 171 | const size_t desired_index = 12345679; 172 | const vector<size_t> indices = {desired_index}; 173 | const size_t num_rows = 4473; 174 | const size_t num_cols = 4472; 175 | ASSERT_THAT(Context()->Params()->dimensions(), 176 | ElementsAre(num_rows, num_cols)); 177 | 178 | ASSIGN_OR_FAIL(auto request_proto, client_->CreateRequest(indices)); 179 | ASSERT_EQ(request_proto.query_size(), 1); 180 | ASSIGN_OR_FAIL(auto request, LoadCiphertexts(Context()->SEALContext(), 181 | request_proto.query(0))); 182 | ASSERT_EQ(request.size(), 3); 183 | EXPECT_THAT(request_proto.galois_keys(), Not(IsEmpty())); 184 | EXPECT_THAT(request_proto.relin_keys(), Not(IsEmpty())); 185 | 186 | const size_t expected_row = 2760; 187 | const size_t expected_col = 2959; 188 | const auto plain_mod = encryption_params_.plain_modulus().value(); 189 | 190 | vector<Plaintext> pts(request.size()); 191 | for (size_t i = 0; i < pts.size(); ++i) { 192 | Decryptor()->decrypt(request[i], pts[i]); 193 | } 194 | 195 | // first plaintext should be all zero except for row value 196 | EXPECT_EQ((pts[0][expected_row] * POLY_MODULUS_DEGREE) % plain_mod, 1); 197 | for (size_t i = 0; i < pts[0].coeff_count(); ++i) { 198 | if (i != expected_row) { 199 | EXPECT_EQ(pts[0][i], 0) << "i = " << i; 200 | } 201 | } 202 | 203 | // second plaintext should be all zero except for col value with offset from 204 | // values for row value 205 | const size_t expected_index = expected_col + num_rows - POLY_MODULUS_DEGREE; 206 | EXPECT_EQ((pts[1][expected_index] * POLY_MODULUS_DEGREE) % plain_mod, 1); 207 | for (size_t i = 0; i < pts[1].coeff_count(); ++i) { 208 | if (i != expected_index) { 209 | EXPECT_EQ(pts[1][i], 0) << "i = " << i; 210 | } 211 | } 212 | 213 | // third plaintext should be all zeros 214 | for (size_t i = 0; i < pts[2].coeff_count(); ++i) { 215 | EXPECT_EQ(pts[2][i], 0) << "i = " << i; 216 | } 217 | } 218 | 219 | TEST_F(PIRClientTest, TestCreateRequestMultiDimMultiCT2) { 220 | SetUpDB(20000000, 2); 221 | const size_t desired_index = 12346679; 222 | const vector<size_t> indices = {desired_index}; 223 | const size_t num_rows = 4473; 224 | const size_t num_cols = 4472; 225 | ASSERT_THAT(Context()->Params()->dimensions(), 226 | ElementsAre(num_rows, num_cols)); 227 | 228 | ASSIGN_OR_FAIL(auto request_proto, client_->CreateRequest(indices)); 229 | ASSERT_EQ(request_proto.query_size(), 1); 230 | ASSIGN_OR_FAIL(auto request, LoadCiphertexts(Context()->SEALContext(), 231 | request_proto.query(0))); 232 | ASSERT_EQ(request.size(), 3); 233 | EXPECT_THAT(request_proto.galois_keys(), Not(IsEmpty())); 234 | EXPECT_THAT(request_proto.relin_keys(), Not(IsEmpty())); 235 | 236 | const size_t expected_row = 2760; 237 | const size_t expected_col = 3959; 238 | const auto plain_mod = encryption_params_.plain_modulus().value(); 239 | 240 | vector<Plaintext> pts(request.size()); 241 | for (size_t i = 0; i < pts.size(); ++i) { 242 | Decryptor()->decrypt(request[i], pts[i]); 243 | } 244 | 245 | // first plaintext should be all zero except for row value 246 | EXPECT_EQ((pts[0][expected_row] * POLY_MODULUS_DEGREE) % plain_mod, 1); 247 | for (size_t i = 0; i < pts[0].coeff_count(); ++i) { 248 | if (i != expected_row) { 249 | EXPECT_EQ(pts[0][i], 0) << "i = " << i; 250 | } 251 | } 252 | 253 | // second plaintext should be all zeros 254 | for (size_t i = 0; i < pts[1].coeff_count(); ++i) { 255 | EXPECT_EQ(pts[1][i], 0) << "i = " << i; 256 | } 257 | // third plaintext should be all zero except for col value with offset 258 | const size_t expected_index = 259 | expected_col + num_rows - 2 * POLY_MODULUS_DEGREE; 260 | const size_t m = next_power_two((num_rows + num_cols) % POLY_MODULUS_DEGREE); 261 | EXPECT_EQ((pts[2][expected_index] * m) % plain_mod, 1); 262 | for (size_t i = 0; i < pts[2].coeff_count(); ++i) { 263 | if (i != expected_index) { 264 | EXPECT_EQ(pts[2][i], 0) << "i = " << i; 265 | } 266 | } 267 | } 268 | 269 | TEST_F(PIRClientTest, TestCreateRequest_InvalidIndex) { 270 | auto request_or = client_->CreateRequest({db_size_ + 1}); 271 | ASSERT_EQ(request_or.status().code(), absl::StatusCode::kInvalidArgument); 272 | } 273 | 274 | class CreateRequestTest : public PIRClientTest, 275 | public testing::WithParamInterface< 276 | tuple<size_t, vector<size_t>, uint64_t>> {}; 277 | 278 | TEST_P(CreateRequestTest, TestCreateRequest) { 279 | const auto dbsize = get<0>(GetParam()); 280 | vector<size_t> indices = get<1>(GetParam()); 281 | SetUpDB(dbsize); 282 | 283 | const auto poly_modulus_degree = encryption_params_.poly_modulus_degree(); 284 | const auto plain_mod = encryption_params_.plain_modulus().value(); 285 | 286 | ASSIGN_OR_FAIL(auto request, client_->CreateRequest(indices)); 287 | ASSERT_EQ(request.query_size(), indices.size()); 288 | EXPECT_THAT(request.galois_keys(), Not(IsEmpty())); 289 | 290 | auto m = get<2>(GetParam()); 291 | 292 | for (size_t idx = 0; idx < indices.size(); ++idx) { 293 | ASSIGN_OR_FAIL(auto query, LoadCiphertexts(Context()->SEALContext(), 294 | request.query(idx))); 295 | ASSERT_EQ(query.size(), dbsize / poly_modulus_degree + 1); 296 | size_t desired_index = indices[idx]; 297 | 298 | for (const auto& ct : query) { 299 | Plaintext pt; 300 | Decryptor()->decrypt(ct, pt); 301 | 302 | if (desired_index < 0 || 303 | static_cast<size_t>(desired_index) >= poly_modulus_degree) { 304 | desired_index -= poly_modulus_degree; 305 | for (size_t i = 0; i < pt.coeff_count(); ++i) { 306 | EXPECT_EQ(pt[i], 0); 307 | } 308 | } else { 309 | EXPECT_EQ((pt[desired_index] * m) % plain_mod, 1); 310 | for (size_t i = 0; i < pt.coeff_count(); ++i) { 311 | if (i != desired_index) { 312 | EXPECT_EQ(pt[i], 0); 313 | } 314 | } 315 | desired_index = -1; 316 | } 317 | } 318 | } 319 | } 320 | 321 | INSTANTIATE_TEST_SUITE_P( 322 | Requests, CreateRequestTest, 323 | testing::Values( 324 | make_tuple(10000, vector<size_t>({5005}), POLY_MODULUS_DEGREE), 325 | make_tuple(10000, vector<size_t>({0}), POLY_MODULUS_DEGREE), 326 | make_tuple(10000, vector<size_t>({1}), POLY_MODULUS_DEGREE), 327 | make_tuple(10000, vector<size_t>({3333}), POLY_MODULUS_DEGREE), 328 | make_tuple(10000, vector<size_t>({4095}), POLY_MODULUS_DEGREE), 329 | make_tuple(10000, vector<size_t>({4096}), POLY_MODULUS_DEGREE), 330 | make_tuple(10000, vector<size_t>({4097}), POLY_MODULUS_DEGREE), 331 | make_tuple(10000, vector<size_t>({8191}), POLY_MODULUS_DEGREE), 332 | make_tuple(10000, vector<size_t>({8192}), 2048), 333 | make_tuple(10000, vector<size_t>({8193}), 2048), 334 | make_tuple(10000, vector<size_t>({9007}), 2048), 335 | make_tuple(10000, vector<size_t>({9999}), 2048), 336 | make_tuple(4096, vector<size_t>({0}), 4096), 337 | make_tuple(4096, vector<size_t>({4095}), 4096), 338 | make_tuple(16384, vector<size_t>({12288}), 4096), 339 | make_tuple(16384, vector<size_t>({12289}), 4096), 340 | make_tuple(16384, vector<size_t>({16383}), 4096), 341 | 342 | make_tuple(10000, vector<size_t>({5005}), POLY_MODULUS_DEGREE), 343 | make_tuple(10000, vector<size_t>({0}), POLY_MODULUS_DEGREE), 344 | make_tuple(10000, vector<size_t>({8191}), POLY_MODULUS_DEGREE), 345 | make_tuple(10000, vector<size_t>({0, 8191}), POLY_MODULUS_DEGREE), 346 | make_tuple(10000, vector<size_t>({0, 5005, 8191}), POLY_MODULUS_DEGREE), 347 | make_tuple(10000, vector<size_t>({0, 1, 2, 3, 4, 5}), 348 | POLY_MODULUS_DEGREE))); 349 | 350 | class ProcessResponseTest 351 | : public PIRClientTest, 352 | public testing::WithParamInterface<tuple< 353 | size_t, size_t, size_t, size_t, vector<size_t>, vector<size_t>>> { 354 | protected: 355 | void SetUpForCTMultiply(bool use_ciphertext_multiplication) { 356 | const auto dbsize = get<0>(GetParam()); 357 | d_ = get<1>(GetParam()); 358 | elem_size_ = get<2>(GetParam()); 359 | pt_size_ = get<3>(GetParam()); 360 | desired_indices_ = get<4>(GetParam()); 361 | result_offsets_ = get<5>(GetParam()); 362 | ASSERT_EQ(desired_indices_.size(), result_offsets_.size()); 363 | 364 | SetUpDB(dbsize, d_, elem_size_, use_ciphertext_multiplication); 365 | 366 | prng_ = seal::UniformRandomGeneratorFactory::DefaultFactory()->create({99}); 367 | 368 | ASSIGN_OR_FAIL(ct_reencoder_, 369 | CiphertextReencoder::Create(Context()->SEALContext())); 370 | } 371 | 372 | vector<Ciphertext> DecompCT(Ciphertext input_ct) { 373 | vector<Ciphertext> result_cts({input_ct}); 374 | for (size_t d = 0; d < d_ - 1; ++d) { 375 | vector<Ciphertext> inter_cts(result_cts.size() * 376 | ct_reencoder_->ExpansionRatio() * 377 | input_ct.size()); 378 | auto inter_cts_iter = inter_cts.begin(); 379 | for (const auto& ct : result_cts) { 380 | auto pts = ct_reencoder_->Encode(ct); 381 | for (const auto& pt : pts) { 382 | Encryptor()->encrypt(pt, *(inter_cts_iter++)); 383 | } 384 | } 385 | result_cts = inter_cts; 386 | } 387 | return result_cts; 388 | } 389 | 390 | size_t d_; 391 | size_t elem_size_; 392 | size_t pt_size_; 393 | vector<size_t> desired_indices_; 394 | vector<size_t> result_offsets_; 395 | 396 | shared_ptr<UniformRandomGenerator> prng_; 397 | unique_ptr<CiphertextReencoder> ct_reencoder_; 398 | }; 399 | 400 | TEST_P(ProcessResponseTest, TestProcessResponse) { 401 | SetUpForCTMultiply(false); 402 | vector<string> values(desired_indices_.size(), string(pt_size_, 0)); 403 | for (size_t i = 0; i < values.size(); ++i) { 404 | prng_->generate(values[i].size(), 405 | reinterpret_cast<seal::SEAL_BYTE*>(values[i].data())); 406 | } 407 | 408 | StringEncoder encoder(Context()->SEALContext()); 409 | Response response; 410 | for (auto& value : values) { 411 | Plaintext pt; 412 | encoder.encode(value, pt); 413 | Ciphertext ct; 414 | Encryptor()->encrypt(pt, ct); 415 | SaveCiphertexts(DecompCT(ct), response.add_reply()); 416 | } 417 | 418 | ASSIGN_OR_FAIL(auto result, 419 | client_->ProcessResponse(desired_indices_, response)); 420 | 421 | ASSERT_EQ(result.size(), values.size()); 422 | for (size_t i = 0; i < result.size(); ++i) { 423 | EXPECT_EQ(result[i], values[i].substr(result_offsets_[i], elem_size_)) 424 | << "i = " << i; 425 | } 426 | } 427 | 428 | TEST_P(ProcessResponseTest, TestProcessResponseCTMultiply) { 429 | SetUpForCTMultiply(true); 430 | vector<string> values(desired_indices_.size(), string(pt_size_, 0)); 431 | for (size_t i = 0; i < values.size(); ++i) { 432 | prng_->generate(values[i].size(), 433 | reinterpret_cast<seal::SEAL_BYTE*>(values[i].data())); 434 | } 435 | 436 | StringEncoder encoder(Context()->SEALContext()); 437 | Response response; 438 | for (auto& value : values) { 439 | Plaintext pt; 440 | encoder.encode(value, pt); 441 | Ciphertext ct; 442 | Encryptor()->encrypt(pt, ct); 443 | SaveCiphertexts({ct}, response.add_reply()); 444 | } 445 | 446 | ASSIGN_OR_FAIL(auto result, 447 | client_->ProcessResponse(desired_indices_, response)); 448 | 449 | ASSERT_EQ(result.size(), values.size()); 450 | for (size_t i = 0; i < result.size(); ++i) { 451 | EXPECT_EQ(result[i], values[i].substr(result_offsets_[i], elem_size_)) 452 | << "i = " << i; 453 | } 454 | } 455 | 456 | TEST_P(ProcessResponseTest, TestProcessResponseInteger) { 457 | SetUpForCTMultiply(false); 458 | vector<int64_t> values(desired_indices_.size()); 459 | for (size_t i = 0; i < values.size(); ++i) { 460 | prng_->generate(sizeof(values[i]), 461 | reinterpret_cast<seal::SEAL_BYTE*>(&values[i])); 462 | } 463 | 464 | Response response; 465 | for (auto& value : values) { 466 | Plaintext pt; 467 | Context()->Encoder()->encode(value, pt); 468 | Ciphertext ct; 469 | Encryptor()->encrypt(pt, ct); 470 | auto decomp_cts = DecompCT(ct); 471 | SaveCiphertexts(decomp_cts, response.add_reply()); 472 | } 473 | 474 | ASSIGN_OR_FAIL(auto result, client_->ProcessResponseInteger(response)); 475 | ASSERT_EQ(result.size(), values.size()); 476 | for (size_t i = 0; i < result.size(); ++i) { 477 | ASSERT_EQ(result[i], values[i]); 478 | } 479 | } 480 | 481 | TEST_P(ProcessResponseTest, TestProcessResponseIntegerCTMultiply) { 482 | SetUpForCTMultiply(true); 483 | vector<int64_t> values(desired_indices_.size()); 484 | for (size_t i = 0; i < values.size(); ++i) { 485 | prng_->generate(sizeof(values[i]), 486 | reinterpret_cast<seal::SEAL_BYTE*>(&values[i])); 487 | } 488 | 489 | Response response; 490 | for (auto& value : values) { 491 | Plaintext pt; 492 | Context()->Encoder()->encode(value, pt); 493 | Ciphertext ct; 494 | Encryptor()->encrypt(pt, ct); 495 | SaveCiphertexts({ct}, response.add_reply()); 496 | } 497 | 498 | ASSIGN_OR_FAIL(auto result, client_->ProcessResponseInteger(response)); 499 | ASSERT_EQ(result.size(), values.size()); 500 | for (size_t i = 0; i < result.size(); ++i) { 501 | ASSERT_EQ(result[i], values[i]); 502 | } 503 | } 504 | 505 | INSTANTIATE_TEST_SUITE_P( 506 | PIRClientProcessResponsesTests, ProcessResponseTest, 507 | testing::Values( 508 | make_tuple(1000, 1, 64, 7680, vector<size_t>({720, 777, 839}), 509 | vector<size_t>({0, 3648, 7616})), 510 | make_tuple(1000, 2, 64, 7680, vector<size_t>({720, 777, 839}), 511 | vector<size_t>({0, 3648, 7616})), 512 | make_tuple(1000, 3, 64, 7680, vector<size_t>({720, 777, 839}), 513 | vector<size_t>({0, 3648, 7616})), 514 | make_tuple(1000, 4, 64, 7680, vector<size_t>({720, 777, 839}), 515 | vector<size_t>({0, 3648, 7616})))); 516 | 517 | } // namespace pir 518 | -------------------------------------------------------------------------------- /pir/cpp/context.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | #include "pir/cpp/context.h" 17 | 18 | #include "pir/cpp/serialization.h" 19 | #include "pir/cpp/status_asserts.h" 20 | #include "seal/seal.h" 21 | 22 | namespace pir { 23 | 24 | using absl::InternalError; 25 | using absl::InvalidArgumentError; 26 | using absl::StatusOr; 27 | using seal::EncryptionParameters; 28 | 29 | PIRContext::PIRContext(shared_ptr<PIRParameters> params, 30 | const EncryptionParameters& enc_params, 31 | shared_ptr<seal::SEALContext> context) 32 | : parameters_(params), encryption_params_(enc_params), context_(context) { 33 | encoder_ = std::make_shared<seal::IntegerEncoder>(this->context_); 34 | evaluator_ = std::make_shared<seal::Evaluator>(context_); 35 | } 36 | 37 | StatusOr<std::unique_ptr<PIRContext>> PIRContext::Create( 38 | shared_ptr<PIRParameters> params) { 39 | ASSIGN_OR_RETURN(auto enc_params, SEALDeserialize<EncryptionParameters>( 40 | params->encryption_parameters())); 41 | 42 | try { 43 | auto context = seal::SEALContext::Create(enc_params); 44 | return absl::WrapUnique(new PIRContext(params, enc_params, context)); 45 | } catch (const std::exception& e) { 46 | return InvalidArgumentError(e.what()); 47 | } 48 | 49 | return InternalError("this should never happen"); 50 | } 51 | 52 | } // namespace pir 53 | -------------------------------------------------------------------------------- /pir/cpp/context.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_CONTEXT_H_ 18 | #define PIR_CONTEXT_H_ 19 | 20 | #include "absl/status/statusor.h" 21 | #include "pir/cpp/parameters.h" 22 | #include "seal/seal.h" 23 | 24 | namespace pir { 25 | 26 | using absl::StatusOr; 27 | 28 | using ::std::optional; 29 | using ::std::shared_ptr; 30 | using ::std::string; 31 | using ::std::vector; 32 | 33 | using seal::EncryptionParameters; 34 | 35 | class PIRContext { 36 | public: 37 | /** 38 | * Creates a new context 39 | * @param[in] params PIR parameters 40 | * @returns InvalidArgument if the SEAL parameter deserialization fails 41 | **/ 42 | static StatusOr<std::unique_ptr<PIRContext>> Create( 43 | shared_ptr<PIRParameters> /*params*/); 44 | /** 45 | * Returns an Evaluator instance. 46 | **/ 47 | std::shared_ptr<seal::Evaluator>& Evaluator() { return evaluator_; } 48 | /** 49 | * Returns the SEAL context. 50 | **/ 51 | std::shared_ptr<seal::SEALContext>& SEALContext() { return context_; } 52 | /** 53 | * Returns the PIR parameters protobuffer. 54 | **/ 55 | shared_ptr<PIRParameters> Params() { return parameters_; } 56 | /** 57 | * Returns the dimensions sum. 58 | **/ 59 | size_t DimensionsSum() { 60 | return std::accumulate(Params()->dimensions().begin(), 61 | Params()->dimensions().end(), 0); 62 | } 63 | /** 64 | * Returns the encryption parameters used to create SEAL context. 65 | **/ 66 | const EncryptionParameters& EncryptionParams() { return encryption_params_; } 67 | 68 | /** 69 | * Returns the encoder 70 | **/ 71 | std::shared_ptr<seal::IntegerEncoder>& Encoder() { return encoder_; } 72 | 73 | private: 74 | PIRContext(shared_ptr<PIRParameters> /*params*/, 75 | const EncryptionParameters& /*enc_params*/, 76 | shared_ptr<seal::SEALContext> /*seal_context*/); 77 | 78 | shared_ptr<PIRParameters> parameters_; 79 | EncryptionParameters encryption_params_; 80 | shared_ptr<seal::SEALContext> context_; 81 | shared_ptr<seal::Evaluator> evaluator_; 82 | shared_ptr<seal::IntegerEncoder> encoder_; 83 | }; 84 | 85 | } // namespace pir 86 | 87 | #endif // PIR_CONTEXT_H_ 88 | -------------------------------------------------------------------------------- /pir/cpp/correctness_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include <algorithm> 18 | #include <iostream> 19 | #include <vector> 20 | 21 | #include "gmock/gmock.h" 22 | #include "gtest/gtest.h" 23 | #include "pir/cpp/client.h" 24 | #include "pir/cpp/server.h" 25 | #include "pir/cpp/status_asserts.h" 26 | #include "pir/cpp/test_base.h" 27 | #include "pir/cpp/utils.h" 28 | 29 | namespace pir { 30 | // namespace { 31 | 32 | using std::cout; 33 | using std::endl; 34 | using std::get; 35 | using std::make_tuple; 36 | using std::make_unique; 37 | using std::shared_ptr; 38 | using std::string; 39 | using std::tuple; 40 | using std::unique_ptr; 41 | using std::vector; 42 | 43 | using seal::Ciphertext; 44 | using seal::GaloisKeys; 45 | using seal::Plaintext; 46 | using seal::RelinKeys; 47 | 48 | using namespace seal; 49 | using namespace ::testing; 50 | using std::int64_t; 51 | using std::vector; 52 | 53 | class PIRCorrectnessTest 54 | : public ::testing::TestWithParam< 55 | tuple<bool, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, 56 | uint32_t, vector<size_t>>>, 57 | public PIRTestingBase { 58 | protected: 59 | void SetUp() { 60 | const auto use_ciphertext_multiplication = get<0>(GetParam()); 61 | const auto poly_modulus_degree = get<1>(GetParam()); 62 | const auto plain_mod_bits = get<2>(GetParam()); 63 | const auto elem_size = get<3>(GetParam()); 64 | const auto bits_per_coeff = get<4>(GetParam()); 65 | const auto dbsize = get<5>(GetParam()); 66 | const auto d = get<6>(GetParam()); 67 | 68 | SetUpParams(dbsize, elem_size, d, poly_modulus_degree, plain_mod_bits, 69 | bits_per_coeff, use_ciphertext_multiplication); 70 | GenerateDB(); 71 | 72 | client_ = *(PIRClient::Create(pir_params_)); 73 | server_ = *(PIRServer::Create(pir_db_, pir_params_)); 74 | ASSERT_THAT(client_, NotNull()); 75 | ASSERT_THAT(server_, NotNull()); 76 | } 77 | 78 | unique_ptr<PIRClient> client_; 79 | unique_ptr<PIRServer> server_; 80 | }; 81 | 82 | TEST_P(PIRCorrectnessTest, TestCorrectness) { 83 | const auto desired_indices = get<7>(GetParam()); 84 | ASSIGN_OR_FAIL(auto request, client_->CreateRequest(desired_indices)); 85 | ASSIGN_OR_FAIL(auto response, server_->ProcessRequest(request)); 86 | ASSIGN_OR_FAIL(auto results, 87 | client_->ProcessResponse(desired_indices, response)); 88 | 89 | ASSERT_EQ(results.size(), desired_indices.size()); 90 | for (size_t i = 0; i < results.size(); ++i) { 91 | ASSERT_EQ(results[i], string_db_[desired_indices[i]]) << "i = " << i; 92 | } 93 | } 94 | 95 | INSTANTIATE_TEST_SUITE_P( 96 | CorrectnessTest, PIRCorrectnessTest, 97 | testing::Values( 98 | make_tuple(true, 4096, 24, 0, 0, 10, 1, vector<size_t>({0})), 99 | make_tuple(true, 4096, 16, 0, 10, 9, 2, vector<size_t>({1, 5})), 100 | make_tuple(true, 4096, 16, 0, 6, 500, 2, vector<size_t>({9, 125})), 101 | make_tuple(true, 8192, 42, 0, 0, 87, 2, vector<size_t>({5, 33, 86})), 102 | make_tuple(true, 4096, 16, 64, 10, 1200, 1, 103 | vector<size_t>({0, 80, 81, 123, 777, 1199})), 104 | make_tuple(true, 4096, 16, 289, 10, 1200, 1, 105 | vector<size_t>({0, 47, 777, 1199})), 106 | 107 | make_tuple(false, 4096, 24, 0, 0, 10, 1, vector<size_t>({0})), 108 | make_tuple(false, 4096, 24, 0, 10, 9, 2, vector<size_t>({1, 5})), 109 | make_tuple(false, 4096, 24, 0, 6, 500, 2, vector<size_t>({9, 125})), 110 | make_tuple(false, 4096, 24, 64, 10, 1200, 1, 111 | vector<size_t>({0, 80, 81, 123, 777, 1199})), 112 | make_tuple(false, 4096, 24, 289, 10, 1200, 1, 113 | vector<size_t>({0, 47, 777, 1199})))); 114 | 115 | //} // namespace 116 | } // namespace pir 117 | -------------------------------------------------------------------------------- /pir/cpp/ct_reencoder.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | #include "pir/cpp/ct_reencoder.h" 17 | 18 | #include "pir/cpp/serialization.h" 19 | #include "pir/cpp/status_asserts.h" 20 | #include "seal/seal.h" 21 | 22 | namespace pir { 23 | 24 | StatusOr<std::unique_ptr<CiphertextReencoder>> CiphertextReencoder::Create( 25 | shared_ptr<SEALContext> context) { 26 | return absl::WrapUnique(new CiphertextReencoder(context)); 27 | } 28 | 29 | uint32_t CiphertextReencoder::ExpansionRatio() const { 30 | uint32_t expansion_ratio = 0; 31 | const auto params = context_->first_context_data()->parms(); 32 | uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value()); 33 | for (size_t i = 0; i < params.coeff_modulus().size(); ++i) { 34 | double coeff_bit_size = log2(params.coeff_modulus()[i].value()); 35 | expansion_ratio += ceil(coeff_bit_size / pt_bits_per_coeff); 36 | } 37 | return expansion_ratio; 38 | } 39 | 40 | vector<Plaintext> CiphertextReencoder::Encode(const Ciphertext& ct) { 41 | const auto params = context_->first_context_data()->parms(); 42 | const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value()); 43 | const auto coeff_count = params.poly_modulus_degree(); 44 | const auto coeff_mod_count = params.coeff_modulus().size(); 45 | const uint64_t pt_bitmask = (1 << pt_bits_per_coeff) - 1; 46 | 47 | vector<Plaintext> result(ExpansionRatio() * ct.size()); 48 | auto pt_iter = result.begin(); 49 | for (size_t poly_index = 0; poly_index < ct.size(); ++poly_index) { 50 | for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count; 51 | ++coeff_mod_index) { 52 | const double coeff_bit_size = 53 | log2(params.coeff_modulus()[coeff_mod_index].value()); 54 | const size_t local_expansion_ratio = 55 | ceil(coeff_bit_size / pt_bits_per_coeff); 56 | size_t shift = 0; 57 | for (size_t i = 0; i < local_expansion_ratio; ++i) { 58 | pt_iter->resize(coeff_count); 59 | for (size_t c = 0; c < coeff_count; ++c) { 60 | (*pt_iter)[c] = 61 | (ct.data(poly_index)[coeff_mod_index * coeff_count + c] >> 62 | shift) & 63 | pt_bitmask; 64 | } 65 | ++pt_iter; 66 | shift += pt_bits_per_coeff; 67 | } 68 | } 69 | } 70 | return result; 71 | } 72 | 73 | Ciphertext CiphertextReencoder::Decode(const vector<Plaintext>& pts) { 74 | return Decode(pts.begin(), pts.size() / ExpansionRatio()); 75 | } 76 | 77 | Ciphertext CiphertextReencoder::Decode( 78 | vector<Plaintext>::const_iterator pt_iter, const size_t ct_poly_count) { 79 | const auto params = context_->first_context_data()->parms(); 80 | const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value()); 81 | const auto coeff_count = params.poly_modulus_degree(); 82 | const auto coeff_mod_count = params.coeff_modulus().size(); 83 | // size_t pt_count = 0; 84 | // TODO: should check here if numbers match 85 | 86 | Ciphertext ct(context_); 87 | ct.resize(ct_poly_count); 88 | for (size_t poly_index = 0; poly_index < ct_poly_count; ++poly_index) { 89 | for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count; 90 | ++coeff_mod_index) { 91 | const double coeff_bit_size = 92 | log2(params.coeff_modulus()[coeff_mod_index].value()); 93 | const size_t local_expansion_ratio = 94 | ceil(coeff_bit_size / pt_bits_per_coeff); 95 | size_t shift = 0; 96 | for (size_t i = 0; i < local_expansion_ratio; ++i) { 97 | for (size_t c = 0; c < pt_iter->coeff_count(); ++c) { 98 | if (shift == 0) { 99 | ct.data(poly_index)[coeff_mod_index * coeff_count + c] = 100 | (*pt_iter)[c]; 101 | } else { 102 | ct.data(poly_index)[coeff_mod_index * coeff_count + c] += 103 | ((*pt_iter)[c] << shift); 104 | } 105 | } 106 | ++pt_iter; 107 | shift += pt_bits_per_coeff; 108 | } 109 | } 110 | } 111 | return ct; 112 | } 113 | 114 | } // namespace pir 115 | -------------------------------------------------------------------------------- /pir/cpp/ct_reencoder.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_CT_REENCODER_H_ 18 | #define PIR_CT_REENCODER_H_ 19 | 20 | #include "absl/status/statusor.h" 21 | #include "seal/seal.h" 22 | 23 | namespace pir { 24 | 25 | using absl::StatusOr; 26 | 27 | using seal::Ciphertext; 28 | using seal::Plaintext; 29 | using seal::SEALContext; 30 | using ::std::shared_ptr; 31 | using ::std::vector; 32 | 33 | class CiphertextReencoder { 34 | public: 35 | static StatusOr<std::unique_ptr<CiphertextReencoder>> Create( 36 | shared_ptr<SEALContext> /*params*/); 37 | 38 | uint32_t ExpansionRatio() const; 39 | 40 | /** 41 | * Reencode a ciphertext as a set of plaintexts. 42 | * @param[in] ct Ciphertext to reencode. 43 | * @returns Vector of plaintexts created by decomposing CT. 44 | */ 45 | vector<Plaintext> Encode(const Ciphertext& ct); 46 | 47 | /** 48 | * Recompose a ciphertext from a set of plaintexts. 49 | * @param[in] pts Vector of plaintexts to decode. 50 | * @returns Ciphertext recomposed from plaintexts. 51 | */ 52 | Ciphertext Decode(const vector<Plaintext>& pts); 53 | 54 | Ciphertext Decode(vector<Plaintext>::const_iterator pt_iter, 55 | const size_t ct_poly_count); 56 | 57 | private: 58 | CiphertextReencoder(shared_ptr<SEALContext> context) : context_(context) {} 59 | 60 | shared_ptr<SEALContext> context_; 61 | }; 62 | 63 | } // namespace pir 64 | 65 | #endif // PIR_CT_REENCODER_H_ 66 | -------------------------------------------------------------------------------- /pir/cpp/ct_reencoder_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "pir/cpp/ct_reencoder.h" 18 | 19 | #include <memory> 20 | 21 | #include "gmock/gmock.h" 22 | #include "gtest/gtest.h" 23 | #include "pir/cpp/parameters.h" 24 | #include "pir/cpp/status_asserts.h" 25 | #include "pir/cpp/string_encoder.h" 26 | 27 | namespace pir { 28 | namespace { 29 | 30 | using std::cout; 31 | using std::endl; 32 | using std::make_unique; 33 | using std::unique_ptr; 34 | using std::vector; 35 | 36 | using namespace seal; 37 | using namespace ::testing; 38 | 39 | constexpr size_t POLY_MODULUS_DEGREE = 4096; 40 | 41 | string generate_string(size_t size) { 42 | static auto prng = 43 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 44 | string result(size, 0); 45 | prng->generate(size, reinterpret_cast<seal::SEAL_BYTE*>(result.data())); 46 | return result; 47 | } 48 | 49 | class CiphertextReencoderTest : public ::testing::Test { 50 | protected: 51 | void SetUp() { 52 | auto params = GenerateEncryptionParams(POLY_MODULUS_DEGREE); 53 | seal_context_ = seal::SEALContext::Create(params); 54 | if (!seal_context_->parameters_set()) { 55 | FAIL() << "Error setting encryption parameters: " 56 | << seal_context_->parameter_error_message(); 57 | } 58 | keygen_ = make_unique<KeyGenerator>(seal_context_); 59 | encryptor_ = make_unique<Encryptor>(seal_context_, keygen_->public_key()); 60 | decryptor_ = make_unique<Decryptor>(seal_context_, keygen_->secret_key()); 61 | encoder_ = make_unique<StringEncoder>(seal_context_); 62 | ct_reencoder_ = *(CiphertextReencoder::Create(seal_context_)); 63 | } 64 | 65 | string GenerateSampleString() { 66 | return generate_string(encoder_->max_bytes_per_plaintext()); 67 | } 68 | 69 | shared_ptr<SEALContext> seal_context_; 70 | unique_ptr<CiphertextReencoder> ct_reencoder_; 71 | unique_ptr<StringEncoder> encoder_; 72 | unique_ptr<KeyGenerator> keygen_; 73 | unique_ptr<Encryptor> encryptor_; 74 | unique_ptr<Decryptor> decryptor_; 75 | }; 76 | 77 | TEST_F(CiphertextReencoderTest, TextExpansionRatio) { 78 | EXPECT_EQ(ct_reencoder_->ExpansionRatio(), 4); 79 | } 80 | 81 | TEST_F(CiphertextReencoderTest, TestEncodeDecode) { 82 | string value = GenerateSampleString(); 83 | Plaintext pt; 84 | encoder_->encode(value, pt); 85 | Ciphertext ct; 86 | encryptor_->encrypt(pt, ct); 87 | auto pt_decomp = ct_reencoder_->Encode(ct); 88 | ASSERT_EQ(pt_decomp.size(), ct.size() * ct_reencoder_->ExpansionRatio()); 89 | auto result_ct = ct_reencoder_->Decode(pt_decomp); 90 | Plaintext result_pt; 91 | decryptor_->decrypt(result_ct, result_pt); 92 | EXPECT_EQ(result_pt, pt); 93 | ASSIGN_OR_FAIL(auto result, encoder_->decode(result_pt)); 94 | EXPECT_EQ(result, value); 95 | } 96 | 97 | TEST_F(CiphertextReencoderTest, TestRecursion) { 98 | string value = GenerateSampleString(); 99 | Plaintext pt; 100 | encoder_->encode(value, pt); 101 | 102 | // level 1 103 | Ciphertext ct; 104 | encryptor_->encrypt(pt, ct); 105 | auto pt_lvl_1 = ct_reencoder_->Encode(ct); 106 | size_t exp_ratio = ct.size() * ct_reencoder_->ExpansionRatio(); 107 | ASSERT_EQ(pt_lvl_1.size(), exp_ratio); 108 | 109 | // level 2 110 | vector<Plaintext> pt_lvl_2; 111 | pt_lvl_2.reserve(pt_lvl_1.size() * exp_ratio); 112 | for (size_t i = 0; i < pt_lvl_1.size(); ++i) { 113 | Ciphertext ct; 114 | encryptor_->encrypt(pt_lvl_1[i], ct); 115 | auto pts = ct_reencoder_->Encode(ct); 116 | pt_lvl_2.insert(pt_lvl_2.end(), pts.begin(), pts.end()); 117 | } 118 | ASSERT_EQ(pt_lvl_2.size(), exp_ratio * exp_ratio); 119 | 120 | // decode level 2 121 | vector<Plaintext> result_pt_lvl_1(exp_ratio); 122 | for (size_t i = 0; i < exp_ratio; ++i) { 123 | auto result_ct = 124 | ct_reencoder_->Decode(pt_lvl_2.begin() + (i * exp_ratio), ct.size()); 125 | decryptor_->decrypt(result_ct, result_pt_lvl_1[i]); 126 | } 127 | 128 | // decode level 1 129 | auto result_ct = ct_reencoder_->Decode(result_pt_lvl_1); 130 | Plaintext result_pt; 131 | decryptor_->decrypt(result_ct, result_pt); 132 | EXPECT_EQ(result_pt, pt); 133 | ASSIGN_OR_FAIL(auto result, encoder_->decode(result_pt)); 134 | EXPECT_EQ(result, value); 135 | } 136 | 137 | TEST_F(CiphertextReencoderTest, TestEncryptDecrypt) { 138 | string value = GenerateSampleString(); 139 | Plaintext pt; 140 | encoder_->encode(value, pt); 141 | Ciphertext ct; 142 | encryptor_->encrypt(pt, ct); 143 | auto pt_decomp = ct_reencoder_->Encode(ct); 144 | ASSERT_EQ(pt_decomp.size(), ct.size() * ct_reencoder_->ExpansionRatio()); 145 | 146 | vector<Ciphertext> cts(pt_decomp.size()); 147 | for (size_t i = 0; i < cts.size(); ++i) { 148 | encryptor_->encrypt(pt_decomp[i], cts[i]); 149 | } 150 | 151 | vector<Plaintext> pts(pt_decomp.size()); 152 | for (size_t i = 0; i < cts.size(); ++i) { 153 | decryptor_->decrypt(cts[i], pts[i]); 154 | } 155 | 156 | auto result_ct = ct_reencoder_->Decode(pts); 157 | Plaintext result_pt; 158 | decryptor_->decrypt(result_ct, result_pt); 159 | EXPECT_EQ(result_pt, pt); 160 | ASSIGN_OR_FAIL(auto result, encoder_->decode(result_pt)); 161 | EXPECT_EQ(result, value); 162 | } 163 | 164 | TEST_F(CiphertextReencoderTest, TestMultOneEnc) { 165 | string value = GenerateSampleString(); 166 | Plaintext pt; 167 | encoder_->encode(value, pt); 168 | Ciphertext ct; 169 | encryptor_->encrypt(pt, ct); 170 | auto pt_decomp = ct_reencoder_->Encode(ct); 171 | ASSERT_EQ(pt_decomp.size(), ct.size() * ct_reencoder_->ExpansionRatio()); 172 | 173 | Plaintext one_pt(1); 174 | one_pt[0] = 1; 175 | Ciphertext one_ct; 176 | encryptor_->encrypt(one_pt, one_ct); 177 | 178 | Evaluator eval(seal_context_); 179 | vector<Ciphertext> cts(pt_decomp.size()); 180 | for (size_t i = 0; i < cts.size(); ++i) { 181 | eval.multiply_plain(one_ct, pt_decomp[i], cts[i]); 182 | } 183 | 184 | vector<Plaintext> pts(pt_decomp.size()); 185 | for (size_t i = 0; i < cts.size(); ++i) { 186 | decryptor_->decrypt(cts[i], pts[i]); 187 | } 188 | 189 | auto result_ct = ct_reencoder_->Decode(pts); 190 | Plaintext result_pt; 191 | decryptor_->decrypt(result_ct, result_pt); 192 | ASSIGN_OR_FAIL(auto result, encoder_->decode(result_pt)); 193 | EXPECT_EQ(result, value); 194 | } 195 | 196 | } // namespace 197 | } // namespace pir 198 | -------------------------------------------------------------------------------- /pir/cpp/database.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | #include "pir/cpp/database.h" 17 | 18 | #include <iostream> 19 | #include <memory> 20 | 21 | #include "pir/cpp/ct_reencoder.h" 22 | #include "pir/cpp/status_asserts.h" 23 | #include "pir/cpp/string_encoder.h" 24 | #include "pir/cpp/utils.h" 25 | #include "seal/seal.h" 26 | 27 | namespace pir { 28 | 29 | using absl::InternalError; 30 | using absl::InvalidArgumentError; 31 | using absl::StatusOr; 32 | using google::protobuf::RepeatedField; 33 | using seal::Ciphertext; 34 | using seal::Evaluator; 35 | using seal::Plaintext; 36 | using std::unique_ptr; 37 | using std::vector; 38 | 39 | StatusOr<shared_ptr<PIRDatabase>> PIRDatabase::Create( 40 | shared_ptr<PIRParameters> params) { 41 | ASSIGN_OR_RETURN(auto context, PIRContext::Create(params)); 42 | return std::make_shared<PIRDatabase>(std::move(context)); 43 | } 44 | StatusOr<shared_ptr<PIRDatabase>> PIRDatabase::Create( 45 | const vector<std::int64_t>& rawdb, shared_ptr<PIRParameters> params) { 46 | ASSIGN_OR_RETURN(auto context, PIRContext::Create(params)); 47 | auto pir_db = std::make_shared<PIRDatabase>(std::move(context)); 48 | RETURN_IF_ERROR(pir_db->populate(rawdb)); 49 | return std::move(pir_db); 50 | } 51 | 52 | StatusOr<shared_ptr<PIRDatabase>> PIRDatabase::Create( 53 | const vector<string>& rawdb, shared_ptr<PIRParameters> params) { 54 | ASSIGN_OR_RETURN(auto context, PIRContext::Create(params)); 55 | auto pir_db = std::make_shared<PIRDatabase>(std::move(context)); 56 | RETURN_IF_ERROR(pir_db->populate(rawdb)); 57 | return std::move(pir_db); 58 | } 59 | 60 | Status PIRDatabase::populate(const vector<std::int64_t>& rawdb) { 61 | if (rawdb.size() != context_->Params()->num_items()) { 62 | return InvalidArgumentError( 63 | "Database size " + std::to_string(rawdb.size()) + 64 | " does not match params value " + 65 | std::to_string(context_->Params()->num_items())); 66 | } 67 | 68 | auto evaluator = std::make_unique<seal::Evaluator>(context_->SEALContext()); 69 | db_.resize(rawdb.size()); 70 | for (size_t idx = 0; idx < rawdb.size(); ++idx) { 71 | try { 72 | context_->Encoder()->encode(rawdb[idx], db_[idx]); 73 | if (!context_->Params()->use_ciphertext_multiplication()) { 74 | evaluator->transform_to_ntt_inplace( 75 | db_[idx], context_->SEALContext()->first_parms_id()); 76 | } 77 | } catch (std::exception& e) { 78 | return InvalidArgumentError(e.what()); 79 | } 80 | } 81 | return absl::OkStatus(); 82 | } 83 | 84 | Status PIRDatabase::populate(const vector<string>& rawdb) { 85 | if (rawdb.size() != context_->Params()->num_items()) { 86 | return InvalidArgumentError( 87 | "Database size " + std::to_string(rawdb.size()) + 88 | " does not match params value " + 89 | std::to_string(context_->Params()->num_items())); 90 | } 91 | 92 | const auto items_per_pt = context_->Params()->items_per_plaintext(); 93 | db_.resize(context_->Params()->num_pt()); 94 | auto encoder = std::make_unique<StringEncoder>(context_->SEALContext()); 95 | auto evaluator = std::make_unique<seal::Evaluator>(context_->SEALContext()); 96 | if (context_->Params()->bits_per_coeff() > 0) { 97 | encoder->set_bits_per_coeff(context_->Params()->bits_per_coeff()); 98 | } 99 | auto raw_it = rawdb.begin(); 100 | for (size_t i = 0; i < db_.size(); ++i) { 101 | auto end_it = std::min(raw_it + items_per_pt, rawdb.end()); 102 | RETURN_IF_ERROR(encoder->encode(raw_it, end_it, db_[i])); 103 | if (!context_->Params()->use_ciphertext_multiplication()) { 104 | evaluator->transform_to_ntt_inplace( 105 | db_[i], context_->SEALContext()->first_parms_id()); 106 | } 107 | raw_it += items_per_pt; 108 | } 109 | return absl::OkStatus(); 110 | } 111 | 112 | /** 113 | * Helper class to make the recursive multiplication operation on the 114 | * multi-dimensional representation of the database easier. Encapsulates all of 115 | * the variables needed to do the multiplication, and keeps track of the 116 | * database iterator to separate it from the database itself. 117 | */ 118 | class DatabaseMultiplier { 119 | public: 120 | /** 121 | * Create a multiplier for the given scenario. 122 | * @param[in] database Database against which to multiply. 123 | * @param[in] selection_vector multi-dimensional selection vector 124 | * @param[in] evaluator Evaluator to use for homomorphic operations. 125 | * @param[in] relin_keys If not nullptr, relinearization will be done after 126 | * every homomorphic multiplication. 127 | * @param[in] decryptor If not nullptr, outputs to cout the noise budget 128 | * remaining after every homomorphic operation. 129 | */ 130 | DatabaseMultiplier(const vector<Plaintext>& database, 131 | vector<Ciphertext>& selection_vector, 132 | shared_ptr<Evaluator> evaluator, 133 | unique_ptr<CiphertextReencoder> ct_reencoder, 134 | std::shared_ptr<seal::SEALContext> seal_context, 135 | const seal::RelinKeys* const relin_keys, 136 | seal::Decryptor* const decryptor) 137 | : database_(database), 138 | selection_vector_(selection_vector), 139 | evaluator_(evaluator), 140 | ct_reencoder_(std::move(ct_reencoder)), 141 | seal_context_(seal_context), 142 | exp_ratio_(ct_reencoder_ == nullptr ? 1 143 | : ct_reencoder_->ExpansionRatio()), 144 | relin_keys_(relin_keys), 145 | decryptor_(decryptor) {} 146 | 147 | /** 148 | * Do the multiplication using the given dimension sizes. 149 | */ 150 | vector<Ciphertext> multiply(const RepeatedField<uint32_t>& dimensions) { 151 | database_it_ = database_.begin(); 152 | return multiply(dimensions, selection_vector_.begin(), 0); 153 | } 154 | 155 | private: 156 | /** 157 | * Recursive function to do the dot product of each dimension with the db. 158 | * Calls itself to move down dimensions until you get to the bottom dimension. 159 | * Bottom dimension just does a dot product with the DB, and returns the 160 | * result. Upper levels then take those results, and dot product again with 161 | * the selection vector, until you get back to the top. NB: Database iterator 162 | * is kept at the class level so that we move through the database one element 163 | * at a time. 164 | * 165 | * @param[in] dimensions List of remaining demainsion sizes. 166 | * @param[in] selection_vector_it Iterator into the start of the selection 167 | * vector for the current depth. 168 | * @param[in] depth Current depth. 169 | */ 170 | vector<Ciphertext> multiply(const RepeatedField<uint32_t>& dimensions, 171 | vector<Ciphertext>::iterator selection_vector_it, 172 | size_t depth) { 173 | const size_t this_dimension = dimensions[0]; 174 | auto remaining_dimensions = 175 | RepeatedField<uint32_t>(dimensions.begin() + 1, dimensions.end()); 176 | 177 | string depth_string(depth, ' '); 178 | 179 | vector<Ciphertext> result; 180 | bool first_pass = true; 181 | for (size_t i = 0; i < this_dimension; ++i) { 182 | // make sure we don't go past end of DB 183 | if (database_it_ == database_.end()) break; 184 | vector<Ciphertext> temp_ct; 185 | if (remaining_dimensions.empty()) { 186 | // base case: have to multiply against DB 187 | temp_ct.resize(1); 188 | if (ct_reencoder_ != nullptr && 189 | !(selection_vector_it + i)->is_ntt_form()) { 190 | evaluator_->transform_to_ntt_inplace(*(selection_vector_it + i)); 191 | } 192 | evaluator_->multiply_plain(*(selection_vector_it + i), 193 | *(database_it_++), temp_ct[0]); 194 | print_noise(depth, "base", temp_ct[0], i); 195 | 196 | } else { 197 | auto lower_result = 198 | multiply(remaining_dimensions, selection_vector_it + this_dimension, 199 | depth + 1); 200 | print_noise(depth, "recurse", lower_result[0], i); 201 | 202 | if (ct_reencoder_ == nullptr) { 203 | temp_ct.resize(1); 204 | evaluator_->multiply(lower_result[0], *(selection_vector_it + i), 205 | temp_ct[0]); 206 | print_noise(depth, "mult", temp_ct[0], i); 207 | 208 | if (relin_keys_ != nullptr) { 209 | evaluator_->relinearize_inplace(temp_ct[0], *relin_keys_); 210 | print_noise(depth, "relin", temp_ct[0], i); 211 | } 212 | 213 | } else { 214 | // TODO: check that all CT are size 2 215 | temp_ct.resize(lower_result.size() * exp_ratio_ * 2); 216 | auto temp_ct_it = temp_ct.begin(); 217 | for (const auto& ct : lower_result) { 218 | auto pt_decomp = ct_reencoder_->Encode(ct); 219 | size_t k = 0; 220 | for (auto pt : pt_decomp) { 221 | if (!(selection_vector_it + i)->is_ntt_form()) { 222 | evaluator_->transform_to_ntt_inplace( 223 | *(selection_vector_it + i)); 224 | } 225 | if (!pt.is_ntt_form()) { 226 | evaluator_->transform_to_ntt_inplace( 227 | pt, seal_context_->first_parms_id()); 228 | } 229 | evaluator_->multiply_plain(*(selection_vector_it + i), pt, 230 | *temp_ct_it); 231 | print_noise(depth, "mult", *temp_ct_it, k++); 232 | ++temp_ct_it; 233 | } 234 | } 235 | } 236 | } 237 | 238 | if (first_pass) { 239 | result = temp_ct; 240 | first_pass = false; 241 | print_noise(depth, "first_pass", result[0], i); 242 | } else { 243 | for (size_t j = 0; j < result.size(); ++j) { 244 | evaluator_->add_inplace(result[j], temp_ct[j]); 245 | print_noise(depth, "result", result[j], i); 246 | } 247 | } 248 | } 249 | 250 | for (auto& ct : result) { 251 | if (ct.is_ntt_form()) { 252 | evaluator_->transform_from_ntt_inplace(ct); 253 | } 254 | } 255 | 256 | print_noise(depth, "final", result[0]); 257 | return result; 258 | } 259 | 260 | void print_noise(size_t depth, const string& desc, const Ciphertext& ct, 261 | std::optional<size_t> i_opt = {}) { 262 | if (decryptor_ != nullptr) { 263 | std::cout << string(depth, ' '); 264 | if (i_opt) { 265 | std::cout << "i = " << (*i_opt) << " "; 266 | } 267 | std::cout << desc << " noise budget " 268 | << decryptor_->invariant_noise_budget(ct) << std::endl; 269 | } 270 | } 271 | 272 | const vector<Plaintext>& database_; 273 | vector<Ciphertext>& selection_vector_; 274 | shared_ptr<Evaluator> evaluator_; 275 | unique_ptr<CiphertextReencoder> ct_reencoder_; 276 | std::shared_ptr<seal::SEALContext> seal_context_; 277 | const size_t exp_ratio_; 278 | 279 | // If not null, relinearization keys are applied after each HE op 280 | const seal::RelinKeys* const relin_keys_; 281 | 282 | // If not null, used to get invariant noise budget after each HE op 283 | seal::Decryptor* const decryptor_; 284 | 285 | // Current location as we move through the database. 286 | // Needs to be kept here, as lower levels of recursion move forward. 287 | vector<Plaintext>::const_iterator database_it_; 288 | }; 289 | 290 | StatusOr<vector<Ciphertext>> PIRDatabase::multiply( 291 | vector<Ciphertext>& selection_vector, 292 | const seal::RelinKeys* const relin_keys, 293 | seal::Decryptor* const decryptor) const { 294 | auto& dimensions = context_->Params()->dimensions(); 295 | const size_t dim_sum = context_->DimensionsSum(); 296 | 297 | if (selection_vector.size() != dim_sum) { 298 | return InvalidArgumentError( 299 | "Selection vector size does not match dimensions"); 300 | } 301 | 302 | unique_ptr<CiphertextReencoder> ct_reencoder = nullptr; 303 | if (!context_->Params()->use_ciphertext_multiplication()) { 304 | ASSIGN_OR_RETURN(ct_reencoder, 305 | CiphertextReencoder::Create(context_->SEALContext())); 306 | } 307 | 308 | try { 309 | DatabaseMultiplier dbm(db_, selection_vector, context_->Evaluator(), 310 | std::move(ct_reencoder), context_->SEALContext(), 311 | relin_keys, decryptor); 312 | return dbm.multiply(dimensions); 313 | } catch (std::exception& e) { 314 | return InternalError(e.what()); 315 | } 316 | } 317 | 318 | vector<uint32_t> PIRDatabase::calculate_indices(uint32_t index) { 319 | uint32_t pt_index = index / context_->Params()->items_per_plaintext(); 320 | vector<uint32_t> results(context_->Params()->dimensions_size(), 0); 321 | for (int i = results.size() - 1; i >= 0; --i) { 322 | results[i] = pt_index % context_->Params()->dimensions(i); 323 | pt_index = pt_index / context_->Params()->dimensions(i); 324 | } 325 | return results; 326 | } 327 | 328 | size_t PIRDatabase::calculate_item_offset(uint32_t index) { 329 | uint32_t pt_index = index / context_->Params()->items_per_plaintext(); 330 | return (index - (pt_index * context_->Params()->items_per_plaintext())) * 331 | context_->Params()->bytes_per_item(); 332 | } 333 | 334 | vector<uint32_t> PIRDatabase::calculate_dimensions(uint32_t db_size, 335 | uint32_t num_dimensions) { 336 | vector<uint32_t> results; 337 | for (int i = num_dimensions; i > 0; --i) { 338 | results.push_back(std::ceil(std::pow(db_size, 1.0 / i))); 339 | db_size = std::ceil(static_cast<double>(db_size) / results.back()); 340 | } 341 | return results; 342 | } 343 | 344 | } // namespace pir 345 | -------------------------------------------------------------------------------- /pir/cpp/database.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_DATABASE_H_ 18 | #define PIR_DATABASE_H_ 19 | 20 | #include <string> 21 | #include <vector> 22 | 23 | #include "absl/status/statusor.h" 24 | #include "pir/cpp/context.h" 25 | #include "seal/seal.h" 26 | 27 | namespace pir { 28 | 29 | using absl::Status; 30 | using absl::StatusOr; 31 | using std::shared_ptr; 32 | using std::vector; 33 | 34 | /** 35 | * Representation of a PIR database, helpful for both server and client. Server 36 | * uses this class to process responses by multiplying a selection vector 37 | * against database values in multi-dimensional format, while the client uses it 38 | * without loading the backing data to calculate indices and offsets. 39 | */ 40 | class PIRDatabase { 41 | public: 42 | /** 43 | * Creates and returns an empty PIR database with the params used to generate 44 | * a context. 45 | * @param[in] PIR parameters 46 | **/ 47 | static StatusOr<shared_ptr<PIRDatabase>> Create( 48 | shared_ptr<PIRParameters> params); 49 | 50 | /** 51 | * Shortcut to create and return a new PIR database instance using a vector of 52 | *integers encoded one per database plaintext using IntegerEncoder. Only 53 | *really used for testing, not intended for actual PIR use. 54 | * @param[in] db Vector of integers to encode into database of plaintexts 55 | * @param[in] PIR parameters 56 | **/ 57 | static StatusOr<shared_ptr<PIRDatabase>> Create( 58 | const vector<std::int64_t>& /*database*/, 59 | shared_ptr<PIRParameters> params); 60 | 61 | /** 62 | * Shortcut to create and return a new PIR database instance using the values 63 | *given. Values are packed into the database as per the parameters given. 64 | * @param[in] db Database to load 65 | * @param[in] PIR parameters 66 | **/ 67 | static StatusOr<shared_ptr<PIRDatabase>> Create( 68 | const vector<string>& /*database*/, shared_ptr<PIRParameters> params); 69 | 70 | /** 71 | * Populate the database plaintexts from a list of integers. Only really used 72 | * for testing. 73 | */ 74 | Status populate(const vector<std::int64_t>& /*database*/); 75 | 76 | /** 77 | * Populate the database plaintexts from a list of strings. Items must match 78 | * the settings in the context or InvalidArgumentError will be returned. 79 | */ 80 | Status populate(const vector<string>& /*database*/); 81 | 82 | /** 83 | * Multiplies the database represented as a multi-dimensional hypercube with 84 | * a selection vector. Selection vector is split into sub vectors based on 85 | * dimensions fetched from PIRParameters in the current context. 86 | * @param[in] selection_vector Selection vector to multiply against 87 | * @returns Ciphertext resulting from multiplication, or error 88 | */ 89 | StatusOr<std::vector<seal::Ciphertext>> multiply( 90 | std::vector<seal::Ciphertext>& selection_vector, 91 | const seal::RelinKeys* const relin_keys = nullptr, 92 | seal::Decryptor* const decryptor = nullptr) const; 93 | 94 | /** 95 | * Database size. 96 | **/ 97 | std::size_t size() const { return db_.size(); } 98 | 99 | /** 100 | * Helper function to calculate indices within the multi-dimensional 101 | * representation of the database for a given index in the flat 102 | * representation. 103 | * @param[in] dims The dimensions to use in multi-dimensional rep. 104 | * @param[in] index Index in the flat representation. 105 | * @returns Vector of indices. 106 | */ 107 | vector<uint32_t> calculate_indices(uint32_t index); 108 | 109 | /** 110 | * Calculate the offset of an item within a plaintext. 111 | * @param[in] index Item index in the database 112 | * @returns Offset in bytes from start of the plaintext that contains item. 113 | */ 114 | size_t calculate_item_offset(uint32_t index); 115 | 116 | /** 117 | * Helper function to calculate the multi-dimensional representation of the 118 | * database 119 | * @param[in] db_size, The database size. 120 | * @param[in] num_dimensions The mumber of dimensions. 121 | * @returns Vector of dimensions. 122 | */ 123 | static vector<uint32_t> calculate_dimensions(uint32_t db_size, 124 | uint32_t num_dimensions); 125 | 126 | PIRDatabase(std::unique_ptr<PIRContext> context) 127 | : context_(std::move(context)) {} 128 | 129 | private: 130 | vector<seal::Plaintext> db_; 131 | std::unique_ptr<PIRContext> context_; 132 | }; 133 | 134 | } // namespace pir 135 | 136 | #endif // PIR_DATABASE_H_ 137 | -------------------------------------------------------------------------------- /pir/cpp/database_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include <algorithm> 18 | #include <iostream> 19 | #include <vector> 20 | 21 | #include "gmock/gmock.h" 22 | #include "gtest/gtest.h" 23 | #include "pir/cpp/client.h" 24 | #include "pir/cpp/ct_reencoder.h" 25 | #include "pir/cpp/server.h" 26 | #include "pir/cpp/status_asserts.h" 27 | #include "pir/cpp/string_encoder.h" 28 | #include "pir/cpp/test_base.h" 29 | #include "pir/cpp/utils.h" 30 | 31 | namespace pir { 32 | namespace { 33 | 34 | using std::cout; 35 | using std::endl; 36 | using std::get; 37 | using std::make_tuple; 38 | using std::make_unique; 39 | using std::shared_ptr; 40 | using std::string; 41 | using std::tuple; 42 | using std::unique_ptr; 43 | using std::vector; 44 | 45 | using seal::Ciphertext; 46 | using seal::GaloisKeys; 47 | using seal::Plaintext; 48 | 49 | using namespace seal; 50 | using namespace ::testing; 51 | using std::int64_t; 52 | using std::vector; 53 | 54 | constexpr uint32_t POLY_MODULUS_DEGREE = 4096; 55 | 56 | class PIRDatabaseTestBase : public PIRTestingBase { 57 | protected: 58 | void SetUpDBImpl(size_t dbsize, size_t dimensions, 59 | uint32_t poly_modulus_degree, uint32_t plain_mod_bit_size, 60 | bool use_ciphertext_multiplication) { 61 | SetUpParams(dbsize, 0, dimensions, poly_modulus_degree, plain_mod_bit_size, 62 | 0, use_ciphertext_multiplication); 63 | GenerateIntDB(); 64 | SetUpSealTools(); 65 | encoder_ = make_unique<seal::IntegerEncoder>(seal_context_); 66 | } 67 | 68 | void SetUpStringDBImpl(size_t dbsize, size_t dimensions = 1, 69 | uint32_t poly_modulus_degree = POLY_MODULUS_DEGREE, 70 | uint32_t plain_mod_bit_size = 20, size_t elem_size = 0, 71 | bool use_ciphertext_multiplication = false) { 72 | SetUpParams(dbsize, elem_size, dimensions, poly_modulus_degree, 73 | plain_mod_bit_size, 0, use_ciphertext_multiplication); 74 | GenerateDB(); 75 | SetUpSealTools(); 76 | } 77 | 78 | void decode_result(vector<Ciphertext> result_cts, Plaintext& result_pt, 79 | size_t input_ct_size, size_t d, 80 | bool use_ciphertext_multiplication) { 81 | if (use_ciphertext_multiplication) { 82 | ASSERT_EQ(result_cts.size(), 1); 83 | decryptor_->decrypt(result_cts[0], result_pt); 84 | } else { 85 | decode_from_decomp(result_cts, result_pt, input_ct_size, d); 86 | } 87 | } 88 | void decode_from_decomp(vector<Ciphertext> result_cts, Plaintext& result_pt, 89 | size_t input_ct_size, size_t d) { 90 | ASSERT_GT(d, 0); 91 | if (d <= 1) { 92 | ASSERT_EQ(result_cts.size(), 1); 93 | decryptor_->decrypt(result_cts[0], result_pt); 94 | return; 95 | } 96 | 97 | ASSIGN_OR_FAIL(auto ct_reencoder, 98 | CiphertextReencoder::Create(seal_context_)); 99 | ASSERT_EQ(result_cts.size(), 100 | ipow(ct_reencoder->ExpansionRatio() * input_ct_size, d - 1)); 101 | 102 | auto result_pts = 103 | decode_recursion(result_cts, input_ct_size, d, ct_reencoder.get()); 104 | ASSERT_EQ(result_pts.size(), 1); 105 | result_pt = result_pts[0]; 106 | } 107 | 108 | vector<Plaintext> decode_recursion(vector<Ciphertext> cts, 109 | size_t input_ct_size, size_t d, 110 | CiphertextReencoder* ct_reencoder) { 111 | vector<Plaintext> pts(cts.size()); 112 | for (size_t i = 0; i < pts.size(); ++i) { 113 | decryptor_->decrypt(cts[i], pts[i]); 114 | } 115 | 116 | if (d <= 1) { 117 | return pts; 118 | } 119 | 120 | size_t expansion_ratio = ct_reencoder->ExpansionRatio() * input_ct_size; 121 | vector<Ciphertext> result_cts(cts.size() / expansion_ratio); 122 | for (size_t i = 0; i < result_cts.size(); ++i) { 123 | result_cts[i] = ct_reencoder->Decode(pts.begin() + (i * expansion_ratio), 124 | input_ct_size); 125 | } 126 | return decode_recursion(result_cts, input_ct_size, d - 1, ct_reencoder); 127 | } 128 | 129 | unique_ptr<seal::IntegerEncoder> encoder_; 130 | }; 131 | 132 | class PIRDatabaseTest : public PIRDatabaseTestBase, 133 | public ::testing::TestWithParam<bool> { 134 | protected: 135 | void SetUp() { SetUpDB(100); } 136 | 137 | void SetUpDB(size_t dbsize, size_t dimensions = 1, 138 | uint32_t poly_modulus_degree = POLY_MODULUS_DEGREE, 139 | uint32_t plain_mod_bit_size = 20) { 140 | bool use_ciphertext_multiplication = GetParam(); 141 | SetUpDBImpl(dbsize, dimensions, poly_modulus_degree, plain_mod_bit_size, 142 | use_ciphertext_multiplication); 143 | } 144 | 145 | void SetUpStringDB(size_t dbsize, size_t dimensions = 1, 146 | uint32_t poly_modulus_degree = POLY_MODULUS_DEGREE, 147 | uint32_t plain_mod_bit_size = 20, size_t elem_size = 0) { 148 | bool use_ciphertext_multiplication = GetParam(); 149 | SetUpStringDBImpl(dbsize, dimensions, poly_modulus_degree, 150 | plain_mod_bit_size, elem_size, 151 | use_ciphertext_multiplication); 152 | } 153 | }; 154 | 155 | TEST_P(PIRDatabaseTest, TestMultiply) { 156 | vector<int32_t> v(db_size_); 157 | std::generate(v.begin(), v.end(), 158 | [n = -db_size_ / 2]() mutable { return n; }); 159 | ASSERT_THAT(pir_db_->size(), Eq(v.size())); 160 | 161 | vector<Ciphertext> cts(v.size()); 162 | int64_t expected = 0; 163 | for (size_t i = 0; i < cts.size(); ++i) { 164 | Plaintext pt; 165 | encoder_->encode(v[i], pt); 166 | encryptor_->encrypt(pt, cts[i]); 167 | expected += v[i] * int_db_[i]; 168 | } 169 | 170 | ASSIGN_OR_FAIL(auto result_cts, pir_db_->multiply(cts, nullptr)); 171 | ASSERT_EQ(result_cts.size(), 1); 172 | 173 | Plaintext pt; 174 | decryptor_->decrypt(result_cts[0], pt); 175 | auto result = encoder_->decode_int64(pt); 176 | 177 | EXPECT_THAT(result, Eq(expected)); 178 | } 179 | 180 | TEST_P(PIRDatabaseTest, TestMultiplySelectionVectorTooSmall) { 181 | SetUpDB(100, 2); 182 | const uint32_t desired_index = 42; 183 | const auto dims = PIRDatabase::calculate_dimensions(db_size_, 2); 184 | const auto indices = pir_db_->calculate_indices(desired_index); 185 | 186 | vector<Ciphertext> cts; 187 | for (size_t d = 0; d < dims.size(); ++d) { 188 | for (size_t i = 0; i < dims[d]; ++i) { 189 | Ciphertext ct; 190 | encryptor_->encrypt_zero(ct); 191 | cts.push_back(ct); 192 | } 193 | } 194 | 195 | cts.resize(cts.size() - 1); 196 | auto results_or = pir_db_->multiply(cts); 197 | ASSERT_THAT(results_or.status().code(), 198 | Eq(absl::StatusCode::kInvalidArgument)); 199 | } 200 | 201 | TEST_P(PIRDatabaseTest, TestMultiplySelectionVectorTooBig) { 202 | SetUpDB(100, 2); 203 | const uint32_t desired_index = 42; 204 | const auto dims = PIRDatabase::calculate_dimensions(db_size_, 2); 205 | const auto indices = pir_db_->calculate_indices(desired_index); 206 | 207 | vector<Ciphertext> cts; 208 | for (size_t d = 0; d < dims.size(); ++d) { 209 | for (size_t i = 0; i < dims[d] + 1; ++i) { 210 | Ciphertext ct; 211 | encryptor_->encrypt_zero(ct); 212 | cts.push_back(ct); 213 | } 214 | } 215 | 216 | auto results_or = pir_db_->multiply(cts); 217 | ASSERT_THAT(results_or.status().code(), 218 | Eq(absl::StatusCode::kInvalidArgument)); 219 | } 220 | 221 | TEST_P(PIRDatabaseTest, TestMultiplyStringValues) { 222 | constexpr size_t db_size = 10; 223 | constexpr size_t desired_index = 7; 224 | 225 | SetUpStringDB(db_size, 1, POLY_MODULUS_DEGREE, 22); 226 | 227 | vector<Plaintext> selection_vector_pt(db_size); 228 | vector<Ciphertext> selection_vector_ct(db_size); 229 | for (size_t i = 0; i < db_size; ++i) { 230 | selection_vector_pt[i].resize(POLY_MODULUS_DEGREE); 231 | selection_vector_pt[i].set_zero(); 232 | if (i == desired_index) { 233 | selection_vector_pt[i][0] = 1; 234 | } 235 | encryptor_->encrypt(selection_vector_pt[i], selection_vector_ct[i]); 236 | } 237 | 238 | ASSIGN_OR_FAIL(auto result_cts, pir_db_->multiply(selection_vector_ct)); 239 | ASSERT_EQ(result_cts.size(), 1); 240 | 241 | Plaintext result_pt; 242 | decryptor_->decrypt(result_cts[0], result_pt); 243 | auto string_encoder = make_unique<StringEncoder>(seal_context_); 244 | ASSIGN_OR_FAIL(auto result, string_encoder->decode(result_pt)); 245 | 246 | EXPECT_THAT(result, Eq(string_db_[desired_index])); 247 | } 248 | 249 | vector<Ciphertext> create_selection_vector(const vector<uint32_t>& dims, 250 | const vector<uint32_t>& indices, 251 | Encryptor& encryptor) { 252 | vector<Ciphertext> cts; 253 | for (size_t d = 0; d < dims.size(); ++d) { 254 | for (size_t i = 0; i < dims[d]; ++i) { 255 | Ciphertext ct; 256 | if (i == indices[d]) { 257 | Plaintext pt(POLY_MODULUS_DEGREE); 258 | pt.set_zero(); 259 | pt[0] = 1; 260 | encryptor.encrypt(pt, ct); 261 | } else { 262 | encryptor.encrypt_zero(ct); 263 | } 264 | cts.push_back(ct); 265 | } 266 | } 267 | return cts; 268 | } 269 | 270 | TEST_P(PIRDatabaseTest, TestMultiplyStringValuesD2) { 271 | constexpr size_t d = 2; 272 | constexpr size_t db_size = 9; 273 | constexpr size_t desired_index = 5; 274 | 275 | SetUpStringDB(db_size, d, POLY_MODULUS_DEGREE, 16); 276 | 277 | const auto dims = PIRDatabase::calculate_dimensions(db_size, d); 278 | const auto indices = pir_db_->calculate_indices(desired_index); 279 | auto sv = create_selection_vector(dims, indices, *encryptor_); 280 | 281 | auto relin_keys = keygen_->relin_keys_local(); 282 | ASSIGN_OR_FAIL(auto result_cts, pir_db_->multiply(sv, &relin_keys)); 283 | Plaintext result_pt; 284 | decode_result(result_cts, result_pt, sv[0].size(), d, GetParam()); 285 | 286 | auto string_encoder = make_unique<StringEncoder>(seal_context_); 287 | ASSIGN_OR_FAIL(auto result, string_encoder->decode(result_pt)); 288 | 289 | EXPECT_THAT(result.substr(0, string_db_[desired_index].size()), 290 | Eq(string_db_[desired_index])); 291 | } 292 | 293 | TEST_P(PIRDatabaseTest, TestMultiplyMultipleValuesPerPT) { 294 | constexpr size_t d = 2; 295 | constexpr size_t db_size = 1000; 296 | constexpr size_t elem_size = 128; 297 | constexpr size_t desired_index = 754; 298 | 299 | SetUpStringDB(db_size, d, POLY_MODULUS_DEGREE, 16, elem_size); 300 | ASSERT_EQ(pir_db_->size(), pir_params_->num_pt()); 301 | ASSERT_EQ(pir_params_->bytes_per_item(), elem_size); 302 | 303 | const size_t items_per_pt = pir_params_->items_per_plaintext(); 304 | const size_t num_db_pt = ceil(static_cast<double>(db_size) / items_per_pt); 305 | const size_t desired_pt_index = desired_index / items_per_pt; 306 | const size_t desired_offset = 307 | (desired_index - desired_pt_index * items_per_pt) * elem_size; 308 | 309 | const auto dims = PIRDatabase::calculate_dimensions(num_db_pt, d); 310 | const auto indices = pir_db_->calculate_indices(desired_index); 311 | auto sv = create_selection_vector(dims, indices, *encryptor_); 312 | 313 | auto relin_keys = keygen_->relin_keys_local(); 314 | ASSIGN_OR_FAIL(auto result_cts, pir_db_->multiply(sv, &relin_keys)); 315 | 316 | Plaintext result_pt; 317 | decode_result(result_cts, result_pt, sv[0].size(), d, GetParam()); 318 | auto string_encoder = make_unique<StringEncoder>(seal_context_); 319 | ASSIGN_OR_FAIL(auto result, 320 | string_encoder->decode(result_pt, elem_size, desired_offset)); 321 | 322 | EXPECT_THAT(result, Eq(string_db_[desired_index])); 323 | } 324 | 325 | TEST_P(PIRDatabaseTest, TestCreateValueDoesntMatch) { 326 | SetUpParams(10, 9728, 1, 4096, 20, 19); 327 | 328 | auto prng = 329 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 330 | vector<string> db(db_size_); 331 | for (size_t i = 0; i < db_size_; ++i) { 332 | db[i].resize(9729); 333 | prng->generate(db[i].size(), reinterpret_cast<SEAL_BYTE*>(db[i].data())); 334 | } 335 | 336 | auto pir_db_or = PIRDatabase::Create(db, pir_params_); 337 | ASSERT_FALSE(pir_db_or.ok()); 338 | ASSERT_EQ(pir_db_or.status().code(), absl::StatusCode::kInvalidArgument); 339 | } 340 | 341 | INSTANTIATE_TEST_SUITE_P(PIRDatabaseTests, PIRDatabaseTest, 342 | testing::Values(false, true)); 343 | 344 | class MultiplyMultiDimTest 345 | : public PIRDatabaseTestBase, 346 | public testing::TestWithParam< 347 | tuple<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t>> { 348 | protected: 349 | void TestMultiply(bool use_ciphertext_multiplication) { 350 | const auto poly_modulus_degree = get<0>(GetParam()); 351 | const auto plain_mod_bits = get<1>(GetParam()); 352 | const auto dbsize = get<2>(GetParam()); 353 | const auto d = get<3>(GetParam()); 354 | const auto desired_index = get<4>(GetParam()); 355 | SetUpStringDBImpl(dbsize, d, poly_modulus_degree, plain_mod_bits, 0, 356 | use_ciphertext_multiplication); 357 | const size_t elem_size = pir_params_->bytes_per_item(); 358 | const auto dims = PIRDatabase::calculate_dimensions(dbsize, d); 359 | const auto indices = pir_db_->calculate_indices(desired_index); 360 | auto cts = create_selection_vector(dims, indices, *encryptor_); 361 | 362 | unique_ptr<RelinKeys> relin_keys; 363 | if (use_ciphertext_multiplication) { 364 | relin_keys = make_unique<RelinKeys>(keygen_->relin_keys_local()); 365 | } 366 | ASSIGN_OR_FAIL(auto result_cts, pir_db_->multiply(cts, relin_keys.get())); 367 | 368 | Plaintext result_pt; 369 | decode_result(result_cts, result_pt, cts[0].size(), d, 370 | use_ciphertext_multiplication); 371 | auto string_encoder = make_unique<StringEncoder>(seal_context_); 372 | ASSIGN_OR_FAIL(auto result, string_encoder->decode(result_pt, elem_size)); 373 | EXPECT_THAT(result, Eq(string_db_[desired_index])); 374 | } 375 | }; 376 | 377 | TEST_P(MultiplyMultiDimTest, CTDecomp) { TestMultiply(false); } 378 | 379 | TEST_P(MultiplyMultiDimTest, CTMultiply) { TestMultiply(true); } 380 | 381 | INSTANTIATE_TEST_SUITE_P(PIRDatabaseMultiplies, MultiplyMultiDimTest, 382 | testing::Values(make_tuple(4096, 16, 10, 1, 7), 383 | make_tuple(4096, 16, 16, 2, 11), 384 | make_tuple(4096, 16, 16, 2, 0), 385 | make_tuple(4096, 16, 16, 2, 15), 386 | make_tuple(4096, 16, 82, 2, 42), 387 | make_tuple(8192, 20, 27, 3, 2), 388 | make_tuple(8192, 20, 117, 3, 17))); 389 | 390 | class CalculateIndicesTest 391 | : public testing::TestWithParam< 392 | tuple<uint32_t, uint32_t, uint32_t, uint32_t, vector<uint32_t>>> {}; 393 | 394 | TEST_P(CalculateIndicesTest, IndicesExamples) { 395 | const auto num_items = get<0>(GetParam()); 396 | const auto size_per_item = get<1>(GetParam()); 397 | const auto d = get<2>(GetParam()); 398 | const auto desired_index = get<3>(GetParam()); 399 | const auto& expected_indices = get<4>(GetParam()); 400 | ASSIGN_OR_FAIL(const auto pir_params, 401 | CreatePIRParameters(num_items, size_per_item, d, 402 | GenerateEncryptionParams(4096, 16))); 403 | ASSIGN_OR_FAIL(auto pir_db, PIRDatabase::Create(pir_params)); 404 | ASSERT_THAT(expected_indices, SizeIs(d)); 405 | auto indices = pir_db->calculate_indices(desired_index); 406 | EXPECT_THAT(indices, ContainerEq(expected_indices)); 407 | } 408 | 409 | INSTANTIATE_TEST_SUITE_P( 410 | PIRDatabaseCalculateIndices, CalculateIndicesTest, 411 | Values(make_tuple(100, 0, 1, 42, vector<uint32_t>{42}), 412 | make_tuple(100, 0, 1, 7, vector<uint32_t>{7}), 413 | make_tuple(84, 0, 2, 7, vector<uint32_t>{0, 7}), 414 | make_tuple(87, 0, 2, 27, vector<uint32_t>{3, 0}), 415 | make_tuple(87, 0, 2, 42, vector<uint32_t>{4, 6}), 416 | make_tuple(87, 0, 2, 86, vector<uint32_t>{9, 5}), 417 | make_tuple(82, 0, 3, 3, vector<uint32_t>{0, 0, 3}), 418 | make_tuple(82, 0, 3, 20, vector<uint32_t>{1, 0, 0}), 419 | make_tuple(82, 0, 3, 75, vector<uint32_t>{3, 3, 3}), 420 | make_tuple(5000, 64, 1, 2222, vector<uint32_t>{18}), 421 | make_tuple(5000, 64, 1, 1200, vector<uint32_t>{10}))); 422 | 423 | class CalculateOffsetTest : public testing::TestWithParam< 424 | tuple<uint32_t, uint32_t, uint32_t, uint32_t>> { 425 | }; 426 | 427 | TEST_P(CalculateOffsetTest, OffsetExamples) { 428 | const auto num_items = get<0>(GetParam()); 429 | const auto size_per_item = get<1>(GetParam()); 430 | const auto desired_index = get<2>(GetParam()); 431 | const auto& expected_offset = get<3>(GetParam()); 432 | ASSIGN_OR_FAIL(const auto pir_params, 433 | CreatePIRParameters(num_items, size_per_item, 1, 434 | GenerateEncryptionParams(4096, 16))); 435 | ASSIGN_OR_FAIL(auto pir_db, PIRDatabase::Create(pir_params)); 436 | auto offset = pir_db->calculate_item_offset(desired_index); 437 | EXPECT_EQ(offset, expected_offset); 438 | } 439 | 440 | INSTANTIATE_TEST_SUITE_P(PIRDatabaseCalculateOffset, CalculateOffsetTest, 441 | Values(make_tuple(100, 0, 42, 0), 442 | make_tuple(1000, 64, 42, 2688), 443 | make_tuple(1000, 64, 960, 0), 444 | make_tuple(1000, 64, 999, 2496))); 445 | 446 | class CalculateDimensionsTest 447 | : public testing::TestWithParam< 448 | tuple<uint32_t, uint32_t, vector<uint32_t>>> {}; 449 | 450 | TEST_P(CalculateDimensionsTest, dimensionsExamples) { 451 | EXPECT_THAT( 452 | PIRDatabase::calculate_dimensions(get<0>(GetParam()), get<1>(GetParam())), 453 | ContainerEq(get<2>(GetParam()))); 454 | } 455 | 456 | INSTANTIATE_TEST_SUITE_P( 457 | CalculateDimensions, CalculateDimensionsTest, 458 | testing::Values(make_tuple(100, 1, vector<uint32_t>{100}), 459 | make_tuple(100, 2, vector<uint32_t>{10, 10}), 460 | make_tuple(82, 2, vector<uint32_t>{10, 9}), 461 | make_tuple(975, 2, vector<uint32_t>{32, 31}), 462 | make_tuple(1000, 3, vector<uint32_t>{10, 10, 10}), 463 | make_tuple(1001, 3, vector<uint32_t>{11, 10, 10}), 464 | make_tuple(1000001, 3, vector<uint32_t>{101, 100, 100}))); 465 | 466 | } // namespace 467 | } // namespace pir 468 | -------------------------------------------------------------------------------- /pir/cpp/parameters.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | #include "pir/cpp/parameters.h" 17 | 18 | #include "pir/cpp/database.h" 19 | #include "pir/cpp/serialization.h" 20 | #include "pir/cpp/status_asserts.h" 21 | #include "pir/cpp/string_encoder.h" 22 | #include "pir/cpp/utils.h" 23 | #include "seal/seal.h" 24 | 25 | namespace pir { 26 | 27 | using absl::InvalidArgumentError; 28 | using absl::StatusOr; 29 | using ::seal::EncryptionParameters; 30 | using ::std::make_shared; 31 | using ::std::shared_ptr; 32 | 33 | EncryptionParameters GenerateEncryptionParams(uint32_t poly_mod_degree, 34 | uint32_t plain_mod_bit_size) { 35 | return GenerateEncryptionParams( 36 | poly_mod_degree, 37 | seal::PlainModulus::Batching(poly_mod_degree, plain_mod_bit_size)); 38 | } 39 | 40 | EncryptionParameters GenerateEncryptionParams( 41 | std::optional<uint32_t> poly_mod_opt, std::optional<Modulus> plain_mod_opt, 42 | std::optional<std::vector<Modulus>> coeff_opt) { 43 | auto poly_modulus_degree = poly_mod_opt.value_or(DEFAULT_POLY_MODULUS_DEGREE); 44 | auto plain_modulus = plain_mod_opt.value_or( 45 | seal::PlainModulus::Batching(poly_modulus_degree, 20)); 46 | auto coeff = 47 | coeff_opt.value_or(seal::CoeffModulus::BFVDefault(poly_modulus_degree)); 48 | 49 | EncryptionParameters parms(seal::scheme_type::BFV); 50 | parms.set_poly_modulus_degree(poly_modulus_degree); 51 | parms.set_plain_modulus(plain_modulus); 52 | parms.set_coeff_modulus(coeff); 53 | return parms; 54 | } 55 | 56 | StatusOr<shared_ptr<PIRParameters>> CreatePIRParameters( 57 | size_t dbsize, size_t bytes_per_item, size_t dimensions, 58 | EncryptionParameters seal_params, bool use_ciphertext_multiplication, 59 | size_t bits_per_coeff) { 60 | // Make sure SEAL Parameter are valid 61 | auto seal_context = seal::SEALContext::Create(seal_params); 62 | if (!seal_context->parameters_set()) { 63 | return InvalidArgumentError( 64 | string("Error setting encryption parameters: ") + 65 | seal_context->parameter_error_message()); 66 | } 67 | StringEncoder encoder(seal_context); 68 | 69 | auto parameters = std::make_shared<PIRParameters>(); 70 | parameters->set_num_items(dbsize); 71 | parameters->set_use_ciphertext_multiplication(use_ciphertext_multiplication); 72 | 73 | if (bits_per_coeff > 0) { 74 | if (bits_per_coeff > encoder.bits_per_coeff()) { 75 | return InvalidArgumentError("Bits per coefficient greater than max"); 76 | } 77 | encoder.set_bits_per_coeff(bits_per_coeff); 78 | parameters->set_bits_per_coeff(bits_per_coeff); 79 | } 80 | 81 | if (bytes_per_item > 0) { 82 | parameters->set_bytes_per_item(bytes_per_item); 83 | parameters->set_items_per_plaintext( 84 | encoder.num_items_per_plaintext(bytes_per_item)); 85 | if (parameters->items_per_plaintext() <= 0) { 86 | return InvalidArgumentError("Cannot fit an item within one plaintext"); 87 | } 88 | size_t num_pt = dbsize / parameters->items_per_plaintext(); 89 | while (dbsize > num_pt * parameters->items_per_plaintext()) { 90 | ++num_pt; 91 | } 92 | parameters->set_num_pt(num_pt); 93 | } else { 94 | parameters->set_bytes_per_item(encoder.max_bytes_per_plaintext()); 95 | parameters->set_items_per_plaintext(1); 96 | parameters->set_num_pt(dbsize); 97 | } 98 | 99 | RETURN_IF_ERROR(SEALSerialize<EncryptionParameters>( 100 | seal_params, parameters->mutable_encryption_parameters())); 101 | 102 | for (auto& dim : 103 | PIRDatabase::calculate_dimensions(parameters->num_pt(), dimensions)) 104 | parameters->add_dimensions(dim); 105 | 106 | return parameters; 107 | } 108 | 109 | } // namespace pir 110 | -------------------------------------------------------------------------------- /pir/cpp/parameters.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_PARAMETERS_H_ 18 | #define PIR_PARAMETERS_H_ 19 | 20 | #include <vector> 21 | 22 | #include "absl/memory/memory.h" 23 | #include "absl/status/statusor.h" 24 | #include "pir/proto/payload.pb.h" 25 | #include "seal/seal.h" 26 | 27 | namespace pir { 28 | 29 | using ::std::optional; 30 | using ::std::shared_ptr; 31 | using ::std::size_t; 32 | using ::std::vector; 33 | 34 | using ::seal::EncryptionParameters; 35 | using ::seal::Modulus; 36 | 37 | using absl::InvalidArgumentError; 38 | using absl::StatusOr; 39 | 40 | constexpr uint32_t DEFAULT_POLY_MODULUS_DEGREE = 4096; 41 | 42 | /** 43 | * Helper function to generate encryption parameters. 44 | * @param[in] optional The polynomial modulus degree 45 | * @param[in] optional The plaintext modulus 46 | * @param[in] optional The coefficient modulus 47 | */ 48 | EncryptionParameters GenerateEncryptionParams( 49 | optional<uint32_t> poly_mod_opt = {}, optional<Modulus> plain_mod_opt = {}, 50 | optional<std::vector<Modulus>> coeff_opt = {}); 51 | 52 | /** 53 | * Shortcut to generate encryption parameters for a given poly modulus degree 54 | * and bit size of the plain modulus. 55 | */ 56 | EncryptionParameters GenerateEncryptionParams(uint32_t poly_mod_degree, 57 | uint32_t plain_mod_bit_size); 58 | 59 | /** 60 | * Helper function to create the PIRParameters 61 | * @param[in] dbsize The number of individual items in the database. 62 | * @param[in] bytes_per_item Size in bytes of each item in the database. 63 | * @param[in] dimensions Number of dimensions in the database representation. 64 | * @param[in] enc_params SEAL Encryption Parameters to be used. 65 | * @param[in] bits_per_coeff If non-zero, number of bits to encode per plaintext 66 | * plaintext coefficient in the database. 67 | * @returns InvalidArgument if EncryptionParameters serialization fails. 68 | */ 69 | StatusOr<std::shared_ptr<PIRParameters>> CreatePIRParameters( 70 | size_t dbsize, size_t bytes_per_item, size_t dimensions = 1, 71 | EncryptionParameters enc_params = GenerateEncryptionParams(), 72 | bool use_ciphertext_multiplication = false, size_t bits_per_coeff = 0); 73 | } // namespace pir 74 | 75 | #endif // PIR_PARAMETERS_H_ 76 | -------------------------------------------------------------------------------- /pir/cpp/parameters_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "pir/cpp/parameters.h" 18 | 19 | #include "gmock/gmock.h" 20 | #include "gtest/gtest.h" 21 | #include "pir/cpp/serialization.h" 22 | #include "pir/cpp/status_asserts.h" 23 | 24 | namespace pir { 25 | namespace { 26 | 27 | using std::cout; 28 | using std::endl; 29 | using std::get; 30 | using std::make_tuple; 31 | using std::make_unique; 32 | using std::shared_ptr; 33 | using std::string; 34 | using std::tuple; 35 | using std::unique_ptr; 36 | using std::vector; 37 | 38 | using seal::Ciphertext; 39 | using seal::GaloisKeys; 40 | using seal::Plaintext; 41 | 42 | using namespace seal; 43 | using namespace ::testing; 44 | using std::int64_t; 45 | using std::vector; 46 | 47 | TEST(PIRParametersTest, SanityCheck) { 48 | // make sure we can actually initialize SEAL and that defaults are sane 49 | ASSIGN_OR_FAIL(auto pir_params, CreatePIRParameters(1026, 256)); 50 | EXPECT_THAT(pir_params->num_items(), Eq(1026)); 51 | EXPECT_THAT(pir_params->num_pt(), Eq(27)); 52 | EXPECT_THAT(pir_params->bytes_per_item(), Eq(256)); 53 | EXPECT_THAT(pir_params->items_per_plaintext(), Eq(38)); 54 | EXPECT_THAT(pir_params->dimensions(), ElementsAre(27)); 55 | ASSIGN_OR_FAIL(auto encryptionParams, 56 | SEALDeserialize<EncryptionParameters>( 57 | pir_params->encryption_parameters())); 58 | auto context = seal::SEALContext::Create(encryptionParams); 59 | EXPECT_THAT(context->parameters_set(), IsTrue()) 60 | << "Error setting encryption parameters: " 61 | << context->parameter_error_message(); 62 | } 63 | 64 | TEST(PIRParametersTest, CreateMultiDim) { 65 | ASSIGN_OR_FAIL(auto pir_params, CreatePIRParameters(19011, 500, 3)); 66 | EXPECT_THAT(pir_params->num_items(), Eq(19011)); 67 | EXPECT_THAT(pir_params->num_pt(), Eq(1001)); 68 | EXPECT_THAT(pir_params->bytes_per_item(), Eq(500)); 69 | EXPECT_THAT(pir_params->items_per_plaintext(), Eq(19)); 70 | EXPECT_THAT(pir_params->dimensions(), ElementsAre(11, 10, 10)); 71 | ASSIGN_OR_FAIL(auto encryption_params, 72 | SEALDeserialize<EncryptionParameters>( 73 | pir_params->encryption_parameters())); 74 | auto context = seal::SEALContext::Create(encryption_params); 75 | EXPECT_THAT(context->parameters_set(), IsTrue()) 76 | << "Error setting encryption parameters: " 77 | << context->parameter_error_message(); 78 | } 79 | 80 | TEST(PIRParametersTest, CreateAllParams) { 81 | ASSIGN_OR_FAIL(auto pir_params, 82 | CreatePIRParameters(77412, 777, 2, 83 | GenerateEncryptionParams(8192), true, 12)); 84 | EXPECT_THAT(pir_params->num_items(), Eq(77412)); 85 | EXPECT_THAT(pir_params->num_pt(), Eq(5161)); 86 | EXPECT_THAT(pir_params->bytes_per_item(), Eq(777)); 87 | EXPECT_THAT(pir_params->items_per_plaintext(), Eq(15)); 88 | EXPECT_THAT(pir_params->dimensions(), ElementsAre(72, 72)); 89 | EXPECT_THAT(pir_params->use_ciphertext_multiplication(), IsTrue()); 90 | EXPECT_THAT(pir_params->bits_per_coeff(), Eq(12)); 91 | ASSIGN_OR_FAIL(auto encryption_params, 92 | SEALDeserialize<EncryptionParameters>( 93 | pir_params->encryption_parameters())); 94 | auto context = seal::SEALContext::Create(encryption_params); 95 | EXPECT_THAT(context->parameters_set(), IsTrue()) 96 | << "Error setting encryption parameters: " 97 | << context->parameter_error_message(); 98 | } 99 | 100 | TEST(PIRParametersTest, EncryptionParamsSerialization) { 101 | // use something other than defaults 102 | auto params = GenerateEncryptionParams(8192); 103 | std::string serial; 104 | ASSERT_OK(SEALSerialize<EncryptionParameters>(params, &serial)); 105 | ASSIGN_OR_FAIL(auto new_params, 106 | SEALDeserialize<EncryptionParameters>(serial)); 107 | ASSERT_THAT(new_params, Eq(params)); 108 | } 109 | 110 | } // namespace 111 | } // namespace pir 112 | -------------------------------------------------------------------------------- /pir/cpp/serialization.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | #include "pir/cpp/serialization.h" 17 | 18 | #include "pir/cpp/status_asserts.h" 19 | #include "pir/cpp/utils.h" 20 | #include "seal/seal.h" 21 | 22 | namespace pir { 23 | 24 | using absl::InvalidArgumentError; 25 | using absl::StatusOr; 26 | using seal::Ciphertext; 27 | using seal::GaloisKeys; 28 | using seal::RelinKeys; 29 | using std::string; 30 | using std::vector; 31 | 32 | StatusOr<vector<Ciphertext>> LoadCiphertexts( 33 | const std::shared_ptr<seal::SEALContext>& sealctx, 34 | const Ciphertexts& input) { 35 | vector<Ciphertext> output(input.ct_size()); 36 | for (int idx = 0; idx < input.ct_size(); ++idx) { 37 | ASSIGN_OR_RETURN(output[idx], 38 | SEALDeserialize<Ciphertext>(sealctx, input.ct(idx))); 39 | } 40 | 41 | return output; 42 | } 43 | 44 | Status SaveCiphertexts(const vector<Ciphertext>& ciphertexts, 45 | Ciphertexts* output) { 46 | if (output == nullptr) { 47 | return InvalidArgumentError("output nullptr"); 48 | } 49 | 50 | for (size_t idx = 0; idx < ciphertexts.size(); ++idx) { 51 | RETURN_IF_ERROR( 52 | SEALSerialize<Ciphertext>(ciphertexts[idx], output->add_ct())); 53 | } 54 | return absl::OkStatus(); 55 | } 56 | 57 | Status SaveRequest(const vector<vector<Ciphertext>>& cts, Request* request) { 58 | for (const auto& ct : cts) { 59 | RETURN_IF_ERROR(SaveCiphertexts(ct, request->add_query())); 60 | } 61 | return absl::OkStatus(); 62 | } 63 | 64 | Status SaveRequest(const vector<vector<Ciphertext>>& cts, 65 | const seal::GaloisKeys& galois_keys, 66 | const seal::RelinKeys& relin_keys, Request* request) { 67 | RETURN_IF_ERROR(SaveRequest(cts, request)); 68 | RETURN_IF_ERROR( 69 | SEALSerialize<GaloisKeys>(galois_keys, request->mutable_galois_keys())); 70 | RETURN_IF_ERROR( 71 | SEALSerialize<RelinKeys>(relin_keys, request->mutable_relin_keys())); 72 | return absl::OkStatus(); 73 | } 74 | 75 | }; // namespace pir 76 | -------------------------------------------------------------------------------- /pir/cpp/serialization.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_SERIALIZATION_H_ 18 | #define PIR_SERIALIZATION_H_ 19 | 20 | #include <string> 21 | 22 | #include "absl/status/statusor.h" 23 | #include "pir/proto/payload.pb.h" 24 | #include "seal/seal.h" 25 | 26 | namespace pir { 27 | 28 | using absl::InternalError; 29 | using absl::InvalidArgumentError; 30 | using absl::Status; 31 | using absl::StatusOr; 32 | using seal::Ciphertext; 33 | using seal::SEALContext; 34 | using std::shared_ptr; 35 | using std::string; 36 | using std::vector; 37 | 38 | /** 39 | * Decodes and loads a PIR Ciphertext. 40 | * @param[in] The SEAL context, for buffer allocations. 41 | * @param[in] The encoded ciphertext. 42 | * @returns InvalidArgument if the decoding fails. 43 | **/ 44 | StatusOr<vector<Ciphertext>> LoadCiphertexts(const shared_ptr<SEALContext>& ctx, 45 | const Ciphertexts& encoded); 46 | 47 | /** 48 | * Saves the Ciphertexts to a protobuffer. 49 | * @returns InvalidArgument if the encoding fails 50 | **/ 51 | Status SaveCiphertexts(const vector<Ciphertext>& buff, Ciphertexts* output); 52 | 53 | /** 54 | * Shortcut to save response data to a protocol buffer based on a list of 55 | * Ciphertexts. It is assumed that Galois keys will be added elsewhere. 56 | * @param[in] cts The list of Ciphertexts in the query. 57 | * @param[out] request Point to the request protocol buffer to fill in. 58 | * @returns InvalidArgument if the encoding fails. 59 | */ 60 | Status SaveRequest(const vector<vector<Ciphertext>>& cts, Request* request); 61 | 62 | /** 63 | * Shortcut to save response data to a protocol buffer based on a list of 64 | * Ciphertexts, a set of GaloisKeys, and a set or relinearization keys. 65 | * @param[in] cts The list of Ciphertexts in the query. 66 | * @param[in] galois_keys The Galois Keys to encode in the protocol buffer. 67 | * @param[in] relin_keys The relinearization keys to encode. 68 | * @param[out] request Point to the request protocol buffer to fill in. 69 | * @returns InvalidArgument if the encoding fails. 70 | */ 71 | Status SaveRequest(const vector<vector<Ciphertext>>& cts, 72 | const seal::GaloisKeys& galois_keys, 73 | const seal::RelinKeys& relin_keys, Request* request); 74 | 75 | /** 76 | * Saves a SEAL object to a string. 77 | * Compatible SEAL types: Ciphertext, Plaintext, SecretKey, PublicKey, 78 | *GaloisKeys, RelinKeys. 79 | * @returns InternalError if the encoding fails. 80 | **/ 81 | template <class T> 82 | Status SEALSerialize(const T& sealobj, string* output) { 83 | if (output == nullptr) { 84 | return InvalidArgumentError("output nullptr"); 85 | } 86 | std::stringstream stream; 87 | 88 | try { 89 | sealobj.save(stream); 90 | } catch (const std::exception& e) { 91 | return InternalError(e.what()); 92 | } 93 | 94 | *output = stream.str(); 95 | return absl::OkStatus(); 96 | } 97 | 98 | /** 99 | * Loads a SEAL object from a string. 100 | * Compatible SEAL types: Ciphertext, Plaintext, SecretKey, PublicKey, 101 | *GaloisKeys, RelinKeys. 102 | * @returns InvalidArgument if the decoding fails. 103 | **/ 104 | template <class T> 105 | StatusOr<T> SEALDeserialize(const shared_ptr<SEALContext>& sealctx, 106 | const string& in) { 107 | T out; 108 | 109 | try { 110 | std::stringstream stream; 111 | stream << in; 112 | out.load(sealctx, stream); 113 | } catch (const std::exception& e) { 114 | return InvalidArgumentError(e.what()); 115 | } 116 | 117 | return out; 118 | } 119 | 120 | /** 121 | * Loads a SEAL object from a string. 122 | * Compatible SEAL types: EncryptionParameters, Modulus, BigUInt, IntArray 123 | * @returns InvalidArgument if the decoding fails. 124 | **/ 125 | template <class T> 126 | StatusOr<T> SEALDeserialize(const string& in) { 127 | T out; 128 | 129 | try { 130 | std::stringstream stream; 131 | stream << in; 132 | out.load(stream); 133 | } catch (const std::exception& e) { 134 | return InvalidArgumentError(e.what()); 135 | } 136 | 137 | return out; 138 | } 139 | 140 | } // namespace pir 141 | 142 | #endif // PIR_SERIALIZATION_H_ 143 | -------------------------------------------------------------------------------- /pir/cpp/serialization_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "pir/cpp/serialization.h" 18 | 19 | #include "absl/status/statusor.h" 20 | #include "gmock/gmock.h" 21 | #include "gtest/gtest.h" 22 | #include "pir/cpp/context.h" 23 | #include "pir/cpp/status_asserts.h" 24 | #include "pir/cpp/utils.h" 25 | #include "seal/seal.h" 26 | 27 | namespace pir { 28 | 29 | using namespace seal; 30 | using std::get; 31 | using std::make_tuple; 32 | using std::make_unique; 33 | using std::size_t; 34 | using std::tuple; 35 | 36 | using ::testing::ElementsAreArray; 37 | 38 | class PIRSerializationTest : public ::testing::Test { 39 | protected: 40 | static constexpr size_t DB_SIZE = 100; 41 | static constexpr size_t ELEM_SIZE = 64; 42 | void SetUp() { SetUpDB(DB_SIZE); } 43 | 44 | void SetUpDB(size_t dbsize) { 45 | auto pir_params = *(CreatePIRParameters(dbsize, ELEM_SIZE)); 46 | context_ = std::move(*(PIRContext::Create(pir_params))); 47 | 48 | auto keygen_ = 49 | std::make_unique<seal::KeyGenerator>(context_->SEALContext()); 50 | encryptor_ = std::make_shared<seal::Encryptor>(context_->SEALContext(), 51 | keygen_->public_key()); 52 | decryptor_ = std::make_shared<seal::Decryptor>(context_->SEALContext(), 53 | keygen_->secret_key()); 54 | } 55 | 56 | std::shared_ptr<PIRContext> context_; 57 | std::shared_ptr<seal::Encryptor> encryptor_; 58 | std::shared_ptr<seal::Decryptor> decryptor_; 59 | }; 60 | 61 | TEST_F(PIRSerializationTest, TestResponseSerialization) { 62 | int64_t value = 987654321; 63 | Plaintext pt, reloaded_pt; 64 | context_->Encoder()->encode(value, pt); 65 | vector<Ciphertext> ct(1); 66 | encryptor_->encrypt(pt, ct[0]); 67 | 68 | Response response_proto; 69 | SaveCiphertexts(ct, response_proto.add_reply()); 70 | 71 | ASSIGN_OR_FAIL(auto reloaded, LoadCiphertexts(context_->SEALContext(), 72 | response_proto.reply(0))); 73 | ASSERT_EQ(reloaded.size(), 1); 74 | decryptor_->decrypt(reloaded[0], reloaded_pt); 75 | EXPECT_THAT(reloaded_pt, pt); 76 | } 77 | 78 | TEST_F(PIRSerializationTest, TestRequestSerialization_IndividualMethods) { 79 | int64_t value = 987654321; 80 | Plaintext pt, reloaded_pt; 81 | context_->Encoder()->encode(value, pt); 82 | vector<Ciphertext> ct(1); 83 | encryptor_->encrypt(pt, ct[0]); 84 | 85 | auto keygen_ = make_unique<KeyGenerator>(context_->SEALContext()); 86 | auto elts = generate_galois_elts(DEFAULT_POLY_MODULUS_DEGREE); 87 | GaloisKeys gal_keys = keygen_->galois_keys_local(elts); 88 | RelinKeys relin_keys = keygen_->relin_keys_local(); 89 | 90 | Request request_proto; 91 | SaveCiphertexts(ct, request_proto.add_query()); 92 | SEALSerialize<GaloisKeys>(gal_keys, request_proto.mutable_galois_keys()); 93 | SEALSerialize<RelinKeys>(relin_keys, request_proto.mutable_relin_keys()); 94 | 95 | ASSIGN_OR_FAIL(auto request, LoadCiphertexts(context_->SEALContext(), 96 | request_proto.query(0))); 97 | ASSERT_EQ(request.size(), 1); 98 | decryptor_->decrypt(request[0], reloaded_pt); 99 | 100 | ASSIGN_OR_FAIL(auto gal_keys_post, 101 | SEALDeserialize<GaloisKeys>(context_->SEALContext(), 102 | request_proto.galois_keys())); 103 | for (const auto& e : elts) { 104 | // Can't really test equality of the keys, so just check that they exist. 105 | ASSERT_TRUE(gal_keys_post.has_key(e)); 106 | } 107 | 108 | ASSIGN_OR_FAIL(auto relin_keys_post, 109 | SEALDeserialize<RelinKeys>(context_->SEALContext(), 110 | request_proto.relin_keys())); 111 | // Can't really check if the relin keys are valid. Just assume it's ok here. 112 | } 113 | 114 | TEST_F(PIRSerializationTest, TestRequestSerialization_Shortcut) { 115 | int64_t value = 987654321; 116 | Plaintext pt, reloaded_pt; 117 | context_->Encoder()->encode(value, pt); 118 | vector<Ciphertext> ct(1); 119 | encryptor_->encrypt(pt, ct[0]); 120 | 121 | Request request_proto; 122 | SaveRequest({ct}, &request_proto); 123 | 124 | ASSIGN_OR_FAIL(auto request, LoadCiphertexts(context_->SEALContext(), 125 | request_proto.query(0))); 126 | ASSERT_EQ(request.size(), 1); 127 | decryptor_->decrypt(request[0], reloaded_pt); 128 | 129 | ASSERT_THAT(request_proto.galois_keys(), testing::IsEmpty()); 130 | ASSERT_THAT(request_proto.relin_keys(), testing::IsEmpty()); 131 | } 132 | 133 | TEST_F(PIRSerializationTest, TestRequestSerialization_ShortcutWithRelin) { 134 | int64_t value = 987654321; 135 | Plaintext pt, reloaded_pt; 136 | context_->Encoder()->encode(value, pt); 137 | vector<Ciphertext> ct(1); 138 | encryptor_->encrypt(pt, ct[0]); 139 | 140 | auto keygen_ = make_unique<KeyGenerator>(context_->SEALContext()); 141 | auto elts = generate_galois_elts(DEFAULT_POLY_MODULUS_DEGREE); 142 | GaloisKeys gal_keys = keygen_->galois_keys_local(elts); 143 | RelinKeys relin_keys = keygen_->relin_keys_local(); 144 | 145 | Request request_proto; 146 | SaveRequest({ct}, gal_keys, relin_keys, &request_proto); 147 | 148 | ASSIGN_OR_FAIL(auto request, LoadCiphertexts(context_->SEALContext(), 149 | request_proto.query(0))); 150 | ASSERT_EQ(request.size(), 1); 151 | decryptor_->decrypt(request[0], reloaded_pt); 152 | 153 | ASSIGN_OR_FAIL(auto gal_keys_post, 154 | SEALDeserialize<GaloisKeys>(context_->SEALContext(), 155 | request_proto.galois_keys())); 156 | for (const auto& e : elts) { 157 | // Can't really test equality of the keys, so just check that they exists. 158 | ASSERT_TRUE(gal_keys_post.has_key(e)); 159 | } 160 | 161 | ASSIGN_OR_FAIL(auto relin_keys_post, 162 | SEALDeserialize<RelinKeys>(context_->SEALContext(), 163 | request_proto.relin_keys())); 164 | // Can't really check if the relin keys are valid. Just assume it's ok here. 165 | } 166 | 167 | TEST_F(PIRSerializationTest, TestEncryptionParamsSerialization) { 168 | auto params = GenerateEncryptionParams(); 169 | std::string serial; 170 | ASSERT_OK(SEALSerialize<EncryptionParameters>(params, &serial)); 171 | ASSIGN_OR_FAIL(auto decoded_params, 172 | SEALDeserialize<EncryptionParameters>(serial)); 173 | ASSERT_EQ(params.plain_modulus(), decoded_params.plain_modulus()); 174 | ASSERT_EQ(params.poly_modulus_degree(), decoded_params.poly_modulus_degree()); 175 | } 176 | } // namespace pir 177 | -------------------------------------------------------------------------------- /pir/cpp/server.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | #include "pir/cpp/server.h" 17 | 18 | #include "pir/cpp/status_asserts.h" 19 | #include "pir/cpp/utils.h" 20 | #include "seal/seal.h" 21 | #include "seal/util/polyarithsmallmod.h" 22 | 23 | namespace pir { 24 | 25 | using absl::Status; 26 | using absl::StatusOr; 27 | using ::seal::GaloisKeys; 28 | using ::seal::RelinKeys; 29 | using ::std::shared_ptr; 30 | 31 | PIRServer::PIRServer(std::unique_ptr<PIRContext> context, 32 | std::shared_ptr<PIRDatabase> db) 33 | : context_(std::move(context)), db_(db) {} 34 | 35 | StatusOr<std::unique_ptr<PIRServer>> PIRServer::Create( 36 | std::shared_ptr<PIRDatabase> db, shared_ptr<PIRParameters> params) { 37 | if (params->num_pt() != db->size()) { 38 | return absl::InvalidArgumentError("database size mismatch"); 39 | } 40 | ASSIGN_OR_RETURN(auto context, PIRContext::Create(params)); 41 | return absl::WrapUnique(new PIRServer(std::move(context), db)); 42 | } 43 | 44 | StatusOr<Response> PIRServer::ProcessRequest(const Request& request) const { 45 | Response response; 46 | ASSIGN_OR_RETURN(auto galois_keys, 47 | SEALDeserialize<GaloisKeys>(context_->SEALContext(), 48 | request.galois_keys())); 49 | 50 | const auto dimensions = context_->Params()->dimensions(); 51 | const size_t dim_sum = context_->DimensionsSum(); 52 | 53 | optional<RelinKeys> relin_keys; 54 | if (!request.relin_keys().empty()) { 55 | ASSIGN_OR_RETURN(relin_keys, 56 | SEALDeserialize<RelinKeys>(context_->SEALContext(), 57 | request.relin_keys())); 58 | } 59 | 60 | for (const auto& query : request.query()) { 61 | RETURN_IF_ERROR(processQuery(query, galois_keys, relin_keys, dim_sum, 62 | response.add_reply())); 63 | } 64 | return response; 65 | } 66 | 67 | Status PIRServer::substitute_power_x_inplace( 68 | seal::Ciphertext& ct, uint32_t power, 69 | const seal::GaloisKeys& gal_keys) const { 70 | try { 71 | context_->Evaluator()->apply_galois_inplace(ct, power, gal_keys); 72 | } catch (const std::exception& e) { 73 | return absl::InternalError(e.what()); 74 | } 75 | return absl::OkStatus(); 76 | } 77 | 78 | void PIRServer::multiply_inverse_power_of_x( 79 | const seal::Ciphertext& encrypted, uint32_t k, 80 | seal::Ciphertext& destination) const { 81 | // This has to get the actual params from the SEALContext. Using just the 82 | // params from PIR doesn't work. 83 | const auto& params = context_->SEALContext()->first_context_data()->parms(); 84 | const auto poly_modulus_degree = params.poly_modulus_degree(); 85 | const auto coeff_mod_count = params.coeff_modulus().size(); 86 | 87 | uint32_t index = 88 | ((poly_modulus_degree << 1) - k) % (poly_modulus_degree << 1); 89 | 90 | // have to make a copy here 91 | destination = encrypted; 92 | 93 | // Loop over polynomials in ciphertext 94 | for (size_t i = 0; i < encrypted.size(); i++) { 95 | // loop over each coefficient in polynomial 96 | for (size_t j = 0; j < coeff_mod_count; j++) { 97 | seal::util::negacyclic_shift_poly_coeffmod( 98 | encrypted.data(i) + (j * poly_modulus_degree), poly_modulus_degree, 99 | index, params.coeff_modulus()[j], 100 | destination.data(i) + (j * poly_modulus_degree)); 101 | } 102 | } 103 | } 104 | 105 | StatusOr<std::vector<seal::Ciphertext>> PIRServer::oblivious_expansion( 106 | const seal::Ciphertext& ct, const size_t num_items, 107 | const seal::GaloisKeys& gal_keys) const { 108 | const auto poly_modulus_degree = 109 | context_->EncryptionParams().poly_modulus_degree(); 110 | 111 | if (num_items > poly_modulus_degree) { 112 | return absl::InvalidArgumentError( 113 | "Cannot expand more items from a CT than poly modulus degree"); 114 | } 115 | 116 | size_t logm = ceil_log2(num_items); 117 | std::vector<seal::Ciphertext> results(next_power_two(num_items)); 118 | results[0] = ct; 119 | 120 | for (size_t j = 0; j < logm; ++j) { 121 | const size_t two_power_j = (1 << j); 122 | for (size_t k = 0; k < two_power_j; ++k) { 123 | auto c0 = results[k]; 124 | 125 | RETURN_IF_ERROR(substitute_power_x_inplace( 126 | c0, (poly_modulus_degree >> j) + 1, gal_keys)); 127 | 128 | // This essentially produces what the paper calls c1 129 | multiply_inverse_power_of_x(results[k], two_power_j, 130 | results[k + two_power_j]); 131 | 132 | // Do the multiply by power of x after substitution operator to avoid 133 | // having to do the substitution operator a second time, since it's about 134 | // 20x slower. Except that now instead of multiplying by x^(-2^j) we have 135 | // to do the substitution first ourselves, producing 136 | // (x^(N/2^j + 1))^(-2^j) = 1/x^(2^j * (N/2^j + 1)) = 1/x^(N + 2^j) 137 | seal::Ciphertext c1; 138 | multiply_inverse_power_of_x(c0, poly_modulus_degree + two_power_j, c1); 139 | 140 | context_->Evaluator()->add_inplace(results[k], c0); 141 | context_->Evaluator()->add_inplace(results[k + two_power_j], c1); 142 | } 143 | } 144 | results.resize(num_items); 145 | return results; 146 | } 147 | 148 | StatusOr<std::vector<seal::Ciphertext>> PIRServer::oblivious_expansion( 149 | const std::vector<seal::Ciphertext>& cts, size_t total_items, 150 | const seal::GaloisKeys& gal_keys) const { 151 | size_t poly_modulus_degree = 152 | context_->EncryptionParams().poly_modulus_degree(); 153 | 154 | if (cts.size() != total_items / poly_modulus_degree + 1) { 155 | return absl::InvalidArgumentError( 156 | "Number of ciphertexts doesn't match number of items for oblivious " 157 | "expansion."); 158 | } 159 | 160 | std::vector<seal::Ciphertext> results; 161 | results.reserve(total_items); 162 | for (const auto& ct : cts) { 163 | ASSIGN_OR_RETURN( 164 | auto v, oblivious_expansion( 165 | ct, std::min(poly_modulus_degree, total_items), gal_keys)); 166 | results.insert(results.end(), std::make_move_iterator(v.begin()), 167 | std::make_move_iterator(v.end())); 168 | total_items -= poly_modulus_degree; 169 | } 170 | return results; 171 | } 172 | 173 | Status PIRServer::processQuery(const Ciphertexts& query_proto, 174 | const GaloisKeys& galois_keys, 175 | const optional<RelinKeys>& relin_keys, 176 | const size_t& dim_sum, 177 | Ciphertexts* output) const { 178 | ASSIGN_OR_RETURN(auto query, 179 | LoadCiphertexts(context_->SEALContext(), query_proto)); 180 | 181 | ASSIGN_OR_RETURN(auto selection_vector, 182 | oblivious_expansion(query, dim_sum, galois_keys)); 183 | 184 | vector<seal::Ciphertext> results; 185 | if (relin_keys) { 186 | ASSIGN_OR_RETURN(results, 187 | db_->multiply(selection_vector, &relin_keys.value())); 188 | } else { 189 | ASSIGN_OR_RETURN(results, db_->multiply(selection_vector)); 190 | } 191 | 192 | RETURN_IF_ERROR(SaveCiphertexts(results, output)); 193 | 194 | return absl::OkStatus(); 195 | } 196 | 197 | } // namespace pir 198 | -------------------------------------------------------------------------------- /pir/cpp/server.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_SERVER_H_ 18 | #define PIR_SERVER_H_ 19 | 20 | #include <vector> 21 | 22 | #include "absl/status/statusor.h" 23 | #include "pir/cpp/context.h" 24 | #include "pir/cpp/database.h" 25 | #include "pir/cpp/serialization.h" 26 | #include "seal/seal.h" 27 | 28 | namespace pir { 29 | 30 | using absl::Status; 31 | using absl::StatusOr; 32 | using ::seal::GaloisKeys; 33 | using ::seal::RelinKeys; 34 | 35 | class PIRServer { 36 | public: 37 | /** 38 | * Creates and returns a new server instance, holding a database. 39 | * @param[in] db PIRDatabase to load 40 | * @param[in] params PIR Paramerters 41 | * @returns InvalidArgument if the database encoding fails 42 | **/ 43 | static StatusOr<std::unique_ptr<PIRServer>> Create( 44 | std::shared_ptr<PIRDatabase> database, shared_ptr<PIRParameters> params); 45 | 46 | /** 47 | * Handles a client request. 48 | * @param[in] request The PIR Payload 49 | * @returns InvalidArgument if the deserialization or encrypted operations 50 | *fail 51 | **/ 52 | StatusOr<Response> ProcessRequest(const Request& request) const; 53 | 54 | PIRServer() = delete; 55 | 56 | /** 57 | * Helper function to do the substitution operation on a ciphertext. If the 58 | * ciphertext is the encryption of polynomial p(x), then given power k, the 59 | * result will be p(x**k). 60 | * @param ct Ciphertext to operate, modification done in place. 61 | * @param[in] power The power k to raise x to in the plaintext polynomial. 62 | * @param[in] gal_keys Galois keys used for automorphism. Must be generated by 63 | * whoever encrypted the ciphertext using keygen, and must include the power 64 | * being asked for. 65 | */ 66 | Status substitute_power_x_inplace(seal::Ciphertext& ct, std::uint32_t power, 67 | const seal::GaloisKeys& gal_keys) const; 68 | 69 | /** 70 | * Helper function to multiply a ciphertext by a given power of 1/x. As a 71 | * result plaintext is also multiplied by the same power of 1/x. For example, 72 | * if the ciphertext is the encryption of p(x) = 99x^5, and if the given k is 73 | * 3, then this results in p(x) * 1/x^3 = 99x^2. 74 | * @param[in] encrypted Ciphertext to take as input. 75 | * @param[in] k Power of 1/x to multiply. 76 | * @param[out] destination Output ciphertext after multiplying by power of x. 77 | */ 78 | void multiply_inverse_power_of_x(const seal::Ciphertext& encrypted, 79 | uint32_t k, 80 | seal::Ciphertext& destination) const; 81 | 82 | /** 83 | * Performs an oblivious expansion on an input ciphertext to a vector of 84 | * ciphertexts. If the input ciphertext is the encryption of a plaintext 85 | * polynomial of the form a0 + a1*x + a2*x^2 + ... + an*x^n, where n is the 86 | * num_items below, then the output is a series of ciphertexts, where each is 87 | * the encryption of each term as the constant coefficient. In other words, 88 | * the output will be a vector of: [enc(a0), enc(a1), ..., enc(an)], where 89 | * enc(ai) is the encryption of a polynomial that only has ai in the constant 90 | * coefficient (all other terms are zero). 91 | * 92 | * The most common example of this is to expand a selection vector in PIR to 93 | * individual ciphertexts. The selection vector is just 0 in all slots except 94 | * for 1 in the desired slot. If there are 4 items, and the third item is the 95 | * one desired, the selection vector is [0, 0, 1, 0]. The client represents 96 | * this as a single polynomial of the form 1*x^2. The server uses this 97 | * expansion function to expand this to 4 ciphertexts: [enc(0), enc(0), 98 | * enc(1), enc(0)], which it then uses to produce the response. 99 | * 100 | * NB: Due to an optimization in PIR, this is left as (expanded_vector) * m, 101 | * where m is the smallest power of 2 greater than num_items. It is assumed 102 | * that the plaintext modulus will be changed to make this irrelevant. 103 | * 104 | * @param[in] ct The input ciphertext to expand. 105 | * @param[in] num_items The number of items to extract. 106 | * @param[in] gal_keys Galois keys supplied by the client. 107 | * @returns A vector of ciphertexts that are the expansion as described above. 108 | */ 109 | StatusOr<std::vector<seal::Ciphertext>> oblivious_expansion( 110 | const seal::Ciphertext& ct, const size_t num_items, 111 | const seal::GaloisKeys& gal_keys) const; 112 | 113 | /** 114 | * Extension of oblivious_expansion to multiple ciphertexts. This allows 115 | * selection vectors that are larger then poly_modulus_degree to be used. The 116 | * output of the expansion of each ciphertext is concatenated to form the 117 | * results of this function. Each ciphertext that isn't the last in the vector 118 | * are assumed to contain exactly poly_modulus_degree items to be expanded. 119 | * 120 | * @param[in] cts List of ciphertexts to use as input to the expansion 121 | * @param[in] total_items Total number of ciphertexts after expansion 122 | * @param[in] gal_keys Galois keys supplied by the client 123 | * @returns A vector of ciphertexts that are the expansion of all of the input 124 | * ciphertexts concatenated. 125 | */ 126 | StatusOr<std::vector<seal::Ciphertext>> oblivious_expansion( 127 | const std::vector<seal::Ciphertext>& cts, const size_t total_items, 128 | const seal::GaloisKeys& gal_keys) const; 129 | 130 | // Just for testing: get the context 131 | PIRContext* Context() { return context_.get(); } 132 | 133 | private: 134 | PIRServer(std::unique_ptr<PIRContext> /*sealctx*/, 135 | std::shared_ptr<PIRDatabase> /*db*/); 136 | 137 | Status processQuery(const Ciphertexts& query, const GaloisKeys& galois_keys, 138 | const optional<RelinKeys>& relin_keys, 139 | const size_t& dim_sum, Ciphertexts* output) const; 140 | 141 | std::unique_ptr<PIRContext> context_; 142 | std::shared_ptr<PIRDatabase> db_; 143 | }; 144 | 145 | } // namespace pir 146 | 147 | #endif // PIR_SERVER_H_ 148 | -------------------------------------------------------------------------------- /pir/cpp/server_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "pir/cpp/server.h" 18 | 19 | #include <algorithm> 20 | #include <iostream> 21 | #include <vector> 22 | 23 | #include "gmock/gmock.h" 24 | #include "gtest/gtest.h" 25 | #include "pir/cpp/client.h" 26 | #include "pir/cpp/ct_reencoder.h" 27 | #include "pir/cpp/status_asserts.h" 28 | #include "pir/cpp/test_base.h" 29 | #include "pir/cpp/utils.h" 30 | 31 | namespace pir { 32 | namespace { 33 | 34 | using std::cout; 35 | using std::endl; 36 | using std::get; 37 | using std::make_tuple; 38 | using std::make_unique; 39 | using std::shared_ptr; 40 | using std::string; 41 | using std::tuple; 42 | using std::unique_ptr; 43 | using std::vector; 44 | 45 | using seal::Ciphertext; 46 | using seal::GaloisKeys; 47 | using seal::Plaintext; 48 | using seal::RelinKeys; 49 | 50 | using namespace seal; 51 | using namespace ::testing; 52 | using std::int64_t; 53 | using std::vector; 54 | 55 | #ifdef TEST_DEBUG 56 | #define DEBUG_OUT(x) std::cout << x << std::endl 57 | #else 58 | #define DEBUG_OUT(x) 59 | #endif // TEST_DEBUG 60 | 61 | constexpr uint32_t POLY_MODULUS_DEGREE = 4096; 62 | constexpr uint32_t ELEM_SIZE = 7680; 63 | 64 | class PIRServerTestBase : public PIRTestingBase { 65 | protected: 66 | void SetUpDBImpl(size_t dbsize, size_t dimensions = 1, 67 | size_t elem_size = ELEM_SIZE, 68 | uint32_t plain_mod_bit_size = 20, 69 | bool use_ciphertext_multiplication = false) { 70 | SetUpParams(dbsize, elem_size, dimensions, POLY_MODULUS_DEGREE, 71 | plain_mod_bit_size, 0, use_ciphertext_multiplication); 72 | GenerateIntDB(); 73 | SetUpSealTools(); 74 | 75 | gal_keys_ = 76 | keygen_->galois_keys_local(generate_galois_elts(POLY_MODULUS_DEGREE)); 77 | relin_keys_ = keygen_->relin_keys_local(); 78 | 79 | server_ = *(PIRServer::Create(pir_db_, pir_params_)); 80 | ASSERT_THAT(server_, NotNull()); 81 | } 82 | 83 | unique_ptr<PIRServer> server_; 84 | GaloisKeys gal_keys_; 85 | RelinKeys relin_keys_; 86 | }; 87 | 88 | class PIRServerTest : public ::testing::TestWithParam<bool>, 89 | public PIRServerTestBase { 90 | protected: 91 | void SetUp() { SetUpDB(10); } 92 | void SetUpDB(size_t dbsize, size_t dimensions = 1, 93 | size_t elem_size = ELEM_SIZE, uint32_t plain_mod_bit_size = 20) { 94 | SetUpDBImpl(dbsize, dimensions, elem_size, plain_mod_bit_size, GetParam()); 95 | } 96 | }; 97 | 98 | TEST_P(PIRServerTest, TestProcessRequest_SingleCT) { 99 | const size_t desired_index = 7; 100 | Plaintext pt(POLY_MODULUS_DEGREE); 101 | pt.set_zero(); 102 | pt[desired_index] = 1; 103 | 104 | vector<Ciphertext> query(1); 105 | encryptor_->encrypt(pt, query[0]); 106 | 107 | Request request_proto; 108 | SaveRequest({query}, gal_keys_, relin_keys_, &request_proto); 109 | 110 | ASSIGN_OR_FAIL(auto result_raw, server_->ProcessRequest(request_proto)); 111 | ASSERT_EQ(result_raw.reply_size(), 1); 112 | ASSIGN_OR_FAIL(auto result, LoadCiphertexts(server_->Context()->SEALContext(), 113 | result_raw.reply(0))); 114 | ASSERT_THAT(result, SizeIs(1)); 115 | 116 | Plaintext result_pt; 117 | decryptor_->decrypt(result[0], result_pt); 118 | auto encoder = server_->Context()->Encoder(); 119 | ASSERT_THAT(encoder->decode_int64(result_pt), 120 | Eq(int_db_[desired_index] * next_power_two(db_size_))); 121 | } 122 | 123 | TEST_P(PIRServerTest, TestProcessRequest_MultiCT) { 124 | SetUpDB(5000); 125 | const size_t desired_index = 4200; 126 | Plaintext pt(POLY_MODULUS_DEGREE); 127 | pt.set_zero(); 128 | 129 | vector<Ciphertext> query(2); 130 | encryptor_->encrypt(pt, query[0]); 131 | pt[desired_index - POLY_MODULUS_DEGREE] = 1; 132 | encryptor_->encrypt(pt, query[1]); 133 | 134 | Request request_proto; 135 | SaveRequest({query}, gal_keys_, relin_keys_, &request_proto); 136 | 137 | ASSIGN_OR_FAIL(auto result_raw, server_->ProcessRequest(request_proto)); 138 | ASSERT_EQ(result_raw.reply_size(), 1); 139 | ASSIGN_OR_FAIL(auto result, LoadCiphertexts(server_->Context()->SEALContext(), 140 | result_raw.reply(0))); 141 | ASSERT_THAT(result, SizeIs(1)); 142 | 143 | Plaintext result_pt; 144 | decryptor_->decrypt(result[0], result_pt); 145 | auto encoder = server_->Context()->Encoder(); 146 | DEBUG_OUT("Expected DB value " << int_db_[desired_index]); 147 | DEBUG_OUT("Expected m " << next_power_two(db_size_ - POLY_MODULUS_DEGREE)); 148 | ASSERT_THAT(encoder->decode_int64(result_pt), 149 | Eq(int_db_[desired_index] * 150 | next_power_two(db_size_ - POLY_MODULUS_DEGREE))); 151 | } 152 | 153 | TEST_P(PIRServerTest, TestProcessBatchRequest) { 154 | const vector<size_t> indexes = {3, 4, 5}; 155 | vector<vector<Ciphertext>> queries(indexes.size()); 156 | 157 | for (size_t idx = 0; idx < indexes.size(); ++idx) { 158 | Plaintext pt(POLY_MODULUS_DEGREE); 159 | pt.set_zero(); 160 | pt[indexes[idx]] = 1; 161 | 162 | vector<Ciphertext> query(1); 163 | encryptor_->encrypt(pt, query[0]); 164 | queries[idx] = query; 165 | } 166 | 167 | Request request_proto; 168 | SaveRequest(queries, gal_keys_, relin_keys_, &request_proto); 169 | 170 | ASSIGN_OR_FAIL(auto response, server_->ProcessRequest(request_proto)); 171 | for (size_t idx = 0; idx < indexes.size(); ++idx) { 172 | ASSIGN_OR_FAIL(auto result, 173 | LoadCiphertexts(server_->Context()->SEALContext(), 174 | response.reply(idx))); 175 | ASSERT_THAT(result, SizeIs(1)); 176 | 177 | Plaintext result_pt; 178 | decryptor_->decrypt(result[0], result_pt); 179 | auto encoder = server_->Context()->Encoder(); 180 | ASSERT_THAT(encoder->decode_int64(result_pt), 181 | Eq(int_db_[indexes[idx]] * next_power_two(db_size_))); 182 | } 183 | } 184 | 185 | // Make sure that if we get a weird request from client nothing explodes. 186 | TEST_P(PIRServerTest, TestProcessRequestZeroInput) { 187 | Plaintext pt(POLY_MODULUS_DEGREE); 188 | pt.set_zero(); 189 | 190 | vector<Ciphertext> query(1); 191 | encryptor_->encrypt(pt, query[0]); 192 | 193 | Request request_proto; 194 | SaveRequest({query}, gal_keys_, relin_keys_, &request_proto); 195 | 196 | ASSIGN_OR_FAIL(auto result_raw, server_->ProcessRequest(request_proto)); 197 | ASSERT_EQ(result_raw.reply_size(), 1); 198 | ASSIGN_OR_FAIL(auto result, LoadCiphertexts(server_->Context()->SEALContext(), 199 | result_raw.reply(0))); 200 | 201 | ASSERT_THAT(result, SizeIs(1)); 202 | 203 | Plaintext result_pt; 204 | decryptor_->decrypt(result[0], result_pt); 205 | auto encoder = server_->Context()->Encoder(); 206 | ASSERT_THAT(encoder->decode_int64(result_pt), 0); 207 | } 208 | 209 | TEST_P(PIRServerTest, TestProcessRequest_2Dim) { 210 | SetUpDB(82, 2); 211 | const size_t desired_index = 42; 212 | 213 | uint64_t m_inv; 214 | ASSERT_TRUE(seal::util::try_invert_uint_mod( 215 | next_power_two(server_->Context()->DimensionsSum()), 216 | server_->Context()->EncryptionParams().plain_modulus().value(), m_inv)); 217 | 218 | Plaintext pt(POLY_MODULUS_DEGREE); 219 | pt.set_zero(); 220 | // select 4th row 221 | pt[4] = m_inv; 222 | // select 6th column (after 10-item selection vector for rows) 223 | pt[16] = m_inv; 224 | 225 | vector<Ciphertext> query(1); 226 | encryptor_->encrypt(pt, query[0]); 227 | 228 | Request request_proto; 229 | SaveRequest({query}, gal_keys_, relin_keys_, &request_proto); 230 | 231 | ASSIGN_OR_FAIL(auto response, server_->ProcessRequest(request_proto)); 232 | ASSERT_EQ(response.reply_size(), 1); 233 | ASSIGN_OR_FAIL(auto reply, LoadCiphertexts(server_->Context()->SEALContext(), 234 | response.reply(0))); 235 | 236 | Plaintext result_pt; 237 | if (GetParam()) { 238 | // CT Multiplication 239 | ASSERT_THAT(reply, SizeIs(1)); 240 | EXPECT_THAT(reply[0].size(), Eq(2)) 241 | << "Ciphertext larger than expected. Were relin keys used?"; 242 | decryptor_->decrypt(reply[0], result_pt); 243 | 244 | } else { 245 | ASSIGN_OR_FAIL(auto ct_reencoder, CiphertextReencoder::Create( 246 | server_->Context()->SEALContext())); 247 | ASSERT_THAT(reply, 248 | SizeIs(ct_reencoder->ExpansionRatio() * query[0].size())); 249 | vector<Plaintext> reply_pts(reply.size()); 250 | for (size_t i = 0; i < reply_pts.size(); ++i) { 251 | decryptor_->decrypt(reply[i], reply_pts[i]); 252 | } 253 | auto result_ct = ct_reencoder->Decode(reply_pts); 254 | EXPECT_EQ(result_ct.size(), query[0].size()); 255 | decryptor_->decrypt(result_ct, result_pt); 256 | } 257 | 258 | auto encoder = server_->Context()->Encoder(); 259 | ASSERT_THAT(encoder->decode_int64(result_pt), Eq(int_db_[desired_index])); 260 | } 261 | 262 | INSTANTIATE_TEST_SUITE_P(PIRServerTests, PIRServerTest, 263 | testing::Values(false, true)); 264 | 265 | class SubstituteOperatorTest 266 | : public PIRServerTestBase, 267 | public testing::TestWithParam<tuple<string, uint32_t, string>> { 268 | void SetUp() { SetUpDBImpl(10); } 269 | }; 270 | 271 | TEST_P(SubstituteOperatorTest, SubstituteExamples) { 272 | Plaintext input_pt(get<0>(GetParam())); 273 | DEBUG_OUT("Input PT: " << input_pt.to_string()); 274 | 275 | Ciphertext ct; 276 | encryptor_->encrypt(input_pt, ct); 277 | 278 | auto k = get<1>(GetParam()); 279 | GaloisKeys gal_keys = keygen_->galois_keys_local(vector<uint32_t>({k})); 280 | server_->substitute_power_x_inplace(ct, k, gal_keys); 281 | 282 | Plaintext result_pt; 283 | decryptor_->decrypt(ct, result_pt); 284 | DEBUG_OUT("Result PT: " << result_pt.to_string()); 285 | 286 | Plaintext expected_pt(get<2>(GetParam())); 287 | DEBUG_OUT("Expected PT: " << expected_pt.to_string()); 288 | ASSERT_THAT(result_pt, Eq(expected_pt)); 289 | } 290 | 291 | INSTANTIATE_TEST_SUITE_P( 292 | Substitutions, SubstituteOperatorTest, 293 | testing::Values(make_tuple("42", 3, "42"), make_tuple("1x^1", 5, "1x^5"), 294 | make_tuple("6x^2", 3, "6x^6"), 295 | make_tuple("1x^1", POLY_MODULUS_DEGREE + 1, "FC000x^1"), 296 | make_tuple("1x^4", POLY_MODULUS_DEGREE + 1, "1x^4"), 297 | make_tuple("1x^8", POLY_MODULUS_DEGREE / 2 + 1, "1x^8"), 298 | make_tuple("1x^8", POLY_MODULUS_DEGREE / 4 + 1, "1x^8"), 299 | make_tuple("1x^8", POLY_MODULUS_DEGREE / 8 + 1, "FC000x^8"), 300 | make_tuple("77x^4095", 3, "77x^4093"), 301 | make_tuple("1x^4095", POLY_MODULUS_DEGREE + 1, 302 | "FC000x^4095"), 303 | make_tuple("4x^4 + 33x^3 + 222x^2 + 19x^1 + 42", 304 | POLY_MODULUS_DEGREE + 1, 305 | "4x^4 + FBFCEx^3 + 222x^2 + FBFE8x^1 + 42"))); 306 | 307 | class MultiplyInversePowerXTest 308 | : public PIRServerTestBase, 309 | public testing::TestWithParam<tuple<string, uint32_t, string>> { 310 | void SetUp() { SetUpDBImpl(10); } 311 | }; 312 | 313 | TEST_P(MultiplyInversePowerXTest, MultiplyInversePowerXExamples) { 314 | Plaintext input_pt(get<0>(GetParam())); 315 | DEBUG_OUT("Input PT: " << input_pt.to_string()); 316 | 317 | Ciphertext ct; 318 | encryptor_->encrypt(input_pt, ct); 319 | 320 | auto k = get<1>(GetParam()); 321 | Ciphertext result_ct; 322 | server_->multiply_inverse_power_of_x(ct, k, result_ct); 323 | 324 | Plaintext result_pt; 325 | decryptor_->decrypt(result_ct, result_pt); 326 | DEBUG_OUT("Result PT: " << result_pt.to_string()); 327 | 328 | Plaintext expected_pt(get<2>(GetParam())); 329 | DEBUG_OUT("Expected PT: " << expected_pt.to_string()); 330 | ASSERT_THAT(result_pt, Eq(expected_pt)); 331 | } 332 | 333 | INSTANTIATE_TEST_SUITE_P(InversePowersOfX, MultiplyInversePowerXTest, 334 | testing::Values(make_tuple("42x^1", 1, "42"), 335 | make_tuple("42x^42", 41, "42x^1"), 336 | make_tuple("1x^4 + 1x^3 + 1x^1", 1, 337 | "1x^3 + 1x^2 + 1"), 338 | make_tuple("1x^16 + 1x^12 + 1x^8", 4, 339 | "1x^12 + 1x^8 + 1x^4"))); 340 | 341 | class ObliviousExpansionTest 342 | : public PIRServerTestBase, 343 | public testing::TestWithParam<tuple<string, vector<string>>> { 344 | void SetUp() { SetUpDBImpl(10); } 345 | }; 346 | 347 | TEST_P(ObliviousExpansionTest, ObliviousExpansionExamples) { 348 | Plaintext input_pt(get<0>(GetParam())); 349 | DEBUG_OUT("Input PT: " << input_pt.to_string()); 350 | 351 | Ciphertext ct; 352 | encryptor_->encrypt(input_pt, ct); 353 | 354 | auto expected = get<1>(GetParam()); 355 | ASSIGN_OR_FAIL(auto results, 356 | server_->oblivious_expansion( 357 | ct, expected.size(), 358 | keygen_->galois_keys_local( 359 | generate_galois_elts(POLY_MODULUS_DEGREE)))); 360 | 361 | vector<Plaintext> results_pt(results.size()); 362 | for (size_t i = 0; i < results.size(); ++i) { 363 | decryptor_->decrypt(results[i], results_pt[i]); 364 | DEBUG_OUT("Result PT[" << i << "]: " << results_pt[i].to_string()); 365 | } 366 | 367 | vector<Plaintext> expected_pt(expected.size()); 368 | for (size_t i = 0; i < expected_pt.size(); ++i) { 369 | expected_pt[i] = Plaintext(expected[i]); 370 | DEBUG_OUT("Expected PT[" << i << "]: " << expected_pt[i].to_string()); 371 | } 372 | 373 | ASSERT_THAT(results_pt, ContainerEq(expected_pt)); 374 | } 375 | 376 | INSTANTIATE_TEST_SUITE_P( 377 | ObliviousExpansion, ObliviousExpansionTest, 378 | testing::Values(make_tuple("1", vector<string>({"2", "0"})), 379 | make_tuple("1x^1", vector<string>({"0", "2"})), 380 | make_tuple("3x^3 + 2x^2 + 1x^1 + 42", 381 | vector<string>({"108", "4", "8", "C"})), 382 | make_tuple("1x^5", vector<string>({"0", "0", "0", "0", "0", 383 | "8"})))); 384 | 385 | class ObliviousExpansionTestMultiCT 386 | : public PIRServerTestBase, 387 | public testing::TestWithParam<tuple<size_t, size_t, uint64_t>> { 388 | void SetUp() { SetUpDBImpl(10); } 389 | }; 390 | 391 | TEST_P(ObliviousExpansionTestMultiCT, MultiCTExamples) { 392 | const auto num_items = get<0>(GetParam()); 393 | const auto index = get<1>(GetParam()); 394 | const auto expected_value = get<2>(GetParam()); 395 | 396 | vector<Plaintext> input_pt(num_items / POLY_MODULUS_DEGREE + 1, 397 | Plaintext(POLY_MODULUS_DEGREE)); 398 | input_pt[index / POLY_MODULUS_DEGREE][index % POLY_MODULUS_DEGREE] = 1; 399 | vector<Ciphertext> input_ct(input_pt.size()); 400 | for (size_t i = 0; i < input_pt.size(); ++i) { 401 | DEBUG_OUT("Input PT[" << i << "]: " << input_pt[i].to_string()); 402 | encryptor_->encrypt(input_pt[i], input_ct[i]); 403 | } 404 | 405 | ASSIGN_OR_FAIL(auto results, 406 | server_->oblivious_expansion( 407 | input_ct, num_items, 408 | keygen_->galois_keys_local( 409 | generate_galois_elts(POLY_MODULUS_DEGREE)))); 410 | 411 | ASSERT_THAT(results, SizeIs(num_items)); 412 | for (size_t i = 0; i < results.size(); ++i) { 413 | Plaintext result_pt; 414 | decryptor_->decrypt(results[i], result_pt); 415 | const auto exp = (i == index) ? expected_value : 0; 416 | EXPECT_THAT(result_pt.coeff_count(), Eq(1)) 417 | << "i = " << i << ", pt = " << result_pt.to_string(); 418 | EXPECT_THAT(result_pt[0], Eq(exp)) 419 | << "i = " << i << ", pt = " << result_pt.to_string(); 420 | } 421 | } 422 | 423 | INSTANTIATE_TEST_SUITE_P( 424 | ObliviousExpansionMultiCT, ObliviousExpansionTestMultiCT, 425 | testing::Values(make_tuple(100, 42, 128), make_tuple(100, 0, 128), 426 | make_tuple(100, 99, 128), make_tuple(4096, 3007, 4096), 427 | make_tuple(5000, 4095, 4096), 428 | make_tuple(5000, 4200, 1024))); 429 | 430 | } // namespace 431 | } // namespace pir 432 | -------------------------------------------------------------------------------- /pir/cpp/status_asserts.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "absl/memory/memory.h" 18 | #include "absl/status/statusor.h" 19 | #include "absl/strings/escaping.h" 20 | #include "absl/strings/str_cat.h" 21 | #include "absl/types/span.h" 22 | 23 | #define ASSERT_OK(expr) \ 24 | do { \ 25 | const Status _status = (expr); \ 26 | ASSERT_TRUE(_status.ok()) << "Error: " << _status.ToString(); \ 27 | } while (false) 28 | 29 | #define EXPECT_OK(expr) \ 30 | do { \ 31 | const Status _status = (expr); \ 32 | EXPECT_TRUE(_status.ok()) << "Error: " << _status.ToString(); \ 33 | } while (false) 34 | 35 | #define ASSIGN_OR_FAIL(lhs, rexpr) \ 36 | ASSIGN_OR_FAIL_IMPL_(CONCAT_NAME_(status_or_, __LINE__), lhs, rexpr) 37 | 38 | #define ASSIGN_OR_RETURN(lhs, rexpr) \ 39 | ASSIGN_OR_RETURN_IMPL_(CONCAT_NAME_(status_or_, __LINE__), lhs, rexpr) 40 | 41 | #define CONCAT_NAME_INNER_(x, y) x##y 42 | #define CONCAT_NAME_(x, y) CONCAT_NAME_INNER_(x, y) 43 | 44 | #define ASSIGN_OR_FAIL_IMPL_(statusor, lhs, rexpr) \ 45 | auto statusor = (rexpr); \ 46 | ASSERT_TRUE(statusor.ok()) << "Error: " << statusor.status().ToString(); \ 47 | lhs = std::move(*statusor); 48 | 49 | #define ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rexpr) \ 50 | auto statusor = (rexpr); \ 51 | if (!statusor.ok()) return statusor.status(); \ 52 | lhs = std::move(*statusor); 53 | -------------------------------------------------------------------------------- /pir/cpp/string_encoder.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "pir/cpp/string_encoder.h" 18 | 19 | #include "pir/cpp/status_asserts.h" 20 | 21 | namespace pir { 22 | 23 | using absl::InvalidArgumentError; 24 | 25 | size_t StringEncoder::num_items_per_plaintext(size_t item_size) { 26 | return poly_modulus_degree_ * bits_per_coeff_ / item_size / 8; 27 | } 28 | 29 | size_t StringEncoder::max_bytes_per_plaintext() { 30 | return poly_modulus_degree_ * bits_per_coeff_ / 8; 31 | } 32 | 33 | /** 34 | * Helper class for encoding strings to PT coefficients and keeping track of 35 | * where we are. 36 | */ 37 | class StringEncoderImpl { 38 | public: 39 | StringEncoderImpl(Plaintext& destination, size_t bits_per_coeff) 40 | : destination_(destination), 41 | bits_per_coeff_(bits_per_coeff), 42 | coeff_bits_(bits_per_coeff) {} 43 | 44 | StringEncoderImpl() = delete; 45 | 46 | void encode(const string& value); 47 | void terminate(); 48 | 49 | private: 50 | Plaintext& destination_; 51 | size_t bits_per_coeff_; 52 | 53 | // temporary variables for encoding 54 | size_t coeff_index_ = 0; 55 | size_t coeff_bits_; 56 | }; 57 | 58 | void StringEncoderImpl::encode(const string& value) { 59 | for (uint8_t c : value) { 60 | size_t remain_bits = 8; 61 | while (remain_bits > 0) { 62 | size_t n = std::min(coeff_bits_, remain_bits); 63 | destination_[coeff_index_] <<= n; 64 | destination_[coeff_index_] |= (c >> (8 - n)); 65 | c <<= n; 66 | coeff_bits_ -= n; 67 | remain_bits -= n; 68 | if (coeff_bits_ <= 0) { 69 | ++coeff_index_; 70 | coeff_bits_ = bits_per_coeff_; 71 | } 72 | } 73 | } 74 | } 75 | 76 | void StringEncoderImpl::terminate() { 77 | if (coeff_bits_ < bits_per_coeff_ && coeff_bits_ > 0) { 78 | destination_[coeff_index_] <<= coeff_bits_; 79 | } 80 | } 81 | StringEncoder::StringEncoder(shared_ptr<seal::SEALContext> context) 82 | : context_(context) { 83 | const auto params = context_->first_context_data()->parms(); 84 | poly_modulus_degree_ = params.poly_modulus_degree(); 85 | bits_per_coeff_ = log2(params.plain_modulus().value()); 86 | } 87 | 88 | StatusOr<size_t> StringEncoder::calc_num_coeff(size_t num_bytes) const { 89 | size_t num_coeff = ceil(static_cast<double>(num_bytes * 8) / bits_per_coeff_); 90 | if (num_coeff > poly_modulus_degree_) { 91 | return InvalidArgumentError( 92 | "Number of coefficients needed greater than poly modulus degree"); 93 | } 94 | return num_coeff; 95 | } 96 | 97 | Status StringEncoder::encode(const string& value, 98 | Plaintext& destination) const { 99 | ASSIGN_OR_RETURN(const auto num_coeff, calc_num_coeff(value.size())); 100 | destination.resize(num_coeff); 101 | destination.set_zero(); 102 | StringEncoderImpl impl(destination, bits_per_coeff_); 103 | impl.encode(value); 104 | impl.terminate(); 105 | return absl::OkStatus(); 106 | } 107 | 108 | Status StringEncoder::encode(vector<string>::const_iterator v, 109 | const vector<string>::const_iterator end, 110 | Plaintext& destination) const { 111 | size_t total_size = std::accumulate( 112 | v, end, 0, [](int a, const string& b) { return a + b.size(); }); 113 | ASSIGN_OR_RETURN(auto num_coeff, calc_num_coeff(total_size)); 114 | destination.resize(num_coeff); 115 | destination.set_zero(); 116 | StringEncoderImpl impl(destination, bits_per_coeff_); 117 | while (v != end) { 118 | impl.encode(*(v++)); 119 | } 120 | impl.terminate(); 121 | return absl::OkStatus(); 122 | } 123 | 124 | StatusOr<string> StringEncoder::decode(const Plaintext& pt, size_t length, 125 | size_t byte_offset) const { 126 | if ((byte_offset + length) > (pt.coeff_count() * bits_per_coeff_ / 8)) { 127 | return InvalidArgumentError( 128 | "Requested decode beyond end of data in polynomial"); 129 | } 130 | size_t start_coeff_index = byte_offset * 8 / bits_per_coeff_; 131 | size_t coeff_bits = 132 | ((start_coeff_index + 1) * bits_per_coeff_) - (byte_offset * 8); 133 | if (coeff_bits <= 0) { 134 | coeff_bits = bits_per_coeff_; 135 | } 136 | if (length <= 0) { 137 | length = pt.significant_coeff_count() * bits_per_coeff_ / 8; 138 | } 139 | string result(length, 0); 140 | size_t result_index = 0; 141 | size_t remain_bits = 8; 142 | for (size_t i = start_coeff_index; i < pt.coeff_count(); ++i) { 143 | while (coeff_bits > 0) { 144 | size_t n = std::min(coeff_bits, remain_bits); 145 | result[result_index] <<= n; 146 | result[result_index] |= (pt[i] >> (coeff_bits - n)); 147 | 148 | coeff_bits -= n; 149 | remain_bits -= n; 150 | if (remain_bits <= 0) { 151 | if (++result_index >= length) return result; 152 | remain_bits = 8; 153 | } 154 | } 155 | coeff_bits = bits_per_coeff_; 156 | } 157 | return result; 158 | } 159 | 160 | } // namespace pir 161 | -------------------------------------------------------------------------------- /pir/cpp/string_encoder.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_STRING_ENCODER_H_ 18 | #define PIR_STRING_ENCODER_H_ 19 | 20 | #include <string> 21 | 22 | #include "absl/status/statusor.h" 23 | #include "seal/seal.h" 24 | 25 | namespace pir { 26 | 27 | using absl::Status; 28 | using absl::StatusOr; 29 | using seal::Plaintext; 30 | using std::shared_ptr; 31 | using std::string; 32 | using std::vector; 33 | 34 | class StringEncoder { 35 | public: 36 | // just for now / testing. Change this to factory later 37 | StringEncoder(shared_ptr<seal::SEALContext> context); 38 | 39 | /** 40 | * Calculate the number of items that can be encoded into a single plaintext. 41 | * @param[in] item_size Size of each item in database 42 | * @returns Number of items per plaintext 43 | */ 44 | size_t num_items_per_plaintext(size_t item_size); 45 | 46 | /** 47 | * Calculate the maximum number of bytes that can be encded in a single pt. 48 | */ 49 | size_t max_bytes_per_plaintext(); 50 | 51 | /** 52 | * Encode a string of binary value into the destination using a 53 | * minimal amount of coefficients. 54 | * @param[in] value String to encode 55 | * @param[out] destination Plaintext to populate with encoded value 56 | * @returns Invalid argument if string is too big for plaintext polynomial 57 | */ 58 | Status encode(const string& value, Plaintext& destination) const; 59 | 60 | /** 61 | * Encodes several strings into a plaintext using the 62 | * minimal amount of coefficients. 63 | * @param[in] v Iterator pointing to the start of values to encode 64 | * @param[in] end End of the sequence of values to 65 | * @param[out] destination Plaintext to populate with encoded value 66 | * @returns Invalid argument if total string length is too big for plaintext 67 | * polynomial 68 | */ 69 | Status encode(vector<string>::const_iterator v, 70 | const vector<string>::const_iterator end, 71 | Plaintext& destination) const; 72 | 73 | /** 74 | * Decode a plaintext assumed to be in packed form into a string. 75 | * @param[in] pt The plaintext value to decode from. 76 | * @param[in] length The length in bytes of the string to decode. If not 77 | * provided or set to zero, decodes the values from all significant 78 | * coefficients in plaintext polynomial. 79 | * @param[in] offset Offset in bytes from the start of the plaintext from 80 | * which to decode. 81 | * @returns String decoded or Error 82 | */ 83 | StatusOr<string> decode(const Plaintext& pt, size_t length = 0, 84 | size_t offset = 0) const; 85 | 86 | /** 87 | * Allows overriding number of bits to pack per coefficient. 88 | */ 89 | void set_bits_per_coeff(size_t bits_per_coeff) { 90 | bits_per_coeff_ = bits_per_coeff; 91 | } 92 | 93 | /** 94 | * Number of bits to use per coefficient. 95 | */ 96 | size_t bits_per_coeff() { return bits_per_coeff_; } 97 | 98 | private: 99 | shared_ptr<seal::SEALContext> context_; 100 | size_t poly_modulus_degree_; 101 | size_t bits_per_coeff_; 102 | 103 | // Helper to calculate the number of coefficients needed to encode a number of 104 | // bytes of input in the current context, or InvalidArgumentError if the input 105 | // is too long. 106 | StatusOr<size_t> calc_num_coeff(size_t num_bytes) const; 107 | }; 108 | 109 | } // namespace pir 110 | 111 | #endif // PIR_STRING_ENCODER_H_ 112 | -------------------------------------------------------------------------------- /pir/cpp/string_encoder_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "pir/cpp/string_encoder.h" 18 | 19 | #include <iostream> 20 | #include <memory> 21 | 22 | #include "gmock/gmock.h" 23 | #include "gtest/gtest.h" 24 | #include "pir/cpp/parameters.h" 25 | #include "pir/cpp/status_asserts.h" 26 | 27 | namespace pir { 28 | namespace { 29 | 30 | using std::cout; 31 | using std::endl; 32 | using std::make_unique; 33 | using std::unique_ptr; 34 | 35 | using namespace seal; 36 | using namespace testing; 37 | 38 | constexpr size_t POLY_MODULUS_DEGREE = 4096; 39 | 40 | class StringEncoderTest : public ::testing::Test { 41 | protected: 42 | void SetUp() { 43 | auto params = GenerateEncryptionParams(POLY_MODULUS_DEGREE); 44 | seal_context_ = seal::SEALContext::Create(params); 45 | if (!seal_context_->parameters_set()) { 46 | FAIL() << "Error setting encryption parameters: " 47 | << seal_context_->parameter_error_message(); 48 | } 49 | keygen_ = make_unique<KeyGenerator>(seal_context_); 50 | encryptor_ = make_unique<Encryptor>(seal_context_, keygen_->public_key()); 51 | evaluator_ = make_unique<Evaluator>(seal_context_); 52 | decryptor_ = make_unique<Decryptor>(seal_context_, keygen_->secret_key()); 53 | encoder_ = std::make_unique<StringEncoder>(seal_context_); 54 | } 55 | 56 | shared_ptr<SEALContext> seal_context_; 57 | unique_ptr<StringEncoder> encoder_; 58 | unique_ptr<KeyGenerator> keygen_; 59 | unique_ptr<Encryptor> encryptor_; 60 | unique_ptr<Evaluator> evaluator_; 61 | unique_ptr<Decryptor> decryptor_; 62 | }; 63 | 64 | TEST_F(StringEncoderTest, TestNumItemsPerPlaintext) { 65 | EXPECT_EQ(encoder_->num_items_per_plaintext(1), 9728); 66 | EXPECT_EQ(encoder_->num_items_per_plaintext(9728), 1); 67 | EXPECT_EQ(encoder_->num_items_per_plaintext(9729), 0); 68 | EXPECT_EQ(encoder_->num_items_per_plaintext(99999), 0); 69 | EXPECT_EQ(encoder_->num_items_per_plaintext(64), 152); 70 | EXPECT_EQ(encoder_->num_items_per_plaintext(288), 33); 71 | } 72 | 73 | TEST_F(StringEncoderTest, TestEncodeDecode) { 74 | string value("This is a string test for random VALUES@!#"); 75 | size_t num_coeff = ceil((value.size() * 8) / 19.0); 76 | Plaintext pt; 77 | EXPECT_OK(encoder_->encode(value, pt)); 78 | EXPECT_EQ(pt.coeff_count(), num_coeff); 79 | ASSIGN_OR_FAIL(auto result, encoder_->decode(pt)); 80 | ASSERT_GE(result.size(), value.size()); 81 | EXPECT_EQ(result.substr(0, value.size()), value); 82 | EXPECT_THAT(result.substr(value.size()), Each(0)); 83 | } 84 | 85 | TEST_F(StringEncoderTest, TestEncodeDecodePRN) { 86 | auto prng = 87 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 88 | string v(9728, 0); 89 | prng->generate(v.size(), reinterpret_cast<SEAL_BYTE *>(v.data())); 90 | Plaintext pt; 91 | EXPECT_OK(encoder_->encode(v, pt)); 92 | ASSIGN_OR_FAIL(auto result, encoder_->decode(pt)); 93 | ASSERT_GE(result.size(), v.size()); 94 | EXPECT_EQ(result.substr(0, v.size()), v); 95 | EXPECT_THAT(result.substr(v.size()), Each(0)); 96 | } 97 | 98 | TEST_F(StringEncoderTest, TestEncodeDecodeVector) { 99 | auto prng = 100 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 101 | vector<string> v(152); 102 | for (auto &s : v) { 103 | s.resize(64); 104 | prng->generate(s.size(), reinterpret_cast<SEAL_BYTE *>(s.data())); 105 | } 106 | 107 | for (const auto &s : v) { 108 | EXPECT_THAT(s, SizeIs(64)); 109 | EXPECT_THAT(s, Not(Each(0))); 110 | } 111 | 112 | Plaintext pt; 113 | EXPECT_OK(encoder_->encode(v.begin(), v.end(), pt)); 114 | size_t offset = 0; 115 | for (size_t i = 0; i < v.size(); ++i) { 116 | ASSIGN_OR_FAIL(auto result, encoder_->decode(pt, v[i].size(), offset)); 117 | offset += v[i].size(); 118 | EXPECT_THAT(result, StrEq(v[i])) << "i = " << i; 119 | } 120 | } 121 | 122 | TEST_F(StringEncoderTest, TestEncodeDecodeTooBig) { 123 | auto prng = 124 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 125 | string v(9729, 0); 126 | prng->generate(v.size(), reinterpret_cast<SEAL_BYTE *>(v.data())); 127 | Plaintext pt; 128 | auto status = encoder_->encode(v, pt); 129 | EXPECT_FALSE(status.ok()); 130 | EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); 131 | } 132 | 133 | TEST_F(StringEncoderTest, TestEncodeVectorTooBig) { 134 | auto prng = 135 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 136 | vector<string> v(141); 137 | for (auto &s : v) { 138 | s.resize(69); 139 | prng->generate(s.size(), reinterpret_cast<SEAL_BYTE *>(s.data())); 140 | } 141 | 142 | Plaintext pt; 143 | auto status = encoder_->encode(v.begin(), v.end(), pt); 144 | EXPECT_FALSE(status.ok()); 145 | EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); 146 | } 147 | 148 | TEST_F(StringEncoderTest, TestDecodeTooBig) { 149 | auto prng = 150 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 151 | string v(9728, 0); 152 | prng->generate(v.size(), reinterpret_cast<SEAL_BYTE *>(v.data())); 153 | Plaintext pt; 154 | EXPECT_OK(encoder_->encode(v, pt)); 155 | auto result_or = encoder_->decode(pt, 100, 9629); 156 | EXPECT_FALSE(result_or.status().ok()); 157 | EXPECT_EQ(result_or.status().code(), absl::StatusCode::kInvalidArgument); 158 | } 159 | 160 | TEST_F(StringEncoderTest, TestEncOp) { 161 | auto prng = 162 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({42}); 163 | string v(9728, 0); 164 | prng->generate(v.size(), reinterpret_cast<SEAL_BYTE *>(v.data())); 165 | Plaintext pt; 166 | EXPECT_OK(encoder_->encode(v, pt)); 167 | 168 | Plaintext selection_vector_pt(POLY_MODULUS_DEGREE); 169 | selection_vector_pt.set_zero(); 170 | selection_vector_pt[0] = 1; 171 | Ciphertext selection_vector_ct; 172 | encryptor_->encrypt(selection_vector_pt, selection_vector_ct); 173 | 174 | evaluator_->multiply_plain_inplace(selection_vector_ct, pt); 175 | 176 | Plaintext result_pt; 177 | decryptor_->decrypt(selection_vector_ct, result_pt); 178 | ASSIGN_OR_FAIL(auto result, encoder_->decode(result_pt)); 179 | ASSERT_GE(result.size(), v.size()); 180 | EXPECT_EQ(result.substr(0, v.size()), v); 181 | EXPECT_THAT(result.substr(v.size()), Each(0)); 182 | } 183 | 184 | class StringEncoderMaxBytesPerPlaintextTest 185 | : public testing::TestWithParam<tuple<uint32_t, uint32_t, uint32_t>> { 186 | protected: 187 | void SetUp() { 188 | auto params = 189 | GenerateEncryptionParams(get<0>(GetParam()), get<1>(GetParam())); 190 | seal_context_ = seal::SEALContext::Create(params); 191 | if (!seal_context_->parameters_set()) { 192 | FAIL() << "Error setting encryption parameters: " 193 | << seal_context_->parameter_error_message(); 194 | } 195 | encoder_ = std::make_unique<StringEncoder>(seal_context_); 196 | } 197 | 198 | shared_ptr<SEALContext> seal_context_; 199 | unique_ptr<StringEncoder> encoder_; 200 | }; 201 | 202 | TEST_P(StringEncoderMaxBytesPerPlaintextTest, Examples) { 203 | const auto exp = get<2>(GetParam()); 204 | EXPECT_EQ(encoder_->max_bytes_per_plaintext(), exp); 205 | } 206 | 207 | INSTANTIATE_TEST_SUITE_P(StringEncoderMaxBytesPerPlaintext, 208 | StringEncoderMaxBytesPerPlaintextTest, 209 | Values(make_tuple(4096, 20, 9728), 210 | make_tuple(4096, 16, 7680), 211 | make_tuple(8192, 20, 19456))); 212 | 213 | } // namespace 214 | } // namespace pir 215 | -------------------------------------------------------------------------------- /pir/cpp/test_base.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | #include "pir/cpp/test_base.h" 17 | 18 | #include "gtest/gtest.h" 19 | #include "pir/cpp/parameters.h" 20 | #include "pir/cpp/status_asserts.h" 21 | #include "pir/cpp/utils.h" 22 | 23 | namespace pir { 24 | 25 | using absl::make_unique; 26 | 27 | vector<string> generate_test_db(size_t db_size, size_t elem_size, 28 | uint64_t seed) { 29 | auto prng = 30 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({seed}); 31 | vector<string> db(db_size, string(elem_size, 0)); 32 | for (size_t i = 0; i < db_size; ++i) { 33 | prng->generate(db[i].size(), 34 | reinterpret_cast<seal::SEAL_BYTE*>(db[i].data())); 35 | } 36 | return db; 37 | } 38 | 39 | void PIRTestingBase::SetUpParams(size_t db_size, size_t elem_size, 40 | size_t dimensions, 41 | uint32_t poly_modulus_degree, 42 | uint32_t plain_mod_bit_size, 43 | uint32_t bits_per_coeff, 44 | bool use_ciphertext_multiplication) { 45 | db_size_ = db_size; 46 | 47 | auto encryption_params = 48 | GenerateEncryptionParams(poly_modulus_degree, plain_mod_bit_size); 49 | 50 | seal_context_ = seal::SEALContext::Create(encryption_params); 51 | if (!seal_context_->parameters_set()) { 52 | FAIL() << "Error setting encryption parameters: " 53 | << seal_context_->parameter_error_message(); 54 | } 55 | 56 | ASSIGN_OR_FAIL( 57 | pir_params_, 58 | CreatePIRParameters(db_size, elem_size, dimensions, encryption_params, 59 | use_ciphertext_multiplication, bits_per_coeff)); 60 | } 61 | 62 | void PIRTestingBase::GenerateDB(uint32_t seed) { 63 | string_db_ = generate_test_db(db_size_, pir_params_->bytes_per_item(), seed); 64 | ASSIGN_OR_FAIL(pir_db_, PIRDatabase::Create(string_db_, pir_params_)); 65 | } 66 | 67 | void PIRTestingBase::GenerateIntDB(uint32_t seed) { 68 | auto prng = 69 | seal::UniformRandomGeneratorFactory::DefaultFactory()->create({seed}); 70 | int_db_.resize(db_size_, 0); 71 | for (size_t i = 0; i < db_size_; ++i) { 72 | // can't use full size, or will run out of room on decode when multiplied by 73 | // selection vector 74 | prng->generate(sizeof(int_db_[i]) - 2, 75 | reinterpret_cast<seal::SEAL_BYTE*>(&int_db_[i])); 76 | } 77 | ASSIGN_OR_FAIL(pir_db_, PIRDatabase::Create(int_db_, pir_params_)); 78 | } 79 | 80 | void PIRTestingBase::SetUpSealTools() { 81 | keygen_ = make_unique<KeyGenerator>(seal_context_); 82 | encryptor_ = make_unique<Encryptor>(seal_context_, keygen_->public_key()); 83 | decryptor_ = make_unique<Decryptor>(seal_context_, keygen_->secret_key()); 84 | } 85 | 86 | } // namespace pir 87 | -------------------------------------------------------------------------------- /pir/cpp/test_base.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_TEST_UTILS_H 18 | #define PIR_TEST_UTILS_H 19 | 20 | #include <memory> 21 | #include <string> 22 | #include <vector> 23 | 24 | #include "pir/cpp/database.h" 25 | #include "pir/proto/payload.pb.h" 26 | #include "seal/seal.h" 27 | 28 | namespace pir { 29 | 30 | using std::int64_t; 31 | using std::shared_ptr; 32 | using std::string; 33 | using std::unique_ptr; 34 | using std::vector; 35 | 36 | using namespace seal; 37 | 38 | constexpr uint32_t POLY_MODULUS_DEGREE = 4096; 39 | 40 | // Utility function to generate a vector of testing data 41 | vector<string> generate_test_db(size_t db_size, size_t elem_size, 42 | uint64_t seed = 42); 43 | 44 | class PIRTestingBase { 45 | public: 46 | PIRTestingBase() {} 47 | virtual ~PIRTestingBase() {} 48 | 49 | protected: 50 | // Generate the EncryptParameters and PIRParameters and validate them. 51 | void SetUpParams(size_t db_size, size_t elem_size, size_t dimensions = 1, 52 | uint32_t poly_modulus_degree = POLY_MODULUS_DEGREE, 53 | uint32_t plain_mod_bit_size = 20, 54 | uint32_t bits_per_coeff = 0, 55 | bool use_ciphertext_multiplication = false); 56 | 57 | // Genrate a DB of random values 58 | void GenerateDB(uint32_t seed = 42); 59 | void GenerateIntDB(uint32_t seed = 42); 60 | 61 | void SetUpSealTools(); 62 | 63 | size_t db_size_; 64 | vector<string> string_db_; 65 | vector<int64_t> int_db_; 66 | shared_ptr<SEALContext> seal_context_; 67 | shared_ptr<PIRParameters> pir_params_; 68 | shared_ptr<PIRDatabase> pir_db_; 69 | unique_ptr<KeyGenerator> keygen_; 70 | unique_ptr<Encryptor> encryptor_; 71 | unique_ptr<Decryptor> decryptor_; 72 | }; 73 | } // namespace pir 74 | 75 | #endif // PIR_TEST_UTILS_H 76 | -------------------------------------------------------------------------------- /pir/cpp/utils.cpp: -------------------------------------------------------------------------------- 1 | #include "pir/cpp/utils.h" 2 | 3 | namespace pir { 4 | 5 | using std::vector; 6 | 7 | vector<uint32_t> generate_galois_elts(uint64_t N) { 8 | const size_t logN = ceil_log2(N); 9 | vector<uint32_t> galois_elts(logN); 10 | for (size_t i = 0; i < logN; ++i) { 11 | galois_elts[i] = (N >> i) + 1; 12 | } 13 | return galois_elts; 14 | } 15 | 16 | uint32_t log2(uint32_t v) { 17 | static const int MultiplyDeBruijnBitPosition[32] = { 18 | 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30, 19 | 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31}; 20 | 21 | v |= v >> 1; 22 | v |= v >> 2; 23 | v |= v >> 4; 24 | v |= v >> 8; 25 | v |= v >> 16; 26 | 27 | return MultiplyDeBruijnBitPosition[(uint32_t)(v * 0x07C4ACDDU) >> 27]; 28 | } 29 | 30 | uint32_t ceil_log2(uint32_t v) { 31 | static const int MultiplyDeBruijnBitPosition[32] = { 32 | 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 33 | 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9}; 34 | 35 | --v; 36 | v |= v >> 1; 37 | v |= v >> 2; 38 | v |= v >> 4; 39 | v |= v >> 8; 40 | v |= v >> 16; 41 | ++v; 42 | 43 | return MultiplyDeBruijnBitPosition[(uint32_t)(v * 0x077CB531U) >> 27]; 44 | } 45 | 46 | } // namespace pir 47 | -------------------------------------------------------------------------------- /pir/cpp/utils.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #ifndef PIR_UTILS_H_ 18 | #define PIR_UTILS_H_ 19 | 20 | #include <cstdint> 21 | #include <vector> 22 | 23 | namespace pir { 24 | 25 | // Utility function to generate Galois elements needed for Oblivious Expansion. 26 | std::vector<uint32_t> generate_galois_elts(uint64_t N); 27 | 28 | // Utility function to find the next highest power of 2 of a given number. 29 | template <typename t> 30 | t next_power_two(t n) { 31 | if (n == 0) return 1; 32 | --n; 33 | for (size_t i = 1; i < sizeof(n) * 8; i = i << 1) { 34 | n |= n >> i; 35 | } 36 | return n + 1; 37 | } 38 | 39 | // Utility function to find the log base 2 of v rounded up. 40 | uint32_t ceil_log2(uint32_t v); 41 | 42 | // Utility function to find the log base 2 of v truncated. 43 | uint32_t log2(uint32_t v); 44 | 45 | // Utility function to calculate integer power 46 | inline size_t ipow(size_t base, size_t exp) { 47 | size_t result = 1; 48 | for (;;) { 49 | if (exp & 1) { 50 | result *= base; 51 | } 52 | exp >>= 1; 53 | if (!exp) break; 54 | base *= base; 55 | } 56 | return result; 57 | } 58 | 59 | } // namespace pir 60 | 61 | namespace private_join_and_compute { 62 | // Really don't know why this isn't included with private_join_and_compute 63 | 64 | // Run a command that returns a util::Status. If the called code returns an 65 | // error status, return that status up out of this method too. 66 | // 67 | // Example: 68 | // RETURN_IF_ERROR(DoThings(4)); 69 | #define RETURN_IF_ERROR(expr) \ 70 | do { \ 71 | /* Using _status below to avoid capture problems if expr is "status". */ \ 72 | const Status _status = (expr); \ 73 | if (!_status.ok()) return _status; \ 74 | } while (0) 75 | 76 | } // namespace private_join_and_compute 77 | 78 | #endif // PIR_UTILS_H_ 79 | -------------------------------------------------------------------------------- /pir/cpp/utils_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include "pir/cpp/utils.h" 18 | 19 | #include "gtest/gtest.h" 20 | 21 | namespace pir { 22 | namespace { 23 | 24 | TEST(NextPowerTwoTest, NextPowerTwo) { 25 | EXPECT_EQ(next_power_two(0), 1); 26 | EXPECT_EQ(next_power_two(1), 1); 27 | EXPECT_EQ(next_power_two(2), 2); 28 | EXPECT_EQ(next_power_two(3), 4); 29 | EXPECT_EQ(next_power_two(8), 8); 30 | EXPECT_EQ(next_power_two(9), 16); 31 | EXPECT_EQ(next_power_two(1 << 16), 65536); 32 | EXPECT_EQ(next_power_two((1 << 16) + 1), 131072); 33 | EXPECT_EQ(next_power_two((1UL << 30) + 1), 2147483648); 34 | } 35 | 36 | TEST(CeilLog2Test, CeilLog2) { 37 | EXPECT_EQ(ceil_log2(1), 0); 38 | EXPECT_EQ(ceil_log2(2), 1); 39 | EXPECT_EQ(ceil_log2(3), 2); 40 | EXPECT_EQ(ceil_log2(8), 3); 41 | EXPECT_EQ(ceil_log2(15), 4); 42 | EXPECT_EQ(ceil_log2(16), 4); 43 | EXPECT_EQ(ceil_log2(17), 5); 44 | EXPECT_EQ(ceil_log2((1 << 16) - 1), 16); 45 | EXPECT_EQ(ceil_log2(1 << 16), 16); 46 | EXPECT_EQ(ceil_log2((1 << 16) + 1), 17); 47 | EXPECT_EQ(ceil_log2(1UL << 31), 31); 48 | } 49 | 50 | TEST(Log2Test, Log2) { 51 | EXPECT_EQ(log2(1), 0); 52 | EXPECT_EQ(log2(2), 1); 53 | EXPECT_EQ(log2(3), 1); 54 | EXPECT_EQ(log2(8), 3); 55 | EXPECT_EQ(log2(15), 3); 56 | EXPECT_EQ(log2(16), 4); 57 | EXPECT_EQ(log2(17), 4); 58 | EXPECT_EQ(log2((1 << 16) - 1), 15); 59 | EXPECT_EQ(log2(1 << 16), 16); 60 | EXPECT_EQ(log2((1 << 16) + 1), 16); 61 | EXPECT_EQ(log2((1UL << 31) - 1), 30); 62 | EXPECT_EQ(log2(1UL << 31), 31); 63 | } 64 | 65 | } // namespace 66 | } // namespace pir 67 | -------------------------------------------------------------------------------- /pir/deps.bzl: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 18 | load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains") 19 | load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies") 20 | 21 | def pir_deps(): 22 | if "com_google_googletest" not in native.existing_rules(): 23 | http_archive( 24 | name = "com_google_googletest", 25 | sha256 = "94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91", 26 | strip_prefix = "googletest-release-1.10.0", 27 | url = "https://github.com/google/googletest/archive/release-1.10.0.zip", 28 | ) 29 | 30 | if "com_google_benchmark" not in native.existing_rules(): 31 | http_archive( 32 | name = "com_google_benchmark", 33 | sha256 = "a9d41abe1bd45a707d39fdfd46c01b92e340923bc5972c0b54a48002a9a7cfa3", 34 | strip_prefix = "benchmark-8cead007830bdbe94b7cc259e873179d0ef84da6", 35 | url = "https://github.com/google/benchmark/archive/8cead007830bdbe94b7cc259e873179d0ef84da6.zip", 36 | ) 37 | 38 | if "com_google_absl" not in native.existing_rules(): 39 | http_archive( 40 | name = "com_google_absl", 41 | sha256 = "d29785bb94deaba45946d40bde5b356c66a4eb76505de0181ea9a23c46bc5ed4", 42 | strip_prefix = "abseil-cpp-592924480acf034aec0454160492a20bccdbdf3e", 43 | url = "https://github.com/abseil/abseil-cpp/archive/592924480acf034aec0454160492a20bccdbdf3e.zip", 44 | ) 45 | 46 | if "com_github_glog_glog" not in native.existing_rules(): 47 | http_archive( 48 | name = "com_github_glog_glog", 49 | sha256 = "ec64c82f3c2cd5be25d18f52bcca2840c1b29cf3d109cd61149935838645817b", 50 | strip_prefix = "glog-381e349a5bc3fd858a84b80c48ac465ad79c4a71", 51 | urls = ["https://github.com/schoppmp/glog/archive/381e349a5bc3fd858a84b80c48ac465ad79c4a71.zip"], 52 | ) 53 | 54 | if "com_github_gflags_gflags" not in native.existing_rules(): 55 | http_archive( 56 | name = "com_github_gflags_gflags", 57 | sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf", 58 | strip_prefix = "gflags-2.2.2", 59 | urls = [ 60 | "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz", 61 | ], 62 | ) 63 | 64 | if "com_microsoft_seal" not in native.existing_rules(): 65 | http_archive( 66 | name = "com_microsoft_seal", 67 | build_file = "//third_party:seal.BUILD", 68 | sha256 = "13674a39a48c0d1c6ff544521cf10ee539ce1af75c02bfbe093f7621869e3406", 69 | strip_prefix = "SEAL-3.5.6", 70 | urls = ["https://github.com/microsoft/SEAL/archive/v3.5.6.tar.gz"], 71 | ) 72 | 73 | rules_proto_dependencies() 74 | 75 | rules_proto_toolchains() 76 | 77 | rules_foreign_cc_dependencies() 78 | -------------------------------------------------------------------------------- /pir/preload.bzl: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 18 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") 19 | 20 | def pir_preload(): 21 | if "rules_proto" not in native.existing_rules(): 22 | ver = "f7a30f6f80006b591fa7c437fe5a951eb10bcbcf" 23 | http_archive( 24 | name = "rules_proto", 25 | sha256 = "9fc210a34f0f9e7cc31598d109b5d069ef44911a82f507d5a88716db171615a8", 26 | strip_prefix = "rules_proto-" + ver, 27 | urls = [ 28 | "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/" + ver + ".tar.gz", 29 | "https://github.com/bazelbuild/rules_proto/archive/" + ver + ".tar.gz", 30 | ], 31 | ) 32 | 33 | if "rules_foreign_cc" not in native.existing_rules(): 34 | git_repository( 35 | name = "rules_foreign_cc", 36 | remote = "https://github.com/bazelbuild/rules_foreign_cc", 37 | init_submodules = True, 38 | commit="d54c78ab86b40770ee19f0949db9d74a831ab9f0", 39 | ) 40 | -------------------------------------------------------------------------------- /pir/proto/BUILD: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | package(default_visibility = ["//visibility:public"]) 18 | 19 | load("@rules_proto//proto:defs.bzl", "proto_library") 20 | 21 | proto_library( 22 | name = "payload_proto", 23 | srcs = ["payload.proto"], 24 | ) 25 | 26 | cc_proto_library( 27 | name = "payload_cc_proto", 28 | deps = [":payload_proto"], 29 | ) 30 | -------------------------------------------------------------------------------- /pir/proto/payload.proto: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2020 the authors listed in CONTRIBUTORS.md 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | syntax = "proto3"; 17 | 18 | package pir; 19 | 20 | // A set of ciphertexts, used for queries or responses. 21 | message Ciphertexts { 22 | repeated bytes ct = 1; 23 | } 24 | 25 | // Request sent from the client to the server. Includes 1 or more query 26 | // ciphertexts and a set of galois keys to be used. 27 | message Request { 28 | // Each query may have 1 or more ciphertexts. 29 | repeated Ciphertexts query = 1; 30 | 31 | // Galois keys, needed to compute substitution operation on ciphertexts. 32 | bytes galois_keys = 2; 33 | 34 | // Relinearization keys, only needed for recursion depths more than 1. 35 | bytes relin_keys = 3; 36 | } 37 | 38 | // Response to a query, a set of ciphertexts. 39 | message Response { 40 | // Reply to query as a set of 1 or more serialized ciphertexts. 41 | repeated Ciphertexts reply = 1; 42 | } 43 | 44 | // Private information retrieval setup parameters 45 | message PIRParameters { 46 | // Number of items in the database (NOT number of plaintexts) 47 | uint64 num_items = 1; 48 | 49 | // Number of plaintexts in the database 50 | uint64 num_pt = 4; 51 | 52 | // Size of each dimension in the database representation 53 | repeated uint32 dimensions = 2; 54 | 55 | // Serialized homomorphic encryption parameters 56 | bytes encryption_parameters = 3; 57 | 58 | // Number of bytes per database item 59 | uint32 bytes_per_item = 5; 60 | 61 | // Number of database items packed into each plaintext 62 | uint32 items_per_plaintext = 6; 63 | 64 | // Number of bits to pack into each plaintext coefficient 65 | uint32 bits_per_coeff = 7; 66 | 67 | // Set this to true to use CT multiplication instead of decomposition 68 | bool use_ciphertext_multiplication = 8; 69 | } 70 | -------------------------------------------------------------------------------- /third_party/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMined/PIR/e039662808a244a9042b4bee17213a7eba5388eb/third_party/BUILD -------------------------------------------------------------------------------- /third_party/seal.BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_foreign_cc//tools/build_defs:cmake.bzl", "cmake_external") 2 | 3 | filegroup( 4 | name = "src", 5 | srcs = glob(["**"]), 6 | visibility = ["//visibility:public"] 7 | ) 8 | 9 | cmake_external( 10 | name = "seal", 11 | cmake_options = [ 12 | "-DSEAL_USE_CXX17=ON", 13 | "-DSEAL_USE_INTRIN=ON", 14 | "-DSEAL_USE_MSGSL=OFF", 15 | "-DSEAL_USE_ZLIB=OFF", 16 | "-DSEAL_BUILD_TESTS=OFF", 17 | "-DBUILD_SHARED_LIBS=OFF", 18 | "-DCMAKE_BUILD_TYPE=Release", 19 | ], 20 | make_commands = [ 21 | "make -j", 22 | "make install" 23 | ], 24 | lib_source = ":src", 25 | install_prefix = "native/src", 26 | out_include_dir = "include/SEAL-3.5", 27 | static_libraries = ["libseal-3.5.a"], 28 | visibility = ["//visibility:public"], 29 | ) 30 | --------------------------------------------------------------------------------