├── .gitignore ├── .gitmodules ├── LICENSE.md ├── Makefile ├── Makefile.src ├── README.md ├── scripts ├── get_prime.py ├── nbtheory.py ├── nn │ ├── LICENSE │ ├── README.md │ ├── cifar │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── model.py │ │ └── train.py │ ├── compress_model.py │ ├── mnist │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── model.py │ │ └── train.py │ ├── pretrained_models │ │ ├── cifar10_minionn-8bit.pth │ │ ├── cifar10_minionn-9bit.pth │ │ ├── cifar10_minionn.pth │ │ ├── mnist-8bit.pth │ │ ├── mnist.pth │ │ ├── mnist_cryptonets-10bit.pth │ │ ├── mnist_cryptonets-8bit.pth │ │ ├── mnist_cryptonets-9bit.pth │ │ ├── mnist_cryptonets.pth │ │ ├── mnist_deepsecure-8bit.pth │ │ ├── mnist_deepsecure-9bit.pth │ │ ├── mnist_deepsecure.pth │ │ ├── mnist_minionn-8bit.pth │ │ ├── mnist_minionn-9bit.pth │ │ ├── mnist_minionn.pth │ │ ├── mnist_secure_ml-10bit.pth │ │ ├── mnist_secure_ml-8bit.pth │ │ ├── mnist_secure_ml-9bit.pth │ │ └── mnist_secure_ml.pth │ └── utee │ │ ├── __init__.py │ │ ├── act.py │ │ ├── compress.py │ │ ├── misc.py │ │ ├── select.py │ │ └── selector.py └── throttle.sh ├── src ├── demo │ ├── ahe │ │ ├── conv1d-benchmark.cpp │ │ ├── conv2d-benchmark.cpp │ │ ├── conv2d-online.cpp │ │ ├── encoding-benchmark.cpp │ │ ├── fv-automorph-benchmark.cpp │ │ ├── fv-benchmark.cpp │ │ ├── fv-she-benchmark.cpp │ │ ├── gemm-benchmark.cpp │ │ ├── gemm-online.cpp │ │ ├── mat-mul-benchmark.cpp │ │ ├── mat-mul-online.cpp │ │ ├── square-benchmark.cpp │ │ ├── square-online.cpp │ │ └── transfrom-benchmark.cpp │ ├── network-benchmark.cpp │ └── tpc │ │ ├── act-gc-benchmark.cpp │ │ ├── aes-gc-benchmark.cpp │ │ ├── aesCircuit │ │ ├── gc-online.cpp │ │ └── ot-benchmark.cpp ├── lib │ ├── gc │ │ ├── aes.h │ │ ├── aescircuits.cpp │ │ ├── aescircuits.h │ │ ├── circuits.cpp │ │ ├── circuits.h │ │ ├── common.cpp │ │ ├── common.h │ │ ├── gates.cpp │ │ ├── gates.h │ │ ├── gazelle_circuits.cpp │ │ ├── gazelle_circuits.h │ │ ├── gc.cpp │ │ ├── gc.h │ │ ├── scd.nocpp │ │ ├── util.cpp │ │ └── util.h │ ├── math │ │ ├── automorph.cpp │ │ ├── automorph.h │ │ ├── bit_twiddle.cpp │ │ ├── bit_twiddle.h │ │ ├── discretegaussiangenerator.cpp │ │ ├── discretegaussiangenerator.h │ │ ├── distrgen.h │ │ ├── distributiongenerator.cpp │ │ ├── distributiongenerator.h │ │ ├── nbtheory.cpp │ │ ├── nbtheory.h │ │ ├── params.cpp │ │ ├── params.h │ │ ├── transfrm.cpp │ │ └── transfrm.h │ ├── ot │ │ ├── cot_recv.cpp │ │ ├── cot_recv.h │ │ ├── cot_send.cpp │ │ ├── cot_send.h │ │ ├── ot_ifc.h │ │ ├── sr_base_ot.cpp │ │ ├── sr_base_ot.h │ │ ├── tools.cpp │ │ └── tools.h │ ├── pke │ │ ├── conv1d.cpp │ │ ├── conv1d.h │ │ ├── conv2d.cpp │ │ ├── conv2d.h │ │ ├── encoding.cpp │ │ ├── encoding.h │ │ ├── fv.cpp │ │ ├── fv.h │ │ ├── gazelle.h │ │ ├── gemm.cpp │ │ ├── gemm.h │ │ ├── layers.cpp │ │ ├── layers.h │ │ ├── mat_mul.cpp │ │ ├── mat_mul.h │ │ ├── pke_types.h │ │ ├── square.cpp │ │ └── square.h │ └── utils │ │ ├── backend.h │ │ ├── debug.cpp │ │ ├── debug.h │ │ ├── network.cpp │ │ ├── network.h │ │ ├── test.cpp │ │ └── test.h └── unittest │ ├── Main_TestAll.cpp │ ├── UnitTestDistrGen.cpp │ ├── UnitTestFVAutomorph.cpp │ ├── UnitTestFVBase.cpp │ ├── UnitTestFVSHE.cpp │ ├── UnitTestNTT.cpp │ └── UnitTestTransform.cpp └── test └── include └── gtest ├── gtest-all.cc └── gtest.h /.gitignore: -------------------------------------------------------------------------------- 1 | bin/* 2 | src/bin/* 3 | 4 | __pycache__ 5 | log 6 | scripts/nn/dataset/ 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/cryptoTools"] 2 | path = third_party/cryptoTools 3 | url = https://github.com/ladnir/cryptoTools.git 4 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Chiraag Juvekar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # This file is based on the PALISADE Makefile 2 | # Multi OS makefile (No Windows yet) 3 | 4 | UNAME_S := $(shell uname -s) 5 | HOST_SYSTEM = $(shell uname | cut -f 1 -d_) 6 | SYSTEM ?= $(HOST_SYSTEM) 7 | 8 | CPPSTD := -std=c++14 -fPIC 9 | 10 | ifeq ($(SYSTEM),Darwin) 11 | CC := /usr/local/opt/llvm/bin/clang++ $(CPPSTD) 12 | LDFLAGS += -L/usr/local/opt/llvm/lib -Wl,-rpath,/usr/local/opt/llvm/lib 13 | CPPFLAGS += -I/usr/local/opt/llvm/include -I/usr/local/opt/llvm/include/c++/v1/ 14 | LIBSUFFIX := .dylib 15 | LIBCMD := -dynamiclib -undefined suppress -flat_namespace 16 | else 17 | CC := g++ $(CPPSTD) 18 | # Suppress warnings from alignment of STL structs of blocks 19 | CPPFLAGS += -Wno-ignored-attributes 20 | LIBSUFFIX := .so 21 | LIBCMD := -fPIC -shared -Wl,--export-dynamic,-z,defs 22 | endif 23 | 24 | RDYNAMIC := -rdynamic 25 | 26 | COMPTHREADFLAG := -pthread 27 | LOADTHREADFLAG := -pthread 28 | 29 | # LDFLAGS += -lboost_system -lboost_thread -fsanitize=address 30 | LDFLAGS += -lboost_system -lboost_thread 31 | 32 | #main best performance configuration for parallel operation - cross-platform 33 | # CPPFLAGS += -maes -msse4 -g -Wall -Werror -fsanitize=address $(COMPTHREADFLAG) ##undefine for parallel best performance operation with debug 34 | # CPPFLAGS += -maes -msse4 -g -O3 -Wall -fno-omit-frame-pointer $(COMPTHREADFLAG) ##undefine for parallel best performance operation with debug 35 | CPPFLAGS += -maes -msse4 -O3 -Wall -fno-omit-frame-pointer $(COMPTHREADFLAG) ##undefine for parallel best performance operation 36 | 37 | TEST_LIB := $(LOADTHREADFLAG) 38 | 39 | #build and bin directory 40 | BUILDDIR := build 41 | BINDIR := bin 42 | 43 | #cryptoTools locations 44 | MIRACL_LIBDIR := third_party/cryptoTools/thirdparty/linux/miracl/miracl/source 45 | MIRACL_INCDIR := third_party/cryptoTools/thirdparty/linux/miracl 46 | LIBCMD += -L$(MIRACL_LIBDIR) 47 | 48 | # BOOST_LIBDIR := third_party/cryptoTools/thirdparty/linux/boost/stage/lib/ 49 | # BOODT_INCDIR := third_party/cryptoTools/thirdparty/linux/boost 50 | # LIBCMD += -L$(BOOST_LIBDIR) 51 | 52 | CT_LIBDIR := third_party/cryptoTools/lib 53 | CT_INCDIR := third_party/cryptoTools 54 | LIBCMD += -L$(CT_LIBDIR) 55 | 56 | #LDFLAGS += -lmiracl -lcryptoTools 57 | 58 | #sources folders 59 | EXTLIBDIR := bin/lib 60 | EXTTESTDIR := bin/unittest 61 | EXTDEMODIR := bin/demo 62 | 63 | # extentions for source and header files 64 | SRCEXT := cpp 65 | HDREXT := h 66 | 67 | $(objects) : %.o : %.cpp 68 | 69 | # External libraries 70 | #EXTLIB := -L$(EXTLIBDIR) $(TEST_LIB) -pg ## include profiling 71 | EXTLIB := -L$(EXTLIBDIR) $(TEST_LIB) ## no-profiling 72 | 73 | INC := -I src/lib -I test -I $(CT_INCDIR) -I $(MIRACL_INCDIR) 74 | # INC += -I $(BOODT_INCDIR) 75 | 76 | #the name of the shared object library 77 | CORELIB := libgazelle$(LIBSUFFIX) 78 | 79 | # run make for all components. you can run any individual component separately 80 | # by invoking "make alltargets" for example 81 | # each corresponding makefile will make the allxxxx target 82 | all: allcore 83 | 84 | alldemos: allcoredemos 85 | 86 | testall: testcore 87 | 88 | # clean up all components. you can clean any individual compoenent separately 89 | # by invoking "make cleantargets" for example 90 | # each corresponding makefile will make the cleanxxxx target 91 | .PHONEY: clean 92 | clean: cleancore 93 | @echo 'Cleaning top level autogenerated directories' 94 | $(RM) -f test/include/gtest/gtest-all.o 95 | $(RM) -rf bin 96 | 97 | include Makefile.src 98 | 99 | test/include/gtest/gtest-all.o: test/include/gtest/gtest-all.cc 100 | $(CC) -c $(CPPFLAGS) -o $@ $< 101 | -------------------------------------------------------------------------------- /Makefile.src: -------------------------------------------------------------------------------- 1 | # This file is based on the PALISADE Makefile 2 | 3 | CORESRCDIR := src 4 | COREBINDIR := src/bin 5 | CORETESTDIR := src/unittest 6 | COREDEMODIR := src/demo 7 | 8 | CORESOURCES := $(shell find $(CORESRCDIR)/lib -name '*.cpp' ! -name '*.pb.cpp') 9 | COREUNITSOURCES := $(wildcard $(CORESRCDIR)/unittest/*.cpp) 10 | COREDEMOSOURCES := $(shell find $(CORESRCDIR)/demo -name '*.cpp' ! -name '*.pb.cpp') 11 | 12 | COREOBJECTS := $(patsubst $(CORESRCDIR)/%,$(COREBINDIR)/%,$(patsubst %.cpp,%.o,$(CORESOURCES))) 13 | COREUNITOBJECTS := $(patsubst $(CORESRCDIR)/%,$(COREBINDIR)/%,$(patsubst %.cpp,%.o,$(COREUNITSOURCES))) 14 | COREUNITOBJECTS += test/include/gtest/gtest-all.o 15 | COREDEMOOBJECTS := $(patsubst $(CORESRCDIR)/%,$(COREBINDIR)/%,$(patsubst %.cpp,%.o,$(COREDEMOSOURCES))) 16 | 17 | TEST_TARGET := $(EXTTESTDIR)/tests$(EXESUFFIX) 18 | 19 | -include $(COREOBJECTS:.o=.d) 20 | -include $(COREUNITOBJECTS:.o=.d) 21 | -include $(COREDEMOOBJECTS:.o=.d) 22 | 23 | .PHONY:allcore 24 | allcore: $(EXTLIBDIR)/$(CORELIB) allcoredemos 25 | 26 | allcoredemos: $(EXTLIBDIR)/$(CORELIB) $(patsubst $(COREBINDIR)/demo/%,bin/demo/%,$(patsubst %.o,%$(EXESUFFIX),$(COREDEMOOBJECTS))) 27 | 28 | bin/demo/%$(EXESUFFIX): src/bin/demo/%.o $(EXTLIBDIR)/$(CORELIB) 29 | @mkdir -p $(@D) 30 | $(CC) -o $@ $^ $(EXTLIB) $(LDFLAGS) -lgazelle 31 | # $(CC) -o $@ $^ $(EXTLIB) $(LDFLAGS) -lgazelle -L$(BOOST_LIBDIR) 32 | 33 | #this builds the shared library out of the objects 34 | $(EXTLIBDIR)/$(CORELIB): $(COREOBJECTS) 35 | @echo " -- core:linking $@ from COREOBJECTS" 36 | mkdir -p $(EXTLIBDIR) 37 | $(CC) $(LIBCMD) -o $@ $(COREOBJECTS) $(TEST_LIB) -lcryptoTools -lmiracl -lboost_system -lboost_thread 38 | # $(CC) $(LIBCMD) -o $@ $(COREOBJECTS) $(TEST_LIB) -fsanitize-address -lcryptoTools -lmiracl -lboost_system -lboost_thread 39 | 40 | ### #this builds the individual objects that make up the library . 41 | .PRECIOUS: $(COREBINDIR)/% 42 | $(COREBINDIR)/%: 43 | @if [ "$(suffix $@)" = ".o" ] ; \ 44 | then \ 45 | mkdir -p $(@D) ;\ 46 | echo $(CC) $(CPPFLAGS) $(INC) -c -o $@ $(patsubst $(COREBINDIR)/%,$(CORESRCDIR)/%,$(patsubst %.o,%.cpp,$@)) ;\ 47 | $(CC) -MM $(CPPFLAGS) $(INC) $(patsubst $(COREBINDIR)/%,$(CORESRCDIR)/%,$(patsubst %.o,%.cpp,$@)) > $(patsubst %.o,%.d,$@) ;\ 48 | mv -f $(patsubst %.o,%.d,$@) $(patsubst %.o,%.d.tmp,$@) ;\ 49 | sed -e 's|.*\.o:|$(COREBINDIR)/$*:|' < $(patsubst %.o,%.d.tmp,$@) > $(patsubst %.o,%.d,$@) ;\ 50 | rm -f $(patsubst %.o,%.d.tmp,$@) ; \ 51 | $(CC) $(CPPFLAGS) $(INC) -c -o $@ $(patsubst $(COREBINDIR)/%,$(CORESRCDIR)/%,$(patsubst %.o,%.cpp,$@)) ;\ 52 | fi 53 | 54 | #this target is used to cleanup, it is called from the top Makefile 55 | .PHONY: cleancore 56 | 57 | cleancore: 58 | $(RM) -fr $(COREBINDIR) $(EXTLIBDIR)/$(CORELIB) `dirname $(TEST_TARGET)` bin/demo 59 | 60 | # this links test executable from objects in the test build directory 61 | $(TEST_TARGET): $(COREUNITOBJECTS) $(EXTLIBDIR)/$(CORELIB) 62 | @mkdir -p `dirname $(TEST_TARGET)` 63 | $(CC) -o $(TEST_TARGET) $^ $(EXTLIB) $(LDFLAGS) $(TEST_LIB) 64 | 65 | #used to run tests from make 66 | .PHONY: testcore 67 | testcore: $(TEST_TARGET) 68 | $(TEST_TARGET) 69 | 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gazelle 2 | Gazelle MPC Framework 3 | 4 | ## Install 5 | 6 | This was last tested on Ubuntu 16.04 LTS and 17.10 7 | ```bash 8 | # Install dependencies 9 | sudo apt-get install g++ nasm cmake libboost-all-dev 10 | 11 | # Clone this repo 12 | git clone https://github.com/chiraag/gazelle_mpc 13 | cd gazelle_mpc 14 | git submodule update --init --recursive 15 | 16 | # Compile miracl for OSU cryptotools 17 | cd third_party/cryptoTools/thirdparty/linux 18 | bash miracl.get 19 | cd miracl/miracl/source 20 | sed -i -e 's/g++ -c/g++ -c -fPIC/g' linux64 21 | bash linux64 22 | 23 | # Compile cryptotools 24 | cd ../../../../../ 25 | cmake . 26 | make -j8 27 | 28 | # Compile gazelle 29 | cd ../../ 30 | make -j8 31 | ``` 32 | 33 | If you want to run to run the network conversion scripts you will 34 | need a python interpreter and pytorch. These scripts were tested with 35 | Anaconda3 on a machine that had a GPU. 36 | 37 | ## Running examples 38 | 39 | Have a look at the demo folder to see some examples. 40 | -------------------------------------------------------------------------------- /scripts/get_prime.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import primefac 4 | import numpy as np 5 | import nbtheory 6 | 7 | def get_p(pbits=20, m=4096, num_primes=10): 8 | r = pow(2, pbits, m) 9 | primes = [] 10 | for k in range(2**pbits//m): 11 | p = 2**pbits - r + m*k + 1 12 | if primefac._primefac.isprime(p): 13 | primes.append(p) 14 | if len(primes) == num_primes: 15 | break 16 | return primes 17 | 18 | def get_q(qbits=60, pbits=20, n=2048, r=1, num_primes=10): 19 | m = 2*n 20 | qbase = 2**qbits 21 | min_delta = qbase 22 | primes_p = get_p(pbits, m, num_primes) 23 | q, p = (None, None) 24 | # print primes_p 25 | for p_curr in primes_p: 26 | m_inv = primefac._primefac.modinv(m, p_curr) 27 | p_inv = primefac._primefac.modinv(p_curr, m) 28 | delta_p = (2**qbits - r) % p_curr 29 | delta = (-1*p_inv*p_curr + delta_p*m_inv*m) % (m*p_curr) 30 | q_curr = 2**qbits - delta 31 | if primefac._primefac.isprime(q_curr): 32 | if delta < min_delta: 33 | min_delta = delta 34 | q, p = q_curr, p_curr 35 | if q is not None: 36 | assert(p % m == 1), p % m 37 | assert((q - r) % p == 0), q % p 38 | assert(q % m == 1), q % m 39 | assert(primefac._primefac.isprime(q)) 40 | assert(primefac._primefac.isprime(p)) 41 | # print "delta", np.log2(delta) 42 | if np.log2(min_delta) < (qbits-6)/2: 43 | return q, p 44 | else: 45 | return None, None 46 | 47 | n = 2048 48 | prime_table = {} 49 | for pbits in range(18, 21): 50 | print("Searching for pbits = %d" % pbits) 51 | for qbits in [60, 61, 59, 62, 58, 63]: 52 | for r in [1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7]: 53 | q, p = get_q(qbits, pbits, n, r, 16384) 54 | if q is not None: 55 | prime_table[pbits] = (q, p) 56 | print("Found") 57 | break 58 | if pbits in prime_table: 59 | break 60 | 61 | for pbits in prime_table: 62 | q, p = prime_table[pbits] 63 | zq = nbtheory.root_of_unity(q, n*2) 64 | zp = nbtheory.root_of_unity(p, n*2) 65 | print(q, p, np.log2(q), np.log2(p), q % p, zq, zp) 66 | -------------------------------------------------------------------------------- /scripts/nbtheory.py: -------------------------------------------------------------------------------- 1 | import primefac 2 | # import yafu 3 | import numpy as np 4 | import random 5 | 6 | factor_silent = True 7 | 8 | def round(x): 9 | return int(np.round(x)) 10 | 11 | def floor(x): 12 | return int(np.floor(x)) 13 | 14 | def ceil(x): 15 | return int(np.ceil(x)) 16 | 17 | def log2(x): 18 | return np.log2(x*1.0) 19 | 20 | def factor(n): 21 | if log2(n) > 128: 22 | return yafu.factor(n, silent=factor_silent) 23 | else: 24 | return primefac.factorint(n) 25 | 26 | def prime_form(bits, i, m): 27 | # Outputs p (may not be prime) s.t. p-1 | m 28 | return 2**bits - pow(2, bits, m) + i*m + 1 29 | 30 | def get_primes(bits, m, max_attempt=10000): 31 | # Outputs primes s.t. p-1 | m 32 | primes = {} 33 | for i in range(max_attempt): 34 | p = prime_form(bits, i, m) 35 | if primefac._primefac.isprime(p): 36 | primes[p] = i 37 | return primes 38 | 39 | def primitive_root(q, factors=None): 40 | if factors == None: 41 | factors = factor(q-1) 42 | g = 0 43 | while True: 44 | g = random.randint(0, q-1) # randint limits are inclusive 45 | if not primefac._primefac.gcd(g, q) == 1: 46 | continue 47 | 48 | is_primtive = True 49 | for p in factors: 50 | co_factor = (q-1)//p 51 | if pow(g, co_factor, q) == 1: 52 | is_primtive = False 53 | break 54 | 55 | if is_primtive: 56 | break 57 | return g 58 | 59 | def root_of_unity(q, m): 60 | assert((q-1) % m == 0) 61 | g = primitive_root(q) 62 | z = pow(g, (q-1)//m, q) 63 | assert(pow(z, m, q) == 1) # Necessary, sufficient if g is primitive 64 | return z 65 | 66 | def is_mult_generator(g, m, factors=None): 67 | if factors == None: 68 | factors = factor(m-1) 69 | for p in factors: 70 | co_factor = (m-1)//p 71 | if pow(g, co_factor, m) == 1: 72 | return False 73 | return True 74 | -------------------------------------------------------------------------------- /scripts/nn/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Chiraag Juvekar, Aaron Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/nn/README.md: -------------------------------------------------------------------------------- 1 | NN-Compress 2 | ----------- 3 | 4 | Simple scripts to compress and quantize NN models. These are based on this [repo](https://github.com/aaron-xichen/pytorch-playground). 5 | 6 | The main additions are: 7 | - Conversion to fixed point integer representation. 8 | - Support for Square Non-Linearities 9 | - Support for existing "Secure" networks implementations 10 | -------------------------------------------------------------------------------- /scripts/nn/cifar/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/cifar/__init__.py -------------------------------------------------------------------------------- /scripts/nn/cifar/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from torch.utils.data import DataLoader 4 | import os 5 | 6 | def get10(batch_size, data_root='/home/chiraag/research/datasets', train=True, val=True, **kwargs): 7 | data_root = os.path.expanduser(os.path.join(data_root, 'cifar10-data')) 8 | num_workers = kwargs.setdefault('num_workers', 1) 9 | kwargs.pop('input_size', None) 10 | print("Building CIFAR-10 data loader with {} workers".format(num_workers)) 11 | ds = [] 12 | if train: 13 | train_loader = torch.utils.data.DataLoader( 14 | datasets.CIFAR10( 15 | root=data_root, train=True, download=True, 16 | transform=transforms.Compose([ 17 | transforms.Pad(4), 18 | transforms.RandomCrop(32), 19 | transforms.RandomHorizontalFlip(), 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 22 | ])), 23 | batch_size=batch_size, shuffle=True, **kwargs) 24 | ds.append(train_loader) 25 | if val: 26 | test_loader = torch.utils.data.DataLoader( 27 | datasets.CIFAR10( 28 | root=data_root, train=False, download=True, 29 | transform=transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 32 | ])), 33 | batch_size=batch_size, shuffle=False, **kwargs) 34 | ds.append(test_loader) 35 | ds = ds[0] if len(ds) == 1 else ds 36 | return ds 37 | 38 | def get100(batch_size, data_root='/home/chiraag/research/datasets', train=True, val=True, **kwargs): 39 | data_root = os.path.expanduser(os.path.join(data_root, 'cifar100-data')) 40 | num_workers = kwargs.setdefault('num_workers', 1) 41 | kwargs.pop('input_size', None) 42 | print("Building CIFAR-100 data loader with {} workers".format(num_workers)) 43 | ds = [] 44 | if train: 45 | train_loader = torch.utils.data.DataLoader( 46 | datasets.CIFAR100( 47 | root=data_root, train=True, download=True, 48 | transform=transforms.Compose([ 49 | transforms.Pad(4), 50 | transforms.RandomCrop(32), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 54 | ])), 55 | batch_size=batch_size, shuffle=True, **kwargs) 56 | ds.append(train_loader) 57 | 58 | if val: 59 | test_loader = torch.utils.data.DataLoader( 60 | datasets.CIFAR100( 61 | root=data_root, train=False, download=True, 62 | transform=transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 65 | ])), 66 | batch_size=batch_size, shuffle=False, **kwargs) 67 | ds.append(test_loader) 68 | ds = ds[0] if len(ds) == 1 else ds 69 | return ds 70 | 71 | -------------------------------------------------------------------------------- /scripts/nn/cifar/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | from IPython import embed 4 | from collections import OrderedDict 5 | 6 | from utee import misc 7 | print = misc.logger.info 8 | 9 | model_urls = { 10 | 'cifar10': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/cifar10-d875770b.pth', 11 | 'cifar100': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/cifar100-3a55a987.pth', 12 | 'cifar10_minionn': 'cifar10_minionn.pth', 13 | } 14 | 15 | class CIFAR(nn.Module): 16 | def __init__(self, features, n_channel, num_classes): 17 | super(CIFAR, self).__init__() 18 | assert isinstance(features, nn.Sequential), type(features) 19 | self.features = features 20 | self.classifier = nn.Sequential( 21 | nn.Linear(n_channel, num_classes) 22 | ) 23 | print(self.features) 24 | print(self.classifier) 25 | 26 | def forward(self, x): 27 | x = self.features(x) 28 | x = x.view(x.size(0), -1) 29 | x = self.classifier(x) 30 | return x 31 | 32 | class CIFARMiniONN(nn.Module): 33 | def __init__(self): 34 | super(CIFARMiniONN, self).__init__() 35 | 36 | self.features = nn.Sequential( 37 | nn.Conv2d(3, 64, 3, 1, 1), nn.ReLU(), 38 | nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(), 39 | nn.AvgPool2d(2), 40 | nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(), 41 | nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(), 42 | nn.AvgPool2d(2), 43 | nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(), 44 | nn.Conv2d(64, 64, 1), nn.ReLU(), 45 | nn.Conv2d(64, 16, 1), nn.ReLU(), 46 | ) 47 | self.classifier = nn.Sequential( 48 | nn.Linear(1024, 10) 49 | ) 50 | print(self.features) 51 | print(self.classifier) 52 | 53 | def forward(self, x): 54 | x = self.features(x) 55 | x = x.view(x.size(0), -1) 56 | x = self.classifier(x) 57 | return x 58 | 59 | def make_layers(cfg, batch_norm=False): 60 | layers = [] 61 | in_channels = 3 62 | for i, v in enumerate(cfg): 63 | if v == 'M': 64 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 65 | else: 66 | padding = v[1] if isinstance(v, tuple) else 1 67 | out_channels = v[0] if isinstance(v, tuple) else v 68 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding) 69 | if batch_norm: 70 | layers += [conv2d, nn.BatchNorm2d(out_channels, affine=False), nn.ReLU()] 71 | else: 72 | layers += [conv2d, nn.ReLU()] 73 | in_channels = out_channels 74 | return nn.Sequential(*layers) 75 | 76 | def cifar10(n_channel, pretrained=None): 77 | cfg = [n_channel, n_channel, 'M', 2*n_channel, 2*n_channel, 'M', 4*n_channel, 4*n_channel, 'M', (8*n_channel, 0), 'M'] 78 | layers = make_layers(cfg, batch_norm=True) 79 | model = CIFAR(layers, n_channel=8*n_channel, num_classes=10) 80 | if pretrained is not None: 81 | m = model_zoo.load_url(model_urls['cifar10']) 82 | state_dict = m.state_dict() if isinstance(m, nn.Module) else m 83 | assert isinstance(state_dict, (dict, OrderedDict)), type(state_dict) 84 | model.load_state_dict(state_dict) 85 | return model 86 | 87 | def cifar100(n_channel, pretrained=None): 88 | cfg = [n_channel, n_channel, 'M', 2*n_channel, 2*n_channel, 'M', 4*n_channel, 4*n_channel, 'M', (8*n_channel, 0), 'M'] 89 | layers = make_layers(cfg, batch_norm=True) 90 | model = CIFAR(layers, n_channel=8*n_channel, num_classes=100) 91 | if pretrained is not None: 92 | m = model_zoo.load_url(model_urls['cifar100']) 93 | state_dict = m.state_dict() if isinstance(m, nn.Module) else m 94 | assert isinstance(state_dict, (dict, OrderedDict)), type(state_dict) 95 | model.load_state_dict(state_dict) 96 | return model 97 | 98 | def get(model_name, model_dir, pretrained=False): 99 | if model_name == 'cifar10': 100 | model = cifar10(128, pretrained='log/cifar10/best-135.pth') 101 | elif model_name == 'cifar100': 102 | model = cifar100(128, pretrained='log/cifar100/best-135.pth') 103 | elif model_name == 'cifar10_minionn': 104 | model = CIFARMiniONN() 105 | else: 106 | assert False, model_name 107 | 108 | if pretrained: 109 | m = model_zoo.load_url(model_urls[model_name], model_dir) 110 | state_dict = m.state_dict() if isinstance(m, nn.Module) else m 111 | assert isinstance(state_dict, (dict, OrderedDict)), type(state_dict) 112 | model.load_state_dict(state_dict) 113 | return model 114 | 115 | if __name__ == '__main__': 116 | model = cifar10(128, pretrained='log/cifar10/best-135.pth') 117 | embed() 118 | 119 | -------------------------------------------------------------------------------- /scripts/nn/compress_model.py: -------------------------------------------------------------------------------- 1 | from utee import misc, compress 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | import os 5 | cudnn.benchmark =True 6 | from IPython import embed 7 | 8 | params = { 9 | 'dataset': 'mnist', 10 | 'model': 'mnist_cryptonets', 11 | 'batch_size': 100, 12 | 'seed': 117, 13 | 'model_dir': './pretrained_models', 14 | 'data_dir': 'dataset/', 15 | 'n_sample': 20, 16 | 'weight_bits': 6, 17 | 'act_bits': 9, 18 | 'overflow_rate': 0.0 19 | } 20 | 21 | misc.ensure_dir(params['model_dir']) 22 | params['model_dir'] = misc.expand_user(params['model_dir']) 23 | params['data_dir'] = misc.expand_user(params['data_dir']) 24 | 25 | print("================PARAMS==================") 26 | for k, v in params.items(): 27 | print('{}: {}'.format(k, v)) 28 | print("========================================") 29 | 30 | assert torch.cuda.is_available(), 'no cuda' 31 | torch.manual_seed(params['seed']) 32 | torch.cuda.manual_seed(params['seed']) 33 | 34 | # load model and dataset fetcher 35 | model_raw, ds_fetcher = misc.load_model(params['model'], params['dataset'], 36 | model_root=params['model_dir'], pretrained=True) 37 | model_raw.cuda() 38 | model_raw.eval() 39 | 40 | model_new = compress.CompressedModel(model_raw, input_scale=255, 41 | act_bits=params['act_bits'], weight_bits=params['weight_bits']) 42 | model_new = model_new.cuda() 43 | print(model_new) 44 | 45 | val_ds = ds_fetcher(params['batch_size'], data_root=params['data_dir'], train=False) 46 | acc1, acc5 = misc.eval_model(model_new, val_ds, ngpu=1, n_sample=params['n_sample'], is_imagenet=False) 47 | print("FP accuracy Top1: %g Top5: %g" % (acc1, acc5)) 48 | 49 | model_new.quantize_params() 50 | acc1, acc5 = misc.eval_model(model_new, val_ds, ngpu=1, n_sample=params['n_sample'], is_imagenet=False) 51 | print("Quant accuracy Top1: %g Top5: %g" % (acc1, acc5)) 52 | print(acc1, acc5) 53 | 54 | print(model_new) 55 | new_file = os.path.join(params['model_dir'], 56 | '{}-{}bit.pth'.format(params['model'], params['act_bits'])) 57 | misc.model_snapshot(model_new, new_file, old_file=None, verbose=True) 58 | 59 | #embed() 60 | -------------------------------------------------------------------------------- /scripts/nn/mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/mnist/__init__.py -------------------------------------------------------------------------------- /scripts/nn/mnist/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | from torchvision import datasets, transforms 4 | import os 5 | 6 | def get(batch_size, data_root, train=True, val=True, **kwargs): 7 | data_root = os.path.expanduser(os.path.join(data_root, 'mnist-data')) 8 | kwargs.pop('input_size', None) 9 | num_workers = kwargs.setdefault('num_workers', 1) 10 | print("Building MNIST data loader with {} workers".format(num_workers)) 11 | ds = [] 12 | if train: 13 | train_loader = torch.utils.data.DataLoader( 14 | datasets.MNIST(root=data_root, train=True, download=True, 15 | transform=transforms.Compose([ 16 | transforms.ToTensor() 17 | # , transforms.Normalize((0.1307,), (0.3081,)) 18 | ])), 19 | batch_size=batch_size, shuffle=True, **kwargs) 20 | ds.append(train_loader) 21 | if val: 22 | test_loader = torch.utils.data.DataLoader( 23 | datasets.MNIST(root=data_root, train=False, download=True, 24 | transform=transforms.Compose([ 25 | transforms.ToTensor() 26 | #, transforms.Normalize((0.1307,), (0.3081,)) 27 | ])), 28 | batch_size=batch_size, shuffle=True, **kwargs) 29 | ds.append(test_loader) 30 | ds = ds[0] if len(ds) == 1 else ds 31 | return ds 32 | 33 | -------------------------------------------------------------------------------- /scripts/nn/mnist/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | import torch.utils.model_zoo as model_zoo 4 | from utee import misc 5 | from utee import act 6 | print = misc.logger.info 7 | 8 | model_urls = { 9 | 'mnist': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/mnist.pth', 10 | 'mnist_secure_ml': 'mnist_secure_ml.pth', 11 | 'mnist_cryptonets': 'mnist_cryptonets.pth', 12 | 'mnist_deepsecure': 'mnist_deepsecure.pth', 13 | 'mnist_minionn': 'mnist_minionn.pth', 14 | } 15 | 16 | class MNIST(nn.Module): 17 | def __init__(self): 18 | super(MNIST, self).__init__() 19 | 20 | self.features = nn.Sequential() 21 | self.classifier = nn.Sequential( 22 | nn.Linear(28*28, 256), nn.ReLU(), nn.Dropout(0.2), 23 | nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.2), 24 | nn.Linear(256, 10) 25 | ) 26 | 27 | print(self.features) 28 | print(self.classifier) 29 | 30 | def forward(self, x): 31 | x = self.features(x) 32 | x = x.view(x.size(0), -1) 33 | x = self.classifier(x) 34 | return x 35 | 36 | class MNISTCryptoNets(nn.Module): 37 | def __init__(self): 38 | super(MNISTCryptoNets, self).__init__() 39 | 40 | self.features = nn.Sequential( 41 | nn.Conv2d(1, 5, 5, 2, 2), act.Square() 42 | ) 43 | 44 | self.classifier = nn.Sequential( 45 | nn.Linear(980, 100), act.Square(), # nn.Dropout(0.2), 46 | nn.Linear(100, 10) 47 | ) 48 | 49 | print(self.features) 50 | print(self.classifier) 51 | 52 | def forward(self, x): 53 | x = self.features(x) 54 | x = x.view(x.size(0), -1) 55 | # print("Num features", list(x.size())) 56 | x = self.classifier(x) 57 | return x 58 | 59 | class MNISTDeepSecure(nn.Module): 60 | def __init__(self): 61 | super(MNISTDeepSecure, self).__init__() 62 | 63 | self.features = nn.Sequential( 64 | nn.Conv2d(1, 5, 5, 2, 2), nn.ReLU() 65 | ) 66 | 67 | self.classifier = nn.Sequential( 68 | nn.Linear(980, 100), nn.ReLU(), # nn.Dropout(0.2), 69 | nn.Linear(100, 10) 70 | ) 71 | 72 | print(self.features) 73 | print(self.classifier) 74 | 75 | def forward(self, x): 76 | x = self.features(x) 77 | x = x.view(x.size(0), -1) 78 | # print("Num features", list(x.size())) 79 | x = self.classifier(x) 80 | return x 81 | 82 | class MNISTMiniONN(nn.Module): 83 | def __init__(self): 84 | super(MNISTMiniONN, self).__init__() 85 | 86 | self.features = nn.Sequential( 87 | nn.Conv2d(1, 16, 5), 88 | nn.ReLU(), nn.MaxPool2d(2), 89 | nn.Conv2d(16, 16, 5, 1), 90 | nn.ReLU(), nn.MaxPool2d(2) 91 | ) 92 | 93 | self.classifier = nn.Sequential( 94 | nn.Linear(256, 100), nn.ReLU(), # nn.Dropout(0.2), 95 | nn.Linear(100, 10) 96 | ) 97 | 98 | print(self.features) 99 | print(self.classifier) 100 | 101 | def forward(self, x): 102 | x = self.features(x) 103 | x = x.view(x.size(0), -1) 104 | x = self.classifier(x) 105 | return x 106 | 107 | class MNISTSecureML(nn.Module): 108 | def __init__(self): 109 | super(MNISTSecureML, self).__init__() 110 | 111 | self.features = nn.Sequential() 112 | self.classifier = nn.Sequential( 113 | nn.Linear(28*28, 128), act.Square(), nn.Dropout(0.2), 114 | nn.Linear(128, 128), act.Square(), nn.Dropout(0.2), 115 | nn.Linear(128, 10) 116 | ) 117 | 118 | print(self.features) 119 | print(self.classifier) 120 | 121 | def forward(self, x): 122 | x = self.features(x) 123 | x = x.view(x.size(0), -1) 124 | x = self.classifier(x) 125 | return x 126 | 127 | def get(model_name, model_dir, pretrained=False): 128 | if model_name == 'mnist': 129 | model = MNIST() 130 | elif model_name == 'mnist_secure_ml': 131 | model = MNISTSecureML() 132 | elif model_name == 'mnist_cryptonets': 133 | model = MNISTCryptoNets() 134 | elif model_name == 'mnist_deepsecure': 135 | model = MNISTDeepSecure() 136 | elif model_name == 'mnist_minionn': 137 | model = MNISTMiniONN() 138 | else: 139 | assert False, model_name 140 | 141 | if pretrained: 142 | m = model_zoo.load_url(model_urls[model_name], model_dir) 143 | state_dict = m.state_dict() if isinstance(m, nn.Module) else m 144 | assert isinstance(state_dict, (dict, OrderedDict)), type(state_dict) 145 | model.load_state_dict(state_dict) 146 | return model 147 | 148 | -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/cifar10_minionn-8bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/cifar10_minionn-8bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/cifar10_minionn-9bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/cifar10_minionn-9bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/cifar10_minionn.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/cifar10_minionn.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist-8bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist-8bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_cryptonets-10bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_cryptonets-10bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_cryptonets-8bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_cryptonets-8bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_cryptonets-9bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_cryptonets-9bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_cryptonets.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_cryptonets.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_deepsecure-8bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_deepsecure-8bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_deepsecure-9bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_deepsecure-9bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_deepsecure.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_deepsecure.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_minionn-8bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_minionn-8bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_minionn-9bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_minionn-9bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_minionn.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_minionn.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_secure_ml-10bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_secure_ml-10bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_secure_ml-8bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_secure_ml-8bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_secure_ml-9bit.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_secure_ml-9bit.pth -------------------------------------------------------------------------------- /scripts/nn/pretrained_models/mnist_secure_ml.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/pretrained_models/mnist_secure_ml.pth -------------------------------------------------------------------------------- /scripts/nn/utee/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/scripts/nn/utee/__init__.py -------------------------------------------------------------------------------- /scripts/nn/utee/act.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Square(nn.Module): 5 | def __init__(self): 6 | super(Square, self).__init__() 7 | 8 | def forward(self, x): 9 | # unfortunately we don't have automatic broadcasting yet 10 | return torch.mul(x, x) 11 | -------------------------------------------------------------------------------- /scripts/nn/utee/select.py: -------------------------------------------------------------------------------- 1 | from utee import misc 2 | import os 3 | print = misc.logger.info 4 | from IPython import embed 5 | 6 | def load(model_name, dataset_name, model_root): 7 | if dataset_name == 'mnist': 8 | from mnist import dataset, model 9 | f = dataset.get 10 | m = model.get(model_name, model_root) 11 | elif dataset_name == 'svhn': 12 | from svhn import dataset, model 13 | f = dataset.get 14 | m = model.get(model_name, model_root) 15 | elif dataset_name == 'cifar10': 16 | from cifar import dataset, model 17 | f = dataset.get10 18 | m = model.get(model_name, model_root) 19 | elif dataset_name == 'cifar100': 20 | from cifar import dataset, model 21 | f = dataset.get100 22 | m = model.get(model_name, model_root) 23 | elif dataset_name == 'stl10': 24 | from stl10 import dataset, model 25 | f = dataset.get 26 | m = model.get(model_name, model_root) 27 | elif dataset_name == 'imagenet': 28 | from imagenet import dataset, model 29 | f = dataset.get 30 | m = model.get(model_name, model_root) 31 | else: 32 | print('Dataset not implemented') 33 | return f(**kwargs), model 34 | 35 | if __name__ == '__main__': 36 | m1 = alexnet() 37 | embed() 38 | -------------------------------------------------------------------------------- /scripts/nn/utee/selector.py: -------------------------------------------------------------------------------- 1 | from utee import misc 2 | import os 3 | from imagenet import dataset 4 | print = misc.logger.info 5 | from IPython import embed 6 | 7 | known_models = [ 8 | 'mnist', 'svhn', # 28x28 9 | 'cifar10', 'cifar100', # 32x32 10 | 'stl10', # 96x96 11 | 'alexnet', # 224x224 12 | 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', # 224x224 13 | 'resnet18', 'resnet34', 'resnet50', 'resnet101','resnet152', # 224x224 14 | 'squeezenet_v0', 'squeezenet_v1', #224x224 15 | 'inception_v3', # 299x299 16 | ] 17 | 18 | def mnist(cuda=True, model_root=None): 19 | print("Building and initializing mnist parameters") 20 | from mnist import model, dataset 21 | m = model.mnist(pretrained=os.path.join(model_root, 'mnist.pth')) 22 | if cuda: 23 | m = m.cuda() 24 | return m, dataset.get, False 25 | 26 | def svhn(cuda=True, model_root=None): 27 | print("Building and initializing svhn parameters") 28 | from svhn import model, dataset 29 | m = model.svhn(32, pretrained=os.path.join(model_root, 'svhn.pth')) 30 | if cuda: 31 | m = m.cuda() 32 | return m, dataset.get, False 33 | 34 | def cifar10(cuda=True, model_root=None): 35 | print("Building and initializing cifar10 parameters") 36 | from cifar import model, dataset 37 | m = model.cifar10(128, pretrained=os.path.join(model_root, 'cifar10.pth')) 38 | if cuda: 39 | m = m.cuda() 40 | return m, dataset.get10, False 41 | 42 | def cifar100(cuda=True, model_root=None): 43 | print("Building and initializing cifar100 parameters") 44 | from cifar import model, dataset 45 | m = model.cifar100(128, pretrained=os.path.join(model_root, 'cifar100.pth')) 46 | if cuda: 47 | m = m.cuda() 48 | return m, dataset.get100, False 49 | 50 | def stl10(cuda=True, model_root=None): 51 | print("Building and initializing stl10 parameters") 52 | from stl10 import model, dataset 53 | m = model.stl10(32, pretrained=os.path.join(model_root, 'stl10.pth')) 54 | if cuda: 55 | m = m.cuda() 56 | return m, dataset.get, False 57 | 58 | def alexnet(cuda=True, model_root=None): 59 | print("Building and initializing alexnet parameters") 60 | from imagenet import alexnet as alx 61 | m = alx.alexnet(True, model_root) 62 | if cuda: 63 | m = m.cuda() 64 | return m, dataset.get, True 65 | 66 | def vgg16(cuda=True, model_root=None): 67 | print("Building and initializing vgg16 parameters") 68 | from imagenet import vgg 69 | m = vgg.vgg16(True, model_root) 70 | if cuda: 71 | m = m.cuda() 72 | return m, dataset.get, True 73 | 74 | def vgg16_bn(cuda=True, model_root=None): 75 | print("Building vgg16_bn parameters") 76 | from imagenet import vgg 77 | m = vgg.vgg19_bn(model_root) 78 | if cuda: 79 | m = m.cuda() 80 | return m, dataset.get, True 81 | 82 | def vgg19(cuda=True, model_root=None): 83 | print("Building and initializing vgg19 parameters") 84 | from imagenet import vgg 85 | m = vgg.vgg19(True, model_root) 86 | if cuda: 87 | m = m.cuda() 88 | return m, dataset.get, True 89 | 90 | def vgg19_bn(cuda=True, model_root=None): 91 | print("Building vgg19_bn parameters") 92 | from imagenet import vgg 93 | m = vgg.vgg19_bn(model_root) 94 | if cuda: 95 | m = m.cuda() 96 | return m, dataset.get, True 97 | 98 | def inception_v3(cuda=True, model_root=None): 99 | print("Building and initializing inception_v3 parameters") 100 | from imagenet import inception 101 | m = inception.inception_v3(True, model_root) 102 | if cuda: 103 | m = m.cuda() 104 | return m, dataset.get, True 105 | 106 | def resnet18(cuda=True, model_root=None): 107 | print("Building and initializing resnet-18 parameters") 108 | from imagenet import resnet 109 | m = resnet.resnet18(True, model_root) 110 | if cuda: 111 | m = m.cuda() 112 | return m, dataset.get, True 113 | 114 | def resnet34(cuda=True, model_root=None): 115 | print("Building and initializing resnet-34 parameters") 116 | from imagenet import resnet 117 | m = resnet.resnet34(True, model_root) 118 | if cuda: 119 | m = m.cuda() 120 | return m, dataset.get, True 121 | 122 | def resnet50(cuda=True, model_root=None): 123 | print("Building and initializing resnet-50 parameters") 124 | from imagenet import resnet 125 | m = resnet.resnet50(True, model_root) 126 | if cuda: 127 | m = m.cuda() 128 | return m, dataset.get, True 129 | 130 | def resnet101(cuda=True, model_root=None): 131 | print("Building and initializing resnet-101 parameters") 132 | from imagenet import resnet 133 | m = resnet.resnet101(True, model_root) 134 | if cuda: 135 | m = m.cuda() 136 | return m, dataset.get, True 137 | 138 | def resnet152(cuda=True, model_root=None): 139 | print("Building and initializing resnet-152 parameters") 140 | from imagenet import resnet 141 | m = resnet.resnet152(True, model_root) 142 | if cuda: 143 | m = m.cuda() 144 | return m, dataset.get, True 145 | 146 | def squeezenet_v0(cuda=True, model_root=None): 147 | print("Building and initializing squeezenet_v0 parameters") 148 | from imagenet import squeezenet 149 | m = squeezenet.squeezenet1_0(True, model_root) 150 | if cuda: 151 | m = m.cuda() 152 | return m, dataset.get, True 153 | 154 | def squeezenet_v1(cuda=True, model_root=None): 155 | print("Building and initializing squeezenet_v1 parameters") 156 | from imagenet import squeezenet 157 | m = squeezenet.squeezenet1_1(True, model_root) 158 | if cuda: 159 | m = m.cuda() 160 | return m, dataset.get, True 161 | 162 | def select(model_name, **kwargs): 163 | assert model_name in known_models, model_name 164 | kwargs.setdefault('model_root', os.path.expanduser('~/.torch/models')) 165 | return eval('{}'.format(model_name))(**kwargs) 166 | 167 | if __name__ == '__main__': 168 | m1 = alexnet() 169 | embed() 170 | 171 | 172 | -------------------------------------------------------------------------------- /scripts/throttle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Based on this gist: https://gist.github.com/trongthanh/1196596 4 | # Replace lo with eth0/wlan0 to limit speed from wide lan 5 | 6 | if [ $# -ne 1 ]; then 7 | echo "Usage: $0 [start|stop|report]" 8 | elif [ "$1" == "start" ]; then 9 | #Setup the rate control and delay 10 | sudo tc qdisc add dev lo root handle 1: htb default 12 11 | sudo tc class add dev lo parent 1:1 classid 1:12 htb rate 2.5Gbit # ceil 20Mbit 12 | sudo tc qdisc add dev lo parent 1:12 netem delay 0.1ms 13 | elif [ $1 == "stop" ]; then 14 | #Remove the rate control/delay 15 | sudo tc qdisc del dev lo root 16 | elif [ $1 == "report" ]; then 17 | #To see what is configured on an interface, do this 18 | sudo tc -s qdisc ls dev lo 19 | else 20 | echo "Usage: $0 [start|stop|report]" 21 | fi 22 | 23 | -------------------------------------------------------------------------------- /src/demo/ahe/conv1d-benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | NN-Layers-Benchmarking: This code benchmarks FC and Conv layers for a neural network 3 | 4 | List of Authors: 5 | Chiraag Juvekar, chiraag@mit.edu 6 | 7 | License Information: 8 | MIT License 9 | Copyright (c) 2017, Massachusetts Institute of Technology (MIT) 10 | 11 | */ 12 | 13 | #include 14 | #include 15 | #include "pke/gazelle.h" 16 | 17 | using namespace std; 18 | using namespace lbcrypto; 19 | 20 | int main() { 21 | std::cout << "NN Layers Benchmark (ms):" << std::endl; 22 | 23 | //------------------ Setup Parameters ------------------ 24 | ui64 nRep = 1; 25 | double start, stop; 26 | 27 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 28 | ui64 z_p = RootOfUnity(opt::phim << 1, opt::p); 29 | ftt_precompute(z, opt::q, opt::logn); 30 | ftt_precompute(z_p, opt::p, opt::logn); 31 | encoding_precompute(opt::p, opt::logn); 32 | precompute_automorph_index(opt::phim); 33 | 34 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 35 | 36 | FVParams slow_params { 37 | false, 38 | opt::q, opt::p, opt::logn, opt::phim, 39 | (opt::q/opt::p), 40 | OPTIMIZED, std::make_shared(dgg), 41 | 9 42 | }; 43 | 44 | FVParams fast_params = slow_params; 45 | fast_params.fast_modulli = true; 46 | 47 | FVParams test_params = fast_params; 48 | 49 | //------------------- Synthetic Data ------------------- 50 | uv64 vec = get_dgg_testvector(opt::phim, opt::p); 51 | 52 | ui32 filter_size = 5; 53 | auto filter_1d = get_dgg_testvector(filter_size, opt::p); 54 | 55 | //----------------------- KeyGen ----------------------- 56 | nRep = 1; 57 | auto kp = KeyGen(test_params); 58 | uv32 index_list; 59 | for (ui32 i = 1; i <(filter_size/2); i++){ 60 | index_list.push_back(i); 61 | } 62 | 63 | for(ui32 i=1; i<=(filter_size/2); i++){ 64 | index_list.push_back(opt::phim/2-i); 65 | } 66 | 67 | start = currentDateTime(); 68 | for(ui64 i=0; i < nRep; i++){ 69 | kp = KeyGen(test_params); 70 | EvalAutomorphismKeyGen(kp.sk, index_list, test_params); 71 | } 72 | stop = currentDateTime(); 73 | std::cout << " KeyGen ("<< index_list.size() <<" keys): " << (stop-start)/nRep << std::endl; 74 | 75 | //----------------- Preprocess Vector ------------------ 76 | nRep = 100; 77 | ui32 mat_window_size = 10; 78 | ui32 mat_num_windows = 2; 79 | uv64 pt = packed_encode(vec, opt::p, opt::logn); 80 | auto ct_vec = preprocess_vec(kp.sk, pt, mat_window_size, mat_num_windows, test_params); 81 | start = currentDateTime(); 82 | for(ui64 i=0; i < nRep; i++){ 83 | pt = packed_encode(vec, opt::p, opt::logn); 84 | ct_vec = preprocess_vec(kp.sk, pt, mat_window_size, mat_num_windows, test_params); 85 | } 86 | stop = currentDateTime(); 87 | std::cout << " Preprocess Vector ("<< mat_num_windows <<" windows): " << (stop-start)/nRep << std::endl; 88 | 89 | //----------------- Preprocess Filter ------------------ 90 | auto enc_filter = preprocess_filter_1d(filter_1d, mat_window_size, mat_num_windows, test_params); 91 | start = currentDateTime(); 92 | for(ui64 i=0; i < nRep; i++){ 93 | enc_filter = preprocess_filter_1d(filter_1d, mat_window_size, mat_num_windows, test_params); 94 | } 95 | stop = currentDateTime(); 96 | std::cout << " Preprocess Filter: " << (stop-start)/nRep << std::endl; 97 | 98 | //--------------------- Conv1D (Rot) -------------------- 99 | auto ct_conv_rot = conv_1d_rot(ct_vec, enc_filter.size(), test_params); 100 | start = currentDateTime(); 101 | for(ui64 i=0; i < nRep; i++){ 102 | ct_conv_rot = conv_1d_rot(ct_vec, enc_filter.size(), test_params); 103 | } 104 | stop = currentDateTime(); 105 | std::cout << " Conv1D: " << (stop-start)/nRep << std::endl; 106 | 107 | //--------------------- Conv1D (Mul) -------------------- 108 | auto ct_conv_mul = conv_1d_mul(ct_conv_rot, enc_filter, test_params); 109 | start = currentDateTime(); 110 | for(ui64 i=0; i < nRep; i++){ 111 | ct_conv_mul = conv_1d_mul(ct_conv_rot, enc_filter, test_params); 112 | } 113 | stop = currentDateTime(); 114 | std::cout << " Conv1D: " << (stop-start)/nRep << std::endl; 115 | 116 | //------------------------ Conv1D ---------------------- 117 | auto ct_conv_1d = conv_1d_online(ct_vec, enc_filter, test_params); 118 | start = currentDateTime(); 119 | for(ui64 i=0; i < nRep; i++){ 120 | ct_conv_1d = conv_1d_online(ct_vec, enc_filter, test_params); 121 | } 122 | stop = currentDateTime(); 123 | std::cout << " Conv1D: " << (stop-start)/nRep << std::endl; 124 | 125 | 126 | //----------------------- Check ------------------------ 127 | auto conv_1d = packed_decode(Decrypt(kp.sk, ct_conv_1d, test_params), opt::p, opt::logn); 128 | 129 | std::cout << std::endl; 130 | std::cout << "Margin ct: " << NoiseMargin(kp.sk, ct_vec[0], test_params) << std::endl; 131 | std::cout << "Margin conv_1d: " << NoiseMargin(kp.sk, ct_conv_1d, test_params) << std::endl; 132 | std::cout << std::endl; 133 | 134 | auto conv_1d_ref = conv_1d_pt(vec, filter_1d, opt::p); 135 | /* std::cout << vec_to_str(to_signed(vec, opt::p)) << std::endl; 136 | std::cout << vec_to_str(to_signed(filter_1d, opt::p)) << std::endl; 137 | std::cout << vec_to_str(to_signed(conv_1d_ref, opt::p)) << std::endl; 138 | std::cout << vec_to_str(to_signed(conv_1d, opt::p)) << std::endl;*/ 139 | // check_vec_eq(conv_1d_ref, conv_1d, "conv_1d mismatch:\n"); 140 | 141 | return 0; 142 | } 143 | 144 | -------------------------------------------------------------------------------- /src/demo/ahe/fv-automorph-benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | FV-Automorphism-Benchmarking: This code benchmarks automorphisms for FV 3 | 4 | List of Authors: 5 | Chiraag Juvekar, chiraag@mit.edu 6 | 7 | License Information: 8 | MIT License 9 | Copyright (c) 2017, Massachusetts Institute of Technology (MIT) 10 | 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using namespace std; 19 | using namespace lbcrypto; 20 | 21 | int main() { 22 | std::cout << "FV Automorph Benchmark (ms):" << std::endl; 23 | 24 | //------------------ Setup Parameters ------------------ 25 | ui64 nRep = 100; 26 | double start, stop; 27 | 28 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 29 | ui64 z_p = RootOfUnity(opt::phim << 1, opt::p); 30 | ftt_precompute(z, opt::q, opt::logn); 31 | ftt_precompute(z_p, opt::p, opt::logn); 32 | encoding_precompute(opt::p, opt::logn); 33 | precompute_automorph_index(opt::phim); 34 | 35 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 36 | 37 | FVParams slow_params { 38 | false, 39 | opt::q, opt::p, opt::logn, opt::phim, 40 | (opt::q/opt::p), 41 | OPTIMIZED, std::make_shared(dgg), 42 | 20 43 | }; 44 | uv32 windows = {20, 10, 5}; 45 | 46 | FVParams test_params = slow_params; 47 | 48 | auto kp = KeyGen(test_params); 49 | uv64 v1 = get_dgg_testvector(opt::phim, opt::p); 50 | uv64 pt1 = packed_encode(v1, opt::p, opt::logn); 51 | auto ct1 = Encrypt(kp.sk, pt1, test_params); 52 | ui32 rot = 2; 53 | 54 | for(ui32 nw=0; nw 14 | #include 15 | #include 16 | 17 | using namespace std; 18 | using namespace lbcrypto; 19 | 20 | 21 | int main() { 22 | std::cout << "FV Benchmark (ms):" << std::endl; 23 | 24 | //------------------ Setup Parameters ------------------ 25 | ui64 nRep = 1000; 26 | double start, stop; 27 | 28 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 29 | 30 | FVParams slow_params { 31 | false, 32 | opt::q, opt::p, opt::logn, opt::phim, 33 | (opt::q/opt::p), 34 | OPTIMIZED, std::make_shared(dgg), 35 | 20 36 | }; 37 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 38 | ui64 z_p = RootOfUnity(opt::phim << 1, opt::p); 39 | ftt_precompute(z, opt::q, opt::logn); 40 | ftt_precompute(z_p, opt::p, opt::logn); 41 | encoding_precompute(opt::p, opt::logn); 42 | //------------------------ Setup ----------------------- 43 | start = currentDateTime(); 44 | for(ui64 i=0; i < 100; i++){ 45 | z = RootOfUnity(opt::phim << 1, opt::q); 46 | z_p = RootOfUnity(opt::phim << 1, opt::p); 47 | ftt_precompute(z, opt::q, opt::logn); 48 | ftt_precompute(z_p, opt::p, opt::logn); 49 | encoding_precompute(opt::p, opt::logn); 50 | } 51 | stop = currentDateTime(); 52 | std::cout << " Setup: " << (stop-start)/100 << std::endl; 53 | 54 | FVParams test_params = slow_params; 55 | 56 | for(ui32 t=0; t<2; t++){ 57 | test_params.fast_modulli = !test_params.fast_modulli; 58 | 59 | uv64 v = get_dgg_testvector(opt::phim, opt::p); 60 | uv64 pt = packed_encode(v, opt::p, opt::logn); 61 | 62 | //----------------------- KeyGen ----------------------- 63 | auto kp = KeyGen(test_params); 64 | start = currentDateTime(); 65 | for(ui64 i=0; i < nRep; i++){ 66 | kp = KeyGen(test_params); 67 | } 68 | stop = currentDateTime(); 69 | std::cout << " KeyGen: " << (stop-start)/nRep << std::endl; 70 | //--------------------- PK-Encrypt---------------------- 71 | Ciphertext ct_pk(opt::phim); 72 | start = currentDateTime(); 73 | for(ui64 i=0; i < nRep; i++){ 74 | ct_pk = Encrypt(kp.pk, pt, test_params); 75 | } 76 | stop = currentDateTime(); 77 | std::cout << " PK-Encrypt: " << (stop-start)/nRep << std::endl; 78 | 79 | //--------------------- SK-Encrypt---------------------- 80 | Ciphertext ct_sk(opt::phim); 81 | start = currentDateTime(); 82 | for(ui64 i=0; i < nRep; i++){ 83 | pt = packed_encode(v, opt::p, opt::logn); 84 | ct_sk = Encrypt(kp.sk, pt, test_params); 85 | } 86 | stop = currentDateTime(); 87 | std::cout << " SK-Encrypt: " << (stop-start)/nRep << std::endl; 88 | 89 | //---------------------- Decrypt ----------------------- 90 | uv64 pt_pk(opt::phim), pt_sk(opt::phim), v_pk(opt::phim), v_sk(opt::phim); 91 | start = currentDateTime(); 92 | for(ui64 i=0; i < (nRep/2); i++){ 93 | pt_pk = Decrypt(kp.sk, ct_pk, test_params); 94 | v_pk = packed_decode(pt_pk, opt::p, opt::logn); 95 | pt_sk = Decrypt(kp.sk, ct_sk, test_params); 96 | v_sk = packed_decode(pt_sk, opt::p, opt::logn); 97 | } 98 | stop = currentDateTime(); 99 | std::cout << " Decrypt: " << (stop-start)/nRep << std::endl; 100 | 101 | check_vec_eq(v, v_pk, "pk enc-dec mismatch:\n"); 102 | check_vec_eq(v, v_sk, "sk enc-dec mismatch:\n"); 103 | } 104 | return 0; 105 | } 106 | 107 | -------------------------------------------------------------------------------- /src/demo/ahe/fv-she-benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | FV-SHE-Benchmarking: This code benchmarks add, sub, neg and mult-plain for FV 3 | 4 | List of Authors: 5 | Chiraag Juvekar, chiraag@mit.edu 6 | 7 | License Information: 8 | MIT License 9 | Copyright (c) 2017, Massachusetts Institute of Technology (MIT) 10 | 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using namespace std; 19 | using namespace lbcrypto; 20 | 21 | int main() { 22 | std::cout << "FV SHE Benchmark (ms):" << std::endl; 23 | 24 | //------------------ Setup Parameters ------------------ 25 | ui64 nRep = 1000; 26 | double start, stop; 27 | 28 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 29 | ui64 z_p = RootOfUnity(opt::phim << 1, opt::p); 30 | ftt_precompute(z, opt::q, opt::logn); 31 | ftt_precompute(z_p, opt::p, opt::logn); 32 | encoding_precompute(opt::p, opt::logn); 33 | 34 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 35 | 36 | FVParams slow_params { 37 | false, 38 | opt::q, opt::p, opt::logn, opt::phim, 39 | (opt::q/opt::p), 40 | OPTIMIZED, std::make_shared(dgg), 41 | 20 42 | }; 43 | 44 | FVParams test_params = slow_params; 45 | 46 | auto kp = KeyGen(test_params); 47 | uv64 v1 = get_dgg_testvector(opt::phim, opt::p); 48 | uv64 v2 = get_dgg_testvector(opt::phim, opt::p); 49 | 50 | uv64 pt1 = packed_encode(v1, opt::p, opt::logn); 51 | uv64 pt2 = packed_encode(v2, opt::p, opt::logn); 52 | 53 | uv64 v1_f = packed_decode(pt1, opt::p, opt::logn); 54 | check_vec_eq(v1, v1_f, "Decode mismatch:\n"); 55 | 56 | uv64 pt_add(opt::phim); 57 | for(ui32 i=0; i 14 | #include 15 | #include 16 | #include "math/bit_twiddle.h" 17 | using namespace std; 18 | using namespace lbcrypto; 19 | 20 | int main() { 21 | std::cout << "NN Layers Benchmark (ms):" << std::endl; 22 | 23 | //------------------ Setup Parameters ------------------ 24 | ui64 nRep = 1; 25 | double start, stop; 26 | 27 | ui64 z = opt::z; 28 | ui64 z_p = opt::z_p; 29 | 30 | ftt_precompute(z, opt::q, opt::logn); 31 | ftt_precompute(z_p, opt::p, opt::logn); 32 | encoding_precompute(opt::p, opt::logn); 33 | precompute_automorph_index(opt::phim); 34 | 35 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 36 | 37 | FVParams slow_params { 38 | false, 39 | opt::q, opt::p, opt::logn, opt::phim, 40 | (opt::q/opt::p), 41 | OPTIMIZED, std::make_shared(dgg), 42 | 8 43 | }; 44 | 45 | FVParams fast_params = slow_params; 46 | fast_params.fast_modulli = true; 47 | 48 | FVParams test_params = fast_params; 49 | 50 | //------------------- Synthetic Data ------------------- 51 | ui32 num_rows = 64, num_cols = 128, window_size = 8; 52 | std::cin >> num_rows >> num_cols >> window_size; 53 | test_params.window_size = window_size; 54 | 55 | uv64 vec = get_dgg_testvector(num_cols, opt::p); 56 | 57 | std::vector mat(num_rows, uv64(num_cols)); 58 | for(ui32 row=0; row 14 | #include 15 | #include 16 | #include 17 | #include "math/bit_twiddle.h" 18 | using namespace std; 19 | using namespace lbcrypto; 20 | 21 | int main() { 22 | std::cout << "NN Layers Benchmark (ms):" << std::endl; 23 | 24 | //------------------ Setup Parameters ------------------ 25 | ui64 nRep = 1; 26 | double start, stop; 27 | 28 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 29 | ui64 z_p = RootOfUnity(opt::phim << 1, opt::p); 30 | ftt_precompute(z, opt::q, opt::logn); 31 | ftt_precompute(z_p, opt::p, opt::logn); 32 | encoding_precompute(opt::p, opt::logn); 33 | precompute_automorph_index(opt::phim); 34 | 35 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 36 | 37 | FVParams slow_params { 38 | false, 39 | opt::q, opt::p, opt::logn, opt::phim, 40 | (opt::q/opt::p), 41 | OPTIMIZED, std::make_shared(dgg), 42 | 8 43 | }; 44 | 45 | FVParams fast_params = slow_params; 46 | fast_params.fast_modulli = true; 47 | 48 | FVParams test_params = fast_params; 49 | 50 | //------------------- Synthetic Data ------------------- 51 | ui32 vec_size = 2048; 52 | std::cin >> vec_size; 53 | uv64 vec_c = get_dgg_testvector(vec_size, opt::p); 54 | uv64 vec_s = get_dgg_testvector(vec_size, opt::p); 55 | 56 | //----------------------- KeyGen ----------------------- 57 | nRep = 10; 58 | auto kp = KeyGen(test_params); 59 | 60 | start = currentDateTime(); 61 | for(ui64 i=0; i < nRep; i++){ 62 | kp = KeyGen(test_params); 63 | } 64 | stop = currentDateTime(); 65 | std::cout << " KeyGen: " << (stop-start)/nRep << std::endl; 66 | 67 | //----------------- Client Preprocess ------------------ 68 | nRep = 100; 69 | auto ct_vec = preprocess_client_share(kp.sk, vec_c, test_params); 70 | start = currentDateTime(); 71 | for(ui64 i=0; i < nRep; i++){ 72 | ct_vec = preprocess_client_share(kp.sk, vec_c, test_params); 73 | } 74 | stop = currentDateTime(); 75 | std::cout << " Preprocess Client: " << (stop-start)/nRep << std::endl; 76 | 77 | //----------------- Server Preprocess ----------------- 78 | std::vector pt_vec; 79 | uv64 vec_s_f; 80 | std::tie(pt_vec, vec_s_f) = preprocess_server_share(vec_s, test_params); 81 | start = currentDateTime(); 82 | for(ui64 i=0; i < nRep; i++){ 83 | std::tie(pt_vec, vec_s_f) = preprocess_server_share(vec_s, test_params); 84 | } 85 | stop = currentDateTime(); 86 | std::cout << " Preprocess Server: " << (stop-start)/nRep << std::endl; 87 | 88 | //---------------------- Square ----------------------- 89 | auto ct_c_f = square_online(ct_vec, pt_vec, test_params); 90 | start = currentDateTime(); 91 | for(ui64 i=0; i < nRep; i++){ 92 | ct_c_f = square_online(ct_vec, pt_vec, test_params); 93 | } 94 | stop = currentDateTime(); 95 | std::cout << " Multiply: " << (stop-start)/nRep << std::endl; 96 | 97 | //------------------- Post-Process --------------------- 98 | auto vec_c_f = postprocess_client_share(kp.sk, ct_c_f, vec_size, test_params); 99 | start = currentDateTime(); 100 | for(ui64 i=0; i < nRep; i++){ 101 | vec_c_f = postprocess_client_share(kp.sk, ct_c_f, vec_size, test_params); 102 | } 103 | stop = currentDateTime(); 104 | std::cout << " Post-Process: " << (stop-start)/nRep << std::endl; 105 | 106 | //--------------------- Square PT ---------------------- 107 | start = currentDateTime(); 108 | auto vec_c_f_ref = square_pt(vec_c, vec_s, vec_s_f, opt::p); 109 | for(ui64 i=0; i < nRep; i++){ 110 | vec_c_f_ref = square_pt(vec_c, vec_s, vec_s_f, opt::p); 111 | } 112 | stop = currentDateTime(); 113 | std::cout << " Multiply PT: " << (stop-start)/nRep << std::endl; 114 | 115 | //----------------------- Check ------------------------ 116 | // std::cout << std::endl; 117 | // std::cout << "Margin ct: " << NoiseMargin(kp.sk, ct_vec[0], test_params) << std::endl; 118 | // std::cout << "Margin prod: " << NoiseMargin(kp.sk, ct_prod, test_params) << std::endl; 119 | std::cout << std::endl; 120 | 121 | check_vec_eq(vec_c_f_ref, vec_c_f, "square mismatch:\n"); 122 | 123 | return 0; 124 | } 125 | 126 | -------------------------------------------------------------------------------- /src/demo/ahe/square-online.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * gc-online.cpp 3 | * 4 | * Created on: Nov 28, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include "math/bit_twiddle.h" 22 | 23 | using namespace lbcrypto; 24 | using namespace osuCrypto; 25 | 26 | std::string addr = "localhost"; 27 | ui32 vec_size = 2048, window_size = 9; 28 | ui32 num_rep = 100; 29 | 30 | void ahe_client(){ 31 | std::cout << "Client" << std::endl; 32 | 33 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 34 | FVParams test_params { 35 | true, 36 | opt::q, opt::p, opt::logn, opt::phim, 37 | (opt::q/opt::p), 38 | OPTIMIZED, std::make_shared(dgg), 39 | window_size 40 | }; 41 | 42 | // get up the networking 43 | IOService ios(0); 44 | Session sess(ios, addr, 1212, EpMode::Client); 45 | Channel chl = sess.addChannel(); 46 | 47 | Timer time; 48 | chl.resetStats(); 49 | time.setTimePoint("start"); 50 | // KeyGen 51 | auto kp = KeyGen(test_params); 52 | uv64 vec_c = get_dgg_testvector(vec_size, opt::p); 53 | 54 | 55 | std::cout 56 | << " Sent: " << chl.getTotalDataSent() << std::endl 57 | << " received: " << chl.getTotalDataRecv() << std::endl << std::endl; 58 | chl.resetStats(); 59 | 60 | time.setTimePoint("setup"); 61 | 62 | for(ui32 rep=0; rep(dgg), 99 | window_size 100 | }; 101 | 102 | // get up the networking 103 | IOService ios(0); 104 | Session sess(ios, addr, 1212, EpMode::Server); 105 | Channel chl = sess.addChannel(); 106 | 107 | Timer time; 108 | time.setTimePoint("start"); 109 | 110 | uv64 vec_s = get_dgg_testvector(vec_size, opt::p); 111 | 112 | time.setTimePoint("setup"); 113 | for(ui32 rep=0; rep pt_vec; 115 | uv64 vec_s_f; 116 | std::tie(pt_vec, vec_s_f) = preprocess_server_share(vec_s, test_params); 117 | 118 | CTVec ct_vec(2, Ciphertext(opt::phim)); 119 | for(ui32 n=0; n> vec_size >> window_size; 142 | 143 | ftt_precompute(opt::z, opt::q, opt::logn); 144 | ftt_precompute(opt::z_p, opt::p, opt::logn); 145 | encoding_precompute(opt::p, opt::logn); 146 | precompute_automorph_index(opt::phim); 147 | 148 | if (argc == 1) 149 | { 150 | std::vector thrds(2); 151 | thrds[0] = std::thread([]() { ahe_server(); }); 152 | thrds[1] = std::thread([]() { ahe_client(); }); 153 | 154 | for (auto& thrd : thrds) 155 | thrd.join(); 156 | } 157 | else if(argc == 2) 158 | { 159 | int role = atoi(argv[1]); // 0: send, 1: recv 160 | role ? ahe_server() : ahe_client(); 161 | } 162 | else 163 | { 164 | std::cout << "this program takes a runtime argument.\n\n" 165 | << "to run the AES GC, run\n\n" 166 | << " gc-online [0|1]\n\n" 167 | << "the optional {0,1} argument specifies in which case the program will\n" 168 | << "run between two terminals, where each one was set to the opposite value. e.g.\n\n" 169 | << " gc-online 0\n\n" 170 | << " gc-online 1\n\n" 171 | << "These programs are fully networked and try to connect at localhost:1212.\n" 172 | << std::endl; 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /src/demo/ahe/transfrom-benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Transform-Benchmarking: This code benchmarks the FTT code 3 | 4 | List of Authors: 5 | Chiraag Juvekar, chiraag@mit.edu 6 | 7 | License Information: 8 | MIT License 9 | Copyright (c) 2017, Massachusetts Institute of Technology (MIT) 10 | 11 | */ 12 | 13 | #include 14 | #include 15 | #include 16 | #include "utils/debug.h" 17 | #include "utils/test.h" 18 | #include "math/params.h" 19 | #include "math/nbtheory.h" 20 | #include "math/distrgen.h" 21 | #include "math/transfrm.h" 22 | 23 | using namespace lbcrypto; 24 | 25 | 26 | int main() { 27 | std::cout << "Transform Benchmark (ms):" << std::endl; 28 | 29 | //------------------ Setup Parameters ------------------ 30 | ui64 nRep; 31 | double start, stop; 32 | 33 | uv64 x = get_dug_vector(opt::phim, opt::q); 34 | uv64 X, xx; 35 | 36 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 37 | ui64 z_p = RootOfUnity(opt::phim << 1, opt::p); 38 | ftt_precompute(z, opt::q, opt::logn); 39 | ftt_precompute(z_p, opt::p, opt::logn); 40 | X = ftt_fwd(x, opt::q, opt::logn); 41 | xx = ftt_inv(X, opt::q, opt::logn); 42 | 43 | check_vec_eq(x, xx, "ftt mismatch\n"); 44 | 45 | //-------------------- Baseline FTT -------------------- 46 | nRep = 1000; 47 | start = currentDateTime(); 48 | for(uint64_t n=0; n 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include "utils/network.h" 17 | 18 | using namespace osuCrypto; 19 | 20 | std::string addr = "localhost"; 21 | 22 | void sender(){ 23 | setThreadName("Sender"); 24 | 25 | // get up the networking 26 | IOService ios(0); 27 | Session sess(ios, addr, 1212, EpMode::Client); 28 | Channel chl = sess.addChannel(); 29 | 30 | senderGetLatency(chl); 31 | 32 | chl.close(); 33 | sess.stop(); 34 | ios.stop(); 35 | return; 36 | } 37 | 38 | void receiver(){ 39 | setThreadName("Receiver"); 40 | 41 | // get up the networking 42 | IOService ios(0); 43 | Session sess(ios, addr, 1212, EpMode::Server); 44 | Channel chl = sess.addChannel(); 45 | 46 | recverGetLatency(chl); 47 | 48 | chl.close(); 49 | sess.stop(); 50 | ios.stop(); 51 | return; 52 | } 53 | 54 | int main(int argc, char** argv) { 55 | if (argc == 1) 56 | { 57 | std::vector thrds(2); 58 | thrds[0] = std::thread([]() { sender(); }); 59 | thrds[1] = std::thread([]() { receiver(); }); 60 | 61 | for (auto& thrd : thrds) 62 | thrd.join(); 63 | } 64 | else if(argc == 2) 65 | { 66 | int role = atoi(argv[1]); // 0: send, 1: recv 67 | role ? receiver() : sender(); 68 | } 69 | else 70 | { 71 | std::cout << "this program takes a runtime argument.\n\n" 72 | << "to run the AES GC, run\n\n" 73 | << " gc-online [0|1]\n\n" 74 | << "the optional {0,1} argument specifies in which case the program will\n" 75 | << "run between two terminals, where each one was set to the opposite value. e.g.\n\n" 76 | << " gc-online 0\n\n" 77 | << " gc-online 1\n\n" 78 | << "These programs are fully networked and try to connect at localhost:1212.\n" 79 | << std::endl; 80 | } 81 | } 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /src/demo/tpc/act-gc-benchmark.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of JustGarble. 3 | 4 | JustGarble is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | JustGarble is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with JustGarble. If not, see . 16 | 17 | */ 18 | 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include "pke/gazelle.h" 29 | #include "gc/gc.h" 30 | #include "gc/util.h" 31 | #include "gc/gazelle_circuits.h" 32 | #include 33 | 34 | using namespace osuCrypto; 35 | using namespace lbcrypto; 36 | 37 | void buildCircuit(GarbledCircuit& gc, BuildContext& context, 38 | ui64 width, ui64 in_args, ui64 n_circ, ui64 p) { 39 | std::vector in(in_args, uv64(width)); 40 | uv64 out(width); 41 | uv64 s_p(width); 42 | uv64 s_p_2(width); 43 | 44 | int n = n_circ*width*in_args; 45 | int m = n_circ*width; 46 | 47 | startBuilding(&gc, &context, n, m, n_circ*1000); 48 | gc.n_c = n_circ*width; 49 | CONSTCircuit(&gc, &context, p, width, s_p); 50 | CONSTCircuit(&gc, &context, p/2, width, s_p_2); 51 | for(ui64 i=0; i din = std::vector(n_circ, uv64(in_args)); 79 | std::vector dref = std::vector(n_circ, uv64(out_args)); 80 | std::vector dout_pt = std::vector(n_circ, uv64(out_args)); 81 | std::vector dout = std::vector(n_circ, uv64(out_args)); 82 | 83 | for(ui32 n=0; n. 16 | 17 | */ 18 | 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "gc/gc.h" 27 | #include "gc/util.h" 28 | #include "gc/aescircuits.h" 29 | #include 30 | 31 | using namespace osuCrypto; 32 | using namespace lbcrypto; 33 | 34 | std::string AES_CIRCUIT_FILE_NAME = "./aesCircuit"; 35 | 36 | unsigned long timedEval(GarbledCircuit *garbledCircuit, InputLabels& inputLabels) { 37 | 38 | int n = garbledCircuit->n; 39 | int m = garbledCircuit->m; 40 | ExtractedLabels extractedLabels(n); 41 | OutputLabels outputs(m); 42 | int j; 43 | InputMap inputs(n); 44 | unsigned long startTime, endTime; 45 | unsigned long sum = 0; 46 | for (j = 0; j < n; j++) { 47 | inputs[j] = rand() % 2; 48 | } 49 | extractLabels(extractedLabels, inputLabels, inputs); 50 | startTime = RDTSC; 51 | evaluate(garbledCircuit, extractedLabels, outputs); 52 | endTime = RDTSC; 53 | sum = endTime - startTime; 54 | return sum; 55 | 56 | } 57 | 58 | int main() { 59 | int rounds = 10; 60 | int n = 128 + (128 * rounds); 61 | int m = 128; 62 | 63 | GarbledCircuit aesCircuit; 64 | BuildContext context; 65 | buildAESCircuit(aesCircuit, context); 66 | 67 | InputLabels inputLabels(n); 68 | OutputMap outputMap(m); 69 | int i, j; 70 | 71 | int timeGarble[TIMES]; 72 | int timeEval[TIMES]; 73 | double timeGarbleMedians[TIMES]; 74 | double timeEvalMedians[TIMES]; 75 | garbleCircuit(&aesCircuit, inputLabels, outputMap); 76 | 77 | for (j = 0; j < TIMES; j++) { 78 | for (i = 0; i < TIMES; i++) { 79 | timeGarble[i] = garbleCircuit(&aesCircuit, inputLabels, outputMap); 80 | timeEval[i] = timedEval(&aesCircuit, inputLabels); 81 | } 82 | timeGarbleMedians[j] = ((double) median(timeGarble, TIMES)) 83 | / aesCircuit.q; 84 | timeEvalMedians[j] = ((double) median(timeEval, TIMES)) / aesCircuit.q; 85 | } 86 | double garblingTime = doubleMean(timeGarbleMedians, TIMES); 87 | double evalTime = doubleMean(timeEvalMedians, TIMES); 88 | std::cout << garblingTime << " " << evalTime << std::endl; 89 | return 0; 90 | } 91 | -------------------------------------------------------------------------------- /src/demo/tpc/aesCircuit: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chiraag/gazelle_mpc/f4eb3bae09bf4897f2651946eac7dee17e094a6f/src/demo/tpc/aesCircuit -------------------------------------------------------------------------------- /src/demo/tpc/ot-benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | //using namespace std; 5 | #include 6 | using namespace osuCrypto; 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | u64 baseCount = 128; 21 | u64 numOTs = 1 << 24; 22 | 23 | void iknp_recv(void) 24 | { 25 | setThreadName("Receiver"); 26 | 27 | PRNG prng0(_mm_set_epi32(4253465, 3434565, 234435, 23987045)); 28 | BitVector choice(numOTs); 29 | choice.randomize(prng0); 30 | SRBaseOT send; 31 | IKNPReceiver r; 32 | std::vector> baseSend(baseCount); 33 | std::vector msgs(numOTs); 34 | 35 | // get up the networking 36 | std::string name = "n"; 37 | IOService ios(0); 38 | Session ep0(ios, "localhost", 1212, EpMode::Server, name); 39 | Channel chl = ep0.addChannel(name, name); 40 | 41 | send.send(baseSend, prng0, chl); 42 | r.setBaseOts(baseSend); 43 | 44 | r.receive(choice, msgs, prng0, chl); 45 | 46 | /*for(u64 n=0; n> msgs(numOTs); 69 | for(u64 n=0; n baseRecv(baseCount); 74 | BitVector baseChoice(baseCount); 75 | baseChoice.randomize(prng0); 76 | SRBaseOT base_ot; 77 | IKNPSender s; 78 | 79 | Timer time; 80 | 81 | 82 | time.setTimePoint("start"); 83 | base_ot.receive(baseChoice, baseRecv, prng0, chl); 84 | s.setBaseOts(baseRecv, baseChoice); 85 | 86 | time.setTimePoint("base"); 87 | s.send(msgs, prng0, chl); 88 | 89 | 90 | /*for(u64 n=0; n> numOTs; 106 | if (argc == 1) 107 | { 108 | std::vector thrds(2); 109 | thrds[0] = std::thread([]() { iknp_send(); }); 110 | thrds[1] = std::thread([]() { iknp_recv(); }); 111 | 112 | for (auto& thrd : thrds) 113 | thrd.join(); 114 | } 115 | else if(argc == 2) 116 | { 117 | int role = atoi(argv[1]); // 0: send, 1: recv 118 | role ? iknp_recv() : iknp_send(); 119 | } 120 | else 121 | { 122 | std::cout << "this program takes a runtime argument.\n\n" 123 | << "to run the IKNP passive secure 1-out-of-2 OT, run\n\n" 124 | << " frontend.exe [0|1]\n\n" 125 | << "the optional {0,1} argument specifies in which case the program will\n" 126 | << "run between two terminals, where each one was set to the opposite value. e.g.\n\n" 127 | << " frontend.exe 0\n\n" 128 | << " frontend.exe 1\n\n" 129 | << "These programs are fully networked and try to connect at localhost:1212.\n" 130 | << std::endl; 131 | } 132 | 133 | return 0; 134 | } 135 | -------------------------------------------------------------------------------- /src/lib/gc/aescircuits.h: -------------------------------------------------------------------------------- 1 | /* 2 | * aescircuits.h 3 | * 4 | * Created on: Dec 1, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #ifndef SRC_LIB_GC_AESCIRCUITS_H_ 9 | #define SRC_LIB_GC_AESCIRCUITS_H_ 10 | 11 | namespace lbcrypto { 12 | 13 | void SBOXNOTABLE(GarbledCircuit *garbledCircuit, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 14 | void AddRoundKey(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 15 | void SubBytes(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 16 | void SubBytesTable(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 17 | void ShiftRows(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 18 | void MixColumns(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 19 | void MULTE_GF16(GarbledCircuit *garbledCircuit, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 20 | void INV_GF16(GarbledCircuit *garbledCircuit, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 21 | void AFFINE(GarbledCircuit *garbledCircuit, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 22 | void SBOX(GarbledCircuit *garbledCircuit, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 23 | void INVMAP(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 24 | void GF8MULCircuit(GarbledCircuit *garbledCircuit, BuildContext *garblingContext, ui64 n, ui64* inputs, ui64* outputs); 25 | 26 | void GF4MULCircuit(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 27 | void GF4SQCircuit(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 28 | void GF4SCLNCircuit(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 29 | void GF4SCLN2Circuit(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 30 | 31 | void NewSBOXCircuit(GarbledCircuit *gc, BuildContext *garblingContext, ui64* inputs, ui64* outputs); 32 | 33 | void buildAESCircuit(GarbledCircuit& garbledCircuit, BuildContext& garblingContext); 34 | 35 | } 36 | 37 | #endif /* SRC_LIB_GC_AESCIRCUITS_H_ */ 38 | -------------------------------------------------------------------------------- /src/lib/gc/circuits.h: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of JustGarble. 3 | 4 | JustGarble is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | JustGarble is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with JustGarble. If not, see . 16 | 17 | */ 18 | 19 | #ifndef CIRCUITS_H_ 20 | #define CIRCUITS_H_ 21 | 22 | #include "gc.h" 23 | 24 | namespace lbcrypto { 25 | 26 | void CONSTCircuit(GarbledCircuit *gc, BuildContext *context, ui64 p, ui64 width, uv64& out); 27 | 28 | void ANDTreeCircuit(GarbledCircuit *gc, BuildContext *context, uv64& in, ui64& out); 29 | void ORCircuit(GarbledCircuit *gc, BuildContext *context, ui64 n, ui64* inputs, ui64* outputs); 30 | 31 | void XORCircuit(GarbledCircuit *gc, BuildContext *context, ui64 n, ui64* inputs, ui64* outputs); 32 | void XORCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, uv64& out); 33 | void ANDCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, uv64& out); 34 | void ORCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, uv64& out); 35 | void NOTCircuit(GarbledCircuit *gc, BuildContext *context, uv64& in, uv64& out); 36 | void MIXEDCircuit(GarbledCircuit *gc, BuildContext *context, ui64 n, ui64* inputs, ui64* outputs); 37 | 38 | void SHLCircuit(GarbledCircuit *gc, BuildContext *context, uv64& in, ui64 shift, uv64& out); 39 | void SHRCircuit(GarbledCircuit *gc, BuildContext *context, uv64& in, ui64 shift, uv64& out); 40 | 41 | void ADD32Circuit(GarbledCircuit *gc, BuildContext *context, ui64 a, ui64 b, ui64 cin, ui64& s, ui64& cout); 42 | void ADD22Circuit(GarbledCircuit *gc, BuildContext *context, ui64 a, ui64 b, ui64& s, ui64& cout); 43 | void SUB32Circuit(GarbledCircuit *gc,BuildContext *context, ui64 a, ui64 b, ui64 cin, ui64& s, ui64& cout); 44 | 45 | void INCCircuit(GarbledCircuit *gc, BuildContext *context, uv64& in, uv64& out, ui64& carry); 46 | void ADDCircuit(GarbledCircuit *gc, BuildContext *context, 47 | const uv64& in_a, const uv64& in_b, uv64& out, ui64& carry); 48 | void SUBSlowCircuit(GarbledCircuit *gc, BuildContext *context, uv64& in_a, uv64& in_b, uv64& out, ui64& carry); 49 | void SUBCircuit(GarbledCircuit *gc, BuildContext *context, 50 | const uv64& in_a, const uv64& in_b, uv64& out, ui64& carry); 51 | 52 | void EQUCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, ui64& out); 53 | void LEQCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, ui64& out); 54 | void GEQCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, ui64& out); 55 | void LESCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, ui64& out); 56 | void GRECircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, ui64& out); 57 | 58 | void MUXCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in0, 59 | const uv64& in1, const ui64 sel, uv64& out); 60 | void MINCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, uv64& out); 61 | void MAXCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& in_a, const uv64& in_b, uv64& out); 62 | 63 | // ui64 MULCircuit(GarbledCircuit *gc, GarblingContext *context, ui64 n, ui64* inputs, ui64* outputs); 64 | 65 | void MultiXORCircuit(GarbledCircuit *gc, BuildContext *context, ui64 d, ui64 n, ui64* inputs, ui64* outputs); 66 | 67 | 68 | void EncoderCircuit(GarbledCircuit *gc, BuildContext *context, ui64* inputs, ui64* outputs, ui64 enc[]); 69 | void EncoderOneCircuit(GarbledCircuit *gc, BuildContext *context, ui64* inputs, ui64* outputs, ui64 enc[]); 70 | 71 | void RANDCircuit(GarbledCircuit *garbledCircuit, BuildContext *context, ui64 n, ui64* inputs, ui64* outputs, ui64 q, ui64 qf); 72 | 73 | void buildTestCircuit(GarbledCircuit& garbledCircuit, BuildContext& context); 74 | 75 | } 76 | 77 | #endif /* CIRCUITS_H_ */ 78 | -------------------------------------------------------------------------------- /src/lib/gc/common.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * common.cpp 3 | * 4 | * Created on: Nov 28, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #include "common.h" 9 | #include 10 | -------------------------------------------------------------------------------- /src/lib/gc/common.h: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of JustGarble. 3 | 4 | JustGarble is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | JustGarble is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with JustGarble. If not, see . 16 | 17 | */ 18 | 19 | 20 | #ifndef common 21 | #define common 22 | #include 23 | #include 24 | 25 | #include 26 | #include 27 | #include 28 | 29 | #include 30 | 31 | #include "aes.h" 32 | #include "utils/backend.h" 33 | 34 | namespace lbcrypto { 35 | 36 | #define xorBlocks(x,y) _mm_xor_si128(x,y) 37 | #define zero_block() _mm_setzero_si128() 38 | #define unequal_blocks(x,y) (_mm_movemask_epi8(_mm_cmpeq_epi8(x,y)) != 0xffff) 39 | 40 | #define getLSB(x) (_mm_cvtsi128_si64(x)&1) 41 | #define makeBlock(X,Y) _mm_set_epi64((__m64)(X), (__m64)(Y)) 42 | #define getFromBlock(X,i) _mm_extract_epi64(X, i) 43 | 44 | #define DOUBLE(B) _mm_slli_epi64(B,1) 45 | 46 | #define SUCCESS 0 47 | #define FAILURE -1 48 | 49 | // #define STANDARD 50 | #define HALF_GATES 51 | 52 | #define FIXED_ZERO_GATE 0x00 53 | #define ANDGATE 0x08 54 | #define ORGATE 0x3e 55 | #define XORGATE 0x06 56 | #define XNORGATE 0x09 57 | #define NOTGATE 0x05 58 | #define FIXED_ONE_GATE 0x0f 59 | 60 | #ifdef STANDARD 61 | #define TABLE_SIZE 4 62 | #else 63 | #define TABLE_SIZE 2 64 | #endif 65 | 66 | #define TIMES 10 67 | #define RUNNING_TIME_ITER 100 68 | block randomBlock(); 69 | 70 | typedef struct { 71 | block label, label0, label1; 72 | } Wire; 73 | 74 | typedef struct { 75 | long input0, input1, output, type; 76 | } GarbledGate; 77 | 78 | typedef struct { 79 | block table[TABLE_SIZE]; 80 | } GarbledTable; 81 | 82 | typedef struct { 83 | int n, m, q, r; 84 | int n_c; 85 | block table_key; 86 | std::vector garbledGates; // Circuit topology 87 | std::vector outputs; // Indices of wires that are outputs 88 | std::vector garbledTable; // Tables 89 | std::vector wires; // Labels 90 | } GarbledCircuit; 91 | 92 | typedef struct { 93 | long wireIndex, gateIndex, tableIndex, outputIndex; 94 | } BuildContext; 95 | 96 | typedef std::vector> InputLabels; 97 | typedef std::vector ExtractedLabels; 98 | typedef std::vector OutputLabels; 99 | typedef osuCrypto::BitVector InputMap; 100 | typedef osuCrypto::BitVector OutputMap; 101 | 102 | } 103 | 104 | #endif 105 | -------------------------------------------------------------------------------- /src/lib/gc/gates.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of JustGarble. 3 | 4 | JustGarble is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | JustGarble is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with JustGarble. If not, see . 16 | 17 | */ 18 | 19 | #include "gc.h" 20 | #include "gates.h" 21 | 22 | namespace lbcrypto { 23 | 24 | void genericGate(GarbledCircuit *gc, BuildContext *context, ui64 in0, ui64 in1, ui64& output, ui64 type) { 25 | GarbledGate *garbledGate = &(gc->garbledGates[context->gateIndex]); 26 | output = context->wireIndex; 27 | 28 | garbledGate->type = type; 29 | garbledGate->input0 = in0; 30 | garbledGate->input1 = in1; 31 | garbledGate->output = output; 32 | 33 | if(in0 >= output || in1 >= output){ 34 | std::cout << in0 << " " << in1 << " " << output << std::endl; 35 | throw std::logic_error("bad circuit"); 36 | } 37 | 38 | context->wireIndex++; 39 | context->gateIndex++; 40 | if(type != XORGATE && type != XNORGATE){ 41 | context->tableIndex++; 42 | } 43 | } 44 | 45 | } -------------------------------------------------------------------------------- /src/lib/gc/gates.h: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of JustGarble. 3 | 4 | JustGarble is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | JustGarble is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with JustGarble. If not, see . 16 | 17 | */ 18 | 19 | #ifndef GATES_H_ 20 | #define GATES_H_ 21 | 22 | #include 23 | #include "gc.h" 24 | 25 | namespace lbcrypto { 26 | 27 | void genericGate(GarbledCircuit *gc, BuildContext *context, ui64 in0, ui64 in1, ui64& out, ui64 type); 28 | 29 | inline ui64 fixedZeroWire(GarbledCircuit *gc, BuildContext *garblingContext) { 30 | return gc->n; 31 | } 32 | 33 | inline ui64 fixedOneWire(GarbledCircuit *gc, BuildContext *garblingContext) { 34 | return (gc->n + 1); 35 | } 36 | 37 | inline void NOTGate(GarbledCircuit *gc, BuildContext *context, ui64 in, ui64& out) { 38 | return genericGate(gc, context, in, (gc->n + 1), out, XORGATE); 39 | } 40 | 41 | inline void ANDGate(GarbledCircuit *gc, BuildContext *context, ui64 in0, ui64 in1, ui64& out) { 42 | return genericGate(gc, context, in0, in1, out, ANDGATE); 43 | } 44 | 45 | inline void ORGate(GarbledCircuit *gc, BuildContext *context, ui64 in0, ui64 in1, ui64& out) { 46 | return genericGate(gc, context, in0, in1, out, ORGATE); 47 | } 48 | 49 | inline void XORGate(GarbledCircuit *gc, BuildContext *context, ui64 in0, ui64 in1, ui64& out) { 50 | return genericGate(gc, context, in0, in1, out, XORGATE); 51 | } 52 | 53 | inline void XNORGate(GarbledCircuit *gc, BuildContext *context, ui64 in0, ui64 in1, ui64& out) { 54 | return genericGate(gc, context, in0, in1, out, XNORGATE); 55 | } 56 | 57 | } 58 | 59 | #endif /* GATES_H_ */ -------------------------------------------------------------------------------- /src/lib/gc/gazelle_circuits.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * gazelle_circuits.cpp 3 | * 4 | * Created on: Nov 30, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | 9 | #include "gazelle_circuits.h" 10 | 11 | namespace lbcrypto { 12 | 13 | void A2BCircuit(GarbledCircuit *gc, BuildContext *context, 14 | const uv64& s_p, const uv64& s_c_x, const uv64& s_s_x, uv64& s_x) { 15 | ui64 n = s_c_x.size(); 16 | uv64 s_x_0, s_x_1; 17 | ui64 carry, nonneg; 18 | ADDCircuit(gc, context, s_c_x, s_s_x, s_x_0, carry); 19 | SUBCircuit(gc, context, s_x_0, s_p, s_x_1, nonneg); 20 | ui64 s_x_sel; 21 | ORGate(gc, context, carry, nonneg, s_x_sel); 22 | MUXCircuit(gc, context, s_x_0, s_x_1, s_x_sel, s_x); 23 | s_x.resize(n); 24 | } 25 | 26 | 27 | void B2ACircuit(GarbledCircuit *gc, BuildContext *context, 28 | const uv64& s_p, const uv64& s_x, const uv64& s_s_x, uv64& s_c_x) { 29 | ui64 n = s_c_x.size(); 30 | uv64 s_c_x_0, s_c_x_1; 31 | ui64 carry, nonneg; 32 | ADDCircuit(gc, context, s_x, s_s_x, s_c_x_0, carry); 33 | SUBCircuit(gc, context, s_c_x_0, s_p, s_c_x_1, nonneg); 34 | ui64 s_c_x_sel; 35 | ORGate(gc, context, carry, nonneg, s_c_x_sel); 36 | MUXCircuit(gc, context, s_c_x_0, s_c_x_1, s_c_x_sel, s_c_x); 37 | s_c_x.resize(n); 38 | } 39 | 40 | void ReLUCircuit(GarbledCircuit *gc, BuildContext *context, const uv64& s_p, 41 | const uv64& s_p_2, const uv64& s_c_x, const uv64& s_s_x, const uv64& s_s_y, 42 | uv64& s_c_y) { 43 | uv64 s_x, s_y; 44 | A2BCircuit(gc, context, s_p, s_c_x, s_s_x, s_x); 45 | MAXCircuit(gc, context, s_x, s_p_2, s_y); 46 | B2ACircuit(gc, context, s_p, s_y, s_s_y, s_c_y); 47 | } 48 | 49 | void Pool2Circuit(GarbledCircuit *gc, BuildContext *context, const uv64& s_p, 50 | const uv64& s_p_2, const std::vector& s_c_x, 51 | const std::vector& s_s_x, const uv64& s_s_y, uv64& s_c_y) { 52 | std::vector s_x(4); 53 | uv64 s_in = s_p_2; 54 | uv64 s_out; 55 | for(ui64 i=0; i<4; i++){ 56 | A2BCircuit(gc, context, s_p, s_c_x[i], s_s_x[i], s_x[i]); 57 | MAXCircuit(gc, context, s_x[i], s_in, s_out); 58 | s_in = s_out; 59 | } 60 | B2ACircuit(gc, context, s_p, s_out, s_s_y, s_c_y); 61 | } 62 | 63 | ui64 fill_vector(uv64& v, ui64 start){ 64 | ui64 count = start; 65 | for(ui64 i=0; i in(3, uv64(width)); 76 | uv64 out(width); 77 | uv64 s_p(width), s_p_2(width); 78 | 79 | int n = n_circ*width*3; 80 | int m = n_circ*width; 81 | 82 | startBuilding(&gc, &context, n, m, n_circ*1000); 83 | gc.n_c = n_circ*width; 84 | CONSTCircuit(&gc, &context, p, width, s_p); 85 | CONSTCircuit(&gc, &context, p/2, width, s_p_2); 86 | for(ui64 i=0; i c_x(4, uv64(width)); 101 | std::vector s_x(4, uv64(width)); 102 | uv64 s_y(width); 103 | uv64 c_y(width); 104 | 105 | int n = n_circ*width*9; 106 | int m = n_circ*width; 107 | 108 | startBuilding(&gc, &context, n, m, n_circ*2200); 109 | gc.n_c = 4*n_circ*width; 110 | uv64 s_p, s_p_2; 111 | CONSTCircuit(&gc, &context, p, width, s_p); 112 | CONSTCircuit(&gc, &context, p/2, width, s_p_2); 113 | for(ui64 i=0; i& s_c_x, 28 | const std::vector& s_s_x, const uv64& s_s_y, uv64& s_c_y); 29 | 30 | ui64 fill_vector(uv64& v, ui64 start); 31 | 32 | void buildRELULayer(GarbledCircuit& gc, BuildContext& context, 33 | ui64 width, ui64 n_circ, ui64 p); 34 | 35 | void buildPool2Layer(GarbledCircuit& gc, BuildContext& context, 36 | ui64 width, ui64 n_circ, ui64 p); 37 | 38 | void relu_ref(uv64& din, uv64& dref, ui64 mask, ui64 p); 39 | 40 | void pool2_ref(uv64& din, uv64& dref, ui64 mask, ui64 p); 41 | 42 | } 43 | 44 | #endif /* SRC_LIB_GC_GAZELLE_CIRCUITS_H_ */ 45 | -------------------------------------------------------------------------------- /src/lib/gc/gc.h: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of JustGarble. 3 | 4 | JustGarble is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | JustGarble is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with JustGarble. If not, see . 16 | 17 | */ 18 | 19 | #ifndef justGarble 20 | #define justGarble 1 21 | #include "common.h" 22 | 23 | #include 24 | 25 | namespace lbcrypto { 26 | 27 | /* 28 | * The following are the functions involved in creating, garbling, and 29 | * evaluating circuits. Most of the data-structures involved are defined 30 | * above, and the rest are in other header files. 31 | */ 32 | 33 | // Start and finish building a circuit. In between these two steps, gates 34 | // and sub-circuits can be added in. See AESFullTest and LargeCircuitTest 35 | // for examples. Note that the data-structures involved are GarbledCircuit 36 | // and GarblingContext, although these steps do not involve any garbling. 37 | // The reason for this is efficiency. Typically, garbleCircuit is called 38 | // right after finishBuilding. So, using a GarbledCircuit data-structure 39 | // here means that there is no need to create and initialize a new 40 | // data-structure just before calling garbleCircuit. 41 | int startBuilding(GarbledCircuit *gc, BuildContext *ctx, long n, long m, long q=50000); 42 | void addOutputs(GarbledCircuit *garbledCircuit, BuildContext *ctx, uv64& outputs); 43 | int finishBuilding(GarbledCircuit *garbledCircuit, BuildContext *ctx); 44 | 45 | //Garble the circuit described in garbledCircuit. For efficiency reasons, 46 | //we use the garbledCircuit data-structure for representing the input 47 | //circuit and the garbled output. The garbling process is non-destructive and 48 | //only affects the garbledTable member of the GarbledCircuit data-structure. 49 | //In other words, the same GarbledCircuit object can be reused multiple times, 50 | //to create multiple garbled circuit copies, 51 | //by just switching the garbledTable field to a new one. Also, the garbledTable 52 | //field is the only one that should be sent over the network in the case of an 53 | //MPC-type application, as the topology is expected to be avaiable to the 54 | //receiver, and the gate-types are to be hidden away. 55 | //The inputLabels field is expected to contain 2n fresh input labels, obtained 56 | //by calling createInputLabels. The outputMap is expected to be a 2m-block sized 57 | //empty array. 58 | long garbleCircuit(GarbledCircuit *garbledCircuit, InputLabels& inputLabels, 59 | OutputMap& outputMap); 60 | 61 | // A simple function that selects n input labels from 2n labels, using the 62 | // inputBits array where each element is a bit. 63 | void extractLabels(ExtractedLabels& extractedLabels, InputLabels& inputLabels, 64 | InputMap& inputBits); 65 | 66 | //Evaluate a garbled circuit, using n input labels in the Extracted Labels 67 | //to return m output labels. The garbled circuit might be generated either in 68 | //one piece, as the result of running garbleCircuit, or may be pieced together, 69 | // by building the circuit (startBuilding ... finishBuilding), and adding 70 | // garbledTable from another source, say, a network transmission. 71 | int evaluate(GarbledCircuit *garbledCircuit, ExtractedLabels& extractedLabels, 72 | OutputLabels& outputLabels); 73 | 74 | int evaluate_pt(GarbledCircuit *garbledCircuit, InputMap& inputMap, 75 | OutputMap& outputMap); 76 | 77 | // A simple function that takes 2m output labels, m labels from evaluate, 78 | // and returns a m bit output by matching the labels. If one or more of the 79 | // m evaluated labels donot match either of the two corresponding output labels, 80 | // then the function flags an error. 81 | void mapOutputs(OutputMap& outputMap, OutputLabels& outputLabels, OutputMap& extractedMap); 82 | 83 | 84 | void pack_inputs(std::vector& din, InputMap& inputMap, ui64 width); 85 | void unpack_outputs(OutputMap& outputMap, std::vector& dout, ui64 width); 86 | void print_results(std::vector& din, std::vector& dout_pt, 87 | std::vector& dout, std::vector& dref); 88 | 89 | // int writeCircuitToFile(GarbledCircuit *garbledCircuit, std::string fileName); 90 | // int readCircuitFromFile(GarbledCircuit *garbledCircuit, std::string fileName); 91 | 92 | } 93 | 94 | #endif 95 | -------------------------------------------------------------------------------- /src/lib/gc/scd.nocpp: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of JustGarble. 3 | 4 | JustGarble is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | JustGarble is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with JustGarble. If not, see . 16 | 17 | */ 18 | 19 | 20 | #include "justGarble.h" 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | long fsize(const char *filename) { 28 | struct stat st; 29 | 30 | if (stat(filename, &st) == 0) 31 | return st.st_size; 32 | 33 | return -1; 34 | } 35 | 36 | int writeCircuitToFile(GarbledCircuit *garbledCircuit, std::string fileName) { 37 | FILE *f = fopen(fileName.c_str(), "wb"); 38 | if (f == NULL) { 39 | printf("Write: Error in opening file.\n"); 40 | return FAILURE; 41 | } 42 | msgpack_sbuffer* buffer = msgpack_sbuffer_new(); 43 | msgpack_packer* pk = msgpack_packer_new(buffer, msgpack_sbuffer_write); 44 | msgpack_sbuffer_clear(buffer); 45 | int n = garbledCircuit->n; 46 | int q = garbledCircuit->q; 47 | int m = garbledCircuit->m; 48 | GarbledGate *garbledGate; 49 | msgpack_pack_array(pk, 3 + 3 * q + m); 50 | msgpack_pack_int(pk, n); 51 | msgpack_pack_int(pk, m); 52 | msgpack_pack_int(pk, q); 53 | int i; 54 | for (i = 0; i < q; i++) { 55 | garbledGate = &(garbledCircuit->garbledGates[i]); 56 | msgpack_pack_int(pk, garbledGate->input0); 57 | } 58 | for (i = 0; i < q; i++) { 59 | garbledGate = &(garbledCircuit->garbledGates[i]); 60 | msgpack_pack_int(pk, garbledGate->input1); 61 | } 62 | for (i = 0; i < q; i++) { 63 | garbledGate = &(garbledCircuit->garbledGates[i]); 64 | msgpack_pack_int(pk, garbledGate->type); 65 | } 66 | for (i = 0; i < m; i++) { 67 | msgpack_pack_int(pk, garbledCircuit->outputs[i]); 68 | 69 | } 70 | fwrite(buffer->data, (buffer->size), 1, f); 71 | fclose(f); 72 | return SUCCESS; 73 | } 74 | 75 | int readCircuitFromFile(GarbledCircuit *garbledCircuit, std::string fileName) { 76 | int fs = fsize(fileName.c_str()); 77 | FILE *f = fopen(fileName.c_str(), "rb"); 78 | if (f == NULL) { 79 | printf("READ:Error in opening file %s.\n", fileName.c_str()); 80 | return FAILURE; 81 | } 82 | msgpack_sbuffer* buffer = msgpack_sbuffer_new(); 83 | void *storage = malloc(fs); 84 | if (fread(storage, fs, 1, f) != 1) { 85 | printf("File not read completely.\n"); 86 | } 87 | fclose(f); 88 | buffer->data = (char *)storage; 89 | buffer->size = fs; 90 | msgpack_unpacked msg; 91 | msgpack_unpacked_init(&msg); 92 | msgpack_unpack_next(&msg, buffer->data, buffer->size, NULL); 93 | msgpack_object obj = msg.data; 94 | msgpack_object* p = obj.via.array.ptr; 95 | int n = (*p).via.i64; 96 | ++p; 97 | int m = (*p).via.i64; 98 | ++p; 99 | int q = (*p).via.i64; 100 | garbledCircuit->m = m; 101 | garbledCircuit->n = n; 102 | garbledCircuit->q = q; 103 | garbledCircuit->r = n+q+2; 104 | 105 | garbledCircuit->outputs.resize(m); 106 | garbledCircuit->garbledGates.resize(q); 107 | garbledCircuit->garbledTable.resize(q); 108 | garbledCircuit->wires.resize(garbledCircuit->r); 109 | 110 | int i; 111 | for (i = 0; i < q; i++) { 112 | garbledCircuit->garbledGates[i].output = n+i+1; 113 | } 114 | for (i = 0; i < q; i++) { 115 | ++p; 116 | garbledCircuit->garbledGates[i].input0 = (*p).via.i64; 117 | } 118 | for (i = 0; i < q; i++) { 119 | ++p; 120 | garbledCircuit->garbledGates[i].input1 = (*p).via.i64; 121 | } 122 | for (i = 0; i < q; i++) { 123 | ++p; 124 | garbledCircuit->garbledGates[i].type = (*p).via.i64; 125 | } 126 | for (i = 0; i < m; i++) { 127 | ++p; 128 | garbledCircuit->outputs[i] = (*p).via.i64; 129 | } 130 | return SUCCESS; 131 | } 132 | 133 | -------------------------------------------------------------------------------- /src/lib/gc/util.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | This file is part of JustGarble. 3 | 4 | JustGarble is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | JustGarble is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with JustGarble. If not, see . 16 | 17 | */ 18 | 19 | #include "aes.h" 20 | #include "common.h" 21 | #include "util.h" 22 | #include "gc.h" 23 | #include 24 | #include 25 | #include 26 | 27 | namespace lbcrypto { 28 | 29 | // static __m128i cur_seed; 30 | block __current_rand_index; 31 | AES_KEY __rand_aes_key; 32 | 33 | void countToN(ui64 *a, ui64 n) { 34 | for (ui64 i = 0; i < n; i++) 35 | a[i] = i; 36 | } 37 | 38 | int dbgBlock(block a) { 39 | int *A = (int *) &a; 40 | int i; 41 | int out = 0; 42 | for (i = 0; i < 4; i++) 43 | out = out + (A[i] + 13432) * 23517; 44 | return out; 45 | } 46 | 47 | int compare(const void * a, const void * b) { 48 | return (*(int*) a - *(int*) b); 49 | } 50 | 51 | int median(int *values, int n) { 52 | // int i; 53 | qsort(values, n, sizeof(int), compare); 54 | if (n % 2 == 1) 55 | return values[(n + 1) / 2]; 56 | else 57 | return (values[n / 2] + values[n / 2 + 1]) / 2; 58 | } 59 | 60 | double doubleMean(double *values, int n) { 61 | int i; 62 | double total = 0; 63 | for (i = 0; i < n; i++) 64 | total += values[i]; 65 | return total / n; 66 | } 67 | 68 | // This is only for testing and benchmark purposes. Use a more 69 | // secure seeding mechanism for actual use. 70 | int already_initialized = 0; 71 | void seedRandom() { 72 | if (!already_initialized) { 73 | already_initialized = 1; 74 | __current_rand_index = zero_block(); 75 | srand(time(NULL)); 76 | block cur_seed = _mm_set_epi32(rand(), rand(), rand(), rand()); 77 | AES_set_encrypt_key((unsigned char *) &cur_seed, 128, &__rand_aes_key); 78 | } 79 | } 80 | 81 | block randomBlock() { 82 | block out; 83 | const __m128i *sched = getRandContext(); 84 | randAESBlock(&out, sched); 85 | return out; 86 | } 87 | 88 | void print_block(block x){ 89 | ui64* x64 = (ui64*) &x; 90 | printf("%016llx%016llx", x64[1], x64[0]); 91 | } 92 | 93 | void print_gc(GarbledCircuit& gc){ 94 | std::cout << "n: " << gc.n <. 16 | 17 | */ 18 | 19 | #ifndef UTIL_H_ 20 | #define UTIL_H_ 21 | 22 | #include "common.h" 23 | #include "aes.h" 24 | #include 25 | 26 | namespace lbcrypto { 27 | 28 | void countToN(ui64 *a, ui64 N); 29 | int dbgBlock(block a); 30 | 31 | #define RDTSC ({unsigned long long res; unsigned hi, lo; __asm__ __volatile__ ("rdtsc" : "=a"(lo), "=d"(hi)); res = ( (unsigned long long)lo)|( ((unsigned long long)hi)<<32 );res;}) 32 | #define fbits( v, p) ((v & (1 << p))>>p) 33 | int getWords(char *line, char *words[], int maxwords); 34 | 35 | int median(int A[], int n); 36 | double doubleMean(double A[], int n); 37 | 38 | void seedRandom(void); 39 | block randomBlock(); 40 | void randAESBlock(block* out); 41 | 42 | // Compute AES in place. out is a block and sched is a pointer to an 43 | // expanded AES key. 44 | #define inPlaceAES(out, sched) {int jx; out = _mm_xor_si128(out, sched[0]);\ 45 | for (jx = 1; jx < 10; jx++)\ 46 | out = _mm_aesenc_si128(out, sched[jx]);\ 47 | out = _mm_aesenclast_si128(out, sched[jx]);} 48 | 49 | extern block __current_rand_index; 50 | extern AES_KEY __rand_aes_key; 51 | 52 | // #define getRandContext() ((__m128i *) (__rand_aes_key.rd_key)); 53 | // #define randAESBlock(out,sched) {__current_rand_index++; *out = __current_rand_index;inPlaceAES(*out,sched);} 54 | static inline block* getRandContext(void) {return __rand_aes_key.rd_key;}; 55 | static inline void randAESBlock(block* out,const block* sched) {__current_rand_index = __current_rand_index+1; *out = __current_rand_index;inPlaceAES(*out,sched);} 56 | 57 | void print_block(block x); 58 | 59 | void print_gc(GarbledCircuit& gc); 60 | 61 | } 62 | 63 | #endif /* UTIL_H_ */ 64 | -------------------------------------------------------------------------------- /src/lib/math/automorph.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "automorph.h" 3 | #include 4 | 5 | #include 6 | 7 | namespace lbcrypto { 8 | 9 | std::map g_automorph_index; 10 | 11 | std::vector base_decompose(const uv64& coeff, const ui32 window_size, const ui32 num_windows){ 12 | ui32 phim = coeff.size(); 13 | std::vector decomposed(num_windows, uv64(phim)); 14 | 15 | ui64 mask = (1 << window_size) - 1; 16 | for(ui32 j=0; j> window_size); 21 | } 22 | } 23 | 24 | return decomposed; 25 | } 26 | 27 | void precompute_automorph_index(const ui32 phim){ 28 | uv32 automorph_indices = uv32(phim); 29 | 30 | ui32 g = 1; 31 | ui32 phim_by_2 = phim >> 1; 32 | ui32 mask = (phim << 1) - 1; 33 | for(ui32 i=0; i base_decompose(const uv64& coeff, const ui32 window_size, const ui32 num_windows); 15 | 16 | void precompute_automorph_index(const ui32 phim); 17 | 18 | ui32 get_automorph_index(const ui32 rot, const ui32 phim); 19 | 20 | uv64 automorph(const uv64& input, const ui32 rot); 21 | 22 | uv64 automorph_pt(const uv64& input, const ui32 rot); 23 | } 24 | 25 | #endif /* LBCRYPTO_MATH_AUTOMORPH_H_ */ 26 | -------------------------------------------------------------------------------- /src/lib/math/bit_twiddle.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "math/bit_twiddle.h" 3 | 4 | namespace lbcrypto { 5 | 6 | // precomputed reverse of a byte 7 | 8 | inline static unsigned char reverse_byte(unsigned char x) 9 | { 10 | static const unsigned char table[] = { 11 | 0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0, 12 | 0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0, 13 | 0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8, 14 | 0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8, 15 | 0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4, 16 | 0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4, 17 | 0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec, 18 | 0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc, 19 | 0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2, 20 | 0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2, 21 | 0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea, 22 | 0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa, 23 | 0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6, 24 | 0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6, 25 | 0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee, 26 | 0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe, 27 | 0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1, 28 | 0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1, 29 | 0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9, 30 | 0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9, 31 | 0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5, 32 | 0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5, 33 | 0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed, 34 | 0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd, 35 | 0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3, 36 | 0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3, 37 | 0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb, 38 | 0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb, 39 | 0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7, 40 | 0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7, 41 | 0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef, 42 | 0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff, 43 | }; 44 | return table[x]; 45 | } 46 | 47 | /* Function to reverse bits of num */ 48 | ui32 ReverseBits(ui32 num, ui32 msb) 49 | { 50 | ui32 result; 51 | unsigned char * p = (unsigned char *) # 52 | unsigned char * q = (unsigned char *) &result; 53 | q[3] = reverse_byte(p[0]); 54 | q[2] = reverse_byte(p[1]); 55 | q[1] = reverse_byte(p[2]); 56 | q[0] = reverse_byte(p[3]); 57 | return (result) >> (32-msb); 58 | } 59 | 60 | ui32 log_pow2(ui32 v){ 61 | // Find log2 when v is a known power of two 62 | const ui32 b[] = {0xAAAAAAAA, 0xCCCCCCCC, 0xF0F0F0F0, 0xFF00FF00, 0xFFFF0000}; 63 | ui32 r = (v & b[0]) != 0; 64 | r |= ((v & b[1]) != 0) << 1; 65 | r |= ((v & b[2]) != 0) << 2; 66 | r |= ((v & b[3]) != 0) << 3; 67 | r |= ((v & b[4]) != 0) << 4; 68 | 69 | return r; 70 | } 71 | 72 | ui32 nxt_pow2(ui32 x){ 73 | x -= 1; 74 | 75 | // Fill LSB with ones 76 | x |= x >> 1; 77 | x |= x >> 2; 78 | x |= x >> 4; 79 | x |= x >> 8; 80 | x |= x >> 16; 81 | 82 | return x+1; 83 | } 84 | 85 | ui32 num_ones(ui32 x){ 86 | ui32 bits_set = 0; 87 | while(x != 0){ 88 | x = x & (x-1); 89 | bits_set++; 90 | } 91 | return bits_set; 92 | } 93 | 94 | } 95 | -------------------------------------------------------------------------------- /src/lib/math/bit_twiddle.h: -------------------------------------------------------------------------------- 1 | #ifndef LBCRYPTO_MATH_BIT_TWIDDLE_H 2 | #define LBCRYPTO_MATH_BIT_TWIDDLE_H 3 | 4 | #include "utils/backend.h" 5 | 6 | namespace lbcrypto { 7 | inline ui64 ones(ui32 n){ 8 | return ((ui64)1 << n)-1; 9 | } 10 | 11 | // Only works for positive num and den. (num cannot be zero) 12 | inline ui32 div_ceil(ui32 num, ui32 den){ 13 | return 1 + ((num - 1) / den); 14 | } 15 | 16 | /** 17 | * Method to reverse bits of num and return an unsigned int, for all bits up to an including the designated most significant bit. 18 | * 19 | * @param input an unsigned int 20 | * @param msb the most significant bit. All larger bits are disregarded. 21 | * 22 | * @return an unsigned integer that represents the reversed bits. 23 | */ 24 | ui32 ReverseBits(ui32 input, ui32 msb); 25 | 26 | /** 27 | * Get MSB of an unsigned 64 bit integer. 28 | * 29 | * @param x the input to find MSB of. 30 | * 31 | * @return the index of the MSB bit location. 32 | */ 33 | inline ui32 GetMSB64(uint64_t x) { 34 | if (x == 0) return 0; 35 | 36 | // hardware instructions for finding MSB are used are used; 37 | #if defined(_MSC_VER) 38 | // a wrapper for VC++ 39 | unsigned long msb; 40 | _BitScanReverse64(&msb, x); 41 | return msb + 1; 42 | #else 43 | // a wrapper for GCC 44 | return 64 - (sizeof(unsigned long) == 8 ? __builtin_clzl(x) : __builtin_clzll(x)); 45 | #endif 46 | } 47 | 48 | ui32 log_pow2(ui32 v); 49 | 50 | ui32 nxt_pow2(ui32 x); 51 | 52 | ui32 num_ones(ui32 x); 53 | 54 | } 55 | 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /src/lib/math/discretegaussiangenerator.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @file discretegaussiangenerator.cpp This code provides generation of gaussian distibutions of discrete values. 3 | * Discrete uniform generator relies on the built-in C++ generator for 32-bit unsigned integers defined in . 4 | * @author TPOC: palisade@njit.edu 5 | * 6 | * @copyright Copyright (c) 2017, New Jersey Institute of Technology (NJIT) 7 | * All rights reserved. 8 | * Redistribution and use in source and binary forms, with or without modification, 9 | * are permitted provided that the following conditions are met: 10 | * 1. Redistributions of source code must retain the above copyright notice, this 11 | * list of conditions and the following disclaimer. 12 | * 2. Redistributions in binary form must reproduce the above copyright notice, this 13 | * list of conditions and the following disclaimer in the documentation and/or other 14 | * materials provided with the distribution. 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS 21 | * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN 24 | * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | */ 27 | 28 | #include 29 | #include "discretegaussiangenerator.h" 30 | 31 | // #include 32 | 33 | namespace lbcrypto { 34 | 35 | DiscreteGaussianGenerator::DiscreteGaussianGenerator(double std) : DistributionGenerator() { 36 | m_std = std; 37 | 38 | m_vals.clear(); 39 | 40 | //weightDiscreteGaussian 41 | double acc = 1e-15; 42 | double variance = m_std * m_std; 43 | 44 | int fin = (int)ceil(m_std * sqrt(-2 * log(acc))); 45 | //this value of fin (M) corresponds to the limit for double precision 46 | // usually the bound of m_std * M is used, where M = 20 .. 40 - see DG14 for details 47 | // M = 20 corresponds to 1e-87 48 | //double mr = 20; // see DG14 for details 49 | //int fin = (int)ceil(m_std * mr); 50 | 51 | double cusum = 1.0; 52 | 53 | for (si32 x = 1; x <= fin; x++) { 54 | cusum = cusum + 2 * exp(-x * x / (variance * 2)); 55 | } 56 | 57 | m_a = 1 / cusum; 58 | 59 | //fin = (int)ceil(sqrt(-2 * variance * log(acc))); //not needed - same as above 60 | double temp; 61 | 62 | for (si32 i = 1; i <= fin; i++) { 63 | temp = m_a * exp(-((double)(i * i) / (2 * variance))); 64 | m_vals.push_back(temp); 65 | } 66 | 67 | // take cumulative summation 68 | for (ui32 i = 1; i < m_vals.size(); i++) { 69 | m_vals[i] += m_vals[i - 1]; 70 | } 71 | 72 | // for (ui32 i = 0; i &S, double search) const { 80 | //STL binary search implementation 81 | auto lower = std::lower_bound(S.begin(), S.end(), search); 82 | if (lower != S.end()) 83 | return lower - S.begin(); 84 | else 85 | throw std::runtime_error("DGG Inversion Sampling. FindInVector value not found: " + std::to_string(search)); 86 | } 87 | 88 | uv64 DiscreteGaussianGenerator::GenerateVector(const ui32 size, const ui64 &modulus) const { 89 | //we need to use the binary uniform generator rathen than regular continuous distribution; see DG14 for details 90 | std::uniform_real_distribution distribution(0.0, 1.0); 91 | 92 | uv64 ans(size); 93 | auto& prng = get_prng(); 94 | for (ui32 i = 0; i < size; i++) { 95 | double seed = distribution(prng) - 0.5; 96 | if (std::abs(seed) <= m_a / 2) { 97 | ans[i] = ui64(0); 98 | } else{ 99 | ui32 val = FindInVector(m_vals, (std::abs(seed) - m_a / 2)); 100 | if (seed > 0) { 101 | ans[i] = ui64(val+1); 102 | } else { 103 | ans[i] = ui64(modulus-val-1); 104 | } 105 | } 106 | } 107 | return ans; 108 | } 109 | 110 | } // namespace lbcrypto 111 | -------------------------------------------------------------------------------- /src/lib/math/discretegaussiangenerator.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file discretegaussiangenerator.h This code provides generation of gaussian distibutions of discrete values. 3 | * Discrete uniform generator relies on the built-in C++ generator for 32-bit unsigned integers defined in . 4 | * @author TPOC: palisade@njit.edu 5 | * 6 | * @copyright Copyright (c) 2017, New Jersey Institute of Technology (NJIT) 7 | * All rights reserved. 8 | * Redistribution and use in source and binary forms, with or without modification, 9 | * are permitted provided that the following conditions are met: 10 | * 1. Redistributions of source code must retain the above copyright notice, this 11 | * list of conditions and the following disclaimer. 12 | * 2. Redistributions in binary form must reproduce the above copyright notice, this 13 | * list of conditions and the following disclaimer in the documentation and/or other 14 | * materials provided with the distribution. 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS 21 | * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN 24 | * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | */ 27 | 28 | #ifndef LBCRYPTO_MATH_DISCRETEGAUSSIANGENERATOR_H_ 29 | #define LBCRYPTO_MATH_DISCRETEGAUSSIANGENERATOR_H_ 30 | 31 | #define _USE_MATH_DEFINES // added for Visual Studio support 32 | 33 | #include 34 | #include 35 | #include 36 | 37 | #include "utils/backend.h" 38 | #include "distributiongenerator.h" 39 | 40 | namespace lbcrypto { 41 | /** 42 | * @brief The class for Discrete Gaussion Distribution generator. 43 | */ 44 | class DiscreteGaussianGenerator : public DistributionGenerator { 45 | 46 | public: 47 | /** 48 | * @brief Basic constructor for specifying distribution parameter and modulus. 49 | * @param modulus The modulus to use to generate discrete values. 50 | * @param std The standard deviation for this Gaussian Distribution. 51 | */ 52 | DiscreteGaussianGenerator (double std = 4.0); 53 | 54 | /** 55 | * @brief Generates a vector of random values within this Discrete Gaussian Distribution. Uses Peikert's inversion method. 56 | * 57 | * @param size The number of values to return. 58 | * @param modulus modulus of the polynomial ring. 59 | * @return The vector of values within this Discrete Gaussian Distribution. 60 | */ 61 | uv64 GenerateVector (ui32 size, const ui64 &modulus) const; 62 | 63 | private: 64 | ui32 FindInVector (const std::vector &S, double search) const; 65 | 66 | // Gyana to add precomputation methods and data members 67 | // all parameters are set as int because it is assumed that they are used for generating "small" polynomials only 68 | double m_a; 69 | 70 | std::vector m_vals; 71 | 72 | /** 73 | * The standard deviation of the distribution. 74 | */ 75 | double m_std; 76 | 77 | }; 78 | 79 | } // namespace lbcrypto 80 | 81 | #endif // LBCRYPTO_MATH_DISCRETEGAUSSIANGENERATOR_H_ 82 | -------------------------------------------------------------------------------- /src/lib/math/distrgen.h: -------------------------------------------------------------------------------- 1 | /* 2 | * distrgen.h 3 | * 4 | * Created on: Aug 25, 2017 5 | * Author: chiraag 6 | * 7 | */ 8 | 9 | #ifndef LBCRYPTO_MATH_DISTRGEN_H_ 10 | #define LBCRYPTO_MATH_DISTRGEN_H_ 11 | 12 | #define _USE_MATH_DEFINES 13 | #include "distributiongenerator.h" 14 | #include "discretegaussiangenerator.h" 15 | 16 | #endif // LBCRYPTO_MATH_DISTRGEN_H_ 17 | -------------------------------------------------------------------------------- /src/lib/math/distributiongenerator.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * distributiongenerator.cpp 3 | * 4 | * Created on: Aug 25, 2017 5 | * Author: chiraag 6 | * 7 | */ 8 | 9 | #include "distributiongenerator.h" 10 | #include "math/params.h" 11 | #include 12 | #include 13 | 14 | namespace lbcrypto { 15 | osuCrypto::PRNG aes128_engine::m_prng(_mm_setzero_si128(), 256); 16 | 17 | aes128_engine& get_prng(){ 18 | // C++11 thread-safe static initialization 19 | // static thread_local std::mt19937_64 prng(std::random_device{}()); 20 | static thread_local aes128_engine prng; 21 | return prng; 22 | } 23 | 24 | /* std::mt19937_64& get_prng(){ 25 | // C++11 thread-safe static initialization 26 | static thread_local std::mt19937_64 prng(std::random_device{}()); 27 | return prng; 28 | }*/ 29 | 30 | uv64 get_bug_vector (const ui32 size) { 31 | auto prng = get_prng(); 32 | auto distribution = std::uniform_int_distribution(0, 1); 33 | 34 | uv64 v(size); 35 | for (ui32 i = 0; i < size; i++) { 36 | v[i] = distribution(prng); 37 | } 38 | return v; 39 | } 40 | 41 | uv64 get_tug_vector (const ui32 size, const ui64 modulus) { 42 | auto prng = get_prng(); 43 | auto distribution = std::uniform_int_distribution(-1,1); 44 | ui64 minus1 = modulus - 1; 45 | uv64 v(size); 46 | 47 | for(ui32 m=0; m(0, modulus-1); 66 | 67 | uv64 v(size); 68 | 69 | for (ui32 i = 0; i < size; i++) { 70 | v[i] = distribution(prng); 71 | } 72 | return v; 73 | 74 | } 75 | 76 | uv64 get_dug_vector_opt(const ui32 size) { 77 | auto prng = get_prng(); 78 | ui64 max = ((opt::q) << 4); 79 | 80 | uv64 v(size); 81 | 82 | for (ui32 i = 0; i < size;) { 83 | ui64 rand = prng(); 84 | if(rand < max){ 85 | v[i] = opt::modq_full(rand); 86 | i++; 87 | } 88 | } 89 | return v; 90 | } 91 | 92 | 93 | uv64 get_dgg_testvector(ui32 size, ui64 p, float std_dev){ 94 | std::normal_distribution distribution(0,std_dev); 95 | auto& prng = get_prng(); 96 | 97 | uv64 vec(size); 98 | for(ui32 i=0; i=0)? r : p+r; 101 | } 102 | return vec; 103 | } 104 | 105 | uv64 get_uniform_testvector(ui32 size, ui64 max){ 106 | std::uniform_int_distribution distribution(0, max); 107 | auto& prng = get_prng(); 108 | 109 | uv64 vec(size); 110 | for(ui32 i=0; i(); 118 | } 119 | 120 | 121 | 122 | } // namespace lbcrypto 123 | -------------------------------------------------------------------------------- /src/lib/math/distributiongenerator.h: -------------------------------------------------------------------------------- 1 | /* 2 | * distributiongenerator.h 3 | * 4 | * Created on: Aug 25, 2017 5 | * Author: chiraag 6 | * 7 | */ 8 | 9 | #ifndef LBCRYPTO_MATH_DISTRIBUTIONGENERATOR_H_ 10 | #define LBCRYPTO_MATH_DISTRIBUTIONGENERATOR_H_ 11 | 12 | //used to define a thread-safe generator 13 | #if defined (_MSC_VER) // Visual studio 14 | //#define thread_local __declspec( thread ) 15 | #elif defined (__GCC__) // GCC 16 | #define thread_local __thread 17 | #endif 18 | 19 | #include "utils/backend.h" 20 | #include 21 | #include 22 | #include 23 | 24 | namespace lbcrypto { 25 | 26 | // AES Engine 27 | struct aes128_engine { 28 | private: 29 | static osuCrypto::PRNG m_prng; 30 | 31 | public: 32 | aes128_engine(){}; 33 | ~aes128_engine() {}; 34 | using result_type = uint64_t; 35 | constexpr static result_type min() { return 0; } 36 | constexpr static result_type max() { return -1; } 37 | 38 | result_type operator()(); 39 | }; 40 | 41 | // Return a static generator object 42 | aes128_engine &get_prng(); 43 | // std::mt19937_64 &get_prng(); 44 | 45 | uv64 get_bug_vector(const ui32 size); 46 | 47 | uv64 get_tug_vector(const ui32 size, const ui64 modulus); 48 | 49 | uv64 get_dug_vector(const ui32 size, const ui64 modulus); 50 | 51 | uv64 get_dug_vector_opt(const ui32 size); 52 | 53 | uv64 get_dgg_testvector(ui32 size, ui64 p, float std_dev = 40.0); 54 | 55 | uv64 get_uniform_testvector(ui32 size, ui64 max); 56 | 57 | /** 58 | * @brief Abstract class describing generator requirements. 59 | * 60 | * The Distribution Generator defines the methods that must be implemented by a real generator. 61 | * It also holds the single PRNG, which should be called by all child class when generating a random number is required. 62 | * 63 | */ 64 | 65 | // Base class for Distribution Generator by type 66 | class DistributionGenerator { 67 | public: 68 | DistributionGenerator () {} 69 | virtual ~DistributionGenerator() {} 70 | }; 71 | 72 | } // namespace lbcrypto 73 | 74 | #endif // LBCRYPTO_MATH_DISTRIBUTIONGENERATOR_H_ 75 | -------------------------------------------------------------------------------- /src/lib/math/params.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * params.cpp 3 | * 4 | * Created on: Aug 25, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #include 9 | #include "math/params.h" 10 | 11 | namespace lbcrypto { 12 | namespace opt { 13 | ui32 logn = 11; 14 | ui32 phim = (1 << logn); 15 | 16 | /* 17 | //q (<60 bits), p (>18) bits 18 | ui64 q(1152921504346550273); 19 | ui64 z(236170385413442746); 20 | ui64 p(307201); 21 | ui64 z_p(227254); 22 | ui64 mu = ((ui64)1 << (2*19+2))/p; // [Adjusted for fast partial] 23 | */ 24 | 25 | 26 | //q (<60 bits), p (>19) bits 27 | ui64 q(1152921504499937281); 28 | ui64 z(246029739010950493); 29 | ui64 p(557057); 30 | ui64 z_p(201127); 31 | ui64 mu = ((ui64)1 << (2*20+2))/p; // [Adjusted for fast partial] 32 | 33 | 34 | /* //q (<60 bits), p (>20) bits 35 | ui64 q(1152921504414760961); 36 | ui64 z(1012134726195831682); 37 | ui64 p(1712129); 38 | ui64 z_p(290337); 39 | ui64 mu = ((ui64)1 << (2*21+2))/p; // [Adjusted for fast partial] 40 | ui64 mu_h = (mu >> 4); // [Adjusted for fast partial] 41 | ui64 mu_l = (mu % 16); // [Adjusted for fast partial] 42 | */ 43 | 44 | 45 | // ui64 z(824956925455712260); 46 | ui64 delta = ((ui64)1<<60)-q; 47 | ui64 delta16 = delta << 4; 48 | ui64 q4 = q << 2; 49 | ui64 delta2 = delta << 1; 50 | 51 | ui64 p2 = (p << 1); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/lib/math/params.h: -------------------------------------------------------------------------------- 1 | #ifndef LBCRYPTO_MATH_PARAMS_H 2 | #define LBCRYPTO_MATH_PARAMS_H 3 | 4 | #include "utils/backend.h" 5 | #include "math/bit_twiddle.h" 6 | 7 | namespace lbcrypto { 8 | namespace opt { 9 | extern ui32 logn; 10 | extern ui32 phim; 11 | 12 | extern ui64 q; 13 | extern ui64 z; 14 | extern ui64 p; 15 | extern ui64 z_p; 16 | extern ui64 mu; 17 | extern ui64 mu_h; 18 | extern ui64 mu_l; 19 | // extern ui64 z; 20 | extern ui64 delta; 21 | extern ui64 delta16; 22 | extern ui64 q4; 23 | extern ui64 delta2; 24 | extern ui64 p2; 25 | 26 | inline ui64 modp_part(ui64 a){ 27 | // The constant here is 2*ceil(log2(p))+2 28 | return (a - ((a*mu) >> 42)*p); 29 | } 30 | 31 | /*inline ui64 modp_part(ui64 a){ 32 | return (a - ((a*mu_h + ((a*mu_l) >> 4)) >> 40)*p); 33 | }*/ 34 | 35 | inline ui64 modp_full(ui64 a){ 36 | ui64 b = modp_part(a); 37 | return ((b >= p)? b-p: b); 38 | } 39 | 40 | inline ui64 modp_finalize(ui64 a){ 41 | return ((a >= p)? a-p: a); 42 | } 43 | 44 | inline ui64 modq_part(ui128 a){ 45 | ui128 b = (ui128)((ui64)a) + (a >> 64)*(ui128)delta16; // (64b) + (60+34=94b) = max(95b) 46 | ui64 c = (ui64)(b >> 61)*delta2 + ((ui64)b & ones(61)); // max (34+31=65b) + (61b) = 62b 47 | return c; 48 | } 49 | 50 | inline ui64 modq_full(ui128 a){ 51 | ui64 b = modq_part(a); 52 | while(b >= q){ 53 | b -= q; 54 | } 55 | return b; 56 | } 57 | 58 | inline ui64 modq_part(ui64 a){ 59 | return (a >> 60)*delta + (a & ones(60)); 60 | } 61 | 62 | inline ui64 modq_full(ui64 a){ 63 | ui64 b = modq_part(a); 64 | while(b >= q){ 65 | b -= q; 66 | } 67 | return b; 68 | } 69 | 70 | inline ui64 sub_modq_part(ui64 a, ui64 b){ 71 | return modq_part(a + q4 - b); 72 | } 73 | 74 | inline ui64 mul_modq_part(ui64 a, ui64 b){ 75 | ui128 c = (ui128)a*(ui128)b; // 124b number 76 | return modq_part(c); 77 | } 78 | 79 | inline ui64 lshift_modq_part(ui64 a, ui32 shift){ 80 | ui128 c = ((ui128)a << shift); // 124b number 81 | return modq_part(c); 82 | } 83 | } 84 | } 85 | 86 | 87 | #endif 88 | -------------------------------------------------------------------------------- /src/lib/math/transfrm.h: -------------------------------------------------------------------------------- 1 | /* 2 | * transform.h 3 | * 4 | * An initial draft of the transform code is inspired from my experiments with 5 | * the PALISADE transform code. 6 | * Created on: Aug 25, 2017 7 | * Author: chiraag 8 | * 9 | */ 10 | 11 | #ifndef LBCRYPTO_MATH_TRANSFRM_H 12 | #define LBCRYPTO_MATH_TRANSFRM_H 13 | 14 | 15 | #include "utils/backend.h" 16 | #include "nbtheory.h" 17 | //#include "../utils/utilities.h" 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | /** 26 | * @namespace lbcrypto 27 | * The namespace of lbcrypto 28 | */ 29 | namespace lbcrypto { 30 | 31 | uv64 ftt_fwd(const uv64& element, const ui64 modulus, const ui32 logn); 32 | 33 | uv64 ftt_inv(const uv64& element, const ui64 modulus, const ui32 logn); 34 | 35 | uv64 ftt_fwd_opt(const uv64& element); 36 | 37 | uv64 ftt_inv_opt(const uv64& element); 38 | 39 | uv64 ftt_fwd_opt_p(const uv64& element); 40 | 41 | uv64 ftt_inv_opt_p(const uv64& element); 42 | 43 | void ftt_precompute(const ui64 rootOfUnity, const ui64 modulus, const ui32 logn); 44 | 45 | void ftt_pre_compute(const uv64 &rootOfUnity, const uv64 &moduliiChain, const ui32 logn); 46 | 47 | } // namespace lbcrypto ends 48 | 49 | #endif 50 | -------------------------------------------------------------------------------- /src/lib/ot/cot_recv.h: -------------------------------------------------------------------------------- 1 | #ifndef OT_IKNPOTEXTRECV_H 2 | #define OT_IKNPOTEXTRECV_H 3 | 4 | // This file and the associated implementation has been placed in the public domain, waiving all copyright. No restrictions are placed on its use. 5 | #include 6 | #include "ot_ifc.h" 7 | 8 | namespace osuCrypto 9 | { 10 | 11 | class IKNPReceiver : 12 | public OTExtReceiver 13 | { 14 | public: 15 | IKNPReceiver() 16 | :mHasBase(false) 17 | {} 18 | 19 | bool hasBaseOts() const override 20 | { 21 | return mHasBase; 22 | } 23 | 24 | bool mHasBase; 25 | std::array, gOtExtBaseOtCount> mGens; 26 | 27 | void setBaseOts( 28 | span> baseSendOts)override; 29 | std::unique_ptr split() override; 30 | 31 | 32 | void receive( 33 | const BitVector& choices, 34 | span messages, 35 | PRNG& prng, 36 | Channel& chl 37 | ) override; 38 | 39 | }; 40 | 41 | } 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /src/lib/ot/cot_send.h: -------------------------------------------------------------------------------- 1 | #ifndef OT_IKNPOTEXTSENDER_H 2 | #define OT_IKNPOTEXTSENDER_H 3 | 4 | // This file and the associated implementation has been placed in the public domain, waiving all copyright. No restrictions are placed on its use. 5 | #include 6 | #include "ot_ifc.h" 7 | 8 | namespace osuCrypto { 9 | 10 | class IKNPSender : 11 | public OTExtSender 12 | { 13 | public: 14 | std::array mGens; 15 | BitVector mBaseChoiceBits; 16 | block mDelta; 17 | std::unique_ptr split() override; 18 | 19 | bool hasBaseOts() const override 20 | { 21 | return mBaseChoiceBits.size() > 0; 22 | } 23 | 24 | void setBaseOts( 25 | span baseRecvOts, 26 | const BitVector& choices) override; 27 | 28 | 29 | void send( 30 | span> in_data, 31 | PRNG& prng, 32 | Channel& chl/*, 33 | std::atomic& doneIdx*/) override; 34 | 35 | void set_delta(block delta); 36 | 37 | }; 38 | } 39 | 40 | #endif 41 | -------------------------------------------------------------------------------- /src/lib/ot/ot_ifc.h: -------------------------------------------------------------------------------- 1 | #ifndef OT_OTEXTIFC_H 2 | #define OT_OTEXTIFC_H 3 | 4 | 5 | // This file and the associated implementation has been placed in the public domain, waiving all copyright. No restrictions are placed on its use. 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #ifdef GetMessage 12 | #undef GetMessage 13 | #endif 14 | 15 | 16 | namespace osuCrypto 17 | { 18 | 19 | // The hard coded number of base OT that is expected by the OT Extension implementations. 20 | // This can be changed if the code is adequately adapted. 21 | const u64 gOtExtBaseOtCount(128); 22 | const u64 gCommStepSize(512); 23 | const u64 gSuperBlkSize(8); 24 | 25 | 26 | class BaseOTReceiver 27 | { 28 | public: 29 | BaseOTReceiver() {} 30 | virtual ~BaseOTReceiver(){}; 31 | 32 | virtual void receive( 33 | const BitVector& choices, 34 | span messages, 35 | PRNG& prng, 36 | Channel& chl) = 0; 37 | 38 | }; 39 | 40 | class BaseOTSender 41 | { 42 | public: 43 | BaseOTSender() {} 44 | virtual ~BaseOTSender(){}; 45 | 46 | virtual void send( 47 | span> messages, 48 | PRNG& prng, 49 | Channel& chl) = 0; 50 | 51 | }; 52 | 53 | 54 | 55 | class OTExtReceiver 56 | { 57 | public: 58 | OTExtReceiver() {} 59 | virtual ~OTExtReceiver(){}; 60 | 61 | virtual void setBaseOts( 62 | span> baseSendOts) = 0; 63 | 64 | virtual bool hasBaseOts() const = 0; 65 | virtual std::unique_ptr split() = 0; 66 | 67 | virtual void receive( 68 | const BitVector& choices, 69 | span out_data, 70 | PRNG& prng, 71 | Channel& chl) = 0; 72 | 73 | }; 74 | 75 | class OTExtSender 76 | { 77 | public: 78 | OTExtSender() {}; 79 | virtual ~OTExtSender(){}; 80 | 81 | virtual bool hasBaseOts() const = 0; 82 | 83 | virtual void setBaseOts( 84 | span baseRecvOts, 85 | const BitVector& choices) = 0; 86 | 87 | virtual std::unique_ptr split() = 0; 88 | 89 | virtual void send( 90 | span> in_data, 91 | PRNG& prng, 92 | Channel& chl) = 0; 93 | }; 94 | 95 | } 96 | 97 | #endif 98 | -------------------------------------------------------------------------------- /src/lib/ot/sr_base_ot.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // This file and the associated implementation has been placed in the public domain, waiving all copyright. No restrictions are placed on its use. 3 | 4 | #include 5 | #include 6 | #include "ot_ifc.h" 7 | 8 | namespace osuCrypto 9 | { 10 | 11 | class SRBaseOT : public BaseOTReceiver, public BaseOTSender 12 | { 13 | public: 14 | 15 | SRBaseOT(); 16 | ~SRBaseOT(); 17 | 18 | void receive( 19 | const BitVector& choices, 20 | span messages, 21 | PRNG& prng, 22 | Channel& chl, 23 | u64 numThreads); 24 | 25 | void send( 26 | span> messages, 27 | PRNG& prng, 28 | Channel& sock, 29 | u64 numThreads); 30 | 31 | void receive( 32 | const BitVector& choices, 33 | span messages, 34 | PRNG& prng, 35 | Channel& chl) override 36 | { 37 | receive(choices, messages, prng, chl, 2); 38 | } 39 | 40 | void send( 41 | span> messages, 42 | PRNG& prng, 43 | Channel& sock) override 44 | { 45 | send(messages, prng, sock, 2); 46 | } 47 | }; 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/lib/ot/tools.h: -------------------------------------------------------------------------------- 1 | #ifndef OT_TOOLS_H 2 | #define OT_TOOLS_H 3 | 4 | // This file and the associated implementation has been placed in the public domain, waiving all copyright. No restrictions are placed on its use. 5 | 6 | #include 7 | #include 8 | #include 9 | namespace osuCrypto { 10 | 11 | void eklundh_transpose128(std::array& inOut); 12 | void sse_transpose128(std::array& inOut); 13 | void print(std::array& inOut); 14 | u8 getBit(std::array& inOut, u64 i, u64 j); 15 | 16 | void sse_transpose128x1024(std::array, 128>& inOut); 17 | 18 | void sse_transpose(const MatrixView& in, const MatrixView& out); 19 | //void sse_transpose_new(const MatrixView& in, const MatrixView& out); 20 | void sse_transpose(const MatrixView& in, const MatrixView& out); 21 | 22 | } 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /src/lib/pke/conv1d.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * conv1d.cpp 3 | * 4 | * Created on: Sep 1, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #include "math/automorph.h" 9 | #include "pke/encoding.h" 10 | #include "pke/fv.h" 11 | 12 | #include "pke/conv1d.h" 13 | 14 | #include "utils/test.h" 15 | #include 16 | 17 | namespace lbcrypto { 18 | 19 | // TODO: Massive preprocessing savings possible by computing for all ones filter 20 | // scaling for the appropriate filter 21 | EncMat preprocess_filter_1d(const uv64& filter, const ui32 window_size, 22 | const ui32 num_windows, const FVParams& params){ 23 | // Create the diagonal rotation of the filter matrix 24 | ui32 filter_size = filter.size(); 25 | ui32 offset = (filter_size-1)/2; 26 | std::vector filter_mat(filter_size, uv64(params.phim, 0)); 27 | for(ui32 row=0; row= offset) && ( col + row < offset + params.phim)) { 30 | filter_mat[row][col] = filter[row]; 31 | } 32 | } 33 | } 34 | 35 | EncMat enc_filter(filter_size, std::vector(num_windows, uv64(params.phim))); 36 | for(ui32 row=0; row> 1)-1; 49 | 50 | CTMat ct_mat(filter_size, std::vector(ct_vec.size(), Ciphertext(params.phim))); 51 | for(ui32 w=0; w= offset) && ( y+f < phim + offset); 97 | ui64 in = ((not_edge) ? vec[(y+f-offset) & mask] : 0); 98 | conv[y] = (conv[y] + in*filter[f]) % p; 99 | } 100 | } 101 | 102 | return conv; 103 | } 104 | 105 | } 106 | 107 | 108 | -------------------------------------------------------------------------------- /src/lib/pke/conv1d.h: -------------------------------------------------------------------------------- 1 | /* 2 | * conv1d.h 3 | * 4 | * Created on: Sep 1, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #ifndef SRC_LIB_PKE_CONV1D_H_ 9 | #define SRC_LIB_PKE_CONV1D_H_ 10 | 11 | #include "utils/backend.h" 12 | #include "pke/layers.h" 13 | #include "pke_types.h" 14 | 15 | namespace lbcrypto{ 16 | 17 | EncMat preprocess_filter_1d(const uv64& filter, const ui32 window_size, 18 | const ui32 num_windows, const FVParams& params); 19 | 20 | CTMat conv_1d_rot(const CTVec& ct_vec, const ui32& filter_size, const FVParams& params); 21 | 22 | Ciphertext conv_1d_mul(const CTMat& ct_mat, const EncMat& enc_filter, const FVParams& params); 23 | 24 | Ciphertext conv_1d_online(const CTVec& ct_vec, const EncMat& enc_filter, const FVParams& params); 25 | 26 | uv64 conv_1d_pt(const uv64& vec, const uv64& filter, const ui32 p); 27 | } 28 | 29 | 30 | 31 | #endif /* SRC_LIB_PKE_CONV1D_H_ */ 32 | -------------------------------------------------------------------------------- /src/lib/pke/conv2d.h: -------------------------------------------------------------------------------- 1 | /* 2 | * conv2d.h 3 | * 4 | * Created on: Sep 1, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #ifndef SRC_LIB_PKE_CONV2D_H_ 9 | #define SRC_LIB_PKE_CONV2D_H_ 10 | 11 | #include "utils/backend.h" 12 | #include "pke/layers.h" 13 | #include "pke_types.h" 14 | 15 | namespace lbcrypto { 16 | 17 | CTMat preprocess_ifmap(const SecretKey& sk, const ConvLayer& pt, 18 | const ui32 window_size, const ui32 num_windows, const FVParams& params); 19 | 20 | EncMat preprocess_filter(const Filter2D& filter, const ConvShape& shape, 21 | const ui32 window_size, const ui32 num_windows, const FVParams& params); 22 | 23 | CTVec conv_2d_online(const CTMat& ct_mat, const EncMat& enc_mat, 24 | const Filter2DShape& filter_shape, const ConvShape& in_shape, const FVParams& params); 25 | 26 | EncMat preprocess_filter_2stage(const Filter2D& filter, const ConvShape& shape, 27 | const ui32 window_size, const ui32 num_windows, const FVParams& params); 28 | 29 | CTVec conv_2d_2stage_online(const CTMat& ct_mat, const EncMat& enc_mat, 30 | const Filter2DShape& filter_shape, const ConvShape& in_shape, const FVParams& params); 31 | 32 | 33 | ConvLayer postprocess_conv(const SecretKey& sk, const CTVec& ct_vec, 34 | const ConvShape& shape, const FVParams& params); 35 | 36 | ConvLayer conv_2d_pt(const ConvLayer& in, const Filter2D& filter, bool same, const ui32 p); 37 | 38 | bool check_conv(const ConvLayer& ofmap, const ConvLayer& ofmap_ref); 39 | } 40 | 41 | 42 | 43 | 44 | #endif /* SRC_LIB_PKE_CONV2D_H_ */ 45 | -------------------------------------------------------------------------------- /src/lib/pke/encoding.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | encoding.h: This code implements packed integer encoding 3 | 4 | List of Authors: 5 | Chiraag Juvekar, chiraag@mit.edu 6 | 7 | License Information: 8 | MIT License 9 | Copyright (c) 2017, Massachusetts Institute of Technology (MIT) 10 | 11 | */ 12 | 13 | #ifndef LBCRYPTO_CRYPTO_FV_C 14 | #define LBCRYPTO_CRYPTO_FV_C 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | #include "math/params.h" 21 | #include "math/transfrm.h" 22 | 23 | namespace lbcrypto { 24 | std::map g_to_ftt_map; 25 | std::map g_from_ftt_map; 26 | 27 | void encoding_precompute(const ui64& mod_p, const ui32& logn){ 28 | ui32 phim = (1 << logn); 29 | ui32 phim_by_2 = phim/2; 30 | ui32 mask = phim*2-1; 31 | 32 | // Create the permutations that interchange the automorphism and crt ordering 33 | // First we create the cyclic group generated by 5 and then adjoin the co-factor by multiplying by 3 34 | uv32 to_ftt_perm(phim); 35 | uv32 from_ftt_perm(phim); 36 | 37 | ui32 curr_index = 1; 38 | for (ui32 i = 0; i < phim_by_2; i++) { 39 | to_ftt_perm[(curr_index - 1) / 2] = i; 40 | from_ftt_perm[i] = (curr_index - 1) / 2; 41 | 42 | ui32 cofactor_index = (curr_index * mask) & mask; 43 | to_ftt_perm[(cofactor_index - 1) / 2] = i + phim_by_2; 44 | from_ftt_perm[i + phim_by_2] = (cofactor_index - 1) / 2; 45 | 46 | curr_index = (curr_index * 5) & mask; 47 | } 48 | g_to_ftt_map[mod_p] = std::move(to_ftt_perm); 49 | g_from_ftt_map[mod_p] = std::move(from_ftt_perm); 50 | 51 | return; 52 | } 53 | 54 | uv64 packed_encode(const uv64& input, const ui64 mod_p, const ui32 logn){ 55 | ui32 phim = (1< 14 | using std::shared_ptr; 15 | 16 | #include "utils/backend.h" 17 | #include "math/transfrm.h" 18 | #include "pke_types.h" 19 | 20 | namespace lbcrypto { 21 | 22 | /** 23 | * @brief This is the parameters class for the FV encryption scheme. 24 | * 25 | * The FV scheme parameter guidelines are introduced here: 26 | * - Junfeng Fan and Frederik Vercauteren. Somewhat Practical Fully Homomorphic Encryption. Cryptology ePrint Archive, Report 2012/144. (https://eprint.iacr.org/2012/144.pdf) 27 | * 28 | * We used the optimized parameter selection from the designs here: 29 | * - Lepoint T., Naehrig M. (2014) A Comparison of the Homomorphic Encryption Schemes FV and YASHE. In: Pointcheval D., Vergnaud D. (eds) Progress in Cryptology – AFRICACRYPT 2014. AFRICACRYPT 2014. Lecture Notes in Computer Science, vol 8469. Springer, Cham. (https://eprint.iacr.org/2014/062.pdf) 30 | * 31 | * @tparam Element a ring element type. 32 | */ 33 | struct FVParams { 34 | bool fast_modulli; 35 | 36 | ui64 q, p; 37 | ui32 logn, phim; 38 | 39 | // delta = floor(modq/modp) __NOT__ the delta from params 40 | ui64 delta; 41 | 42 | // specifies whether the keys are generated from discrete 43 | // Gaussian distribution or ternary distribution with the norm of unity 44 | MODE mode; 45 | shared_ptr dgg; 46 | 47 | ui32 window_size; 48 | }; 49 | 50 | extern std::map> g_rk_map; 51 | 52 | uv64 inline ToCoeff(const uv64& eval, const FVParams& params){ 53 | if(params.fast_modulli){ 54 | return ftt_inv_opt(eval); 55 | } else { 56 | return ftt_inv(eval, params.q, params.logn); 57 | } 58 | } 59 | 60 | uv64 inline ToEval(const uv64& coeff, const FVParams& params){ 61 | if(params.fast_modulli){ 62 | return ftt_fwd_opt(coeff); 63 | } else { 64 | return ftt_fwd(coeff, params.q, params.logn); 65 | } 66 | } 67 | 68 | uv64 NullEncrypt(uv64& pt, const FVParams& params); 69 | 70 | Ciphertext Encrypt(const PublicKey& pk, uv64& pt, const FVParams& params); 71 | 72 | Ciphertext Encrypt(const SecretKey& sk, uv64& pt, const FVParams& params); 73 | 74 | uv64 Decrypt(const SecretKey& sk, const Ciphertext& ct, const FVParams& params); 75 | 76 | sv64 Noise(const SecretKey& sk, const Ciphertext& ct, const FVParams& params); 77 | 78 | double NoiseMargin(const SecretKey& sk, const Ciphertext& ct, const FVParams& params); 79 | 80 | KeyPair KeyGen(const FVParams& params); 81 | 82 | Ciphertext EvalAdd(const Ciphertext& ct1, const Ciphertext& ct2, const FVParams& params); 83 | 84 | Ciphertext EvalAddPlain(const Ciphertext& ct, const uv64& pt, const FVParams& params); 85 | 86 | Ciphertext EvalSub(const Ciphertext& ct1, const Ciphertext& ct2, const FVParams& params); 87 | 88 | Ciphertext EvalSubPlain(const Ciphertext& ct, const uv64& pt, const FVParams& params); 89 | 90 | Ciphertext EvalMultPlain(const Ciphertext& ct, const uv64& pt, const FVParams& params); 91 | 92 | Ciphertext EvalNegate(const Ciphertext& ct, const FVParams& params); 93 | 94 | std::vector HoistedDecompose(const Ciphertext& ct, const FVParams& params); 95 | 96 | Ciphertext KeySwitchDigits(const RelinKey& rk, const Ciphertext& ct, 97 | const std::vector digits_ct, const FVParams& params); 98 | 99 | RelinKey KeySwitchGen(const SecretKey& orig_sk, const SecretKey& new_sk, const FVParams& params); 100 | 101 | Ciphertext KeySwitch(const RelinKey& relin_key, const Ciphertext& ct, const FVParams& params); 102 | 103 | Ciphertext EvalAutomorphismDigits(const ui32 rot, const RelinKey& rk, const Ciphertext& ct, 104 | const std::vector& digits_ct, const FVParams& params); 105 | 106 | Ciphertext EvalAutomorphism(const ui32 rot, const Ciphertext& ct, const FVParams& params); 107 | 108 | shared_ptr GetAutomorphismKey(ui32 rot); 109 | 110 | void EvalAutomorphismKeyGen(const SecretKey& sk, const uv32& index_list, const FVParams& params); 111 | 112 | Ciphertext AddRandomNoise(const Ciphertext& ct, const FVParams& params); 113 | 114 | } // namespace lbcrypto ends 115 | #endif 116 | -------------------------------------------------------------------------------- /src/lib/pke/gazelle.h: -------------------------------------------------------------------------------- 1 | /* 2 | * gazelle.h 3 | * 4 | * Created on: Aug 25, 2017 5 | * Author: chiraag 6 | * 7 | */ 8 | 9 | #ifndef SRC_LIB_GAZELLE_H_ 10 | #define SRC_LIB_GAZELLE_H_ 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include "../utils/backend.h" 17 | #include "utils/debug.h" 18 | #include "utils/test.h" 19 | 20 | #include "math/params.h" 21 | #include "math/distrgen.h" 22 | #include "math/automorph.h" 23 | #include "math/transfrm.h" 24 | 25 | #include "pke/encoding.h" 26 | #include "pke/fv.h" 27 | #include "pke/layers.h" 28 | #include "pke/mat_mul.h" 29 | #include "pke/gemm.h" 30 | #include "pke/square.h" 31 | #include "pke/conv1d.h" 32 | #include "pke/conv2d.h" 33 | #include "pke/pke_types.h" 34 | 35 | #endif /* SRC_LIB_GAZELLE_H_ */ 36 | -------------------------------------------------------------------------------- /src/lib/pke/gemm.h: -------------------------------------------------------------------------------- 1 | /* 2 | * mat_mul.h 3 | * 4 | * Created on: Sep 1, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #ifndef SRC_LIB_PKE_GEMM_H_ 9 | #define SRC_LIB_PKE_GEMM_H_ 10 | 11 | #include "utils/backend.h" 12 | #include "pke/layers.h" 13 | #include "pke_types.h" 14 | 15 | namespace lbcrypto{ 16 | CTMat preprocess_gemm_c(const SecretKey& sk, const std::vector& mat, 17 | const ui32 window_size, const ui32 num_windows, const FVParams& params); 18 | 19 | EncMat preprocess_gemm_s(const std::vector& mat, const ui32 num_cols_c, 20 | const ui32 window_size, const ui32 num_windows, const FVParams& params); 21 | 22 | CTVec gemm_online(const CTMat& ct_mat_c, const EncMat& enc_mat_s, 23 | const ui32 num_cols_c, const FVParams& params); 24 | 25 | CTVec gemm_phim_online(const CTMat& ct_mat_c, const std::vector& mat_s_t, 26 | const ui32 window_size, const ui32 num_windows, const FVParams& params); 27 | 28 | std::vector postprocess_gemm(const SecretKey& sk, const CTVec& ct_prod, 29 | const ui32 num_rows, const ui32 num_cols, const FVParams& params); 30 | 31 | std::vector gemm_pt(const std::vector& mat_c, 32 | const std::vector& mat_s_t, const ui64 p); 33 | } 34 | 35 | 36 | 37 | 38 | #endif /* SRC_LIB_PKE_MAT_MUL_H_ */ 39 | -------------------------------------------------------------------------------- /src/lib/pke/layers.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * layers.cpp 3 | * 4 | * Created on: Aug 28, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #include "pke/fv.h" 9 | 10 | #include "pke/layers.h" 11 | 12 | namespace lbcrypto{ 13 | /* 14 | CTVec preprocess_vec(const SecretKey& sk, const uv64& pt, 15 | const ui32 window_size, const ui32 num_windows, const FVParams& params){ 16 | // Expand the input with multiples of the plaintext base 17 | std::vector pt_scaled(num_windows, uv64(params.phim)); 18 | for (ui32 w=0; w CTVec; 18 | typedef std::vector> CTMat; 19 | 20 | typedef std::vector> EncMat; 21 | 22 | struct Filter2DShape{ 23 | ui32 out_chn, in_chn, f_h, f_w; 24 | 25 | Filter2DShape(ui32 out_chn, ui32 in_chn, ui32 f_h, ui32 f_w) : 26 | out_chn(out_chn), in_chn(in_chn), f_h(f_h), f_w(f_w) {}; 27 | }; 28 | 29 | struct ConvShape{ 30 | ui32 chn, h, w; 31 | 32 | ConvShape(ui32 chn, ui32 h, ui32 w) : 33 | chn(chn), h(h), w(w) {}; 34 | }; 35 | 36 | struct Filter2D{ 37 | Filter2DShape shape; 38 | std::vector>> w; 39 | uv64 b; 40 | 41 | Filter2D(ui32 out_chn, ui32 in_chn, ui32 f_h, ui32 f_w) : 42 | shape(out_chn, in_chn, f_h, f_w), 43 | w(out_chn, std::vector>(in_chn, std::vector(f_h, uv64(f_w)))), 44 | b(out_chn) {}; 45 | }; 46 | 47 | struct ConvLayer{ 48 | ConvShape shape; 49 | std::vector> act; 50 | 51 | ConvLayer(ui32 chn, ui32 h, ui32 w) : 52 | shape(chn, h, w), act(chn, std::vector(h, uv64(w))) {}; 53 | }; 54 | 55 | } 56 | 57 | 58 | 59 | #endif /* SRC_LIB_PKE_LAYERS_H_ */ 60 | -------------------------------------------------------------------------------- /src/lib/pke/mat_mul.h: -------------------------------------------------------------------------------- 1 | /* 2 | * mat_mul.h 3 | * 4 | * Created on: Sep 1, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #ifndef SRC_LIB_PKE_MAT_MUL_H_ 9 | #define SRC_LIB_PKE_MAT_MUL_H_ 10 | 11 | #include "utils/backend.h" 12 | #include "pke/layers.h" 13 | #include "pke_types.h" 14 | 15 | namespace lbcrypto{ 16 | CTVec preprocess_vec(const SecretKey& sk, const uv64& vec, 17 | const ui32 window_size, const ui32 num_windows, const FVParams& params); 18 | 19 | EncMat preprocess_matrix(const std::vector& mat, 20 | const ui32 window_size, const ui32 num_windows, const FVParams& params); 21 | 22 | Ciphertext mat_mul_online(const CTVec& vec, const EncMat& enc_mat, 23 | const ui32 pack_factor, const FVParams& params); 24 | 25 | uv64 postprocess_prod(const SecretKey& sk, const Ciphertext& ct_prod, 26 | const ui32 vec_size, const ui32 num_rows, const FVParams& params); 27 | 28 | uv64 mat_mul_pt(const uv64& vec, const std::vector& mat, const ui64 p); 29 | } 30 | 31 | 32 | 33 | 34 | #endif /* SRC_LIB_PKE_MAT_MUL_H_ */ 35 | -------------------------------------------------------------------------------- /src/lib/pke/pke_types.h: -------------------------------------------------------------------------------- 1 | /* 2 | * pke_types.h 3 | * 4 | * Barebones data-structures 5 | * Created on: Aug 25, 2017 6 | * Author: chiraag 7 | * 8 | */ 9 | 10 | #ifndef LBCRYPTO_CRYPTO_PUBKEYLP_H 11 | #define LBCRYPTO_CRYPTO_PUBKEYLP_H 12 | 13 | #include 14 | #include "math/distrgen.h" 15 | 16 | 17 | namespace lbcrypto { 18 | struct Ciphertext { 19 | uv64 a; 20 | uv64 b; 21 | 22 | Ciphertext(ui32 size) : a(size), b(size) {}; 23 | }; 24 | 25 | struct PublicKey { 26 | uv64 a; 27 | uv64 b; 28 | 29 | PublicKey(ui32 size) : a(size), b(size) {}; 30 | }; 31 | 32 | struct RelinKey { 33 | std::vector a; 34 | std::vector b; 35 | 36 | RelinKey(ui32 size, ui32 windows) : a(windows, uv64(size)), b(windows, uv64(size)) {}; 37 | }; 38 | 39 | 40 | struct SecretKey { 41 | uv64 s; 42 | 43 | SecretKey(ui32 size) : s(size) {}; 44 | }; 45 | 46 | struct KeyPair { 47 | public: 48 | PublicKey pk; 49 | SecretKey sk; 50 | 51 | KeyPair(const PublicKey& pk, const SecretKey& sk) : pk(pk), sk(sk) {}; 52 | }; 53 | 54 | } 55 | #endif 56 | -------------------------------------------------------------------------------- /src/lib/pke/square.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * mat_mul.cpp 3 | * 4 | * Created on: Sep 1, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #include "math/bit_twiddle.h" 9 | #include "math/automorph.h" 10 | #include "math/params.h" 11 | #include "pke/encoding.h" 12 | #include "pke/fv.h" 13 | 14 | #include "pke/square.h" 15 | 16 | #include "utils/test.h" 17 | #include 18 | #include 19 | 20 | namespace lbcrypto{ 21 | 22 | CTVec preprocess_client_share(const SecretKey& sk, const uv64& vec, const FVParams& params){ 23 | std::vector pt(2, uv64(params.phim)); 24 | pt[0] = vec; 25 | for(ui32 n=0; n, uv64> preprocess_server_share(const uv64& vec, const FVParams& params){ 43 | std::vector pt(3, uv64(params.phim)); 44 | for(ui32 n=0; n ct_vec(3, uv64(params.phim)); 56 | for(ui32 i=0; i<3; i++){ 57 | auto pt_enc = packed_encode(pt[i], params.p, params.logn); 58 | if(i != 0){ 59 | for(ui32 n=0; n& pt_vec_s, const FVParams& params){ 70 | auto ct_share = EvalMultPlain(ct_vec_c[0], pt_vec_s[0], params); 71 | ct_share = EvalAdd(ct_share, ct_vec_c[1], params); 72 | ct_share = EvalAddPlain(ct_share, pt_vec_s[1], params); 73 | ct_share = EvalAddPlain(ct_share, pt_vec_s[2], params); 74 | return ct_share; 75 | } 76 | 77 | uv64 postprocess_client_share(const SecretKey& sk, const Ciphertext& ct, 78 | const ui32 vec_size, const FVParams& params){ 79 | auto pt = packed_decode(Decrypt(sk, ct, params), params.p, params.logn); 80 | uv64 vec(vec_size); 81 | for(ui32 n=0; n, uv64> preprocess_server_share(const uv64& vec, const FVParams& params); 19 | 20 | Ciphertext square_online(const CTVec& ct_vec_c, const std::vector& pt_vec_s, const FVParams& params); 21 | 22 | uv64 postprocess_client_share(const SecretKey& sk, const Ciphertext& ct, 23 | const ui32 vec_size, const FVParams& params); 24 | 25 | uv64 square_pt(const uv64& vec_c, const uv64& vec_s, const uv64& vec_s_f, const ui64 p); 26 | } 27 | 28 | 29 | 30 | 31 | #endif /* SRC_LIB_PKE_MAT_MUL_H_ */ 32 | -------------------------------------------------------------------------------- /src/lib/utils/backend.h: -------------------------------------------------------------------------------- 1 | /* 2 | * backend.h 3 | * 4 | * Created on: Aug 25, 2017 5 | * Author: chiraag 6 | * 7 | */ 8 | 9 | #ifndef LBCRYPTO_MATH_BACKEND_H 10 | #define LBCRYPTO_MATH_BACKEND_H 11 | 12 | #include 13 | #include 14 | 15 | /** 16 | * @namespace lbcrypto 17 | * The namespace of lbcrypto 18 | */ 19 | namespace lbcrypto { 20 | typedef int32_t si32; 21 | typedef uint32_t ui32; 22 | typedef int64_t si64; 23 | typedef uint64_t ui64; 24 | typedef __uint128_t ui128; 25 | 26 | typedef std::vector sv32; 27 | typedef std::vector uv32; 28 | typedef std::vector sv64; 29 | typedef std::vector uv64; 30 | typedef std::vector uv128; 31 | 32 | 33 | /** 34 | * @brief Lists all modes for RLWE schemes, such as BGV and FV 35 | */ 36 | enum MODE { 37 | RLWE = 0, 38 | OPTIMIZED = 1 39 | }; 40 | 41 | } // namespace lbcrypto ends 42 | 43 | 44 | #endif 45 | -------------------------------------------------------------------------------- /src/lib/utils/debug.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @file debug.cpp This file contains macros and associated helper functions for quick cerr oriented debugging 3 | that can be quickly enabled and disabled. It also contains functions for timing code. 4 | * @author TPOC: palisade@njit.edu 5 | * 6 | * @copyright Copyright (c) 2017, New Jersey Institute of Technology (NJIT) 7 | * All rights reserved. 8 | * Redistribution and use in source and binary forms, with or without modification, 9 | * are permitted provided that the following conditions are met: 10 | * 1. Redistributions of source code must retain the above copyright notice, this 11 | * list of conditions and the following disclaimer. 12 | * 2. Redistributions in binary form must reproduce the above copyright notice, this 13 | * list of conditions and the following disclaimer in the documentation and/or other 14 | * materials provided with the distribution. 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS 21 | * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN 24 | * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | */ 27 | 28 | #include 29 | #include 30 | #include "time.h" 31 | #include 32 | #include "debug.h" 33 | 34 | 35 | double currentDateTime() 36 | { 37 | 38 | std::chrono::time_point now = std::chrono::system_clock::now(); 39 | 40 | time_t tnow = std::chrono::system_clock::to_time_t(now); 41 | tm *date = localtime(&tnow); //todo: dperecated use localtime_s 42 | date->tm_hour = 0; 43 | date->tm_min = 0; 44 | date->tm_sec = 0; 45 | 46 | auto midnight = std::chrono::system_clock::from_time_t(mktime(date)); 47 | 48 | return std::chrono::duration (now - midnight).count(); 49 | } 50 | 51 | 52 | -------------------------------------------------------------------------------- /src/lib/utils/network.cpp: -------------------------------------------------------------------------------- 1 | #include "network.h" 2 | 3 | using namespace osuCrypto; 4 | #include 5 | #include 6 | #include 7 | #define tryCount 2 8 | 9 | void senderGetLatency(Channel& chl) 10 | { 11 | 12 | u8 dummy[1]; 13 | 14 | chl.asyncSend(dummy, 1); 15 | 16 | 17 | 18 | chl.recv(dummy, 1); 19 | chl.asyncSend(dummy, 1); 20 | 21 | 22 | std::vector oneMbit((1 << 20) / 8); 23 | for (u64 i = 0; i < tryCount; ++i) 24 | { 25 | chl.recv(dummy, 1); 26 | 27 | for(u64 j =0; j < (1<<10); ++j) 28 | chl.asyncSend(oneMbit.data(), oneMbit.size()); 29 | } 30 | chl.recv(dummy, 1); 31 | 32 | } 33 | 34 | void recverGetLatency(Channel& chl) 35 | { 36 | 37 | u8 dummy[1]; 38 | chl.recv(dummy, 1); 39 | Timer timer; 40 | auto start = timer.setTimePoint(""); 41 | chl.asyncSend(dummy, 1); 42 | 43 | 44 | chl.recv(dummy, 1); 45 | 46 | auto mid = timer.setTimePoint(""); 47 | auto recvStart = mid; 48 | auto recvEnd = mid; 49 | 50 | auto rrt = mid - start; 51 | std::cout << "latency: " << std::chrono::duration_cast(rrt).count() << " ms" << std::endl; 52 | 53 | std::vector oneMbit((1 << 20) / 8); 54 | for (u64 i = 0; i < tryCount; ++i) 55 | { 56 | recvStart = timer.setTimePoint(""); 57 | chl.asyncSend(dummy, 1); 58 | 59 | for (u64 j = 0; j < (1 << 10); ++j) 60 | chl.recv(oneMbit); 61 | 62 | recvEnd = timer.setTimePoint(""); 63 | 64 | // nanoseconds per GegaBit 65 | auto uspGb = std::chrono::duration_cast(recvEnd - recvStart - rrt / 2).count(); 66 | 67 | // nanoseconds per second 68 | double usps = std::chrono::duration_cast(std::chrono::seconds(1)).count(); 69 | 70 | // MegaBits per second 71 | auto Mbps = usps / uspGb * (1 << 10); 72 | 73 | std::cout << "bandwidth: " << Mbps << " Mbps" << std::endl; 74 | } 75 | 76 | chl.asyncSend(dummy, 1); 77 | 78 | } 79 | -------------------------------------------------------------------------------- /src/lib/utils/network.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // This file and the associated implementation has been placed in the public domain, waiving all copyright. No restrictions are placed on its use. 3 | 4 | 5 | 6 | #include 7 | void senderGetLatency(osuCrypto::Channel& chl); 8 | 9 | void recverGetLatency(osuCrypto::Channel& chl); 10 | -------------------------------------------------------------------------------- /src/lib/utils/test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * test.cpp 3 | * 4 | * Created on: Sep 1, 2017 5 | * Author: chiraag 6 | */ 7 | 8 | #include 9 | #include "utils/test.h" 10 | 11 | namespace lbcrypto { 12 | 13 | sv64 to_signed(uv64 v, ui64 p){ 14 | sv64 sv(v.size()); 15 | ui64 bound = p >> 1; 16 | for(ui32 i=0; i bound) ? -1*(si64)(p-v[i]): v[i]; 18 | } 19 | 20 | return sv; 21 | } 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/lib/utils/test.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "backend.h" 7 | 8 | #ifndef LBCRYPTO_TEST_H 9 | #define LBCRYPTO_TEST_H 10 | 11 | namespace lbcrypto { 12 | 13 | sv64 to_signed(uv64 v, ui64 p); 14 | 15 | template 16 | std::string vec_to_str(std::vector v){ 17 | std::string str; 18 | for(ui32 i=0; i 26 | std::string mat_to_str(std::vector> m){ 27 | std::string str; 28 | for(ui32 j=0; j 39 | void check_vec_eq(std::vector v1, std::vector v2, 40 | const std::string& what){ 41 | if(v1 != v2){ 42 | std::cout << vec_to_str(v1) << std::endl; 43 | std::cout << vec_to_str(v2) << std::endl; 44 | throw std::logic_error(what); 45 | } 46 | 47 | return; 48 | } 49 | 50 | template 51 | void check_mat_eq( 52 | std::vector> m1, 53 | std::vector> m2, 54 | const std::string& what){ 55 | if(m1.size() != m2.size()){ 56 | std::cout << "Sizes: " << m1.size() << " " << m1.size() << std::endl; 57 | throw std::logic_error(what); 58 | // return m2.size(); 59 | } else { 60 | for(ui32 n=0; n 28 | 29 | //#include "../lib/lattice/dcrtpoly.h" 30 | #include "include/gtest/gtest.h" 31 | 32 | 33 | #include "math/backend.h" 34 | //#include "math/nbtheory.h" 35 | //#include "lattice/elemparams.h" 36 | //#include "lattice/ilparams.h" 37 | //#include "lattice/ildcrtparams.h" 38 | //#include "lattice/ilelement.h" 39 | #include "math/distrgen.h" 40 | //#include "lattice/poly.h" 41 | //#include "utils/utilities.h" 42 | 43 | using namespace std; 44 | using namespace lbcrypto; 45 | 46 | int main(int argc, char **argv) { 47 | 48 | ::testing::InitGoogleTest(&argc, argv); 49 | 50 | // if there are no filters used, default to omitting VERY_LONG tests 51 | // otherwise we lose control over which tests we can run 52 | //::testing::GTEST_FLAG(filter) = "*CRT_polynomial_multiplication_small"; 53 | 54 | if (::testing::GTEST_FLAG(filter) == "*") { 55 | ::testing::GTEST_FLAG(filter) = "-*_VERY_LONG"; 56 | } 57 | int rv = RUN_ALL_TESTS(); 58 | 59 | std::cout << rv << ", press return to continue..." << std::endl; 60 | std::cin.get(); 61 | 62 | return 0; 63 | } 64 | 65 | -------------------------------------------------------------------------------- /src/unittest/UnitTestFVAutomorph.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @file 3 | * @author TPOC: palisade@njit.edu 4 | * 5 | * @copyright Copyright (c) 2017, New Jersey Institute of Technology (NJIT) 6 | * All rights reserved. 7 | * Redistribution and use in source and binary forms, with or without modification, 8 | * are permitted provided that the following conditions are met: 9 | * 1. Redistributions of source code must retain the above copyright notice, this 10 | * list of conditions and the following disclaimer. 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, this 12 | * list of conditions and the following disclaimer in the documentation and/or other 13 | * materials provided with the distribution. 14 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 18 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 19 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS 20 | * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 21 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 22 | * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN 23 | * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | * 25 | */ 26 | /* 27 | This code tests the transform feature of the PALISADE lattice encryption library. 28 | */ 29 | 30 | #include "include/gtest/gtest.h" 31 | #include 32 | 33 | #include "../lib/pke/gazelle.h" 34 | 35 | using namespace std; 36 | using namespace lbcrypto; 37 | 38 | 39 | class UnitFVAutomorph : public ::testing::Test { 40 | protected: 41 | virtual void SetUp() { 42 | } 43 | 44 | virtual void TearDown() { 45 | // Code here will be called immediately after each test 46 | // (right before the destructor). 47 | } 48 | }; 49 | 50 | TEST(UTFV_Automorph, Ref){ 51 | //------------------ Setup Parameters ------------------ 52 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 53 | ui64 z_p = RootOfUnity(opt::phim << 1, opt::p); 54 | ftt_precompute(z, opt::q, opt::logn); 55 | ftt_precompute(z_p, opt::p, opt::logn); 56 | encoding_precompute(opt::p, opt::logn); 57 | precompute_automorph_index(opt::phim); 58 | 59 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 60 | 61 | FVParams test_params { 62 | false, 63 | opt::q, opt::p, opt::logn, opt::phim, 64 | (opt::q/opt::p), 65 | OPTIMIZED, std::make_shared(dgg), 66 | 20 67 | }; 68 | 69 | auto kp = KeyGen(test_params); 70 | uv64 v1 = get_dgg_testvector(opt::phim, opt::p); 71 | uv64 pt1 = packed_encode(v1, opt::p, opt::logn); 72 | auto ct1 = Encrypt(kp.sk, pt1, test_params); 73 | ui32 rot = 4; 74 | 75 | uv32 index_list(opt::logn); 76 | ui32 index = 1; 77 | for(ui32 i=0; i(dgg), 112 | 20 113 | }; 114 | 115 | auto kp = KeyGen(test_params); 116 | uv64 v1 = get_dgg_testvector(opt::phim, opt::p); 117 | uv64 pt1 = packed_encode(v1, opt::p, opt::logn); 118 | auto ct1 = Encrypt(kp.sk, pt1, test_params); 119 | ui32 rot = 4; 120 | 121 | uv32 index_list(opt::logn); 122 | ui32 index = 1; 123 | for(ui32 i=0; i 32 | 33 | #include "../lib/pke/gazelle.h" 34 | 35 | using namespace std; 36 | using namespace lbcrypto; 37 | 38 | 39 | class UnitTestFVBase : public ::testing::Test { 40 | protected: 41 | virtual void SetUp() { 42 | } 43 | 44 | virtual void TearDown() { 45 | // Code here will be called immediately after each test 46 | // (right before the destructor). 47 | } 48 | }; 49 | 50 | /*--------------------------------------- TESTING METHODS OF TRANSFORM --------------------------------------------*/ 51 | 52 | // TEST CASE TO TEST POLYNOMIAL MULTIPLICATION USING CHINESE REMAINDER THEOREM 53 | 54 | TEST(UTFV, Ref){ 55 | //------------------ Setup Parameters ------------------ 56 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 57 | 58 | FVParams test_params { 59 | false, 60 | opt::q, opt::p, opt::logn, opt::phim, 61 | (opt::q/opt::p), 62 | OPTIMIZED, std::make_shared(dgg) 63 | }; 64 | 65 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 66 | ftt_precompute(z, opt::q, opt::logn); 67 | 68 | uv64 pt = get_dug_vector(opt::phim, opt::p); 69 | 70 | //----------------------- KeyGen ----------------------- 71 | auto kp = KeyGen(test_params); 72 | kp = KeyGen(test_params); 73 | 74 | //--------------------- PK-Encrypt---------------------- 75 | Ciphertext ct_pk(opt::phim); 76 | ct_pk = Encrypt(kp.pk, pt, test_params); 77 | 78 | //--------------------- SK-Encrypt---------------------- 79 | Ciphertext ct_sk(opt::phim); 80 | ct_sk = Encrypt(kp.sk, pt, test_params); 81 | 82 | //---------------------- Decrypt ----------------------- 83 | uv64 pt_pk(opt::phim), pt_sk(opt::phim); 84 | pt_pk = Decrypt(kp.sk, ct_pk, test_params); 85 | pt_sk = Decrypt(kp.sk, ct_sk, test_params); 86 | 87 | EXPECT_EQ(pt, pt_pk); 88 | 89 | EXPECT_EQ(pt, pt_sk); 90 | 91 | } 92 | 93 | TEST(UTFV, Fast){ 94 | //------------------ Setup Parameters ------------------ 95 | DiscreteGaussianGenerator dgg = DiscreteGaussianGenerator(4.0); 96 | 97 | FVParams test_params { 98 | true, 99 | opt::q, opt::p, opt::logn, opt::phim, 100 | (opt::q/opt::p), 101 | OPTIMIZED, std::make_shared(dgg) 102 | }; 103 | 104 | ui64 z = RootOfUnity(opt::phim << 1, opt::q); 105 | ftt_precompute(z, opt::q, opt::logn); 106 | 107 | uv64 pt = get_dug_vector(opt::phim, opt::p); 108 | 109 | //----------------------- KeyGen ----------------------- 110 | auto kp = KeyGen(test_params); 111 | kp = KeyGen(test_params); 112 | 113 | //--------------------- PK-Encrypt---------------------- 114 | Ciphertext ct_pk(opt::phim); 115 | ct_pk = Encrypt(kp.pk, pt, test_params); 116 | 117 | //--------------------- SK-Encrypt---------------------- 118 | Ciphertext ct_sk(opt::phim); 119 | ct_sk = Encrypt(kp.sk, pt, test_params); 120 | 121 | //---------------------- Decrypt ----------------------- 122 | uv64 pt_pk(opt::phim), pt_sk(opt::phim); 123 | pt_pk = Decrypt(kp.sk, ct_pk, test_params); 124 | pt_sk = Decrypt(kp.sk, ct_sk, test_params); 125 | 126 | EXPECT_EQ(pt, pt_pk); 127 | 128 | EXPECT_EQ(pt, pt_sk); 129 | 130 | } 131 | -------------------------------------------------------------------------------- /src/unittest/UnitTestTransform.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @file 3 | * @author TPOC: palisade@njit.edu 4 | * 5 | * @copyright Copyright (c) 2017, New Jersey Institute of Technology (NJIT) 6 | * All rights reserved. 7 | * Redistribution and use in source and binary forms, with or without modification, 8 | * are permitted provided that the following conditions are met: 9 | * 1. Redistributions of source code must retain the above copyright notice, this 10 | * list of conditions and the following disclaimer. 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, this 12 | * list of conditions and the following disclaimer in the documentation and/or other 13 | * materials provided with the distribution. 14 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 18 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 19 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS 20 | * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 21 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 22 | * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN 23 | * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | * 25 | */ 26 | /* 27 | This code tests the transform feature of the PALISADE lattice encryption library. 28 | */ 29 | 30 | #include "include/gtest/gtest.h" 31 | #include 32 | 33 | #include "math/backend.h" 34 | #include "../lib/math/transfrm.h" 35 | #include "math/nbtheory.h" 36 | 37 | using namespace std; 38 | using namespace lbcrypto; 39 | 40 | class UnitTestTransform : public ::testing::Test { 41 | protected: 42 | virtual void SetUp() { 43 | } 44 | 45 | virtual void TearDown() { 46 | // Code here will be called immediately after each test 47 | // (right before the destructor). 48 | } 49 | }; 50 | 51 | /*--------------------------------------- TESTING METHODS OF TRANSFORM --------------------------------------------*/ 52 | 53 | // TEST CASE TO TEST POLYNOMIAL MULTIPLICATION USING CHINESE REMAINDER THEOREM 54 | 55 | TEST(UTTransform, CRT_polynomial_multiplication){ 56 | 57 | ui64 modulus(113); //65537 58 | ui32 logn = 2; 59 | ui32 phim = (1 << logn); 60 | ui32 m = 2*phim; 61 | 62 | ui64 rootOfUnity = lbcrypto::RootOfUnity(m, modulus); 63 | 64 | uv64 a = {1,2,4,1}; 65 | uv64 b(a); 66 | 67 | ftt_precompute(rootOfUnity, modulus, 2); 68 | 69 | uv64 A = ftt_fwd(a, modulus, 2); 70 | uv64 B = ftt_fwd(b, modulus, 2); 71 | 72 | uv64 AB; 73 | for (ui32 i=0; i