├── .gitignore ├── src ├── crypto │ ├── prp.hpp │ ├── hash.cpp │ ├── prg.hpp │ ├── ot │ │ ├── base.hpp │ │ ├── extension.hpp │ │ ├── correlated.hpp │ │ └── base.cpp │ ├── mitccrh.hpp │ ├── hash.hpp │ ├── group.hpp │ ├── group.cpp │ └── aes.hpp ├── programs │ ├── registry.cpp │ ├── registry.hpp │ ├── full_sort.cpp │ ├── real_sum.cpp │ ├── merge_sorted.cpp │ ├── real_cpir.cpp │ ├── matrix_vector_multiply.cpp │ ├── real_matrix_vector_multiply.cpp │ ├── password.cpp │ ├── matrix_multiply.cpp │ ├── aspirin_seq.cpp │ ├── real_statistics.cpp │ ├── util.hpp │ ├── binary_fc_layer.cpp │ ├── aspirin.cpp │ └── real_matrix_multiply.cpp ├── platform │ ├── misc.cpp │ ├── misc.hpp │ ├── memory.cpp │ ├── network.hpp │ ├── network.cpp │ └── filesystem.cpp ├── protocols │ ├── registry.cpp │ ├── ckks_constants.hpp │ ├── plaintext.cpp │ ├── tfhe.cpp │ ├── plaintext.hpp │ ├── ckks.cpp │ ├── registry.hpp │ ├── tfhe.hpp │ └── tfhe_scheme.hpp ├── executables │ ├── disassemble.cpp │ ├── mage.cpp │ └── planner.cpp ├── memprog │ ├── annotation.cpp │ ├── annotation.hpp │ └── pipeline.cpp ├── dsl │ └── util.hpp └── util │ ├── stats.hpp │ ├── progress.cpp │ ├── progress.hpp │ └── misc.hpp ├── tests ├── test_main.cpp ├── test_circbuffer.cpp └── test_prioqueue.cpp ├── Makefile ├── README.md └── install_deps.sh /.gitignore: -------------------------------------------------------------------------------- 1 | bin/* 2 | doxygen/* 3 | *.cir 4 | *.lin 5 | *.ann 6 | *.pln 7 | -------------------------------------------------------------------------------- /src/crypto/prp.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * This file is based heavily on the file utils/prp.h in EMP-toolkit. 3 | */ 4 | 5 | #ifndef MAGE_CRYPTO_PRP_HPP_ 6 | #define MAGE_CRYPTO_PRP_HPP_ 7 | 8 | #include 9 | #include "crypto/aes.hpp" 10 | 11 | namespace mage::crypto { 12 | /* Simple AES wrapper. */ 13 | class PRP { 14 | public: 15 | PRP(const void* seed = fix_key) { 16 | this->aes_set_key(_mm_loadu_si128(reinterpret_cast(seed))); 17 | } 18 | 19 | PRP(const block& seed) { 20 | this->aes_set_key(seed); 21 | } 22 | 23 | void aes_set_key(const block& v) { 24 | AES_set_encrypt_key(v, &this->aes); 25 | } 26 | 27 | public: 28 | AES_KEY aes; 29 | }; 30 | } 31 | 32 | #endif 33 | -------------------------------------------------------------------------------- /tests/test_main.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #define BOOST_TEST_DYN_LINK 23 | #define BOOST_TEST_MODULE "MAGE" 24 | #include "boost/test/unit_test.hpp" 25 | -------------------------------------------------------------------------------- /src/programs/registry.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "memprog/program.hpp" 23 | #include "programs/registry.hpp" 24 | 25 | namespace mage::programs { 26 | memprog::DefaultProgram* program_ptr = nullptr; 27 | } 28 | -------------------------------------------------------------------------------- /src/platform/misc.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2021 Sam Kumar 3 | * Copyright (C) 2021 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "platform/misc.hpp" 23 | #include 24 | #include 25 | #include 26 | 27 | namespace mage::platform { 28 | void get_terminal_size(TerminalSize& ts) { 29 | struct winsize ws; 30 | if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws) == 0) { 31 | ts.num_rows = ws.ws_row; 32 | ts.num_cols = ws.ws_col; 33 | } else { 34 | ts.num_rows = 0; 35 | ts.num_cols = 0; 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/crypto/hash.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "crypto/hash.hpp" 23 | #include 24 | #include 25 | #include 26 | #include "crypto/block.hpp" 27 | 28 | namespace mage::crypto { 29 | void hash(const void* src, std::size_t src_length, std::uint8_t* into) { 30 | SHA256(static_cast(src), src_length, into); 31 | } 32 | 33 | block hash_to_block(const void* src, std::size_t src_length) { 34 | std::uint8_t into[hash_length] __attribute__((aligned(sizeof(block)))); 35 | hash(src, src_length, into); 36 | return *reinterpret_cast(&into[0]); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/programs/registry.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_PROGRAMS_REGISTRY_HPP_ 23 | #define MAGE_PROGRAMS_REGISTRY_HPP_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include "memprog/program.hpp" 29 | #include "util/config.hpp" 30 | #include "util/registry.hpp" 31 | 32 | namespace mage::programs { 33 | struct ProgramOptions { 34 | const util::ConfigValue* worker_config; 35 | WorkerID num_workers; 36 | WorkerID worker_index; 37 | std::uint64_t problem_size; 38 | }; 39 | 40 | using RegisteredProgram = util::CallableRegistryEntry; 41 | using RegisterProgram = util::Register; 42 | 43 | extern memprog::DefaultProgram* program_ptr; 44 | } 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CXX = clang++ 2 | CXXFLAGS = -std=c++2a -Ofast -DNDEBUG -march=native -ggdb3 -pthread -I./src/ 3 | LDFLAGS = -pthread -laio -lssl -lcrypto -lyaml-cpp 4 | 5 | # Uncomment for tfhe support 6 | CXXFLAGS += -DTFHE 7 | LDFLAGS += -ltfhe-spqlios-fma 8 | 9 | # Uncomment for ckks support 10 | CXXFLAGS += -DCKKS -I/usr/local/include/SEAL-3.6 11 | LDFLAGS += -lseal 12 | 13 | MAGE_DIRS = src src/dsl src/platform src/crypto src/crypto/ot src/memprog src/engine src/programs src/protocols src/util 14 | MAGE_CPP_SOURCES = $(foreach dir,$(MAGE_DIRS),$(wildcard $(dir)/*.cpp)) 15 | MAGE_HEADERS = $(foreach dir,$(MAGE_DIRS),$(wildcard $(dir)/*.hpp)) 16 | 17 | MAGE_TEST_SOURCES = $(wildcard tests/*.cpp) 18 | MAGE_EXECUTABLE_SOURCES = $(wildcard src/executables/*.cpp) 19 | MAGE_EXECUTABLE_NAMES = $(foreach file,$(MAGE_EXECUTABLE_SOURCES),$(notdir $(basename $(file)))) 20 | 21 | BINDIR = bin 22 | MAGE_OBJECTS = $(addprefix $(BINDIR)/,$(MAGE_CPP_SOURCES:.cpp=.o)) 23 | TEST_OBJECTS = $(addprefix $(BINDIR)/,$(MAGE_TEST_SOURCES:.cpp=.o)) 24 | EXECUTABLES = $(addprefix $(BINDIR)/,$(MAGE_EXECUTABLE_NAMES)) 25 | 26 | .PHONY: clean 27 | 28 | default: $(EXECUTABLES) 29 | 30 | all: $(EXECUTABLES) tests 31 | 32 | tests: $(BINDIR)/test 33 | 34 | $(BINDIR)/test: $(MAGE_OBJECTS) $(TEST_OBJECTS) 35 | $(CXX) $(LDFLAGS) $+ -lboost_unit_test_framework -o $@ 36 | 37 | $(BINDIR)/tests/%.o: tests/%.cpp $(MAGE_HEADERS) 38 | mkdir -p $(dir $@) 39 | $(CXX) $(CXXFLAGS) -c $< -o $@ 40 | 41 | $(BINDIR)/%: $(MAGE_OBJECTS) $(BINDIR)/src/executables/%.o 42 | $(CXX) $(LDFLAGS) $+ -o $@ 43 | 44 | $(BINDIR)/src/%.o: src/%.cpp $(MAGE_HEADERS) 45 | mkdir -p $(dir $@) 46 | $(CXX) $(CXXFLAGS) -c $< -o $@ 47 | 48 | clean: 49 | rm -rf bin 50 | -------------------------------------------------------------------------------- /src/crypto/prg.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * This file is based heavily on the file utils/prg.h in EMP-toolkit. 3 | */ 4 | 5 | #ifndef MAGE_CRYPTO_PRG_HPP_ 6 | #define MAGE_CRYPTO_PRG_HPP_ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include "crypto/aes.hpp" 14 | 15 | namespace mage::crypto { 16 | /* Pseudorandom generator implemented using AES-CTR. */ 17 | class PRG { 18 | public: 19 | PRG(const void* seed = nullptr) : counter(0) { 20 | if (seed != nullptr) { 21 | const block* seed_block = reinterpret_cast(seed); 22 | this->set_seed(*seed_block); 23 | return; 24 | } 25 | block v; 26 | int rv = RAND_bytes(reinterpret_cast(&v), sizeof(v)); 27 | if (rv == 0) { 28 | ERR_print_errors_fp(stderr); 29 | std::abort(); 30 | } 31 | this->set_seed(v); 32 | } 33 | 34 | void set_seed(const block& key) { 35 | AES_set_encrypt_key(key, &this->aes); 36 | } 37 | 38 | void random_block(block* data, int count = 1) { 39 | int i; 40 | for (i = 0; i < count; i++) { 41 | data[i] = makeBlock(0LL, this->counter++); 42 | } 43 | for(i = 0; i < count - AES_BATCH_SIZE; i += AES_BATCH_SIZE) { 44 | AES_ecb_encrypt_blks(data + i, AES_BATCH_SIZE, &this->aes); 45 | } 46 | AES_ecb_encrypt_blks(data + i, std::min(count - i, AES_BATCH_SIZE), &this->aes); 47 | } 48 | 49 | private: 50 | std::uint64_t counter; 51 | AES_KEY aes; 52 | }; 53 | } 54 | 55 | #endif 56 | -------------------------------------------------------------------------------- /src/crypto/ot/base.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_CRYPTO_OT_BASE_HPP_ 23 | #define MAGE_CRYPTO_OT_BASE_HPP_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include "crypto/block.hpp" 31 | #include "crypto/group.hpp" 32 | #include "crypto/hash.hpp" 33 | #include "crypto/prg.hpp" 34 | #include "util/filebuffer.hpp" 35 | 36 | namespace mage::crypto::ot { 37 | /* Basic oblivious transfer, on top of which we can implement OT extension. */ 38 | void base_send(const DDHGroup& g, util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, const std::pair* choices, std::size_t num_choices); 39 | void base_choose(const DDHGroup& g, util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, const bool* choices, block* results, std::size_t num_choices); 40 | } 41 | 42 | #endif 43 | -------------------------------------------------------------------------------- /src/platform/misc.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2021 Sam Kumar 3 | * Copyright (C) 2021 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | /** 23 | * @file platform/misc.hpp 24 | * @brief Miscellaneous system-level utilities. 25 | */ 26 | 27 | #include 28 | 29 | namespace mage::platform { 30 | /** 31 | * @brief Describes the size of the terminal window. 32 | */ 33 | struct TerminalSize { 34 | std::uint32_t num_rows; 35 | std::uint32_t num_cols; 36 | }; 37 | 38 | /** 39 | * @brief Populates @p tz with the terminal window size. 40 | * 41 | * If the terminal size could not be obtained (e.g., because the system 42 | * does not support it, because standard output is not connected to a 43 | * terminal window, or the operation is unsupported) then @p tz is 44 | * populated with zero rows and zero columns. 45 | * 46 | * @param[out] tz The structure to populate with the terminal window size. 47 | */ 48 | void get_terminal_size(TerminalSize& ts); 49 | } 50 | -------------------------------------------------------------------------------- /src/crypto/mitccrh.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MAGE_CRYPTO_MITCCRH_HPP_ 2 | #define MAGE_CRYPTO_MITCCRH_HPP_ 3 | 4 | #include 5 | #include 6 | #include "crypto/aes_cpu.hpp" 7 | #include "crypto/block.hpp" 8 | 9 | namespace mage::crypto { 10 | class MiTCCRH { 11 | public: 12 | ROUND_KEYS key_schedule[KS_BATCH_N]; 13 | int key_used = KS_BATCH_N; 14 | block start_point; 15 | 16 | MiTCCRH() { 17 | } 18 | 19 | void setS(block sin) { 20 | this->start_point = sin; 21 | } 22 | 23 | void renew_ks(uint64_t gid) { 24 | switch (KS_BATCH_N) { 25 | case 2: 26 | AES_ks2_index(start_point, gid, key_schedule); break; 27 | case 4: 28 | AES_ks4_index(start_point, gid, key_schedule); break; 29 | case 8: 30 | AES_ks8_index(start_point, gid, key_schedule); break; 31 | default: 32 | std::abort(); 33 | } 34 | key_used = 0; 35 | } 36 | 37 | void k2_h2(block A, block B, block *H) { 38 | block keys[2], masks[2]; 39 | keys[0] = sigma(A); 40 | keys[1] = sigma(B); 41 | masks[0] = keys[0]; 42 | masks[1] = keys[1]; 43 | 44 | AES_ecb_ccr_ks2_enc2(keys, keys, &key_schedule[key_used]); 45 | key_used += 2; 46 | 47 | H[0] = xorBlocks(keys[0], masks[0]); 48 | H[1] = xorBlocks(keys[1], masks[1]); 49 | } 50 | 51 | void k2_h4(block A0, block A1, block B0, block B1, block *H) { 52 | block keys[4], masks[4]; 53 | keys[0] = sigma(A0); 54 | keys[1] = sigma(A1); 55 | keys[2] = sigma(B0); 56 | keys[3] = sigma(B1); 57 | memcpy(masks, keys, sizeof keys); 58 | 59 | AES_ecb_ccr_ks2_enc4(keys, keys, &key_schedule[key_used]); 60 | key_used += 2; 61 | 62 | H[0] = xorBlocks(keys[0], masks[0]); 63 | H[1] = xorBlocks(keys[1], masks[1]); 64 | H[2] = xorBlocks(keys[2], masks[2]); 65 | H[3] = xorBlocks(keys[3], masks[3]); 66 | } 67 | }; 68 | } 69 | 70 | #endif 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | MAGE: Memory-Aware Garbling Engine 2 | ================================== 3 | MAGE is an execution engine for secure computation protocols, such as secure multi-party computation (SMPC) and homomorphic encryption (HE). MAGE is designed to execute secure computation efficiently even when it does not fit in memory. 4 | 5 | The implementation of MAGE in this repository accompanies our OSDI 2021 paper: 6 | 7 | Sam Kumar, David E. Culler, and Raluca Ada Popa. MAGE: Nearly Zero-Cost Virtual Memory for Secure Computation. OSDI 2021. 8 | 9 | **WARNING: This implementation is a prototype designed for academic study and proof-of-concept use cases. It has not received code review and is *not* production-ready.** 10 | 11 | Secure computation is inherently *oblivious*—that is, it contains no data-dependent memory accesses. The reason for this lies in maintaining security if memory accesses depended on the data, then an attacker could potentially analyze the memory access pattern and infer the contents of sensitive data. This is a problem because the point of using secure computation is to compute on sensitive data without revealing the contents of that data. 12 | 13 | MAGE leverages the *oblivious* nature of secure computation to manage memory efficiently. Specifically, MAGE introduces a planning phase in which it analyzes in advance the computation it is going to perform. The result of MAGE's planning phase is a *memory program*, an execution plan for performing the computation. The memory program can be understood as (roughly) a pre-processed execution trace (with all functions inlined and all loops unrolled), including preplanned data transfers between memory and storage. At runtime, MAGE uses the memory program to transfer data between memory and storage very efficiently, in effect providing virtual memory at a very low cost. 14 | 15 | How to Build and Use MAGE 16 | ------------------------- 17 | Instructions to build MAGE, and a tutorial for using it, are available on the [MAGE wiki](https://github.com/ucbrise/mage/wiki). 18 | 19 | To build documentation, run `doxygen` in the repository's root directory. 20 | 21 | License 22 | ------- 23 | The code in this repository is available under version 3 of the GNU General Public License (GPL). 24 | -------------------------------------------------------------------------------- /src/protocols/registry.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include "protocols/registry.hpp" 29 | 30 | namespace mage::protocols { 31 | std::vector evaluator_synonyms = { "evaluator", "0", "bob" }; 32 | std::vector garbler_synonyms = { "garbler", "1", "alice" }; 33 | 34 | std::optional parse_party_id(const std::string& party) { 35 | if (std::find(evaluator_synonyms.begin(), evaluator_synonyms.end(), party) != evaluator_synonyms.end()) { 36 | return evaluator_party_id; 37 | } else if (std::find(garbler_synonyms.begin(), garbler_synonyms.end(), party) != garbler_synonyms.end()) { 38 | return garbler_party_id; 39 | } else { 40 | std::size_t length; 41 | unsigned long long party_id = std::stoull(party, &length); 42 | if (length != party.length()) { 43 | return {}; 44 | } 45 | return static_cast(party_id); 46 | } 47 | } 48 | 49 | memprog::AllocationSize identity_physical_size(std::uint64_t logical_width, memprog::PlaceableType type) { 50 | return logical_width; 51 | } 52 | 53 | RegisterPlacementPlugin identity_plugin("identity_plugin", "Object's MAGE-virtual size is its logical width", identity_physical_size); 54 | } 55 | -------------------------------------------------------------------------------- /src/programs/full_sort.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::merge_sorted { 32 | template 33 | void create_full_sort_circuit(const ProgramOptions& args) { 34 | int input_array_length = args.problem_size * 2; 35 | 36 | ClusterUtils utils; 37 | utils.self_id = args.worker_index; 38 | utils.num_proc = args.num_workers; 39 | 40 | ShardedArray> list(input_array_length, args.worker_index, args.num_workers, Layout::Cyclic); 41 | list.for_each([=](std::size_t i, auto& elem) { 42 | elem.data.mark_input(i < args.problem_size ? Party::Garbler : Party::Evaluator); 43 | }); 44 | 45 | program_ptr->print_stats(); 46 | program_ptr->start_timer(); 47 | 48 | parallel_sorter(list); 49 | 50 | program_ptr->stop_timer(); 51 | program_ptr->print_stats(); 52 | 53 | list.for_each([=](std::size_t i, auto& elem) { 54 | elem.data.mark_output(); 55 | }); 56 | } 57 | 58 | RegisterProgram full_sort("full_sort", "Bitonic Sort (problem_size = number of elements per party)", create_full_sort_circuit<>); 59 | } 60 | -------------------------------------------------------------------------------- /src/protocols/ckks_constants.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_SCHEMES_CKKS_CONSTANTS_HPP_ 23 | #define MAGE_SCHEMES_CKKS_CONSTANTS_HPP_ 24 | 25 | #include 26 | #include 27 | 28 | namespace mage::protocols::ckks { 29 | inline double ckks_scale = std::pow(2.0, 40); 30 | 31 | constexpr std::uint64_t ckks_ciphertext_size(std::int32_t level, bool normalized) { 32 | if (normalized) { 33 | if (level == 0) { 34 | return 131689; 35 | } else if (level == 1) { 36 | return 263273; 37 | } else if (level == 2) { 38 | return 394857; 39 | } else { 40 | return UINT64_MAX; 41 | } 42 | } else { 43 | if (level == 0) { 44 | return UINT64_MAX; 45 | } else if (level == 1) { 46 | return 394857; 47 | } else if (level == 2) { 48 | return 592233; 49 | } else { 50 | return UINT64_MAX; 51 | } 52 | } 53 | } 54 | 55 | constexpr std::uint64_t ckks_plaintext_size(std::int32_t level) { 56 | if (level == 0) { 57 | return 65624; 58 | } else if (level == 1) { 59 | return 131160; 60 | } else if (level == 2) { 61 | return 196696; 62 | } else { 63 | return UINT64_MAX; 64 | } 65 | } 66 | } 67 | 68 | #endif 69 | -------------------------------------------------------------------------------- /tests/test_circbuffer.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #define BOOST_TEST_DYN_LINK 23 | #include "boost/test/unit_test.hpp" 24 | #include "boost/test/data/test_case.hpp" 25 | #include "boost/test/data/monomorphic.hpp" 26 | 27 | #include 28 | #include 29 | 30 | #include "util/circbuffer.hpp" 31 | 32 | namespace bdata = boost::unit_test::data; 33 | using mage::util::CircularBuffer; 34 | 35 | constexpr const std::size_t circbuf_capacity_shift = 6; 36 | constexpr const std::size_t circbuf_capacity = 1 << circbuf_capacity_shift; 37 | constexpr const std::uint64_t num_iterations = 100; 38 | 39 | BOOST_DATA_TEST_CASE(test_circbuffer_wrap, bdata::xrange(circbuf_capacity), step_size) { 40 | CircularBuffer cb(circbuf_capacity_shift); 41 | 42 | std::uint64_t counter = 0; 43 | for (std::uint64_t i = 0; i != num_iterations; i++) { 44 | std::vector x(step_size); 45 | for (std::uint64_t i = 0; i != step_size; i++) { 46 | x[i] = i; 47 | } 48 | cb.write_unchecked(x.data(), step_size); 49 | 50 | std::vector y(step_size); 51 | cb.read_unchecked(y.data(), step_size); 52 | 53 | BOOST_REQUIRE(x.size() == step_size); 54 | BOOST_REQUIRE(x.size() == y.size()); 55 | for (std::uint64_t k = 0; k != x.size(); k++) { 56 | BOOST_CHECK_MESSAGE(x[k] == y[k], "x[" << k << "] is " << x[k] << ", but y[" << k << "] is " << y[k]); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/platform/memory.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | namespace mage::platform { 28 | void* allocate_resident_memory(std::size_t num_bytes, bool lazy) { 29 | int flags = MAP_PRIVATE | MAP_ANONYMOUS; 30 | if (!lazy) { 31 | flags |= (MAP_NORESERVE | MAP_POPULATE); 32 | } 33 | void* region = mmap(NULL, num_bytes, PROT_READ | PROT_WRITE, flags, -1, 0); 34 | if (region == MAP_FAILED) { 35 | std::perror("allocate_resident_memory -> mmap"); 36 | std::abort(); 37 | } 38 | return region; 39 | } 40 | 41 | void deallocate_resident_memory(void* memory, std::size_t num_bytes) { 42 | if (munmap(memory, num_bytes) != 0) { 43 | std::perror("deallocate_resident_memory -> munmap"); 44 | std::abort(); 45 | } 46 | } 47 | 48 | void* map_file(int fd, std::size_t length, bool mutate) { 49 | void* region = mmap(NULL, length, PROT_READ | PROT_WRITE, mutate ? MAP_SHARED : MAP_PRIVATE, fd, 0); 50 | if (region == MAP_FAILED) { 51 | std::perror("map_file -> mmap"); 52 | std::abort(); 53 | } 54 | return region; 55 | } 56 | 57 | void unmap_file(void* memory, std::size_t length) { 58 | if (munmap(memory, length) != 0) { 59 | std::perror("unmap_file -> munmap"); 60 | std::abort(); 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/executables/disassemble.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include "instruction.hpp" 24 | #include "programfile.hpp" 25 | 26 | int main(int argc, char** argv) { 27 | if (argc != 2) { 28 | std::cout << "Usage: " << argv[0] << " file.memprog" << std::endl; 29 | return EXIT_FAILURE; 30 | } 31 | std::string filename(argv[1]); 32 | if (filename.ends_with(".memprog") || filename.ends_with(".repprog")) { 33 | mage::PhysProgramFileReader program(argv[1]); 34 | mage::InstructionNumber num_instructions = program.get_header().num_instructions; 35 | for (mage::InstructionNumber i = 0; i != num_instructions; i++) { 36 | mage::PackedPhysInstruction& phys = program.start_instruction(); 37 | std::cout << phys << std::endl; 38 | program.finish_instruction(phys.size()); 39 | } 40 | } else if (filename.ends_with(".prog")) { 41 | mage::VirtProgramFileReader program(argv[1]); 42 | mage::InstructionNumber num_instructions = program.get_header().num_instructions; 43 | for (mage::InstructionNumber i = 0; i != num_instructions; i++) { 44 | mage::PackedVirtInstruction& virt = program.start_instruction(); 45 | std::cout << virt << std::endl; 46 | program.finish_instruction(virt.size()); 47 | } 48 | } else { 49 | std::cout << "Error: could not infer bytecode type from file extension" << std::endl; 50 | return EXIT_FAILURE; 51 | } 52 | 53 | return EXIT_SUCCESS; 54 | } 55 | -------------------------------------------------------------------------------- /src/protocols/plaintext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include "engine/andxor.hpp" 24 | #include "protocols/plaintext.hpp" 25 | #include "protocols/registry.hpp" 26 | 27 | namespace mage::protocols::plaintext { 28 | void run_plaintext(const EngineOptions& args) { 29 | std::string file_base = args.problem_name + "_" + std::to_string(args.self_id); 30 | std::string prog_file = file_base + ".memprog"; 31 | std::string output_file = file_base + ".output"; 32 | std::string evaluator_input_file = file_base + "_evaluator.input"; 33 | std::string garbler_input_file = file_base + "_garbler.input"; 34 | 35 | std::chrono::time_point start; 36 | std::chrono::time_point end; 37 | 38 | util::Configuration& c = *args.config; 39 | PlaintextEvaluationEngine p(garbler_input_file.c_str(), evaluator_input_file.c_str(), output_file.c_str()); 40 | start = std::chrono::steady_clock::now(); 41 | engine::ANDXOREngine executor(args.cluster, c["parties"][args.party_id]["workers"][args.self_id], p, prog_file.c_str()); 42 | executor.execute_program(); 43 | end = std::chrono::steady_clock::now(); 44 | std::chrono::milliseconds ms = std::chrono::duration_cast(end - start); 45 | std::cerr << ms.count() << " ms" << std::endl; 46 | } 47 | 48 | RegisterProtocol plaintext("plaintext", "Plaintext simulation of halfgates (for testing)", run_plaintext, "identity_plugin"); 49 | } 50 | -------------------------------------------------------------------------------- /src/crypto/hash.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_CRYPTO_HASH_HPP_ 23 | #define MAGE_CRYPTO_HASH_HPP_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include "crypto/block.hpp" 30 | 31 | namespace mage::crypto { 32 | constexpr const std::uint16_t hash_length = SHA256_DIGEST_LENGTH; 33 | void hash(const void* src, std::size_t src_length, std::uint8_t* into); 34 | block hash_to_block(const void* src, std::size_t src_length); 35 | 36 | class Hasher { 37 | public: 38 | static constexpr const std::uint32_t output_length = SHA256_DIGEST_LENGTH; 39 | 40 | Hasher() : active(true) { 41 | SHA256_Init(&this->ctx); 42 | } 43 | 44 | Hasher(const void* src, std::size_t src_length) : Hasher() { 45 | this->update(src, src_length); 46 | } 47 | 48 | void update(const void* src, std::size_t src_length) { 49 | assert(this->active); 50 | SHA256_Update(&this->ctx, src, src_length); 51 | } 52 | 53 | void output(std::uint8_t* into) { 54 | assert(this->active); 55 | this->active = false; 56 | SHA256_Final(into, &this->ctx); 57 | } 58 | 59 | block output_block() { 60 | std::uint8_t into[hash_length] __attribute__((aligned(sizeof(block)))); 61 | this->output(into); 62 | return *reinterpret_cast(&into[0]); 63 | } 64 | 65 | private: 66 | SHA256_CTX ctx; 67 | bool active; 68 | }; 69 | } 70 | 71 | #endif 72 | -------------------------------------------------------------------------------- /src/crypto/group.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * This file is based heavily on the file utils/group.h in EMP-toolkit. 3 | */ 4 | 5 | #ifndef MAGE_CRYPTO_GROUP_HPP_ 6 | #define MAGE_CRYPTO_GROUP_HPP_ 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace mage::crypto { 14 | // class BigInt { 15 | // public: 16 | // BigInt(); 17 | // BigInt(const BigInt& other); 18 | // BigInt& operator =(const BigInt& other); 19 | // ~BigInt(); 20 | // 21 | // private: 22 | // BIGNUM* n; 23 | // }; 24 | 25 | class DDHGroupElement; 26 | class ScalarMod; 27 | 28 | class DDHGroup { 29 | friend class DDHGroupElement; 30 | friend class ScalarMod; 31 | 32 | public: 33 | DDHGroup(); 34 | ~DDHGroup(); 35 | 36 | private: 37 | BN_CTX* bn_ctx = nullptr; 38 | EC_GROUP* ec_group = nullptr; 39 | BIGNUM* order = nullptr; 40 | }; 41 | 42 | class ScalarMod { 43 | friend class DDHGroupElement; 44 | public: 45 | ScalarMod(const DDHGroup& g); 46 | ScalarMod(const ScalarMod& other); 47 | ~ScalarMod(); 48 | 49 | void set_random(); 50 | void multiply(const ScalarMod& a, const ScalarMod& b); 51 | 52 | private: 53 | const DDHGroup& group; 54 | BIGNUM* n = nullptr; 55 | }; 56 | 57 | class DDHGroupElement { 58 | friend class DDHGroup; 59 | 60 | public: 61 | DDHGroupElement(const DDHGroup& g); 62 | DDHGroupElement(const DDHGroupElement& other); 63 | DDHGroupElement& operator =(const DDHGroupElement& other); 64 | ~DDHGroupElement(); 65 | 66 | void marshal_uncompressed(std::uint8_t* buffer, std::size_t length) const; 67 | std::size_t marshalled_uncompressed_size() const; 68 | void unmarshal_uncompressed(const std::uint8_t* buffer, std::size_t length); 69 | 70 | void set_generator(); 71 | void add(const DDHGroupElement& a, const DDHGroupElement& __restrict b); 72 | void multiply_generator(const ScalarMod& __restrict m); 73 | void multiply_restrict(const DDHGroupElement& __restrict base, const ScalarMod& __restrict m); 74 | void invert(); 75 | bool operator ==(const DDHGroupElement& other); 76 | 77 | private: 78 | const DDHGroup& group; 79 | EC_POINT* point = nullptr; 80 | }; 81 | } 82 | 83 | #endif 84 | -------------------------------------------------------------------------------- /src/protocols/tfhe.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include "engine/andxor.hpp" 24 | #include "protocols/registry.hpp" 25 | #include "protocols/tfhe.hpp" 26 | 27 | namespace mage::protocols::tfhe { 28 | void run_tfhe(const EngineOptions& args) { 29 | std::string file_base = args.problem_name + "_" + std::to_string(args.self_id); 30 | std::string prog_file = file_base + ".memprog"; 31 | std::string output_file = file_base + ".output"; 32 | std::string evaluator_input_file = file_base + "_evaluator.input"; 33 | std::string garbler_input_file = file_base + "_garbler.input"; 34 | 35 | std::chrono::time_point start; 36 | std::chrono::time_point end; 37 | 38 | util::Configuration& c = *args.config; 39 | TFHEEngine p(garbler_input_file.c_str(), evaluator_input_file.c_str(), output_file.c_str()); 40 | start = std::chrono::steady_clock::now(); 41 | engine::ANDXOREngine executor(args.cluster, c["parties"][args.party_id]["workers"][args.self_id], p, prog_file.c_str()); 42 | executor.execute_program(); 43 | end = std::chrono::steady_clock::now(); 44 | std::chrono::milliseconds ms = std::chrono::duration_cast(end - start); 45 | std::cerr << ms.count() << " ms" << std::endl; 46 | } 47 | 48 | memprog::AllocationSize tfhe_physical_size(std::uint64_t logical_width, memprog::PlaceableType type) { 49 | return logical_width; 50 | } 51 | 52 | RegisterProtocol tfhe("tfhe", "Fast Fully Homomorphic Encryption over the Torus", run_tfhe, "identity_plugin"); 53 | } 54 | -------------------------------------------------------------------------------- /src/programs/real_sum.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::real_sum { 32 | void create_real_sum_circuit(const ProgramOptions& args) { 33 | int input_array_length = args.problem_size; 34 | 35 | ClusterUtils utils; 36 | utils.self_id = args.worker_index; 37 | utils.num_proc = args.num_workers; 38 | 39 | ShardedArray> inputs(input_array_length, args.worker_index, args.num_workers, Layout::Blocked); 40 | inputs.for_each([=](std::size_t i, auto& input) { 41 | input.mark_input(); 42 | }); 43 | 44 | program_ptr->print_stats(); 45 | program_ptr->start_timer(); 46 | 47 | std::vector>& locals = inputs.get_locals(); 48 | 49 | LeveledBatch<0, true> local_result; 50 | if (locals.size() == 0) { 51 | local_result = LeveledBatch<0, true>(0); 52 | } else { 53 | local_result = std::move(locals[0]); 54 | for (std::size_t i = 1; i != locals.size(); i++) { 55 | local_result = local_result + locals[i]; 56 | } 57 | } 58 | 59 | std::optional> global_result = utils.reduce_aggregates>(0, local_result, [](LeveledBatch<0, true>& a, LeveledBatch<0, true>& b) -> LeveledBatch<0, true> { 60 | return a + b; 61 | }); 62 | 63 | program_ptr->stop_timer(); 64 | program_ptr->print_stats(); 65 | 66 | if (args.worker_index == 0) { 67 | global_result->mark_output(); 68 | } 69 | } 70 | 71 | RegisterProgram real_sum("real_sum", "Compute sum of array of real numbers (problem_size = number of elements)", create_real_sum_circuit); 72 | } 73 | -------------------------------------------------------------------------------- /install_deps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | for flag in $1 $2 $3 $4 4 | do 5 | case $flag in 6 | --install-mage-deps) 7 | # Install apt packages 8 | sudo apt update 9 | sudo apt install -y git build-essential clang cmake libssl-dev libaio-dev 10 | 11 | # Install yaml-cpp version 0.63 12 | wget https://github.com/jbeder/yaml-cpp/archive/yaml-cpp-0.6.3.tar.gz 13 | tar zxf yaml-cpp-0.6.3.tar.gz 14 | pushd yaml-cpp-yaml-cpp-0.6.3 15 | mkdir build 16 | pushd build 17 | cmake -DYAML_BUILD_SHARED_LIBS=ON .. 18 | make -j 2 19 | sudo make install 20 | popd 21 | popd 22 | 23 | # Install tfhe 24 | git clone https://github.com/tfhe/tfhe 25 | pushd tfhe 26 | make -j 2 27 | sudo make install 28 | popd 29 | 30 | # Install SEAL 31 | wget https://github.com/microsoft/SEAL/archive/v3.6.1.tar.gz 32 | tar zxf v3.6.1.tar.gz 33 | pushd SEAL-3.6.1 34 | cmake -S . -B build -DSEAL_USE_ZLIB=OFF -DBUILD_SHARED_LIBS=ON 35 | cmake --build build -j 2 36 | sudo cmake --install build 37 | popd 38 | 39 | # Update shared libraries 40 | sudo ldconfig 41 | ;; 42 | --install-utils) 43 | # Other useful tools for experimentation 44 | sudo apt install -y tmux iperf3 python3 htop net-tools cgroup-tools 45 | 46 | # EMP-Toolkit dependencies 47 | sudo apt install -y cmake git build-essential libssl-dev libgmp-dev libboost-all-dev 48 | ;; 49 | --setup-wan-tcp) 50 | # Modify /etc/sysctl.conf to widen the TCP windows for WAN experiments 51 | echo "# Increase TCP buffer sizes for WAN experiments" | sudo tee -a /etc/sysctl.conf 52 | echo "net.core.rmem_max = 67108864" | sudo tee -a /etc/sysctl.conf 53 | echo "net.core.wmem_max = 67108864" | sudo tee -a /etc/sysctl.conf 54 | echo "net.ipv4.tcp_rmem = 4096 87380 33554432" | sudo tee -a /etc/sysctl.conf 55 | echo "net.ipv4.tcp_wmem = 4096 65536 33554432" | sudo tee -a /etc/sysctl.conf 56 | 57 | # Modify /etc/sysctl.conf to use frequent TCP keepalives so that WAN firewalls don't drop idle connections 58 | echo "# Use frequent TCP keepalives for WAN experiments" | sudo tee -a /etc/sysctl.conf 59 | echo "net.ipv4.tcp_keepalive_time = 240" | sudo tee -a /etc/sysctl.conf 60 | echo "net.ipv4.tcp_keepalive_intvl = 65" | sudo tee -a /etc/sysctl.conf 61 | echo "net.ipv4.tcp_keepalive_probes = 5" | sudo tee -a /etc/sysctl.conf 62 | ;; 63 | *) 64 | echo "Unknown command-line flag" $flag 65 | esac 66 | done 67 | -------------------------------------------------------------------------------- /src/programs/merge_sorted.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::merge_sorted { 32 | template 33 | void create_merge_sorted_circuit(const ProgramOptions& args) { 34 | int input_array_length = args.problem_size * 2; 35 | 36 | ClusterUtils utils; 37 | utils.self_id = args.worker_index; 38 | utils.num_proc = args.num_workers; 39 | 40 | ShardedArray> inputs(input_array_length, args.worker_index, args.num_workers, Layout::Cyclic); 41 | inputs.for_each([=](std::size_t i, auto& input) { 42 | input.data.mark_input(i < args.problem_size ? Party::Garbler : Party::Evaluator); 43 | }); 44 | 45 | program_ptr->print_stats(); 46 | program_ptr->start_timer(); 47 | 48 | /* 49 | * For malicious MPC, we would want to check that the input is sorted 50 | * as we would expect --- ascending for the first half, then 51 | * descending, so that the concatenation is a bitonic sequence. The 52 | * aspirin count circuit does this, but I'm skipping it here in order 53 | * to isolate the cost of merging the two sorted lists. (It also isn't 54 | * necessary for semi-honest MPC.) 55 | */ 56 | 57 | // Sort inputs and switch to blocked layout 58 | parallel_bitonic_sorter(inputs); 59 | 60 | program_ptr->stop_timer(); 61 | program_ptr->print_stats(); 62 | 63 | // Output sorted list 64 | inputs.for_each([=](std::size_t i, auto& input) { 65 | input.data.mark_output(); 66 | }); 67 | } 68 | 69 | RegisterProgram merge_sorted("merge_sorted", "Merge Sorted Lists (problem_size = number of elements per party)", create_merge_sorted_circuit<>); 70 | } 71 | -------------------------------------------------------------------------------- /src/programs/real_cpir.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2021 Sam Kumar 3 | * Copyright (C) 2021 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::real_cpir { 32 | void create_real_cpir_circuit(const ProgramOptions& args) { 33 | int num_sets = args.problem_size; 34 | int set_size = args.problem_size; 35 | int input_array_length = num_sets * set_size; 36 | 37 | ClusterUtils utils; 38 | utils.self_id = args.worker_index; 39 | utils.num_proc = args.num_workers; 40 | 41 | ShardedArray> inputs(input_array_length, args.worker_index, args.num_workers, Layout::Blocked); 42 | inputs.for_each([=](std::size_t i, auto& input) { 43 | input = LeveledPlaintextBatch<1>(static_cast(i + 1)); 44 | }); 45 | 46 | program_ptr->print_stats(); 47 | program_ptr->start_timer(); 48 | 49 | std::vector>& locals = inputs.get_locals(); 50 | 51 | if ((locals.size() % set_size) != 0) { 52 | std::cerr << "Each worker must handle a whole number of sets" << std::endl; 53 | std::abort(); 54 | } 55 | 56 | std::vector> set_request(set_size); 57 | for (int i = 0; i != set_size; i++) { 58 | set_request[i].mark_input(); 59 | } 60 | 61 | for (std::size_t i = 0; i != locals.size(); i += set_size) { 62 | LeveledBatch<1, false> set_response = set_request[0].multiply_without_normalizing(locals[i]); 63 | for (std::size_t j = 1; j != set_size; j++) { 64 | set_response = set_response + set_request[j].multiply_without_normalizing(locals[i + j]); 65 | } 66 | set_response.renormalize().mark_output(); 67 | } 68 | 69 | program_ptr->stop_timer(); 70 | program_ptr->print_stats(); 71 | } 72 | 73 | RegisterProgram real_cpir("real_cpir", "Perform computational PIR on an array of real numbers (problem_size = square root of the number of elements)", create_real_cpir_circuit); 74 | } 75 | -------------------------------------------------------------------------------- /src/crypto/ot/extension.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_CRYPTO_OT_EXTENSION_HPP_ 23 | #define MAGE_CRYPTO_OT_EXTENSION_HPP_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include "crypto/block.hpp" 30 | #include "crypto/ot/base.hpp" 31 | #include "util/filebuffer.hpp" 32 | 33 | namespace mage::crypto::ot { 34 | /* 35 | * This implements OT Extension based on the protocol in Section 5.3 and 36 | * Protocol 5.2 of the following paper: 37 | * G. Asharov, Y. Lindell, T. Schneider, and M. Zohner. More Efficient 38 | * Oblivious Transfer and Extensions for Faster Secure Computation. CCS 39 | * 2013. 40 | */ 41 | 42 | /* Kappa is the symmetric security parameter. */ 43 | constexpr const std::uint8_t extension_kappa = block_num_bits; 44 | 45 | class ExtensionSender { 46 | public: 47 | ExtensionSender(); 48 | void initialize(util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out); 49 | 50 | void send(util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, const std::pair* choices, std::size_t num_choices); 51 | 52 | protected: 53 | void prepare_send(std::size_t num_choices, const block* u, block* q); 54 | void finish_send(const std::pair* choices, std::size_t num_choices, block* y, const block* qT); 55 | 56 | std::array prgs; 57 | block s; 58 | 59 | bool initialized; 60 | }; 61 | 62 | struct ExtChooserPRGs { 63 | PRG g0; 64 | PRG g1; 65 | }; 66 | 67 | class ExtensionChooser { 68 | public: 69 | ExtensionChooser(); 70 | void initialize(util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out); 71 | 72 | void choose(util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, const block* choices, block* results, std::size_t num_choices); 73 | 74 | protected: 75 | void prepare_choose(const block* choices, std::size_t num_choices, block* u, block* t); 76 | void finish_choose(const block* choices, block* results, std::size_t num_choices, const block* y, const block* tT); 77 | 78 | std::array prgs; 79 | 80 | bool initialized; 81 | }; 82 | } 83 | 84 | #endif 85 | -------------------------------------------------------------------------------- /src/protocols/plaintext.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_PROTOCOLS_PLAINTEXT_HPP_ 23 | #define MAGE_PROTOCOLS_PLAINTEXT_HPP_ 24 | 25 | #include 26 | #include 27 | #include "util/binaryfile.hpp" 28 | 29 | namespace mage::protocols::plaintext { 30 | class PlaintextEvaluationEngine { 31 | public: 32 | using Wire = unsigned __int128; 33 | 34 | PlaintextEvaluationEngine(std::string garbler_input_file, std::string evaluator_input_file, std::string output_file) 35 | : garbler_input_reader(garbler_input_file.c_str()), evaluator_input_reader(evaluator_input_file.c_str()), output_writer(output_file.c_str()) { 36 | } 37 | 38 | void print_stats() { 39 | } 40 | 41 | void input(Wire* data, unsigned int length, bool garbler) { 42 | util::BinaryFileReader& input_reader = garbler ? this->garbler_input_reader : this->evaluator_input_reader; 43 | for (unsigned int i = 0; i != length; i++) { 44 | std::uint8_t bit = input_reader.read1(); 45 | data[i] = bit; 46 | } 47 | } 48 | 49 | void output(const Wire* data, unsigned int length) { 50 | for (unsigned int i = 0; i != length; i++) { 51 | std::uint8_t bit = static_cast(data[i]) & 0x1; 52 | this->output_writer.write1(bit); 53 | } 54 | } 55 | 56 | void op_and(Wire& output, const Wire& input1, const Wire& input2) { 57 | output = input1 & input2; 58 | } 59 | 60 | void op_xor(Wire& output, const Wire& input1, const Wire& input2) { 61 | output = input1 ^ input2; 62 | } 63 | 64 | void op_not(Wire& output, const Wire& input) { 65 | output = !input; 66 | } 67 | 68 | void op_xnor(Wire& output, const Wire& input1, const Wire& input2) { 69 | output = !(input1 ^ input2); 70 | } 71 | 72 | void op_copy(Wire& output, const Wire& input) { 73 | output = input; 74 | } 75 | 76 | void one(Wire& output) const { 77 | output = 1; 78 | } 79 | 80 | void zero(Wire& output) const { 81 | output = 0; 82 | } 83 | 84 | private: 85 | util::BinaryFileReader garbler_input_reader; 86 | util::BinaryFileReader evaluator_input_reader; 87 | util::BinaryFileWriter output_writer; 88 | }; 89 | } 90 | 91 | #endif 92 | -------------------------------------------------------------------------------- /src/protocols/ckks.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include "engine/addmultiply.hpp" 25 | #include "protocols/registry.hpp" 26 | #include "protocols/ckks.hpp" 27 | #include "protocols/ckks_constants.hpp" 28 | 29 | namespace mage::protocols::ckks { 30 | void run_ckks(const EngineOptions& args) { 31 | std::string file_base = args.problem_name + "_" + std::to_string(args.self_id); 32 | std::string prog_file = file_base + ".memprog"; 33 | std::string output_file = file_base + ".output"; 34 | std::string input_file = file_base + "_garbler.input"; 35 | 36 | std::chrono::time_point start; 37 | std::chrono::time_point end; 38 | 39 | util::Configuration& c = *args.config; 40 | { 41 | CKKSEngine p(input_file.c_str(), output_file.c_str()); 42 | engine::AddMultiplyEngine executor(args.cluster, c["parties"][args.party_id]["workers"][args.self_id], p, prog_file.c_str()); 43 | start = std::chrono::steady_clock::now(); 44 | executor.execute_program(); 45 | end = std::chrono::steady_clock::now(); 46 | } 47 | std::chrono::milliseconds ms = std::chrono::duration_cast(end - start); 48 | std::cout << ms.count() << " ms" << std::endl; 49 | } 50 | 51 | RegisterProtocol ckks("ckks", "Homomorphic Encryption for Arithmetic of Approximate Numbers", run_ckks, "ckks_plugin"); 52 | 53 | memprog::AllocationSize ckks_physical_size(std::uint64_t logical_width, memprog::PlaceableType type) { 54 | std:uint64_t result = UINT64_MAX; 55 | switch (type) { 56 | case memprog::PlaceableType::Ciphertext: 57 | result = ckks_ciphertext_size(logical_width, true); 58 | break; 59 | case memprog::PlaceableType::Plaintext: 60 | result = ckks_plaintext_size(logical_width); 61 | break; 62 | case memprog::PlaceableType::DenormalizedCiphertext: 63 | result = ckks_ciphertext_size(logical_width, false); 64 | break; 65 | } 66 | if (result == UINT64_MAX) { 67 | throw memprog::InvalidPlacementException("ckks", logical_width, type); 68 | } 69 | return result; 70 | } 71 | 72 | RegisterPlacementPlugin ckks_plugin("ckks_plugin", "Object's MAGE-virtual size is the size of a CKKS ciphertext/plaintext in bytes", ckks_physical_size); 73 | } 74 | -------------------------------------------------------------------------------- /src/executables/mage.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include "addr.hpp" 28 | #include "protocols/registry.hpp" 29 | #include "platform/network.hpp" 30 | #include "util/config.hpp" 31 | 32 | using namespace mage; 33 | using mage::protocols::RegisteredProtocol; 34 | using mage::protocols::EngineOptions; 35 | using mage::util::Registry; 36 | 37 | int main(int argc, char** argv) { 38 | if (argc != 6) { 39 | std::cerr << "Usage: " << argv[0] << " protocol config.yaml party_id worker_id program_name" << std::endl; 40 | return EXIT_FAILURE; 41 | } 42 | 43 | /* Parse the protocol name. */ 44 | 45 | std::string protocol_name(argv[1]); 46 | const RegisteredProtocol* prot_ptr = Registry::look_up_by_name(protocol_name); 47 | if (prot_ptr == nullptr) { 48 | std::cerr << protocol_name << " is not a valid protocol name. "; // lack of std::endl is intentional 49 | Registry::print_all("protocols", std::cerr); 50 | return EXIT_FAILURE; 51 | } 52 | 53 | /* Parse the config.yaml file. */ 54 | 55 | util::Configuration c(argv[2]); 56 | 57 | /* Parse the party ID. */ 58 | 59 | std::optional party_id = mage::protocols::parse_party_id(argv[3]); 60 | if (!party_id.has_value()) { 61 | std::cerr << "Invalid party_id (try \"garbler\", \"evaluator\", or an integer)" << std::endl; 62 | return EXIT_FAILURE; 63 | } 64 | 65 | /* Parse the worker ID. */ 66 | 67 | WorkerID self_id; 68 | std::istringstream self_id_stream(argv[4]); 69 | self_id_stream >> self_id; 70 | 71 | /* Establish cluster networking. */ 72 | 73 | std::size_t buffer_size = 1 << 18; 74 | 75 | /* TODO: do this more systematically. */ 76 | if (protocol_name == "ckks") { 77 | buffer_size = 1 << 20; 78 | } 79 | 80 | auto cluster = std::make_shared(self_id, buffer_size); 81 | std::string err = cluster->establish(c["parties"][*party_id]); 82 | if (!err.empty()) { 83 | std::cerr << err << std::endl; 84 | return EXIT_FAILURE; 85 | } 86 | 87 | /* Dispatch to the protocol. */ 88 | 89 | EngineOptions args = {}; 90 | args.config = &c; 91 | args.party_id = *party_id; 92 | args.self_id = self_id; 93 | args.cluster = cluster; 94 | args.problem_name = argv[5]; 95 | 96 | const RegisteredProtocol& protocol = *prot_ptr; 97 | protocol(args); 98 | 99 | return EXIT_SUCCESS; 100 | } 101 | -------------------------------------------------------------------------------- /src/programs/matrix_vector_multiply.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::matrix_vector_multiply { 32 | template 33 | std::vector> local_matrix_vector_multiply(Integer* matrix_a, std::size_t num_rows_a, Integer* vector_x, std::size_t num_cols_a_len_x) { 34 | std::vector> result(num_rows_a); 35 | for (std::size_t row_a = 0; row_a != num_rows_a; row_a++) { 36 | result[row_a] = dot_product(&matrix_a[row_a * num_cols_a_len_x], vector_x, num_cols_a_len_x); 37 | } 38 | return result; 39 | } 40 | 41 | template 42 | void create_matrix_vector_multiply_circuit(const ProgramOptions& args) { 43 | std::uint64_t vector_size = args.problem_size; 44 | std::uint64_t matrix_dimension = vector_size; 45 | std::uint64_t matrix_size = matrix_dimension * matrix_dimension; 46 | 47 | /* Blocked vector provided by the evaluator. */ 48 | ShardedArray> vector_x(vector_size, args.worker_index, args.num_workers, Layout::Blocked); 49 | vector_x.for_each([=](std::size_t i, auto& elem) { 50 | elem.mark_input(Party::Evaluator); 51 | }); 52 | 53 | /* Blocked row-major matrix provided by the garbler. */ 54 | std::vector> my_matrix_a(vector_x.get_locals().size() * matrix_dimension); 55 | for (auto& elem : my_matrix_a) { 56 | elem.mark_input(Party::Garbler); 57 | } 58 | 59 | program_ptr->print_stats(); 60 | program_ptr->start_timer(); 61 | 62 | /* Reconstruct the entire vector x for each worker. */ 63 | std::vector> my_vector_x = vector_x.materialize_global_array(true); 64 | 65 | /* Multiply my portion of the matrix by the entire vector. */ 66 | std::vector> result = local_matrix_vector_multiply(my_matrix_a.data(), my_matrix_a.size() / my_vector_x.size(), my_vector_x.data(), my_vector_x.size()); 67 | 68 | program_ptr->stop_timer(); 69 | program_ptr->print_stats(); 70 | 71 | for (std::size_t i = 0; i != result.size(); i++) { 72 | result[i].mark_output(); 73 | } 74 | } 75 | 76 | RegisterProgram matrix_vector_multiply("matrix_vector_multiply", "Matrix-Vector Multiply (problem_size = number of elements in one side of matrix)", create_matrix_vector_multiply_circuit<>); 77 | } 78 | -------------------------------------------------------------------------------- /src/memprog/annotation.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "memprog/annotation.hpp" 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include "addr.hpp" 31 | #include "instruction.hpp" 32 | #include "programfile.hpp" 33 | #include "platform/memory.hpp" 34 | #include "util/filebuffer.hpp" 35 | #include "util/progress.hpp" 36 | 37 | namespace mage::memprog { 38 | std::uint64_t annotate_program(util::BufferedFileWriter& output, std::string program, PageShift page_shift, util::ProgressBar* progress_bar) { 39 | VirtProgramReverseFileReader instructions(program); 40 | instructions.set_progress_bar(progress_bar); 41 | InstructionNumber inum = instructions.get_header().num_instructions; 42 | 43 | std::unordered_map next_access; 44 | std::uint64_t max_working_set_size = 0; 45 | 46 | std::array vpns; 47 | do { 48 | inum--; 49 | 50 | std::size_t current_size; 51 | PackedVirtInstruction& current = instructions.read_instruction(current_size); 52 | Annotation& ann = output.start_write(); 53 | ann.header.num_pages = current.store_page_numbers(vpns.data(), page_shift); 54 | for (std::uint16_t i = 0; i != ann.header.num_pages; i++) { 55 | /* Re-profile the code if you modify this inner loop. */ 56 | auto iter = next_access.find(vpns[i]); 57 | if (iter == next_access.end()) { 58 | next_access.insert(std::make_pair(vpns[i], inum)); 59 | ann.slots[i].next_use = invalid_instr; 60 | } else { 61 | ann.slots[i].next_use = iter->second; 62 | iter->second = inum; 63 | } 64 | } 65 | output.finish_write(ann.size()); 66 | max_working_set_size = std::max(max_working_set_size, next_access.size()); 67 | 68 | if ((current.header.flags & FlagOutputPageFirstUse) != 0) { 69 | /* 70 | * Instruction format must be NoArgs, OneArg, TwoArgs, 71 | * ThreeArgs, or Constant in order to get this flag. 72 | */ 73 | next_access.erase(pg_num(current.no_args.output, page_shift)); 74 | } 75 | } while (inum != 0); 76 | 77 | return max_working_set_size; 78 | } 79 | 80 | std::uint64_t annotate_program(std::string annotations, std::string program, PageShift page_shift, util::ProgressBar* progress_bar) { 81 | util::BufferedFileWriter output(annotations.c_str()); 82 | return annotate_program(output, program, page_shift, progress_bar); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/programs/real_matrix_vector_multiply.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include "dsl/array.hpp" 24 | #include "dsl/integer.hpp" 25 | #include "dsl/parallel.hpp" 26 | #include "dsl/sort.hpp" 27 | #include "programs/registry.hpp" 28 | #include "programs/util.hpp" 29 | 30 | using namespace mage::dsl; 31 | 32 | namespace mage::programs::real_matrix_vector_multiply { 33 | template 34 | std::vector> local_matrix_vector_multiply(LeveledBatch* matrix_a, std::size_t num_rows_a, LeveledBatch* vector_x, std::size_t num_cols_a_len_x) { 35 | std::vector> result(num_rows_a); 36 | for (std::size_t row_a = level; row_a != num_rows_a; row_a++) { 37 | result[row_a] = real_dot_product(&matrix_a[row_a * num_cols_a_len_x], vector_x, num_cols_a_len_x); 38 | } 39 | return result; 40 | } 41 | 42 | template 43 | void create_real_matrix_vector_multiply_circuit(const ProgramOptions& args) { 44 | std::uint64_t vector_size = args.problem_size; 45 | std::uint64_t matrix_dimension = vector_size; 46 | std::uint64_t matrix_size = matrix_dimension * matrix_dimension; 47 | 48 | /* Blocked vector provided by the evaluator. */ 49 | ShardedArray> vector_x(vector_size, args.worker_index, args.num_workers, Layout::Blocked); 50 | vector_x.for_each([=](std::size_t i, auto& elem) { 51 | elem.mark_input(); 52 | }); 53 | 54 | /* Blocked row-major matrix provided by the garbler. */ 55 | std::vector> my_matrix_a(vector_x.get_locals().size() * matrix_dimension); 56 | for (auto& elem : my_matrix_a) { 57 | elem.mark_input(); 58 | } 59 | 60 | program_ptr->print_stats(); 61 | program_ptr->start_timer(); 62 | 63 | /* Reconstruct the entire vector x for each worker. */ 64 | std::vector> my_vector_x = vector_x.materialize_global_array(true); 65 | 66 | /* Multiply my portion of the matrix by the entire vector. */ 67 | std::vector> result = local_matrix_vector_multiply(my_matrix_a.data(), my_matrix_a.size() / my_vector_x.size(), my_vector_x.data(), my_vector_x.size()); 68 | 69 | program_ptr->stop_timer(); 70 | program_ptr->print_stats(); 71 | 72 | for (std::size_t i = 0; i != result.size(); i++) { 73 | result[i].mark_output(); 74 | } 75 | } 76 | 77 | RegisterProgram real_matrix_vector_multiply("real_matrix_vector_multiply", "Matrix-Vector Multiply with real numbers (problem_size = number of elements in one side of matrix)", create_real_matrix_vector_multiply_circuit<0>); 78 | } 79 | -------------------------------------------------------------------------------- /src/protocols/registry.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_PROTOCOLS_REGISTRY_HPP_ 23 | #define MAGE_PROTOCOLS_REGISTRY_HPP_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include "addr.hpp" 32 | #include "engine/cluster.hpp" 33 | #include "memprog/placement.hpp" 34 | #include "util/config.hpp" 35 | #include "util/registry.hpp" 36 | 37 | namespace mage::protocols { 38 | struct EngineOptions { 39 | util::Configuration* config; 40 | PartyID party_id; 41 | WorkerID self_id; 42 | std::shared_ptr cluster; 43 | std::string problem_name; 44 | }; 45 | 46 | extern std::vector evaluator_synonyms; 47 | extern std::vector garbler_synonyms; 48 | 49 | std::optional parse_party_id(const std::string& party); 50 | 51 | class RegisteredPlacementPlugin : public util::BaseRegistryEntry { 52 | friend class util::Register; 53 | 54 | public: 55 | memprog::PlacementPlugin get_placement_plugin() const { 56 | return this->p; 57 | } 58 | 59 | private: 60 | RegisteredPlacementPlugin(std::string name, std::string desc, memprog::PlacementPlugin plugin) 61 | : util::BaseRegistryEntry(name, desc), p(plugin) { 62 | } 63 | 64 | memprog::PlacementPlugin p; 65 | }; 66 | 67 | using RegisterPlacementPlugin = util::Register; 68 | 69 | class RegisteredProtocol : public util::CallableRegistryEntry { 70 | friend class util::Register; 71 | 72 | public: 73 | const std::string& get_placement_plugin_name() const { 74 | return this->plugin_name; 75 | } 76 | 77 | memprog::PlacementPlugin get_placement_plugin() const { 78 | const std::string& name = this->get_placement_plugin_name(); 79 | const RegisteredPlacementPlugin* plugin_ptr = util::Registry::look_up_by_name(name); 80 | if (plugin_ptr == nullptr) { 81 | std::cerr << "Misconfigured build: protocol \"" << this->get_label() << "\" requires placement plugin \"" << name << "\"" << std::endl; 82 | std::abort(); 83 | } 84 | return plugin_ptr->get_placement_plugin(); 85 | } 86 | 87 | private: 88 | RegisteredProtocol(std::string name, std::string desc, std::function driver, const std::string& plugin) 89 | : util::CallableRegistryEntry(name, desc, driver), plugin_name(plugin) { 90 | } 91 | 92 | std::string plugin_name; 93 | }; 94 | 95 | using RegisterProtocol = util::Register; 96 | } 97 | 98 | #endif 99 | -------------------------------------------------------------------------------- /src/programs/password.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::password { 32 | template 33 | struct UserPassword { 34 | Integer value; 35 | 36 | IntSlice get_user_id() { 37 | return this->value.template slice(pw_bits); 38 | } 39 | 40 | IntSlice get_pw_hash() { 41 | return this->value.template slice(0); 42 | } 43 | 44 | static void comparator(UserPassword& arg0, UserPassword& arg1) { 45 | Bit predicate = arg0.get_user_id() > arg1.get_user_id(); 46 | Integer::swap_if(predicate, arg0.value, arg1.value); 47 | } 48 | 49 | void buffer_send(WorkerID to) { 50 | this->value.buffer_send(to); 51 | } 52 | 53 | static void finish_send(WorkerID to) { 54 | Integer::finish_send(to); 55 | } 56 | 57 | void post_receive(WorkerID from) { 58 | this->value.post_receive(from); 59 | } 60 | 61 | static void finish_receive(WorkerID from) { 62 | Integer::finish_receive(from); 63 | } 64 | }; 65 | 66 | template 67 | void create_password_circuit(const ProgramOptions& args) { 68 | int input_array_length = args.problem_size * 2; 69 | 70 | ClusterUtils utils; 71 | utils.self_id = args.worker_index; 72 | utils.num_proc = args.num_workers; 73 | 74 | ShardedArray> inputs(input_array_length, args.worker_index, args.num_workers, Layout::Cyclic); 75 | inputs.for_each([=](std::size_t i, auto& input) { 76 | input.value.mark_input(i < args.problem_size ? Party::Garbler : Party::Evaluator); 77 | }); 78 | 79 | /* Skip verifying the inputs, since this is semi-honest MPC. */ 80 | 81 | /* Merge the two sorted arrays, sorted by user but not password. */ 82 | parallel_bitonic_sorter(inputs); 83 | 84 | /* Do the PSI. */ 85 | Integer zero, output; 86 | zero.mutate_to_constant(0); 87 | 88 | inputs.for_each_pair([&](std::size_t i, auto& a, auto& b) { 89 | Bit equals = (a.value == b.value); 90 | output = Integer::select(equals, a.get_user_id(), zero); 91 | output.mark_output(); 92 | }); 93 | } 94 | 95 | RegisterProgram password("password", "Password reuse query from the Senate paper", create_password_circuit<>); 96 | } 97 | -------------------------------------------------------------------------------- /src/programs/matrix_multiply.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::matrix_multiply { 32 | template 33 | std::vector> local_naive_matrix_multiply(Integer* matrix_a, std::size_t num_rows_a, Integer* matrix_b, std::size_t num_cols_b, std::size_t num_cols_a_rows_b) { 34 | std::vector> result(num_rows_a * num_cols_b); 35 | for (std::size_t row_a = 0; row_a != num_rows_a; row_a++) { 36 | for (std::size_t col_b = 0; col_b != num_cols_b; col_b++) { 37 | /* This goes in result at row row_a and column col_b. */ 38 | std::size_t i = row_a * num_cols_b + col_b; 39 | result[i] = dot_product(&matrix_a[row_a * num_cols_a_rows_b], &matrix_b[col_b * num_cols_a_rows_b], num_cols_a_rows_b); 40 | } 41 | } 42 | return result; 43 | } 44 | 45 | template 46 | void create_matrix_multiply_circuit(const ProgramOptions& args) { 47 | int matrix_dimension = args.problem_size; 48 | int matrix_size = matrix_dimension * matrix_dimension; 49 | 50 | /* Blocked row-major matrix provided by the garbler. */ 51 | ShardedArray> matrix_a(matrix_size, args.worker_index, args.num_workers, Layout::Blocked); 52 | matrix_a.for_each([=](std::size_t i, auto& elem) { 53 | elem.mark_input(Party::Garbler); 54 | }); 55 | 56 | /* Blocked column-major matrix provided by the evaluator. */ 57 | ShardedArray> matrix_b(matrix_size, args.worker_index, args.num_workers, Layout::Blocked); 58 | matrix_b.for_each([=](std::size_t i, auto& elem) { 59 | elem.mark_input(Party::Evaluator); 60 | }); 61 | 62 | program_ptr->print_stats(); 63 | program_ptr->start_timer(); 64 | 65 | ClusterUtils utils; 66 | utils.self_id = args.worker_index; 67 | utils.num_proc = args.num_workers; 68 | auto [ my_matrix_a, my_matrix_b ] = utils.cross_product(matrix_a, matrix_b); 69 | 70 | std::vector> result = local_naive_matrix_multiply(my_matrix_a.data(), my_matrix_a.size() / matrix_dimension, my_matrix_b.data(), my_matrix_b.size() / matrix_dimension, matrix_dimension); 71 | 72 | program_ptr->stop_timer(); 73 | program_ptr->print_stats(); 74 | 75 | for (std::size_t i = 0; i != result.size(); i++) { 76 | result[i].mark_output(); 77 | } 78 | } 79 | 80 | RegisterProgram matrix_multiply("matrix_multiply", "Matrix Multiply (problem_size = number of elements in one side of matrix)", create_matrix_multiply_circuit<>); 81 | } 82 | -------------------------------------------------------------------------------- /src/programs/aspirin_seq.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::aspirin_seq { 32 | template 33 | struct Input { 34 | Integer patient_id_concat_timestamp; 35 | Bit diagnosis; // or aspirin prescription 36 | 37 | static void comparator(Input& arg0, Input& arg1) { 38 | Bit predicate = arg0.patient_id_concat_timestamp > arg1.patient_id_concat_timestamp; 39 | Integer::swap_if(predicate, arg0.patient_id_concat_timestamp, arg1.patient_id_concat_timestamp); 40 | Bit::swap_if(predicate, arg0.diagnosis, arg1.diagnosis); 41 | } 42 | }; 43 | 44 | template 45 | void create_aspirin_circuit(const ProgramOptions& args) { 46 | int input_array_length = args.problem_size * 2; 47 | std::vector> inputs; 48 | 49 | for (int i = 0; i != input_array_length; i++) { 50 | inputs.emplace_back(); 51 | inputs[i].patient_id_concat_timestamp.mark_input(i < args.problem_size ? Party::Garbler : Party::Evaluator); 52 | inputs[i].diagnosis.mark_input(i < args.problem_size ? Party::Garbler : Party::Evaluator); 53 | } 54 | 55 | // Verify the input first. 56 | Bit order(1); 57 | for (int i = 0; i < args.problem_size - 1; i++) { 58 | Bit lte = inputs[i].patient_id_concat_timestamp <= inputs[i+1].patient_id_concat_timestamp; 59 | order = order & lte; 60 | } 61 | for (int i = args.problem_size; i < 2 * args.problem_size - 1; i++) { 62 | Bit gte = inputs[i].patient_id_concat_timestamp >= inputs[i+1].patient_id_concat_timestamp; 63 | order = order & gte; 64 | } 65 | order.mark_output(); 66 | 67 | // Merge the two arrays, sorted ascending by patient_id_concat_timestamp 68 | bitonic_sorter(inputs.data(), input_array_length); 69 | 70 | // Now, for each input, check if it and the next input have the same patient, but the first is a diagnosis and the second isn't. 71 | Integer total(0); 72 | for (int i = 0; i < input_array_length - 1; i++) { 73 | Bit add = inputs[i].diagnosis & ~inputs[i+1].diagnosis; 74 | IntSlice patient_id_i = inputs[i].patient_id_concat_timestamp.template slice(timestamp_bits); 75 | IntSlice patient_id_ip1 = inputs[i+1].patient_id_concat_timestamp.template slice(timestamp_bits); 76 | add = add & (patient_id_i == patient_id_ip1); 77 | Integer next = total.increment(); 78 | total = Integer::select(add, next, total); 79 | } 80 | 81 | total.mark_output(); 82 | } 83 | 84 | RegisterProgram aspirin_seq("aspirin_seq", "Aspirin Count where each worker computes the whole thing", create_aspirin_circuit<>); 85 | } 86 | -------------------------------------------------------------------------------- /src/crypto/ot/correlated.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_CRYPTO_OT_CORRELATED_HPP_ 23 | #define MAGE_CRYPTO_OT_CORRELATED_HPP_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include "crypto/block.hpp" 29 | #include "crypto/ot/extension.hpp" 30 | #include "util/filebuffer.hpp" 31 | #include "util/userpipe.hpp" 32 | 33 | namespace mage::crypto::ot { 34 | class CorrelatedExtensionSender : public ExtensionSender { 35 | public: 36 | void send(util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, block delta, block* first_choices, std::size_t num_choices); 37 | 38 | protected: 39 | void finish_send(block delta, block* first_choices, std::size_t num_choices, block* y, const block* qT); 40 | }; 41 | 42 | class CorrelatedExtensionChooser : public ExtensionChooser { 43 | public: 44 | void choose(util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, const block* choices, block* results, std::size_t num_choices); 45 | 46 | protected: 47 | void finish_choose(const block* choices, block* results, std::size_t num_choices, const block* y, const block* tT); 48 | }; 49 | 50 | class PipelinedCorrelatedExtSender : private CorrelatedExtensionSender { 51 | public: 52 | PipelinedCorrelatedExtSender(util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, block choice_delta, util::UserPipe& results, std::size_t batch_size, std::size_t max_depth); 53 | ~PipelinedCorrelatedExtSender(); 54 | 55 | /* WARNING: Do not call this concurrently from multiple threads. */ 56 | void submit_send(); 57 | 58 | private: 59 | void start_daemon(); 60 | 61 | util::BufferedFileReader& net_in; 62 | util::BufferedFileWriter& net_out; 63 | std::size_t num_choices; 64 | std::size_t depth; 65 | 66 | std::size_t num_row_blocks; 67 | std::size_t num_blocks; 68 | std::unique_ptr> pipeline; 69 | 70 | block delta; 71 | 72 | util::UserPipe& output; 73 | std::thread daemon; 74 | }; 75 | 76 | class PipelinedCorrelatedExtChooser : private CorrelatedExtensionChooser { 77 | public: 78 | PipelinedCorrelatedExtChooser(util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, util::UserPipe& results, std::size_t batch_size, std::size_t max_depth); 79 | ~PipelinedCorrelatedExtChooser(); 80 | 81 | void start_daemon(); 82 | 83 | /* WARNING: Do not call this concurrently from multiple threads. */ 84 | void submit_choose(const block* choices); 85 | 86 | private: 87 | util::BufferedFileReader& net_in; 88 | util::BufferedFileWriter& net_out; 89 | std::size_t num_choices; 90 | std::size_t depth; 91 | 92 | std::size_t num_row_blocks; 93 | std::size_t num_blocks; 94 | std::unique_ptr> pipeline; 95 | 96 | util::UserPipe& output; 97 | std::thread daemon; 98 | }; 99 | } 100 | 101 | #endif 102 | -------------------------------------------------------------------------------- /src/dsl/util.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | /** 23 | * @file dsl/util.hpp 24 | * @brief Utility functions to use with MAGE's DSLs. 25 | */ 26 | 27 | #ifndef MAGE_DSL_UTIL_HPP_ 28 | #define MAGE_DSL_UTIL_HPP_ 29 | 30 | #include 31 | #include "instruction.hpp" 32 | #include "dsl/integer.hpp" 33 | #include "util/misc.hpp" 34 | 35 | namespace mage::dsl { 36 | /** 37 | * @brief Computes the sum of a set of Integers, increasing the width as 38 | * needed to prevent overflow. 39 | * 40 | * @tparam output_bits The width of the sum, in bits. 41 | * @tparam bits The width of the Integers to add, in bits. 42 | * @tparam sliced True if the Integers to add are sliced, otherwise false. 43 | * @tparam Placer Type of placement module used by the Integers. 44 | * @tparam p Double pointer to the program object used by the Integers. 45 | * @param elements Pointer to the array of Integers whose sum to compute. 46 | * @param num_elements Length of the array whose sum to compute. 47 | * @return The sum, at the specified width. 48 | */ 49 | template ** p> 50 | Integer reduce(Integer* elements, std::size_t num_elements) { 51 | if constexpr (bits == output_bits) { // we need this to end the template recursion 52 | assert(num_elements == 1); 53 | Integer result; 54 | result.mutate(*elements); 55 | elements->recycle(); 56 | return result; 57 | } else { 58 | if (num_elements == 1) { 59 | Integer result; 60 | result.mutate(*elements); 61 | elements->recycle(); 62 | return result; 63 | } 64 | std::vector> partial_sums(util::ceil_div(num_elements, 2).first); 65 | for (std::size_t i = 0; i < num_elements; i += 2) { 66 | if (i + 1 < num_elements) { 67 | partial_sums[i >> 1] = elements[i].add_with_carry(elements[i + 1]); 68 | elements[i].recycle(); 69 | elements[i + 1].recycle(); 70 | } else { 71 | partial_sums[i >> 1].mutate(elements[i]); 72 | elements[i].recycle(); 73 | } 74 | } 75 | return reduce(partial_sums.data(), partial_sums.size()); 76 | } 77 | } 78 | 79 | /** 80 | * @brief Sends any buffered data to the destination workers, and blocks 81 | * until any outstanding receive operations complete. 82 | * 83 | * It is expected that all workers call this function concurrently. 84 | * 85 | * @param self_id The ID of the worker in whose program this is being 86 | * called. 87 | * @param num_proc The total number of workers. 88 | */ 89 | template 90 | void communication_barrier(WorkerID self_id, WorkerID num_proc) { 91 | for (WorkerID w = 0; w != num_proc; w++) { 92 | if (w != self_id) { 93 | T::finish_send(w); 94 | } 95 | } 96 | for (WorkerID w = 0; w != num_proc; w++) { 97 | if (w != self_id) { 98 | T::finish_receive(w); 99 | } 100 | } 101 | } 102 | } 103 | 104 | #endif 105 | -------------------------------------------------------------------------------- /src/platform/network.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | /** 23 | * @file platform/network.hpp 24 | * @brief System-level utilities for communication using the network. 25 | */ 26 | 27 | #ifndef MAGE_PLATFORM_NETWORK_HPP_ 28 | #define MAGE_PLATFORM_NETWORK_HPP_ 29 | 30 | #include 31 | #include 32 | 33 | namespace mage::platform { 34 | /** 35 | * @brief Describes a network-related error. 36 | */ 37 | enum class NetworkError : std::uint8_t { 38 | Success, 39 | ConnectionRefused, 40 | TimedOut, 41 | }; 42 | 43 | /** 44 | * @brief Listens for incoming TCP connections on the specified port and 45 | * accepts the specified number of connections. 46 | * 47 | * If an error occurs, then the process is aborted. 48 | * 49 | * @param port The port on which to listen for incoming connections, 50 | * provided as a string. 51 | * @param[out] into An array into which to write file descriptors for the 52 | * accepted connections. 53 | * @param count The number of incoming connections to accept. 54 | */ 55 | void network_accept(const char* port, int* into, std::uint32_t count = 1); 56 | 57 | /** 58 | * @brief Creates the specified number of TCP connections to the endpoint 59 | * the specified hostname and port. 60 | * 61 | * If @p err is not @p nullptr, then it is treated as an array of error 62 | * conditions where the element at a particular index is populated 63 | * according to which error, if any occurred when establishing the 64 | * connection at that index. If it is populated with an error condition 65 | * (i.e., with something other than @p Success), then that element of 66 | * the @p into array is left uninitialized. If an error occurs that cannot 67 | * be described by a @p NetworkError, or if an error occurs and @p err is 68 | * @p nullptr, then the process is aborted. 69 | * 70 | * @param host The hostname of the TCP endpoint to which to connect. 71 | * @param port The port of the TCP endpoint to which to connect. 72 | * @param[out] into An array into which to write file descriptors for the 73 | * resulting connections. 74 | * @param[out] err An array into which an error condition for each 75 | * connection is written. 76 | * @param count The number of connections to establish with the specified 77 | * TCP endpoint. 78 | */ 79 | void network_connect(const char* host, const char* port, int* into, NetworkError* err, std::uint32_t count = 1); 80 | 81 | /** 82 | * @brief Closes a file descriptor corresponding to a TCP connection, 83 | * shutting down the connection. 84 | * 85 | * If an error occurs, then the process is aborted. 86 | * 87 | * @param socket The file descriptor to close. 88 | */ 89 | void network_close(int socket); 90 | 91 | /** 92 | * @brief Opens a pipe, placing the output file descriptor and input file 93 | * descriptor, in that order, into the specified array. 94 | * 95 | * If an error occurs, then the process is aborted. 96 | * 97 | * @param[out] into The array into which to place the pipe's two file 98 | * descriptors. 99 | */ 100 | void pipe_open(int* into); 101 | 102 | /** 103 | * @brief Closes a file descriptor corresponding to a pipe. 104 | * 105 | * If an error occurs, then the process is aborted. 106 | * 107 | * @param fd The file descriptor to close. 108 | */ 109 | void pipe_close(int fd); 110 | } 111 | 112 | #endif 113 | -------------------------------------------------------------------------------- /src/programs/real_statistics.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::real_statistics { 32 | struct Stats { 33 | LeveledBatch<2, true> sum; 34 | LeveledBatch<1, true> sum_squares; 35 | 36 | void buffer_send(WorkerID to) { 37 | this->sum.buffer_send(to); 38 | this->sum_squares.buffer_send(to); 39 | } 40 | 41 | static void finish_send(WorkerID to) { 42 | LeveledBatch<0, true>::finish_send(to); 43 | } 44 | 45 | void post_receive(WorkerID from) { 46 | this->sum.post_receive(from); 47 | this->sum_squares.post_receive(from); 48 | } 49 | 50 | static void finish_receive(WorkerID from) { 51 | LeveledBatch<0, true>::finish_receive(from); 52 | } 53 | }; 54 | 55 | void create_real_statistics_circuit(const ProgramOptions& args) { 56 | int input_array_length = args.problem_size; 57 | 58 | ClusterUtils utils; 59 | utils.self_id = args.worker_index; 60 | utils.num_proc = args.num_workers; 61 | 62 | ShardedArray> inputs(input_array_length, args.worker_index, args.num_workers, Layout::Blocked); 63 | inputs.for_each([=](std::size_t i, auto& input) { 64 | input.mark_input(); 65 | }); 66 | 67 | program_ptr->print_stats(); 68 | program_ptr->start_timer(); 69 | 70 | std::vector>& locals = inputs.get_locals(); 71 | 72 | Stats local; 73 | if (locals.size() == 0) { 74 | local.sum = LeveledBatch<2, true>(0); 75 | local.sum_squares = LeveledBatch<1, true>(0); 76 | } else { 77 | LeveledBatch<2, false> temp = locals[0].multiply_without_normalizing(locals[0]); 78 | local.sum = std::move(locals[0]); 79 | for (std::size_t i = 1; i != locals.size(); i++) { 80 | temp = temp + locals[i].multiply_without_normalizing(locals[i]); 81 | local.sum = local.sum + locals[i]; 82 | } 83 | local.sum_squares = temp.renormalize(); 84 | } 85 | 86 | std::optional global_stats = utils.reduce_aggregates(0, local, [](Stats& a, Stats& b) -> Stats { 87 | Stats result; 88 | result.sum = a.sum + b.sum; 89 | result.sum_squares = a.sum_squares + b.sum_squares; 90 | return result; 91 | }); 92 | 93 | if (args.worker_index == 0) { 94 | LeveledBatch<1, true> mean = global_stats->sum * LeveledPlaintextBatch<2>(1 / static_cast(input_array_length)); 95 | LeveledBatch<0, true> mean_squares = global_stats->sum_squares * LeveledPlaintextBatch<1>(1 / static_cast(input_array_length)); 96 | LeveledBatch<0, true> variance = mean_squares - (mean * mean); 97 | 98 | program_ptr->stop_timer(); 99 | program_ptr->print_stats(); 100 | 101 | mean.mark_output(); 102 | variance.mark_output(); 103 | } else { 104 | program_ptr->stop_timer(); 105 | program_ptr->print_stats(); 106 | } 107 | } 108 | 109 | RegisterProgram real_statistics("real_statistics", "Compute mean and variance of real numbers (problem_size = number of elements)", create_real_statistics_circuit); 110 | } 111 | -------------------------------------------------------------------------------- /src/programs/util.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_PROGRAMS_UTIL_HPP_ 23 | #define MAGE_PROGRAMS_UTIL_HPP_ 24 | 25 | #include "dsl/leveledbatch.hpp" 26 | #include "dsl/integer.hpp" 27 | 28 | using namespace mage::dsl; 29 | 30 | namespace mage::programs { 31 | using DefaultPlacer = memprog::BinnedPlacer; 32 | constexpr Program** default_program = &program_ptr; 33 | 34 | template 35 | using Integer = mage::dsl::Integer; 36 | 37 | template 38 | using IntSlice = mage::dsl::Integer; 39 | 40 | using Bit = Integer<1>; 41 | using BitSlice = IntSlice<1>; 42 | 43 | template 44 | using LeveledBatch = mage::dsl::LeveledBatch; 45 | 46 | template 47 | using LeveledPlaintextBatch = mage::dsl::LeveledPlaintextBatch; 48 | 49 | template 50 | Integer<2 * width> dot_product(Integer* vector_a, Integer* vector_b, std::size_t length) { 51 | assert(length != 0); 52 | Integer<2 * width> total = vector_a[0] * vector_b[0]; 53 | for (std::size_t i = 1; i != length; i++) { 54 | total = total + (vector_a[i] * vector_b[i]); 55 | } 56 | return total; 57 | } 58 | 59 | template 60 | LeveledBatch real_dot_product_not_normalized(LeveledBatch* vector_a, LeveledBatch* vector_b, std::size_t length) { 61 | assert(length != 0); 62 | LeveledBatch total = vector_a[0].multiply_without_normalizing(vector_b[0]); 63 | for (std::size_t i = 1; i != length; i++) { 64 | total = total + vector_a[i].multiply_without_normalizing(vector_b[i]); 65 | } 66 | return total; 67 | } 68 | 69 | template 70 | LeveledBatch real_dot_product(LeveledBatch* vector_a, LeveledBatch* vector_b, std::size_t length) { 71 | LeveledBatch total = real_dot_product_not_normalized(vector_a, vector_b, length); 72 | return total.renormalize(); 73 | } 74 | 75 | template 76 | struct Record { 77 | Integer data; 78 | 79 | IntSlice get_key() { 80 | return this->data.template slice(0); 81 | } 82 | 83 | IntSlice get_record() { 84 | return this->data.template slice(key_width); 85 | } 86 | 87 | static void comparator(Record& arg0, Record& arg1) { 88 | IntSlice key0 = arg0.get_key(); 89 | IntSlice key1 = arg1.get_key(); 90 | Bit predicate = key0 > key1; 91 | Integer::swap_if(predicate, arg0.data, arg1.data); 92 | } 93 | 94 | void buffer_send(WorkerID to) { 95 | this->data.buffer_send(to); 96 | } 97 | 98 | static void finish_send(WorkerID to) { 99 | Integer::finish_send(to); 100 | } 101 | 102 | void post_receive(WorkerID from) { 103 | this->data.post_receive(from); 104 | } 105 | 106 | static void finish_receive(WorkerID from) { 107 | Integer::finish_receive(from); 108 | } 109 | }; 110 | } 111 | 112 | #endif 113 | -------------------------------------------------------------------------------- /src/memprog/annotation.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | /** 23 | * @file memprog/annotation.hpp 24 | * @brief Annotation reverse pass for MAGE's planner 25 | * 26 | * The annotation reverse pass is needed to apply Belady's theoretically 27 | * optimal paging algorithm (MIN) in the Replacement phase. 28 | */ 29 | 30 | #ifndef MAGE_MEMPROG_ANNOTATION_HPP_ 31 | #define MAGE_MEMPROG_ANNOTATION_HPP_ 32 | 33 | #include 34 | #include 35 | #include "memprog/program.hpp" 36 | #include "util/filebuffer.hpp" 37 | 38 | namespace mage::memprog { 39 | /** 40 | * @brief Structure describing the encoding of annotations. 41 | * 42 | * An annotation describes, for page accessed by an instruction, the 43 | * position of the next instruction that acesses that page, or 44 | * @p invalid_instr if no future instruction accesses that page. 45 | */ 46 | struct Annotation { 47 | struct { 48 | std::uint16_t num_pages; 49 | } __attribute__((packed)) header; 50 | struct { 51 | InstructionNumber next_use : instruction_number_bits; 52 | } __attribute__((packed)) slots[5]; 53 | 54 | /** 55 | * @brief Computes the size of this annotation based on its header. 56 | * 57 | * This is useful when reading annotations from a file. 58 | * 59 | * @return The size of this annotation. 60 | */ 61 | std::uint16_t size() const { 62 | return sizeof(Annotation::header) + this->header.num_pages * sizeof(Annotation::slots[0]); 63 | } 64 | 65 | /** 66 | * @brief Computes the address of the next annotation in the sequence, 67 | * assuming that annotations are packed together sequentially in 68 | * memory. 69 | * 70 | * This is particularly useful when a file containing annotations is 71 | * mapped into memory. 72 | * 73 | * @return The address of the next annotation in the sequence. 74 | */ 75 | Annotation* next() { 76 | std::uint8_t* self = reinterpret_cast(this); 77 | return reinterpret_cast(self + this->size()); 78 | } 79 | 80 | /** 81 | * @brief Computes the address of the next annotation in the sequence, 82 | * assuming that annotations are packed together sequentially in 83 | * memory. 84 | * 85 | * This is particularly useful when a file containing annotations is 86 | * mapped into memory. 87 | * 88 | * @return The address of the next annotation in the sequence. 89 | */ 90 | const Annotation* next() const { 91 | const std::uint8_t* self = reinterpret_cast(this); 92 | return reinterpret_cast(self + this->size()); 93 | } 94 | } __attribute__((packed)); 95 | 96 | /** 97 | * @brief Computes annotations for a virtual bytecode. 98 | * 99 | * This involves iterating over the virtual bytecode in reverse order. 100 | * 101 | * @param annotations The file name to which the annotations should be 102 | * written. 103 | * @param program The file name containing the virtual bytecode to read. 104 | * This sequence of instructions should be reverse-iterable (e.g., written 105 | * with a BufferedFileWriter with backwards_readable == true). 106 | * @param page_shift Base-2 logarithm of the page size. 107 | * @param progress_bar Progress bar to use to show progress, or nullptr if 108 | * none should be used. 109 | */ 110 | std::uint64_t annotate_program(std::string annotations, std::string program, PageShift page_shift, util::ProgressBar* progress_bar = nullptr); 111 | } 112 | 113 | #endif 114 | -------------------------------------------------------------------------------- /src/programs/binary_fc_layer.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::binary_fc_layer { 32 | template 33 | Bit binary_dot_product(Integer* vector_a, Integer* vector_b, std::size_t num_batches) { 34 | assert(num_batches != 0); 35 | std::vector> xnors(num_batches); 36 | std::vector xnor_bits; 37 | for (std::size_t i = 0; i != num_batches; i++) { 38 | xnors[i] = ~(vector_a[i] ^ vector_b[i]); 39 | for (BitWidth j = 0; j != batch_size; j++) { 40 | xnor_bits.emplace_back(xnors[i][j]); 41 | } 42 | } 43 | 44 | Integer<31> popcount = reduce<31, 1, true, DefaultPlacer, default_program>(xnor_bits.data(), xnor_bits.size()); 45 | Integer<32> two_p(0); 46 | two_p.slice<31>(1).mutate(popcount); 47 | 48 | return two_p >= Integer<32>(xnor_bits.size()); 49 | } 50 | 51 | template 52 | std::vector local_binary_layer(Integer* matrix_a, std::size_t num_rows_a, Integer* vector_x, std::size_t num_cols_a_len_x) { 53 | std::vector result(num_rows_a); 54 | for (std::size_t row_a = 0; row_a != num_rows_a; row_a++) { 55 | result[row_a] = binary_dot_product(&matrix_a[row_a * num_cols_a_len_x], vector_x, num_cols_a_len_x); 56 | } 57 | return result; 58 | } 59 | 60 | template 61 | void create_binary_fc_layer_circuit(const ProgramOptions& args) { 62 | std::uint64_t vector_size = args.problem_size; 63 | std::uint64_t matrix_dimension = vector_size; 64 | std::uint64_t matrix_size = matrix_dimension * matrix_dimension; 65 | 66 | if (vector_size % batch_size != 0) { 67 | std::cerr << "Problem size must be a multiple of the batch size" << std::endl; 68 | return; 69 | } 70 | 71 | /* Blocked vector provided by the evaluator. */ 72 | ShardedArray> vector_x(vector_size / batch_size, args.worker_index, args.num_workers, Layout::Blocked); 73 | vector_x.for_each([=](std::size_t i, auto& elem) { 74 | elem.mark_input(Party::Evaluator); 75 | }); 76 | 77 | /* Blocked row-major matrix provided by the garbler. */ 78 | std::uint64_t num_columns = matrix_dimension / batch_size; 79 | std::uint64_t num_rows = matrix_dimension / args.num_workers; 80 | if (args.worker_index < matrix_dimension % args.num_workers) { 81 | num_rows += 1; 82 | } 83 | std::vector> my_matrix_a(num_rows * num_columns); 84 | for (auto& elem : my_matrix_a) { 85 | elem.mark_input(Party::Garbler); 86 | } 87 | 88 | program_ptr->print_stats(); 89 | program_ptr->start_timer(); 90 | 91 | /* Reconstruct the entire vector x for each worker. */ 92 | std::vector> my_vector_x = vector_x.materialize_global_array(true); 93 | 94 | std::vector result = local_binary_layer(my_matrix_a.data(), my_matrix_a.size() / my_vector_x.size(), my_vector_x.data(), my_vector_x.size()); 95 | 96 | program_ptr->stop_timer(); 97 | program_ptr->print_stats(); 98 | 99 | for (std::size_t i = 0; i != result.size(); i++) { 100 | result[i].mark_output(); 101 | } 102 | } 103 | 104 | RegisterProgram binary_fc_layer("binary_fc_layer", "Binary Matrix-Vector Multiply (problem_size = number of elements in one side of matrix)", create_binary_fc_layer_circuit<>); 105 | } 106 | -------------------------------------------------------------------------------- /src/protocols/tfhe.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_PROTOCOLS_TFHE_HPP_ 23 | #define MAGE_PROTOCOLS_TFHE_HPP_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include "engine/cluster.hpp" 32 | #include "platform/network.hpp" 33 | #include "protocols/tfhe_scheme.hpp" 34 | #include "util/binaryfile.hpp" 35 | #include "util/userpipe.hpp" 36 | 37 | namespace mage::protocols::tfhe { 38 | class TFHEEngine { 39 | public: 40 | using Wire = TFHEScheme::Wire; 41 | 42 | TFHEEngine(const char* garbler_input_file, const char* evaluator_input_file, const char* output_file) 43 | : garbler_input_reader(garbler_input_file, std::ios::binary), evaluator_input_reader(evaluator_input_file, std::ios::binary), output_writer(output_file, std::ios::binary) { 44 | { 45 | std::ifstream params_file("params", std::ios::binary); 46 | if (!params_file.is_open()) { 47 | std::cerr << "Could not open params" << std::endl; 48 | std::abort(); 49 | } 50 | this->tfhe.set_params(params_file); 51 | } 52 | { 53 | std::ifstream cloud_key_file("cloud.key", std::ios::binary); 54 | if (!cloud_key_file.is_open()) { 55 | std::cerr << "Could not open cloud.key" << std::endl; 56 | std::abort(); 57 | } 58 | this->tfhe.set_cloud_key(cloud_key_file); 59 | } 60 | } 61 | 62 | void print_stats() { 63 | } 64 | 65 | void input(Wire* data, unsigned int length, bool garbler) { 66 | std::ifstream* reader_ptr = garbler ? &this->garbler_input_reader : &this->evaluator_input_reader; 67 | reader_ptr->read(reinterpret_cast(data), length * sizeof(Wire)); 68 | if (reader_ptr->eof()) { 69 | std::cerr << "TFHE::input -> std::ifstream::read: end of file" << std::endl; 70 | std::abort(); 71 | } else if (reader_ptr->fail() || reader_ptr->bad()) { 72 | std::cerr << "TFHE::input -> std::ifstream::read: failure" << std::endl; 73 | std::abort(); 74 | } 75 | } 76 | 77 | void output(const Wire* data, unsigned int length) { 78 | this->output_writer.write(reinterpret_cast(data), length * sizeof(Wire)); 79 | if (this->output_writer.fail() || this->output_writer.bad()) { 80 | std::cerr << "TFHE::output -> std::ofstream::write: failure" << std::endl; 81 | std::abort(); 82 | } 83 | } 84 | 85 | void op_and(Wire& output, const Wire& input1, const Wire& input2) { 86 | this->tfhe.op_and(output, input1, input2); 87 | } 88 | 89 | void op_xor(Wire& output, const Wire& input1, const Wire& input2) { 90 | this->tfhe.op_xor(output, input1, input2); 91 | } 92 | 93 | void op_not(Wire& output, const Wire& input) { 94 | this->tfhe.op_not(output, input); 95 | } 96 | 97 | void op_xnor(Wire& output, const Wire& input1, const Wire& input2) { 98 | this->tfhe.op_xnor(output, input1, input2); 99 | } 100 | 101 | void op_copy(Wire& output, const Wire& input) { 102 | this->tfhe.op_copy(output, input); 103 | } 104 | 105 | void one(Wire& output) { 106 | this->tfhe.one(output); 107 | } 108 | 109 | void zero(Wire& output) { 110 | this->tfhe.zero(output); 111 | } 112 | 113 | private: 114 | TFHEScheme tfhe; 115 | 116 | std::ifstream garbler_input_reader; 117 | std::ifstream evaluator_input_reader; 118 | std::ofstream output_writer; 119 | }; 120 | } 121 | 122 | #endif 123 | -------------------------------------------------------------------------------- /src/util/stats.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | /** 23 | * @file util/stats.hpp 24 | * @brief Utilities for statistics collection. 25 | */ 26 | 27 | #ifndef MAGE_UTIL_STATS_HPP_ 28 | #define MAGE_UTIL_STATS_HPP_ 29 | 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | namespace mage::util { 36 | /** 37 | * @brief Records statistics (min, mean, max, sum, count) of events. 38 | * 39 | * This is typically used to measure the latency of a type of event, both 40 | * of individual events and in aggregate, to understand its performance 41 | * impact. 42 | */ 43 | class StreamStats { 44 | friend std::ostream& operator <<(std::ostream& out, const StreamStats& s); 45 | 46 | public: 47 | /** 48 | * @brief Creates a StreamStats object with the label "" 49 | * that will not automatically print out the measured statistics when 50 | * its destructor is called. 51 | */ 52 | StreamStats() : StreamStats("") { 53 | } 54 | 55 | /** 56 | * @brief Creates a StreamStats object. 57 | * 58 | * @param name The label for this statistics collector, used when 59 | * printing out the measured statistics. 60 | * @param print_stats_on_exit If true, the statistics will be printed 61 | * when the destructor for this StreamStats is called. 62 | */ 63 | StreamStats(std::string name, bool print_stats_on_exit = false) : label(name), print_on_exit(print_stats_on_exit), 64 | stat_max(0), stat_sum(0), stat_min(0), stat_count(0) { 65 | } 66 | 67 | /** 68 | * @brief Prints out the measured statistics if print_stats_on_exit was 69 | * specified. 70 | */ 71 | ~StreamStats() { 72 | if (this->print_on_exit) { 73 | std::cout << *this << std::endl; 74 | } 75 | } 76 | 77 | /** 78 | * @brief Set the label for this statistics collector, used when 79 | * printing out the measured statistics. 80 | * 81 | * @param label The label for this statistics collector. 82 | * @param print_stats_on_exit If true, the statistics will be printed 83 | * when the destructor for this StreamStats is called. 84 | */ 85 | void set_label(const std::string& label, bool print_stats_on_exit = true) { 86 | this->label = label; 87 | this->print_on_exit = print_stats_on_exit; 88 | } 89 | 90 | /** 91 | * @brief Record an event. 92 | * 93 | * @param stat The value for the event (typically a latency 94 | * measurement), used to compute the min, mean, max, and sum. 95 | */ 96 | void event(std::uint64_t stat) { 97 | if (this->stat_count == 0) { 98 | this->stat_max = stat; 99 | this->stat_sum = stat; 100 | this->stat_min = stat; 101 | this->stat_count = 1; 102 | } else { 103 | this->stat_max = std::max(this->stat_max, stat); 104 | this->stat_sum += stat; 105 | this->stat_min = std::min(this->stat_min, stat); 106 | this->stat_count++; 107 | } 108 | } 109 | 110 | private: 111 | std::uint64_t stat_max; 112 | std::uint64_t stat_sum; 113 | std::uint64_t stat_min; 114 | std::uint64_t stat_count; 115 | 116 | std::string label; 117 | bool print_on_exit; 118 | }; 119 | 120 | /** 121 | * @brief Prints out the measured statistics for a StreamStats object. 122 | */ 123 | inline std::ostream& operator <<(std::ostream& out, const StreamStats& s) { 124 | return out << s.label << ": ( min = " << s.stat_min << ", avg = " << (s.stat_count == 0 ? 0 : s.stat_sum / s.stat_count) << ", max = " << s.stat_max << ", count = " << s.stat_count << ", sum = " << s.stat_sum << " )"; 125 | } 126 | } 127 | 128 | #endif 129 | -------------------------------------------------------------------------------- /src/util/progress.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2021 Sam Kumar 3 | * Copyright (C) 2021 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include "platform/misc.hpp" 30 | #include "util/misc.hpp" 31 | #include "util/progress.hpp" 32 | 33 | namespace mage::util { 34 | ProgressBar::ProgressBar(const std::string& label, std::uint64_t total_units) { 35 | this->reset(label, total_units); 36 | } 37 | 38 | void ProgressBar::set_label(const std::string& label) { 39 | this->display_name = label; 40 | this->current_width = 0; 41 | } 42 | 43 | void ProgressBar::reset(const std::string& label, std::uint64_t total_units) { 44 | this->set_label(label); 45 | this->reset(total_units); 46 | } 47 | 48 | void ProgressBar::reset(std::uint64_t total_units) { 49 | this->update_threshold = total_units; 50 | this->next_update = 0; 51 | this->current_count = 0; 52 | this->total_count = total_units; 53 | } 54 | 55 | void ProgressBar::erase() const { 56 | if (this->current_width != 0) { 57 | std::cout << '\r'; 58 | } 59 | } 60 | 61 | void ProgressBar::display() const { 62 | if (this->current_width != 0) { 63 | std::cout << this->bar << std::flush; 64 | } 65 | } 66 | 67 | void ProgressBar::finish(bool fill) { 68 | if (this->current_width != 0) { 69 | if (fill) { 70 | this->refresh(this->total_count); 71 | } 72 | std::cout << std::endl; 73 | } 74 | } 75 | 76 | void ProgressBar::update(std::uint64_t num_units) { 77 | std::uint32_t percentage = (100 * num_units) / this->total_count; 78 | char* percent_start = this->get_percent_start(); 79 | int written = std::snprintf(percent_start, 4, "%3" PRIu32, percentage); 80 | percent_start[written] = '%'; 81 | 82 | char* bar_start = this->get_bar_start(); 83 | std::uint32_t bar_length = (this->bar_capacity * num_units) / this->total_count; 84 | bar_length = std::min(bar_length, this->bar_capacity); 85 | std::uint32_t i; 86 | for (i = 0; i != bar_length; i++) { 87 | bar_start[i] = ProgressBar::bar_full; 88 | } 89 | for (; i != this->bar_capacity; i++) { 90 | bar_start[i] = ProgressBar::bar_empty; 91 | } 92 | } 93 | 94 | bool ProgressBar::reconstruct_bar_if_necessary() { 95 | platform::TerminalSize ts; 96 | platform::get_terminal_size(ts); 97 | if (ts.num_cols != this->current_width) { 98 | this->current_width = ts.num_cols; 99 | this->construct_bar(); 100 | } 101 | return this->current_width != 0; 102 | } 103 | 104 | void ProgressBar::construct_bar() { 105 | constexpr const char* preamble = ": [ 0%] ["; 106 | 107 | /* 1 is for \r at the beginning. */ 108 | this->bar_start = 1 + this->display_name.length() + std::strlen(preamble); 109 | 110 | /* +1 for the ']' at the end, but -1 for the \r (which doesn't actually take up space.) */ 111 | std::uint32_t bar_space = this->bar_start + 1 - 1; 112 | if (bar_space > this->current_width) { 113 | this->bar_capacity = 0; 114 | } else { 115 | this->bar_capacity = this->current_width - bar_space; 116 | } 117 | 118 | std::ostringstream buffer; 119 | buffer << '\r' << this->display_name << preamble; 120 | for (std::uint32_t i = 0; i != this->bar_capacity; i++) { 121 | buffer << ProgressBar::bar_empty; 122 | } 123 | buffer << ']'; 124 | this->bar = buffer.str(); 125 | 126 | this->update_threshold = this->total_count / std::max(this->bar_capacity, UINT32_C(100)); 127 | if (this->update_threshold == 0) { 128 | this->update_threshold = 1; 129 | } 130 | } 131 | 132 | char* ProgressBar::get_bar_start() { 133 | return &this->bar[this->bar_start]; 134 | } 135 | 136 | char* ProgressBar::get_percent_start() { 137 | return &this->bar[this->bar_start - 7]; 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /src/executables/planner.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include "addr.hpp" 27 | #include "memprog/pipeline.hpp" 28 | #include "programs/registry.hpp" 29 | #include "protocols/registry.hpp" 30 | #include "util/config.hpp" 31 | 32 | using mage::programs::ProgramOptions; 33 | using mage::programs::RegisteredProgram; 34 | using mage::protocols::RegisteredProtocol; 35 | using mage::protocols::RegisteredPlacementPlugin; 36 | using mage::util::Registry; 37 | 38 | int main(int argc, char** argv) { 39 | if (argc != 7) { 40 | std::cerr << "Usage: " << argv[0] << " program_name protocol/plugin config.yaml party_id worker_index input_size" << std::endl; 41 | Registry::print_all("programs", std::cerr); 42 | return EXIT_FAILURE; 43 | } 44 | 45 | std::string program_name(argv[1]); 46 | const RegisteredProgram* prog = Registry::look_up_by_name(program_name); 47 | if (prog == nullptr) { 48 | std::cerr << program_name << " is not a valid program name. "; // lack of std::endl is intentional 49 | Registry::print_all("programs", std::cerr); 50 | return EXIT_FAILURE; 51 | } 52 | 53 | std::string protocol(argv[2]); 54 | std::string plugin_name; 55 | mage::memprog::PlacementPlugin plugin; 56 | const RegisteredProtocol* prot = Registry::look_up_by_name(protocol); 57 | if (prot == nullptr) { 58 | const RegisteredPlacementPlugin* plug = Registry::look_up_by_name(protocol); 59 | if (plug == nullptr) { 60 | std::cerr << protocol << " is not a valid protocol name or plugin name. "; // lack of std::endl is intentional 61 | Registry::print_all("protocols", std::cerr); 62 | Registry::print_all("plugins", std::cerr); 63 | return EXIT_FAILURE; 64 | } 65 | plugin = plug->get_placement_plugin(); 66 | } else { 67 | plugin = prot->get_placement_plugin(); 68 | } 69 | 70 | mage::util::Configuration c(argv[3]); 71 | 72 | std::optional party_id = mage::protocols::parse_party_id(argv[4]); 73 | if (!party_id.has_value()) { 74 | std::cerr << "Invalid party_id (try \"garbler\", \"evaluator\", or an integer)" << std::endl; 75 | return EXIT_FAILURE; 76 | } 77 | 78 | mage::WorkerID num_workers = c["parties"][*party_id]["workers"].get_size(); 79 | 80 | errno = 0; 81 | mage::WorkerID index = std::strtoull(argv[5], nullptr, 10); 82 | if (errno != 0) { 83 | std::perror("Fourth argument (index)"); 84 | return EXIT_FAILURE; 85 | } 86 | if (index >= num_workers) { 87 | std::cerr << "Worker index is " << index << " but there are only " << num_workers << " workers" << std::endl; 88 | return EXIT_FAILURE; 89 | } 90 | 91 | errno = 0; 92 | std::uint64_t problem_size = std::strtoull(argv[6], nullptr, 10); 93 | if (errno != 0 || problem_size == 0) { 94 | std::cerr << "Bad fifth argument (input size)" << std::endl; 95 | return 1; 96 | } 97 | 98 | const mage::util::ConfigValue& w = c["parties"][*party_id]["workers"][index]; 99 | 100 | ProgramOptions args = {}; 101 | args.worker_config = &w; 102 | args.num_workers = num_workers; 103 | args.worker_index = index; 104 | args.problem_size = problem_size; 105 | 106 | std::string problem_name = program_name + "_" + std::to_string(problem_size) + "_" + std::to_string(index); 107 | 108 | mage::memprog::DefaultPipeline planner(problem_name, w); 109 | planner.set_verbose(true); 110 | planner.plan(&mage::programs::program_ptr, prot->get_placement_plugin(), [prog, &args]() { 111 | (*prog)(args); 112 | }); 113 | 114 | std::cout << std::endl; 115 | 116 | const mage::memprog::DefaultPipelineStats& stats = planner.get_stats(); 117 | 118 | std::cout << "Phase Times (ms): " << stats.placement_duration.count() << " " 119 | << stats.replacement_duration.count() << " " << stats.scheduling_duration.count() << std::endl; 120 | 121 | return EXIT_SUCCESS; 122 | } 123 | -------------------------------------------------------------------------------- /src/programs/aspirin.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "dsl/array.hpp" 23 | #include "dsl/integer.hpp" 24 | #include "dsl/parallel.hpp" 25 | #include "dsl/sort.hpp" 26 | #include "programs/registry.hpp" 27 | #include "programs/util.hpp" 28 | 29 | using namespace mage::dsl; 30 | 31 | namespace mage::programs::aspirin { 32 | template 33 | struct Input { 34 | Integer patient_id_concat_timestamp; 35 | Bit diagnosis; // or aspirin prescription 36 | 37 | static void comparator(Input& arg0, Input& arg1) { 38 | Bit predicate = arg0.patient_id_concat_timestamp > arg1.patient_id_concat_timestamp; 39 | Integer::swap_if(predicate, arg0.patient_id_concat_timestamp, arg1.patient_id_concat_timestamp); 40 | Bit::swap_if(predicate, arg0.diagnosis, arg1.diagnosis); 41 | } 42 | 43 | void buffer_send(WorkerID to) { 44 | this->patient_id_concat_timestamp.buffer_send(to); 45 | this->diagnosis.buffer_send(to); 46 | } 47 | 48 | static void finish_send(WorkerID to) { 49 | Integer::finish_send(to); 50 | } 51 | 52 | void post_receive(WorkerID from) { 53 | this->patient_id_concat_timestamp.post_receive(from); 54 | this->diagnosis.post_receive(from); 55 | } 56 | 57 | static void finish_receive(WorkerID from) { 58 | Integer::finish_receive(from); 59 | } 60 | }; 61 | 62 | template 63 | void create_parallel_aspirin_circuit(const ProgramOptions& args) { 64 | int input_array_length = args.problem_size * 2; 65 | 66 | ClusterUtils utils; 67 | utils.self_id = args.worker_index; 68 | utils.num_proc = args.num_workers; 69 | 70 | ShardedArray> inputs(input_array_length, args.worker_index, args.num_workers, Layout::Cyclic); 71 | inputs.for_each([=](std::size_t i, auto& input) { 72 | input.patient_id_concat_timestamp.mark_input(i < args.problem_size ? Party::Garbler : Party::Evaluator); 73 | input.diagnosis.mark_input(i < args.problem_size ? Party::Garbler : Party::Evaluator); 74 | }); 75 | 76 | // Verify that inputs are sorted 77 | 78 | Bit local_order(1); 79 | inputs.for_each_pair([&](std::size_t i, auto& first, auto& second) { 80 | if (i < args.problem_size - 1) { 81 | Bit lte = first.patient_id_concat_timestamp <= second.patient_id_concat_timestamp; 82 | local_order = local_order & lte; 83 | } else if (i >= args.problem_size) { 84 | Bit gte = first.patient_id_concat_timestamp >= second.patient_id_concat_timestamp; 85 | local_order = local_order & gte; 86 | } 87 | }); 88 | std::optional order = utils.reduce_aggregates(0, local_order, [](Bit& first, Bit& second) -> Bit { 89 | return first & second; 90 | }); 91 | if (args.worker_index == 0) { 92 | order.value().mark_output(); 93 | } 94 | 95 | // Sort inputs and switch to blocked layout 96 | parallel_bitonic_sorter(inputs); 97 | 98 | Integer local_total(0); 99 | inputs.for_each_pair([&local_total](std::size_t index, Input& first, Input& second) { 100 | Bit add = first.diagnosis & ~second.diagnosis; 101 | IntSlice patient_id_i = first.patient_id_concat_timestamp.template slice(timestamp_bits); 102 | IntSlice patient_id_ip1 = second.patient_id_concat_timestamp.template slice(timestamp_bits); 103 | add = add & (patient_id_i == patient_id_ip1); 104 | Integer next = local_total.increment(); 105 | local_total = Integer::select(add, next, local_total); 106 | }); 107 | 108 | std::optional> total = utils.reduce_aggregates>(0, local_total, [](Integer& first, Integer& second) -> Integer { 109 | return first + second; 110 | }); 111 | if (args.worker_index == 0) { 112 | total.value().mark_output(); 113 | } 114 | } 115 | 116 | RegisterProgram aspirin("aspirin", "Aspirin Count (problem_size = number of events per party)", create_parallel_aspirin_circuit<>); 117 | } 118 | -------------------------------------------------------------------------------- /src/crypto/group.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * This file is based heavily on the file utils/group_openssl.h in EMP-toolkit. 3 | */ 4 | 5 | #include "crypto/group.hpp" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace mage::crypto { 17 | static inline void openssl_fail() { 18 | ERR_print_errors_fp(stderr); 19 | std::abort(); 20 | } 21 | 22 | DDHGroup::DDHGroup() { 23 | this->ec_group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1); 24 | assert(this->ec_group != nullptr); 25 | 26 | this->bn_ctx = BN_CTX_new(); 27 | assert(this->bn_ctx != nullptr); 28 | 29 | EC_GROUP_precompute_mult(this->ec_group, this->bn_ctx); 30 | 31 | this->order = BN_new(); 32 | assert(this->order != nullptr); 33 | 34 | EC_GROUP_get_order(this->ec_group, this->order, this->bn_ctx); 35 | } 36 | 37 | DDHGroup::~DDHGroup() { 38 | assert(this->ec_group != nullptr); 39 | EC_GROUP_free(this->ec_group); 40 | 41 | assert(this->order != nullptr); 42 | BN_free(this->order); 43 | 44 | assert(this->bn_ctx != nullptr); 45 | BN_CTX_free(this->bn_ctx); 46 | } 47 | 48 | ScalarMod::ScalarMod(const DDHGroup& g) : group(g) { 49 | this->n = BN_new(); 50 | assert(this->n != nullptr); 51 | } 52 | 53 | ScalarMod::ScalarMod(const ScalarMod& other) : ScalarMod(other.group) { 54 | BIGNUM* rv = BN_copy(this->n, other.n); 55 | if (rv == nullptr) { 56 | openssl_fail(); 57 | } 58 | } 59 | 60 | ScalarMod::~ScalarMod() { 61 | assert(this->n != nullptr); 62 | BN_free(this->n); 63 | } 64 | 65 | void ScalarMod::set_random() { 66 | int rv = BN_rand_range(this->n, this->group.order); 67 | if (rv != 1) { 68 | openssl_fail(); 69 | } 70 | } 71 | 72 | void ScalarMod::multiply(const ScalarMod& a, const ScalarMod& b) { 73 | int rv = BN_mod_mul(this->n, a.n, b.n, this->group.order, this->group.bn_ctx); 74 | if (rv != 1) { 75 | openssl_fail(); 76 | } 77 | } 78 | 79 | DDHGroupElement::DDHGroupElement(const DDHGroup& g) : group(g) { 80 | this->point = EC_POINT_new(this->group.ec_group); 81 | assert(this->point != nullptr); 82 | } 83 | 84 | DDHGroupElement::DDHGroupElement(const DDHGroupElement& other) : DDHGroupElement(other.group) { 85 | int rv = EC_POINT_copy(this->point, other.point); 86 | if (rv != 1) { 87 | openssl_fail(); 88 | } 89 | } 90 | 91 | DDHGroupElement::~DDHGroupElement() { 92 | assert(this->point != nullptr); 93 | EC_POINT_free(this->point); 94 | } 95 | 96 | void DDHGroupElement::marshal_uncompressed(std::uint8_t* buffer, std::size_t length) const { 97 | int rv = EC_POINT_point2oct(this->group.ec_group, this->point, POINT_CONVERSION_UNCOMPRESSED, buffer, length, this->group.bn_ctx); 98 | assert(((std::size_t) rv) <= length); 99 | (void) rv; 100 | } 101 | 102 | std::size_t DDHGroupElement::marshalled_uncompressed_size() const { 103 | int rv = EC_POINT_point2oct(this->group.ec_group, this->point, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0, this->group.bn_ctx); 104 | assert(rv != 0); 105 | return (std::size_t) rv; 106 | } 107 | 108 | void DDHGroupElement::unmarshal_uncompressed(const std::uint8_t* buffer, std::size_t length) { 109 | int rv = EC_POINT_oct2point(this->group.ec_group, this->point, buffer, length, this->group.bn_ctx); 110 | if (rv != 1) { 111 | openssl_fail(); 112 | } 113 | } 114 | 115 | void DDHGroupElement::set_generator() { 116 | int rv = EC_POINT_copy(this->point, EC_GROUP_get0_generator(this->group.ec_group)); 117 | if (rv != 1) { 118 | openssl_fail(); 119 | } 120 | } 121 | 122 | void DDHGroupElement::add(const DDHGroupElement& a, const DDHGroupElement& __restrict b) { 123 | int rv = EC_POINT_add(this->group.ec_group, this->point, a.point, b.point, this->group.bn_ctx); 124 | if (rv != 1) { 125 | openssl_fail(); 126 | } 127 | } 128 | 129 | void DDHGroupElement::multiply_generator(const ScalarMod& __restrict m) { 130 | int rv = EC_POINT_mul(this->group.ec_group, this->point, m.n, nullptr, nullptr, this->group.bn_ctx); 131 | if (rv != 1) { 132 | openssl_fail(); 133 | } 134 | } 135 | 136 | void DDHGroupElement::multiply_restrict(const DDHGroupElement& __restrict base, const ScalarMod& __restrict m) { 137 | int rv = EC_POINT_mul(this->group.ec_group, this->point, nullptr, base.point, m.n, this->group.bn_ctx); 138 | if (rv != 1) { 139 | openssl_fail(); 140 | } 141 | } 142 | 143 | void DDHGroupElement::invert() { 144 | int rv = EC_POINT_invert(this->group.ec_group, this->point, this->group.bn_ctx); 145 | if (rv != 1) { 146 | openssl_fail(); 147 | } 148 | } 149 | 150 | bool DDHGroupElement::operator ==(const DDHGroupElement& other) { 151 | int rv = EC_POINT_cmp(this->group.ec_group, this->point, other.point, this->group.bn_ctx); 152 | if (rv == -1) { 153 | openssl_fail(); 154 | } 155 | return rv == 0; 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /src/platform/network.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "platform/network.hpp" 23 | 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | #include 35 | 36 | namespace mage::platform { 37 | void network_accept(const char* port, int* into, std::uint32_t count) { 38 | struct addrinfo hints = { 0 }; 39 | hints.ai_flags = AI_PASSIVE; 40 | hints.ai_family = AF_INET; 41 | hints.ai_socktype = SOCK_STREAM; 42 | 43 | struct addrinfo* info; 44 | int rv = getaddrinfo(NULL, port, &hints, &info); 45 | if (rv != 0) { 46 | std::cerr << "network_accept -> getaddrinfo: " << gai_strerror(rv) << std::endl; 47 | std::abort(); 48 | } 49 | 50 | int server_socket = socket(info->ai_family, info->ai_socktype, info->ai_protocol); 51 | if (server_socket == -1) { 52 | std::perror("network_accept -> socket"); 53 | std::abort(); 54 | } 55 | 56 | 57 | if (bind(server_socket, info->ai_addr, info->ai_addrlen) == -1) { 58 | std::perror("network_accept -> bind"); 59 | std::abort(); 60 | } 61 | 62 | freeaddrinfo(info); 63 | 64 | if (listen(server_socket, 0) == -1) { 65 | std::perror("network_accept -> listen"); 66 | std::abort(); 67 | } 68 | 69 | for (std::uint32_t i = 0; i != count; i++) { 70 | into[i] = accept(server_socket, NULL, NULL); 71 | if (into[i] == -1) { 72 | std::perror("network_accept -> accept"); 73 | std::abort(); 74 | } 75 | 76 | /* Maintain firewall state through idle periods. */ 77 | int keepalive = 1; 78 | if (setsockopt(into[i], SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) == -1) { 79 | std::perror("network_connect -> setsockopt"); 80 | std::abort(); 81 | } 82 | } 83 | 84 | if (close(server_socket) == -1) { 85 | std::perror("network_accept -> close"); 86 | std::abort(); 87 | } 88 | } 89 | 90 | void network_connect(const char* host, const char* port, int* into, NetworkError* err, std::uint32_t count) { 91 | struct addrinfo hints = { 0 }; 92 | hints.ai_family = AF_INET; 93 | hints.ai_socktype = SOCK_STREAM; 94 | 95 | struct addrinfo* info; 96 | int rv = getaddrinfo(host, port, &hints, &info); 97 | if (rv != 0) { 98 | std::cerr << "network_connect -> getaddrinfo: " << gai_strerror(rv) << std::endl; 99 | std::abort(); 100 | } 101 | 102 | for (std::uint32_t i = 0; i != count; i++) { 103 | into[i] = socket(info->ai_family, info->ai_socktype, info->ai_protocol); 104 | if (into[i] == -1) { 105 | std::perror("network_connect -> socket"); 106 | std::abort(); 107 | } 108 | 109 | /* Maintain firewall state through idle periods. */ 110 | int keepalive = 1; 111 | if (setsockopt(into[i], SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) == -1) { 112 | std::perror("network_connect -> setsockopt"); 113 | std::abort(); 114 | } 115 | 116 | if (connect(into[i], info->ai_addr, info->ai_addrlen) == -1) { 117 | if (err != nullptr && errno == ECONNREFUSED) { 118 | err[i] = NetworkError::ConnectionRefused; 119 | } else if (err != nullptr && errno == ETIMEDOUT) { 120 | err[i] = NetworkError::TimedOut; 121 | } else { 122 | std::perror("network_connect -> connect"); 123 | std::abort(); 124 | } 125 | } else if (err != nullptr) { 126 | err[i] = NetworkError::Success; 127 | } 128 | } 129 | 130 | freeaddrinfo(info); 131 | } 132 | 133 | void network_close(int socket) { 134 | if (close(socket) == -1) { 135 | std::perror("network_close -> close"); 136 | std::abort(); 137 | } 138 | } 139 | 140 | void pipe_open(int* into) { 141 | if (pipe(into) != 0) { 142 | std::perror("pipe_open -> pipe"); 143 | std::abort(); 144 | } 145 | } 146 | 147 | void pipe_close(int fd) { 148 | if (close(fd) != 0) { 149 | std::perror("pipe_close -> close"); 150 | std::abort(); 151 | } 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /src/memprog/pipeline.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include "memprog/pipeline.hpp" 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include "memprog/annotation.hpp" 28 | #include "memprog/placement.hpp" 29 | #include "memprog/program.hpp" 30 | #include "memprog/replacement.hpp" 31 | #include "memprog/scheduling.hpp" 32 | 33 | namespace mage::memprog { 34 | DefaultPipeline::DefaultPipeline(const std::string& name) : Pipeline(name), 35 | page_shift(12), num_pages(1 << 10), prefetch_buffer_size(256), prefetch_lookahead(10000), 36 | stats({}), verbose(false) { 37 | } 38 | 39 | DefaultPipeline::DefaultPipeline(const std::string& name, const util::ConfigValue& worker) : Pipeline(name) { 40 | this->read_config(worker); 41 | } 42 | 43 | void DefaultPipeline::set_verbose(bool be_verbose) { 44 | this->verbose = be_verbose; 45 | } 46 | 47 | void DefaultPipeline::read_config(const util::ConfigValue& worker) { 48 | this->page_shift = worker["page_shift"].as_int(); 49 | this->num_pages = worker["num_pages"].as_int(); 50 | this->prefetch_buffer_size = worker["prefetch_buffer_size"].as_int(); 51 | this->prefetch_lookahead = worker["prefetch_lookahead"].as_int(); 52 | } 53 | 54 | void DefaultPipeline::program(Program** p, PlacementPlugin plugin, std::function dsl_program, const std::string& prog_file) { 55 | Program program(prog_file, this->page_shift, plugin); 56 | *p = &program; 57 | dsl_program(); 58 | *p = nullptr; 59 | this->stats.num_instructions = program.num_instructions(); 60 | 61 | if (this->verbose) { 62 | std::cout << "Created program with " << program.num_instructions() << " instructions" << std::endl; 63 | } 64 | } 65 | 66 | void DefaultPipeline::allocate(const std::string& prog_file, const std::string& repprog_file) { 67 | this->progress_bar.set_label("Annotations Pass"); 68 | std::string ann_file = this->program_name + ".ann"; 69 | annotate_program(ann_file, prog_file, this->page_shift, &this->progress_bar); 70 | this->progress_bar.finish(); 71 | if (this->verbose) { 72 | std::cout << "Computed annotations" << std::endl; 73 | } 74 | 75 | this->progress_bar.set_label("Replacement Pass"); 76 | BeladyAllocator allocator(repprog_file, prog_file, ann_file, this->num_pages, this->page_shift); 77 | allocator.allocate(&this->progress_bar); 78 | this->progress_bar.finish(); 79 | this->stats.num_swapouts = allocator.get_num_swapouts(); 80 | this->stats.num_swapins = allocator.get_num_swapins(); 81 | if (this->verbose) { 82 | std::cout << "Finished replacement stage: " << allocator.get_num_swapouts() << " swapouts, " << allocator.get_num_swapins() << " swapins" << std::endl; 83 | } 84 | } 85 | 86 | void DefaultPipeline::schedule(const std::string& repprog_file, const std::string& memprog_file) { 87 | this->progress_bar.set_label("Scheduling Pass"); 88 | BackdatingScheduler scheduler(repprog_file, memprog_file, this->prefetch_lookahead, this->prefetch_buffer_size); 89 | scheduler.schedule(&this->progress_bar); 90 | this->progress_bar.finish(); 91 | this->stats.num_prefetch_alloc_failures = scheduler.get_num_allocation_failures(); 92 | this->stats.num_synchronous_swapins = scheduler.get_num_synchronous_swapins(); 93 | if (this->verbose) { 94 | std::cout << "Finished scheduling swaps: " << scheduler.get_num_allocation_failures() << " allocation failures, " << scheduler.get_num_synchronous_swapins() << " synchronous swapins" << std::endl; 95 | } 96 | } 97 | 98 | void DefaultPipeline::plan(Program** p, PlacementPlugin plugin, std::function program) { 99 | auto program_start = std::chrono::steady_clock::now(); 100 | this->program(p, plugin, program, this->program_name + ".prog"); 101 | auto program_end = std::chrono::steady_clock::now(); 102 | this->stats.placement_duration = std::chrono::duration_cast(program_end - program_start); 103 | 104 | auto replacement_start = std::chrono::steady_clock::now(); 105 | this->allocate(this->program_name + ".prog", this->program_name + ".repprog"); 106 | auto replacement_end = std::chrono::steady_clock::now(); 107 | this->stats.replacement_duration = std::chrono::duration_cast(replacement_end - replacement_start); 108 | 109 | auto scheduling_start = std::chrono::steady_clock::now(); 110 | this->schedule(this->program_name + ".repprog", this->program_name + ".memprog"); 111 | auto scheduling_end = std::chrono::steady_clock::now(); 112 | this->stats.scheduling_duration = std::chrono::duration_cast(scheduling_end - scheduling_start); 113 | } 114 | 115 | const DefaultPipelineStats& DefaultPipeline::get_stats() const { 116 | return this->stats; 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /src/crypto/ot/base.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include "crypto/block.hpp" 28 | #include "crypto/group.hpp" 29 | #include "crypto/hash.hpp" 30 | #include "util/filebuffer.hpp" 31 | 32 | namespace mage::crypto::ot { 33 | /* 34 | * This implements the protocol given in Sections 2.3 and 3 of the 35 | * following paper: 36 | * M. Naor and B. Pinkas. Efficient Oblivious Transfer Protocols. SODA 2001. 37 | * Here is a link, for convenience: http://www.pinkas.net/PAPERS/effot.ps. 38 | * 39 | * I generalized it to support n rounds instead of 1. 40 | */ 41 | 42 | static inline void send_group_element(util::BufferedFileWriter& network_out, const DDHGroupElement& elem) { 43 | std::size_t size = elem.marshalled_uncompressed_size(); 44 | network_out.write() = size; 45 | void* into = network_out.start_write(size); 46 | elem.marshal_uncompressed(static_cast(into), size); 47 | network_out.finish_write(size); 48 | } 49 | 50 | static inline void receive_group_element(util::BufferedFileReader& network_in, DDHGroupElement& elem) { 51 | std::size_t size = network_in.read(); 52 | const void* from = network_in.start_read(size); 53 | elem.unmarshal_uncompressed(static_cast(from), size); 54 | network_in.finish_read(size); 55 | } 56 | 57 | static inline block hash_group_element_to_block(const DDHGroupElement& elem) { 58 | std::size_t size = elem.marshalled_uncompressed_size(); 59 | std::uint8_t buffer[size] __attribute__((aligned(16))); 60 | elem.marshal_uncompressed(buffer, size); 61 | return hash_to_block(buffer, size); 62 | } 63 | 64 | struct BaseOTSenderDecision { 65 | BaseOTSenderDecision(const DDHGroup& g) : r(g), gr(g), pk0(g) { 66 | } 67 | ScalarMod r; 68 | DDHGroupElement gr; 69 | DDHGroupElement pk0; 70 | }; 71 | 72 | void base_send(const DDHGroup& g, util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, const std::pair* choices, std::size_t num_choices) { 73 | DDHGroupElement c(g); 74 | c.set_generator(); 75 | send_group_element(network_out, c); 76 | 77 | network_out.flush(); 78 | 79 | std::vector decisions; 80 | decisions.reserve(num_choices); 81 | for (std::size_t i = 0; i != num_choices; i++) { 82 | decisions.emplace_back(g); 83 | BaseOTSenderDecision& decision = decisions[i]; 84 | receive_group_element(network_in, decision.pk0); 85 | } 86 | 87 | DDHGroupElement pk1(g); 88 | for (std::size_t i = 0; i != num_choices; i++) { 89 | BaseOTSenderDecision& decision = decisions[i]; 90 | decision.r.set_random(); 91 | decision.gr.multiply_generator(decision.r); 92 | send_group_element(network_out, decision.gr); 93 | decision.gr.multiply_restrict(decision.pk0, decision.r); 94 | decision.pk0.invert(); 95 | pk1.add(c, decision.pk0); 96 | 97 | block ciphertext0 = xorBlocks(hash_group_element_to_block(decision.gr), choices[i].first); 98 | block_store_unaligned(ciphertext0, &network_out.write()); 99 | decision.gr.multiply_restrict(pk1, decision.r); 100 | block ciphertext1 = xorBlocks(hash_group_element_to_block(decision.gr), choices[i].second); 101 | block_store_unaligned(ciphertext1, &network_out.write()); 102 | } 103 | 104 | network_out.flush(); 105 | } 106 | 107 | void base_choose(const DDHGroup& g, util::BufferedFileReader& network_in, util::BufferedFileWriter& network_out, const bool* choices, block* results, std::size_t num_choices) { 108 | DDHGroupElement c(g); 109 | receive_group_element(network_in, c); 110 | 111 | std::vector k; 112 | k.reserve(num_choices); 113 | DDHGroupElement key0(g); 114 | for (std::size_t i = 0; i != num_choices; i++) { 115 | k.emplace_back(g); 116 | k[i].set_random(); 117 | key0.multiply_generator(k[i]); 118 | if (choices[i]) { 119 | key0.invert(); 120 | key0.add(key0, c); 121 | } 122 | send_group_element(network_out, key0); 123 | } 124 | network_out.flush(); 125 | 126 | DDHGroupElement gr(g); 127 | for (std::size_t i = 0; i != num_choices; i++) { 128 | receive_group_element(network_in, gr); 129 | key0.multiply_restrict(gr, k[i]); 130 | void* data = network_in.start_read(2 * sizeof(block)); 131 | block* ciphertexts = static_cast(data); 132 | block ciphertext = block_load_unaligned(choices[i] ? &ciphertexts[1] : &ciphertexts[0]); 133 | results[i] = xorBlocks(ciphertext, hash_group_element_to_block(key0)); 134 | network_in.finish_read(2 * sizeof(block)); 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/util/progress.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2021 Sam Kumar 3 | * Copyright (C) 2021 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | /** 23 | * @file util/progress.hpp 24 | * @brief Terminal-based ASCII progress bar. 25 | */ 26 | 27 | #ifndef MAGE_UTIL_PROGRESS_HPP_ 28 | #define MAGE_UTIL_PROGRESS_HPP_ 29 | 30 | #include 31 | #include 32 | #include "util/misc.hpp" 33 | 34 | namespace mage::util { 35 | /** 36 | * @brief Represents a ASCII terminal-based progress bar. 37 | * 38 | * The width of the bar is determined automatically based on the width of 39 | * the terminal corresponding to standard output; if standard output does 40 | * not correspond to a terminal, then no progress bar is displayed. 41 | */ 42 | class ProgressBar { 43 | public: 44 | /** 45 | * @brief Creates a progress bar with the specified label and units of 46 | * work. 47 | * 48 | * Initially, zero units of work are completed for the progress bar; 49 | * it will be printed out (displayed) on the first call to @p refresh. 50 | * 51 | * @param label A string with the label to print out to the left of 52 | * the progress bar. 53 | * @param total_units The total number of units of work in the task 54 | * whose progress is being measured. 55 | */ 56 | ProgressBar(const std::string& label = "", std::uint64_t total_units = 0); 57 | 58 | /** 59 | * @brief Sets the label on a progress bar. 60 | * 61 | * @param label A string with the label to print out to the left of 62 | * the progress bar. 63 | */ 64 | void set_label(const std::string& label); 65 | 66 | /** 67 | * @brief Re-initializes a progress bar, resetting it to the state of a 68 | * newly constructed progress bar. 69 | * 70 | * @param label A string with the label to print out to the left of 71 | * the progress bar. 72 | * @param total_units The total number of units of work in the task 73 | * whose progress is being measured. 74 | */ 75 | void reset(const std::string& label, std::uint64_t total_units); 76 | 77 | /** 78 | * @brief Re-initializes a progress bar, resetting it to the state of a 79 | * newly constructed progress bar. Retains the label from before. 80 | * 81 | * @param total_units The total number of units of work in the task 82 | * whose progress is being measured. 83 | */ 84 | void reset(std::uint64_t total_units); 85 | 86 | /** 87 | * @brief Erases the on-screen progress bar, replacing it with an 88 | * empty row in the terminal. 89 | */ 90 | void erase() const; 91 | 92 | /** 93 | * @brief Displays the progress bar, causing it to re-appear after a 94 | * previous call to @p erase(). 95 | */ 96 | void display() const; 97 | 98 | /** 99 | * @brief Advances the terminal to the next line, leaving the progress 100 | * bar visible on the previous row. 101 | * 102 | * @param fill If true, the progress bar is updated to a "full" state 103 | * (task 100% completed) before advancing to the next line. 104 | */ 105 | void finish(bool fill = true); 106 | 107 | /** 108 | * @brief Increases the progress displayed on the progress bar by the 109 | * specified amount. 110 | * 111 | * @param num_units The number of additional units of work completed 112 | * for the task whose progress is being measured. 113 | */ 114 | void advance(std::uint64_t num_units) { 115 | this->refresh(this->current_count + num_units); 116 | } 117 | 118 | /** 119 | * @brief Increases the progress displayed on the progress bar to the 120 | * specified amount. 121 | * 122 | * @pre The argument @p num_units must be greater than or equal to the 123 | * value of @p num_units in previous calls to this function and less 124 | * than or equal to the value of @p total_units used to initialize this 125 | * progress bar, via the constructor or the @p reset function. 126 | * @param num_units The number of total units of work completed in the 127 | * task whose progress is being measured. 128 | */ 129 | void refresh(std::uint64_t num_units) { 130 | this->current_count = num_units; 131 | if (num_units >= this->next_update) { 132 | if (this->reconstruct_bar_if_necessary()) { 133 | this->update(num_units); 134 | this->display(); 135 | } 136 | this->next_update = static_cast(util::ceil_div(num_units + 1, this->update_threshold).first) * this->update_threshold; 137 | if (this->next_update > this->total_count) { 138 | this->next_update = this->total_count; 139 | } 140 | } 141 | } 142 | 143 | private: 144 | void update(std::uint64_t num_units); 145 | 146 | bool reconstruct_bar_if_necessary(); 147 | void construct_bar(); 148 | char* get_bar_start(); 149 | char* get_percent_start(); 150 | 151 | std::uint64_t update_threshold; 152 | std::uint64_t next_update; 153 | std::uint64_t current_count; 154 | std::uint64_t total_count; 155 | std::uint32_t bar_start; 156 | std::uint32_t bar_capacity; 157 | std::uint32_t current_width; 158 | std::string bar; 159 | std::string display_name; 160 | 161 | static constexpr const char bar_full = '#'; 162 | static constexpr const char bar_empty = '.'; 163 | }; 164 | } 165 | 166 | #endif 167 | -------------------------------------------------------------------------------- /src/crypto/aes.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * This file was originally part of OpenSSL. The EMP-toolkit library used 3 | * this code for AES-NI, with some small modifications. The code below is 4 | * derived from that version of the file. 5 | */ 6 | 7 | /* crypto/aes/aes.h -*- mode:C; c-file-style: "eay" -*- */ 8 | /* ==================================================================== 9 | * Copyright (c) 1998-2002 The OpenSSL Project. All rights reserved. 10 | * 11 | * Redistribution and use in source and binary forms, with or without 12 | * modification, are permitted provided that the following conditions 13 | * are met: 14 | * 15 | * 1. Redistributions of source code must retain the above copyright 16 | * notice, this list of conditions and the following disclaimer. 17 | * 18 | * 2. Redistributions in binary form must reproduce the above copyright 19 | * notice, this list of conditions and the following disclaimer in 20 | * the documentation and/or other materials provided with the 21 | * distribution. 22 | * 23 | * 3. All advertising materials mentioning features or use of this 24 | * software must display the following acknowledgment: 25 | * "This product includes software developed by the OpenSSL Project 26 | * for use in the OpenSSL Toolkit. (http://www.openssl.org/)" 27 | * 28 | * 4. The names "OpenSSL Toolkit" and "OpenSSL Project" must not be used to 29 | * endorse or promote products derived from this software without 30 | * prior written permission. For written permission, please contact 31 | * openssl-core@openssl.org. 32 | * 33 | * 5. Products derived from this software may not be called "OpenSSL" 34 | * nor may "OpenSSL" appear in their names without prior written 35 | * permission of the OpenSSL Project. 36 | * 37 | * 6. Redistributions of any form whatsoever must retain the following 38 | * acknowledgment: 39 | * "This product includes software developed by the OpenSSL Project 40 | * for use in the OpenSSL Toolkit (http://www.openssl.org/)" 41 | * 42 | * THIS SOFTWARE IS PROVIDED BY THE OpenSSL PROJECT ``AS IS'' AND ANY 43 | * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 44 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 45 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE OpenSSL PROJECT OR 46 | * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 47 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 48 | * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 49 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 50 | * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 51 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 52 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 53 | * OF THE POSSIBILITY OF SUCH DAMAGE. 54 | * ==================================================================== 55 | * 56 | */ 57 | 58 | #ifndef MAGE_CRYPTO_AES_HPP_ 59 | #define MAGE_CRYPTO_AES_HPP_ 60 | 61 | #include 62 | #include "crypto/block.hpp" 63 | 64 | namespace mage::crypto { 65 | constexpr const int AES_BATCH_SIZE = 2048; 66 | 67 | typedef struct { block rd_key[11]; unsigned int rounds; } AES_KEY; 68 | 69 | #define EXPAND_ASSIST(v1,v2,v3,v4,shuff_const,aes_const) \ 70 | v2 = _mm_aeskeygenassist_si128(v4,aes_const); \ 71 | v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), \ 72 | _mm_castsi128_ps(v1), 16)); \ 73 | v1 = _mm_xor_si128(v1,v3); \ 74 | v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), \ 75 | _mm_castsi128_ps(v1), 140)); \ 76 | v1 = _mm_xor_si128(v1,v3); \ 77 | v2 = _mm_shuffle_epi32(v2,shuff_const); \ 78 | v1 = _mm_xor_si128(v1,v2) 79 | 80 | inline void 81 | __attribute__((target("aes,sse2"))) 82 | AES_set_encrypt_key(const block userkey, AES_KEY *key) 83 | { 84 | block x0, x1, x2; 85 | block *kp = key->rd_key; 86 | kp[0] = x0 = userkey; 87 | x2 = _mm_setzero_si128(); 88 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 1); 89 | kp[1] = x0; 90 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 2); 91 | kp[2] = x0; 92 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 4); 93 | kp[3] = x0; 94 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 8); 95 | kp[4] = x0; 96 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 16); 97 | kp[5] = x0; 98 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 32); 99 | kp[6] = x0; 100 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 64); 101 | kp[7] = x0; 102 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 128); 103 | kp[8] = x0; 104 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 27); 105 | kp[9] = x0; 106 | EXPAND_ASSIST(x0, x1, x2, x0, 255, 54); 107 | kp[10] = x0; 108 | key->rounds = 10; 109 | } 110 | 111 | inline void 112 | __attribute__((target("aes,sse2"))) 113 | AES_ecb_encrypt_blks(block *blks, unsigned int nblks, const AES_KEY *key) 114 | { 115 | for (unsigned int i = 0; i < nblks; ++i) 116 | blks[i] = _mm_xor_si128(blks[i], key->rd_key[0]); 117 | for (unsigned int j = 1; j < key->rounds; ++j) 118 | for (unsigned int i = 0; i < nblks; ++i) 119 | blks[i] = _mm_aesenc_si128(blks[i], key->rd_key[j]); 120 | for (unsigned int i = 0; i < nblks; ++i) 121 | blks[i] = _mm_aesenclast_si128(blks[i], key->rd_key[key->rounds]); 122 | } 123 | 124 | inline void 125 | __attribute__((target("aes,sse2"))) 126 | AES_set_decrypt_key_fast(AES_KEY *dkey, const AES_KEY *ekey) 127 | { 128 | int j = 0; 129 | int i = ekey->rounds; 130 | #if (OCB_KEY_LEN == 0) 131 | dkey->rounds = i; 132 | #endif 133 | dkey->rd_key[i--] = ekey->rd_key[j++]; 134 | while (i) 135 | dkey->rd_key[i--] = _mm_aesimc_si128(ekey->rd_key[j++]); 136 | dkey->rd_key[i] = ekey->rd_key[j]; 137 | } 138 | 139 | inline void 140 | __attribute__((target("aes,sse2"))) 141 | AES_set_decrypt_key(block userkey, AES_KEY *key) 142 | { 143 | AES_KEY temp_key; 144 | AES_set_encrypt_key(userkey, &temp_key); 145 | AES_set_decrypt_key_fast(key, &temp_key); 146 | } 147 | 148 | inline void 149 | __attribute__((target("aes,sse2"))) 150 | AES_ecb_decrypt_blks(block *blks, unsigned nblks, const AES_KEY *key) 151 | { 152 | unsigned i, j, rnds = key->rounds; 153 | for (i = 0; i < nblks; ++i) 154 | blks[i] = _mm_xor_si128(blks[i], key->rd_key[0]); 155 | for (j = 1; j < rnds; ++j) 156 | for (i = 0; i < nblks; ++i) 157 | blks[i] = _mm_aesdec_si128(blks[i], key->rd_key[j]); 158 | for (i = 0; i < nblks; ++i) 159 | blks[i] = _mm_aesdeclast_si128(blks[i], key->rd_key[j]); 160 | } 161 | } 162 | 163 | #endif 164 | -------------------------------------------------------------------------------- /src/util/misc.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | /** 23 | * @file util/misc.hpp 24 | * @brief Miscellaneous utility functions. 25 | */ 26 | 27 | #ifndef MAGE_UTIL_MISC_HPP_ 28 | #define MAGE_UTIL_MISC_HPP_ 29 | 30 | #include 31 | #include 32 | #include 33 | 34 | namespace mage::util { 35 | /** 36 | * @brief Determines if the specified positive number is a power of two. 37 | * 38 | * @tparam Type of the specified number. 39 | * @param number The specified number. It must be positive. 40 | * @return True if the specified number is a power of two, otherwise false. 41 | */ 42 | template 43 | bool is_power_of_two(T number) { 44 | return (number & (number - 1)) == 0; 45 | } 46 | 47 | /** 48 | * @brief Counts the number of 1s in the binary representation of the 49 | * specified unsigned number. 50 | * 51 | * This quantity is called the "Hamming Weight" of the specified unsigned 52 | * number. 53 | * 54 | * @param number The number in question. 55 | * @return The number of 1s in the binary representation of the specified 56 | * number. 57 | */ 58 | static inline std::uint8_t hamming_weight(std::uint64_t number) { 59 | std::uint8_t weight = 0; 60 | while (number != 0) { 61 | weight += (number & 0x1); 62 | number >>= 1; 63 | } 64 | return weight; 65 | } 66 | 67 | /** 68 | * @brief Checks if the number of 1s in the binary representation of the 69 | * specified unsigned number is even or odd. 70 | * 71 | * @param number The number in question. 72 | * @return True if the number of 1s in the binary representation of the 73 | * specified number is odd, otherwise false. 74 | */ 75 | static inline bool hamming_parity(std::uint64_t number) { 76 | return (hamming_weight(number) & 0x1) != 0x0; 77 | } 78 | 79 | /** 80 | * @brief Computes the base-2 logarithm of the specified number. 81 | * 82 | * @param The number whose logarithm to compute. 83 | * @return The smallest nonnegative number x such that (2 ^ x) >= number. 84 | */ 85 | static inline std::uint8_t log_base_2(std::uint64_t number) { 86 | std::uint8_t logarithm = 0; 87 | while ((UINT64_C(1) << logarithm) < number) { 88 | logarithm++; 89 | } 90 | return logarithm; 91 | } 92 | 93 | /** 94 | * @brief Computes the floor division of two signed numbers. 95 | * 96 | * This function finds the unique quotient q and remainder r, where 97 | * 0 <= r < divisor, such that dividend = (q * divisor) + r. 98 | * 99 | * @param dividend The number to divide (i.e., the numerator). 100 | * @param divisor The number to divide by (i.e., the denominator). 101 | * @return A pair whose first item is the quotient q and whose second item 102 | * is the remainder r. 103 | */ 104 | static inline std::pair floor_div(std::int64_t dividend, std::int64_t divisor) { 105 | int64_t quotient = dividend / divisor; 106 | int64_t remainder = dividend % divisor; 107 | if (remainder < 0) { 108 | quotient -= 1; 109 | remainder += divisor; 110 | } 111 | return std::make_pair(quotient, remainder); 112 | } 113 | 114 | /** 115 | * @brief Computes the ceiling division of two signed numbers. 116 | * 117 | * This function finds the unique quotient q and remainder r, where 118 | * -divisor < r <= 0, such that dividend = (q * divisor) + r. 119 | * 120 | * @param dividend The number to divide (i.e., the numerator). 121 | * @param divisor The number to divide by (i.e., the denominator). 122 | * @return A pair whose first item is the quotient q and whose second item 123 | * is the remainder r. 124 | */ 125 | static inline std::pair ceil_div(std::int64_t dividend, std::int64_t divisor) { 126 | int64_t quotient = dividend / divisor; 127 | int64_t remainder = dividend % divisor; 128 | if (remainder > 0) { 129 | quotient += 1; 130 | remainder -= divisor; 131 | } 132 | return std::make_pair(quotient, remainder); 133 | } 134 | 135 | /** 136 | * @brief A specialization of std::streambuf supporting reads and writes 137 | * to an in-memory buffer. 138 | */ 139 | class MemoryBuffer : public std::streambuf { 140 | /** 141 | * @brief Creates a std::streambuf for reading and writing to an 142 | * in-memory buffer. 143 | * 144 | * @param buffer A pointer to the in-memory buffer to use. 145 | * @param length The size, in bytes, of the in-memory buffer. 146 | */ 147 | MemoryBuffer(void* buffer, std::size_t length) { 148 | char* base = static_cast(buffer); 149 | this->setg(base, base, base + length); 150 | this->setp(base, base + length); 151 | } 152 | }; 153 | 154 | /** 155 | * @brief A specialization of std::streambuf supporting reads from an 156 | * in-memory buffer. 157 | */ 158 | class MemoryReadBuffer : public std::streambuf { 159 | public: 160 | /** 161 | * @brief Creates a std::streambuf for reading from an in-memory 162 | * buffer. 163 | * 164 | * @param buffer A pointer to the in-memory buffer to use. 165 | * @param length The size, in bytes, of the in-memory buffer. 166 | */ 167 | MemoryReadBuffer(const void* buffer, std::size_t length) { 168 | char* base = const_cast(static_cast(buffer)); 169 | this->setg(base, base, base + length); 170 | } 171 | }; 172 | 173 | /** 174 | * @brief A specialization of std::streambuf supporting writes to an 175 | * in-memory buffer. 176 | */ 177 | class MemoryWriteBuffer : public std::streambuf { 178 | public: 179 | /** 180 | * @brief Creates a std::streambuf for writing to an in-memory buffer. 181 | * 182 | * @param buffer A pointer to the in-memory buffer to use. 183 | * @param length The size, in bytes, of the in-memory buffer. 184 | */ 185 | MemoryWriteBuffer(void* buffer, std::size_t length) { 186 | char* base = static_cast(buffer); 187 | this->setp(base, base + length); 188 | } 189 | }; 190 | } 191 | 192 | #endif 193 | -------------------------------------------------------------------------------- /tests/test_prioqueue.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #define BOOST_TEST_DYN_LINK 23 | #include "boost/test/unit_test.hpp" 24 | #include "boost/test/data/test_case.hpp" 25 | #include "boost/test/data/monomorphic.hpp" 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | #include "util/prioqueue.hpp" 33 | 34 | namespace bdata = boost::unit_test::data; 35 | using mage::util::PriorityQueue; 36 | 37 | struct VectorSample { 38 | std::vector data; 39 | friend std::ostream& operator<<(std::ostream& out, const VectorSample& sample); 40 | }; 41 | 42 | std::ostream& operator<<(std::ostream& out, const VectorSample& sample) { 43 | out << "{ "; 44 | for (auto i = sample.data.begin(); i != sample.data.end(); i++) { 45 | out << *i << " "; 46 | } 47 | return out << "}"; 48 | } 49 | 50 | class RandomIntsDataset { 51 | public: 52 | using sample = VectorSample; 53 | 54 | enum { 55 | arity = 1 56 | }; 57 | 58 | struct iterator { 59 | public: 60 | iterator(unsigned int seed) : seedv(seed) { 61 | this->operator++(); 62 | } 63 | 64 | VectorSample operator*() const { 65 | VectorSample s; 66 | s.data = this->sample; 67 | return s; 68 | } 69 | 70 | void operator++() { 71 | int length = rand_r(&this->seedv) % 256; 72 | this->sample.resize(length); 73 | for (int i = 0; i != length; i++) { 74 | this->sample[i] = i + 1; 75 | } 76 | std::random_shuffle(sample.begin(), this->sample.end()); 77 | } 78 | private: 79 | std::vector sample; 80 | unsigned int seedv; 81 | }; 82 | 83 | RandomIntsDataset(bdata::size_t size, unsigned int seed = 12) 84 | : num_samples(size), seedv(seed) { 85 | } 86 | 87 | bdata::size_t size() const { 88 | return this->num_samples; 89 | } 90 | 91 | iterator begin() const { 92 | return iterator(this->seedv); 93 | } 94 | 95 | private: 96 | bdata::size_t num_samples; 97 | unsigned int seedv; 98 | }; 99 | 100 | namespace boost::unit_test::data::monomorphic { 101 | template <> 102 | struct is_dataset : boost::mpl::true_ {}; 103 | } 104 | 105 | VectorSample reverse = { .data = { 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 } }; 106 | 107 | BOOST_DATA_TEST_CASE(test_prioqueue_min, bdata::make(reverse) + RandomIntsDataset(99)) { 108 | std::vector numbers(sample.data); 109 | PriorityQueue pq; 110 | for (auto i = numbers.begin(); i != numbers.end(); i++) { 111 | pq.insert(*i, *i); 112 | } 113 | 114 | std::vector popped; 115 | while (!pq.empty()) { 116 | auto res = pq.remove_min(); 117 | BOOST_CHECK(res.first == res.second); 118 | popped.push_back(res.second); 119 | } 120 | 121 | std::vector sorted(numbers); 122 | std::sort(sorted.begin(), sorted.end()); 123 | BOOST_REQUIRE(sorted.size() == popped.size()); 124 | for (int i = 0; i != numbers.size(); i++) { 125 | BOOST_CHECK(sorted[i] == popped[i]); 126 | } 127 | } 128 | 129 | BOOST_DATA_TEST_CASE(test_prioqueue_second_min, bdata::make(reverse) + RandomIntsDataset(99)) { 130 | std::vector numbers(sample.data); 131 | if (numbers.size() == 0) { 132 | return; 133 | } 134 | 135 | PriorityQueue pq; 136 | for (auto i = numbers.begin(); i != numbers.end(); i++) { 137 | pq.insert(*i, *i); 138 | } 139 | 140 | std::vector popped; 141 | while (pq.size() != 1) { 142 | auto res = pq.remove_second_min(); 143 | BOOST_CHECK(res.first == res.second); 144 | popped.push_back(res.second); 145 | } 146 | 147 | std::vector sorted(numbers); 148 | std::sort(sorted.begin(), sorted.end()); 149 | BOOST_REQUIRE(sorted.size() == popped.size() + 1); 150 | for (int i = 0; i != popped.size(); i++) { 151 | BOOST_CHECK(sorted[i + 1] == popped[i]); 152 | } 153 | } 154 | 155 | BOOST_DATA_TEST_CASE(test_prioqueue_decrease_key, bdata::make(reverse) + RandomIntsDataset(99)) { 156 | std::vector numbers(sample.data); 157 | 158 | std::vector numbers2; 159 | for (int i = numbers.size() / 2; i != numbers.size(); i++) { 160 | numbers2.push_back(numbers[i]); 161 | } 162 | numbers.resize(numbers.size() / 2); 163 | 164 | for (int i = 0; i != numbers.size(); i++) { 165 | numbers[i] = std::min(numbers[i], numbers2[i]); 166 | } 167 | 168 | PriorityQueue pq; 169 | for (int i = 0; i != numbers.size(); i++) { 170 | pq.insert(numbers2[i], numbers[i]); 171 | } 172 | 173 | for (int i = 0; i != numbers.size(); i++) { 174 | pq.decrease_key(numbers[i], numbers[i]); 175 | } 176 | 177 | std::vector popped; 178 | while (!pq.empty()) { 179 | auto res = pq.remove_min(); 180 | BOOST_CHECK(res.first == res.second); 181 | popped.push_back(res.second); 182 | } 183 | 184 | std::vector sorted(numbers); 185 | std::sort(sorted.begin(), sorted.end()); 186 | BOOST_REQUIRE(sorted.size() == popped.size()); 187 | for (int i = 0; i != numbers.size(); i++) { 188 | BOOST_CHECK(sorted[i] == popped[i]); 189 | } 190 | } 191 | 192 | BOOST_DATA_TEST_CASE(test_prioqueue_increase_key, bdata::make(reverse) + RandomIntsDataset(99)) { 193 | std::vector numbers(sample.data); 194 | 195 | std::vector numbers2; 196 | for (int i = numbers.size() / 2; i != numbers.size(); i++) { 197 | numbers2.push_back(numbers[i]); 198 | } 199 | numbers.resize(numbers.size() / 2); 200 | 201 | for (int i = 0; i != numbers.size(); i++) { 202 | numbers[i] = std::max(numbers[i], numbers2[i]); 203 | } 204 | 205 | PriorityQueue pq; 206 | for (int i = 0; i != numbers.size(); i++) { 207 | pq.insert(numbers2[i], numbers[i]); 208 | } 209 | 210 | for (int i = 0; i != numbers.size(); i++) { 211 | pq.increase_key(numbers[i], numbers[i]); 212 | } 213 | 214 | std::vector popped; 215 | while (!pq.empty()) { 216 | auto res = pq.remove_min(); 217 | BOOST_CHECK(res.first == res.second); 218 | popped.push_back(res.second); 219 | } 220 | 221 | std::vector sorted(numbers); 222 | std::sort(sorted.begin(), sorted.end()); 223 | BOOST_REQUIRE(sorted.size() == popped.size()); 224 | for (int i = 0; i != numbers.size(); i++) { 225 | BOOST_CHECK(sorted[i] == popped[i]); 226 | } 227 | } 228 | -------------------------------------------------------------------------------- /src/protocols/tfhe_scheme.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #ifndef MAGE_PROTOCOLS_TFHE_SCHEME_HPP_ 23 | #define MAGE_PROTOCOLS_TFHE_SCHEME_HPP_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | namespace mage::protocols::tfhe { 33 | constexpr const std::size_t tfhe_ciphertext_size = 2536; 34 | constexpr const int tfhe_num_temp_ciphertexts = 3; 35 | 36 | struct TFHECiphertext { 37 | std::uint8_t data[tfhe_ciphertext_size]; 38 | }; 39 | 40 | class TFHECiphertextReadBuffer : public std::streambuf { 41 | public: 42 | TFHECiphertextReadBuffer(const TFHECiphertext* ciphertext) { 43 | char* base = reinterpret_cast(const_cast(&ciphertext->data[0])); 44 | this->setg(base, base, base + sizeof(ciphertext->data)); 45 | } 46 | }; 47 | 48 | class TFHECiphertextWriteBuffer : public std::streambuf { 49 | public: 50 | TFHECiphertextWriteBuffer(TFHECiphertext* ciphertext) { 51 | char* base = reinterpret_cast(&ciphertext->data[0]); 52 | this->setp(base, base + sizeof(ciphertext->data)); 53 | } 54 | }; 55 | 56 | class TFHEScheme { 57 | public: 58 | using Wire = TFHECiphertext; 59 | 60 | TFHEScheme() : params(nullptr), cloud_key(nullptr), ciphertexts(nullptr) { 61 | } 62 | 63 | virtual ~TFHEScheme() { 64 | this->clear_params(); 65 | this->clear_cloud_key(); 66 | this->clear_ciphertexts(); 67 | } 68 | 69 | void set_params(std::istream& params_stream) { 70 | this->clear_params(); 71 | this->clear_ciphertexts(); 72 | 73 | this->params = new_tfheGateBootstrappingParameterSet_fromStream(params_stream); 74 | if (this->params == nullptr) { 75 | std::cerr << "Out of memory (allocating TFHE params)" << std::endl; 76 | std::abort(); 77 | } 78 | 79 | this->ciphertexts = new_gate_bootstrapping_ciphertext_array(tfhe_num_temp_ciphertexts, this->params); 80 | if (this->ciphertexts == nullptr) { 81 | std::cerr << "Out of memory (allocating TFHE ciphertexts)" << std::endl; 82 | std::abort(); 83 | } 84 | } 85 | 86 | void set_cloud_key(std::istream& cloud_stream) { 87 | this->clear_cloud_key(); 88 | 89 | this->cloud_key = new_tfheGateBootstrappingCloudKeySet_fromStream(cloud_stream); 90 | if (this->cloud_key == nullptr) { 91 | std::cerr << "Out of memory (allocating TFHE cloud key)" << std::endl; 92 | std::abort(); 93 | } 94 | } 95 | 96 | void op_and(Wire& output, const Wire& input1, const Wire& input2) { 97 | this->load_ciphertexts(input1, input2); 98 | bootsAND(&this->ciphertexts[0], &this->ciphertexts[1], &this->ciphertexts[2], this->cloud_key); 99 | this->unload_ciphertext(output); 100 | } 101 | 102 | void op_xor(Wire& output, const Wire& input1, const Wire& input2) { 103 | this->load_ciphertexts(input1, input2); 104 | bootsXOR(&this->ciphertexts[0], &this->ciphertexts[1], &this->ciphertexts[2], this->cloud_key); 105 | this->unload_ciphertext(output); 106 | } 107 | 108 | void op_not(Wire& output, const Wire& input) { 109 | this->load_ciphertexts(input); 110 | bootsNOT(&this->ciphertexts[0], &this->ciphertexts[1], this->cloud_key); 111 | this->unload_ciphertext(output); 112 | } 113 | 114 | void op_xnor(Wire& output, const Wire& input1, const Wire& input2) { 115 | this->load_ciphertexts(input1, input2); 116 | bootsXNOR(&this->ciphertexts[0], &this->ciphertexts[1], &this->ciphertexts[2], this->cloud_key); 117 | this->unload_ciphertext(output); 118 | } 119 | 120 | void op_copy(Wire& output, const Wire& input) { 121 | output = input; // don't want to use copy gate --- will add more copies 122 | } 123 | 124 | void one(Wire& output) { 125 | bootsCONSTANT(&this->ciphertexts[0], 1, this->cloud_key); 126 | this->unload_ciphertext(output); 127 | } 128 | 129 | void zero(Wire& output) { 130 | bootsCONSTANT(&this->ciphertexts[0], 1, this->cloud_key); 131 | this->unload_ciphertext(output); 132 | } 133 | 134 | private: 135 | void clear_params() { 136 | if (this->params != nullptr) { 137 | delete_gate_bootstrapping_parameters(this->params); 138 | this->params = nullptr; 139 | } 140 | } 141 | 142 | void clear_cloud_key() { 143 | if (this->cloud_key != nullptr) { 144 | delete_gate_bootstrapping_cloud_keyset(this->cloud_key); 145 | this->cloud_key = nullptr; 146 | } 147 | } 148 | 149 | void clear_ciphertexts() { 150 | if (this->ciphertexts != nullptr) { 151 | delete_gate_bootstrapping_ciphertext_array(tfhe_num_temp_ciphertexts, this->ciphertexts); 152 | this->ciphertexts = nullptr; 153 | } 154 | } 155 | 156 | void write_ciphertext(Wire& into, LweSample* from) { 157 | TFHECiphertextWriteBuffer buffer(&into); 158 | std::ostream stream(&buffer); 159 | export_gate_bootstrapping_ciphertext_toStream(stream, from, this->params); 160 | } 161 | 162 | void read_ciphertext(LweSample* into, const Wire& from) { 163 | TFHECiphertextReadBuffer buffer(&from); 164 | std::istream stream(&buffer); 165 | import_gate_bootstrapping_ciphertext_fromStream(stream, into, this->params); 166 | } 167 | 168 | void load_ciphertexts(const Wire& input1, const Wire& input2) { 169 | this->read_ciphertext(&this->ciphertexts[1], input1); 170 | this->read_ciphertext(&this->ciphertexts[2], input2); 171 | } 172 | 173 | void load_ciphertexts(const Wire& input1) { 174 | this->read_ciphertext(&this->ciphertexts[1], input1); 175 | } 176 | 177 | void unload_ciphertext(Wire& output) { 178 | this->write_ciphertext(output, &this->ciphertexts[0]); 179 | } 180 | 181 | TFheGateBootstrappingParameterSet* params; 182 | TFheGateBootstrappingCloudKeySet* cloud_key; 183 | LweSample* ciphertexts; 184 | }; 185 | } 186 | 187 | #endif 188 | -------------------------------------------------------------------------------- /src/platform/filesystem.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include "memory.hpp" 32 | 33 | namespace mage::platform { 34 | int create_file(const char* filename, std::uint64_t length, bool direct, bool unsparsify) { 35 | int flags = O_CREAT | O_RDWR | O_TRUNC; 36 | if (direct) { 37 | flags |= O_DIRECT; 38 | } 39 | int fd = open(filename, flags, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); 40 | if (fd == -1) { 41 | std::perror("create_file -> open"); 42 | std::abort(); 43 | } 44 | if (ftruncate(fd, (off_t) length) != 0) { 45 | std::perror("create_file -> ftruncate"); 46 | std::abort(); 47 | } 48 | if (unsparsify) { 49 | static constexpr const std::uint64_t buf_size = 4096; 50 | std::uint8_t* buf = allocate_resident_memory(buf_size); 51 | std::fill(buf, buf + buf_size, 0x00); 52 | std::uint64_t left = length; 53 | while (left != 0) { 54 | ssize_t rv = write(fd, buf, std::min(left, buf_size)); 55 | if (rv < 0) { 56 | std::perror("create_file -> write"); 57 | std::abort(); 58 | } 59 | left -= rv; 60 | } 61 | deallocate_resident_memory(buf, buf_size); 62 | } 63 | return fd; 64 | } 65 | 66 | int open_file(const char* filename, std::uint64_t* length, bool direct) { 67 | int flags = O_RDWR; 68 | if (direct) { 69 | flags |= O_DIRECT; 70 | } 71 | int fd = open(filename, flags); 72 | if (fd == -1) { 73 | std::perror("open_file -> open"); 74 | std::abort(); 75 | } 76 | if (length != nullptr) { 77 | off_t end = lseek(fd, 0, SEEK_END); 78 | if (end == (off_t) -1) { 79 | std::perror("open_file -> lseek"); 80 | std::abort(); 81 | } 82 | *length = (std::uint64_t) end; 83 | } 84 | return fd; 85 | } 86 | 87 | std::uint64_t length_file(int fd) { 88 | off_t pos = lseek(fd, 0, SEEK_CUR); 89 | if (pos == (off_t) -1) { 90 | std::perror("length_file -> lseek"); 91 | std::abort(); 92 | } 93 | off_t end = lseek(fd, 0, SEEK_END); 94 | if (end == (off_t) -1) { 95 | std::perror("length_file -> lseek"); 96 | std::abort(); 97 | } 98 | off_t rv = lseek(fd, pos, SEEK_SET); 99 | if (rv == (off_t) -1) { 100 | std::perror("length_file -> lseek"); 101 | std::abort(); 102 | } 103 | return end; 104 | } 105 | 106 | void write_to_file(int fd, const void* buffer, std::size_t length) { 107 | const std::uint8_t* data = reinterpret_cast(buffer); 108 | std::size_t processed = 0; 109 | while (processed != length) { 110 | ssize_t rv = write(fd, &data[processed], length - processed); 111 | if (rv <= 0) { 112 | if (rv < 0) { 113 | std::perror("write_to_file -> write"); 114 | } 115 | std::abort(); 116 | } 117 | processed += rv; 118 | } 119 | } 120 | 121 | void write_to_file_at(int fd, const void* buffer, std::size_t length, std::uint64_t offset) { 122 | const std::uint8_t* data = reinterpret_cast(buffer); 123 | std::size_t processed = 0; 124 | while (processed != length) { 125 | ssize_t rv = pwrite(fd, &data[processed], length - processed, (off_t) (offset + processed)); 126 | if (rv <= 0) { 127 | if (rv < 0) { 128 | std::perror("write_to_file_at -> pwrite"); 129 | } 130 | std::abort(); 131 | } 132 | processed += rv; 133 | } 134 | } 135 | 136 | std::size_t read_from_file(int fd, void* buffer, std::size_t length) { 137 | std::uint8_t* data = reinterpret_cast(buffer); 138 | std::size_t processed = 0; 139 | while (processed != length) { 140 | ssize_t rv = read(fd, &data[processed], length - processed); 141 | if (rv <= 0) { 142 | if (rv < 0) { 143 | std::perror("read_from_file -> read"); 144 | std::abort(); 145 | } 146 | break; 147 | } 148 | processed += rv; 149 | } 150 | return processed; 151 | } 152 | 153 | std::size_t read_from_file_at(int fd, void* buffer, std::size_t length, std::uint64_t offset) { 154 | std::uint8_t* data = reinterpret_cast(buffer); 155 | std::size_t processed = 0; 156 | while (processed != length) { 157 | ssize_t rv = pread(fd, &data[processed], length - processed, (off_t) (offset + processed)); 158 | if (rv <= 0) { 159 | if (rv < 0) { 160 | std::perror("read_from_file -> pread"); 161 | std::abort(); 162 | } 163 | break; 164 | } 165 | processed += rv; 166 | } 167 | return processed; 168 | } 169 | 170 | std::size_t read_available_from_file(int fd, void* buffer, std::size_t length) { 171 | std::uint8_t* data = reinterpret_cast(buffer); 172 | ssize_t rv = read(fd, data, length); 173 | if (rv < 0) { 174 | std::perror("read_from_file -> read"); 175 | std::abort(); 176 | } 177 | return rv; 178 | } 179 | 180 | void seek_file(int fd, std::int64_t amount, bool relative) { 181 | if (lseek(fd, (off_t) amount, relative ? SEEK_CUR : SEEK_SET) == -1) { 182 | std::perror("seek_file -> lseek"); 183 | std::abort(); 184 | } 185 | } 186 | 187 | void prefetch_from_file_at(int fd, std::uint64_t offset, std::size_t length) { 188 | if (readahead(fd, (off64_t) offset, length) == -1) { 189 | std::perror("prefetch_from_file -> readahead"); 190 | std::abort(); 191 | } 192 | } 193 | 194 | std::uint64_t tell_file(int fd) { 195 | off_t rv = lseek(fd, 0, SEEK_CUR); 196 | if (rv == -1) { 197 | std::perror("tell_file -> lseek"); 198 | std::abort(); 199 | } 200 | return rv; 201 | } 202 | 203 | void close_file(int fd) { 204 | if (close(fd) == -1) { 205 | std::perror("close_file -> close"); 206 | std::abort(); 207 | } 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /src/programs/real_matrix_multiply.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2020 Sam Kumar 3 | * Copyright (C) 2020 University of California, Berkeley 4 | * All rights reserved. 5 | * 6 | * This file is part of MAGE. 7 | * 8 | * MAGE is free software: you can redistribute it and/or modify 9 | * it under the terms of the GNU General Public License as published by 10 | * the Free Software Foundation, either version 3 of the License, or 11 | * (at your option) any later version. 12 | * 13 | * MAGE is distributed in the hope that it will be useful, 14 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | * GNU General Public License for more details. 17 | * 18 | * You should have received a copy of the GNU General Public License 19 | * along with MAGE. If not, see . 20 | */ 21 | 22 | #include 23 | #include 24 | #include "dsl/array.hpp" 25 | #include "dsl/integer.hpp" 26 | #include "dsl/parallel.hpp" 27 | #include "dsl/sort.hpp" 28 | #include "programs/registry.hpp" 29 | #include "programs/util.hpp" 30 | 31 | using namespace mage::dsl; 32 | 33 | namespace mage::programs::real_matrix_multiply { 34 | template 35 | std::vector> local_naive_matrix_multiply(LeveledBatch* matrix_a, std::size_t num_rows_a, LeveledBatch* matrix_b, std::size_t num_cols_b, std::size_t num_cols_a_rows_b) { 36 | std::vector> result(num_rows_a * num_cols_b); 37 | for (std::size_t row_a = 0; row_a != num_rows_a; row_a++) { 38 | for (std::size_t col_b = 0; col_b != num_cols_b; col_b++) { 39 | /* This goes in result at row row_a and column col_b. */ 40 | std::size_t i = row_a * num_cols_b + col_b; 41 | result[i] = real_dot_product<0>(&matrix_a[row_a * num_cols_a_rows_b], &matrix_b[col_b * num_cols_a_rows_b], num_cols_a_rows_b); 42 | } 43 | } 44 | return result; 45 | } 46 | 47 | template 48 | std::vector> local_tiled_matrix_multiply(std::size_t batch_dimension, LeveledBatch* matrix_a, std::size_t num_rows_a, LeveledBatch* matrix_b, std::size_t num_cols_b, std::size_t num_cols_a_rows_b) { 49 | std::vector> result(num_rows_a * num_cols_b); 50 | // for (std::size_t i = 0; i != num_rows_a * num_cols_b; i++) { 51 | // result[i] = LeveledBatch(0); 52 | // } 53 | std::size_t num_rows_a_batches = util::ceil_div(num_rows_a, batch_dimension).first; 54 | std::size_t num_cols_b_batches = util::ceil_div(num_cols_b, batch_dimension).first; 55 | std::size_t num_cols_a_rows_b_batches = util::ceil_div(num_cols_a_rows_b, batch_dimension).first; 56 | 57 | for (std::size_t batch_row_a = 0; batch_row_a < num_rows_a; batch_row_a += batch_dimension) { 58 | for (std::size_t batch_col_b = 0; batch_col_b < num_cols_b; batch_col_b += batch_dimension) { 59 | std::vector> result_batch(batch_dimension * batch_dimension); 60 | for (std::size_t batch_cols_a_rows_b = 0; batch_cols_a_rows_b < num_cols_a_rows_b; batch_cols_a_rows_b += batch_dimension) { 61 | /* Multiply the submatrices. */ 62 | for (std::size_t row_a = batch_row_a; row_a < num_rows_a && row_a < batch_row_a + batch_dimension; row_a++) { 63 | for (std::size_t col_b = batch_col_b; col_b < num_cols_b && col_b < batch_col_b + batch_dimension; col_b++) { 64 | /* This goes in result at row row_a and column col_b. */ 65 | std::size_t i_batch = (row_a - batch_row_a) * batch_dimension + (col_b - batch_col_b); 66 | std::size_t dot_product_size = std::min(batch_dimension, num_cols_a_rows_b - batch_cols_a_rows_b); 67 | LeveledBatch dot_product_result = real_dot_product_not_normalized(&matrix_a[row_a * num_cols_a_rows_b + batch_cols_a_rows_b], &matrix_b[col_b * num_cols_a_rows_b + batch_cols_a_rows_b], dot_product_size); 68 | if (result_batch[i_batch].valid()) { 69 | result_batch[i_batch] = result_batch[i_batch] + dot_product_result; 70 | } else { 71 | result_batch[i_batch] = std::move(dot_product_result); 72 | } 73 | if (batch_cols_a_rows_b + batch_dimension >= num_cols_a_rows_b) { 74 | std::size_t i = row_a * num_cols_b + col_b; 75 | result[i] = result_batch[i_batch].renormalize(); 76 | result_batch[i_batch].recycle(); 77 | } 78 | } 79 | } 80 | } 81 | } 82 | } 83 | return result; 84 | } 85 | 86 | template 87 | void create_real_matrix_multiply_circuit(const ProgramOptions& args) { 88 | int matrix_dimension = args.problem_size; 89 | int matrix_size = matrix_dimension * matrix_dimension; 90 | 91 | /* Blocked row-major matrix provided by the garbler. */ 92 | ShardedArray> matrix_a(matrix_size, args.worker_index, args.num_workers, Layout::Blocked); 93 | matrix_a.for_each([=](std::size_t i, auto& elem) { 94 | elem.mark_input(); 95 | }); 96 | 97 | /* Blocked column-major matrix provided by the evaluator. */ 98 | ShardedArray> matrix_b(matrix_size, args.worker_index, args.num_workers, Layout::Blocked); 99 | matrix_b.for_each([=](std::size_t i, auto& elem) { 100 | elem.mark_input(); 101 | }); 102 | 103 | program_ptr->print_stats(); 104 | program_ptr->start_timer(); 105 | 106 | ClusterUtils utils; 107 | utils.self_id = args.worker_index; 108 | utils.num_proc = args.num_workers; 109 | auto [ my_matrix_a, my_matrix_b ] = utils.cross_product(matrix_a, matrix_b); 110 | 111 | std::vector> result; 112 | if constexpr (tiled) { 113 | std::size_t tile_dimension; 114 | if constexpr (tile_size == 0) { 115 | std::int64_t memory_size = (*args.worker_config)["num_pages"].as_int() << (*args.worker_config)["page_shift"].as_int(); 116 | tile_dimension = static_cast((std::sqrt(memory_size) / 2048.0) + 1.0); 117 | } else { 118 | tile_dimension = tile_size; 119 | } 120 | result = local_tiled_matrix_multiply(tile_dimension, my_matrix_a.data(), my_matrix_a.size() / matrix_dimension, my_matrix_b.data(), my_matrix_b.size() / matrix_dimension, matrix_dimension); 121 | } else { 122 | result = local_naive_matrix_multiply(my_matrix_a.data(), my_matrix_a.size() / matrix_dimension, my_matrix_b.data(), my_matrix_b.size() / matrix_dimension, matrix_dimension); 123 | } 124 | 125 | program_ptr->stop_timer(); 126 | program_ptr->print_stats(); 127 | 128 | for (std::size_t i = 0; i != result.size(); i++) { 129 | result[i].mark_output(); 130 | } 131 | } 132 | 133 | RegisterProgram real_naive_matrix_multiply("real_naive_matrix_multiply", "Naive matrix multiply with real numbers (problem_size = number of elements in one side of matrix)", create_real_matrix_multiply_circuit<0, false>); 134 | RegisterProgram real_tiled_matrix_multiply("real_tiled_matrix_multiply", "Tiled matrix multiply with real numbers (problem_size = number of elements in one side of matrix)", create_real_matrix_multiply_circuit<0, true>); 135 | RegisterProgram real_tiled_16_matrix_multiply("real_tiled_16_matrix_multiply", "Tiled matrix multiply with real numbers (problem_size = number of elements in one side of matrix)", create_real_matrix_multiply_circuit<0, true, 16>); 136 | RegisterProgram real_tiled_64_matrix_multiply("real_tiled_64_matrix_multiply", "Tiled matrix multiply with real numbers (problem_size = number of elements in one side of matrix)", create_real_matrix_multiply_circuit<0, true, 64>); 137 | } 138 | --------------------------------------------------------------------------------