├── README.md ├── compare_final.py ├── conv.go ├── eval.go ├── main.go ├── rot_util.go ├── test.go ├── test_BL.go └── test_run /README.md: -------------------------------------------------------------------------------- 1 | # Source code for **Optimized Privacy-Preserving CNN Inference with Fully Homomorphic Encryption** 2 | 3 | ## Requirements 4 | 1. Go 1.16.6 or higher () 5 | - Install command with apt-get: 6 | ```console 7 | apt-get install golang-go 8 | ``` 9 | 2. Go packages (Go Cryptography \& Lattigo (fork)) 10 | - After installing Go, install packages with following commands: 11 | ```console 12 | go get -u golang.org/x/crypto/... 13 | go get -u github.com/dwkim606/test_lattigo 14 | ``` 15 | **CAUTION**: For Lattigo, we must install the forked version with above command (instead of the latest [one](https://github.com/tuneinsight/lattigo)) 16 | 17 | 3. Python3 with numpy package (required only for checking the precision of CNN classifier) 18 | 19 | ## Running the Test 20 | 21 | 0. Dataset Preparation: **Necessary** to run the tests 22 | 23 | - Download the data file from the link: https://drive.google.com/drive/folders/1zLTzJ58E_CDtqvnPv8t9YtgkDaHouWWn?usp=sharing 24 | - Move all folders (Resnet_enc_results, Resnet_plain_data, Resnet_weights, test_conv_data) to the same directory as the source code. 25 | 26 | 27 | The following tests are available: 28 | 29 | 1. Convolutions: run Baseline and Ours for various number of batches 30 | - Arguments (in the order of input): 31 | - conv 32 | - (3,5,7): width of kernel 33 | - (0,1,2,3): set number of batches to 4, 16, 64, 256, respectively. 34 | - (1 to 10): number of test runs 35 | - Example command: convolution with kernel width 3, number of batches 16, and 5 test runs 36 | ```console 37 | go run *.go conv 3 1 5 38 | ``` 39 | 40 | 2. Convolutions followed by ReLU evaluation (and Bootstrapping): run Baseline and Ours 41 | - Arguments: 42 | - convReLU 43 | - other parts are the same as Convolutions 44 | - Example command: convolution with kernel width 5, number of batches 4, and 3 test runs, then ReLU evaluation with Bootstrapping 45 | ```console 46 | go run *.go convReLU 5 0 3 47 | ``` 48 | 3. 20-layer CNN evaluatoin with our method on CIFAR10/CIFAR100 dataset 49 | - Arguments: 50 | - resnet 51 | - (3,5,7): width of kernel 52 | - (8,14,20): number of layers of CNN 53 | - (1,2,3): wideness factor 54 | - (1 to 100 or 1 to 1000): number of tests 55 | - (true, false): true -> CIFAR100, false -> CIFAR10 56 | - **List of available arguments**: (given in the paper; other arguments require appropriate weights for CNN) 57 | - resnet 3 20 1 (1 to 1000) (true/false) 58 | - resnet 5 20 1 (1 to 100) (true/false) 59 | - resnet 7 20 1 (1 to 100) (true/false) 60 | - resnet 5 8 3 (1 to 1000) (true/false) 61 | - resnet 3 14 3 (1 to 1000) (true/false) 62 | - resnet 3 20 3 (1 to 1000) (true/false) 63 | - Example command: run CNN with kernel width 3, number of layers 20, widness factor 1, on 10 test inputs on CIFAR10 dataset 64 | ```console 65 | go run *.go resnet 3 20 1 10 false 66 | ``` 67 | **CAUTION**: The CNN evaluation test requires at most roughly 100GB of memory. 68 | - To check the precision of encrypted inference, run compare_final.py with python3, argument (width of kernel, depth, widness, CIFAR100 or not) 69 | - Example command: check the precision of kernel 3, depth 20, widness 1 CNN inference on CIFAR10 dataset 70 | ```console 71 | python3 compare_final.py 3 20 1 false 72 | ``` 73 | **CAUTION**: Precision check requires the encrypted inference to be performed beforehand. 74 | 75 | ## MISC. 76 | One can also generate an executble for a remote server with Linux and AMD cpu via the following command. 77 | (Modify the command appropriately for other OS and cpu) 78 | ```console 79 | env GOOS=linux GOARCH=amd64 go build -o test_run 80 | ``` 81 | It will generate an executable titled "test_run", which can run on the server with appropriate arguments. 82 | (e.g., for convolution) 83 | ```console 84 | ./test_run conv 3 1 5 85 | ``` 86 | With this command, one can run the test on any server without Go (given that executable is generated from other device with Go for compile). 87 | The executable must be in the same directory as data folders (Resnet_enc_results, Resnet_plain_data, Resnet_weights, test_conv_data). 88 | -------------------------------------------------------------------------------- /compare_final.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import os 4 | from distutils.util import strtobool 5 | from statistics import mean, stdev 6 | 7 | # for resnet, compare plain with enc 8 | def compare_results(ker, depth, wid, cf100): 9 | 10 | if cf100: 11 | plain_folder_dir = 'Resnet_plain_data/cf100_crop_ker'+str(ker)+'_d'+str(depth)+'_wid'+str(wid) 12 | enc_result_dir = 'Resnet_enc_results/results_cf100_crop_ker'+str(ker)+'_d'+str(depth)+'_wid'+str(wid)+'/' 13 | num_classes = 100 14 | else: 15 | plain_folder_dir = 'Resnet_plain_data/crop_ker'+str(ker)+'_d'+str(depth)+'_wid'+str(wid) 16 | enc_result_dir = 'Resnet_enc_results/results_crop_ker'+str(ker)+'_d'+str(depth)+'_wid'+str(wid)+'/' 17 | num_classes = 10 18 | 19 | if wid == 1: 20 | max_num_samples = 100 21 | if ker == 3: 22 | max_num_samples = 1000 23 | else: 24 | max_num_samples = 1000 25 | 26 | plain_pred_file = os.path.join(plain_folder_dir, 'plain_prediction_'+str(max_num_samples)+'.csv') 27 | true_pred_file = os.path.join(plain_folder_dir, 'test_labels_'+str(max_num_samples)+'.csv') 28 | plain_pred = np.reshape(np.loadtxt(plain_pred_file), [max_num_samples, num_classes]) 29 | true_pred = np.reshape(np.loadtxt(true_pred_file), [max_num_samples]) 30 | 31 | acc = 0 32 | true_acc = 0 33 | pl_true_acc = 0 34 | total = 0 35 | no_iters = [] 36 | wrong_result = {} 37 | os_path = enc_result_dir+'class_result_ker'+str(ker)+'_' 38 | 39 | for iter in range(max_num_samples): 40 | if os.path.exists(os_path+str(iter)+'.csv'): 41 | read = np.loadtxt(os_path+str(iter)+'.csv') 42 | total+=1 43 | else: 44 | no_iters.append(iter) 45 | continue 46 | 47 | res_np = read[:num_classes] #np.reshape(read, [-1])[:10] 48 | # print("enc: ", res_np, "argmax: ", np.argmax(res_np)) 49 | # print("plain: ", plain_pred[iter], "argmax: ", np.argmax(plain_pred[iter])) 50 | if (np.argmax(res_np) == np.argmax(plain_pred[iter])): 51 | acc += 1 52 | else: 53 | wrong_result[str(iter)] = [] 54 | wrong_result[str(iter)].insert(0, res_np) 55 | wrong_result[str(iter)].insert(1, plain_pred[iter]) 56 | wrong_result[str(iter)].insert(2, true_pred[iter]) 57 | if (np.argmax(res_np) == true_pred[iter]): 58 | true_acc += 1 59 | if (np.argmax(plain_pred[iter]) == true_pred[iter]): 60 | pl_true_acc += 1 61 | 62 | print("Plain precision: ", pl_true_acc, "/", total) 63 | print("Enc precision: ", true_acc, "/", total) 64 | print("plain vs enc accordance: ", acc, "/", total) 65 | # print("among ", max_num_samples, " samples.") 66 | #print("missing: ", no_iters) 67 | print("\n wrong results: \n") 68 | for i, result in wrong_result.items(): 69 | print(i, "-th iter.") 70 | print("enc: ", result[0], "argmax: ", np.argmax(result[0])) 71 | print("plain: ", result[1], "argmax: ", np.argmax(result[1]), "\n") 72 | print("true: ", result[2], " \n" ) 73 | 74 | # tf_images = tf.reshape(tf.constant(np.loadtxt('test_images_'+str(num_samples)+'.csv'), tf.float32), [num_samples, 32, 32, 3]) 75 | # pred = plain_resnet(tf_images) 76 | # print("enc == plain?", tf.argmax(tf.squeeze(conv, axis=[1,2]),1) == tf.argmax(pred[iter],1)) 77 | 78 | 79 | ### main ### 80 | 81 | #num_iter = 1000 82 | ker = int(sys.argv[1]) 83 | depth = int(sys.argv[2]) 84 | wide = int(sys.argv[3]) 85 | cf100 = strtobool(sys.argv[4]) 86 | 87 | print("ker: ", ker, " depth: ", depth, " wide: ", wide, " cf100? ", bool(cf100)) 88 | compare_results(ker, depth, wide, cf100) 89 | 90 | -------------------------------------------------------------------------------- /conv.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "time" 7 | 8 | "github.com/dwkim606/test_lattigo/ckks" 9 | "github.com/dwkim606/test_lattigo/rlwe" 10 | ) 11 | 12 | // output bitreversed input with bitwid 13 | func reverseBits(num uint32, bitwid int) uint32 { 14 | num = num << (32 - bitwid) 15 | 16 | var ret = uint32(0) 17 | var power = uint32(31) 18 | for num != 0 { 19 | ret += (num & 1) << power 20 | num = num >> 1 21 | power -= 1 22 | } 23 | return ret 24 | } 25 | 26 | // output the slice with bitreversed order from input 27 | func reverseOrder(input []float64, bitwid int) []float64 { 28 | out := make([]float64, len(input)) 29 | for i := range out { 30 | out[i] = input[reverseBits(uint32(i), bitwid)] 31 | } 32 | 33 | return out 34 | } 35 | 36 | // Extract upper left elements of size prt_wid * prt_wid from the input arrg vec with batch batches 37 | // prt_wid, batch all are those from the output of our conv algorithm (consider padding) 38 | func reshape_conv_out(result []float64, prt_wid, out_num int) []float64 { 39 | prt_out := make([]float64, prt_wid*prt_wid*out_num) 40 | in_wid := 2 * prt_wid 41 | batch := len(result) / (in_wid * in_wid) 42 | 43 | for i := 0; i < out_num; i++ { 44 | for j := 0; j < prt_wid; j++ { 45 | for k := 0; k < prt_wid; k++ { 46 | prt_out[i+out_num*(j*prt_wid+k)] = result[i+batch*(j*in_wid+k)] //[batch*(in_wid+1)*(ker_wid-1)+i+batch*(j*in_wid+k)] 47 | } 48 | } 49 | } 50 | 51 | return prt_out 52 | } 53 | 54 | // Reshape 1-D input (from python) into H,W,Batch format 55 | // i.e., 0, 1, 2, 3, -> 1st input of 0,1,2,3-th batch 56 | // only for BL test 57 | func reshape_input_BL(input []float64, in_wid int) (out []complex128) { 58 | out = make([]complex128, len(input)) 59 | batch := len(input) / (in_wid * in_wid) 60 | l := 0 61 | 62 | for i := 0; i < in_wid; i++ { 63 | for j := 0; j < in_wid; j++ { 64 | for k := 0; k < batch; k++ { 65 | out[i*in_wid+j+k*in_wid*in_wid] = complex(input[l], 0) 66 | l++ 67 | } 68 | } 69 | } 70 | 71 | return out 72 | } 73 | 74 | // Reshape 1-D kernel (from python) into (H,W,in,out) format, and applies BN, then into max ker 75 | // ker[i][j][ib][ob]: (i,j)-th elt of kernel for ib-th input to ob-th output 76 | // only for BL test // for transposed conv, we should add rearragement!! 77 | // norm == 1 : normal case, norm == 4 : in & out batch is (1,0,0,0,2,0,0,0,3,0,0,0,4,0,0,0) 78 | func reshape_ker_BL(input, BN_a []float64, ker_wid, inB, outB, max_bat, pos, norm int, trans bool) (max_ker_rs [][][][]float64) { 79 | ker_rs := make([][][][]float64, ker_wid) 80 | for i := 0; i < ker_wid; i++ { 81 | ker_rs[i] = make([][][]float64, ker_wid) 82 | for j := 0; j < ker_wid; j++ { 83 | ker_rs[i][j] = make([][]float64, inB) 84 | for ib := 0; ib < inB; ib++ { 85 | ker_rs[i][j][ib] = make([]float64, outB) 86 | for ob := 0; ob < outB; ob++ { 87 | if trans { 88 | if ib < (inB / 4) { 89 | ker_rs[i][j][ib][ob] = input[(4*ib+pos)+ob*inB+(ker_wid-j-1)*outB*inB+(ker_wid-i-1)*outB*inB*ker_wid] * BN_a[ob] // Apply BN 90 | } 91 | } else { 92 | ker_rs[i][j][ib][ob] = input[ob+ib*outB+j*outB*inB+i*outB*inB*ker_wid] * BN_a[ob] // Apply BN 93 | } 94 | } 95 | } 96 | } 97 | } 98 | // overload to max batch case 99 | max_ker_rs = make([][][][]float64, ker_wid) 100 | for i := 0; i < ker_wid; i++ { 101 | max_ker_rs[i] = make([][][]float64, ker_wid) 102 | for j := 0; j < ker_wid; j++ { 103 | max_ker_rs[i][j] = make([][]float64, max_bat) 104 | for ib := 0; ib < max_bat; ib++ { 105 | max_ker_rs[i][j][ib] = make([]float64, max_bat) 106 | } 107 | for ib := 0; ib < inB; ib++ { 108 | for ob := 0; ob < outB; ob++ { 109 | max_ker_rs[i][j][norm*ib][norm*ob] = ker_rs[i][j][ib][ob] 110 | } 111 | } 112 | } 113 | } 114 | 115 | return max_ker_rs 116 | } 117 | 118 | // rotate input ciphertext and outputs rotated ciphertexts 119 | // always assume Odd ker_wid 120 | func preConv_BL(evaluator ckks.Evaluator, ct_in *ckks.Ciphertext, in_wid, ker_wid int) (ct_in_rots []*ckks.Ciphertext) { 121 | ker_size := ker_wid * ker_wid 122 | ct_in_rots = make([]*ckks.Ciphertext, ker_size) 123 | 124 | st := -(ker_wid / 2) 125 | end := ker_wid / 2 126 | 127 | var rotations []int 128 | for i := st; i <= end; i++ { 129 | for j := st; j <= end; j++ { 130 | rotations = append(rotations, i*in_wid+j) 131 | } 132 | } 133 | ct_rots_test := evaluator.RotateHoisted(ct_in, rotations) 134 | k := 0 135 | for i := st; i <= end; i++ { 136 | for j := st; j <= end; j++ { 137 | ct_in_rots[k] = ct_rots_test[i*in_wid+j] 138 | k++ 139 | } 140 | } 141 | 142 | return ct_in_rots 143 | } 144 | 145 | // eval Convolution for the part of output: need to sum this up with rotations 146 | func postConv_BL(param ckks.Parameters, encoder ckks.Encoder, evaluator ckks.Evaluator, ct_in_rots []*ckks.Ciphertext, in_wid, ker_wid, rot, pad int, max_ker_rs [][][][]float64) (ct_out *ckks.Ciphertext) { 147 | 148 | max_batch := param.Slots() / (in_wid * in_wid) 149 | postKer := make([]complex128, param.Slots()) 150 | pl_tmp := ckks.NewPlaintext(param, ct_in_rots[0].Level(), param.Scale()) 151 | 152 | iter := 0 153 | for i := 0; i < ker_wid; i++ { 154 | for j := 0; j < ker_wid; j++ { 155 | for k := 0; k < max_batch; k++ { 156 | for ki := 0; ki < (in_wid - pad); ki++ { // position of input 157 | for kj := 0; kj < (in_wid - pad); kj++ { 158 | postKer[k*in_wid*in_wid+ki*in_wid+kj] = complex(max_ker_rs[i][j][k][(k-rot+max_batch)%max_batch], 0) 159 | if (((ki + i - (ker_wid / 2)) < 0) || ((ki + i - (ker_wid / 2)) >= (in_wid - pad))) || (((kj + j - (ker_wid / 2)) < 0) || ((kj + j - (ker_wid / 2)) >= (in_wid - pad))) { 160 | postKer[k*in_wid*in_wid+ki*in_wid+kj] = complex(0, 0) 161 | } 162 | } 163 | } 164 | } 165 | encoder.Encode(pl_tmp, postKer, param.LogSlots()) 166 | encoder.ToNTT(pl_tmp) 167 | if (i == 0) && (j == 0) { 168 | ct_out = evaluator.MulNew(ct_in_rots[iter], pl_tmp) 169 | } else { 170 | ct_tmp := evaluator.MulNew(ct_in_rots[iter], pl_tmp) 171 | evaluator.Add(ct_out, ct_tmp, ct_out) 172 | } 173 | iter++ 174 | } 175 | } 176 | 177 | return ct_out 178 | } 179 | 180 | // Reshape 1-D ker_in (from python) into batch number of ker_outs: ker_out[i][j] = j-th kernel (elements then batch order) for i-th output 181 | // i.e., ker_out is of the shape (out_batch, (in_batch * ker_size)) 182 | // ker_out[i] = [k1 for 1st input, ..., ,kk for 1st input, k1 for 2nd input, ...] 183 | // trans = true for transposed convolution (in trans convolution of python, we should rearrage ker_out carefully) 184 | func reshape_ker(ker_in []float64, k_sz, out_batch int, trans bool) (ker_out [][]float64) { 185 | ker_out = make([][]float64, out_batch) 186 | in_batch := len(ker_in) / (k_sz * out_batch) 187 | 188 | for i := 0; i < out_batch; i++ { 189 | ker_out[i] = make([]float64, k_sz*in_batch) 190 | for j := 0; j < in_batch; j++ { 191 | for k := 0; k < k_sz; k++ { 192 | if trans { 193 | ker_out[i][j*k_sz+(k_sz-k-1)] = ker_in[j+i*in_batch+k*out_batch*in_batch] 194 | } else { 195 | ker_out[i][j*k_sz+k] = ker_in[i+j*out_batch+k*out_batch*in_batch] 196 | // ker_out[i][j*k_sz+k] = ker_in[j+i*in_batch+k*out_batch*in_batch] 197 | } 198 | } 199 | } 200 | } 201 | return 202 | } 203 | 204 | // Encode ker_outs from reshape_ker into the i-th ker vector output 205 | // in_wid, in_batch is those for input (to be convolved) includng padding 206 | func encode_ker_final(ker_in [][]float64, pos, i, in_wid, in_batch, ker_wid int) []float64 { 207 | vec_size := in_wid * in_wid * in_batch 208 | output := make([]float64, vec_size) 209 | bias := pos * ker_wid * ker_wid * in_batch 210 | k_sz := ker_wid * ker_wid 211 | 212 | // allocate each kernel so that 0-th batch and B-1th batch adds together at B-1th position (B = in_batch) 213 | for j := 0; j < in_batch; j++ { // j-th input (batch) 214 | for k := 0; k < k_sz; k++ { 215 | // fmt.Println("ecd: ", j, k) 216 | output[(in_wid*(k/ker_wid)+k%ker_wid)*in_batch+j] = ker_in[i][(in_batch-1-j)*k_sz+(k_sz-1-k)+bias] // * scale; 217 | } 218 | } 219 | 220 | // move the kernel to left adj times, so that the result of "transposed" convolution appears at 0-th position 221 | // adj := (in_wid+1)*(ker_wid-3)/2 + (in_batch - 1) 222 | 223 | adj := (in_batch - 1) + (in_batch)*(in_wid+1)*(ker_wid-1)/2 224 | tmp := make([]float64, adj) 225 | for i := 0; i < adj; i++ { 226 | tmp[i] = output[vec_size-adj+i] 227 | output[vec_size-adj+i] = -output[i] 228 | } 229 | for i := 0; i < vec_size-2*adj; i++ { 230 | output[i] = output[i+adj] 231 | } 232 | for i := 0; i < adj; i++ { 233 | output[i+vec_size-2*adj] = tmp[i] 234 | } 235 | 236 | return output 237 | } 238 | 239 | // Generate the logN # of plaintexts idx[i] = X^(2^i) and GaloisKeys for each 240 | // Required for Packing 241 | func gen_idxNlogs(ECD_LV int, keygen rlwe.KeyGenerator, sk *rlwe.SecretKey, encoder ckks.Encoder, params ckks.Parameters) (idx []*ckks.Plaintext, pack_eval ckks.Evaluator) { 242 | logN := params.LogN() 243 | N := params.N() 244 | gals := []uint64{} 245 | idx = make([]*ckks.Plaintext, logN) 246 | coeffs := make([]float64, N) 247 | 248 | for i := 0; i < logN; i++ { 249 | coeffs[1< 1; i /= 2 { 282 | logStep++ 283 | } 284 | j := params.LogN() - logStep 285 | 286 | for step >= norm { 287 | for i := 0; i < step; i += norm { 288 | tmp1 = pack_eval.MulNew(ctxts[i+step], idx[logStep]) 289 | tmp2 = pack_eval.SubNew(ctxts[i], tmp1) 290 | pack_eval.Add(ctxts[i], tmp1, tmp1) 291 | pack_eval.RotateGal(tmp2, (1< set the same as python 487 | func prep_Ker(params ckks.Parameters, encoder ckks.Encoder, ker_in, BN_a []float64, in_wid, ker_wid, real_ib, real_ob, norm, ECD_LV, pos int, trans bool) (pl_ker []*ckks.Plaintext) { 488 | max_bat := params.N() / (in_wid * in_wid) 489 | ker_size := ker_wid * ker_wid 490 | ker_rs := reshape_ker(ker_in, ker_size, real_ob, trans) // ker1[i][j] = j-th kernel for i-th output 491 | 492 | for i := 0; i < real_ob; i++ { // apply batch normalization 493 | for j := range ker_rs[i] { 494 | ker_rs[i][j] = ker_rs[i][j] * BN_a[i] 495 | } 496 | } 497 | 498 | max_ker_rs := make([][]float64, max_bat) // overloading ker_rs to the case with max_batch 499 | for i := 0; i < max_bat; i++ { 500 | max_ker_rs[i] = make([]float64, max_bat*ker_size) 501 | } 502 | for i := 0; i < real_ob; i++ { 503 | for j := 0; j < real_ib; j++ { 504 | for k := 0; k < ker_size; k++ { 505 | max_ker_rs[norm*i][norm*j*ker_size+k] = ker_rs[i][j*ker_size+k] 506 | } 507 | } 508 | } 509 | 510 | pl_ker = make([]*ckks.Plaintext, max_bat) 511 | for i := 0; i < max_bat; i++ { 512 | pl_ker[i] = ckks.NewPlaintext(params, ECD_LV, params.Scale()) 513 | encoder.EncodeCoeffs(encode_ker_final(max_ker_rs, pos, i, in_wid, max_bat, ker_wid), pl_ker[i]) 514 | encoder.ToNTT(pl_ker[i]) 515 | } 516 | 517 | return pl_ker 518 | } 519 | 520 | // Eval Conv, then Pack 521 | // The ciphertexts must be packed into full (without vacant position) 522 | func conv_then_pack(params ckks.Parameters, pack_evaluator ckks.Evaluator, ctxt_in *ckks.Ciphertext, pl_ker []*ckks.Plaintext, plain_idx []*ckks.Plaintext, max_ob, norm, ECD_LV int, out_scale float64) *ckks.Ciphertext { 523 | start := time.Now() 524 | ctxt_out := make([]*ckks.Ciphertext, max_ob) 525 | for i := 0; i < max_ob; i++ { 526 | if i%norm == 0 { 527 | ctxt_out[i] = pack_evaluator.MulNew(ctxt_in, pl_ker[i]) 528 | pack_evaluator.SetScale(ctxt_out[i], out_scale/(float64(max_ob/norm))) 529 | } 530 | // pack_evaluator.Rescale(ctxt_out[i], float64(1<<10), ctxt_out[i]) 531 | } 532 | mt := time.Since(start) 533 | fmt.Println("\t mult time: ", mt) 534 | ctxt_result := pack_ctxts(pack_evaluator, ctxt_out, max_ob, max_ob/norm, plain_idx, params) 535 | fmt.Println("\t Pack time: ", time.Since(start)-mt) 536 | 537 | // fmt.Println("Result Scale: ", math.Log2(ctxt_result.Scale)) 538 | // fmt.Println("Result LV: ", ctxt_result.Level()) 539 | // fmt.Printf("Done in %s \n", time.Since(start)) 540 | 541 | if (out_scale != ctxt_result.Scale) || (0 != ctxt_result.Level()) { 542 | panic("LV or scale after conv then pack, inconsistent") 543 | } 544 | 545 | return ctxt_result 546 | } 547 | -------------------------------------------------------------------------------- /eval.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "time" 7 | 8 | "github.com/dwkim606/test_lattigo/ckks" 9 | ) 10 | 11 | // take raw_in_wid then outputs appropriate kp_wid and out_batch 12 | // only for our convs Test (not for BL) 13 | func set_Variables(batch, raw_in_wid, in_wid, ker_wid int, kind string) (kp_wid, out_batch, logN int, trans bool) { 14 | N := batch * in_wid * in_wid 15 | logN = 0 16 | for ; (1 << logN) < N; logN++ { 17 | } 18 | max_kp_wid := in_wid - ((ker_wid - 1) / 2) // max possible size of raw_in_wid 19 | 20 | switch kind { 21 | case "Conv": 22 | trans = false 23 | kp_wid = raw_in_wid 24 | out_batch = batch 25 | if kp_wid > max_kp_wid { 26 | fmt.Println("max raw_in_wid: ", max_kp_wid) 27 | panic("too large raw_in_wid.") 28 | } 29 | case "StrConv", "StrConv_fast", "StrConv_odd": 30 | trans = false 31 | kp_wid = 2 * (in_wid/2 - ker_wid/2) 32 | out_batch = batch 33 | if kp_wid > max_kp_wid { 34 | fmt.Println("max raw_in_wid: ", max_kp_wid) 35 | panic("too large raw_in_wid.") 36 | } 37 | case "StrConv_inside": 38 | trans = false 39 | kp_wid = (in_wid/2 - ker_wid/2) 40 | out_batch = batch 41 | case "TransConv": 42 | trans = true 43 | kp_wid = 2 * raw_in_wid 44 | out_batch = batch / 4 45 | if kp_wid > max_kp_wid { 46 | fmt.Println("max raw_in_wid: ", max_kp_wid/2) 47 | panic("too large raw_in_wid.") 48 | } 49 | default: 50 | panic("Wrong kinds!") 51 | } 52 | 53 | return 54 | } 55 | 56 | // apply rotation for strided conv (compress) or transposed conv (extend) 57 | // the same rotation for all batches; use BSGS to reduce rotations 58 | // assume that input batches are well-ordered. compress: (0,4) (1,5) (2,6) (3,7) to (0,1,2,...6,7) extend: (0,2,4,6,1,3,5,7) to (0,1) (2,3) (4,5) (6,7) 59 | // rotation for each batch position (0 to 3) is applied after or before compress or extend, resp. 60 | // total rotation = 2*in_wid*4 + (4-1); depth = 2 61 | func evalRot_BL(cont *context, ct_input *ckks.Ciphertext, in_wid, pos int, trans bool) (ct_res *ckks.Ciphertext) { 62 | if trans { 63 | in_size := in_wid * in_wid 64 | cont.evaluator.Rotate(ct_input, pos*in_size, ct_input) 65 | ct_res = bsgs_ctxt(cont.evaluator, cont.encoder, ct_input, cont.m_idx[in_wid][0], cont.r_idx[in_wid][0], cont.params) 66 | } else { 67 | out_size := in_wid * in_wid / 4 68 | ct_res = bsgs_ctxt(cont.evaluator, cont.encoder, ct_input, cont.m_idx[in_wid][0], cont.r_idx[in_wid][0], cont.params) 69 | cont.evaluator.Rotate(ct_res, -pos*out_size, ct_res) 70 | } 71 | return 72 | } 73 | 74 | // Eval Conv only, always assume max batch 75 | // in_wid must be Po2 (also include padding), includes kernel preparation 76 | // norm == 1 : normal case, norm == 4 : in & out batch is (1,0,0,0,2,0,0,0,3,0,0,0,4,0,0,0) 77 | // for test, use pack_evaluator optimizer 78 | func evalConv_BN_BL_test(cont *context, ct_input *ckks.Ciphertext, ker_in, bn_a, bn_b []float64, in_wid, ker_wid, real_ib, real_ob, pos, norm, pad int, trans, printResult bool) (ct_res *ckks.Ciphertext) { 79 | in_size := in_wid * in_wid 80 | out_size := in_size 81 | max_batch := cont.N / (2 * in_size) 82 | 83 | // fmt.Println() 84 | // fmt.Println("=============== (KER) PREPARATION ===============") 85 | // fmt.Println() 86 | start := time.Now() 87 | max_ker_rs := reshape_ker_BL(ker_in, bn_a, ker_wid, real_ib, real_ob, max_batch, pos, norm, trans) 88 | scale_exp := cont.params.Scale() * cont.params.Scale() 89 | if trans { 90 | scale_exp = cont.params.Scale() * cont.params.Scale() * cont.params.Scale() 91 | } 92 | bn_b_slots := make([]complex128, cont.N/2) 93 | for i, elt := range bn_b { 94 | for j := 0; j < in_wid-pad; j++ { 95 | for k := 0; k < in_wid-pad; k++ { 96 | bn_b_slots[j+k*in_wid+norm*out_size*i] = complex(elt, 0) 97 | } 98 | } 99 | } 100 | 101 | pl_bn_b := ckks.NewPlaintext(cont.params, cont.ECD_LV, scale_exp) 102 | cont.encoder.EncodeNTT(pl_bn_b, bn_b_slots, cont.logN-1) 103 | fmt.Printf("Plaintext (kernel) preparation, Done in %s \n", time.Since(start)) 104 | 105 | // fmt.Println() 106 | // fmt.Println("=============== EVALUATION ===============") 107 | // fmt.Println() 108 | start = time.Now() 109 | ct_inputs_rots := preConv_BL(cont.pack_evaluator, ct_input, in_wid, ker_wid) 110 | fmt.Printf("preConv done in %s \n", time.Since(start)) 111 | 112 | var rot_iters int 113 | if norm*real_ob == max_batch { 114 | rot_iters = real_ob 115 | } else { 116 | rot_iters = max_batch 117 | } 118 | for i := 0; i < rot_iters; i++ { 119 | ct_tmp := postConv_BL(cont.params, cont.encoder, cont.pack_evaluator, ct_inputs_rots, in_wid, ker_wid, norm*i, pad, max_ker_rs) 120 | if i == 0 { 121 | ct_res = ct_tmp 122 | } else { 123 | cont.evaluator.Add(ct_res, cont.pack_evaluator.RotateNew(ct_tmp, norm*i*out_size), ct_res) 124 | } 125 | } 126 | 127 | if ct_res.Scale != scale_exp { 128 | panic("Different scale between pl_bn_b and ctxt") 129 | } 130 | cont.evaluator.Add(ct_res, pl_bn_b, ct_res) 131 | fmt.Printf("Conv (with BN) Done in %s \n", time.Since(start)) 132 | 133 | return ct_res 134 | } 135 | 136 | // reduce mean and final FC layer (in_batch -> 16) 137 | // assume that ct_input has batch (1,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0) 138 | // ker_fc is of size in_batch*10 and 1-dim from [64,10] shape 139 | func evalRMFC_BL(cont *context, ct_input *ckks.Ciphertext, ker_fc, bias []float64, printResult bool) (ct_res *ckks.Ciphertext) { 140 | rs_ker := make([][]float64, 64) 141 | for i := 0; i < 64; i++ { 142 | rs_ker[i] = make([]float64, 16) 143 | for j := 0; j < 10; j++ { 144 | rs_ker[i][j] = ker_fc[j+i*10] / 64.0 // we will add 64 elts instead of averaging them 145 | } 146 | } 147 | 148 | // sum 64 elements instead of averaging them 149 | ct_avg := ct_input 150 | for i := 1; i < 64; i *= 2 { 151 | ct_avg = cont.evaluator.AddNew(ct_avg, cont.evaluator.RotateNew(ct_avg, i)) 152 | } 153 | 154 | for i := 0; i < 16; i++ { 155 | tmp := make([]complex128, cont.N/2) 156 | for j := 0; j < 64; j++ { 157 | tmp[j*64*8] = complex(rs_ker[j][(j%16+16-i)%16], 0) 158 | } 159 | pl_ker := cont.encoder.EncodeNTTAtLvlNew(ct_avg.Level(), tmp, cont.logN-1) 160 | 161 | if i == 0 { 162 | ct_res = cont.evaluator.MulNew(ct_avg, pl_ker) 163 | } else { 164 | ct_tmp := cont.evaluator.MulNew(ct_avg, pl_ker) 165 | cont.evaluator.Add(ct_res, cont.evaluator.RotateNew(ct_tmp, i*64*8), ct_res) 166 | } 167 | } 168 | 169 | // final rotations to add up (4 = 64/16) 170 | for i := 1; i < 4; i *= 2 { 171 | ct_res = cont.evaluator.AddNew(ct_res, cont.evaluator.RotateNew(ct_res, i*16*64*8)) 172 | } 173 | 174 | tmp := make([]complex128, cont.N/2) 175 | for j := 0; j < 10; j++ { 176 | tmp[j*64*8] = complex(bias[j], 0) 177 | } 178 | pl_bias := cont.encoder.EncodeNTTAtLvlNew(ct_res.Level(), tmp, cont.logN-1) 179 | cont.evaluator.Add(ct_res, pl_bias, ct_res) 180 | 181 | return 182 | } 183 | 184 | // reduce mean and final FC layer 185 | // assume that ct_input has full batch (1,2,3,4,...) 186 | // ker_fc is of size in_batch*out and 1-dim from [in_batch,out] shape 187 | func evalRMFC_BL_img(cont *context, ct_input *ckks.Ciphertext, ker_fc []float64, in_batch, out_num, raw_in_wid int, printResult bool) (ct_res *ckks.Ciphertext) { 188 | rs_ker := make([][]float64, in_batch) 189 | for i := 0; i < in_batch; i++ { 190 | rs_ker[i] = make([]float64, out_num) 191 | for j := 0; j < out_num; j++ { 192 | rs_ker[i][j] = ker_fc[j+i*out_num] / float64(raw_in_wid*raw_in_wid) // we will add 64 elts instead of averaging them 193 | } 194 | } 195 | 196 | // sum 64 elements instead of averaging them (but only 49 elts are non-zero) 197 | ct_avg := ct_input 198 | for i := 1; i < 64; i *= 2 { 199 | ct_avg = cont.evaluator.AddNew(ct_avg, cont.evaluator.RotateNew(ct_avg, i)) 200 | } 201 | 202 | for i := 0; i < in_batch; i++ { 203 | tmp := make([]complex128, cont.N/2) 204 | for j := 0; j < out_num; j++ { 205 | tmp[(i+j)%in_batch*64] = complex(rs_ker[(i+j)%in_batch][j], 0) 206 | } 207 | pl_ker := cont.encoder.EncodeNTTAtLvlNew(ct_avg.Level(), tmp, cont.logN-1) 208 | 209 | if i == 0 { 210 | ct_res = cont.evaluator.MulNew(ct_avg, pl_ker) 211 | } else { 212 | ct_tmp := cont.evaluator.MulNew(ct_avg, pl_ker) 213 | cont.evaluator.Add(ct_res, cont.evaluator.RotateNew(ct_tmp, i*64), ct_res) 214 | } 215 | } 216 | 217 | return 218 | } 219 | 220 | // Eval Conv only, always assume max batch 221 | // in_wid must be Po2 (also include padding), 222 | // include kernel preparation 223 | // norm = 2 : in&out batches are (1,0,2,0,3,0,...) 224 | func evalConv_BN(cont *context, ct_input *ckks.Ciphertext, ker_in, bn_a, bn_b []float64, in_wid, ker_wid, real_ib, real_ob, norm int, out_scale float64, trans bool) (ct_res *ckks.Ciphertext) { 225 | max_batch := cont.N / (in_wid * in_wid) 226 | 227 | // fmt.Println() 228 | // fmt.Println("=============== (KER) PREPARATION ===============") 229 | // fmt.Println() 230 | start := time.Now() 231 | pl_ker := prep_Ker(cont.params, cont.encoder, ker_in, bn_a, in_wid, ker_wid, real_ib, real_ob, norm, cont.ECD_LV, 0, trans) 232 | // fmt.Printf("for prep_ker %s \n", time.Since(start)) 233 | b_coeffs := make([]float64, cont.N) 234 | for i := range bn_b { 235 | for j := 0; j < in_wid*in_wid; j++ { 236 | b_coeffs[norm*i+j*max_batch] = bn_b[i] 237 | } 238 | } 239 | // scale_exp := ct_input.Scale * cont.params.Scale() * float64(max_batch/norm) 240 | pl_bn_b := ckks.NewPlaintext(cont.params, 0, out_scale) 241 | // pl_bn_b := ckks.NewPlaintext(cont.params, cont.ECD_LV, scale_exp) // contain plaintext values 242 | cont.encoder.EncodeCoeffs(b_coeffs, pl_bn_b) 243 | cont.encoder.ToNTT(pl_bn_b) 244 | fmt.Printf("Plaintext (kernel) preparation, Done in %s \n", time.Since(start)) 245 | 246 | // fmt.Println() 247 | // fmt.Println("=============== EVALUATION ===============") 248 | // fmt.Println() 249 | 250 | start = time.Now() 251 | ct_res = conv_then_pack(cont.params, cont.pack_evaluator, ct_input, pl_ker, cont.pl_idx, max_batch, norm, cont.ECD_LV, out_scale) 252 | if (pl_bn_b.Scale != ct_res.Scale) || (ct_res.Level() != 0) { 253 | fmt.Println("plain scale: ", pl_bn_b.Scale) 254 | fmt.Println("ctxt scale: ", ct_res.Scale) 255 | fmt.Println("ctxt lv: ", ct_res.Level()) 256 | panic("LV or scale after conv then pack, inconsistent") 257 | } 258 | cont.evaluator.Add(ct_res, pl_bn_b, ct_res) // for Batch Normalization (BN) 259 | 260 | fmt.Printf("Conv (with BN) Done in %s \n", time.Since(start)) 261 | 262 | return ct_res 263 | } 264 | 265 | // Eval Conv, BN, relu with Boot 266 | // in_wid must be Po2 (also include padding) 267 | // stride = true: apply [1,2,2,1] stride; false: [1,1,1,1] 268 | // pack_pos: position to pack (0,1,2,3): only for strided case 269 | // real_ib, real_ob: real number of batches (less or equal than max_batch) 270 | // step: step of the output (used for conv_inside only) 271 | // log_sparse: 0 if no full slot, 1 if half slot, etc (maybe the same as norm[]) 272 | func evalConv_BNRelu_new(cont *context, ct_input *ckks.Ciphertext, ker_in, bn_a, bn_b []float64, alpha, pow float64, in_wid, kp_wid, ker_wid, real_ib, real_ob, norm, pack_pos, step, iter, log_sparse int, kind string, fast_pack, debug bool) (ct_res *ckks.Ciphertext) { 273 | // iter := 2 // for full packing (contrary to half packing) 274 | var trans, stride, odd, inside bool 275 | odd = false 276 | trans = false 277 | stride = false 278 | inside = false 279 | sparse := false 280 | in_step := step 281 | modify_ker := false 282 | full := false 283 | switch kind { 284 | case "Conv_sparse": // sparse pack, normal conv 285 | sparse = true 286 | case "StrConv_sparse": // sparse pack, strided conv, 2 convs -> add them -> 1 boot 287 | modify_ker = true 288 | sparse = true 289 | stride = true 290 | case "StrConv_sparse_full": // sparse pack but full pack 291 | sparse = true 292 | modify_ker = true 293 | stride = true 294 | full = true 295 | case "Conv_inside": 296 | inside = true 297 | case "StrConv_inside": 298 | in_step = step / 2 299 | if step%2 != 0 { 300 | panic("step can not be divided by 2 (for strided conv)") 301 | } 302 | inside = true 303 | case "StrConv", "StrConv_fast": 304 | stride = true 305 | case "StrConv_odd": 306 | stride = true 307 | odd = true 308 | case "TransConv": 309 | trans = true 310 | case "Conv": 311 | default: 312 | panic("No kind!") 313 | } 314 | 315 | if odd { // multiply x^{offset} before conv to move input to appropriate position. 316 | odd_time := time.Now() 317 | var offset int // offset depends on the real_wid = in_wid-ker_wid/2 is even or not 318 | if (in_wid-ker_wid/2)%2 == 0 { 319 | offset = 0 320 | } else { 321 | offset = cont.N / (in_wid * in_wid) * (in_wid + 1) 322 | // offset = real_ib * norm * (in_wid + 1) 323 | } 324 | // fmt.Println("offset: ", offset) 325 | xi := make([]float64, cont.N) 326 | xi[offset] = 1.0 327 | xi_plain := ckks.NewPlaintext(cont.params, cont.ECD_LV, 1.0) 328 | cont.encoder.EncodeCoeffs(xi, xi_plain) 329 | cont.encoder.ToNTT(xi_plain) 330 | ct_input = cont.evaluator.MulNew(ct_input, xi_plain) 331 | fmt.Printf("for odd stride, offset time %s \n", time.Since(odd_time)) 332 | } 333 | 334 | var ct_conv *ckks.Ciphertext 335 | if modify_ker { 336 | if !full { 337 | bn_a_0 := make([]float64, real_ib) 338 | bn_a_1 := make([]float64, real_ib) 339 | bn_b_0 := make([]float64, real_ib) 340 | bn_b_1 := make([]float64, real_ib) 341 | for i := range bn_b_0 { 342 | bn_a_0[i] = bn_a[2*i] 343 | bn_a_1[i] = bn_a[2*i+1] 344 | bn_b_0[i] = bn_b[2*i] 345 | bn_b_1[i] = bn_b[2*i+1] 346 | } 347 | ker_in_0 := make([]float64, len(ker_in)/2) 348 | ker_in_1 := make([]float64, len(ker_in)/2) 349 | for k := 0; k < ker_wid*ker_wid; k++ { 350 | for i := 0; i < real_ib; i++ { 351 | for j := 0; j < real_ob/2; j++ { 352 | ker_in_0[k*real_ib*real_ob/2+(i*real_ob/2+j)] = ker_in[k*real_ib*real_ob+(i*real_ob+2*j)] // [i][2*j] 353 | ker_in_1[k*real_ib*real_ob/2+(i*real_ob/2+j)] = ker_in[k*real_ib*real_ob+(i*real_ob+2*j+1)] // [i][2*j+1] 354 | } 355 | } 356 | } 357 | ct_result1 := evalConv_BN(cont, ct_input, ker_in_0, bn_a_0, bn_b_0, in_wid, ker_wid, real_ib, real_ob/2, norm/2, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans) 358 | ct_result2 := evalConv_BN(cont, ct_input, ker_in_1, bn_a_1, bn_b_1, in_wid, ker_wid, real_ib, real_ob/2, norm/2, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans) 359 | 360 | xi := make([]float64, cont.N) 361 | offset := norm / 4 // cont.N / norm 362 | xi[offset] = 1.0 363 | xi_plain := ckks.NewPlaintext(cont.params, ct_result2.Level(), 1.0) 364 | cont.encoder.EncodeCoeffs(xi, xi_plain) 365 | cont.encoder.ToNTT(xi_plain) 366 | ct_result2 = cont.evaluator.MulNew(ct_result2, xi_plain) 367 | 368 | ct_conv = cont.evaluator.AddNew(ct_result1, ct_result2) 369 | 370 | // res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv)) 371 | max_batch := cont.N / (in_wid * in_wid) 372 | // prt_mat_norm_step(res_tmp, max_batch, norm/4, step, 1, 4, false) 373 | 374 | for i := range xi { 375 | xi[i] = 0.0 376 | } 377 | if (in_wid-ker_wid/2)%2 != 0 { 378 | xi[0] = 1.0 379 | } else { 380 | // fmt.Println("offset nonzero!") 381 | offset = cont.N - (max_batch)*(in_wid+1) 382 | xi[offset] = -1.0 383 | } 384 | xi_plain = ckks.NewPlaintext(cont.params, ct_conv.Level(), 1.0) 385 | cont.encoder.EncodeCoeffs(xi, xi_plain) 386 | cont.encoder.ToNTT(xi_plain) 387 | ct_conv = cont.evaluator.MulNew(ct_conv, xi_plain) 388 | // res_tmp = cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv)) 389 | // fmt.Println("After offset: ") 390 | // prt_mat_norm_step(res_tmp, max_batch, norm/4, step, 1, 4, false) 391 | } else { // need to cover the case with full packing 392 | ct_conv = evalConv_BN(cont, ct_input, ker_in, bn_a, bn_b, in_wid, ker_wid, real_ib, real_ob, norm, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans) 393 | 394 | // res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv)) 395 | max_batch := cont.N / (in_wid * in_wid) 396 | // prt_mat_norm_step(res_tmp, max_batch, norm, step, 1, 4, false) 397 | xi := make([]float64, cont.N) 398 | for i := range xi { 399 | xi[i] = 0.0 400 | } 401 | var offset int 402 | if (in_wid-ker_wid/2)%2 != 0 { 403 | xi[0] = 1.0 404 | } else { 405 | // fmt.Println("offset nonzero!") 406 | offset = cont.N - (max_batch)*(in_wid+1) 407 | xi[offset] = -1.0 408 | } 409 | xi_plain := ckks.NewPlaintext(cont.params, ct_conv.Level(), 1.0) 410 | cont.encoder.EncodeCoeffs(xi, xi_plain) 411 | cont.encoder.ToNTT(xi_plain) 412 | ct_conv = cont.evaluator.MulNew(ct_conv, xi_plain) 413 | // res_tmp = cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv)) 414 | // fmt.Println("After offset: ") 415 | // prt_mat_norm_step(res_tmp, max_batch, norm, step, 1, 4, false) 416 | } 417 | } else { 418 | if inside { 419 | new_ker_wid := ker_wid*in_step - in_step + 1 420 | new_ker_in := make([]float64, len(ker_in)*new_ker_wid*new_ker_wid/(ker_wid*ker_wid)) 421 | 422 | for i := 0; i < ker_wid; i++ { 423 | for j := 0; j < ker_wid; j++ { 424 | for ib := 0; ib < real_ib; ib++ { 425 | for ob := 0; ob < real_ob; ob++ { 426 | new_ker_in[in_step*i*new_ker_wid*real_ib*real_ob+(in_step*j)*real_ib*real_ob+ib*real_ob+ob] = ker_in[i*ker_wid*real_ib*real_ob+j*real_ib*real_ob+ib*real_ob+ob] 427 | } 428 | } 429 | } 430 | } 431 | ct_conv = evalConv_BN(cont, ct_input, new_ker_in, bn_a, bn_b, in_wid, new_ker_wid, real_ib, real_ob, norm, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans) 432 | } else { 433 | ct_conv = evalConv_BN(cont, ct_input, ker_in, bn_a, bn_b, in_wid, ker_wid, real_ib, real_ob, norm, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans) 434 | } 435 | } 436 | 437 | ct_conv.Scale = ct_conv.Scale * math.Pow(2, pow) 438 | 439 | // Only for checking the correctness (for CtoS) 440 | var slot1, slot2 []complex128 441 | var cfs_preB []float64 442 | if debug { 443 | cfs_preB = cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv)) 444 | } 445 | fmt.Println("Bootstrapping... Ours (until CtoS):") 446 | start := time.Now() 447 | ct_boots := make([]*ckks.Ciphertext, 2) 448 | switch log_sparse { 449 | case 0: 450 | ct_boots[0], ct_boots[1], _ = cont.btp.BootstrappConv_CtoS(ct_conv) 451 | case 1: 452 | ct_boots[0], ct_boots[1], _ = cont.btp2.BootstrappConv_CtoS(ct_conv) 453 | case 2: 454 | ct_boots[0], ct_boots[1], _ = cont.btp3.BootstrappConv_CtoS(ct_conv) 455 | case 3: 456 | ct_boots[0], ct_boots[1], _ = cont.btp4.BootstrappConv_CtoS(ct_conv) 457 | case 4: 458 | ct_boots[0], ct_boots[1], _ = cont.btp5.BootstrappConv_CtoS(ct_conv) 459 | default: 460 | panic("No cases for log_sparse") 461 | } 462 | 463 | fmt.Printf("Done in %s \n", time.Since(start)) 464 | // fmt.Println("after Boot (CtoS): LV = ", ct_boots[0].Level(), " Scale = ", math.Log2(ct_boots[0].Scale)) 465 | 466 | if debug { 467 | slot1, slot2 = debugCtoS(cont, cfs_preB, log_sparse) 468 | slot1 = printDebug(log_sparse, cont.params, ct_boots[0], slot1, cont.decryptor, cont.encoder) // Compare before & after CtoS 469 | slot2 = printDebug(log_sparse, cont.params, ct_boots[1], slot2, cont.decryptor, cont.encoder) // Compare before & after CtoS 470 | } 471 | 472 | start = time.Now() 473 | for ul := 0; ul < iter; ul++ { // up & low parts 474 | if ct_boots[ul] != nil { 475 | ct_boots[ul] = evalReLU(cont.params, cont.evaluator, ct_boots[ul], alpha) 476 | cont.evaluator.MulByPow2(ct_boots[ul], int(pow), ct_boots[ul]) 477 | } 478 | } 479 | fmt.Printf("ReLU Done in %s \n", time.Since(start)) 480 | start = time.Now() 481 | 482 | // Only for checking the correctness (for ReLU) 483 | var cfs_postB []float64 484 | if debug { 485 | fmt.Println("after Relu: ", math.Log2(ct_boots[0].Scale), "lv: ", ct_boots[0].Level()) 486 | relu1, relu2 := debugReLU(cont, slot1, slot2, alpha, pow) 487 | relu1 = printDebug(log_sparse, cont.params, ct_boots[0], relu1, cont.decryptor, cont.encoder) 488 | relu2 = printDebug(log_sparse, cont.params, ct_boots[1], relu2, cont.decryptor, cont.encoder) 489 | cfs_postB = debugStoC(cont, relu1, relu2, in_wid, kp_wid, pack_pos, step, log_sparse, kind, fast_pack) 490 | } 491 | 492 | ct_keep := make([]*ckks.Ciphertext, iter) // for extend (rotation) of ctxt_in 493 | for ul := 0; ul < iter; ul++ { 494 | if trans { 495 | ct_keep[ul] = ext_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.r_idx[in_wid][ul], cont.params) 496 | } else if stride { 497 | if sparse { // we will use ext_double to reduce rotations; hence similar to fast_pack case 498 | if ct_boots[ul] != nil { 499 | if ul == 0 { 500 | ct_keep[ul] = ext_double_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.m_idx[in_wid][pack_pos], cont.r_idx[in_wid][pack_pos], cont.params) 501 | } else { 502 | ct_keep[ul] = ext_double_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.m_idx_l[in_wid][pack_pos], cont.r_idx_l[in_wid][pack_pos], cont.params) 503 | } 504 | } else { 505 | ct_keep[ul] = nil 506 | } 507 | } else { 508 | if fast_pack { 509 | if ul == 0 { 510 | ct_keep[ul] = ext_double_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.m_idx[in_wid][pack_pos], cont.r_idx[in_wid][pack_pos], cont.params) 511 | } else { 512 | ct_keep[ul] = ext_double_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.m_idx_l[in_wid][pack_pos], cont.r_idx_l[in_wid][pack_pos], cont.params) 513 | } 514 | } else { 515 | if ul == 0 { 516 | ct_keep[ul] = ext_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.r_idx[in_wid][pack_pos], cont.params) 517 | } else { 518 | ct_keep[ul] = ext_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.r_idx_l[in_wid][pack_pos], cont.params) 519 | } 520 | } 521 | } 522 | } else if inside { 523 | if ct_boots[ul] != nil { 524 | if sparse { 525 | ct_keep[ul] = keep_ctxt(cont.params, cont.evaluator, cont.encoder, ct_boots[ul], cont.ext_idx[in_wid][ul]) 526 | } else { 527 | ct_keep[ul] = keep_ctxt(cont.params, cont.evaluator, cont.encoder, ct_boots[ul], cont.ext_idx[step][ul]) 528 | } 529 | } else { 530 | ct_keep[ul] = nil 531 | } 532 | } else { 533 | if ct_boots[ul] != nil { 534 | ct_keep[ul] = keep_ctxt(cont.params, cont.evaluator, cont.encoder, ct_boots[ul], cont.ext_idx[in_wid][ul]) 535 | } else { 536 | ct_keep[ul] = nil 537 | } 538 | } 539 | } 540 | 541 | if iter == 1 { 542 | ct_boots[1] = nil 543 | ct_res = cont.btp.BootstrappConv_StoC(ct_keep[0], ct_boots[1]) 544 | if log_sparse != 0 { 545 | panic("we didn't implement this case") 546 | } 547 | } else { 548 | switch log_sparse { 549 | case 0: 550 | ct_res = cont.btp.BootstrappConv_StoC(ct_keep[0], ct_keep[1]) 551 | case 1: 552 | ct_res = cont.btp2.BootstrappConv_StoC(ct_keep[0], ct_keep[1]) 553 | case 2: 554 | ct_res = cont.btp3.BootstrappConv_StoC(ct_keep[0], ct_keep[1]) 555 | case 3: 556 | ct_res = cont.btp4.BootstrappConv_StoC(ct_keep[0], ct_keep[1]) 557 | case 4: 558 | ct_res = cont.btp5.BootstrappConv_StoC(ct_keep[0], ct_keep[1]) 559 | default: 560 | panic("No cases for log_sparse") 561 | } 562 | } 563 | 564 | cont.evaluator.Rescale(ct_res, cont.params.Scale(), ct_res) 565 | fmt.Printf("Boot (StoC) Done in %s \n", time.Since(start)) 566 | 567 | // Only for checking the correctness (for StoC) 568 | if debug { 569 | fmt.Println("Boot out: ") 570 | switch log_sparse { 571 | case 0: 572 | printDebugCfs(cont.params, ct_res, cfs_postB, cont.decryptor, cont.encoder) 573 | case 1: 574 | printDebugCfs(cont.params2, ct_res, cfs_postB, cont.decryptor, cont.encoder) 575 | case 2: 576 | printDebugCfs(cont.params3, ct_res, cfs_postB, cont.decryptor, cont.encoder) 577 | case 3: 578 | printDebugCfs(cont.params4, ct_res, cfs_postB, cont.decryptor, cont.encoder) 579 | case 4: 580 | printDebugCfs(cont.params5, ct_res, cfs_postB, cont.decryptor, cont.encoder) 581 | default: 582 | panic("No cases for log_sparse") 583 | } 584 | max_batch := cont.N / (in_wid * in_wid) 585 | res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_res)) 586 | if inside { 587 | start := 1 588 | if ker_wid == 5 { 589 | start = step 590 | } 591 | prt_mat_norm_step(res_tmp, max_batch, norm, step, start, 3, false) 592 | } else { 593 | if stride { 594 | max_batch = 4 * cont.N / (in_wid * in_wid) 595 | start := 1 596 | if ker_wid == 5 { 597 | start = step 598 | } 599 | prt_mat_norm_step(res_tmp, max_batch, norm, step, start, 3, false) 600 | } else { 601 | prt_mat_norm(res_tmp, max_batch, norm, 3, false) 602 | } 603 | } 604 | } 605 | 606 | return ct_res 607 | } 608 | 609 | // log_spars = 0 -> full slot, 1 -> full/2 , ... 610 | func debugCtoS(cont *context, cfs_preB []float64, log_sparse int) (slot1, slot2 []complex128) { 611 | preB_cfs1 := make([]float64, cont.params.Slots()) 612 | preB_cfs2 := make([]float64, cont.params.Slots()) 613 | slot1 = make([]complex128, cont.params.Slots()/(1< 10) || (i_batch > 3) { 595 | panic("Too many tests (>10) or too many batch index (>3)") 596 | } 597 | case "convReLU": 598 | boot = true 599 | resnet = false 600 | if (num_tests > 10) || (i_batch > 3) { 601 | panic("Too many tests (>10) or too many batch index (>3)") 602 | } 603 | case "resnet": 604 | resnet = true 605 | default: 606 | panic("wrong test type") 607 | } 608 | 609 | if resnet { 610 | // // latest version for resnet crop cifar10 611 | ker_wid, _ := strconv.Atoi(os.Args[2]) 612 | depth, _ := strconv.Atoi(os.Args[3]) 613 | wide_case, _ := strconv.Atoi(os.Args[4]) 614 | test_num, _ := strconv.Atoi(os.Args[5]) 615 | cf100, _ := strconv.ParseBool(os.Args[6]) 616 | 617 | debug := false // if turned on, it shows all intermediate input 618 | if wide_case == 1 { 619 | // test with small inputs 620 | testResNet_crop_sparse(0, test_num, ker_wid, depth, debug, cf100) 621 | // end test with small inputs 622 | // testResNet_crop_fast_in(0, test_num, ker_wid, depth, debug, cf100) 623 | } else if (wide_case == 2) || (wide_case == 3) { 624 | testResNet_crop_sparse_wide(0, test_num, ker_wid, depth, wide_case, debug, cf100) 625 | // testResNet_crop_sparse_wide_test(0, test_num, ker_wid, depth, wide_case, debug, cf100) 626 | // testResNet_crop_fast_wide_in(0, test_num, ker_wid, depth, wide_case, debug, cf100) 627 | } else { 628 | panic("Wrong wide case!") 629 | } 630 | 631 | } else { 632 | if boot { 633 | fmt.Println("Convolution followed by ReLU (& Bootstrapping) test start!") 634 | } else { 635 | fmt.Println("Convolution test start! (No Bootstrapping)") 636 | } 637 | fmt.Println("Ker: ", ker_wid, "batches: ", batchs[i_batch], "widths: ", widths[i_batch]) 638 | 639 | fmt.Println("Base Line start.") 640 | testConv_BL_in(batchs[i_batch], widths[i_batch], ker_wid, num_tests, boot) 641 | 642 | fmt.Println("Ours start.") 643 | testConv_in(batchs[i_batch], widths[i_batch], ker_wid, num_tests, boot) 644 | } 645 | } 646 | 647 | func printDebugCfs(params ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant []float64, decryptor ckks.Decryptor, encoder ckks.Encoder) (valuesTest []float64) { 648 | total_size := make([]int, 15) 649 | 650 | valuesTest_pre := encoder.DecodeCoeffs(decryptor.DecryptNew(ciphertext)) 651 | valuesTest = make([]float64, 2*params.Slots()) 652 | step := len(valuesTest_pre) / len(valuesTest) // to cover cases with less slots (N/2 => 1, N/4 => 2, ...) 653 | for i := range valuesTest { 654 | valuesTest[i] = valuesTest_pre[i*step] 655 | } 656 | 657 | fmt.Println("len val Want:", len(valuesWant)) 658 | fmt.Println("len val Test:", len(valuesTest)) 659 | 660 | fmt.Println() 661 | fmt.Printf("Level: %d (logQ = %d)\n", ciphertext.Level(), params.LogQLvl(ciphertext.Level())) 662 | fmt.Printf("Scale: 2^%f\n", math.Log2(ciphertext.Scale)) 663 | fmt.Printf("ValuesTest:") 664 | for i := range total_size { 665 | fmt.Printf("%6.10f, ", valuesTest[i]) 666 | } 667 | fmt.Printf("... \n") 668 | fmt.Printf("ValuesWant:") 669 | for i := range total_size { 670 | fmt.Printf("%6.10f, ", valuesWant[i*step]) 671 | } 672 | fmt.Printf("... \n") 673 | 674 | valuesTestC := make([]complex128, len(valuesTest)) 675 | valuesWantC := make([]complex128, len(valuesWant)/step) 676 | 677 | for i := range valuesTestC { 678 | valuesTestC[i] = complex(valuesTest[i], 0) 679 | valuesWantC[i] = complex(valuesWant[i*step], 0) 680 | } 681 | 682 | precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWantC[:params.Slots()], valuesTestC[:params.Slots()], params.LogSlots(), 0) 683 | 684 | fmt.Println(precStats.String()) 685 | 686 | precStats = ckks.GetPrecisionStats(params, encoder, nil, valuesWantC[params.Slots():], valuesTestC[params.Slots():], params.LogSlots(), 0) 687 | 688 | fmt.Println(precStats.String()) 689 | fmt.Println() 690 | 691 | return 692 | } 693 | 694 | func printDebugCfsPlain(valuesTest, valuesWant []float64) { 695 | total_size := make([]int, 10) 696 | 697 | fmt.Printf("ValuesTest:") 698 | for i := range total_size { 699 | fmt.Printf("%6.10f, ", valuesTest[i]) 700 | } 701 | fmt.Printf("... \n") 702 | fmt.Printf("ValuesWant:") 703 | for i := range total_size { 704 | fmt.Printf("%6.10f, ", valuesWant[i]) 705 | } 706 | fmt.Printf("... \n") 707 | 708 | valuesWantC := make([]complex128, len(valuesWant)) 709 | valuesTestC := make([]complex128, len(valuesTest)) 710 | for i := range valuesWantC { 711 | valuesWantC[i] = complex(valuesWant[i], 0) 712 | valuesTestC[i] = complex(valuesTest[i], 0) 713 | } 714 | precStats := ckks.GetPrecisionStatsPlain(valuesWantC, valuesTestC, len(valuesWantC), 0) 715 | fmt.Println(precStats.String()) 716 | fmt.Println() 717 | } 718 | 719 | // decrypt ciphertext then compare with valuesWant, then output the msgs to valuesTest 720 | // log_sparse = 0 -> full slot & TWO ciphertexts 721 | // log_sparse = 1 -> full/2 & ONE ciphertext 722 | func printDebug(log_sparse int, params ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant []complex128, decryptor ckks.Decryptor, encoder ckks.Encoder) (valuesTest []complex128) { 723 | total_size := make([]int, 15) 724 | if ciphertext == nil { 725 | return nil 726 | // valuesTest = make([]complex128, params.Slots()/(1< in_wid)) && ((j <= show) || (j+show > in_wid))) { 817 | fmt.Printf("(%d, %d): ", i, j) 818 | for b := 0; b < batch; b++ { 819 | tmp[b] = real(vec[in_wid*in_wid*b+(i-1)*in_wid+(j-1)]) 820 | } 821 | prt_vec(tmp) 822 | } 823 | } 824 | } 825 | } 826 | 827 | // vec = arrgvec with batch batches, each batch is sqr-sized 828 | // print (i,j)-th position in [batches], only shows (show, show) entries show = 0 : print all 829 | func prt_mat(vec []float64, batch, show int) { 830 | mat_size := len(vec) / batch 831 | row := int(math.Sqrt(float64(mat_size))) 832 | j, k := 1, 1 833 | for i := 0; i < len(vec); i += batch { 834 | if (show == 0) || (((j <= show) || (j > row-show)) && ((k <= show) || (k > (row - show)))) { 835 | fmt.Printf("(%d, %d): ", j, k) 836 | prt_vec(vec[i : i+batch]) 837 | } 838 | k++ 839 | if k*k > mat_size { 840 | k = 1 841 | j++ 842 | } 843 | } 844 | } 845 | 846 | // vec = arrgvec with batch batches, each batch is sqr-sized 847 | // print (i,j)-th position in [batches], only shows (show, show) entries show = 0 : print all 848 | func prt_mat_norm(vec []float64, batch, norm, show int, half bool) { 849 | mat_size := len(vec) / batch 850 | row := int(math.Sqrt(float64(mat_size))) 851 | if half { 852 | row = row / 2 853 | } 854 | tmp := make([]float64, batch/norm) 855 | j, k := 1, 1 856 | for i := 0; i < len(vec); i += batch { 857 | if (show == 0) || (((j <= show) || ((j > row-show) && (j <= row))) && ((k <= show) || ((k > row-show) && (k <= row)))) { 858 | fmt.Printf("(%d, %d): ", j, k) 859 | for idx := range tmp { 860 | tmp[idx] = vec[i+norm*idx] 861 | } 862 | prt_vec(tmp) 863 | } 864 | k++ 865 | if k*k > mat_size { 866 | k = 1 867 | j++ 868 | } 869 | } 870 | } 871 | 872 | // vec = arrgvec with batch batches, each batch is sqr-sized 873 | // print (i,j)-th position in [batches], only shows (show, show) entries show = 0 : print all 874 | // input is strided with steps, read from start 875 | func prt_mat_norm_step(vec []float64, batch, norm, step, start, show int, half bool) { 876 | mat_size := len(vec) / batch 877 | row := int(math.Sqrt(float64(mat_size))) 878 | if half { 879 | row = row / 2 880 | } 881 | tmp := make([]float64, batch/norm) 882 | j, k := 1, 1 883 | for i := 0; i < len(vec); i += batch { 884 | if (show == 0) || (((j <= show*step) || ((j > row-show*step) && (j <= row))) && ((k <= show*step) || ((k > row-show*step) && (k <= row)))) { 885 | if ((j-start)%step == 0) && ((k-start)%step == 0) { 886 | fmt.Printf("(%d, %d): ", (j-start)/step+1, (k-start)/step+1) 887 | for idx := range tmp { 888 | tmp[idx] = vec[i+norm*idx] 889 | } 890 | prt_vec(tmp) 891 | } 892 | } 893 | k += 1 894 | if k*k > mat_size { 895 | k = 1 896 | j += 1 897 | } 898 | } 899 | } 900 | 901 | // only (sj,sk) element in all batches 902 | func prt_mat_one(vec []float64, batch, sj, sk int) (out []float64) { 903 | mat_size := len(vec) / batch 904 | j, k := 1, 1 905 | for i := 0; i < len(vec); i += batch { 906 | if (j == sj) && (k == sk) { 907 | fmt.Print(vec[i : i+batch]) 908 | out = vec[i : i+batch] 909 | } 910 | k++ 911 | if k*k > mat_size { 912 | k = 1 913 | j++ 914 | } 915 | } 916 | return out 917 | } 918 | 919 | // only (sj,sk) element in all batches 920 | func prt_mat_one_norm(vec []float64, batch, norm, sj, sk int) (out []float64) { 921 | mat_size := len(vec) / batch 922 | tmp := make([]float64, batch/norm) 923 | j, k := 1, 1 924 | for i := 0; i < len(vec); i += batch { 925 | if (j == sj) && (k == sk) { 926 | for idx := range tmp { 927 | tmp[idx] = vec[i+norm*idx] 928 | } 929 | prt_vec(tmp) 930 | out = tmp 931 | } 932 | k++ 933 | if k*k > mat_size { 934 | k = 1 935 | j++ 936 | } 937 | } 938 | return out 939 | } 940 | 941 | // only out_num, (1,1) element in all batches (1,0,0,0,0,0,0,0,2,0,0,0,0,0,...) 942 | func prt_mat_one_BL(vec []complex128, max_bat, out_num int) (out []float64) { 943 | mat_size := len(vec) / max_bat 944 | out = make([]float64, out_num) 945 | 946 | for i := range out { 947 | out[i] = real(vec[i*mat_size*8]) 948 | } 949 | 950 | return out 951 | } 952 | 953 | // only out_num, (1,1) element in all batches (1,0,0,0,0,0,0,0,2,0,0,0,0,0,...) 954 | func prt_mat_one_BL_img(vec []complex128, max_bat, out_num int) (out []float64) { 955 | mat_size := len(vec) / max_bat 956 | out = make([]float64, out_num) 957 | 958 | for i := range out { 959 | out[i] = real(vec[i*mat_size]) 960 | } 961 | 962 | return out 963 | } 964 | 965 | func check(e error) { 966 | if e != nil { 967 | panic(e) 968 | } 969 | } 970 | 971 | func readTxt(name_file string, size int) (input []float64) { 972 | 973 | file, err := os.Open(name_file) 974 | check(err) 975 | scanner := bufio.NewScanner(file) 976 | scanner.Split(bufio.ScanWords) 977 | 978 | for scanner.Scan() { 979 | add, _ := strconv.ParseFloat(scanner.Text(), 64) 980 | input = append(input, add) 981 | } 982 | file.Close() 983 | // fmt.Print(input) 984 | 985 | if (size != 0) && (len(input) != size) { 986 | panic("input size inconsistent!") 987 | } 988 | 989 | return input 990 | } 991 | 992 | func writeTxt(name_file string, input []float64) { 993 | file, err := os.OpenFile(name_file, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644) 994 | if err != nil { 995 | log.Fatalf("failed creating file: %s", err) 996 | } 997 | 998 | datawriter := bufio.NewWriter(file) 999 | for _, data := range input { 1000 | _, _ = datawriter.WriteString(strconv.FormatFloat(data, 'e', -1, 64) + "\n") 1001 | } 1002 | 1003 | datawriter.Flush() 1004 | file.Close() 1005 | } 1006 | 1007 | func prep_Input(input []float64, raw_in_wid, in_wid, N, norm int, trans, printResult bool) (out []float64) { 1008 | out = make([]float64, N) 1009 | batch := N / (in_wid * in_wid) 1010 | k := 0 1011 | 1012 | if trans { 1013 | for i := 0; i < in_wid/2; i++ { 1014 | for j := 0; j < in_wid/2; j++ { 1015 | for b := 0; b < batch/norm; b++ { 1016 | if (i < raw_in_wid) && (j < raw_in_wid) { 1017 | out[(2*i+1)*in_wid*batch+(2*j+1)*batch+b*norm] = input[k] 1018 | k++ 1019 | } 1020 | } 1021 | } 1022 | } 1023 | } else { 1024 | for i := 0; i < in_wid; i++ { 1025 | for j := 0; j < in_wid; j++ { 1026 | for b := 0; b < batch/norm; b++ { 1027 | if (i < raw_in_wid) && (j < raw_in_wid) { 1028 | out[i*in_wid*batch+j*batch+b*norm] = input[k] 1029 | k++ 1030 | } 1031 | } 1032 | } 1033 | } 1034 | } 1035 | 1036 | if printResult { 1037 | fmt.Println("Input matrix: ") 1038 | prt_mat(out, batch, 3) 1039 | } 1040 | 1041 | return out 1042 | } 1043 | 1044 | func removeDuplicateInt(intSlice []int) []int { 1045 | allKeys := make(map[int]bool) 1046 | list := []int{} 1047 | for _, item := range intSlice { 1048 | if _, value := allKeys[item]; !value { 1049 | allKeys[item] = true 1050 | list = append(list, item) 1051 | } 1052 | } 1053 | return list 1054 | } 1055 | 1056 | // only returns valid values from 1057 | func post_process(in_cfs []float64, raw_in_wid, in_wid int) []float64 { 1058 | batch := len(in_cfs) / (in_wid * in_wid) 1059 | out := make([]float64, raw_in_wid*raw_in_wid*batch) 1060 | 1061 | for i := 0; i < raw_in_wid; i++ { 1062 | for j := 0; j < raw_in_wid; j++ { 1063 | for b := 0; b < batch; b++ { 1064 | out[i*raw_in_wid*batch+batch*j+b] = in_cfs[i*in_wid*batch+batch*j+b] 1065 | } 1066 | } 1067 | } 1068 | 1069 | return out 1070 | } 1071 | 1072 | // from 8*8 with 1 pad -> 7*7 1073 | func post_trim_BL(in_vals []complex128, raw_in_wid, in_wid int) []float64 { 1074 | batch := len(in_vals) / (in_wid * in_wid) 1075 | out := make([]float64, raw_in_wid*raw_in_wid*batch) 1076 | 1077 | for b := 0; b < batch; b++ { 1078 | for i := 0; i < raw_in_wid; i++ { 1079 | for j := 0; j < raw_in_wid; j++ { 1080 | out[b*raw_in_wid*raw_in_wid+i*raw_in_wid+j] = real(in_vals[b*in_wid*in_wid+i*in_wid+j]) 1081 | } 1082 | } 1083 | } 1084 | 1085 | return out 1086 | } 1087 | 1088 | // only returns valid values from 1089 | func post_process_BL(in_vals []float64, raw_in_wid int) []float64 { 1090 | batch := len(in_vals) / (raw_in_wid * raw_in_wid) 1091 | out := make([]float64, raw_in_wid*raw_in_wid*batch) 1092 | 1093 | for i := 0; i < raw_in_wid; i++ { 1094 | for j := 0; j < raw_in_wid; j++ { 1095 | for b := 0; b < batch; b++ { 1096 | out[i*raw_in_wid*batch+j*batch+b] = in_vals[b*raw_in_wid*raw_in_wid+i*raw_in_wid+j] 1097 | } 1098 | } 1099 | } 1100 | 1101 | return out 1102 | } 1103 | -------------------------------------------------------------------------------- /rot_util.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | ) 7 | 8 | func toFixed(num float64, precision int) float64 { 9 | output := math.Pow(10, float64(precision)) 10 | return float64(math.Round(num*output)) / output 11 | } 12 | 13 | // distribute input to the output starting from pos position 14 | func arrgvec(input []int, output []int, pos int) { 15 | batch := len(output) / len(input) 16 | for i, elt := range input { 17 | output[pos+i*batch] = elt 18 | } 19 | } 20 | 21 | func print_vec(title string, input []float64, in_wid int, pos int) { 22 | row := make([]float64, in_wid) 23 | step := len(input) / (in_wid * in_wid) 24 | fmt.Println(title, ": ") 25 | for j := 0; j < in_wid; j++ { 26 | for i := range row { 27 | row[i] = toFixed(input[(j*in_wid+i)*step+pos], 3) 28 | } 29 | fmt.Println(row) 30 | } 31 | fmt.Println() 32 | 33 | } 34 | 35 | func lRot(a []float64, rotation int) []float64 { 36 | size := len(a) 37 | var newArray []float64 38 | for i := 0; i < rotation; i++ { 39 | newArray = a[1:size] 40 | newArray = append(newArray, a[0]) 41 | a = newArray 42 | } 43 | return a 44 | } 45 | 46 | func rRot(a []float64, rotation int) []float64 { 47 | return lRot(a, len(a)-rotation) 48 | } 49 | 50 | func addSlice(a []float64, b []float64) []float64 { 51 | c := make([]float64, len(a)) 52 | for i := range a { 53 | c[i] = a[i] + b[i] 54 | } 55 | return c 56 | } 57 | 58 | // (bit-reversed) input vector := (upper or lower part) of the total vector having in_wid * in_wid size elts 59 | // Keep only the kp_wid*kp_wid values 60 | // e.g., 1* // ** -> 10 // 00 (before bitreversed, pad = 1) 61 | // ul: up or low 62 | // assume N/2 sized input 63 | func keep_vec(input []float64, in_wid, kp_wid, ul int) []float64 { 64 | output := make([]float64, len(input)) 65 | 66 | tmp := gen_keep_vec(len(input), in_wid, kp_wid, ul) 67 | 68 | for i := range output { 69 | output[i] = input[i] * float64(tmp[i]) 70 | } 71 | 72 | return output 73 | } 74 | 75 | func keep_vec_stride(input []float64, in_wid, kp_wid, step, ul int, raw_in_wid_odd bool) []float64 { 76 | output := make([]float64, len(input)) 77 | 78 | tmp := gen_keep_vec_stride(len(input), in_wid, kp_wid, step, ul, raw_in_wid_odd) 79 | 80 | for i := range output { 81 | output[i] = input[i] * float64(tmp[i]) 82 | } 83 | 84 | return output 85 | } 86 | 87 | func keep_vec_sparse(input []float64, in_wid, kp_wid, log_sparse int) []float64 { 88 | output := make([]float64, len(input)) 89 | 90 | tmp := gen_keep_vec_sparse(len(input), in_wid, kp_wid, log_sparse) 91 | 92 | for i := range output { 93 | output[i] = input[i] * float64(tmp[i]) 94 | } 95 | 96 | return output 97 | } 98 | 99 | func comprs_vec_sparse(input []float64, in_wid, kp_wid, log_sparse, ul, pos int) []float64 { 100 | 101 | tmp1, tmp2 := gen_comprs_sparse(len(input), in_wid, kp_wid, log_sparse, ul, pos) 102 | 103 | mid_output := make([]float64, len(input)) 104 | for i := range tmp1 { 105 | rot_input := make([]float64, len(input)) 106 | for j := range rot_input { 107 | rot_input[j] = input[j] * float64(tmp1[i][j]) 108 | } 109 | rot_input = lRot(rot_input, i) 110 | if i < 0 { 111 | rot_input = rRot(rot_input, -i) 112 | } 113 | 114 | for j := range mid_output { 115 | mid_output[j] = rot_input[j] + mid_output[j] 116 | } 117 | } 118 | 119 | output := make([]float64, len(input)) 120 | for i := range tmp2 { 121 | rot_input := make([]float64, len(input)) 122 | for j := range rot_input { 123 | rot_input[j] = mid_output[j] * float64(tmp2[i][j]) 124 | } 125 | rot_input = lRot(rot_input, i) 126 | if i < 0 { 127 | rot_input = rRot(rot_input, -i) 128 | } 129 | 130 | for j := range output { 131 | output[j] = rot_input[j] + output[j] 132 | } 133 | } 134 | 135 | return output 136 | } 137 | 138 | // returns the idx for keep_vec 139 | // N: length of input (upper + lower) 140 | // ul = 0 -> upper part, ul = 1 -> lower part 141 | func gen_keep_vec(vec_size, in_wid, kp_wid, ul int) (idx []int) { 142 | logN := 0 143 | for ; (1 << logN) < (2 * vec_size); logN++ { 144 | } 145 | idx = make([]int, vec_size) 146 | batch := 2 * vec_size / (in_wid * in_wid) 147 | if kp_wid < in_wid/2 { 148 | panic("keep width too small. less than in_wid/2") 149 | } 150 | 151 | if ul == 0 { 152 | for i := 0; i < in_wid/2; i++ { 153 | for j := 0; j < kp_wid; j++ { 154 | for b := 0; b < batch; b++ { 155 | id := int(reverseBits(uint32(in_wid*batch*i+batch*j+b), logN-1)) 156 | idx[id] = 1 157 | } 158 | } 159 | } 160 | } else if ul == 1 { 161 | for i := 0; i < kp_wid-in_wid/2; i++ { 162 | for j := 0; j < kp_wid; j++ { 163 | for b := 0; b < batch; b++ { 164 | id := int(reverseBits(uint32(in_wid*batch*i+batch*j+b), logN-1)) 165 | idx[id] = 1 166 | } 167 | } 168 | } 169 | } else { 170 | panic("ul not 0 nor 1") 171 | } 172 | 173 | return idx 174 | } 175 | 176 | // returns the idx for keep_vec given that we have sparse pack with log_sparse (=0 if full, 1 if half sparse pack) 177 | // Always assume log_sparse >= 1 so that both up and down part is in one ciphertext 178 | // N: length of input (upper + lower), vec_size is always N/2! 179 | func gen_keep_vec_sparse(vec_size, in_wid, kp_wid, log_sparse int) (idx []int) { 180 | logN := 0 181 | for ; (1 << logN) < (2 * vec_size); logN++ { 182 | } 183 | idx = make([]int, vec_size) 184 | batch := 2 * vec_size / (in_wid * in_wid) 185 | sparsity := 1 << log_sparse 186 | if sparsity == 1 { 187 | panic("We do not support full packing in gen_keep_vec_sparse") 188 | } 189 | if kp_wid < in_wid/2 { 190 | panic("keep width too small. less than in_wid/2") 191 | } 192 | 193 | for i := 0; i < in_wid/2; i++ { 194 | for j := 0; j < kp_wid; j++ { 195 | for b := 0; b < batch/sparsity; b++ { 196 | id := int(reverseBits(uint32(in_wid*batch*i+batch*j+b*sparsity), logN-1)) 197 | idx[id] = 1 198 | } 199 | } 200 | } 201 | for i := 0; i < kp_wid-in_wid/2; i++ { 202 | for j := 0; j < kp_wid; j++ { 203 | for b := 0; b < batch/sparsity; b++ { 204 | id := int(reverseBits(uint32(in_wid*batch*i+batch*j+b*sparsity), logN-1)) + vec_size/sparsity 205 | idx[id] = 1 206 | } 207 | } 208 | } 209 | 210 | post_slot := 2 * len(idx) / sparsity 211 | for i := 0; i < post_slot; i++ { 212 | for j := 1; j < sparsity/2; j++ { 213 | idx[i+post_slot*j] = idx[i] 214 | } 215 | } 216 | 217 | return idx 218 | } 219 | 220 | // returns the idx for keep_vec // same as gen_keep_vec but keeps strided output only 221 | // N: length of input (upper + lower) 222 | // ul = 0 -> upper part, ul = 1 -> lower part 223 | // kp_wid: number of elements to keep (in each row) 224 | // step: distance between each element 225 | // raw_in_wid_odd: raw_in_wid is odd or not 226 | func gen_keep_vec_stride(vec_size, in_wid, kp_wid, step, ul int, raw_in_wid_odd bool) (idx []int) { 227 | logN := 0 228 | for ; (1 << logN) < (2 * vec_size); logN++ { 229 | } 230 | idx = make([]int, vec_size) 231 | batch := 2 * vec_size / (in_wid * in_wid) 232 | 233 | var init int 234 | if raw_in_wid_odd { 235 | init = 0 236 | } else { 237 | init = step - 1 238 | } 239 | 240 | if ul == 0 { 241 | for i := 0; i < kp_wid; i++ { 242 | if (init + i*step) < in_wid/2 { 243 | for j := 0; j < kp_wid; j++ { 244 | for b := 0; b < batch; b++ { 245 | id := int(reverseBits(uint32(in_wid*batch*(init+i*step)+batch*(j*step+init)+b), logN-1)) 246 | idx[id] = 1 247 | } 248 | } 249 | } 250 | } 251 | } else if ul == 1 { 252 | for i := 0; i < kp_wid; i++ { 253 | if (init + i*step) >= in_wid/2 { 254 | for j := 0; j < kp_wid; j++ { 255 | for b := 0; b < batch; b++ { 256 | id := int(reverseBits(uint32(in_wid*batch*(init+i*step-in_wid/2)+batch*(j*step+init)+b), logN-1)) 257 | idx[id] = 1 258 | } 259 | } 260 | } 261 | } 262 | } else { 263 | panic("ul not 0 nor 1") 264 | } 265 | 266 | return idx 267 | } 268 | 269 | // Assume N/2 input vector 270 | // reverse of extend_full (after strided conv -> normal) 271 | // in_wid = input wid including padding 272 | // kp_wid = keep wid 273 | // padding = true: only keeps valid elements // (output) e.g., 12 00 // 34 00 // 00 00 // 00 00 274 | // padding = false: keeps all elements // (output) e.g., 12 // 34 275 | // 0 <= pos < 4 determines to which part the output is positioned at the final output 276 | // ul : up (0) or low (1) part 277 | func comprs_full(input []float64, in_wid, kp_wid, pos, ul int) []float64 { 278 | output := make([]float64, len(input)) 279 | batch := 2 * len(input) / (in_wid * in_wid) 280 | if kp_wid < in_wid/2 { 281 | panic("keep width too small. less than in_wid/2") 282 | } 283 | pos = int(reverseBits(uint32(pos), 2)) 284 | padding := false 285 | min_wid := in_wid / 4 286 | if in_wid%4 != 0 { 287 | panic("input wid not divisible by 4") 288 | } 289 | if in_wid%2 != 0 { 290 | panic("input wid not divisible by 2") 291 | } 292 | log_in_wid := 0 293 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ { 294 | } 295 | 296 | if padding { 297 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j 298 | tmp := make([]float64, len(input)) 299 | for b := 0; b < batch; b++ { 300 | for i := 0; i < min_wid; i++ { 301 | idx := 2*min_wid*in_wid*b + in_wid*j + i + min_wid*in_wid + min_wid 302 | tmp[idx] = input[idx] 303 | } 304 | } 305 | rot := -2*j*min_wid + 2*pos*min_wid*min_wid - min_wid*in_wid - min_wid 306 | output = addSlice(output, rRot(tmp, rot)) 307 | } 308 | // // when we want to extract even positioned inputs 309 | // for j := 0; j < min_wid; j++ { // kinds of mov depends on j 310 | // tmp := make([]int, len(input)) 311 | // for b := 0; b < batch; b++ { 312 | // for i := 0; i < min_wid; i++ { 313 | // idx := 2*min_wid*in_wid*b + in_wid*j + i 314 | // tmp[idx] = input[idx] 315 | // } 316 | // } 317 | // rot := -2*j*min_wid + 2*pos*min_wid*min_wid 318 | // output = addSlice(output, rRot(tmp, rot)) 319 | // } 320 | } else { 321 | if ul == 0 { 322 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j 323 | tmp := make([]float64, len(input)) 324 | for b := 0; b < batch; b++ { 325 | for i := 0; i < min_wid; i++ { 326 | if reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid) { 327 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid 328 | tmp[idx] = input[idx] 329 | } 330 | } 331 | } 332 | rot := -j*min_wid + 2*pos*min_wid*min_wid - min_wid - in_wid*min_wid 333 | output = addSlice(output, rRot(tmp, rot)) 334 | } 335 | } else { 336 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j 337 | tmp := make([]float64, len(input)) 338 | for b := 0; b < batch; b++ { 339 | for i := 0; i < min_wid; i++ { 340 | if (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) && (reverseBits(uint32(3*min_wid+i), log_in_wid-1) < uint32(kp_wid-in_wid/2)) { 341 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid 342 | tmp[idx] = input[idx] 343 | } 344 | } 345 | } 346 | rot := -j*min_wid + 2*pos*min_wid*min_wid - min_wid - in_wid*min_wid 347 | output = addSlice(output, rRot(tmp, rot)) 348 | } 349 | } 350 | // // when we want to extract even positioned inputs 351 | // for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j 352 | // tmp := make([]int, len(input)) 353 | // for b := 0; b < batch; b++ { 354 | // for i := 0; i < min_wid; i++ { 355 | // idx := 2*min_wid*in_wid*b + 2*min_wid*j + i 356 | // tmp[idx] = input[idx] 357 | // } 358 | // } 359 | // rot := -j*min_wid + 2*pos*min_wid*min_wid 360 | // output = addSlice(output, rRot(tmp, rot)) 361 | // } 362 | } 363 | 364 | return output 365 | } 366 | 367 | // Assume N/2 input vector 368 | // reverse of extend_full (after strided conv -> normal) 369 | // in_wid = input wid including padding 370 | // kp_wid = keep wid 371 | // 0 <= pos < 4 determines to which part the output is positioned at the final output 372 | // ul : up (0) or low (1) part 373 | func comprs_full_fast(input []float64, in_wid, kp_wid, pos, ul int) []float64 { 374 | mid_out := make([]float64, len(input)) 375 | output := make([]float64, len(input)) 376 | batch := 2 * len(input) / (in_wid * in_wid) 377 | if kp_wid < in_wid/2 { 378 | panic("keep width too small. less than in_wid/2") 379 | } 380 | pos = int(reverseBits(uint32(pos), 2)) 381 | min_wid := in_wid / 4 382 | if in_wid%4 != 0 { 383 | panic("input wid not divisible by 4") 384 | } 385 | if in_wid%2 != 0 { 386 | panic("input wid not divisible by 2") 387 | } 388 | log_in_wid := 0 389 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ { 390 | } 391 | 392 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j 393 | tmp := make([]float64, len(input)) 394 | for b := 0; b < batch; b++ { 395 | for i := 0; i < min_wid; i++ { 396 | if (ul == 0) && (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) { 397 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid 398 | tmp[idx] = input[idx] 399 | } 400 | if (ul == 1) && (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) && (reverseBits(uint32(min_wid+i), log_in_wid-1) < uint32(kp_wid-in_wid/2)) { 401 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid 402 | tmp[idx] = input[idx] 403 | } 404 | } 405 | } 406 | rot := -j*min_wid + 2*min_wid*min_wid - min_wid 407 | mid_out = addSlice(mid_out, rRot(tmp, rot)) 408 | } 409 | for b := 0; b < batch; b++ { 410 | tmp := make([]float64, len(input)) 411 | for j := 0; j < 2*min_wid; j++ { 412 | for i := 0; i < min_wid; i++ { 413 | idx := 2*min_wid*in_wid*b + 3*in_wid/2*min_wid + j*min_wid + i 414 | tmp[idx] = mid_out[idx] 415 | } 416 | } 417 | rot := -3*b*min_wid*in_wid/2 + pos*min_wid*in_wid/2*batch - 3*min_wid*in_wid/2 418 | output = addSlice(output, rRot(tmp, rot)) 419 | } 420 | 421 | return output 422 | } 423 | 424 | // generate vectors for comprs_full (N/2 input) 425 | // returns the idx and rotations for each idx For comprs_full_hf 426 | // vec_size = slots, in_wid = real in_wid including padding, 427 | // CAUTION: rotation = -rotation (of comprs_full_hf) 428 | func gen_comprs_full(vec_size, in_wid, kp_wid, pos, ul int) (r_idx map[int][]int) { 429 | r_idx = make(map[int][]int) 430 | batch := 2 * vec_size / (in_wid * in_wid) 431 | if kp_wid < in_wid/2 { 432 | panic("keep width too small. less than in_wid/2") 433 | } 434 | pos = int(reverseBits(uint32(pos), 2)) 435 | padding := false 436 | min_wid := in_wid / 4 437 | if in_wid%4 != 0 { 438 | panic("input wid not divisible by 4") 439 | } 440 | if in_wid%2 != 0 { 441 | panic("input wid not divisible by 2") 442 | } 443 | log_in_wid := 0 444 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ { 445 | } 446 | 447 | if padding { 448 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j 449 | tmp := make([]int, vec_size) 450 | for b := 0; b < batch; b++ { 451 | for i := 0; i < min_wid; i++ { 452 | idx := 2*min_wid*in_wid*b + in_wid*j + i + min_wid*in_wid + min_wid 453 | tmp[idx] = 1 454 | } 455 | } 456 | rot := 2*j*min_wid - 2*pos*min_wid*min_wid + min_wid*in_wid + min_wid 457 | r_idx[rot] = tmp 458 | } 459 | } else { 460 | if ul == 0 { 461 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j 462 | tmp := make([]int, vec_size) 463 | for b := 0; b < batch; b++ { 464 | if reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid) { 465 | for i := 0; i < min_wid; i++ { 466 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid 467 | tmp[idx] = 1 468 | } 469 | } 470 | } 471 | rot := j*min_wid - 2*pos*min_wid*min_wid + min_wid + in_wid*min_wid 472 | r_idx[rot] = tmp 473 | } 474 | } else { 475 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j 476 | tmp := make([]int, vec_size) 477 | for b := 0; b < batch; b++ { 478 | for i := 0; i < min_wid; i++ { 479 | if (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) && (reverseBits(uint32(3*min_wid+i), log_in_wid-1) < uint32(kp_wid-in_wid/2)) { 480 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid 481 | tmp[idx] = 1 482 | } 483 | } 484 | } 485 | rot := j*min_wid - 2*pos*min_wid*min_wid + min_wid + in_wid*min_wid 486 | r_idx[rot] = tmp 487 | } 488 | } 489 | } 490 | 491 | return r_idx 492 | } 493 | 494 | // generate vectors for comprs_full_fast (N/2 input) 495 | // returns the idx and rotations for each idx For comprs_full_hf 496 | // vec_size = slots, in_wid = real in_wid including padding, 497 | // CAUTION: rotation = -rotation (of comprs_full_hf) 498 | func gen_comprs_fast(vec_size, in_wid, kp_wid, pos, ul int) (m_idx, r_idx map[int][]int) { 499 | m_idx = make(map[int][]int) 500 | r_idx = make(map[int][]int) 501 | batch := 2 * vec_size / (in_wid * in_wid) 502 | 503 | if kp_wid < in_wid/2 { 504 | panic("keep width too small. less than in_wid/2") 505 | } 506 | pos = int(reverseBits(uint32(pos), 2)) 507 | min_wid := in_wid / 4 508 | if in_wid%4 != 0 { 509 | panic("input wid not divisible by 4") 510 | } 511 | if in_wid%2 != 0 { 512 | panic("input wid not divisible by 2") 513 | } 514 | log_in_wid := 0 515 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ { 516 | } 517 | 518 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j 519 | tmp := make([]int, vec_size) 520 | for b := 0; b < batch; b++ { 521 | for i := 0; i < min_wid; i++ { 522 | if (ul == 0) && (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) { 523 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid 524 | tmp[idx] = 1 525 | } 526 | if (ul == 1) && (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) && (reverseBits(uint32(min_wid+i), log_in_wid-1) < uint32(kp_wid-in_wid/2)) { 527 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid 528 | tmp[idx] = 1 529 | } 530 | } 531 | } 532 | rot := j*min_wid - 2*min_wid*min_wid + min_wid 533 | m_idx[rot] = tmp 534 | } 535 | for b := 0; b < batch; b++ { // kinds of mov depends on b 536 | tmp := make([]int, vec_size) 537 | for j := 0; j < 2*min_wid; j++ { 538 | for i := 0; i < min_wid; i++ { 539 | idx := 2*min_wid*in_wid*b + 3*in_wid/2*min_wid + j*min_wid + i 540 | tmp[idx] = 1 541 | } 542 | } 543 | rot := 3*b*min_wid*in_wid/2 - pos*min_wid*in_wid/2*batch + 3*min_wid*in_wid/2 544 | r_idx[rot] = tmp 545 | } 546 | 547 | return m_idx, r_idx 548 | } 549 | 550 | // generate vectors for comprs_full_fast (N/2 input) 551 | // returns the idx and rotations for each idx For comprs_full_hf 552 | // vec_size = full slots = N/2, in_wid = real in_wid including padding, 553 | // CAUTION: rotation = -rotation (of comprs_full_hf) 554 | // log_sparse: 0 => full slots, 1 => half slots, Of the INPUT 555 | // ul: 0(up), 1(low) 556 | // pos: position after pack [only for full packing case] 557 | func gen_comprs_sparse(vec_size, in_wid, kp_wid, log_sparse, ul, pos int) (m_idx, r_idx map[int][]int) { 558 | m_idx = make(map[int][]int) 559 | r_idx = make(map[int][]int) 560 | batch := 2 * vec_size / (in_wid * in_wid * (1 << log_sparse)) 561 | 562 | // if kp_wid < in_wid/2 { 563 | // panic("keep width too small. less than in_wid/2") 564 | // } 565 | // pos = int(reverseBits(uint32(pos), 2)) 566 | min_wid := in_wid / 2 567 | if in_wid%2 != 0 { 568 | panic("input wid not divisible by 2") 569 | } 570 | log_in_wid := 0 571 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ { 572 | } 573 | 574 | if log_sparse != 0 { 575 | if pos != 0 { 576 | panic("No pos != 0 cases for log_sparse != 0") 577 | } 578 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j 579 | tmp := make([]int, vec_size) 580 | for b := 0; b < batch; b++ { 581 | for i := 0; i < min_wid/2; i++ { 582 | for k := 0; k < 2; k++ { 583 | if (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && ((reverseBits(uint32(i), log_in_wid-2) + uint32(k)*uint32(min_wid)/2) < uint32(kp_wid)) { 584 | idx := k*in_wid*min_wid*batch + in_wid*in_wid*b/2 + in_wid*j/2 + i 585 | tmp[idx] = 1 586 | } 587 | } 588 | } 589 | } 590 | // repeatedly write tmp elements for log_sparse > 1 cases. 591 | for i := 0; i < vec_size/(1<<(log_sparse-1)); i++ { 592 | for k := 1; k < (1 << (log_sparse - 1)); k++ { 593 | tmp[i+k*vec_size/(1<<(log_sparse-1))] = tmp[i] 594 | } 595 | } 596 | rot := j * min_wid / 2 597 | m_idx[rot] = tmp 598 | } 599 | 600 | for b := 0; b < batch; b++ { // kinds of mov depends on b 601 | tmp := make([]int, vec_size) 602 | for j := 0; j < min_wid; j++ { 603 | for i := 0; i < min_wid/2; i++ { 604 | for k := 0; k < 2; k++ { 605 | idx := k*in_wid*min_wid*batch + b*in_wid*in_wid/2 + j*min_wid/2 + i 606 | tmp[idx] = 1 607 | } 608 | } 609 | } 610 | // repeatedly write tmp elements for log_sparse > 1 cases. 611 | for i := 0; i < vec_size/(1<<(log_sparse-1)); i++ { 612 | for k := 1; k < (1 << (log_sparse - 1)); k++ { 613 | tmp[i+k*vec_size/(1<<(log_sparse-1))] = tmp[i] 614 | } 615 | } 616 | rot := 3 * b * min_wid * min_wid / 2 617 | r_idx[rot] = tmp 618 | } 619 | } else { 620 | if batch > 8*min_wid { 621 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j and b 622 | for bk := 0; bk < 8; bk++ { 623 | tmp := make([]int, vec_size) 624 | for b := 0; b < batch/8; b++ { 625 | for i := 0; i < min_wid/2; i++ { 626 | if (ul == 0) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2) < uint32(kp_wid)) { 627 | idx := 8*in_wid*min_wid*b + bk*min_wid*in_wid + min_wid*j + i 628 | tmp[idx] = 1 629 | } 630 | if (ul == 1) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2)+uint32(min_wid/2) < uint32(kp_wid)) { 631 | idx := 8*in_wid*min_wid*b + bk*min_wid*in_wid + min_wid*j + i 632 | tmp[idx] = 1 633 | } 634 | } 635 | } 636 | rot := j*min_wid/2 + 7*bk*min_wid*min_wid/2 637 | m_idx[rot] = tmp 638 | } 639 | } 640 | 641 | for b := 0; b < batch/8; b++ { // kinds of mov depends on b 642 | tmp := make([]int, vec_size) 643 | for bk := 0; bk < 8; bk++ { 644 | for j := 0; j < min_wid; j++ { 645 | for i := 0; i < min_wid/2; i++ { 646 | idx := 8*b*in_wid*min_wid + bk*min_wid*min_wid/2 + j*min_wid/2 + i 647 | tmp[idx] = 1 648 | } 649 | } 650 | } 651 | rot := 3*b*8*min_wid*min_wid/2 - int(reverseBits(uint32(pos), 2))*batch*min_wid*min_wid/2 652 | r_idx[rot] = tmp 653 | } 654 | } else if batch > 4*min_wid { //we may move 4*j for optimizations 655 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j and b 656 | for bk := 0; bk < 4; bk++ { 657 | tmp := make([]int, vec_size) 658 | for b := 0; b < batch/4; b++ { 659 | for i := 0; i < min_wid/2; i++ { 660 | if (ul == 0) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2) < uint32(kp_wid)) { 661 | idx := 4*in_wid*min_wid*b + bk*min_wid*in_wid + min_wid*j + i 662 | tmp[idx] = 1 663 | } 664 | if (ul == 1) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2)+uint32(min_wid/2) < uint32(kp_wid)) { 665 | idx := 4*in_wid*min_wid*b + bk*min_wid*in_wid + min_wid*j + i 666 | tmp[idx] = 1 667 | } 668 | } 669 | } 670 | rot := j*min_wid/2 + 3*bk*min_wid*min_wid/2 671 | m_idx[rot] = tmp 672 | } 673 | } 674 | 675 | for b := 0; b < batch/4; b++ { // kinds of mov depends on b 676 | tmp := make([]int, vec_size) 677 | for bk := 0; bk < 4; bk++ { 678 | for j := 0; j < min_wid; j++ { 679 | for i := 0; i < min_wid/2; i++ { 680 | idx := 4*b*in_wid*min_wid + bk*min_wid*min_wid/2 + j*min_wid/2 + i 681 | tmp[idx] = 1 682 | } 683 | } 684 | } 685 | rot := 3*b*4*min_wid*min_wid/2 - int(reverseBits(uint32(pos), 2))*batch*min_wid*min_wid/2 686 | r_idx[rot] = tmp 687 | } 688 | } else { 689 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j and b 690 | tmp := make([]int, vec_size) 691 | for b := 0; b < batch; b++ { 692 | for i := 0; i < min_wid/2; i++ { 693 | if (ul == 0) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2) < uint32(kp_wid)) { 694 | idx := in_wid*min_wid*b + min_wid*j + i 695 | tmp[idx] = 1 696 | } 697 | if (ul == 1) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2)+uint32(min_wid/2) < uint32(kp_wid)) { 698 | idx := in_wid*min_wid*b + min_wid*j + i 699 | tmp[idx] = 1 700 | } 701 | } 702 | } 703 | rot := j * min_wid / 2 704 | m_idx[rot] = tmp 705 | } 706 | 707 | for b := 0; b < batch; b++ { // kinds of mov depends on b 708 | tmp := make([]int, vec_size) 709 | for j := 0; j < min_wid; j++ { 710 | for i := 0; i < min_wid/2; i++ { 711 | idx := b*in_wid*min_wid + j*min_wid/2 + i 712 | tmp[idx] = 1 713 | } 714 | } 715 | rot := 3*b*min_wid*min_wid/2 - int(reverseBits(uint32(pos), 2))*batch*min_wid*min_wid/2 716 | r_idx[rot] = tmp 717 | } 718 | } 719 | } 720 | 721 | return m_idx, r_idx 722 | } 723 | -------------------------------------------------------------------------------- /test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "time" 7 | 8 | "github.com/dwkim606/test_lattigo/ckks" 9 | ) 10 | 11 | // Fast Conv without boot, Assume full batch with Po2 in_wid & N 12 | // Normal Conv without output modification (e.g., trimming or expanding) 13 | // Assume that the input is 0 padded according to kernel size: only in_wid - (ker_wid-1)/2 elements in row and columns are nonzero 14 | // Also support non-full batching case 15 | func testConv_in(in_batch, in_wid, ker_wid, total_test_num int, boot bool) { 16 | kind := "Conv" 17 | printResult := false 18 | raw_in_batch := in_batch // same as python 19 | raw_in_wid := in_wid - ker_wid/2 // same as python 20 | norm := in_batch / raw_in_batch 21 | test_dir := "test_conv_data/" 22 | pow := 4.0 23 | 24 | // set basic variables for above input variables 25 | kp_wid, out_batch, logN, trans := set_Variables(in_batch, raw_in_wid, in_wid, ker_wid, kind) 26 | raw_out_batch := out_batch / norm 27 | 28 | // generate Context: params, Keys, rotations, general plaintexts 29 | cont := newContext(logN, ker_wid, []int{in_wid}, []int{kp_wid}, boot, kind) 30 | fmt.Println("vec size: log2 = ", cont.logN) 31 | fmt.Println("raw input width: ", raw_in_wid) 32 | fmt.Println("kernel width: ", ker_wid) 33 | fmt.Println("num raw batches in & out: ", raw_in_batch, ", ", raw_out_batch) 34 | 35 | for test_iter := 0; test_iter < total_test_num; test_iter++ { 36 | fmt.Println(test_iter+1, "-th iter...start") 37 | raw_input := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_in_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*raw_in_batch) 38 | ker_in := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_ker_"+strconv.Itoa(test_iter)+".csv", raw_in_batch*raw_out_batch*ker_wid*ker_wid) 39 | bn_a := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_bna_"+strconv.Itoa(test_iter)+".csv", raw_out_batch) 40 | bn_b := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_bnb_"+strconv.Itoa(test_iter)+".csv", raw_out_batch) 41 | 42 | // input encryption 43 | input := prep_Input(raw_input, raw_in_wid, in_wid, cont.N, norm, trans, printResult) 44 | start := time.Now() 45 | plain_tmp := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values 46 | cont.encoder.EncodeCoeffs(input, plain_tmp) 47 | ctxt_input := cont.encryptor.EncryptNew(plain_tmp) 48 | fmt.Printf("Encryption done in %s \n", time.Since(start)) 49 | 50 | // Kernel Prep & Conv (+BN) Evaluation 51 | var ct_result *ckks.Ciphertext 52 | if boot { 53 | ct_result = evalConv_BNRelu_new(cont, ctxt_input, ker_in, bn_a, bn_b, 0.0, pow, in_wid, kp_wid, ker_wid, raw_in_batch, raw_out_batch, norm, 0, 0, 2, 0, kind, false, false) 54 | } else { 55 | ct_result = evalConv_BN(cont, ctxt_input, ker_in, bn_a, bn_b, in_wid, ker_wid, raw_in_batch, raw_out_batch, norm, float64(1<<30), trans) 56 | } 57 | 58 | start = time.Now() 59 | cont.decryptor.Decrypt(ct_result, plain_tmp) 60 | cfs_tmp := cont.encoder.DecodeCoeffs(plain_tmp) 61 | fmt.Printf("Decryption Done in %s \n", time.Since(start)) 62 | 63 | test_out := post_process(cfs_tmp, raw_in_wid, in_wid) 64 | var real_out []float64 65 | if boot { 66 | real_out = readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_reluout_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*raw_in_batch) 67 | } else { 68 | real_out = readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_out_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*raw_in_batch) 69 | } 70 | 71 | printDebugCfsPlain(test_out, real_out) 72 | } 73 | 74 | } 75 | 76 | func testResNet_crop_sparse(st, end, ker_wid, depth int, debug, cf100 bool) { 77 | // init_batch fixed to 16 78 | ker_name := "ker" + strconv.Itoa(ker_wid) 79 | weight_dir := "Resnet_weights/weights_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" // !! NEED to remove "_test" 80 | out_dir := "Resnet_enc_results/results_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" 81 | fc_out := 10 // 100 for cifar100 82 | init_pow := 6.0 // covers [-2^pow, 2^pow] values at ReLU evaluation 83 | mid_pow := 6.0 84 | final_pow := 6.0 85 | if cf100 { 86 | weight_dir = "Resnet_weights/weights_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" 87 | out_dir = "Resnet_enc_results/results_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" 88 | fc_out = 100 // 100 for cifar100 89 | if ker_wid == 3 { 90 | final_pow = 7.0 91 | } else if ker_wid == 5 { 92 | final_pow = 6.0 93 | } else { 94 | final_pow = 5.0 95 | } 96 | init_pow = 5.0 97 | mid_pow = 5.0 98 | } 99 | 100 | var num_blcs [3]int 101 | switch depth { 102 | case 20: 103 | num_blcs[0], num_blcs[1], num_blcs[2] = 7, 5, 5 104 | case 14: 105 | num_blcs[0], num_blcs[1], num_blcs[2] = 5, 3, 3 106 | case 8: 107 | num_blcs[0], num_blcs[1], num_blcs[2] = 3, 1, 1 108 | default: 109 | panic("wrong depth (not in 8, 14, 20)!") 110 | } 111 | real_batch := []int{16, 32, 64} // same as python (small for faster eval test) !! NEEDS to be changed for real test input {16, 32, 64} 112 | norm := []int{4, 8, 16} // only use 1/norm batches among full batches (i.e., sparse packing) 113 | step := []int{1, 1, 1} // non-one only when it is for inside 114 | 115 | logN := 16 // !! NEEDS to be modified to 16 116 | alpha := 0.0 117 | in_wids := []int{32, 16, 8} // before cropping 118 | raw_in_wids := []int{32 - ker_wid/2, 16 - ker_wid/2, 8 - ker_wid/2} // same as python 119 | fast_pack := true 120 | ker_size := ker_wid * ker_wid 121 | max_batch := make([]int, len(real_batch)) // the max batch 122 | for i := range max_batch { 123 | max_batch[i] = (1 << logN) / (in_wids[i] * in_wids[i]) 124 | } 125 | 126 | cont := newContext(logN, ker_wid, in_wids, raw_in_wids, true, "Resnet_crop_sparse") 127 | 128 | for iter := st; iter < end; iter++ { 129 | fmt.Println("Running ", iter, "-th iter... ker size: ", ker_wid) 130 | image := readTxt("Resnet_plain_data/crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid1/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3) 131 | // image := make([]float64, in_wids[0]*in_wids[0]*3) 132 | // for i := range image { 133 | // image[i] = 1.0 - 1.0*float64(i)/float64(len(image)) 134 | // } 135 | if cf100 { 136 | image = readTxt("Resnet_plain_data/cf100_crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid1/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3) 137 | } 138 | input := make([]float64, cont.N) 139 | k := 0 140 | for i := 0; i < in_wids[0]; i++ { 141 | for j := 0; j < in_wids[0]; j++ { 142 | for b := 0; b < 3; b++ { 143 | if (i < raw_in_wids[0]) && (j < raw_in_wids[0]) { 144 | input[i*in_wids[0]*max_batch[0]+j*max_batch[0]+b*norm[0]] = image[k] // sparse pack the input 145 | } 146 | k++ 147 | } 148 | } 149 | } 150 | fmt.Println("Input: ") 151 | prt_mat_norm(input, max_batch[0], norm[0], 3, false) 152 | fmt.Println("vec size: ", cont.N) 153 | fmt.Println("input width: ", raw_in_wids) 154 | fmt.Println("kernel width: ", ker_wid) 155 | fmt.Println("num batches: ", real_batch) 156 | 157 | enc_start := time.Now() 158 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values 159 | cont.encoder.EncodeCoeffs(input, pl_input) 160 | ct_input := cont.encryptor.EncryptNew(pl_input) 161 | fmt.Printf("Encryption done in %s \n", time.Since(enc_start)) 162 | 163 | timings := make([]float64, 6) 164 | begin_start := time.Now() 165 | start := time.Now() 166 | 167 | // ResNet Block 1 168 | pow := init_pow 169 | ct_layer := ct_input 170 | for i := 1; i <= num_blcs[0]; i++ { 171 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-a.csv", real_batch[0]) 172 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-b.csv", real_batch[0]) 173 | // bn_a := make([]float64, real_batch[0]) 174 | // bn_b := make([]float64, real_batch[0]) 175 | // for i := range bn_a { 176 | // bn_a[i] = 0.2 177 | // bn_b[i] = 0.0 178 | // } 179 | ker_in_batch := 3 180 | if i != 1 { 181 | ker_in_batch = real_batch[0] 182 | } 183 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", ker_in_batch*real_batch[0]*ker_size) 184 | // ker_in := make([]float64, ker_in_batch*real_batch[0]*ker_size) 185 | // for i := range ker_in { 186 | // ker_in[i] = 0.05 * float64(i) / float64(len(ker_in)) 187 | // } 188 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, ker_in_batch, real_batch[0], norm[0], 0, step[0], 2, 2, "Conv_sparse", fast_pack, debug) 189 | pow = mid_pow 190 | fmt.Println("Block1, Layer ", i, "done!") 191 | } 192 | fmt.Println("Block1 done.") // !!!! HERE is DONE 193 | timings[0] = time.Since(start).Seconds() 194 | start = time.Now() 195 | 196 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-conv.csv", real_batch[0]*real_batch[1]*ker_size) 197 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-a.csv", real_batch[1]) 198 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-b.csv", real_batch[1]) 199 | // ker_in12 := make([]float64, real_batch[0]*real_batch[1]*ker_size) 200 | // for i := range ker_in12 { 201 | // ker_in12[i] = 0.05 * float64(i) / float64(len(ker_in12)) 202 | // } 203 | // bn_a := make([]float64, real_batch[1]) 204 | // bn_b := make([]float64, real_batch[1]) 205 | // for i := range bn_a { 206 | // bn_a[i] = 0.1 207 | // bn_b[i] = 0.0 208 | // } 209 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in12, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[0], real_batch[1], norm[1], 0, step[1], 2, 1, "StrConv_sparse", fast_pack, debug) 210 | fmt.Println("Block1 to 2 done!") 211 | timings[1] = time.Since(start).Seconds() 212 | start = time.Now() 213 | 214 | // ResNet Block 2 215 | for i := 1; i <= num_blcs[1]; i++ { 216 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-a.csv", real_batch[1]) 217 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-b.csv", real_batch[1]) 218 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size) 219 | // bn_a2 := make([]float64, real_batch[1]) 220 | // bn_b2 := make([]float64, real_batch[1]) 221 | // ker_in2 := make([]float64, real_batch[1]*real_batch[1]*ker_size) 222 | // for i := range bn_a2 { 223 | // bn_a2[i] = 0.1 224 | // bn_b2[i] = 0.0 225 | // } 226 | // for i := range ker_in2 { 227 | // ker_in2[i] = 0.05 * float64(i) / float64(len(ker_in2)) 228 | // } 229 | 230 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, 3, "Conv_sparse", fast_pack, debug) 231 | fmt.Println("Block2, Layer ", i, "done!") 232 | } 233 | fmt.Println("Block2 done.") 234 | timings[2] = time.Since(start).Seconds() 235 | start = time.Now() 236 | 237 | ker_in23 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-conv.csv", real_batch[1]*real_batch[2]*ker_size) 238 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-a.csv", real_batch[2]) 239 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-b.csv", real_batch[2]) 240 | // bn_a3 := make([]float64, real_batch[2]) 241 | // bn_b3 := make([]float64, real_batch[2]) 242 | // ker_in23 := make([]float64, real_batch[1]*real_batch[2]*ker_size) 243 | // for i := range bn_a3 { 244 | // bn_a3[i] = 0.1 245 | // bn_b3[i] = 0.0 246 | // } 247 | // for i := range ker_in23 { 248 | // ker_in23[i] = 0.05 * float64(i) / float64(len(ker_in23)) 249 | // } 250 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in23, bn_a3, bn_b3, alpha, pow, in_wids[1], raw_in_wids[2], ker_wid, real_batch[1], real_batch[2], norm[2], 0, step[2], 2, 2, "StrConv_sparse", fast_pack, debug) 251 | fmt.Println("Block2 to 3 done!") 252 | timings[3] = time.Since(start).Seconds() 253 | start = time.Now() 254 | 255 | // ResNet Block 3 256 | for i := 1; i <= num_blcs[2]; i++ { 257 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-a.csv", real_batch[2]) 258 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-b.csv", real_batch[2]) 259 | ker_in3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-conv.csv", real_batch[2]*real_batch[2]*ker_size) 260 | // bn_a3 := make([]float64, real_batch[2]) 261 | // bn_b3 := make([]float64, real_batch[2]) 262 | // ker_in3 := make([]float64, real_batch[2]*real_batch[2]*ker_size) 263 | // for i := range bn_a3 { 264 | // bn_a3[i] = 0.1 265 | // bn_b3[i] = 0.0 266 | // } 267 | // for i := range ker_in3 { 268 | // ker_in3[i] = 0.1 * float64(i) / float64(len(ker_in3)) 269 | // } 270 | 271 | if i == num_blcs[2] { 272 | pow = final_pow 273 | } 274 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in3, bn_a3, bn_b3, alpha, pow, in_wids[2], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 4, "Conv_sparse", fast_pack, debug) 275 | fmt.Println("Block3, Layer ", i, "done!") 276 | } 277 | fmt.Println("Block3 done.") 278 | timings[4] = time.Since(start).Seconds() 279 | start = time.Now() 280 | 281 | ker_inf_wid := raw_in_wids[2] 282 | if ker_inf_wid%2 == 0 { 283 | ker_inf_wid++ 284 | } 285 | ker_inf := readTxt(weight_dir+"final-fckernel.csv", real_batch[2]*fc_out) 286 | // ker_inf := make([]float64, real_batch[2]*fc_out) 287 | // for i := range ker_inf { 288 | // ker_inf[i] = 0.1 * float64(i) 289 | // } 290 | var ct_result, ct_result2 *ckks.Ciphertext 291 | if cf100 { 292 | ker_inf_1 := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out/2) 293 | ker_inf_2 := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out/2) 294 | for i := 0; i < fc_out/2; i++ { 295 | for j := 0; j < real_batch[2]; j++ { 296 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ { 297 | ker_inf_1[j*fc_out/2+i+b*real_batch[2]*fc_out/2] = ker_inf[j*fc_out+i] 298 | ker_inf_2[j*fc_out/2+i+b*real_batch[2]*fc_out/2] = ker_inf[j*fc_out+i+fc_out/2] 299 | } 300 | } 301 | } 302 | bn_af := make([]float64, fc_out/2) 303 | for i := range bn_af { 304 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements 305 | } 306 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out) 307 | bn_bf_1 := make([]float64, fc_out/2) 308 | bn_bf_2 := make([]float64, fc_out/2) 309 | for i := range bn_bf_1 { 310 | bn_bf_1[i] = bn_bf[i] 311 | bn_bf_2[i] = bn_bf[i+fc_out/2] 312 | } 313 | ct_result = evalConv_BN(cont, ct_layer, ker_inf_1, bn_af, bn_bf_1, in_wids[2], ker_inf_wid, real_batch[2], fc_out/2, norm[2], float64(1<<30), false) 314 | ct_result2 = evalConv_BN(cont, ct_layer, ker_inf_2, bn_af, bn_bf_2, in_wids[2], ker_inf_wid, real_batch[2], fc_out/2, norm[2], float64(1<<30), false) 315 | fmt.Println("Final FC done.") 316 | timings[5] = time.Since(start).Seconds() 317 | start = time.Now() 318 | } else { 319 | ker_inf_ := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out) 320 | for i := range ker_inf { 321 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ { 322 | ker_inf_[i+b*real_batch[2]*fc_out] = ker_inf[i] 323 | } 324 | } 325 | bn_af := make([]float64, fc_out) 326 | for i := range bn_af { 327 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements 328 | } 329 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out) 330 | // bn_bf := make([]float64, fc_out) 331 | // for i := range bn_bf { 332 | // bn_bf[i] = 1 * float64(i) 333 | // } 334 | ct_result = evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[2], ker_inf_wid, real_batch[2], fc_out, norm[2], float64(1<<30), false) 335 | fmt.Println("Final FC done.") 336 | timings[5] = time.Since(start).Seconds() 337 | start = time.Now() 338 | } 339 | 340 | fmt.Println() 341 | fmt.Println("=============== DECRYPTION ===============") 342 | fmt.Println() 343 | if cf100 { 344 | cont.decryptor.Decrypt(ct_result, pl_input) 345 | res_tmp1 := cont.encoder.DecodeCoeffs(pl_input) 346 | cont.decryptor.Decrypt(ct_result2, pl_input) 347 | res_tmp2 := cont.encoder.DecodeCoeffs(pl_input) 348 | fmt.Printf("Decryption Done in %s \n", time.Since(start)) 349 | res_out := append(prt_mat_one_norm(res_tmp1, max_batch[2], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)[:fc_out/2], prt_mat_one_norm(res_tmp2, max_batch[2], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)[:fc_out/2]...) 350 | fmt.Println("\n result: ", res_out) 351 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out) 352 | } else { 353 | cont.decryptor.Decrypt(ct_result, pl_input) 354 | res_tmp := cont.encoder.DecodeCoeffs(pl_input) 355 | fmt.Printf("Decryption Done in %s \n", time.Since(start)) 356 | res_out := prt_mat_one_norm(res_tmp, max_batch[2], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1) 357 | fmt.Println("\n result: ", res_out[:fc_out]) 358 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out[:fc_out]) 359 | } 360 | 361 | fmt.Println("Blc1: ", timings[0], " sec") 362 | fmt.Println("Blc1->2: ", timings[1], " sec") 363 | fmt.Println("Blc2: ", timings[2], " sec") 364 | fmt.Println("Blc2->3: ", timings[3], " sec") 365 | fmt.Println("Blc3: ", timings[4], " sec") 366 | fmt.Println("Final (reduce_mean & FC): ", timings[5], " sec") 367 | fmt.Printf("Total done in %s \n", time.Since(begin_start)) 368 | } 369 | 370 | } 371 | 372 | func testResNet_crop_fast_in(st, end, ker_wid, depth int, debug, cf100 bool) { 373 | // init_batch fixed to 16 374 | ker_name := "ker" + strconv.Itoa(ker_wid) 375 | weight_dir := "Resnet_weights/weights_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" 376 | out_dir := "Resnet_enc_results/results_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" 377 | fc_out := 10 // 100 for cifar100 378 | init_pow := 6.0 // covers [-2^pow, 2^pow] values at ReLU evaluation 379 | mid_pow := 6.0 380 | final_pow := 6.0 381 | if cf100 { 382 | weight_dir = "Resnet_weights/weights_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" 383 | out_dir = "Resnet_enc_results/results_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" 384 | fc_out = 100 // 100 for cifar100 385 | if ker_wid == 3 { 386 | final_pow = 7.0 387 | } else if ker_wid == 5 { 388 | final_pow = 6.0 389 | } else { 390 | final_pow = 5.0 391 | } 392 | init_pow = 5.0 393 | mid_pow = 5.0 394 | } 395 | 396 | var num_blcs [3]int 397 | switch depth { 398 | case 20: 399 | num_blcs[0], num_blcs[1], num_blcs[2] = 7, 5, 5 400 | case 14: 401 | num_blcs[0], num_blcs[1], num_blcs[2] = 5, 3, 3 402 | case 8: 403 | num_blcs[0], num_blcs[1], num_blcs[2] = 3, 1, 1 404 | default: 405 | panic("wrong depth (not in 8, 14, 20)!") 406 | } 407 | real_batch := []int{16, 32, 64} // same as python 408 | norm := []int{4, 2, 1} // only use 1/norm batches 409 | step := []int{1, 2, 4} 410 | prt_start := []int{1, 1, 1} 411 | if ker_wid == 5 { 412 | prt_start[0] = 1 413 | prt_start[1] = 2 414 | prt_start[2] = 4 415 | } 416 | 417 | logN := 16 418 | alpha := 0.0 419 | in_wids := []int{32, 16, 8} // before cropping 420 | raw_in_wids := []int{32 - ker_wid/2, 16 - ker_wid/2, 8 - ker_wid/2} // same as python 421 | fast_pack := true 422 | ker_size := ker_wid * ker_wid 423 | max_batch := make([]int, len(real_batch)) // the max batch 424 | for i := range max_batch { 425 | max_batch[i] = (1 << logN) / (in_wids[i] * in_wids[i]) 426 | } 427 | 428 | cont := newContext(logN, ker_wid, in_wids, raw_in_wids, true, "Resnet_crop_fast") 429 | 430 | for iter := st; iter < end; iter++ { 431 | fmt.Println("Running ", iter, "-th iter... ker size: ", ker_wid) 432 | image := readTxt("Resnet_plain_data/crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid1/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3) 433 | if cf100 { 434 | image = readTxt("Resnet_plain_data/cf100_crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid1/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3) 435 | } 436 | input := make([]float64, cont.N) 437 | k := 0 438 | for i := 0; i < in_wids[0]; i++ { 439 | for j := 0; j < in_wids[0]; j++ { 440 | for b := 0; b < 3; b++ { 441 | if (i < raw_in_wids[0]) && (j < raw_in_wids[0]) { 442 | input[i*in_wids[0]*max_batch[0]+j*max_batch[0]+b*norm[0]] = image[k] 443 | } 444 | k++ 445 | } 446 | } 447 | } 448 | fmt.Println("Input: ") 449 | prt_mat_norm(input, max_batch[0], norm[0], 1, false) 450 | fmt.Println("vec size: ", cont.N) 451 | fmt.Println("input width: ", raw_in_wids) 452 | fmt.Println("kernel width: ", ker_wid) 453 | fmt.Println("num batches: ", real_batch) 454 | 455 | enc_start := time.Now() 456 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values 457 | cont.encoder.EncodeCoeffs(input, pl_input) 458 | ct_input := cont.encryptor.EncryptNew(pl_input) 459 | fmt.Printf("Encryption done in %s \n", time.Since(enc_start)) 460 | 461 | timings := make([]float64, 6) 462 | begin_start := time.Now() 463 | start := time.Now() 464 | 465 | // ResNet Block 1 466 | pow := init_pow 467 | ct_layer := ct_input 468 | for i := 1; i <= num_blcs[0]; i++ { 469 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-a.csv", real_batch[0]) 470 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-b.csv", real_batch[0]) 471 | ker_in_batch := 3 472 | if i != 1 { 473 | ker_in_batch = real_batch[0] 474 | } 475 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", ker_in_batch*real_batch[0]*ker_size) 476 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, ker_in_batch, real_batch[0], norm[0], 0, step[0], 2, 0, "Conv_inside", fast_pack, debug) 477 | pow = mid_pow 478 | fmt.Println("Block1, Layer ", i, "done!") 479 | } 480 | fmt.Println("Block1 done.") 481 | timings[0] = time.Since(start).Seconds() 482 | start = time.Now() 483 | 484 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-conv.csv", real_batch[0]*real_batch[1]*ker_size) 485 | ker_in12_new := make([]float64, 2*real_batch[0]*real_batch[1]*ker_size) 486 | for k := 0; k < ker_size; k++ { 487 | for i := 0; i < real_batch[0]; i++ { 488 | for j := 0; j < real_batch[1]; j++ { 489 | ker_in12_new[k*2*real_batch[0]*real_batch[1]+2*i*real_batch[1]+j] = ker_in12[k*real_batch[0]*real_batch[1]+i*real_batch[1]+j] 490 | } 491 | } 492 | } 493 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-a.csv", real_batch[1]) 494 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-b.csv", real_batch[1]) 495 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in12_new, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, 0, "StrConv_inside", fast_pack, debug) 496 | fmt.Println("Block1 to 2 done!") 497 | if debug { 498 | max_bat := cont.N / (in_wids[0] * in_wids[0]) 499 | res_ttmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_layer)) 500 | prt_mat_norm_step(res_ttmp, max_bat, norm[1], step[1], prt_start[1], 3, false) 501 | } 502 | timings[1] = time.Since(start).Seconds() 503 | start = time.Now() 504 | 505 | // ResNet Block 2 506 | for i := 1; i <= num_blcs[1]; i++ { 507 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-a.csv", real_batch[1]) 508 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-b.csv", real_batch[1]) 509 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size) 510 | 511 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, 0, "Conv_inside", fast_pack, debug) 512 | fmt.Println("Block2, Layer ", i, "done!") 513 | } 514 | fmt.Println("Block2 done.") 515 | timings[2] = time.Since(start).Seconds() 516 | start = time.Now() 517 | 518 | ker_in23 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-conv.csv", real_batch[1]*real_batch[2]*ker_size) 519 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-a.csv", real_batch[2]) 520 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-b.csv", real_batch[2]) 521 | ker_in23_new := make([]float64, 2*real_batch[1]*real_batch[2]*ker_size) 522 | for k := 0; k < ker_size; k++ { 523 | for i := 0; i < real_batch[1]; i++ { 524 | for j := 0; j < real_batch[2]; j++ { 525 | ker_in23_new[k*2*real_batch[1]*real_batch[2]+2*i*real_batch[2]+j] = ker_in23[k*real_batch[1]*real_batch[2]+i*real_batch[2]+j] 526 | } 527 | } 528 | } 529 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in23_new, bn_a3, bn_b3, alpha, pow, in_wids[0], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 0, "StrConv_inside", fast_pack, debug) 530 | fmt.Println("Block2 to 3 done!") 531 | if debug { 532 | max_bat := cont.N / (in_wids[0] * in_wids[0]) 533 | res_ttmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_layer)) 534 | prt_mat_norm_step(res_ttmp, max_bat, norm[2], step[2], prt_start[2], 3, false) 535 | } 536 | timings[3] = time.Since(start).Seconds() 537 | start = time.Now() 538 | 539 | // ResNet Block 3 540 | for i := 1; i <= num_blcs[2]; i++ { 541 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-a.csv", real_batch[2]) 542 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-b.csv", real_batch[2]) 543 | ker_in3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-conv.csv", real_batch[2]*real_batch[2]*ker_size) 544 | 545 | if i == num_blcs[2] { 546 | pow = final_pow 547 | } 548 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in3, bn_a3, bn_b3, alpha, pow, in_wids[0], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 0, "Conv_inside", fast_pack, debug) 549 | fmt.Println("Block3, Layer ", i, "done!") 550 | } 551 | fmt.Println("Block3 done.") 552 | timings[4] = time.Since(start).Seconds() 553 | start = time.Now() 554 | 555 | ker_inf_wid := raw_in_wids[0] 556 | if ker_inf_wid%2 == 0 { 557 | ker_inf_wid++ 558 | } 559 | ker_inf := readTxt(weight_dir+"final-fckernel.csv", real_batch[2]*fc_out) 560 | var ct_result, ct_result2 *ckks.Ciphertext 561 | if cf100 { 562 | ker_inf_1 := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out/2) 563 | ker_inf_2 := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out/2) 564 | for i := 0; i < fc_out/2; i++ { 565 | for j := 0; j < real_batch[2]; j++ { 566 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ { 567 | ker_inf_1[j*fc_out/2+i+b*real_batch[2]*fc_out/2] = ker_inf[j*fc_out+i] 568 | ker_inf_2[j*fc_out/2+i+b*real_batch[2]*fc_out/2] = ker_inf[j*fc_out+i+fc_out/2] 569 | } 570 | } 571 | } 572 | bn_af := make([]float64, fc_out/2) 573 | for i := range bn_af { 574 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements 575 | } 576 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out) 577 | bn_bf_1 := make([]float64, fc_out/2) 578 | bn_bf_2 := make([]float64, fc_out/2) 579 | for i := range bn_bf_1 { 580 | bn_bf_1[i] = bn_bf[i] 581 | bn_bf_2[i] = bn_bf[i+fc_out/2] 582 | } 583 | ct_result = evalConv_BN(cont, ct_layer, ker_inf_1, bn_af, bn_bf_1, in_wids[0], ker_inf_wid, real_batch[2], fc_out/2, norm[2], float64(1<<30), false) 584 | ct_result2 = evalConv_BN(cont, ct_layer, ker_inf_2, bn_af, bn_bf_2, in_wids[0], ker_inf_wid, real_batch[2], fc_out/2, norm[2], float64(1<<30), false) 585 | fmt.Println("Final FC done.") 586 | timings[5] = time.Since(start).Seconds() 587 | start = time.Now() 588 | } else { 589 | ker_inf_ := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out) 590 | for i := range ker_inf { 591 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ { 592 | ker_inf_[i+b*real_batch[2]*fc_out] = ker_inf[i] 593 | } 594 | } 595 | bn_af := make([]float64, fc_out) 596 | for i := range bn_af { 597 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements 598 | } 599 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out) 600 | ct_result = evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[0], ker_inf_wid, real_batch[2], fc_out, norm[2], float64(1<<30), false) 601 | fmt.Println("Final FC done.") 602 | timings[5] = time.Since(start).Seconds() 603 | start = time.Now() 604 | } 605 | 606 | fmt.Println() 607 | fmt.Println("=============== DECRYPTION ===============") 608 | fmt.Println() 609 | if cf100 { 610 | cont.decryptor.Decrypt(ct_result, pl_input) 611 | res_tmp1 := cont.encoder.DecodeCoeffs(pl_input) 612 | cont.decryptor.Decrypt(ct_result2, pl_input) 613 | res_tmp2 := cont.encoder.DecodeCoeffs(pl_input) 614 | fmt.Printf("Decryption Done in %s \n", time.Since(start)) 615 | res_out := append(prt_mat_one_norm(res_tmp1, max_batch[0], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)[:fc_out/2], prt_mat_one_norm(res_tmp2, max_batch[0], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)[:fc_out/2]...) 616 | fmt.Println("\n result: ", res_out) 617 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out) 618 | } else { 619 | cont.decryptor.Decrypt(ct_result, pl_input) 620 | res_tmp := cont.encoder.DecodeCoeffs(pl_input) 621 | fmt.Printf("Decryption Done in %s \n", time.Since(start)) 622 | res_out := prt_mat_one_norm(res_tmp, max_batch[0], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1) 623 | // fmt.Print(res_out) 624 | fmt.Println("\n result: ", res_out[:fc_out]) 625 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out[:fc_out]) 626 | } 627 | 628 | fmt.Println("Blc1: ", timings[0], " sec") 629 | fmt.Println("Blc1->2: ", timings[1], " sec") 630 | fmt.Println("Blc2: ", timings[2], " sec") 631 | fmt.Println("Blc2->3: ", timings[3], " sec") 632 | fmt.Println("Blc3: ", timings[4], " sec") 633 | fmt.Println("Final (reduce_mean & FC): ", timings[5], " sec") 634 | fmt.Printf("Total done in %s \n", time.Since(begin_start)) 635 | } 636 | } 637 | 638 | func testResNet_crop_sparse_wide(st, end, ker_wid, depth, wide_case int, debug, cf100 bool) { 639 | // init_batch fixed to 16 640 | ker_name := "ker" + strconv.Itoa(ker_wid) 641 | weight_dir := "Resnet_weights/weights_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/" 642 | out_dir := "Resnet_enc_results/results_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/" 643 | fc_out := 10 644 | 645 | init_pow := 5.0 646 | mid_pow := 5.0 // needs to be 5.0 in k3 d20 w3 for best performance 647 | final_pow := 5.0 648 | if ker_wid == 5 { 649 | init_pow = 6.0 650 | mid_pow = 6.0 651 | final_pow = 6.0 652 | } 653 | 654 | if cf100 { 655 | weight_dir = "Resnet_weights/weights_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/" 656 | out_dir = "Resnet_enc_results/results_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/" 657 | fc_out = 100 658 | final_pow = 7.0 659 | init_pow = 5.0 660 | mid_pow = 5.0 661 | if (ker_wid == 5) && (depth == 8) { 662 | init_pow = 6.0 663 | final_pow = 6.0 664 | } 665 | } 666 | 667 | init_batch := 16 668 | 669 | var num_blcs [3]int 670 | switch depth { 671 | case 20: 672 | num_blcs[0], num_blcs[1], num_blcs[2] = 7, 5, 5 673 | case 14: 674 | num_blcs[0], num_blcs[1], num_blcs[2] = 5, 3, 3 675 | case 8: 676 | num_blcs[0], num_blcs[1], num_blcs[2] = 3, 1, 1 677 | default: 678 | panic("wrong depth case (not in 8,14,20)!") 679 | } 680 | real_batch := []int{32, 64, 128} // same as python 681 | norm := []int{2, 4, 8} // only use 1/norm batches 682 | log_sparse := []int{1, 2, 3} 683 | step := []int{1, 1, 1} 684 | kind := "Resnet_crop_sparse_wide2" 685 | 686 | if wide_case == 3 { 687 | real_batch = []int{48, 96, 192} 688 | norm = []int{1, 2, 4} 689 | log_sparse = []int{0, 1, 2} 690 | kind = "Resnet_crop_sparse_wide3" 691 | } else if wide_case != 2 { 692 | panic("wrong wide_case (2 nor 3)!") 693 | } 694 | 695 | logN := 16 696 | alpha := 0.0 697 | in_wids := []int{32, 16, 8} // before cropping 698 | raw_in_wids := []int{32 - ker_wid/2, 16 - ker_wid/2, 8 - ker_wid/2} // same as python 699 | fast_pack := true 700 | ker_size := ker_wid * ker_wid 701 | max_batch := make([]int, len(real_batch)) // the max batch 702 | for i := range max_batch { 703 | max_batch[i] = (1 << logN) / (in_wids[i] * in_wids[i]) 704 | } 705 | 706 | cont := newContext(logN, ker_wid, in_wids, raw_in_wids, true, kind) 707 | 708 | for iter := st; iter < end; iter++ { 709 | fmt.Println("Running ", iter, "-th iter... ker size: ", ker_wid) 710 | image := readTxt("Resnet_plain_data/crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid"+strconv.Itoa(wide_case)+"/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3) 711 | 712 | if cf100 { 713 | image = readTxt("Resnet_plain_data/cf100_crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid"+strconv.Itoa(wide_case)+"/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3) 714 | } 715 | input := make([]float64, cont.N) 716 | k := 0 717 | for i := 0; i < in_wids[0]; i++ { 718 | for j := 0; j < in_wids[0]; j++ { 719 | for b := 0; b < 3; b++ { 720 | if (i < raw_in_wids[0]) && (j < raw_in_wids[0]) { 721 | input[i*in_wids[0]*max_batch[0]+j*max_batch[0]+b*norm[0]] = image[k] 722 | } 723 | k++ 724 | } 725 | } 726 | } 727 | fmt.Println("Input: ") 728 | prt_mat_norm(input, max_batch[0], norm[0], 3, false) 729 | fmt.Println("vec size: ", cont.N) 730 | fmt.Println("input width: ", raw_in_wids) 731 | fmt.Println("kernel width: ", ker_wid) 732 | fmt.Println("num batches: ", real_batch) 733 | 734 | enc_start := time.Now() 735 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values 736 | cont.encoder.EncodeCoeffs(input, pl_input) 737 | ct_input := cont.encryptor.EncryptNew(pl_input) 738 | fmt.Printf("Encryption done in %s \n", time.Since(enc_start)) 739 | enc_start = time.Now() 740 | 741 | timings := make([]float64, 6) 742 | begin_start := time.Now() 743 | start := time.Now() 744 | 745 | // ResNet Block 1 746 | pow := init_pow 747 | ct_layer := ct_input 748 | for i := 1; i <= num_blcs[0]; i++ { 749 | if i == 5 { 750 | pow = mid_pow 751 | } 752 | var bn_batch int 753 | if i == 1 { 754 | bn_batch = init_batch 755 | } else { 756 | bn_batch = real_batch[0] 757 | } 758 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-a.csv", bn_batch) 759 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-b.csv", bn_batch) 760 | 761 | if i == 1 { 762 | ker_in := readTxt(weight_dir+"w0-conv.csv", 3*init_batch*ker_size) 763 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, 3, init_batch, norm[0], 0, step[0], 2, log_sparse[0], "Conv_sparse", fast_pack, debug) 764 | // pow = mid_pow 765 | } else if i == 2 { 766 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", init_batch*real_batch[0]*ker_size) 767 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, init_batch, real_batch[0], norm[0], 0, step[0], 2, log_sparse[0], "Conv_sparse", fast_pack, debug) 768 | } else { 769 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", real_batch[0]*real_batch[0]*ker_size) 770 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, real_batch[0], real_batch[0], norm[0], 0, step[0], 2, log_sparse[0], "Conv_sparse", fast_pack, debug) 771 | } 772 | fmt.Println("Block1, Layer ", i, "done!") 773 | } 774 | fmt.Println("Block1 done.") 775 | timings[0] = time.Since(start).Seconds() 776 | start = time.Now() 777 | 778 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-conv.csv", real_batch[0]*real_batch[1]*ker_size) 779 | ker_in12_0 := make([]float64, len(ker_in12)/2) 780 | ker_in12_1 := make([]float64, len(ker_in12)/2) 781 | if wide_case == 3 { 782 | for k := 0; k < ker_size; k++ { 783 | for i := 0; i < real_batch[0]; i++ { 784 | for j := 0; j < real_batch[1]/2; j++ { 785 | ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j)] // [i][2*j] 786 | ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j+1)] // [i][2*j+1] 787 | } 788 | } 789 | } 790 | } 791 | 792 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-a.csv", real_batch[1]) 793 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-b.csv", real_batch[1]) 794 | 795 | if wide_case == 2 { 796 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in12, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[0], real_batch[1], norm[1], 0, step[1], 2, log_sparse[0]-1, "StrConv_sparse", fast_pack, debug) 797 | } else if wide_case == 3 { 798 | bn_a_0 := make([]float64, real_batch[1]/2) 799 | bn_a_1 := make([]float64, real_batch[1]/2) 800 | bn_b_0 := make([]float64, real_batch[1]/2) 801 | bn_b_1 := make([]float64, real_batch[1]/2) 802 | for i := range bn_b_0 { 803 | bn_a_0[i] = bn_a[2*i] 804 | bn_a_1[i] = bn_a[2*i+1] 805 | bn_b_0[i] = bn_b[2*i] 806 | bn_b_1[i] = bn_b[2*i+1] 807 | } 808 | ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a_0, bn_b_0, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[0], real_batch[1]/2, norm[0], 0, step[1], 2, 0, "StrConv_sparse_full", fast_pack, debug) 809 | ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a_1, bn_b_1, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[0], real_batch[1]/2, norm[0], 0, step[1], 2, 0, "StrConv_sparse_full", fast_pack, debug) 810 | 811 | xi := make([]float64, cont.N) 812 | xi[2] = 1.0 813 | xi_plain := ckks.NewPlaintext(cont.params, ct_result2.Level(), 1.0) 814 | cont.encoder.EncodeCoeffs(xi, xi_plain) 815 | cont.encoder.ToNTT(xi_plain) 816 | ct_result2 = cont.evaluator.MulNew(ct_result2, xi_plain) 817 | ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2) 818 | } 819 | fmt.Println("Block1 to 2 done!") 820 | timings[1] = time.Since(start).Seconds() 821 | start = time.Now() 822 | 823 | // ResNet Block 2 824 | for i := 1; i <= num_blcs[1]; i++ { 825 | if i == 5 { 826 | pow = init_pow 827 | } 828 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-a.csv", real_batch[1]) 829 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-b.csv", real_batch[1]) 830 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size) 831 | 832 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, log_sparse[1], "Conv_sparse", fast_pack, debug) 833 | fmt.Println("Block2, Layer ", i, "done!") 834 | } 835 | fmt.Println("Block2 done.") 836 | timings[2] = time.Since(start).Seconds() 837 | start = time.Now() 838 | 839 | pow = mid_pow 840 | ker_in23 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-conv.csv", real_batch[1]*real_batch[2]*ker_size) 841 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-a.csv", real_batch[2]) 842 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-b.csv", real_batch[2]) 843 | 844 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in23, bn_a3, bn_b3, alpha, pow, in_wids[1], raw_in_wids[2], ker_wid, real_batch[1], real_batch[2], norm[2], 0, step[2], 2, log_sparse[1]-1, "StrConv_sparse", fast_pack, debug) 845 | fmt.Println("Block2 to 3 done!") 846 | timings[3] = time.Since(start).Seconds() 847 | start = time.Now() 848 | 849 | // ResNet Block 3 850 | for i := 1; i <= num_blcs[2]; i++ { 851 | if i == 3 { 852 | pow = init_pow 853 | } 854 | if i == 5 { 855 | pow = mid_pow 856 | } 857 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-a.csv", real_batch[2]) 858 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-b.csv", real_batch[2]) 859 | ker_in3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-conv.csv", real_batch[2]*real_batch[2]*ker_size) 860 | 861 | if i == num_blcs[2] { 862 | pow = final_pow 863 | } 864 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in3, bn_a3, bn_b3, alpha, pow, in_wids[2], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, log_sparse[2], "Conv_sparse", fast_pack, debug) 865 | fmt.Println("Block3, Layer ", i, "done!") 866 | } 867 | fmt.Println("Block3 done.") 868 | timings[4] = time.Since(start).Seconds() 869 | start = time.Now() 870 | 871 | ker_inf_wid := raw_in_wids[2] 872 | if ker_inf_wid%2 == 0 { 873 | ker_inf_wid++ 874 | } 875 | ker_inf := readTxt(weight_dir+"final-fckernel.csv", real_batch[2]*fc_out) 876 | 877 | ker_inf_ := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out) 878 | for i := range ker_inf { 879 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ { 880 | ker_inf_[i+b*real_batch[2]*fc_out] = ker_inf[i] 881 | } 882 | } 883 | bn_af := make([]float64, fc_out) 884 | for i := range bn_af { 885 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements 886 | } 887 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out) 888 | 889 | ct_result := evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[2], ker_inf_wid, real_batch[2], fc_out, norm[2], float64(1<<30), false) 890 | fmt.Println("Final FC done.") 891 | timings[5] = time.Since(start).Seconds() 892 | start = time.Now() 893 | 894 | fmt.Println() 895 | fmt.Println("=============== DECRYPTION ===============") 896 | fmt.Println() 897 | cont.decryptor.Decrypt(ct_result, pl_input) 898 | res_tmp := cont.encoder.DecodeCoeffs(pl_input) 899 | fmt.Printf("Decryption Done in %s \n", time.Since(start)) 900 | res_out := prt_mat_one_norm(res_tmp, max_batch[2], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1) 901 | fmt.Println("\n result: ", res_out[:fc_out]) 902 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out[:fc_out]) 903 | 904 | fmt.Println("Blc1: ", timings[0], " sec") 905 | fmt.Println("Blc1->2: ", timings[1], " sec") 906 | fmt.Println("Blc2: ", timings[2], " sec") 907 | fmt.Println("Blc2->3: ", timings[3], " sec") 908 | fmt.Println("Blc3: ", timings[4], " sec") 909 | fmt.Println("Final (reduce_mean & FC): ", timings[5], " sec") 910 | fmt.Printf("Total done in %s \n", time.Since(begin_start)) 911 | } 912 | } 913 | 914 | func testResNet_crop_fast_wide_in(st, end, ker_wid, depth, wide_case int, debug, cf100 bool) { 915 | // init_batch fixed to 16 916 | ker_name := "ker" + strconv.Itoa(ker_wid) 917 | weight_dir := "Resnet_weights/weights_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/" 918 | out_dir := "Resnet_enc_results/results_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/" 919 | fc_out := 10 // 100 for cifar100 920 | 921 | init_pow := 5.0 922 | mid_pow := 5.0 // needs to be 5.0 in k3 d20 w3 for best performance 923 | final_pow := 5.0 924 | if ker_wid == 5 { 925 | init_pow = 6.0 926 | mid_pow = 6.0 927 | final_pow = 6.0 928 | } 929 | 930 | if cf100 { 931 | weight_dir = "Resnet_weights/weights_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/" 932 | out_dir = "Resnet_enc_results/results_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/" 933 | fc_out = 100 // 100 for cifar100 934 | final_pow = 7.0 935 | init_pow = 5.0 936 | mid_pow = 5.0 937 | if (ker_wid == 5) && (depth == 8) { 938 | init_pow = 6.0 939 | final_pow = 6.0 940 | } 941 | } 942 | 943 | init_batch := 16 // needs to be modified to 16 944 | 945 | var num_blcs [3]int 946 | switch depth { 947 | case 20: 948 | num_blcs[0], num_blcs[1], num_blcs[2] = 7, 5, 5 949 | case 14: 950 | num_blcs[0], num_blcs[1], num_blcs[2] = 5, 3, 3 951 | case 8: 952 | num_blcs[0], num_blcs[1], num_blcs[2] = 3, 1, 1 953 | default: 954 | panic("wrong depth case (not in 8,14,20)!") 955 | } 956 | real_batch := []int{32, 64, 128} // same as python 957 | norm := []int{2, 4, 2} // only use 1/norm batches 958 | step := []int{1, 1, 2} 959 | prt_start := []int{1, 1, 1} 960 | kind := "Resnet_crop_fast_wide2" 961 | if ker_wid == 5 { 962 | prt_start[0] = 1 963 | prt_start[1] = 1 964 | prt_start[2] = 2 965 | } 966 | if wide_case == 3 { 967 | real_batch = []int{48, 96, 192} 968 | norm = []int{1, 2, 1} 969 | kind = "Resnet_crop_fast_wide3" 970 | } else if wide_case != 2 { 971 | panic("wrong wide_case (2 nor 3)!") 972 | } 973 | 974 | logN := 16 975 | alpha := 0.0 976 | in_wids := []int{32, 16, 8} // before cropping 977 | raw_in_wids := []int{32 - ker_wid/2, 16 - ker_wid/2, 8 - ker_wid/2} // same as python 978 | fast_pack := true 979 | ker_size := ker_wid * ker_wid 980 | max_batch := make([]int, len(real_batch)) // the max batch 981 | for i := range max_batch { 982 | max_batch[i] = (1 << logN) / (in_wids[i] * in_wids[i]) 983 | } 984 | 985 | cont := newContext(logN, ker_wid, in_wids, raw_in_wids, true, kind) 986 | 987 | for iter := st; iter < end; iter++ { 988 | fmt.Println("Running ", iter, "-th iter... ker size: ", ker_wid) 989 | image := readTxt("Resnet_plain_data/crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid"+strconv.Itoa(wide_case)+"/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3) 990 | if cf100 { 991 | image = readTxt("Resnet_plain_data/cf100_crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid"+strconv.Itoa(wide_case)+"/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3) 992 | } 993 | input := make([]float64, cont.N) 994 | k := 0 995 | for i := 0; i < in_wids[0]; i++ { 996 | for j := 0; j < in_wids[0]; j++ { 997 | for b := 0; b < 3; b++ { 998 | if (i < raw_in_wids[0]) && (j < raw_in_wids[0]) { 999 | input[i*in_wids[0]*max_batch[0]+j*max_batch[0]+b*norm[0]] = image[k] 1000 | } 1001 | k++ 1002 | } 1003 | } 1004 | } 1005 | fmt.Println("Input: ") 1006 | prt_mat_norm(input, max_batch[0], norm[0], 1, false) 1007 | fmt.Println("vec size: ", cont.N) 1008 | fmt.Println("input width: ", raw_in_wids) 1009 | fmt.Println("kernel width: ", ker_wid) 1010 | fmt.Println("num batches: ", real_batch) 1011 | 1012 | enc_start := time.Now() 1013 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values 1014 | cont.encoder.EncodeCoeffs(input, pl_input) 1015 | ct_input := cont.encryptor.EncryptNew(pl_input) 1016 | fmt.Printf("Encryption done in %s \n", time.Since(enc_start)) 1017 | enc_start = time.Now() 1018 | 1019 | timings := make([]float64, 6) 1020 | begin_start := time.Now() 1021 | start := time.Now() 1022 | 1023 | // ResNet Block 1 1024 | pow := init_pow 1025 | ct_layer := ct_input 1026 | for i := 1; i <= num_blcs[0]; i++ { 1027 | if i == 5 { 1028 | pow = mid_pow 1029 | } 1030 | var bn_batch int 1031 | if i == 1 { 1032 | bn_batch = init_batch 1033 | } else { 1034 | bn_batch = real_batch[0] 1035 | } 1036 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-a.csv", bn_batch) 1037 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-b.csv", bn_batch) 1038 | if i == 1 { 1039 | ker_in := readTxt(weight_dir+"w0-conv.csv", 3*init_batch*ker_size) 1040 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, 3, init_batch, norm[0], 0, step[0], 2, 0, "Conv", fast_pack, debug) 1041 | // pow = mid_pow 1042 | } else if i == 2 { 1043 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", init_batch*real_batch[0]*ker_size) 1044 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, init_batch, real_batch[0], norm[0], 0, step[0], 2, 0, "Conv", fast_pack, debug) 1045 | } else { 1046 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", real_batch[0]*real_batch[0]*ker_size) 1047 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, real_batch[0], real_batch[0], norm[0], 0, step[0], 2, 0, "Conv", fast_pack, debug) 1048 | } 1049 | fmt.Println("Block1, Layer ", i, "done!") 1050 | } 1051 | fmt.Println("Block1 done.") 1052 | timings[0] = time.Since(start).Seconds() 1053 | start = time.Now() 1054 | 1055 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-conv.csv", real_batch[0]*real_batch[1]*ker_size) 1056 | ker_in12_new := make([]float64, 2*real_batch[0]*real_batch[1]*ker_size) 1057 | ker_in12_0 := make([]float64, len(ker_in12)/2) 1058 | ker_in12_1 := make([]float64, len(ker_in12)/2) 1059 | if wide_case == 2 { 1060 | for k := 0; k < ker_size; k++ { 1061 | for i := 0; i < real_batch[0]; i++ { 1062 | for j := 0; j < real_batch[1]; j++ { 1063 | ker_in12_new[k*2*real_batch[0]*real_batch[1]+2*i*real_batch[1]+j] = ker_in12[k*real_batch[0]*real_batch[1]+i*real_batch[1]+j] 1064 | } 1065 | } 1066 | } 1067 | } else if wide_case == 3 { 1068 | for k := 0; k < ker_size; k++ { 1069 | for i := 0; i < real_batch[0]; i++ { 1070 | for j := 0; j < real_batch[1]/2; j++ { 1071 | ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j)] // [i][2*j] 1072 | ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j+1)] // [i][2*j+1] 1073 | } 1074 | } 1075 | } 1076 | } 1077 | 1078 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-a.csv", real_batch[1]) 1079 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-b.csv", real_batch[1]) 1080 | 1081 | if wide_case == 2 { 1082 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in12_new, bn_a, bn_b, alpha, pow, in_wids[0], 2*raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[0]/2, 0, step[1], 2, 0, "StrConv_odd", fast_pack, debug) 1083 | } else if wide_case == 3 { 1084 | bn_a_0 := make([]float64, real_batch[1]/2) 1085 | bn_a_1 := make([]float64, real_batch[1]/2) 1086 | bn_b_0 := make([]float64, real_batch[1]/2) 1087 | bn_b_1 := make([]float64, real_batch[1]/2) 1088 | for i := range bn_b_0 { 1089 | bn_a_0[i] = bn_a[2*i] 1090 | bn_a_1[i] = bn_a[2*i+1] 1091 | bn_b_0[i] = bn_b[2*i] 1092 | bn_b_1[i] = bn_b[2*i+1] 1093 | } 1094 | ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a_0, bn_b_0, alpha, pow, in_wids[0], 2*raw_in_wids[1], ker_wid, real_batch[0], real_batch[0], norm[0], 0, step[1], 2, 0, "StrConv_odd", fast_pack, debug) 1095 | ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a_1, bn_b_1, alpha, pow, in_wids[0], 2*raw_in_wids[1], ker_wid, real_batch[0], real_batch[0], norm[0], 2, step[1], 2, 0, "StrConv_odd", fast_pack, debug) 1096 | ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2) 1097 | } 1098 | fmt.Println("Block1 to 2 done!") 1099 | if debug { 1100 | max_bat := cont.N / (in_wids[1] * in_wids[1]) 1101 | res_ttmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_layer)) 1102 | prt_mat_norm_step(res_ttmp, max_bat, norm[1], step[1], prt_start[1], 3, false) 1103 | } 1104 | timings[1] = time.Since(start).Seconds() 1105 | start = time.Now() 1106 | 1107 | // ResNet Block 2 1108 | for i := 1; i <= num_blcs[1]; i++ { 1109 | if i == 5 { 1110 | pow = init_pow 1111 | } 1112 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-a.csv", real_batch[1]) 1113 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-b.csv", real_batch[1]) 1114 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size) 1115 | 1116 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, 0, "Conv_inside", fast_pack, debug) 1117 | fmt.Println("Block2, Layer ", i, "done!") 1118 | } 1119 | fmt.Println("Block2 done.") 1120 | timings[2] = time.Since(start).Seconds() 1121 | start = time.Now() 1122 | 1123 | pow = mid_pow 1124 | ker_in23 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-conv.csv", real_batch[1]*real_batch[2]*ker_size) 1125 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-a.csv", real_batch[2]) 1126 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-b.csv", real_batch[2]) 1127 | ker_in23_new := make([]float64, 2*real_batch[1]*real_batch[2]*ker_size) 1128 | for k := 0; k < ker_size; k++ { 1129 | for i := 0; i < real_batch[1]; i++ { 1130 | for j := 0; j < real_batch[2]; j++ { 1131 | ker_in23_new[k*2*real_batch[1]*real_batch[2]+2*i*real_batch[2]+j] = ker_in23[k*real_batch[1]*real_batch[2]+i*real_batch[2]+j] 1132 | } 1133 | } 1134 | } 1135 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in23_new, bn_a3, bn_b3, alpha, pow, in_wids[1], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 0, "StrConv_inside", fast_pack, debug) 1136 | fmt.Println("Block2 to 3 done!") 1137 | if debug { 1138 | max_bat := cont.N / (in_wids[1] * in_wids[1]) 1139 | res_ttmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_layer)) 1140 | prt_mat_norm_step(res_ttmp, max_bat, norm[2], step[2], prt_start[2], 3, false) 1141 | } 1142 | timings[3] = time.Since(start).Seconds() 1143 | start = time.Now() 1144 | 1145 | // ResNet Block 3 1146 | for i := 1; i <= num_blcs[2]; i++ { 1147 | if i == 3 { 1148 | pow = init_pow 1149 | } 1150 | if i == 5 { 1151 | pow = mid_pow 1152 | } 1153 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-a.csv", real_batch[2]) 1154 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-b.csv", real_batch[2]) 1155 | ker_in3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-conv.csv", real_batch[2]*real_batch[2]*ker_size) 1156 | 1157 | if i == num_blcs[2] { 1158 | pow = final_pow 1159 | } 1160 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in3, bn_a3, bn_b3, alpha, pow, in_wids[1], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 0, "Conv_inside", fast_pack, debug) 1161 | fmt.Println("Block3, Layer ", i, "done!") 1162 | } 1163 | fmt.Println("Block3 done.") 1164 | timings[4] = time.Since(start).Seconds() 1165 | start = time.Now() 1166 | 1167 | ker_inf_wid := raw_in_wids[1] 1168 | if ker_inf_wid%2 == 0 { 1169 | ker_inf_wid++ 1170 | } 1171 | ker_inf := readTxt(weight_dir+"final-fckernel.csv", real_batch[2]*fc_out) 1172 | ker_inf_ := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out) 1173 | for i := range ker_inf { 1174 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ { 1175 | ker_inf_[i+b*real_batch[2]*fc_out] = ker_inf[i] 1176 | } 1177 | } 1178 | bn_af := make([]float64, fc_out) 1179 | for i := range bn_af { 1180 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements 1181 | } 1182 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out) 1183 | ct_result := evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[1], ker_inf_wid, real_batch[2], fc_out, norm[2], float64(1<<30), false) 1184 | fmt.Println("Final FC done.") 1185 | timings[5] = time.Since(start).Seconds() 1186 | start = time.Now() 1187 | 1188 | fmt.Println() 1189 | fmt.Println("=============== DECRYPTION ===============") 1190 | fmt.Println() 1191 | cont.decryptor.Decrypt(ct_result, pl_input) 1192 | res_tmp := cont.encoder.DecodeCoeffs(pl_input) 1193 | fmt.Printf("Decryption Done in %s \n", time.Since(start)) 1194 | res_out := prt_mat_one_norm(res_tmp, max_batch[1], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1) 1195 | // fmt.Print(res_out) 1196 | fmt.Println("\n result: ", res_out[:fc_out]) 1197 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out[:fc_out]) 1198 | 1199 | fmt.Println("Blc1: ", timings[0], " sec") 1200 | fmt.Println("Blc1->2: ", timings[1], " sec") 1201 | fmt.Println("Blc2: ", timings[2], " sec") 1202 | fmt.Println("Blc2->3: ", timings[3], " sec") 1203 | fmt.Println("Blc3: ", timings[4], " sec") 1204 | fmt.Println("Final (reduce_mean & FC): ", timings[5], " sec") 1205 | fmt.Printf("Total done in %s \n", time.Since(begin_start)) 1206 | } 1207 | } 1208 | 1209 | func testImagenet_final_fast_in(st, end, ker_wid int) { 1210 | // We use full packing: i.e., in_wid**2 element is contained in po2_in_wid**2 sized block <-> half padding of Resnet 1211 | // So ReLU, keep or rot, StoC done on both the 1st & 2nd part of the CtoS ciphertexts 1212 | ker_name := "ker" + strconv.Itoa(ker_wid) // "ker5" 1213 | weight_dir := "weight_imgnet_" + ker_name + "_h5/" 1214 | logN := 16 1215 | raw_in_wids := []int{14, 7} // same as python 1216 | real_batch := []int{256, 512} // same as python 1217 | iter := 2 1218 | in_wids := make([]int, len(raw_in_wids)) 1219 | kp_wids := make([]int, len(raw_in_wids)) 1220 | var num_blc1, num_blc2 int 1221 | if ker_name == "ker3" { 1222 | in_wids[0] = 16 1223 | in_wids[1] = 8 1224 | kp_wids[0] = 14 1225 | kp_wids[1] = 7 1226 | num_blc1 = 3 1227 | num_blc2 = 3 1228 | } else if ker_name == "ker5" { 1229 | in_wids[0] = 16 1230 | in_wids[1] = 8 1231 | kp_wids[0] = 14 1232 | kp_wids[1] = 6 1233 | num_blc1 = 3 1234 | num_blc2 = 3 1235 | } else { 1236 | panic("strange ker name!") 1237 | } 1238 | cont := newContext(logN, ker_wid, in_wids, kp_wids, true, "Imagenet_final_fast") 1239 | 1240 | ker_size := ker_wid * ker_wid 1241 | max_batch := make([]int, len(real_batch)) // the max batch 1242 | for i := range max_batch { 1243 | max_batch[i] = cont.N / (in_wids[i] * in_wids[i]) 1244 | } 1245 | alpha := 0.0 // 0.3 => leakyrelu 1246 | init_pow := 6.0 1247 | mid_pow := 5.0 1248 | final_pow := 6.0 1249 | 1250 | // ker5_iter := []int{804, 886, 901, 956} 1251 | // {3, 29, 87, 254, 357, 1252 | // 399, 435, 455, 475, 476, 1253 | // 518, 540, 545, 571, 631, 1254 | // 657, 699, 711, 748, 790, 1255 | // 804, 886, 901, 956} 1256 | 1257 | // for _, name_iter := range ker5_iter { 1258 | for name_iter := st; name_iter < end; name_iter++ { 1259 | weight_num := 10 1260 | norm := 1 1261 | fmt.Println("Start ", name_iter, "-th iter..") 1262 | 1263 | raw_input := readTxt(ker_name+"_data/test_image_"+strconv.Itoa(name_iter)+".csv", raw_in_wids[0]*raw_in_wids[0]*real_batch[0]) 1264 | input := make([]float64, in_wids[0]*in_wids[0]*real_batch[0]) 1265 | for i := 0; i < raw_in_wids[0]; i++ { 1266 | for j := 0; j < raw_in_wids[0]; j++ { 1267 | for b := 0; b < real_batch[0]; b++ { 1268 | input[i*in_wids[0]*real_batch[0]+j*real_batch[0]+b] = raw_input[i*raw_in_wids[0]*real_batch[0]+j*real_batch[0]+b] 1269 | } 1270 | } 1271 | } 1272 | fmt.Println("Input: ") 1273 | prt_mat(input, max_batch[0], 1) 1274 | fmt.Println("vec size: ", cont.N) 1275 | fmt.Println("input width: ", raw_in_wids) 1276 | fmt.Println("kernel width: ", ker_wid) 1277 | fmt.Println("num batches: ", real_batch) 1278 | 1279 | start := time.Now() 1280 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values 1281 | cont.encoder.EncodeCoeffs(input, pl_input) 1282 | ct_input := cont.encryptor.EncryptNew(pl_input) 1283 | fmt.Printf("Encryption done in %s \n", time.Since(start)) 1284 | 1285 | timings := make([]float64, 4) 1286 | begin_start := time.Now() 1287 | new_start := time.Now() 1288 | 1289 | // Block 1 1290 | pow := init_pow 1291 | ct_layer := ct_input 1292 | for i := 1; i <= num_blc1; i++ { 1293 | ker_in1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[0]*real_batch[0]*ker_size) 1294 | weight_num++ 1295 | bn_a1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[0]) 1296 | bn_b1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[0]) 1297 | if i == num_blc1 { 1298 | pow = mid_pow 1299 | } 1300 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in1, bn_a1, bn_b1, alpha, pow, in_wids[0], kp_wids[0], ker_wid, real_batch[0], real_batch[0], norm, 0, 0, iter, 0, "Conv", false, false) 1301 | fmt.Println("Block1, Layer ", i, "done!") 1302 | } 1303 | fmt.Println("Block1 done!") 1304 | timings[0] = time.Since(new_start).Seconds() 1305 | new_start = time.Now() 1306 | 1307 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[0]*real_batch[1]*ker_size) 1308 | weight_num++ 1309 | bn_a12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[1]) 1310 | bn_b12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[1]) 1311 | 1312 | ker_in12_0 := make([]float64, len(ker_in12)/2) 1313 | ker_in12_1 := make([]float64, len(ker_in12)/2) 1314 | for k := 0; k < ker_size; k++ { 1315 | for i := 0; i < real_batch[0]; i++ { 1316 | for j := 0; j < real_batch[1]/2; j++ { 1317 | ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+j)] // [i][j] 1318 | ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+real_batch[1]/2+j)] // [i][j+B/2] 1319 | } 1320 | } 1321 | } 1322 | bn_a12_0 := make([]float64, real_batch[1]/2) 1323 | bn_a12_1 := make([]float64, real_batch[1]/2) 1324 | bn_b12_0 := make([]float64, real_batch[1]/2) 1325 | bn_b12_1 := make([]float64, real_batch[1]/2) 1326 | for i := range bn_b12_0 { 1327 | bn_a12_0[i] = bn_a12[i] 1328 | bn_a12_1[i] = bn_a12[i+real_batch[1]/2] 1329 | bn_b12_0[i] = bn_b12[i] 1330 | bn_b12_1[i] = bn_b12[i+real_batch[1]/2] 1331 | } 1332 | 1333 | // block1 to block 2 1334 | ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a12_0, bn_b12_0, alpha, pow, in_wids[0], 2*kp_wids[1], ker_wid, real_batch[0], real_batch[0], norm, 0, 0, iter, 0, "StrConv", false, false) 1335 | ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a12_1, bn_b12_1, alpha, pow, in_wids[0], 2*kp_wids[1], ker_wid, real_batch[0], real_batch[0], norm, 1, 0, iter, 0, "StrConv", false, false) 1336 | ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2) 1337 | fmt.Println("Block1 to 2 done!") 1338 | // res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_result)) 1339 | // prt_mat_norm(res_tmp, max_batch[1], 1, 4, false) 1340 | timings[1] = time.Since(new_start).Seconds() 1341 | new_start = time.Now() 1342 | 1343 | // Block 2 1344 | for i := 1; i <= num_blc2; i++ { 1345 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size) 1346 | weight_num++ 1347 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[1]) 1348 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[1]) 1349 | if i == num_blc2 { 1350 | pow = final_pow 1351 | } 1352 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], kp_wids[1], ker_wid, real_batch[1], real_batch[1], norm, 0, 0, iter, 0, "Conv", false, false) 1353 | fmt.Println("Block2, Layer ", i, "done!") 1354 | } 1355 | fmt.Println("Block2 done!") 1356 | timings[2] = time.Since(new_start).Seconds() 1357 | new_start = time.Now() 1358 | 1359 | // RMFC 1360 | fin_out_batch := 1000 1361 | ker_inf := readTxt(weight_dir+"fc.csv", real_batch[1]*fin_out_batch) 1362 | bn_af := make([]float64, real_batch[1]*2) 1363 | if ker_wid == 3 { 1364 | for i := range bn_af { 1365 | bn_af[i] = 1.0 / (7 * 7) // for reduce mean on 8*8 elements 1366 | } 1367 | } else { 1368 | for i := range bn_af { 1369 | bn_af[i] = 1.0 / (6 * 6) // for reduce mean on 8*8 elements 1370 | } 1371 | } 1372 | bn_bf := make([]float64, real_batch[1]*2) 1373 | for i := range bn_bf { 1374 | bn_bf[i] = 0.0 //10.0 * float64(i) 1375 | } 1376 | ker_inf_ := make([]float64, 7*7*real_batch[1]*fin_out_batch) 1377 | for b := 0; b < 7*7; b++ { 1378 | for i := 0; i < real_batch[1]; i++ { 1379 | for j := 0; j < fin_out_batch; j++ { 1380 | ker_inf_[b*real_batch[1]*fin_out_batch+i*fin_out_batch+j] = ker_inf[i*fin_out_batch+j] 1381 | } 1382 | } 1383 | } 1384 | ct_result := evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[1], 7, real_batch[1], 1000, 1, float64(1<<30), false) 1385 | timings[3] = time.Since(new_start).Seconds() 1386 | new_start = time.Now() 1387 | 1388 | cont.decryptor.Decrypt(ct_result, pl_input) 1389 | res_tmp := cont.encoder.DecodeCoeffs(pl_input) 1390 | fmt.Printf("Decryption done in %s \n", time.Since(new_start)) 1391 | final_result := prt_mat_one_norm(res_tmp, max_batch[1], 1, 4, 4) 1392 | writeTxt(ker_name+"_enc_result/enc_result_"+strconv.Itoa(name_iter)+".csv", final_result[:1000]) 1393 | 1394 | fmt.Println("Blc1: ", timings[0], " sec") 1395 | fmt.Println("Blc1->2: ", timings[1], " sec") 1396 | fmt.Println("Blc2: ", timings[2], " sec") 1397 | fmt.Println("Final (reduce_mean & FC): ", timings[3], " sec") 1398 | fmt.Printf("Total done in %s \n", time.Since(begin_start)) 1399 | } 1400 | } 1401 | 1402 | func testImagenet_sparse(st, end, ker_wid int) { 1403 | // We use full packing: i.e., in_wid**2 element is contained in po2_in_wid**2 sized block <-> half padding of Resnet 1404 | // So ReLU, keep or rot, StoC done on both the 1st & 2nd part of the CtoS ciphertexts 1405 | debug := false 1406 | ker_name := "ker" + strconv.Itoa(ker_wid) // "ker5" 1407 | weight_dir := "weight_imgnet_" + ker_name + "_h5/" 1408 | logN := 16 1409 | raw_in_wids := []int{14, 7} // same as python 1410 | real_batch := []int{256, 512} // same as python 1411 | log_sparse := []int{0, 1} 1412 | norm := []int{1, 2} 1413 | iter := 2 1414 | in_wids := make([]int, len(raw_in_wids)) 1415 | kp_wids := make([]int, len(raw_in_wids)) 1416 | var num_blc1, num_blc2 int 1417 | if ker_name == "ker3" { 1418 | in_wids[0] = 16 1419 | in_wids[1] = 8 1420 | kp_wids[0] = 14 1421 | kp_wids[1] = 7 1422 | num_blc1 = 3 1423 | num_blc2 = 3 1424 | } else if ker_name == "ker5" { 1425 | in_wids[0] = 16 1426 | in_wids[1] = 8 1427 | kp_wids[0] = 14 1428 | kp_wids[1] = 6 1429 | num_blc1 = 3 1430 | num_blc2 = 3 1431 | } else { 1432 | panic("strange ker name!") 1433 | } 1434 | cont := newContext(logN, ker_wid, in_wids, kp_wids, true, "Imagenet_sparse") 1435 | 1436 | ker_size := ker_wid * ker_wid 1437 | max_batch := make([]int, len(real_batch)) // the max batch 1438 | for i := range max_batch { 1439 | max_batch[i] = cont.N / (in_wids[i] * in_wids[i]) 1440 | } 1441 | alpha := 0.0 // 0.3 => leakyrelu 1442 | init_pow := 6.0 1443 | mid_pow := 5.0 1444 | final_pow := 6.0 1445 | 1446 | for name_iter := st; name_iter < end; name_iter++ { 1447 | weight_num := 10 1448 | fmt.Println("Start ", name_iter, "-th iter..") 1449 | 1450 | raw_input := readTxt(ker_name+"_data/test_image_"+strconv.Itoa(name_iter)+".csv", raw_in_wids[0]*raw_in_wids[0]*real_batch[0]) 1451 | input := make([]float64, in_wids[0]*in_wids[0]*real_batch[0]) 1452 | for i := 0; i < raw_in_wids[0]; i++ { 1453 | for j := 0; j < raw_in_wids[0]; j++ { 1454 | for b := 0; b < real_batch[0]; b++ { 1455 | input[i*in_wids[0]*real_batch[0]+j*real_batch[0]+b] = raw_input[i*raw_in_wids[0]*real_batch[0]+j*real_batch[0]+b] 1456 | } 1457 | } 1458 | } 1459 | fmt.Println("Input: ") 1460 | prt_mat(input, max_batch[0], 1) 1461 | fmt.Println("vec size: ", cont.N) 1462 | fmt.Println("input width: ", raw_in_wids) 1463 | fmt.Println("kernel width: ", ker_wid) 1464 | fmt.Println("num batches: ", real_batch) 1465 | 1466 | start := time.Now() 1467 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values 1468 | cont.encoder.EncodeCoeffs(input, pl_input) 1469 | ct_input := cont.encryptor.EncryptNew(pl_input) 1470 | fmt.Printf("Encryption done in %s \n", time.Since(start)) 1471 | 1472 | timings := make([]float64, 4) 1473 | begin_start := time.Now() 1474 | new_start := time.Now() 1475 | 1476 | // Block 1 1477 | pow := init_pow 1478 | ct_layer := ct_input 1479 | for i := 1; i <= num_blc1; i++ { 1480 | ker_in1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[0]*real_batch[0]*ker_size) 1481 | weight_num++ 1482 | bn_a1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[0]) 1483 | bn_b1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[0]) 1484 | if i == num_blc1 { 1485 | pow = mid_pow 1486 | } 1487 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in1, bn_a1, bn_b1, alpha, pow, in_wids[0], kp_wids[0], ker_wid, real_batch[0], real_batch[0], norm[0], 0, 1, iter, log_sparse[0], "Conv_sparse", true, debug) 1488 | fmt.Println("Block1, Layer ", i, "done!") 1489 | } 1490 | fmt.Println("Block1 done!") 1491 | timings[0] = time.Since(new_start).Seconds() 1492 | new_start = time.Now() 1493 | 1494 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[0]*real_batch[1]*ker_size) 1495 | weight_num++ 1496 | bn_a12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[1]) 1497 | bn_b12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[1]) 1498 | 1499 | ker_in12_0 := make([]float64, len(ker_in12)/2) 1500 | ker_in12_1 := make([]float64, len(ker_in12)/2) 1501 | for k := 0; k < ker_size; k++ { 1502 | for i := 0; i < real_batch[0]; i++ { 1503 | for j := 0; j < real_batch[1]/2; j++ { 1504 | // ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+j)] // [i][j] 1505 | // ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+real_batch[1]/2+j)] // [i][j+B/2] 1506 | ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j)] // [i][2*j] 1507 | ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j+1)] // [i][2*j+1] 1508 | 1509 | } 1510 | } 1511 | } 1512 | bn_a12_0 := make([]float64, real_batch[1]/2) 1513 | bn_a12_1 := make([]float64, real_batch[1]/2) 1514 | bn_b12_0 := make([]float64, real_batch[1]/2) 1515 | bn_b12_1 := make([]float64, real_batch[1]/2) 1516 | for i := range bn_b12_0 { 1517 | // bn_a12_0[i] = bn_a12[i] 1518 | // bn_a12_1[i] = bn_a12[i+real_batch[1]/2] 1519 | // bn_b12_0[i] = bn_b12[i] 1520 | // bn_b12_1[i] = bn_b12[i+real_batch[1]/2] 1521 | bn_a12_0[i] = bn_a12[2*i] 1522 | bn_a12_1[i] = bn_a12[2*i+1] 1523 | bn_b12_0[i] = bn_b12[2*i] 1524 | bn_b12_1[i] = bn_b12[2*i+1] 1525 | } 1526 | 1527 | // block1 to block 2 1528 | // ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a12_0, bn_b12_0, alpha, pow, in_wids[0], 2*kp_wids[1], ker_wid, real_batch[0], real_batch[0], norm, 0, 0, iter, 0, "StrConv", false, false) 1529 | // ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a12_1, bn_b12_1, alpha, pow, in_wids[0], 2*kp_wids[1], ker_wid, real_batch[0], real_batch[0], norm, 1, 0, iter, 0, "StrConv", false, false) 1530 | // ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2) 1531 | 1532 | ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a12_0, bn_b12_0, alpha, pow, in_wids[0], kp_wids[1], ker_wid, real_batch[0], real_batch[1]/2, norm[0], 0, 1, 2, 0, "StrConv_sparse_full", true, debug) 1533 | ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a12_1, bn_b12_1, alpha, pow, in_wids[0], kp_wids[1], ker_wid, real_batch[0], real_batch[1]/2, norm[0], 0, 1, 2, 0, "StrConv_sparse_full", true, debug) 1534 | 1535 | xi := make([]float64, cont.N) 1536 | xi[2] = 1.0 1537 | xi_plain := ckks.NewPlaintext(cont.params, ct_result2.Level(), 1.0) 1538 | cont.encoder.EncodeCoeffs(xi, xi_plain) 1539 | cont.encoder.ToNTT(xi_plain) 1540 | ct_result2 = cont.evaluator.MulNew(ct_result2, xi_plain) 1541 | ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2) 1542 | 1543 | fmt.Println("Block1 to 2 done!") 1544 | // res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_result)) 1545 | // prt_mat_norm(res_tmp, max_batch[1], 1, 4, false) 1546 | timings[1] = time.Since(new_start).Seconds() 1547 | new_start = time.Now() 1548 | 1549 | // Block 2 1550 | for i := 1; i <= num_blc2; i++ { 1551 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size) 1552 | weight_num++ 1553 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[1]) 1554 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[1]) 1555 | if i == num_blc2 { 1556 | pow = final_pow 1557 | } 1558 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], kp_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, 1, iter, log_sparse[1], "Conv_sparse", true, debug) 1559 | fmt.Println("Block2, Layer ", i, "done!") 1560 | } 1561 | fmt.Println("Block2 done!") 1562 | timings[2] = time.Since(new_start).Seconds() 1563 | new_start = time.Now() 1564 | 1565 | // RMFC 1566 | fin_out_batch := 1000 1567 | ker_inf := readTxt(weight_dir+"fc.csv", real_batch[1]*fin_out_batch) 1568 | bn_af := make([]float64, real_batch[1]*2) 1569 | if ker_wid == 3 { 1570 | for i := range bn_af { 1571 | bn_af[i] = 1.0 / (7 * 7) // for reduce mean on 8*8 elements 1572 | } 1573 | } else { 1574 | for i := range bn_af { 1575 | bn_af[i] = 1.0 / (6 * 6) // for reduce mean on 8*8 elements 1576 | } 1577 | } 1578 | bn_bf := make([]float64, real_batch[1]*2) 1579 | for i := range bn_bf { 1580 | bn_bf[i] = 0.0 //10.0 * float64(i) 1581 | } 1582 | ker_inf_ := make([]float64, 7*7*real_batch[1]*fin_out_batch) 1583 | for b := 0; b < 7*7; b++ { 1584 | for i := 0; i < real_batch[1]; i++ { 1585 | for j := 0; j < fin_out_batch; j++ { 1586 | ker_inf_[b*real_batch[1]*fin_out_batch+i*fin_out_batch+j] = ker_inf[i*fin_out_batch+j] 1587 | } 1588 | } 1589 | } 1590 | ct_result := evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[1], 7, real_batch[1], fin_out_batch, 1, float64(1<<30), false) 1591 | fmt.Println("Final FC done.") 1592 | timings[3] = time.Since(new_start).Seconds() 1593 | new_start = time.Now() 1594 | 1595 | cont.decryptor.Decrypt(ct_result, pl_input) 1596 | res_tmp := cont.encoder.DecodeCoeffs(pl_input) 1597 | fmt.Printf("Decryption done in %s \n", time.Since(new_start)) 1598 | final_result := prt_mat_one_norm(res_tmp, max_batch[1], 1, 4, 4) 1599 | writeTxt(ker_name+"_enc_result/enc_result_"+strconv.Itoa(name_iter)+".csv", final_result[:1000]) 1600 | 1601 | fmt.Println("Blc1: ", timings[0], " sec") 1602 | fmt.Println("Blc1->2: ", timings[1], " sec") 1603 | fmt.Println("Blc2: ", timings[2], " sec") 1604 | fmt.Println("Final (reduce_mean & FC): ", timings[3], " sec") 1605 | fmt.Printf("Total done in %s \n", time.Since(begin_start)) 1606 | } 1607 | } 1608 | -------------------------------------------------------------------------------- /test_BL.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "strconv" 7 | "time" 8 | 9 | "github.com/dwkim606/test_lattigo/ckks" 10 | ) 11 | 12 | // BaseLine Conv without boot, Assume full batch with Po2 in_wid & N 13 | // Normal Conv without output modification (e.g., trimming or expanding) 14 | // Input does not need padding 15 | // use imaginary part to save before boot 16 | func testConv_BL_in(real_batch, in_wid, ker_wid, total_test_num int, boot bool) { 17 | in_kind := "Conv" 18 | test_dir := "test_conv_data/" 19 | if (in_kind != "TransConv") && (in_kind != "Conv") && (in_kind != "StrConv") { 20 | panic("Wrong in_kind!") 21 | } 22 | pad := ker_wid / 2 23 | raw_in_wid := in_wid - pad // = in_wid 24 | 25 | in_size := in_wid * in_wid 26 | ker_size := ker_wid * ker_wid 27 | slots := real_batch / 2 * in_size 28 | log_slots := 0 29 | for ; (1 << log_slots) < slots; log_slots++ { 30 | } 31 | out_batch := real_batch 32 | if in_kind == "TransConv" { 33 | out_batch = real_batch / 4 34 | } 35 | kp_wid := 0 36 | kind := "BL_" + in_kind 37 | in_batch := real_batch 38 | 39 | // generate Context: params, Keys, rotations, general plaintexts 40 | cont := newContext(log_slots+1, ker_wid, []int{in_wid}, []int{kp_wid}, boot, kind) 41 | fmt.Println("vec size: log2 = ", cont.logN) 42 | fmt.Println("raw input width: ", raw_in_wid) 43 | fmt.Println("kernel width: ", ker_wid) 44 | fmt.Println("num batches in & out: ", real_batch, ", ", out_batch) 45 | 46 | for test_iter := 0; test_iter < total_test_num; test_iter++ { 47 | fmt.Println(test_iter+1, "-th iter...start") 48 | input := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_in_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*in_batch) 49 | ker_in := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_ker_"+strconv.Itoa(test_iter)+".csv", in_batch*in_batch*ker_wid*ker_wid) 50 | bn_a := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_bna_"+strconv.Itoa(test_iter)+".csv", in_batch) 51 | bn_b := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_bnb_"+strconv.Itoa(test_iter)+".csv", in_batch) 52 | var real_out []float64 53 | if boot { 54 | real_out = readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_reluout_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*in_batch) 55 | } else { 56 | real_out = readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_out_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*in_batch) 57 | } 58 | 59 | pad_input1 := make([]float64, in_wid*in_wid*real_batch/2) 60 | pad_input2 := make([]float64, in_wid*in_wid*real_batch/2) 61 | for i := 0; i < raw_in_wid; i++ { 62 | for j := 0; j < raw_in_wid; j++ { 63 | for b := 0; b < real_batch/2; b++ { 64 | pad_input1[b+j*real_batch/2+i*real_batch/2*in_wid] = input[b+j*real_batch+i*real_batch*raw_in_wid] 65 | pad_input2[b+j*real_batch/2+i*real_batch/2*in_wid] = input[b+real_batch/2+j*real_batch+i*real_batch*raw_in_wid] 66 | } 67 | } 68 | } 69 | 70 | bn_a_sep := make([][]float64, 2) 71 | bn_b_sep := make([][]float64, 2) 72 | zeros := make([]float64, real_batch/2) 73 | for out := 0; out < 2; out++ { 74 | bn_a_sep[out] = make([]float64, real_batch/2) 75 | bn_b_sep[out] = make([]float64, real_batch/2) 76 | for i := 0; i < real_batch/2; i++ { 77 | bn_a_sep[out][i] = bn_a[i+out*real_batch/2] 78 | bn_b_sep[out][i] = bn_b[i+out*real_batch/2] 79 | } 80 | } 81 | 82 | ker_in_sep := make([][][]float64, 2) // number of output ctxts 83 | for out := 0; out < 2; out++ { 84 | ker_in_sep[out] = make([][]float64, 2) // number of input ctxts 85 | for in := 0; in < 2; in++ { 86 | ker_in_sep[out][in] = make([]float64, len(ker_in)/(2*2)) 87 | for k := 0; k < ker_size; k++ { 88 | for i := 0; i < real_batch/2; i++ { // in 89 | for j := 0; j < real_batch/2; j++ { // out 90 | ker_in_sep[out][in][k*real_batch*real_batch/4+i*real_batch/2+j] = 91 | ker_in[k*real_batch*real_batch+(i+in*real_batch/2)*real_batch+out*real_batch/2+j] // [i][4*j] 92 | } 93 | } 94 | } 95 | } 96 | } 97 | 98 | input1_rs := reshape_input_BL(pad_input1, in_wid) 99 | input2_rs := reshape_input_BL(pad_input2, in_wid) 100 | start := time.Now() 101 | ct_input1 := cont.encryptor.EncryptNew(cont.encoder.EncodeAtLvlNew(cont.ECD_LV, input1_rs, cont.logN-1)) 102 | ct_input2 := cont.encryptor.EncryptNew(cont.encoder.EncodeAtLvlNew(cont.ECD_LV, input2_rs, cont.logN-1)) 103 | fmt.Printf("Encryption done in %s \n", time.Since(start)) 104 | 105 | start_eval := time.Now() 106 | ct_res := make([]*ckks.Ciphertext, 2) 107 | for pos := 0; pos < 2; pos++ { 108 | ct_res[pos] = cont.evaluator.AddNew(evalConv_BN_BL_test(cont, ct_input1, ker_in_sep[pos][0], bn_a_sep[pos], bn_b_sep[pos], in_wid, ker_wid, real_batch/2, real_batch/2, 0, 1, pad, false, false), 109 | evalConv_BN_BL_test(cont, ct_input2, ker_in_sep[pos][1], bn_a_sep[pos], zeros, in_wid, ker_wid, real_batch/2, real_batch/2, 0, 1, pad, false, false)) 110 | } 111 | fmt.Printf("Evaluation total done in %s \n", time.Since(start_eval)) 112 | 113 | if boot { 114 | img_eval := time.Now() 115 | for pos := 0; pos < 2; pos++ { 116 | ct_res[pos] = cont.evaluator.AddNew(cont.pack_evaluator.ConjugateNew(ct_res[pos]), ct_res[pos]) 117 | if pos == 1 { 118 | ct_res[pos] = cont.evaluator.MultByiNew(ct_res[pos]) 119 | } 120 | // fmt.Println("before Boot: LV = ", ct_res[pos].Level(), " Scale = ", math.Log2(ct_res[pos].Scale)) 121 | } 122 | ct_res[0] = cont.evaluator.AddNew(ct_res[0], ct_res[1]) 123 | img_eval_part := time.Since(img_eval) 124 | 125 | pos := 0 126 | alpha := 0.0 127 | pow := 4.0 128 | ct_res[pos].Scale = ct_res[pos].Scale * math.Pow(2, pow+2) // +1 for 1/2 * (a + conj(a)), +1 for 1/2 after boot 129 | // vals_preB := cont.encoder.Decode(cont.decryptor.DecryptNew(ct_res[pos]), cont.logN-1) 130 | fmt.Println("\n ========= Bootstrapping... (original) ========= ") 131 | start_boot := time.Now() 132 | // fmt.Println("initial (before boot): LV = ", ct_res[pos].Level(), " Scale = ", math.Log2(ct_res[pos].Scale)) 133 | ct_boot := cont.btp.Bootstrapp(ct_res[pos]) 134 | fmt.Printf("Boot Done in %s \n", time.Since(start_boot)) 135 | 136 | // fmt.Println("after Boot: LV = ", ct_boot.Level(), " Scale = ", math.Log2(ct_boot.Scale)) 137 | 138 | // Only for checking the correctness (for Boot) 139 | // vals_postB := printDebug(cont.params, ct_boot, vals_preB, cont.decryptor, cont.encoder) 140 | // vals_relu := make([]complex128, len(vals_postB)) 141 | // for i, elt := range vals_postB { 142 | // vals_relu[i] = complex((math.Max(0, real(elt))+math.Min(0, real(elt)*alpha))*math.Pow(2, pow), 0) 143 | // } 144 | img_eval = time.Now() 145 | pl_scale := ckks.NewPlaintext(cont.params, ct_boot.Level(), math.Pow(2, 30)*float64(cont.params.Q()[14])*float64(cont.params.Q()[13])/ct_boot.Scale) 146 | val_scale := make([]complex128, cont.N/2) 147 | for i := range val_scale { 148 | val_scale[i] = complex(1.0, 0) // val_scale[i] = complex(1.0/math.Pow(2, pow), 0) 149 | } 150 | cont.encoder.EncodeNTT(pl_scale, val_scale, cont.logN-1) 151 | cont.evaluator.Mul(ct_boot, pl_scale, ct_boot) 152 | cont.evaluator.Rescale(ct_boot, cont.params.Scale(), ct_boot) 153 | 154 | ct_iboot := cont.pack_evaluator.ConjugateNew(ct_boot) 155 | ct_res[0] = cont.evaluator.AddNew(ct_boot, ct_iboot) 156 | ct_res[1] = cont.evaluator.DivByiNew(cont.evaluator.SubNew(ct_boot, ct_iboot)) 157 | fmt.Printf("Imaginary packing and unpacking done in %s \n", img_eval_part+time.Since(img_eval)) 158 | 159 | start = time.Now() 160 | for pos := 0; pos < 2; pos++ { 161 | // fmt.Println("before relu: LV = ", ct_res[pos].Level(), " Scale = ", math.Log2(ct_res[pos].Scale)) 162 | // fmt.Println("after Rescale: LV = ", ct_boot.Level(), " Scale = 2^", math.Log2(ct_boot.Scale)) 163 | ct_res[pos] = evalReLU(cont.params, cont.evaluator, ct_res[pos], alpha) 164 | cont.evaluator.MulByPow2(ct_res[pos], int(pow), ct_res[pos]) 165 | cont.evaluator.SetScale(ct_res[pos], cont.params.Scale()) 166 | // printDebug(cont.params, ct_res[pos], vals_relu, cont.decryptor, cont.encoder) 167 | } 168 | fmt.Printf("Relu Done in %s \n", time.Since(start)) 169 | } 170 | 171 | // fmt.Println() 172 | // fmt.Println("=============== DECRYPTION ===============") 173 | // fmt.Println() 174 | start = time.Now() 175 | vals_tmp1 := cont.encoder.Decode(cont.decryptor.DecryptNew(ct_res[0]), log_slots) 176 | vals_tmp2 := cont.encoder.Decode(cont.decryptor.DecryptNew(ct_res[1]), log_slots) 177 | fmt.Printf("Decryption Done in %s \n", time.Since(start)) 178 | 179 | test_out := post_trim_BL(vals_tmp1, raw_in_wid, in_wid) 180 | test_out = append(test_out, post_trim_BL(vals_tmp2, raw_in_wid, in_wid)...) 181 | test_out = post_process_BL(test_out, raw_in_wid) 182 | 183 | printDebugCfsPlain(test_out, real_out) 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /test_run: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwkim606/optimal_conv/c230bae4fee377ffd9e92b4887752e6cffd0205a/test_run --------------------------------------------------------------------------------