├── .github └── workflows │ ├── arm.yml │ └── x86.yml ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake └── emp-sh2pc-config.cmake ├── emp-sh2pc ├── emp-sh2pc.h ├── semihonest-mult.h ├── semihonest.h ├── sh_eva.h ├── sh_gen.h └── sh_party.h ├── run ├── test ├── CMakeLists.txt ├── HEBasedComputation │ ├── CMakeLists.txt │ ├── convlayer.cpp │ ├── convlayer.h │ ├── fclayer.cpp │ ├── fclayer.h │ ├── utils.cpp │ └── utils.h ├── LinearLayer │ ├── CMakeLists.txt │ ├── conv-field.cpp │ ├── conv-field.h │ ├── conv-new.cpp │ ├── conv-new.h │ ├── conv-protocol.cpp │ ├── conv-protocol.h │ ├── defines-HE.h │ ├── elemwise-prod-field.cpp │ ├── elemwise-prod-field.h │ ├── fc-field.cpp │ ├── fc-field.h │ ├── fc-protocol.cpp │ ├── fc-protocol.h │ ├── utils-HE.cpp │ └── utils-HE.h ├── msi_average.cpp ├── msi_convlayer.cpp ├── msi_linearlayer.cpp ├── msi_microbenchmark.cpp ├── msi_relu.cpp ├── msi_relu_final.cpp ├── msi_relu_integrate.cpp └── msi_relu_preprocess.cpp └── utils ├── utils.cpp └── utils.h /.github/workflows/arm.yml: -------------------------------------------------------------------------------- 1 | name: arm 2 | on: [push] 3 | 4 | jobs: 5 | build_arm: 6 | strategy: 7 | matrix: 8 | os: [ubuntu-latest] 9 | build_type: [Debug, Release] 10 | runs-on: [self-hosted] 11 | timeout-minutes: 30 12 | env: 13 | BUILD_TYPE: ${{matrix.build_type}} 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: install dependency 17 | run: | 18 | wget https://raw.githubusercontent.com/emp-toolkit/emp-readme/master/scripts/install.py 19 | python3 install.py -install -tool -ot 20 | cd emp-tool && cmake -DENABLE_FLOAT=On . && sudo make install 21 | - name: Create Build Environment 22 | run: cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DENABLE_FLOAT=On -DUSE_RANDOM_DEVICE=On && make 23 | - name: Test 24 | shell: bash 25 | run: | 26 | make test 27 | -------------------------------------------------------------------------------- /.github/workflows/x86.yml: -------------------------------------------------------------------------------- 1 | name: x86 2 | on: [push] 3 | 4 | jobs: 5 | build_x86: 6 | strategy: 7 | matrix: 8 | os: [ubuntu-latest, macos-latest] 9 | build_type: [Debug, Release] 10 | runs-on: ${{ matrix.os }} 11 | timeout-minutes: 30 12 | env: 13 | BUILD_TYPE: ${{matrix.build_type}} 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: install dependency 17 | run: | 18 | wget https://raw.githubusercontent.com/emp-toolkit/emp-readme/master/scripts/install.py 19 | python install.py -install -tool -ot 20 | cd emp-tool && cmake -DENABLE_FLOAT=On . && sudo make install 21 | - name: Create Build Environment 22 | run: cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DENABLE_FLOAT=On -DUSE_RANDOM_DEVICE=On && make 23 | - name: Test 24 | shell: bash 25 | run: | 26 | make test 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Testing/ 2 | *.DS_Store 3 | *.txt.cpp 4 | *CTestTestfile.cmake 5 | *perf.data* 6 | CMakeCache.txt 7 | CMakeFiles/ 8 | Makefile 9 | bin/ 10 | cmake_install.cmake 11 | install_manifest.txt 12 | # Compiled Object files 13 | *.slo 14 | *.lo 15 | *.o 16 | *.obj 17 | 18 | # Precompiled Headers 19 | *.gch 20 | *.pch 21 | 22 | # Compiled Dynamic libraries 23 | *.so 24 | *.dylib 25 | *.dll 26 | 27 | # Fortran module files 28 | *.mod 29 | 30 | # Compiled Static libraries 31 | *.lai 32 | *.la 33 | *.a 34 | *.lib 35 | 36 | # Executables 37 | *.exe 38 | *.out 39 | *.app 40 | 41 | # Emacs 42 | *~ 43 | 44 | # CMake generated .h files 45 | emp-ot/latticeInclude.h 46 | 47 | # Eigen source 48 | eigen/ 49 | 50 | # VOLE COT data file 51 | data/ 52 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "extern/SEAL"] 2 | path = extern/SEAL 3 | url = https://github.com/microsoft/SEAL.git 4 | [submodule "extern/--force"] 5 | path = extern/--force 6 | url = https://github.com/microsoft/SEAL.git 7 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 2.8.11) 2 | project (emp-sh2pc) 3 | set(NAME "emp-sh2pc") 4 | 5 | find_path(CMAKE_FOLDER NAMES cmake/emp-tool-config.cmake) 6 | include(${CMAKE_FOLDER}/cmake/common.cmake) 7 | include(${CMAKE_FOLDER}/cmake/enable_rdseed.cmake) 8 | include(${CMAKE_FOLDER}/cmake/enable_float.cmake) 9 | 10 | set (CMAKE_PREFIX_PATH "../emp-ot/") 11 | find_package(emp-ot REQUIRED) 12 | 13 | include_directories(${EMP-OT_INCLUDE_DIRS}) 14 | 15 | # Installation 16 | install(FILES cmake/emp-sh2pc-config.cmake DESTINATION cmake/) 17 | install(DIRECTORY emp-sh2pc DESTINATION include/) 18 | 19 | ENABLE_TESTING() 20 | ADD_SUBDIRECTORY(test) 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Xiao Wang (wangxiao@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | Enquiries about further applications and development opportunities are welcome. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SIMC 2 | This repo contains implementation of our scheme "SIMC: ML Inference Secure Against Malicious Clients at Semi-Honest Cost". The repository is built on \[[emp-toolkit/emp-sh2pc](https://github.com/emp-toolkit/emp-sh2pc)\]. 3 | 4 | Disclaimer: This repository is a proof-of-concept prototype. 5 | 6 | TODO: Update the licensing information. 7 | 8 | # Installation 9 | 1. Create parent directory `msi-code` 10 | ``` 11 | mkdir msi-code && cd msi-code 12 | ``` 13 | 2. To install Eigen3 do: 14 | ``` 15 | sudo apt-get update -y 16 | sudo apt-get install -y libeigen3-dev 17 | ``` 18 | 3. Follow the installation steps of \[[emp-toolkit/emp-sh2pc](https://github.com/emp-toolkit/emp-sh2pc)\]. 19 | 4. Clone this repo in the parent directory `msi-code`. 20 | 5. Install SEAL 3.64 21 | a. Clone \[[SEAL](https://github.com/microsoft/SEAL.git)\] repo in the parent directory `msi-code`. 22 | b. Execute 23 | ``` 24 | cd SEAL 25 | git checkout 3.6.4 26 | mkdir build && cd build 27 | cmake .. 28 | make -j 29 | sudo make install 30 | ``` 31 | 32 | # Compilation 33 | 1. In `msi-code`, go to `emp-tool` and do `git checkout df363bf30b56c48a12c352845efa3a4d8f75b388`. 34 | 2. Next, go to `emp-ot` in `msi-code` and do `git checkout 3b21d6314cb1e7d8dbb9bb1f1ed80261738e4f4c`. 35 | 3. For multi-threading support, go to `emp-tool` and run the following: 36 | ``` 37 | cmake . -DTHREADING=ON 38 | make -j 39 | sudo make install 40 | ``` 41 | 2. Do the same for emp-ot repository. 42 | 3. Finally, do the same in our (simc) repository. 43 | 44 | ## Run Neuralnet Benchmarks 45 | Run the following test files from path `msi-code/simc`: 46 | 1. Fully-connected Layer: In one terminal run `bin/test_msi_linearlayer 1 0.0.0.0 44 ` and in other terminal run `bin/test_msi_linearlayer 2 44 `. 47 | 48 | 2. Convolution Layer: In one terminal run `bin/test_msi_convlayer 1 0.0.0.0 44 ` and in other terminal run `bin/test_msi_convlayer 2 44 `. 49 | 50 | 3. Non-Linear Layer (ReLU): In one terminal run `bin/test_msi_relu_final 1 0.0.0.0 44 0 0 ` and in other terminal run `bin/test_msi_relu_final 2 44 0 0 '`. 51 | 52 | 4. Average Pool Layer: In one terminal run `bin/test_msi_average 1 0.0.0.0 44 ` and in other terminal run `bin/test_msi_average 2 44 `. 53 | 54 | Here, the first parameters 1 and 2 denote the ID of the participating party. denotes the ip address of the server machine and set =1 for MNIST and =2 for CIFAR-10. 55 | 56 | Examples: 57 | ``` 58 | Fully connected Layer: 59 | Terminal 1: bin/test_msi_linearlayer 1 0.0.0.0 31000 44 1 60 | Terminal 2: bin/test_msi_linearlayer 2 31000 44 1 61 | 62 | Convolution Layer: 63 | Terminal 1: bin/test_msi_convlayer 1 0.0.0.0 31000 44 1 64 | Terminal 2: bin/test_msi_convlayer 2 31000 44 1 65 | 66 | Non-Linear Layer (ReLU): 67 | Terminal 1: bin/test_msi_relu_final 1 0.0.0.0 31000 44 1 0 0 8 68 | Terminal 2: bin/test_msi_relu_final 2 31000 44 1 0 0 8 69 | 70 | Average Pool Layer: 71 | Terminal 1: bin/test_msi_average 1 0.0.0.0 31000 44 1 72 | Terminal 2: bin/test_msi_average 2 31000 44 1 73 | ``` 74 | 75 | ## Run Neuralnet Micro-benchmarks 76 | ``` 77 | Terminal 1: bin/test_msi_microbenchmark 1 0.0.0.0 31000 44 <#threads> 78 | Terminal 2: bin/test_msi_microbenchmark 2 31000 44 <#threads> 79 | ``` 80 | Input Parameters: 81 | 1. : IP Address of Server. 82 | 2. : 0 - ReLU6, 1 - ReLU. 83 | 3. : Number of ReLUs 84 | 4. <#threads>: Number of threads 85 | 86 | ``` 87 | if <=2, set <#threads>=1, 88 | else if <=4, set <#threads>=2, 89 | else if <=16, set <#threads>=4, 90 | else if >16, set <#threads>=8. 91 | ``` 92 | Note: For different system-configuration, different number of threads may provide best performance for given number of ReLUs. 93 | 94 | Example: 95 | ``` 96 | Terminal 1: bin/test_msi_microbenchmark 1 0.0.0.0 31000 44 0 16384 8 97 | Terminal 2: bin/test_msi_microbenchmark 2 31000 44 0 16384 8 98 | ``` 99 | 100 | ## Contact 101 | For any queries, kindly contact akashshah08@outlook.com. 102 | -------------------------------------------------------------------------------- /cmake/emp-sh2pc-config.cmake: -------------------------------------------------------------------------------- 1 | find_package(emp-ot) 2 | 3 | find_path(EMP-SH2PC_INCLUDE_DIR emp-sh2pc/emp-sh2pc.h) 4 | 5 | include(FindPackageHandleStandardArgs) 6 | find_package_handle_standard_args(emp-sh2pc DEFAULT_MSG EMP-SH2PC_INCLUDE_DIR) 7 | 8 | if(EMP-SH2PC_FOUND) 9 | set(EMP-SH2PC_INCLUDE_DIRS ${EMP-SH2PC_INCLUDE_DIR}/include/ ${EMP-OT_INCLUDE_DIRS}) 10 | set(EMP-SH2PC_LIBRARIES ${EMP-TOOL_LIBRARIES}) 11 | endif() 12 | -------------------------------------------------------------------------------- /emp-sh2pc/emp-sh2pc.h: -------------------------------------------------------------------------------- 1 | #include "emp-sh2pc/semihonest.h" 2 | #include "emp-sh2pc/semihonest-mult.h" 3 | #include "emp-sh2pc/sh_party.h" 4 | #include "emp-sh2pc/sh_gen.h" 5 | #include "emp-sh2pc/sh_eva.h" 6 | -------------------------------------------------------------------------------- /emp-sh2pc/semihonest-mult.h: -------------------------------------------------------------------------------- 1 | #ifndef EMP_SEMIHONEST_MULT_H__ 2 | #define EMP_SEMIHONEST_MULT_H__ 3 | #include "emp-sh2pc/sh_gen.h" 4 | #include "emp-sh2pc/sh_eva.h" 5 | 6 | namespace emp { 7 | 8 | template 9 | class SHGen{ 10 | public: 11 | IO *io; 12 | HalfGateGen *t; 13 | CircuitExecution *lcirc_exec; 14 | ProtocolExecution *lprot_exec; 15 | block delta_used; 16 | 17 | SHGen(IO *ioObj, int batch_size = 1024*16) { 18 | io = ioObj; 19 | t = new HalfGateGen(io); 20 | lcirc_exec = t; 21 | lprot_exec = new SemiHonestGen(io, t); 22 | delta_used = t->delta; 23 | } 24 | 25 | void setup_execution_env(){ 26 | CircuitExecution::circ_exec = lcirc_exec; 27 | ProtocolExecution::prot_exec = lprot_exec; 28 | } 29 | 30 | ~SHGen() { 31 | delete lcirc_exec; 32 | delete lprot_exec; 33 | } 34 | 35 | }; 36 | 37 | template 38 | class SHEval{ 39 | public: 40 | NetIO *io; 41 | HalfGateEva *t; 42 | CircuitExecution *lcirc_exec; 43 | ProtocolExecution *lprot_exec; 44 | 45 | SHEval(IO *ioObj, int batch_size = 1024*16) { 46 | io = ioObj; 47 | t = new HalfGateEva(io); 48 | lcirc_exec = t; 49 | lprot_exec = new SemiHonestEva(io, t); 50 | } 51 | 52 | void setup_execution_env(){ 53 | CircuitExecution::circ_exec = lcirc_exec; 54 | ProtocolExecution::prot_exec = lprot_exec; 55 | } 56 | 57 | ~SHEval() { 58 | delete lcirc_exec; 59 | delete lprot_exec; 60 | } 61 | }; 62 | } 63 | #endif 64 | -------------------------------------------------------------------------------- /emp-sh2pc/semihonest.h: -------------------------------------------------------------------------------- 1 | #ifndef EMP_SEMIHONEST_H__ 2 | #define EMP_SEMIHONEST_H__ 3 | #include "emp-sh2pc/sh_gen.h" 4 | #include "emp-sh2pc/sh_eva.h" 5 | 6 | namespace emp { 7 | block delta_used; 8 | block delta_blocks[8]; 9 | template 10 | inline SemiHonestParty* setup_semi_honest(IO* io, int party, int batch_size = 1024*16) { 11 | 12 | if(party == ALICE) { 13 | HalfGateGen * t = new HalfGateGen(io); 14 | CircuitExecution::circ_exec = t; 15 | ProtocolExecution::prot_exec = new SemiHonestGen(io, t); 16 | delta_used = t->delta; 17 | } else { 18 | HalfGateEva * t = new HalfGateEva(io); 19 | CircuitExecution::circ_exec = t; 20 | ProtocolExecution::prot_exec = new SemiHonestEva(io, t); 21 | } 22 | return (SemiHonestParty*)ProtocolExecution::prot_exec; 23 | } 24 | 25 | template 26 | inline SemiHonestParty* setup_semi_honest_mult(IO* io, int party, int tid, int batch_size = 1024*16) { 27 | 28 | if(party == ALICE) { 29 | HalfGateGen * t = new HalfGateGen(io); 30 | CircuitExecution::circ_exec = t; 31 | ProtocolExecution::prot_exec = new SemiHonestGen(io, t); 32 | delta_blocks[tid] = t->delta; 33 | } else { 34 | HalfGateEva * t = new HalfGateEva(io); 35 | CircuitExecution::circ_exec = t; 36 | ProtocolExecution::prot_exec = new SemiHonestEva(io, t); 37 | } 38 | return (SemiHonestParty*)ProtocolExecution::prot_exec; 39 | } 40 | 41 | inline void finalize_semi_honest() { 42 | delete CircuitExecution::circ_exec; 43 | delete ProtocolExecution::prot_exec; 44 | } 45 | 46 | } 47 | #endif 48 | -------------------------------------------------------------------------------- /emp-sh2pc/sh_eva.h: -------------------------------------------------------------------------------- 1 | #ifndef EMP_SEMIHONEST_EVA_H__ 2 | #define EMP_SEMIHONEST_EVA_H__ 3 | #include "emp-sh2pc/sh_party.h" 4 | 5 | namespace emp { 6 | template 7 | class SemiHonestEva: public SemiHonestParty { public: 8 | HalfGateEva * gc; 9 | PRG prg; 10 | SemiHonestEva(IO *io, HalfGateEva * gc): SemiHonestParty(io, BOB) { 11 | this->gc = gc; 12 | this->ot->setup_recv(); 13 | 14 | block seed; this->io->recv_block(&seed, 1); 15 | 16 | this->shared_prg.reseed(&seed); 17 | refill(); 18 | } 19 | 20 | void refill() { 21 | prg.random_bool(this->buff, this->batch_size); 22 | this->ot->recv_cot(this->buf, this->buff, this->batch_size); 23 | this->top = 0; 24 | } 25 | 26 | void feed(block * label, int party, const bool* b, int length) { 27 | if(party == ALICE) { 28 | this->shared_prg.random_block(label, length); 29 | } else { 30 | if (length > this->batch_size) { 31 | this->ot->recv_cot(label, b, length); 32 | } else { 33 | bool * tmp = new bool[length]; 34 | if(length > this->batch_size - this->top) { 35 | memcpy(label, this->buf + this->top, (this->batch_size-this->top)*sizeof(block)); 36 | memcpy(tmp, this->buff + this->top, (this->batch_size-this->top)); 37 | int filled = this->batch_size - this->top; 38 | refill(); 39 | memcpy(label+filled, this->buf, (length - filled)*sizeof(block)); 40 | memcpy(tmp+ filled, this->buff, length - filled); 41 | this->top = length - filled; 42 | } else { 43 | memcpy(label, this->buf+this->top, length*sizeof(block)); 44 | memcpy(tmp, this->buff+this->top, length); 45 | this->top+=length; 46 | } 47 | 48 | for(int i = 0; i < length; ++i) 49 | tmp[i] = (tmp[i] != b[i]); 50 | this->io->send_data(tmp, length); 51 | 52 | delete[] tmp; 53 | } 54 | } 55 | } 56 | 57 | void reveal(bool * b, int party, const block * label, int length) { 58 | if (party == XOR) { 59 | for (int i = 0; i < length; ++i) 60 | b[i] = getLSB(label[i]); 61 | return; 62 | } 63 | for (int i = 0; i < length; ++i) { 64 | bool lsb = getLSB(label[i]), tmp; 65 | if (party == BOB or party == PUBLIC) { 66 | this->io->recv_data(&tmp, 1); 67 | b[i] = (tmp != lsb); 68 | } else if (party == ALICE) { 69 | this->io->send_data(&lsb, 1); 70 | b[i] = false; 71 | } 72 | } 73 | if(party == PUBLIC) 74 | this->io->send_data(b, length); 75 | } 76 | 77 | }; 78 | } 79 | 80 | #endif// GARBLE_CIRCUIT_SEMIHONEST_H__ 81 | -------------------------------------------------------------------------------- /emp-sh2pc/sh_gen.h: -------------------------------------------------------------------------------- 1 | #ifndef EMP_SEMIHONEST_GEN_H__ 2 | #define EMP_SEMIHONEST_GEN_H__ 3 | #include "emp-sh2pc/sh_party.h" 4 | 5 | namespace emp { 6 | 7 | template 8 | class SemiHonestGen: public SemiHonestParty { public: 9 | HalfGateGen * gc; 10 | SemiHonestGen(IO* io, HalfGateGen* gc): SemiHonestParty(io, ALICE) { 11 | this->gc = gc; 12 | bool delta_bool[128]; 13 | block_to_bool(delta_bool, gc->delta); 14 | this->ot->setup_send(delta_bool); 15 | block seed; 16 | PRG prg; 17 | prg.random_block(&seed, 1); 18 | this->io->send_block(&seed, 1); 19 | this->shared_prg.reseed(&seed); 20 | refill(); 21 | } 22 | 23 | void refill() { 24 | this->ot->send_cot(this->buf, this->batch_size); 25 | this->top = 0; 26 | } 27 | 28 | void feed(block * label, int party, const bool* b, int length) { 29 | if(party == ALICE) { 30 | this->shared_prg.random_block(label, length); 31 | for (int i = 0; i < length; ++i) { 32 | if(b[i]) 33 | label[i] = label[i] ^ gc->delta; 34 | } 35 | } else { 36 | if (length > this->batch_size) { 37 | this->ot->send_cot(label, length); 38 | } else { 39 | bool * tmp = new bool[length]; 40 | if(length > this->batch_size - this->top) { 41 | memcpy(label, this->buf + this->top, (this->batch_size-this->top)*sizeof(block)); 42 | int filled = this->batch_size - this->top; 43 | refill(); 44 | memcpy(label + filled, this->buf, (length - filled)*sizeof(block)); 45 | this->top = (length - filled); 46 | } else { 47 | memcpy(label, this->buf+this->top, length*sizeof(block)); 48 | this->top+=length; 49 | } 50 | 51 | this->io->recv_data(tmp, length); 52 | for (int i = 0; i < length; ++i) 53 | if(tmp[i]) 54 | label[i] = label[i] ^ gc->delta; 55 | delete[] tmp; 56 | } 57 | } 58 | } 59 | 60 | void reveal(bool* b, int party, const block * label, int length) { 61 | if (party == XOR) { 62 | for (int i = 0; i < length; ++i) 63 | b[i] = getLSB(label[i]); 64 | return; 65 | } 66 | for (int i = 0; i < length; ++i) { 67 | bool lsb = getLSB(label[i]); 68 | if (party == BOB or party == PUBLIC) { 69 | this->io->send_data(&lsb, 1); 70 | b[i] = false; 71 | } else if(party == ALICE) { 72 | bool tmp; 73 | this->io->recv_data(&tmp, 1); 74 | b[i] = (tmp != lsb); 75 | } 76 | } 77 | if(party == PUBLIC) 78 | this->io->recv_data(b, length); 79 | } 80 | }; 81 | } 82 | #endif //SEMIHONEST_GEN_H__ 83 | -------------------------------------------------------------------------------- /emp-sh2pc/sh_party.h: -------------------------------------------------------------------------------- 1 | #ifndef EMP_SH_PARTY_H__ 2 | #define EMP_SH_PARTY_H__ 3 | #include "emp-tool/emp-tool.h" 4 | #include "emp-ot/emp-ot.h" 5 | 6 | namespace emp { 7 | 8 | template 9 | class SemiHonestParty: public ProtocolExecution { public: 10 | IO* io = nullptr; 11 | IKNP * ot = nullptr; 12 | PRG shared_prg; 13 | 14 | block * buf = nullptr; 15 | bool * buff = nullptr; 16 | int top = 0; 17 | int batch_size = 1024*16; 18 | 19 | SemiHonestParty(IO * io, int party) : ProtocolExecution(party) { 20 | this->io = io; 21 | ot = new IKNP(io, true); 22 | buf = new block[batch_size]; 23 | buff = new bool[batch_size]; 24 | } 25 | void set_batch_size(int size) { 26 | delete[] buf; 27 | delete[] buff; 28 | batch_size = size; 29 | buf = new block[batch_size]; 30 | buff = new bool[batch_size]; 31 | } 32 | 33 | ~SemiHonestParty() { 34 | delete[] buf; 35 | delete[] buff; 36 | delete ot; 37 | } 38 | }; 39 | } 40 | #endif 41 | -------------------------------------------------------------------------------- /run: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ "$1" == "-p1" ] 3 | then 4 | shift 5 | perf record $1 1 12345 & (sleep 0.1; $1 2 12345) 6 | elif [ "$1" == "-p2" ] 7 | then 8 | shift 9 | (sleep 0.1; $1 1 12345) & (perf record $1 2 12345) 10 | 11 | elif [ "$1" == "-m1" ] 12 | then 13 | shift 14 | valgrind --leak-check=full $1 1 12345 & $1 2 12345 15 | elif [ "$1" == "-m2" ] 16 | then 17 | shift 18 | $1 1 12345 & valgrind --leak-check=full $1 2 12345 19 | elif [ "$1" == "-t1" ] 20 | then 21 | shift 22 | time $1 1 12345 & $1 2 12345 23 | elif [ "$1" == "-t2" ] 24 | then 25 | shift 26 | $1 1 12345 & time $1 2 12345 27 | 28 | else 29 | (sleep 0.05; $1 1 12345) & $1 2 12345 30 | fi 31 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(SEAL 3.6.4 REQUIRED) 2 | include_directories(${SEAL_INCLUDE_DIRS}) 3 | 4 | ADD_SUBDIRECTORY(LinearLayer) 5 | 6 | #Testing macro 7 | macro (add_test_executable_with_lib _name libs) 8 | add_executable(test_${_name} "${_name}.cpp") 9 | target_link_libraries(test_${_name} ${EMP-OT_LIBRARIES} SEAL::seal LinearLayer) 10 | endmacro() 11 | 12 | macro (add_test_case _name) 13 | add_test_executable_with_lib(${_name} "") 14 | add_test(NAME ${_name} COMMAND "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/test_${_name}" WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}/") 15 | endmacro() 16 | 17 | macro (add_test_case_with_run _name) 18 | add_test_executable_with_lib(${_name} "") 19 | add_test(NAME ${_name} COMMAND "./run" "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/test_${_name}" WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}/") 20 | endmacro() 21 | 22 | # Test cases 23 | add_test_case_with_run(msi_relu) 24 | add_test_case_with_run(msi_relu_final) 25 | add_test_case_with_run(msi_linearlayer) 26 | add_test_case_with_run(msi_convlayer) 27 | add_test_case_with_run(msi_microbenchmark) 28 | add_test_case_with_run(msi_relu_preprocess) 29 | add_test_case_with_run(msi_relu_integrate) 30 | add_test_case_with_run(msi_average) 31 | -------------------------------------------------------------------------------- /test/HEBasedComputation/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(HE-CodeBase 2 | utils.cpp 3 | fclayer.cpp 4 | convlayer.cpp 5 | ) 6 | 7 | target_link_libraries( HE-CodeBase 8 | ${EMP-OT_LIBRARIES} SEAL::seal 9 | ) 10 | -------------------------------------------------------------------------------- /test/HEBasedComputation/convlayer.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shahakash28/simc/2a5fd092b52427cc9cac55b36ec50ae43ecee6be/test/HEBasedComputation/convlayer.cpp -------------------------------------------------------------------------------- /test/HEBasedComputation/convlayer.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shahakash28/simc/2a5fd092b52427cc9cac55b36ec50ae43ecee6be/test/HEBasedComputation/convlayer.h -------------------------------------------------------------------------------- /test/HEBasedComputation/fclayer.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Original Author: ryanleh 3 | Modified Work Copyright (c) 2020 Microsoft Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | Modified by Deevashwer Rathee 24 | */ 25 | 26 | #include "fclayer.h" 27 | 28 | using namespace std; 29 | using namespace seal; 30 | 31 | /* Helper function for rounding to the next power of 2 32 | * Credit: https://stackoverflow.com/questions/466204/rounding-up-to-next-power-of-2 */ 33 | inline int next_pow2(int val) { 34 | return pow(2, ceil(log(val)/log(2))); 35 | } 36 | 37 | Ciphertext preprocess_vec(const uint64_t *input, const FCMetadata &data, 38 | Encryptor &encryptor, BatchEncoder &batch_encoder) { 39 | // Create copies of the input vector to fill the ciphertext appropiately. 40 | // Pack using powers of two for easy rotations later 41 | vector pod_matrix(data.slot_count, 0ULL); 42 | uint64_t size_pow2 = next_pow2(data.image_size); 43 | for (int col = 0; col < data.image_size; col++) { 44 | for (int idx = 0; idx < data.pack_num; idx++) { 45 | pod_matrix[col + size_pow2 * idx] = input[col]; 46 | } 47 | } 48 | 49 | Ciphertext ciphertext; 50 | Plaintext tmp; 51 | batch_encoder.encode(pod_matrix, tmp); 52 | encryptor.encrypt(tmp, ciphertext); 53 | return ciphertext; 54 | } 55 | -------------------------------------------------------------------------------- /test/HEBasedComputation/fclayer.h: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | struct FCMetadata { 4 | int slot_count; 5 | int32_t pack_num; 6 | int32_t inp_ct; 7 | // Filter is a matrix 8 | int32_t filter_h; 9 | int32_t filter_w; 10 | int32_t filter_size; 11 | // Image is a vector 12 | int32_t image_size; 13 | }; 14 | 15 | seal::Ciphertext preprocess_vec(const uint64_t *input, const FCMetadata &data, 16 | seal::Encryptor &encryptor, 17 | seal::BatchEncoder &batch_encoder); 18 | /* 19 | std::vector 20 | preprocess_matrix(const uint64_t *const *matrix, const FCMetadata &data, 21 | seal::BatchEncoder &batch_encoder); 22 | 23 | seal::Ciphertext fc_preprocess_noise(const uint64_t *secret_share, 24 | const FCMetadata &data, 25 | seal::Encryptor &encryptor, 26 | seal::BatchEncoder &batch_encoder); 27 | 28 | seal::Ciphertext fc_online(seal::Ciphertext &ct, 29 | std::vector &enc_mat, 30 | const FCMetadata &data, seal::Evaluator &evaluator, 31 | seal::GaloisKeys &gal_keys, seal::Ciphertext &zero, 32 | seal::Ciphertext &enc_noise); 33 | 34 | uint64_t *fc_postprocess(seal::Ciphertext &result, const FCMetadata &data, 35 | seal::BatchEncoder &batch_encoder, 36 | seal::Decryptor &decryptor); 37 | */ 38 | /* 39 | class FCField { 40 | public: 41 | int party; 42 | sci::NetIO *io; 43 | FCMetadata data; 44 | std::shared_ptr context; 45 | seal::Encryptor *encryptor; 46 | seal::Decryptor *decryptor; 47 | seal::Evaluator *evaluator; 48 | seal::BatchEncoder *encoder; 49 | seal::GaloisKeys *gal_keys; 50 | seal::Ciphertext *zero; 51 | size_t slot_count; 52 | 53 | FCField(int party, sci::NetIO *io); 54 | 55 | ~FCField(); 56 | 57 | void configure(); 58 | 59 | std::vector ideal_functionality(uint64_t *vec, uint64_t **matrix); 60 | 61 | void matrix_multiplication(int32_t num_rows, int32_t common_dim, 62 | int32_t num_cols, 63 | std::vector> &A, 64 | std::vector> &B, 65 | std::vector> &C, 66 | bool verify_output = false, bool verbose = false); 67 | 68 | void verify(std::vector *vec, std::vector *matrix, 69 | std::vector> &C); 70 | }; 71 | #endif 72 | */ 73 | -------------------------------------------------------------------------------- /test/HEBasedComputation/utils.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "utils.h" 3 | 4 | 5 | void key_generator(int party, NetIO* io, SEALContext context, Encryptor *&encryptor, Decryptor *&decryptor, Evaluator *&evaluator, BatchEncoder *&encoder, GaloisKeys *&galois_keys, Ciphertext *&zero) { 6 | encoder = new BatchEncoder(context); 7 | evaluator = new Evaluator(context); 8 | 9 | if(party == BOB) { 10 | //Client generates HE Keys 11 | KeyGenerator keygen(context); 12 | SecretKey secret_key = keygen.secret_key(); 13 | PublicKey public_key; 14 | keygen.create_public_key(public_key); 15 | 16 | GaloisKeys galois_keys; 17 | keygen.create_galois_keys(galois_keys); 18 | 19 | stringstream os; 20 | public_key.save(os); 21 | uint64_t pk_size = os.tellp(); 22 | galois_keys.save(os); 23 | uint64_t gk_size = (uint64_t)os.tellp() - pk_size; 24 | 25 | //Send keys to server 26 | string keys_ser = os.str(); 27 | io->send_data(&pk_size, sizeof(uint64_t)); 28 | io->send_data(&gk_size, sizeof(uint64_t)); 29 | io->send_data(keys_ser.c_str(), pk_size + gk_size); 30 | 31 | encryptor = new Encryptor(context, public_key); 32 | decryptor = new Decryptor(context, secret_key); 33 | } else { 34 | //Receive keys from client 35 | uint64_t pk_size; 36 | uint64_t gk_size; 37 | io->recv_data(&pk_size, sizeof(uint64_t)); 38 | io->recv_data(&gk_size, sizeof(uint64_t)); 39 | char *key_share = new char[pk_size + gk_size]; 40 | io->recv_data(key_share, pk_size + gk_size); 41 | //Load keys from received data 42 | stringstream is; 43 | PublicKey public_key; 44 | is.write(key_share, pk_size); 45 | public_key.load(context, is); 46 | galois_keys = new GaloisKeys(); 47 | is.write(key_share + pk_size, gk_size); 48 | galois_keys->load(context, is); 49 | delete[] key_share; 50 | 51 | encryptor = new Encryptor(context, public_key); 52 | vector pod_matrix(POLY_MOD_DEGREE, 0ULL); 53 | 54 | Plaintext tmp; 55 | encoder->encode(pod_matrix, tmp); 56 | zero = new Ciphertext; 57 | encryptor->encrypt(tmp, *zero); 58 | } 59 | } 60 | 61 | void free_keys(int party, Encryptor *&encryptor, Decryptor *&decryptor, Evaluator *&evaluator, BatchEncoder *&encoder, GaloisKeys *&gal_keys, Ciphertext *&zero) { 62 | delete encoder; 63 | delete evaluator; 64 | delete encryptor; 65 | if (party == BOB) { 66 | delete decryptor; 67 | } 68 | else // party ==ALICE 69 | { 70 | delete gal_keys; 71 | delete zero; 72 | } 73 | } 74 | 75 | void send_ciphertext(NetIO *io, Ciphertext &ct) { 76 | stringstream os; 77 | uint64_t ct_size; 78 | ct.save(os); 79 | ct_size = os.tellp(); 80 | string ct_ser = os.str(); 81 | io->send_data(&ct_size, sizeof(uint64_t)); 82 | io->send_data(ct_ser.c_str(), ct_ser.size()); 83 | } 84 | 85 | void recv_ciphertext(NetIO *io, SEALContext context, Ciphertext &ct) { 86 | stringstream is; 87 | uint64_t ct_size; 88 | io->recv_data(&ct_size, sizeof(uint64_t)); 89 | char *c_enc_result = new char[ct_size]; 90 | io->recv_data(c_enc_result, ct_size); 91 | is.write(c_enc_result, ct_size); 92 | ct.unsafe_load(context, is); 93 | delete[] c_enc_result; 94 | } 95 | -------------------------------------------------------------------------------- /test/HEBasedComputation/utils.h: -------------------------------------------------------------------------------- 1 | 2 | #include "emp-sh2pc/emp-sh2pc.h" 3 | #include "seal/seal.h" 4 | 5 | using namespace std; 6 | using namespace seal; 7 | using namespace emp; 8 | 9 | //44 bit prime 10 | const uint64_t PLAINTEXT_MODULUS = 17592060215297; 11 | const uint64_t POLY_MOD_DEGREE = 8192; 12 | 13 | void key_generator(int party, NetIO* io, SEALContext context, Encryptor *&encryptor, Decryptor *&decryptor, Evaluator *&evaluator, BatchEncoder *&encoder, GaloisKeys *&galois_keys, RelinKeys *& relin_keys, Ciphertext *&zero); 14 | void free_keys(int party, seal::Encryptor *&encryptor, seal::Decryptor *&decryptor, seal::Evaluator *&evaluator, seal::BatchEncoder *&encoder, seal::GaloisKeys *&gal_keys, seal::Ciphertext *&zero); 15 | 16 | void send_ciphertext(NetIO *io, Ciphertext &ct); 17 | void recv_ciphertext(NetIO *io, SEALContext context, Ciphertext &ct); 18 | -------------------------------------------------------------------------------- /test/LinearLayer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(OpenMP REQUIRED) 2 | find_package(Eigen3 REQUIRED) 3 | 4 | add_library(LinearLayer 5 | #conv-field.cpp 6 | fc-field.cpp 7 | conv-new.cpp 8 | #elemwise-prod-field.cpp 9 | utils-HE.cpp 10 | ) 11 | 12 | target_link_libraries(LinearLayer 13 | PUBLIC 14 | ${EMP-OT_LIBRARIES} 15 | SEAL::seal 16 | OpenMP::OpenMP_CXX 17 | Eigen3::Eigen 18 | ) 19 | -------------------------------------------------------------------------------- /test/LinearLayer/conv-field.h: -------------------------------------------------------------------------------- 1 | /* 2 | Original Author: ryanleh 3 | Modified Work Copyright (c) 2020 Microsoft Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | Modified by Deevashwer Rathee 24 | */ 25 | 26 | #ifndef CONV_FIELD_H__ 27 | #define CONV_FIELD_H__ 28 | 29 | #include "utils-HE.h" 30 | #include 31 | 32 | // This is to keep compatibility for im2col implementations 33 | typedef Eigen::Matrix 34 | Channel; 35 | typedef std::vector Image; 36 | typedef std::vector Filters; 37 | 38 | struct ConvMetadata { 39 | int slot_count; 40 | // Number of plaintext slots in a half ciphertext 41 | // (since ciphertexts are a two column matrix) 42 | int32_t pack_num; 43 | // Number of Channels that can fit in a half ciphertext 44 | int32_t chans_per_half; 45 | // Number of input ciphertexts for convolution 46 | int32_t inp_ct; 47 | // Number of output ciphertexts 48 | int32_t out_ct; 49 | // Image and Filters metadata 50 | int32_t image_h; 51 | int32_t image_w; 52 | size_t image_size; 53 | int32_t inp_chans; 54 | int32_t filter_h; 55 | int32_t filter_w; 56 | int32_t filter_size; 57 | int32_t out_chans; 58 | // How many total ciphertext halves the input and output take up 59 | int32_t inp_halves; 60 | int32_t out_halves; 61 | /* How many Channels are in the last output or input half */ 62 | int32_t out_in_last; 63 | int32_t inp_in_last; 64 | // The modulo used when deciding which output channels to pack into a mask 65 | int32_t out_mod; 66 | // How many permutations of ciphertexts are needed to generate all 67 | // intermediate rotation sets 68 | int32_t half_perms; 69 | bool last_repeats; 70 | int32_t repeat_chans; 71 | /* The number of rotations for each ciphertext half */ 72 | int32_t half_rots; 73 | int32_t last_rots; 74 | // Total number of convolutions needed to generate all 75 | // intermediate rotations sets 76 | int32_t convs; 77 | int32_t stride_h; 78 | int32_t stride_w; 79 | int32_t output_h; 80 | int32_t output_w; 81 | int32_t pad_t; 82 | int32_t pad_b; 83 | int32_t pad_r; 84 | int32_t pad_l; 85 | bool pad_valid; 86 | }; 87 | 88 | /* Use casting to do two conditionals instead of one - check if a > 0 and a < b 89 | */ 90 | inline bool condition_check(int a, int b) { 91 | return static_cast(a) < static_cast(b); 92 | } 93 | 94 | Image pad_image(ConvMetadata data, Image &image); 95 | 96 | void i2c(Image &image, Channel &column, const int filter_h, const int filter_w, 97 | const int stride_h, const int stride_w, const int output_h, 98 | const int output_w); 99 | 100 | std::vector 101 | HE_preprocess_noise(const uint64_t *const *secret_share, 102 | const ConvMetadata &data, seal::Encryptor &encryptor, 103 | seal::BatchEncoder &batch_encoder, 104 | seal::Evaluator &evaluator); 105 | std::vector HE_preprocess_noise_plain(const uint64_t* const* secret_share, const ConvMetadata &data, 106 | seal::BatchEncoder &batch_encoder); 107 | 108 | std::vector> preprocess_image_OP(Image &image, 109 | ConvMetadata data); 110 | template std::vector filter_rotations_dash(T &input, const ConvMetadata &data, seal::Evaluator *evaluator = NULL, 111 | seal::GaloisKeys *gal_keys = NULL); 112 | 113 | std::vector> 114 | filter_rotations(std::vector &input, const ConvMetadata &data, 115 | seal::Evaluator *evaluator = NULL, 116 | seal::GaloisKeys *gal_keys = NULL); 117 | 118 | std::vector pt_rotate(int slot_count, int rotation, std::vector &vec); 119 | std::vector HE_encrypt(std::vector> &pt, 120 | const ConvMetadata &data, 121 | seal::Encryptor &encryptor, 122 | seal::BatchEncoder &batch_encoder); 123 | 124 | std::vector> HE_encrypt_rotations(std::vector>> &pt, 125 | const ConvMetadata &data, 126 | seal::Encryptor &encryptor, 127 | seal::BatchEncoder &batch_encoder); 128 | 129 | std::vector>> 130 | HE_preprocess_filters_OP(Filters &filters, const ConvMetadata &data, 131 | seal::BatchEncoder &batch_encoder); 132 | std::vector>> HE_preprocess_filters(const uint64_t* const* const* filters, 133 | const ConvMetadata &data, seal::BatchEncoder &batch_encoder); 134 | std::vector 135 | HE_conv_OP(std::vector>> &masks, 136 | std::vector> &rotations, 137 | const ConvMetadata &data, seal::Evaluator &evaluator, 138 | seal::Ciphertext &zero); 139 | std::vector> HE_conv(std::vector>> &masks, 140 | std::vector> &rotations, const ConvMetadata &data, seal::Evaluator &evaluator, 141 | seal::RelinKeys &relin_keys, seal::Ciphertext &zero); 142 | 143 | std::vector 144 | HE_output_rotations(std::vector &convs, 145 | const ConvMetadata &data, seal::Evaluator &evaluator, 146 | seal::GaloisKeys &gal_keys, seal::Ciphertext &zero, 147 | std::vector &enc_noise); 148 | 149 | std::vector HE_output_rotations_dash(std::vector> convs, 150 | const ConvMetadata &data, seal::Evaluator &evaluator, seal::GaloisKeys &gal_keys, 151 | seal::Ciphertext &zero); 152 | 153 | uint64_t **HE_decrypt(std::vector &enc_result, 154 | const ConvMetadata &data, seal::Decryptor &decryptor, 155 | seal::BatchEncoder &batch_encoder); 156 | 157 | class ConvField { 158 | public: 159 | int party; 160 | emp::NetIO *io; 161 | std::shared_ptr context[2]; 162 | seal::Encryptor *encryptor[2]; 163 | seal::Decryptor *decryptor[2]; 164 | seal::Evaluator *evaluator[2]; 165 | seal::BatchEncoder *encoder[2]; 166 | seal::GaloisKeys *gal_keys[2]; 167 | seal::RelinKeys *relin_keys[2]; 168 | seal::Ciphertext *zero[2]; 169 | size_t slot_count; 170 | ConvMetadata data; 171 | 172 | ConvField(int party, emp::NetIO *io); 173 | 174 | ~ConvField(); 175 | 176 | void configure(); 177 | void configure_1(); 178 | 179 | Image ideal_functionality(Image &image, Filters &filters); 180 | 181 | void non_strided_conv(int32_t H, int32_t W, int32_t CI, int32_t FH, 182 | int32_t FW, int32_t CO, Image *image, Filters *filters, 183 | std::vector>> &outArr, 184 | bool verbose = false); 185 | 186 | void convolution_first( 187 | int32_t H, int32_t W, int32_t CI, int32_t FH, int32_t FW, 188 | int32_t CO, int32_t strideH, int32_t strideW, bool pad_valid, 189 | std::vector>>> &inputArr, 190 | std::vector>>> &filterArr, 191 | bool verify_output = false, bool verbose = false); 192 | 193 | void convolution_gen( 194 | int32_t H, int32_t W, int32_t CI, int32_t FH, int32_t FW, 195 | int32_t CO, int32_t strideH, int32_t strideW, bool pad_valid, 196 | std::vector>>> &inputArr, 197 | std::vector>>> &inputMacArr, 198 | std::vector>>> &filterArr, 199 | seal::Modulus mod, 200 | bool verify_output = false, bool verbose = false); 201 | 202 | uint64_t** conv_preprocess_invert(std::vector &r_mac_ct, const ConvMetadata &data, seal::Decryptor &decryptor, 203 | seal::BatchEncoder &batch_encoder); 204 | void convolution( 205 | int32_t N, int32_t H, int32_t W, int32_t CI, int32_t FH, int32_t FW, 206 | int32_t CO, int32_t zPadHLeft, int32_t zPadHRight, int32_t zPadWLeft, 207 | int32_t zPadWRight, int32_t strideH, int32_t strideW, 208 | std::vector>>> &inputArr, 209 | std::vector>>> &filterArr, 210 | std::vector>>> &outArr, 211 | bool verify_output = false, bool verbose = false); 212 | void verify(int H, int W, int CI, int CO, Image &image, Filters *filters, 213 | std::vector>>> &outArr); 214 | }; 215 | 216 | #endif 217 | -------------------------------------------------------------------------------- /test/LinearLayer/conv-new.h: -------------------------------------------------------------------------------- 1 | /* 2 | Original Author: ryanleh 3 | Modified Work Copyright (c) 2020 Microsoft Research 4 | Modified Work Copyright (c) 2021 Microsoft Research 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | 24 | Modified by Deevashwer Rathee 25 | Modifed by Akash Shah 26 | */ 27 | 28 | #ifndef CONV_FIELD_H__ 29 | #define CONV_FIELD_H__ 30 | 31 | #include "utils-HE.h" 32 | #include 33 | 34 | // This is to keep compatibility for im2col implementations 35 | typedef Eigen::Matrix 36 | Channel; 37 | typedef std::vector Image; 38 | typedef std::vector Filters; 39 | 40 | struct ConvMetadata { 41 | int slot_count; 42 | // Number of plaintext slots in a half ciphertext 43 | // (since ciphertexts are a two column matrix) 44 | int32_t pack_num; 45 | // Number of Channels that can fit in a half ciphertext 46 | int32_t chans_per_half; 47 | // Number of input ciphertexts for convolution 48 | int32_t inp_ct; 49 | // Number of output ciphertexts 50 | int32_t out_ct; 51 | // Image and Filters metadata 52 | int32_t image_h; 53 | int32_t image_w; 54 | size_t image_size; 55 | int32_t inp_chans; 56 | int32_t filter_h; 57 | int32_t filter_w; 58 | int32_t filter_size; 59 | int32_t out_chans; 60 | // How many total ciphertext halves the input and output take up 61 | int32_t inp_halves; 62 | int32_t out_halves; 63 | /* How many Channels are in the last output or input half */ 64 | int32_t out_in_last; 65 | int32_t inp_in_last; 66 | // The modulo used when deciding which output channels to pack into a mask 67 | int32_t out_mod; 68 | // How many permutations of ciphertexts are needed to generate all 69 | // intermediate rotation sets 70 | int32_t half_perms; 71 | bool last_repeats; 72 | int32_t repeat_chans; 73 | /* The number of rotations for each ciphertext half */ 74 | int32_t half_rots; 75 | int32_t last_rots; 76 | // Total number of convolutions needed to generate all 77 | // intermediate rotations sets 78 | int32_t convs; 79 | int32_t stride_h; 80 | int32_t stride_w; 81 | int32_t output_h; 82 | int32_t output_w; 83 | int32_t pad_t; 84 | int32_t pad_b; 85 | int32_t pad_r; 86 | int32_t pad_l; 87 | bool pad_valid; 88 | }; 89 | 90 | /* Use casting to do two conditionals instead of one - check if a > 0 and a < b 91 | */ 92 | inline bool condition_check(int a, int b) { 93 | return static_cast(a) < static_cast(b); 94 | } 95 | 96 | Image pad_image(ConvMetadata data, Image &image); 97 | 98 | void i2c(Image &image, Channel &column, const int filter_h, const int filter_w, 99 | const int stride_h, const int stride_w, const int output_h, 100 | const int output_w); 101 | 102 | std::vector 103 | HE_preprocess_noise(const uint64_t *const *secret_share, 104 | const ConvMetadata &data, seal::Encryptor &encryptor, 105 | seal::BatchEncoder &batch_encoder, 106 | seal::Evaluator &evaluator); 107 | std::vector HE_preprocess_noise_plain(const uint64_t* const* secret_share, const ConvMetadata &data, 108 | seal::BatchEncoder &batch_encoder); 109 | 110 | std::vector> preprocess_image_OP(Image &image, 111 | ConvMetadata data); 112 | template std::vector filter_rotations_dash(T &input, const ConvMetadata &data, seal::Evaluator *evaluator = NULL, 113 | seal::GaloisKeys *gal_keys = NULL); 114 | 115 | std::vector> 116 | filter_rotations(std::vector &input, const ConvMetadata &data, 117 | seal::Evaluator *evaluator = NULL, 118 | seal::GaloisKeys *gal_keys = NULL); 119 | 120 | std::vector pt_rotate(int slot_count, int rotation, std::vector &vec); 121 | std::vector HE_encrypt(std::vector> &pt, 122 | const ConvMetadata &data, 123 | seal::Encryptor &encryptor, 124 | seal::BatchEncoder &batch_encoder); 125 | 126 | std::vector> HE_encrypt_rotations(std::vector>> &pt, 127 | const ConvMetadata &data, 128 | seal::Encryptor &encryptor, 129 | seal::BatchEncoder &batch_encoder); 130 | 131 | std::vector>> 132 | HE_preprocess_filters_OP(Filters &filters, const ConvMetadata &data, 133 | seal::BatchEncoder &batch_encoder); 134 | std::vector>> HE_preprocess_filters(const uint64_t* const* const* filters, 135 | const ConvMetadata &data, seal::BatchEncoder &batch_encoder); 136 | std::vector 137 | HE_conv_OP(std::vector>> &masks, 138 | std::vector> &rotations, 139 | const ConvMetadata &data, seal::Evaluator &evaluator, 140 | seal::Ciphertext &zero); 141 | std::vector> HE_conv(std::vector>> &masks, 142 | std::vector> &rotations, const ConvMetadata &data, seal::Evaluator &evaluator, 143 | seal::RelinKeys &relin_keys, seal::Ciphertext &zero); 144 | 145 | std::vector 146 | HE_output_rotations(std::vector &convs, 147 | const ConvMetadata &data, seal::Evaluator &evaluator, 148 | seal::GaloisKeys &gal_keys, seal::Ciphertext &zero, 149 | std::vector &enc_noise); 150 | 151 | std::vector HE_output_rotations_dash(std::vector> convs, 152 | const ConvMetadata &data, seal::Evaluator &evaluator, seal::GaloisKeys &gal_keys, 153 | seal::Ciphertext &zero); 154 | 155 | uint64_t **HE_decrypt(std::vector &enc_result, 156 | const ConvMetadata &data, seal::Decryptor &decryptor, 157 | seal::BatchEncoder &batch_encoder); 158 | 159 | class ConvField { 160 | public: 161 | int party; 162 | emp::NetIO *io; 163 | std::shared_ptr context[2]; 164 | seal::Encryptor *encryptor[2]; 165 | seal::Decryptor *decryptor[2]; 166 | seal::Evaluator *evaluator[2]; 167 | seal::BatchEncoder *encoder[2]; 168 | seal::GaloisKeys *gal_keys[2]; 169 | seal::RelinKeys *relin_keys[2]; 170 | seal::Ciphertext *zero[2]; 171 | size_t slot_count; 172 | ConvMetadata data; 173 | 174 | ConvField(int party, emp::NetIO *io); 175 | 176 | ~ConvField(); 177 | 178 | void configure(); 179 | void configure_1(); 180 | 181 | Image ideal_functionality(Image &image, Filters &filters); 182 | 183 | void non_strided_conv(int32_t H, int32_t W, int32_t CI, int32_t FH, 184 | int32_t FW, int32_t CO, Image *image, Filters *filters, 185 | std::vector>> &outArr, 186 | bool verbose = false); 187 | 188 | void convolution_first( 189 | int32_t H, int32_t W, int32_t CI, int32_t FH, int32_t FW, 190 | int32_t CO, int32_t strideH, int32_t strideW, bool pad_valid, 191 | std::vector>>> &inputArr, 192 | std::vector>>> &filterArr, 193 | bool verify_output = false, bool verbose = false); 194 | 195 | void convolution_gen( 196 | int32_t H, int32_t W, int32_t CI, int32_t FH, int32_t FW, 197 | int32_t CO, int32_t strideH, int32_t strideW, bool pad_valid, 198 | std::vector>>> &inputArr, 199 | std::vector>>> &inputMacArr, 200 | std::vector>>> &filterArr, 201 | seal::Modulus mod, 202 | bool verify_output = false, bool verbose = fFalse); 203 | 204 | uint64_t** conv_preprocess_invert(std::vector &r_mac_ct, const ConvMetadata &data, seal::Decryptor &decryptor, 205 | seal::BatchEncoder &batch_encoder); 206 | void convolution( 207 | int32_t N, int32_t H, int32_t W, int32_t CI, int32_t FH, int32_t FW, 208 | int32_t CO, int32_t zPadHLeft, int32_t zPadHRight, int32_t zPadWLeft, 209 | int32_t zPadWRight, int32_t strideH, int32_t strideW, 210 | std::vector>>> &inputArr, 211 | std::vector>>> &filterArr, 212 | std::vector>>> &outArr, 213 | bool verify_output = false, bool verbose = false); 214 | void verify(int H, int W, int CI, int CO, Image &image, Filters *filters, 215 | std::vector>>> &outArr); 216 | }; 217 | 218 | #endif 219 | -------------------------------------------------------------------------------- /test/LinearLayer/conv-protocol.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shahakash28/simc/2a5fd092b52427cc9cac55b36ec50ae43ecee6be/test/LinearLayer/conv-protocol.cpp -------------------------------------------------------------------------------- /test/LinearLayer/conv-protocol.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shahakash28/simc/2a5fd092b52427cc9cac55b36ec50ae43ecee6be/test/LinearLayer/conv-protocol.h -------------------------------------------------------------------------------- /test/LinearLayer/defines-HE.h: -------------------------------------------------------------------------------- 1 | /* 2 | Authors: Deevashwer Rathee 3 | Copyright: 4 | Copyright (c) 2020 Microsoft Research 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | */ 21 | 22 | #ifndef DEFINES_HE_H__ 23 | #define DEFINES_HE_H__ 24 | 25 | #include 26 | #include 27 | //#define HE_DEBUG 28 | 29 | // The following are UBUNTU/LINUX, and MacOS ONLY terminal color codes. 30 | #define RESET "\033[0m" 31 | #define BLACK "\033[30m" /* Black */ 32 | #define RED "\033[31m" /* Red */ 33 | #define GREEN "\033[32m" /* Green */ 34 | #define YELLOW "\033[33m" /* Yellow */ 35 | #define BLUE "\033[34m" /* Blue */ 36 | #define MAGENTA "\033[35m" /* Magenta */ 37 | #define CYAN "\033[36m" /* Cyan */ 38 | #define WHITE "\033[37m" /* White */ 39 | #define BOLDBLACK "\033[1m\033[30m" /* Bold Black */ 40 | #define BOLDRED "\033[1m\033[31m" /* Bold Red */ 41 | #define BOLDGREEN "\033[1m\033[32m" /* Bold Green */ 42 | #define BOLDYELLOW "\033[1m\033[33m" /* Bold Yellow */ 43 | #define BOLDBLUE "\033[1m\033[34m" /* Bold Blue */ 44 | #define BOLDMAGENTA "\033[1m\033[35m" /* Bold Magenta */ 45 | #define BOLDCYAN "\033[1m\033[36m" /* Bold Cyan */ 46 | #define BOLDWHITE "\033[1m\033[37m" /* Bold White */ 47 | 48 | extern uint64_t prime_mod; 49 | extern int32_t bitlength; 50 | extern int32_t num_threads; 51 | 52 | const uint64_t PLAINTEXT_MODULUS = 17592060215297; 53 | const uint64_t POLY_MOD_DEGREE = 8192; 54 | 55 | const uint64_t POLY_MOD_DEGREE_LARGE = 32768; 56 | const int32_t SMUDGING_BITLEN = 108 - bitlength; 57 | 58 | /* Helper function for rounding to the next power of 2 59 | * Credit: 60 | * https://stackoverflow.com/questions/466204/rounding-up-to-next-power-of-2 */ 61 | inline int next_pow2(int val) { return pow(2, ceil(log(val) / log(2))); } 62 | 63 | #endif // DEFINES_HE_H__ 64 | -------------------------------------------------------------------------------- /test/LinearLayer/elemwise-prod-field.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Authors: Deevashwer Rathee 3 | Copyright: 4 | Copyright (c) 2020 Microsoft Research 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | */ 21 | 22 | #include "elemwise-prod-field.h" 23 | 24 | using namespace std; 25 | using namespace seal; 26 | using namespace emp; 27 | 28 | ElemWiseProdField::ElemWiseProdField(int party, NetIO *io) { 29 | this->party = party; 30 | this->io = io; 31 | this->slot_count = POLY_MOD_DEGREE; 32 | 33 | generate_new_keys(party, io, slot_count, context, encryptor, decryptor, 34 | evaluator, encoder, gal_keys, zero); 35 | } 36 | 37 | ElemWiseProdField::~ElemWiseProdField() { 38 | free_keys(party, encryptor, decryptor, evaluator, encoder, gal_keys, zero); 39 | } 40 | 41 | vector 42 | ElemWiseProdField::ideal_functionality(vector &inArr, 43 | vector &multArr) { 44 | vector result(inArr.size(), 0ULL); 45 | 46 | for (size_t i = 0; i < inArr.size(); i++) { 47 | result[i] = multArr[i] * inArr[i]; 48 | } 49 | return result; 50 | } 51 | 52 | void ElemWiseProdField::elemwise_product(int32_t size, vector &inArr, 53 | vector &multArr, 54 | vector &outputArr, 55 | bool verify_output, bool verbose) { 56 | int num_ct = ceil(float(size) / slot_count); 57 | 58 | if (party == BOB) { 59 | vector ct(num_ct); 60 | for (int i = 0; i < num_ct; i++) { 61 | int offset = i * slot_count; 62 | vector tmp_vec(slot_count, 0); 63 | Plaintext tmp_pt; 64 | for (int j = 0; j < slot_count && j + offset < size; j++) { 65 | tmp_vec[j] = neg_mod((int64_t)inArr[j + offset], (int64_t)prime_mod); 66 | } 67 | encoder->encode(tmp_vec, tmp_pt); 68 | encryptor->encrypt(tmp_pt, ct[i]); 69 | } 70 | send_encrypted_vector(io, ct); 71 | 72 | vector enc_result(num_ct); 73 | recv_encrypted_vector(io, context, enc_result); 74 | for (int i = 0; i < num_ct; i++) { 75 | int offset = i * slot_count; 76 | vector tmp_vec(slot_count, 0); 77 | Plaintext tmp_pt; 78 | decryptor->decrypt(enc_result[i], tmp_pt); 79 | encoder->decode(tmp_pt, tmp_vec); 80 | for (int j = 0; j < slot_count && j + offset < size; j++) { 81 | outputArr[j + offset] = tmp_vec[j]; 82 | } 83 | } 84 | if (verify_output) 85 | verify(inArr, nullptr, outputArr); 86 | } else // party == ALICE 87 | { 88 | vector multArr_pt(num_ct); 89 | for (int i = 0; i < num_ct; i++) { 90 | int offset = i * slot_count; 91 | vector<uint64_t> tmp_vec(slot_count, 0); 92 | for (int j = 0; j < slot_count && j + offset < size; j++) { 93 | tmp_vec[j] = neg_mod((int64_t)multArr[j + offset], (int64_t)prime_mod); 94 | } 95 | encoder->encode(tmp_vec, multArr_pt[i]); 96 | } 97 | 98 | PRG prg; 99 | vector<Ciphertext> enc_noise(num_ct); 100 | vector<vector<uint64_t>> secret_share(num_ct, 101 | vector<uint64_t>(slot_count, 0)); 102 | for (int i = 0; i < num_ct; i++) { 103 | Plaintext tmp_pt; 104 | random_mod_p(prg, secret_share[i].data(), slot_count, prime_mod); 105 | encoder->encode(secret_share[i], tmp_pt); 106 | encryptor->encrypt(tmp_pt, enc_noise[i]); 107 | } 108 | 109 | vector<Ciphertext> ct(num_ct); 110 | recv_encrypted_vector(io, context, ct); 111 | 112 | vector<Ciphertext> enc_result(num_ct); 113 | for (int i = 0; i < num_ct; i++) { 114 | #ifdef HE_DEBUG 115 | if (!i) 116 | PRINT_NOISE_BUDGET(decryptor, ct[i], "before product"); 117 | #endif 118 | 119 | if (multArr_pt[i].is_zero()) { 120 | enc_result[i] = *zero; 121 | } else { 122 | evaluator->multiply_plain(ct[i], multArr_pt[i], enc_result[i]); 123 | } 124 | evaluator->add_inplace(enc_result[i], enc_noise[i]); 125 | 126 | #ifdef HE_DEBUG 127 | if (!i) 128 | PRINT_NOISE_BUDGET(decryptor, enc_result[i], "after product"); 129 | #endif 130 | 131 | evaluator->mod_switch_to_next_inplace(enc_result[i]); 132 | 133 | #ifdef HE_DEBUG 134 | if (!i) 135 | PRINT_NOISE_BUDGET(decryptor, enc_result[i], "after mod-switch"); 136 | #endif 137 | 138 | parms_id_type parms_id = enc_result[i].parms_id(); 139 | shared_ptr<const SEALContext::ContextData> context_data = 140 | context->get_context_data(parms_id); 141 | flood_ciphertext(enc_result[i], context_data, SMUDGING_BITLEN); 142 | 143 | #ifdef HE_DEBUG 144 | if (!i) 145 | PRINT_NOISE_BUDGET(decryptor, enc_result[i], "after noise flooding"); 146 | #endif 147 | 148 | evaluator->mod_switch_to_next_inplace(enc_result[i]); 149 | 150 | #ifdef HE_DEBUG 151 | if (!i) 152 | PRINT_NOISE_BUDGET(decryptor, enc_result[i], "after mod-switch"); 153 | #endif 154 | } 155 | send_encrypted_vector(io, enc_result); 156 | 157 | auto result = ideal_functionality(inArr, multArr); 158 | 159 | for (int i = 0; i < num_ct; i++) { 160 | int offset = i * slot_count; 161 | for (int j = 0; j < slot_count && j + offset < size; j++) { 162 | outputArr[j + offset] = 163 | neg_mod((int64_t)result[j + offset] - (int64_t)secret_share[i][j], 164 | (int64_t)prime_mod); 165 | } 166 | } 167 | if (verify_output) 168 | verify(inArr, &multArr, outputArr); 169 | } 170 | } 171 | 172 | void ElemWiseProdField::verify(vector<uint64_t> &inArr, 173 | vector<uint64_t> *multArr, 174 | vector<uint64_t> &outArr) { 175 | if (party == BOB) { 176 | io->send_data(inArr.data(), inArr.size() * sizeof(uint64_t)); 177 | io->send_data(outArr.data(), outArr.size() * sizeof(uint64_t)); 178 | } else // party == ALICE 179 | { 180 | vector<uint64_t> inArr_0(inArr.size()); 181 | io->recv_data(inArr_0.data(), inArr.size() * sizeof(uint64_t)); 182 | for (size_t i = 0; i < inArr.size(); i++) { 183 | inArr_0[i] = (inArr[i] + inArr_0[i]) % prime_mod; 184 | } 185 | 186 | auto result = ideal_functionality(inArr_0, *multArr); 187 | 188 | vector<uint64_t> outArr_0(outArr.size()); 189 | io->recv_data(outArr_0.data(), outArr.size() * sizeof(uint64_t)); 190 | for (size_t i = 0; i < outArr.size(); i++) { 191 | outArr_0[i] = (outArr[i] + outArr_0[i]) % prime_mod; 192 | } 193 | bool pass = true; 194 | for (size_t i = 0; i < outArr.size(); i++) { 195 | if (neg_mod(result[i], (int64_t)prime_mod) != (int64_t)outArr_0[i]) { 196 | pass = false; 197 | } 198 | } 199 | if (pass) 200 | cout << GREEN << "[Server] Successful Operation" << RESET << endl; 201 | else { 202 | cout << RED << "[Server] Failed Operation" << RESET << endl; 203 | cout << RED << "WARNING: The implementation assumes that the computation" 204 | << endl; 205 | cout << "performed locally by the server (on the model and its input " 206 | "share)" 207 | << endl; 208 | cout << "fits in a 64-bit integer. The failed operation could be a result" 209 | << endl; 210 | cout << "of overflowing the bound." << RESET << endl; 211 | } 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /test/LinearLayer/elemwise-prod-field.h: -------------------------------------------------------------------------------- 1 | /* 2 | Authors: Deevashwer Rathee 3 | Copyright: 4 | Copyright (c) 2020 Microsoft Research 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | */ 21 | 22 | #ifndef ELEMWISEPROD_FIELD_H__ 23 | #define ELEMWISEPROD_FIELD_H__ 24 | 25 | #include "utils-HE.h" 26 | 27 | class ElemWiseProdField { 28 | public: 29 | int party; 30 | emp::NetIO *io; 31 | std::shared_ptr<seal::SEALContext> context; 32 | seal::Encryptor *encryptor; 33 | seal::Decryptor *decryptor; 34 | seal::Evaluator *evaluator; 35 | seal::BatchEncoder *encoder; 36 | seal::GaloisKeys *gal_keys; 37 | seal::Ciphertext *zero; 38 | int slot_count; 39 | 40 | ElemWiseProdField(int party, emp::NetIO *io); 41 | 42 | ~ElemWiseProdField(); 43 | 44 | std::vector<uint64_t> ideal_functionality(std::vector<uint64_t> &inArr, 45 | std::vector<uint64_t> &multArr); 46 | 47 | void elemwise_product(int32_t size, std::vector<uint64_t> &inArr, 48 | std::vector<uint64_t> &multArr, 49 | std::vector<uint64_t> &outputArr, 50 | bool verify_output = false, bool verbose = false); 51 | 52 | void verify(std::vector<uint64_t> &inArr, std::vector<uint64_t> *multArr, 53 | std::vector<uint64_t> &outArr); 54 | }; 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /test/LinearLayer/fc-field.h: -------------------------------------------------------------------------------- 1 | /* 2 | Original Author: ryanleh 3 | Modified Work Copyright (c) 2020 Microsoft Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | Modified by Deevashwer Rathee 24 | */ 25 | 26 | #ifndef FC_FIELD_H__ 27 | #define FC_FIELD_H__ 28 | 29 | #include "utils-HE.h" 30 | 31 | struct FCMetadata { 32 | int slot_count; 33 | int32_t pack_num; 34 | int32_t inp_ct; 35 | int32_t inp_ct_1; 36 | // Filter is a matrix 37 | int32_t filter_h; 38 | int32_t filter_w; 39 | int32_t filter_size; 40 | // Image is a vector 41 | int32_t image_size; 42 | }; 43 | 44 | seal::Ciphertext preprocess_vec(const uint64_t *input, const FCMetadata &data, 45 | seal::Encryptor &encryptor, 46 | seal::BatchEncoder &batch_encoder); 47 | seal::Plaintext preprocess_vec_plain(const uint64_t *input, const FCMetadata &data, 48 | seal::BatchEncoder &batch_encoder); 49 | std::vector<seal::Plaintext> 50 | preprocess_matrix(const uint64_t *const *matrix, const FCMetadata &data, 51 | seal::BatchEncoder &batch_encoder); 52 | 53 | seal::Plaintext fc_preprocess_noise(const uint64_t *secret_share, 54 | const FCMetadata &data, 55 | seal::BatchEncoder &batch_encoder); 56 | 57 | seal::Ciphertext fc_online(seal::Ciphertext &ct, 58 | std::vector<seal::Plaintext> &enc_mat, 59 | const FCMetadata &data, seal::Evaluator &evaluator, 60 | seal::GaloisKeys &gal_keys, seal::RelinKeys &relin_keys, seal::Ciphertext &zero); 61 | 62 | uint64_t *fc_postprocess(seal::Ciphertext &result, const FCMetadata &data, 63 | seal::BatchEncoder &batch_encoder, 64 | seal::Decryptor &decryptor); 65 | 66 | uint64_t *fc_postprocess_mac(seal::Ciphertext &result, const FCMetadata &data, 67 | seal::BatchEncoder &batch_encoder, 68 | seal::Decryptor &decryptor); 69 | 70 | class FCField { 71 | public: 72 | int party; 73 | emp::NetIO *io; 74 | FCMetadata data; 75 | std::shared_ptr<seal::SEALContext> context; 76 | seal::Encryptor *encryptor; 77 | seal::Decryptor *decryptor; 78 | seal::Evaluator *evaluator; 79 | seal::BatchEncoder *encoder; 80 | seal::GaloisKeys *gal_keys; 81 | seal::RelinKeys *relin_keys; 82 | seal::Ciphertext *zero; 83 | size_t slot_count; 84 | 85 | FCField(int party, emp::NetIO *io); 86 | 87 | ~FCField(); 88 | 89 | void configure(); 90 | 91 | std::vector<uint64_t> ideal_functionality(uint64_t *vec, uint64_t **matrix, seal::Modulus mod); 92 | 93 | void matrix_multiplication(int32_t num_rows, int32_t common_dim, 94 | int32_t num_cols, 95 | std::vector<std::vector<uint64_t>> &A, 96 | std::vector<std::vector<uint64_t>> &B, 97 | std::vector<std::vector<uint64_t>> &C, 98 | bool verify_output = false, bool verbose = false); 99 | void matrix_multiplication_first(int32_t num_rows, int32_t common_dim, 100 | int32_t num_cols, 101 | std::vector<std::vector<uint64_t>> &inputs, 102 | std::vector<std::vector<uint64_t>> &op_shares, 103 | std::vector<std::vector<uint64_t>> &mac_op_shares, 104 | seal::Modulus mod, 105 | bool verify_output = false, bool verbose = false); 106 | 107 | void matrix_multiplication_gen(int32_t num_rows, int32_t common_dim, 108 | int32_t num_cols, 109 | std::vector<std::vector<uint64_t>> &matrix, 110 | std::vector<std::vector<uint64_t>> &input_share, 111 | std::vector<std::vector<uint64_t>> &mac_input_share, 112 | seal::Modulus mod, 113 | bool verify_output = false, bool verbose = false); 114 | void verify(std::vector<uint64_t> *vec, std::vector<uint64_t *> *matrix, 115 | std::vector<std::vector<uint64_t>> &C); 116 | }; 117 | #endif 118 | -------------------------------------------------------------------------------- /test/LinearLayer/fc-protocol.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shahakash28/simc/2a5fd092b52427cc9cac55b36ec50ae43ecee6be/test/LinearLayer/fc-protocol.cpp -------------------------------------------------------------------------------- /test/LinearLayer/fc-protocol.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shahakash28/simc/2a5fd092b52427cc9cac55b36ec50ae43ecee6be/test/LinearLayer/fc-protocol.h -------------------------------------------------------------------------------- /test/LinearLayer/utils-HE.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Authors: Deevashwer Rathee 3 | Copyright: 4 | Copyright (c) 2020 Microsoft Research 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | */ 21 | 22 | #include "utils-HE.h" 23 | #include "seal/util/polyarithsmallmod.h" 24 | 25 | using namespace std; 26 | using namespace emp; 27 | using namespace seal; 28 | using namespace seal::util; 29 | 30 | 31 | /* Helper function for performing modulo with possibly negative numbers */ 32 | /*inline int8_t neg_mod(int8_t val, int8_t mod) { 33 | return ((val % mod) + mod) % mod; 34 | }*/ 35 | 36 | void generate_new_keys(int party, NetIO *io, int slot_count, 37 | shared_ptr<SEALContext> &context_, 38 | Encryptor *&encryptor_, Decryptor *&decryptor_, 39 | Evaluator *&evaluator_, BatchEncoder *&encoder_, 40 | GaloisKeys *&gal_keys_, RelinKeys *& relin_keys_, Ciphertext *&zero_, 41 | bool verbose) { 42 | // auto context = SEALContext::Create(parms, true, sec_level_type::none); 43 | EncryptionParameters parms(scheme_type::bfv); 44 | parms.set_poly_modulus_degree(slot_count); 45 | parms.set_coeff_modulus(CoeffModulus::BFVDefault(slot_count)); 46 | parms.set_plain_modulus(PLAINTEXT_MODULUS); 47 | 48 | context_ = shared_ptr<SEALContext>(new SEALContext(parms, false, sec_level_type::none)); 49 | 50 | encoder_ = new BatchEncoder(*context_); 51 | evaluator_ = new Evaluator(*context_); 52 | if (party == BOB) { 53 | KeyGenerator keygen(*context_); 54 | PublicKey pub_key; 55 | keygen.create_public_key(pub_key); 56 | auto sec_key = keygen.secret_key(); 57 | GaloisKeys gal_keys_; 58 | keygen.create_galois_keys(gal_keys_); 59 | RelinKeys relin_keys_; 60 | keygen.create_relin_keys(relin_keys_); 61 | 62 | stringstream os; 63 | pub_key.save(os); 64 | uint64_t pk_size = os.tellp(); 65 | gal_keys_.save(os); 66 | uint64_t gk_size = (uint64_t)os.tellp() - pk_size; 67 | relin_keys_.save(os); 68 | uint64_t rk_size = (uint64_t)os.tellp() - (pk_size + gk_size); 69 | 70 | string keys_ser = os.str(); 71 | io->send_data(&pk_size, sizeof(uint64_t)); 72 | io->send_data(&gk_size, sizeof(uint64_t)); 73 | io->send_data(&rk_size, sizeof(uint64_t)); 74 | io->send_data(keys_ser.c_str(), pk_size + gk_size + rk_size); 75 | 76 | #ifdef HE_DEBUG 77 | stringstream os_sk; 78 | sec_key.save(os_sk); 79 | uint64_t sk_size = os_sk.tellp(); 80 | string keys_ser_sk = os_sk.str(); 81 | io->send_data(&sk_size, sizeof(uint64_t)); 82 | io->send_data(keys_ser_sk.c_str(), sk_size); 83 | #endif 84 | encryptor_ = new Encryptor(*context_, pub_key); 85 | decryptor_ = new Decryptor(*context_, sec_key); 86 | } else // party == ALICE 87 | { 88 | uint64_t pk_size; 89 | uint64_t gk_size; 90 | uint64_t rk_size; 91 | io->recv_data(&pk_size, sizeof(uint64_t)); 92 | io->recv_data(&gk_size, sizeof(uint64_t)); 93 | io->recv_data(&rk_size, sizeof(uint64_t)); 94 | 95 | char *key_share = new char[pk_size + gk_size + rk_size]; 96 | io->recv_data(key_share, pk_size + gk_size + rk_size); 97 | stringstream is; 98 | PublicKey pub_key; 99 | is.write(key_share, pk_size); 100 | pub_key.load(*context_, is); 101 | gal_keys_ = new GaloisKeys(); 102 | is.write(key_share + pk_size, gk_size); 103 | gal_keys_->load(*context_, is); 104 | relin_keys_ = new RelinKeys(); 105 | is.write(key_share + pk_size + gk_size, rk_size); 106 | relin_keys_->load(*context_, is); 107 | delete[] key_share; 108 | 109 | #ifdef HE_DEBUG 110 | uint64_t sk_size; 111 | io->recv_data(&sk_size, sizeof(uint64_t)); 112 | char *key_share_sk = new char[sk_size]; 113 | io->recv_data(key_share_sk, sk_size); 114 | stringstream is_sk; 115 | SecretKey sec_key; 116 | is_sk.write(key_share_sk, sk_size); 117 | sec_key.load(*context_, is_sk); 118 | delete[] key_share_sk; 119 | decryptor_ = new Decryptor(*context_, sec_key); 120 | #endif 121 | encryptor_ = new Encryptor(*context_, pub_key); 122 | vector<uint64_t> pod_matrix(slot_count, 0ULL); 123 | Plaintext tmp; 124 | encoder_->encode(pod_matrix, tmp); 125 | zero_ = new Ciphertext; 126 | encryptor_->encrypt(tmp, *zero_); 127 | } 128 | if (verbose) 129 | cout << "Keys Generated (slot_count: " << slot_count << ")" << endl; 130 | } 131 | 132 | void free_keys(int party, Encryptor *&encryptor_, Decryptor *&decryptor_, 133 | Evaluator *&evaluator_, BatchEncoder *&encoder_, 134 | GaloisKeys *&gal_keys_, RelinKeys *&relin_keys_, Ciphertext *&zero_) { 135 | delete encoder_; 136 | delete evaluator_; 137 | delete encryptor_; 138 | if (party == BOB) { 139 | delete decryptor_; 140 | } else // party ==ALICE 141 | { 142 | #ifdef HE_DEBUG 143 | delete decryptor_; 144 | #endif 145 | delete gal_keys_; 146 | delete relin_keys_; 147 | delete zero_; 148 | } 149 | } 150 | 151 | void send_encrypted_vector(NetIO *io, vector<Ciphertext> &ct_vec) { 152 | assert(ct_vec.size() > 0); 153 | stringstream os; 154 | uint64_t ct_size[ct_vec.size()]; 155 | uint64_t prev_size=0; 156 | for (size_t ct = 0; ct < ct_vec.size(); ct++) { 157 | ct_vec[ct].save(os); 158 | ct_size[ct] = os.tellp() - prev_size; 159 | prev_size += ct_size[ct]; 160 | } 161 | 162 | string ct_ser = os.str(); 163 | for(int i=0; i<ct_vec.size(); i++) { 164 | io->send_data(&ct_size[i], sizeof(uint64_t)); 165 | } 166 | io->send_data(ct_ser.c_str(), ct_ser.size()); 167 | } 168 | 169 | void recv_encrypted_vector(NetIO *io, shared_ptr<SEALContext> &context_, vector<Ciphertext> &ct_vec) { 170 | assert(ct_vec.size() > 0); 171 | stringstream is; 172 | uint64_t ct_size[ct_vec.size()]; 173 | uint64_t total_size=0; 174 | for(int i=0; i<ct_vec.size(); i++) { 175 | io->recv_data(&ct_size[i], sizeof(uint64_t)); 176 | total_size += ct_size[i]; 177 | } 178 | 179 | char *c_enc_result = new char[total_size]; 180 | io->recv_data(c_enc_result, total_size); 181 | uint64_t prev_size=0; 182 | for (size_t ct = 0; ct < ct_vec.size(); ct++) { 183 | is.write(c_enc_result + prev_size, ct_size[ct]); 184 | prev_size += ct_size[ct]; 185 | ct_vec[ct].unsafe_load(*context_, is); 186 | } 187 | delete[] c_enc_result; 188 | } 189 | 190 | void send_ciphertext(NetIO *io, Ciphertext &ct) { 191 | stringstream os; 192 | uint64_t ct_size; 193 | ct.save(os); 194 | ct_size = os.tellp(); 195 | string ct_ser = os.str(); 196 | io->send_data(&ct_size, sizeof(uint64_t)); 197 | io->send_data(ct_ser.c_str(), ct_ser.size()); 198 | } 199 | 200 | void recv_ciphertext(NetIO *io, shared_ptr<SEALContext> &context_, Ciphertext &ct) { 201 | stringstream is; 202 | uint64_t ct_size; 203 | io->recv_data(&ct_size, sizeof(uint64_t)); 204 | char *c_enc_result = new char[ct_size]; 205 | io->recv_data(c_enc_result, ct_size); 206 | is.write(c_enc_result, ct_size); 207 | ct.unsafe_load(*context_, is); 208 | delete[] c_enc_result; 209 | } 210 | 211 | void set_poly_coeffs_uniform( 212 | uint64_t *poly, uint32_t bitlen, shared_ptr<UniformRandomGenerator> random, 213 | shared_ptr<const SEALContext::ContextData> &context_data) { 214 | assert(bitlen < 128 && bitlen > 0); 215 | auto &parms = context_data->parms(); 216 | auto &coeff_modulus = parms.coeff_modulus(); 217 | size_t coeff_count = parms.poly_modulus_degree(); 218 | size_t coeff_mod_count = coeff_modulus.size(); 219 | uint64_t bitlen_mask = (1ULL << (bitlen % 64)) - 1; 220 | 221 | RandomToStandardAdapter engine(random); 222 | for (size_t i = 0; i < coeff_count; i++) { 223 | if (bitlen < 64) { 224 | uint64_t noise = (uint64_t(engine()) << 32) | engine(); 225 | noise &= bitlen_mask; 226 | for (size_t j = 0; j < coeff_mod_count; j++) { 227 | poly[i + (j * coeff_count)] = 228 | barrett_reduce_64(noise, coeff_modulus[j]); 229 | } 230 | } else { 231 | uint64_t noise[2]; // LSB || MSB 232 | for (int j = 0; j < 2; j++) { 233 | noise[0] = (uint64_t(engine()) << 32) | engine(); 234 | noise[1] = (uint64_t(engine()) << 32) | engine(); 235 | } 236 | noise[1] &= bitlen_mask; 237 | for (size_t j = 0; j < coeff_mod_count; j++) { 238 | poly[i + (j * coeff_count)] = 239 | barrett_reduce_128(noise, coeff_modulus[j]); 240 | } 241 | } 242 | } 243 | } 244 | 245 | void flood_ciphertext(Ciphertext &ct, 246 | shared_ptr<const SEALContext::ContextData> &context_data, 247 | uint32_t noise_len, MemoryPoolHandle pool) { 248 | 249 | auto &parms = context_data->parms(); 250 | auto &coeff_modulus = parms.coeff_modulus(); 251 | size_t coeff_count = parms.poly_modulus_degree(); 252 | size_t coeff_mod_count = coeff_modulus.size(); 253 | 254 | auto noise(allocate_poly(coeff_count, coeff_mod_count, pool)); 255 | shared_ptr<UniformRandomGenerator> random(parms.random_generator()->create()); 256 | 257 | set_poly_coeffs_uniform(noise.get(), noise_len, random, context_data); 258 | for (size_t i = 0; i < coeff_mod_count; i++) { 259 | add_poly_coeffmod(noise.get() + (i * coeff_count), 260 | ct.data() + (i * coeff_count), coeff_count, 261 | coeff_modulus[i], ct.data() + (i * coeff_count)); 262 | } 263 | 264 | set_poly_coeffs_uniform(noise.get(), noise_len, random, context_data); 265 | for (size_t i = 0; i < coeff_mod_count; i++) { 266 | add_poly_coeffmod(noise.get() + (i * coeff_count), 267 | ct.data(1) + (i * coeff_count), coeff_count, 268 | coeff_modulus[i], ct.data(1) + (i * coeff_count)); 269 | } 270 | } 271 | 272 | void random_mod_p(PRG &prg, uint64_t *arr, uint64_t size, uint64_t prime_mod) { 273 | uint64_t boundary = (((-1 * prime_mod) / prime_mod) + 1) * 274 | prime_mod; // prime_mod*floor((2^l)/prime_mod) 275 | int tries_before_resampling = 2; 276 | uint64_t size_total = tries_before_resampling * size; 277 | uint64_t *randomness = new uint64_t[size_total]; 278 | uint64_t rptr = 0, arrptr = 0; 279 | while (arrptr < size) { 280 | prg.random_data(randomness, sizeof(uint64_t) * size_total); 281 | rptr = 0; 282 | for (; (arrptr < size) && (rptr < size_total); arrptr++, rptr++) { 283 | while (randomness[rptr] > boundary) { 284 | rptr++; 285 | if (rptr >= size_total) { 286 | prg.random_data(randomness, sizeof(uint64_t) * size_total); 287 | rptr = 0; 288 | } 289 | } 290 | arr[arrptr] = randomness[rptr] % prime_mod; 291 | } 292 | } 293 | delete[] randomness; 294 | } 295 | 296 | uint64_t mod_mult(uint64_t a, uint64_t b, seal::Modulus mod) { 297 | unsigned long long temp_result[2]; 298 | seal::util::multiply_uint64(a, b, temp_result); 299 | uint64_t result = seal::util::barrett_reduce_128(temp_result, mod); 300 | return result; 301 | } 302 | -------------------------------------------------------------------------------- /test/LinearLayer/utils-HE.h: -------------------------------------------------------------------------------- 1 | /* 2 | Authors: Deevashwer Rathee 3 | Copyright: 4 | Copyright (c) 2020 Microsoft Research 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | */ 21 | 22 | #ifndef UTILS_HE_H__ 23 | #define UTILS_HE_H__ 24 | 25 | #include "defines-HE.h" 26 | #include "seal/seal.h" 27 | #include "emp-tool/emp-tool.h" 28 | #include "emp-ot/emp-ot.h" 29 | 30 | // Taken from https://github.com/mc2-project/delphi/blob/master/rust/protocols-sys/c++/src/lib/conv2d.h 31 | /* Helper function for performing modulo with possibly negative numbers */ 32 | inline int64_t neg_mod(int64_t val, int64_t mod) { 33 | return ((val % mod) + mod) % mod; 34 | } 35 | 36 | #define PRINT_NOISE_BUDGET(decryptor, ct, print_msg) \ 37 | if (verbose) \ 38 | std::cout << "[Server] Noise Budget " << print_msg << ": " << YELLOW \ 39 | << decryptor->invariant_noise_budget(ct) << " bits" << RESET \ 40 | << std::endl 41 | 42 | void generate_new_keys(int party, emp::NetIO *io, int slot_count, 43 | std::shared_ptr<seal::SEALContext> &context_, 44 | seal::Encryptor *&encryptor_, 45 | seal::Decryptor *&decryptor_, 46 | seal::Evaluator *&evaluator_, 47 | seal::BatchEncoder *&encoder_, 48 | seal::GaloisKeys *&gal_keys_, 49 | seal::RelinKeys *& relin_keys_, 50 | seal::Ciphertext *&zero_, 51 | bool verbose = false); 52 | 53 | void free_keys(int party, seal::Encryptor *&encryptor_, 54 | seal::Decryptor *&decryptor_, seal::Evaluator *&evaluator_, 55 | seal::BatchEncoder *&encoder_, seal::GaloisKeys *&gal_keys_, 56 | seal::RelinKeys *&relin_keys_, 57 | seal::Ciphertext *&zero_); 58 | 59 | void send_encrypted_vector(emp::NetIO *io, 60 | std::vector<seal::Ciphertext> &ct_vec); 61 | 62 | void recv_encrypted_vector(emp::NetIO *io, std::shared_ptr<seal::SEALContext> &context_, 63 | std::vector<seal::Ciphertext> &ct_vec); 64 | 65 | void send_ciphertext(emp::NetIO *io, seal::Ciphertext &ct); 66 | 67 | void recv_ciphertext(emp::NetIO *io, std::shared_ptr<seal::SEALContext> &context_, seal::Ciphertext &ct); 68 | 69 | void set_poly_coeffs_uniform( 70 | uint64_t *poly, uint32_t bitlen, 71 | std::shared_ptr<seal::UniformRandomGenerator> random, 72 | std::shared_ptr<const seal::SEALContext::ContextData> &context_data); 73 | 74 | void flood_ciphertext( 75 | seal::Ciphertext &ct, 76 | std::shared_ptr<const seal::SEALContext::ContextData> &context_data, 77 | uint32_t noise_len, 78 | seal::MemoryPoolHandle pool = seal::MemoryManager::GetPool()); 79 | 80 | void random_mod_p(PRG &prg, uint64_t *arr, uint64_t size, uint64_t prime_mod); 81 | 82 | uint64_t mod_mult(uint64_t a, uint64_t b, seal::Modulus mod); 83 | #endif // UTILS_HE_H__ 84 | -------------------------------------------------------------------------------- /test/msi_average.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-sh2pc/emp-sh2pc.h" 2 | #include <cmath> 3 | 4 | #include "seal/util/uintarith.h" 5 | #include "seal/util/uintarithsmallmod.h" 6 | #include "LinearLayer/defines-HE.h" 7 | #include <thread> 8 | #include "LinearLayer/utils-HE.h" 9 | #define MAX_THREADS 8 10 | using namespace emp; 11 | using namespace std; 12 | 13 | enum neural_net { 14 | NONE, 15 | MINIONN, 16 | CIFAR10 17 | }; 18 | 19 | struct dimension { 20 | int N; 21 | int l; 22 | int b; 23 | int d; 24 | }; 25 | 26 | neural_net choice_nn; 27 | neural_net def_nn = NONE; 28 | 29 | string address; 30 | 31 | uint64_t prime_val = 17592060215297; 32 | seal::Modulus mod(prime_val); 33 | 34 | uint64_t prime_mod; 35 | uint64_t moduloMask; 36 | uint64_t moduloMidPt; 37 | uint64_t avg_pool_const = 64; 38 | 39 | uint64_t mac_key; 40 | PRG prg; 41 | string benchmark; 42 | 43 | NetIO *ioArr[MAX_THREADS]; 44 | 45 | uint64_t prime_field; 46 | 47 | int l = 44; 48 | 49 | typedef std::vector<uint64_t> uint64_1D; 50 | 51 | template <typename T> vector<T> make_vector(size_t size) { 52 | return std::vector<T>(size); 53 | } 54 | 55 | template <typename T> T *make_array(size_t s1) { return new T[s1]; } 56 | 57 | template <typename T> T *make_array(size_t s1, size_t s2) { 58 | return new T[s1 * s2]; 59 | } 60 | 61 | template <typename T> T *make_array(size_t s1, size_t s2, size_t s3) { 62 | return new T[s1 * s2 * s3]; 63 | } 64 | 65 | template <typename T> 66 | T *make_array(size_t s1, size_t s2, size_t s3, size_t s4) { 67 | return new T[s1 * s2 * s3 * s4]; 68 | } 69 | 70 | template <typename T> 71 | T *make_array(size_t s1, size_t s2, size_t s3, size_t s4, size_t s5) { 72 | return new T[s1 * s2 * s3 * s4 * s5]; 73 | } 74 | 75 | uint64_t mod_mult(uint64_t a, uint64_t b) { 76 | unsigned long long temp_result[2]; 77 | seal::util::multiply_uint64(a, b, temp_result); 78 | 79 | /*uint64_t input[2]; 80 | input[0] = res[0]; 81 | input[1] = res[1];*/ 82 | uint64_t result = seal::util::barrett_reduce_128(temp_result, mod); 83 | return result; 84 | } 85 | 86 | void div_floor(int64_t a, int64_t b, int64_t &quot, int64_t &rem) { 87 | assert(b > 0); 88 | int64_t q = a / b; 89 | int64_t r = a % b; 90 | int64_t corr = ((r != 0) && (r < 0)); 91 | quot = q - corr; 92 | rem = (r + b) % b; 93 | } 94 | 95 | inline int64_t getSignedVal(uint64_t x) { 96 | assert(x < prime_mod); 97 | int64_t sx = x; 98 | if (x >= moduloMidPt) 99 | sx = x - prime_mod; 100 | return sx; 101 | } 102 | 103 | inline uint64_t getRingElt(int64_t x) { return ((uint64_t)x) & moduloMask; } 104 | 105 | inline uint64_t PublicAdd(uint64_t x, uint64_t y) { 106 | assert((x < prime_mod) && (y < prime_mod)); 107 | return (x + y) & moduloMask; 108 | } 109 | 110 | inline uint64_t PublicSub(uint64_t x, uint64_t y) { 111 | assert((x < prime_mod) && (y < prime_mod)); 112 | return (x - y) & moduloMask; 113 | } 114 | 115 | inline uint64_t PublicMult(uint64_t x, uint64_t y) { 116 | assert((x < prime_mod) && (y < prime_mod)); 117 | return (x * y) & moduloMask; // This works because its a two-power ring 118 | } 119 | 120 | inline bool PublicGT(uint64_t x, uint64_t y) { 121 | int64_t sx = getSignedVal(x); 122 | int64_t sy = getSignedVal(y); 123 | return (sx > sy); 124 | } 125 | 126 | inline bool PublicGTE(uint64_t x, uint64_t y) { 127 | int64_t sx = getSignedVal(x); 128 | int64_t sy = getSignedVal(y); 129 | return (sx >= sy); 130 | } 131 | 132 | inline bool PublicLT(uint64_t x, uint64_t y) { 133 | int64_t sx = getSignedVal(x); 134 | int64_t sy = getSignedVal(y); 135 | return (sx < sy); 136 | } 137 | 138 | inline bool PublicLTE(uint64_t x, uint64_t y) { 139 | int64_t sx = getSignedVal(x); 140 | int64_t sy = getSignedVal(y); 141 | return (sx <= sy); 142 | } 143 | 144 | uint64_t PublicDiv(uint64_t x, uint64_t y) { 145 | int64_t sx = getSignedVal(x); 146 | int64_t sy = getSignedVal(y); 147 | int64_t q, r; 148 | div_floor(sx, sy, q, r); 149 | return getRingElt(q); 150 | } 151 | 152 | uint64_t PublicMod(uint64_t x, uint64_t y) { 153 | int64_t sx = getSignedVal(x); 154 | int64_t sy = getSignedVal(y); 155 | int64_t q, r; 156 | div_floor(sx, sy, q, r); 157 | return r; 158 | } 159 | 160 | inline uint64_t PublicRShiftA(uint64_t x, uint64_t y) { 161 | assert((x < prime_mod) && (y < prime_mod)); 162 | int64_t sx = getSignedVal(x); 163 | int64_t ans = sx >> y; 164 | return getRingElt(ans); 165 | } 166 | 167 | inline uint64_t PublicRShiftL(uint64_t x, uint64_t y) { 168 | assert((x < prime_mod) && (y < prime_mod)); 169 | return (x >> y); 170 | } 171 | 172 | inline uint64_t PublicLShift(uint64_t x, uint64_t y) { 173 | assert((x < prime_mod) && (y < prime_mod)); 174 | return (x << y) & moduloMask; 175 | } 176 | 177 | void AvgPool_pt(uint64_t N, uint64_t H, uint64_t W, uint64_t C, uint64_t ksizeH, 178 | uint64_t ksizeW, uint64_t zPadHLeft, uint64_t zPadHRight, 179 | uint64_t zPadWLeft, uint64_t zPadWRight, uint64_t strideH, 180 | uint64_t strideW, uint64_t N1, uint64_t imgH, uint64_t imgW, 181 | uint64_t C1, 182 | std::vector<std::vector<std::vector<uint64_1D>>> &inArr, 183 | std::vector<std::vector<std::vector<uint64_1D>>> &outArr) { 184 | uint64_t rows = (PublicMult((PublicMult((PublicMult(N, C)), H)), W)); 185 | 186 | auto filterAvg = make_vector<uint64_t>(rows); 187 | 188 | uint64_t rowIdx = (int32_t)0; 189 | for (uint64_t n = (int32_t)0; n < N; n++) { 190 | for (uint64_t c = (int32_t)0; c < C; c++) { 191 | 192 | uint64_t leftTopCornerH = (PublicSub((int32_t)0, zPadHLeft)); 193 | 194 | uint64_t extremeRightBottomCornerH = 195 | (PublicAdd((PublicSub(imgH, (int32_t)1)), zPadHRight)); 196 | 197 | uint64_t ctH = (int32_t)0; 198 | while ((PublicLTE( 199 | (PublicSub((PublicAdd(leftTopCornerH, ksizeH)), (int32_t)1)), 200 | extremeRightBottomCornerH))) { 201 | 202 | uint64_t leftTopCornerW = (PublicSub((int32_t)0, zPadWLeft)); 203 | 204 | uint64_t extremeRightBottomCornerW = 205 | (PublicAdd((PublicSub(imgW, (int32_t)1)), zPadWRight)); 206 | 207 | uint64_t ctW = (int32_t)0; 208 | while ((PublicLTE( 209 | (PublicSub((PublicAdd(leftTopCornerW, ksizeW)), (int32_t)1)), 210 | extremeRightBottomCornerW))) { 211 | 212 | uint64_t curFilterSum = (int64_t)0; 213 | for (uint64_t fh = (int32_t)0; fh < ksizeH; fh++) { 214 | for (uint64_t fw = (int32_t)0; fw < ksizeW; fw++) { 215 | 216 | uint64_t curPosH = (PublicAdd(leftTopCornerH, fh)); 217 | 218 | uint64_t curPosW = (PublicAdd(leftTopCornerW, fw)); 219 | 220 | uint64_t temp = (int64_t)0; 221 | if ((((PublicLT(curPosH, (int32_t)0)) || 222 | (PublicGTE(curPosH, imgH))) || 223 | ((PublicLT(curPosW, (int32_t)0)) || 224 | (PublicGTE(curPosW, imgW))))) { 225 | temp = (int64_t)0; 226 | } else { 227 | temp = inArr[n][curPosH][curPosW][c]; 228 | } 229 | curFilterSum = (PublicAdd(curFilterSum, temp)); 230 | } 231 | } 232 | 233 | uint64_t ksizeH64 = ksizeH; 234 | 235 | uint64_t ksizeW64 = ksizeW; 236 | 237 | uint64_t filterSz64 = (PublicMult(ksizeH64, ksizeW64)); 238 | 239 | uint64_t curFilterAvg = (PublicDiv(curFilterSum, filterSz64)); 240 | filterAvg[rowIdx] = curFilterAvg; 241 | rowIdx = (PublicAdd(rowIdx, (int32_t)1)); 242 | leftTopCornerW = (PublicAdd(leftTopCornerW, strideW)); 243 | ctW = (PublicAdd(ctW, (int32_t)1)); 244 | } 245 | 246 | leftTopCornerH = (PublicAdd(leftTopCornerH, strideH)); 247 | ctH = (PublicAdd(ctH, (int32_t)1)); 248 | } 249 | } 250 | } 251 | for (uint64_t n = (int32_t)0; n < N; n++) { 252 | for (uint64_t c = (int32_t)0; c < C; c++) { 253 | for (uint64_t h = (int32_t)0; h < H; h++) { 254 | for (uint64_t w = (int32_t)0; w < W; w++) { 255 | outArr[n][h][w][c] = filterAvg[(PublicAdd( 256 | (PublicAdd( 257 | (PublicAdd( 258 | (PublicMult((PublicMult((PublicMult(n, C)), H)), W)), 259 | (PublicMult((PublicMult(c, H)), W)))), 260 | (PublicMult(h, W)))), 261 | w))]; 262 | } 263 | } 264 | } 265 | } 266 | } 267 | 268 | void parse_arguments(int argc, char**arg, int *party, int *port, int *bitlen) { 269 | *party = atoi (arg[1]); 270 | address = arg[2]; 271 | *port = atoi (arg[3]); 272 | if(argc < 5) { 273 | *bitlen = l; 274 | } else { 275 | *bitlen = atoi(arg[4]); 276 | } 277 | 278 | if(argc < 6) { 279 | choice_nn =def_nn; 280 | } else { 281 | choice_nn = neural_net(atoi (arg[5])); 282 | } 283 | 284 | if(choice_nn == MINIONN) { 285 | benchmark = "mnist"; 286 | } else { 287 | benchmark = "cifar10"; 288 | } 289 | 290 | prime_mod = (*bitlen == 64 ? 0ULL : 1ULL << *bitlen); 291 | moduloMask = prime_mod - 1; 292 | moduloMidPt = prime_mod / 2; 293 | } 294 | 295 | int main(int argc, char** argv) { 296 | 297 | srand(time(NULL)); 298 | int port, party, nrelu, bitlen; 299 | //Parse input arguments and configure parameters 300 | parse_arguments(argc, argv, &party, &port, &bitlen); 301 | 302 | cout<<"Executing Average-pool Layers ..."<<endl; 303 | cout << "=====================Configuration======================" << endl; 304 | cout<<"Role: "<< party<<" - IP Address: "<< address <<" - Port: "<<port<<" - Benchmark: "<<benchmark<<" - Bitlength: "<<bitlen<<endl; 305 | cout << "========================================================" << endl; 306 | 307 | 308 | ioArr[0] = new NetIO(party==ALICE ? nullptr : address.c_str(), port); 309 | 310 | //Prepare and share inputs 311 | uint64_t layers_count=2; 312 | 313 | uint64_t *inputs[layers_count], *inputs_mac[layers_count], *outputs[layers_count], *outputs_mac[layers_count]; 314 | uint64_t *prepared_input[layers_count], *prepared_input_mac[layers_count]; 315 | uint64_t *send_inputs[layers_count], *send_input_mac[layers_count]; 316 | 317 | dimension input_dim[layers_count], output_dim[layers_count]; 318 | 319 | if(choice_nn == MINIONN) { 320 | input_dim[0].N = 1; input_dim[0].l = 24, input_dim[0].b = 24, input_dim[0].d = 16; 321 | input_dim[1].N = 1; input_dim[1].l = 8, input_dim[1].b = 8, input_dim[1].d = 16; 322 | 323 | output_dim[0].N = 1; output_dim[0].l = 12, output_dim[0].b = 12, output_dim[0].d = 16; 324 | output_dim[1].N = 1; output_dim[1].l = 4, output_dim[1].b = 4, output_dim[1].d = 16; 325 | } else { 326 | input_dim[0].N = 1; input_dim[0].l = 32, input_dim[0].b = 32, input_dim[0].d = 64; 327 | input_dim[1].N = 1; input_dim[1].l = 16, input_dim[1].b = 16, input_dim[1].d = 64; 328 | 329 | output_dim[0].N = 1; output_dim[0].l = 16, output_dim[0].b = 16, output_dim[0].d = 64; 330 | output_dim[1].N = 1; output_dim[1].l = 16, output_dim[1].b = 16, output_dim[1].d = 64; 331 | } 332 | 333 | for(int i=0; i< layers_count; i++) { 334 | inputs[i] = make_array<uint64_t>((int32_t)input_dim[i].N, (int32_t)input_dim[i].l, (int32_t)input_dim[i].b, (int32_t)input_dim[i].d); 335 | inputs_mac[i] = make_array<uint64_t>((int32_t)input_dim[i].N, (int32_t)input_dim[i].l, (int32_t)input_dim[i].b, (int32_t)input_dim[i].d); 336 | 337 | outputs[i] = make_array<uint64_t>((int32_t)output_dim[i].N, (int32_t)output_dim[i].l, (int32_t)output_dim[i].b, (int32_t)output_dim[i].d); 338 | outputs_mac[i] = make_array<uint64_t>((int32_t)output_dim[i].N, (int32_t)output_dim[i].l, (int32_t)output_dim[i].b, (int32_t)output_dim[i].d); 339 | } 340 | 341 | std::random_device rd; 342 | std::mt19937_64 eng(rd()); 343 | std::uniform_int_distribution<uint64_t> distr; 344 | 345 | if(party == ALICE) { 346 | prg.random_data(&mac_key, 8); 347 | mac_key %= prime_val; 348 | 349 | for(int i=0; i<layers_count; i++) { 350 | prepared_input[i] = make_array<uint64_t>((int32_t)input_dim[i].N, (int32_t)input_dim[i].l, (int32_t)input_dim[i].b, (int32_t)input_dim[i].d); 351 | prepared_input_mac[i] = make_array<uint64_t>((int32_t)input_dim[i].N, (int32_t)input_dim[i].l, (int32_t)input_dim[i].b, (int32_t)input_dim[i].d); 352 | send_inputs[i] = make_array<uint64_t>((int32_t)input_dim[i].N, (int32_t)input_dim[i].l, (int32_t)input_dim[i].b, (int32_t)input_dim[i].d); 353 | send_input_mac[i] = make_array<uint64_t>((int32_t)input_dim[i].N, (int32_t)input_dim[i].l, (int32_t)input_dim[i].b, (int32_t)input_dim[i].d); 354 | int arr_size = input_dim[i].N * input_dim[i].l * input_dim[i].b * input_dim[i].d; 355 | random_mod_p(prg, prepared_input[i], arr_size, prime_val); 356 | for(int j=0; j< arr_size; j++) { 357 | prepared_input_mac[i][j] = mod_mult(mac_key,prepared_input[i][j]); 358 | } 359 | random_mod_p(prg, inputs[i], arr_size, prime_val); 360 | random_mod_p(prg, inputs_mac[i], arr_size, prime_val); 361 | for(int j=0; j<arr_size; j++) { 362 | send_inputs[i][j] = (prepared_input[i][j] - inputs[i][j])%prime_val; 363 | send_input_mac[i][j] = (prepared_input_mac[i][j] - inputs_mac[i][j])%prime_val; 364 | } 365 | ioArr[0]->send_data(send_inputs[i], sizeof(uint64_t) * arr_size); 366 | ioArr[0]->send_data(send_input_mac[i], sizeof(uint64_t) * arr_size); 367 | } 368 | } else { 369 | for(int i=0; i<layers_count; i++) { 370 | int arr_size = input_dim[i].N * input_dim[i].l * input_dim[i].b * input_dim[i].d; 371 | ioArr[0]->recv_data(inputs[i], sizeof(uint64_t)* arr_size); 372 | ioArr[0]->recv_data(inputs_mac[i], sizeof(uint64_t)* arr_size); 373 | } 374 | } 375 | 376 | 377 | //Performance Result 378 | std::vector<std::vector<std::vector<std::vector<uint64_t>>>> inVec[layers_count]; 379 | std::vector<std::vector<std::vector<std::vector<uint64_t>>>> inVecMac[layers_count]; 380 | std::vector<std::vector<std::vector<std::vector<uint64_t>>>> outVec[layers_count]; 381 | std::vector<std::vector<std::vector<std::vector<uint64_t>>>> outVecMac[layers_count]; 382 | 383 | for(int i=0; i<layers_count; i++) { 384 | inVec[i].resize(input_dim[i].N, std::vector<std::vector<std::vector<uint64_t>>>( 385 | input_dim[i].l, std::vector<std::vector<uint64_t>>( 386 | input_dim[i].b, std::vector<uint64_t>(input_dim[i].d, 0)))); 387 | inVecMac[i].resize(input_dim[i].N, std::vector<std::vector<std::vector<uint64_t>>>( 388 | input_dim[i].l, std::vector<std::vector<uint64_t>>( 389 | input_dim[i].b, std::vector<uint64_t>(input_dim[i].d, 0)))); 390 | outVec[i].resize(output_dim[i].N, std::vector<std::vector<std::vector<uint64_t>>>( 391 | output_dim[i].l, std::vector<std::vector<uint64_t>>( 392 | output_dim[i].b, std::vector<uint64_t>(output_dim[i].d, 0)))); 393 | outVecMac[i].resize(output_dim[i].N, std::vector<std::vector<std::vector<uint64_t>>>( 394 | output_dim[i].l, std::vector<std::vector<uint64_t>>( 395 | output_dim[i].b, std::vector<uint64_t>(output_dim[i].d, 0)))); 396 | for(int j=0; j<input_dim[i].N; j++) { 397 | for(int k=0; k<input_dim[i].l; k++) { 398 | for(int l=0; l<input_dim[i].b; l++) { 399 | for(int m=0; m<input_dim[i].d; m++) { 400 | inVec[i][j][k][l][m] = inputs[i][(j) * (input_dim[i].l) * (input_dim[i].b) * (input_dim[i].d) + (k) * (input_dim[i].b) * (input_dim[i].d) + (l) * (input_dim[i].d) + (m)]; 401 | inVecMac[i][j][k][l][m] = (*((inputs_mac[i]) + (j) * (input_dim[i].l) * (input_dim[i].b) * (input_dim[i].d) + (k) * (input_dim[i].b) * (input_dim[i].d) + (l) * (input_dim[i].d) + (m))); 402 | } 403 | } 404 | } 405 | } 406 | } 407 | 408 | 409 | auto start = clock_start(); 410 | if(choice_nn == MINIONN) { 411 | AvgPool_pt((int32_t) output_dim[0].N, (int32_t) output_dim[0].l, (int32_t) output_dim[0].b, (int32_t) output_dim[0].d, (int32_t) 2, 412 | (int32_t) 2, (int32_t) 0, (int32_t) 0, 413 | (int32_t) 0, (int32_t) 0, (int32_t) 2, 414 | (int32_t) 2, (int32_t) input_dim[0].N, (int32_t) input_dim[0].l, (int32_t) input_dim[0].b, 415 | (int32_t) input_dim[0].d, 416 | inVec[0], outVec[0]); 417 | 418 | AvgPool_pt((int32_t) output_dim[0].N, (int32_t) output_dim[0].l, (int32_t) output_dim[0].b, (int32_t) output_dim[0].d, (int32_t) 2, 419 | (int32_t) 2, (int32_t) 0, (int32_t) 0, 420 | (int32_t) 0, (int32_t) 0, (int32_t) 2, 421 | (int32_t) 2, (int32_t) input_dim[0].N, (int32_t) input_dim[0].l, (int32_t) input_dim[0].b, 422 | (int32_t) input_dim[0].d, 423 | inVecMac[0], outVecMac[0]); 424 | AvgPool_pt((int32_t) output_dim[1].N, (int32_t) output_dim[1].l, (int32_t) output_dim[1].b, (int32_t) output_dim[1].d, (int32_t) 2, 425 | (int32_t) 2, (int32_t) 0, (int32_t) 0, 426 | (int32_t) 0, (int32_t) 0, (int32_t) 2, 427 | (int32_t) 2, (int32_t) input_dim[1].N, (int32_t) input_dim[1].l, (int32_t) input_dim[1].b, 428 | (int32_t) input_dim[1].d, 429 | inVec[1], outVecMac[1]); 430 | AvgPool_pt((int32_t) output_dim[1].N, (int32_t) output_dim[1].l, (int32_t) output_dim[1].b, (int32_t) output_dim[1].d, (int32_t) 2, 431 | (int32_t) 2, (int32_t) 0, (int32_t) 0, 432 | (int32_t) 0, (int32_t) 0, (int32_t) 2, 433 | (int32_t) 2, (int32_t) input_dim[1].N, (int32_t) input_dim[1].l, (int32_t) input_dim[1].b, 434 | (int32_t) input_dim[1].d, 435 | inVecMac[1], outVecMac[1]); 436 | } else { 437 | AvgPool_pt((int32_t) output_dim[0].N, (int32_t) output_dim[0].l, (int32_t) output_dim[0].b, (int32_t) output_dim[0].d, (int32_t) 2, 438 | (int32_t) 2, (int32_t) 0, (int32_t) 0, 439 | (int32_t) 0, (int32_t) 0, (int32_t) 2, 440 | (int32_t) 2, (int32_t) input_dim[0].N, (int32_t) input_dim[0].l, (int32_t) input_dim[0].b, 441 | (int32_t) input_dim[0].d, 442 | inVec[0], outVec[0]); 443 | 444 | AvgPool_pt((int32_t) output_dim[0].N, (int32_t) output_dim[0].l, (int32_t) output_dim[0].b, (int32_t) output_dim[0].d, (int32_t) 2, 445 | (int32_t) 2, (int32_t) 0, (int32_t) 0, 446 | (int32_t) 0, (int32_t) 0, (int32_t) 2, 447 | (int32_t) 2, (int32_t) input_dim[0].N, (int32_t) input_dim[0].l, (int32_t) input_dim[0].b, 448 | (int32_t) input_dim[0].d, 449 | inVecMac[0], outVecMac[0]); 450 | AvgPool_pt((int32_t) output_dim[1].N, (int32_t) output_dim[1].l, (int32_t) output_dim[1].b, (int32_t) output_dim[1].d, (int32_t) 2, 451 | (int32_t) 2, (int32_t) 0, (int32_t) 1, 452 | (int32_t) 0, (int32_t) 1, (int32_t) 1, 453 | (int32_t) 1, (int32_t) input_dim[1].N, (int32_t) input_dim[1].l, (int32_t) input_dim[1].b, 454 | (int32_t) input_dim[1].d, 455 | inVec[1], outVec[1]); 456 | AvgPool_pt((int32_t) output_dim[1].N, (int32_t) output_dim[1].l, (int32_t) output_dim[1].b, (int32_t) output_dim[1].d, (int32_t) 2, 457 | (int32_t) 2, (int32_t) 0, (int32_t) 1, 458 | (int32_t) 0, (int32_t) 1, (int32_t) 1, 459 | (int32_t) 1, (int32_t) input_dim[1].N, (int32_t) input_dim[1].l, (int32_t) input_dim[1].b, 460 | (int32_t) input_dim[1].d, 461 | inVecMac[1], outVecMac[1]); 462 | } 463 | long long t = time_from(start); 464 | cout << "######################Performance#######################" <<endl; 465 | cout<<"Time Taken: "<<t<<" mus"<<endl; 466 | cout<<"Sent Data (MB): "<<0<<endl; 467 | cout << "########################################################" <<endl; 468 | 469 | } 470 | -------------------------------------------------------------------------------- /test/msi_convlayer.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Authors: Deevashwer Rathee 3 | Copyright: 4 | Copyright (c) 2020 Microsoft Research 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | */ 21 | 22 | #include "LinearLayer/conv-field.h" 23 | #include "LinearLayer/defines-HE.h" 24 | 25 | using namespace std; 26 | using namespace seal; 27 | using namespace emp; 28 | 29 | uint64_t prime_mod = PLAINTEXT_MODULUS; 30 | 31 | enum neural_net { 32 | NONE, 33 | MINIONN, 34 | CIFAR10 35 | }; 36 | neural_net choice_nn; 37 | neural_net def_nn = NONE; 38 | 39 | long long total_time = 0; 40 | 41 | int party = 0; 42 | int bitlength = 44; 43 | int num_threads = 8; 44 | int port = 8000; 45 | string address = "127.0.0.1"; 46 | int image_h = 28; 47 | int image_w = 28; 48 | int inp_chans = 1; 49 | int out_chans = 16; 50 | int filter_h = 5; 51 | int filter_w = 5; 52 | int stride_h = 1; 53 | int stride_w = 1; 54 | 55 | int stride = 1; 56 | int filter_precision = 12; 57 | int pad_l = 0; 58 | int pad_r = 0; 59 | string benchmark; 60 | 61 | seal::Modulus mod(prime_mod); 62 | 63 | void Conv(ConvField &he_conv, int32_t H, int32_t CI, int32_t FH, int32_t CO, 64 | int32_t zPadHLeft, int32_t zPadHRight, int32_t strideH) { 65 | int newH = 1 + (H + zPadHLeft + zPadHRight - FH) / strideH; 66 | int N = 1; 67 | int W = H; 68 | int FW = FH; 69 | int zPadWLeft = zPadHLeft; 70 | int zPadWRight = zPadHRight; 71 | int strideW = strideH; 72 | int newW = newH; 73 | vector<vector<vector<vector<uint64_t>>>> inputArr(N); 74 | vector<vector<vector<vector<uint64_t>>>> filterArr(FH); 75 | vector<vector<vector<vector<uint64_t>>>> outArr(N); 76 | 77 | PRG prg; 78 | for (int i = 0; i < N; i++) { 79 | outArr[i].resize(newH); 80 | for (int j = 0; j < newH; j++) { 81 | outArr[i][j].resize(newW); 82 | for (int k = 0; k < newW; k++) { 83 | outArr[i][j][k].resize(CO); 84 | } 85 | } 86 | } 87 | if (party == ALICE) { 88 | for (int i = 0; i < FH; i++) { 89 | filterArr[i].resize(FW); 90 | for (int j = 0; j < FW; j++) { 91 | filterArr[i][j].resize(CI); 92 | for (int k = 0; k < CI; k++) { 93 | filterArr[i][j][k].resize(CO); 94 | prg.random_data(filterArr[i][j][k].data(), CO * sizeof(uint64_t)); 95 | for (int h = 0; h < CO; h++) { 96 | filterArr[i][j][k][h] = 97 | ((int64_t)filterArr[i][j][k][h]) >> (64 - filter_precision); 98 | } 99 | } 100 | } 101 | } 102 | } 103 | for (int i = 0; i < N; i++) { 104 | inputArr[i].resize(H); 105 | for (int j = 0; j < H; j++) { 106 | inputArr[i][j].resize(W); 107 | for (int k = 0; k < W; k++) { 108 | inputArr[i][j][k].resize(CI); 109 | random_mod_p(prg, inputArr[i][j][k].data(), CI, prime_mod); 110 | } 111 | } 112 | } 113 | uint64_t comm_start = he_conv.io->counter; 114 | 115 | he_conv.convolution(N, H, W, CI, FH, FW, CO, zPadHLeft, zPadHRight, zPadWLeft, 116 | zPadWRight, strideH, strideW, inputArr, filterArr, outArr, 117 | true, true); 118 | uint64_t comm_end = he_conv.io->counter; 119 | cout << "Total Comm: " << (comm_end - comm_start) / (1.0 * (1ULL << 20)) 120 | << endl; 121 | } 122 | 123 | void Conv_First(ConvField &he_conv, int32_t H, int32_t W, int32_t FH, int32_t FW, int32_t CI, int32_t CO, int32_t strideH, 124 | int32_t strideW, bool pad_valid) { 125 | int N = 1; 126 | vector<vector<vector<vector<uint64_t>>>> inputArr(N); 127 | vector<vector<vector<vector<uint64_t>>>> filterArr(FH); 128 | 129 | PRG prg; 130 | cout<<"Party "<< party<<endl; 131 | if (party == ALICE) { 132 | for (int i = 0; i < FH; i++) { 133 | filterArr[i].resize(FW); 134 | for (int j = 0; j < FW; j++) { 135 | filterArr[i][j].resize(CI); 136 | for (int k = 0; k < CI; k++) { 137 | filterArr[i][j][k].resize(CO); 138 | random_mod_p(prg, filterArr[i][j][k].data(), CO, prime_mod); 139 | } 140 | } 141 | } 142 | } else { 143 | for (int i = 0; i < N; i++) { 144 | inputArr[i].resize(H); 145 | for (int j = 0; j < H; j++) { 146 | inputArr[i][j].resize(W); 147 | for (int k = 0; k < W; k++) { 148 | inputArr[i][j][k].resize(CI); 149 | random_mod_p(prg, inputArr[i][j][k].data(), CI, prime_mod); 150 | } 151 | } 152 | } 153 | } 154 | 155 | auto start = clock_start(); 156 | 157 | he_conv.convolution_first(H, W, CI, FH, FW, CO, strideH, strideW, pad_valid, inputArr, filterArr, 158 | false, true); 159 | long long t = time_from(start); 160 | total_time += t; 161 | } 162 | 163 | void Conv_Gen(ConvField &he_conv, int32_t H, int32_t W, int32_t FH, int32_t FW, int32_t CI, int32_t CO, int32_t strideH, 164 | int32_t strideW, bool pad_valid) { 165 | int N=1; 166 | vector<vector<vector<vector<uint64_t>>>> inputArr(N); 167 | vector<vector<vector<vector<uint64_t>>>> inputMacArr(N); 168 | 169 | vector<vector<vector<vector<uint64_t>>>> filterArr(FH); 170 | PRG prg; 171 | 172 | for (int i = 0; i < N; i++) { 173 | inputArr[i].resize(H); 174 | for (int j = 0; j < H; j++) { 175 | inputArr[i][j].resize(W); 176 | for (int k = 0; k < W; k++) { 177 | inputArr[i][j][k].resize(CI); 178 | random_mod_p(prg, inputArr[i][j][k].data(), CI, prime_mod); 179 | } 180 | } 181 | } 182 | 183 | for (int i = 0; i < N; i++) { 184 | inputMacArr[i].resize(H); 185 | for (int j = 0; j < H; j++) { 186 | inputMacArr[i][j].resize(W); 187 | for (int k = 0; k < W; k++) { 188 | inputMacArr[i][j][k].resize(CI); 189 | random_mod_p(prg, inputMacArr[i][j][k].data(), CI, prime_mod); 190 | } 191 | } 192 | } 193 | 194 | if (party == ALICE) { 195 | for (int i = 0; i < FH; i++) { 196 | filterArr[i].resize(FW); 197 | for (int j = 0; j < FW; j++) { 198 | filterArr[i][j].resize(CI); 199 | for (int k = 0; k < CI; k++) { 200 | filterArr[i][j][k].resize(CO); 201 | random_mod_p(prg, filterArr[i][j][k].data(), CO, prime_mod); 202 | } 203 | } 204 | } 205 | } 206 | auto start = clock_start(); 207 | he_conv.convolution_gen(H, W, CI, FH, FW, CO, strideH, strideW, pad_valid, inputArr, inputMacArr, filterArr, 208 | mod, false, true); 209 | long long t = time_from(start); 210 | total_time += t; 211 | } 212 | 213 | void parse_arguments(int argc, char**arg, int *party, int *port) { 214 | *party = atoi (arg[1]); 215 | address = arg[2]; 216 | *port = atoi (arg[3]); 217 | if(argc < 6) { 218 | choice_nn = def_nn; 219 | } else { 220 | choice_nn = neural_net(atoi (arg[5])); 221 | } 222 | 223 | if(choice_nn == MINIONN) 224 | benchmark = "mnist"; 225 | else 226 | benchmark = "cifar10"; 227 | } 228 | 229 | int main(int argc, char **argv) { 230 | parse_arguments(argc, argv, &party, &port); 231 | 232 | cout<<"Executing Convolution Layers ..."<<endl; 233 | cout << "=====================Configuration======================" << endl; 234 | cout<<"Role: "<< party<<" - IP Address: "<< address <<" - Port: "<<port<<" - Benchmark: "<<benchmark<<" - Bitlength: "<<bitlength<<endl; 235 | cout << "========================================================" << endl; 236 | 237 | bool pad_valid; 238 | 239 | NetIO *io = new NetIO(party == 1 ? nullptr : address.c_str(), port); 240 | uint64_t comm_sent = 0; 241 | uint64_t start_comm = io->counter; 242 | 243 | ConvField he_conv(party, io); 244 | auto start = clock_start(); 245 | start_comm = io->counter; 246 | long long t = time_from(start); 247 | 248 | total_time += t; 249 | if(choice_nn==MINIONN) { 250 | image_h = 28; 251 | image_w = 28; 252 | filter_h = 5; 253 | filter_w = 5; 254 | stride_h = 1; 255 | stride_w = 1; 256 | inp_chans = 1; 257 | out_chans = 16; 258 | pad_valid = false; 259 | Conv_First(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 260 | 261 | image_h = 12; 262 | image_w = 12; 263 | filter_h = 5; 264 | filter_w = 5; 265 | stride_h = 1; 266 | stride_w = 1; 267 | inp_chans = 16; 268 | out_chans = 16; 269 | pad_valid = false; 270 | Conv_Gen(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 271 | } else { 272 | //Layer 1 273 | image_h = 32; 274 | image_w = 32; 275 | filter_h = 3; 276 | filter_w = 3; 277 | stride_h = 1; 278 | stride_w = 1; 279 | inp_chans = 3; 280 | out_chans = 64; 281 | pad_l = 1; 282 | pad_r = 1; 283 | pad_valid = true; 284 | Conv_First(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 285 | 286 | //Layer 2 287 | image_h = 32; 288 | image_w = 32; 289 | filter_h = 3; 290 | filter_w = 3; 291 | stride_h = 1; 292 | stride_w = 1; 293 | inp_chans = 64; 294 | out_chans = 64; 295 | pad_l = 1; 296 | pad_r = 1; 297 | pad_valid = true; 298 | Conv_Gen(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 299 | 300 | //Layer 3 301 | image_h = 16; 302 | image_w = 16; 303 | filter_h = 3; 304 | filter_w = 3; 305 | stride_h = 1; 306 | stride_w = 1; 307 | inp_chans = 64; 308 | out_chans = 64; 309 | pad_l = 1; 310 | pad_r = 1; 311 | pad_valid = true; 312 | Conv_Gen(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 313 | 314 | //Layer 4 315 | image_h = 16; 316 | image_w = 16; 317 | filter_h = 3; 318 | filter_w = 3; 319 | stride_h = 1; 320 | stride_w = 1; 321 | inp_chans = 64; 322 | out_chans = 64; 323 | pad_l = 1; 324 | pad_r = 1; 325 | pad_valid = true; 326 | Conv_Gen(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 327 | 328 | //Layer 5 329 | image_h = 8; 330 | image_w = 8; 331 | filter_h = 3; 332 | filter_w = 3; 333 | stride_h = 1; 334 | stride_w = 1; 335 | inp_chans = 64; 336 | out_chans = 64; 337 | pad_l = 1; 338 | pad_r = 1; 339 | pad_valid = true; 340 | Conv_Gen(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 341 | 342 | //Layer 6 343 | image_h = 8; 344 | image_w = 8; 345 | filter_h = 1; 346 | filter_w = 1; 347 | stride_h = 1; 348 | stride_w = 1; 349 | inp_chans = 64; 350 | out_chans = 64; 351 | //no padding 352 | pad_valid = false; 353 | Conv_Gen(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 354 | 355 | //Layer 7 356 | image_h = 8; 357 | image_w = 8; 358 | filter_h = 1; 359 | filter_w = 1; 360 | stride_h = 1; 361 | stride_w = 1; 362 | inp_chans = 64; 363 | out_chans = 16; 364 | //no padding 365 | pad_valid = false; 366 | Conv_Gen(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, pad_valid); 367 | 368 | } 369 | 370 | //Conv_First(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, true); 371 | //Conv_Gen(he_conv, image_h, image_w, filter_h, filter_w, inp_chans, out_chans, stride_h, stride_w, true); 372 | cout << "######################Performance#######################" <<endl; 373 | cout<<"Time Taken: "<<total_time<<" mus"<<endl; 374 | //Calculate Communication 375 | comm_sent = (io->counter-start_comm)>>20; 376 | cout<<"Sent Data (MB): "<<comm_sent<<endl; 377 | cout << "########################################################" <<endl; 378 | 379 | io->flush(); 380 | return 0; 381 | } 382 | -------------------------------------------------------------------------------- /test/msi_linearlayer.cpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "LinearLayer/fc-field.h" 4 | #include "LinearLayer/defines-HE.h" 5 | 6 | using namespace std; 7 | using namespace emp; 8 | using namespace seal; 9 | 10 | enum neural_net { 11 | NONE, 12 | MINIONN, 13 | CIFAR10 14 | }; 15 | neural_net choice_nn; 16 | neural_net def_nn = NONE; 17 | 18 | long long total_time = 0; 19 | 20 | int32_t bitlength = 44; 21 | uint64_t prime_mod = PLAINTEXT_MODULUS; 22 | int party = 0; 23 | int num_threads = 8; 24 | int port = 8000; 25 | string address = "127.0.0.1"; 26 | int num_rows = 512; 27 | int common_dim = 1024; 28 | int filter_precision = 15; 29 | string benchmark; 30 | 31 | seal::Modulus mod(prime_mod); 32 | 33 | void LinearLayerFirstFC(FCField &he_fc, int32_t num_rows, int32_t common_dim) { 34 | int num_cols = 1; 35 | 36 | //Setup Input objects 37 | vector<vector<uint64_t>> inputs; 38 | 39 | //Setup Output shares objects 40 | vector<vector<uint64_t>> op_shares(num_rows); 41 | vector<vector<uint64_t>> mac_op_shares(num_rows); 42 | for (int i = 0; i < num_rows; i++) { 43 | op_shares[i].resize(num_cols); 44 | mac_op_shares[i].resize(num_cols); 45 | } 46 | PRG prg; 47 | 48 | //Prepare Dummy Inputs 49 | if(party == ALICE) { 50 | inputs.resize(num_rows, vector<uint64_t>(common_dim, 0)); 51 | //Create input matrix 52 | for(int i=0; i< num_rows; i++) { 53 | random_mod_p(prg, inputs[i].data(), common_dim, prime_mod); 54 | } 55 | } else { 56 | //Create input vector 57 | inputs.resize(common_dim, vector<uint64_t>(num_cols, 0)); 58 | for(int i=0; i<common_dim; i++) { 59 | random_mod_p(prg, inputs[i].data(), num_cols, prime_mod); 60 | } 61 | } 62 | he_fc.matrix_multiplication_first(num_rows, common_dim, num_cols, inputs, op_shares, mac_op_shares, mod, false, 63 | false); 64 | } 65 | 66 | void LinearLayerFC(FCField &he_fc, int32_t num_rows, int32_t common_dim) { 67 | int num_cols = 1; 68 | 69 | //Setup Input objects 70 | vector<vector<uint64_t>> matrix; 71 | vector<vector<uint64_t>> input_share; 72 | vector<vector<uint64_t>> mac_input_share; 73 | 74 | PRG prg; 75 | //Prepare Dummy Inputs 76 | input_share.resize(common_dim, vector<uint64_t>(num_cols, 0)); 77 | for(int i=0; i<common_dim; i++) { 78 | random_mod_p(prg, input_share[i].data(), num_cols, prime_mod); 79 | } 80 | 81 | mac_input_share.resize(common_dim, vector<uint64_t>(num_cols, 0)); 82 | for(int i=0; i<common_dim; i++) { 83 | random_mod_p(prg, mac_input_share[i].data(), num_cols, prime_mod); 84 | } 85 | 86 | if(party == ALICE) { 87 | matrix.resize(num_rows, vector<uint64_t>(common_dim, 0)); 88 | //Create input matrix 89 | for(int i=0; i< num_rows; i++) { 90 | random_mod_p(prg, matrix[i].data(), common_dim, prime_mod); 91 | } 92 | } 93 | auto start = clock_start(); 94 | he_fc.matrix_multiplication_gen(num_rows, common_dim, num_cols, matrix, input_share, mac_input_share, mod, false, 95 | false); 96 | long long t = time_from(start); 97 | total_time += t; 98 | } 99 | 100 | void parse_arguments(int argc, char**arg, int *party, int *port) { 101 | *party = atoi (arg[1]); 102 | address = arg[2]; 103 | *port = atoi (arg[3]); 104 | 105 | if(argc < 6) { 106 | choice_nn = def_nn; 107 | } else { 108 | choice_nn = neural_net(atoi (arg[5])); 109 | } 110 | 111 | if(choice_nn == MINIONN) 112 | benchmark = "mnist"; 113 | else 114 | benchmark = "cifar10"; 115 | } 116 | 117 | int main(int argc, char** argv){ 118 | parse_arguments(argc, argv, &party, &port); 119 | cout<<"Executing Fully-Connected Layers ..."<<endl; 120 | cout << "=====================Configuration======================" << endl; 121 | cout<<"Role: "<< party<<" - IP Address: "<< address <<" - Port: "<<port<<" - Benchmark: "<<benchmark<<" - Bitlength: "<<bitlength<<endl; 122 | cout << "========================================================" << endl; 123 | 124 | NetIO * io = new NetIO(party==ALICE ? nullptr : address.c_str(), port); 125 | 126 | int slot_count = POLY_MOD_DEGREE; 127 | shared_ptr<SEALContext> context; 128 | Encryptor *encryptor; 129 | Decryptor *decryptor; 130 | Evaluator *evaluator; 131 | BatchEncoder *encoder; 132 | GaloisKeys *galois_keys; 133 | Ciphertext *zero; 134 | 135 | //Generate Keys 136 | //generate_new_keys(party, io, slot_count, context, encryptor, decryptor, evaluator, encoder, galois_keys, zero); 137 | /*if(party == ALICE) { 138 | //Generate Server's random input 139 | 140 | } else { 141 | //Generate Client's random input 142 | for (int i = 0; i < common_dim; i++) { 143 | B[i].resize(1); 144 | random_mod_p(prg, B[i].data(), num_cols, prime_mod); 145 | } 146 | }*/ 147 | uint64_t comm_sent = 0; 148 | uint64_t start_comm = io->counter; 149 | auto start = clock_start(); 150 | start_comm = io->counter; 151 | FCField he_fc(party, io); 152 | long long t = time_from(start); 153 | total_time += t; 154 | if(choice_nn==MINIONN) { 155 | num_rows = 100; 156 | common_dim = 256; 157 | LinearLayerFC(he_fc, num_rows, common_dim); 158 | num_rows = 10; 159 | common_dim = 100; 160 | LinearLayerFC(he_fc, num_rows, common_dim); 161 | } else { 162 | num_rows = 10; 163 | common_dim = 1024; 164 | LinearLayerFC(he_fc, num_rows, common_dim); 165 | } 166 | cout << "######################Performance#######################" <<endl; 167 | cout<<"Time Taken: "<<total_time<<" mus"<<endl; 168 | //Calculate Communication 169 | comm_sent = (io->counter-start_comm)>>20; 170 | cout<<"Sent Data (MB): "<<comm_sent<<endl; 171 | cout << "########################################################" <<endl; 172 | //LinearLayerFirstFC(he_fc, num_rows, common_dim); 173 | } 174 | -------------------------------------------------------------------------------- /test/msi_microbenchmark.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-sh2pc/emp-sh2pc.h" 2 | #include <cmath> 3 | 4 | #include "seal/util/uintarith.h" 5 | #include "seal/util/uintarithsmallmod.h" 6 | #include <thread> 7 | #define MAX_THREADS 8 8 | using namespace emp; 9 | using namespace std; 10 | 11 | 12 | int num_threads = 8; 13 | 14 | //Slackoverflow Code For bit-wise shift 15 | #define SHL128(v, n) \ 16 | ({ \ 17 | __m128i v1, v2; \ 18 | \ 19 | if ((n) >= 64) \ 20 | { \ 21 | v1 = _mm_slli_si128(v, 8); \ 22 | v1 = _mm_slli_epi64(v1, (n) - 64); \ 23 | } \ 24 | else \ 25 | { \ 26 | v1 = _mm_slli_epi64(v, n); \ 27 | v2 = _mm_slli_si128(v, 8); \ 28 | v2 = _mm_srli_epi64(v2, 64 - (n)); \ 29 | v1 = _mm_or_si128(v1, v2); \ 30 | } \ 31 | v1; \ 32 | }) 33 | 34 | enum neural_net { 35 | NONE, 36 | MINIONN, 37 | CIFAR10 38 | }; 39 | 40 | neural_net choice_nn; 41 | int choose_relu; 42 | uint64_t start_comm[MAX_THREADS]; 43 | uint64_t comm_sent = 0; 44 | NetIO *ioArr[MAX_THREADS]; 45 | uint64_t prime_mod = 17592060215297; 46 | seal::Modulus mod(prime_mod); 47 | int port = 32000, def_nrelu = 1<<20, l = 44; 48 | neural_net def_nn = NONE; 49 | string address; 50 | bool run_all = false; 51 | uint64_t mac_key; 52 | PRG prg; 53 | 54 | bool verify = false; 55 | int MINIONN_RELUS[] = { 16*576, 16*64, 100*1 56 | }; 57 | 58 | int CIFAR10_RELUS[] = { 64*1024, 64*1024, 64*256, 64*256, 64*64, 64*64, 16*64 59 | }; 60 | 61 | uint64_t mod_shift(uint64_t a, uint64_t b, uint64_t prime_mod) { 62 | __m128i temp, stemp; 63 | memcpy(&temp, &a, 8); 64 | stemp = SHL128(temp, b); 65 | 66 | uint64_t input[2]; 67 | input[0] = stemp[0]; 68 | input[1] = stemp[1]; 69 | 70 | uint64_t result = seal::util::barrett_reduce_128(input, mod); 71 | 72 | return result; 73 | } 74 | 75 | uint64_t mod_mult(uint64_t a, uint64_t b) { 76 | unsigned long long temp_result[2]; 77 | seal::util::multiply_uint64(a, b, temp_result); 78 | 79 | /*uint64_t input[2]; 80 | input[0] = res[0]; 81 | input[1] = res[1];*/ 82 | uint64_t result = seal::util::barrett_reduce_128(temp_result, mod); 83 | return result; 84 | } 85 | 86 | 87 | //Referred SCI OT repo's logic to pack ot messages 88 | void pack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 89 | uint64_t beg_idx = 0; 90 | uint64_t end_idx = 0; 91 | uint64_t beg_blk = 0; 92 | uint64_t end_blk = 0; 93 | uint64_t temp_blk = 0; 94 | uint64_t mask = (1ULL << bitlen) - 1; 95 | uint64_t pack_blk_size = 64; 96 | 97 | if (bitlen == 64) 98 | mask = -1; 99 | 100 | for (int i = 0; i < pack_size; i++) { 101 | pack_table[i] = 0; 102 | } 103 | 104 | for (int i = 0; i < batch_size; i++) { 105 | beg_idx = i * bitlen; 106 | end_idx = beg_idx + bitlen; 107 | end_idx -= 1; 108 | beg_blk = beg_idx / pack_blk_size; 109 | end_blk = end_idx / pack_blk_size; 110 | 111 | if (beg_blk == end_blk) { 112 | pack_table[beg_blk] ^= (ciphertexts[i] & mask) << (beg_idx % pack_blk_size); 113 | } else { 114 | temp_blk = (ciphertexts[i] & mask); 115 | pack_table[beg_blk] ^= (temp_blk) << (beg_idx % pack_blk_size); 116 | pack_table[end_blk] ^= (temp_blk) >> (pack_blk_size - (beg_idx % pack_blk_size)); 117 | } 118 | } 119 | } 120 | 121 | //Referred SCI OT repo's logic to unpack ot messages 122 | void unpack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 123 | uint64_t beg_idx = 0; 124 | uint64_t end_idx = 0; 125 | uint64_t beg_blk = 0; 126 | uint64_t end_blk = 0; 127 | uint64_t temp_blk = 0; 128 | uint64_t mask = (1ULL << bitlen) - 1; 129 | uint64_t pack_blk_size = 64; 130 | 131 | for (int i = 0; i < batch_size; i++) { 132 | beg_idx = i * bitlen; 133 | end_idx = beg_idx + bitlen - 1; 134 | beg_blk = beg_idx / pack_blk_size; 135 | end_blk = end_idx / pack_blk_size; 136 | 137 | if (beg_blk == end_blk) { 138 | ciphertexts[i] = (pack_table[beg_blk] >> (beg_idx % pack_blk_size)) & mask; 139 | } else { 140 | ciphertexts[i] = 0; 141 | ciphertexts[i] ^= (pack_table[beg_blk] >> (beg_idx % pack_blk_size)); 142 | ciphertexts[i] ^= (pack_table[end_blk] << (pack_blk_size - (beg_idx % pack_blk_size))); 143 | ciphertexts[i] = ciphertexts[i] & mask; 144 | } 145 | } 146 | } 147 | 148 | void create_ciphertexts(Integer *garbled_data, block label_delta, uint64_t *ciphertexts, uint64_t* server_shares, int bitlen, int nrelu, uint64_t alpha, int l_idx) { 149 | uint64_t delta_int; 150 | memcpy(&delta_int, &label_delta[l_idx], 8); 151 | 152 | uint64_t mask = (1ULL << bitlen) - 1; 153 | uint64_t label_temp; 154 | uint64_t **random_val = (uint64_t **)malloc(nrelu*sizeof(uint64_t*)); 155 | uint8_t pnp, cpnp; 156 | for(int i=0; i<nrelu; i++) { 157 | random_val[i] = (uint64_t *)malloc(bitlen*sizeof(uint64_t)); 158 | } 159 | 160 | for(int i=0; i<nrelu; i++) { 161 | server_shares[i] = 0; 162 | for(int j=0; j<bitlen; j++) { 163 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 164 | prg.random_data(&random_val[i][j], 8); 165 | random_val[i][j] %= prime_mod; 166 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 167 | cpnp = 1 - pnp; 168 | ciphertexts[(i*bitlen+j)*2+pnp] = (random_val[i][j])^(label_temp & mask); 169 | ciphertexts[(i*bitlen+j)*2+cpnp] = ((random_val[i][j]+alpha)%prime_mod)^((label_temp^delta_int) & mask); 170 | if(i==0) { 171 | uint64_t l1 = label_temp & mask, l2 = label_temp^delta_int & mask; 172 | } 173 | server_shares[i] = (server_shares[i] + mod_shift(random_val[i][j],j,prime_mod))%prime_mod; 174 | } 175 | server_shares[i] = prime_mod - server_shares[i]; 176 | } 177 | } 178 | 179 | void decrypt_ciphertexts(Integer *garbled_data, uint64_t *ciphertexts, uint64_t* client_shares, int bitlen, int nrelu, int l_idx) { 180 | uint64_t label_temp; 181 | uint8_t pnp; 182 | uint64_t random_val; 183 | 184 | uint64_t mask = (1ULL << bitlen) - 1; 185 | 186 | for(int i=0; i<nrelu; i++) { 187 | client_shares[i] = 0; 188 | for(int j=0; j< bitlen; j++) { 189 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 190 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 191 | random_val = ciphertexts[(i*bitlen+j)*2+pnp]^(label_temp & mask); 192 | client_shares[i] = (client_shares[i] + mod_shift(random_val,j,prime_mod))%prime_mod; 193 | } 194 | } 195 | } 196 | 197 | void msi_relu_6(int party, NetIO* io, uint64_t* inputs, int nrelu, int bitlen, uint64_t* ip_ss, uint64_t* op_ss, uint64_t* op_mss) { 198 | //Public prime values 199 | Integer p(bitlen + 1, prime_mod, PUBLIC); 200 | Integer p_mod2(bitlen, prime_mod/2, PUBLIC); 201 | Integer zero(bitlen, 0, PUBLIC); 202 | Integer six(bitlen, 6, PUBLIC); 203 | 204 | //Assign Inputs 205 | Integer *X = new Integer[nrelu]; 206 | for(int i = 0; i < nrelu; ++i) 207 | X[i] = Integer(bitlen+1, inputs[i], ALICE); 208 | Integer *Y = new Integer[nrelu]; 209 | for(int i = 0; i < nrelu; ++i) 210 | Y[i] = Integer(bitlen+1, inputs[i], BOB); 211 | 212 | Integer *S = new Integer[nrelu]; 213 | Integer *T = new Integer[nrelu]; 214 | 215 | //Check if Bob's share is < p 216 | Bit res[nrelu]; 217 | for(int i=0; i < nrelu; ++i) 218 | res[i] = Y[i] > p; 219 | 220 | for(int i=0; i < nrelu; ++i) { 221 | //Perform mod p 222 | Integer s0 = X[i]; 223 | //s0.resize(s0.size()+1); 224 | 225 | Integer s1 = Y[i]; 226 | //s1.resize(s1.size()+1); 227 | 228 | Integer sum = s0 + s1; 229 | 230 | Integer mod_p_val = sum - p; 231 | 232 | Bit borrow_bit = mod_p_val[mod_p_val.size()-1]; 233 | 234 | Integer s = mod_p_val.select(borrow_bit, sum); 235 | 236 | S[i] = s; 237 | 238 | //Perform RELU 239 | Integer p2_minus_s = p_mod2-s; 240 | 241 | Bit is_negative = p2_minus_s[p2_minus_s.size()-1]; 242 | 243 | Integer relu_s = s.select(is_negative, zero); 244 | 245 | Integer six_minus_res = six - relu_s; 246 | Bit is_greater_than_six = six_minus_res[six_minus_res.size()-1]; 247 | 248 | Integer res = relu_s.select(is_greater_than_six, six); 249 | 250 | T[i] = res; 251 | } 252 | 253 | int pack_size = ceil(nrelu*bitlen*bitlen*2.0/(8*sizeof(uint64_t))); 254 | int batch_size = nrelu*bitlen*2; 255 | 256 | uint64_t *ip_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 257 | 258 | uint64_t *op_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 259 | 260 | uint64_t *op_mcts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 261 | 262 | uint64_t *ip_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 263 | uint64_t *op_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 264 | uint64_t *opm_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 265 | 266 | if(party == ALICE) { 267 | 268 | create_ciphertexts(S, delta_used, ip_cts, ip_ss, bitlen, nrelu, mac_key, 1); 269 | create_ciphertexts(T, delta_used, op_cts, op_ss, bitlen, nrelu, 1, 0); 270 | create_ciphertexts(T, delta_used, op_mcts, op_mss, bitlen, nrelu, mac_key, 1); 271 | 272 | pack_decryption_table(ip_pack_table, ip_cts, pack_size, batch_size, bitlen); 273 | pack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 274 | pack_decryption_table(opm_pack_table, op_mcts, pack_size, batch_size, bitlen); 275 | 276 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 277 | io->send_data(ip_pack_table, sizeof(uint64_t) * pack_size); 278 | io->send_data(op_pack_table, sizeof(uint64_t) * pack_size); 279 | io->send_data(opm_pack_table, sizeof(uint64_t) * pack_size); 280 | } else { 281 | io->recv_data(ip_pack_table, sizeof(uint64_t) * pack_size); 282 | io->recv_data(op_pack_table, sizeof(uint64_t) * pack_size); 283 | io->recv_data(opm_pack_table, sizeof(uint64_t) * pack_size); 284 | 285 | unpack_decryption_table(ip_pack_table, ip_cts, pack_size, batch_size, bitlen); 286 | unpack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 287 | unpack_decryption_table(opm_pack_table, op_mcts, pack_size, batch_size, bitlen); 288 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 289 | 290 | decrypt_ciphertexts(S, ip_cts, ip_ss, bitlen, nrelu, 1); 291 | decrypt_ciphertexts(T, op_cts, op_ss, bitlen, nrelu, 0); 292 | decrypt_ciphertexts(T, op_mcts, op_mss, bitlen, nrelu, 1); 293 | } 294 | } 295 | 296 | void msi_relu(int party, NetIO* io, uint64_t inputs[], int nrelu, int bitlen, uint64_t* ip_ss, uint64_t* op_ss, uint64_t* op_mss) { 297 | uint64_t comm_sent; 298 | //Public prime values 299 | Integer p(bitlen + 1, prime_mod, PUBLIC); 300 | Integer p_mod2(bitlen, prime_mod/2, PUBLIC); 301 | Integer zero(bitlen, 0, PUBLIC); 302 | 303 | //Assign Inputs 304 | Integer *X = new Integer[nrelu]; 305 | for(int i = 0; i < nrelu; ++i) 306 | X[i] = Integer(bitlen+1, inputs[i], ALICE); 307 | Integer *Y = new Integer[nrelu]; 308 | for(int i = 0; i < nrelu; ++i) 309 | Y[i] = Integer(bitlen+1, inputs[i], BOB); 310 | 311 | Integer *S = new Integer[nrelu]; 312 | Integer *T = new Integer[nrelu]; 313 | 314 | //Check if Bob's share is < p 315 | Bit res[nrelu]; 316 | for(int i=0; i < nrelu; ++i) 317 | res[i] = Y[i] > p; 318 | 319 | for(int i=0; i < nrelu; ++i) { 320 | //Perform mod p 321 | Integer s0 = X[i]; 322 | //s0.resize(s0.size()+1); 323 | 324 | Integer s1 = Y[i]; 325 | //s1.resize(s1.size()+1); 326 | 327 | Integer sum = s0 + s1; 328 | 329 | Integer mod_p_val = sum - p; 330 | 331 | Bit borrow_bit = mod_p_val[mod_p_val.size()-1]; 332 | 333 | Integer s = mod_p_val.select(borrow_bit, sum); 334 | 335 | S[i] = s; 336 | 337 | //Perform RELU 338 | Integer p2_minus_s = p_mod2-s; 339 | 340 | Bit is_negative = p2_minus_s[p2_minus_s.size()-1]; 341 | 342 | Integer relu_s = s.select(is_negative, zero); 343 | 344 | T[i] = relu_s; 345 | } 346 | 347 | int pack_size = ceil(nrelu*bitlen*bitlen*2.0/(8*sizeof(uint64_t))); 348 | int batch_size = nrelu*bitlen*2; 349 | 350 | uint64_t *ip_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 351 | 352 | uint64_t *op_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 353 | 354 | uint64_t *op_mcts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 355 | 356 | uint64_t *ip_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 357 | uint64_t *op_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 358 | uint64_t *opm_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 359 | 360 | if(party == ALICE) { 361 | 362 | create_ciphertexts(S, delta_used, ip_cts, ip_ss, bitlen, nrelu, mac_key, 1); 363 | create_ciphertexts(T, delta_used, op_cts, op_ss, bitlen, nrelu, 1, 0); 364 | create_ciphertexts(T, delta_used, op_mcts, op_mss, bitlen, nrelu, mac_key, 1); 365 | 366 | pack_decryption_table(ip_pack_table, ip_cts, pack_size, batch_size, bitlen); 367 | pack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 368 | pack_decryption_table(opm_pack_table, op_mcts, pack_size, batch_size, bitlen); 369 | 370 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 371 | io->send_data(ip_pack_table, sizeof(uint64_t) * pack_size); 372 | io->send_data(op_pack_table, sizeof(uint64_t) * pack_size); 373 | io->send_data(opm_pack_table, sizeof(uint64_t) * pack_size); 374 | } else { 375 | io->recv_data(ip_pack_table, sizeof(uint64_t) * pack_size); 376 | io->recv_data(op_pack_table, sizeof(uint64_t) * pack_size); 377 | io->recv_data(opm_pack_table, sizeof(uint64_t) * pack_size); 378 | 379 | unpack_decryption_table(ip_pack_table, ip_cts, pack_size, batch_size, bitlen); 380 | unpack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 381 | unpack_decryption_table(opm_pack_table, op_mcts, pack_size, batch_size, bitlen); 382 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 383 | 384 | decrypt_ciphertexts(S, ip_cts, ip_ss, bitlen, nrelu, 1); 385 | decrypt_ciphertexts(T, op_cts, op_ss, bitlen, nrelu, 0); 386 | decrypt_ciphertexts(T, op_mcts, op_mss, bitlen, nrelu, 1); 387 | } 388 | } 389 | 390 | void parse_arguments(int argc, char**arg, int *party, int *port, int *bitlen, int *nrelu) { 391 | *party = atoi (arg[1]); 392 | address = arg[2]; 393 | *port = atoi (arg[3]); 394 | if(argc < 5) { 395 | *bitlen = l; 396 | } else { 397 | *bitlen = atoi(arg[4]); 398 | } 399 | 400 | if(argc < 6) { 401 | choose_relu = 0; 402 | } else { 403 | choose_relu = atoi(arg[5]); 404 | } 405 | 406 | if(argc < 7) { 407 | *nrelu = def_nrelu; 408 | } else { 409 | *nrelu = atoi(arg[6]); 410 | } 411 | 412 | if(argc < 8) { 413 | num_threads = 8; 414 | } else { 415 | num_threads = atoi(arg[7]); 416 | } 417 | } 418 | 419 | void thread_process(int tid, int party, int choice, uint64_t* inputs, int nrelu, int bitlen, uint64_t* ip_ss, uint64_t* op_ss, uint64_t* op_mss) { 420 | setup_semi_honest(ioArr[tid], party); 421 | 422 | uint64_t nr_per_thread = nrelu/num_threads; 423 | uint64_t r = nrelu % num_threads; 424 | uint64_t actual_per_thread; 425 | if(tid == num_threads-1) 426 | actual_per_thread = nr_per_thread + r; 427 | else 428 | actual_per_thread = nr_per_thread; 429 | 430 | uint64_t offset = tid*nr_per_thread; 431 | 432 | if(choice == 0) { 433 | msi_relu_6(party, ioArr[tid], inputs+offset, actual_per_thread, bitlen, ip_ss+offset, op_ss+offset, op_mss+offset); 434 | } else { 435 | msi_relu(party, ioArr[tid], inputs+offset, actual_per_thread, bitlen, ip_ss+offset, op_ss+offset, op_mss+offset); 436 | } 437 | 438 | ioArr[tid]->flush(); 439 | finalize_semi_honest(); 440 | } 441 | 442 | int main(int argc, char** argv) { 443 | srand(time(NULL)); 444 | int port, party, nrelu, bitlen; 445 | //Parse input arguments and configure parameters 446 | parse_arguments(argc, argv, &party, &port, &bitlen, &nrelu); 447 | cout<<"Running Microbenchmarks ..."<<endl; 448 | cout << "=====================Configuration======================" << endl; 449 | cout<<"Party Id: "<< party<<" - Server IP Address: "<< address <<" - Port: "<<port<<" - NRelu: "<<nrelu<<" - Bitlen: "<<bitlen<<" - Choice RELU: "<<choose_relu<<" - #Threads: "<<num_threads<<endl; 450 | cout << "========================================================" << endl; 451 | //Prepare Inputs 452 | std::random_device rd; 453 | std::mt19937_64 eng(rd()); 454 | std::uniform_int_distribution<uint64_t> distr; 455 | 456 | uint64_t* inputs=(uint64_t *)malloc(nrelu*sizeof(uint64_t)); 457 | for(int i = 0; i < nrelu; ++i) 458 | inputs[i] = distr(eng)%prime_mod; 459 | 460 | uint64_t *ip_ss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 461 | uint64_t *op_ss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 462 | uint64_t *op_mss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 463 | 464 | for(int i=0; i <num_threads; i++) { 465 | ioArr[i] = new NetIO(party==ALICE ? nullptr : address.c_str(), port+i); 466 | } 467 | 468 | //Communication Initialization 469 | for(int i=0; i<num_threads; i++) 470 | start_comm[i] = ioArr[i]->counter; 471 | 472 | //Time Begin 473 | auto start = clock_start(); 474 | 475 | if(party == ALICE) { 476 | prg.random_data(&mac_key, 8); 477 | mac_key %= prime_mod; 478 | } 479 | 480 | std::thread relu_threads[num_threads]; 481 | for(int i=0; i<num_threads; i++) { 482 | relu_threads[i] = std::thread(thread_process, i, party, choose_relu, inputs, nrelu, bitlen, ip_ss, op_ss, op_mss); 483 | } 484 | 485 | //Join 486 | for(int i=0; i<num_threads; i++) { 487 | relu_threads[i].join(); 488 | } 489 | //Time End 490 | long long t = time_from(start); 491 | cout << "######################Performance#######################" <<endl; 492 | cout<<"Time Taken: "<<t<<" mus"<<endl; 493 | //Calculate Communication 494 | comm_sent = 0; 495 | for(int i=0; i<num_threads; i++) { 496 | comm_sent += (ioArr[i]->counter-start_comm[i]); 497 | } 498 | 499 | cout<<"Sent Data (Bytes): "<<comm_sent<<endl; 500 | comm_sent = comm_sent>>10; 501 | cout<<"Sent Data (KB): "<<comm_sent<<endl; 502 | comm_sent = comm_sent>>10; 503 | cout<<"Sent Data (MB): "<<comm_sent<<endl; 504 | cout << "########################################################" <<endl; 505 | 506 | 507 | //cout<<"nrelu: "<<nrelu<<endl; 508 | //Test Protocol 509 | if(verify) { 510 | ioArr[0] = new NetIO(party==ALICE ? nullptr : address.c_str(), port); 511 | if(party == BOB) { 512 | ioArr[0]->send_data(inputs, sizeof(uint64_t) * nrelu); 513 | ioArr[0]->send_data(ip_ss, sizeof(uint64_t) * nrelu); 514 | ioArr[0]->send_data(op_ss, sizeof(uint64_t) * nrelu); 515 | ioArr[0]->send_data(op_mss, sizeof(uint64_t) * nrelu); 516 | } else { 517 | uint64_t inputs_1[nrelu]; 518 | uint64_t inputs_res[nrelu]; 519 | uint64_t relu_res[nrelu]; 520 | 521 | uint64_t *ip_ssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 522 | uint64_t *op_ssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 523 | uint64_t *op_mssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 524 | 525 | ioArr[0]->recv_data(inputs_1, sizeof(uint64_t) * nrelu); 526 | ioArr[0]->recv_data(ip_ssc, sizeof(uint64_t) * nrelu); 527 | ioArr[0]->recv_data(op_ssc, sizeof(uint64_t) * nrelu); 528 | ioArr[0]->recv_data(op_mssc, sizeof(uint64_t) * nrelu); 529 | for(int i=0; i< nrelu; i++) { 530 | inputs_res[i] = (inputs_1[i] + inputs[i])%prime_mod; 531 | if(inputs_res[i] > prime_mod/2) { 532 | relu_res[i] = 0; 533 | } else { 534 | if(inputs_res[i] > 6) 535 | relu_res[i] = 6; 536 | else 537 | relu_res[i] = inputs_res[i]; 538 | } 539 | } 540 | 541 | uint64_t ip_shares, ip_corr, op_shares, op_corr, opm_shares, opm_corr; 542 | uint64_t ctr_ip, ctr_op, ctr_opm=0; 543 | 544 | for(int i=0; i<nrelu; i++) { 545 | ip_shares = (ip_ss[i]+ip_ssc[i])%prime_mod; 546 | ip_corr = mod_mult(mac_key,inputs_res[i]); 547 | if(ip_shares == ip_corr) 548 | ctr_ip++; 549 | else { 550 | cout<<"Index: "<<i<<endl; 551 | break; 552 | } 553 | 554 | op_shares = (op_ss[i]+op_ssc[i])%prime_mod; 555 | if(op_shares == relu_res[i]) 556 | ctr_op++; 557 | 558 | opm_shares = (op_mss[i] + op_mssc[i])%prime_mod; 559 | opm_corr = mod_mult(mac_key,relu_res[i]); 560 | if(opm_shares == opm_corr) 561 | ctr_opm++; 562 | } 563 | cout << "**********************Verification**********************" <<endl; 564 | cout<<"Correct Input Macs: "<< ctr_ip<<endl; 565 | cout<<"Correct Outputs: "<< ctr_op<<endl; 566 | cout<<"Correct Output Macs: "<< ctr_opm<<endl; 567 | cout << "********************************************************" <<endl; 568 | } 569 | } 570 | //Performance Result 571 | 572 | } 573 | -------------------------------------------------------------------------------- /test/msi_relu.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-sh2pc/emp-sh2pc.h" 2 | #include <cmath> 3 | 4 | #include "seal/util/uintarith.h" 5 | #include "seal/util/uintarithsmallmod.h" 6 | using namespace emp; 7 | using namespace std; 8 | 9 | //Slackoverflow Code For bit-wise shift 10 | #define SHL128(v, n) \ 11 | ({ \ 12 | __m128i v1, v2; \ 13 | \ 14 | if ((n) >= 64) \ 15 | { \ 16 | v1 = _mm_slli_si128(v, 8); \ 17 | v1 = _mm_slli_epi64(v1, (n) - 64); \ 18 | } \ 19 | else \ 20 | { \ 21 | v1 = _mm_slli_epi64(v, n); \ 22 | v2 = _mm_slli_si128(v, 8); \ 23 | v2 = _mm_srli_epi64(v2, 64 - (n)); \ 24 | v1 = _mm_or_si128(v1, v2); \ 25 | } \ 26 | v1; \ 27 | }) 28 | 29 | enum neural_net { 30 | NONE, 31 | MINIONN, 32 | CIFAR10 33 | }; 34 | 35 | uint64_t start_comm; 36 | uint64_t prime_mod = 17592060215297; 37 | seal::Modulus mod(prime_mod); 38 | int port = 32000, def_nrelu = 1<<20, l = 44; 39 | neural_net def_nn = NONE; 40 | string address; 41 | 42 | uint64_t mac_key; 43 | PRG prg; 44 | 45 | bool verify = true; 46 | int MINIONN_RELUS[] = { 16*576, 16*64, 100*1 47 | }; 48 | 49 | int CIFAR10_RELUS[] = { 64*1024, 64*1024, 64*256, 64*256, 64*64, 64*64, 16*64 50 | }; 51 | 52 | uint64_t mod_shift(uint64_t a, uint64_t b, uint64_t prime_mod) { 53 | __m128i temp, stemp; 54 | memcpy(&temp, &a, 8); 55 | stemp = SHL128(temp, b); 56 | 57 | uint64_t input[2]; 58 | input[0] = stemp[0]; 59 | input[1] = stemp[1]; 60 | 61 | uint64_t result = seal::util::barrett_reduce_128(input, mod); 62 | 63 | return result; 64 | } 65 | 66 | uint64_t mod_mult(uint64_t a, uint64_t b) { 67 | unsigned long long temp_result[2]; 68 | seal::util::multiply_uint64(a, b, temp_result); 69 | 70 | /*uint64_t input[2]; 71 | input[0] = res[0]; 72 | input[1] = res[1];*/ 73 | uint64_t result = seal::util::barrett_reduce_128(temp_result, mod); 74 | return result; 75 | } 76 | 77 | 78 | //Referred SCI OT repo's logic to pack ot messages 79 | void pack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 80 | uint64_t beg_idx = 0; 81 | uint64_t end_idx = 0; 82 | uint64_t beg_blk = 0; 83 | uint64_t end_blk = 0; 84 | uint64_t temp_blk = 0; 85 | uint64_t mask = (1ULL << bitlen) - 1; 86 | uint64_t pack_blk_size = 64; 87 | 88 | if (bitlen == 64) 89 | mask = -1; 90 | 91 | for (int i = 0; i < pack_size; i++) { 92 | pack_table[i] = 0; 93 | } 94 | 95 | for (int i = 0; i < batch_size; i++) { 96 | beg_idx = i * bitlen; 97 | end_idx = beg_idx + bitlen; 98 | end_idx -= 1; 99 | beg_blk = beg_idx / pack_blk_size; 100 | end_blk = end_idx / pack_blk_size; 101 | 102 | if (beg_blk == end_blk) { 103 | pack_table[beg_blk] ^= (ciphertexts[i] & mask) << (beg_idx % pack_blk_size); 104 | } else { 105 | temp_blk = (ciphertexts[i] & mask); 106 | pack_table[beg_blk] ^= (temp_blk) << (beg_idx % pack_blk_size); 107 | pack_table[end_blk] ^= (temp_blk) >> (pack_blk_size - (beg_idx % pack_blk_size)); 108 | } 109 | } 110 | } 111 | 112 | //Referred SCI OT repo's logic to unpack ot messages 113 | void unpack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 114 | uint64_t beg_idx = 0; 115 | uint64_t end_idx = 0; 116 | uint64_t beg_blk = 0; 117 | uint64_t end_blk = 0; 118 | uint64_t temp_blk = 0; 119 | uint64_t mask = (1ULL << bitlen) - 1; 120 | uint64_t pack_blk_size = 64; 121 | 122 | for (int i = 0; i < batch_size; i++) { 123 | beg_idx = i * bitlen; 124 | end_idx = beg_idx + bitlen - 1; 125 | beg_blk = beg_idx / pack_blk_size; 126 | end_blk = end_idx / pack_blk_size; 127 | 128 | if (beg_blk == end_blk) { 129 | ciphertexts[i] = (pack_table[beg_blk] >> (beg_idx % pack_blk_size)) & mask; 130 | } else { 131 | ciphertexts[i] = 0; 132 | ciphertexts[i] ^= (pack_table[beg_blk] >> (beg_idx % pack_blk_size)); 133 | ciphertexts[i] ^= (pack_table[end_blk] << (pack_blk_size - (beg_idx % pack_blk_size))); 134 | ciphertexts[i] = ciphertexts[i] & mask; 135 | } 136 | } 137 | } 138 | 139 | void create_ciphertexts(Integer *garbled_data, block label_delta, uint64_t *ciphertexts, uint64_t* server_shares, int bitlen, int nrelu, uint64_t alpha, int l_idx) { 140 | uint64_t delta_int; 141 | memcpy(&delta_int, &label_delta[l_idx], 8); 142 | 143 | uint64_t mask = (1ULL << bitlen) - 1; 144 | uint64_t label_temp; 145 | uint64_t **random_val = (uint64_t **)malloc(nrelu*sizeof(uint64_t*)); 146 | uint8_t pnp, cpnp; 147 | for(int i=0; i<nrelu; i++) { 148 | random_val[i] = (uint64_t *)malloc(bitlen*sizeof(uint64_t)); 149 | } 150 | 151 | for(int i=0; i<nrelu; i++) { 152 | server_shares[i] = 0; 153 | for(int j=0; j<bitlen; j++) { 154 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 155 | prg.random_data(&random_val[i][j], 8); 156 | random_val[i][j] %= prime_mod; 157 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 158 | cpnp = 1 - pnp; 159 | ciphertexts[(i*bitlen+j)*2+pnp] = (random_val[i][j])^(label_temp & mask); 160 | ciphertexts[(i*bitlen+j)*2+cpnp] = ((random_val[i][j]+alpha)%prime_mod)^((label_temp^delta_int) & mask); 161 | if(i==0) { 162 | uint64_t l1 = label_temp & mask, l2 = label_temp^delta_int & mask; 163 | } 164 | server_shares[i] = (server_shares[i] + mod_shift(random_val[i][j],j,prime_mod))%prime_mod; 165 | } 166 | server_shares[i] = prime_mod - server_shares[i]; 167 | } 168 | } 169 | 170 | void decrypt_ciphertexts(Integer *garbled_data, uint64_t *ciphertexts, uint64_t* client_shares, int bitlen, int nrelu, int l_idx) { 171 | uint64_t label_temp; 172 | uint8_t pnp; 173 | uint64_t random_val; 174 | 175 | uint64_t mask = (1ULL << bitlen) - 1; 176 | 177 | for(int i=0; i<nrelu; i++) { 178 | client_shares[i] = 0; 179 | for(int j=0; j< bitlen; j++) { 180 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 181 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 182 | random_val = ciphertexts[(i*bitlen+j)*2+pnp]^(label_temp & mask); 183 | client_shares[i] = (client_shares[i] + mod_shift(random_val,j,prime_mod))%prime_mod; 184 | } 185 | } 186 | } 187 | 188 | void msi_relu(int party, NetIO* io, uint64_t inputs[], int nrelu, int bitlen, uint64_t* ip_ss, uint64_t* op_ss, uint64_t* op_mss) { 189 | uint64_t comm_sent; 190 | //Public prime values 191 | Integer p(bitlen + 1, prime_mod, PUBLIC); 192 | Integer p_mod2(bitlen, prime_mod/2, PUBLIC); 193 | Integer zero(bitlen, 0, PUBLIC); 194 | 195 | //Assign Inputs 196 | Integer *X = new Integer[nrelu]; 197 | for(int i = 0; i < nrelu; ++i) 198 | X[i] = Integer(bitlen+1, inputs[i], ALICE); 199 | Integer *Y = new Integer[nrelu]; 200 | for(int i = 0; i < nrelu; ++i) 201 | Y[i] = Integer(bitlen+1, inputs[i], BOB); 202 | 203 | Integer *S = new Integer[nrelu]; 204 | Integer *T = new Integer[nrelu]; 205 | 206 | //Check if Bob's share is < p 207 | Bit res[nrelu]; 208 | for(int i=0; i < nrelu; ++i) 209 | res[i] = Y[i] > p; 210 | 211 | for(int i=0; i < nrelu; ++i) { 212 | //Perform mod p 213 | Integer s0 = X[i]; 214 | //s0.resize(s0.size()+1); 215 | 216 | Integer s1 = Y[i]; 217 | //s1.resize(s1.size()+1); 218 | 219 | Integer sum = s0 + s1; 220 | 221 | Integer mod_p_val = sum - p; 222 | 223 | Bit borrow_bit = mod_p_val[mod_p_val.size()-1]; 224 | 225 | Integer s = mod_p_val.select(borrow_bit, sum); 226 | 227 | S[i] = s; 228 | 229 | //Perform RELU 230 | Integer p2_minus_s = p_mod2-s; 231 | 232 | Bit is_negative = p2_minus_s[p2_minus_s.size()-1]; 233 | 234 | Integer relu_s = s.select(is_negative, zero); 235 | 236 | T[i] = relu_s; 237 | } 238 | 239 | int pack_size = ceil(nrelu*bitlen*bitlen*2.0/(8*sizeof(uint64_t))); 240 | int batch_size = nrelu*bitlen*2; 241 | 242 | uint64_t *ip_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 243 | 244 | uint64_t *op_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 245 | 246 | uint64_t *op_mcts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 247 | 248 | uint64_t *ip_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 249 | uint64_t *op_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 250 | uint64_t *opm_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 251 | 252 | if(party == ALICE) { 253 | 254 | create_ciphertexts(S, delta_used, ip_cts, ip_ss, bitlen, nrelu, mac_key, 1); 255 | create_ciphertexts(T, delta_used, op_cts, op_ss, bitlen, nrelu, 1, 0); 256 | create_ciphertexts(T, delta_used, op_mcts, op_mss, bitlen, nrelu, mac_key, 1); 257 | 258 | pack_decryption_table(ip_pack_table, ip_cts, pack_size, batch_size, bitlen); 259 | pack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 260 | pack_decryption_table(opm_pack_table, op_mcts, pack_size, batch_size, bitlen); 261 | 262 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 263 | io->send_data(ip_pack_table, sizeof(uint64_t) * pack_size); 264 | io->send_data(op_pack_table, sizeof(uint64_t) * pack_size); 265 | io->send_data(opm_pack_table, sizeof(uint64_t) * pack_size); 266 | } else { 267 | io->recv_data(ip_pack_table, sizeof(uint64_t) * pack_size); 268 | io->recv_data(op_pack_table, sizeof(uint64_t) * pack_size); 269 | io->recv_data(opm_pack_table, sizeof(uint64_t) * pack_size); 270 | 271 | unpack_decryption_table(ip_pack_table, ip_cts, pack_size, batch_size, bitlen); 272 | unpack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 273 | unpack_decryption_table(opm_pack_table, op_mcts, pack_size, batch_size, bitlen); 274 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 275 | 276 | decrypt_ciphertexts(S, ip_cts, ip_ss, bitlen, nrelu, 1); 277 | decrypt_ciphertexts(T, op_cts, op_ss, bitlen, nrelu, 0); 278 | decrypt_ciphertexts(T, op_mcts, op_mss, bitlen, nrelu, 1); 279 | } 280 | } 281 | 282 | void parse_arguments(int argc, char**arg, int *party, int *port, int *bitlen, int *nrelu) { 283 | neural_net choice_nn; 284 | *party = atoi (arg[1]); 285 | address = arg[2]; 286 | *port = atoi (arg[3]); 287 | if(argc < 5) { 288 | *bitlen = l; 289 | } else { 290 | *bitlen = atoi(arg[4]); 291 | } 292 | if(argc < 6) { 293 | choice_nn =def_nn; 294 | } else { 295 | choice_nn = neural_net(atoi (arg[5])); 296 | } 297 | 298 | switch(choice_nn) { 299 | case NONE: { 300 | if(argc < 7) { 301 | *nrelu = def_nrelu; 302 | } else { 303 | *nrelu = atoi(arg[6]); 304 | } 305 | } 306 | break; 307 | case MINIONN: { 308 | *nrelu = 0; 309 | int len = *(&MINIONN_RELUS+1)-MINIONN_RELUS; 310 | for(int i=0; i< len; i++) { 311 | *nrelu += MINIONN_RELUS[i]; 312 | } 313 | } 314 | break; 315 | case CIFAR10: { 316 | *nrelu = 0; 317 | int len = *(&CIFAR10_RELUS+1)-CIFAR10_RELUS; 318 | for(int i=0; i< len; i++) { 319 | *nrelu += CIFAR10_RELUS[i]; 320 | } 321 | } 322 | } 323 | } 324 | 325 | int main(int argc, char** argv) { 326 | srand(time(NULL)); 327 | int port, party, nrelu, bitlen; 328 | //Parse input arguments and configure parameters 329 | parse_arguments(argc, argv, &party, &port, &bitlen, &nrelu); 330 | cout << "=====================Configuration======================" << endl; 331 | cout<<"Party Id: "<< party<<", Server IP Address: "<< address <<", Port: "<<port<<", NRelu: "<<nrelu<<", Bitlen: "<<bitlen<<endl; 332 | cout << "========================================================" << endl; 333 | //Prepare Inputs 334 | std::random_device rd; 335 | std::mt19937_64 eng(rd()); 336 | std::uniform_int_distribution<uint64_t> distr; 337 | 338 | uint64_t inputs[nrelu]; 339 | for(int i = 0; i < nrelu; ++i) 340 | inputs[i] = distr(eng)%prime_mod; 341 | 342 | NetIO * io = new NetIO(party==ALICE ? nullptr : address.c_str(), port); 343 | 344 | uint64_t *ip_ss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 345 | uint64_t *op_ss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 346 | uint64_t *op_mss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 347 | 348 | //Communication Initialization 349 | uint64_t comm_sent = 0; 350 | start_comm = io->counter; 351 | //Time Begin 352 | auto start = clock_start(); 353 | //Setup 354 | setup_semi_honest(io, party); 355 | if(party == ALICE) { 356 | prg.random_data(&mac_key, 8); 357 | mac_key %= prime_mod; 358 | } 359 | 360 | //Garbled Circuit 361 | msi_relu(party, io, inputs, nrelu, bitlen, ip_ss, op_ss, op_mss); 362 | //Time End 363 | long long t = time_from(start); 364 | cout << "######################Performance#######################" <<endl; 365 | cout<<"Time Taken: "<<t<<" mus"<<endl; 366 | //Calculate Communication 367 | comm_sent = (io->counter-start_comm)>>20; 368 | cout<<"Sent Data (MB): "<<comm_sent<<endl; 369 | cout << "########################################################" <<endl; 370 | finalize_semi_honest(); 371 | 372 | //Test Protocol 373 | if(verify) { 374 | if(party == BOB) { 375 | io->send_data(inputs, sizeof(uint64_t) * nrelu); 376 | io->send_data(ip_ss, sizeof(uint64_t) * nrelu); 377 | io->send_data(op_ss, sizeof(uint64_t) * nrelu); 378 | io->send_data(op_mss, sizeof(uint64_t) * nrelu); 379 | } else { 380 | uint64_t inputs_1[nrelu]; 381 | uint64_t inputs_res[nrelu]; 382 | uint64_t relu_res[nrelu]; 383 | 384 | uint64_t *ip_ssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 385 | uint64_t *op_ssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 386 | uint64_t *op_mssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 387 | 388 | io->recv_data(inputs_1, sizeof(uint64_t) * nrelu); 389 | io->recv_data(ip_ssc, sizeof(uint64_t) * nrelu); 390 | io->recv_data(op_ssc, sizeof(uint64_t) * nrelu); 391 | io->recv_data(op_mssc, sizeof(uint64_t) * nrelu); 392 | for(int i=0; i< nrelu; i++) { 393 | inputs_res[i] = (inputs_1[i] + inputs[i])%prime_mod; 394 | if(inputs_res[i] > prime_mod/2) { 395 | relu_res[i] = 0; 396 | } else { 397 | relu_res[i] = inputs_res[i]; 398 | } 399 | } 400 | 401 | uint64_t ip_shares, ip_corr, op_shares, op_corr, opm_shares, opm_corr; 402 | uint64_t ctr_ip, ctr_op, ctr_opm=0; 403 | 404 | for(int i=0; i<nrelu; i++) { 405 | ip_shares = (ip_ss[i]+ip_ssc[i])%prime_mod; 406 | ip_corr = mod_mult(mac_key,inputs_res[i]); 407 | if(ip_shares == ip_corr) 408 | ctr_ip++; 409 | 410 | op_shares = (op_ss[i]+op_ssc[i])%prime_mod; 411 | if(op_shares == relu_res[i]) 412 | ctr_op++; 413 | 414 | opm_shares = (op_mss[i] + op_mssc[i])%prime_mod; 415 | opm_corr = mod_mult(mac_key,relu_res[i]); 416 | if(opm_shares == opm_corr) 417 | ctr_opm++; 418 | } 419 | cout << "**********************Verification**********************" <<endl; 420 | cout<<"Correct Input Macs: "<< ctr_ip<<endl; 421 | cout<<"Correct Outputs: "<< ctr_op<<endl; 422 | cout<<"Correct Output Macs: "<< ctr_opm<<endl; 423 | cout << "********************************************************" <<endl; 424 | } 425 | } 426 | //Performance Result 427 | 428 | } 429 | -------------------------------------------------------------------------------- /test/msi_relu_final.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-sh2pc/emp-sh2pc.h" 2 | #include <cmath> 3 | 4 | #include "seal/util/uintarith.h" 5 | #include "seal/util/uintarithsmallmod.h" 6 | #include <thread> 7 | #define MAX_THREADS 8 8 | using namespace emp; 9 | using namespace std; 10 | 11 | 12 | int num_threads = 8; 13 | 14 | //Slackoverflow Code For bit-wise shift 15 | #define SHL128(v, n) \ 16 | ({ \ 17 | __m128i v1, v2; \ 18 | \ 19 | if ((n) >= 64) \ 20 | { \ 21 | v1 = _mm_slli_si128(v, 8); \ 22 | v1 = _mm_slli_epi64(v1, (n) - 64); \ 23 | } \ 24 | else \ 25 | { \ 26 | v1 = _mm_slli_epi64(v, n); \ 27 | v2 = _mm_slli_si128(v, 8); \ 28 | v2 = _mm_srli_epi64(v2, 64 - (n)); \ 29 | v1 = _mm_or_si128(v1, v2); \ 30 | } \ 31 | v1; \ 32 | }) 33 | 34 | enum neural_net { 35 | NONE, 36 | MINIONN, 37 | CIFAR10 38 | }; 39 | 40 | neural_net choice_nn; 41 | uint64_t start_comm[MAX_THREADS]; 42 | uint64_t comm_sent = 0; 43 | NetIO *ioArr[MAX_THREADS]; 44 | uint64_t prime_mod = 17592060215297; 45 | seal::Modulus mod(prime_mod); 46 | int port = 32000, def_nrelu = 1<<20, l = 44; 47 | neural_net def_nn = NONE; 48 | string address; 49 | bool run_all = false; 50 | uint64_t mac_key; 51 | PRG prg; 52 | string benchmark; 53 | 54 | bool verify = false; 55 | int MINIONN_RELUS[] = { 16*576, 16*64, 100*1 56 | }; 57 | 58 | int CIFAR10_RELUS[] = { 64*1024, 64*1024, 64*256, 64*256, 64*64, 64*64, 16*64 59 | }; 60 | 61 | uint64_t mod_shift(uint64_t a, uint64_t b, uint64_t prime_mod) { 62 | __m128i temp, stemp; 63 | memcpy(&temp, &a, 8); 64 | stemp = SHL128(temp, b); 65 | 66 | uint64_t input[2]; 67 | input[0] = stemp[0]; 68 | input[1] = stemp[1]; 69 | 70 | uint64_t result = seal::util::barrett_reduce_128(input, mod); 71 | 72 | return result; 73 | } 74 | 75 | uint64_t mod_mult(uint64_t a, uint64_t b) { 76 | unsigned long long temp_result[2]; 77 | seal::util::multiply_uint64(a, b, temp_result); 78 | 79 | /*uint64_t input[2]; 80 | input[0] = res[0]; 81 | input[1] = res[1];*/ 82 | uint64_t result = seal::util::barrett_reduce_128(temp_result, mod); 83 | return result; 84 | } 85 | 86 | 87 | //Referred SCI OT repo's logic to pack ot messages 88 | void pack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 89 | uint64_t beg_idx = 0; 90 | uint64_t end_idx = 0; 91 | uint64_t beg_blk = 0; 92 | uint64_t end_blk = 0; 93 | uint64_t temp_blk = 0; 94 | uint64_t mask = (1ULL << bitlen) - 1; 95 | uint64_t pack_blk_size = 64; 96 | 97 | if (bitlen == 64) 98 | mask = -1; 99 | 100 | for (int i = 0; i < pack_size; i++) { 101 | pack_table[i] = 0; 102 | } 103 | 104 | for (int i = 0; i < batch_size; i++) { 105 | beg_idx = i * bitlen; 106 | end_idx = beg_idx + bitlen; 107 | end_idx -= 1; 108 | beg_blk = beg_idx / pack_blk_size; 109 | end_blk = end_idx / pack_blk_size; 110 | 111 | if (beg_blk == end_blk) { 112 | pack_table[beg_blk] ^= (ciphertexts[i] & mask) << (beg_idx % pack_blk_size); 113 | } else { 114 | temp_blk = (ciphertexts[i] & mask); 115 | pack_table[beg_blk] ^= (temp_blk) << (beg_idx % pack_blk_size); 116 | pack_table[end_blk] ^= (temp_blk) >> (pack_blk_size - (beg_idx % pack_blk_size)); 117 | } 118 | } 119 | } 120 | 121 | //Referred SCI OT repo's logic to unpack ot messages 122 | void unpack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 123 | uint64_t beg_idx = 0; 124 | uint64_t end_idx = 0; 125 | uint64_t beg_blk = 0; 126 | uint64_t end_blk = 0; 127 | uint64_t temp_blk = 0; 128 | uint64_t mask = (1ULL << bitlen) - 1; 129 | uint64_t pack_blk_size = 64; 130 | 131 | for (int i = 0; i < batch_size; i++) { 132 | beg_idx = i * bitlen; 133 | end_idx = beg_idx + bitlen - 1; 134 | beg_blk = beg_idx / pack_blk_size; 135 | end_blk = end_idx / pack_blk_size; 136 | 137 | if (beg_blk == end_blk) { 138 | ciphertexts[i] = (pack_table[beg_blk] >> (beg_idx % pack_blk_size)) & mask; 139 | } else { 140 | ciphertexts[i] = 0; 141 | ciphertexts[i] ^= (pack_table[beg_blk] >> (beg_idx % pack_blk_size)); 142 | ciphertexts[i] ^= (pack_table[end_blk] << (pack_blk_size - (beg_idx % pack_blk_size))); 143 | ciphertexts[i] = ciphertexts[i] & mask; 144 | } 145 | } 146 | } 147 | 148 | void create_ciphertexts(Integer *garbled_data, block label_delta, uint64_t *ciphertexts, uint64_t* server_shares, int bitlen, int nrelu, uint64_t alpha, int l_idx, bool apply_prg) { 149 | uint64_t delta_int; 150 | memcpy(&delta_int, &label_delta[l_idx], 8); 151 | 152 | uint64_t mask = (1ULL << bitlen) - 1; 153 | block seed, label_block_0, label_block_1; 154 | uint64_t label_temp_0, label_temp_1; 155 | uint64_t **random_val = (uint64_t **)malloc(nrelu*sizeof(uint64_t*)); 156 | uint8_t pnp, cpnp; 157 | for(int i=0; i<nrelu; i++) { 158 | random_val[i] = (uint64_t *)malloc(bitlen*sizeof(uint64_t)); 159 | } 160 | 161 | for(int i=0; i<nrelu; i++) { 162 | server_shares[i] = 0; 163 | for(int j=0; j<bitlen; j++) { 164 | if(apply_prg) { 165 | memcpy(&seed, &garbled_data[i].bits[j].bit, 16); 166 | PRG prg0(&seed); 167 | prg0.random_data(&label_block_0, 16); 168 | seed = garbled_data[i].bits[j].bit^label_delta; 169 | PRG prg1(&seed); 170 | prg1.random_data(&label_block_1, 16); 171 | 172 | memcpy(&label_temp_0, &label_block_0[l_idx], 8); 173 | memcpy(&label_temp_1, &label_block_1[l_idx], 8); 174 | 175 | } else { 176 | memcpy(&label_temp_0, &garbled_data[i].bits[j].bit[l_idx], 8); 177 | label_temp_1 = label_temp_0^delta_int; 178 | } 179 | 180 | prg.random_data(&random_val[i][j], 8); 181 | random_val[i][j] %= prime_mod; 182 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 183 | cpnp = 1 - pnp; 184 | ciphertexts[(i*bitlen+j)*2+pnp] = (random_val[i][j])^(label_temp_0 & mask); 185 | ciphertexts[(i*bitlen+j)*2+cpnp] = ((random_val[i][j]+alpha)%prime_mod)^(label_temp_1 & mask); 186 | server_shares[i] = (server_shares[i] + mod_shift(random_val[i][j],j,prime_mod))%prime_mod; 187 | } 188 | server_shares[i] = prime_mod - server_shares[i]; 189 | } 190 | } 191 | 192 | void decrypt_ciphertexts(Integer *garbled_data, uint64_t *ciphertexts, uint64_t* client_shares, int bitlen, int nrelu, int l_idx, bool apply_prg) { 193 | uint64_t label_temp; 194 | block label_block; 195 | uint8_t pnp; 196 | uint64_t random_val; 197 | 198 | uint64_t mask = (1ULL << bitlen) - 1; 199 | 200 | for(int i=0; i<nrelu; i++) { 201 | client_shares[i] = 0; 202 | for(int j=0; j< bitlen; j++) { 203 | if(apply_prg) { 204 | PRG prg(&garbled_data[i].bits[j].bit); 205 | prg.random_data(&label_block, 16); 206 | 207 | memcpy(&label_temp, &label_block[l_idx], 8); 208 | } else { 209 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 210 | } 211 | 212 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 213 | random_val = ciphertexts[(i*bitlen+j)*2+pnp]^(label_temp & mask); 214 | client_shares[i] = (client_shares[i] + mod_shift(random_val,j,prime_mod))%prime_mod; 215 | } 216 | } 217 | } 218 | 219 | void msi_relu_6(int party, int tid, NetIO* io, uint64_t* inputs, int nrelu, int bitlen, uint64_t* ip_ss, uint64_t* op_ss, uint64_t* op_mss) { 220 | //Public prime values 221 | Integer p(bitlen + 1, prime_mod, PUBLIC); 222 | Integer p_mod2(bitlen, prime_mod/2, PUBLIC); 223 | Integer zero(bitlen, 0, PUBLIC); 224 | Integer six(bitlen, 6, PUBLIC); 225 | 226 | //Assign Inputs 227 | Integer *X = new Integer[nrelu]; 228 | for(int i = 0; i < nrelu; ++i) 229 | X[i] = Integer(bitlen+1, inputs[i], ALICE); 230 | Integer *Y = new Integer[nrelu]; 231 | for(int i = 0; i < nrelu; ++i) 232 | Y[i] = Integer(bitlen+1, inputs[i], BOB); 233 | 234 | Integer *S = new Integer[nrelu]; 235 | Integer *U = new Integer[nrelu]; 236 | Integer *T = new Integer[nrelu]; 237 | 238 | //Check if Bob's share is < p 239 | Bit res[nrelu]; 240 | for(int i=0; i < nrelu; ++i) 241 | res[i] = Y[i] > p; 242 | 243 | for(int i=0; i < nrelu; ++i) { 244 | //Perform mod p 245 | Integer s0 = X[i]; 246 | //s0.resize(s0.size()+1); 247 | 248 | Integer s1 = Y[i]; 249 | //s1.resize(s1.size()+1); 250 | 251 | Integer sum = s0 + s1; 252 | 253 | Integer mod_p_val = sum - p; 254 | 255 | Bit borrow_bit = mod_p_val[mod_p_val.size()-1]; 256 | 257 | Integer s = mod_p_val.select(borrow_bit, sum); 258 | 259 | S[i] = s; 260 | 261 | //Perform RELU 262 | Integer p2_minus_s = p_mod2-s; 263 | 264 | Bit is_negative = p2_minus_s[p2_minus_s.size()-1]; 265 | 266 | Integer relu_s = s.select(is_negative, zero); 267 | 268 | Integer six_minus_res = six - relu_s; 269 | Bit is_greater_than_six = six_minus_res[six_minus_res.size()-1]; 270 | 271 | Integer res = relu_s.select(is_greater_than_six, six); 272 | 273 | T[i] = res; 274 | } 275 | 276 | int pack_size = ceil(nrelu*bitlen*bitlen*2.0/(8*sizeof(uint64_t))); 277 | int batch_size = nrelu*bitlen*2; 278 | 279 | uint64_t *ip_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 280 | 281 | uint64_t *op_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 282 | 283 | uint64_t *op_mcts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 284 | 285 | uint64_t *ip_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 286 | uint64_t *op_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 287 | uint64_t *opm_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 288 | 289 | if(party == ALICE) { 290 | 291 | create_ciphertexts(S, delta_blocks[tid], ip_cts, ip_ss, bitlen, nrelu, mac_key, 1, true); 292 | create_ciphertexts(T, delta_blocks[tid], op_cts, op_ss, bitlen, nrelu, 1, 0, false); 293 | create_ciphertexts(T, delta_blocks[tid], op_mcts, op_mss, bitlen, nrelu, mac_key, 1, false); 294 | 295 | pack_decryption_table(ip_pack_table, ip_cts, pack_size, batch_size, bitlen); 296 | pack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 297 | pack_decryption_table(opm_pack_table, op_mcts, pack_size, batch_size, bitlen); 298 | 299 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 300 | io->send_data(ip_pack_table, sizeof(uint64_t) * pack_size); 301 | io->send_data(op_pack_table, sizeof(uint64_t) * pack_size); 302 | io->send_data(opm_pack_table, sizeof(uint64_t) * pack_size); 303 | } else { 304 | io->recv_data(ip_pack_table, sizeof(uint64_t) * pack_size); 305 | io->recv_data(op_pack_table, sizeof(uint64_t) * pack_size); 306 | io->recv_data(opm_pack_table, sizeof(uint64_t) * pack_size); 307 | 308 | unpack_decryption_table(ip_pack_table, ip_cts, pack_size, batch_size, bitlen); 309 | unpack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 310 | unpack_decryption_table(opm_pack_table, op_mcts, pack_size, batch_size, bitlen); 311 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 312 | 313 | decrypt_ciphertexts(S, ip_cts, ip_ss, bitlen, nrelu, 1, true); 314 | decrypt_ciphertexts(T, op_cts, op_ss, bitlen, nrelu, 0, false); 315 | decrypt_ciphertexts(T, op_mcts, op_mss, bitlen, nrelu, 1, false); 316 | } 317 | } 318 | 319 | void parse_arguments(int argc, char**arg, int *party, int *port, int *bitlen, int *nrelu) { 320 | *party = atoi (arg[1]); 321 | address = arg[2]; 322 | *port = atoi (arg[3]); 323 | if(argc < 5) { 324 | *bitlen = l; 325 | } else { 326 | *bitlen = atoi(arg[4]); 327 | } 328 | if(argc < 6) { 329 | choice_nn =def_nn; 330 | } else { 331 | choice_nn = neural_net(atoi (arg[5])); 332 | } 333 | 334 | switch(choice_nn) { 335 | case NONE: { 336 | if(argc < 7) { 337 | *nrelu = def_nrelu; 338 | } else { 339 | *nrelu = atoi(arg[6]); 340 | } 341 | benchmark = "non-selected"; 342 | } 343 | break; 344 | case MINIONN: { 345 | *nrelu = 0; 346 | int len = *(&MINIONN_RELUS+1)-MINIONN_RELUS; 347 | for(int i=0; i< len; i++) { 348 | *nrelu += MINIONN_RELUS[i]; 349 | } 350 | benchmark = "mnist"; 351 | } 352 | break; 353 | case CIFAR10: { 354 | *nrelu = 0; 355 | int len = *(&CIFAR10_RELUS+1)-CIFAR10_RELUS; 356 | for(int i=0; i< len; i++) { 357 | *nrelu += CIFAR10_RELUS[i]; 358 | } 359 | benchmark = "cifar10"; 360 | } 361 | } 362 | 363 | if(argc < 8) { 364 | run_all = false; 365 | } else { 366 | run_all = (bool)(atoi(arg[7])); 367 | } 368 | 369 | if(argc < 9) { 370 | // 371 | } else { 372 | num_threads = atoi(arg[8]); 373 | } 374 | } 375 | 376 | void thread_process(int tid, int party, uint64_t* inputs, int nrelu, int bitlen, uint64_t* ip_ss, uint64_t* op_ss, uint64_t* op_mss) { 377 | uint64_t *ptr = inputs+4608; 378 | setup_semi_honest_mult(ioArr[tid], party, tid); 379 | int prev_ctr=0; 380 | int len = *(&MINIONN_RELUS+1)-MINIONN_RELUS; 381 | for(int i=0; i<len; i++) { 382 | uint64_t num_relu_layer = MINIONN_RELUS[i]; 383 | uint64_t nr_per_thread = num_relu_layer/num_threads; 384 | uint64_t r = num_relu_layer % num_threads; 385 | uint64_t actual_per_thread; 386 | if(tid ==num_threads-1) 387 | actual_per_thread = nr_per_thread + r; 388 | else 389 | actual_per_thread = nr_per_thread; 390 | uint64_t offset = prev_ctr + tid*nr_per_thread; 391 | //cout<<"Thread id: "<<tid<<", Offset: "<<offset<<", NR Threads: "<<nr_per_thread<<"Actual Threads: "<<actual_per_thread<<endl; 392 | //cout<<"Thread id:"<<tid<<", First Value (Out): "<<*(inputs+offset)<<endl; 393 | msi_relu_6(party, tid, ioArr[tid], inputs+offset, actual_per_thread, bitlen, ip_ss+offset, op_ss+offset, op_mss+offset); 394 | prev_ctr += MINIONN_RELUS[i]; 395 | } 396 | ioArr[tid]->flush(); 397 | finalize_semi_honest(); 398 | } 399 | 400 | void thread_process_1(int tid, int party, uint64_t* inputs, int nrelu, int bitlen, uint64_t* ip_ss, uint64_t* op_ss, uint64_t* op_mss) { 401 | uint64_t *ptr = inputs+4608; 402 | setup_semi_honest_mult(ioArr[tid], party, tid); 403 | int prev_ctr=0; 404 | int len = *(&CIFAR10_RELUS+1)-CIFAR10_RELUS; 405 | for(int i=0; i<len; i++) { 406 | uint64_t num_relu_layer = CIFAR10_RELUS[i]; 407 | uint64_t nr_per_thread = num_relu_layer/num_threads; 408 | uint64_t r = num_relu_layer % num_threads; 409 | uint64_t actual_per_thread; 410 | if(tid ==num_threads-1) 411 | actual_per_thread = nr_per_thread + r; 412 | else 413 | actual_per_thread = nr_per_thread; 414 | uint64_t offset = prev_ctr + tid*nr_per_thread; 415 | //cout<<"Thread id: "<<tid<<", Offset: "<<offset<<", NR Threads: "<<nr_per_thread<<"Actual Threads: "<<actual_per_thread<<endl; 416 | //cout<<"Thread id:"<<tid<<", First Value (Out): "<<*(inputs+offset)<<endl; 417 | msi_relu_6(party, tid, ioArr[tid], inputs+offset, actual_per_thread, bitlen, ip_ss+offset, op_ss+offset, op_mss+offset); 418 | prev_ctr += CIFAR10_RELUS[i]; 419 | } 420 | ioArr[tid]->flush(); 421 | finalize_semi_honest(); 422 | } 423 | 424 | 425 | 426 | int main(int argc, char** argv) { 427 | srand(time(NULL)); 428 | int port, party, nrelu, bitlen; 429 | //Parse input arguments and configure parameters 430 | parse_arguments(argc, argv, &party, &port, &bitlen, &nrelu); 431 | cout<<"Executing Non-linear Layers ..."<<endl; 432 | cout << "=====================Configuration======================" << endl; 433 | cout<<"Role: "<< party<<" - IP Address: "<< address <<" - Port: "<<port<<" - Benchmark: "<<benchmark<<" - Bitlength: "<<bitlen<<endl; 434 | cout << "========================================================" << endl; 435 | //Prepare Inputs 436 | std::random_device rd; 437 | std::mt19937_64 eng(rd()); 438 | std::uniform_int_distribution<uint64_t> distr; 439 | 440 | uint64_t* inputs=(uint64_t *)malloc(nrelu*sizeof(uint64_t)); 441 | for(int i = 0; i < nrelu; ++i) 442 | inputs[i] = distr(eng)%prime_mod; 443 | 444 | uint64_t *ip_ss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 445 | uint64_t *op_ss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 446 | uint64_t *op_mss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 447 | 448 | for(int i=0; i <num_threads; i++) { 449 | ioArr[i] = new NetIO(party==ALICE ? nullptr : address.c_str(), port+i); 450 | } 451 | 452 | //Communication Initialization 453 | for(int i=0; i<num_threads; i++) 454 | start_comm[i] = ioArr[i]->counter; 455 | 456 | //Time Begin 457 | auto start = clock_start(); 458 | 459 | if(party == ALICE) { 460 | prg.random_data(&mac_key, 8); 461 | mac_key %= prime_mod; 462 | } 463 | 464 | std::thread relu_threads[num_threads]; 465 | for(int i=0; i<num_threads; i++) { 466 | if(choice_nn == MINIONN) { 467 | relu_threads[i] = std::thread(thread_process, i, party, inputs, nrelu, bitlen, ip_ss, op_ss, op_mss); 468 | } else { 469 | relu_threads[i] = std::thread(thread_process_1, i, party, inputs, nrelu, bitlen, ip_ss, op_ss, op_mss); 470 | } 471 | } 472 | 473 | //Join 474 | for(int i=0; i<num_threads; i++) { 475 | relu_threads[i].join(); 476 | } 477 | //Time End 478 | long long t = time_from(start); 479 | cout << "######################Performance#######################" <<endl; 480 | cout<<"Time Taken: "<<t<<" mus"<<endl; 481 | //Calculate Communication 482 | comm_sent = 0; 483 | for(int i=0; i<num_threads; i++) { 484 | comm_sent += (ioArr[i]->counter-start_comm[i]); 485 | } 486 | comm_sent = comm_sent>>20; 487 | cout<<"Sent Data (MB): "<<comm_sent<<endl; 488 | cout << "########################################################" <<endl; 489 | 490 | //Test Protocol 491 | if(verify) { 492 | ioArr[0] = new NetIO(party==ALICE ? nullptr : address.c_str(), port); 493 | cout<<"nrelu: "<<nrelu<<endl; 494 | if(party == BOB) { 495 | ioArr[0]->send_data(inputs, sizeof(uint64_t) * nrelu); 496 | ioArr[0]->send_data(ip_ss, sizeof(uint64_t) * nrelu); 497 | ioArr[0]->send_data(op_ss, sizeof(uint64_t) * nrelu); 498 | ioArr[0]->send_data(op_mss, sizeof(uint64_t) * nrelu); 499 | } else { 500 | uint64_t inputs_1[nrelu]; 501 | uint64_t inputs_res[nrelu]; 502 | uint64_t relu_res[nrelu]; 503 | 504 | uint64_t *ip_ssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 505 | uint64_t *op_ssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 506 | uint64_t *op_mssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 507 | 508 | ioArr[0]->recv_data(inputs_1, sizeof(uint64_t) * nrelu); 509 | ioArr[0]->recv_data(ip_ssc, sizeof(uint64_t) * nrelu); 510 | ioArr[0]->recv_data(op_ssc, sizeof(uint64_t) * nrelu); 511 | ioArr[0]->recv_data(op_mssc, sizeof(uint64_t) * nrelu); 512 | for(int i=0; i< nrelu; i++) { 513 | inputs_res[i] = (inputs_1[i] + inputs[i])%prime_mod; 514 | if(inputs_res[i] > prime_mod/2) { 515 | relu_res[i] = 0; 516 | } else { 517 | if(inputs_res[i] > 6) 518 | relu_res[i] = 6; 519 | else 520 | relu_res[i] = inputs_res[i]; 521 | } 522 | } 523 | 524 | uint64_t ip_shares, ip_corr, op_shares, op_corr, opm_shares, opm_corr; 525 | uint64_t ctr_ip, ctr_op, ctr_opm=0; 526 | 527 | for(int i=0; i<nrelu; i++) { 528 | ip_shares = (ip_ss[i]+ip_ssc[i])%prime_mod; 529 | ip_corr = mod_mult(mac_key,inputs_res[i]); 530 | if(ip_shares == ip_corr) 531 | ctr_ip++; 532 | else { 533 | cout<<"Index: "<<i<<endl; 534 | break; 535 | } 536 | 537 | op_shares = (op_ss[i]+op_ssc[i])%prime_mod; 538 | if(op_shares == relu_res[i]) 539 | ctr_op++; 540 | 541 | opm_shares = (op_mss[i] + op_mssc[i])%prime_mod; 542 | opm_corr = mod_mult(mac_key,relu_res[i]); 543 | if(opm_shares == opm_corr) 544 | ctr_opm++; 545 | } 546 | cout << "**********************Verification**********************" <<endl; 547 | cout<<"Correct Input Macs: "<< ctr_ip<<endl; 548 | cout<<"Correct Outputs: "<< ctr_op<<endl; 549 | cout<<"Correct Output Macs: "<< ctr_opm<<endl; 550 | cout << "********************************************************" <<endl; 551 | } 552 | } 553 | //Performance Result 554 | 555 | } 556 | -------------------------------------------------------------------------------- /test/msi_relu_preprocess.cpp: -------------------------------------------------------------------------------- 1 | #include "emp-sh2pc/emp-sh2pc.h" 2 | #include <cmath> 3 | 4 | #include "seal/util/uintarith.h" 5 | #include "seal/util/uintarithsmallmod.h" 6 | #include <thread> 7 | #define MAX_THREADS 8 8 | using namespace emp; 9 | using namespace std; 10 | 11 | 12 | int num_threads = 8; 13 | 14 | //Slackoverflow Code For bit-wise shift 15 | #define SHL128(v, n) \ 16 | ({ \ 17 | __m128i v1, v2; \ 18 | \ 19 | if ((n) >= 64) \ 20 | { \ 21 | v1 = _mm_slli_si128(v, 8); \ 22 | v1 = _mm_slli_epi64(v1, (n) - 64); \ 23 | } \ 24 | else \ 25 | { \ 26 | v1 = _mm_slli_epi64(v, n); \ 27 | v2 = _mm_slli_si128(v, 8); \ 28 | v2 = _mm_srli_epi64(v2, 64 - (n)); \ 29 | v1 = _mm_or_si128(v1, v2); \ 30 | } \ 31 | v1; \ 32 | }) 33 | 34 | enum neural_net { 35 | NONE, 36 | MINIONN, 37 | CIFAR10 38 | }; 39 | 40 | neural_net choice_nn; 41 | uint64_t start_comm[MAX_THREADS]; 42 | uint64_t comm_sent = 0; 43 | NetIO *ioArr[MAX_THREADS]; 44 | uint64_t prime_mod = 17592060215297; 45 | seal::Modulus mod(prime_mod); 46 | int port = 32000, def_nrelu = 1<<20, l = 44; 47 | neural_net def_nn = NONE; 48 | string address; 49 | bool run_all = false; 50 | uint64_t mac_key; 51 | PRG prg; 52 | 53 | bool verify = false; 54 | int MINIONN_RELUS[] = { 16*576, 16*64, 100*1 55 | }; 56 | 57 | int CIFAR10_RELUS[] = { 64*1024, 64*1024, 64*256, 64*256, 64*64, 64*64, 16*64 58 | }; 59 | 60 | uint64_t mod_shift(uint64_t a, uint64_t b, uint64_t prime_mod) { 61 | __m128i temp, stemp; 62 | memcpy(&temp, &a, 8); 63 | stemp = SHL128(temp, b); 64 | 65 | uint64_t input[2]; 66 | input[0] = stemp[0]; 67 | input[1] = stemp[1]; 68 | 69 | uint64_t result = seal::util::barrett_reduce_128(input, mod); 70 | 71 | return result; 72 | } 73 | 74 | uint64_t mod_mult(uint64_t a, uint64_t b) { 75 | unsigned long long temp_result[2]; 76 | seal::util::multiply_uint64(a, b, temp_result); 77 | 78 | /*uint64_t input[2]; 79 | input[0] = res[0]; 80 | input[1] = res[1];*/ 81 | uint64_t result = seal::util::barrett_reduce_128(temp_result, mod); 82 | return result; 83 | } 84 | 85 | 86 | //Referred SCI OT repo's logic to pack ot messages 87 | void pack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 88 | uint64_t beg_idx = 0; 89 | uint64_t end_idx = 0; 90 | uint64_t beg_blk = 0; 91 | uint64_t end_blk = 0; 92 | uint64_t temp_blk = 0; 93 | uint64_t mask = (1ULL << bitlen) - 1; 94 | uint64_t pack_blk_size = 64; 95 | 96 | if (bitlen == 64) 97 | mask = -1; 98 | 99 | for (int i = 0; i < pack_size; i++) { 100 | pack_table[i] = 0; 101 | } 102 | 103 | for (int i = 0; i < batch_size; i++) { 104 | beg_idx = i * bitlen; 105 | end_idx = beg_idx + bitlen; 106 | end_idx -= 1; 107 | beg_blk = beg_idx / pack_blk_size; 108 | end_blk = end_idx / pack_blk_size; 109 | 110 | if (beg_blk == end_blk) { 111 | pack_table[beg_blk] ^= (ciphertexts[i] & mask) << (beg_idx % pack_blk_size); 112 | } else { 113 | temp_blk = (ciphertexts[i] & mask); 114 | pack_table[beg_blk] ^= (temp_blk) << (beg_idx % pack_blk_size); 115 | pack_table[end_blk] ^= (temp_blk) >> (pack_blk_size - (beg_idx % pack_blk_size)); 116 | } 117 | } 118 | } 119 | 120 | //Referred SCI OT repo's logic to unpack ot messages 121 | void unpack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 122 | uint64_t beg_idx = 0; 123 | uint64_t end_idx = 0; 124 | uint64_t beg_blk = 0; 125 | uint64_t end_blk = 0; 126 | uint64_t temp_blk = 0; 127 | uint64_t mask = (1ULL << bitlen) - 1; 128 | uint64_t pack_blk_size = 64; 129 | 130 | for (int i = 0; i < batch_size; i++) { 131 | beg_idx = i * bitlen; 132 | end_idx = beg_idx + bitlen - 1; 133 | beg_blk = beg_idx / pack_blk_size; 134 | end_blk = end_idx / pack_blk_size; 135 | 136 | if (beg_blk == end_blk) { 137 | ciphertexts[i] = (pack_table[beg_blk] >> (beg_idx % pack_blk_size)) & mask; 138 | } else { 139 | ciphertexts[i] = 0; 140 | ciphertexts[i] ^= (pack_table[beg_blk] >> (beg_idx % pack_blk_size)); 141 | ciphertexts[i] ^= (pack_table[end_blk] << (pack_blk_size - (beg_idx % pack_blk_size))); 142 | ciphertexts[i] = ciphertexts[i] & mask; 143 | } 144 | } 145 | } 146 | 147 | void create_ciphertexts(Integer *garbled_data, block label_delta, uint64_t *ciphertexts, uint64_t* server_shares, int bitlen, int nrelu, uint64_t alpha, int l_idx) { 148 | uint64_t delta_int; 149 | memcpy(&delta_int, &label_delta[l_idx], 8); 150 | 151 | uint64_t mask = (1ULL << bitlen) - 1; 152 | uint64_t label_temp; 153 | uint64_t **random_val = (uint64_t **)malloc(nrelu*sizeof(uint64_t*)); 154 | uint8_t pnp, cpnp; 155 | for(int i=0; i<nrelu; i++) { 156 | random_val[i] = (uint64_t *)malloc(bitlen*sizeof(uint64_t)); 157 | } 158 | 159 | for(int i=0; i<nrelu; i++) { 160 | server_shares[i] = 0; 161 | for(int j=0; j<bitlen; j++) { 162 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 163 | prg.random_data(&random_val[i][j], 8); 164 | random_val[i][j] %= prime_mod; 165 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 166 | cpnp = 1 - pnp; 167 | ciphertexts[(i*bitlen+j)*2+pnp] = (random_val[i][j])^(label_temp & mask); 168 | ciphertexts[(i*bitlen+j)*2+cpnp] = ((random_val[i][j]+alpha)%prime_mod)^((label_temp^delta_int) & mask); 169 | if(i==0) { 170 | uint64_t l1 = label_temp & mask, l2 = label_temp^delta_int & mask; 171 | } 172 | server_shares[i] = (server_shares[i] + mod_shift(random_val[i][j],j,prime_mod))%prime_mod; 173 | } 174 | server_shares[i] = prime_mod - server_shares[i]; 175 | } 176 | } 177 | 178 | void decrypt_ciphertexts(Integer *garbled_data, uint64_t *ciphertexts, uint64_t* client_shares, int bitlen, int nrelu, int l_idx) { 179 | uint64_t label_temp; 180 | uint8_t pnp; 181 | uint64_t random_val; 182 | 183 | uint64_t mask = (1ULL << bitlen) - 1; 184 | 185 | for(int i=0; i<nrelu; i++) { 186 | client_shares[i] = 0; 187 | for(int j=0; j< bitlen; j++) { 188 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 189 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 190 | random_val = ciphertexts[(i*bitlen+j)*2+pnp]^(label_temp & mask); 191 | client_shares[i] = (client_shares[i] + mod_shift(random_val,j,prime_mod))%prime_mod; 192 | } 193 | } 194 | } 195 | 196 | void msi_relu_6(int party, NetIO* io, uint64_t* inputs, uint64_t* rand_vals, int nrelu, int bitlen, uint64_t* ip_mss, uint64_t* rand_mss, uint64_t* op_ss) { 197 | //Public prime values 198 | Integer p(bitlen + 1, prime_mod, PUBLIC); 199 | Integer p_mod2(bitlen, prime_mod/2, PUBLIC); 200 | Integer zero(bitlen, 0, PUBLIC); 201 | Integer six(bitlen, 6, PUBLIC); 202 | 203 | //Assign Inputs 204 | Integer *X = new Integer[nrelu]; 205 | for(int i = 0; i < nrelu; ++i) 206 | X[i] = Integer(bitlen+1, inputs[i], ALICE); 207 | 208 | Integer *Y = new Integer[nrelu]; 209 | for(int i = 0; i < nrelu; ++i) 210 | Y[i] = Integer(bitlen+1, inputs[i], BOB); 211 | 212 | Integer *Z = new Integer[nrelu]; 213 | for(int i=0; i < nrelu; i++) { 214 | Z[i] = Integer(bitlen+1, rand_vals, ALICE); 215 | } 216 | 217 | Integer *R = new Integer[nrelu]; 218 | for(int i=0; i<nrelu; i++) { 219 | R[i] = Integer(bitlen+1, rand_vals, BOB); 220 | } 221 | 222 | Integer *S = new Integer[nrelu]; 223 | Integer *XR = new Integer[nrelu]; 224 | Integer *T = new Integer[nrelu]; 225 | 226 | //Check if Bob's share is < p 227 | Bit *res = new Bit[nrelu]; 228 | for(int i=0; i < nrelu; ++i) 229 | res[i] = Y[i] > p; 230 | 231 | Bit *rres = new Bit[nrelu]; 232 | for(int i=0; i < nrelu; ++i) 233 | rres[i] = R[i] > p; 234 | 235 | for(int i=0; i < nrelu; ++i) { 236 | //Perform mod p 237 | Integer s0 = X[i]; 238 | //s0.resize(s0.size()+1); 239 | 240 | Integer s1 = Y[i]; 241 | //s1.resize(s1.size()+1); 242 | 243 | Integer sum = s0 + s1; 244 | 245 | Integer mod_p_val = sum - p; 246 | 247 | Bit borrow_bit = mod_p_val[mod_p_val.size()-1]; 248 | 249 | Integer s = mod_p_val.select(borrow_bit, sum); 250 | 251 | S[i] = s; 252 | 253 | //Perform RELU 254 | Integer p2_minus_s = p_mod2-s; 255 | 256 | Bit is_negative = p2_minus_s[p2_minus_s.size()-1]; 257 | 258 | Integer relu_s = s.select(is_negative, zero); 259 | 260 | Integer six_minus_res = six - relu_s; 261 | Bit is_greater_than_six = six_minus_res[six_minus_res.size()-1]; 262 | 263 | Integer relu_res = relu_s.select(is_greater_than_six, six); 264 | 265 | Integer intermediate_share = relu_res + R[i]; 266 | Integer intermediate_share_m = intermediate_share - p; 267 | 268 | borrow_bit = intermediate_share_m[intermediate_share_m.size()-1]; 269 | Integer intermediate_share_modp = intermediate_share_m.select(borrow_bit, intermediate_share); 270 | 271 | Integer res = intermediate_share_modp + Z[i]; 272 | Integer res_m = res - p; 273 | borrow_bit = res_m[res_m.size()-1]; 274 | Integer res_modp = res_m.select(borrow_bit, res); 275 | 276 | T[i] = res_modp; 277 | } 278 | 279 | int pack_size = ceil(nrelu*bitlen*bitlen*2.0/(8*sizeof(uint64_t))); 280 | int batch_size = nrelu*bitlen*2; 281 | 282 | uint64_t *ip_mcts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 283 | 284 | uint64_t *r_mcts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 285 | 286 | uint64_t *op_cts = (uint64_t *)malloc(nrelu*bitlen*2*sizeof(uint64_t)); 287 | 288 | uint64_t *ipm_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 289 | uint64_t *rm_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 290 | uint64_t *op_pack_table = (uint64_t *)malloc(pack_size*sizeof(uint64_t)); 291 | 292 | if(party == ALICE) { 293 | 294 | create_ciphertexts(S, delta_used, ip_mcts, ip_mss, bitlen, nrelu, mac_key, 1); 295 | create_ciphertexts(R, delta_used, r_mcts, rand_mss, bitlen, nrelu, mac_key, 1); 296 | create_ciphertexts(T, delta_used, op_cts, op_ss, bitlen, nrelu, 1, 1); 297 | 298 | pack_decryption_table(ipm_pack_table, ip_mcts, pack_size, batch_size, bitlen); 299 | pack_decryption_table(rm_pack_table, r_mcts, pack_size, batch_size, bitlen); 300 | pack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 301 | 302 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 303 | io->send_data(ipm_pack_table, sizeof(uint64_t) * pack_size); 304 | io->send_data(rm_pack_table, sizeof(uint64_t) * pack_size); 305 | io->send_data(op_pack_table, sizeof(uint64_t) * pack_size); 306 | } else { 307 | io->recv_data(ipm_pack_table, sizeof(uint64_t) * pack_size); 308 | io->recv_data(rm_pack_table, sizeof(uint64_t) * pack_size); 309 | io->recv_data(op_pack_table, sizeof(uint64_t) * pack_size); 310 | 311 | unpack_decryption_table(ipm_pack_table, ip_mcts, pack_size, batch_size, bitlen); 312 | unpack_decryption_table(rm_pack_table, r_mcts, pack_size, batch_size, bitlen); 313 | unpack_decryption_table(op_pack_table, op_cts, pack_size, batch_size, bitlen); 314 | //cout<<"First element (meth):"<<ip_pack_table[0]<<endl; 315 | 316 | decrypt_ciphertexts(S, ip_mcts, ip_mss, bitlen, nrelu, 1); 317 | decrypt_ciphertexts(T, r_mcts, rand_mss, bitlen, nrelu, 0); 318 | decrypt_ciphertexts(T, op_cts, op_ss, bitlen, nrelu, 1); 319 | } 320 | } 321 | 322 | void parse_arguments(int argc, char**arg, int *party, int *port, int *bitlen, int *nrelu) { 323 | *party = atoi (arg[1]); 324 | address = arg[2]; 325 | *port = atoi (arg[3]); 326 | if(argc < 5) { 327 | *bitlen = l; 328 | } else { 329 | *bitlen = atoi(arg[4]); 330 | } 331 | if(argc < 6) { 332 | choice_nn =def_nn; 333 | } else { 334 | choice_nn = neural_net(atoi (arg[5])); 335 | } 336 | 337 | switch(choice_nn) { 338 | case NONE: { 339 | if(argc < 7) { 340 | *nrelu = def_nrelu; 341 | } else { 342 | *nrelu = atoi(arg[6]); 343 | } 344 | } 345 | break; 346 | case MINIONN: { 347 | *nrelu = 0; 348 | int len = *(&MINIONN_RELUS+1)-MINIONN_RELUS; 349 | for(int i=0; i< len; i++) { 350 | *nrelu += MINIONN_RELUS[i]; 351 | } 352 | } 353 | break; 354 | case CIFAR10: { 355 | *nrelu = 0; 356 | int len = *(&CIFAR10_RELUS+1)-CIFAR10_RELUS; 357 | for(int i=0; i< len; i++) { 358 | *nrelu += CIFAR10_RELUS[i]; 359 | } 360 | } 361 | } 362 | 363 | if(argc < 8) { 364 | run_all = false; 365 | } else { 366 | run_all = (bool)(atoi(arg[7])); 367 | } 368 | 369 | if(argc < 9) { 370 | // 371 | } else { 372 | num_threads = atoi(arg[8]); 373 | } 374 | } 375 | 376 | void thread_process(int tid, int party, uint64_t* inputs, uint64_t* rand_vals, int nrelu, int bitlen, uint64_t* ip_mss, uint64_t* rand_mss, uint64_t* op_ss) { 377 | setup_semi_honest(ioArr[tid], party); 378 | int prev_ctr=0; 379 | int len = *(&MINIONN_RELUS+1)-MINIONN_RELUS; 380 | for(int i=0; i<len; i++) { 381 | uint64_t num_relu_layer = MINIONN_RELUS[i]; 382 | uint64_t nr_per_thread = num_relu_layer/num_threads; 383 | uint64_t r = num_relu_layer % num_threads; 384 | uint64_t actual_per_thread; 385 | if(tid ==num_threads-1) 386 | actual_per_thread = nr_per_thread + r; 387 | else 388 | actual_per_thread = nr_per_thread; 389 | uint64_t offset = prev_ctr + tid*nr_per_thread; 390 | msi_relu_6(party, ioArr[tid], inputs+offset, rand_vals+offset, actual_per_thread, bitlen, ip_mss+offset, rand_mss+offset, op_ss+offset); 391 | prev_ctr += MINIONN_RELUS[i]; 392 | } 393 | ioArr[tid]->flush(); 394 | finalize_semi_honest(); 395 | } 396 | 397 | void thread_process_1(int tid, int party, uint64_t* inputs, uint64_t* rand_vals, int nrelu, int bitlen, uint64_t* ip_mss, uint64_t* rand_mss, uint64_t* op_ss) { 398 | setup_semi_honest(ioArr[tid], party); 399 | int prev_ctr=0; 400 | int len = *(&CIFAR10_RELUS+1)-CIFAR10_RELUS; 401 | for(int i=0; i<len; i++) { 402 | uint64_t num_relu_layer = CIFAR10_RELUS[i]; 403 | uint64_t nr_per_thread = num_relu_layer/num_threads; 404 | uint64_t r = num_relu_layer % num_threads; 405 | uint64_t actual_per_thread; 406 | if(tid ==num_threads-1) 407 | actual_per_thread = nr_per_thread + r; 408 | else 409 | actual_per_thread = nr_per_thread; 410 | uint64_t offset = prev_ctr + tid*nr_per_thread; 411 | //cout<<"Thread id: "<<tid<<", Offset: "<<offset<<", NR Threads: "<<nr_per_thread<<"Actual Threads: "<<actual_per_thread<<endl; 412 | //cout<<"Thread id:"<<tid<<", First Value (Out): "<<*(inputs+offset)<<endl; 413 | msi_relu_6(party, ioArr[tid], inputs+offset, rand_vals+offset, actual_per_thread, bitlen, ip_mss+offset, rand_mss+offset, op_ss+offset); 414 | prev_ctr += CIFAR10_RELUS[i]; 415 | } 416 | ioArr[tid]->flush(); 417 | finalize_semi_honest(); 418 | } 419 | 420 | 421 | 422 | int main(int argc, char** argv) { 423 | srand(time(NULL)); 424 | int port, party, nrelu, bitlen; 425 | //Parse input arguments and configure parameters 426 | parse_arguments(argc, argv, &party, &port, &bitlen, &nrelu); 427 | cout << "=====================Configuration======================" << endl; 428 | cout<<"Party Id: "<< party<<", Server IP Address: "<< address <<", Port: "<<port<<", NRelu: "<<nrelu<<", Bitlen: "<<bitlen<<endl; 429 | cout << "========================================================" << endl; 430 | //Prepare Inputs 431 | std::random_device rd; 432 | std::mt19937_64 eng(rd()); 433 | std::uniform_int_distribution<uint64_t> distr; 434 | 435 | uint64_t* inputs=(uint64_t *)malloc(nrelu*sizeof(uint64_t)); 436 | for(int i = 0; i < nrelu; ++i) 437 | inputs[i] = distr(eng)%prime_mod; 438 | 439 | uint64_t* rand_vals = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 440 | for(int i = 0; i < nrelu; ++i) 441 | rand_vals[i] = distr(eng)%prime_mod; 442 | 443 | uint64_t *ip_mss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 444 | uint64_t *rand_mss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 445 | uint64_t *op_ss = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 446 | 447 | for(int i=0; i <num_threads; i++) { 448 | ioArr[i] = new NetIO(party==ALICE ? nullptr : address.c_str(), port+i); 449 | } 450 | 451 | //Communication Initialization 452 | for(int i=0; i<num_threads; i++) 453 | start_comm[i] = ioArr[i]->counter; 454 | 455 | //Time Begin 456 | auto start = clock_start(); 457 | 458 | if(party == ALICE) { 459 | prg.random_data(&mac_key, 8); 460 | mac_key %= prime_mod; 461 | } 462 | 463 | std::thread relu_threads[num_threads]; 464 | for(int i=0; i<num_threads; i++) { 465 | if(choice_nn == MINIONN) { 466 | relu_threads[i] = std::thread(thread_process, i, party, inputs, rand_vals, nrelu, bitlen, ip_mss, rand_mss, op_ss); 467 | } else { 468 | relu_threads[i] = std::thread(thread_process_1, i, party, inputs, rand_vals, nrelu, bitlen, ip_mss, rand_mss, op_ss); 469 | } 470 | } 471 | 472 | //Join 473 | for(int i=0; i<num_threads; i++) { 474 | relu_threads[i].join(); 475 | } 476 | //Time End 477 | long long t = time_from(start); 478 | cout << "######################Performance#######################" <<endl; 479 | cout<<"Time Taken: "<<t<<" mus"<<endl; 480 | //Calculate Communication 481 | comm_sent = 0; 482 | for(int i=0; i<num_threads; i++) { 483 | comm_sent += (ioArr[i]->counter-start_comm[i]); 484 | } 485 | comm_sent = comm_sent>>20; 486 | cout<<"Sent Data (MB): "<<comm_sent<<endl; 487 | cout << "########################################################" <<endl; 488 | /* 489 | ioArr[0] = new NetIO(party==ALICE ? nullptr : address.c_str(), port); 490 | cout<<"nrelu: "<<nrelu<<endl; 491 | //Test Protocol 492 | if(verify) { 493 | if(party == BOB) { 494 | ioArr[0]->send_data(inputs, sizeof(uint64_t) * nrelu); 495 | ioArr[0]->send_data(ip_ss, sizeof(uint64_t) * nrelu); 496 | ioArr[0]->send_data(op_ss, sizeof(uint64_t) * nrelu); 497 | ioArr[0]->send_data(op_mss, sizeof(uint64_t) * nrelu); 498 | } else { 499 | uint64_t inputs_1[nrelu]; 500 | uint64_t inputs_res[nrelu]; 501 | uint64_t relu_res[nrelu]; 502 | 503 | uint64_t *ip_ssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 504 | uint64_t *op_ssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 505 | uint64_t *op_mssc = (uint64_t *)malloc(nrelu*sizeof(uint64_t)); 506 | 507 | ioArr[0]->recv_data(inputs_1, sizeof(uint64_t) * nrelu); 508 | ioArr[0]->recv_data(ip_ssc, sizeof(uint64_t) * nrelu); 509 | ioArr[0]->recv_data(op_ssc, sizeof(uint64_t) * nrelu); 510 | ioArr[0]->recv_data(op_mssc, sizeof(uint64_t) * nrelu); 511 | for(int i=0; i< nrelu; i++) { 512 | inputs_res[i] = (inputs_1[i] + inputs[i])%prime_mod; 513 | if(inputs_res[i] > prime_mod/2) { 514 | relu_res[i] = 0; 515 | } else { 516 | if(inputs_res[i] > 6) 517 | relu_res[i] = 6; 518 | else 519 | relu_res[i] = inputs_res[i]; 520 | } 521 | } 522 | 523 | uint64_t ip_shares, ip_corr, op_shares, op_corr, opm_shares, opm_corr; 524 | uint64_t ctr_ip, ctr_op, ctr_opm=0; 525 | 526 | for(int i=0; i<nrelu; i++) { 527 | ip_shares = (ip_ss[i]+ip_ssc[i])%prime_mod; 528 | ip_corr = mod_mult(mac_key,inputs_res[i]); 529 | if(ip_shares == ip_corr) 530 | ctr_ip++; 531 | else { 532 | cout<<"Index: "<<i<<endl; 533 | break; 534 | } 535 | 536 | op_shares = (op_ss[i]+op_ssc[i])%prime_mod; 537 | if(op_shares == relu_res[i]) 538 | ctr_op++; 539 | 540 | opm_shares = (op_mss[i] + op_mssc[i])%prime_mod; 541 | opm_corr = mod_mult(mac_key,relu_res[i]); 542 | if(opm_shares == opm_corr) 543 | ctr_opm++; 544 | } 545 | cout << "**********************Verification**********************" <<endl; 546 | cout<<"Correct Input Macs: "<< ctr_ip<<endl; 547 | cout<<"Correct Outputs: "<< ctr_op<<endl; 548 | cout<<"Correct Output Macs: "<< ctr_opm<<endl; 549 | cout << "********************************************************" <<endl; 550 | } 551 | }*/ 552 | //Performance Result 553 | 554 | } 555 | -------------------------------------------------------------------------------- /utils/utils.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shahakash28/simc/2a5fd092b52427cc9cac55b36ec50ae43ecee6be/utils/utils.cpp -------------------------------------------------------------------------------- /utils/utils.h: -------------------------------------------------------------------------------- 1 | #define MAX_THREADS 8 2 | 3 | //Slackoverflow Code For bit-wise shift 4 | #define SHL128(v, n) \ 5 | ({ \ 6 | __m128i v1, v2; \ 7 | \ 8 | if ((n) >= 64) \ 9 | { \ 10 | v1 = _mm_slli_si128(v, 8); \ 11 | v1 = _mm_slli_epi64(v1, (n) - 64); \ 12 | } \ 13 | else \ 14 | { \ 15 | v1 = _mm_slli_epi64(v, n); \ 16 | v2 = _mm_slli_si128(v, 8); \ 17 | v2 = _mm_srli_epi64(v2, 64 - (n)); \ 18 | v1 = _mm_or_si128(v1, v2); \ 19 | } \ 20 | v1; \ 21 | }) 22 | 23 | uint64_t mod_shift(uint64_t a, uint64_t b, uint64_t prime_mod) { 24 | __m128i temp, stemp; 25 | memcpy(&temp, &a, 8); 26 | stemp = SHL128(temp, b); 27 | 28 | uint64_t input[2]; 29 | input[0] = stemp[0]; 30 | input[1] = stemp[1]; 31 | 32 | uint64_t result = seal::util::barrett_reduce_128(input, mod); 33 | 34 | return result; 35 | } 36 | 37 | uint64_t mod_mult(uint64_t a, uint64_t b) { 38 | unsigned long long temp_result[2]; 39 | seal::util::multiply_uint64(a, b, temp_result); 40 | 41 | /*uint64_t input[2]; 42 | input[0] = res[0]; 43 | input[1] = res[1];*/ 44 | uint64_t result = seal::util::barrett_reduce_128(temp_result, mod); 45 | return result; 46 | } 47 | 48 | //Referred SCI OT repo's logic to pack ot messages 49 | void pack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 50 | uint64_t beg_idx = 0; 51 | uint64_t end_idx = 0; 52 | uint64_t beg_blk = 0; 53 | uint64_t end_blk = 0; 54 | uint64_t temp_blk = 0; 55 | uint64_t mask = (1ULL << bitlen) - 1; 56 | uint64_t pack_blk_size = 64; 57 | 58 | if (bitlen == 64) 59 | mask = -1; 60 | 61 | for (int i = 0; i < pack_size; i++) { 62 | pack_table[i] = 0; 63 | } 64 | 65 | for (int i = 0; i < batch_size; i++) { 66 | beg_idx = i * bitlen; 67 | end_idx = beg_idx + bitlen; 68 | end_idx -= 1; 69 | beg_blk = beg_idx / pack_blk_size; 70 | end_blk = end_idx / pack_blk_size; 71 | 72 | if (beg_blk == end_blk) { 73 | pack_table[beg_blk] ^= (ciphertexts[i] & mask) << (beg_idx % pack_blk_size); 74 | } else { 75 | temp_blk = (ciphertexts[i] & mask); 76 | pack_table[beg_blk] ^= (temp_blk) << (beg_idx % pack_blk_size); 77 | pack_table[end_blk] ^= (temp_blk) >> (pack_blk_size - (beg_idx % pack_blk_size)); 78 | } 79 | } 80 | } 81 | 82 | //Referred SCI OT repo's logic to unpack ot messages 83 | void unpack_decryption_table(uint64_t *pack_table, uint64_t *ciphertexts, int pack_size, int batch_size, int bitlen) { 84 | uint64_t beg_idx = 0; 85 | uint64_t end_idx = 0; 86 | uint64_t beg_blk = 0; 87 | uint64_t end_blk = 0; 88 | uint64_t temp_blk = 0; 89 | uint64_t mask = (1ULL << bitlen) - 1; 90 | uint64_t pack_blk_size = 64; 91 | 92 | for (int i = 0; i < batch_size; i++) { 93 | beg_idx = i * bitlen; 94 | end_idx = beg_idx + bitlen - 1; 95 | beg_blk = beg_idx / pack_blk_size; 96 | end_blk = end_idx / pack_blk_size; 97 | 98 | if (beg_blk == end_blk) { 99 | ciphertexts[i] = (pack_table[beg_blk] >> (beg_idx % pack_blk_size)) & mask; 100 | } else { 101 | ciphertexts[i] = 0; 102 | ciphertexts[i] ^= (pack_table[beg_blk] >> (beg_idx % pack_blk_size)); 103 | ciphertexts[i] ^= (pack_table[end_blk] << (pack_blk_size - (beg_idx % pack_blk_size))); 104 | ciphertexts[i] = ciphertexts[i] & mask; 105 | } 106 | } 107 | } 108 | 109 | void create_ciphertexts(Integer *garbled_data, block label_delta, uint64_t *ciphertexts, uint64_t* server_shares, int bitlen, int nrelu, uint64_t alpha, int l_idx) { 110 | uint64_t delta_int; 111 | memcpy(&delta_int, &label_delta[l_idx], 8); 112 | 113 | uint64_t mask = (1ULL << bitlen) - 1; 114 | uint64_t label_temp; 115 | uint64_t **random_val = (uint64_t **)malloc(nrelu*sizeof(uint64_t*)); 116 | uint8_t pnp, cpnp; 117 | for(int i=0; i<nrelu; i++) { 118 | random_val[i] = (uint64_t *)malloc(bitlen*sizeof(uint64_t)); 119 | } 120 | 121 | for(int i=0; i<nrelu; i++) { 122 | server_shares[i] = 0; 123 | for(int j=0; j<bitlen; j++) { 124 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 125 | prg.random_data(&random_val[i][j], 8); 126 | random_val[i][j] %= prime_mod; 127 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 128 | cpnp = 1 - pnp; 129 | ciphertexts[(i*bitlen+j)*2+pnp] = (random_val[i][j])^(label_temp & mask); 130 | ciphertexts[(i*bitlen+j)*2+cpnp] = ((random_val[i][j]+alpha)%prime_mod)^((label_temp^delta_int) & mask); 131 | if(i==0) { 132 | uint64_t l1 = label_temp & mask, l2 = label_temp^delta_int & mask; 133 | } 134 | server_shares[i] = (server_shares[i] + mod_shift(random_val[i][j],j,prime_mod))%prime_mod; 135 | } 136 | server_shares[i] = prime_mod - server_shares[i]; 137 | } 138 | } 139 | 140 | void decrypt_ciphertexts(Integer *garbled_data, uint64_t *ciphertexts, uint64_t* client_shares, int bitlen, int nrelu, int l_idx) { 141 | uint64_t label_temp; 142 | uint8_t pnp; 143 | uint64_t random_val; 144 | 145 | uint64_t mask = (1ULL << bitlen) - 1; 146 | 147 | for(int i=0; i<nrelu; i++) { 148 | client_shares[i] = 0; 149 | for(int j=0; j< bitlen; j++) { 150 | memcpy(&label_temp, &garbled_data[i].bits[j].bit[l_idx], 8); 151 | pnp = (garbled_data[i].bits[j].bit[0]) & 1; 152 | random_val = ciphertexts[(i*bitlen+j)*2+pnp]^(label_temp & mask); 153 | client_shares[i] = (client_shares[i] + mod_shift(random_val,j,prime_mod))%prime_mod; 154 | } 155 | } 156 | } 157 | --------------------------------------------------------------------------------