├── README.md
├── compare_final.py
├── conv.go
├── eval.go
├── main.go
├── rot_util.go
├── test.go
├── test_BL.go
└── test_run
/README.md:
--------------------------------------------------------------------------------
1 | # Source code for **Optimized Privacy-Preserving CNN Inference with Fully Homomorphic Encryption**
2 |
3 | ## Requirements
4 | 1. Go 1.16.6 or higher ()
5 | - Install command with apt-get:
6 | ```console
7 | apt-get install golang-go
8 | ```
9 | 2. Go packages (Go Cryptography \& Lattigo (fork))
10 | - After installing Go, install packages with following commands:
11 | ```console
12 | go get -u golang.org/x/crypto/...
13 | go get -u github.com/dwkim606/test_lattigo
14 | ```
15 | **CAUTION**: For Lattigo, we must install the forked version with above command (instead of the latest [one](https://github.com/tuneinsight/lattigo))
16 |
17 | 3. Python3 with numpy package (required only for checking the precision of CNN classifier)
18 |
19 | ## Running the Test
20 |
21 | 0. Dataset Preparation: **Necessary** to run the tests
22 |
23 | - Download the data file from the link: https://drive.google.com/drive/folders/1zLTzJ58E_CDtqvnPv8t9YtgkDaHouWWn?usp=sharing
24 | - Move all folders (Resnet_enc_results, Resnet_plain_data, Resnet_weights, test_conv_data) to the same directory as the source code.
25 |
26 |
27 | The following tests are available:
28 |
29 | 1. Convolutions: run Baseline and Ours for various number of batches
30 | - Arguments (in the order of input):
31 | - conv
32 | - (3,5,7): width of kernel
33 | - (0,1,2,3): set number of batches to 4, 16, 64, 256, respectively.
34 | - (1 to 10): number of test runs
35 | - Example command: convolution with kernel width 3, number of batches 16, and 5 test runs
36 | ```console
37 | go run *.go conv 3 1 5
38 | ```
39 |
40 | 2. Convolutions followed by ReLU evaluation (and Bootstrapping): run Baseline and Ours
41 | - Arguments:
42 | - convReLU
43 | - other parts are the same as Convolutions
44 | - Example command: convolution with kernel width 5, number of batches 4, and 3 test runs, then ReLU evaluation with Bootstrapping
45 | ```console
46 | go run *.go convReLU 5 0 3
47 | ```
48 | 3. 20-layer CNN evaluatoin with our method on CIFAR10/CIFAR100 dataset
49 | - Arguments:
50 | - resnet
51 | - (3,5,7): width of kernel
52 | - (8,14,20): number of layers of CNN
53 | - (1,2,3): wideness factor
54 | - (1 to 100 or 1 to 1000): number of tests
55 | - (true, false): true -> CIFAR100, false -> CIFAR10
56 | - **List of available arguments**: (given in the paper; other arguments require appropriate weights for CNN)
57 | - resnet 3 20 1 (1 to 1000) (true/false)
58 | - resnet 5 20 1 (1 to 100) (true/false)
59 | - resnet 7 20 1 (1 to 100) (true/false)
60 | - resnet 5 8 3 (1 to 1000) (true/false)
61 | - resnet 3 14 3 (1 to 1000) (true/false)
62 | - resnet 3 20 3 (1 to 1000) (true/false)
63 | - Example command: run CNN with kernel width 3, number of layers 20, widness factor 1, on 10 test inputs on CIFAR10 dataset
64 | ```console
65 | go run *.go resnet 3 20 1 10 false
66 | ```
67 | **CAUTION**: The CNN evaluation test requires at most roughly 100GB of memory.
68 | - To check the precision of encrypted inference, run compare_final.py with python3, argument (width of kernel, depth, widness, CIFAR100 or not)
69 | - Example command: check the precision of kernel 3, depth 20, widness 1 CNN inference on CIFAR10 dataset
70 | ```console
71 | python3 compare_final.py 3 20 1 false
72 | ```
73 | **CAUTION**: Precision check requires the encrypted inference to be performed beforehand.
74 |
75 | ## MISC.
76 | One can also generate an executble for a remote server with Linux and AMD cpu via the following command.
77 | (Modify the command appropriately for other OS and cpu)
78 | ```console
79 | env GOOS=linux GOARCH=amd64 go build -o test_run
80 | ```
81 | It will generate an executable titled "test_run", which can run on the server with appropriate arguments.
82 | (e.g., for convolution)
83 | ```console
84 | ./test_run conv 3 1 5
85 | ```
86 | With this command, one can run the test on any server without Go (given that executable is generated from other device with Go for compile).
87 | The executable must be in the same directory as data folders (Resnet_enc_results, Resnet_plain_data, Resnet_weights, test_conv_data).
88 |
--------------------------------------------------------------------------------
/compare_final.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | import os
4 | from distutils.util import strtobool
5 | from statistics import mean, stdev
6 |
7 | # for resnet, compare plain with enc
8 | def compare_results(ker, depth, wid, cf100):
9 |
10 | if cf100:
11 | plain_folder_dir = 'Resnet_plain_data/cf100_crop_ker'+str(ker)+'_d'+str(depth)+'_wid'+str(wid)
12 | enc_result_dir = 'Resnet_enc_results/results_cf100_crop_ker'+str(ker)+'_d'+str(depth)+'_wid'+str(wid)+'/'
13 | num_classes = 100
14 | else:
15 | plain_folder_dir = 'Resnet_plain_data/crop_ker'+str(ker)+'_d'+str(depth)+'_wid'+str(wid)
16 | enc_result_dir = 'Resnet_enc_results/results_crop_ker'+str(ker)+'_d'+str(depth)+'_wid'+str(wid)+'/'
17 | num_classes = 10
18 |
19 | if wid == 1:
20 | max_num_samples = 100
21 | if ker == 3:
22 | max_num_samples = 1000
23 | else:
24 | max_num_samples = 1000
25 |
26 | plain_pred_file = os.path.join(plain_folder_dir, 'plain_prediction_'+str(max_num_samples)+'.csv')
27 | true_pred_file = os.path.join(plain_folder_dir, 'test_labels_'+str(max_num_samples)+'.csv')
28 | plain_pred = np.reshape(np.loadtxt(plain_pred_file), [max_num_samples, num_classes])
29 | true_pred = np.reshape(np.loadtxt(true_pred_file), [max_num_samples])
30 |
31 | acc = 0
32 | true_acc = 0
33 | pl_true_acc = 0
34 | total = 0
35 | no_iters = []
36 | wrong_result = {}
37 | os_path = enc_result_dir+'class_result_ker'+str(ker)+'_'
38 |
39 | for iter in range(max_num_samples):
40 | if os.path.exists(os_path+str(iter)+'.csv'):
41 | read = np.loadtxt(os_path+str(iter)+'.csv')
42 | total+=1
43 | else:
44 | no_iters.append(iter)
45 | continue
46 |
47 | res_np = read[:num_classes] #np.reshape(read, [-1])[:10]
48 | # print("enc: ", res_np, "argmax: ", np.argmax(res_np))
49 | # print("plain: ", plain_pred[iter], "argmax: ", np.argmax(plain_pred[iter]))
50 | if (np.argmax(res_np) == np.argmax(plain_pred[iter])):
51 | acc += 1
52 | else:
53 | wrong_result[str(iter)] = []
54 | wrong_result[str(iter)].insert(0, res_np)
55 | wrong_result[str(iter)].insert(1, plain_pred[iter])
56 | wrong_result[str(iter)].insert(2, true_pred[iter])
57 | if (np.argmax(res_np) == true_pred[iter]):
58 | true_acc += 1
59 | if (np.argmax(plain_pred[iter]) == true_pred[iter]):
60 | pl_true_acc += 1
61 |
62 | print("Plain precision: ", pl_true_acc, "/", total)
63 | print("Enc precision: ", true_acc, "/", total)
64 | print("plain vs enc accordance: ", acc, "/", total)
65 | # print("among ", max_num_samples, " samples.")
66 | #print("missing: ", no_iters)
67 | print("\n wrong results: \n")
68 | for i, result in wrong_result.items():
69 | print(i, "-th iter.")
70 | print("enc: ", result[0], "argmax: ", np.argmax(result[0]))
71 | print("plain: ", result[1], "argmax: ", np.argmax(result[1]), "\n")
72 | print("true: ", result[2], " \n" )
73 |
74 | # tf_images = tf.reshape(tf.constant(np.loadtxt('test_images_'+str(num_samples)+'.csv'), tf.float32), [num_samples, 32, 32, 3])
75 | # pred = plain_resnet(tf_images)
76 | # print("enc == plain?", tf.argmax(tf.squeeze(conv, axis=[1,2]),1) == tf.argmax(pred[iter],1))
77 |
78 |
79 | ### main ###
80 |
81 | #num_iter = 1000
82 | ker = int(sys.argv[1])
83 | depth = int(sys.argv[2])
84 | wide = int(sys.argv[3])
85 | cf100 = strtobool(sys.argv[4])
86 |
87 | print("ker: ", ker, " depth: ", depth, " wide: ", wide, " cf100? ", bool(cf100))
88 | compare_results(ker, depth, wide, cf100)
89 |
90 |
--------------------------------------------------------------------------------
/conv.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "time"
7 |
8 | "github.com/dwkim606/test_lattigo/ckks"
9 | "github.com/dwkim606/test_lattigo/rlwe"
10 | )
11 |
12 | // output bitreversed input with bitwid
13 | func reverseBits(num uint32, bitwid int) uint32 {
14 | num = num << (32 - bitwid)
15 |
16 | var ret = uint32(0)
17 | var power = uint32(31)
18 | for num != 0 {
19 | ret += (num & 1) << power
20 | num = num >> 1
21 | power -= 1
22 | }
23 | return ret
24 | }
25 |
26 | // output the slice with bitreversed order from input
27 | func reverseOrder(input []float64, bitwid int) []float64 {
28 | out := make([]float64, len(input))
29 | for i := range out {
30 | out[i] = input[reverseBits(uint32(i), bitwid)]
31 | }
32 |
33 | return out
34 | }
35 |
36 | // Extract upper left elements of size prt_wid * prt_wid from the input arrg vec with batch batches
37 | // prt_wid, batch all are those from the output of our conv algorithm (consider padding)
38 | func reshape_conv_out(result []float64, prt_wid, out_num int) []float64 {
39 | prt_out := make([]float64, prt_wid*prt_wid*out_num)
40 | in_wid := 2 * prt_wid
41 | batch := len(result) / (in_wid * in_wid)
42 |
43 | for i := 0; i < out_num; i++ {
44 | for j := 0; j < prt_wid; j++ {
45 | for k := 0; k < prt_wid; k++ {
46 | prt_out[i+out_num*(j*prt_wid+k)] = result[i+batch*(j*in_wid+k)] //[batch*(in_wid+1)*(ker_wid-1)+i+batch*(j*in_wid+k)]
47 | }
48 | }
49 | }
50 |
51 | return prt_out
52 | }
53 |
54 | // Reshape 1-D input (from python) into H,W,Batch format
55 | // i.e., 0, 1, 2, 3, -> 1st input of 0,1,2,3-th batch
56 | // only for BL test
57 | func reshape_input_BL(input []float64, in_wid int) (out []complex128) {
58 | out = make([]complex128, len(input))
59 | batch := len(input) / (in_wid * in_wid)
60 | l := 0
61 |
62 | for i := 0; i < in_wid; i++ {
63 | for j := 0; j < in_wid; j++ {
64 | for k := 0; k < batch; k++ {
65 | out[i*in_wid+j+k*in_wid*in_wid] = complex(input[l], 0)
66 | l++
67 | }
68 | }
69 | }
70 |
71 | return out
72 | }
73 |
74 | // Reshape 1-D kernel (from python) into (H,W,in,out) format, and applies BN, then into max ker
75 | // ker[i][j][ib][ob]: (i,j)-th elt of kernel for ib-th input to ob-th output
76 | // only for BL test // for transposed conv, we should add rearragement!!
77 | // norm == 1 : normal case, norm == 4 : in & out batch is (1,0,0,0,2,0,0,0,3,0,0,0,4,0,0,0)
78 | func reshape_ker_BL(input, BN_a []float64, ker_wid, inB, outB, max_bat, pos, norm int, trans bool) (max_ker_rs [][][][]float64) {
79 | ker_rs := make([][][][]float64, ker_wid)
80 | for i := 0; i < ker_wid; i++ {
81 | ker_rs[i] = make([][][]float64, ker_wid)
82 | for j := 0; j < ker_wid; j++ {
83 | ker_rs[i][j] = make([][]float64, inB)
84 | for ib := 0; ib < inB; ib++ {
85 | ker_rs[i][j][ib] = make([]float64, outB)
86 | for ob := 0; ob < outB; ob++ {
87 | if trans {
88 | if ib < (inB / 4) {
89 | ker_rs[i][j][ib][ob] = input[(4*ib+pos)+ob*inB+(ker_wid-j-1)*outB*inB+(ker_wid-i-1)*outB*inB*ker_wid] * BN_a[ob] // Apply BN
90 | }
91 | } else {
92 | ker_rs[i][j][ib][ob] = input[ob+ib*outB+j*outB*inB+i*outB*inB*ker_wid] * BN_a[ob] // Apply BN
93 | }
94 | }
95 | }
96 | }
97 | }
98 | // overload to max batch case
99 | max_ker_rs = make([][][][]float64, ker_wid)
100 | for i := 0; i < ker_wid; i++ {
101 | max_ker_rs[i] = make([][][]float64, ker_wid)
102 | for j := 0; j < ker_wid; j++ {
103 | max_ker_rs[i][j] = make([][]float64, max_bat)
104 | for ib := 0; ib < max_bat; ib++ {
105 | max_ker_rs[i][j][ib] = make([]float64, max_bat)
106 | }
107 | for ib := 0; ib < inB; ib++ {
108 | for ob := 0; ob < outB; ob++ {
109 | max_ker_rs[i][j][norm*ib][norm*ob] = ker_rs[i][j][ib][ob]
110 | }
111 | }
112 | }
113 | }
114 |
115 | return max_ker_rs
116 | }
117 |
118 | // rotate input ciphertext and outputs rotated ciphertexts
119 | // always assume Odd ker_wid
120 | func preConv_BL(evaluator ckks.Evaluator, ct_in *ckks.Ciphertext, in_wid, ker_wid int) (ct_in_rots []*ckks.Ciphertext) {
121 | ker_size := ker_wid * ker_wid
122 | ct_in_rots = make([]*ckks.Ciphertext, ker_size)
123 |
124 | st := -(ker_wid / 2)
125 | end := ker_wid / 2
126 |
127 | var rotations []int
128 | for i := st; i <= end; i++ {
129 | for j := st; j <= end; j++ {
130 | rotations = append(rotations, i*in_wid+j)
131 | }
132 | }
133 | ct_rots_test := evaluator.RotateHoisted(ct_in, rotations)
134 | k := 0
135 | for i := st; i <= end; i++ {
136 | for j := st; j <= end; j++ {
137 | ct_in_rots[k] = ct_rots_test[i*in_wid+j]
138 | k++
139 | }
140 | }
141 |
142 | return ct_in_rots
143 | }
144 |
145 | // eval Convolution for the part of output: need to sum this up with rotations
146 | func postConv_BL(param ckks.Parameters, encoder ckks.Encoder, evaluator ckks.Evaluator, ct_in_rots []*ckks.Ciphertext, in_wid, ker_wid, rot, pad int, max_ker_rs [][][][]float64) (ct_out *ckks.Ciphertext) {
147 |
148 | max_batch := param.Slots() / (in_wid * in_wid)
149 | postKer := make([]complex128, param.Slots())
150 | pl_tmp := ckks.NewPlaintext(param, ct_in_rots[0].Level(), param.Scale())
151 |
152 | iter := 0
153 | for i := 0; i < ker_wid; i++ {
154 | for j := 0; j < ker_wid; j++ {
155 | for k := 0; k < max_batch; k++ {
156 | for ki := 0; ki < (in_wid - pad); ki++ { // position of input
157 | for kj := 0; kj < (in_wid - pad); kj++ {
158 | postKer[k*in_wid*in_wid+ki*in_wid+kj] = complex(max_ker_rs[i][j][k][(k-rot+max_batch)%max_batch], 0)
159 | if (((ki + i - (ker_wid / 2)) < 0) || ((ki + i - (ker_wid / 2)) >= (in_wid - pad))) || (((kj + j - (ker_wid / 2)) < 0) || ((kj + j - (ker_wid / 2)) >= (in_wid - pad))) {
160 | postKer[k*in_wid*in_wid+ki*in_wid+kj] = complex(0, 0)
161 | }
162 | }
163 | }
164 | }
165 | encoder.Encode(pl_tmp, postKer, param.LogSlots())
166 | encoder.ToNTT(pl_tmp)
167 | if (i == 0) && (j == 0) {
168 | ct_out = evaluator.MulNew(ct_in_rots[iter], pl_tmp)
169 | } else {
170 | ct_tmp := evaluator.MulNew(ct_in_rots[iter], pl_tmp)
171 | evaluator.Add(ct_out, ct_tmp, ct_out)
172 | }
173 | iter++
174 | }
175 | }
176 |
177 | return ct_out
178 | }
179 |
180 | // Reshape 1-D ker_in (from python) into batch number of ker_outs: ker_out[i][j] = j-th kernel (elements then batch order) for i-th output
181 | // i.e., ker_out is of the shape (out_batch, (in_batch * ker_size))
182 | // ker_out[i] = [k1 for 1st input, ..., ,kk for 1st input, k1 for 2nd input, ...]
183 | // trans = true for transposed convolution (in trans convolution of python, we should rearrage ker_out carefully)
184 | func reshape_ker(ker_in []float64, k_sz, out_batch int, trans bool) (ker_out [][]float64) {
185 | ker_out = make([][]float64, out_batch)
186 | in_batch := len(ker_in) / (k_sz * out_batch)
187 |
188 | for i := 0; i < out_batch; i++ {
189 | ker_out[i] = make([]float64, k_sz*in_batch)
190 | for j := 0; j < in_batch; j++ {
191 | for k := 0; k < k_sz; k++ {
192 | if trans {
193 | ker_out[i][j*k_sz+(k_sz-k-1)] = ker_in[j+i*in_batch+k*out_batch*in_batch]
194 | } else {
195 | ker_out[i][j*k_sz+k] = ker_in[i+j*out_batch+k*out_batch*in_batch]
196 | // ker_out[i][j*k_sz+k] = ker_in[j+i*in_batch+k*out_batch*in_batch]
197 | }
198 | }
199 | }
200 | }
201 | return
202 | }
203 |
204 | // Encode ker_outs from reshape_ker into the i-th ker vector output
205 | // in_wid, in_batch is those for input (to be convolved) includng padding
206 | func encode_ker_final(ker_in [][]float64, pos, i, in_wid, in_batch, ker_wid int) []float64 {
207 | vec_size := in_wid * in_wid * in_batch
208 | output := make([]float64, vec_size)
209 | bias := pos * ker_wid * ker_wid * in_batch
210 | k_sz := ker_wid * ker_wid
211 |
212 | // allocate each kernel so that 0-th batch and B-1th batch adds together at B-1th position (B = in_batch)
213 | for j := 0; j < in_batch; j++ { // j-th input (batch)
214 | for k := 0; k < k_sz; k++ {
215 | // fmt.Println("ecd: ", j, k)
216 | output[(in_wid*(k/ker_wid)+k%ker_wid)*in_batch+j] = ker_in[i][(in_batch-1-j)*k_sz+(k_sz-1-k)+bias] // * scale;
217 | }
218 | }
219 |
220 | // move the kernel to left adj times, so that the result of "transposed" convolution appears at 0-th position
221 | // adj := (in_wid+1)*(ker_wid-3)/2 + (in_batch - 1)
222 |
223 | adj := (in_batch - 1) + (in_batch)*(in_wid+1)*(ker_wid-1)/2
224 | tmp := make([]float64, adj)
225 | for i := 0; i < adj; i++ {
226 | tmp[i] = output[vec_size-adj+i]
227 | output[vec_size-adj+i] = -output[i]
228 | }
229 | for i := 0; i < vec_size-2*adj; i++ {
230 | output[i] = output[i+adj]
231 | }
232 | for i := 0; i < adj; i++ {
233 | output[i+vec_size-2*adj] = tmp[i]
234 | }
235 |
236 | return output
237 | }
238 |
239 | // Generate the logN # of plaintexts idx[i] = X^(2^i) and GaloisKeys for each
240 | // Required for Packing
241 | func gen_idxNlogs(ECD_LV int, keygen rlwe.KeyGenerator, sk *rlwe.SecretKey, encoder ckks.Encoder, params ckks.Parameters) (idx []*ckks.Plaintext, pack_eval ckks.Evaluator) {
242 | logN := params.LogN()
243 | N := params.N()
244 | gals := []uint64{}
245 | idx = make([]*ckks.Plaintext, logN)
246 | coeffs := make([]float64, N)
247 |
248 | for i := 0; i < logN; i++ {
249 | coeffs[1< 1; i /= 2 {
282 | logStep++
283 | }
284 | j := params.LogN() - logStep
285 |
286 | for step >= norm {
287 | for i := 0; i < step; i += norm {
288 | tmp1 = pack_eval.MulNew(ctxts[i+step], idx[logStep])
289 | tmp2 = pack_eval.SubNew(ctxts[i], tmp1)
290 | pack_eval.Add(ctxts[i], tmp1, tmp1)
291 | pack_eval.RotateGal(tmp2, (1< set the same as python
487 | func prep_Ker(params ckks.Parameters, encoder ckks.Encoder, ker_in, BN_a []float64, in_wid, ker_wid, real_ib, real_ob, norm, ECD_LV, pos int, trans bool) (pl_ker []*ckks.Plaintext) {
488 | max_bat := params.N() / (in_wid * in_wid)
489 | ker_size := ker_wid * ker_wid
490 | ker_rs := reshape_ker(ker_in, ker_size, real_ob, trans) // ker1[i][j] = j-th kernel for i-th output
491 |
492 | for i := 0; i < real_ob; i++ { // apply batch normalization
493 | for j := range ker_rs[i] {
494 | ker_rs[i][j] = ker_rs[i][j] * BN_a[i]
495 | }
496 | }
497 |
498 | max_ker_rs := make([][]float64, max_bat) // overloading ker_rs to the case with max_batch
499 | for i := 0; i < max_bat; i++ {
500 | max_ker_rs[i] = make([]float64, max_bat*ker_size)
501 | }
502 | for i := 0; i < real_ob; i++ {
503 | for j := 0; j < real_ib; j++ {
504 | for k := 0; k < ker_size; k++ {
505 | max_ker_rs[norm*i][norm*j*ker_size+k] = ker_rs[i][j*ker_size+k]
506 | }
507 | }
508 | }
509 |
510 | pl_ker = make([]*ckks.Plaintext, max_bat)
511 | for i := 0; i < max_bat; i++ {
512 | pl_ker[i] = ckks.NewPlaintext(params, ECD_LV, params.Scale())
513 | encoder.EncodeCoeffs(encode_ker_final(max_ker_rs, pos, i, in_wid, max_bat, ker_wid), pl_ker[i])
514 | encoder.ToNTT(pl_ker[i])
515 | }
516 |
517 | return pl_ker
518 | }
519 |
520 | // Eval Conv, then Pack
521 | // The ciphertexts must be packed into full (without vacant position)
522 | func conv_then_pack(params ckks.Parameters, pack_evaluator ckks.Evaluator, ctxt_in *ckks.Ciphertext, pl_ker []*ckks.Plaintext, plain_idx []*ckks.Plaintext, max_ob, norm, ECD_LV int, out_scale float64) *ckks.Ciphertext {
523 | start := time.Now()
524 | ctxt_out := make([]*ckks.Ciphertext, max_ob)
525 | for i := 0; i < max_ob; i++ {
526 | if i%norm == 0 {
527 | ctxt_out[i] = pack_evaluator.MulNew(ctxt_in, pl_ker[i])
528 | pack_evaluator.SetScale(ctxt_out[i], out_scale/(float64(max_ob/norm)))
529 | }
530 | // pack_evaluator.Rescale(ctxt_out[i], float64(1<<10), ctxt_out[i])
531 | }
532 | mt := time.Since(start)
533 | fmt.Println("\t mult time: ", mt)
534 | ctxt_result := pack_ctxts(pack_evaluator, ctxt_out, max_ob, max_ob/norm, plain_idx, params)
535 | fmt.Println("\t Pack time: ", time.Since(start)-mt)
536 |
537 | // fmt.Println("Result Scale: ", math.Log2(ctxt_result.Scale))
538 | // fmt.Println("Result LV: ", ctxt_result.Level())
539 | // fmt.Printf("Done in %s \n", time.Since(start))
540 |
541 | if (out_scale != ctxt_result.Scale) || (0 != ctxt_result.Level()) {
542 | panic("LV or scale after conv then pack, inconsistent")
543 | }
544 |
545 | return ctxt_result
546 | }
547 |
--------------------------------------------------------------------------------
/eval.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "time"
7 |
8 | "github.com/dwkim606/test_lattigo/ckks"
9 | )
10 |
11 | // take raw_in_wid then outputs appropriate kp_wid and out_batch
12 | // only for our convs Test (not for BL)
13 | func set_Variables(batch, raw_in_wid, in_wid, ker_wid int, kind string) (kp_wid, out_batch, logN int, trans bool) {
14 | N := batch * in_wid * in_wid
15 | logN = 0
16 | for ; (1 << logN) < N; logN++ {
17 | }
18 | max_kp_wid := in_wid - ((ker_wid - 1) / 2) // max possible size of raw_in_wid
19 |
20 | switch kind {
21 | case "Conv":
22 | trans = false
23 | kp_wid = raw_in_wid
24 | out_batch = batch
25 | if kp_wid > max_kp_wid {
26 | fmt.Println("max raw_in_wid: ", max_kp_wid)
27 | panic("too large raw_in_wid.")
28 | }
29 | case "StrConv", "StrConv_fast", "StrConv_odd":
30 | trans = false
31 | kp_wid = 2 * (in_wid/2 - ker_wid/2)
32 | out_batch = batch
33 | if kp_wid > max_kp_wid {
34 | fmt.Println("max raw_in_wid: ", max_kp_wid)
35 | panic("too large raw_in_wid.")
36 | }
37 | case "StrConv_inside":
38 | trans = false
39 | kp_wid = (in_wid/2 - ker_wid/2)
40 | out_batch = batch
41 | case "TransConv":
42 | trans = true
43 | kp_wid = 2 * raw_in_wid
44 | out_batch = batch / 4
45 | if kp_wid > max_kp_wid {
46 | fmt.Println("max raw_in_wid: ", max_kp_wid/2)
47 | panic("too large raw_in_wid.")
48 | }
49 | default:
50 | panic("Wrong kinds!")
51 | }
52 |
53 | return
54 | }
55 |
56 | // apply rotation for strided conv (compress) or transposed conv (extend)
57 | // the same rotation for all batches; use BSGS to reduce rotations
58 | // assume that input batches are well-ordered. compress: (0,4) (1,5) (2,6) (3,7) to (0,1,2,...6,7) extend: (0,2,4,6,1,3,5,7) to (0,1) (2,3) (4,5) (6,7)
59 | // rotation for each batch position (0 to 3) is applied after or before compress or extend, resp.
60 | // total rotation = 2*in_wid*4 + (4-1); depth = 2
61 | func evalRot_BL(cont *context, ct_input *ckks.Ciphertext, in_wid, pos int, trans bool) (ct_res *ckks.Ciphertext) {
62 | if trans {
63 | in_size := in_wid * in_wid
64 | cont.evaluator.Rotate(ct_input, pos*in_size, ct_input)
65 | ct_res = bsgs_ctxt(cont.evaluator, cont.encoder, ct_input, cont.m_idx[in_wid][0], cont.r_idx[in_wid][0], cont.params)
66 | } else {
67 | out_size := in_wid * in_wid / 4
68 | ct_res = bsgs_ctxt(cont.evaluator, cont.encoder, ct_input, cont.m_idx[in_wid][0], cont.r_idx[in_wid][0], cont.params)
69 | cont.evaluator.Rotate(ct_res, -pos*out_size, ct_res)
70 | }
71 | return
72 | }
73 |
74 | // Eval Conv only, always assume max batch
75 | // in_wid must be Po2 (also include padding), includes kernel preparation
76 | // norm == 1 : normal case, norm == 4 : in & out batch is (1,0,0,0,2,0,0,0,3,0,0,0,4,0,0,0)
77 | // for test, use pack_evaluator optimizer
78 | func evalConv_BN_BL_test(cont *context, ct_input *ckks.Ciphertext, ker_in, bn_a, bn_b []float64, in_wid, ker_wid, real_ib, real_ob, pos, norm, pad int, trans, printResult bool) (ct_res *ckks.Ciphertext) {
79 | in_size := in_wid * in_wid
80 | out_size := in_size
81 | max_batch := cont.N / (2 * in_size)
82 |
83 | // fmt.Println()
84 | // fmt.Println("=============== (KER) PREPARATION ===============")
85 | // fmt.Println()
86 | start := time.Now()
87 | max_ker_rs := reshape_ker_BL(ker_in, bn_a, ker_wid, real_ib, real_ob, max_batch, pos, norm, trans)
88 | scale_exp := cont.params.Scale() * cont.params.Scale()
89 | if trans {
90 | scale_exp = cont.params.Scale() * cont.params.Scale() * cont.params.Scale()
91 | }
92 | bn_b_slots := make([]complex128, cont.N/2)
93 | for i, elt := range bn_b {
94 | for j := 0; j < in_wid-pad; j++ {
95 | for k := 0; k < in_wid-pad; k++ {
96 | bn_b_slots[j+k*in_wid+norm*out_size*i] = complex(elt, 0)
97 | }
98 | }
99 | }
100 |
101 | pl_bn_b := ckks.NewPlaintext(cont.params, cont.ECD_LV, scale_exp)
102 | cont.encoder.EncodeNTT(pl_bn_b, bn_b_slots, cont.logN-1)
103 | fmt.Printf("Plaintext (kernel) preparation, Done in %s \n", time.Since(start))
104 |
105 | // fmt.Println()
106 | // fmt.Println("=============== EVALUATION ===============")
107 | // fmt.Println()
108 | start = time.Now()
109 | ct_inputs_rots := preConv_BL(cont.pack_evaluator, ct_input, in_wid, ker_wid)
110 | fmt.Printf("preConv done in %s \n", time.Since(start))
111 |
112 | var rot_iters int
113 | if norm*real_ob == max_batch {
114 | rot_iters = real_ob
115 | } else {
116 | rot_iters = max_batch
117 | }
118 | for i := 0; i < rot_iters; i++ {
119 | ct_tmp := postConv_BL(cont.params, cont.encoder, cont.pack_evaluator, ct_inputs_rots, in_wid, ker_wid, norm*i, pad, max_ker_rs)
120 | if i == 0 {
121 | ct_res = ct_tmp
122 | } else {
123 | cont.evaluator.Add(ct_res, cont.pack_evaluator.RotateNew(ct_tmp, norm*i*out_size), ct_res)
124 | }
125 | }
126 |
127 | if ct_res.Scale != scale_exp {
128 | panic("Different scale between pl_bn_b and ctxt")
129 | }
130 | cont.evaluator.Add(ct_res, pl_bn_b, ct_res)
131 | fmt.Printf("Conv (with BN) Done in %s \n", time.Since(start))
132 |
133 | return ct_res
134 | }
135 |
136 | // reduce mean and final FC layer (in_batch -> 16)
137 | // assume that ct_input has batch (1,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,4,0,0,0,0,0,0,0)
138 | // ker_fc is of size in_batch*10 and 1-dim from [64,10] shape
139 | func evalRMFC_BL(cont *context, ct_input *ckks.Ciphertext, ker_fc, bias []float64, printResult bool) (ct_res *ckks.Ciphertext) {
140 | rs_ker := make([][]float64, 64)
141 | for i := 0; i < 64; i++ {
142 | rs_ker[i] = make([]float64, 16)
143 | for j := 0; j < 10; j++ {
144 | rs_ker[i][j] = ker_fc[j+i*10] / 64.0 // we will add 64 elts instead of averaging them
145 | }
146 | }
147 |
148 | // sum 64 elements instead of averaging them
149 | ct_avg := ct_input
150 | for i := 1; i < 64; i *= 2 {
151 | ct_avg = cont.evaluator.AddNew(ct_avg, cont.evaluator.RotateNew(ct_avg, i))
152 | }
153 |
154 | for i := 0; i < 16; i++ {
155 | tmp := make([]complex128, cont.N/2)
156 | for j := 0; j < 64; j++ {
157 | tmp[j*64*8] = complex(rs_ker[j][(j%16+16-i)%16], 0)
158 | }
159 | pl_ker := cont.encoder.EncodeNTTAtLvlNew(ct_avg.Level(), tmp, cont.logN-1)
160 |
161 | if i == 0 {
162 | ct_res = cont.evaluator.MulNew(ct_avg, pl_ker)
163 | } else {
164 | ct_tmp := cont.evaluator.MulNew(ct_avg, pl_ker)
165 | cont.evaluator.Add(ct_res, cont.evaluator.RotateNew(ct_tmp, i*64*8), ct_res)
166 | }
167 | }
168 |
169 | // final rotations to add up (4 = 64/16)
170 | for i := 1; i < 4; i *= 2 {
171 | ct_res = cont.evaluator.AddNew(ct_res, cont.evaluator.RotateNew(ct_res, i*16*64*8))
172 | }
173 |
174 | tmp := make([]complex128, cont.N/2)
175 | for j := 0; j < 10; j++ {
176 | tmp[j*64*8] = complex(bias[j], 0)
177 | }
178 | pl_bias := cont.encoder.EncodeNTTAtLvlNew(ct_res.Level(), tmp, cont.logN-1)
179 | cont.evaluator.Add(ct_res, pl_bias, ct_res)
180 |
181 | return
182 | }
183 |
184 | // reduce mean and final FC layer
185 | // assume that ct_input has full batch (1,2,3,4,...)
186 | // ker_fc is of size in_batch*out and 1-dim from [in_batch,out] shape
187 | func evalRMFC_BL_img(cont *context, ct_input *ckks.Ciphertext, ker_fc []float64, in_batch, out_num, raw_in_wid int, printResult bool) (ct_res *ckks.Ciphertext) {
188 | rs_ker := make([][]float64, in_batch)
189 | for i := 0; i < in_batch; i++ {
190 | rs_ker[i] = make([]float64, out_num)
191 | for j := 0; j < out_num; j++ {
192 | rs_ker[i][j] = ker_fc[j+i*out_num] / float64(raw_in_wid*raw_in_wid) // we will add 64 elts instead of averaging them
193 | }
194 | }
195 |
196 | // sum 64 elements instead of averaging them (but only 49 elts are non-zero)
197 | ct_avg := ct_input
198 | for i := 1; i < 64; i *= 2 {
199 | ct_avg = cont.evaluator.AddNew(ct_avg, cont.evaluator.RotateNew(ct_avg, i))
200 | }
201 |
202 | for i := 0; i < in_batch; i++ {
203 | tmp := make([]complex128, cont.N/2)
204 | for j := 0; j < out_num; j++ {
205 | tmp[(i+j)%in_batch*64] = complex(rs_ker[(i+j)%in_batch][j], 0)
206 | }
207 | pl_ker := cont.encoder.EncodeNTTAtLvlNew(ct_avg.Level(), tmp, cont.logN-1)
208 |
209 | if i == 0 {
210 | ct_res = cont.evaluator.MulNew(ct_avg, pl_ker)
211 | } else {
212 | ct_tmp := cont.evaluator.MulNew(ct_avg, pl_ker)
213 | cont.evaluator.Add(ct_res, cont.evaluator.RotateNew(ct_tmp, i*64), ct_res)
214 | }
215 | }
216 |
217 | return
218 | }
219 |
220 | // Eval Conv only, always assume max batch
221 | // in_wid must be Po2 (also include padding),
222 | // include kernel preparation
223 | // norm = 2 : in&out batches are (1,0,2,0,3,0,...)
224 | func evalConv_BN(cont *context, ct_input *ckks.Ciphertext, ker_in, bn_a, bn_b []float64, in_wid, ker_wid, real_ib, real_ob, norm int, out_scale float64, trans bool) (ct_res *ckks.Ciphertext) {
225 | max_batch := cont.N / (in_wid * in_wid)
226 |
227 | // fmt.Println()
228 | // fmt.Println("=============== (KER) PREPARATION ===============")
229 | // fmt.Println()
230 | start := time.Now()
231 | pl_ker := prep_Ker(cont.params, cont.encoder, ker_in, bn_a, in_wid, ker_wid, real_ib, real_ob, norm, cont.ECD_LV, 0, trans)
232 | // fmt.Printf("for prep_ker %s \n", time.Since(start))
233 | b_coeffs := make([]float64, cont.N)
234 | for i := range bn_b {
235 | for j := 0; j < in_wid*in_wid; j++ {
236 | b_coeffs[norm*i+j*max_batch] = bn_b[i]
237 | }
238 | }
239 | // scale_exp := ct_input.Scale * cont.params.Scale() * float64(max_batch/norm)
240 | pl_bn_b := ckks.NewPlaintext(cont.params, 0, out_scale)
241 | // pl_bn_b := ckks.NewPlaintext(cont.params, cont.ECD_LV, scale_exp) // contain plaintext values
242 | cont.encoder.EncodeCoeffs(b_coeffs, pl_bn_b)
243 | cont.encoder.ToNTT(pl_bn_b)
244 | fmt.Printf("Plaintext (kernel) preparation, Done in %s \n", time.Since(start))
245 |
246 | // fmt.Println()
247 | // fmt.Println("=============== EVALUATION ===============")
248 | // fmt.Println()
249 |
250 | start = time.Now()
251 | ct_res = conv_then_pack(cont.params, cont.pack_evaluator, ct_input, pl_ker, cont.pl_idx, max_batch, norm, cont.ECD_LV, out_scale)
252 | if (pl_bn_b.Scale != ct_res.Scale) || (ct_res.Level() != 0) {
253 | fmt.Println("plain scale: ", pl_bn_b.Scale)
254 | fmt.Println("ctxt scale: ", ct_res.Scale)
255 | fmt.Println("ctxt lv: ", ct_res.Level())
256 | panic("LV or scale after conv then pack, inconsistent")
257 | }
258 | cont.evaluator.Add(ct_res, pl_bn_b, ct_res) // for Batch Normalization (BN)
259 |
260 | fmt.Printf("Conv (with BN) Done in %s \n", time.Since(start))
261 |
262 | return ct_res
263 | }
264 |
265 | // Eval Conv, BN, relu with Boot
266 | // in_wid must be Po2 (also include padding)
267 | // stride = true: apply [1,2,2,1] stride; false: [1,1,1,1]
268 | // pack_pos: position to pack (0,1,2,3): only for strided case
269 | // real_ib, real_ob: real number of batches (less or equal than max_batch)
270 | // step: step of the output (used for conv_inside only)
271 | // log_sparse: 0 if no full slot, 1 if half slot, etc (maybe the same as norm[])
272 | func evalConv_BNRelu_new(cont *context, ct_input *ckks.Ciphertext, ker_in, bn_a, bn_b []float64, alpha, pow float64, in_wid, kp_wid, ker_wid, real_ib, real_ob, norm, pack_pos, step, iter, log_sparse int, kind string, fast_pack, debug bool) (ct_res *ckks.Ciphertext) {
273 | // iter := 2 // for full packing (contrary to half packing)
274 | var trans, stride, odd, inside bool
275 | odd = false
276 | trans = false
277 | stride = false
278 | inside = false
279 | sparse := false
280 | in_step := step
281 | modify_ker := false
282 | full := false
283 | switch kind {
284 | case "Conv_sparse": // sparse pack, normal conv
285 | sparse = true
286 | case "StrConv_sparse": // sparse pack, strided conv, 2 convs -> add them -> 1 boot
287 | modify_ker = true
288 | sparse = true
289 | stride = true
290 | case "StrConv_sparse_full": // sparse pack but full pack
291 | sparse = true
292 | modify_ker = true
293 | stride = true
294 | full = true
295 | case "Conv_inside":
296 | inside = true
297 | case "StrConv_inside":
298 | in_step = step / 2
299 | if step%2 != 0 {
300 | panic("step can not be divided by 2 (for strided conv)")
301 | }
302 | inside = true
303 | case "StrConv", "StrConv_fast":
304 | stride = true
305 | case "StrConv_odd":
306 | stride = true
307 | odd = true
308 | case "TransConv":
309 | trans = true
310 | case "Conv":
311 | default:
312 | panic("No kind!")
313 | }
314 |
315 | if odd { // multiply x^{offset} before conv to move input to appropriate position.
316 | odd_time := time.Now()
317 | var offset int // offset depends on the real_wid = in_wid-ker_wid/2 is even or not
318 | if (in_wid-ker_wid/2)%2 == 0 {
319 | offset = 0
320 | } else {
321 | offset = cont.N / (in_wid * in_wid) * (in_wid + 1)
322 | // offset = real_ib * norm * (in_wid + 1)
323 | }
324 | // fmt.Println("offset: ", offset)
325 | xi := make([]float64, cont.N)
326 | xi[offset] = 1.0
327 | xi_plain := ckks.NewPlaintext(cont.params, cont.ECD_LV, 1.0)
328 | cont.encoder.EncodeCoeffs(xi, xi_plain)
329 | cont.encoder.ToNTT(xi_plain)
330 | ct_input = cont.evaluator.MulNew(ct_input, xi_plain)
331 | fmt.Printf("for odd stride, offset time %s \n", time.Since(odd_time))
332 | }
333 |
334 | var ct_conv *ckks.Ciphertext
335 | if modify_ker {
336 | if !full {
337 | bn_a_0 := make([]float64, real_ib)
338 | bn_a_1 := make([]float64, real_ib)
339 | bn_b_0 := make([]float64, real_ib)
340 | bn_b_1 := make([]float64, real_ib)
341 | for i := range bn_b_0 {
342 | bn_a_0[i] = bn_a[2*i]
343 | bn_a_1[i] = bn_a[2*i+1]
344 | bn_b_0[i] = bn_b[2*i]
345 | bn_b_1[i] = bn_b[2*i+1]
346 | }
347 | ker_in_0 := make([]float64, len(ker_in)/2)
348 | ker_in_1 := make([]float64, len(ker_in)/2)
349 | for k := 0; k < ker_wid*ker_wid; k++ {
350 | for i := 0; i < real_ib; i++ {
351 | for j := 0; j < real_ob/2; j++ {
352 | ker_in_0[k*real_ib*real_ob/2+(i*real_ob/2+j)] = ker_in[k*real_ib*real_ob+(i*real_ob+2*j)] // [i][2*j]
353 | ker_in_1[k*real_ib*real_ob/2+(i*real_ob/2+j)] = ker_in[k*real_ib*real_ob+(i*real_ob+2*j+1)] // [i][2*j+1]
354 | }
355 | }
356 | }
357 | ct_result1 := evalConv_BN(cont, ct_input, ker_in_0, bn_a_0, bn_b_0, in_wid, ker_wid, real_ib, real_ob/2, norm/2, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans)
358 | ct_result2 := evalConv_BN(cont, ct_input, ker_in_1, bn_a_1, bn_b_1, in_wid, ker_wid, real_ib, real_ob/2, norm/2, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans)
359 |
360 | xi := make([]float64, cont.N)
361 | offset := norm / 4 // cont.N / norm
362 | xi[offset] = 1.0
363 | xi_plain := ckks.NewPlaintext(cont.params, ct_result2.Level(), 1.0)
364 | cont.encoder.EncodeCoeffs(xi, xi_plain)
365 | cont.encoder.ToNTT(xi_plain)
366 | ct_result2 = cont.evaluator.MulNew(ct_result2, xi_plain)
367 |
368 | ct_conv = cont.evaluator.AddNew(ct_result1, ct_result2)
369 |
370 | // res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv))
371 | max_batch := cont.N / (in_wid * in_wid)
372 | // prt_mat_norm_step(res_tmp, max_batch, norm/4, step, 1, 4, false)
373 |
374 | for i := range xi {
375 | xi[i] = 0.0
376 | }
377 | if (in_wid-ker_wid/2)%2 != 0 {
378 | xi[0] = 1.0
379 | } else {
380 | // fmt.Println("offset nonzero!")
381 | offset = cont.N - (max_batch)*(in_wid+1)
382 | xi[offset] = -1.0
383 | }
384 | xi_plain = ckks.NewPlaintext(cont.params, ct_conv.Level(), 1.0)
385 | cont.encoder.EncodeCoeffs(xi, xi_plain)
386 | cont.encoder.ToNTT(xi_plain)
387 | ct_conv = cont.evaluator.MulNew(ct_conv, xi_plain)
388 | // res_tmp = cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv))
389 | // fmt.Println("After offset: ")
390 | // prt_mat_norm_step(res_tmp, max_batch, norm/4, step, 1, 4, false)
391 | } else { // need to cover the case with full packing
392 | ct_conv = evalConv_BN(cont, ct_input, ker_in, bn_a, bn_b, in_wid, ker_wid, real_ib, real_ob, norm, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans)
393 |
394 | // res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv))
395 | max_batch := cont.N / (in_wid * in_wid)
396 | // prt_mat_norm_step(res_tmp, max_batch, norm, step, 1, 4, false)
397 | xi := make([]float64, cont.N)
398 | for i := range xi {
399 | xi[i] = 0.0
400 | }
401 | var offset int
402 | if (in_wid-ker_wid/2)%2 != 0 {
403 | xi[0] = 1.0
404 | } else {
405 | // fmt.Println("offset nonzero!")
406 | offset = cont.N - (max_batch)*(in_wid+1)
407 | xi[offset] = -1.0
408 | }
409 | xi_plain := ckks.NewPlaintext(cont.params, ct_conv.Level(), 1.0)
410 | cont.encoder.EncodeCoeffs(xi, xi_plain)
411 | cont.encoder.ToNTT(xi_plain)
412 | ct_conv = cont.evaluator.MulNew(ct_conv, xi_plain)
413 | // res_tmp = cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv))
414 | // fmt.Println("After offset: ")
415 | // prt_mat_norm_step(res_tmp, max_batch, norm, step, 1, 4, false)
416 | }
417 | } else {
418 | if inside {
419 | new_ker_wid := ker_wid*in_step - in_step + 1
420 | new_ker_in := make([]float64, len(ker_in)*new_ker_wid*new_ker_wid/(ker_wid*ker_wid))
421 |
422 | for i := 0; i < ker_wid; i++ {
423 | for j := 0; j < ker_wid; j++ {
424 | for ib := 0; ib < real_ib; ib++ {
425 | for ob := 0; ob < real_ob; ob++ {
426 | new_ker_in[in_step*i*new_ker_wid*real_ib*real_ob+(in_step*j)*real_ib*real_ob+ib*real_ob+ob] = ker_in[i*ker_wid*real_ib*real_ob+j*real_ib*real_ob+ib*real_ob+ob]
427 | }
428 | }
429 | }
430 | }
431 | ct_conv = evalConv_BN(cont, ct_input, new_ker_in, bn_a, bn_b, in_wid, new_ker_wid, real_ib, real_ob, norm, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans)
432 | } else {
433 | ct_conv = evalConv_BN(cont, ct_input, ker_in, bn_a, bn_b, in_wid, ker_wid, real_ib, real_ob, norm, math.Exp2(math.Round(math.Log2(float64(cont.params.Q()[0]))-(pow+8))), trans)
434 | }
435 | }
436 |
437 | ct_conv.Scale = ct_conv.Scale * math.Pow(2, pow)
438 |
439 | // Only for checking the correctness (for CtoS)
440 | var slot1, slot2 []complex128
441 | var cfs_preB []float64
442 | if debug {
443 | cfs_preB = cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_conv))
444 | }
445 | fmt.Println("Bootstrapping... Ours (until CtoS):")
446 | start := time.Now()
447 | ct_boots := make([]*ckks.Ciphertext, 2)
448 | switch log_sparse {
449 | case 0:
450 | ct_boots[0], ct_boots[1], _ = cont.btp.BootstrappConv_CtoS(ct_conv)
451 | case 1:
452 | ct_boots[0], ct_boots[1], _ = cont.btp2.BootstrappConv_CtoS(ct_conv)
453 | case 2:
454 | ct_boots[0], ct_boots[1], _ = cont.btp3.BootstrappConv_CtoS(ct_conv)
455 | case 3:
456 | ct_boots[0], ct_boots[1], _ = cont.btp4.BootstrappConv_CtoS(ct_conv)
457 | case 4:
458 | ct_boots[0], ct_boots[1], _ = cont.btp5.BootstrappConv_CtoS(ct_conv)
459 | default:
460 | panic("No cases for log_sparse")
461 | }
462 |
463 | fmt.Printf("Done in %s \n", time.Since(start))
464 | // fmt.Println("after Boot (CtoS): LV = ", ct_boots[0].Level(), " Scale = ", math.Log2(ct_boots[0].Scale))
465 |
466 | if debug {
467 | slot1, slot2 = debugCtoS(cont, cfs_preB, log_sparse)
468 | slot1 = printDebug(log_sparse, cont.params, ct_boots[0], slot1, cont.decryptor, cont.encoder) // Compare before & after CtoS
469 | slot2 = printDebug(log_sparse, cont.params, ct_boots[1], slot2, cont.decryptor, cont.encoder) // Compare before & after CtoS
470 | }
471 |
472 | start = time.Now()
473 | for ul := 0; ul < iter; ul++ { // up & low parts
474 | if ct_boots[ul] != nil {
475 | ct_boots[ul] = evalReLU(cont.params, cont.evaluator, ct_boots[ul], alpha)
476 | cont.evaluator.MulByPow2(ct_boots[ul], int(pow), ct_boots[ul])
477 | }
478 | }
479 | fmt.Printf("ReLU Done in %s \n", time.Since(start))
480 | start = time.Now()
481 |
482 | // Only for checking the correctness (for ReLU)
483 | var cfs_postB []float64
484 | if debug {
485 | fmt.Println("after Relu: ", math.Log2(ct_boots[0].Scale), "lv: ", ct_boots[0].Level())
486 | relu1, relu2 := debugReLU(cont, slot1, slot2, alpha, pow)
487 | relu1 = printDebug(log_sparse, cont.params, ct_boots[0], relu1, cont.decryptor, cont.encoder)
488 | relu2 = printDebug(log_sparse, cont.params, ct_boots[1], relu2, cont.decryptor, cont.encoder)
489 | cfs_postB = debugStoC(cont, relu1, relu2, in_wid, kp_wid, pack_pos, step, log_sparse, kind, fast_pack)
490 | }
491 |
492 | ct_keep := make([]*ckks.Ciphertext, iter) // for extend (rotation) of ctxt_in
493 | for ul := 0; ul < iter; ul++ {
494 | if trans {
495 | ct_keep[ul] = ext_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.r_idx[in_wid][ul], cont.params)
496 | } else if stride {
497 | if sparse { // we will use ext_double to reduce rotations; hence similar to fast_pack case
498 | if ct_boots[ul] != nil {
499 | if ul == 0 {
500 | ct_keep[ul] = ext_double_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.m_idx[in_wid][pack_pos], cont.r_idx[in_wid][pack_pos], cont.params)
501 | } else {
502 | ct_keep[ul] = ext_double_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.m_idx_l[in_wid][pack_pos], cont.r_idx_l[in_wid][pack_pos], cont.params)
503 | }
504 | } else {
505 | ct_keep[ul] = nil
506 | }
507 | } else {
508 | if fast_pack {
509 | if ul == 0 {
510 | ct_keep[ul] = ext_double_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.m_idx[in_wid][pack_pos], cont.r_idx[in_wid][pack_pos], cont.params)
511 | } else {
512 | ct_keep[ul] = ext_double_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.m_idx_l[in_wid][pack_pos], cont.r_idx_l[in_wid][pack_pos], cont.params)
513 | }
514 | } else {
515 | if ul == 0 {
516 | ct_keep[ul] = ext_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.r_idx[in_wid][pack_pos], cont.params)
517 | } else {
518 | ct_keep[ul] = ext_ctxt(cont.evaluator, cont.encoder, ct_boots[ul], cont.r_idx_l[in_wid][pack_pos], cont.params)
519 | }
520 | }
521 | }
522 | } else if inside {
523 | if ct_boots[ul] != nil {
524 | if sparse {
525 | ct_keep[ul] = keep_ctxt(cont.params, cont.evaluator, cont.encoder, ct_boots[ul], cont.ext_idx[in_wid][ul])
526 | } else {
527 | ct_keep[ul] = keep_ctxt(cont.params, cont.evaluator, cont.encoder, ct_boots[ul], cont.ext_idx[step][ul])
528 | }
529 | } else {
530 | ct_keep[ul] = nil
531 | }
532 | } else {
533 | if ct_boots[ul] != nil {
534 | ct_keep[ul] = keep_ctxt(cont.params, cont.evaluator, cont.encoder, ct_boots[ul], cont.ext_idx[in_wid][ul])
535 | } else {
536 | ct_keep[ul] = nil
537 | }
538 | }
539 | }
540 |
541 | if iter == 1 {
542 | ct_boots[1] = nil
543 | ct_res = cont.btp.BootstrappConv_StoC(ct_keep[0], ct_boots[1])
544 | if log_sparse != 0 {
545 | panic("we didn't implement this case")
546 | }
547 | } else {
548 | switch log_sparse {
549 | case 0:
550 | ct_res = cont.btp.BootstrappConv_StoC(ct_keep[0], ct_keep[1])
551 | case 1:
552 | ct_res = cont.btp2.BootstrappConv_StoC(ct_keep[0], ct_keep[1])
553 | case 2:
554 | ct_res = cont.btp3.BootstrappConv_StoC(ct_keep[0], ct_keep[1])
555 | case 3:
556 | ct_res = cont.btp4.BootstrappConv_StoC(ct_keep[0], ct_keep[1])
557 | case 4:
558 | ct_res = cont.btp5.BootstrappConv_StoC(ct_keep[0], ct_keep[1])
559 | default:
560 | panic("No cases for log_sparse")
561 | }
562 | }
563 |
564 | cont.evaluator.Rescale(ct_res, cont.params.Scale(), ct_res)
565 | fmt.Printf("Boot (StoC) Done in %s \n", time.Since(start))
566 |
567 | // Only for checking the correctness (for StoC)
568 | if debug {
569 | fmt.Println("Boot out: ")
570 | switch log_sparse {
571 | case 0:
572 | printDebugCfs(cont.params, ct_res, cfs_postB, cont.decryptor, cont.encoder)
573 | case 1:
574 | printDebugCfs(cont.params2, ct_res, cfs_postB, cont.decryptor, cont.encoder)
575 | case 2:
576 | printDebugCfs(cont.params3, ct_res, cfs_postB, cont.decryptor, cont.encoder)
577 | case 3:
578 | printDebugCfs(cont.params4, ct_res, cfs_postB, cont.decryptor, cont.encoder)
579 | case 4:
580 | printDebugCfs(cont.params5, ct_res, cfs_postB, cont.decryptor, cont.encoder)
581 | default:
582 | panic("No cases for log_sparse")
583 | }
584 | max_batch := cont.N / (in_wid * in_wid)
585 | res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_res))
586 | if inside {
587 | start := 1
588 | if ker_wid == 5 {
589 | start = step
590 | }
591 | prt_mat_norm_step(res_tmp, max_batch, norm, step, start, 3, false)
592 | } else {
593 | if stride {
594 | max_batch = 4 * cont.N / (in_wid * in_wid)
595 | start := 1
596 | if ker_wid == 5 {
597 | start = step
598 | }
599 | prt_mat_norm_step(res_tmp, max_batch, norm, step, start, 3, false)
600 | } else {
601 | prt_mat_norm(res_tmp, max_batch, norm, 3, false)
602 | }
603 | }
604 | }
605 |
606 | return ct_res
607 | }
608 |
609 | // log_spars = 0 -> full slot, 1 -> full/2 , ...
610 | func debugCtoS(cont *context, cfs_preB []float64, log_sparse int) (slot1, slot2 []complex128) {
611 | preB_cfs1 := make([]float64, cont.params.Slots())
612 | preB_cfs2 := make([]float64, cont.params.Slots())
613 | slot1 = make([]complex128, cont.params.Slots()/(1< 10) || (i_batch > 3) {
595 | panic("Too many tests (>10) or too many batch index (>3)")
596 | }
597 | case "convReLU":
598 | boot = true
599 | resnet = false
600 | if (num_tests > 10) || (i_batch > 3) {
601 | panic("Too many tests (>10) or too many batch index (>3)")
602 | }
603 | case "resnet":
604 | resnet = true
605 | default:
606 | panic("wrong test type")
607 | }
608 |
609 | if resnet {
610 | // // latest version for resnet crop cifar10
611 | ker_wid, _ := strconv.Atoi(os.Args[2])
612 | depth, _ := strconv.Atoi(os.Args[3])
613 | wide_case, _ := strconv.Atoi(os.Args[4])
614 | test_num, _ := strconv.Atoi(os.Args[5])
615 | cf100, _ := strconv.ParseBool(os.Args[6])
616 |
617 | debug := false // if turned on, it shows all intermediate input
618 | if wide_case == 1 {
619 | // test with small inputs
620 | testResNet_crop_sparse(0, test_num, ker_wid, depth, debug, cf100)
621 | // end test with small inputs
622 | // testResNet_crop_fast_in(0, test_num, ker_wid, depth, debug, cf100)
623 | } else if (wide_case == 2) || (wide_case == 3) {
624 | testResNet_crop_sparse_wide(0, test_num, ker_wid, depth, wide_case, debug, cf100)
625 | // testResNet_crop_sparse_wide_test(0, test_num, ker_wid, depth, wide_case, debug, cf100)
626 | // testResNet_crop_fast_wide_in(0, test_num, ker_wid, depth, wide_case, debug, cf100)
627 | } else {
628 | panic("Wrong wide case!")
629 | }
630 |
631 | } else {
632 | if boot {
633 | fmt.Println("Convolution followed by ReLU (& Bootstrapping) test start!")
634 | } else {
635 | fmt.Println("Convolution test start! (No Bootstrapping)")
636 | }
637 | fmt.Println("Ker: ", ker_wid, "batches: ", batchs[i_batch], "widths: ", widths[i_batch])
638 |
639 | fmt.Println("Base Line start.")
640 | testConv_BL_in(batchs[i_batch], widths[i_batch], ker_wid, num_tests, boot)
641 |
642 | fmt.Println("Ours start.")
643 | testConv_in(batchs[i_batch], widths[i_batch], ker_wid, num_tests, boot)
644 | }
645 | }
646 |
647 | func printDebugCfs(params ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant []float64, decryptor ckks.Decryptor, encoder ckks.Encoder) (valuesTest []float64) {
648 | total_size := make([]int, 15)
649 |
650 | valuesTest_pre := encoder.DecodeCoeffs(decryptor.DecryptNew(ciphertext))
651 | valuesTest = make([]float64, 2*params.Slots())
652 | step := len(valuesTest_pre) / len(valuesTest) // to cover cases with less slots (N/2 => 1, N/4 => 2, ...)
653 | for i := range valuesTest {
654 | valuesTest[i] = valuesTest_pre[i*step]
655 | }
656 |
657 | fmt.Println("len val Want:", len(valuesWant))
658 | fmt.Println("len val Test:", len(valuesTest))
659 |
660 | fmt.Println()
661 | fmt.Printf("Level: %d (logQ = %d)\n", ciphertext.Level(), params.LogQLvl(ciphertext.Level()))
662 | fmt.Printf("Scale: 2^%f\n", math.Log2(ciphertext.Scale))
663 | fmt.Printf("ValuesTest:")
664 | for i := range total_size {
665 | fmt.Printf("%6.10f, ", valuesTest[i])
666 | }
667 | fmt.Printf("... \n")
668 | fmt.Printf("ValuesWant:")
669 | for i := range total_size {
670 | fmt.Printf("%6.10f, ", valuesWant[i*step])
671 | }
672 | fmt.Printf("... \n")
673 |
674 | valuesTestC := make([]complex128, len(valuesTest))
675 | valuesWantC := make([]complex128, len(valuesWant)/step)
676 |
677 | for i := range valuesTestC {
678 | valuesTestC[i] = complex(valuesTest[i], 0)
679 | valuesWantC[i] = complex(valuesWant[i*step], 0)
680 | }
681 |
682 | precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWantC[:params.Slots()], valuesTestC[:params.Slots()], params.LogSlots(), 0)
683 |
684 | fmt.Println(precStats.String())
685 |
686 | precStats = ckks.GetPrecisionStats(params, encoder, nil, valuesWantC[params.Slots():], valuesTestC[params.Slots():], params.LogSlots(), 0)
687 |
688 | fmt.Println(precStats.String())
689 | fmt.Println()
690 |
691 | return
692 | }
693 |
694 | func printDebugCfsPlain(valuesTest, valuesWant []float64) {
695 | total_size := make([]int, 10)
696 |
697 | fmt.Printf("ValuesTest:")
698 | for i := range total_size {
699 | fmt.Printf("%6.10f, ", valuesTest[i])
700 | }
701 | fmt.Printf("... \n")
702 | fmt.Printf("ValuesWant:")
703 | for i := range total_size {
704 | fmt.Printf("%6.10f, ", valuesWant[i])
705 | }
706 | fmt.Printf("... \n")
707 |
708 | valuesWantC := make([]complex128, len(valuesWant))
709 | valuesTestC := make([]complex128, len(valuesTest))
710 | for i := range valuesWantC {
711 | valuesWantC[i] = complex(valuesWant[i], 0)
712 | valuesTestC[i] = complex(valuesTest[i], 0)
713 | }
714 | precStats := ckks.GetPrecisionStatsPlain(valuesWantC, valuesTestC, len(valuesWantC), 0)
715 | fmt.Println(precStats.String())
716 | fmt.Println()
717 | }
718 |
719 | // decrypt ciphertext then compare with valuesWant, then output the msgs to valuesTest
720 | // log_sparse = 0 -> full slot & TWO ciphertexts
721 | // log_sparse = 1 -> full/2 & ONE ciphertext
722 | func printDebug(log_sparse int, params ckks.Parameters, ciphertext *ckks.Ciphertext, valuesWant []complex128, decryptor ckks.Decryptor, encoder ckks.Encoder) (valuesTest []complex128) {
723 | total_size := make([]int, 15)
724 | if ciphertext == nil {
725 | return nil
726 | // valuesTest = make([]complex128, params.Slots()/(1< in_wid)) && ((j <= show) || (j+show > in_wid))) {
817 | fmt.Printf("(%d, %d): ", i, j)
818 | for b := 0; b < batch; b++ {
819 | tmp[b] = real(vec[in_wid*in_wid*b+(i-1)*in_wid+(j-1)])
820 | }
821 | prt_vec(tmp)
822 | }
823 | }
824 | }
825 | }
826 |
827 | // vec = arrgvec with batch batches, each batch is sqr-sized
828 | // print (i,j)-th position in [batches], only shows (show, show) entries show = 0 : print all
829 | func prt_mat(vec []float64, batch, show int) {
830 | mat_size := len(vec) / batch
831 | row := int(math.Sqrt(float64(mat_size)))
832 | j, k := 1, 1
833 | for i := 0; i < len(vec); i += batch {
834 | if (show == 0) || (((j <= show) || (j > row-show)) && ((k <= show) || (k > (row - show)))) {
835 | fmt.Printf("(%d, %d): ", j, k)
836 | prt_vec(vec[i : i+batch])
837 | }
838 | k++
839 | if k*k > mat_size {
840 | k = 1
841 | j++
842 | }
843 | }
844 | }
845 |
846 | // vec = arrgvec with batch batches, each batch is sqr-sized
847 | // print (i,j)-th position in [batches], only shows (show, show) entries show = 0 : print all
848 | func prt_mat_norm(vec []float64, batch, norm, show int, half bool) {
849 | mat_size := len(vec) / batch
850 | row := int(math.Sqrt(float64(mat_size)))
851 | if half {
852 | row = row / 2
853 | }
854 | tmp := make([]float64, batch/norm)
855 | j, k := 1, 1
856 | for i := 0; i < len(vec); i += batch {
857 | if (show == 0) || (((j <= show) || ((j > row-show) && (j <= row))) && ((k <= show) || ((k > row-show) && (k <= row)))) {
858 | fmt.Printf("(%d, %d): ", j, k)
859 | for idx := range tmp {
860 | tmp[idx] = vec[i+norm*idx]
861 | }
862 | prt_vec(tmp)
863 | }
864 | k++
865 | if k*k > mat_size {
866 | k = 1
867 | j++
868 | }
869 | }
870 | }
871 |
872 | // vec = arrgvec with batch batches, each batch is sqr-sized
873 | // print (i,j)-th position in [batches], only shows (show, show) entries show = 0 : print all
874 | // input is strided with steps, read from start
875 | func prt_mat_norm_step(vec []float64, batch, norm, step, start, show int, half bool) {
876 | mat_size := len(vec) / batch
877 | row := int(math.Sqrt(float64(mat_size)))
878 | if half {
879 | row = row / 2
880 | }
881 | tmp := make([]float64, batch/norm)
882 | j, k := 1, 1
883 | for i := 0; i < len(vec); i += batch {
884 | if (show == 0) || (((j <= show*step) || ((j > row-show*step) && (j <= row))) && ((k <= show*step) || ((k > row-show*step) && (k <= row)))) {
885 | if ((j-start)%step == 0) && ((k-start)%step == 0) {
886 | fmt.Printf("(%d, %d): ", (j-start)/step+1, (k-start)/step+1)
887 | for idx := range tmp {
888 | tmp[idx] = vec[i+norm*idx]
889 | }
890 | prt_vec(tmp)
891 | }
892 | }
893 | k += 1
894 | if k*k > mat_size {
895 | k = 1
896 | j += 1
897 | }
898 | }
899 | }
900 |
901 | // only (sj,sk) element in all batches
902 | func prt_mat_one(vec []float64, batch, sj, sk int) (out []float64) {
903 | mat_size := len(vec) / batch
904 | j, k := 1, 1
905 | for i := 0; i < len(vec); i += batch {
906 | if (j == sj) && (k == sk) {
907 | fmt.Print(vec[i : i+batch])
908 | out = vec[i : i+batch]
909 | }
910 | k++
911 | if k*k > mat_size {
912 | k = 1
913 | j++
914 | }
915 | }
916 | return out
917 | }
918 |
919 | // only (sj,sk) element in all batches
920 | func prt_mat_one_norm(vec []float64, batch, norm, sj, sk int) (out []float64) {
921 | mat_size := len(vec) / batch
922 | tmp := make([]float64, batch/norm)
923 | j, k := 1, 1
924 | for i := 0; i < len(vec); i += batch {
925 | if (j == sj) && (k == sk) {
926 | for idx := range tmp {
927 | tmp[idx] = vec[i+norm*idx]
928 | }
929 | prt_vec(tmp)
930 | out = tmp
931 | }
932 | k++
933 | if k*k > mat_size {
934 | k = 1
935 | j++
936 | }
937 | }
938 | return out
939 | }
940 |
941 | // only out_num, (1,1) element in all batches (1,0,0,0,0,0,0,0,2,0,0,0,0,0,...)
942 | func prt_mat_one_BL(vec []complex128, max_bat, out_num int) (out []float64) {
943 | mat_size := len(vec) / max_bat
944 | out = make([]float64, out_num)
945 |
946 | for i := range out {
947 | out[i] = real(vec[i*mat_size*8])
948 | }
949 |
950 | return out
951 | }
952 |
953 | // only out_num, (1,1) element in all batches (1,0,0,0,0,0,0,0,2,0,0,0,0,0,...)
954 | func prt_mat_one_BL_img(vec []complex128, max_bat, out_num int) (out []float64) {
955 | mat_size := len(vec) / max_bat
956 | out = make([]float64, out_num)
957 |
958 | for i := range out {
959 | out[i] = real(vec[i*mat_size])
960 | }
961 |
962 | return out
963 | }
964 |
965 | func check(e error) {
966 | if e != nil {
967 | panic(e)
968 | }
969 | }
970 |
971 | func readTxt(name_file string, size int) (input []float64) {
972 |
973 | file, err := os.Open(name_file)
974 | check(err)
975 | scanner := bufio.NewScanner(file)
976 | scanner.Split(bufio.ScanWords)
977 |
978 | for scanner.Scan() {
979 | add, _ := strconv.ParseFloat(scanner.Text(), 64)
980 | input = append(input, add)
981 | }
982 | file.Close()
983 | // fmt.Print(input)
984 |
985 | if (size != 0) && (len(input) != size) {
986 | panic("input size inconsistent!")
987 | }
988 |
989 | return input
990 | }
991 |
992 | func writeTxt(name_file string, input []float64) {
993 | file, err := os.OpenFile(name_file, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644)
994 | if err != nil {
995 | log.Fatalf("failed creating file: %s", err)
996 | }
997 |
998 | datawriter := bufio.NewWriter(file)
999 | for _, data := range input {
1000 | _, _ = datawriter.WriteString(strconv.FormatFloat(data, 'e', -1, 64) + "\n")
1001 | }
1002 |
1003 | datawriter.Flush()
1004 | file.Close()
1005 | }
1006 |
1007 | func prep_Input(input []float64, raw_in_wid, in_wid, N, norm int, trans, printResult bool) (out []float64) {
1008 | out = make([]float64, N)
1009 | batch := N / (in_wid * in_wid)
1010 | k := 0
1011 |
1012 | if trans {
1013 | for i := 0; i < in_wid/2; i++ {
1014 | for j := 0; j < in_wid/2; j++ {
1015 | for b := 0; b < batch/norm; b++ {
1016 | if (i < raw_in_wid) && (j < raw_in_wid) {
1017 | out[(2*i+1)*in_wid*batch+(2*j+1)*batch+b*norm] = input[k]
1018 | k++
1019 | }
1020 | }
1021 | }
1022 | }
1023 | } else {
1024 | for i := 0; i < in_wid; i++ {
1025 | for j := 0; j < in_wid; j++ {
1026 | for b := 0; b < batch/norm; b++ {
1027 | if (i < raw_in_wid) && (j < raw_in_wid) {
1028 | out[i*in_wid*batch+j*batch+b*norm] = input[k]
1029 | k++
1030 | }
1031 | }
1032 | }
1033 | }
1034 | }
1035 |
1036 | if printResult {
1037 | fmt.Println("Input matrix: ")
1038 | prt_mat(out, batch, 3)
1039 | }
1040 |
1041 | return out
1042 | }
1043 |
1044 | func removeDuplicateInt(intSlice []int) []int {
1045 | allKeys := make(map[int]bool)
1046 | list := []int{}
1047 | for _, item := range intSlice {
1048 | if _, value := allKeys[item]; !value {
1049 | allKeys[item] = true
1050 | list = append(list, item)
1051 | }
1052 | }
1053 | return list
1054 | }
1055 |
1056 | // only returns valid values from
1057 | func post_process(in_cfs []float64, raw_in_wid, in_wid int) []float64 {
1058 | batch := len(in_cfs) / (in_wid * in_wid)
1059 | out := make([]float64, raw_in_wid*raw_in_wid*batch)
1060 |
1061 | for i := 0; i < raw_in_wid; i++ {
1062 | for j := 0; j < raw_in_wid; j++ {
1063 | for b := 0; b < batch; b++ {
1064 | out[i*raw_in_wid*batch+batch*j+b] = in_cfs[i*in_wid*batch+batch*j+b]
1065 | }
1066 | }
1067 | }
1068 |
1069 | return out
1070 | }
1071 |
1072 | // from 8*8 with 1 pad -> 7*7
1073 | func post_trim_BL(in_vals []complex128, raw_in_wid, in_wid int) []float64 {
1074 | batch := len(in_vals) / (in_wid * in_wid)
1075 | out := make([]float64, raw_in_wid*raw_in_wid*batch)
1076 |
1077 | for b := 0; b < batch; b++ {
1078 | for i := 0; i < raw_in_wid; i++ {
1079 | for j := 0; j < raw_in_wid; j++ {
1080 | out[b*raw_in_wid*raw_in_wid+i*raw_in_wid+j] = real(in_vals[b*in_wid*in_wid+i*in_wid+j])
1081 | }
1082 | }
1083 | }
1084 |
1085 | return out
1086 | }
1087 |
1088 | // only returns valid values from
1089 | func post_process_BL(in_vals []float64, raw_in_wid int) []float64 {
1090 | batch := len(in_vals) / (raw_in_wid * raw_in_wid)
1091 | out := make([]float64, raw_in_wid*raw_in_wid*batch)
1092 |
1093 | for i := 0; i < raw_in_wid; i++ {
1094 | for j := 0; j < raw_in_wid; j++ {
1095 | for b := 0; b < batch; b++ {
1096 | out[i*raw_in_wid*batch+j*batch+b] = in_vals[b*raw_in_wid*raw_in_wid+i*raw_in_wid+j]
1097 | }
1098 | }
1099 | }
1100 |
1101 | return out
1102 | }
1103 |
--------------------------------------------------------------------------------
/rot_util.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | )
7 |
8 | func toFixed(num float64, precision int) float64 {
9 | output := math.Pow(10, float64(precision))
10 | return float64(math.Round(num*output)) / output
11 | }
12 |
13 | // distribute input to the output starting from pos position
14 | func arrgvec(input []int, output []int, pos int) {
15 | batch := len(output) / len(input)
16 | for i, elt := range input {
17 | output[pos+i*batch] = elt
18 | }
19 | }
20 |
21 | func print_vec(title string, input []float64, in_wid int, pos int) {
22 | row := make([]float64, in_wid)
23 | step := len(input) / (in_wid * in_wid)
24 | fmt.Println(title, ": ")
25 | for j := 0; j < in_wid; j++ {
26 | for i := range row {
27 | row[i] = toFixed(input[(j*in_wid+i)*step+pos], 3)
28 | }
29 | fmt.Println(row)
30 | }
31 | fmt.Println()
32 |
33 | }
34 |
35 | func lRot(a []float64, rotation int) []float64 {
36 | size := len(a)
37 | var newArray []float64
38 | for i := 0; i < rotation; i++ {
39 | newArray = a[1:size]
40 | newArray = append(newArray, a[0])
41 | a = newArray
42 | }
43 | return a
44 | }
45 |
46 | func rRot(a []float64, rotation int) []float64 {
47 | return lRot(a, len(a)-rotation)
48 | }
49 |
50 | func addSlice(a []float64, b []float64) []float64 {
51 | c := make([]float64, len(a))
52 | for i := range a {
53 | c[i] = a[i] + b[i]
54 | }
55 | return c
56 | }
57 |
58 | // (bit-reversed) input vector := (upper or lower part) of the total vector having in_wid * in_wid size elts
59 | // Keep only the kp_wid*kp_wid values
60 | // e.g., 1* // ** -> 10 // 00 (before bitreversed, pad = 1)
61 | // ul: up or low
62 | // assume N/2 sized input
63 | func keep_vec(input []float64, in_wid, kp_wid, ul int) []float64 {
64 | output := make([]float64, len(input))
65 |
66 | tmp := gen_keep_vec(len(input), in_wid, kp_wid, ul)
67 |
68 | for i := range output {
69 | output[i] = input[i] * float64(tmp[i])
70 | }
71 |
72 | return output
73 | }
74 |
75 | func keep_vec_stride(input []float64, in_wid, kp_wid, step, ul int, raw_in_wid_odd bool) []float64 {
76 | output := make([]float64, len(input))
77 |
78 | tmp := gen_keep_vec_stride(len(input), in_wid, kp_wid, step, ul, raw_in_wid_odd)
79 |
80 | for i := range output {
81 | output[i] = input[i] * float64(tmp[i])
82 | }
83 |
84 | return output
85 | }
86 |
87 | func keep_vec_sparse(input []float64, in_wid, kp_wid, log_sparse int) []float64 {
88 | output := make([]float64, len(input))
89 |
90 | tmp := gen_keep_vec_sparse(len(input), in_wid, kp_wid, log_sparse)
91 |
92 | for i := range output {
93 | output[i] = input[i] * float64(tmp[i])
94 | }
95 |
96 | return output
97 | }
98 |
99 | func comprs_vec_sparse(input []float64, in_wid, kp_wid, log_sparse, ul, pos int) []float64 {
100 |
101 | tmp1, tmp2 := gen_comprs_sparse(len(input), in_wid, kp_wid, log_sparse, ul, pos)
102 |
103 | mid_output := make([]float64, len(input))
104 | for i := range tmp1 {
105 | rot_input := make([]float64, len(input))
106 | for j := range rot_input {
107 | rot_input[j] = input[j] * float64(tmp1[i][j])
108 | }
109 | rot_input = lRot(rot_input, i)
110 | if i < 0 {
111 | rot_input = rRot(rot_input, -i)
112 | }
113 |
114 | for j := range mid_output {
115 | mid_output[j] = rot_input[j] + mid_output[j]
116 | }
117 | }
118 |
119 | output := make([]float64, len(input))
120 | for i := range tmp2 {
121 | rot_input := make([]float64, len(input))
122 | for j := range rot_input {
123 | rot_input[j] = mid_output[j] * float64(tmp2[i][j])
124 | }
125 | rot_input = lRot(rot_input, i)
126 | if i < 0 {
127 | rot_input = rRot(rot_input, -i)
128 | }
129 |
130 | for j := range output {
131 | output[j] = rot_input[j] + output[j]
132 | }
133 | }
134 |
135 | return output
136 | }
137 |
138 | // returns the idx for keep_vec
139 | // N: length of input (upper + lower)
140 | // ul = 0 -> upper part, ul = 1 -> lower part
141 | func gen_keep_vec(vec_size, in_wid, kp_wid, ul int) (idx []int) {
142 | logN := 0
143 | for ; (1 << logN) < (2 * vec_size); logN++ {
144 | }
145 | idx = make([]int, vec_size)
146 | batch := 2 * vec_size / (in_wid * in_wid)
147 | if kp_wid < in_wid/2 {
148 | panic("keep width too small. less than in_wid/2")
149 | }
150 |
151 | if ul == 0 {
152 | for i := 0; i < in_wid/2; i++ {
153 | for j := 0; j < kp_wid; j++ {
154 | for b := 0; b < batch; b++ {
155 | id := int(reverseBits(uint32(in_wid*batch*i+batch*j+b), logN-1))
156 | idx[id] = 1
157 | }
158 | }
159 | }
160 | } else if ul == 1 {
161 | for i := 0; i < kp_wid-in_wid/2; i++ {
162 | for j := 0; j < kp_wid; j++ {
163 | for b := 0; b < batch; b++ {
164 | id := int(reverseBits(uint32(in_wid*batch*i+batch*j+b), logN-1))
165 | idx[id] = 1
166 | }
167 | }
168 | }
169 | } else {
170 | panic("ul not 0 nor 1")
171 | }
172 |
173 | return idx
174 | }
175 |
176 | // returns the idx for keep_vec given that we have sparse pack with log_sparse (=0 if full, 1 if half sparse pack)
177 | // Always assume log_sparse >= 1 so that both up and down part is in one ciphertext
178 | // N: length of input (upper + lower), vec_size is always N/2!
179 | func gen_keep_vec_sparse(vec_size, in_wid, kp_wid, log_sparse int) (idx []int) {
180 | logN := 0
181 | for ; (1 << logN) < (2 * vec_size); logN++ {
182 | }
183 | idx = make([]int, vec_size)
184 | batch := 2 * vec_size / (in_wid * in_wid)
185 | sparsity := 1 << log_sparse
186 | if sparsity == 1 {
187 | panic("We do not support full packing in gen_keep_vec_sparse")
188 | }
189 | if kp_wid < in_wid/2 {
190 | panic("keep width too small. less than in_wid/2")
191 | }
192 |
193 | for i := 0; i < in_wid/2; i++ {
194 | for j := 0; j < kp_wid; j++ {
195 | for b := 0; b < batch/sparsity; b++ {
196 | id := int(reverseBits(uint32(in_wid*batch*i+batch*j+b*sparsity), logN-1))
197 | idx[id] = 1
198 | }
199 | }
200 | }
201 | for i := 0; i < kp_wid-in_wid/2; i++ {
202 | for j := 0; j < kp_wid; j++ {
203 | for b := 0; b < batch/sparsity; b++ {
204 | id := int(reverseBits(uint32(in_wid*batch*i+batch*j+b*sparsity), logN-1)) + vec_size/sparsity
205 | idx[id] = 1
206 | }
207 | }
208 | }
209 |
210 | post_slot := 2 * len(idx) / sparsity
211 | for i := 0; i < post_slot; i++ {
212 | for j := 1; j < sparsity/2; j++ {
213 | idx[i+post_slot*j] = idx[i]
214 | }
215 | }
216 |
217 | return idx
218 | }
219 |
220 | // returns the idx for keep_vec // same as gen_keep_vec but keeps strided output only
221 | // N: length of input (upper + lower)
222 | // ul = 0 -> upper part, ul = 1 -> lower part
223 | // kp_wid: number of elements to keep (in each row)
224 | // step: distance between each element
225 | // raw_in_wid_odd: raw_in_wid is odd or not
226 | func gen_keep_vec_stride(vec_size, in_wid, kp_wid, step, ul int, raw_in_wid_odd bool) (idx []int) {
227 | logN := 0
228 | for ; (1 << logN) < (2 * vec_size); logN++ {
229 | }
230 | idx = make([]int, vec_size)
231 | batch := 2 * vec_size / (in_wid * in_wid)
232 |
233 | var init int
234 | if raw_in_wid_odd {
235 | init = 0
236 | } else {
237 | init = step - 1
238 | }
239 |
240 | if ul == 0 {
241 | for i := 0; i < kp_wid; i++ {
242 | if (init + i*step) < in_wid/2 {
243 | for j := 0; j < kp_wid; j++ {
244 | for b := 0; b < batch; b++ {
245 | id := int(reverseBits(uint32(in_wid*batch*(init+i*step)+batch*(j*step+init)+b), logN-1))
246 | idx[id] = 1
247 | }
248 | }
249 | }
250 | }
251 | } else if ul == 1 {
252 | for i := 0; i < kp_wid; i++ {
253 | if (init + i*step) >= in_wid/2 {
254 | for j := 0; j < kp_wid; j++ {
255 | for b := 0; b < batch; b++ {
256 | id := int(reverseBits(uint32(in_wid*batch*(init+i*step-in_wid/2)+batch*(j*step+init)+b), logN-1))
257 | idx[id] = 1
258 | }
259 | }
260 | }
261 | }
262 | } else {
263 | panic("ul not 0 nor 1")
264 | }
265 |
266 | return idx
267 | }
268 |
269 | // Assume N/2 input vector
270 | // reverse of extend_full (after strided conv -> normal)
271 | // in_wid = input wid including padding
272 | // kp_wid = keep wid
273 | // padding = true: only keeps valid elements // (output) e.g., 12 00 // 34 00 // 00 00 // 00 00
274 | // padding = false: keeps all elements // (output) e.g., 12 // 34
275 | // 0 <= pos < 4 determines to which part the output is positioned at the final output
276 | // ul : up (0) or low (1) part
277 | func comprs_full(input []float64, in_wid, kp_wid, pos, ul int) []float64 {
278 | output := make([]float64, len(input))
279 | batch := 2 * len(input) / (in_wid * in_wid)
280 | if kp_wid < in_wid/2 {
281 | panic("keep width too small. less than in_wid/2")
282 | }
283 | pos = int(reverseBits(uint32(pos), 2))
284 | padding := false
285 | min_wid := in_wid / 4
286 | if in_wid%4 != 0 {
287 | panic("input wid not divisible by 4")
288 | }
289 | if in_wid%2 != 0 {
290 | panic("input wid not divisible by 2")
291 | }
292 | log_in_wid := 0
293 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ {
294 | }
295 |
296 | if padding {
297 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j
298 | tmp := make([]float64, len(input))
299 | for b := 0; b < batch; b++ {
300 | for i := 0; i < min_wid; i++ {
301 | idx := 2*min_wid*in_wid*b + in_wid*j + i + min_wid*in_wid + min_wid
302 | tmp[idx] = input[idx]
303 | }
304 | }
305 | rot := -2*j*min_wid + 2*pos*min_wid*min_wid - min_wid*in_wid - min_wid
306 | output = addSlice(output, rRot(tmp, rot))
307 | }
308 | // // when we want to extract even positioned inputs
309 | // for j := 0; j < min_wid; j++ { // kinds of mov depends on j
310 | // tmp := make([]int, len(input))
311 | // for b := 0; b < batch; b++ {
312 | // for i := 0; i < min_wid; i++ {
313 | // idx := 2*min_wid*in_wid*b + in_wid*j + i
314 | // tmp[idx] = input[idx]
315 | // }
316 | // }
317 | // rot := -2*j*min_wid + 2*pos*min_wid*min_wid
318 | // output = addSlice(output, rRot(tmp, rot))
319 | // }
320 | } else {
321 | if ul == 0 {
322 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j
323 | tmp := make([]float64, len(input))
324 | for b := 0; b < batch; b++ {
325 | for i := 0; i < min_wid; i++ {
326 | if reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid) {
327 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid
328 | tmp[idx] = input[idx]
329 | }
330 | }
331 | }
332 | rot := -j*min_wid + 2*pos*min_wid*min_wid - min_wid - in_wid*min_wid
333 | output = addSlice(output, rRot(tmp, rot))
334 | }
335 | } else {
336 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j
337 | tmp := make([]float64, len(input))
338 | for b := 0; b < batch; b++ {
339 | for i := 0; i < min_wid; i++ {
340 | if (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) && (reverseBits(uint32(3*min_wid+i), log_in_wid-1) < uint32(kp_wid-in_wid/2)) {
341 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid
342 | tmp[idx] = input[idx]
343 | }
344 | }
345 | }
346 | rot := -j*min_wid + 2*pos*min_wid*min_wid - min_wid - in_wid*min_wid
347 | output = addSlice(output, rRot(tmp, rot))
348 | }
349 | }
350 | // // when we want to extract even positioned inputs
351 | // for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j
352 | // tmp := make([]int, len(input))
353 | // for b := 0; b < batch; b++ {
354 | // for i := 0; i < min_wid; i++ {
355 | // idx := 2*min_wid*in_wid*b + 2*min_wid*j + i
356 | // tmp[idx] = input[idx]
357 | // }
358 | // }
359 | // rot := -j*min_wid + 2*pos*min_wid*min_wid
360 | // output = addSlice(output, rRot(tmp, rot))
361 | // }
362 | }
363 |
364 | return output
365 | }
366 |
367 | // Assume N/2 input vector
368 | // reverse of extend_full (after strided conv -> normal)
369 | // in_wid = input wid including padding
370 | // kp_wid = keep wid
371 | // 0 <= pos < 4 determines to which part the output is positioned at the final output
372 | // ul : up (0) or low (1) part
373 | func comprs_full_fast(input []float64, in_wid, kp_wid, pos, ul int) []float64 {
374 | mid_out := make([]float64, len(input))
375 | output := make([]float64, len(input))
376 | batch := 2 * len(input) / (in_wid * in_wid)
377 | if kp_wid < in_wid/2 {
378 | panic("keep width too small. less than in_wid/2")
379 | }
380 | pos = int(reverseBits(uint32(pos), 2))
381 | min_wid := in_wid / 4
382 | if in_wid%4 != 0 {
383 | panic("input wid not divisible by 4")
384 | }
385 | if in_wid%2 != 0 {
386 | panic("input wid not divisible by 2")
387 | }
388 | log_in_wid := 0
389 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ {
390 | }
391 |
392 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j
393 | tmp := make([]float64, len(input))
394 | for b := 0; b < batch; b++ {
395 | for i := 0; i < min_wid; i++ {
396 | if (ul == 0) && (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) {
397 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid
398 | tmp[idx] = input[idx]
399 | }
400 | if (ul == 1) && (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) && (reverseBits(uint32(min_wid+i), log_in_wid-1) < uint32(kp_wid-in_wid/2)) {
401 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid
402 | tmp[idx] = input[idx]
403 | }
404 | }
405 | }
406 | rot := -j*min_wid + 2*min_wid*min_wid - min_wid
407 | mid_out = addSlice(mid_out, rRot(tmp, rot))
408 | }
409 | for b := 0; b < batch; b++ {
410 | tmp := make([]float64, len(input))
411 | for j := 0; j < 2*min_wid; j++ {
412 | for i := 0; i < min_wid; i++ {
413 | idx := 2*min_wid*in_wid*b + 3*in_wid/2*min_wid + j*min_wid + i
414 | tmp[idx] = mid_out[idx]
415 | }
416 | }
417 | rot := -3*b*min_wid*in_wid/2 + pos*min_wid*in_wid/2*batch - 3*min_wid*in_wid/2
418 | output = addSlice(output, rRot(tmp, rot))
419 | }
420 |
421 | return output
422 | }
423 |
424 | // generate vectors for comprs_full (N/2 input)
425 | // returns the idx and rotations for each idx For comprs_full_hf
426 | // vec_size = slots, in_wid = real in_wid including padding,
427 | // CAUTION: rotation = -rotation (of comprs_full_hf)
428 | func gen_comprs_full(vec_size, in_wid, kp_wid, pos, ul int) (r_idx map[int][]int) {
429 | r_idx = make(map[int][]int)
430 | batch := 2 * vec_size / (in_wid * in_wid)
431 | if kp_wid < in_wid/2 {
432 | panic("keep width too small. less than in_wid/2")
433 | }
434 | pos = int(reverseBits(uint32(pos), 2))
435 | padding := false
436 | min_wid := in_wid / 4
437 | if in_wid%4 != 0 {
438 | panic("input wid not divisible by 4")
439 | }
440 | if in_wid%2 != 0 {
441 | panic("input wid not divisible by 2")
442 | }
443 | log_in_wid := 0
444 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ {
445 | }
446 |
447 | if padding {
448 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j
449 | tmp := make([]int, vec_size)
450 | for b := 0; b < batch; b++ {
451 | for i := 0; i < min_wid; i++ {
452 | idx := 2*min_wid*in_wid*b + in_wid*j + i + min_wid*in_wid + min_wid
453 | tmp[idx] = 1
454 | }
455 | }
456 | rot := 2*j*min_wid - 2*pos*min_wid*min_wid + min_wid*in_wid + min_wid
457 | r_idx[rot] = tmp
458 | }
459 | } else {
460 | if ul == 0 {
461 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j
462 | tmp := make([]int, vec_size)
463 | for b := 0; b < batch; b++ {
464 | if reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid) {
465 | for i := 0; i < min_wid; i++ {
466 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid
467 | tmp[idx] = 1
468 | }
469 | }
470 | }
471 | rot := j*min_wid - 2*pos*min_wid*min_wid + min_wid + in_wid*min_wid
472 | r_idx[rot] = tmp
473 | }
474 | } else {
475 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j
476 | tmp := make([]int, vec_size)
477 | for b := 0; b < batch; b++ {
478 | for i := 0; i < min_wid; i++ {
479 | if (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) && (reverseBits(uint32(3*min_wid+i), log_in_wid-1) < uint32(kp_wid-in_wid/2)) {
480 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid
481 | tmp[idx] = 1
482 | }
483 | }
484 | }
485 | rot := j*min_wid - 2*pos*min_wid*min_wid + min_wid + in_wid*min_wid
486 | r_idx[rot] = tmp
487 | }
488 | }
489 | }
490 |
491 | return r_idx
492 | }
493 |
494 | // generate vectors for comprs_full_fast (N/2 input)
495 | // returns the idx and rotations for each idx For comprs_full_hf
496 | // vec_size = slots, in_wid = real in_wid including padding,
497 | // CAUTION: rotation = -rotation (of comprs_full_hf)
498 | func gen_comprs_fast(vec_size, in_wid, kp_wid, pos, ul int) (m_idx, r_idx map[int][]int) {
499 | m_idx = make(map[int][]int)
500 | r_idx = make(map[int][]int)
501 | batch := 2 * vec_size / (in_wid * in_wid)
502 |
503 | if kp_wid < in_wid/2 {
504 | panic("keep width too small. less than in_wid/2")
505 | }
506 | pos = int(reverseBits(uint32(pos), 2))
507 | min_wid := in_wid / 4
508 | if in_wid%4 != 0 {
509 | panic("input wid not divisible by 4")
510 | }
511 | if in_wid%2 != 0 {
512 | panic("input wid not divisible by 2")
513 | }
514 | log_in_wid := 0
515 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ {
516 | }
517 |
518 | for j := 0; j < 2*min_wid; j++ { // kinds of mov depends on j
519 | tmp := make([]int, vec_size)
520 | for b := 0; b < batch; b++ {
521 | for i := 0; i < min_wid; i++ {
522 | if (ul == 0) && (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) {
523 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid
524 | tmp[idx] = 1
525 | }
526 | if (ul == 1) && (reverseBits(uint32(in_wid/2+j), log_in_wid) < uint32(kp_wid)) && (reverseBits(uint32(min_wid+i), log_in_wid-1) < uint32(kp_wid-in_wid/2)) {
527 | idx := 2*min_wid*in_wid*b + 2*min_wid*j + i + in_wid*min_wid + min_wid
528 | tmp[idx] = 1
529 | }
530 | }
531 | }
532 | rot := j*min_wid - 2*min_wid*min_wid + min_wid
533 | m_idx[rot] = tmp
534 | }
535 | for b := 0; b < batch; b++ { // kinds of mov depends on b
536 | tmp := make([]int, vec_size)
537 | for j := 0; j < 2*min_wid; j++ {
538 | for i := 0; i < min_wid; i++ {
539 | idx := 2*min_wid*in_wid*b + 3*in_wid/2*min_wid + j*min_wid + i
540 | tmp[idx] = 1
541 | }
542 | }
543 | rot := 3*b*min_wid*in_wid/2 - pos*min_wid*in_wid/2*batch + 3*min_wid*in_wid/2
544 | r_idx[rot] = tmp
545 | }
546 |
547 | return m_idx, r_idx
548 | }
549 |
550 | // generate vectors for comprs_full_fast (N/2 input)
551 | // returns the idx and rotations for each idx For comprs_full_hf
552 | // vec_size = full slots = N/2, in_wid = real in_wid including padding,
553 | // CAUTION: rotation = -rotation (of comprs_full_hf)
554 | // log_sparse: 0 => full slots, 1 => half slots, Of the INPUT
555 | // ul: 0(up), 1(low)
556 | // pos: position after pack [only for full packing case]
557 | func gen_comprs_sparse(vec_size, in_wid, kp_wid, log_sparse, ul, pos int) (m_idx, r_idx map[int][]int) {
558 | m_idx = make(map[int][]int)
559 | r_idx = make(map[int][]int)
560 | batch := 2 * vec_size / (in_wid * in_wid * (1 << log_sparse))
561 |
562 | // if kp_wid < in_wid/2 {
563 | // panic("keep width too small. less than in_wid/2")
564 | // }
565 | // pos = int(reverseBits(uint32(pos), 2))
566 | min_wid := in_wid / 2
567 | if in_wid%2 != 0 {
568 | panic("input wid not divisible by 2")
569 | }
570 | log_in_wid := 0
571 | for ; (1 << log_in_wid) < in_wid; log_in_wid++ {
572 | }
573 |
574 | if log_sparse != 0 {
575 | if pos != 0 {
576 | panic("No pos != 0 cases for log_sparse != 0")
577 | }
578 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j
579 | tmp := make([]int, vec_size)
580 | for b := 0; b < batch; b++ {
581 | for i := 0; i < min_wid/2; i++ {
582 | for k := 0; k < 2; k++ {
583 | if (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && ((reverseBits(uint32(i), log_in_wid-2) + uint32(k)*uint32(min_wid)/2) < uint32(kp_wid)) {
584 | idx := k*in_wid*min_wid*batch + in_wid*in_wid*b/2 + in_wid*j/2 + i
585 | tmp[idx] = 1
586 | }
587 | }
588 | }
589 | }
590 | // repeatedly write tmp elements for log_sparse > 1 cases.
591 | for i := 0; i < vec_size/(1<<(log_sparse-1)); i++ {
592 | for k := 1; k < (1 << (log_sparse - 1)); k++ {
593 | tmp[i+k*vec_size/(1<<(log_sparse-1))] = tmp[i]
594 | }
595 | }
596 | rot := j * min_wid / 2
597 | m_idx[rot] = tmp
598 | }
599 |
600 | for b := 0; b < batch; b++ { // kinds of mov depends on b
601 | tmp := make([]int, vec_size)
602 | for j := 0; j < min_wid; j++ {
603 | for i := 0; i < min_wid/2; i++ {
604 | for k := 0; k < 2; k++ {
605 | idx := k*in_wid*min_wid*batch + b*in_wid*in_wid/2 + j*min_wid/2 + i
606 | tmp[idx] = 1
607 | }
608 | }
609 | }
610 | // repeatedly write tmp elements for log_sparse > 1 cases.
611 | for i := 0; i < vec_size/(1<<(log_sparse-1)); i++ {
612 | for k := 1; k < (1 << (log_sparse - 1)); k++ {
613 | tmp[i+k*vec_size/(1<<(log_sparse-1))] = tmp[i]
614 | }
615 | }
616 | rot := 3 * b * min_wid * min_wid / 2
617 | r_idx[rot] = tmp
618 | }
619 | } else {
620 | if batch > 8*min_wid {
621 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j and b
622 | for bk := 0; bk < 8; bk++ {
623 | tmp := make([]int, vec_size)
624 | for b := 0; b < batch/8; b++ {
625 | for i := 0; i < min_wid/2; i++ {
626 | if (ul == 0) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2) < uint32(kp_wid)) {
627 | idx := 8*in_wid*min_wid*b + bk*min_wid*in_wid + min_wid*j + i
628 | tmp[idx] = 1
629 | }
630 | if (ul == 1) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2)+uint32(min_wid/2) < uint32(kp_wid)) {
631 | idx := 8*in_wid*min_wid*b + bk*min_wid*in_wid + min_wid*j + i
632 | tmp[idx] = 1
633 | }
634 | }
635 | }
636 | rot := j*min_wid/2 + 7*bk*min_wid*min_wid/2
637 | m_idx[rot] = tmp
638 | }
639 | }
640 |
641 | for b := 0; b < batch/8; b++ { // kinds of mov depends on b
642 | tmp := make([]int, vec_size)
643 | for bk := 0; bk < 8; bk++ {
644 | for j := 0; j < min_wid; j++ {
645 | for i := 0; i < min_wid/2; i++ {
646 | idx := 8*b*in_wid*min_wid + bk*min_wid*min_wid/2 + j*min_wid/2 + i
647 | tmp[idx] = 1
648 | }
649 | }
650 | }
651 | rot := 3*b*8*min_wid*min_wid/2 - int(reverseBits(uint32(pos), 2))*batch*min_wid*min_wid/2
652 | r_idx[rot] = tmp
653 | }
654 | } else if batch > 4*min_wid { //we may move 4*j for optimizations
655 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j and b
656 | for bk := 0; bk < 4; bk++ {
657 | tmp := make([]int, vec_size)
658 | for b := 0; b < batch/4; b++ {
659 | for i := 0; i < min_wid/2; i++ {
660 | if (ul == 0) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2) < uint32(kp_wid)) {
661 | idx := 4*in_wid*min_wid*b + bk*min_wid*in_wid + min_wid*j + i
662 | tmp[idx] = 1
663 | }
664 | if (ul == 1) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2)+uint32(min_wid/2) < uint32(kp_wid)) {
665 | idx := 4*in_wid*min_wid*b + bk*min_wid*in_wid + min_wid*j + i
666 | tmp[idx] = 1
667 | }
668 | }
669 | }
670 | rot := j*min_wid/2 + 3*bk*min_wid*min_wid/2
671 | m_idx[rot] = tmp
672 | }
673 | }
674 |
675 | for b := 0; b < batch/4; b++ { // kinds of mov depends on b
676 | tmp := make([]int, vec_size)
677 | for bk := 0; bk < 4; bk++ {
678 | for j := 0; j < min_wid; j++ {
679 | for i := 0; i < min_wid/2; i++ {
680 | idx := 4*b*in_wid*min_wid + bk*min_wid*min_wid/2 + j*min_wid/2 + i
681 | tmp[idx] = 1
682 | }
683 | }
684 | }
685 | rot := 3*b*4*min_wid*min_wid/2 - int(reverseBits(uint32(pos), 2))*batch*min_wid*min_wid/2
686 | r_idx[rot] = tmp
687 | }
688 | } else {
689 | for j := 0; j < min_wid; j++ { // kinds of mov depends on j and b
690 | tmp := make([]int, vec_size)
691 | for b := 0; b < batch; b++ {
692 | for i := 0; i < min_wid/2; i++ {
693 | if (ul == 0) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2) < uint32(kp_wid)) {
694 | idx := in_wid*min_wid*b + min_wid*j + i
695 | tmp[idx] = 1
696 | }
697 | if (ul == 1) && (reverseBits(uint32(j), log_in_wid-1) < uint32(kp_wid)) && (reverseBits(uint32(i), log_in_wid-2)+uint32(min_wid/2) < uint32(kp_wid)) {
698 | idx := in_wid*min_wid*b + min_wid*j + i
699 | tmp[idx] = 1
700 | }
701 | }
702 | }
703 | rot := j * min_wid / 2
704 | m_idx[rot] = tmp
705 | }
706 |
707 | for b := 0; b < batch; b++ { // kinds of mov depends on b
708 | tmp := make([]int, vec_size)
709 | for j := 0; j < min_wid; j++ {
710 | for i := 0; i < min_wid/2; i++ {
711 | idx := b*in_wid*min_wid + j*min_wid/2 + i
712 | tmp[idx] = 1
713 | }
714 | }
715 | rot := 3*b*min_wid*min_wid/2 - int(reverseBits(uint32(pos), 2))*batch*min_wid*min_wid/2
716 | r_idx[rot] = tmp
717 | }
718 | }
719 | }
720 |
721 | return m_idx, r_idx
722 | }
723 |
--------------------------------------------------------------------------------
/test.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "strconv"
6 | "time"
7 |
8 | "github.com/dwkim606/test_lattigo/ckks"
9 | )
10 |
11 | // Fast Conv without boot, Assume full batch with Po2 in_wid & N
12 | // Normal Conv without output modification (e.g., trimming or expanding)
13 | // Assume that the input is 0 padded according to kernel size: only in_wid - (ker_wid-1)/2 elements in row and columns are nonzero
14 | // Also support non-full batching case
15 | func testConv_in(in_batch, in_wid, ker_wid, total_test_num int, boot bool) {
16 | kind := "Conv"
17 | printResult := false
18 | raw_in_batch := in_batch // same as python
19 | raw_in_wid := in_wid - ker_wid/2 // same as python
20 | norm := in_batch / raw_in_batch
21 | test_dir := "test_conv_data/"
22 | pow := 4.0
23 |
24 | // set basic variables for above input variables
25 | kp_wid, out_batch, logN, trans := set_Variables(in_batch, raw_in_wid, in_wid, ker_wid, kind)
26 | raw_out_batch := out_batch / norm
27 |
28 | // generate Context: params, Keys, rotations, general plaintexts
29 | cont := newContext(logN, ker_wid, []int{in_wid}, []int{kp_wid}, boot, kind)
30 | fmt.Println("vec size: log2 = ", cont.logN)
31 | fmt.Println("raw input width: ", raw_in_wid)
32 | fmt.Println("kernel width: ", ker_wid)
33 | fmt.Println("num raw batches in & out: ", raw_in_batch, ", ", raw_out_batch)
34 |
35 | for test_iter := 0; test_iter < total_test_num; test_iter++ {
36 | fmt.Println(test_iter+1, "-th iter...start")
37 | raw_input := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_in_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*raw_in_batch)
38 | ker_in := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_ker_"+strconv.Itoa(test_iter)+".csv", raw_in_batch*raw_out_batch*ker_wid*ker_wid)
39 | bn_a := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_bna_"+strconv.Itoa(test_iter)+".csv", raw_out_batch)
40 | bn_b := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_bnb_"+strconv.Itoa(test_iter)+".csv", raw_out_batch)
41 |
42 | // input encryption
43 | input := prep_Input(raw_input, raw_in_wid, in_wid, cont.N, norm, trans, printResult)
44 | start := time.Now()
45 | plain_tmp := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values
46 | cont.encoder.EncodeCoeffs(input, plain_tmp)
47 | ctxt_input := cont.encryptor.EncryptNew(plain_tmp)
48 | fmt.Printf("Encryption done in %s \n", time.Since(start))
49 |
50 | // Kernel Prep & Conv (+BN) Evaluation
51 | var ct_result *ckks.Ciphertext
52 | if boot {
53 | ct_result = evalConv_BNRelu_new(cont, ctxt_input, ker_in, bn_a, bn_b, 0.0, pow, in_wid, kp_wid, ker_wid, raw_in_batch, raw_out_batch, norm, 0, 0, 2, 0, kind, false, false)
54 | } else {
55 | ct_result = evalConv_BN(cont, ctxt_input, ker_in, bn_a, bn_b, in_wid, ker_wid, raw_in_batch, raw_out_batch, norm, float64(1<<30), trans)
56 | }
57 |
58 | start = time.Now()
59 | cont.decryptor.Decrypt(ct_result, plain_tmp)
60 | cfs_tmp := cont.encoder.DecodeCoeffs(plain_tmp)
61 | fmt.Printf("Decryption Done in %s \n", time.Since(start))
62 |
63 | test_out := post_process(cfs_tmp, raw_in_wid, in_wid)
64 | var real_out []float64
65 | if boot {
66 | real_out = readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_reluout_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*raw_in_batch)
67 | } else {
68 | real_out = readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_out_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*raw_in_batch)
69 | }
70 |
71 | printDebugCfsPlain(test_out, real_out)
72 | }
73 |
74 | }
75 |
76 | func testResNet_crop_sparse(st, end, ker_wid, depth int, debug, cf100 bool) {
77 | // init_batch fixed to 16
78 | ker_name := "ker" + strconv.Itoa(ker_wid)
79 | weight_dir := "Resnet_weights/weights_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/" // !! NEED to remove "_test"
80 | out_dir := "Resnet_enc_results/results_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/"
81 | fc_out := 10 // 100 for cifar100
82 | init_pow := 6.0 // covers [-2^pow, 2^pow] values at ReLU evaluation
83 | mid_pow := 6.0
84 | final_pow := 6.0
85 | if cf100 {
86 | weight_dir = "Resnet_weights/weights_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/"
87 | out_dir = "Resnet_enc_results/results_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/"
88 | fc_out = 100 // 100 for cifar100
89 | if ker_wid == 3 {
90 | final_pow = 7.0
91 | } else if ker_wid == 5 {
92 | final_pow = 6.0
93 | } else {
94 | final_pow = 5.0
95 | }
96 | init_pow = 5.0
97 | mid_pow = 5.0
98 | }
99 |
100 | var num_blcs [3]int
101 | switch depth {
102 | case 20:
103 | num_blcs[0], num_blcs[1], num_blcs[2] = 7, 5, 5
104 | case 14:
105 | num_blcs[0], num_blcs[1], num_blcs[2] = 5, 3, 3
106 | case 8:
107 | num_blcs[0], num_blcs[1], num_blcs[2] = 3, 1, 1
108 | default:
109 | panic("wrong depth (not in 8, 14, 20)!")
110 | }
111 | real_batch := []int{16, 32, 64} // same as python (small for faster eval test) !! NEEDS to be changed for real test input {16, 32, 64}
112 | norm := []int{4, 8, 16} // only use 1/norm batches among full batches (i.e., sparse packing)
113 | step := []int{1, 1, 1} // non-one only when it is for inside
114 |
115 | logN := 16 // !! NEEDS to be modified to 16
116 | alpha := 0.0
117 | in_wids := []int{32, 16, 8} // before cropping
118 | raw_in_wids := []int{32 - ker_wid/2, 16 - ker_wid/2, 8 - ker_wid/2} // same as python
119 | fast_pack := true
120 | ker_size := ker_wid * ker_wid
121 | max_batch := make([]int, len(real_batch)) // the max batch
122 | for i := range max_batch {
123 | max_batch[i] = (1 << logN) / (in_wids[i] * in_wids[i])
124 | }
125 |
126 | cont := newContext(logN, ker_wid, in_wids, raw_in_wids, true, "Resnet_crop_sparse")
127 |
128 | for iter := st; iter < end; iter++ {
129 | fmt.Println("Running ", iter, "-th iter... ker size: ", ker_wid)
130 | image := readTxt("Resnet_plain_data/crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid1/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3)
131 | // image := make([]float64, in_wids[0]*in_wids[0]*3)
132 | // for i := range image {
133 | // image[i] = 1.0 - 1.0*float64(i)/float64(len(image))
134 | // }
135 | if cf100 {
136 | image = readTxt("Resnet_plain_data/cf100_crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid1/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3)
137 | }
138 | input := make([]float64, cont.N)
139 | k := 0
140 | for i := 0; i < in_wids[0]; i++ {
141 | for j := 0; j < in_wids[0]; j++ {
142 | for b := 0; b < 3; b++ {
143 | if (i < raw_in_wids[0]) && (j < raw_in_wids[0]) {
144 | input[i*in_wids[0]*max_batch[0]+j*max_batch[0]+b*norm[0]] = image[k] // sparse pack the input
145 | }
146 | k++
147 | }
148 | }
149 | }
150 | fmt.Println("Input: ")
151 | prt_mat_norm(input, max_batch[0], norm[0], 3, false)
152 | fmt.Println("vec size: ", cont.N)
153 | fmt.Println("input width: ", raw_in_wids)
154 | fmt.Println("kernel width: ", ker_wid)
155 | fmt.Println("num batches: ", real_batch)
156 |
157 | enc_start := time.Now()
158 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values
159 | cont.encoder.EncodeCoeffs(input, pl_input)
160 | ct_input := cont.encryptor.EncryptNew(pl_input)
161 | fmt.Printf("Encryption done in %s \n", time.Since(enc_start))
162 |
163 | timings := make([]float64, 6)
164 | begin_start := time.Now()
165 | start := time.Now()
166 |
167 | // ResNet Block 1
168 | pow := init_pow
169 | ct_layer := ct_input
170 | for i := 1; i <= num_blcs[0]; i++ {
171 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-a.csv", real_batch[0])
172 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-b.csv", real_batch[0])
173 | // bn_a := make([]float64, real_batch[0])
174 | // bn_b := make([]float64, real_batch[0])
175 | // for i := range bn_a {
176 | // bn_a[i] = 0.2
177 | // bn_b[i] = 0.0
178 | // }
179 | ker_in_batch := 3
180 | if i != 1 {
181 | ker_in_batch = real_batch[0]
182 | }
183 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", ker_in_batch*real_batch[0]*ker_size)
184 | // ker_in := make([]float64, ker_in_batch*real_batch[0]*ker_size)
185 | // for i := range ker_in {
186 | // ker_in[i] = 0.05 * float64(i) / float64(len(ker_in))
187 | // }
188 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, ker_in_batch, real_batch[0], norm[0], 0, step[0], 2, 2, "Conv_sparse", fast_pack, debug)
189 | pow = mid_pow
190 | fmt.Println("Block1, Layer ", i, "done!")
191 | }
192 | fmt.Println("Block1 done.") // !!!! HERE is DONE
193 | timings[0] = time.Since(start).Seconds()
194 | start = time.Now()
195 |
196 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-conv.csv", real_batch[0]*real_batch[1]*ker_size)
197 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-a.csv", real_batch[1])
198 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-b.csv", real_batch[1])
199 | // ker_in12 := make([]float64, real_batch[0]*real_batch[1]*ker_size)
200 | // for i := range ker_in12 {
201 | // ker_in12[i] = 0.05 * float64(i) / float64(len(ker_in12))
202 | // }
203 | // bn_a := make([]float64, real_batch[1])
204 | // bn_b := make([]float64, real_batch[1])
205 | // for i := range bn_a {
206 | // bn_a[i] = 0.1
207 | // bn_b[i] = 0.0
208 | // }
209 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in12, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[0], real_batch[1], norm[1], 0, step[1], 2, 1, "StrConv_sparse", fast_pack, debug)
210 | fmt.Println("Block1 to 2 done!")
211 | timings[1] = time.Since(start).Seconds()
212 | start = time.Now()
213 |
214 | // ResNet Block 2
215 | for i := 1; i <= num_blcs[1]; i++ {
216 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-a.csv", real_batch[1])
217 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-b.csv", real_batch[1])
218 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size)
219 | // bn_a2 := make([]float64, real_batch[1])
220 | // bn_b2 := make([]float64, real_batch[1])
221 | // ker_in2 := make([]float64, real_batch[1]*real_batch[1]*ker_size)
222 | // for i := range bn_a2 {
223 | // bn_a2[i] = 0.1
224 | // bn_b2[i] = 0.0
225 | // }
226 | // for i := range ker_in2 {
227 | // ker_in2[i] = 0.05 * float64(i) / float64(len(ker_in2))
228 | // }
229 |
230 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, 3, "Conv_sparse", fast_pack, debug)
231 | fmt.Println("Block2, Layer ", i, "done!")
232 | }
233 | fmt.Println("Block2 done.")
234 | timings[2] = time.Since(start).Seconds()
235 | start = time.Now()
236 |
237 | ker_in23 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-conv.csv", real_batch[1]*real_batch[2]*ker_size)
238 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-a.csv", real_batch[2])
239 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-b.csv", real_batch[2])
240 | // bn_a3 := make([]float64, real_batch[2])
241 | // bn_b3 := make([]float64, real_batch[2])
242 | // ker_in23 := make([]float64, real_batch[1]*real_batch[2]*ker_size)
243 | // for i := range bn_a3 {
244 | // bn_a3[i] = 0.1
245 | // bn_b3[i] = 0.0
246 | // }
247 | // for i := range ker_in23 {
248 | // ker_in23[i] = 0.05 * float64(i) / float64(len(ker_in23))
249 | // }
250 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in23, bn_a3, bn_b3, alpha, pow, in_wids[1], raw_in_wids[2], ker_wid, real_batch[1], real_batch[2], norm[2], 0, step[2], 2, 2, "StrConv_sparse", fast_pack, debug)
251 | fmt.Println("Block2 to 3 done!")
252 | timings[3] = time.Since(start).Seconds()
253 | start = time.Now()
254 |
255 | // ResNet Block 3
256 | for i := 1; i <= num_blcs[2]; i++ {
257 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-a.csv", real_batch[2])
258 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-b.csv", real_batch[2])
259 | ker_in3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-conv.csv", real_batch[2]*real_batch[2]*ker_size)
260 | // bn_a3 := make([]float64, real_batch[2])
261 | // bn_b3 := make([]float64, real_batch[2])
262 | // ker_in3 := make([]float64, real_batch[2]*real_batch[2]*ker_size)
263 | // for i := range bn_a3 {
264 | // bn_a3[i] = 0.1
265 | // bn_b3[i] = 0.0
266 | // }
267 | // for i := range ker_in3 {
268 | // ker_in3[i] = 0.1 * float64(i) / float64(len(ker_in3))
269 | // }
270 |
271 | if i == num_blcs[2] {
272 | pow = final_pow
273 | }
274 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in3, bn_a3, bn_b3, alpha, pow, in_wids[2], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 4, "Conv_sparse", fast_pack, debug)
275 | fmt.Println("Block3, Layer ", i, "done!")
276 | }
277 | fmt.Println("Block3 done.")
278 | timings[4] = time.Since(start).Seconds()
279 | start = time.Now()
280 |
281 | ker_inf_wid := raw_in_wids[2]
282 | if ker_inf_wid%2 == 0 {
283 | ker_inf_wid++
284 | }
285 | ker_inf := readTxt(weight_dir+"final-fckernel.csv", real_batch[2]*fc_out)
286 | // ker_inf := make([]float64, real_batch[2]*fc_out)
287 | // for i := range ker_inf {
288 | // ker_inf[i] = 0.1 * float64(i)
289 | // }
290 | var ct_result, ct_result2 *ckks.Ciphertext
291 | if cf100 {
292 | ker_inf_1 := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out/2)
293 | ker_inf_2 := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out/2)
294 | for i := 0; i < fc_out/2; i++ {
295 | for j := 0; j < real_batch[2]; j++ {
296 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ {
297 | ker_inf_1[j*fc_out/2+i+b*real_batch[2]*fc_out/2] = ker_inf[j*fc_out+i]
298 | ker_inf_2[j*fc_out/2+i+b*real_batch[2]*fc_out/2] = ker_inf[j*fc_out+i+fc_out/2]
299 | }
300 | }
301 | }
302 | bn_af := make([]float64, fc_out/2)
303 | for i := range bn_af {
304 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements
305 | }
306 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out)
307 | bn_bf_1 := make([]float64, fc_out/2)
308 | bn_bf_2 := make([]float64, fc_out/2)
309 | for i := range bn_bf_1 {
310 | bn_bf_1[i] = bn_bf[i]
311 | bn_bf_2[i] = bn_bf[i+fc_out/2]
312 | }
313 | ct_result = evalConv_BN(cont, ct_layer, ker_inf_1, bn_af, bn_bf_1, in_wids[2], ker_inf_wid, real_batch[2], fc_out/2, norm[2], float64(1<<30), false)
314 | ct_result2 = evalConv_BN(cont, ct_layer, ker_inf_2, bn_af, bn_bf_2, in_wids[2], ker_inf_wid, real_batch[2], fc_out/2, norm[2], float64(1<<30), false)
315 | fmt.Println("Final FC done.")
316 | timings[5] = time.Since(start).Seconds()
317 | start = time.Now()
318 | } else {
319 | ker_inf_ := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out)
320 | for i := range ker_inf {
321 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ {
322 | ker_inf_[i+b*real_batch[2]*fc_out] = ker_inf[i]
323 | }
324 | }
325 | bn_af := make([]float64, fc_out)
326 | for i := range bn_af {
327 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements
328 | }
329 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out)
330 | // bn_bf := make([]float64, fc_out)
331 | // for i := range bn_bf {
332 | // bn_bf[i] = 1 * float64(i)
333 | // }
334 | ct_result = evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[2], ker_inf_wid, real_batch[2], fc_out, norm[2], float64(1<<30), false)
335 | fmt.Println("Final FC done.")
336 | timings[5] = time.Since(start).Seconds()
337 | start = time.Now()
338 | }
339 |
340 | fmt.Println()
341 | fmt.Println("=============== DECRYPTION ===============")
342 | fmt.Println()
343 | if cf100 {
344 | cont.decryptor.Decrypt(ct_result, pl_input)
345 | res_tmp1 := cont.encoder.DecodeCoeffs(pl_input)
346 | cont.decryptor.Decrypt(ct_result2, pl_input)
347 | res_tmp2 := cont.encoder.DecodeCoeffs(pl_input)
348 | fmt.Printf("Decryption Done in %s \n", time.Since(start))
349 | res_out := append(prt_mat_one_norm(res_tmp1, max_batch[2], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)[:fc_out/2], prt_mat_one_norm(res_tmp2, max_batch[2], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)[:fc_out/2]...)
350 | fmt.Println("\n result: ", res_out)
351 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out)
352 | } else {
353 | cont.decryptor.Decrypt(ct_result, pl_input)
354 | res_tmp := cont.encoder.DecodeCoeffs(pl_input)
355 | fmt.Printf("Decryption Done in %s \n", time.Since(start))
356 | res_out := prt_mat_one_norm(res_tmp, max_batch[2], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)
357 | fmt.Println("\n result: ", res_out[:fc_out])
358 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out[:fc_out])
359 | }
360 |
361 | fmt.Println("Blc1: ", timings[0], " sec")
362 | fmt.Println("Blc1->2: ", timings[1], " sec")
363 | fmt.Println("Blc2: ", timings[2], " sec")
364 | fmt.Println("Blc2->3: ", timings[3], " sec")
365 | fmt.Println("Blc3: ", timings[4], " sec")
366 | fmt.Println("Final (reduce_mean & FC): ", timings[5], " sec")
367 | fmt.Printf("Total done in %s \n", time.Since(begin_start))
368 | }
369 |
370 | }
371 |
372 | func testResNet_crop_fast_in(st, end, ker_wid, depth int, debug, cf100 bool) {
373 | // init_batch fixed to 16
374 | ker_name := "ker" + strconv.Itoa(ker_wid)
375 | weight_dir := "Resnet_weights/weights_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/"
376 | out_dir := "Resnet_enc_results/results_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/"
377 | fc_out := 10 // 100 for cifar100
378 | init_pow := 6.0 // covers [-2^pow, 2^pow] values at ReLU evaluation
379 | mid_pow := 6.0
380 | final_pow := 6.0
381 | if cf100 {
382 | weight_dir = "Resnet_weights/weights_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/"
383 | out_dir = "Resnet_enc_results/results_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid1/"
384 | fc_out = 100 // 100 for cifar100
385 | if ker_wid == 3 {
386 | final_pow = 7.0
387 | } else if ker_wid == 5 {
388 | final_pow = 6.0
389 | } else {
390 | final_pow = 5.0
391 | }
392 | init_pow = 5.0
393 | mid_pow = 5.0
394 | }
395 |
396 | var num_blcs [3]int
397 | switch depth {
398 | case 20:
399 | num_blcs[0], num_blcs[1], num_blcs[2] = 7, 5, 5
400 | case 14:
401 | num_blcs[0], num_blcs[1], num_blcs[2] = 5, 3, 3
402 | case 8:
403 | num_blcs[0], num_blcs[1], num_blcs[2] = 3, 1, 1
404 | default:
405 | panic("wrong depth (not in 8, 14, 20)!")
406 | }
407 | real_batch := []int{16, 32, 64} // same as python
408 | norm := []int{4, 2, 1} // only use 1/norm batches
409 | step := []int{1, 2, 4}
410 | prt_start := []int{1, 1, 1}
411 | if ker_wid == 5 {
412 | prt_start[0] = 1
413 | prt_start[1] = 2
414 | prt_start[2] = 4
415 | }
416 |
417 | logN := 16
418 | alpha := 0.0
419 | in_wids := []int{32, 16, 8} // before cropping
420 | raw_in_wids := []int{32 - ker_wid/2, 16 - ker_wid/2, 8 - ker_wid/2} // same as python
421 | fast_pack := true
422 | ker_size := ker_wid * ker_wid
423 | max_batch := make([]int, len(real_batch)) // the max batch
424 | for i := range max_batch {
425 | max_batch[i] = (1 << logN) / (in_wids[i] * in_wids[i])
426 | }
427 |
428 | cont := newContext(logN, ker_wid, in_wids, raw_in_wids, true, "Resnet_crop_fast")
429 |
430 | for iter := st; iter < end; iter++ {
431 | fmt.Println("Running ", iter, "-th iter... ker size: ", ker_wid)
432 | image := readTxt("Resnet_plain_data/crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid1/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3)
433 | if cf100 {
434 | image = readTxt("Resnet_plain_data/cf100_crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid1/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3)
435 | }
436 | input := make([]float64, cont.N)
437 | k := 0
438 | for i := 0; i < in_wids[0]; i++ {
439 | for j := 0; j < in_wids[0]; j++ {
440 | for b := 0; b < 3; b++ {
441 | if (i < raw_in_wids[0]) && (j < raw_in_wids[0]) {
442 | input[i*in_wids[0]*max_batch[0]+j*max_batch[0]+b*norm[0]] = image[k]
443 | }
444 | k++
445 | }
446 | }
447 | }
448 | fmt.Println("Input: ")
449 | prt_mat_norm(input, max_batch[0], norm[0], 1, false)
450 | fmt.Println("vec size: ", cont.N)
451 | fmt.Println("input width: ", raw_in_wids)
452 | fmt.Println("kernel width: ", ker_wid)
453 | fmt.Println("num batches: ", real_batch)
454 |
455 | enc_start := time.Now()
456 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values
457 | cont.encoder.EncodeCoeffs(input, pl_input)
458 | ct_input := cont.encryptor.EncryptNew(pl_input)
459 | fmt.Printf("Encryption done in %s \n", time.Since(enc_start))
460 |
461 | timings := make([]float64, 6)
462 | begin_start := time.Now()
463 | start := time.Now()
464 |
465 | // ResNet Block 1
466 | pow := init_pow
467 | ct_layer := ct_input
468 | for i := 1; i <= num_blcs[0]; i++ {
469 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-a.csv", real_batch[0])
470 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-b.csv", real_batch[0])
471 | ker_in_batch := 3
472 | if i != 1 {
473 | ker_in_batch = real_batch[0]
474 | }
475 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", ker_in_batch*real_batch[0]*ker_size)
476 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, ker_in_batch, real_batch[0], norm[0], 0, step[0], 2, 0, "Conv_inside", fast_pack, debug)
477 | pow = mid_pow
478 | fmt.Println("Block1, Layer ", i, "done!")
479 | }
480 | fmt.Println("Block1 done.")
481 | timings[0] = time.Since(start).Seconds()
482 | start = time.Now()
483 |
484 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-conv.csv", real_batch[0]*real_batch[1]*ker_size)
485 | ker_in12_new := make([]float64, 2*real_batch[0]*real_batch[1]*ker_size)
486 | for k := 0; k < ker_size; k++ {
487 | for i := 0; i < real_batch[0]; i++ {
488 | for j := 0; j < real_batch[1]; j++ {
489 | ker_in12_new[k*2*real_batch[0]*real_batch[1]+2*i*real_batch[1]+j] = ker_in12[k*real_batch[0]*real_batch[1]+i*real_batch[1]+j]
490 | }
491 | }
492 | }
493 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-a.csv", real_batch[1])
494 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-b.csv", real_batch[1])
495 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in12_new, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, 0, "StrConv_inside", fast_pack, debug)
496 | fmt.Println("Block1 to 2 done!")
497 | if debug {
498 | max_bat := cont.N / (in_wids[0] * in_wids[0])
499 | res_ttmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_layer))
500 | prt_mat_norm_step(res_ttmp, max_bat, norm[1], step[1], prt_start[1], 3, false)
501 | }
502 | timings[1] = time.Since(start).Seconds()
503 | start = time.Now()
504 |
505 | // ResNet Block 2
506 | for i := 1; i <= num_blcs[1]; i++ {
507 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-a.csv", real_batch[1])
508 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-b.csv", real_batch[1])
509 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size)
510 |
511 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, 0, "Conv_inside", fast_pack, debug)
512 | fmt.Println("Block2, Layer ", i, "done!")
513 | }
514 | fmt.Println("Block2 done.")
515 | timings[2] = time.Since(start).Seconds()
516 | start = time.Now()
517 |
518 | ker_in23 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-conv.csv", real_batch[1]*real_batch[2]*ker_size)
519 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-a.csv", real_batch[2])
520 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-b.csv", real_batch[2])
521 | ker_in23_new := make([]float64, 2*real_batch[1]*real_batch[2]*ker_size)
522 | for k := 0; k < ker_size; k++ {
523 | for i := 0; i < real_batch[1]; i++ {
524 | for j := 0; j < real_batch[2]; j++ {
525 | ker_in23_new[k*2*real_batch[1]*real_batch[2]+2*i*real_batch[2]+j] = ker_in23[k*real_batch[1]*real_batch[2]+i*real_batch[2]+j]
526 | }
527 | }
528 | }
529 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in23_new, bn_a3, bn_b3, alpha, pow, in_wids[0], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 0, "StrConv_inside", fast_pack, debug)
530 | fmt.Println("Block2 to 3 done!")
531 | if debug {
532 | max_bat := cont.N / (in_wids[0] * in_wids[0])
533 | res_ttmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_layer))
534 | prt_mat_norm_step(res_ttmp, max_bat, norm[2], step[2], prt_start[2], 3, false)
535 | }
536 | timings[3] = time.Since(start).Seconds()
537 | start = time.Now()
538 |
539 | // ResNet Block 3
540 | for i := 1; i <= num_blcs[2]; i++ {
541 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-a.csv", real_batch[2])
542 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-b.csv", real_batch[2])
543 | ker_in3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-conv.csv", real_batch[2]*real_batch[2]*ker_size)
544 |
545 | if i == num_blcs[2] {
546 | pow = final_pow
547 | }
548 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in3, bn_a3, bn_b3, alpha, pow, in_wids[0], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 0, "Conv_inside", fast_pack, debug)
549 | fmt.Println("Block3, Layer ", i, "done!")
550 | }
551 | fmt.Println("Block3 done.")
552 | timings[4] = time.Since(start).Seconds()
553 | start = time.Now()
554 |
555 | ker_inf_wid := raw_in_wids[0]
556 | if ker_inf_wid%2 == 0 {
557 | ker_inf_wid++
558 | }
559 | ker_inf := readTxt(weight_dir+"final-fckernel.csv", real_batch[2]*fc_out)
560 | var ct_result, ct_result2 *ckks.Ciphertext
561 | if cf100 {
562 | ker_inf_1 := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out/2)
563 | ker_inf_2 := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out/2)
564 | for i := 0; i < fc_out/2; i++ {
565 | for j := 0; j < real_batch[2]; j++ {
566 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ {
567 | ker_inf_1[j*fc_out/2+i+b*real_batch[2]*fc_out/2] = ker_inf[j*fc_out+i]
568 | ker_inf_2[j*fc_out/2+i+b*real_batch[2]*fc_out/2] = ker_inf[j*fc_out+i+fc_out/2]
569 | }
570 | }
571 | }
572 | bn_af := make([]float64, fc_out/2)
573 | for i := range bn_af {
574 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements
575 | }
576 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out)
577 | bn_bf_1 := make([]float64, fc_out/2)
578 | bn_bf_2 := make([]float64, fc_out/2)
579 | for i := range bn_bf_1 {
580 | bn_bf_1[i] = bn_bf[i]
581 | bn_bf_2[i] = bn_bf[i+fc_out/2]
582 | }
583 | ct_result = evalConv_BN(cont, ct_layer, ker_inf_1, bn_af, bn_bf_1, in_wids[0], ker_inf_wid, real_batch[2], fc_out/2, norm[2], float64(1<<30), false)
584 | ct_result2 = evalConv_BN(cont, ct_layer, ker_inf_2, bn_af, bn_bf_2, in_wids[0], ker_inf_wid, real_batch[2], fc_out/2, norm[2], float64(1<<30), false)
585 | fmt.Println("Final FC done.")
586 | timings[5] = time.Since(start).Seconds()
587 | start = time.Now()
588 | } else {
589 | ker_inf_ := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out)
590 | for i := range ker_inf {
591 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ {
592 | ker_inf_[i+b*real_batch[2]*fc_out] = ker_inf[i]
593 | }
594 | }
595 | bn_af := make([]float64, fc_out)
596 | for i := range bn_af {
597 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements
598 | }
599 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out)
600 | ct_result = evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[0], ker_inf_wid, real_batch[2], fc_out, norm[2], float64(1<<30), false)
601 | fmt.Println("Final FC done.")
602 | timings[5] = time.Since(start).Seconds()
603 | start = time.Now()
604 | }
605 |
606 | fmt.Println()
607 | fmt.Println("=============== DECRYPTION ===============")
608 | fmt.Println()
609 | if cf100 {
610 | cont.decryptor.Decrypt(ct_result, pl_input)
611 | res_tmp1 := cont.encoder.DecodeCoeffs(pl_input)
612 | cont.decryptor.Decrypt(ct_result2, pl_input)
613 | res_tmp2 := cont.encoder.DecodeCoeffs(pl_input)
614 | fmt.Printf("Decryption Done in %s \n", time.Since(start))
615 | res_out := append(prt_mat_one_norm(res_tmp1, max_batch[0], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)[:fc_out/2], prt_mat_one_norm(res_tmp2, max_batch[0], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)[:fc_out/2]...)
616 | fmt.Println("\n result: ", res_out)
617 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out)
618 | } else {
619 | cont.decryptor.Decrypt(ct_result, pl_input)
620 | res_tmp := cont.encoder.DecodeCoeffs(pl_input)
621 | fmt.Printf("Decryption Done in %s \n", time.Since(start))
622 | res_out := prt_mat_one_norm(res_tmp, max_batch[0], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)
623 | // fmt.Print(res_out)
624 | fmt.Println("\n result: ", res_out[:fc_out])
625 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out[:fc_out])
626 | }
627 |
628 | fmt.Println("Blc1: ", timings[0], " sec")
629 | fmt.Println("Blc1->2: ", timings[1], " sec")
630 | fmt.Println("Blc2: ", timings[2], " sec")
631 | fmt.Println("Blc2->3: ", timings[3], " sec")
632 | fmt.Println("Blc3: ", timings[4], " sec")
633 | fmt.Println("Final (reduce_mean & FC): ", timings[5], " sec")
634 | fmt.Printf("Total done in %s \n", time.Since(begin_start))
635 | }
636 | }
637 |
638 | func testResNet_crop_sparse_wide(st, end, ker_wid, depth, wide_case int, debug, cf100 bool) {
639 | // init_batch fixed to 16
640 | ker_name := "ker" + strconv.Itoa(ker_wid)
641 | weight_dir := "Resnet_weights/weights_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/"
642 | out_dir := "Resnet_enc_results/results_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/"
643 | fc_out := 10
644 |
645 | init_pow := 5.0
646 | mid_pow := 5.0 // needs to be 5.0 in k3 d20 w3 for best performance
647 | final_pow := 5.0
648 | if ker_wid == 5 {
649 | init_pow = 6.0
650 | mid_pow = 6.0
651 | final_pow = 6.0
652 | }
653 |
654 | if cf100 {
655 | weight_dir = "Resnet_weights/weights_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/"
656 | out_dir = "Resnet_enc_results/results_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/"
657 | fc_out = 100
658 | final_pow = 7.0
659 | init_pow = 5.0
660 | mid_pow = 5.0
661 | if (ker_wid == 5) && (depth == 8) {
662 | init_pow = 6.0
663 | final_pow = 6.0
664 | }
665 | }
666 |
667 | init_batch := 16
668 |
669 | var num_blcs [3]int
670 | switch depth {
671 | case 20:
672 | num_blcs[0], num_blcs[1], num_blcs[2] = 7, 5, 5
673 | case 14:
674 | num_blcs[0], num_blcs[1], num_blcs[2] = 5, 3, 3
675 | case 8:
676 | num_blcs[0], num_blcs[1], num_blcs[2] = 3, 1, 1
677 | default:
678 | panic("wrong depth case (not in 8,14,20)!")
679 | }
680 | real_batch := []int{32, 64, 128} // same as python
681 | norm := []int{2, 4, 8} // only use 1/norm batches
682 | log_sparse := []int{1, 2, 3}
683 | step := []int{1, 1, 1}
684 | kind := "Resnet_crop_sparse_wide2"
685 |
686 | if wide_case == 3 {
687 | real_batch = []int{48, 96, 192}
688 | norm = []int{1, 2, 4}
689 | log_sparse = []int{0, 1, 2}
690 | kind = "Resnet_crop_sparse_wide3"
691 | } else if wide_case != 2 {
692 | panic("wrong wide_case (2 nor 3)!")
693 | }
694 |
695 | logN := 16
696 | alpha := 0.0
697 | in_wids := []int{32, 16, 8} // before cropping
698 | raw_in_wids := []int{32 - ker_wid/2, 16 - ker_wid/2, 8 - ker_wid/2} // same as python
699 | fast_pack := true
700 | ker_size := ker_wid * ker_wid
701 | max_batch := make([]int, len(real_batch)) // the max batch
702 | for i := range max_batch {
703 | max_batch[i] = (1 << logN) / (in_wids[i] * in_wids[i])
704 | }
705 |
706 | cont := newContext(logN, ker_wid, in_wids, raw_in_wids, true, kind)
707 |
708 | for iter := st; iter < end; iter++ {
709 | fmt.Println("Running ", iter, "-th iter... ker size: ", ker_wid)
710 | image := readTxt("Resnet_plain_data/crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid"+strconv.Itoa(wide_case)+"/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3)
711 |
712 | if cf100 {
713 | image = readTxt("Resnet_plain_data/cf100_crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid"+strconv.Itoa(wide_case)+"/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3)
714 | }
715 | input := make([]float64, cont.N)
716 | k := 0
717 | for i := 0; i < in_wids[0]; i++ {
718 | for j := 0; j < in_wids[0]; j++ {
719 | for b := 0; b < 3; b++ {
720 | if (i < raw_in_wids[0]) && (j < raw_in_wids[0]) {
721 | input[i*in_wids[0]*max_batch[0]+j*max_batch[0]+b*norm[0]] = image[k]
722 | }
723 | k++
724 | }
725 | }
726 | }
727 | fmt.Println("Input: ")
728 | prt_mat_norm(input, max_batch[0], norm[0], 3, false)
729 | fmt.Println("vec size: ", cont.N)
730 | fmt.Println("input width: ", raw_in_wids)
731 | fmt.Println("kernel width: ", ker_wid)
732 | fmt.Println("num batches: ", real_batch)
733 |
734 | enc_start := time.Now()
735 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values
736 | cont.encoder.EncodeCoeffs(input, pl_input)
737 | ct_input := cont.encryptor.EncryptNew(pl_input)
738 | fmt.Printf("Encryption done in %s \n", time.Since(enc_start))
739 | enc_start = time.Now()
740 |
741 | timings := make([]float64, 6)
742 | begin_start := time.Now()
743 | start := time.Now()
744 |
745 | // ResNet Block 1
746 | pow := init_pow
747 | ct_layer := ct_input
748 | for i := 1; i <= num_blcs[0]; i++ {
749 | if i == 5 {
750 | pow = mid_pow
751 | }
752 | var bn_batch int
753 | if i == 1 {
754 | bn_batch = init_batch
755 | } else {
756 | bn_batch = real_batch[0]
757 | }
758 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-a.csv", bn_batch)
759 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-b.csv", bn_batch)
760 |
761 | if i == 1 {
762 | ker_in := readTxt(weight_dir+"w0-conv.csv", 3*init_batch*ker_size)
763 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, 3, init_batch, norm[0], 0, step[0], 2, log_sparse[0], "Conv_sparse", fast_pack, debug)
764 | // pow = mid_pow
765 | } else if i == 2 {
766 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", init_batch*real_batch[0]*ker_size)
767 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, init_batch, real_batch[0], norm[0], 0, step[0], 2, log_sparse[0], "Conv_sparse", fast_pack, debug)
768 | } else {
769 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", real_batch[0]*real_batch[0]*ker_size)
770 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, real_batch[0], real_batch[0], norm[0], 0, step[0], 2, log_sparse[0], "Conv_sparse", fast_pack, debug)
771 | }
772 | fmt.Println("Block1, Layer ", i, "done!")
773 | }
774 | fmt.Println("Block1 done.")
775 | timings[0] = time.Since(start).Seconds()
776 | start = time.Now()
777 |
778 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-conv.csv", real_batch[0]*real_batch[1]*ker_size)
779 | ker_in12_0 := make([]float64, len(ker_in12)/2)
780 | ker_in12_1 := make([]float64, len(ker_in12)/2)
781 | if wide_case == 3 {
782 | for k := 0; k < ker_size; k++ {
783 | for i := 0; i < real_batch[0]; i++ {
784 | for j := 0; j < real_batch[1]/2; j++ {
785 | ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j)] // [i][2*j]
786 | ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j+1)] // [i][2*j+1]
787 | }
788 | }
789 | }
790 | }
791 |
792 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-a.csv", real_batch[1])
793 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-b.csv", real_batch[1])
794 |
795 | if wide_case == 2 {
796 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in12, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[0], real_batch[1], norm[1], 0, step[1], 2, log_sparse[0]-1, "StrConv_sparse", fast_pack, debug)
797 | } else if wide_case == 3 {
798 | bn_a_0 := make([]float64, real_batch[1]/2)
799 | bn_a_1 := make([]float64, real_batch[1]/2)
800 | bn_b_0 := make([]float64, real_batch[1]/2)
801 | bn_b_1 := make([]float64, real_batch[1]/2)
802 | for i := range bn_b_0 {
803 | bn_a_0[i] = bn_a[2*i]
804 | bn_a_1[i] = bn_a[2*i+1]
805 | bn_b_0[i] = bn_b[2*i]
806 | bn_b_1[i] = bn_b[2*i+1]
807 | }
808 | ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a_0, bn_b_0, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[0], real_batch[1]/2, norm[0], 0, step[1], 2, 0, "StrConv_sparse_full", fast_pack, debug)
809 | ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a_1, bn_b_1, alpha, pow, in_wids[0], raw_in_wids[1], ker_wid, real_batch[0], real_batch[1]/2, norm[0], 0, step[1], 2, 0, "StrConv_sparse_full", fast_pack, debug)
810 |
811 | xi := make([]float64, cont.N)
812 | xi[2] = 1.0
813 | xi_plain := ckks.NewPlaintext(cont.params, ct_result2.Level(), 1.0)
814 | cont.encoder.EncodeCoeffs(xi, xi_plain)
815 | cont.encoder.ToNTT(xi_plain)
816 | ct_result2 = cont.evaluator.MulNew(ct_result2, xi_plain)
817 | ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2)
818 | }
819 | fmt.Println("Block1 to 2 done!")
820 | timings[1] = time.Since(start).Seconds()
821 | start = time.Now()
822 |
823 | // ResNet Block 2
824 | for i := 1; i <= num_blcs[1]; i++ {
825 | if i == 5 {
826 | pow = init_pow
827 | }
828 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-a.csv", real_batch[1])
829 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-b.csv", real_batch[1])
830 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size)
831 |
832 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, log_sparse[1], "Conv_sparse", fast_pack, debug)
833 | fmt.Println("Block2, Layer ", i, "done!")
834 | }
835 | fmt.Println("Block2 done.")
836 | timings[2] = time.Since(start).Seconds()
837 | start = time.Now()
838 |
839 | pow = mid_pow
840 | ker_in23 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-conv.csv", real_batch[1]*real_batch[2]*ker_size)
841 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-a.csv", real_batch[2])
842 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-b.csv", real_batch[2])
843 |
844 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in23, bn_a3, bn_b3, alpha, pow, in_wids[1], raw_in_wids[2], ker_wid, real_batch[1], real_batch[2], norm[2], 0, step[2], 2, log_sparse[1]-1, "StrConv_sparse", fast_pack, debug)
845 | fmt.Println("Block2 to 3 done!")
846 | timings[3] = time.Since(start).Seconds()
847 | start = time.Now()
848 |
849 | // ResNet Block 3
850 | for i := 1; i <= num_blcs[2]; i++ {
851 | if i == 3 {
852 | pow = init_pow
853 | }
854 | if i == 5 {
855 | pow = mid_pow
856 | }
857 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-a.csv", real_batch[2])
858 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-b.csv", real_batch[2])
859 | ker_in3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-conv.csv", real_batch[2]*real_batch[2]*ker_size)
860 |
861 | if i == num_blcs[2] {
862 | pow = final_pow
863 | }
864 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in3, bn_a3, bn_b3, alpha, pow, in_wids[2], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, log_sparse[2], "Conv_sparse", fast_pack, debug)
865 | fmt.Println("Block3, Layer ", i, "done!")
866 | }
867 | fmt.Println("Block3 done.")
868 | timings[4] = time.Since(start).Seconds()
869 | start = time.Now()
870 |
871 | ker_inf_wid := raw_in_wids[2]
872 | if ker_inf_wid%2 == 0 {
873 | ker_inf_wid++
874 | }
875 | ker_inf := readTxt(weight_dir+"final-fckernel.csv", real_batch[2]*fc_out)
876 |
877 | ker_inf_ := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out)
878 | for i := range ker_inf {
879 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ {
880 | ker_inf_[i+b*real_batch[2]*fc_out] = ker_inf[i]
881 | }
882 | }
883 | bn_af := make([]float64, fc_out)
884 | for i := range bn_af {
885 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements
886 | }
887 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out)
888 |
889 | ct_result := evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[2], ker_inf_wid, real_batch[2], fc_out, norm[2], float64(1<<30), false)
890 | fmt.Println("Final FC done.")
891 | timings[5] = time.Since(start).Seconds()
892 | start = time.Now()
893 |
894 | fmt.Println()
895 | fmt.Println("=============== DECRYPTION ===============")
896 | fmt.Println()
897 | cont.decryptor.Decrypt(ct_result, pl_input)
898 | res_tmp := cont.encoder.DecodeCoeffs(pl_input)
899 | fmt.Printf("Decryption Done in %s \n", time.Since(start))
900 | res_out := prt_mat_one_norm(res_tmp, max_batch[2], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)
901 | fmt.Println("\n result: ", res_out[:fc_out])
902 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out[:fc_out])
903 |
904 | fmt.Println("Blc1: ", timings[0], " sec")
905 | fmt.Println("Blc1->2: ", timings[1], " sec")
906 | fmt.Println("Blc2: ", timings[2], " sec")
907 | fmt.Println("Blc2->3: ", timings[3], " sec")
908 | fmt.Println("Blc3: ", timings[4], " sec")
909 | fmt.Println("Final (reduce_mean & FC): ", timings[5], " sec")
910 | fmt.Printf("Total done in %s \n", time.Since(begin_start))
911 | }
912 | }
913 |
914 | func testResNet_crop_fast_wide_in(st, end, ker_wid, depth, wide_case int, debug, cf100 bool) {
915 | // init_batch fixed to 16
916 | ker_name := "ker" + strconv.Itoa(ker_wid)
917 | weight_dir := "Resnet_weights/weights_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/"
918 | out_dir := "Resnet_enc_results/results_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/"
919 | fc_out := 10 // 100 for cifar100
920 |
921 | init_pow := 5.0
922 | mid_pow := 5.0 // needs to be 5.0 in k3 d20 w3 for best performance
923 | final_pow := 5.0
924 | if ker_wid == 5 {
925 | init_pow = 6.0
926 | mid_pow = 6.0
927 | final_pow = 6.0
928 | }
929 |
930 | if cf100 {
931 | weight_dir = "Resnet_weights/weights_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/"
932 | out_dir = "Resnet_enc_results/results_cf100_crop_" + ker_name + "_d" + strconv.Itoa(depth) + "_wid" + strconv.Itoa(wide_case) + "/"
933 | fc_out = 100 // 100 for cifar100
934 | final_pow = 7.0
935 | init_pow = 5.0
936 | mid_pow = 5.0
937 | if (ker_wid == 5) && (depth == 8) {
938 | init_pow = 6.0
939 | final_pow = 6.0
940 | }
941 | }
942 |
943 | init_batch := 16 // needs to be modified to 16
944 |
945 | var num_blcs [3]int
946 | switch depth {
947 | case 20:
948 | num_blcs[0], num_blcs[1], num_blcs[2] = 7, 5, 5
949 | case 14:
950 | num_blcs[0], num_blcs[1], num_blcs[2] = 5, 3, 3
951 | case 8:
952 | num_blcs[0], num_blcs[1], num_blcs[2] = 3, 1, 1
953 | default:
954 | panic("wrong depth case (not in 8,14,20)!")
955 | }
956 | real_batch := []int{32, 64, 128} // same as python
957 | norm := []int{2, 4, 2} // only use 1/norm batches
958 | step := []int{1, 1, 2}
959 | prt_start := []int{1, 1, 1}
960 | kind := "Resnet_crop_fast_wide2"
961 | if ker_wid == 5 {
962 | prt_start[0] = 1
963 | prt_start[1] = 1
964 | prt_start[2] = 2
965 | }
966 | if wide_case == 3 {
967 | real_batch = []int{48, 96, 192}
968 | norm = []int{1, 2, 1}
969 | kind = "Resnet_crop_fast_wide3"
970 | } else if wide_case != 2 {
971 | panic("wrong wide_case (2 nor 3)!")
972 | }
973 |
974 | logN := 16
975 | alpha := 0.0
976 | in_wids := []int{32, 16, 8} // before cropping
977 | raw_in_wids := []int{32 - ker_wid/2, 16 - ker_wid/2, 8 - ker_wid/2} // same as python
978 | fast_pack := true
979 | ker_size := ker_wid * ker_wid
980 | max_batch := make([]int, len(real_batch)) // the max batch
981 | for i := range max_batch {
982 | max_batch[i] = (1 << logN) / (in_wids[i] * in_wids[i])
983 | }
984 |
985 | cont := newContext(logN, ker_wid, in_wids, raw_in_wids, true, kind)
986 |
987 | for iter := st; iter < end; iter++ {
988 | fmt.Println("Running ", iter, "-th iter... ker size: ", ker_wid)
989 | image := readTxt("Resnet_plain_data/crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid"+strconv.Itoa(wide_case)+"/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3)
990 | if cf100 {
991 | image = readTxt("Resnet_plain_data/cf100_crop_ker"+strconv.Itoa(ker_wid)+"_d"+strconv.Itoa(depth)+"_wid"+strconv.Itoa(wide_case)+"/test_image_"+strconv.Itoa(iter)+".csv", in_wids[0]*in_wids[0]*3)
992 | }
993 | input := make([]float64, cont.N)
994 | k := 0
995 | for i := 0; i < in_wids[0]; i++ {
996 | for j := 0; j < in_wids[0]; j++ {
997 | for b := 0; b < 3; b++ {
998 | if (i < raw_in_wids[0]) && (j < raw_in_wids[0]) {
999 | input[i*in_wids[0]*max_batch[0]+j*max_batch[0]+b*norm[0]] = image[k]
1000 | }
1001 | k++
1002 | }
1003 | }
1004 | }
1005 | fmt.Println("Input: ")
1006 | prt_mat_norm(input, max_batch[0], norm[0], 1, false)
1007 | fmt.Println("vec size: ", cont.N)
1008 | fmt.Println("input width: ", raw_in_wids)
1009 | fmt.Println("kernel width: ", ker_wid)
1010 | fmt.Println("num batches: ", real_batch)
1011 |
1012 | enc_start := time.Now()
1013 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values
1014 | cont.encoder.EncodeCoeffs(input, pl_input)
1015 | ct_input := cont.encryptor.EncryptNew(pl_input)
1016 | fmt.Printf("Encryption done in %s \n", time.Since(enc_start))
1017 | enc_start = time.Now()
1018 |
1019 | timings := make([]float64, 6)
1020 | begin_start := time.Now()
1021 | start := time.Now()
1022 |
1023 | // ResNet Block 1
1024 | pow := init_pow
1025 | ct_layer := ct_input
1026 | for i := 1; i <= num_blcs[0]; i++ {
1027 | if i == 5 {
1028 | pow = mid_pow
1029 | }
1030 | var bn_batch int
1031 | if i == 1 {
1032 | bn_batch = init_batch
1033 | } else {
1034 | bn_batch = real_batch[0]
1035 | }
1036 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-a.csv", bn_batch)
1037 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-b.csv", bn_batch)
1038 | if i == 1 {
1039 | ker_in := readTxt(weight_dir+"w0-conv.csv", 3*init_batch*ker_size)
1040 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, 3, init_batch, norm[0], 0, step[0], 2, 0, "Conv", fast_pack, debug)
1041 | // pow = mid_pow
1042 | } else if i == 2 {
1043 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", init_batch*real_batch[0]*ker_size)
1044 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, init_batch, real_batch[0], norm[0], 0, step[0], 2, 0, "Conv", fast_pack, debug)
1045 | } else {
1046 | ker_in := readTxt(weight_dir+"w"+strconv.Itoa(i-1)+"-conv.csv", real_batch[0]*real_batch[0]*ker_size)
1047 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in, bn_a, bn_b, alpha, pow, in_wids[0], raw_in_wids[0], ker_wid, real_batch[0], real_batch[0], norm[0], 0, step[0], 2, 0, "Conv", fast_pack, debug)
1048 | }
1049 | fmt.Println("Block1, Layer ", i, "done!")
1050 | }
1051 | fmt.Println("Block1 done.")
1052 | timings[0] = time.Since(start).Seconds()
1053 | start = time.Now()
1054 |
1055 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-conv.csv", real_batch[0]*real_batch[1]*ker_size)
1056 | ker_in12_new := make([]float64, 2*real_batch[0]*real_batch[1]*ker_size)
1057 | ker_in12_0 := make([]float64, len(ker_in12)/2)
1058 | ker_in12_1 := make([]float64, len(ker_in12)/2)
1059 | if wide_case == 2 {
1060 | for k := 0; k < ker_size; k++ {
1061 | for i := 0; i < real_batch[0]; i++ {
1062 | for j := 0; j < real_batch[1]; j++ {
1063 | ker_in12_new[k*2*real_batch[0]*real_batch[1]+2*i*real_batch[1]+j] = ker_in12[k*real_batch[0]*real_batch[1]+i*real_batch[1]+j]
1064 | }
1065 | }
1066 | }
1067 | } else if wide_case == 3 {
1068 | for k := 0; k < ker_size; k++ {
1069 | for i := 0; i < real_batch[0]; i++ {
1070 | for j := 0; j < real_batch[1]/2; j++ {
1071 | ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j)] // [i][2*j]
1072 | ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j+1)] // [i][2*j+1]
1073 | }
1074 | }
1075 | }
1076 | }
1077 |
1078 | bn_a := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-a.csv", real_batch[1])
1079 | bn_b := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0])+"-b.csv", real_batch[1])
1080 |
1081 | if wide_case == 2 {
1082 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in12_new, bn_a, bn_b, alpha, pow, in_wids[0], 2*raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[0]/2, 0, step[1], 2, 0, "StrConv_odd", fast_pack, debug)
1083 | } else if wide_case == 3 {
1084 | bn_a_0 := make([]float64, real_batch[1]/2)
1085 | bn_a_1 := make([]float64, real_batch[1]/2)
1086 | bn_b_0 := make([]float64, real_batch[1]/2)
1087 | bn_b_1 := make([]float64, real_batch[1]/2)
1088 | for i := range bn_b_0 {
1089 | bn_a_0[i] = bn_a[2*i]
1090 | bn_a_1[i] = bn_a[2*i+1]
1091 | bn_b_0[i] = bn_b[2*i]
1092 | bn_b_1[i] = bn_b[2*i+1]
1093 | }
1094 | ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a_0, bn_b_0, alpha, pow, in_wids[0], 2*raw_in_wids[1], ker_wid, real_batch[0], real_batch[0], norm[0], 0, step[1], 2, 0, "StrConv_odd", fast_pack, debug)
1095 | ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a_1, bn_b_1, alpha, pow, in_wids[0], 2*raw_in_wids[1], ker_wid, real_batch[0], real_batch[0], norm[0], 2, step[1], 2, 0, "StrConv_odd", fast_pack, debug)
1096 | ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2)
1097 | }
1098 | fmt.Println("Block1 to 2 done!")
1099 | if debug {
1100 | max_bat := cont.N / (in_wids[1] * in_wids[1])
1101 | res_ttmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_layer))
1102 | prt_mat_norm_step(res_ttmp, max_bat, norm[1], step[1], prt_start[1], 3, false)
1103 | }
1104 | timings[1] = time.Since(start).Seconds()
1105 | start = time.Now()
1106 |
1107 | // ResNet Block 2
1108 | for i := 1; i <= num_blcs[1]; i++ {
1109 | if i == 5 {
1110 | pow = init_pow
1111 | }
1112 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-a.csv", real_batch[1])
1113 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-b.csv", real_batch[1])
1114 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+i)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size)
1115 |
1116 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], raw_in_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, step[1], 2, 0, "Conv_inside", fast_pack, debug)
1117 | fmt.Println("Block2, Layer ", i, "done!")
1118 | }
1119 | fmt.Println("Block2 done.")
1120 | timings[2] = time.Since(start).Seconds()
1121 | start = time.Now()
1122 |
1123 | pow = mid_pow
1124 | ker_in23 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-conv.csv", real_batch[1]*real_batch[2]*ker_size)
1125 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-a.csv", real_batch[2])
1126 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+1)+"-b.csv", real_batch[2])
1127 | ker_in23_new := make([]float64, 2*real_batch[1]*real_batch[2]*ker_size)
1128 | for k := 0; k < ker_size; k++ {
1129 | for i := 0; i < real_batch[1]; i++ {
1130 | for j := 0; j < real_batch[2]; j++ {
1131 | ker_in23_new[k*2*real_batch[1]*real_batch[2]+2*i*real_batch[2]+j] = ker_in23[k*real_batch[1]*real_batch[2]+i*real_batch[2]+j]
1132 | }
1133 | }
1134 | }
1135 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in23_new, bn_a3, bn_b3, alpha, pow, in_wids[1], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 0, "StrConv_inside", fast_pack, debug)
1136 | fmt.Println("Block2 to 3 done!")
1137 | if debug {
1138 | max_bat := cont.N / (in_wids[1] * in_wids[1])
1139 | res_ttmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_layer))
1140 | prt_mat_norm_step(res_ttmp, max_bat, norm[2], step[2], prt_start[2], 3, false)
1141 | }
1142 | timings[3] = time.Since(start).Seconds()
1143 | start = time.Now()
1144 |
1145 | // ResNet Block 3
1146 | for i := 1; i <= num_blcs[2]; i++ {
1147 | if i == 3 {
1148 | pow = init_pow
1149 | }
1150 | if i == 5 {
1151 | pow = mid_pow
1152 | }
1153 | bn_a3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-a.csv", real_batch[2])
1154 | bn_b3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-b.csv", real_batch[2])
1155 | ker_in3 := readTxt(weight_dir+"w"+strconv.Itoa(num_blcs[0]+num_blcs[1]+i+1)+"-conv.csv", real_batch[2]*real_batch[2]*ker_size)
1156 |
1157 | if i == num_blcs[2] {
1158 | pow = final_pow
1159 | }
1160 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in3, bn_a3, bn_b3, alpha, pow, in_wids[1], raw_in_wids[2], ker_wid, real_batch[2], real_batch[2], norm[2], 0, step[2], 2, 0, "Conv_inside", fast_pack, debug)
1161 | fmt.Println("Block3, Layer ", i, "done!")
1162 | }
1163 | fmt.Println("Block3 done.")
1164 | timings[4] = time.Since(start).Seconds()
1165 | start = time.Now()
1166 |
1167 | ker_inf_wid := raw_in_wids[1]
1168 | if ker_inf_wid%2 == 0 {
1169 | ker_inf_wid++
1170 | }
1171 | ker_inf := readTxt(weight_dir+"final-fckernel.csv", real_batch[2]*fc_out)
1172 | ker_inf_ := make([]float64, ker_inf_wid*ker_inf_wid*real_batch[2]*fc_out)
1173 | for i := range ker_inf {
1174 | for b := 0; b < ker_inf_wid*ker_inf_wid; b++ {
1175 | ker_inf_[i+b*real_batch[2]*fc_out] = ker_inf[i]
1176 | }
1177 | }
1178 | bn_af := make([]float64, fc_out)
1179 | for i := range bn_af {
1180 | bn_af[i] = 1.0 / float64(raw_in_wids[2]*raw_in_wids[2]) // for reduce mean on raw_in_wids[2]**2 elements
1181 | }
1182 | bn_bf := readTxt(weight_dir+"final-fcbias.csv", fc_out)
1183 | ct_result := evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[1], ker_inf_wid, real_batch[2], fc_out, norm[2], float64(1<<30), false)
1184 | fmt.Println("Final FC done.")
1185 | timings[5] = time.Since(start).Seconds()
1186 | start = time.Now()
1187 |
1188 | fmt.Println()
1189 | fmt.Println("=============== DECRYPTION ===============")
1190 | fmt.Println()
1191 | cont.decryptor.Decrypt(ct_result, pl_input)
1192 | res_tmp := cont.encoder.DecodeCoeffs(pl_input)
1193 | fmt.Printf("Decryption Done in %s \n", time.Since(start))
1194 | res_out := prt_mat_one_norm(res_tmp, max_batch[1], norm[2], ker_inf_wid/2+1, ker_inf_wid/2+1)
1195 | // fmt.Print(res_out)
1196 | fmt.Println("\n result: ", res_out[:fc_out])
1197 | writeTxt(out_dir+"class_result_"+ker_name+"_"+strconv.Itoa(iter)+".csv", res_out[:fc_out])
1198 |
1199 | fmt.Println("Blc1: ", timings[0], " sec")
1200 | fmt.Println("Blc1->2: ", timings[1], " sec")
1201 | fmt.Println("Blc2: ", timings[2], " sec")
1202 | fmt.Println("Blc2->3: ", timings[3], " sec")
1203 | fmt.Println("Blc3: ", timings[4], " sec")
1204 | fmt.Println("Final (reduce_mean & FC): ", timings[5], " sec")
1205 | fmt.Printf("Total done in %s \n", time.Since(begin_start))
1206 | }
1207 | }
1208 |
1209 | func testImagenet_final_fast_in(st, end, ker_wid int) {
1210 | // We use full packing: i.e., in_wid**2 element is contained in po2_in_wid**2 sized block <-> half padding of Resnet
1211 | // So ReLU, keep or rot, StoC done on both the 1st & 2nd part of the CtoS ciphertexts
1212 | ker_name := "ker" + strconv.Itoa(ker_wid) // "ker5"
1213 | weight_dir := "weight_imgnet_" + ker_name + "_h5/"
1214 | logN := 16
1215 | raw_in_wids := []int{14, 7} // same as python
1216 | real_batch := []int{256, 512} // same as python
1217 | iter := 2
1218 | in_wids := make([]int, len(raw_in_wids))
1219 | kp_wids := make([]int, len(raw_in_wids))
1220 | var num_blc1, num_blc2 int
1221 | if ker_name == "ker3" {
1222 | in_wids[0] = 16
1223 | in_wids[1] = 8
1224 | kp_wids[0] = 14
1225 | kp_wids[1] = 7
1226 | num_blc1 = 3
1227 | num_blc2 = 3
1228 | } else if ker_name == "ker5" {
1229 | in_wids[0] = 16
1230 | in_wids[1] = 8
1231 | kp_wids[0] = 14
1232 | kp_wids[1] = 6
1233 | num_blc1 = 3
1234 | num_blc2 = 3
1235 | } else {
1236 | panic("strange ker name!")
1237 | }
1238 | cont := newContext(logN, ker_wid, in_wids, kp_wids, true, "Imagenet_final_fast")
1239 |
1240 | ker_size := ker_wid * ker_wid
1241 | max_batch := make([]int, len(real_batch)) // the max batch
1242 | for i := range max_batch {
1243 | max_batch[i] = cont.N / (in_wids[i] * in_wids[i])
1244 | }
1245 | alpha := 0.0 // 0.3 => leakyrelu
1246 | init_pow := 6.0
1247 | mid_pow := 5.0
1248 | final_pow := 6.0
1249 |
1250 | // ker5_iter := []int{804, 886, 901, 956}
1251 | // {3, 29, 87, 254, 357,
1252 | // 399, 435, 455, 475, 476,
1253 | // 518, 540, 545, 571, 631,
1254 | // 657, 699, 711, 748, 790,
1255 | // 804, 886, 901, 956}
1256 |
1257 | // for _, name_iter := range ker5_iter {
1258 | for name_iter := st; name_iter < end; name_iter++ {
1259 | weight_num := 10
1260 | norm := 1
1261 | fmt.Println("Start ", name_iter, "-th iter..")
1262 |
1263 | raw_input := readTxt(ker_name+"_data/test_image_"+strconv.Itoa(name_iter)+".csv", raw_in_wids[0]*raw_in_wids[0]*real_batch[0])
1264 | input := make([]float64, in_wids[0]*in_wids[0]*real_batch[0])
1265 | for i := 0; i < raw_in_wids[0]; i++ {
1266 | for j := 0; j < raw_in_wids[0]; j++ {
1267 | for b := 0; b < real_batch[0]; b++ {
1268 | input[i*in_wids[0]*real_batch[0]+j*real_batch[0]+b] = raw_input[i*raw_in_wids[0]*real_batch[0]+j*real_batch[0]+b]
1269 | }
1270 | }
1271 | }
1272 | fmt.Println("Input: ")
1273 | prt_mat(input, max_batch[0], 1)
1274 | fmt.Println("vec size: ", cont.N)
1275 | fmt.Println("input width: ", raw_in_wids)
1276 | fmt.Println("kernel width: ", ker_wid)
1277 | fmt.Println("num batches: ", real_batch)
1278 |
1279 | start := time.Now()
1280 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values
1281 | cont.encoder.EncodeCoeffs(input, pl_input)
1282 | ct_input := cont.encryptor.EncryptNew(pl_input)
1283 | fmt.Printf("Encryption done in %s \n", time.Since(start))
1284 |
1285 | timings := make([]float64, 4)
1286 | begin_start := time.Now()
1287 | new_start := time.Now()
1288 |
1289 | // Block 1
1290 | pow := init_pow
1291 | ct_layer := ct_input
1292 | for i := 1; i <= num_blc1; i++ {
1293 | ker_in1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[0]*real_batch[0]*ker_size)
1294 | weight_num++
1295 | bn_a1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[0])
1296 | bn_b1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[0])
1297 | if i == num_blc1 {
1298 | pow = mid_pow
1299 | }
1300 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in1, bn_a1, bn_b1, alpha, pow, in_wids[0], kp_wids[0], ker_wid, real_batch[0], real_batch[0], norm, 0, 0, iter, 0, "Conv", false, false)
1301 | fmt.Println("Block1, Layer ", i, "done!")
1302 | }
1303 | fmt.Println("Block1 done!")
1304 | timings[0] = time.Since(new_start).Seconds()
1305 | new_start = time.Now()
1306 |
1307 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[0]*real_batch[1]*ker_size)
1308 | weight_num++
1309 | bn_a12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[1])
1310 | bn_b12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[1])
1311 |
1312 | ker_in12_0 := make([]float64, len(ker_in12)/2)
1313 | ker_in12_1 := make([]float64, len(ker_in12)/2)
1314 | for k := 0; k < ker_size; k++ {
1315 | for i := 0; i < real_batch[0]; i++ {
1316 | for j := 0; j < real_batch[1]/2; j++ {
1317 | ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+j)] // [i][j]
1318 | ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+real_batch[1]/2+j)] // [i][j+B/2]
1319 | }
1320 | }
1321 | }
1322 | bn_a12_0 := make([]float64, real_batch[1]/2)
1323 | bn_a12_1 := make([]float64, real_batch[1]/2)
1324 | bn_b12_0 := make([]float64, real_batch[1]/2)
1325 | bn_b12_1 := make([]float64, real_batch[1]/2)
1326 | for i := range bn_b12_0 {
1327 | bn_a12_0[i] = bn_a12[i]
1328 | bn_a12_1[i] = bn_a12[i+real_batch[1]/2]
1329 | bn_b12_0[i] = bn_b12[i]
1330 | bn_b12_1[i] = bn_b12[i+real_batch[1]/2]
1331 | }
1332 |
1333 | // block1 to block 2
1334 | ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a12_0, bn_b12_0, alpha, pow, in_wids[0], 2*kp_wids[1], ker_wid, real_batch[0], real_batch[0], norm, 0, 0, iter, 0, "StrConv", false, false)
1335 | ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a12_1, bn_b12_1, alpha, pow, in_wids[0], 2*kp_wids[1], ker_wid, real_batch[0], real_batch[0], norm, 1, 0, iter, 0, "StrConv", false, false)
1336 | ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2)
1337 | fmt.Println("Block1 to 2 done!")
1338 | // res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_result))
1339 | // prt_mat_norm(res_tmp, max_batch[1], 1, 4, false)
1340 | timings[1] = time.Since(new_start).Seconds()
1341 | new_start = time.Now()
1342 |
1343 | // Block 2
1344 | for i := 1; i <= num_blc2; i++ {
1345 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size)
1346 | weight_num++
1347 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[1])
1348 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[1])
1349 | if i == num_blc2 {
1350 | pow = final_pow
1351 | }
1352 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], kp_wids[1], ker_wid, real_batch[1], real_batch[1], norm, 0, 0, iter, 0, "Conv", false, false)
1353 | fmt.Println("Block2, Layer ", i, "done!")
1354 | }
1355 | fmt.Println("Block2 done!")
1356 | timings[2] = time.Since(new_start).Seconds()
1357 | new_start = time.Now()
1358 |
1359 | // RMFC
1360 | fin_out_batch := 1000
1361 | ker_inf := readTxt(weight_dir+"fc.csv", real_batch[1]*fin_out_batch)
1362 | bn_af := make([]float64, real_batch[1]*2)
1363 | if ker_wid == 3 {
1364 | for i := range bn_af {
1365 | bn_af[i] = 1.0 / (7 * 7) // for reduce mean on 8*8 elements
1366 | }
1367 | } else {
1368 | for i := range bn_af {
1369 | bn_af[i] = 1.0 / (6 * 6) // for reduce mean on 8*8 elements
1370 | }
1371 | }
1372 | bn_bf := make([]float64, real_batch[1]*2)
1373 | for i := range bn_bf {
1374 | bn_bf[i] = 0.0 //10.0 * float64(i)
1375 | }
1376 | ker_inf_ := make([]float64, 7*7*real_batch[1]*fin_out_batch)
1377 | for b := 0; b < 7*7; b++ {
1378 | for i := 0; i < real_batch[1]; i++ {
1379 | for j := 0; j < fin_out_batch; j++ {
1380 | ker_inf_[b*real_batch[1]*fin_out_batch+i*fin_out_batch+j] = ker_inf[i*fin_out_batch+j]
1381 | }
1382 | }
1383 | }
1384 | ct_result := evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[1], 7, real_batch[1], 1000, 1, float64(1<<30), false)
1385 | timings[3] = time.Since(new_start).Seconds()
1386 | new_start = time.Now()
1387 |
1388 | cont.decryptor.Decrypt(ct_result, pl_input)
1389 | res_tmp := cont.encoder.DecodeCoeffs(pl_input)
1390 | fmt.Printf("Decryption done in %s \n", time.Since(new_start))
1391 | final_result := prt_mat_one_norm(res_tmp, max_batch[1], 1, 4, 4)
1392 | writeTxt(ker_name+"_enc_result/enc_result_"+strconv.Itoa(name_iter)+".csv", final_result[:1000])
1393 |
1394 | fmt.Println("Blc1: ", timings[0], " sec")
1395 | fmt.Println("Blc1->2: ", timings[1], " sec")
1396 | fmt.Println("Blc2: ", timings[2], " sec")
1397 | fmt.Println("Final (reduce_mean & FC): ", timings[3], " sec")
1398 | fmt.Printf("Total done in %s \n", time.Since(begin_start))
1399 | }
1400 | }
1401 |
1402 | func testImagenet_sparse(st, end, ker_wid int) {
1403 | // We use full packing: i.e., in_wid**2 element is contained in po2_in_wid**2 sized block <-> half padding of Resnet
1404 | // So ReLU, keep or rot, StoC done on both the 1st & 2nd part of the CtoS ciphertexts
1405 | debug := false
1406 | ker_name := "ker" + strconv.Itoa(ker_wid) // "ker5"
1407 | weight_dir := "weight_imgnet_" + ker_name + "_h5/"
1408 | logN := 16
1409 | raw_in_wids := []int{14, 7} // same as python
1410 | real_batch := []int{256, 512} // same as python
1411 | log_sparse := []int{0, 1}
1412 | norm := []int{1, 2}
1413 | iter := 2
1414 | in_wids := make([]int, len(raw_in_wids))
1415 | kp_wids := make([]int, len(raw_in_wids))
1416 | var num_blc1, num_blc2 int
1417 | if ker_name == "ker3" {
1418 | in_wids[0] = 16
1419 | in_wids[1] = 8
1420 | kp_wids[0] = 14
1421 | kp_wids[1] = 7
1422 | num_blc1 = 3
1423 | num_blc2 = 3
1424 | } else if ker_name == "ker5" {
1425 | in_wids[0] = 16
1426 | in_wids[1] = 8
1427 | kp_wids[0] = 14
1428 | kp_wids[1] = 6
1429 | num_blc1 = 3
1430 | num_blc2 = 3
1431 | } else {
1432 | panic("strange ker name!")
1433 | }
1434 | cont := newContext(logN, ker_wid, in_wids, kp_wids, true, "Imagenet_sparse")
1435 |
1436 | ker_size := ker_wid * ker_wid
1437 | max_batch := make([]int, len(real_batch)) // the max batch
1438 | for i := range max_batch {
1439 | max_batch[i] = cont.N / (in_wids[i] * in_wids[i])
1440 | }
1441 | alpha := 0.0 // 0.3 => leakyrelu
1442 | init_pow := 6.0
1443 | mid_pow := 5.0
1444 | final_pow := 6.0
1445 |
1446 | for name_iter := st; name_iter < end; name_iter++ {
1447 | weight_num := 10
1448 | fmt.Println("Start ", name_iter, "-th iter..")
1449 |
1450 | raw_input := readTxt(ker_name+"_data/test_image_"+strconv.Itoa(name_iter)+".csv", raw_in_wids[0]*raw_in_wids[0]*real_batch[0])
1451 | input := make([]float64, in_wids[0]*in_wids[0]*real_batch[0])
1452 | for i := 0; i < raw_in_wids[0]; i++ {
1453 | for j := 0; j < raw_in_wids[0]; j++ {
1454 | for b := 0; b < real_batch[0]; b++ {
1455 | input[i*in_wids[0]*real_batch[0]+j*real_batch[0]+b] = raw_input[i*raw_in_wids[0]*real_batch[0]+j*real_batch[0]+b]
1456 | }
1457 | }
1458 | }
1459 | fmt.Println("Input: ")
1460 | prt_mat(input, max_batch[0], 1)
1461 | fmt.Println("vec size: ", cont.N)
1462 | fmt.Println("input width: ", raw_in_wids)
1463 | fmt.Println("kernel width: ", ker_wid)
1464 | fmt.Println("num batches: ", real_batch)
1465 |
1466 | start := time.Now()
1467 | pl_input := ckks.NewPlaintext(cont.params, cont.ECD_LV, cont.params.Scale()) // contain plaintext values
1468 | cont.encoder.EncodeCoeffs(input, pl_input)
1469 | ct_input := cont.encryptor.EncryptNew(pl_input)
1470 | fmt.Printf("Encryption done in %s \n", time.Since(start))
1471 |
1472 | timings := make([]float64, 4)
1473 | begin_start := time.Now()
1474 | new_start := time.Now()
1475 |
1476 | // Block 1
1477 | pow := init_pow
1478 | ct_layer := ct_input
1479 | for i := 1; i <= num_blc1; i++ {
1480 | ker_in1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[0]*real_batch[0]*ker_size)
1481 | weight_num++
1482 | bn_a1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[0])
1483 | bn_b1 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[0])
1484 | if i == num_blc1 {
1485 | pow = mid_pow
1486 | }
1487 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in1, bn_a1, bn_b1, alpha, pow, in_wids[0], kp_wids[0], ker_wid, real_batch[0], real_batch[0], norm[0], 0, 1, iter, log_sparse[0], "Conv_sparse", true, debug)
1488 | fmt.Println("Block1, Layer ", i, "done!")
1489 | }
1490 | fmt.Println("Block1 done!")
1491 | timings[0] = time.Since(new_start).Seconds()
1492 | new_start = time.Now()
1493 |
1494 | ker_in12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[0]*real_batch[1]*ker_size)
1495 | weight_num++
1496 | bn_a12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[1])
1497 | bn_b12 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[1])
1498 |
1499 | ker_in12_0 := make([]float64, len(ker_in12)/2)
1500 | ker_in12_1 := make([]float64, len(ker_in12)/2)
1501 | for k := 0; k < ker_size; k++ {
1502 | for i := 0; i < real_batch[0]; i++ {
1503 | for j := 0; j < real_batch[1]/2; j++ {
1504 | // ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+j)] // [i][j]
1505 | // ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+real_batch[1]/2+j)] // [i][j+B/2]
1506 | ker_in12_0[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j)] // [i][2*j]
1507 | ker_in12_1[k*real_batch[0]*real_batch[1]/2+(i*real_batch[1]/2+j)] = ker_in12[k*real_batch[0]*real_batch[1]+(i*real_batch[1]+2*j+1)] // [i][2*j+1]
1508 |
1509 | }
1510 | }
1511 | }
1512 | bn_a12_0 := make([]float64, real_batch[1]/2)
1513 | bn_a12_1 := make([]float64, real_batch[1]/2)
1514 | bn_b12_0 := make([]float64, real_batch[1]/2)
1515 | bn_b12_1 := make([]float64, real_batch[1]/2)
1516 | for i := range bn_b12_0 {
1517 | // bn_a12_0[i] = bn_a12[i]
1518 | // bn_a12_1[i] = bn_a12[i+real_batch[1]/2]
1519 | // bn_b12_0[i] = bn_b12[i]
1520 | // bn_b12_1[i] = bn_b12[i+real_batch[1]/2]
1521 | bn_a12_0[i] = bn_a12[2*i]
1522 | bn_a12_1[i] = bn_a12[2*i+1]
1523 | bn_b12_0[i] = bn_b12[2*i]
1524 | bn_b12_1[i] = bn_b12[2*i+1]
1525 | }
1526 |
1527 | // block1 to block 2
1528 | // ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a12_0, bn_b12_0, alpha, pow, in_wids[0], 2*kp_wids[1], ker_wid, real_batch[0], real_batch[0], norm, 0, 0, iter, 0, "StrConv", false, false)
1529 | // ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a12_1, bn_b12_1, alpha, pow, in_wids[0], 2*kp_wids[1], ker_wid, real_batch[0], real_batch[0], norm, 1, 0, iter, 0, "StrConv", false, false)
1530 | // ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2)
1531 |
1532 | ct_result1 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_0, bn_a12_0, bn_b12_0, alpha, pow, in_wids[0], kp_wids[1], ker_wid, real_batch[0], real_batch[1]/2, norm[0], 0, 1, 2, 0, "StrConv_sparse_full", true, debug)
1533 | ct_result2 := evalConv_BNRelu_new(cont, ct_layer, ker_in12_1, bn_a12_1, bn_b12_1, alpha, pow, in_wids[0], kp_wids[1], ker_wid, real_batch[0], real_batch[1]/2, norm[0], 0, 1, 2, 0, "StrConv_sparse_full", true, debug)
1534 |
1535 | xi := make([]float64, cont.N)
1536 | xi[2] = 1.0
1537 | xi_plain := ckks.NewPlaintext(cont.params, ct_result2.Level(), 1.0)
1538 | cont.encoder.EncodeCoeffs(xi, xi_plain)
1539 | cont.encoder.ToNTT(xi_plain)
1540 | ct_result2 = cont.evaluator.MulNew(ct_result2, xi_plain)
1541 | ct_layer = cont.evaluator.AddNew(ct_result1, ct_result2)
1542 |
1543 | fmt.Println("Block1 to 2 done!")
1544 | // res_tmp := cont.encoder.DecodeCoeffs(cont.decryptor.DecryptNew(ct_result))
1545 | // prt_mat_norm(res_tmp, max_batch[1], 1, 4, false)
1546 | timings[1] = time.Since(new_start).Seconds()
1547 | new_start = time.Now()
1548 |
1549 | // Block 2
1550 | for i := 1; i <= num_blc2; i++ {
1551 | ker_in2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-conv.csv", real_batch[1]*real_batch[1]*ker_size)
1552 | weight_num++
1553 | bn_a2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-a.csv", real_batch[1])
1554 | bn_b2 := readTxt(weight_dir+"w"+strconv.Itoa(weight_num)+"-b.csv", real_batch[1])
1555 | if i == num_blc2 {
1556 | pow = final_pow
1557 | }
1558 | ct_layer = evalConv_BNRelu_new(cont, ct_layer, ker_in2, bn_a2, bn_b2, alpha, pow, in_wids[1], kp_wids[1], ker_wid, real_batch[1], real_batch[1], norm[1], 0, 1, iter, log_sparse[1], "Conv_sparse", true, debug)
1559 | fmt.Println("Block2, Layer ", i, "done!")
1560 | }
1561 | fmt.Println("Block2 done!")
1562 | timings[2] = time.Since(new_start).Seconds()
1563 | new_start = time.Now()
1564 |
1565 | // RMFC
1566 | fin_out_batch := 1000
1567 | ker_inf := readTxt(weight_dir+"fc.csv", real_batch[1]*fin_out_batch)
1568 | bn_af := make([]float64, real_batch[1]*2)
1569 | if ker_wid == 3 {
1570 | for i := range bn_af {
1571 | bn_af[i] = 1.0 / (7 * 7) // for reduce mean on 8*8 elements
1572 | }
1573 | } else {
1574 | for i := range bn_af {
1575 | bn_af[i] = 1.0 / (6 * 6) // for reduce mean on 8*8 elements
1576 | }
1577 | }
1578 | bn_bf := make([]float64, real_batch[1]*2)
1579 | for i := range bn_bf {
1580 | bn_bf[i] = 0.0 //10.0 * float64(i)
1581 | }
1582 | ker_inf_ := make([]float64, 7*7*real_batch[1]*fin_out_batch)
1583 | for b := 0; b < 7*7; b++ {
1584 | for i := 0; i < real_batch[1]; i++ {
1585 | for j := 0; j < fin_out_batch; j++ {
1586 | ker_inf_[b*real_batch[1]*fin_out_batch+i*fin_out_batch+j] = ker_inf[i*fin_out_batch+j]
1587 | }
1588 | }
1589 | }
1590 | ct_result := evalConv_BN(cont, ct_layer, ker_inf_, bn_af, bn_bf, in_wids[1], 7, real_batch[1], fin_out_batch, 1, float64(1<<30), false)
1591 | fmt.Println("Final FC done.")
1592 | timings[3] = time.Since(new_start).Seconds()
1593 | new_start = time.Now()
1594 |
1595 | cont.decryptor.Decrypt(ct_result, pl_input)
1596 | res_tmp := cont.encoder.DecodeCoeffs(pl_input)
1597 | fmt.Printf("Decryption done in %s \n", time.Since(new_start))
1598 | final_result := prt_mat_one_norm(res_tmp, max_batch[1], 1, 4, 4)
1599 | writeTxt(ker_name+"_enc_result/enc_result_"+strconv.Itoa(name_iter)+".csv", final_result[:1000])
1600 |
1601 | fmt.Println("Blc1: ", timings[0], " sec")
1602 | fmt.Println("Blc1->2: ", timings[1], " sec")
1603 | fmt.Println("Blc2: ", timings[2], " sec")
1604 | fmt.Println("Final (reduce_mean & FC): ", timings[3], " sec")
1605 | fmt.Printf("Total done in %s \n", time.Since(begin_start))
1606 | }
1607 | }
1608 |
--------------------------------------------------------------------------------
/test_BL.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "strconv"
7 | "time"
8 |
9 | "github.com/dwkim606/test_lattigo/ckks"
10 | )
11 |
12 | // BaseLine Conv without boot, Assume full batch with Po2 in_wid & N
13 | // Normal Conv without output modification (e.g., trimming or expanding)
14 | // Input does not need padding
15 | // use imaginary part to save before boot
16 | func testConv_BL_in(real_batch, in_wid, ker_wid, total_test_num int, boot bool) {
17 | in_kind := "Conv"
18 | test_dir := "test_conv_data/"
19 | if (in_kind != "TransConv") && (in_kind != "Conv") && (in_kind != "StrConv") {
20 | panic("Wrong in_kind!")
21 | }
22 | pad := ker_wid / 2
23 | raw_in_wid := in_wid - pad // = in_wid
24 |
25 | in_size := in_wid * in_wid
26 | ker_size := ker_wid * ker_wid
27 | slots := real_batch / 2 * in_size
28 | log_slots := 0
29 | for ; (1 << log_slots) < slots; log_slots++ {
30 | }
31 | out_batch := real_batch
32 | if in_kind == "TransConv" {
33 | out_batch = real_batch / 4
34 | }
35 | kp_wid := 0
36 | kind := "BL_" + in_kind
37 | in_batch := real_batch
38 |
39 | // generate Context: params, Keys, rotations, general plaintexts
40 | cont := newContext(log_slots+1, ker_wid, []int{in_wid}, []int{kp_wid}, boot, kind)
41 | fmt.Println("vec size: log2 = ", cont.logN)
42 | fmt.Println("raw input width: ", raw_in_wid)
43 | fmt.Println("kernel width: ", ker_wid)
44 | fmt.Println("num batches in & out: ", real_batch, ", ", out_batch)
45 |
46 | for test_iter := 0; test_iter < total_test_num; test_iter++ {
47 | fmt.Println(test_iter+1, "-th iter...start")
48 | input := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_in_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*in_batch)
49 | ker_in := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_ker_"+strconv.Itoa(test_iter)+".csv", in_batch*in_batch*ker_wid*ker_wid)
50 | bn_a := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_bna_"+strconv.Itoa(test_iter)+".csv", in_batch)
51 | bn_b := readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_bnb_"+strconv.Itoa(test_iter)+".csv", in_batch)
52 | var real_out []float64
53 | if boot {
54 | real_out = readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_reluout_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*in_batch)
55 | } else {
56 | real_out = readTxt(test_dir+"test_conv"+strconv.Itoa(ker_wid)+"_batch_"+strconv.Itoa(in_batch)+"_out_"+strconv.Itoa(test_iter)+".csv", raw_in_wid*raw_in_wid*in_batch)
57 | }
58 |
59 | pad_input1 := make([]float64, in_wid*in_wid*real_batch/2)
60 | pad_input2 := make([]float64, in_wid*in_wid*real_batch/2)
61 | for i := 0; i < raw_in_wid; i++ {
62 | for j := 0; j < raw_in_wid; j++ {
63 | for b := 0; b < real_batch/2; b++ {
64 | pad_input1[b+j*real_batch/2+i*real_batch/2*in_wid] = input[b+j*real_batch+i*real_batch*raw_in_wid]
65 | pad_input2[b+j*real_batch/2+i*real_batch/2*in_wid] = input[b+real_batch/2+j*real_batch+i*real_batch*raw_in_wid]
66 | }
67 | }
68 | }
69 |
70 | bn_a_sep := make([][]float64, 2)
71 | bn_b_sep := make([][]float64, 2)
72 | zeros := make([]float64, real_batch/2)
73 | for out := 0; out < 2; out++ {
74 | bn_a_sep[out] = make([]float64, real_batch/2)
75 | bn_b_sep[out] = make([]float64, real_batch/2)
76 | for i := 0; i < real_batch/2; i++ {
77 | bn_a_sep[out][i] = bn_a[i+out*real_batch/2]
78 | bn_b_sep[out][i] = bn_b[i+out*real_batch/2]
79 | }
80 | }
81 |
82 | ker_in_sep := make([][][]float64, 2) // number of output ctxts
83 | for out := 0; out < 2; out++ {
84 | ker_in_sep[out] = make([][]float64, 2) // number of input ctxts
85 | for in := 0; in < 2; in++ {
86 | ker_in_sep[out][in] = make([]float64, len(ker_in)/(2*2))
87 | for k := 0; k < ker_size; k++ {
88 | for i := 0; i < real_batch/2; i++ { // in
89 | for j := 0; j < real_batch/2; j++ { // out
90 | ker_in_sep[out][in][k*real_batch*real_batch/4+i*real_batch/2+j] =
91 | ker_in[k*real_batch*real_batch+(i+in*real_batch/2)*real_batch+out*real_batch/2+j] // [i][4*j]
92 | }
93 | }
94 | }
95 | }
96 | }
97 |
98 | input1_rs := reshape_input_BL(pad_input1, in_wid)
99 | input2_rs := reshape_input_BL(pad_input2, in_wid)
100 | start := time.Now()
101 | ct_input1 := cont.encryptor.EncryptNew(cont.encoder.EncodeAtLvlNew(cont.ECD_LV, input1_rs, cont.logN-1))
102 | ct_input2 := cont.encryptor.EncryptNew(cont.encoder.EncodeAtLvlNew(cont.ECD_LV, input2_rs, cont.logN-1))
103 | fmt.Printf("Encryption done in %s \n", time.Since(start))
104 |
105 | start_eval := time.Now()
106 | ct_res := make([]*ckks.Ciphertext, 2)
107 | for pos := 0; pos < 2; pos++ {
108 | ct_res[pos] = cont.evaluator.AddNew(evalConv_BN_BL_test(cont, ct_input1, ker_in_sep[pos][0], bn_a_sep[pos], bn_b_sep[pos], in_wid, ker_wid, real_batch/2, real_batch/2, 0, 1, pad, false, false),
109 | evalConv_BN_BL_test(cont, ct_input2, ker_in_sep[pos][1], bn_a_sep[pos], zeros, in_wid, ker_wid, real_batch/2, real_batch/2, 0, 1, pad, false, false))
110 | }
111 | fmt.Printf("Evaluation total done in %s \n", time.Since(start_eval))
112 |
113 | if boot {
114 | img_eval := time.Now()
115 | for pos := 0; pos < 2; pos++ {
116 | ct_res[pos] = cont.evaluator.AddNew(cont.pack_evaluator.ConjugateNew(ct_res[pos]), ct_res[pos])
117 | if pos == 1 {
118 | ct_res[pos] = cont.evaluator.MultByiNew(ct_res[pos])
119 | }
120 | // fmt.Println("before Boot: LV = ", ct_res[pos].Level(), " Scale = ", math.Log2(ct_res[pos].Scale))
121 | }
122 | ct_res[0] = cont.evaluator.AddNew(ct_res[0], ct_res[1])
123 | img_eval_part := time.Since(img_eval)
124 |
125 | pos := 0
126 | alpha := 0.0
127 | pow := 4.0
128 | ct_res[pos].Scale = ct_res[pos].Scale * math.Pow(2, pow+2) // +1 for 1/2 * (a + conj(a)), +1 for 1/2 after boot
129 | // vals_preB := cont.encoder.Decode(cont.decryptor.DecryptNew(ct_res[pos]), cont.logN-1)
130 | fmt.Println("\n ========= Bootstrapping... (original) ========= ")
131 | start_boot := time.Now()
132 | // fmt.Println("initial (before boot): LV = ", ct_res[pos].Level(), " Scale = ", math.Log2(ct_res[pos].Scale))
133 | ct_boot := cont.btp.Bootstrapp(ct_res[pos])
134 | fmt.Printf("Boot Done in %s \n", time.Since(start_boot))
135 |
136 | // fmt.Println("after Boot: LV = ", ct_boot.Level(), " Scale = ", math.Log2(ct_boot.Scale))
137 |
138 | // Only for checking the correctness (for Boot)
139 | // vals_postB := printDebug(cont.params, ct_boot, vals_preB, cont.decryptor, cont.encoder)
140 | // vals_relu := make([]complex128, len(vals_postB))
141 | // for i, elt := range vals_postB {
142 | // vals_relu[i] = complex((math.Max(0, real(elt))+math.Min(0, real(elt)*alpha))*math.Pow(2, pow), 0)
143 | // }
144 | img_eval = time.Now()
145 | pl_scale := ckks.NewPlaintext(cont.params, ct_boot.Level(), math.Pow(2, 30)*float64(cont.params.Q()[14])*float64(cont.params.Q()[13])/ct_boot.Scale)
146 | val_scale := make([]complex128, cont.N/2)
147 | for i := range val_scale {
148 | val_scale[i] = complex(1.0, 0) // val_scale[i] = complex(1.0/math.Pow(2, pow), 0)
149 | }
150 | cont.encoder.EncodeNTT(pl_scale, val_scale, cont.logN-1)
151 | cont.evaluator.Mul(ct_boot, pl_scale, ct_boot)
152 | cont.evaluator.Rescale(ct_boot, cont.params.Scale(), ct_boot)
153 |
154 | ct_iboot := cont.pack_evaluator.ConjugateNew(ct_boot)
155 | ct_res[0] = cont.evaluator.AddNew(ct_boot, ct_iboot)
156 | ct_res[1] = cont.evaluator.DivByiNew(cont.evaluator.SubNew(ct_boot, ct_iboot))
157 | fmt.Printf("Imaginary packing and unpacking done in %s \n", img_eval_part+time.Since(img_eval))
158 |
159 | start = time.Now()
160 | for pos := 0; pos < 2; pos++ {
161 | // fmt.Println("before relu: LV = ", ct_res[pos].Level(), " Scale = ", math.Log2(ct_res[pos].Scale))
162 | // fmt.Println("after Rescale: LV = ", ct_boot.Level(), " Scale = 2^", math.Log2(ct_boot.Scale))
163 | ct_res[pos] = evalReLU(cont.params, cont.evaluator, ct_res[pos], alpha)
164 | cont.evaluator.MulByPow2(ct_res[pos], int(pow), ct_res[pos])
165 | cont.evaluator.SetScale(ct_res[pos], cont.params.Scale())
166 | // printDebug(cont.params, ct_res[pos], vals_relu, cont.decryptor, cont.encoder)
167 | }
168 | fmt.Printf("Relu Done in %s \n", time.Since(start))
169 | }
170 |
171 | // fmt.Println()
172 | // fmt.Println("=============== DECRYPTION ===============")
173 | // fmt.Println()
174 | start = time.Now()
175 | vals_tmp1 := cont.encoder.Decode(cont.decryptor.DecryptNew(ct_res[0]), log_slots)
176 | vals_tmp2 := cont.encoder.Decode(cont.decryptor.DecryptNew(ct_res[1]), log_slots)
177 | fmt.Printf("Decryption Done in %s \n", time.Since(start))
178 |
179 | test_out := post_trim_BL(vals_tmp1, raw_in_wid, in_wid)
180 | test_out = append(test_out, post_trim_BL(vals_tmp2, raw_in_wid, in_wid)...)
181 | test_out = post_process_BL(test_out, raw_in_wid)
182 |
183 | printDebugCfsPlain(test_out, real_out)
184 | }
185 | }
186 |
--------------------------------------------------------------------------------
/test_run:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dwkim606/optimal_conv/c230bae4fee377ffd9e92b4887752e6cffd0205a/test_run
--------------------------------------------------------------------------------