├── cmake └── emp-aby-config.cmake ├── emp-aby ├── emp-aby.h ├── triple-providers │ ├── bit-triple.h │ ├── mp-bit-triple.h │ ├── bit-triple.hpp │ └── mp-bit-triple.hpp ├── io │ ├── mp_io_channel.h │ ├── util.hpp │ ├── multi-io-base.hpp │ └── multi-io.hpp ├── utils.h ├── converter │ ├── b2aconverter.h │ └── a2bconverter.h ├── wire.h ├── simd_interface │ ├── simd.h │ ├── arithmetic-circ.h │ └── mp-simd-exec.h ├── lut.h └── mp-circuit.hpp ├── run ├── .gitignore ├── .github └── workflows │ └── x86.yml ├── LICENSE ├── .clang-format ├── README.md ├── test ├── CMakeLists.txt ├── and.txt ├── positive.cpp ├── circuit.cpp ├── mp_circuit.cpp ├── mp_bit_triple.cpp ├── lut.cpp ├── positive.txt ├── b2aconverter.cpp ├── bit_triple.cpp ├── arithmetic_circ.cpp ├── positive2.txt ├── a2bconverter.cpp ├── adder64.txt ├── he.cpp ├── simd_exec.cpp └── mp_simd_exec.cpp ├── local_test.py └── CMakeLists.txt /cmake/emp-aby-config.cmake: -------------------------------------------------------------------------------- 1 | find_package(emp-ot) 2 | 3 | find_path(EMP-ABY_INCLUDE_DIR emp-aby/emp-aby.h) 4 | 5 | include(FindPackageHandleStandardArgs) 6 | 7 | find_package_handle_standard_args(emp-aby DEFAULT_MSG EMP-ABY_INCLUDE_DIR) 8 | 9 | if(EMP-ABY_FOUND) 10 | set(EMP-ABY_INCLUDE_DIRS ${EMP-TOOL_INCLUDE_DIRS} ${EMP-ABY_INCLUDE_DIR}) 11 | set(EMP-ABY_LIBRARIES ${EMP-TOOL_LIBRARIES}) 12 | endif() 13 | -------------------------------------------------------------------------------- /emp-aby/emp-aby.h: -------------------------------------------------------------------------------- 1 | #include "emp-aby/utils.h" 2 | #include "emp-aby/io/mp_io_channel.h" 3 | 4 | #include "emp-aby/triple-providers/bit-triple.h" 5 | #include "emp-aby/triple-providers/mp-bit-triple.h" 6 | 7 | #include "emp-aby/simd_interface/arithmetic-circ.h" 8 | #include "emp-aby/simd_interface/mp-simd-exec.h" 9 | #include "emp-aby/simd_interface/simd_exec.h" 10 | 11 | #include "emp-aby/he_interface.hpp" 12 | #include "emp-aby/lut.h" 13 | 14 | #include "emp-aby/converter/b2aconverter.h" 15 | #include "emp-aby/converter/a2bconverter.h" 16 | -------------------------------------------------------------------------------- /run: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ "$1" == "-p1" ] 3 | then 4 | shift 5 | perf record $1 1 12345 & (sleep 0.1; $1 2 12345) 6 | elif [ "$1" == "-p2" ] 7 | then 8 | shift 9 | (sleep 0.1; $1 1 12345) & (perf record $1 2 12345) 10 | 11 | elif [ "$1" == "-m1" ] 12 | then 13 | shift 14 | valgrind --leak-check=full $1 1 12345 & $1 2 12345 15 | elif [ "$1" == "-m2" ] 16 | then 17 | shift 18 | $1 1 12345 & valgrind --leak-check=full $1 2 12345 19 | elif [ "$1" == "-t1" ] 20 | then 21 | shift 22 | time $1 1 12345 & $1 2 12345 23 | elif [ "$1" == "-t2" ] 24 | then 25 | shift 26 | $1 1 12345 & time $1 2 12345 27 | 28 | else 29 | (sleep 0.05; $1 1 12345 ${@:2}) & $1 2 12345 ${@:2} 30 | fi 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Testing/ 2 | *.DS_Store 3 | *.txt.cpp 4 | *CTestTestfile.cmake 5 | .exrc 6 | build/ 7 | CMakeCache.txt 8 | CMakeFiles/ 9 | Makefile 10 | bin/ 11 | cmake_install.cmake 12 | install_manifest.txt 13 | # Compiled Object files 14 | *.slo 15 | *.lo 16 | *.o 17 | *.obj 18 | *.hex 19 | 20 | # Precompiled Headers 21 | *.gch 22 | *.pch 23 | 24 | # Compiled Dynamic libraries 25 | *.so 26 | *.dylib 27 | *.dll 28 | 29 | # Fortran module files 30 | *.mod 31 | 32 | # Compiled Static libraries 33 | *.lai 34 | *.la 35 | *.a 36 | *.lib 37 | 38 | # Executables 39 | *.exe 40 | *.out 41 | *.app 42 | 43 | #COT data files 44 | data/ 45 | 46 | #Editor 47 | .idea 48 | .ninja_deps 49 | .ninja_log 50 | build.ninja 51 | .vscode 52 | .cmake 53 | 54 | ip_file.txt 55 | -------------------------------------------------------------------------------- /emp-aby/triple-providers/bit-triple.h: -------------------------------------------------------------------------------- 1 | #ifndef BIT_TRIPLE_H__ 2 | #define BIT_TRIPLE_H__ 3 | #include "emp-ot/emp-ot.h" 4 | #include "emp-aby/utils.h" 5 | 6 | namespace emp { 7 | 8 | template 9 | class BitTripleProvider { 10 | public: 11 | BitTripleProvider(int party, int threads, IO** ios); 12 | 13 | void get_triple(block* a, block* b, block* c); //length should be the #blocks 14 | void get_triple(bool* a, bool* b, bool* c); // length ... # bools 15 | 16 | ~BitTripleProvider(); 17 | 18 | static long long int BUFFER_SZ; 19 | 20 | private: 21 | int party, threads; 22 | IO** ios; 23 | FerretCOT*ot0 = nullptr, *ot1 = nullptr; 24 | CRH crh; 25 | PRG prg; 26 | block delta; 27 | vector r0, r1, scratch, A_hat, A_star, B_hat; 28 | vector a_bool, b_bool, c_bool; 29 | void compute_rcots(bool* b); 30 | }; 31 | 32 | #include "bit-triple.hpp" 33 | } // namespace emp 34 | 35 | #endif //BIT_TRIPLE_H__ 36 | -------------------------------------------------------------------------------- /.github/workflows/x86.yml: -------------------------------------------------------------------------------- 1 | name: x86 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | build_x86: 6 | strategy: 7 | matrix: 8 | os: [ubuntu-latest] 9 | build_type: [Debug, Release] 10 | runs-on: ${{ matrix.os }} 11 | timeout-minutes: 60 12 | env: 13 | BUILD_TYPE: ${{matrix.build_type}} 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: install dependency 17 | run: | 18 | wget https://raw.githubusercontent.com/emp-toolkit/emp-readme/master/scripts/install.py 19 | python3 install.py --deps --tool --ot 20 | cd 21 | git clone https://github.com/openfheorg/openfhe-development.git --branch v1.0.4 22 | cd openfhe-development && mkdir build && cd build 23 | cmake .. && make -j8 && sudo make install 24 | cd 25 | - name: Create Build Environment 26 | run: cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DUSE_RANDOM_DEVICE=On && make 27 | - name: Test 28 | shell: bash 29 | run: | 30 | mkdir data 31 | make test 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Xiao Wang (wangxiao@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | Enquiries about further applications and development opportunities are welcome. 24 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | # https://clang.llvm.org/docs/ClangFormatStyleOptions.html 2 | # Manual run to reformat a file "foo.cpp": clang-format -i --style=file foo.cpp 3 | ####################################################################### 4 | BasedOnStyle: Google 5 | ####################################################################### 6 | Language: Cpp 7 | # DisableFormat: true 8 | AccessModifierOffset: -4 9 | # AlignArrayOfStructures: Right (enable this if clang-format-13 or newer is installed) 10 | AlignConsecutiveMacros: true 11 | AlignConsecutiveAssignments: true 12 | #AlignConsecutiveDeclarations: true 13 | AlignEscapedNewlines: Left 14 | AlignOperands: true 15 | AllowShortFunctionsOnASingleLine: Empty 16 | AllowShortIfStatementsOnASingleLine: Never 17 | AllowShortLambdasOnASingleLine: Inline 18 | AllowShortLoopsOnASingleLine: false 19 | BreakBeforeBraces: Custom 20 | BraceWrapping: 21 | BeforeCatch: true 22 | BeforeElse: true 23 | BreakBeforeTernaryOperators: false 24 | BreakStringLiterals: false 25 | ColumnLimit: 120 26 | DerivePointerAlignment: false 27 | #EmptyLineBeforeAccessModifier: true (enable this if clang-format-13 or newer is installed) 28 | IndentPPDirectives: BeforeHash 29 | IndentWidth: 4 30 | PointerAlignment: Left 31 | ReflowComments: false 32 | SortIncludes: false 33 | UseTab: Never 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scalable Mixed-Mode MPC 2 | 3 | # Installation 4 | 5 | ### Dependencies 6 | 7 | 1. `wget https://raw.githubusercontent.com/emp-toolkit/emp-readme/master/scripts/install.py` 8 | 2. `python install.py -install -tool -ot` 9 | 1. By default it will build for Release. `-DCMAKE_BUILD_TYPE=[Release|Debug]` option is also available. 10 | 2. No sudo? Change [`CMAKE_INSTALL_PREFIX`](https://cmake.org/cmake/help/v2.8.8/cmake.html#variable%3aCMAKE_INSTALL_PREFIX). 11 | 3. On Mac [homebrew](https://brew.sh/) is needed for installation. 12 | 3. Install openFHE 13 | ```console 14 | git clone https://github.com/openfheorg/openfhe-development.git --branch v1.0.4 15 | cd openfhe-development && mkdir build && cd build 16 | cmake .. && make -j8 && sudo make install 17 | ``` 18 | ### Build this project 19 | 20 | 1. Clone the repository and build using cmake. 21 | 22 | ```console 23 | git clone https://github.com/radhika1601/ScalableMixedModeMPC.git 24 | cd ScalableMixedModeMPC 25 | cmake . 26 | make 27 | ``` 28 | 29 | # Running tests 30 | 31 | To run the tests across multiple servers you can pass the ip configuration as a txt file with the following format. 32 | 33 | ``` 34 | 35 | 36 | ``` 37 | 38 | ### Question 39 | Please send email to radhikaradhika2028@u.northwestern.edu 40 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | #Testing macro 2 | macro (add_test_executable_with_lib _name libs) 3 | add_executable(test_${_name} "${_name}.cpp") 4 | target_link_libraries(test_${_name} ${EMP-OT_LIBRARIES}) 5 | endmacro() 6 | 7 | macro (add_test_case _name) 8 | add_test_executable_with_lib(${_name} "") 9 | add_test(NAME ${_name} COMMAND "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/test_${_name}" WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}/") 10 | endmacro() 11 | 12 | macro (add_test_case_with_runarg _name _arg) 13 | add_test_executable_with_lib(${_name} "") 14 | add_test(NAME ${_name} COMMAND "./run" "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/test_${_name}" "${_arg}" WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}/") 15 | endmacro() 16 | 17 | macro (add_test_case_with_run _name) 18 | add_test_case_with_runarg(${_name} "") 19 | endmacro() 20 | 21 | 22 | # Test cases 23 | add_test_case_with_run(bit_triple) 24 | add_test_case_with_run(simd_exec) 25 | add_test_case_with_runarg(circuit "test/relu.txt") 26 | add_test_case_with_runarg(positive "test/positive.txt") 27 | add_test_case_with_runarg(mp_bit_triple "2") 28 | add_test_case_with_runarg(mp_simd_exec "2") 29 | add_test_case_with_runarg(he "2") 30 | add_test_case_with_runarg(lut "2") 31 | add_test_case_with_runarg(b2aconverter "2") 32 | add_test_case_with_runarg(a2bconverter "2") 33 | add_test_case_with_runarg(arithmetic_circ "2") 34 | add_test_case_with_runarg(mp_circuit "emp-aby/modsum.txt 2") 35 | -------------------------------------------------------------------------------- /test/and.txt: -------------------------------------------------------------------------------- 1 | 65 193 2 | 64 64 0 3 | 4 | 2 1 0 64 128 AND 5 | 2 1 1 65 129 AND 6 | 2 1 2 66 130 AND 7 | 2 1 3 67 131 AND 8 | 2 1 4 68 132 AND 9 | 2 1 5 69 133 AND 10 | 2 1 6 70 134 AND 11 | 2 1 7 71 135 AND 12 | 2 1 8 72 136 AND 13 | 2 1 9 73 137 AND 14 | 2 1 10 74 138 AND 15 | 2 1 11 75 139 AND 16 | 2 1 12 76 140 AND 17 | 2 1 13 77 141 AND 18 | 2 1 14 78 142 AND 19 | 2 1 15 79 143 AND 20 | 2 1 16 80 144 AND 21 | 2 1 17 81 145 AND 22 | 2 1 18 82 146 AND 23 | 2 1 19 83 147 AND 24 | 2 1 20 84 148 AND 25 | 2 1 21 85 149 AND 26 | 2 1 22 86 150 AND 27 | 2 1 23 87 151 AND 28 | 2 1 24 88 152 AND 29 | 2 1 25 89 153 AND 30 | 2 1 26 90 154 AND 31 | 2 1 27 91 155 AND 32 | 2 1 28 92 156 AND 33 | 2 1 29 93 157 AND 34 | 2 1 30 94 158 AND 35 | 2 1 31 95 159 AND 36 | 2 1 32 96 160 AND 37 | 2 1 33 97 161 AND 38 | 2 1 34 98 162 AND 39 | 2 1 35 99 163 AND 40 | 2 1 36 100 164 AND 41 | 2 1 37 101 165 AND 42 | 2 1 38 102 166 AND 43 | 2 1 39 103 167 AND 44 | 2 1 40 104 168 AND 45 | 2 1 41 105 169 AND 46 | 2 1 42 106 170 AND 47 | 2 1 43 107 171 AND 48 | 2 1 44 108 172 AND 49 | 2 1 45 109 173 AND 50 | 2 1 46 110 174 AND 51 | 2 1 47 111 175 AND 52 | 2 1 48 112 176 AND 53 | 2 1 49 113 177 AND 54 | 2 1 50 114 178 AND 55 | 2 1 51 115 179 AND 56 | 2 1 52 116 180 AND 57 | 2 1 53 117 181 AND 58 | 2 1 54 118 182 AND 59 | 2 1 55 119 183 AND 60 | 2 1 56 120 184 AND 61 | 2 1 57 121 185 AND 62 | 2 1 58 122 186 AND 63 | 2 1 59 123 187 AND 64 | 2 1 60 124 188 AND 65 | 2 1 61 125 189 AND 66 | 2 1 62 126 190 AND 67 | 2 1 63 127 191 AND 68 | 2 1 0 0 192 XOR 69 | -------------------------------------------------------------------------------- /local_test.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('-n', '--num_parties', action='store', dest='n', type = int, required = True) 6 | parser.add_argument('-c', '--command', action='store', dest='c', type = str, required = True) 7 | parser.add_argument('-s', '--start_party', action='store', dest='s', type = int, required = True) 8 | parser.add_argument('-e', '--end_party', action='store', dest='e', type = int, required = True) 9 | parser.add_argument('-p', '--num_processes', action='store', dest='p', type = int, required = False) 10 | args = parser.parse_args() 11 | 12 | def test(n): 13 | s = "" 14 | for i in range(args.s, args.e): 15 | # if i == 5: continue 16 | if args.c == 'yao': 17 | s += "./bin/test_"+ args.c +" " + ( i + 1).__str__() + " 12345 test/and.txt " + n.__str__() + " ip_file.txt" 18 | if args.c == 'mp_circuit': 19 | s += "./bin/test_"+ args.c +" " + ( i + 1).__str__() + " 12345 ../emp-aby/modsum.txt " + n.__str__() 20 | else: 21 | s += "./bin/test_"+ args.c +" " + ( i + 1).__str__() + " 12345 " + n.__str__() + " ip_file.txt" 22 | if i != args.e-1: 23 | s += " & " 24 | print(s) 25 | return s 26 | 27 | sleep_time_batch = 70 28 | #test(args.n) 29 | if args.p: 30 | s = "" 31 | sleep_time = 0 32 | for i in range(args.p): 33 | s += "sleep " + (sleep_time * sleep_time_batch).__str__() + "; " + test(args.n) 34 | sleep_time += 1 35 | if i != args.p-1: 36 | s += " & " 37 | 38 | print(s) 39 | subprocess.Popen(s, shell="True") 40 | else: 41 | subprocess.Popen(test(args.n), shell="True") 42 | -------------------------------------------------------------------------------- /emp-aby/io/mp_io_channel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace emp { 6 | enum MESSAGE_TYPE : uint8_t { NORM_MSG = 0, BOOT_REQ_MSG = 1, BOOT_RSP_MSG = 2, TERMINATE_MSG = 3 }; 7 | template 8 | class MPIOChannel { 9 | public: 10 | virtual void send_data(int dst, const void* data, int len, int j = 0, MESSAGE_TYPE msg_type = NORM_MSG) = 0; 11 | virtual void recv_data(int src, void* data, int len, int j = 0, MESSAGE_TYPE msg_type = NORM_MSG) = 0; 12 | virtual void* recv_data(int src, int& len, int j = 0, MESSAGE_TYPE msg_type = NORM_MSG) = 0; 13 | virtual void send_bool(int dst, bool* data, int length, int j = 0) = 0; 14 | virtual void recv_bool(int src, bool* data, int length, int j = 0) = 0; 15 | virtual void send_block(int dst, const block* data, int length, int j = 0) = 0; 16 | virtual void recv_block(int src, block* data, int length, int j = 0) = 0; 17 | virtual void sync() = 0; 18 | virtual void flush(int idx = 0, int j = 0) = 0; 19 | virtual T*& get(size_t idx, bool b = false) = 0; 20 | virtual ~MPIOChannel() = 0; 21 | virtual int get_total_bytes_sent() = 0; 22 | }; 23 | 24 | template 25 | MPIOChannel::~MPIOChannel() {} 26 | } // namespace emp 27 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 3.0) 2 | project (emp-aby) 3 | set(NAME "emp-aby") 4 | 5 | set(CMAKE_BUILD_TYPE Debug) 6 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}-O3 -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -Wno-sign-compare") 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenFHE_CXX_FLAGS} -O3 -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -DMATHBACKEND=4") 8 | 9 | if(NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") 10 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") 11 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DLIBS_USE_LIBUV -DLIBS_USE_OPENSSL") 12 | endif() 13 | 14 | set(CMAKE_CXX_STANDARD 17) 15 | 16 | find_path(CMAKE_FOLDER NAMES cmake/emp-tool-config.cmake) 17 | include(${CMAKE_FOLDER}/cmake/emp-base.cmake) 18 | 19 | find_package(OpenFHE REQUIRED) 20 | 21 | include_directories(${OPENMP_INCLUDES}) 22 | include_directories(${OpenFHE_INCLUDE}) 23 | include_directories(${OpenFHE_INCLUDE}/third-party/include) 24 | include_directories(${OpenFHE_INCLUDE}/core) 25 | include_directories(${OpenFHE_INCLUDE}/pke) 26 | link_directories(${OpenFHE_LIBDIR}) 27 | link_directories(${OPENMP_LIBRARIES}) 28 | if(BUILD_STATIC) 29 | set( CMAKE_EXE_LINKER_FLAGS "${OpenFHE_EXE_LINKER_FLAGS} -static") 30 | link_libraries( ${OpenFHE_STATIC_LIBRARIES} ) 31 | else() 32 | set( CMAKE_EXE_LINKER_FLAGS ${OpenFHE_EXE_LINKER_FLAGS} ) 33 | link_libraries( ${OpenFHE_SHARED_LIBRARIES} ) 34 | endif() 35 | 36 | find_package(emp-ot REQUIRED) 37 | include_directories(${EMP-OT_INCLUDE_DIRS}) 38 | # target_link_libraries(emp-aby ${OPENFHE_LIBS}) 39 | 40 | # Installation 41 | install(FILES cmake/emp-aby-config.cmake DESTINATION cmake/) 42 | install(DIRECTORY emp-aby DESTINATION include/) 43 | 44 | ENABLE_TESTING() 45 | ADD_SUBDIRECTORY(test) 46 | 47 | 48 | -------------------------------------------------------------------------------- /emp-aby/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "emp-tool/emp-tool.h" 3 | namespace emp { 4 | 5 | inline void andBlocks_arr(block* res, const block* x, const block* y, int nblocks) { 6 | const block* dest = nblocks + x; 7 | for (; x != dest;) 8 | *(res++) = *(x++) & *(y++); 9 | } 10 | 11 | inline void andBlocks_arr(block* res, const block* x, block y, int nblocks) { 12 | const block* dest = nblocks + x; 13 | for (; x != dest;) 14 | *(res++) = *(x++) & y; 15 | } 16 | 17 | inline void xorBools_arr(bool* res, const bool* x, const bool* y, int nbools) { 18 | const bool* dest = nbools + x; 19 | for (; x != dest;) 20 | *(res++) = *(x++) ^ *(y++); 21 | } 22 | 23 | inline void andBools_arr(bool* res, const bool* x, const bool* y, int nbools) { 24 | const bool* dest = nbools + x; 25 | // Don't change & to && 26 | for (; x != dest;) 27 | *(res++) = *(x++) & *(y++); 28 | } 29 | 30 | inline void bool_to_block_arr(block* a, bool* data, int bool_length) { 31 | uint64_t* data64 = (uint64_t*)data; 32 | uint64_t low = 0, high = 0; 33 | int i = 0; 34 | for (; i < bool_length / 8; ++i) { 35 | unsigned long long mask = 0x0101010101010101ULL; 36 | unsigned long long tmp = 0; 37 | #if defined(__BMI2__) 38 | tmp = _pext_u64(data64[i], mask); 39 | #else 40 | // https://github.com/Forceflow/libmorton/issues/6 41 | for (unsigned long long bb = 1; mask != 0; bb += bb) { 42 | if (data64[i] & mask & -mask) { 43 | tmp |= bb; 44 | } 45 | mask &= (mask - 1); 46 | } 47 | #endif 48 | if (i > 0 && i % 16 == 0) { 49 | a[i / 16 - 1] = makeBlock(high, low); 50 | high = 0; 51 | low = tmp; 52 | } 53 | else if (i % 16 < 8) { 54 | low += tmp << (i % 16) * 8; 55 | } 56 | else if (i % 16 >= 8) { 57 | high += tmp << ((i % 16 - 8) * 8); 58 | } 59 | } 60 | if (i % 16 == 0) { 61 | a[i / 16 - 1] = makeBlock(high, low); 62 | } 63 | else { 64 | a[i / 16] = makeBlock(high, low); 65 | } 66 | } 67 | } // namespace emp 68 | -------------------------------------------------------------------------------- /test/positive.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/mp-circuit.hpp" 2 | #include 3 | 4 | int party, port; 5 | 6 | const static int threads = 1; 7 | 8 | unsigned char mpc_main(long long INPUT_A, long long INPUT_B) { 9 | long long value = (INPUT_A + INPUT_B) & ((1ULL << 61) - 1); 10 | return value > ((1ULL << 60) - 1) ? 1 : 0; 11 | } 12 | 13 | void test_circuit(const char* file, int party, NetIO** ios, const int num, SIMDCircExec* simd_circ) { 14 | PRG prg; 15 | Circuit>* circuit = new Circuit>(file, party, simd_circ); 16 | bool* in; 17 | if (party == ALICE) { 18 | in = new bool[(circuit->n1) * num]; 19 | prg.random_bool(in, (circuit->n1) * num); 20 | } 21 | else { 22 | in = new bool[(circuit->n2) * num]; 23 | prg.random_bool(in, (circuit->n2) * num); 24 | } 25 | 26 | bool* out = new bool[(circuit->n3) * num]; 27 | auto start = clock_start(); 28 | for (int j = 0; j < 50; ++j) { 29 | circuit->compute(out, in, num); 30 | } 31 | long long t = time_from(start); 32 | 33 | NetIO* io = ios[0]; 34 | if (party == ALICE) { 35 | io->send_bool(in, (circuit->n1) * num); 36 | io->send_bool(out, (circuit->n3) * num); 37 | } 38 | else { 39 | bool* in1 = new bool[(circuit->n1) * num]; 40 | bool* out1 = new bool[(circuit->n3) * num]; 41 | io->recv_bool(in1, (circuit->n1) * num); 42 | io->recv_bool(out1, (circuit->n3) * num); 43 | xorBools_arr(out, out1, out, (circuit->n3) * num); 44 | for (int i = 0; i < num; ++i) { 45 | long long A = bool_to_int(in1 + i * circuit->n1); 46 | long long B = bool_to_int(in + i * circuit->n2); 47 | bool output = out[i]; 48 | bool sim_output = mpc_main(A, B); 49 | if (output != sim_output) { 50 | error("Test Failed!"); 51 | } 52 | } 53 | } 54 | 55 | std::cout << "Compute Done: " << t << " us" << std::endl; 56 | } 57 | int main(int argc, char** argv) { 58 | parse_party_and_port(argv, &party, &port); 59 | const char* file = argv[3]; 60 | vector ios; 61 | for (int i = 0; i < threads; ++i) 62 | ios.push_back(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port)); 63 | 64 | SIMDCircExec* simd_circ = new SIMDCircExec(party, threads, ios.data()); 65 | test_circuit(file, party, ios.data(), 10000, simd_circ); 66 | delete simd_circ; 67 | } 68 | -------------------------------------------------------------------------------- /test/circuit.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/mp-circuit.hpp" 2 | #include 3 | 4 | int party, port; 5 | 6 | const static int threads = 1; 7 | 8 | long long mpc_main(long long INPUT_A, long long INPUT_B, long long INPUT_B2) { 9 | long long value = (INPUT_A + INPUT_B) & ((1LL << 61) - 1); 10 | long long relu = value < ((1LL << 60) - 1) ? value : 0; 11 | return relu + ((1LL << 61) - 1) - INPUT_B2; 12 | } 13 | 14 | void test_circuit(const char* file, int party, NetIO** ios, const int num, SIMDCircExec* simd_circ) { 15 | PRG prg; 16 | Circuit>* circuit = new Circuit>(file, party, simd_circ); 17 | 18 | bool* in = new bool[(circuit->n1 + circuit->n2) * num]; 19 | 20 | bool* out = new bool[(circuit->n3) * num]; 21 | 22 | for (int i = 0; i < 5; ++i) { 23 | prg.random_bool(in, (circuit->n1 + circuit->n2) * num); 24 | 25 | auto start = clock_start(); 26 | circuit->compute(out, in, num); 27 | long long t = time_from(start); 28 | 29 | NetIO* io = ios[0]; 30 | if (party == ALICE) { 31 | io->send_bool(in, (circuit->n1) * num); 32 | io->send_bool(out, (circuit->n3) * num); 33 | } 34 | else { 35 | bool* in1 = new bool[(circuit->n1) * num]; 36 | bool* out1 = new bool[(circuit->n3) * num]; 37 | io->recv_bool(in1, (circuit->n1) * num); 38 | io->recv_bool(out1, (circuit->n3) * num); 39 | xorBools_arr(out, out1, out, (circuit->n3) * num); 40 | for (int i = 0; i < num; ++i) { 41 | long long A = bool_to_int(in1 + i * circuit->n1); 42 | long long B = bool_to_int(in + i * circuit->n2); 43 | long long B2 = bool_to_int(in + i * circuit->n2 + 64); 44 | long long output = bool_to_int(out + i * circuit->n3); 45 | long long sim_output = mpc_main(A, B, B2); 46 | if (output != sim_output) { 47 | error("Test Failed!"); 48 | } 49 | } 50 | } 51 | std::cout << "Compute Done: " << t << " us" << std::endl; 52 | } 53 | delete[] in; 54 | delete[] out; 55 | } 56 | 57 | int main(int argc, char** argv) { 58 | parse_party_and_port(argv, &party, &port); 59 | const char* file = argv[3]; 60 | vector ios; 61 | for (int i = 0; i < threads; ++i) 62 | ios.push_back(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port)); 63 | 64 | SIMDCircExec* simd_circ = new SIMDCircExec(party, threads, ios.data()); 65 | test_circuit(file, party, ios.data(), 10000, simd_circ); 66 | std::cout << "num and gates " << simd_circ->num_and_gates << " " << simd_circ->depth << "\n"; 67 | delete simd_circ; 68 | for (auto io : ios) 69 | delete io; 70 | } 71 | -------------------------------------------------------------------------------- /emp-aby/converter/b2aconverter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "emp-aby/lut.h" 4 | #include "emp-aby/mp-circuit.hpp" 5 | #include "emp-aby/simd_interface/arithmetic-circ.h" 6 | 7 | namespace emp { 8 | 9 | template 10 | class B2AConverter { 11 | private: 12 | LUT* bit_to_a; 13 | MPIOChannel* io; 14 | int num_party, party; 15 | ThreadPool* pool; 16 | HE* he; 17 | 18 | public: 19 | B2AConverter(int num_party, int party, MPIOChannel* io, ThreadPool* pool, HE* he, int pool_size); 20 | 21 | /** 22 | * @brief This function converts bit vector for boolean shares to arithmetic shares 23 | * 24 | * @param out Output arithmetic share array 25 | * @param in Input array of boolean shares 26 | * @param length length of boolean array 27 | * @param l length of each bit-vector 28 | */ 29 | void convert(int64_t* out, bool* in, size_t length, int l); 30 | }; 31 | 32 | template 33 | B2AConverter::B2AConverter(int num_party, int party, MPIOChannel* io, ThreadPool* pool, HE* he, int pool_size) { 34 | int64_t table[2] = {0, 1}; 35 | this->pool = pool; 36 | this->io = io; 37 | this->num_party = num_party; 38 | this->party = party; 39 | this->he = he; 40 | // std::cout << "go in lut constructor \n"; 41 | this->bit_to_a = new LUT(num_party, party, io, pool, he, table, pool_size); 42 | } 43 | 44 | template 45 | void B2AConverter::convert(int64_t* out, bool* in, size_t length, int l) { 46 | if (length % l != 0) 47 | error("Length of boolean array is not divisible by length of each bit vector."); 48 | size_t n = length / l; 49 | int64_t* in_ashare = new int64_t[length]; 50 | this->bit_to_a->lookup(in_ashare, in, length); 51 | memset(out, 0, n * sizeof(int64_t)); 52 | 53 | size_t threads = pool->size(); 54 | size_t num_steps = ceil((double)n / (double)threads); 55 | vector> res; 56 | for (size_t t = 0; t < threads; ++t) { 57 | res.push_back(pool->enqueue([this, n, t, num_steps, in_ashare, out, l]() { 58 | for (size_t step = 0; step < num_steps; ++step) { 59 | size_t j = t * num_steps + step; 60 | if (j >= n) 61 | break; 62 | 63 | for (size_t i = 0; i < l; ++i) { 64 | uint64_t x; 65 | in_ashare[j * l + i] = (in_ashare[j * l + i] % he->q + he->q) % he->q; 66 | x = in_ashare[j * l + i]; 67 | x = x << i; 68 | x %= he->q; 69 | out[j] = (out[j] + x) % he->q; 70 | } 71 | } 72 | })); 73 | } 74 | for (auto& fut : res) 75 | fut.get(); 76 | res.clear(); 77 | } 78 | 79 | } // namespace emp 80 | -------------------------------------------------------------------------------- /emp-aby/wire.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "emp-aby/simd_interface/simd_exec.h" 4 | #include 5 | namespace emp { 6 | 7 | #define INPUT 0 8 | #define AND 1 9 | #define XOR 2 10 | #define INV 3 11 | //#define __debug__ 12 | class Wire { 13 | public: 14 | int type; 15 | block* value; 16 | bool* rem_value; 17 | uint32_t level; 18 | Wire* in1 = nullptr; 19 | Wire* in2 = nullptr; 20 | bool set = false; 21 | int num_required = 0; 22 | int num_used = 0; 23 | 24 | Wire(int type) { 25 | if (type == INPUT) { 26 | this->type = type; 27 | this->level = 0; 28 | } 29 | else { 30 | this->type = type; 31 | } 32 | } 33 | 34 | Wire(int type, Wire* in1) { 35 | if (type == INV) { 36 | this->type = type; 37 | this->in1 = in1; 38 | this->level = in1->level; 39 | this->in1->num_required = this->in1->num_required + 1; 40 | } 41 | else if (type == INPUT) 42 | error("Input wire has no imcoming!"); 43 | else 44 | error("AND and XOR require 2 inputs!"); 45 | } 46 | 47 | Wire(int type, Wire* in1, Wire* in2) { 48 | this->type = type; 49 | this->in1 = in1; 50 | this->in2 = in2; 51 | this->in1->num_required = this->in1->num_required + 1; 52 | this->in2->num_required = this->in2->num_required + 1; 53 | switch (this->type) { 54 | case INPUT: 55 | case INV: 56 | error("Cannot set second incoming wire!"); 57 | case AND: 58 | this->level = std::max(this->in1->level, this->in2->level) + 1; 59 | break; 60 | case XOR: 61 | this->level = std::max(this->in1->level, this->in2->level); 62 | break; 63 | } 64 | } 65 | void initialise_value(int n, int m = 0) { 66 | if (this->set == true && this->num_required > 0) 67 | error("Wire being initialised twice!"); 68 | this->value = (block*)malloc(n * sizeof(block)); 69 | memset(this->value, 0, n * sizeof(block)); 70 | this->rem_value = (bool*)malloc(m * sizeof(bool)); 71 | memset(this->rem_value, 0, m * sizeof(bool)); 72 | } 73 | void set_value(block* value, int n, bool* rem_value, int m) { 74 | memcpy(this->value, value, n * sizeof(block)); 75 | memcpy(this->rem_value, rem_value, m * sizeof(bool)); 76 | if (this->num_required > 0) 77 | this->set = true; 78 | } 79 | void reset_value() { 80 | this->num_used = this->num_used + 1; 81 | if (this->num_used >= this->num_required) { 82 | free(this->value); 83 | this->set = false; 84 | this->num_used = 0; 85 | free(this->rem_value); 86 | } 87 | } 88 | }; 89 | 90 | } // namespace emp -------------------------------------------------------------------------------- /emp-aby/io/util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "emp-tool/emp-tool.h" 5 | 6 | namespace emp { 7 | void accept_base_connections(int expected_connections, int bind_port, string bind_address, 8 | std::map& socket_map, int party, int num_party) { 9 | int reuse = 1; 10 | struct sockaddr_in dest; 11 | struct sockaddr_in serv; 12 | 13 | memset(&serv, 0, sizeof(serv)); 14 | serv.sin_family = AF_INET; 15 | serv.sin_addr.s_addr = inet_addr(bind_address.c_str()); /* set our address to any interface */ 16 | serv.sin_port = htons(bind_port); /* set the server port number */ 17 | socklen_t socksize = sizeof(struct sockaddr_in); 18 | int mysocket = socket(AF_INET, SOCK_STREAM, 0); 19 | 20 | setsockopt(mysocket, SOL_SOCKET, SO_REUSEADDR, (const char*)&reuse, sizeof(reuse)); 21 | if (bind(mysocket, (struct sockaddr*)&serv, sizeof(struct sockaddr)) < 0) { 22 | error("error: bind"); 23 | } 24 | 25 | if (listen(mysocket, expected_connections) < 0) { 26 | error("error: listen"); 27 | } 28 | for (int i = 0; i < expected_connections; ++i) { 29 | int consocket = accept(mysocket, (struct sockaddr*)&dest, &socksize); 30 | 31 | const int one = 1; 32 | setsockopt(consocket, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); 33 | int p; 34 | read(consocket, &p, sizeof(p)); 35 | if (p > num_party) { 36 | error("connected to wrong party"); 37 | } 38 | if (socket_map.find(p) != socket_map.end()) { 39 | error("already connected to this party"); 40 | } 41 | if (send(consocket, &party, sizeof(party), 0) < 0) { 42 | error("error: send"); 43 | } 44 | socket_map.emplace(std::make_pair(p, consocket)); 45 | } 46 | } 47 | 48 | int request_base_connection(int p, std::vector>& net_config, int party, 49 | int num_party) { 50 | struct sockaddr_in dest; 51 | memset(&dest, 0, sizeof(dest)); 52 | dest.sin_family = AF_INET; 53 | dest.sin_addr.s_addr = inet_addr(net_config[p - 1].first.c_str()); 54 | dest.sin_port = htons(net_config[p - 1].second); 55 | 56 | int consocket; 57 | 58 | while (1) { 59 | consocket = socket(AF_INET, SOCK_STREAM, 0); 60 | 61 | if (connect(consocket, (struct sockaddr*)&dest, sizeof(struct sockaddr)) == 0) { 62 | break; 63 | } 64 | 65 | close(consocket); 66 | usleep(1000); 67 | } 68 | const int one = 1; 69 | setsockopt(consocket, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); 70 | send(consocket, &party, sizeof(party), 0); 71 | int new_p; 72 | read(consocket, &new_p, sizeof(new_p)); 73 | if (new_p != p) { 74 | std::cout << party << ": " << new_p << " != " << p << std::endl; 75 | error("connected to wrong party"); 76 | } 77 | 78 | return consocket; 79 | } 80 | 81 | } // namespace emp -------------------------------------------------------------------------------- /emp-aby/io/multi-io-base.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "emp-aby/io/util.hpp" 6 | 7 | namespace emp { 8 | class MultiIOBase : public IOChannel { 9 | private: 10 | public: 11 | int consocket = -1; 12 | std::shared_mutex sock_mutex; 13 | bool continue_comm = false; 14 | std::deque> recv_msg_queue[3]; 15 | std::condition_variable recv_condition_vars[3]; 16 | std::mutex recv_mutex[3]; 17 | 18 | MultiIOBase(int consocket, bool quiet = false) : consocket(consocket), continue_comm(true) { 19 | set_nodelay(); 20 | if (!quiet) 21 | std::cout << "connected\n"; 22 | } 23 | 24 | void sync() {} 25 | void flush() {} 26 | 27 | ~MultiIOBase() { 28 | close(consocket); 29 | } 30 | 31 | void set_nodelay() { 32 | const int one = 1; 33 | setsockopt(consocket, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); 34 | } 35 | 36 | void set_delay() { 37 | const int zero = 0; 38 | setsockopt(consocket, IPPROTO_TCP, TCP_NODELAY, &zero, sizeof(zero)); 39 | } 40 | 41 | void send_msg(const void* data, int len, MESSAGE_TYPE msg_type = NORM_MSG) { 42 | char* meta_buff = (char*)malloc(5); 43 | meta_buff[0] = msg_type; 44 | memcpy(meta_buff + 1, &(len), 4); 45 | std::shared_lock lock(sock_mutex); 46 | this->send_data(meta_buff, 5); 47 | if(msg_type != TERMINATE_MSG) 48 | this->send_data(data, len); 49 | lock.unlock(); 50 | free(meta_buff); 51 | } 52 | 53 | bool recv_msg() { 54 | char* meta_buff = (char*)malloc(5); 55 | std::shared_lock lock(sock_mutex); 56 | 57 | // add select here 58 | // select(consocket, ); 59 | 60 | this->recv_data(meta_buff, 5); 61 | MESSAGE_TYPE recv_type = static_cast(meta_buff[0]); 62 | if (recv_type == TERMINATE_MSG) { 63 | lock.unlock(); 64 | this->continue_comm = false; 65 | return false; 66 | } 67 | 68 | int len; 69 | memcpy(&(len), meta_buff + 1, 4); 70 | // std::cout << "received length" << len << std::endl; 71 | void* data = malloc(len); 72 | this->recv_data(data, len); 73 | lock.unlock(); 74 | 75 | std::unique_lock que_lock(recv_mutex[recv_type]); 76 | recv_msg_queue[recv_type].push_back(std::pair(len, data)); 77 | que_lock.unlock(); 78 | recv_condition_vars[recv_type].notify_one(); 79 | return true; 80 | } 81 | 82 | void send_data_internal(const void* data, size_t len) { 83 | size_t sent = 0; 84 | while (sent < len) { 85 | size_t res = write(consocket, (char*)data + sent, len - sent); 86 | if (res > 0) 87 | sent += res; 88 | else 89 | error("net_send_data\n"); 90 | } 91 | } 92 | 93 | void recv_data_internal(void* data, size_t len) { 94 | size_t recvd = 0; 95 | while (recvd < len) { 96 | size_t res = read(consocket, (char*)data + recvd, len - recvd); 97 | if (res > 0) 98 | recvd += res; 99 | else 100 | error("net_recv_data\n"); 101 | } 102 | } 103 | }; 104 | 105 | } // namespace emp 106 | -------------------------------------------------------------------------------- /test/mp_circuit.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/mp-circuit.hpp" 2 | #include "emp-aby/io/multi-io.hpp" 3 | 4 | #include 5 | 6 | int party, port, num_party; 7 | 8 | const static int threads = 4; 9 | int64_t mod = ((1UL << 32) - (1UL << 30) + 1); 10 | 11 | uint32_t mpc_main(uint32_t INPUT_A, uint32_t INPUT_B) { 12 | uint64_t x = (uint64_t)INPUT_A + (uint64_t)INPUT_B; 13 | 14 | if (x > mod) 15 | x = x - mod; 16 | return (uint32_t)x; 17 | } 18 | 19 | template 20 | void test_circuit(const char* file, int party, MPIOChannel* io, const int num, MPSIMDCircExec* simd_circ, 21 | Circuit>* circuit) { 22 | PRG prg; 23 | 24 | bool* in = new bool[(circuit->n1 + circuit->n2) * num]; 25 | prg.random_bool(in, (circuit->n1 + circuit->n2) * num); 26 | 27 | bool* out = new bool[(circuit->n3) * num]; 28 | circuit->template compute(out, in, num); 29 | io->sync(); 30 | if (party == ALICE) { 31 | bool* in_tmp = new bool[(circuit->n1 + circuit->n2) * num]; 32 | bool* out_tmp = new bool[(circuit->n3) * num]; 33 | for (int i = 2; i <= num_party; ++i) { 34 | io->recv_bool(i, in_tmp, (circuit->n1 + circuit->n2) * num); 35 | io->recv_bool(i, out_tmp, (circuit->n3) * num); 36 | xorBools_arr(in, in, in_tmp, (circuit->n1 + circuit->n2) * num); 37 | xorBools_arr(out, out, out_tmp, circuit->n3 * num); 38 | io->flush(i); 39 | } 40 | for (int i = 0; i < num; ++i) { 41 | uint32_t A = bool_to_int(in + i * circuit->n1); 42 | uint32_t B = bool_to_int(in + num * circuit->n1 + i * circuit->n2); 43 | uint32_t output = bool_to_int(out + i * circuit->n3); 44 | uint32_t sim_output = mpc_main(A, B); 45 | if ((output + mod) % mod != (sim_output + mod) % mod) { 46 | std::cout << A << " " << B << " " << output << " " << sim_output << "\n"; 47 | error("Test Failed!"); 48 | } 49 | } 50 | std::cout << "Test passed" << std::endl; 51 | } 52 | else { 53 | io->send_bool(ALICE, in, (circuit->n1 + circuit->n2) * num); 54 | io->send_bool(ALICE, out, (circuit->n3) * num); 55 | io->flush(ALICE); 56 | } 57 | } 58 | 59 | int main(int argc, char** argv) { 60 | if (argc < 5) { 61 | std::cout << "Format: test_mp_bit_triple PartyID Port Circuit_file num_parties" << std::endl; 62 | exit(0); 63 | } 64 | parse_party_and_port(argv, &party, &port); 65 | const char* file = argv[3]; 66 | num_party = atoi(argv[4]); 67 | 68 | std::vector> net_config; 69 | 70 | if (argc == 6) { 71 | const char* file = argv[5]; 72 | FILE* f = fopen(file, "r"); 73 | for (int i = 0; i < num_party; ++i) { 74 | char* c = (char*)malloc(15 * sizeof(char)); 75 | uint p; 76 | fscanf(f, "%s %d\n", c, &p); 77 | std::string s(c); 78 | net_config.push_back(std::make_pair(s, p)); 79 | fflush(f); 80 | } 81 | fclose(f); 82 | } 83 | else { 84 | for (int i = 0; i < num_party; ++i) { 85 | std::string s = "127.0.0.1"; 86 | uint p = (port + 4 * num_party * i); 87 | net_config.push_back(std::make_pair(s, p)); 88 | } 89 | } 90 | 91 | MultiIO* io = new MultiIO(party, num_party, net_config); 92 | std::cout << party << " connected \n"; 93 | ThreadPool pool(threads); 94 | 95 | io->setup_ot_ios(); 96 | io->flush(); 97 | MPSIMDCircExec* simd_circ = new MPSIMDCircExec(num_party, party, &pool, io); 98 | std::cout << "SIMD CIRC setup \n"; 99 | Circuit>* circuit = new Circuit>(file, party, simd_circ); 100 | std::cout << "Circuit setup done \n"; 101 | 102 | test_circuit(file, party, io, 5000, simd_circ, circuit); 103 | test_circuit(file, party, io, 10000, simd_circ, circuit); 104 | delete io; 105 | } 106 | -------------------------------------------------------------------------------- /emp-aby/simd_interface/simd.h: -------------------------------------------------------------------------------- 1 | #ifndef EMP__SIMD_CIRCUIT_EXECUTION_H__ 2 | #define EMP__SIMD_CIRCUIT_EXECUTION_H__ 3 | #include "emp-aby/triple-providers/bit-triple.h" 4 | 5 | namespace emp { 6 | 7 | template 8 | class SIMDCircuitExecution { 9 | public: 10 | #ifndef THREADING 11 | static SIMDCircuitExecution* simd_circ_exec; 12 | #else 13 | static __thread SIMDCircuitExecution* simd_circ_exec; 14 | #endif 15 | virtual BTP* getBtp() = 0; 16 | virtual void and_gate(block* out1, block* in1, block* in2, size_t length) = 0; 17 | virtual void xor_gate(block* out1, block* in1, block* in2, size_t length) = 0; 18 | virtual void not_gate(block* out1, block* in1, size_t length) = 0; 19 | virtual void and_gate(bool* out, bool* in1, bool* in2, size_t bool_length, block* block_out, block* block_in1, 20 | block* block_in2, size_t length) = 0; 21 | // virtual void mux_gate(block** out1, block** in1, block** in2, bool** select, int width, size_t length); 22 | virtual void and_gate(bool* out1, bool* in1, bool* in2, size_t length) = 0; 23 | virtual void xor_gate(bool* out1, bool* in1, bool* in2, size_t length) = 0; 24 | virtual void not_gate(bool* out1, bool* in1, size_t length) = 0; 25 | 26 | virtual ~SIMDCircuitExecution() {} 27 | 28 | protected: 29 | 30 | template 31 | inline void and_helper(T*& a, T*& b, T*& c, size_t length, bool& delete_array, T* bit_triple_a, T* bit_triple_b, 32 | T* bit_triple_c, size_t num_triples_pool, size_t& num_triples) { 33 | BTP* btp = getBtp(); 34 | if (length > num_triples_pool) { 35 | a = new T[(length + num_triples_pool - 1) / num_triples_pool * num_triples_pool]; 36 | b = new T[(length + num_triples_pool - 1) / num_triples_pool * num_triples_pool]; 37 | c = new T[(length + num_triples_pool - 1) / num_triples_pool * num_triples_pool]; 38 | for (uint i = 0; i < (length + num_triples_pool - 1) / num_triples_pool; ++i) 39 | btp->get_triple(a + i * num_triples_pool, b + i * num_triples_pool, c + i * num_triples_pool); 40 | size_t tocp = 41 | min((length + num_triples_pool - 1) / num_triples_pool * num_triples_pool - length, num_triples); 42 | memcpy(bit_triple_a, a + (length + num_triples_pool - 1) / num_triples_pool * num_triples_pool - length, 43 | tocp); 44 | memcpy(bit_triple_b, b + (length + num_triples_pool - 1) / num_triples_pool * num_triples_pool - length, 45 | tocp); 46 | memcpy(bit_triple_c, c + (length + num_triples_pool - 1) / num_triples_pool * num_triples_pool - length, 47 | tocp); 48 | num_triples = 0; 49 | delete_array = true; 50 | } 51 | else if (length > num_triples_pool - num_triples) { // buffer is not long enough 52 | a = new T[length]; 53 | b = new T[length]; 54 | c = new T[length]; 55 | delete_array = true; 56 | memcpy(a, bit_triple_a + num_triples, sizeof(T) * (num_triples_pool - num_triples)); 57 | memcpy(b, bit_triple_b + num_triples, sizeof(T) * (num_triples_pool - num_triples)); 58 | memcpy(c, bit_triple_c + num_triples, sizeof(T) * (num_triples_pool - num_triples)); 59 | btp->get_triple(bit_triple_a, bit_triple_b, bit_triple_c); 60 | memcpy(a + num_triples_pool - num_triples, bit_triple_a, 61 | sizeof(T) * (length - (num_triples_pool - num_triples))); 62 | memcpy(b + num_triples_pool - num_triples, bit_triple_b, 63 | sizeof(T) * (length - (num_triples_pool - num_triples))); 64 | memcpy(c + num_triples_pool - num_triples, bit_triple_c, 65 | sizeof(T) * (length - (num_triples_pool - num_triples))); 66 | num_triples = length - (num_triples_pool - num_triples); 67 | } 68 | else { 69 | a = bit_triple_a + num_triples; 70 | b = bit_triple_b + num_triples; 71 | c = bit_triple_c + num_triples; 72 | num_triples += length; 73 | } 74 | } 75 | 76 | private: 77 | }; 78 | 79 | } // namespace emp 80 | 81 | #endif // EMP__SIMD_CIRCUIT_EXECUTION_H__ -------------------------------------------------------------------------------- /emp-aby/triple-providers/mp-bit-triple.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "emp-tool/emp-tool.h" 4 | #include "emp-ot/emp-ot.h" 5 | #include "emp-aby/io/multi-io.hpp" 6 | #include "emp-aby/utils.h" 7 | 8 | namespace emp { 9 | template 10 | class MPBitTripleProvider { 11 | public: 12 | MPBitTripleProvider( 13 | int num_party, int party, ThreadPool* pool, MPIOChannel* io, 14 | int buffer_length = ((ferret_b13.n - ferret_b13.k - ferret_b13.t * ferret_b13.log_bin_sz_pre - 128) / 128) * 15 | 128); 16 | 17 | void get_triple(block* a, block* b, block* c); //length should be the #blocks 18 | void get_triple(bool* a, bool* b, bool* c); // length ... # bools 19 | ~MPBitTripleProvider(); 20 | 21 | int BUFFER_SZ; 22 | 23 | private: 24 | vector*> cot_sender; 25 | vector*> cot_receiver; 26 | block ch[2]; 27 | 28 | void seed_gen() { 29 | vector> res; 30 | prg.random_block(&seed, 1); 31 | this->io->flush(); 32 | if (party == ALICE) { 33 | for (int i = 2; i <= num_party; ++i) { 34 | res.push_back(pool->enqueue([this, i] { this->io->send_block(i, &seed, 1); })); 35 | } 36 | } 37 | else { 38 | io->recv_block(ALICE, &seed, 1); 39 | } 40 | for (auto& v : res) 41 | v.get(); 42 | res.clear(); 43 | this->io->flush(); 44 | } 45 | 46 | block gen_delta() { 47 | block delta_i; 48 | 49 | prg.random_block(&delta_i, 1); 50 | block one = makeBlock(0xFFFFFFFFFFFFFFFFLL, 0xFFFFFFFFFFFFFFFELL); 51 | delta_i = delta_i & one; 52 | delta_i = delta_i ^ 0x1; 53 | 54 | return delta_i; 55 | } 56 | 57 | block seed; 58 | int party, threads; 59 | ThreadPool* pool = nullptr; 60 | MPIOChannel* io = nullptr; 61 | PRG prg; 62 | vector a_bool, b_bool, c_bool; 63 | block **key_star, **xor_mac, **xor_key, **key, **mac; 64 | bool *sent, *received; 65 | CRH crh; 66 | block delta; 67 | int num_party; 68 | bool b_set = false; 69 | void send(int send_to, const bool* a, block* s, int thread_idx) { 70 | if (send_to == party - 1) 71 | return; 72 | if (sent[send_to]) 73 | return; 74 | sent[send_to] = true; 75 | cot_sender[send_to]->rcot_inplace(this->key[thread_idx], ferret_b13.n, seed); 76 | for (int i = 0; i < BUFFER_SZ; i += 128) { 77 | xorBlocks_arr(key_star[thread_idx] + i, key[thread_idx] + i, delta, 128); 78 | crh.Hn(key[thread_idx] + i, key[thread_idx] + i, 128); 79 | crh.Hn(key_star[thread_idx] + i, key_star[thread_idx] + i, 128); 80 | } 81 | io->get(send_to + 1, party > (send_to + 1))->flush(); 82 | for (int j = 0; j < BUFFER_SZ; ++j) { 83 | s[j] = key[thread_idx][j] ^ key_star[thread_idx][j] ^ ch[a[j]]; 84 | } 85 | io->get(send_to + 1, party > (send_to + 1))->send_block(s, BUFFER_SZ); 86 | io->get(send_to + 1, party > (send_to + 1))->flush(); 87 | xorBlocks_arr(xor_key[thread_idx], xor_key[thread_idx], key[thread_idx], BUFFER_SZ); 88 | } 89 | 90 | void recv(int receive_from, bool* b, block* w, block* s, int thread_idx) { 91 | if (receive_from == party - 1) 92 | return; 93 | if (received[receive_from]) 94 | return; 95 | received[receive_from] = true; 96 | cot_receiver[receive_from]->rcot_inplace(this->mac[thread_idx], ferret_b13.n, seed); 97 | if (!b_set) { 98 | for (int i = 0; i < BUFFER_SZ; ++i) 99 | b[i] = getLSB(mac[thread_idx][i]); 100 | 101 | b_set = true; 102 | } 103 | io->get(receive_from + 1, party < (receive_from + 1))->flush(); 104 | io->get(receive_from + 1, party < (receive_from + 1))->recv_block(s, BUFFER_SZ); 105 | io->get(receive_from + 1, party < (receive_from + 1))->flush(); 106 | 107 | for (int i = 0; i < BUFFER_SZ; i += 128) { 108 | crh.Hn(mac[thread_idx] + i, mac[thread_idx] + i, 128); 109 | xorBlocks_arr(w + i, s + i, w + i, 128); 110 | xorBlocks_arr(xor_mac[thread_idx] + i, xor_mac[thread_idx] + i, mac[thread_idx] + i, 128); 111 | } 112 | } 113 | }; 114 | 115 | #include "emp-aby/triple-providers/mp-bit-triple.hpp" 116 | 117 | } // namespace emp 118 | -------------------------------------------------------------------------------- /test/mp_bit_triple.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/triple-providers/mp-bit-triple.h" 2 | #include "emp-aby/io/multi-io.hpp" 3 | using namespace emp; 4 | 5 | #include 6 | 7 | int party, port; 8 | 9 | const static int threads = 4; 10 | 11 | int num_party; 12 | 13 | template 14 | double test_bool_triple(MPBitTripleProvider* mp_bit_triple_provider, MPIOChannel* io, int length = 10) { 15 | length = mp_bit_triple_provider->BUFFER_SZ; 16 | bool *a = new bool[length], *b = new bool[length], *c = new bool[length]; 17 | auto start = clock_start(); 18 | 19 | mp_bit_triple_provider->get_triple(a, b, c); 20 | double t = time_from(start); 21 | io->sync(); 22 | 23 | if (party == ALICE) { 24 | bool *a0 = new bool[length], *b0 = new bool[length], *c0 = new bool[length]; 25 | 26 | for (int i = 2; i <= num_party; ++i) { 27 | io->recv_bool(i, a0, length); 28 | io->recv_bool(i, b0, length); 29 | io->recv_bool(i, c0, length); 30 | xorBools_arr(a, a, a0, length); 31 | xorBools_arr(b, b, b0, length); 32 | xorBools_arr(c, c, c0, length); 33 | } 34 | 35 | andBools_arr(c0, a, b, length); 36 | 37 | for (int i = 0; i < length; ++i) { 38 | if (c[i] != c0[i]) { 39 | std::cout << i << " " << a[i] << " " << b[i] << " " << c[i] << " " << c0[i] << std::endl; 40 | error("Bool Triple failed!"); 41 | } 42 | } 43 | 44 | // std::cout << "Bool Triple Passed" << std::endl; 45 | } 46 | else { 47 | io->send_bool(ALICE, a, length); 48 | io->send_bool(ALICE, b, length); 49 | io->send_bool(ALICE, c, length); 50 | } 51 | io->flush(); 52 | return t; 53 | } 54 | 55 | template 56 | double test_block_triple(MPBitTripleProvider* mp_bit_triple_provider, MPIOChannel* io) { 57 | int length = mp_bit_triple_provider->BUFFER_SZ / 128; 58 | block *a = new block[length], *b = new block[length], *c = new block[length]; 59 | auto start = clock_start(); 60 | mp_bit_triple_provider->get_triple(a, b, c); 61 | double t = time_from(start); 62 | io->sync(); 63 | 64 | if (party == ALICE) { 65 | block *a0 = new block[length], *b0 = new block[length], *c0 = new block[length]; 66 | 67 | for (int i = 2; i <= num_party; ++i) { 68 | io->recv_block(i, a0, length); 69 | io->recv_block(i, b0, length); 70 | io->recv_block(i, c0, length); 71 | xorBlocks_arr(a, a, a0, length); 72 | xorBlocks_arr(b, b, b0, length); 73 | xorBlocks_arr(c, c, c0, length); 74 | } 75 | 76 | andBlocks_arr(c0, a, b, length); 77 | 78 | if (!cmpBlock(c, c0, length)) { 79 | error("block Triple failed!"); 80 | } 81 | 82 | // std::cout << "block Triple Passed" << std::endl; 83 | } 84 | else { 85 | io->send_block(ALICE, a, length); 86 | io->send_block(ALICE, b, length); 87 | io->send_block(ALICE, c, length); 88 | } 89 | io->flush(); 90 | return t; 91 | } 92 | 93 | int main(int argc, char** argv) { 94 | if (argc < 4) { 95 | std::cout << "Format: test_mp_bit_triple PartyID Port num_parties" << std::endl; 96 | exit(0); 97 | } 98 | parse_party_and_port(argv, &party, &port); 99 | num_party = atoi(argv[3]); 100 | 101 | std::vector> net_config; 102 | 103 | if (argc == 5) { 104 | const char* file = argv[4]; 105 | FILE* f = fopen(file, "r"); 106 | for (int i = 0; i < num_party; ++i) { 107 | char* c = (char*)malloc(15 * sizeof(char)); 108 | uint p; 109 | fscanf(f, "%s %d\n", c, &p); 110 | std::string s(c); 111 | net_config.push_back(std::make_pair(s, p)); 112 | fflush(f); 113 | } 114 | fclose(f); 115 | } 116 | else { 117 | for (int i = 0; i < num_party; ++i) { 118 | std::string s = "127.0.0.1"; 119 | uint p = (port + 4 * num_party * i); 120 | net_config.push_back(std::make_pair(s, p)); 121 | } 122 | } 123 | 124 | MultiIO* io = new MultiIO(party, num_party, net_config); 125 | std::cout << party << " connected \n"; 126 | 127 | ThreadPool pool(threads); 128 | io->setup_ot_ios(); 129 | io->flush(); 130 | auto start = clock_start(); 131 | 132 | MPBitTripleProvider* mp_bit_triple_provider = 133 | new MPBitTripleProvider(num_party, party, &pool, io); 134 | double timeused = time_from(start); 135 | std::cout << party << "\tsetup\t" << timeused / 1000 << "ms" << std::endl; 136 | // const double buffer_size = pow(128, 3); 137 | 138 | std::cout << party << " BOOL TRIPLE GENERATION\t" << test_bool_triple(mp_bit_triple_provider, io) / (1000) << " ms" 139 | << std::endl; 140 | std::cout << party << " BLOCK TRIPLE GENERATION\t" << test_block_triple(mp_bit_triple_provider, io) / (1000) 141 | << " ms" << std::endl; 142 | 143 | delete io; 144 | } 145 | -------------------------------------------------------------------------------- /test/lut.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/lut.h" 2 | #include "emp-aby/io/multi-io.hpp" 3 | 4 | using namespace emp; 5 | 6 | #include 7 | 8 | int party, port; 9 | 10 | const static int threads = 4; 11 | 12 | int num_party; 13 | 14 | template 15 | void test_generate_shares(HE* he, LUT* lut, MPIOChannel* io, int n = 100) { 16 | n = he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2; 17 | int64_t* lut_share = new int64_t[2 * n]; 18 | bool* rotation = new bool[n]; 19 | int64_t table[2] = {0, 1}; 20 | lut->generate_shares(lut_share, rotation, n, table); 21 | if (party == ALICE) { 22 | bool* tmp_rot = new bool[n]; 23 | int64_t* tmp = new int64_t[2 * n]; 24 | 25 | for (int i = 2; i <= num_party; ++i) { 26 | io->recv_bool(i, tmp_rot, n); 27 | io->recv_data(i, (int64_t*)tmp, 2 * n * sizeof(int64_t)); 28 | 29 | for (int j = 0; j < 2 * n; ++j) { 30 | lut_share[j] = (he->q + lut_share[j] + tmp[j]) % he->q; 31 | } 32 | xorBools_arr(rotation, tmp_rot, rotation, n); 33 | } 34 | for (int i = 0; i < n; ++i) { 35 | // std::cout << rotation[i] << " " << lut_share[2 * i] << " " << lut_share[2 * i + 1] << "\n"; 36 | if (table[0] != lut_share[2 * i + rotation[i] * 1]) 37 | error("Failed lut share generation!"); 38 | if (table[1] != lut_share[2 * i + (!rotation[i]) * 1]) 39 | error("Failed lut share generation!"); 40 | } 41 | std::cout << "LUT Shares generated!" << std::endl; 42 | } 43 | else { 44 | io->send_bool(ALICE, rotation, n); 45 | io->send_data(ALICE, lut_share, 2 * n * sizeof(int64_t)); 46 | } 47 | } 48 | 49 | template 50 | void test_lookup(MPIOChannel* io, LUT* lut, HE* he, int64_t t[2], int n = 100) { 51 | bool* in = new bool[n]; 52 | int64_t* out = new int64_t[n]; 53 | PRG prg; 54 | prg.random_bool(in, n); 55 | auto start = clock_start(); 56 | lut->lookup(out, in, n); 57 | long long timeused = time_from(start); 58 | 59 | if (party == ALICE) { 60 | bool* tmp = new bool[n]; 61 | int64_t* tmp_out = new int64_t[n]; 62 | for (int i = 2; i <= num_party; ++i) { 63 | io->recv_bool(i, tmp, n); 64 | io->recv_data(i, tmp_out, n * sizeof(int64_t)); 65 | xorBools_arr(in, in, tmp, n); 66 | 67 | for (size_t j = 0; j < n; ++j) { 68 | out[j] = (he->q + out[j] + tmp_out[j]) % he->q; 69 | } 70 | } 71 | 72 | for (int i = 0; i < n; ++i) { 73 | if (t[(int)in[i]] != out[i]) { 74 | std::cout << t[(int)in[i]] << " " << out[i] << "\n"; 75 | error("Lookup failed!"); 76 | } 77 | } 78 | // std::cout << "Lookup test passed " << n << std::endl; 79 | } 80 | else { 81 | io->send_bool(ALICE, in, n); 82 | io->send_data(ALICE, out, n * sizeof(int64_t)); 83 | io->flush(ALICE); 84 | } 85 | std::cout << party << " Compute Done " << n << ": " << timeused << " micro sec" << std::endl; 86 | } 87 | 88 | int main(int argc, char** argv) { 89 | if (argc < 4) { 90 | std::cout << "Format: lut PartyID port num_parties" << std::endl; 91 | exit(0); 92 | } 93 | parse_party_and_port(argv, &party, &port); 94 | num_party = atoi(argv[3]); 95 | 96 | std::vector> net_config; 97 | 98 | if (argc == 5) { 99 | const char* file = argv[4]; 100 | FILE* f = fopen(file, "r"); 101 | for (int i = 0; i < num_party; ++i) { 102 | char* c = (char*)malloc(15 * sizeof(char)); 103 | uint p; 104 | fscanf(f, "%s %d\n", c, &p); 105 | std::string s(c); 106 | net_config.push_back(std::make_pair(s, p)); 107 | fflush(f); 108 | } 109 | fclose(f); 110 | } 111 | else { 112 | for (int i = 0; i < num_party; ++i) { 113 | std::string s = "127.0.0.1"; 114 | uint p = (port + 4 * num_party * i); 115 | net_config.push_back(std::make_pair(s, p)); 116 | } 117 | } 118 | 119 | ThreadPool pool(threads); 120 | 121 | MultiIO* io = new MultiIO(party, num_party, net_config); 122 | // io->extra_io(); 123 | std::cout << "io setup" << std::endl; 124 | 125 | auto start = clock_start(); 126 | const long long int modulus = (1L << 32) - (1 << 30) + 1; 127 | 128 | HE* he = new HE(num_party, io, &pool, party, modulus); 129 | he->rotation_keygen(); 130 | int n = he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2; 131 | // std::cout << "log2 q = " << log2(he->cc->GetCryptoParameters()->GetElementParams()->GetModulus().ConvertToDouble()) 132 | // << std::endl; 133 | int64_t table[2] = {0, 1}; 134 | 135 | LUT* lut = new LUT(num_party, party, io, &pool, he, table); 136 | long long timeused = time_from(start); 137 | std::cout << party << "\tsetup\t" << timeused / 1000 << "ms" << std::endl; 138 | 139 | test_lookup(io, lut, he, table, 1 * n); 140 | test_generate_shares(he, lut, io, n); 141 | delete io; 142 | } 143 | -------------------------------------------------------------------------------- /emp-aby/triple-providers/bit-triple.hpp: -------------------------------------------------------------------------------- 1 | template 2 | long long int BitTripleProvider::BUFFER_SZ = 3 | ((ferret_b13.n - ferret_b13.k - ferret_b13.t * ferret_b13.log_bin_sz - 128) / 128) * 128; 4 | 5 | template 6 | void BitTripleProvider::compute_rcots(bool* b) { 7 | auto t = clock_start(); 8 | ot0->rcot_inplace(r0.data(), ferret_b13.n); 9 | ot1->rcot_inplace(r1.data(), ferret_b13.n); 10 | // duplex_rcot_inplace(ot0, ot1, r0.data(),r1.data(), N_REG, N_REG); 11 | // std::cout << "1: " << time_from(t) << "\n"; 12 | t = clock_start(); 13 | if (party == ALICE) { 14 | for (int i = 0; i < BUFFER_SZ; i += 128) { 15 | this->crh.Hn(A_hat.data() + i, r0.data() + i, 128, scratch.data()); 16 | xorBlocks_arr(A_star.data() + i, r0.data() + i, this->delta, 128); 17 | this->crh.Hn(A_star.data() + i, A_star.data() + i, 128, scratch.data()); 18 | this->crh.Hn(B_hat.data() + i, r1.data() + i, 128, scratch.data()); 19 | for (int j = 0; j < 128; ++j) 20 | b[j + i] = getLSB(r1[j + i]); 21 | } 22 | } 23 | else { // BOB 24 | for (int i = 0; i < BUFFER_SZ; i += 128) { 25 | this->crh.Hn(B_hat.data() + i, r0.data() + i, 128, scratch.data()); 26 | for (int j = 0; j < 128; ++j) 27 | b[j + i] = getLSB(r0[j + i]); 28 | this->crh.Hn(A_hat.data() + i, r1.data() + i, 128, scratch.data()); 29 | xorBlocks_arr(A_star.data() + i, r1.data() + i, this->delta, 128); 30 | this->crh.Hn(A_star.data() + i, A_star.data() + i, 128, scratch.data()); 31 | } 32 | } 33 | // std::cout << "2: " << time_from(t) << "\n"; 34 | } 35 | 36 | template 37 | BitTripleProvider::BitTripleProvider(int party, int threads, IO** ios) { 38 | this->party = party; 39 | this->threads = threads; 40 | this->ios = ios; 41 | static std::string ot0_pre_ot_filename = (party == ALICE ? "./data/ALICE_sender" : "./data/BOB_receiver"); 42 | this->ot0 = new FerretCOT(party, threads, this->ios, false, true, ferret_b13, ot0_pre_ot_filename); 43 | int rev_party = party == ALICE ? BOB : ALICE; 44 | static std::string ot1_pre_ot_filename = (party == ALICE ? "./data/ALICE_receiver" : "./data/BOB_sender"); 45 | this->ot1 = new FerretCOT(rev_party, threads, this->ios, false, true, ferret_b13, ot1_pre_ot_filename); 46 | BUFFER_SZ = ((ferret_b13.n - ferret_b13.k - ferret_b13.t * ferret_b13.log_bin_sz - 128) / 128) * 128; 47 | this->delta = (party == ALICE) ? this->ot0->Delta : this->ot1->Delta; 48 | r0.resize(emp::ferret_b13.n); 49 | r1.resize(emp::ferret_b13.n); 50 | scratch.resize(BUFFER_SZ); 51 | A_hat.resize(BUFFER_SZ); 52 | A_star.resize(BUFFER_SZ); 53 | B_hat.resize(BUFFER_SZ); 54 | a_bool.resize(BUFFER_SZ); 55 | b_bool.resize(BUFFER_SZ); 56 | c_bool.resize(BUFFER_SZ); 57 | } 58 | 59 | template 60 | BitTripleProvider::~BitTripleProvider() { 61 | delete ot0; 62 | delete ot1; 63 | } 64 | 65 | template 66 | void BitTripleProvider::get_triple(block* a, block* b, block* c) { 67 | this->get_triple((bool*)a_bool.data(), (bool*)b_bool.data(), (bool*)c_bool.data()); 68 | bool_to_block_arr(a, (bool*)a_bool.data(), BUFFER_SZ); 69 | bool_to_block_arr(b, (bool*)b_bool.data(), BUFFER_SZ); 70 | bool_to_block_arr(c, (bool*)c_bool.data(), BUFFER_SZ); 71 | } 72 | 73 | template 74 | void BitTripleProvider::get_triple(bool* a, bool* b, bool* c) { 75 | this->compute_rcots(b); 76 | // auto t = clock_start(); 77 | for (int i = 0; i < BUFFER_SZ; ++i) { 78 | a[i] = getLSB(A_hat[i]) ^ getLSB(A_star[i]); 79 | c[i] = (a[i] & b[i]) ^ getLSB(B_hat[i]) ^ getLSB(A_hat[i]); 80 | } 81 | // std::cout<<"3: "<< time_from(t)<<"\n"; 82 | } 83 | 84 | // bit-vector multiplication. AKA multiplexier 85 | // \Xor C[i,...,i+wdith-1] = (\xor b[i])(\Xor A [i,...,i+width-1]) for i in[0, length) 86 | /* template 87 | void BitTripleProvider::get_mux_triple(block* A, bool* b, block* C, int width, int length) { 88 | int num = length * width; 89 | block *A_hat = new block[num], *A_star = new block[num]; 90 | block* B_hat = new block[num]; 91 | bool* b_all = new bool[num]; 92 | this->compute_rcots(b_all); 93 | block* r = new block[num]; 94 | for (int i = 0; i < num; ++i) 95 | if (b_all[i]) 96 | r[i] = A_star[i]; 97 | else 98 | r[i] = A_hat[i]; 99 | 100 | xorBlocks_arr(A, A_hat, A_star, num); 101 | xorBlocks_arr(C, B_hat, r, num); 102 | 103 | bool* xors = new bool[num]; 104 | 105 | for (int i = 0; i < length; ++i) { 106 | for (int j = 0; j < width; ++j) { 107 | xors[i * width + j] = b_all[i * width] ^ b_all[i * width + j]; 108 | } 109 | } 110 | if (party == ALICE) { 111 | this->ios[0]->send_bool(xors, num); 112 | this->ios[0]->recv_bool(xors, num); 113 | } 114 | else if (party == BOB) { 115 | bool* xors_a = new bool[num]; 116 | this->ios[0]->recv_bool(xors_a, num); 117 | xorBools_arr(xors, xors, xors_a, num); 118 | delete[] xors_a; 119 | this->ios[0]->send_bool(xors, num); 120 | } 121 | this->ios[0]->sync(); 122 | for (int i = 0; i < length; ++i) { 123 | b[i] = b_all[i * width]; 124 | for (int j = 0; j < width; ++j) { 125 | if (xors[i * width + j]) { 126 | C[i * width + j] = C[i * width + j] ^ A[i * width + j]; 127 | } 128 | } 129 | } 130 | 131 | delete[] A_hat; 132 | delete[] B_hat; 133 | delete[] b_all; 134 | delete[] A_star; 135 | delete[] xors; 136 | delete[] r; 137 | } 138 | */ 139 | -------------------------------------------------------------------------------- /test/positive.txt: -------------------------------------------------------------------------------- 1 | 239 367 2 | 64 64 1 3 | 4 | 2 1 0 64 128 AND 5 | 2 1 128 1 129 XOR 6 | 2 1 128 65 130 XOR 7 | 2 1 129 130 131 AND 8 | 2 1 131 128 132 XOR 9 | 2 1 132 2 133 XOR 10 | 2 1 132 66 134 XOR 11 | 2 1 133 134 135 AND 12 | 2 1 135 132 136 XOR 13 | 2 1 136 3 137 XOR 14 | 2 1 136 67 138 XOR 15 | 2 1 137 138 139 AND 16 | 2 1 139 136 140 XOR 17 | 2 1 140 4 141 XOR 18 | 2 1 140 68 142 XOR 19 | 2 1 141 142 143 AND 20 | 2 1 143 140 144 XOR 21 | 2 1 144 5 145 XOR 22 | 2 1 144 69 146 XOR 23 | 2 1 145 146 147 AND 24 | 2 1 147 144 148 XOR 25 | 2 1 148 6 149 XOR 26 | 2 1 148 70 150 XOR 27 | 2 1 149 150 151 AND 28 | 2 1 151 148 152 XOR 29 | 2 1 152 7 153 XOR 30 | 2 1 152 71 154 XOR 31 | 2 1 153 154 155 AND 32 | 2 1 155 152 156 XOR 33 | 2 1 156 8 157 XOR 34 | 2 1 156 72 158 XOR 35 | 2 1 157 158 159 AND 36 | 2 1 159 156 160 XOR 37 | 2 1 160 9 161 XOR 38 | 2 1 160 73 162 XOR 39 | 2 1 161 162 163 AND 40 | 2 1 163 160 164 XOR 41 | 2 1 164 10 165 XOR 42 | 2 1 164 74 166 XOR 43 | 2 1 165 166 167 AND 44 | 2 1 167 164 168 XOR 45 | 2 1 168 11 169 XOR 46 | 2 1 168 75 170 XOR 47 | 2 1 169 170 171 AND 48 | 2 1 171 168 172 XOR 49 | 2 1 172 12 173 XOR 50 | 2 1 172 76 174 XOR 51 | 2 1 173 174 175 AND 52 | 2 1 175 172 176 XOR 53 | 2 1 176 13 177 XOR 54 | 2 1 176 77 178 XOR 55 | 2 1 177 178 179 AND 56 | 2 1 179 176 180 XOR 57 | 2 1 180 14 181 XOR 58 | 2 1 180 78 182 XOR 59 | 2 1 181 182 183 AND 60 | 2 1 183 180 184 XOR 61 | 2 1 184 15 185 XOR 62 | 2 1 184 79 186 XOR 63 | 2 1 185 186 187 AND 64 | 2 1 187 184 188 XOR 65 | 2 1 188 16 189 XOR 66 | 2 1 188 80 190 XOR 67 | 2 1 189 190 191 AND 68 | 2 1 191 188 192 XOR 69 | 2 1 192 17 193 XOR 70 | 2 1 192 81 194 XOR 71 | 2 1 193 194 195 AND 72 | 2 1 195 192 196 XOR 73 | 2 1 196 18 197 XOR 74 | 2 1 196 82 198 XOR 75 | 2 1 197 198 199 AND 76 | 2 1 199 196 200 XOR 77 | 2 1 200 19 201 XOR 78 | 2 1 200 83 202 XOR 79 | 2 1 201 202 203 AND 80 | 2 1 203 200 204 XOR 81 | 2 1 204 20 205 XOR 82 | 2 1 204 84 206 XOR 83 | 2 1 205 206 207 AND 84 | 2 1 207 204 208 XOR 85 | 2 1 208 21 209 XOR 86 | 2 1 208 85 210 XOR 87 | 2 1 209 210 211 AND 88 | 2 1 211 208 212 XOR 89 | 2 1 212 22 213 XOR 90 | 2 1 212 86 214 XOR 91 | 2 1 213 214 215 AND 92 | 2 1 215 212 216 XOR 93 | 2 1 216 23 217 XOR 94 | 2 1 216 87 218 XOR 95 | 2 1 217 218 219 AND 96 | 2 1 219 216 220 XOR 97 | 2 1 220 24 221 XOR 98 | 2 1 220 88 222 XOR 99 | 2 1 221 222 223 AND 100 | 2 1 223 220 224 XOR 101 | 2 1 224 25 225 XOR 102 | 2 1 224 89 226 XOR 103 | 2 1 225 226 227 AND 104 | 2 1 227 224 228 XOR 105 | 2 1 228 26 229 XOR 106 | 2 1 228 90 230 XOR 107 | 2 1 229 230 231 AND 108 | 2 1 231 228 232 XOR 109 | 2 1 232 27 233 XOR 110 | 2 1 232 91 234 XOR 111 | 2 1 233 234 235 AND 112 | 2 1 235 232 236 XOR 113 | 2 1 236 28 237 XOR 114 | 2 1 236 92 238 XOR 115 | 2 1 237 238 239 AND 116 | 2 1 239 236 240 XOR 117 | 2 1 240 29 241 XOR 118 | 2 1 240 93 242 XOR 119 | 2 1 241 242 243 AND 120 | 2 1 243 240 244 XOR 121 | 2 1 244 30 245 XOR 122 | 2 1 244 94 246 XOR 123 | 2 1 245 246 247 AND 124 | 2 1 247 244 248 XOR 125 | 2 1 248 31 249 XOR 126 | 2 1 248 95 250 XOR 127 | 2 1 249 250 251 AND 128 | 2 1 251 248 252 XOR 129 | 2 1 252 32 253 XOR 130 | 2 1 252 96 254 XOR 131 | 2 1 253 254 255 AND 132 | 2 1 255 252 256 XOR 133 | 2 1 256 33 257 XOR 134 | 2 1 256 97 258 XOR 135 | 2 1 257 258 259 AND 136 | 2 1 259 256 260 XOR 137 | 2 1 260 34 261 XOR 138 | 2 1 260 98 262 XOR 139 | 2 1 261 262 263 AND 140 | 2 1 263 260 264 XOR 141 | 2 1 264 35 265 XOR 142 | 2 1 264 99 266 XOR 143 | 2 1 265 266 267 AND 144 | 2 1 267 264 268 XOR 145 | 2 1 268 36 269 XOR 146 | 2 1 268 100 270 XOR 147 | 2 1 269 270 271 AND 148 | 2 1 271 268 272 XOR 149 | 2 1 272 37 273 XOR 150 | 2 1 272 101 274 XOR 151 | 2 1 273 274 275 AND 152 | 2 1 275 272 276 XOR 153 | 2 1 276 38 277 XOR 154 | 2 1 276 102 278 XOR 155 | 2 1 277 278 279 AND 156 | 2 1 279 276 280 XOR 157 | 2 1 280 39 281 XOR 158 | 2 1 280 103 282 XOR 159 | 2 1 281 282 283 AND 160 | 2 1 283 280 284 XOR 161 | 2 1 284 40 285 XOR 162 | 2 1 284 104 286 XOR 163 | 2 1 285 286 287 AND 164 | 2 1 287 284 288 XOR 165 | 2 1 288 41 289 XOR 166 | 2 1 288 105 290 XOR 167 | 2 1 289 290 291 AND 168 | 2 1 291 288 292 XOR 169 | 2 1 292 42 293 XOR 170 | 2 1 292 106 294 XOR 171 | 2 1 293 294 295 AND 172 | 2 1 295 292 296 XOR 173 | 2 1 296 43 297 XOR 174 | 2 1 296 107 298 XOR 175 | 2 1 297 298 299 AND 176 | 2 1 299 296 300 XOR 177 | 2 1 300 44 301 XOR 178 | 2 1 300 108 302 XOR 179 | 2 1 301 302 303 AND 180 | 2 1 303 300 304 XOR 181 | 2 1 304 45 305 XOR 182 | 2 1 304 109 306 XOR 183 | 2 1 305 306 307 AND 184 | 2 1 307 304 308 XOR 185 | 2 1 308 46 309 XOR 186 | 2 1 308 110 310 XOR 187 | 2 1 309 310 311 AND 188 | 2 1 311 308 312 XOR 189 | 2 1 312 47 313 XOR 190 | 2 1 312 111 314 XOR 191 | 2 1 313 314 315 AND 192 | 2 1 315 312 316 XOR 193 | 2 1 316 48 317 XOR 194 | 2 1 316 112 318 XOR 195 | 2 1 317 318 319 AND 196 | 2 1 319 316 320 XOR 197 | 2 1 320 49 321 XOR 198 | 2 1 320 113 322 XOR 199 | 2 1 321 322 323 AND 200 | 2 1 323 320 324 XOR 201 | 2 1 324 50 325 XOR 202 | 2 1 324 114 326 XOR 203 | 2 1 325 326 327 AND 204 | 2 1 327 324 328 XOR 205 | 2 1 328 51 329 XOR 206 | 2 1 328 115 330 XOR 207 | 2 1 329 330 331 AND 208 | 2 1 331 328 332 XOR 209 | 2 1 332 52 333 XOR 210 | 2 1 332 116 334 XOR 211 | 2 1 333 334 335 AND 212 | 2 1 335 332 336 XOR 213 | 2 1 336 53 337 XOR 214 | 2 1 336 117 338 XOR 215 | 2 1 337 338 339 AND 216 | 2 1 339 336 340 XOR 217 | 2 1 340 54 341 XOR 218 | 2 1 340 118 342 XOR 219 | 2 1 341 342 343 AND 220 | 2 1 343 340 344 XOR 221 | 2 1 344 55 345 XOR 222 | 2 1 344 119 346 XOR 223 | 2 1 345 346 347 AND 224 | 2 1 347 344 348 XOR 225 | 2 1 348 56 349 XOR 226 | 2 1 348 120 350 XOR 227 | 2 1 349 350 351 AND 228 | 2 1 351 348 352 XOR 229 | 2 1 352 57 353 XOR 230 | 2 1 352 121 354 XOR 231 | 2 1 353 354 355 AND 232 | 2 1 355 352 356 XOR 233 | 2 1 356 58 357 XOR 234 | 2 1 356 122 358 XOR 235 | 2 1 357 358 359 AND 236 | 2 1 359 356 360 XOR 237 | 2 1 360 59 361 XOR 238 | 2 1 360 123 362 XOR 239 | 2 1 361 362 363 AND 240 | 2 1 363 360 364 XOR 241 | 2 1 364 124 365 XOR 242 | 2 1 365 60 366 XOR 243 | -------------------------------------------------------------------------------- /test/b2aconverter.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/converter/b2aconverter.h" 2 | #include "emp-aby/io/multi-io.hpp" 3 | 4 | using namespace emp; 5 | 6 | #include 7 | 8 | int party, port; 9 | 10 | const static int threads = 4; 11 | 12 | int num_party; 13 | 14 | template 15 | void check(MPIOChannel* io, bool* b, int64_t* a, int n, long long q, int l, string msg, bool mod = true) { 16 | int length = l * n; 17 | // io->sync(); 18 | if (party == ALICE) { 19 | bool* tmp_b = new bool[length]; 20 | int64_t* tmp_a = new int64_t[n]; 21 | 22 | for (int i = 2; i <= num_party; ++i) { 23 | io->recv_bool(i, tmp_b, length); 24 | xorBools_arr(b, b, tmp_b, length); 25 | 26 | io->recv_data(i, tmp_a, n * sizeof(int64_t)); 27 | io->flush(i); 28 | for (int j = 0; j < n; ++j) { 29 | a[j] = (a[j] + tmp_a[j]) % q; 30 | } 31 | } 32 | int64_t* check = new int64_t[n]; 33 | memset(check, 0, n * sizeof(int64_t)); 34 | // std::cout << l << " " << q << " " << n << std::endl; 35 | for (int i = 0; i < l; ++i) { 36 | for (int j = 0; j < n; ++j) { 37 | if (b[j * l + i]) { 38 | check[j] += (1L << i); 39 | } 40 | if (mod) 41 | check[j] %= q; 42 | } 43 | } 44 | for (int i = 0; i < n; ++i) 45 | if (check[i] != a[i]) { 46 | std::cout << i << " " << check[i] << " " << a[i] << std::endl; 47 | std::cout << msg << " "; 48 | error("Test failed!"); 49 | } 50 | 51 | // std::cout << msg << " Test passed" << std::endl; 52 | 53 | delete[] tmp_a; 54 | delete[] tmp_b; 55 | delete[] check; 56 | } 57 | else { 58 | io->send_bool(ALICE, b, length); 59 | io->send_data(ALICE, a, n * sizeof(int64_t)); 60 | io->flush(ALICE); 61 | } 62 | } 63 | 64 | template 65 | void test_b2a(MPIOChannel* io, B2AConverter* converter, HE* he, int n = 100, double comm_offset = 0) { 66 | PRG prg; 67 | int l = ceil(log2(he->q)); 68 | int length = l * n; 69 | bool* boolean = new bool[length]; 70 | prg.random_bool(boolean, length); 71 | int64_t* arithmetic = new int64_t[n]; 72 | auto start = clock_start(); 73 | converter->convert(arithmetic, boolean, length, l); 74 | double timeused = time_from(start); 75 | double online_comm = io->get_total_bytes_sent() - comm_offset; 76 | check(io, boolean, arithmetic, n, he->q, l, "B2A"); 77 | std::cout << party << " Online time B2A conversion per 32-bit: " << timeused / (n * 1000) << " ms" << std::endl; 78 | std::cout << party << " Online comm B2A conversion per 32-bit: " << online_comm / (n) << " KB" << std::endl; 79 | } 80 | 81 | int main(int argc, char** argv) { 82 | if (argc < 4) { 83 | std::cout << "Format: b2aconverter PartyID port num_parties" << std::endl; 84 | exit(0); 85 | } 86 | parse_party_and_port(argv, &party, &port); 87 | num_party = atoi(argv[3]); 88 | 89 | std::vector> net_config; 90 | 91 | if (argc == 5) { 92 | const char* file = argv[4]; 93 | FILE* f = fopen(file, "r"); 94 | for (int i = 0; i < num_party; ++i) { 95 | char* c = (char*)malloc(15 * sizeof(char)); 96 | uint p; 97 | fscanf(f, "%s %d\n", c, &p); 98 | std::string s(c); 99 | net_config.push_back(std::make_pair(s, p)); 100 | fflush(f); 101 | } 102 | fclose(f); 103 | } 104 | else { 105 | for (int i = 0; i < num_party; ++i) { 106 | std::string s = "127.0.0.1"; 107 | uint p = (port + 4 * num_party * i); 108 | net_config.push_back(std::make_pair(s, p)); 109 | } 110 | } 111 | 112 | MultiIO* io = new MultiIO(party, num_party, net_config); 113 | std::cout << "io setup" << std::endl; 114 | ThreadPool* pool = new ThreadPool(threads); 115 | const long long int modulus = (1L << 32) - (1L << 30) + 1; 116 | HE* he = new HE(num_party, io, pool, party, modulus); 117 | he->multiplication_keygen(); 118 | he->rotation_keygen(); 119 | std::cout << party << " p = " << he->cc->GetCryptoParameters()->GetPlaintextModulus() << std::endl; 120 | std::cout << party << " n = " << he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2 121 | << std::endl; 122 | std::cout << party 123 | << " log2 q = " << log2(he->cc->GetCryptoParameters()->GetElementParams()->GetModulus().ConvertToDouble()) 124 | << std::endl; 125 | // io->flush(); 126 | int pool_size = std::max(20 / num_party - 1, 0) * num_party + num_party; 127 | if (num_party > 16) 128 | pool_size = 16; 129 | std::cout << "pool size " << pool_size << std::endl; 130 | 131 | // int pool_size = 20; 132 | // std::cout << "go in b2a constructor \n"; 133 | auto start = clock_start(); 134 | B2AConverter* converter = new B2AConverter(num_party, party, io, pool, he, pool_size); 135 | double timeused = time_from(start); 136 | double offline_comm = io->get_total_bytes_sent(); 137 | int n = (he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2); 138 | 139 | std::cout << party << "\tB2A offline time\t" << timeused / ((pool_size * n / 32) * 1000) << " ms" << std::endl; 140 | std::cout << party << "\tB2A offline comm\t" << offline_comm / ((pool_size * n / 32)) << " KB" << std::endl; 141 | test_b2a(io, converter, he, (min(20, pool_size) * n) / 32, offline_comm); 142 | delete he; 143 | delete io; 144 | } 145 | -------------------------------------------------------------------------------- /test/bit_triple.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/triple-providers/bit-triple.h" 2 | #include "emp-aby/io/multi-io.hpp" 3 | 4 | #include 5 | int party, port; 6 | 7 | const static int threads = 1; 8 | 9 | template 10 | double test_get_block_triple(BitTripleProvider* bt, IO* io) { 11 | const int length = BitTripleProvider::BUFFER_SZ / 128; 12 | block* a = new block[length]; 13 | block* b = new block[length]; 14 | block* c = new block[length]; 15 | auto start = clock_start(); 16 | bt->get_triple(a, b, c); 17 | long long t = time_from(start); 18 | if (party == ALICE) { 19 | io->send_block(a, length); 20 | io->send_block(b, length); 21 | io->send_block(c, length); 22 | } 23 | else { 24 | block* a0 = new block[length]; 25 | block* b0 = new block[length]; 26 | block* c0 = new block[length]; 27 | io->recv_block(a0, length); 28 | io->recv_block(b0, length); 29 | io->recv_block(c0, length); 30 | 31 | block* lhs = new block[length]; 32 | block* rhs = new block[length]; 33 | xorBlocks_arr(lhs, c0, c, length); 34 | block* a_xor = new block[length]; 35 | block* b_xor = new block[length]; 36 | xorBlocks_arr(a_xor, a0, a, length); 37 | xorBlocks_arr(b_xor, b0, b, length); 38 | andBlocks_arr(rhs, a_xor, b_xor, length); 39 | if (!cmpBlock(lhs, rhs, length)) { 40 | std::cout << "Bit Triple Failed \n"; 41 | error("Bit Triple Failed"); 42 | } 43 | 44 | delete[] a0; 45 | delete[] b0; 46 | delete[] c0; 47 | delete[] lhs; 48 | delete[] rhs; 49 | delete[] a_xor; 50 | delete[] b_xor; 51 | } 52 | delete[] a; 53 | delete[] b; 54 | delete[] c; 55 | return t; 56 | } 57 | 58 | template 59 | double test_get_bool_triple(BitTripleProvider* bt, IO* io) { 60 | const int length = BitTripleProvider::BUFFER_SZ; 61 | bool* a = new bool[length]; 62 | bool* b = new bool[length]; 63 | bool* c = new bool[length]; 64 | memset(a, 0, length); 65 | memset(b, 0, length); 66 | memset(c, 0, length); 67 | auto start = clock_start(); 68 | bt->get_triple(a, b, c); 69 | long long t = time_from(start); 70 | if (party == ALICE) { 71 | io->send_bool(a, length); 72 | io->send_bool(b, length); 73 | io->send_bool(c, length); 74 | } 75 | else { 76 | bool* a0 = new bool[length]; 77 | bool* b0 = new bool[length]; 78 | bool* c0 = new bool[length]; 79 | io->recv_bool(a0, length); 80 | io->recv_bool(b0, length); 81 | io->recv_bool(c0, length); 82 | 83 | bool* lhs = new bool[length]; 84 | bool* rhs = new bool[length]; 85 | for (int i = 0; i < length; ++i) { 86 | lhs[i] = c0[i] ^ c[i]; 87 | rhs[i] = (a0[i] ^ a[i]) & (b0[i] ^ b[i]); 88 | if (lhs[i] != rhs[i]) { 89 | std::cout << c0[i] << c[i] << std::endl; 90 | std::cout << a0[i] << a[i] << std::endl; 91 | std::cout << b0[i] << b[i] << std::endl; 92 | std::cout << "Bool Triple Failed \n"; 93 | } 94 | } 95 | 96 | delete[] a0; 97 | delete[] b0; 98 | delete[] c0; 99 | delete[] lhs; 100 | delete[] rhs; 101 | } 102 | delete[] a; 103 | delete[] b; 104 | delete[] c; 105 | return t; 106 | } 107 | 108 | /*double test_get_mux_triple(BitTripleProvider *bt, int length = 10, int width = 12) 109 | { 110 | 111 | block *A = new block[length * width], *C = new block[length * width]; 112 | bool *b = new bool[length]; 113 | auto start = clock_start(); 114 | bt->get_mux_triple(A, b, C, width, length); 115 | long long t = time_from(start); 116 | IO *io = new IO(party == ALICE ? nullptr : "127.0.0.1", port + length); 117 | if (party == ALICE) 118 | { 119 | io->send_block(A, length * width); 120 | io->send_bool(b, length); 121 | io->send_block(C, length * width); 122 | } 123 | else if (party == BOB) 124 | { 125 | block *A0 = new block[length * width], *C0 = new block[length * width]; 126 | bool *b0 = new bool[length]; 127 | 128 | io->recv_block(A0, length * width); 129 | io->recv_bool(b0, length); 130 | io->recv_block(C0, length * width); 131 | 132 | xorBlocks_arr(A, A0, A, length * width); 133 | xorBlocks_arr(C, C0, C, length * width); 134 | 135 | for (int i = 0; i < length; ++i) 136 | { 137 | if (b[i] ^ b0[i]) 138 | { 139 | if (!cmpBlock(A + i * width, C + i * width, width)) 140 | error("Mux Triple Failed"); 141 | } 142 | else 143 | { 144 | for (int j = 0; j < width; ++j) 145 | if (!cmpBlock(C + i * width + j, &zero_block, 1)) 146 | error("Mux Triple Failed"); 147 | } 148 | } 149 | delete[] A0; 150 | delete[] b0; 151 | delete[] C0; 152 | } 153 | // std::cout << "Mux triple passed. \t"; 154 | delete io; 155 | delete[] A; 156 | delete[] b; 157 | delete[] C; 158 | return t; 159 | }*/ 160 | 161 | int main(int argc, char** argv) { 162 | parse_party_and_port(argv, &party, &port); 163 | 164 | std::vector> net_config; 165 | 166 | for (int i = 0; i < 2; ++i) { 167 | std::string s = "127.0.0.1"; 168 | uint p = (port + 4 * 2 * i); 169 | net_config.push_back(std::make_pair(s, p)); 170 | } 171 | auto start = clock_start(); 172 | vector ios; 173 | 174 | for (int i = 0; i < threads; ++i) 175 | ios.push_back(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port)); 176 | BitTripleProvider* bt = new BitTripleProvider(party, threads, ios.data()); 177 | double timeused = time_from(start); 178 | const double buffer_size = BitTripleProvider::BUFFER_SZ; 179 | std::cout << party << "\tsetup\t" << timeused / 1000 << "ms" << std::endl; 180 | std::cout << "BLOCK TRIPLE GENERATION\t" << test_get_block_triple(bt, ios[0]) / buffer_size * 128 181 | << " ns per triple" << std::endl; 182 | std::cout << "BOOL TRIPLE GENERATION\t" << test_get_bool_triple(bt, ios[0]) / buffer_size << " ns per triple" 183 | << std::endl; 184 | //std::cout << "(10000, 100) MUX TRIPLE GENERATION\t" << test_get_mux_triple(bt, 10000, 100)/1000 << "ms" << std::endl; 185 | delete bt; 186 | for (auto io : ios) 187 | delete io; 188 | } 189 | -------------------------------------------------------------------------------- /test/arithmetic_circ.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/simd_interface/arithmetic-circ.h" 2 | #include "emp-aby/io/multi-io.hpp" 3 | 4 | using namespace emp; 5 | 6 | #include 7 | 8 | int party, port; 9 | 10 | const static int threads = 4; 11 | 12 | int num_party; 13 | 14 | template 15 | void test_triple(MPIOChannel* io, HE* he, ArithmeticCirc* circ) { 16 | int n = circ->num_triples_pool; 17 | int64_t *triple_a, *triple_b, *triple_c; 18 | 19 | triple_a = new int64_t[n]; 20 | triple_b = new int64_t[n]; 21 | triple_c = new int64_t[n]; 22 | 23 | circ->get_triples(triple_a, triple_b, triple_c); 24 | 25 | if (party == ALICE) { 26 | int64_t *a, *b, *c; 27 | 28 | a = new int64_t[n]; 29 | b = new int64_t[n]; 30 | c = new int64_t[n]; 31 | for (int i = 2; i <= num_party; ++i) { 32 | io->recv_data(i, a, n * sizeof(int64_t)); 33 | io->recv_data(i, b, n * sizeof(int64_t)); 34 | io->recv_data(i, c, n * sizeof(int64_t)); 35 | 36 | for (int j = 0; j < n; ++j) { 37 | triple_a[j] = (he->q + triple_a[j] + a[j]) % he->q; 38 | triple_b[j] = (he->q + triple_b[j] + b[j]) % he->q; 39 | triple_c[j] = (he->q + triple_c[j] + c[j]) % he->q; 40 | } 41 | } 42 | for (int i = 0; i < n; ++i) { 43 | uint64_t x = ((uint64_t)triple_a[i] * (uint64_t)triple_b[i]); 44 | if (triple_c[i] != x % he->q) { 45 | std::cout << i << " " << triple_a[i] << " " << triple_b[i] << " " << triple_c[i] << " " << x % he->q 46 | << std::endl; 47 | error("Arithmetic triple failed!"); 48 | } 49 | } 50 | std::cout << "Arithmetic triple test passed!" << std::endl; 51 | } 52 | else { 53 | io->send_data(ALICE, triple_a, n * sizeof(int64_t)); 54 | io->send_data(ALICE, triple_b, n * sizeof(int64_t)); 55 | io->send_data(ALICE, triple_c, n * sizeof(int64_t)); 56 | } 57 | io->flush(); 58 | } 59 | 60 | template 61 | void test_mul(MPIOChannel* io, HE* he, ArithmeticCirc* circ, int length = 1000) { 62 | int64_t *in1 = new int64_t[length], *in2 = new int64_t[length], *out = new int64_t[length]; 63 | PRG prg; 64 | prg.random_data(in1, length * sizeof(int64_t)); 65 | prg.random_data(in2, length * sizeof(int64_t)); 66 | for (size_t i = 0; i < length; ++i) { 67 | in1[i] %= he->q; 68 | in2[i] %= he->q; 69 | in1[i] = (he->q + in1[i]) % he->q; 70 | in2[i] = (he->q + in2[i]) % he->q; 71 | } 72 | auto start = clock_start(); 73 | circ->mult(out, in1, in2, length); 74 | double timeused = time_from(start); 75 | if (party == ALICE) { 76 | int64_t *a, *b, *c; 77 | 78 | a = new int64_t[length]; 79 | b = new int64_t[length]; 80 | c = new int64_t[length]; 81 | for (int i = 2; i <= num_party; ++i) { 82 | io->recv_data(i, a, length * sizeof(int64_t)); 83 | io->recv_data(i, b, length * sizeof(int64_t)); 84 | io->recv_data(i, c, length * sizeof(int64_t)); 85 | 86 | for (int j = 0; j < length; ++j) { 87 | in1[j] = (he->q + in1[j] + a[j]) % he->q; 88 | in2[j] = (he->q + in2[j] + b[j]) % he->q; 89 | out[j] = (he->q + out[j] + c[j]) % he->q; 90 | } 91 | } 92 | for (int i = 0; i < length; ++i) { 93 | uint64_t x = ((uint64_t)in1[i] * (uint64_t)in2[i]); 94 | if (out[i] != x % he->q) { 95 | std::cout << i << " " << out[i] << " " << in1[i] << " " << in2[i] << std::endl; 96 | error("Arithmetic multiplication failed!"); 97 | } 98 | } 99 | } 100 | else { 101 | io->send_data(ALICE, in1, length * sizeof(int64_t)); 102 | io->send_data(ALICE, in2, length * sizeof(int64_t)); 103 | io->send_data(ALICE, out, length * sizeof(int64_t)); 104 | } 105 | std::cout << "Arith. Mult Online:\t" << timeused / (1000 * length) << " ms" << std::endl; 106 | } 107 | 108 | int main(int argc, char** argv) { 109 | if (argc < 4) { 110 | std::cout << "Format: test_mp_bit_triple PartyID Port num_parties" << std::endl; 111 | exit(0); 112 | } 113 | parse_party_and_port(argv, &party, &port); 114 | num_party = atoi(argv[3]); 115 | 116 | std::vector> net_config; 117 | 118 | if (argc == 5) { 119 | const char* file = argv[4]; 120 | FILE* f = fopen(file, "r"); 121 | for (int i = 0; i < num_party; ++i) { 122 | char* c = (char*)malloc(15 * sizeof(char)); 123 | uint p; 124 | fscanf(f, "%s %d\n", c, &p); 125 | std::string s(c); 126 | net_config.push_back(std::make_pair(s, p)); 127 | fflush(f); 128 | } 129 | fclose(f); 130 | } 131 | else { 132 | for (int i = 0; i < num_party; ++i) { 133 | std::string s = "127.0.0.1"; 134 | uint p = (port + 4 * num_party * i); 135 | net_config.push_back(std::make_pair(s, p)); 136 | } 137 | } 138 | 139 | MultiIO* io = new MultiIO(party, num_party, net_config); 140 | std::cout << party << " connected \n"; 141 | ThreadPool pool(threads); 142 | io->flush(); 143 | std::cout << "io setup" << std::endl; 144 | const long long int modulus = (1L << 32) - (1 << 30) + 1; 145 | 146 | HE* he = new HE(num_party, io, &pool, party, modulus, 1, true, false, true); 147 | he->multiplication_keygen(); 148 | 149 | std::cout << "p = " << he->cc->GetCryptoParameters()->GetPlaintextModulus() << std::endl; 150 | std::cout << "n = " << he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2 << std::endl; 151 | std::cout << "log2 q = " << log2(he->cc->GetCryptoParameters()->GetElementParams()->GetModulus().ConvertToDouble()) 152 | << std::endl; 153 | 154 | auto start = clock_start(); 155 | ArithmeticCirc* circ = new ArithmeticCirc(num_party, party, io, he); 156 | double timeused = time_from(start); 157 | std::cout << "Arith triple gen\t" << timeused / (1000 * circ->num_triples_pool) << " ms" << std::endl; 158 | test_mul(io, he, circ, circ->num_triples_pool); 159 | // delete he; 160 | delete io; 161 | } 162 | -------------------------------------------------------------------------------- /emp-aby/triple-providers/mp-bit-triple.hpp: -------------------------------------------------------------------------------- 1 | template 2 | MPBitTripleProvider::MPBitTripleProvider(int num_party, int party, ThreadPool* pool, MPIOChannel* io, 3 | const int buffer_length) { 4 | this->party = party; 5 | this->threads = pool->size(); 6 | if (this->threads % 2) 7 | error("MPBitTripleProvider needs even number of threads!"); 8 | this->io = io; 9 | this->num_party = num_party; 10 | this->BUFFER_SZ = buffer_length; 11 | this->pool = pool; 12 | a_bool.resize(BUFFER_SZ); 13 | b_bool.resize(BUFFER_SZ); 14 | c_bool.resize(BUFFER_SZ); 15 | sent = new bool[num_party]; 16 | received = new bool[num_party]; 17 | mac = (block**)(malloc((threads / 2) * sizeof(block*))); 18 | key = (block**)(malloc((threads / 2) * sizeof(block*))); 19 | key_star = (block**)(malloc((threads / 2) * sizeof(block*))); 20 | xor_mac = (block**)(malloc((threads / 2) * sizeof(block*))); 21 | xor_key = (block**)(malloc((threads / 2) * sizeof(block*))); 22 | for (int i = 0; i < threads / 2; ++i) { 23 | mac[i] = new block[emp::ferret_b13.n]; 24 | key[i] = new block[emp::ferret_b13.n]; 25 | xor_mac[i] = new block[BUFFER_SZ]; 26 | xor_key[i] = new block[BUFFER_SZ]; 27 | key_star[i] = new block[BUFFER_SZ]; 28 | } 29 | delta = gen_delta(); 30 | cot_sender.resize((num_party) * sizeof(FerretCOT*)); 31 | cot_receiver.resize((num_party) * sizeof(FerretCOT*)); 32 | // Create the semi-honest Ferret instances. 33 | for (int i = 0; i < num_party; ++i) { 34 | if (party != i + 1) { 35 | cot_sender[i] = new FerretCOT(ALICE, 1, &(io->get(i + 1, party > (i + 1))), false, false); 36 | cot_receiver[i] = new FerretCOT(BOB, 1, &(io->get(i + 1, party < (i + 1))), false, false); 37 | } 38 | } 39 | seed_gen(); 40 | int l = ferret_b13.log_bin_sz * ferret_b13.t + ferret_b13.k + 128; 41 | bool* pre_choice = new bool[l]; 42 | prg.random_bool(pre_choice, l); 43 | vector> res; 44 | for (int i = 0; i < num_party; ++i) 45 | if (party < (i + 1)) { 46 | res.push_back(pool->enqueue([this, io, i, pre_choice]() { 47 | cot_sender[i]->setup(this->delta, 48 | "./data/" + std::to_string(this->party) + "to" + std::to_string(i + 1) + ".txt", 49 | pre_choice, seed); 50 | io->flush(i + 1); 51 | })); 52 | res.push_back(pool->enqueue([this, io, i, pre_choice]() { 53 | cot_receiver[i]->setup( 54 | "./data/" + std::to_string(this->party) + "from" + std::to_string(i + 1) + ".txt", pre_choice, 55 | seed); 56 | io->flush(i + 1); 57 | })); 58 | } 59 | else if (party > (i + 1)) { 60 | res.push_back(pool->enqueue([this, io, i, pre_choice]() { 61 | cot_receiver[i]->setup( 62 | "./data/" + std::to_string(this->party) + "from" + std::to_string(i + 1) + ".txt", pre_choice, 63 | seed); 64 | io->flush(i + 1); 65 | })); 66 | res.push_back(pool->enqueue([this, io, i, pre_choice]() { 67 | cot_sender[i]->setup(this->delta, 68 | "./data/" + std::to_string(this->party) + "to" + std::to_string(i + 1) + ".txt", 69 | pre_choice, seed); 70 | io->flush(i + 1); 71 | })); 72 | } 73 | 74 | for (auto& v : res) 75 | v.get(); 76 | res.clear(); 77 | 78 | delete[] pre_choice; 79 | } 80 | 81 | template 82 | MPBitTripleProvider::~MPBitTripleProvider() { 83 | delete io; 84 | delete pool; 85 | delete[] mac; 86 | delete[] key; 87 | delete[] key_star; 88 | delete[] xor_mac; 89 | delete[] xor_key; 90 | delete[] sent; 91 | delete[] received; 92 | for (int i = 0; i < num_party; ++i) { 93 | if (i + 1 == party) 94 | continue; 95 | delete cot_sender[i]; 96 | delete cot_receiver[i]; 97 | } 98 | } 99 | 100 | template 101 | void MPBitTripleProvider::get_triple(block* a, block* b, block* c) { 102 | this->get_triple((bool*)a_bool.data(), (bool*)b_bool.data(), (bool*)c_bool.data()); 103 | bool_to_block_arr(a, (bool*)a_bool.data(), BUFFER_SZ); 104 | bool_to_block_arr(b, (bool*)b_bool.data(), BUFFER_SZ); 105 | bool_to_block_arr(c, (bool*)c_bool.data(), BUFFER_SZ); 106 | } 107 | 108 | template 109 | void MPBitTripleProvider::get_triple(bool* a, bool* b, bool* c) { 110 | prg.random_bool(a, BUFFER_SZ); 111 | memset(sent, 0, num_party * sizeof(bool)); 112 | memset(received, 0, num_party * sizeof(bool)); 113 | ch[0] = zero_block; 114 | ch[1] = all_one_block; 115 | block **w, **s; 116 | w = (block**)malloc((threads / 2) * sizeof(block*)); 117 | s = (block**)malloc(threads * sizeof(block*)); 118 | for (int i = 0; i < threads / 2; ++i) { 119 | w[i] = new block[BUFFER_SZ]; 120 | memset(w[i], 0, BUFFER_SZ * sizeof(block)); 121 | memset(xor_mac[i], 0, BUFFER_SZ * sizeof(block)); 122 | memset(xor_key[i], 0, BUFFER_SZ * sizeof(block)); 123 | } 124 | for (int i = 0; i < threads; ++i) { 125 | s[i] = new block[BUFFER_SZ]; 126 | } 127 | seed_gen(); 128 | int num_steps = ceil((double)(num_party - 1) / ((double)threads / 2)); 129 | vector> res; 130 | for (int i = 0; i < threads / 2; ++i) { 131 | res.push_back(pool->enqueue([this, i, num_steps, a, s] { 132 | for (int step = 1; step <= num_steps; ++step) { 133 | int send_to = ((party - 1) + step + num_steps * i) % num_party; 134 | send(send_to, a, s[i], i); 135 | } 136 | })); 137 | } 138 | for (int i = 0; i < threads / 2; ++i) { 139 | res.push_back(pool->enqueue([this, i, num_steps, b, w, s] { 140 | for (int step = 1; step <= num_steps; ++step) { 141 | int receive_from = ((party - 1) + num_party - step - num_steps * i) % num_party; 142 | recv(receive_from, b, w[i], s[i + threads / 2], i); 143 | } 144 | })); 145 | } 146 | 147 | for (auto& v : res) 148 | v.get(); 149 | res.clear(); 150 | 151 | for (int i = 1; i < threads / 2; ++i) { 152 | xorBlocks_arr(w[0], w[0], w[i], BUFFER_SZ); 153 | } 154 | 155 | for (int i = 0; i < BUFFER_SZ; ++i) { 156 | w[0][i] = w[0][i] & ch[b[i]]; 157 | } 158 | 159 | for (int i = 0; i < threads / 2; ++i) { 160 | xorBlocks_arr(w[0], xor_mac[i], w[0], BUFFER_SZ); 161 | xorBlocks_arr(w[0], xor_key[i], w[0], BUFFER_SZ); 162 | } 163 | for (int i = 0; i < BUFFER_SZ; ++i) { 164 | c[i] = (a[i] & b[i]) ^ getLSB(w[0][i]); 165 | } 166 | 167 | free(s); 168 | free(w); 169 | b_set = false; 170 | } 171 | -------------------------------------------------------------------------------- /test/positive2.txt: -------------------------------------------------------------------------------- 1 | 293 421 2 | 64 64 1 3 | 4 | 2 1 60 124 128 XOR 5 | 2 1 59 123 129 AND 6 | 2 1 59 123 130 XOR 7 | 2 1 58 122 131 AND 8 | 2 1 58 122 132 XOR 9 | 2 1 57 121 133 AND 10 | 2 1 132 133 134 AND 11 | 2 1 131 134 135 XOR 12 | 2 1 57 121 136 XOR 13 | 2 1 132 136 137 AND 14 | 2 1 56 120 138 AND 15 | 2 1 56 120 139 XOR 16 | 2 1 55 119 140 AND 17 | 2 1 139 140 141 AND 18 | 2 1 138 141 142 XOR 19 | 2 1 137 142 143 AND 20 | 2 1 135 143 144 XOR 21 | 2 1 130 144 145 AND 22 | 2 1 129 145 146 XOR 23 | 2 1 55 119 147 XOR 24 | 2 1 139 147 148 AND 25 | 2 1 137 148 149 AND 26 | 2 1 130 149 150 AND 27 | 2 1 54 118 151 AND 28 | 2 1 54 118 152 XOR 29 | 2 1 53 117 153 AND 30 | 2 1 152 153 154 AND 31 | 2 1 151 154 155 XOR 32 | 2 1 53 117 156 XOR 33 | 2 1 152 156 157 AND 34 | 2 1 52 116 158 AND 35 | 2 1 52 116 159 XOR 36 | 2 1 51 115 160 AND 37 | 2 1 159 160 161 AND 38 | 2 1 158 161 162 XOR 39 | 2 1 157 162 163 AND 40 | 2 1 155 163 164 XOR 41 | 2 1 51 115 165 XOR 42 | 2 1 159 165 166 AND 43 | 2 1 157 166 167 AND 44 | 2 1 50 114 168 AND 45 | 2 1 50 114 169 XOR 46 | 2 1 49 113 170 AND 47 | 2 1 169 170 171 AND 48 | 2 1 168 171 172 XOR 49 | 2 1 49 113 173 XOR 50 | 2 1 169 173 174 AND 51 | 2 1 48 112 175 AND 52 | 2 1 48 112 176 XOR 53 | 2 1 47 111 177 AND 54 | 2 1 176 177 178 AND 55 | 2 1 175 178 179 XOR 56 | 2 1 174 179 180 AND 57 | 2 1 172 180 181 XOR 58 | 2 1 167 181 182 AND 59 | 2 1 164 182 183 XOR 60 | 2 1 150 183 184 AND 61 | 2 1 146 184 185 XOR 62 | 2 1 47 111 186 XOR 63 | 2 1 176 186 187 AND 64 | 2 1 174 187 188 AND 65 | 2 1 167 188 189 AND 66 | 2 1 150 189 190 AND 67 | 2 1 46 110 191 AND 68 | 2 1 46 110 192 XOR 69 | 2 1 45 109 193 AND 70 | 2 1 192 193 194 AND 71 | 2 1 191 194 195 XOR 72 | 2 1 45 109 196 XOR 73 | 2 1 192 196 197 AND 74 | 2 1 44 108 198 AND 75 | 2 1 44 108 199 XOR 76 | 2 1 43 107 200 AND 77 | 2 1 199 200 201 AND 78 | 2 1 198 201 202 XOR 79 | 2 1 197 202 203 AND 80 | 2 1 195 203 204 XOR 81 | 2 1 43 107 205 XOR 82 | 2 1 199 205 206 AND 83 | 2 1 197 206 207 AND 84 | 2 1 42 106 208 AND 85 | 2 1 42 106 209 XOR 86 | 2 1 41 105 210 AND 87 | 2 1 209 210 211 AND 88 | 2 1 208 211 212 XOR 89 | 2 1 41 105 213 XOR 90 | 2 1 209 213 214 AND 91 | 2 1 40 104 215 AND 92 | 2 1 40 104 216 XOR 93 | 2 1 39 103 217 AND 94 | 2 1 216 217 218 AND 95 | 2 1 215 218 219 XOR 96 | 2 1 214 219 220 AND 97 | 2 1 212 220 221 XOR 98 | 2 1 207 221 222 AND 99 | 2 1 204 222 223 XOR 100 | 2 1 39 103 224 XOR 101 | 2 1 216 224 225 AND 102 | 2 1 214 225 226 AND 103 | 2 1 207 226 227 AND 104 | 2 1 38 102 228 AND 105 | 2 1 38 102 229 XOR 106 | 2 1 37 101 230 AND 107 | 2 1 229 230 231 AND 108 | 2 1 228 231 232 XOR 109 | 2 1 37 101 233 XOR 110 | 2 1 229 233 234 AND 111 | 2 1 36 100 235 AND 112 | 2 1 36 100 236 XOR 113 | 2 1 35 99 237 AND 114 | 2 1 236 237 238 AND 115 | 2 1 235 238 239 XOR 116 | 2 1 234 239 240 AND 117 | 2 1 232 240 241 XOR 118 | 2 1 35 99 242 XOR 119 | 2 1 236 242 243 AND 120 | 2 1 234 243 244 AND 121 | 2 1 34 98 245 AND 122 | 2 1 34 98 246 XOR 123 | 2 1 33 97 247 AND 124 | 2 1 246 247 248 AND 125 | 2 1 245 248 249 XOR 126 | 2 1 33 97 250 XOR 127 | 2 1 246 250 251 AND 128 | 2 1 32 96 252 AND 129 | 2 1 32 96 253 XOR 130 | 2 1 31 95 254 AND 131 | 2 1 253 254 255 AND 132 | 2 1 252 255 256 XOR 133 | 2 1 251 256 257 AND 134 | 2 1 249 257 258 XOR 135 | 2 1 244 258 259 AND 136 | 2 1 241 259 260 XOR 137 | 2 1 227 260 261 AND 138 | 2 1 223 261 262 XOR 139 | 2 1 190 262 263 AND 140 | 2 1 185 263 264 XOR 141 | 2 1 31 95 265 XOR 142 | 2 1 253 265 266 AND 143 | 2 1 251 266 267 AND 144 | 2 1 244 267 268 AND 145 | 2 1 227 268 269 AND 146 | 2 1 190 269 270 AND 147 | 2 1 30 94 271 AND 148 | 2 1 30 94 272 XOR 149 | 2 1 29 93 273 AND 150 | 2 1 272 273 274 AND 151 | 2 1 271 274 275 XOR 152 | 2 1 29 93 276 XOR 153 | 2 1 272 276 277 AND 154 | 2 1 28 92 278 AND 155 | 2 1 28 92 279 XOR 156 | 2 1 27 91 280 AND 157 | 2 1 279 280 281 AND 158 | 2 1 278 281 282 XOR 159 | 2 1 277 282 283 AND 160 | 2 1 275 283 284 XOR 161 | 2 1 27 91 285 XOR 162 | 2 1 279 285 286 AND 163 | 2 1 277 286 287 AND 164 | 2 1 26 90 288 AND 165 | 2 1 26 90 289 XOR 166 | 2 1 25 89 290 AND 167 | 2 1 289 290 291 AND 168 | 2 1 288 291 292 XOR 169 | 2 1 25 89 293 XOR 170 | 2 1 289 293 294 AND 171 | 2 1 24 88 295 AND 172 | 2 1 24 88 296 XOR 173 | 2 1 23 87 297 AND 174 | 2 1 296 297 298 AND 175 | 2 1 295 298 299 XOR 176 | 2 1 294 299 300 AND 177 | 2 1 292 300 301 XOR 178 | 2 1 287 301 302 AND 179 | 2 1 284 302 303 XOR 180 | 2 1 23 87 304 XOR 181 | 2 1 296 304 305 AND 182 | 2 1 294 305 306 AND 183 | 2 1 287 306 307 AND 184 | 2 1 22 86 308 AND 185 | 2 1 22 86 309 XOR 186 | 2 1 21 85 310 AND 187 | 2 1 309 310 311 AND 188 | 2 1 308 311 312 XOR 189 | 2 1 21 85 313 XOR 190 | 2 1 309 313 314 AND 191 | 2 1 20 84 315 AND 192 | 2 1 20 84 316 XOR 193 | 2 1 19 83 317 AND 194 | 2 1 316 317 318 AND 195 | 2 1 315 318 319 XOR 196 | 2 1 314 319 320 AND 197 | 2 1 312 320 321 XOR 198 | 2 1 19 83 322 XOR 199 | 2 1 316 322 323 AND 200 | 2 1 314 323 324 AND 201 | 2 1 18 82 325 AND 202 | 2 1 18 82 326 XOR 203 | 2 1 17 81 327 AND 204 | 2 1 326 327 328 AND 205 | 2 1 325 328 329 XOR 206 | 2 1 17 81 330 XOR 207 | 2 1 326 330 331 AND 208 | 2 1 16 80 332 AND 209 | 2 1 16 80 333 XOR 210 | 2 1 15 79 334 AND 211 | 2 1 333 334 335 AND 212 | 2 1 332 335 336 XOR 213 | 2 1 331 336 337 AND 214 | 2 1 329 337 338 XOR 215 | 2 1 324 338 339 AND 216 | 2 1 321 339 340 XOR 217 | 2 1 307 340 341 AND 218 | 2 1 303 341 342 XOR 219 | 2 1 15 79 343 XOR 220 | 2 1 333 343 344 AND 221 | 2 1 331 344 345 AND 222 | 2 1 324 345 346 AND 223 | 2 1 307 346 347 AND 224 | 2 1 14 78 348 AND 225 | 2 1 14 78 349 XOR 226 | 2 1 13 77 350 AND 227 | 2 1 349 350 351 AND 228 | 2 1 348 351 352 XOR 229 | 2 1 13 77 353 XOR 230 | 2 1 349 353 354 AND 231 | 2 1 12 76 355 AND 232 | 2 1 12 76 356 XOR 233 | 2 1 11 75 357 AND 234 | 2 1 356 357 358 AND 235 | 2 1 355 358 359 XOR 236 | 2 1 354 359 360 AND 237 | 2 1 352 360 361 XOR 238 | 2 1 11 75 362 XOR 239 | 2 1 356 362 363 AND 240 | 2 1 354 363 364 AND 241 | 2 1 10 74 365 AND 242 | 2 1 10 74 366 XOR 243 | 2 1 9 73 367 AND 244 | 2 1 366 367 368 AND 245 | 2 1 365 368 369 XOR 246 | 2 1 9 73 370 XOR 247 | 2 1 366 370 371 AND 248 | 2 1 8 72 372 AND 249 | 2 1 8 72 373 XOR 250 | 2 1 7 71 374 AND 251 | 2 1 373 374 375 AND 252 | 2 1 372 375 376 XOR 253 | 2 1 371 376 377 AND 254 | 2 1 369 377 378 XOR 255 | 2 1 364 378 379 AND 256 | 2 1 361 379 380 XOR 257 | 2 1 7 71 381 XOR 258 | 2 1 373 381 382 AND 259 | 2 1 371 382 383 AND 260 | 2 1 364 383 384 AND 261 | 2 1 6 70 385 AND 262 | 2 1 6 70 386 XOR 263 | 2 1 5 69 387 AND 264 | 2 1 386 387 388 AND 265 | 2 1 385 388 389 XOR 266 | 2 1 5 69 390 XOR 267 | 2 1 386 390 391 AND 268 | 2 1 4 68 392 AND 269 | 2 1 4 68 393 XOR 270 | 2 1 3 67 394 AND 271 | 2 1 393 394 395 AND 272 | 2 1 392 395 396 XOR 273 | 2 1 391 396 397 AND 274 | 2 1 389 397 398 XOR 275 | 2 1 3 67 399 XOR 276 | 2 1 393 399 400 AND 277 | 2 1 391 400 401 AND 278 | 2 1 2 66 402 AND 279 | 2 1 2 66 403 XOR 280 | 2 1 1 65 404 AND 281 | 2 1 403 404 405 AND 282 | 2 1 402 405 406 XOR 283 | 2 1 1 65 407 XOR 284 | 2 1 403 407 408 AND 285 | 2 1 0 64 409 AND 286 | 2 1 408 409 410 AND 287 | 2 1 406 410 411 XOR 288 | 2 1 401 411 412 AND 289 | 2 1 398 412 413 XOR 290 | 2 1 384 413 414 AND 291 | 2 1 380 414 415 XOR 292 | 2 1 347 415 416 AND 293 | 2 1 342 416 417 XOR 294 | 2 1 270 417 418 AND 295 | 2 1 264 418 419 XOR 296 | 2 1 128 419 420 XOR 297 | -------------------------------------------------------------------------------- /test/a2bconverter.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/converter/a2bconverter.h" 2 | #include "emp-aby/io/multi-io.hpp" 3 | using namespace emp; 4 | 5 | #include 6 | 7 | int party, port; 8 | 9 | const static int threads = 4; 10 | 11 | int num_party; 12 | 13 | template 14 | void check(MPIOChannel* io, bool* b, int64_t* a, int n, long long q, int l, string msg, bool mod = true) { 15 | int length = l * n; 16 | // io->sync(); 17 | if (party == ALICE) { 18 | bool* tmp_b = new bool[length]; 19 | int64_t* tmp_a = new int64_t[n]; 20 | 21 | for (int i = 2; i <= num_party; ++i) { 22 | io->recv_bool(i, tmp_b, length); 23 | xorBools_arr(b, b, tmp_b, length); 24 | 25 | io->recv_data(i, tmp_a, n * sizeof(int64_t)); 26 | io->flush(i); 27 | for (int j = 0; j < n; ++j) { 28 | a[j] = (a[j] + tmp_a[j]) % q; 29 | } 30 | } 31 | int64_t* check = new int64_t[n]; 32 | memset(check, 0, n * sizeof(int64_t)); 33 | // std::cout << l << " " << q << " " << n << std::endl; 34 | for (int i = 0; i < l; ++i) { 35 | for (int j = 0; j < n; ++j) { 36 | if (b[j * l + i]) { 37 | check[j] += (1L << i); 38 | } 39 | if (mod) 40 | check[j] %= q; 41 | } 42 | } 43 | for (int i = 0; i < n; ++i) 44 | if (check[i] != a[i]) { 45 | std::cout << i << " " << check[i] << " " << a[i] << std::endl; 46 | std::cout << msg << " "; 47 | error("Test failed!"); 48 | } 49 | 50 | // std::cout << msg << " Test passed" << std::endl; 51 | 52 | delete[] tmp_a; 53 | delete[] tmp_b; 54 | delete[] check; 55 | } 56 | else { 57 | io->send_bool(ALICE, b, length); 58 | io->send_data(ALICE, a, n * sizeof(int64_t)); 59 | io->flush(ALICE); 60 | } 61 | } 62 | 63 | template 64 | void test_a2b(MPIOChannel* io, A2BConverter* converter, HE* he, size_t num = 1, double comm_offset = 0) { 65 | PRG prg; 66 | const long long int q = he->q; 67 | const int l = ceil(log2(q)); 68 | int64_t* a = new int64_t[num]; 69 | prg.random_data(a, num * sizeof(int64_t)); 70 | for (int i = 0; i < num; ++i) { 71 | a[i] = ((a[i] % he->q) + he->q) % he->q; 72 | } 73 | 74 | bool* b = new bool[num * l]; 75 | auto start = clock_start(); 76 | converter->convert(b, a, num); 77 | double timeused = time_from(start); 78 | double online_comm = io->get_total_bytes_sent() - comm_offset; 79 | check(io, b, a, num, q, l, "A2B"); 80 | std::cout << party << " A2B online time conversion per 32-bit: " << timeused / (num * 1000) << " ms\t" << std::endl; 81 | std::cout << party << " A2B online comm conversion per 32-bit: " << online_comm / (num) << " KB\t" << std::endl; 82 | } 83 | 84 | template 85 | void test_rand_ab_shares(MPIOChannel* io, A2BConverter* converter, HE* he, size_t length = 1) { 86 | const int l = ceil(log2(he->q)); 87 | bool* b = new bool[length * l]; 88 | int64_t* a = new int64_t[length]; 89 | 90 | int waste = converter->rand_ab_shares(a, b, length); 91 | check(io, b, a, length - waste, he->q, l, "Random shares gen", false); 92 | } 93 | 94 | int main(int argc, char** argv) { 95 | if (argc < 4) { 96 | std::cout << "Format: a2bconverter PartyID port num_parties" << std::endl; 97 | exit(0); 98 | } 99 | parse_party_and_port(argv, &party, &port); 100 | num_party = atoi(argv[3]); 101 | 102 | std::vector> net_config; 103 | 104 | if (argc == 5) { 105 | const char* file = argv[4]; 106 | FILE* f = fopen(file, "r"); 107 | for (int i = 0; i < num_party; ++i) { 108 | char* c = (char*)malloc(15 * sizeof(char)); 109 | uint p; 110 | fscanf(f, "%s %d\n", c, &p); 111 | std::string s(c); 112 | net_config.push_back(std::make_pair(s, p)); 113 | fflush(f); 114 | } 115 | fclose(f); 116 | } 117 | else { 118 | for (int i = 0; i < num_party; ++i) { 119 | std::string s = "127.0.0.1"; 120 | uint p = (port + 4 * num_party * i); 121 | net_config.push_back(std::make_pair(s, p)); 122 | } 123 | } 124 | 125 | ThreadPool pool(threads); 126 | 127 | MultiIO* io = new MultiIO(party, num_party, net_config); 128 | // io->extra_io(); 129 | io->setup_ot_ios(); 130 | std::cout << "io setup" << std::endl; 131 | 132 | size_t BUFFER_SZ = ((ferret_b13.n - ferret_b13.k - ferret_b13.t * ferret_b13.log_bin_sz_pre - 128) / 128) * 128; 133 | auto start = clock_start(); 134 | MPSIMDCircExec* simd_circ = new MPSIMDCircExec(num_party, party, &pool, io); 135 | double timeused = time_from(start); 136 | double triple_comm = io->get_total_bytes_sent(); 137 | // std::cout << party << "\tTriple Generation time\t" << timeused / (2 * BUFFER_SZ * 1000) << " ms\t" << std::endl; 138 | // std::cout << party << "\tTriple Generation comm\t" << triple_comm / (2 * BUFFER_SZ * 1000) << " KB\t" << std::endl; 139 | double per_triple_time = timeused / (2 * BUFFER_SZ * 1000); 140 | double per_triple_comm = triple_comm / (2 * BUFFER_SZ); 141 | const long long int modulus = (1L << 32) - (1L << 30) + 1; 142 | HE* he = new HE(num_party, io, &pool, party, modulus); 143 | he->multiplication_keygen(); 144 | he->rotation_keygen(); 145 | std::cout << "p = " << he->cc->GetCryptoParameters()->GetPlaintextModulus() << std::endl; 146 | std::cout << "n = " << he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2 << std::endl; 147 | std::cout << "log2 q = " << log2(he->cc->GetCryptoParameters()->GetElementParams()->GetModulus().ConvertToDouble()) 148 | << std::endl; 149 | int pool_size = std::max(20 / num_party - 1, 0) * num_party + num_party; 150 | if (num_party > 16) 151 | pool_size = 16; 152 | std::cout << "pool size " << pool_size << std::endl; 153 | start = clock_start(); 154 | A2BConverter* converter = 155 | new A2BConverter(num_party, party, io, &pool, he, simd_circ, pool_size); 156 | timeused = time_from(start); 157 | double offline_comm = io->get_total_bytes_sent(); 158 | int offline_pool = (converter->ab_share_pool - converter->num_rejected); 159 | std::cout << party << "\tA2B offline\t" << timeused / (offline_pool * 1000) + (per_triple_time * 325) << " ms\t" 160 | << std::endl; 161 | std::cout << party << "\tA2B offline comm\t" 162 | << ((offline_comm - triple_comm) / (offline_pool)) + (per_triple_comm * 325) << " KB\t" << std::endl; 163 | 164 | int online_pool = min(converter->ab_share_pool - converter->num_rejected, BUFFER_SZ / 325); 165 | test_a2b(io, converter, he, online_pool, offline_comm); 166 | delete he; 167 | delete io; 168 | } 169 | -------------------------------------------------------------------------------- /test/adder64.txt: -------------------------------------------------------------------------------- 1 | 376 504 2 | 64 64 64 3 | 4 | 2 1 63 127 376 XOR 5 | 2 1 62 126 375 XOR 6 | 2 1 61 125 374 XOR 7 | 2 1 60 124 373 XOR 8 | 2 1 59 123 372 XOR 9 | 2 1 58 122 371 XOR 10 | 2 1 57 121 370 XOR 11 | 2 1 56 120 369 XOR 12 | 2 1 55 119 368 XOR 13 | 2 1 54 118 367 XOR 14 | 2 1 53 117 366 XOR 15 | 2 1 52 116 365 XOR 16 | 2 1 51 115 364 XOR 17 | 2 1 50 114 363 XOR 18 | 2 1 49 113 362 XOR 19 | 2 1 48 112 361 XOR 20 | 2 1 47 111 360 XOR 21 | 2 1 46 110 359 XOR 22 | 2 1 45 109 358 XOR 23 | 2 1 44 108 357 XOR 24 | 2 1 43 107 356 XOR 25 | 2 1 42 106 355 XOR 26 | 2 1 41 105 354 XOR 27 | 2 1 40 104 353 XOR 28 | 2 1 39 103 352 XOR 29 | 2 1 38 102 351 XOR 30 | 2 1 37 101 350 XOR 31 | 2 1 36 100 349 XOR 32 | 2 1 35 99 348 XOR 33 | 2 1 34 98 347 XOR 34 | 2 1 33 97 346 XOR 35 | 2 1 32 96 345 XOR 36 | 2 1 31 95 344 XOR 37 | 2 1 30 94 343 XOR 38 | 2 1 29 93 342 XOR 39 | 2 1 28 92 341 XOR 40 | 2 1 27 91 340 XOR 41 | 2 1 26 90 339 XOR 42 | 2 1 25 89 338 XOR 43 | 2 1 24 88 337 XOR 44 | 2 1 23 87 336 XOR 45 | 2 1 22 86 335 XOR 46 | 2 1 21 85 334 XOR 47 | 2 1 20 84 333 XOR 48 | 2 1 19 83 332 XOR 49 | 2 1 18 82 331 XOR 50 | 2 1 17 81 330 XOR 51 | 2 1 16 80 329 XOR 52 | 2 1 15 79 328 XOR 53 | 2 1 14 78 327 XOR 54 | 2 1 13 77 326 XOR 55 | 2 1 12 76 325 XOR 56 | 2 1 11 75 324 XOR 57 | 2 1 10 74 323 XOR 58 | 2 1 9 73 322 XOR 59 | 2 1 8 72 321 XOR 60 | 2 1 7 71 320 XOR 61 | 2 1 6 70 319 XOR 62 | 2 1 5 69 318 XOR 63 | 2 1 4 68 317 XOR 64 | 2 1 3 67 316 XOR 65 | 2 1 2 66 315 XOR 66 | 2 1 1 65 314 XOR 67 | 2 1 0 64 440 XOR 68 | 2 1 0 64 377 AND 69 | 2 1 65 377 129 XOR 70 | 2 1 1 377 128 XOR 71 | 2 1 128 129 130 AND 72 | 2 1 130 377 378 XOR 73 | 2 1 66 378 132 XOR 74 | 2 1 2 378 131 XOR 75 | 2 1 131 132 133 AND 76 | 2 1 133 378 379 XOR 77 | 2 1 67 379 135 XOR 78 | 2 1 3 379 134 XOR 79 | 2 1 134 135 136 AND 80 | 2 1 136 379 380 XOR 81 | 2 1 68 380 138 XOR 82 | 2 1 4 380 137 XOR 83 | 2 1 137 138 139 AND 84 | 2 1 139 380 381 XOR 85 | 2 1 69 381 141 XOR 86 | 2 1 5 381 140 XOR 87 | 2 1 140 141 142 AND 88 | 2 1 142 381 382 XOR 89 | 2 1 70 382 144 XOR 90 | 2 1 6 382 143 XOR 91 | 2 1 143 144 145 AND 92 | 2 1 145 382 383 XOR 93 | 2 1 71 383 147 XOR 94 | 2 1 7 383 146 XOR 95 | 2 1 146 147 148 AND 96 | 2 1 148 383 384 XOR 97 | 2 1 72 384 150 XOR 98 | 2 1 8 384 149 XOR 99 | 2 1 149 150 151 AND 100 | 2 1 151 384 385 XOR 101 | 2 1 73 385 153 XOR 102 | 2 1 9 385 152 XOR 103 | 2 1 152 153 154 AND 104 | 2 1 154 385 386 XOR 105 | 2 1 74 386 156 XOR 106 | 2 1 10 386 155 XOR 107 | 2 1 155 156 157 AND 108 | 2 1 157 386 387 XOR 109 | 2 1 75 387 159 XOR 110 | 2 1 11 387 158 XOR 111 | 2 1 158 159 160 AND 112 | 2 1 160 387 388 XOR 113 | 2 1 76 388 162 XOR 114 | 2 1 12 388 161 XOR 115 | 2 1 161 162 163 AND 116 | 2 1 163 388 389 XOR 117 | 2 1 77 389 165 XOR 118 | 2 1 13 389 164 XOR 119 | 2 1 164 165 166 AND 120 | 2 1 166 389 390 XOR 121 | 2 1 78 390 168 XOR 122 | 2 1 14 390 167 XOR 123 | 2 1 167 168 169 AND 124 | 2 1 169 390 391 XOR 125 | 2 1 79 391 171 XOR 126 | 2 1 15 391 170 XOR 127 | 2 1 170 171 172 AND 128 | 2 1 172 391 392 XOR 129 | 2 1 80 392 174 XOR 130 | 2 1 16 392 173 XOR 131 | 2 1 173 174 175 AND 132 | 2 1 175 392 393 XOR 133 | 2 1 81 393 177 XOR 134 | 2 1 17 393 176 XOR 135 | 2 1 176 177 178 AND 136 | 2 1 178 393 394 XOR 137 | 2 1 82 394 180 XOR 138 | 2 1 18 394 179 XOR 139 | 2 1 179 180 181 AND 140 | 2 1 181 394 395 XOR 141 | 2 1 83 395 183 XOR 142 | 2 1 19 395 182 XOR 143 | 2 1 182 183 184 AND 144 | 2 1 184 395 396 XOR 145 | 2 1 84 396 186 XOR 146 | 2 1 20 396 185 XOR 147 | 2 1 185 186 187 AND 148 | 2 1 187 396 397 XOR 149 | 2 1 85 397 189 XOR 150 | 2 1 21 397 188 XOR 151 | 2 1 188 189 190 AND 152 | 2 1 190 397 398 XOR 153 | 2 1 86 398 192 XOR 154 | 2 1 22 398 191 XOR 155 | 2 1 191 192 193 AND 156 | 2 1 193 398 399 XOR 157 | 2 1 87 399 195 XOR 158 | 2 1 23 399 194 XOR 159 | 2 1 194 195 196 AND 160 | 2 1 196 399 400 XOR 161 | 2 1 88 400 198 XOR 162 | 2 1 24 400 197 XOR 163 | 2 1 197 198 199 AND 164 | 2 1 199 400 401 XOR 165 | 2 1 89 401 201 XOR 166 | 2 1 25 401 200 XOR 167 | 2 1 200 201 202 AND 168 | 2 1 202 401 402 XOR 169 | 2 1 90 402 204 XOR 170 | 2 1 26 402 203 XOR 171 | 2 1 203 204 205 AND 172 | 2 1 205 402 403 XOR 173 | 2 1 91 403 207 XOR 174 | 2 1 27 403 206 XOR 175 | 2 1 206 207 208 AND 176 | 2 1 208 403 404 XOR 177 | 2 1 341 404 468 XOR 178 | 2 1 92 404 210 XOR 179 | 2 1 28 404 209 XOR 180 | 2 1 209 210 211 AND 181 | 2 1 211 404 405 XOR 182 | 2 1 342 405 469 XOR 183 | 2 1 340 403 467 XOR 184 | 2 1 93 405 213 XOR 185 | 2 1 29 405 212 XOR 186 | 2 1 212 213 214 AND 187 | 2 1 214 405 406 XOR 188 | 2 1 343 406 470 XOR 189 | 2 1 339 402 466 XOR 190 | 2 1 94 406 216 XOR 191 | 2 1 30 406 215 XOR 192 | 2 1 215 216 217 AND 193 | 2 1 217 406 407 XOR 194 | 2 1 338 401 465 XOR 195 | 2 1 31 407 218 XOR 196 | 2 1 344 407 471 XOR 197 | 2 1 337 400 464 XOR 198 | 2 1 95 407 219 XOR 199 | 2 1 218 219 220 AND 200 | 2 1 220 407 408 XOR 201 | 2 1 345 408 472 XOR 202 | 2 1 336 399 463 XOR 203 | 2 1 96 408 222 XOR 204 | 2 1 32 408 221 XOR 205 | 2 1 221 222 223 AND 206 | 2 1 223 408 409 XOR 207 | 2 1 346 409 473 XOR 208 | 2 1 335 398 462 XOR 209 | 2 1 97 409 225 XOR 210 | 2 1 33 409 224 XOR 211 | 2 1 224 225 226 AND 212 | 2 1 226 409 410 XOR 213 | 2 1 347 410 474 XOR 214 | 2 1 334 397 461 XOR 215 | 2 1 98 410 228 XOR 216 | 2 1 34 410 227 XOR 217 | 2 1 227 228 229 AND 218 | 2 1 229 410 411 XOR 219 | 2 1 333 396 460 XOR 220 | 2 1 35 411 230 XOR 221 | 2 1 348 411 475 XOR 222 | 2 1 332 395 459 XOR 223 | 2 1 99 411 231 XOR 224 | 2 1 230 231 232 AND 225 | 2 1 232 411 412 XOR 226 | 2 1 349 412 476 XOR 227 | 2 1 331 394 458 XOR 228 | 2 1 100 412 234 XOR 229 | 2 1 36 412 233 XOR 230 | 2 1 233 234 235 AND 231 | 2 1 235 412 413 XOR 232 | 2 1 350 413 477 XOR 233 | 2 1 330 393 457 XOR 234 | 2 1 101 413 237 XOR 235 | 2 1 37 413 236 XOR 236 | 2 1 236 237 238 AND 237 | 2 1 238 413 414 XOR 238 | 2 1 351 414 478 XOR 239 | 2 1 329 392 456 XOR 240 | 2 1 102 414 240 XOR 241 | 2 1 38 414 239 XOR 242 | 2 1 239 240 241 AND 243 | 2 1 241 414 415 XOR 244 | 2 1 328 391 455 XOR 245 | 2 1 39 415 242 XOR 246 | 2 1 352 415 479 XOR 247 | 2 1 327 390 454 XOR 248 | 2 1 103 415 243 XOR 249 | 2 1 242 243 244 AND 250 | 2 1 244 415 416 XOR 251 | 2 1 353 416 480 XOR 252 | 2 1 326 389 453 XOR 253 | 2 1 104 416 246 XOR 254 | 2 1 40 416 245 XOR 255 | 2 1 245 246 247 AND 256 | 2 1 247 416 417 XOR 257 | 2 1 354 417 481 XOR 258 | 2 1 325 388 452 XOR 259 | 2 1 105 417 249 XOR 260 | 2 1 41 417 248 XOR 261 | 2 1 248 249 250 AND 262 | 2 1 250 417 418 XOR 263 | 2 1 355 418 482 XOR 264 | 2 1 324 387 451 XOR 265 | 2 1 106 418 252 XOR 266 | 2 1 42 418 251 XOR 267 | 2 1 251 252 253 AND 268 | 2 1 253 418 419 XOR 269 | 2 1 323 386 450 XOR 270 | 2 1 43 419 254 XOR 271 | 2 1 356 419 483 XOR 272 | 2 1 322 385 449 XOR 273 | 2 1 107 419 255 XOR 274 | 2 1 254 255 256 AND 275 | 2 1 256 419 420 XOR 276 | 2 1 357 420 484 XOR 277 | 2 1 321 384 448 XOR 278 | 2 1 108 420 258 XOR 279 | 2 1 44 420 257 XOR 280 | 2 1 257 258 259 AND 281 | 2 1 259 420 421 XOR 282 | 2 1 358 421 485 XOR 283 | 2 1 320 383 447 XOR 284 | 2 1 109 421 261 XOR 285 | 2 1 45 421 260 XOR 286 | 2 1 260 261 262 AND 287 | 2 1 262 421 422 XOR 288 | 2 1 359 422 486 XOR 289 | 2 1 319 382 446 XOR 290 | 2 1 110 422 264 XOR 291 | 2 1 46 422 263 XOR 292 | 2 1 263 264 265 AND 293 | 2 1 265 422 423 XOR 294 | 2 1 318 381 445 XOR 295 | 2 1 47 423 266 XOR 296 | 2 1 360 423 487 XOR 297 | 2 1 317 380 444 XOR 298 | 2 1 111 423 267 XOR 299 | 2 1 266 267 268 AND 300 | 2 1 268 423 424 XOR 301 | 2 1 361 424 488 XOR 302 | 2 1 316 379 443 XOR 303 | 2 1 112 424 270 XOR 304 | 2 1 48 424 269 XOR 305 | 2 1 269 270 271 AND 306 | 2 1 271 424 425 XOR 307 | 2 1 362 425 489 XOR 308 | 2 1 315 378 442 XOR 309 | 2 1 113 425 273 XOR 310 | 2 1 49 425 272 XOR 311 | 2 1 272 273 274 AND 312 | 2 1 274 425 426 XOR 313 | 2 1 363 426 490 XOR 314 | 2 1 314 377 441 XOR 315 | 2 1 114 426 276 XOR 316 | 2 1 50 426 275 XOR 317 | 2 1 275 276 277 AND 318 | 2 1 277 426 427 XOR 319 | 2 1 115 427 279 XOR 320 | 2 1 51 427 278 XOR 321 | 2 1 278 279 280 AND 322 | 2 1 280 427 428 XOR 323 | 2 1 116 428 282 XOR 324 | 2 1 52 428 281 XOR 325 | 2 1 281 282 283 AND 326 | 2 1 283 428 429 XOR 327 | 2 1 117 429 285 XOR 328 | 2 1 53 429 284 XOR 329 | 2 1 284 285 286 AND 330 | 2 1 286 429 430 XOR 331 | 2 1 118 430 288 XOR 332 | 2 1 54 430 287 XOR 333 | 2 1 287 288 289 AND 334 | 2 1 289 430 431 XOR 335 | 2 1 119 431 291 XOR 336 | 2 1 55 431 290 XOR 337 | 2 1 290 291 292 AND 338 | 2 1 292 431 432 XOR 339 | 2 1 120 432 294 XOR 340 | 2 1 56 432 293 XOR 341 | 2 1 293 294 295 AND 342 | 2 1 295 432 433 XOR 343 | 2 1 370 433 497 XOR 344 | 2 1 121 433 297 XOR 345 | 2 1 57 433 296 XOR 346 | 2 1 296 297 298 AND 347 | 2 1 298 433 434 XOR 348 | 2 1 371 434 498 XOR 349 | 2 1 369 432 496 XOR 350 | 2 1 122 434 300 XOR 351 | 2 1 58 434 299 XOR 352 | 2 1 299 300 301 AND 353 | 2 1 301 434 435 XOR 354 | 2 1 372 435 499 XOR 355 | 2 1 368 431 495 XOR 356 | 2 1 123 435 303 XOR 357 | 2 1 59 435 302 XOR 358 | 2 1 302 303 304 AND 359 | 2 1 304 435 436 XOR 360 | 2 1 367 430 494 XOR 361 | 2 1 60 436 305 XOR 362 | 2 1 373 436 500 XOR 363 | 2 1 366 429 493 XOR 364 | 2 1 124 436 306 XOR 365 | 2 1 305 306 307 AND 366 | 2 1 307 436 437 XOR 367 | 2 1 374 437 501 XOR 368 | 2 1 365 428 492 XOR 369 | 2 1 125 437 309 XOR 370 | 2 1 61 437 308 XOR 371 | 2 1 308 309 310 AND 372 | 2 1 310 437 438 XOR 373 | 2 1 375 438 502 XOR 374 | 2 1 364 427 491 XOR 375 | 2 1 126 438 312 XOR 376 | 2 1 62 438 311 XOR 377 | 2 1 311 312 313 AND 378 | 2 1 313 438 439 XOR 379 | 2 1 376 439 503 XOR 380 | -------------------------------------------------------------------------------- /test/he.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/he_interface.hpp" 2 | #include "emp-aby/io/multi-io.hpp" 3 | using namespace emp; 4 | 5 | #include 6 | #include 7 | int party, port; 8 | 9 | const static int threads = 4; 10 | 11 | int num_party; 12 | 13 | template 14 | void test_decrypt(HE* he, MPIOChannel* io) { 15 | int n = he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2; 16 | 17 | if (party == ALICE) { 18 | std::vector original(n); 19 | PRG prg; 20 | prg.random_data(original.data(), n * sizeof(int64_t)); 21 | for (int i = 0; i < n; ++i) { 22 | original[i] = ((original[i] % he->q) + he->q) % 2; 23 | } 24 | 25 | lbcrypto::Plaintext plaintext = he->cc->MakePackedPlaintext(original); 26 | 27 | auto ciphertext = he->cc->Encrypt(he->pk, plaintext); 28 | 29 | he->serialize_sendall(ciphertext); 30 | auto partial_dec = he->decrypt_partial({ciphertext}); 31 | 32 | std::vector> partialCiphertextVec; 33 | partialCiphertextVec.push_back(partial_dec[0]); 34 | for (int i = 2; i <= num_party; ++i) { 35 | he->deserialize_recv(partial_dec, i); 36 | partialCiphertextVec.push_back(partial_dec[0]); 37 | } 38 | 39 | lbcrypto::Plaintext new_ptxt; 40 | he->cc->MultipartyDecryptFusion(partialCiphertextVec, &new_ptxt); 41 | // new_ptxt->SetLength(plaintext->GetLength()); 42 | 43 | vector dec_vals = new_ptxt->GetPackedValue(); 44 | for (int i = 0; i < n; ++i) { 45 | if (dec_vals[i] != original[i]) 46 | if ((dec_vals[i] + he->q) % he->q != original[i]) { 47 | std::cout << (dec_vals[i] + he->q) % he->q << " " << original[i] << "\n"; 48 | error("Decryption not working!"); 49 | } 50 | } 51 | std::cout << "Decryption test passed!" << std::endl; 52 | } 53 | else { 54 | lbcrypto::Ciphertext ciphertext; 55 | he->deserialize_recv(ciphertext, ALICE); 56 | auto partial_dec = he->decrypt_partial({ciphertext}); 57 | 58 | he->serialize_send(partial_dec, ALICE); 59 | } 60 | } 61 | 62 | template 63 | void test_rotation(HE* he, MPIOChannel* io) { 64 | int n = he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2; 65 | if (party == ALICE) { 66 | std::vector original(n); 67 | PRG prg; 68 | prg.random_data(original.data(), n * sizeof(int64_t)); 69 | for (int i = 0; i < n; ++i) { 70 | original[i] = ((original[i] % he->q) + he->q) % 2; 71 | } 72 | 73 | lbcrypto::Plaintext plaintext = he->cc->MakePackedPlaintext(original); 74 | auto ciphertext = he->cc->Encrypt(he->pk, plaintext); 75 | auto permutedCiphertext = he->cc->EvalRotate(ciphertext, 1); 76 | std::cout << "send to all\n"; 77 | he->serialize_sendall(permutedCiphertext); 78 | 79 | int64_t* share = new int64_t[n]; 80 | std::vector> vec = {ciphertext}; 81 | he->enc_to_share(vec, share, n); 82 | 83 | for (int i = 2; i <= he->num_party; ++i) { 84 | std::vector tmp; 85 | tmp.resize(n); 86 | io->recv_data(i, (int64_t*)tmp.data(), tmp.size() * sizeof(int64_t)); 87 | for (int j = 0; j < n; ++j) { 88 | share[j] = (he->q + share[j] + tmp[j]) % he->q; 89 | } 90 | } 91 | std::cout << "original\t"; 92 | for (int i = 0; i < 5; ++i) { 93 | std::cout << original[i] << " "; 94 | } 95 | std::cout << "\n"; 96 | 97 | std::cout << "rotated\t"; 98 | for (int i = 0; i < 5; ++i) { 99 | std::cout << share[i] << " "; 100 | } 101 | std::cout << "\n"; 102 | std::cout << "Rotation test passed!" << std::endl; 103 | } 104 | else { 105 | lbcrypto::Ciphertext ciphertext; 106 | he->deserialize_recv(ciphertext, ALICE); 107 | std::cout << "receive ciphertext\n"; 108 | int64_t* share = new int64_t[n]; 109 | std::vector> vec = {ciphertext}; 110 | he->enc_to_share(vec, share, n); 111 | io->send_data(ALICE, (int64_t*)share, n * sizeof(int64_t)); 112 | } 113 | } 114 | 115 | template 116 | void test_enc_to_share(HE* he, MPIOChannel* io, int n = 100) { 117 | PRG prg; 118 | int batch_size = he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2; 119 | n = batch_size; 120 | int64_t* original = new int64_t[n]; 121 | prg.random_data(original, n * sizeof(int64_t)); 122 | for (int i = 0; i < n; ++i) { 123 | original[i] = ((original[i] % he->q) + he->q) % he->q; 124 | } 125 | 126 | std::vector> ciphertext; 127 | 128 | if (party == ALICE) { 129 | for (int i = 0; i < ceil((double)n / (double)batch_size); ++i) { 130 | vector tmp; 131 | if ((i + 1) * batch_size <= n) { 132 | tmp.resize(batch_size); 133 | memcpy(tmp.data(), original + i * batch_size, batch_size * sizeof(int64_t)); 134 | } 135 | else { 136 | tmp.resize(n % batch_size); 137 | memcpy(tmp.data(), original + i * batch_size, (n % batch_size) * sizeof(int64_t)); 138 | } 139 | 140 | auto plaintext = he->cc->MakePackedPlaintext(tmp); 141 | // std::cout << "Plaintext: " << plaintext << std::endl; 142 | ciphertext.push_back(he->cc->Encrypt(he->pk, plaintext)); 143 | } 144 | he->serialize_sendall(ciphertext); 145 | 146 | int64_t* share = new int64_t[n]; 147 | he->enc_to_share(ciphertext, share, n); 148 | 149 | for (int i = 2; i <= he->num_party; ++i) { 150 | std::vector tmp; 151 | tmp.resize(n); 152 | io->recv_data(i, (int64_t*)tmp.data(), tmp.size() * sizeof(int64_t)); 153 | for (int j = 0; j < n; ++j) { 154 | share[j] = (he->q + share[j] + tmp[j]) % he->q; 155 | } 156 | } 157 | 158 | for (int i = 0; i < n; ++i) { 159 | if (original[i] != share[i]) { 160 | std::cout << i << " " << original[i] << " " << share[i] << std::endl; 161 | error("enc_to_share Failed!"); 162 | } 163 | } 164 | std::cout << "enc_to_share test passed!" << std::endl; 165 | } 166 | else { 167 | he->deserialize_recv(ciphertext, ALICE); 168 | 169 | he->enc_to_share(ciphertext, original, n); 170 | io->send_data(ALICE, (int64_t*)original, n * sizeof(int64_t)); 171 | } 172 | } 173 | 174 | int main(int argc, char** argv) { 175 | if (argc < 4) { 176 | std::cout << "Format: test_mp_bit_triple PartyID Port num_parties" << std::endl; 177 | exit(0); 178 | } 179 | parse_party_and_port(argv, &party, &port); 180 | num_party = atoi(argv[3]); 181 | 182 | std::vector> net_config; 183 | 184 | if (argc == 5) { 185 | const char* file = argv[4]; 186 | FILE* f = fopen(file, "r"); 187 | for (int i = 0; i < num_party; ++i) { 188 | char* c = (char*)malloc(15 * sizeof(char)); 189 | uint p; 190 | fscanf(f, "%s %d\n", c, &p); 191 | std::string s(c); 192 | net_config.push_back(std::make_pair(s, p)); 193 | fflush(f); 194 | } 195 | fclose(f); 196 | } 197 | else { 198 | for (int i = 0; i < num_party; ++i) { 199 | std::string s = "127.0.0.1"; 200 | uint p = (port + 4 * num_party * i); 201 | net_config.push_back(std::make_pair(s, p)); 202 | } 203 | } 204 | 205 | MultiIO* io = new MultiIO(party, num_party, net_config); 206 | std::cout << party << " connected \n"; 207 | ThreadPool pool(threads); 208 | 209 | io->flush(); 210 | const long long int modulus = (1L << 32) - (1 << 30) + 1; 211 | HE* he = new HE(num_party, io, &pool, party, modulus); 212 | he->multiplication_keygen(); 213 | he->rotation_keygen(); 214 | 215 | std::cout << party << " p = " << he->cc->GetCryptoParameters()->GetPlaintextModulus() << std::endl; 216 | std::cout << party << " n = " << he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2 217 | << std::endl; 218 | std::cout << party 219 | << " log2 q = " << log2(he->cc->GetCryptoParameters()->GetElementParams()->GetModulus().ConvertToDouble()) 220 | << std::endl; 221 | 222 | vector> res; 223 | // uint batch_size = he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2; 224 | test_decrypt(he, io); 225 | test_enc_to_share(he, io); 226 | test_rotation(he, io); 227 | delete io; 228 | } 229 | -------------------------------------------------------------------------------- /emp-aby/simd_interface/arithmetic-circ.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "emp-aby/he_interface.hpp" 4 | 5 | namespace emp { 6 | 7 | template 8 | class ArithmeticCirc { 9 | private: 10 | PRG prg; 11 | MPIOChannel* io; 12 | HE* he; 13 | int num_party, party; 14 | int64_t *triple_a, *triple_b, *triple_c; 15 | 16 | public: 17 | size_t num_triples_pool, num_triples = 0; 18 | void get_triples(int64_t* triple_a, int64_t* triple_b, int64_t* triple_c); 19 | ArithmeticCirc(int num_party, int party, MPIOChannel* io, HE* he); 20 | ~ArithmeticCirc(); 21 | void sum(int64_t* out, int64_t* in1, int64_t* in2, size_t length); 22 | void sub(int64_t* out, int64_t* in1, int64_t* in2, size_t length); 23 | void mult(int64_t* out, int64_t* in1, int64_t* in2, size_t length); 24 | }; 25 | 26 | template 27 | ArithmeticCirc::ArithmeticCirc(int num_party, int party, MPIOChannel* io, HE* he) { 28 | this->num_party = num_party; 29 | this->party = party; 30 | this->io = io; 31 | this->he = he; 32 | this->num_triples_pool = 20 * (he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2); 33 | 34 | triple_a = new int64_t[num_triples_pool]; 35 | triple_b = new int64_t[num_triples_pool]; 36 | triple_c = new int64_t[num_triples_pool]; 37 | 38 | this->get_triples(triple_a, triple_b, triple_c); 39 | } 40 | 41 | template 42 | ArithmeticCirc::~ArithmeticCirc() { 43 | delete[] triple_a; 44 | delete[] triple_b; 45 | delete[] triple_c; 46 | } 47 | 48 | template 49 | void ArithmeticCirc::sum(int64_t* out, int64_t* in1, int64_t* in2, size_t length) { 50 | for (int i = 0; i < length; ++i) 51 | out[i] = (in1[i] + in2[i]) % he->q; 52 | } 53 | 54 | template 55 | void ArithmeticCirc::sub(int64_t* out, int64_t* in1, int64_t* in2, size_t length) { 56 | for (int i = 0; i < length; ++i) 57 | out[i] = (in1[i] - in2[i]) % he->q; 58 | } 59 | 60 | template 61 | void ArithmeticCirc::get_triples(int64_t* triple_a, int64_t* triple_b, int64_t* triple_c) { 62 | int batch_size = (he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2); 63 | prg.random_data(triple_a, num_triples_pool * sizeof(int64_t)); 64 | prg.random_data(triple_b, num_triples_pool * sizeof(int64_t)); 65 | for (size_t i = 0; i < num_triples_pool; ++i) { 66 | triple_a[i] %= he->q; 67 | triple_b[i] %= he->q; 68 | triple_a[i] = (he->q + triple_a[i]) % he->q; 69 | triple_b[i] = (he->q + triple_b[i]) % he->q; 70 | } 71 | std::vector> a, b, c; 72 | for (int i = 0; i < num_triples_pool / batch_size; ++i) { 73 | lbcrypto::Plaintext p_a, p_b; 74 | vector tmp_a, tmp_b; 75 | tmp_a.resize(batch_size); 76 | tmp_b.resize(batch_size); 77 | memcpy(tmp_a.data(), triple_a + i * batch_size, batch_size * sizeof(int64_t)); 78 | memcpy(tmp_b.data(), triple_b + i * batch_size, batch_size * sizeof(int64_t)); 79 | 80 | p_a = he->cc->MakePackedPlaintext(tmp_a); 81 | p_b = he->cc->MakePackedPlaintext(tmp_b); 82 | 83 | a.push_back(he->cc->Encrypt(he->pk, p_a)); 84 | b.push_back(he->cc->Encrypt(he->pk, p_b)); 85 | } 86 | if (party == ALICE) { 87 | std::vector> tmp_a, tmp_b; 88 | for (int i = 2; i <= num_party; ++i) { 89 | he->deserialize_recv(tmp_a, i); 90 | he->deserialize_recv(tmp_b, i); 91 | for (int j = 0; j < num_triples_pool / batch_size; ++j) { 92 | he->cc->EvalAddInPlace(a[j], tmp_a[j]); 93 | he->cc->EvalAddInPlace(b[j], tmp_b[j]); 94 | he->cc->ModReduceInPlace(a[j]); 95 | he->cc->ModReduceInPlace(b[j]); 96 | } 97 | } 98 | 99 | for (int i = 0; i < num_triples_pool / batch_size; ++i) { 100 | auto tmp = he->cc->EvalMult(a[i], b[i]); 101 | c.push_back(he->cc->ModReduce(tmp)); 102 | } 103 | he->serialize_sendall(c); 104 | } 105 | else { 106 | he->serialize_send(a, ALICE); 107 | he->serialize_send(b, ALICE); 108 | he->deserialize_recv(c, ALICE); 109 | } 110 | he->enc_to_share(c, triple_c, num_triples_pool); 111 | } 112 | 113 | template 114 | void ArithmeticCirc::mult(int64_t* out, int64_t* in1, int64_t* in2, size_t length) { 115 | bool delete_array = false; 116 | int64_t *a, *b, *c; 117 | // std::cout << "In mult \n"; 118 | if (length > num_triples_pool) { 119 | a = new int64_t[(length + num_triples_pool - 1) / num_triples_pool * num_triples_pool]; 120 | b = new int64_t[(length + num_triples_pool - 1) / num_triples_pool * num_triples_pool]; 121 | c = new int64_t[(length + num_triples_pool - 1) / num_triples_pool * num_triples_pool]; 122 | for (uint i = 0; i < (length + num_triples_pool - 1) / num_triples_pool; ++i) 123 | this->get_triples(a + i * num_triples_pool, b + i * num_triples_pool, c + i * num_triples_pool); 124 | size_t tocp = min((length + num_triples_pool - 1) / num_triples_pool * num_triples_pool - length, num_triples); 125 | memcpy(triple_a, a + (length + num_triples_pool - 1) / num_triples_pool * num_triples_pool - length, 126 | tocp * sizeof(int64_t)); 127 | memcpy(triple_b, b + (length + num_triples_pool - 1) / num_triples_pool * num_triples_pool - length, 128 | tocp * sizeof(int64_t)); 129 | memcpy(triple_c, c + (length + num_triples_pool - 1) / num_triples_pool * num_triples_pool - length, 130 | tocp * sizeof(int64_t)); 131 | num_triples = 0; 132 | delete_array = true; 133 | } 134 | else if (length > num_triples_pool - num_triples) { 135 | a = new int64_t[length]; 136 | b = new int64_t[length]; 137 | c = new int64_t[length]; 138 | delete_array = true; 139 | memcpy(a, triple_a + num_triples, (num_triples_pool - num_triples) * sizeof(int64_t)); 140 | memcpy(c, triple_c + num_triples, (num_triples_pool - num_triples) * sizeof(int64_t)); 141 | memcpy(b, triple_b + num_triples, (num_triples_pool - num_triples) * sizeof(int64_t)); 142 | get_triples(triple_a, triple_b, triple_c); 143 | memcpy(a + num_triples_pool - num_triples, triple_a, 144 | (length - (num_triples_pool - num_triples)) * sizeof(int64_t)); 145 | memcpy(b + num_triples_pool - num_triples, triple_b, 146 | (length - (num_triples_pool - num_triples)) * sizeof(int64_t)); 147 | memcpy(c + num_triples_pool - num_triples, triple_c, 148 | (length - (num_triples_pool - num_triples)) * sizeof(int64_t)); 149 | num_triples = length - (num_triples_pool - num_triples); 150 | } 151 | else { 152 | a = triple_a + num_triples; 153 | b = triple_b + num_triples; 154 | c = triple_c + num_triples; 155 | num_triples += length; 156 | } 157 | 158 | int64_t *d = new int64_t[length], *e = new int64_t[length]; 159 | 160 | for (int i = 0; i < length; ++i) { 161 | d[i] = (he->q + in1[i] - a[i]) % he->q; 162 | e[i] = (he->q + in2[i] - b[i]) % he->q; 163 | c[i] = (he->q + c[i]) % he->q; 164 | } 165 | 166 | // io->sync(); 167 | if (party == ALICE) { 168 | int64_t *d0 = new int64_t[length], *e0 = new int64_t[length]; 169 | 170 | for (int i = 2; i <= num_party; ++i) { 171 | io->recv_data(i, d0, length * sizeof(int64_t)); 172 | io->recv_data(i, e0, length * sizeof(int64_t)); 173 | for (int j = 0; j < length; ++j) { 174 | d[j] = (d[j] + d0[j]) % he->q; 175 | e[j] = (e[j] + e0[j]) % he->q; 176 | } 177 | } 178 | 179 | for (int i = 2; i <= num_party; ++i) { 180 | io->send_data(i, d, length * sizeof(int64_t)); 181 | io->send_data(i, e, length * sizeof(int64_t)); 182 | io->flush(i); 183 | } 184 | 185 | delete[] d0; 186 | delete[] e0; 187 | } 188 | else { 189 | io->send_data(ALICE, d, length * sizeof(int64_t)); 190 | io->send_data(ALICE, e, length * sizeof(int64_t)); 191 | io->flush(ALICE); 192 | io->recv_data(ALICE, d, length * sizeof(int64_t)); 193 | io->recv_data(ALICE, e, length * sizeof(int64_t)); 194 | } 195 | io->flush(); 196 | 197 | for (uint i = 0; i < length; ++i) { 198 | long long int x = ((uint64_t)((uint64_t)e[i] * (uint64_t)a[i])) % he->q; 199 | long long int y = ((uint64_t)((uint64_t)d[i] * (uint64_t)b[i])) % he->q; 200 | 201 | out[i] = (c[i] + x + y) % he->q; 202 | } 203 | if (party == ALICE) { 204 | for (uint i = 0; i < length; ++i) { 205 | long long int x = ((uint64_t)((uint64_t)d[i] * (uint64_t)e[i])) % he->q; 206 | out[i] = (out[i] + x + he->q) % he->q; 207 | } 208 | } 209 | delete[] d; 210 | delete[] e; 211 | if (delete_array) { 212 | delete[] a; 213 | delete[] b; 214 | delete[] c; 215 | } 216 | } 217 | } // namespace emp 218 | -------------------------------------------------------------------------------- /emp-aby/io/multi-io.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "emp-aby/io/mp_io_channel.h" 4 | #include "emp-aby/io/multi-io-base.hpp" 5 | #include 6 | 7 | namespace emp { 8 | class MultiIO : public MPIOChannel { 9 | private: 10 | /* data */ 11 | public: 12 | int party, num_party; 13 | int bind_port; 14 | string bind_address; 15 | std::vector> net_config; 16 | std::map ios; 17 | std::map ot_ios[2]; 18 | bool continue_comm; 19 | std::future background_recv_fut; 20 | MultiIO(int party, int num_party, std::vector>& net_config); 21 | ~MultiIO(); 22 | 23 | void setup_ot_ios(); 24 | void send_data(int dst, const void* data, int len, int j = 0, MESSAGE_TYPE msg_type = NORM_MSG); 25 | void recv_data(int src, void* data, int len, int j = 0, MESSAGE_TYPE msg_type = NORM_MSG); 26 | void* recv_data(int src, int& len, int j = 0, MESSAGE_TYPE msg_type = NORM_MSG); 27 | 28 | int get_total_bytes_sent(); 29 | 30 | void flush(int idx = 0, int j = 0) {} 31 | void sync() {} 32 | 33 | // optimise later to send aligned data 34 | void send_bool(int dst, bool* data, int length, int j = 0) { 35 | send_data(dst, data, length, j, NORM_MSG); 36 | } 37 | 38 | // optimise later to recv aligned data 39 | void recv_bool(int src, bool* data, int length, int j = 0) { 40 | recv_data(src, data, length, j, NORM_MSG); 41 | } 42 | 43 | void send_block(int dst, const block* data, int length, int j = 0) { 44 | send_data(dst, data, length * sizeof(block), j, NORM_MSG); 45 | } 46 | 47 | void recv_block(int src, block* data, int length, int j = 0) { 48 | recv_data(src, data, length * sizeof(block), j, NORM_MSG); 49 | } 50 | 51 | MultiIOBase*& get(size_t idx, bool b = false) { 52 | if (b) 53 | return ot_ios[0][idx]; 54 | else 55 | return ot_ios[1][idx]; 56 | } 57 | 58 | void background_recv(); 59 | }; 60 | 61 | MultiIO::MultiIO(int party, int num_party, std::vector>& net_config) 62 | : party(party), 63 | num_party(num_party), 64 | bind_port(net_config[party - 1].second), 65 | bind_address(net_config[party - 1].first), 66 | net_config(net_config) { 67 | std::map socket_map; 68 | accept_base_connections(num_party - party, net_config[party - 1].second, net_config[party - 1].first, socket_map, 69 | party, num_party); 70 | for (uint p = party - 1; p > 0; --p) { 71 | int consocket = request_base_connection(p, net_config, party, num_party); 72 | socket_map.emplace(std::pair(p, consocket)); 73 | } 74 | for (auto& sock : socket_map) { 75 | MultiIOBase* io = new MultiIOBase(sock.second, true); 76 | ios.emplace(std::pair(sock.first, io)); 77 | } 78 | socket_map.clear(); 79 | continue_comm = true; 80 | background_recv_fut = std::async([this]() { 81 | this->background_recv(); 82 | }); 83 | } 84 | 85 | void MultiIO::setup_ot_ios() { 86 | std::map socket_map; 87 | 88 | for (auto& conf : net_config) { 89 | conf.second++; 90 | } 91 | accept_base_connections(num_party - party, net_config[party - 1].second, net_config[party - 1].first, socket_map, 92 | party, num_party); 93 | for (uint p = party - 1; p > 0; --p) { 94 | int consocket = request_base_connection(p, net_config, party, num_party); 95 | socket_map.emplace(std::pair(p, consocket)); 96 | } 97 | for (auto& sock : socket_map) { 98 | MultiIOBase* io = new MultiIOBase(sock.second, true); 99 | ot_ios[0].emplace(std::pair(sock.first, io)); 100 | } 101 | socket_map.clear(); 102 | for (auto& conf : net_config) { 103 | conf.second++; 104 | } 105 | accept_base_connections(num_party - party, net_config[party - 1].second, net_config[party - 1].first, socket_map, 106 | party, num_party); 107 | for (uint p = party - 1; p > 0; --p) { 108 | int consocket = request_base_connection(p, net_config, party, num_party); 109 | socket_map.emplace(std::pair(p, consocket)); 110 | } 111 | for (auto& sock : socket_map) { 112 | MultiIOBase* io = new MultiIOBase(sock.second, true); 113 | ot_ios[1].emplace(std::pair(sock.first, io)); 114 | } 115 | socket_map.clear(); 116 | } 117 | 118 | MultiIO::~MultiIO() { 119 | for (auto& io : ios) { 120 | io.second->send_msg(nullptr, 0, TERMINATE_MSG); 121 | } 122 | bool c = false; 123 | for (auto& io : ios) { 124 | c |= io.second->continue_comm; 125 | } 126 | continue_comm = c; 127 | background_recv_fut.get(); 128 | for (auto& io : ios) { 129 | io.second->~MultiIOBase(); 130 | } 131 | for (auto& io : ot_ios[0]) { 132 | io.second->~MultiIOBase(); 133 | } 134 | for (auto& io : ot_ios[1]) { 135 | io.second->~MultiIOBase(); 136 | } 137 | net_config.clear(); 138 | ios.clear(); 139 | } 140 | 141 | void MultiIO::send_data(int dst, const void* data, int len, int j, MESSAGE_TYPE msg_type) { 142 | if (dst != 0 && dst != party) { 143 | MultiIOBase* io = ios[dst]; 144 | io->send_msg(data, len, msg_type); 145 | } 146 | else { 147 | error("sending to invalid party"); 148 | } 149 | } 150 | 151 | void MultiIO::recv_data(int src, void* data, int len, int j, MESSAGE_TYPE msg_type) { 152 | if (src != 0 && src != party) { 153 | MultiIOBase* io = ios[src]; 154 | bool received = false; 155 | while (!received) { 156 | std::unique_lock lock(io->recv_mutex[msg_type]); 157 | if (!io->recv_msg_queue[msg_type].empty()) { 158 | received = true; 159 | int recv_len = io->recv_msg_queue[msg_type].front().first; 160 | if (len != recv_len) { 161 | std::cout << "\n" << party << " " << src << " lengths " << len << " " << recv_len << "\n"; 162 | error("unequal length"); 163 | } 164 | memcpy(data, io->recv_msg_queue[msg_type].front().second, len); 165 | io->recv_msg_queue[msg_type].pop_front(); 166 | lock.unlock(); 167 | return; 168 | } 169 | else { 170 | io->recv_condition_vars[msg_type].wait( 171 | lock, [io, msg_type] { return !io->recv_msg_queue[msg_type].empty(); }); 172 | } 173 | } 174 | } 175 | else { 176 | error("receive called for invalid party"); 177 | } 178 | } 179 | 180 | void* MultiIO::recv_data(int src, int& len, int j, MESSAGE_TYPE msg_type) { 181 | void* data = nullptr; 182 | if (src != 0 and src != party) { 183 | MultiIOBase* io = ios[src]; 184 | bool received = false; 185 | while (!received) { 186 | std::unique_lock lock(io->recv_mutex[msg_type]); 187 | if (!io->recv_msg_queue[msg_type].empty()) { 188 | received = true; 189 | int recv_len = io->recv_msg_queue[msg_type].front().first; 190 | len = recv_len; 191 | data = io->recv_msg_queue[msg_type].front().second; 192 | io->recv_msg_queue[msg_type].pop_front(); 193 | lock.unlock(); 194 | return data; 195 | } 196 | else { 197 | io->recv_condition_vars[msg_type].wait( 198 | lock, [io, msg_type] { return !io->recv_msg_queue[msg_type].empty(); }); 199 | } 200 | } 201 | } 202 | else { 203 | error("receive called for invalid party"); 204 | } 205 | return data; 206 | } 207 | 208 | int MultiIO::get_total_bytes_sent() { 209 | double kb = 0; 210 | for (auto& io : this->ios) { 211 | kb += (double)(io.second->counter) / 1000; 212 | } 213 | for (auto& io : this->ot_ios[0]) { 214 | kb += (double)(io.second->counter) / 1000; 215 | } 216 | for (auto& io : this->ot_ios[1]) { 217 | kb += (double)(io.second->counter) / 1000; 218 | } 219 | return kb; 220 | } 221 | 222 | void MultiIO::background_recv() { 223 | struct pollfd* pfds; 224 | pfds = (struct pollfd*)calloc(num_party - 1, sizeof(struct pollfd)); 225 | std::map socket_party; 226 | nfds_t i = 0; 227 | for (auto& io : ios) { 228 | pfds[i].fd = io.second->consocket; 229 | pfds[i].events = POLLIN; 230 | socket_party.emplace(io.second->consocket, io.first); 231 | ++i; 232 | } 233 | 234 | while (continue_comm) { 235 | int ready = 0; 236 | while (ready < 1) { 237 | if (!continue_comm) 238 | return; 239 | ready = poll(pfds, num_party - 1, -1); 240 | if (ready == -1) { 241 | error("error: poll"); 242 | } 243 | } 244 | bool c = true; 245 | for (nfds_t i = 0; i < num_party - 1; ++i) { 246 | if (pfds[i].revents != 0) { 247 | if (pfds[i].revents & POLLIN) { 248 | int p = socket_party[pfds[i].fd]; 249 | c = ios[p]->recv_msg(); 250 | } 251 | } 252 | } 253 | if (!c) { 254 | for (auto& io : ios) { 255 | c |= io.second->continue_comm; 256 | } 257 | continue_comm = c; 258 | } 259 | } 260 | } 261 | 262 | } // namespace emp 263 | -------------------------------------------------------------------------------- /emp-aby/converter/a2bconverter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "emp-aby/lut.h" 4 | #include "emp-aby/mp-circuit.hpp" 5 | #include "emp-aby/simd_interface/arithmetic-circ.h" 6 | 7 | namespace emp { 8 | 9 | template 10 | class A2BConverter { 11 | private: 12 | LUT* bit_to_a; 13 | MPIOChannel* io; 14 | int num_party, party; 15 | ThreadPool* pool; 16 | HE* he; 17 | PRG prg; 18 | Circuit>* circuit; 19 | ArithmeticCirc* arithmetic_circ; 20 | bool* b_share; 21 | int64_t* a_share; 22 | 23 | public: 24 | size_t ab_share_pool, num_used_shares = 0, num_rejected = 0; 25 | A2BConverter(int num_party, int party, MPIOChannel* io, ThreadPool* pool, HE* he, 26 | MPSIMDCircExec* simd_circ, int pool_size = 20); 27 | 28 | void convert(bool* out, int64_t* in, size_t length); 29 | 30 | size_t rand_ab_shares(int64_t* a_share, bool* b_share, const size_t length); 31 | }; 32 | 33 | template 34 | A2BConverter::A2BConverter(int num_party, int party, MPIOChannel* io, ThreadPool* pool, HE* he, 35 | MPSIMDCircExec* simd_circ, int pool_size) { 36 | int64_t table[2] = {0, 1}; 37 | this->pool = pool; 38 | this->io = io; 39 | this->num_party = num_party; 40 | this->party = party; 41 | this->he = he; 42 | this->bit_to_a = new LUT(num_party, party, io, pool, he, table, pool_size); 43 | // std::cout << "lut setup done \n"; 44 | this->circuit = new Circuit>("emp-aby/modsum.txt", party, simd_circ); 45 | this->arithmetic_circ = new ArithmeticCirc(num_party, party, io, he); 46 | // std::cout << "arithmetic circ setup done \n"; 47 | int l = ceil(log2(he->q)); 48 | this->ab_share_pool = pool_size * (he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2) / l; 49 | this->a_share = new int64_t[ab_share_pool]; 50 | this->b_share = new bool[ab_share_pool * (int)ceil(log2(he->q))]; 51 | this->num_rejected = this->rand_ab_shares(this->a_share, this->b_share, this->ab_share_pool); 52 | // std::cout << "random ab shares setup done \n"; 53 | } 54 | 55 | template 56 | size_t A2BConverter::rand_ab_shares(int64_t* a_share, bool* b_share, const size_t length) { 57 | int64_t l = ceil(log2(he->q)); 58 | bool* r_b = new bool[length * l]; 59 | int64_t* r_a = new int64_t[length * l]; 60 | prg.random_bool(r_b, length * l); 61 | 62 | this->bit_to_a->lookup(r_a, r_b, length * l); 63 | // std::cout << "lookup done\n"; 64 | for (int i = 0; i < length * l; ++i) { 65 | r_a[i] = (r_a[i] % he->q + he->q) % he->q; 66 | } 67 | 68 | bool* bits = new bool[l]; 69 | int x = he->q - 1; 70 | for (int i = 0; i < l; ++i) { 71 | bits[i] = ((x & 1) == 1); 72 | x >>= 1; 73 | } 74 | int64_t* zero_sum = new int64_t[length]; 75 | memset(zero_sum, 0, length * sizeof(length)); 76 | for (int i = 0; i < l; ++i) { 77 | if (!bits[i]) { 78 | for (int j = 0; j < length; ++j) { 79 | zero_sum[j] = (zero_sum[j] + r_a[j * l + i]) % he->q; 80 | } 81 | } 82 | } 83 | int64_t* tmp = new int64_t[length]; 84 | for (int i = 0; i < l; ++i) { 85 | if (bits[i]) { 86 | for (int j = 0; j < length; ++j) { 87 | tmp[j] = r_a[j * l + i]; 88 | } 89 | this->arithmetic_circ->mult(zero_sum, zero_sum, tmp, length); 90 | } 91 | } 92 | if (party == ALICE) { 93 | for (int i = 2; i <= num_party; ++i) { 94 | io->recv_data(i, tmp, length * sizeof(int64_t)); 95 | for (int j = 0; j < length; ++j) 96 | zero_sum[j] = (zero_sum[j] + tmp[j]) % he->q; 97 | } 98 | 99 | for (int i = 2; i <= num_party; ++i) { 100 | io->send_data(i, zero_sum, length * sizeof(int64_t)); 101 | io->flush(i); 102 | } 103 | } 104 | else { 105 | io->send_data(ALICE, zero_sum, length * sizeof(int64_t)); 106 | io->flush(ALICE); 107 | 108 | io->recv_data(ALICE, zero_sum, length * sizeof(int64_t)); 109 | } 110 | memset(a_share, 0, length * sizeof(int64_t)); 111 | 112 | size_t waste = 0; 113 | uint64_t y = 0; 114 | for (int i = 0; i < length; ++i) { 115 | if (zero_sum[i] == 0) { 116 | for (int j = 0; j < l; ++j) { 117 | y = r_a[i * l + j]; 118 | y = y << j; 119 | y %= he->q; 120 | a_share[i - waste] = (a_share[i - waste] + y) % he->q; 121 | b_share[(i - waste) * l + j] = r_b[i * l + j]; 122 | } 123 | } 124 | else 125 | waste += 1; 126 | } 127 | 128 | delete[] tmp; 129 | return waste; 130 | } 131 | 132 | template 133 | void A2BConverter::convert(bool* out, int64_t* in, size_t length) { 134 | int64_t l = ceil(log2(he->q)); 135 | bool* r_b; 136 | int64_t* r_a; 137 | bool delete_array = false; 138 | 139 | if (length > ab_share_pool - num_used_shares - num_rejected) { 140 | r_a = new int64_t[length]; 141 | r_b = new bool[length * l]; 142 | memcpy(r_a, this->a_share + num_used_shares, 143 | (ab_share_pool - num_used_shares - num_rejected) * sizeof(int64_t)); 144 | memcpy(r_b, this->b_share + num_used_shares * l, (ab_share_pool - num_used_shares - num_rejected) * l); 145 | int accepted = ab_share_pool - num_used_shares - num_rejected; 146 | while (accepted < length) { 147 | num_rejected = this->rand_ab_shares(a_share, b_share, ab_share_pool); 148 | uint tocp = min(ab_share_pool - num_rejected, length - accepted); 149 | memcpy(r_a + accepted, this->a_share, tocp * sizeof(int64_t)); 150 | memcpy(r_b + accepted * l, this->b_share, tocp * l); 151 | accepted += ab_share_pool - num_rejected; 152 | } 153 | num_used_shares = ab_share_pool - (accepted - length) - num_rejected; 154 | delete_array = true; 155 | } 156 | else { 157 | r_a = a_share + num_used_shares; 158 | r_b = b_share + num_used_shares * l; 159 | num_used_shares += length; 160 | } 161 | // std::cout << "AB shares generated" << std::endl; 162 | bool* y_b = new bool[length * (circuit->n1 + circuit->n2)]; 163 | memset(y_b, 0, length * (circuit->n1 + circuit->n2)); // x-r (length * n), r (length * n) 164 | 165 | for (size_t i = 0; i < length; ++i) { 166 | for (int j = 0; j < l; ++j) { 167 | y_b[i * circuit->n2 + j + length * circuit->n1] = r_b[i * l + j]; 168 | } 169 | } 170 | 171 | for (size_t i = 0; i < length; ++i) { 172 | r_a[i] = (in[i] - r_a[i]) % he->q; 173 | r_a[i] = (he->q + r_a[i]) % he->q; 174 | } 175 | 176 | // std::cout << "Do in - r_a" << std::endl; 177 | if (party == ALICE) { 178 | if (pool->size() == 1) { 179 | int64_t* tmp = new int64_t[length]; 180 | for (int i = 2; i <= num_party; ++i) { 181 | io->recv_data(i, tmp, length * sizeof(int64_t)); 182 | io->flush(i); 183 | for (size_t j = 0; j < length; ++j) 184 | r_a[j] = (tmp[j] + r_a[j]) % he->q; 185 | } 186 | } 187 | else { 188 | int threads = pool->size(); 189 | int64_t* tmp[threads]; 190 | for (int i = 0; i < threads; ++i) { 191 | tmp[i] = new int64_t[length]; 192 | memset(tmp[i], 0, length * sizeof(int64_t)); 193 | } 194 | vector> res; 195 | int num_steps = ceil((double)(num_party) / (double)threads); 196 | for (int i = 0; i < threads; ++i) { 197 | res.push_back(pool->enqueue([this, i, num_steps, length, t = tmp[i]]() { 198 | int64_t* tmp_i = new int64_t[length]; 199 | 200 | for (int j = 0; j < num_steps; ++j) { 201 | if (i * num_steps + j + 1 > num_party) 202 | break; 203 | if (i * num_steps + j + 1 == party) 204 | continue; 205 | 206 | io->recv_data(i * num_steps + j + 1, tmp_i, length * sizeof(int64_t)); 207 | io->flush(i * num_steps + j + 1); 208 | for (size_t k = 0; k < length; ++k) 209 | t[k] = (t[k] + tmp_i[k]) % he->q; 210 | } 211 | delete[] tmp_i; 212 | })); 213 | } 214 | 215 | for (auto& v : res) 216 | v.get(); 217 | res.clear(); 218 | 219 | for (int i = 0; i < threads; ++i) { 220 | for (size_t j = 0; j < length; ++j) 221 | r_a[j] = (tmp[i][j] + r_a[j]) % he->q; 222 | delete[] tmp[i]; 223 | } 224 | } 225 | 226 | for (size_t i = 0; i < length; ++i) { 227 | int64_t x = r_a[i]; 228 | for (int j = 0; j < circuit->n1; ++j) { 229 | y_b[i * circuit->n1 + j] = ((x & 1) == 1); 230 | x >>= 1; 231 | } 232 | } 233 | 234 | // delete[] tmp; 235 | } 236 | else { 237 | this->io->send_data(ALICE, r_a, length * sizeof(int64_t)); 238 | io->flush(ALICE); 239 | } 240 | // std::cout << "Circuit to be computed \n"; 241 | // std::cout << length *circuit->n3 << std::endl; 242 | bool* tmp_out = new bool[length * circuit->n3]; 243 | circuit->template compute(tmp_out, y_b, length); 244 | // std::cout << "Circuit computed \n"; 245 | 246 | for (int i = 0; i < length; ++i) { 247 | for (int j = 0; j < l; ++j) { 248 | out[i * l + j] = tmp_out[i * circuit->n3 + j]; 249 | } 250 | } 251 | // std::cout << "Circuit computed \n"; 252 | 253 | delete[] y_b; 254 | delete[] tmp_out; 255 | if (delete_array) { 256 | delete[] r_a; 257 | delete[] r_b; 258 | } 259 | } 260 | 261 | } //namespace emp 262 | -------------------------------------------------------------------------------- /test/simd_exec.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/simd_interface/simd_exec.h" 2 | int party, port; 3 | 4 | const static int threads = 1; 5 | 6 | double test_and(SIMDCircExec* simd_circ, NetIO* io, int length = 2000) { 7 | bool *A = new bool[length], *B = new bool[length], *C = new bool[length]; 8 | PRG prg; 9 | prg.random_bool(A, length); 10 | prg.random_bool(B, length); 11 | auto start = clock_start(); 12 | simd_circ->and_gate(C, A, B, length); 13 | long long t = time_from(start); 14 | io->flush(); 15 | if (party == ALICE) { 16 | io->send_bool(A, length); 17 | io->send_bool(B, length); 18 | io->send_bool(C, length); 19 | } 20 | else if (party == BOB) { 21 | bool *A0 = new bool[length], *B0 = new bool[length], *C0 = new bool[length]; 22 | io->recv_bool(A0, length); 23 | io->recv_bool(B0, length); 24 | io->recv_bool(C0, length); 25 | for (int i = 0; i < length; ++i) { 26 | if ((C0[i] ^ C[i]) != ((A0[i] ^ A[i]) & (B0[i] ^ B[i]))) { 27 | error(" Bool AND Failed!"); 28 | }; 29 | } 30 | delete[] A0; 31 | delete[] B0; 32 | delete[] C0; 33 | } 34 | io->sync(); 35 | delete[] A; 36 | delete[] B; 37 | delete[] C; 38 | return t; 39 | // std::cout << "Bool AND passed \n"; 40 | } 41 | 42 | double test_block_and(SIMDCircExec* simd_circ, NetIO* io, int length = 2000) { 43 | block *A = new block[length], *B = new block[length], *C = new block[length]; 44 | PRG prg; 45 | prg.random_block(A, length); 46 | prg.random_block(B, length); 47 | 48 | auto start = clock_start(); 49 | simd_circ->and_gate(C, A, B, length); 50 | long long t = time_from(start); 51 | io->flush(); 52 | if (party == ALICE) { 53 | io->send_block(A, length); 54 | io->send_block(B, length); 55 | io->send_block(C, length); 56 | } 57 | else if (party == BOB) { 58 | block *A0 = new block[length], *B0 = new block[length], *C0 = new block[length]; 59 | io->recv_block(A0, length); 60 | io->recv_block(B0, length); 61 | io->recv_block(C0, length); 62 | xorBlocks_arr(C, C0, C, length); 63 | xorBlocks_arr(A, A0, A, length); 64 | xorBlocks_arr(B, B0, B, length); 65 | andBlocks_arr(A, A, B, length); 66 | if (!cmpBlock(A, C, length)) { 67 | error("Block AND Failed"); 68 | } 69 | delete[] A0; 70 | delete[] B0; 71 | delete[] C0; 72 | } 73 | io->sync(); 74 | delete[] A; 75 | delete[] B; 76 | delete[] C; 77 | return t; 78 | // std::cout << "Block AND passed \n"; 79 | } 80 | 81 | double test_not(SIMDCircExec* simd_circ, NetIO* io, int length = 2000) { 82 | bool *A = new bool[length], *B = new bool[length]; 83 | PRG prg; 84 | prg.random_bool(A, length); 85 | auto start = clock_start(); 86 | simd_circ->not_gate(B, A, length); 87 | long long t = time_from(start); 88 | io->flush(); 89 | if (party == ALICE) { 90 | io->send_bool(A, length); 91 | io->send_bool(B, length); 92 | } 93 | else if (party == BOB) { 94 | bool *A0 = new bool[length], *B0 = new bool[length]; 95 | io->recv_bool(A0, length); 96 | io->recv_bool(B0, length); 97 | 98 | xorBools_arr(A, A, A0, length); 99 | xorBools_arr(B, B, B0, length); 100 | 101 | for (int i = 0; i < length; ++i) { 102 | assert(A[i] == (1 ^ B[i])); 103 | } 104 | delete[] A0; 105 | delete[] B0; 106 | } 107 | io->sync(); 108 | delete[] A; 109 | delete[] B; 110 | return t; 111 | // std::cout << "Bool NOT gate test Passed \n"; 112 | } 113 | double test_block_not(SIMDCircExec* simd_circ, NetIO* io, int length = 2000) { 114 | block *A = new block[length], *B = new block[length]; 115 | PRG prg; 116 | prg.random_block(A, length); 117 | auto start = clock_start(); 118 | simd_circ->not_gate(B, A, length); 119 | long long t = time_from(start); 120 | if (party == ALICE) { 121 | io->send_block(A, length); 122 | io->send_block(B, length); 123 | } 124 | else if (party == BOB) { 125 | block *A0 = new block[length], *B0 = new block[length]; 126 | io->recv_block(A0, length); 127 | io->recv_block(B0, length); 128 | 129 | xorBlocks_arr(A, A, A0, length); 130 | xorBlocks_arr(B, B, B0, length); 131 | xorBlocks_arr(A, A, all_one_block, length); 132 | 133 | if (!cmpBlock(A, B, length)) { 134 | error("Block NOT Failed"); 135 | } 136 | delete[] A0; 137 | delete[] B0; 138 | } 139 | delete[] A; 140 | delete[] B; 141 | return t; 142 | // std::cout << "Block NOT Gate tests Passed \n"; 143 | } 144 | double test_xor(SIMDCircExec* simd_circ, NetIO* io, int length = 2000) { 145 | bool *A = new bool[length], *B = new bool[length], *C = new bool[length]; 146 | PRG prg; 147 | prg.random_bool(A, length); 148 | prg.random_bool(B, length); 149 | 150 | auto start = clock_start(); 151 | simd_circ->xor_gate(C, A, B, length); 152 | long long t = time_from(start); 153 | io->flush(); 154 | if (party == ALICE) { 155 | io->send_bool(A, length); 156 | io->send_bool(B, length); 157 | io->send_bool(C, length); 158 | } 159 | else if (party == BOB) { 160 | bool *A0 = new bool[length], *B0 = new bool[length], *C0 = new bool[length]; 161 | io->recv_bool(A0, length); 162 | io->recv_bool(B0, length); 163 | io->recv_bool(C0, length); 164 | for (int i = 0; i < length; ++i) { 165 | if ((C0[i] ^ C[i]) != ((A0[i] ^ A[i]) ^ (B0[i] ^ B[i]))) { 166 | error("Bool XOR Failed!"); 167 | }; 168 | } 169 | delete[] A0; 170 | delete[] B0; 171 | delete[] C0; 172 | } 173 | io->sync(); 174 | delete[] A; 175 | delete[] B; 176 | delete[] C; 177 | return t; 178 | // std::cout << "Bool XOR passed \n"; 179 | } 180 | double test_block_xor(SIMDCircExec* simd_circ, NetIO* io, int length = 2000) { 181 | block *A = new block[length], *B = new block[length], *C = new block[length]; 182 | PRG prg; 183 | prg.random_block(A, length); 184 | prg.random_block(B, length); 185 | 186 | auto start = clock_start(); 187 | simd_circ->xor_gate(C, A, B, length); 188 | long long t = time_from(start); 189 | io->flush(); 190 | if (party == ALICE) { 191 | io->send_block(A, length); 192 | io->send_block(B, length); 193 | io->send_block(C, length); 194 | } 195 | else if (party == BOB) { 196 | block *A0 = new block[length], *B0 = new block[length], *C0 = new block[length]; 197 | io->recv_block(A0, length); 198 | io->recv_block(B0, length); 199 | io->recv_block(C0, length); 200 | xorBlocks_arr(C, C0, C, length); 201 | xorBlocks_arr(A, A0, A, length); 202 | xorBlocks_arr(B, B0, B, length); 203 | xorBlocks_arr(A, A, B, length); 204 | if (!cmpBlock(A, C, length)) { 205 | error("Block XOR Failed"); 206 | } 207 | delete[] A0; 208 | delete[] B0; 209 | delete[] C0; 210 | } 211 | io->sync(); 212 | delete[] A; 213 | delete[] B; 214 | delete[] C; 215 | return t; 216 | // std::cout << "Block XOR passed \n"; 217 | } 218 | 219 | // double test_mux(SIMDCircExec *simd_circ, NetIO *io, int length = 2000, 220 | // int width = 120) { 221 | // block *A = new block[length * width], *B = new block[length * width], 222 | // *C = new block[length * width]; 223 | // bool *sel = new bool[length]; 224 | // PRG prg; 225 | // prg.random_block(A, length * width); 226 | // prg.random_block(B, length * width); 227 | // prg.random_bool(sel, length); 228 | 229 | // auto start = clock_start(); 230 | // simd_circ->mux_gate(C, A, B, sel, width, length); 231 | // long long t = time_from(start); 232 | // io->flush(); 233 | // if (party == ALICE) { 234 | // io->send_block(A, length * width); 235 | // io->send_block(B, length * width); 236 | // io->send_block(C, length * width); 237 | // io->send_bool(sel, length); 238 | // } else if (party == BOB) { 239 | // block *A0 = new block[length * width], *B0 = new block[length * width], 240 | // *C0 = new block[length * width]; 241 | // bool *sel0 = new bool[length]; 242 | 243 | // io->recv_block(A0, length * width); 244 | // io->recv_block(B0, length * width); 245 | // io->recv_block(C0, length * width); 246 | // io->recv_bool(sel0, length); 247 | 248 | // xorBools_arr(sel, sel, sel0, length); 249 | // xorBlocks_arr(A, A, A0, length * width); 250 | // xorBlocks_arr(B, B, B0, length * width); 251 | // xorBlocks_arr(C, C, C0, length * width); 252 | // for (int i = 0; i < length; ++i) { 253 | // if (sel[i]) { 254 | // if (!cmpBlock(B + i * width, C + i * width, width)) { 255 | // error("MUX Failed"); 256 | // } 257 | // } else { 258 | // if (!cmpBlock(A + i * width, C + i * width, width)) { 259 | // error("MUX Failed"); 260 | // } 261 | // } 262 | // } 263 | // delete[] A0; 264 | // delete[] B0; 265 | // delete[] C0; 266 | // delete[] sel0; 267 | // } 268 | // io->sync(); 269 | // delete[] A; 270 | // delete[] B; 271 | // delete[] C; 272 | // delete[] sel; 273 | // return t; 274 | // } 275 | 276 | int main(int argc, char** argv) { 277 | parse_party_and_port(argv, &party, &port); 278 | vector ios; 279 | for (int i = 0; i < threads; ++i) 280 | ios.push_back(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port)); 281 | 282 | auto start = clock_start(); 283 | SIMDCircExec* simd_circ = new SIMDCircExec(party, threads, ios.data()); 284 | double timeused = time_from(start); 285 | std::cout << party << "\tsetup\t" << timeused / 1000 << "ms" << std::endl; 286 | NetIO* io = ios[0]; 287 | 288 | // std::cout << "BOOL MUX EVALUATION\t" << test_mux(simd_circ, io)/1000 << 289 | // "ms" << std::endl; 290 | std::cout << "BOOL AND EVALUATION\t" << test_and(simd_circ, io) / 1000 << "ms" << std::endl; 291 | std::cout << "BOOL XOR EVALUATION\t" << test_xor(simd_circ, io) / 1000 << "ms" << std::endl; 292 | std::cout << "BOOL NOT EVALUATION\t" << test_not(simd_circ, io) / 1000 << "ms" << std::endl; 293 | 294 | std::cout << "BLOCK AND EVALUATION\t" << test_block_and(simd_circ, io) / 1000 << "ms" << std::endl; 295 | std::cout << "BLOCK XOR EVALUATION\t" << test_block_xor(simd_circ, io) / 1000 << "ms" << std::endl; 296 | std::cout << "BLOCK NOT EVALUATION\t" << test_block_not(simd_circ, io) / 1000 << "ms" << std::endl; 297 | 298 | delete simd_circ; 299 | delete io; 300 | } -------------------------------------------------------------------------------- /emp-aby/lut.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "emp-aby/he_interface.hpp" 4 | #include "emp-aby/utils.h" 5 | 6 | #include 7 | #include 8 | namespace emp { 9 | 10 | template 11 | class LUT { 12 | private: 13 | int rotated_pool_size; 14 | int num_used = 0; 15 | ThreadPool* pool; 16 | bool* rotation; 17 | // int64_t* lut_share; 18 | HE* he; 19 | MPIOChannel* io; 20 | PRG prg; 21 | 22 | void shuffle(lbcrypto::Ciphertext& c, bool* rotation, size_t batch_size, size_t i); 23 | 24 | public: 25 | int64_t* lut_share; 26 | int num_party; 27 | int party; 28 | int64_t table[2]; 29 | LUT(int num_party, int party, MPIOChannel* io, ThreadPool* pool, HE* he, int rot_pool_size = 20); 30 | LUT(int num_party, int party, MPIOChannel* io, ThreadPool* pool, HE* he, int64_t table[2], 31 | int rot_pool_size = 20); 32 | ~LUT(); 33 | void generate_shares(int64_t* lut_share, bool* rotation, int num_shares, int64_t table[2]); 34 | void lookup(int64_t* out, bool* in, size_t length); 35 | }; 36 | 37 | template 38 | LUT::LUT(int num_party, int party, MPIOChannel* io, ThreadPool* pool, HE* he, int rot_pool_size) { 39 | this->io = io; 40 | this->party = party; 41 | this->num_party = num_party; 42 | this->pool = pool; 43 | 44 | this->he = he; 45 | this->rotated_pool_size = 46 | rot_pool_size * (he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2); 47 | } 48 | 49 | template 50 | LUT::LUT(int num_party, int party, MPIOChannel* io, ThreadPool* pool, HE* he, int64_t table[2], 51 | int rot_pool_size) 52 | : LUT(num_party, party, io, pool, he, rot_pool_size) { 53 | this->table[0] = table[0]; 54 | this->table[1] = table[1]; 55 | this->rotation = new bool[rotated_pool_size]; 56 | this->lut_share = new int64_t[2 * rotated_pool_size]; 57 | // std::cout << "go in gen shares \n"; 58 | this->generate_shares(this->lut_share, this->rotation, this->rotated_pool_size, this->table); 59 | } 60 | 61 | template 62 | void LUT::shuffle(lbcrypto::Ciphertext& c, bool* rotation, size_t batch_size, size_t rot_idx) { 63 | lbcrypto::Ciphertext rot_1, rot_2; 64 | rot_1 = he->cc->EvalRotate(c, 1); 65 | rot_2 = he->cc->EvalRotate(c, -1); 66 | vector tmp; 67 | vector mult1, mult2, mult3; 68 | mult1.resize(batch_size); 69 | mult2.resize(batch_size); 70 | mult3.resize(batch_size); 71 | for (int j = 0; j < batch_size / 2; ++j) { 72 | if (rotation[(rot_idx)*batch_size / 2 + j] == true) { 73 | mult1[2 * j] = 1; 74 | mult2[2 * j + 1] = 1; 75 | } 76 | else { 77 | mult3[2 * j] = 1; 78 | mult3[2 * j + 1] = 1; 79 | } 80 | } 81 | 82 | auto plain1 = he->cc->MakePackedPlaintext(mult3); 83 | // auto tmp1 = he->cc->Encrypt(he->pk, plain1); 84 | auto tmp1 = he->cc->EvalMult(c, plain1); 85 | c = tmp1; 86 | 87 | plain1 = he->cc->MakePackedPlaintext(mult1); 88 | // tmp1 = he->cc->Encrypt(he->pk, plain1); 89 | tmp1 = he->cc->EvalMult(rot_1, plain1); 90 | he->cc->EvalAddInPlace(c, tmp1); 91 | 92 | plain1 = he->cc->MakePackedPlaintext(mult2); 93 | // tmp1 = he->cc->Encrypt(he->pk, plain1); 94 | tmp1 = he->cc->EvalMult(rot_2, plain1); 95 | he->cc->EvalAddInPlace(c, tmp1); 96 | } 97 | 98 | template 99 | void LUT::generate_shares(int64_t* lut_share, bool* rotation, int num_shares, int64_t* table) { 100 | int batch_size = he->cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2; 101 | int n = 2 * num_shares; 102 | prg.random_bool((bool*)rotation, num_shares); 103 | std::vector> ciphertext; 104 | vector> res; 105 | res.push_back(pool->enqueue([this, n, rotation, batch_size, &ciphertext, table, lut_share]() { 106 | // auto start = clock_start(); 107 | for (int i = 0; i < ceil((double)n / (double)batch_size); ++i) { 108 | lbcrypto::Ciphertext c; 109 | 110 | // for (int j = 1; j < party; ++j) { 111 | if (((party - 1) >= he->mult_depth) && (((party - 1) % he->mult_depth) == 0)) { 112 | he->bootstrap(c, party - 1, party); 113 | } 114 | // } 115 | if (party == ALICE) { 116 | vector tmp; 117 | tmp.resize(batch_size); 118 | for (int j = 0; j < batch_size / 2; ++j) { 119 | if (rotation[(i * batch_size) / 2 + j] == true) { 120 | tmp[2 * j] = table[1]; 121 | tmp[2 * j + 1] = table[0]; 122 | } 123 | else { 124 | tmp[2 * j] = table[0]; 125 | tmp[2 * j + 1] = table[1]; 126 | } 127 | } 128 | 129 | lbcrypto::Plaintext plaintext = he->cc->MakePackedPlaintext(tmp); 130 | c = he->cc->Encrypt(he->pk, plaintext); 131 | } 132 | else { 133 | if (!(party > he->mult_depth) || !((party - 1) % he->mult_depth == 0)) { 134 | he->deserialize_recv(c, party - 1); 135 | } 136 | shuffle(c, rotation, batch_size, i); 137 | // Refresh ciphertext 138 | const std::vector& cv = c->GetElements(); 139 | const auto cryptoParams = std::dynamic_pointer_cast>( 140 | he->pk->GetCryptoParameters()); 141 | const auto ns = cryptoParams->GetNoiseScale(); 142 | 143 | lbcrypto::DCRTPoly::DggType dgg(NOISE_FLOODING::MP_SD); 144 | lbcrypto::DCRTPoly e(dgg, cv[0].GetParams(), Format::EVALUATION); 145 | lbcrypto::DCRTPoly b = cv[0] + ns * e; 146 | c->SetElements({std::move(b), std::move(cv[1])}); 147 | } 148 | 149 | if (party != num_party) { 150 | if ((party < he->mult_depth) || (party % he->mult_depth != 0)) { 151 | he->serialize_send(c, party + 1); 152 | } 153 | else { 154 | he->bootstrap(c, party, party + 1); 155 | } 156 | } 157 | if (party == num_party){ 158 | ciphertext.push_back(c); 159 | he->serialize_sendall(ciphertext, 1, BOOT_REQ_MSG); 160 | he->enc_to_share(c, lut_share + i * batch_size, PACKED_ENCODING); 161 | ciphertext.clear(); 162 | } 163 | } 164 | // double timeused = time_from(start); 165 | // std::cout << party << "\tprocessing time\t" << timeused / 1000 << std::endl; 166 | })); 167 | 168 | for (int j = 1; j < num_party; ++j) { 169 | if ((j != party) && (j != party - 1)) { 170 | if ((j >= he->mult_depth) && ((j % he->mult_depth) == 0)) { 171 | res.push_back(pool->enqueue([this, j, batch_size, n]() { 172 | for (int i = 0; i < ceil((double)n / (double)batch_size); ++i) { 173 | lbcrypto::Ciphertext c; 174 | he->bootstrap(c, j, j + 1); 175 | } 176 | })); 177 | } 178 | } 179 | } 180 | 181 | if(party != num_party){ 182 | res.push_back(pool->enqueue([this, batch_size, n, lut_share]() { 183 | std::vector> c; 184 | for (int i = 0; i < ceil((double)n / (double)batch_size); ++i) { 185 | he->deserialize_recv(c, num_party, 1, BOOT_REQ_MSG); 186 | he->enc_to_share(c[0], lut_share + i * batch_size, PACKED_ENCODING); 187 | } 188 | })); 189 | } 190 | 191 | for (auto& v : res) 192 | v.get(); 193 | res.clear(); 194 | 195 | } 196 | 197 | template 198 | void LUT::lookup(int64_t* out, bool* in, size_t length) { 199 | bool* r = nullptr; 200 | int64_t* t = nullptr; 201 | bool delete_array = false; 202 | if (length > rotated_pool_size) { 203 | r = new bool[(length + rotated_pool_size - 1) / rotated_pool_size * rotated_pool_size]; 204 | t = new int64_t[(length + rotated_pool_size - 1) / rotated_pool_size * rotated_pool_size * 2]; 205 | delete_array = true; 206 | for (uint i = 0; i < (length + rotated_pool_size - 1) / rotated_pool_size; ++i) 207 | this->generate_shares(t + 2 * i * rotated_pool_size, r + i * rotated_pool_size, rotated_pool_size, 208 | this->table); 209 | size_t tocp = std::min((int)((length + rotated_pool_size - 1) / rotated_pool_size * rotated_pool_size - length), 210 | num_used); 211 | memcpy(this->rotation, r + (length + rotated_pool_size - 1) / rotated_pool_size * rotated_pool_size - length, 212 | tocp); 213 | memcpy(this->lut_share, 214 | t + 2 * ((length + rotated_pool_size - 1) / rotated_pool_size * rotated_pool_size - length), 215 | 2 * tocp * sizeof(int64_t)); 216 | num_used = 0; 217 | delete_array = true; 218 | } 219 | else if (length > rotated_pool_size - num_used) { 220 | r = new bool[length]; 221 | t = new int64_t[2 * length]; 222 | delete_array = true; 223 | memcpy(r, rotation + num_used, rotated_pool_size - num_used); 224 | memcpy(t, lut_share + 2 * num_used, 2 * (rotated_pool_size - num_used) * sizeof(int64_t)); 225 | this->generate_shares(this->lut_share, this->rotation, this->rotated_pool_size, this->table); 226 | memcpy(r + rotated_pool_size - num_used, this->rotation, length - (rotated_pool_size - num_used)); 227 | memcpy(t + 2 * (rotated_pool_size - num_used), this->lut_share, 228 | 2 * (length - (rotated_pool_size - num_used)) * sizeof(int64_t)); 229 | num_used = length - (rotated_pool_size - num_used); 230 | } 231 | else { 232 | r = rotation + num_used; 233 | t = lut_share + 2 * num_used; 234 | num_used += length; 235 | } 236 | bool* e = new bool[length]; 237 | xorBools_arr(e, r, in, length); 238 | if (party == ALICE) { 239 | bool* tmp = new bool[length]; 240 | for (uint i = 2; i <= num_party; ++i) { 241 | io->recv_bool(i, tmp, length); 242 | xorBools_arr(e, e, tmp, length); 243 | } 244 | vector> res; 245 | for (uint i = 2; i <= num_party; ++i) { 246 | res.push_back(pool->enqueue([this, i, e, length]() { 247 | this->io->send_bool(i, e, length); 248 | io->flush(i); 249 | })); 250 | } 251 | for (auto& v : res) 252 | v.get(); 253 | res.clear(); 254 | 255 | delete[] tmp; 256 | } 257 | else { 258 | io->send_bool(ALICE, e, length); 259 | io->flush(ALICE); 260 | io->recv_bool(ALICE, e, length); 261 | } 262 | for (int i = 0; i < length; ++i) { 263 | out[i] = t[i * 2 + (int)e[i]]; 264 | } 265 | 266 | delete[] e; 267 | if (delete_array) { 268 | delete[] r; 269 | delete[] t; 270 | } 271 | } 272 | 273 | template 274 | LUT::~LUT() { 275 | delete[] rotation; 276 | delete[] lut_share; 277 | } 278 | 279 | } // namespace emp 280 | -------------------------------------------------------------------------------- /test/mp_simd_exec.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-aby/simd_interface/mp-simd-exec.h" 2 | #include "fstream" 3 | #include "emp-aby/io/multi-io.hpp" 4 | 5 | int party, port; 6 | 7 | const static int threads = 4; 8 | int num_party; 9 | 10 | template 11 | double test_and(MPSIMDCircExec* simd_circ, MPIOChannel* io, int length = 2000) { 12 | bool *A = new bool[length], *B = new bool[length], *C = new bool[length]; 13 | PRG prg; 14 | prg.random_bool(A, length); 15 | prg.random_bool(B, length); 16 | auto start = clock_start(); 17 | // std::cout << "start multiplying" << std::endl; 18 | simd_circ->and_gate(C, A, B, length); 19 | // for(int i = 0; i < length; ++i){ 20 | // std::cout << i << " A: " << A[i] << " B: " << B[i] << " C: " << C[i] << std::endl; 21 | // } 22 | long long t = time_from(start); 23 | io->flush(); 24 | // std::cout << "Now checking" << std::endl; 25 | io->sync(); 26 | if (party == ALICE) { 27 | bool *A0 = new bool[length], *B0 = new bool[length], *C0 = new bool[length]; 28 | for (int i = 2; i <= num_party; ++i) { 29 | io->recv_bool(i, A0, length); 30 | io->recv_bool(i, B0, length); 31 | io->recv_bool(i, C0, length); 32 | xorBools_arr(A, A, A0, length); 33 | xorBools_arr(B, B, B0, length); 34 | xorBools_arr(C, C, C0, length); 35 | } 36 | for (int i = 0; i < length; ++i) { 37 | if ((A[i] & B[i]) != C[i]) { 38 | std::cout << i << " "; 39 | error(" Bool AND Failed!"); 40 | } 41 | } 42 | delete[] A0; 43 | delete[] B0; 44 | delete[] C0; 45 | } 46 | else { 47 | io->send_bool(ALICE, A, length); 48 | io->send_bool(ALICE, B, length); 49 | io->send_bool(ALICE, C, length); 50 | } 51 | io->flush(); 52 | delete[] A; 53 | delete[] B; 54 | delete[] C; 55 | return t; 56 | // std::cout << "Bool AND passed \n"; 57 | } 58 | 59 | template 60 | double test_block_and(MPSIMDCircExec* simd_circ, MPIOChannel* io, int length = 2000) { 61 | block *A = new block[length], *B = new block[length], *C = new block[length]; 62 | PRG prg; 63 | prg.random_block(A, length); 64 | prg.random_block(B, length); 65 | 66 | auto start = clock_start(); 67 | simd_circ->and_gate(C, A, B, length); 68 | long long t = time_from(start); 69 | io->flush(); 70 | // std::cout << "Now checking" << std::endl; 71 | if (party == ALICE) { 72 | block *A0 = new block[length], *B0 = new block[length], *C0 = new block[length]; 73 | for (int i = 2; i <= num_party; ++i) { 74 | io->recv_block(i, A0, length); 75 | io->recv_block(i, B0, length); 76 | io->recv_block(i, C0, length); 77 | xorBlocks_arr(A, A, A0, length); 78 | xorBlocks_arr(B, B, B0, length); 79 | xorBlocks_arr(C, C, C0, length); 80 | } 81 | andBlocks_arr(A, A, B, length); 82 | 83 | if (!cmpBlock(A, C, length)) { 84 | error(" Bool AND Failed!"); 85 | } 86 | 87 | delete[] A0; 88 | delete[] B0; 89 | delete[] C0; 90 | } 91 | else { 92 | io->send_block(ALICE, A, length); 93 | io->send_block(ALICE, B, length); 94 | io->send_block(ALICE, C, length); 95 | } 96 | io->flush(); 97 | delete[] A; 98 | delete[] B; 99 | delete[] C; 100 | return t; 101 | // std::cout << "Block AND passed \n"; 102 | } 103 | 104 | template 105 | double test_not(MPSIMDCircExec* simd_circ, MPIOChannel* io, int length = 2000) { 106 | bool *A = new bool[length], *B = new bool[length]; 107 | PRG prg; 108 | prg.random_bool(A, length); 109 | auto start = clock_start(); 110 | simd_circ->not_gate(B, A, length); 111 | long long t = time_from(start); 112 | io->flush(); 113 | io->sync(); 114 | if (party == ALICE) { 115 | bool *A0 = new bool[length], *B0 = new bool[length], *C0 = new bool[length]; 116 | for (int i = 2; i <= num_party; ++i) { 117 | io->recv_bool(i, A0, length); 118 | io->recv_bool(i, B0, length); 119 | xorBools_arr(A, A, A0, length); 120 | xorBools_arr(B, B, B0, length); 121 | } 122 | for (int i = 0; i < length; ++i) { 123 | if (A[i] != !B[i]) { 124 | std::cout << i << " "; 125 | error(" Bool NOT Failed!"); 126 | } 127 | } 128 | delete[] A0; 129 | delete[] B0; 130 | delete[] C0; 131 | } 132 | else { 133 | io->send_bool(ALICE, A, length); 134 | io->send_bool(ALICE, B, length); 135 | } 136 | io->flush(); 137 | delete[] A; 138 | delete[] B; 139 | return t; 140 | // std::cout << "Bool NOT gate test Passed \n"; 141 | } 142 | 143 | template 144 | double test_block_not(MPSIMDCircExec* simd_circ, MPIOChannel* io, int length = 2000) { 145 | block *A = new block[length], *B = new block[length]; 146 | PRG prg; 147 | prg.random_block(A, length); 148 | auto start = clock_start(); 149 | simd_circ->not_gate(B, A, length); 150 | long long t = time_from(start); 151 | if (party == ALICE) { 152 | block *A0 = new block[length], *B0 = new block[length]; 153 | for (int i = 2; i <= num_party; ++i) { 154 | io->recv_block(i, A0, length); 155 | io->recv_block(i, B0, length); 156 | xorBlocks_arr(A, A, A0, length); 157 | xorBlocks_arr(B, B, B0, length); 158 | } 159 | // andBlocks_arr(A, A, B, length); 160 | xorBlocks_arr(A, A, all_one_block, length); 161 | if (!cmpBlock(A, B, length)) { 162 | error(" Bool AND Failed!"); 163 | } 164 | 165 | delete[] A0; 166 | delete[] B0; 167 | } 168 | else { 169 | io->send_block(ALICE, A, length); 170 | io->send_block(ALICE, B, length); 171 | } 172 | io->flush(); 173 | delete[] A; 174 | delete[] B; 175 | return t; 176 | // std::cout << "Block NOT Gate tests Passed \n"; 177 | } 178 | 179 | template 180 | double test_xor(MPSIMDCircExec* simd_circ, MPIOChannel* io, int length = 10000) { 181 | bool *A = new bool[length], *B = new bool[length], *C = new bool[length]; 182 | PRG prg; 183 | prg.random_bool(A, length); 184 | prg.random_bool(B, length); 185 | 186 | auto start = clock_start(); 187 | simd_circ->xor_gate(C, A, B, length); 188 | long long t = time_from(start); 189 | io->flush(); 190 | io->sync(); 191 | if (party == ALICE) { 192 | bool *A0 = new bool[length], *B0 = new bool[length], *C0 = new bool[length]; 193 | for (int i = 2; i <= num_party; ++i) { 194 | io->recv_bool(i, A0, length); 195 | io->recv_bool(i, B0, length); 196 | io->recv_bool(i, C0, length); 197 | xorBools_arr(A, A, A0, length); 198 | xorBools_arr(B, B, B0, length); 199 | xorBools_arr(C, C, C0, length); 200 | } 201 | for (int i = 0; i < length; ++i) { 202 | if ((A[i] ^ B[i]) != C[i]) { 203 | std::cout << i << " "; 204 | error(" Bool AND Failed!"); 205 | } 206 | } 207 | delete[] A0; 208 | delete[] B0; 209 | delete[] C0; 210 | } 211 | else { 212 | io->send_bool(ALICE, A, length); 213 | io->send_bool(ALICE, B, length); 214 | io->send_bool(ALICE, C, length); 215 | } 216 | io->flush(); 217 | io->sync(); 218 | delete[] A; 219 | delete[] B; 220 | delete[] C; 221 | return t; 222 | // std::cout << "Bool XOR passed \n"; 223 | } 224 | 225 | template 226 | double test_block_xor(MPSIMDCircExec* simd_circ, MPIOChannel* io, int length = 2000) { 227 | block *A = new block[length], *B = new block[length], *C = new block[length]; 228 | PRG prg; 229 | prg.random_block(A, length); 230 | prg.random_block(B, length); 231 | 232 | auto start = clock_start(); 233 | simd_circ->xor_gate(C, A, B, length); 234 | long long t = time_from(start); 235 | io->flush(); 236 | if (party == ALICE) { 237 | block *A0 = new block[length], *B0 = new block[length], *C0 = new block[length]; 238 | for (int i = 2; i <= num_party; ++i) { 239 | io->recv_block(i, A0, length); 240 | io->recv_block(i, B0, length); 241 | io->recv_block(i, C0, length); 242 | xorBlocks_arr(A, A, A0, length); 243 | xorBlocks_arr(B, B, B0, length); 244 | xorBlocks_arr(C, C, C0, length); 245 | } 246 | xorBlocks_arr(A, A, B, length); 247 | 248 | if (!cmpBlock(A, C, length)) { 249 | error(" Bool AND Failed!"); 250 | } 251 | 252 | delete[] A0; 253 | delete[] B0; 254 | delete[] C0; 255 | } 256 | else { 257 | io->send_block(ALICE, A, length); 258 | io->send_block(ALICE, B, length); 259 | io->send_block(ALICE, C, length); 260 | } 261 | io->flush(); 262 | delete[] A; 263 | delete[] B; 264 | delete[] C; 265 | return t; 266 | // std::cout << "Block XOR passed \n"; 267 | } 268 | 269 | int main(int argc, char** argv) { 270 | if (argc < 4) { 271 | std::cout << "Format: test_mp_bit_triple PartyID Port num_parties" << std::endl; 272 | exit(0); 273 | } 274 | parse_party_and_port(argv, &party, &port); 275 | num_party = atoi(argv[3]); 276 | 277 | std::vector> net_config; 278 | 279 | if (argc == 5) { 280 | const char* file = argv[4]; 281 | FILE* f = fopen(file, "r"); 282 | for (int i = 0; i < num_party; ++i) { 283 | char* c = (char*)malloc(15 * sizeof(char)); 284 | uint p; 285 | fscanf(f, "%s %d\n", c, &p); 286 | std::string s(c); 287 | net_config.push_back(std::make_pair(s, p)); 288 | fflush(f); 289 | } 290 | fclose(f); 291 | } 292 | else { 293 | for (int i = 0; i < num_party; ++i) { 294 | std::string s = "127.0.0.1"; 295 | uint p = (port + 4 * num_party * i); 296 | net_config.push_back(std::make_pair(s, p)); 297 | } 298 | } 299 | 300 | MultiIO* io = new MultiIO(party, num_party, net_config); 301 | std::cout << party << " connected \n"; 302 | 303 | ThreadPool pool(threads); 304 | io->setup_ot_ios(); 305 | io->flush(); 306 | std::cout << "io setup" << std::endl; 307 | long long int buffer_length =((ferret_b13.n - ferret_b13.k - ferret_b13.t * ferret_b13.log_bin_sz_pre - 128) / 128) * 128; 308 | auto start = clock_start(); 309 | MPSIMDCircExec* simd_circ = new MPSIMDCircExec(num_party, party, &pool, io); 310 | double timeused = time_from(start); 311 | std::cout << party << "\tsetup\t" << timeused / (2000 * buffer_length) << " ms" << std::endl; 312 | // std::cout << party << " BOOL AND EVALUATION\t" << test_and(simd_circ, io, buffer_length) / 1000 << "ms" 313 | // << std::endl; 314 | // std::cout << party << " BOOL XOR EVALUATION\t" << test_xor(simd_circ, io) / 1000 << "ms" 315 | // << std::endl; 316 | // std::cout << party << " BOOL NOT EVALUATION\t" << test_not(simd_circ, io) / 1000 << "ms" 317 | // << std::endl; 318 | 319 | std::cout << party << " BLOCK AND EVALUATION\t" 320 | << test_block_and(simd_circ, io, buffer_length / 128) / (buffer_length) << " us" << std::endl; 321 | // std::cout << party << " BLOCK XOR EVALUATION\t" << test_block_xor(simd_circ, io) / 1000 322 | // << "ms" << std::endl; 323 | // std::cout << party << " BLOCK NOT EVALUATION\t" << test_block_not(simd_circ, io) / 1000 324 | // << "ms" << std::endl; 325 | // delete simd_circ; 326 | delete io; 327 | } 328 | -------------------------------------------------------------------------------- /emp-aby/simd_interface/mp-simd-exec.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "simd.h" 4 | #include "emp-aby/triple-providers/mp-bit-triple.h" 5 | 6 | namespace emp { 7 | 8 | template 9 | class MPSIMDCircExec : SIMDCircuitExecution> { 10 | private: 11 | MPBitTripleProvider* btp = nullptr; 12 | MPIOChannel* io; 13 | ThreadPool* pool; 14 | bool *bit_triple_a, *bit_triple_b, *bit_triple_c; 15 | block *block_triple_a, *block_triple_b, *block_triple_c; 16 | size_t num_triples = 0, num_block_triples = 0; 17 | size_t num_triples_pool; 18 | size_t num_block_triples_pool; 19 | int num_party; 20 | double total_time = 0; 21 | vector d, d1, e, e1; 22 | 23 | public: 24 | int cur_party; 25 | 26 | MPSIMDCircExec(int num_party, int party, ThreadPool* pool, MPIOChannel* io) { 27 | this->cur_party = party; 28 | this->num_party = num_party; 29 | this->pool = pool; 30 | btp = new MPBitTripleProvider(num_party, party, pool, io); 31 | num_triples_pool = btp->BUFFER_SZ; 32 | num_block_triples_pool = btp->BUFFER_SZ / 128; 33 | bit_triple_a = new bool[num_triples_pool]; 34 | bit_triple_b = new bool[num_triples_pool]; 35 | bit_triple_c = new bool[num_triples_pool]; 36 | btp->get_triple(bit_triple_a, bit_triple_b, bit_triple_c); 37 | block_triple_a = new block[num_block_triples_pool]; 38 | block_triple_b = new block[num_block_triples_pool]; 39 | block_triple_c = new block[num_block_triples_pool]; 40 | btp->get_triple(block_triple_a, block_triple_b, block_triple_c); 41 | this->io = io; 42 | } 43 | 44 | ~MPSIMDCircExec() { 45 | delete btp; 46 | delete[] bit_triple_a; 47 | delete[] bit_triple_b; 48 | delete[] bit_triple_c; 49 | 50 | delete[] block_triple_a; 51 | delete[] block_triple_b; 52 | delete[] block_triple_c; 53 | std::cout << "Total time in SIMD: " << total_time << " us\n"; 54 | }; 55 | 56 | MPBitTripleProvider * getBtp(){ 57 | return btp; 58 | } 59 | 60 | void and_gate(bool* out1, bool* in1, bool* in2, size_t length) { 61 | auto t = clock_start(); 62 | bool *a = nullptr, *b = nullptr, *c = nullptr; 63 | bool delete_array = false; 64 | this->template and_helper(a, b, c, length, delete_array, bit_triple_a, bit_triple_b, bit_triple_c, num_triples_pool, 65 | num_triples); 66 | bool *d = new bool[length], *e = new bool[length]; 67 | 68 | for (uint i = 0; i < length; ++i) { 69 | d[i] = in1[i] ^ a[i]; 70 | e[i] = in2[i] ^ b[i]; 71 | } 72 | 73 | if (cur_party == ALICE) { 74 | bool *d0 = new bool[length], *e0 = new bool[length]; 75 | 76 | for (int i = 2; i <= num_party; ++i) { 77 | io->recv_bool(i, d0, length); 78 | io->recv_bool(i, e0, length); 79 | xorBools_arr(d, d, d0, length); 80 | xorBools_arr(e, e, e0, length); 81 | } 82 | 83 | vector> res; 84 | for (int i = 2; i <= num_party; ++i) { 85 | res.push_back(pool->enqueue([this, i, d, e, length]() { 86 | this->io->send_bool(i, d, length); 87 | this->io->send_bool(i, e, length); 88 | })); 89 | } 90 | 91 | for (auto& v : res) 92 | v.get(); 93 | res.clear(); 94 | 95 | delete[] d0; 96 | delete[] e0; 97 | } 98 | else { 99 | io->send_bool(ALICE, d, length); 100 | io->send_bool(ALICE, e, length); 101 | 102 | io->recv_bool(ALICE, d, length); 103 | io->recv_bool(ALICE, e, length); 104 | } 105 | io->flush(); 106 | if (cur_party == ALICE) { 107 | for (uint i = 0; i < length; ++i) 108 | out1[i] = (d[i] & b[i]) ^ (e[i] & a[i]) ^ c[i] ^ (d[i] & e[i]); 109 | } 110 | else { 111 | for (uint i = 0; i < length; ++i) 112 | out1[i] = (d[i] & b[i]) ^ (e[i] & a[i]) ^ c[i]; 113 | } 114 | delete[] d; 115 | delete[] e; 116 | if (delete_array) { 117 | delete[] a; 118 | delete[] b; 119 | delete[] c; 120 | } 121 | total_time += time_from(t); 122 | } 123 | 124 | void and_gate(block* out1, block* in1, block* in2, size_t length) { 125 | auto t = clock_start(); 126 | block *a = nullptr, *b = nullptr, *c = nullptr; 127 | bool delete_array = false; 128 | this->template and_helper(a, b, c, length, delete_array, block_triple_a, block_triple_b, block_triple_c, 129 | num_block_triples_pool, num_block_triples); 130 | if (length > d.size()) { 131 | d.resize(length); 132 | d1.resize(length); 133 | e.resize(length); 134 | e1.resize(length); 135 | } 136 | 137 | for (uint i = 0; i < length; ++i) { 138 | d[i] = in1[i] ^ a[i]; 139 | e[i] = in2[i] ^ b[i]; 140 | } 141 | 142 | if (cur_party == ALICE) { 143 | for (int i = 2; i <= num_party; ++i) { 144 | io->recv_block(i, d1.data(), length); 145 | io->recv_block(i, e1.data(), length); 146 | xorBlocks_arr(d.data(), d.data(), d1.data(), length); 147 | xorBlocks_arr(e.data(), e.data(), e1.data(), length); 148 | } 149 | 150 | for (int i = 2; i <= num_party; ++i) { 151 | io->send_block(i, d.data(), length); 152 | io->send_block(i, e.data(), length); 153 | } 154 | } 155 | else { 156 | io->send_block(ALICE, d.data(), length); 157 | io->send_block(ALICE, e.data(), length); 158 | 159 | io->recv_block(ALICE, d.data(), length); 160 | io->recv_block(ALICE, e.data(), length); 161 | } 162 | io->flush(); 163 | 164 | if (cur_party == ALICE) { 165 | for (uint i = 0; i < length; ++i) 166 | out1[i] = (d[i] & b[i]) ^ (e[i] & a[i]) ^ c[i] ^ (d[i] & e[i]); 167 | } 168 | else { 169 | for (uint i = 0; i < length; ++i) 170 | out1[i] = (d[i] & b[i]) ^ (e[i] & a[i]) ^ c[i]; 171 | } 172 | 173 | if (delete_array) { 174 | delete[] a; 175 | delete[] b; 176 | delete[] c; 177 | } 178 | total_time += time_from(t); 179 | } 180 | 181 | void and_gate(bool* out, bool* in1, bool* in2, size_t bool_length, block* block_out, block* block_in1, 182 | block* block_in2, size_t length) { 183 | auto t = clock_start(); 184 | bool *a = nullptr, *b = nullptr, *c = nullptr; 185 | bool delete_array = false; 186 | this->template and_helper(a, b, c, bool_length, delete_array, bit_triple_a, bit_triple_b, bit_triple_c, num_triples_pool, 187 | num_triples); 188 | bool *d = new bool[bool_length], *e = new bool[bool_length]; 189 | 190 | for (uint i = 0; i < bool_length; ++i) { 191 | d[i] = in1[i] ^ a[i]; 192 | e[i] = in2[i] ^ b[i]; 193 | } 194 | 195 | block *block_a = nullptr, *block_b = nullptr, *block_c = nullptr; 196 | bool delete_block_array = false; 197 | this->template and_helper(block_a, block_b, block_c, length, delete_block_array, block_triple_a, block_triple_b, 198 | block_triple_c, num_block_triples_pool, num_block_triples); 199 | 200 | if (length > this->d.size()) { 201 | this->d.resize(length); 202 | this->d1.resize(length); 203 | this->e.resize(length); 204 | this->e1.resize(length); 205 | } 206 | 207 | for (uint i = 0; i < length; ++i) { 208 | this->d[i] = block_in1[i] ^ block_a[i]; 209 | this->e[i] = block_in2[i] ^ block_b[i]; 210 | } 211 | 212 | if (cur_party == ALICE) { 213 | bool *d0 = new bool[bool_length], *e0 = new bool[bool_length]; 214 | 215 | for (int i = 2; i <= num_party; ++i) { 216 | io->recv_bool(i, d0, bool_length); 217 | io->recv_bool(i, e0, bool_length); 218 | xorBools_arr(d, d, d0, bool_length); 219 | xorBools_arr(e, e, e0, bool_length); 220 | io->recv_block(i, this->d1.data(), length); 221 | io->recv_block(i, this->e1.data(), length); 222 | xorBlocks_arr(this->d.data(), this->d.data(), this->d1.data(), length); 223 | xorBlocks_arr(this->e.data(), this->e.data(), this->e1.data(), length); 224 | } 225 | 226 | vector> res; 227 | for (int i = 2; i <= num_party; ++i) { 228 | res.push_back(pool->enqueue([this, i, d, e, bool_length, length]() { 229 | this->io->send_bool(i, d, bool_length); 230 | this->io->send_bool(i, e, bool_length); 231 | io->send_block(i, this->d.data(), length); 232 | io->send_block(i, this->e.data(), length); 233 | })); 234 | } 235 | 236 | for (auto& v : res) 237 | v.get(); 238 | res.clear(); 239 | 240 | delete[] d0; 241 | delete[] e0; 242 | } 243 | else { 244 | io->send_bool(ALICE, d, bool_length); 245 | io->send_bool(ALICE, e, bool_length); 246 | io->send_block(ALICE, this->d.data(), length); 247 | io->send_block(ALICE, this->e.data(), length); 248 | 249 | io->recv_bool(ALICE, d, bool_length); 250 | io->recv_bool(ALICE, e, bool_length); 251 | io->recv_block(ALICE, this->d.data(), length); 252 | io->recv_block(ALICE, this->e.data(), length); 253 | } 254 | io->flush(); 255 | if (cur_party == ALICE) { 256 | for (uint i = 0; i < bool_length; ++i) 257 | out[i] = (d[i] & b[i]) ^ (e[i] & a[i]) ^ c[i] ^ (d[i] & e[i]); 258 | for (uint i = 0; i < length; ++i) 259 | block_out[i] = 260 | (this->d[i] & block_b[i]) ^ (this->e[i] & block_a[i]) ^ block_c[i] ^ (this->d[i] & this->e[i]); 261 | } 262 | else { 263 | for (uint i = 0; i < bool_length; ++i) 264 | out[i] = (d[i] & b[i]) ^ (e[i] & a[i]) ^ c[i]; 265 | for (uint i = 0; i < length; ++i) 266 | block_out[i] = (this->d[i] & block_b[i]) ^ (this->e[i] & block_a[i]) ^ block_c[i]; 267 | } 268 | delete[] d; 269 | delete[] e; 270 | if (delete_array) { 271 | delete[] a; 272 | delete[] b; 273 | delete[] c; 274 | } 275 | total_time += time_from(t); 276 | } 277 | 278 | void xor_gate(bool* out1, bool* in1, bool* in2, size_t length) { 279 | //auto t = clock_start(); 280 | xorBools_arr(out1, in1, in2, length); 281 | //total_time += time_from(t); 282 | } 283 | 284 | void xor_gate(block* out1, block* in1, block* in2, size_t length) { 285 | //auto t = clock_start(); 286 | for (uint i = 0; i < length; ++i) { 287 | out1[i] = in1[i] ^ in2[i]; 288 | } 289 | //total_time += time_from(t); 290 | } 291 | 292 | void not_gate(bool* out1, bool* in1, size_t length) { 293 | //auto t = clock_start(); 294 | bool bit_to_xor = (cur_party == ALICE); 295 | for (uint i = 0; i < length; ++i) { 296 | (out1[i]) = (in1[i]) ^ bit_to_xor; 297 | } 298 | //total_time += time_from(t); 299 | } 300 | 301 | void not_gate(block* out1, block* in1, size_t length) { 302 | // auto t = clock_start(); 303 | block bit_to_xor = (cur_party == ALICE) ? all_one_block : zero_block; 304 | for (uint i = 0; i < length; ++i) { 305 | out1[i] = in1[i] ^ bit_to_xor; 306 | } 307 | // total_time += time_from(t); 308 | } 309 | }; 310 | 311 | } // namespace emp 312 | -------------------------------------------------------------------------------- /emp-aby/mp-circuit.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "emp-aby/wire.h" 4 | #include "emp-aby/simd_interface/mp-simd-exec.h" 5 | 6 | #include 7 | #define _debug 8 | namespace emp { 9 | 10 | template 11 | class Circuit { 12 | public: 13 | uint num_gates, num_wires, n1, n2, n3, party; 14 | vector> level_map; 15 | vector circuit; 16 | SIMDCirc* simd_circ; 17 | ~Circuit() {} 18 | Circuit(FILE* f) { 19 | this->from_file(f); 20 | } 21 | 22 | Circuit(const char* file, int party, SIMDCirc* simd_circ) { 23 | this->party = party; 24 | this->from_file(file); 25 | this->simd_circ = simd_circ; 26 | } 27 | 28 | void from_file(const char* file) { 29 | FILE* f = fopen(file, "r"); 30 | if (f == nullptr) { 31 | std::cout << file << "\n"; 32 | error("could not open file\n"); 33 | } 34 | this->from_file(f); 35 | fclose(f); 36 | } 37 | 38 | void add_input_wires() { 39 | for (uint i = 0; i < (n1 + n2); ++i) { 40 | Wire* w = new Wire(INPUT); 41 | circuit[i] = w; 42 | insert_level_map(w, i); 43 | } 44 | } 45 | 46 | void from_file(FILE* f) { 47 | circuit.clear(); 48 | level_map.clear(); 49 | int tmp = 0, in1 = 0, in2 = 0, out = 0, type = 0; 50 | fscanf(f, "%d%d\n", &num_gates, &num_wires); 51 | for (uint i = 0; i < num_wires; ++i) { 52 | Wire* w = nullptr; 53 | circuit.push_back(w); 54 | } 55 | fflush(f); 56 | fscanf(f, "%d%d%d\n", &n1, &n2, &n3); 57 | fflush(f); 58 | fscanf(f, "\n"); 59 | fflush(f); 60 | this->add_input_wires(); 61 | char str[10]; 62 | Wire* new_wire; 63 | for (uint i = 0; i < num_gates; ++i) { 64 | fscanf(f, "%d", &tmp); 65 | if (tmp == 2) { 66 | fscanf(f, "%d%d%d%d%s", &tmp, &in1, &in2, &out, str); 67 | fflush(f); 68 | if (str[0] == 'A') 69 | type = AND; 70 | else if (str[0] == 'X') 71 | type = XOR; 72 | Wire* in1_wire = circuit[in1]; 73 | Wire* in2_wire = circuit[in2]; 74 | new_wire = new Wire(type, in1_wire, in2_wire); 75 | circuit[out] = new_wire; 76 | this->insert_level_map(new_wire, out); 77 | } 78 | else if (tmp == 1) { 79 | (void)fscanf(f, "%d%d%d%s", &tmp, &in1, &out, str); 80 | type = INV; 81 | Wire* in1_wire = circuit[in1]; 82 | new_wire = new Wire(type, in1_wire); 83 | circuit[out] = new_wire; 84 | this->insert_level_map(new_wire, out); 85 | } 86 | } 87 | } 88 | 89 | void insert_level_map(Wire* new_wire, int out) { 90 | if (level_map.size() < new_wire->level) { 91 | error("Missing level!"); 92 | } 93 | else if (level_map.size() == new_wire->level) { 94 | level_map.push_back(std::vector(1, out)); 95 | } 96 | else { 97 | level_map[new_wire->level].push_back(out); 98 | } 99 | } 100 | 101 | template 102 | void compute(bool* out, bool* in, uint num, bool shared = false) { 103 | #ifdef __debug 104 | auto t = clock_start(); 105 | #endif 106 | uint block_num = num / 128, bool_num = num % 128; 107 | Wire* w; 108 | if (!shared && std::is_same>::value && (party == BOB)) { 109 | for (uint i = 0; i < n1; ++i) { 110 | w = circuit[i]; 111 | w->initialise_value(block_num, bool_num); 112 | w->set = true; 113 | } 114 | for (uint i = 0; i < n2; ++i) { 115 | w = circuit[i + n1]; 116 | w->initialise_value(block_num, bool_num); 117 | bool dummy[128] = {false}; 118 | uint j; 119 | for (j = 0; j < num; ++j) { 120 | dummy[j % 128] = in[i + j * n2]; 121 | if (j % 128 == 127) { 122 | w->value[j / 128] = bool_to_block(dummy); 123 | } 124 | } 125 | if (j % 128 != 0) { 126 | for (uint k = 0; k < bool_num; ++k) { 127 | w->rem_value[k] = dummy[k]; 128 | } 129 | } 130 | w->set = true; 131 | } 132 | } 133 | else { 134 | for (uint i = 0; i < n1; ++i) { 135 | w = circuit[i]; 136 | w->initialise_value(block_num, bool_num); 137 | bool dummy[128] = {false}; 138 | uint j; 139 | for (j = 0; j < num; ++j) { 140 | dummy[j % 128] = in[i + j * n1]; 141 | if (j % 128 == 127) { 142 | w->value[j / 128] = bool_to_block(dummy); 143 | } 144 | } 145 | if (j % 128 != 0) { 146 | for (uint k = 0; k < bool_num; ++k) { 147 | w->rem_value[k] = dummy[k]; 148 | } 149 | } 150 | w->set = true; 151 | } 152 | if (!shared && std::is_same>::value && (party == ALICE)) { 153 | for (uint i = 0; i < n2; ++i) { 154 | w = circuit[i + n1]; 155 | w->initialise_value(block_num, bool_num); 156 | w->set = true; 157 | } 158 | } 159 | else { 160 | for (uint i = 0; i < n2; ++i) { 161 | w = circuit[i + n1]; 162 | w->initialise_value(block_num, bool_num); 163 | bool dummy[128] = {false}; 164 | uint j; 165 | for (j = 0; j < num; ++j) { 166 | dummy[j % 128] = in[i + j * n2 + n1 * num]; 167 | if (j % 128 == 127) { 168 | w->value[j / 128] = bool_to_block(dummy); 169 | } 170 | } 171 | if (j % 128 != 0) { 172 | for (uint k = 0; k < bool_num; ++k) { 173 | w->rem_value[k] = dummy[k]; 174 | } 175 | } 176 | w->set = true; 177 | } 178 | } 179 | } 180 | #ifdef __debug 181 | std::cout << "Inputs: " << time_from(t) << " us\n"; 182 | #endif 183 | for (uint i = 0; i < level_map.size(); ++i) { 184 | #ifdef __debug 185 | t = clock_start(); 186 | #endif 187 | block *in1_and = nullptr, *in2_and = nullptr, *out_and = nullptr; 188 | bool *bool_in1_and = nullptr, *bool_in2_and = nullptr, *bool_out_and = nullptr; 189 | int num_ands = 0; 190 | for (auto it = level_map[i].begin(); it != level_map[i].end(); ++it) { 191 | Wire* w = circuit[*it]; 192 | switch (w->type) { 193 | case AND: 194 | if (w->in1->set == false || w->in2->set == false) { 195 | error("AND Unset wire being used"); 196 | } 197 | in1_and = (block*)realloc(in1_and, block_num * (num_ands + 1) * sizeof(block)); 198 | in2_and = (block*)realloc(in2_and, block_num * (num_ands + 1) * sizeof(block)); 199 | bool_in1_and = (bool*)realloc(bool_in1_and, bool_num * (num_ands + 1) * sizeof(bool)); 200 | bool_in2_and = (bool*)realloc(bool_in2_and, bool_num * (num_ands + 1) * sizeof(bool)); 201 | memcpy(in1_and + num_ands * block_num, w->in1->value, block_num * sizeof(block)); 202 | memcpy(in2_and + num_ands * block_num, w->in2->value, block_num * sizeof(block)); 203 | memcpy(bool_in1_and + num_ands * bool_num, w->in1->rem_value, bool_num * sizeof(bool)); 204 | memcpy(bool_in2_and + num_ands * bool_num, w->in2->rem_value, bool_num * sizeof(bool)); 205 | num_ands++; 206 | if ((w->set == true) && (w->num_required > 0)) { 207 | std::cout << *it << " " << w->num_required << std::endl; 208 | error("Set wire is being set again"); 209 | } 210 | break; 211 | default: 212 | break; 213 | } 214 | } 215 | #ifdef __debug 216 | std::cout << "level " << i << ": before AND " << time_from(t) << "\t"; 217 | t = clock_start(); 218 | #endif 219 | if (num_ands != 0) { 220 | out_and = (block*)malloc(block_num * num_ands * sizeof(block)); 221 | bool_out_and = (bool*)malloc(bool_num * num_ands * sizeof(bool)); 222 | simd_circ->and_gate(bool_out_and, bool_in1_and, bool_in2_and, num_ands * bool_num, out_and, in1_and, 223 | in2_and, num_ands * block_num); 224 | } 225 | #ifdef __debug 226 | std::cout << "AND: " << time_from(t) << "\t"; 227 | t = clock_start(); 228 | #endif 229 | int num_and_rev = 0; 230 | for (auto it = level_map[i].begin(); it != level_map[i].end(); ++it) { 231 | Wire* w = circuit[*it]; 232 | switch (w->type) { 233 | case AND: 234 | w->initialise_value(block_num, bool_num); 235 | w->set_value(out_and + (num_and_rev * block_num), block_num, 236 | bool_out_and + (num_and_rev * bool_num), bool_num); 237 | num_and_rev++; 238 | w->in1->reset_value(); 239 | w->in2->reset_value(); 240 | break; 241 | case XOR: 242 | w->initialise_value(block_num, bool_num); 243 | if (w->in1->set == false || w->in2->set == false) 244 | error("XOR Unset wire being used"); 245 | 246 | if ((w->set == true) && (w->num_required > 0)) { 247 | std::cout << *it << " " << w->num_required << std::endl; 248 | error("Set wire is being set again"); 249 | } 250 | simd_circ->xor_gate((w->value), (w->in1->value), (w->in2->value), block_num); 251 | simd_circ->xor_gate(w->rem_value, w->in1->rem_value, w->in2->rem_value, bool_num); 252 | w->set = true; 253 | w->in1->reset_value(); 254 | w->in2->reset_value(); 255 | break; 256 | case INV: 257 | w->initialise_value(block_num, bool_num); 258 | if (w->in1->set == false) 259 | error("INV Unset wire being used"); 260 | 261 | if ((w->set == true) && (w->num_required > 0)) { 262 | std::cout << *it << " " << w->num_required << std::endl; 263 | error("Set wire is being set again"); 264 | } 265 | simd_circ->not_gate((w->value), (w->in1->value), block_num); 266 | simd_circ->not_gate(w->rem_value, w->in1->rem_value, bool_num); 267 | w->set = true; 268 | w->in1->reset_value(); 269 | break; 270 | default: 271 | break; 272 | } 273 | } 274 | #ifdef __debug 275 | std::cout << "after AND: " << time_from(t) << "\n"; 276 | #endif 277 | } 278 | #ifdef __debug 279 | t = clock_start(); 280 | #endif 281 | for (uint i = 0; i < n3; ++i) { 282 | Wire* w = circuit[num_wires - n3 + i]; 283 | for (uint circ_index = 0; circ_index < block_num; ++circ_index) { 284 | bool dummy[128] = {0}; 285 | block_to_bool(dummy, w->value[circ_index]); 286 | for (uint j = 0; j < 128; ++j) { 287 | if ((circ_index * 128 + j) < num) { 288 | out[i + (circ_index * 128 + j) * n3] = dummy[j]; 289 | } 290 | } 291 | } 292 | for (uint circ_index = 0; circ_index < bool_num; ++circ_index) 293 | out[i + (circ_index + block_num * 128) * n3] = w->rem_value[circ_index]; 294 | } 295 | #ifdef __debug 296 | std::cout << "output:" << time_from(t) << " us\n"; 297 | #endif 298 | } 299 | }; 300 | } // namespace emp --------------------------------------------------------------------------------