├── .gitmodules ├── CMakeLists.txt ├── LICENSE.md ├── README.md ├── data.tar.gz ├── script ├── build.sh ├── demo_lenet.sh └── demo_vgg.sh └── src ├── CMakeLists.txt ├── circuit.cpp ├── circuit.h ├── global_var.hpp ├── main_demo_lenet.cpp ├── main_demo_vgg.cpp ├── models.cpp ├── models.hpp ├── neuralNetwork.cpp ├── neuralNetwork.hpp ├── polynomial.cpp ├── polynomial.h ├── prover.cpp ├── prover.hpp ├── utils.cpp ├── utils.hpp ├── verifier.cpp └── verifier.hpp /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rd/hyrax-bls12-381"] 2 | path = 3rd/hyrax-bls12-381 3 | url = git@github.com:TAMUCrypto/hyrax-bls12-381.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(zkCNN) 3 | set(CMAKE_CXX_STANDARD 14) 4 | 5 | link_directories(3rd/hyrax-bls12-381) 6 | 7 | include_directories(src) 8 | include_directories(3rd) 9 | include_directories(3rd/hyrax-bls12-381/3rd/mcl/include) 10 | 11 | add_subdirectory(src) 12 | add_subdirectory(3rd/hyrax-bls12-381) -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 TAMUCrypto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zkCNN 2 | 3 | ## Introduction 4 | 5 | This is the implementation of [this paper](https://eprint.iacr.org/2021/673), which is a GKR-based snark for CNN reference, containing some common CNN models such as LeNet5, vgg11 and vgg16. **Currently this version doesn't add complete zero-knowledge property.** 6 | 7 | 8 | 9 | ## Requirement 10 | 11 | - C++14 12 | - cmake >= 3.10 13 | - GMP library 14 | 15 | 16 | 17 | ## Input Format 18 | 19 | To run the program, the command is 20 | ```bash 21 | # run_file is the name of executable file. 22 | # in_file is the file containing the data and weight matrix, please refer to the details below. 23 | # config_file is the file containing config (scale and zero-point) of the model. In this implementation, 24 | # we don't use this file because we compute the scale and zero-point directly from the data in each 25 | # layer, so it's okay to put an empty file here. 26 | # ou_file is the prediction result of this picture. 27 | # exp_result is the experiment results filled in the table. For the definition of the table header, 28 | # please refer to the file `src/global_var.hpp`. 29 | # pic_cnt is the number of pictures to be predicted. For details please refer to the following section. 30 | 31 | ${run_file} ${in_file} ${config_file} ${ou_file} ${pic_cnt} > ${exp_result} 32 | ``` 33 | 34 | ### The format of `in_file` 35 | 36 | In the current code, we only allow one picture and one matrix to be put in this file. If `pic_cnt` > 1, then the code will internally duplicate the data for corresponding times. Thus to test different pictures, you might need to adjust the code of loading the input. 37 | 38 | #### Data Part 39 | 40 | This part is for a picture data, a vector reshaped from its original matrix by 41 | 42 | ![formula1](https://render.githubusercontent.com/render/math?math=ch_{in}%20%5Ccdot%20h\times%20w) 43 | 44 | where ![formula2](https://render.githubusercontent.com/render/math?math=ch_{in}) is the number of channel, ![formula3](https://render.githubusercontent.com/render/math?math=h) is the height, ![formula4](https://render.githubusercontent.com/render/math?math=w) is the width. 45 | 46 | #### Weight Part 47 | 48 | This part is the set of parameters in the neural network, which contains 49 | 50 | - convolution kernel of size ![formula10](https://render.githubusercontent.com/render/math?math=ch_{out}%20\times%20ch_{in}%20\times%20m%20\times%20m) 51 | 52 | where ![formula11](https://render.githubusercontent.com/render/math?math=ch_{out}) and ![formula12](https://render.githubusercontent.com/render/math?math=ch_{in}) are the number of output and input channels, ![formula13](https://render.githubusercontent.com/render/math?math=m) is the sideness of the kernel (here we only support square kernel). 53 | 54 | - convolution bias of size ![formula16](https://render.githubusercontent.com/render/math?math=ch_{out}). 55 | 56 | - fully-connected kernel of size ![formula14](https://render.githubusercontent.com/render/math?math=ch_{out}\times%20ch_{in}). 57 | 58 | - fully-connected bias of size ![formula15](https://render.githubusercontent.com/render/math?math=ch_{out}). 59 | 60 | ### The format of `config_file` 61 | Typically this is a file to record scale and zero-point for the data in each layer. However, in our current implementation, those are computed directly from the those data. Therefore, you can just leave it blank. 62 | 63 | ## Experiment Script 64 | ### Clone the repo 65 | To run the code, make sure you clone with 66 | ``` bash 67 | git clone --recurse-submodules git@github.com:TAMUCrypto/zkCNN.git 68 | ``` 69 | since the polynomial commitment is included as a submodule. 70 | 71 | ### Run a demo of LeNet5 72 | The script to run LeNet5 model (please run the script in ``script/`` directory). 73 | ``` bash 74 | ./demo_lenet.sh 75 | ``` 76 | 77 | - The input data is in ``data/lenet5.mnist.relu.max/``. 78 | - The experiment evaluation is ``output/single/demo-result-lenet5.txt``. 79 | - The inference result is ``output/single/lenet5.mnist.relu.max-1-infer.csv``. 80 | 81 | 82 | ### Run a demo of vgg11 83 | The script to run vgg11 model (please run the script in ``script/`` directory). 84 | ``` bash 85 | ./demo_vgg.sh 86 | ``` 87 | 88 | - The input data is in ``data/vgg11/``. 89 | - The experiment evaluation is ``output/single/demo-result.txt``. 90 | - The inference result is ``output/single/vgg11.cifar.relu-1-infer.csv``. 91 | 92 | ## Polynomial Commitment 93 | 94 | Here we implement a [hyrax polynomial commitment](https://eprint.iacr.org/2017/1132.pdf) based on BLS12-381 elliptic curve. It is a submodule and someone who is interested can refer to this repo [hyrax-bls12-381](https://github.com/TAMUCrypto/hyrax-bls12-381). 95 | 96 | ## Reference 97 | - [zkCNN: Zero knowledge proofs for convolutional neural network predictions and accuracy](https://doi.org/10.1145/3460120.3485379). 98 | Liu, T., Xie, X., & Zhang, Y. (CCS 2021). 99 | 100 | - [Doubly-efficient zksnarks without trusted setup](https://doi.org/10.1109/SP.2018.00060). Wahby, R. S., Tzialla, I., Shelat, A., Thaler, J., & Walfish, M. (S&P 2018) 101 | 102 | - [Hyrax](https://github.com/hyraxZK/hyraxZK.git) 103 | 104 | - [mcl](https://github.com/herumi/mcl) 105 | -------------------------------------------------------------------------------- /data.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TAMUCrypto/zkCNN/bdc3309255154447e5d72ac813e3f00f88cd55e3/data.tar.gz -------------------------------------------------------------------------------- /script/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd .. 3 | mkdir -p cmake-build-release 4 | cd cmake-build-release 5 | cmake -DCMAKE_BUILD_TYPE=Release -G "CodeBlocks - Unix Makefiles" .. 6 | make 7 | cd .. 8 | 9 | if [ ! -d "./data" ] 10 | then 11 | tar -xzvf data.tar.gz 12 | fi 13 | cd script 14 | -------------------------------------------------------------------------------- /script/demo_lenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | ./build.sh 6 | /usr/bin/cmake --build ../cmake-build-release --target demo_lenet_run -- -j 6 7 | 8 | run_file=../cmake-build-release/src/demo_lenet_run 9 | out_file=../output/single/demo-result-lenet5.txt 10 | 11 | mkdir -p ../output/single 12 | mkdir -p ../log/single 13 | 14 | lenet_i=../data/lenet5.mnist.relu.max/lenet5.mnist.relu.max-1-images-weights-qint8.csv 15 | lenet_c=../data/lenet5.mnist.relu.max/lenet5.mnist.relu.max-1-scale-zeropoint-uint8.csv 16 | lenet_o=../output/single/lenet5.mnist.relu.max-1-infer.csv 17 | 18 | ${run_file} ${lenet_i} ${lenet_c} ${lenet_o} 1 > ${out_file} -------------------------------------------------------------------------------- /script/demo_vgg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | ./build.sh 6 | /usr/bin/cmake --build ../cmake-build-release --target demo_vgg_run -- -j 6 7 | 8 | run_file=../cmake-build-release/src/demo_vgg_run 9 | out_file=../output/single/demo-result-vgg11.txt 10 | 11 | mkdir -p ../output/single 12 | mkdir -p ../log/single 13 | 14 | vgg11_i=../data/vgg11/vgg11.cifar.relu-1-images-weights-qint8.csv 15 | vgg11_c=../data/vgg11/vgg11.cifar.relu-1-scale-zeropoint-uint8.csv 16 | vgg11_o=../output/single/vgg11.cifar.relu-1-infer.csv 17 | vgg11_n=../data/vgg11/vgg11-config.csv 18 | 19 | ${run_file} ${vgg11_i} ${vgg11_c} ${vgg11_o} ${vgg11_n} 1 > ${out_file} -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | aux_source_directory(. conv_src) 2 | list(FILTER conv_src EXCLUDE REGEX "main*") 3 | 4 | add_library(cnn_lib ${conv_src}) 5 | 6 | add_executable(demo_vgg_run main_demo_vgg.cpp) 7 | target_link_libraries(demo_vgg_run cnn_lib hyrax_lib mcl mclbn384_256) 8 | 9 | add_executable(demo_lenet_run main_demo_lenet.cpp) 10 | target_link_libraries(demo_lenet_run cnn_lib hyrax_lib mcl mclbn384_256) -------------------------------------------------------------------------------- /src/circuit.cpp: -------------------------------------------------------------------------------- 1 | #include "circuit.h" 2 | #include "utils.hpp" 3 | 4 | void layeredCircuit::initSubset() { 5 | cerr << "begin subset init." << endl; 6 | vector visited_uidx(circuit[0].size); // whether the i-th layer, j-th gate has been visited in the current layer 7 | vector subset_uidx(circuit[0].size); // the subset index of the i-th layer, j-th gate 8 | vector visited_vidx(circuit[0].size); // whether the i-th layer, j-th gate has been visited in the current layer 9 | vector subset_vidx(circuit[0].size); // the subset index of the i-th layer, j-th gate 10 | 11 | for (u8 i = 1; i < size; ++i) { 12 | auto &cur = circuit[i], &lst = circuit[i - 1]; 13 | bool has_pre_layer_u = circuit[i].ty == layerType::FFT || circuit[i].ty == layerType::IFFT; 14 | bool has_pre_layer_v = false; 15 | 16 | for (auto &gate: cur.uni_gates) { 17 | if (!gate.lu) { 18 | if (visited_uidx[gate.u] != i) { 19 | visited_uidx[gate.u] = i; 20 | subset_uidx[gate.u] = cur.size_u[0]; 21 | cur.ori_id_u.push_back(gate.u); 22 | ++cur.size_u[0]; 23 | } 24 | gate.u = subset_uidx[gate.u]; 25 | } 26 | has_pre_layer_u |= (gate.lu != 0); 27 | } 28 | 29 | for (auto &gate: cur.bin_gates) { 30 | if (!gate.getLayerIdU(i)) { 31 | if (visited_uidx[gate.u] != i) { 32 | visited_uidx[gate.u] = i; 33 | subset_uidx[gate.u] = cur.size_u[0]; 34 | cur.ori_id_u.push_back(gate.u); 35 | ++cur.size_u[0]; 36 | } 37 | gate.u = subset_uidx[gate.u]; 38 | } 39 | if (!gate.getLayerIdV(i)) { 40 | if (visited_vidx[gate.v] != i) { 41 | visited_vidx[gate.v] = i; 42 | subset_vidx[gate.v] = cur.size_v[0]; 43 | cur.ori_id_v.push_back(gate.v); 44 | ++cur.size_v[0]; 45 | } 46 | gate.v = subset_vidx[gate.v]; 47 | } 48 | has_pre_layer_u |= (gate.getLayerIdU(i) != 0); 49 | has_pre_layer_v |= (gate.getLayerIdV(i) != 0); 50 | } 51 | 52 | cur.bit_length_u[0] = ceilPow2BitLength(cur.size_u[0]); 53 | cur.bit_length_v[0] = ceilPow2BitLength(cur.size_v[0]); 54 | 55 | if (has_pre_layer_u) switch (cur.ty) { 56 | case layerType::FFT: 57 | cur.size_u[1] = 1ULL << cur.fft_bit_length - 1; 58 | cur.bit_length_u[1] = cur.fft_bit_length - 1; 59 | break; 60 | case layerType::IFFT: 61 | cur.size_u[1] = 1ULL << cur.fft_bit_length; 62 | cur.bit_length_u[1] = cur.fft_bit_length; 63 | break; 64 | default: 65 | cur.size_u[1] = lst.size ; 66 | cur.bit_length_u[1] = lst.bit_length; 67 | break; 68 | } else { 69 | cur.size_u[1] = 0; 70 | cur.bit_length_u[1] = -1; 71 | } 72 | 73 | if (has_pre_layer_v) { 74 | if (cur.ty == layerType::DOT_PROD) { 75 | cur.size_v[1] = lst.size >> cur.fft_bit_length; 76 | cur.bit_length_v[1] = lst.bit_length - cur.fft_bit_length; 77 | } else { 78 | cur.size_v[1] = lst.size; 79 | cur.bit_length_v[1] = lst.bit_length; 80 | } 81 | } else { 82 | cur.size_v[1] = 0; 83 | cur.bit_length_v[1] = -1; 84 | } 85 | cur.updateSize(); 86 | } 87 | cerr << "begin subset finish." << endl; 88 | } 89 | 90 | void layeredCircuit::init(u8 q_bit_size, u8 _layer_sz) { 91 | two_mul.resize((q_bit_size + 1) << 1); 92 | two_mul[0] = F_ONE; 93 | two_mul[q_bit_size + 1] = -F_ONE; 94 | for (int i = 1; i <= q_bit_size; ++i) { 95 | two_mul[i] = two_mul[i - 1] + two_mul[i - 1]; 96 | two_mul[i + q_bit_size + 1] = -two_mul[i]; 97 | } 98 | size = _layer_sz; 99 | circuit.resize(size); 100 | } 101 | -------------------------------------------------------------------------------- /src/circuit.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "global_var.hpp" 10 | 11 | using std::cerr; 12 | using std::endl; 13 | using std::vector; 14 | 15 | struct uniGate { 16 | u32 g, u; 17 | u8 lu, sc; 18 | uniGate(u32 _g, u32 _u, u8 _lu, u8 _sc) : 19 | g(_g), u(_u), lu(_lu), sc(_sc) { 20 | // cerr << "uni: " << g << ' ' << u << ' ' << lu <<' ' << sc.real << endl; 21 | } 22 | }; 23 | 24 | struct binGate { 25 | u32 g, u, v; 26 | u8 sc, l; 27 | binGate(u32 _g, u32 _u, u32 _v, u8 _sc, u8 _l): 28 | g(_g), u(_u), v(_v), sc(_sc), l(_l) { 29 | // cerr << "bin: " << g << ' ' << u << ' ' << lu << ' ' << v << ' ' << lu << ' ' << sc.real << endl; 30 | } 31 | [[nodiscard]] u8 getLayerIdU(u8 layer_id) const { return !l ? 0 : layer_id - 1; } 32 | [[nodiscard]] u8 getLayerIdV(u8 layer_id) const { return !(l & 1) ? 0 : layer_id - 1; } 33 | }; 34 | 35 | enum class layerType { 36 | INPUT, FFT, IFFT, ADD_BIAS, RELU, Sqr, OPT_AVG_POOL, MAX_POOL, AVG_POOL, DOT_PROD, PADDING, FCONN, NCONV, NCONV_MUL, NCONV_ADD 37 | }; 38 | 39 | class layer { 40 | public: 41 | layerType ty; 42 | u32 size{}, size_u[2]{}, size_v[2]{}; 43 | i8 bit_length_u[2]{}, bit_length_v[2]{}, bit_length{}; 44 | i8 max_bl_u{}, max_bl_v{}; 45 | 46 | bool need_phase2; 47 | 48 | // bit decomp related 49 | u32 zero_start_id; 50 | 51 | std::vector uni_gates; 52 | std::vector bin_gates; 53 | 54 | vector ori_id_u, ori_id_v; 55 | i8 fft_bit_length; 56 | 57 | // iFFT or avg pooling. 58 | F scale; 59 | 60 | layer() { 61 | bit_length_u[0] = bit_length_v[0] = -1; 62 | size_u[0] = size_v[0] = 0; 63 | bit_length_u[1] = bit_length_v[1] = -1; 64 | size_u[1] = size_v[1] = 0; 65 | need_phase2 = false; 66 | zero_start_id = 0; 67 | fft_bit_length = -1; 68 | scale = F_ONE; 69 | } 70 | 71 | void updateSize() { 72 | max_bl_u = std::max(bit_length_u[0], bit_length_u[1]); 73 | max_bl_v = 0; 74 | if (!need_phase2) return; 75 | 76 | max_bl_v = std::max(bit_length_v[0], bit_length_v[1]); 77 | } 78 | }; 79 | 80 | class layeredCircuit { 81 | public: 82 | vector circuit; 83 | u8 size; 84 | vector two_mul; 85 | 86 | void init(u8 q_bit_size, u8 _layer_sz); 87 | void initSubset(); 88 | }; 89 | 90 | -------------------------------------------------------------------------------- /src/global_var.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 5/4/2021. 3 | // 4 | 5 | #include 6 | #include 7 | 8 | #ifndef ZKCNN_GLOBAL_VAR_HPP 9 | #define ZKCNN_GLOBAL_VAR_HPP 10 | 11 | // the output format 12 | #define MO_INFO_OUT_ID 0 13 | #define PSIZE_OUT_ID 1 14 | #define KSIZE_OUT_ID 2 15 | #define PCNT_OUT_ID 3 16 | #define CONV_TY_OUT_ID 4 17 | #define QS_OUT_ID 5 18 | #define WS_OUT_ID 6 19 | #define PT_OUT_ID 7 20 | #define VT_OUT_ID 8 21 | #define PS_OUT_ID 9 22 | #define POLY_PT_OUT_ID 10 23 | #define POLY_VT_OUT_ID 11 24 | #define POLY_PS_OUT_ID 12 25 | #define TOT_PT_OUT_ID 13 26 | #define TOT_VT_OUT_ID 14 27 | #define TOT_PS_OUT_ID 15 28 | 29 | using std::cerr; 30 | using std::endl; 31 | using std::vector; 32 | using std::string; 33 | using std::max; 34 | using std::min; 35 | using std::ifstream; 36 | using std::ofstream; 37 | using std::ostream; 38 | using std::pair; 39 | using std::make_pair; 40 | 41 | extern vector output_tb; 42 | 43 | #define F Fr 44 | #define G G1 45 | #define F_ONE (Fr::one()) 46 | #define F_ZERO (Fr(0)) 47 | 48 | #define F_BYTE_SIZE (Fr::getByteSize()) 49 | 50 | template 51 | string to_string_wp(const T a_value, const int n = 4) { 52 | std::ostringstream out; 53 | out.precision(n); 54 | out << std::fixed << a_value; 55 | return out.str(); 56 | } 57 | 58 | #endif //ZKCNN_GLOBAL_VAR_HPP 59 | -------------------------------------------------------------------------------- /src/main_demo_lenet.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 4/12/2021. 3 | // 4 | 5 | #include "circuit.h" 6 | #include "neuralNetwork.hpp" 7 | #include "verifier.hpp" 8 | #include "models.hpp" 9 | #include "global_var.hpp" 10 | 11 | // the arguments' format 12 | #define INPUT_FILE_ID 1 // the input filename 13 | #define CONFIG_FILE_ID 2 // the config filename 14 | #define OUTPUT_FILE_ID 3 // the input filename 15 | #define PIC_CNT 4 // the number of picture paralleled 16 | 17 | vector output_tb(16, ""); 18 | 19 | int main(int argc, char **argv) { 20 | initPairing(mcl::BLS12_381); 21 | 22 | char i_filename[500], c_filename[500], o_filename[500]; 23 | 24 | strcpy(i_filename, argv[INPUT_FILE_ID]); 25 | strcpy(c_filename, argv[CONFIG_FILE_ID]); 26 | strcpy(o_filename, argv[OUTPUT_FILE_ID]); 27 | 28 | int pic_cnt = atoi(argv[PIC_CNT]); 29 | 30 | output_tb[MO_INFO_OUT_ID] ="lenet (relu)"; 31 | output_tb[PCNT_OUT_ID] = std::to_string(pic_cnt); 32 | 33 | prover p; 34 | lenet nn(32, 32, 1, pic_cnt, MAX, i_filename, c_filename, o_filename); 35 | nn.create(p, false); 36 | verifier v(&p, p.C); 37 | v.verify(); 38 | 39 | for (auto &s: output_tb) printf("%s, ", s.c_str()); 40 | puts(""); 41 | } -------------------------------------------------------------------------------- /src/main_demo_vgg.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 4/12/2021. 3 | // 4 | 5 | #include "circuit.h" 6 | #include "neuralNetwork.hpp" 7 | #include "verifier.hpp" 8 | #include "models.hpp" 9 | #include "global_var.hpp" 10 | 11 | // the arguments' format 12 | #define INPUT_FILE_ID 1 // the input filename 13 | #define CONFIG_FILE_ID 2 // the config filename 14 | #define OUTPUT_FILE_ID 3 // the input filename 15 | #define NETWORK_FILE_ID 4 // the configuration of vgg 16 | #define PIC_CNT 5 // the number of picture paralleled 17 | 18 | vector output_tb(16, ""); 19 | 20 | int main(int argc, char **argv) { 21 | initPairing(mcl::BLS12_381); 22 | 23 | char i_filename[500], c_filename[500], o_filename[500], n_filename[500]; 24 | 25 | strcpy(i_filename, argv[INPUT_FILE_ID]); 26 | strcpy(c_filename, argv[CONFIG_FILE_ID]); 27 | strcpy(o_filename, argv[OUTPUT_FILE_ID]); 28 | strcpy(n_filename, argv[NETWORK_FILE_ID]); 29 | 30 | int pic_cnt = atoi(argv[PIC_CNT]); 31 | 32 | output_tb[MO_INFO_OUT_ID] ="vgg (relu)"; 33 | output_tb[PCNT_OUT_ID] = std::to_string(pic_cnt); 34 | 35 | prover p; 36 | vgg nn(32, 32, 3, pic_cnt, i_filename, c_filename, o_filename, n_filename); 37 | nn.create(p, false); 38 | verifier v(&p, p.C); 39 | v.verify(); 40 | 41 | for (auto &s: output_tb) printf("%s, ", s.c_str()); 42 | puts(""); 43 | } -------------------------------------------------------------------------------- /src/models.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/16/2021. 3 | // 4 | 5 | #include 6 | #include 7 | #include "models.hpp" 8 | #include "utils.hpp" 9 | #undef USE_VIRGO 10 | 11 | 12 | vgg::vgg(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename, 13 | const string &c_filename, const std::string &o_filename, const std::string &n_filename): 14 | neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { 15 | assert(psize_x == psize_y); 16 | conv_section.resize(5); 17 | 18 | ifstream config_in(n_filename); 19 | string con; 20 | i64 kernel_size = 3, ch_in = pic_channel, ch_out, new_nx = pic_size_x, new_ny = pic_size_y; 21 | convType conv_ty = kernel_size > 3 || pparallel > 1 ? FFT : NAIVE_FAST; 22 | 23 | int idx = 0; 24 | while (config_in >> con) { 25 | if (con[0] != 'M' && con[0] != 'A') { 26 | ch_out = stoi(con, nullptr, 10); 27 | conv_section[idx].emplace_back(conv_ty, ch_out, ch_in, kernel_size); 28 | ch_in = ch_out; 29 | } else { 30 | ++idx; 31 | pool.emplace_back(con[0] == 'M' ? MAX : AVG, 2, 1); 32 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 33 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 34 | } 35 | } 36 | 37 | assert(pic_size_x == 32); 38 | full_conn.emplace_back(512, new_nx * new_ny * ch_in); 39 | full_conn.emplace_back(512, 512); 40 | full_conn.emplace_back(10, 512); 41 | } 42 | 43 | vgg16::vgg16(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, poolType pool_ty, const std::string &i_filename, 44 | const string &c_filename, const std::string &o_filename) 45 | : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { 46 | assert(psize_x == psize_y); 47 | conv_section.resize(5); 48 | 49 | i64 start = 64, kernel_size = 3, new_nx = pic_size_x, new_ny = pic_size_y; 50 | convType conv_ty = kernel_size > 3 || pparallel > 1 ? FFT : NAIVE_FAST; 51 | 52 | conv_section[0].emplace_back(conv_ty, start, pic_channel, kernel_size); 53 | conv_section[0].emplace_back(conv_ty, start, start, kernel_size); 54 | pool.emplace_back(pool_ty, 2, 1); 55 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 56 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 57 | 58 | conv_section[1].emplace_back(conv_ty, start << 1, start, kernel_size); 59 | conv_section[1].emplace_back(conv_ty, start << 1, start << 1, kernel_size); 60 | pool.emplace_back(pool_ty, 2, 1); 61 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 62 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 63 | 64 | conv_section[2].emplace_back(conv_ty, start << 2, start << 1, kernel_size); 65 | conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size); 66 | conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size); 67 | pool.emplace_back(pool_ty, 2, 1); 68 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 69 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 70 | 71 | conv_section[3].emplace_back(conv_ty, start << 3, start << 2, 3); 72 | conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3); 73 | conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3); 74 | pool.emplace_back(pool_ty, 2, 1); 75 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 76 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 77 | 78 | conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); 79 | conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); 80 | conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); 81 | 82 | pool.emplace_back(pool_ty, 2, 1); 83 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 84 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 85 | 86 | if (pic_size_x == 224) { 87 | full_conn.emplace_back(4096, new_nx * new_ny * (start << 3)); 88 | full_conn.emplace_back(4096, 4096); 89 | full_conn.emplace_back(1000, 4096); 90 | } else { 91 | assert(pic_size_x == 32); 92 | full_conn.emplace_back(512, new_nx * new_ny * (start << 3)); 93 | full_conn.emplace_back(512, 512); 94 | full_conn.emplace_back(10, 512); 95 | } 96 | } 97 | 98 | vgg11::vgg11(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, poolType pool_ty, const std::string &i_filename, 99 | const string &c_filename, const std::string &o_filename) 100 | : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { 101 | assert(psize_x == psize_y); 102 | conv_section.resize(5); 103 | 104 | i64 start = 64, kernel_size = 3, new_nx = pic_size_x, new_ny = pic_size_y; 105 | convType conv_ty = kernel_size > 3 || pparallel > 1 ? FFT : NAIVE_FAST; 106 | 107 | conv_section[0].emplace_back(conv_ty, start, pic_channel, kernel_size); 108 | pool.emplace_back(pool_ty, 2, 1); 109 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 110 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 111 | 112 | conv_section[1].emplace_back(conv_ty, start << 1, start, kernel_size); 113 | pool.emplace_back(pool_ty, 2, 1); 114 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 115 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 116 | 117 | conv_section[2].emplace_back(conv_ty, start << 2, start << 1, kernel_size); 118 | conv_section[2].emplace_back(conv_ty, start << 2, start << 2, kernel_size); 119 | pool.emplace_back(pool_ty, 2, 1); 120 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 121 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 122 | 123 | conv_section[3].emplace_back(conv_ty, start << 3, start << 2, 3); 124 | conv_section[3].emplace_back(conv_ty, start << 3, start << 3, 3); 125 | pool.emplace_back(pool_ty, 2, 1); 126 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 127 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 128 | 129 | conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); 130 | conv_section[4].emplace_back(conv_ty, start << 3, start << 3, 3); 131 | 132 | pool.emplace_back(pool_ty, 2, 1); 133 | new_nx = ((new_nx - pool.back().size) >> pool.back().stride_bl) + 1; 134 | new_ny = ((new_ny - pool.back().size) >> pool.back().stride_bl) + 1; 135 | 136 | if (pic_size_x == 224) { 137 | full_conn.emplace_back(4096, new_nx * new_ny * (start << 3)); 138 | full_conn.emplace_back(4096, 4096); 139 | full_conn.emplace_back(1000, 4096); 140 | } else { 141 | assert(pic_size_x == 32); 142 | full_conn.emplace_back(512, new_nx * new_ny * (start << 3)); 143 | full_conn.emplace_back(512, 512); 144 | full_conn.emplace_back(10, 512); 145 | } 146 | } 147 | 148 | ccnn::ccnn(i64 psize_x, i64 psize_y, i64 pparallel, i64 pchannel, poolType pool_ty) : 149 | neuralNetwork(psize_x, psize_y, pchannel, pparallel, "", "", "") { 150 | conv_section.resize(1); 151 | 152 | i64 kernel_size = 2; 153 | convType conv_ty = kernel_size > 3 || pparallel > 1 ? FFT : NAIVE_FAST; 154 | conv_section[0].emplace_back(conv_ty, 2, pchannel, kernel_size, 0, 0); 155 | pool.emplace_back(pool_ty, 2, 1); 156 | 157 | // conv_section[1].emplace_back(FFT, 64, 4, 3); 158 | // conv_section[1].emplace_back(NAIVE, 64, 64, 3); 159 | // pool.emplace_back(pool_ty, 2, 1); 160 | 161 | // conv_section[0].emplace_back(FFT, 2, pic_channel, 3); 162 | // conv_section[1].emplace_back(NAIVE, 1, 2, 3); 163 | // pool.emplace_back(pool_ty, 2, 1); 164 | } 165 | 166 | lenet::lenet(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, poolType pool_ty, const std::string &i_filename, 167 | const string &c_filename, const std::string &o_filename) 168 | : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { 169 | conv_section.emplace_back(); 170 | 171 | i64 kernel_size = 5; 172 | convType conv_ty = kernel_size > 3 || pparallel > 1 ? FFT : NAIVE_FAST; 173 | 174 | if (psize_x == 28 && psize_y == 28) 175 | conv_section[0].emplace_back(conv_ty, 6, pchannel, kernel_size, 0, 2); 176 | else conv_section[0].emplace_back(conv_ty, 6, pchannel, kernel_size, 0, 0); 177 | pool.emplace_back(pool_ty, 2, 1); 178 | 179 | conv_section.emplace_back(); 180 | conv_section[1].emplace_back(conv_ty, 16, 6, kernel_size, 0, 0); 181 | pool.emplace_back(pool_ty, 2, 1); 182 | 183 | full_conn.emplace_back(120, 400); 184 | full_conn.emplace_back(84, 120); 185 | full_conn.emplace_back(10, 84); 186 | } 187 | 188 | lenetCifar::lenetCifar(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, poolType pool_ty, 189 | const std::string &i_filename, const string &c_filename, const std::string &o_filename) 190 | : neuralNetwork(psize_x, psize_y, pchannel, pparallel, i_filename, c_filename, o_filename) { 191 | conv_section.resize(3); 192 | 193 | i64 kernel_size = 5; 194 | convType conv_ty = kernel_size > 3 || pparallel > 1 ? FFT : NAIVE_FAST; 195 | 196 | conv_section[0].emplace_back(conv_ty, 6, pchannel, kernel_size, 0, 0); 197 | pool.emplace_back(pool_ty, 2, 1); 198 | 199 | conv_section[1].emplace_back(conv_ty, 16, 6, kernel_size, 0, 0); 200 | pool.emplace_back(pool_ty, 2, 1); 201 | 202 | conv_section[2].emplace_back(conv_ty, 120, 16, kernel_size, 0, 0); 203 | 204 | full_conn.emplace_back(84, 120); 205 | full_conn.emplace_back(10, 84); 206 | } 207 | 208 | void singleConv::createConv(prover &p) { 209 | initParamConv(); 210 | p.C.init(Q_BIT_SIZE, SIZE); 211 | 212 | p.val.resize(SIZE); 213 | val = p.val.begin(); 214 | two_mul = p.C.two_mul.begin(); 215 | 216 | i64 layer_id = 0; 217 | inputLayer(p.C.circuit[layer_id++]); 218 | 219 | new_nx_in = pic_size_x; 220 | new_ny_in = pic_size_y; 221 | pool_ty = NONE; 222 | for (i64 i = 0; i < conv_section.size(); ++i) { 223 | auto &sec = conv_section[i]; 224 | for (i64 j = 0; j < sec.size(); ++j) { 225 | auto &conv = sec[j]; 226 | refreshConvParam(new_nx_in, new_ny_in, conv); 227 | 228 | switch (conv.ty) { 229 | case FFT: 230 | paddingLayer(p.C.circuit[layer_id], layer_id, conv.weight_start_id); 231 | fftLayer(p.C.circuit[layer_id], layer_id); 232 | dotProdLayer(p.C.circuit[layer_id], layer_id); 233 | ifftLayer(p.C.circuit[layer_id], layer_id); 234 | break; 235 | case NAIVE_FAST: 236 | naiveConvLayerFast(p.C.circuit[layer_id], layer_id, conv.weight_start_id, conv.bias_start_id); 237 | break; 238 | default: 239 | naiveConvLayerMul(p.C.circuit[layer_id], layer_id, conv.weight_start_id); 240 | naiveConvLayerAdd(p.C.circuit[layer_id], layer_id, conv.bias_start_id); 241 | } 242 | } 243 | } 244 | p.C.initSubset(); 245 | // for (i64 i = 0; i < SIZE; ++i) { 246 | // cerr << i << "(" << p.C.circuit[i].zero_start_id << ", " << p.C.circuit[i].size << "):\t"; 247 | // for (i64 j = 0; j < std::min(100u, p.C.circuit[i].size); ++j) 248 | // cerr << p.val[i][j] << ' '; 249 | // cerr << endl; 250 | // bool flag = false; 251 | // for (i64 j = 0; j < p.C.circuit[i].size; ++j) 252 | // if (p.val[i][j] != F_ZERO) flag = true; 253 | // if (flag) cerr << "not all zero: " << i << endl; 254 | // for (i64 j = p.C.circuit[i].zero_start_id; j < p.C.circuit[i].size; ++j) 255 | // if (p.val[i][j] != F_ZERO) { cerr << "WRONG! " << i << ' ' << j << ' ' << p.val[i][j] << endl; exit(EXIT_FAILURE); } 256 | // } 257 | cerr << "finish creating circuit." << endl; 258 | } 259 | 260 | void singleConv::initParamConv() { 261 | i64 conv_layer_cnt = 0; 262 | total_in_size = 0; 263 | total_para_size = total_relu_in_size = total_ave_in_size = total_max_in_size = 0; 264 | 265 | // data 266 | i64 pos = pic_size_x * pic_size_y * pic_channel * pic_parallel; 267 | 268 | new_nx_in = pic_size_x; 269 | new_ny_in = pic_size_y; 270 | for (i64 i = 0; i < conv_section.size(); ++i) { 271 | auto &sec = conv_section[i]; 272 | for (i64 j = 0; j < sec.size(); ++j) { 273 | refreshConvParam(new_nx_in, new_ny_in, sec[j]); 274 | conv_layer_cnt += sec[j].ty == FFT ? FFT_SIZE - 1 : sec[j].ty == NAIVE ? NCONV_SIZE : NCONV_FAST_SIZE; 275 | // conv_kernel 276 | sec[j].weight_start_id = pos; 277 | pos += sqr(m) * channel_in * channel_out; 278 | total_para_size += sqr(m) * channel_in * channel_out; 279 | sec[j].bias_start_id = -1; 280 | } 281 | } 282 | total_in_size = pos; 283 | 284 | SIZE = 1 + conv_layer_cnt; 285 | cerr << "SIZE: " << SIZE << endl; 286 | } 287 | 288 | vector singleConv::getFFTAns(const vector &output) { 289 | vector res; 290 | res.resize(nx_out * ny_out * channel_out * pic_channel * pic_parallel); 291 | 292 | i64 lst_fft_lenh = getFFTLen() >> 1; 293 | i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; 294 | for (i64 p = 0; p < pic_parallel; ++p) 295 | for (i64 co = 0; co < channel_out; ++co) 296 | for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) 297 | for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { 298 | i64 idx = tesIdx(p, co, ((x - L) >> log_stride), ((y - L) >> log_stride), channel_out, nx_out, ny_out); 299 | i64 i = cubIdx(p, co, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_out, lst_fft_lenh); 300 | res[idx] = output[i]; 301 | } 302 | return res; 303 | } 304 | 305 | double singleConv::calcRawFFT() { 306 | auto in = val[0].begin(); 307 | auto conv = val[0].begin() + conv_section[0][0].weight_start_id; 308 | auto bias = val[0].begin() + conv_section[0][0].bias_start_id; 309 | 310 | timer tm; 311 | int logn = ceilPow2BitLength(nx_padded_in * ny_padded_in) + 1; 312 | vector res(nx_out * ny_out * channel_out * pic_parallel, F_ZERO); 313 | vector arr1(1 << logn, F_ZERO); 314 | vector arr2(arr1.size(), F_ZERO); 315 | 316 | assert(pic_parallel == 1 && pic_channel == 1); 317 | // data matrix 318 | i64 L = -padding; 319 | i64 Rx = nx_in + padding, Ry = ny_in + padding; 320 | 321 | tm.start(); 322 | for (i64 x = L; x < Rx; ++x) 323 | for (i64 y = L; y < Ry; ++y) 324 | if (check(x, y, nx_in, ny_in)) { 325 | i64 g = matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in); 326 | i64 u = matIdx(x, y, ny_in); 327 | arr1[g] = in[u]; 328 | } 329 | 330 | // kernel matrix 331 | for (i64 x = 0; x < nx_padded_in; ++x) 332 | for (i64 y = 0; y < ny_padded_in; ++y) 333 | if (check(x, y, m, m)) { 334 | i64 g = matIdx(x, y, ny_padded_in); 335 | i64 u = matIdx(x, y, m); 336 | arr2[g] = conv[u]; 337 | } 338 | 339 | fft(arr1, logn, false); 340 | fft(arr2, logn, false); 341 | for (i64 i = 0; i < arr1.size(); ++i) 342 | arr1[i] = arr1[i] * arr2[i]; 343 | fft(arr1, logn, true); 344 | reverse(arr1.begin(), arr1.end()); 345 | 346 | tm.stop(); 347 | return tm.elapse_sec(); 348 | } 349 | 350 | double singleConv::calcRawNaive() { 351 | auto in = val[0].begin(); 352 | auto conv = val[0].begin() + conv_section[0][0].weight_start_id; 353 | auto bias = val[0].begin() + conv_section[0][0].bias_start_id; 354 | 355 | timer tm; 356 | vector res(nx_out * ny_out * channel_out * pic_parallel, F_ZERO); 357 | tm.start(); 358 | for (i64 p = 0; p < pic_parallel; ++p) 359 | for (i64 i = 0; i < channel_out; ++i) 360 | for (i64 j = 0; j < channel_in; ++j) 361 | for (i64 x = -padding; x + m <= nx_in + padding; x += (1 << log_stride)) 362 | for (i64 y = -padding; y + m <= ny_in + padding; y += (1 << log_stride)) { 363 | i64 idx = tesIdx(p, i, (x + padding) >> log_stride, (y + padding) >> log_stride, channel_out, nx_out, ny_out); 364 | if (j == 0) res[idx] = res[idx] + bias[i]; 365 | for (i64 tx = x; tx < x + m; ++tx) 366 | for (i64 ty = y; ty < y + m; ++ty) 367 | if (check(tx, ty, nx_in, ny_in)) { 368 | i64 u = tesIdx(p, j, tx, ty, channel_in, nx_in, ny_in); 369 | i64 v = tesIdx(i, j, tx - x, ty - y, channel_in, m, m); 370 | res.at(idx) = res.at(idx) + in[u] * conv[v]; 371 | } 372 | } 373 | tm.stop(); 374 | return tm.elapse_sec(); 375 | } 376 | -------------------------------------------------------------------------------- /src/models.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/16/2021. 3 | // 4 | 5 | #ifndef ZKCNN_VGG_HPP 6 | #define ZKCNN_VGG_HPP 7 | 8 | #include "neuralNetwork.hpp" 9 | 10 | class vgg: public neuralNetwork { 11 | 12 | public: 13 | explicit vgg(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const std::string &i_filename, const string &c_filename, const std::string &o_filename, const std::string &n_filename); 14 | 15 | }; 16 | 17 | class vgg16: public neuralNetwork { 18 | 19 | public: 20 | explicit vgg16(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, poolType pool_ty, 21 | const std::string &i_filename, 22 | const string &c_filename, const std::string &o_filename); 23 | 24 | }; 25 | 26 | class vgg11: public neuralNetwork { 27 | 28 | public: 29 | explicit vgg11(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, poolType pool_ty, 30 | const std::string &i_filename, 31 | const string &c_filename, const std::string &o_filename); 32 | 33 | }; 34 | 35 | class lenet: public neuralNetwork { 36 | public: 37 | explicit lenet(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, poolType pool_ty, 38 | const std::string &i_filename, 39 | const string &c_filename, const std::string &o_filename); 40 | }; 41 | 42 | class lenetCifar: public neuralNetwork { 43 | public: 44 | explicit lenetCifar(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, poolType pool_ty, 45 | const std::string &i_filename, const string &c_filename, const std::string &o_filename); 46 | }; 47 | 48 | class ccnn: public neuralNetwork { 49 | public: 50 | explicit ccnn(i64 psize_x, i64 psize_y, i64 pparallel, i64 pchannel, poolType pool_ty); 51 | }; 52 | 53 | class singleConv: public neuralNetwork { 54 | public: 55 | explicit singleConv(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, i64 kernel_size, i64 channel_out, 56 | i64 log_stride, i64 padding, convType conv_ty); 57 | 58 | void createConv(prover &p); 59 | 60 | void initParamConv(); 61 | 62 | vector getFFTAns(const vector &output); 63 | 64 | double calcRawFFT(); 65 | 66 | double calcRawNaive(); 67 | }; 68 | 69 | #endif //ZKCNN_VGG_HPP 70 | -------------------------------------------------------------------------------- /src/neuralNetwork.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/16/2021. 3 | // 4 | 5 | #include "neuralNetwork.hpp" 6 | #include "utils.hpp" 7 | #include "global_var.hpp" 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using std::cerr; 14 | using std::endl; 15 | using std::max; 16 | using std::ifstream; 17 | using std::ofstream; 18 | 19 | ifstream in; 20 | ifstream conf; 21 | ofstream out; 22 | 23 | neuralNetwork::neuralNetwork(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename, 24 | const string &c_filename, const string &o_filename) : 25 | pic_size_x(psize_x), pic_size_y(psize_y), pic_channel(pchannel), pic_parallel(pparallel), 26 | SIZE(0), NCONV_FAST_SIZE(1), NCONV_SIZE(2), FFT_SIZE(5), 27 | AVE_POOL_SIZE(1), FC_SIZE(1), RELU_SIZE(1), act_ty(RELU_ACT) { 28 | 29 | in.open(i_filename); 30 | if (!in.is_open()) 31 | fprintf(stderr, "Can't find the input file!!!\n"); 32 | conf.open(c_filename); 33 | if (!conf.is_open()) 34 | fprintf(stderr, "Can't find the config file!!!\n"); 35 | 36 | if (!o_filename.empty()) out.open(o_filename); 37 | } 38 | 39 | neuralNetwork::neuralNetwork(i64 psize, i64 pchannel, i64 pparallel, i64 kernel_size, i64 sec_size, i64 fc_size, 40 | i64 start_channel, poolType pool_ty) 41 | : neuralNetwork(psize, psize, pchannel, pparallel, "", "", "") { 42 | pool_bl = 2; 43 | pool_stride_bl = pool_bl >> 1; 44 | conv_section.resize(sec_size); 45 | 46 | convType conv_ty = kernel_size > 3 || pparallel > 1 ? FFT : NAIVE_FAST; 47 | i64 start = start_channel; 48 | for (i64 i = 0; i < sec_size; ++i) { 49 | conv_section[i].emplace_back(conv_ty, start << i, i ? (start << (i - 1)) : pic_channel, kernel_size); 50 | conv_section[i].emplace_back(conv_ty, start << i, start << i, kernel_size); 51 | pool.emplace_back(pool_ty, 2, 1); 52 | } 53 | 54 | i64 new_nx = (pic_size_x >> pool_stride_bl * conv_section.size()); 55 | i64 new_ny = (pic_size_y >> pool_stride_bl * conv_section.size()); 56 | for (i64 i = 0; i < fc_size; ++i) 57 | full_conn.emplace_back(i == fc_size - 1 ? 1000 : 4096, i ? 4096 : new_nx * new_ny * (start << (sec_size - 1))); 58 | } 59 | 60 | void neuralNetwork::create(prover &pr, bool only_compute) { 61 | assert(pool.size() >= conv_section.size() - 1); 62 | 63 | initParam(); 64 | pr.C.init(Q_BIT_SIZE, SIZE); 65 | 66 | pr.val.resize(SIZE); 67 | val = pr.val.begin(); 68 | two_mul = pr.C.two_mul.begin(); 69 | 70 | i64 layer_id = 0; 71 | inputLayer(pr.C.circuit[layer_id++]); 72 | 73 | new_nx_in = pic_size_x; 74 | new_ny_in = pic_size_y; 75 | for (i64 i = 0; i < conv_section.size(); ++i) { 76 | auto &sec = conv_section[i]; 77 | for (i64 j = 0; j < sec.size(); ++j) { 78 | auto &conv = sec[j]; 79 | refreshConvParam(new_nx_in, new_ny_in, conv); 80 | pool_ty = i < pool.size() && j == sec.size() - 1 ? pool[i].ty : NONE; 81 | x_bit = x_next_bit; 82 | switch (conv.ty) { 83 | case FFT: 84 | paddingLayer(pr.C.circuit[layer_id], layer_id, conv.weight_start_id); 85 | fftLayer(pr.C.circuit[layer_id], layer_id); 86 | dotProdLayer(pr.C.circuit[layer_id], layer_id); 87 | ifftLayer(pr.C.circuit[layer_id], layer_id); 88 | addBiasLayer(pr.C.circuit[layer_id], layer_id, conv.bias_start_id); 89 | break; 90 | case NAIVE_FAST: 91 | naiveConvLayerFast(pr.C.circuit[layer_id], layer_id, conv.weight_start_id, conv.bias_start_id); 92 | break; 93 | default: 94 | naiveConvLayerMul(pr.C.circuit[layer_id], layer_id, conv.weight_start_id); 95 | naiveConvLayerAdd(pr.C.circuit[layer_id], layer_id, conv.bias_start_id); 96 | } 97 | 98 | // update the scale bit 99 | x_next_bit = getNextBit(layer_id - 1); 100 | T = x_bit + w_bit - x_next_bit; 101 | Q_MAX = Q + T; 102 | if (pool_ty != MAX) 103 | reluActConvLayer(pr.C.circuit[layer_id], layer_id); 104 | } 105 | 106 | if (i >= pool.size()) continue; 107 | calcSizeAfterPool(pool[i]); 108 | switch (pool[i].ty) { 109 | case AVG: avgPoolingLayer(pr.C.circuit[layer_id], layer_id); break; 110 | case MAX: maxPoolingLayer(pr.C, layer_id, pool[i].dcmp_start_id, pool[i].max_start_id, 111 | pool[i].max_dcmp_start_id); break; 112 | } 113 | } 114 | 115 | pool_ty = NONE; 116 | for (int i = 0; i < full_conn.size(); ++i) { 117 | auto &fc = full_conn[i]; 118 | refreshFCParam(fc); 119 | x_bit = x_next_bit; 120 | fullyConnLayer(pr.C.circuit[layer_id], layer_id, fc.weight_start_id, fc.bias_start_id); 121 | if (i == full_conn.size() - 1) break; 122 | 123 | // update the scale bit 124 | x_next_bit = getNextBit(layer_id - 1); 125 | T = x_bit + w_bit - x_next_bit; 126 | Q_MAX = Q + T; 127 | reluActFconLayer(pr.C.circuit[layer_id], layer_id); 128 | } 129 | 130 | assert(SIZE == layer_id); 131 | 132 | total_in_size += total_max_in_size + total_ave_in_size + total_relu_in_size; 133 | initLayer(pr.C.circuit[0], total_in_size, layerType::INPUT); 134 | assert(total_in_size == pr.val[0].size()); 135 | 136 | printInfer(pr); 137 | // printLayerValues(pr); 138 | 139 | if (only_compute) return; 140 | pr.C.initSubset(); 141 | cerr << "finish creating circuit." << endl; 142 | } 143 | 144 | void neuralNetwork::inputLayer(layer &circuit) { 145 | initLayer(circuit, total_in_size, layerType::INPUT); 146 | 147 | for (i64 i = 0; i < total_in_size; ++i) 148 | circuit.uni_gates.emplace_back(i, 0, 0, 0); 149 | 150 | calcInputLayer(circuit); 151 | printLayerInfo(circuit, 0); 152 | } 153 | 154 | void 155 | neuralNetwork::paddingLayer(layer &circuit, i64 &layer_id, i64 first_conv_id) { 156 | i64 lenh = getFFTLen() >> 1; 157 | i64 size = lenh * channel_in * (pic_parallel + channel_out); 158 | initLayer(circuit, size, layerType::PADDING); 159 | circuit.fft_bit_length = getFFTBitLen(); 160 | 161 | // data matrix 162 | i64 L = -padding; 163 | i64 Rx = nx_in + padding, Ry = ny_in + padding; 164 | 165 | for (i64 p = 0; p < pic_parallel; ++p) 166 | for (i64 ci = 0; ci < channel_in; ++ci) 167 | for (i64 x = L; x < Rx; ++x) 168 | for (i64 y = L; y < Ry; ++y) 169 | if (check(x, y, nx_in, ny_in)) { 170 | i64 g = cubIdx(p, ci, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_in, lenh); 171 | i64 u = tesIdx(p, ci, x, y, channel_in, nx_in, ny_in); 172 | circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0); 173 | } 174 | 175 | // kernel matrix 176 | i64 first = pic_parallel * channel_in * lenh; 177 | for (i64 co = 0; co < channel_out; ++co) 178 | for (i64 ci = 0; ci < channel_in; ++ci) 179 | for (i64 x = 0; x < nx_padded_in; ++x) 180 | for (i64 y = 0; y < ny_padded_in; ++y) 181 | if (check(x, y, m, m)) { 182 | i64 g = first + cubIdx(co, ci, matIdx(x, y, ny_padded_in), channel_in, lenh) ; 183 | i64 u = first_conv_id + tesIdx(co, ci, x, y, channel_in, m, m); 184 | circuit.uni_gates.emplace_back(g, u, 0, 0); 185 | } 186 | 187 | readConvWeight(first_conv_id); 188 | calcNormalLayer(circuit, layer_id); 189 | printLayerInfo(circuit, layer_id++); 190 | } 191 | 192 | void neuralNetwork::fftLayer(layer &circuit, i64 &layer_id) { 193 | i64 size = getFFTLen() * channel_in * (pic_parallel + channel_out); 194 | initLayer(circuit, size, layerType::FFT); 195 | circuit.fft_bit_length = getFFTBitLen(); 196 | 197 | calcFFTLayer(circuit, layer_id); 198 | printLayerInfo(circuit, layer_id++); 199 | } 200 | 201 | void neuralNetwork::dotProdLayer(layer &circuit, i64 &layer_id) { 202 | i64 len = getFFTLen(); 203 | i64 size = len * channel_out * pic_parallel; 204 | initLayer(circuit, size, layerType::DOT_PROD); 205 | circuit.need_phase2 = true; 206 | circuit.fft_bit_length = getFFTBitLen(); 207 | 208 | for (i64 p = 0; p < pic_parallel; ++p) 209 | for (i64 co = 0; co < channel_out; ++co) 210 | for (i64 ci = 0; ci < channel_in; ++ci) { 211 | i64 g = matIdx(p, co, channel_out); 212 | i64 u = matIdx(p, ci, channel_in); 213 | i64 v = matIdx(pic_parallel + co, ci, channel_in); 214 | circuit.bin_gates.emplace_back(g, u, v, 0, 1); 215 | } 216 | 217 | calcDotProdLayer(circuit, layer_id); 218 | printLayerInfo(circuit, layer_id++); 219 | } 220 | 221 | void neuralNetwork::ifftLayer(layer &circuit, i64 &layer_id) { 222 | i64 len = getFFTLen(), lenh = len >> 1; 223 | i64 size = lenh * channel_out * pic_parallel; 224 | initLayer(circuit, size, layerType::IFFT); 225 | circuit.fft_bit_length = getFFTBitLen(); 226 | F::inv(circuit.scale, F(1ULL << circuit.fft_bit_length)); 227 | 228 | calcFFTLayer(circuit, layer_id); 229 | printLayerInfo(circuit, layer_id++); 230 | } 231 | 232 | void neuralNetwork::addBiasLayer(layer &circuit, i64 &layer_id, i64 first_bias_id) { 233 | i64 len = getFFTLen(); 234 | i64 size = nx_out * ny_out * channel_out * pic_parallel; 235 | initLayer(circuit, size, layerType::ADD_BIAS); 236 | 237 | i64 lenh = len >> 1; 238 | i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; 239 | for (i64 p = 0; p < pic_parallel; ++p) 240 | for (i64 co = 0; co < channel_out; ++co) 241 | for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) 242 | for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { 243 | i64 u = cubIdx(p, co, matIdx(Rx - x - 1, Ry - y - 1, ny_padded_in), channel_out, lenh); 244 | i64 g = tesIdx(p, co, (x - L) >> log_stride, (y - L) >> log_stride, channel_out, nx_out, ny_out); 245 | circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0); 246 | circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0); 247 | } 248 | 249 | readBias(first_bias_id); 250 | calcNormalLayer(circuit, layer_id); 251 | printLayerInfo(circuit, layer_id++); 252 | } 253 | 254 | void neuralNetwork::naiveConvLayerFast(layer &circuit, i64 &layer_id, i64 first_conv_id, i64 first_bias_id) { 255 | i64 size = nx_out * ny_out * channel_out * pic_parallel; 256 | initLayer(circuit, size, layerType::NCONV); 257 | circuit.need_phase2 = true; 258 | 259 | i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; 260 | i64 mat_in_size = nx_in * ny_in; 261 | i64 m_sqr = sqr(m); 262 | for (i64 p = 0; p < pic_parallel; ++p) 263 | for (i64 co = 0; co < channel_out; ++co) 264 | for (i64 ci = 0; ci < channel_in; ++ci) 265 | for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) 266 | for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { 267 | i64 g = tesIdx(p, co, ((x - L) >> log_stride), ((y - L) >> log_stride), channel_out, nx_out, ny_out); 268 | if (ci == 0 && ~first_bias_id) circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0); 269 | for (i64 tx = x; tx < x + m; ++tx) 270 | for (i64 ty = y; ty < y + m; ++ty) 271 | if (check(tx, ty, nx_in, ny_in)) { 272 | i64 u = tesIdx(p, ci, tx, ty, channel_in, nx_in, ny_in); 273 | i64 v = first_conv_id + tesIdx(co, ci, tx - x, ty - y, channel_in, m, m); 274 | circuit.bin_gates.emplace_back(g, u, v, 0, 2 * (u8) (layer_id > 1)); 275 | } 276 | } 277 | 278 | readConvWeight(first_conv_id); 279 | if (~first_bias_id) readBias(first_bias_id); 280 | calcNormalLayer(circuit, layer_id); 281 | printLayerInfo(circuit, layer_id++); 282 | } 283 | 284 | void neuralNetwork::naiveConvLayerMul(layer &circuit, i64 &layer_id, i64 first_conv_id) { 285 | i64 mat_out_size = nx_out * ny_out; 286 | i64 mat_in_size = nx_in * ny_in; 287 | i64 m_sqr = sqr(m); 288 | i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; 289 | 290 | i64 g = 0; 291 | for (i64 p = 0; p < pic_parallel; ++p) 292 | for (i64 co = 0; co < channel_out; ++co) 293 | for (i64 ci = 0; ci < channel_in; ++ci) 294 | for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) 295 | for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) 296 | for (i64 tx = x; tx < x + m; ++tx) 297 | for (i64 ty = y; ty < y + m; ++ty) 298 | if (check(tx, ty, nx_in, ny_in)) { 299 | i64 u = tesIdx(p, ci, tx, ty, channel_in, nx_in, ny_in); 300 | i64 v = first_conv_id + tesIdx(co, ci, tx - x, ty - y, channel_in, m, m); 301 | circuit.bin_gates.emplace_back(g++, u, v, 0, 2 * (u8) (layer_id > 1)); 302 | } 303 | 304 | initLayer(circuit, g, layerType::NCONV_MUL); 305 | circuit.need_phase2 = true; 306 | readConvWeight(first_conv_id); 307 | calcNormalLayer(circuit, layer_id); 308 | printLayerInfo(circuit, layer_id++); 309 | } 310 | 311 | void neuralNetwork::naiveConvLayerAdd(layer &circuit, i64 &layer_id, i64 first_bias_id) { 312 | i64 size = nx_out * ny_out * channel_out * pic_parallel; 313 | initLayer(circuit, size, layerType::NCONV_ADD); 314 | 315 | i64 mat_in_size = nx_in * ny_in; 316 | i64 m_sqr = sqr(m); 317 | i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; 318 | 319 | i64 u = 0; 320 | for (i64 p = 0; p < pic_parallel; ++p) 321 | for (i64 co = 0; co < channel_out; ++co) 322 | for (i64 ci = 0; ci < channel_in; ++ci) 323 | for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) 324 | for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { 325 | i64 g = tesIdx(p, co, ((x - L) >> log_stride),( (y - L) >> log_stride), channel_out, nx_out, ny_out); 326 | i64 cnt = 0; 327 | if (ci == 0 && ~first_bias_id) { 328 | circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0); 329 | ++cnt; 330 | } 331 | for (i64 tx = x; tx < x + m; ++tx) 332 | for (i64 ty = y; ty < y + m; ++ty) 333 | if (check(tx, ty, nx_in, ny_in)) { 334 | circuit.uni_gates.emplace_back(g, u++, layer_id - 1, 0); 335 | ++cnt; 336 | } 337 | } 338 | 339 | if (~first_bias_id) readBias(first_bias_id); 340 | calcNormalLayer(circuit, layer_id); 341 | printLayerInfo(circuit, layer_id++); 342 | } 343 | 344 | void neuralNetwork::reluActConvLayer(layer &circuit, i64 &layer_id) { 345 | i64 mat_out_size = nx_out * ny_out; 346 | i64 size = 1L * mat_out_size * channel_out * (2 + Q_MAX) * pic_parallel; 347 | i64 block_len = mat_out_size * channel_out * pic_parallel; 348 | 349 | i64 dcmp_cnt = block_len * Q_MAX; 350 | i64 first_dcmp_id = val[0].size(); 351 | val[0].resize(val[0].size() + dcmp_cnt); 352 | total_relu_in_size += dcmp_cnt; 353 | 354 | initLayer(circuit, size, layerType::RELU); 355 | circuit.need_phase2 = true; 356 | 357 | circuit.zero_start_id = block_len; 358 | 359 | for (i64 g = 0; g < block_len; ++g) { 360 | i64 sign_u = first_dcmp_id + g * Q_MAX; 361 | for (i64 s = 1; s < Q; ++s) { 362 | i64 v = sign_u + s; 363 | circuit.uni_gates.emplace_back(g, v, 0, Q - 1 - s); 364 | circuit.bin_gates.emplace_back(g, sign_u, v, Q - s + Q_BIT_SIZE, 0); 365 | } 366 | } 367 | 368 | i64 len = getFFTLen(); 369 | i64 lenh = len >> 1; 370 | i64 L = -padding, Rx = nx_in + padding, Ry = ny_in + padding; 371 | for (i64 p = 0; p < pic_parallel; ++p) 372 | for (i64 co = 0; co < channel_out; ++ co) 373 | for (i64 x = L; x + m <= Rx; x += (1 << log_stride)) 374 | for (i64 y = L; y + m <= Ry; y += (1 << log_stride)) { 375 | i64 u = tesIdx(p, co, (x - L) >> log_stride, (y - L) >> log_stride, channel_out, nx_out, ny_out); 376 | i64 g = block_len + u, sign_v = first_dcmp_id + u * Q_MAX; 377 | circuit.uni_gates.emplace_back(g, u, layer_id - 1, Q_BIT_SIZE + 1); 378 | circuit.bin_gates.emplace_back(g, u, sign_v, 1, 2 * (u8) (layer_id > 1)); 379 | prepareSignBit(layer_id - 1, u, sign_v); 380 | for (i64 s = 1; s < Q_MAX; ++s) { 381 | i64 v = sign_v + s; 382 | circuit.uni_gates.emplace_back(g, v, 0, Q_MAX - s - 1); 383 | prepareDecmpBit(layer_id - 1, u, v, Q_MAX - s - 1); 384 | } 385 | } 386 | 387 | for (i64 g = block_len << 1; g < (block_len << 1) + block_len * Q_MAX; ++g) { 388 | i64 u = first_dcmp_id + g - (block_len << 1); 389 | circuit.bin_gates.emplace_back(g, u, u, 0, 0); 390 | circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1); 391 | } 392 | 393 | calcNormalLayer(circuit, layer_id); 394 | printLayerInfo(circuit, layer_id++); 395 | } 396 | 397 | void neuralNetwork::reluActFconLayer(layer &circuit, i64 &layer_id) { 398 | i64 block_len = channel_out * pic_parallel; 399 | i64 size = block_len * (2 + Q_MAX); 400 | initLayer(circuit, size, layerType::RELU); 401 | circuit.zero_start_id = block_len; 402 | circuit.need_phase2 = true; 403 | 404 | i64 dcmp_cnt = block_len * Q_MAX; 405 | i64 first_dcmp_id = val[0].size(); 406 | val[0].resize(val[0].size() + dcmp_cnt); 407 | total_relu_in_size += dcmp_cnt; 408 | 409 | for (i64 g = 0; g < block_len; ++g) { 410 | i64 sign_u = first_dcmp_id + g * Q_MAX; 411 | for (i64 s = 1; s < Q; ++s) { 412 | i64 v = sign_u + s; 413 | circuit.uni_gates.emplace_back(g, v, 0, (Q - s - 1)); 414 | circuit.bin_gates.emplace_back(g, sign_u, v, Q - s + Q_BIT_SIZE, 0); 415 | } 416 | } 417 | 418 | for (i64 u = 0; u < block_len; ++u) { 419 | i64 g = block_len + u, sign_v = first_dcmp_id + u * Q_MAX; 420 | circuit.uni_gates.emplace_back(g, u, layer_id - 1, Q_BIT_SIZE + 1); 421 | circuit.bin_gates.emplace_back(g, u, sign_v, 1, 2 * (u8) (layer_id > 1)); 422 | prepareSignBit(layer_id - 1, u, sign_v); 423 | 424 | for (i64 s = 1; s < Q_MAX; ++s) { 425 | i64 v = sign_v + s; 426 | circuit.uni_gates.emplace_back(g, v, 0, Q_MAX - s - 1); 427 | prepareDecmpBit(layer_id - 1, u, v, Q_MAX - s - 1); 428 | } 429 | } 430 | 431 | for (i64 g = block_len << 1; g < (block_len << 1) + block_len * Q_MAX; ++g) { 432 | i64 u = first_dcmp_id + g - (block_len << 1); 433 | circuit.bin_gates.emplace_back(g, u, u, 0, 0); 434 | circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1); 435 | } 436 | 437 | calcNormalLayer(circuit, layer_id); 438 | printLayerInfo(circuit, layer_id++); 439 | } 440 | 441 | void neuralNetwork::avgPoolingLayer(layer &circuit, i64 &layer_id) { 442 | i64 mat_out_size = nx_out * ny_out; 443 | i64 zero_start_id = new_nx_in * new_ny_in * channel_out * pic_parallel; 444 | i64 size = zero_start_id + getPoolDecmpSize(); 445 | u8 dpool_bl = pool_bl << 1; 446 | i64 pool_sz_sqr = sqr(pool_sz); 447 | initLayer(circuit, size, layerType::AVG_POOL); 448 | F::inv(circuit.scale, pool_sz_sqr); 449 | circuit.zero_start_id = zero_start_id; 450 | circuit.need_phase2 = true; 451 | 452 | i64 first_gate_id = val[0].size(); 453 | val[0].resize(val[0].size() + zero_start_id * dpool_bl); 454 | total_ave_in_size += zero_start_id * dpool_bl; 455 | 456 | // [0 .. zero_start_id] 457 | // [zero_start_id .. zero_start_id + (g = 0..channel_out * mat_new_size) * dpool_bl + rm_i .. channel_out * mat_new_size * (1 + dpool_bl)] 458 | for (i64 p = 0; p < pic_parallel; ++p) 459 | for (i64 co = 0; co < channel_out; ++co) 460 | for (i64 x = 0; x + pool_sz <= nx_out; x += pool_stride) 461 | for (i64 y = 0; y + pool_sz <= ny_out; y += pool_stride) { 462 | i64 g = tesIdx(p, co, (x >> pool_stride_bl), (y >> pool_stride_bl), channel_out, new_nx_in, new_ny_in); 463 | F data = F_ZERO; 464 | for (i64 tx = x; tx < x + pool_sz; ++tx) 465 | for (i64 ty = y; ty < y + pool_sz; ++ty) { 466 | i64 u = tesIdx(p, co, tx, ty, channel_out, nx_out, ny_out); 467 | circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0); 468 | data = data + val[layer_id - 1][u]; 469 | } 470 | 471 | for (i64 rm_i = 0; rm_i < dpool_bl; ++rm_i) { 472 | i64 idx = matIdx(g, rm_i, dpool_bl), u = first_gate_id + idx, g_bit = zero_start_id + idx; 473 | circuit.uni_gates.emplace_back(g, u, 0, dpool_bl - rm_i + Q_BIT_SIZE); 474 | prepareFieldBit(F(data), u, dpool_bl - rm_i - 1); 475 | 476 | // check bit 477 | circuit.bin_gates.emplace_back(g_bit, u, u, 0, 0); 478 | circuit.uni_gates.emplace_back(g_bit, u, 0, Q_BIT_SIZE + 1); 479 | } 480 | } 481 | 482 | calcNormalLayer(circuit, layer_id); 483 | printLayerInfo(circuit, layer_id++); 484 | } 485 | 486 | void 487 | neuralNetwork::maxPoolingLayer(layeredCircuit &C, i64 &layer_id, i64 first_dcmp_id, i64 first_max_id, 488 | i64 first_max_dcmp_id) { 489 | i64 mat_out_size = nx_out * ny_out; 490 | i64 tot_out_size = mat_out_size * channel_out * pic_parallel; 491 | i64 mat_new_size = new_nx_in * new_ny_in; 492 | i64 tot_new_size = mat_new_size * channel_out * pic_parallel; 493 | i64 pool_sz_sqr = sqr(pool_sz); 494 | 495 | i64 dcmp_cnt = getPoolDecmpSize(); 496 | first_dcmp_id = val[0].size(); 497 | val[0].resize(val[0].size() + dcmp_cnt); 498 | total_max_in_size += dcmp_cnt; 499 | 500 | i64 max_cnt = tot_new_size; 501 | first_max_id = val[0].size(); 502 | val[0].resize(val[0].size() + max_cnt); 503 | total_max_in_size += max_cnt; 504 | 505 | i64 max_dcmp_cnt = tot_new_size * (Q_MAX - 1); 506 | first_max_dcmp_id = val[0].size(); 507 | val[0].resize(val[0].size() + max_dcmp_cnt); 508 | total_max_in_size += max_dcmp_cnt; 509 | 510 | // 0: max - everyone & max - (max bits) == 0 511 | // [0..tot_new_size * sqr(pool_sz)][tot_new_size * sqr(pool_sz)..tot_new_size * sqr(pool_sz) + tot_new_size] 512 | i64 size_0 = tot_new_size * pool_sz_sqr + tot_new_size; 513 | layer &circuit = C.circuit[layer_id]; 514 | initLayer(circuit, size_0, layerType::MAX_POOL); 515 | circuit.zero_start_id = tot_new_size * pool_sz_sqr; 516 | i64 fft_len = getFFTLen(), fft_lenh = fft_len >> 1; 517 | for (i64 p = 0; p < pic_parallel; ++p) 518 | for (i64 co = 0; co < channel_out; ++co) { 519 | for (i64 x = 0; x + pool_sz <= nx_out; x += pool_stride) 520 | for (i64 y = 0; y + pool_sz <= ny_out; y += pool_stride) { 521 | i64 i_max = tesIdx(p, co, x >> pool_stride_bl, y >> pool_stride_bl, channel_out, new_nx_in, new_ny_in); 522 | i64 u_max = first_max_id + i_max; 523 | for (i64 tx = x; tx < x + pool_sz; ++tx) 524 | for (i64 ty = y; ty < y + pool_sz; ++ty) { 525 | i64 g = cubIdx(tesIdx(p, co, x >> pool_stride_bl, y >> pool_stride_bl, channel_out, new_nx_in, new_ny_in), tx - x, ty - y, pool_sz, pool_sz); 526 | i64 u_g = tesIdx(p, co, tx, ty, channel_out, nx_out, ny_out); 527 | circuit.uni_gates.emplace_back(g, u_max, 0, 0); 528 | circuit.uni_gates.emplace_back(g, u_g, layer_id - 1, Q_BIT_SIZE + 1); 529 | prepareMax(layer_id - 1, u_g, u_max); 530 | } 531 | } 532 | } 533 | 534 | for (i64 i_new = 0; i_new < tot_new_size; ++i_new) { 535 | i64 g_new = circuit.zero_start_id + i_new; 536 | i64 u_new = first_max_id + i_new; 537 | circuit.uni_gates.emplace_back(g_new, u_new, 0, Q_BIT_SIZE + 1); 538 | for (i64 i_new_bit = 0; i_new_bit < Q_MAX - 1; ++i_new_bit) { 539 | i64 u_new_bit = first_max_dcmp_id + matIdx(i_new, i_new_bit, Q_MAX - 1); 540 | circuit.uni_gates.emplace_back(g_new, u_new_bit, 0, Q_MAX - 2 - i_new_bit); 541 | prepareDecmpBit(0, u_new, u_new_bit, Q_MAX - 2 - i_new_bit); 542 | } 543 | } 544 | calcNormalLayer(circuit, layer_id); 545 | printLayerInfo(circuit, layer_id++); 546 | 547 | // 1: (max - someone)^2 & max - everyone - ((max - everyone) bits) == 0 548 | // [0..tot_new_size * (sqr(pool_sz) + 1 >> 1)][tot_new_size * (sqr(pool_sz) + 1 >> 1)..tot_new_size * (sqr(pool_sz) + 1 >> 1) + tot_new_size * sqr(pool_sz)] 549 | // 2: (max - someone)^4 550 | // ?: (max - someone)^(2^?) 551 | // [0..(((tot_out_size + 1) / 2 + 1) / 2...+ 1) / 2] 552 | // f: new tensor & (max - someone)^(pool_sz^2 + 1) & all (include minus and max) bits check 553 | // [0..tot_new_size] 554 | // [tot_new_size..tot_new_size * 2] 555 | // [tot_new_size * 2..tot_new_size * (Q + 1)] 556 | // [tot_new_size * (Q + 1) 557 | // ..tot_new_size * (Q + 1) + (g = 0..tot_out_size) * (Q - 1) + bit_i 558 | // ..tot_new_size * (Q + 1) + tot_new_size * (pool_sz^2) * (Q - 1)] 559 | i64 contain_max_ly = 1, ksize = pool_sz_sqr; 560 | while (!(ksize & 1)) { ksize >>= 1; ++contain_max_ly; } 561 | ksize = pool_sz_sqr; 562 | 563 | for (int i = 1; i < pool_layer_cnt; ++i) { 564 | layer &circuit = C.circuit[layer_id]; 565 | i64 size = tot_new_size * ( ((ksize + 1 )>> 1) + (i64) (i == 1) * ksize ) + 566 | (i64) (i == pool_layer_cnt - 1) * tot_new_size * Q_MAX + 567 | (i64) (i == pool_layer_cnt - 1) * tot_new_size * pool_sz_sqr * (Q_MAX - 1); 568 | initLayer(circuit, size, layerType::MAX_POOL); 569 | circuit.need_phase2 = true; 570 | 571 | // new tensor 572 | i64 before_mul = 0; 573 | if (i == pool_layer_cnt - 1) { 574 | before_mul = tot_new_size; 575 | for (i64 g = 0; g < tot_new_size; ++g) 576 | for (i64 j = 0; j < Q - 1; ++j) { 577 | i64 u = first_max_dcmp_id + matIdx(g, j, Q_MAX - 1); 578 | circuit.uni_gates.emplace_back(g, u, 0, Q - 2 - j); 579 | } 580 | } 581 | 582 | // multiplications of subtraction 583 | for (i64 cnt = 0; cnt < tot_new_size; ++cnt) { 584 | i64 v_max = first_max_id + cnt; 585 | for (i64 j = 0; (j << 1) < ksize; ++j) { 586 | i64 idx = matIdx(cnt, j, (ksize + 1) >> 1); 587 | i64 g = before_mul + idx; 588 | i64 u = matIdx(cnt, (j << 1), ksize); 589 | if ((j << 1 | 1) < ksize) { 590 | i64 v = matIdx(cnt, (j << 1 | 1), ksize); 591 | circuit.bin_gates.emplace_back(g, u, v, 0, layer_id > 1); 592 | } else if (i == contain_max_ly) 593 | circuit.bin_gates.emplace_back(g, u, v_max, 0, 2 * (u8) (layer_id > 1)); 594 | else 595 | circuit.uni_gates.emplace_back(g, u, layer_id - 1, 0); 596 | } 597 | } 598 | 599 | if (i == 1) { 600 | i64 minus_cnt = tot_new_size * ksize; 601 | i64 minus_new_cnt = tot_new_size * ((ksize + 1) >> 1); 602 | circuit.zero_start_id = minus_new_cnt; 603 | for (i64 v = 0; v < minus_cnt; ++v) { 604 | i64 g = minus_new_cnt + v; 605 | circuit.uni_gates.emplace_back(g, v, layer_id - 1, Q_BIT_SIZE + 1); 606 | for (i64 bit_j = 0; bit_j < Q_MAX - 1; ++bit_j) { 607 | i64 u = first_dcmp_id + matIdx(v, bit_j, Q_MAX - 1); 608 | circuit.uni_gates.emplace_back(g, u, 0, Q_MAX - 2 - bit_j); 609 | prepareDecmpBit(layer_id - 1, v, u, Q_MAX - 2 - bit_j); 610 | } 611 | } 612 | } else if (i == pool_layer_cnt - 1) { 613 | i64 minus_cnt = tot_new_size * pool_sz_sqr; 614 | circuit.zero_start_id = before_mul; 615 | for (i64 j = 0; j < minus_cnt; ++j) { 616 | i64 g = before_mul + tot_new_size + j; 617 | i64 u = first_dcmp_id + j; 618 | circuit.bin_gates.emplace_back(g, u, u, 0, 0); 619 | circuit.uni_gates.emplace_back(g, u, 0, Q_BIT_SIZE + 1); 620 | } 621 | } 622 | ksize = (ksize + 1) >> 1; 623 | calcNormalLayer(circuit, layer_id); 624 | printLayerInfo(circuit, layer_id++); 625 | } 626 | 627 | } 628 | 629 | void neuralNetwork::fullyConnLayer(layer &circuit, i64 &layer_id, i64 first_fc_id, i64 first_bias_id) { 630 | i64 size = channel_out * pic_parallel; 631 | initLayer(circuit, size, layerType::FCONN); 632 | circuit.need_phase2 = true; 633 | 634 | for (i64 p = 0; p < pic_parallel; ++p) 635 | for (i64 co = 0; co < channel_out; ++co) { 636 | i64 g = matIdx(p, co, channel_out); 637 | circuit.uni_gates.emplace_back(g, first_bias_id + co, 0, 0); 638 | for (i64 ci = 0; ci < channel_in; ++ci) { 639 | i64 u = matIdx(p, ci, channel_in); 640 | i64 v = first_fc_id + matIdx(co, ci, channel_in); 641 | circuit.bin_gates.emplace_back(g, u, v, 0, 2 * (u8) (layer_id > 1)); 642 | } 643 | } 644 | 645 | readFconWeight(first_fc_id); 646 | readBias(first_bias_id); 647 | calcNormalLayer(circuit, layer_id); 648 | printLayerInfo(circuit, layer_id++); 649 | } 650 | 651 | void 652 | neuralNetwork::refreshConvParam(i64 new_nx, i64 new_ny, const convKernel &conv) { 653 | nx_in = new_nx; 654 | ny_in = new_ny; 655 | padding = conv.padding; 656 | nx_padded_in = nx_in + (conv.padding * 2); 657 | ny_padded_in = ny_in + (conv.padding * 2); 658 | 659 | m = conv.size; 660 | channel_in = conv.channel_in; 661 | channel_out = conv.channel_out; 662 | log_stride = conv.stride_bl; 663 | 664 | nx_out = ((nx_padded_in - m) >> log_stride) + 1; 665 | ny_out = ((ny_padded_in - m) >> log_stride) + 1; 666 | 667 | new_nx_in = nx_out; 668 | new_ny_in = ny_out; 669 | conv_layer_cnt = conv.ty == FFT ? FFT_SIZE : conv.ty == NAIVE ? NCONV_SIZE : NCONV_FAST_SIZE; 670 | } 671 | 672 | void neuralNetwork::refreshFCParam(const fconKernel &fc) { 673 | nx_in = nx_out = m = 1; 674 | ny_in = ny_out = 1; 675 | channel_in = fc.channel_in; 676 | channel_out = fc.channel_out; 677 | } 678 | 679 | i64 neuralNetwork::getFFTLen() const { 680 | return 1L << getFFTBitLen(); 681 | } 682 | 683 | i8 neuralNetwork::getFFTBitLen() const { 684 | return ceilPow2BitLength( (u32)nx_padded_in * ny_padded_in ) + 1; 685 | } 686 | 687 | // input: [data] 688 | // [[conv_kernel || relu_conv_bit_decmp]{sec.size()}[max_pool]{if maxPool}[pool_bit_decmp]]{conv_section.size()} 689 | // [fc_kernel || relu_fc_bit_decmp] 690 | void neuralNetwork::initParam() { 691 | act_layer_cnt = RELU_SIZE; 692 | i64 total_conv_layer_cnt = 0, total_pool_layer_cnt = 0; 693 | total_in_size = 0; 694 | total_para_size = 0; 695 | total_relu_in_size = 0; 696 | total_ave_in_size = 0; 697 | total_max_in_size = 0; 698 | 699 | // data 700 | i64 pos = pic_size_x * pic_size_y * pic_channel * pic_parallel; 701 | 702 | new_nx_in = pic_size_x; 703 | new_ny_in = pic_size_y; 704 | for (i64 i = 0; i < conv_section.size(); ++i) { 705 | auto &sec = conv_section[i]; 706 | for (i64 j = 0; j < sec.size(); ++j) { 707 | refreshConvParam(new_nx_in, new_ny_in, sec[j]); 708 | // conv_kernel 709 | sec[j].weight_start_id = pos; 710 | u32 para_size = sqr(m) * channel_in * channel_out; 711 | pos += para_size; 712 | total_para_size += para_size; 713 | fprintf(stderr, "kernel weight: %11d%11lld\n", para_size, total_para_size); 714 | 715 | sec[j].bias_start_id = pos; 716 | pos += channel_out; 717 | total_para_size += channel_out; 718 | fprintf(stderr, "bias weight: %11lld%11lld\n", channel_out, total_para_size); 719 | } 720 | 721 | total_conv_layer_cnt += sec.size() * (conv_layer_cnt + act_layer_cnt); 722 | 723 | if (i >= pool.size()) continue; 724 | calcSizeAfterPool(pool[i]); 725 | total_pool_layer_cnt += pool_layer_cnt; 726 | if (pool[i].ty == MAX) 727 | if (act_ty == RELU_ACT) total_conv_layer_cnt -= act_layer_cnt; 728 | } 729 | 730 | for (int i = 0; i < full_conn.size(); ++i) { 731 | auto &fc = full_conn[i]; 732 | refreshFCParam(fc); 733 | // fc_kernel 734 | fc.weight_start_id = pos; 735 | u32 para_size = channel_out * channel_in; 736 | pos += para_size; 737 | total_para_size += para_size; 738 | fprintf(stderr, "kernel weight: %11d%11lld\n", para_size, total_para_size); 739 | fc.bias_start_id = pos; 740 | pos += channel_out; 741 | total_para_size += channel_out; 742 | fprintf(stderr, "bias weight: %11lld%11lld\n", channel_out, total_para_size); 743 | if (i == full_conn.size() - 1) break; 744 | } 745 | total_in_size = pos; 746 | 747 | SIZE = 1 + total_conv_layer_cnt + total_pool_layer_cnt + (FC_SIZE + RELU_SIZE) * full_conn.size(); 748 | if (!full_conn.empty()) SIZE -= RELU_SIZE; 749 | cerr << "SIZE: " << SIZE << endl; 750 | } 751 | 752 | void neuralNetwork::printLayerInfo(const layer &circuit, i64 layer_id) { 753 | // fprintf(stderr, "+ %2lld " , layer_id); 754 | // switch (circuit.ty) { 755 | // case layerType::INPUT: fprintf(stderr, "inputLayer "); break; 756 | // case layerType::PADDING: fprintf(stderr, "paddingLayer "); break; 757 | // case layerType::FFT: fprintf(stderr, "fftLayer "); break; 758 | // case layerType::DOT_PROD: fprintf(stderr, "dotProdLayer "); break; 759 | // case layerType::IFFT: fprintf(stderr, "ifftLayer "); break; 760 | // case layerType::ADD_BIAS: fprintf(stderr, "addBiasLayer "); break; 761 | // case layerType::RELU: fprintf(stderr, "reluActLayer "); break; 762 | // case layerType::Sqr: fprintf(stderr, "squareActLayer "); break; 763 | // case layerType::OPT_AVG_POOL: fprintf(stderr, "avgOptPoolingLayer "); break; 764 | // case layerType::AVG_POOL: fprintf(stderr, "avgPoolingLayer "); break; 765 | // case layerType::MAX_POOL: fprintf(stderr, "maxPoolingLayer "); break; 766 | // case layerType::FCONN: fprintf(stderr, "fullyConnLayer "); break; 767 | // case layerType::NCONV: fprintf(stderr, "naiveConvFast "); break; 768 | // case layerType::NCONV_MUL: fprintf(stderr, "naiveConvMul "); break; 769 | // case layerType::NCONV_ADD: fprintf(stderr, "naiveConvAdd "); break; 770 | //m 771 | // } 772 | // fprintf(stderr, "%11u (2^%2d)\n", circuit.size, (int) circuit.bit_length); 773 | } 774 | 775 | void neuralNetwork::printWitnessInfo(const layer &circuit) const { 776 | assert(circuit.size == total_in_size); 777 | u32 total_data_in_size = total_in_size - total_relu_in_size - total_ave_in_size - total_max_in_size; 778 | fprintf(stderr,"%u (2^%2d) = %u (%.2f%% data) + %lld (%.2f%% relu) + %lld (%.2f%% ave) + %lld (%.2f%% max), ", 779 | circuit.size, circuit.bit_length, total_data_in_size, 100.0 * total_data_in_size / (double) total_in_size, 780 | total_relu_in_size, 100.0 * total_relu_in_size / (double) total_in_size, 781 | total_ave_in_size, 100.0 * total_ave_in_size / (double) total_in_size, 782 | total_max_in_size, 100.0 * total_max_in_size / (double) total_in_size); 783 | output_tb[WS_OUT_ID] = std::to_string(circuit.size) + "(2^" + std::to_string(ceilPow2BitLength(circuit.size)) + ")"; 784 | } 785 | 786 | i64 neuralNetwork::getPoolDecmpSize() const { 787 | switch (pool_ty) { 788 | case AVG: return new_nx_in * new_ny_in * (pool_bl << 1) * channel_out * pic_parallel; 789 | case MAX: return new_nx_in * new_ny_in * sqr(pool_sz) * channel_out * pic_parallel * (Q_MAX - 1); 790 | default: 791 | assert(false); 792 | } 793 | } 794 | 795 | void neuralNetwork::calcSizeAfterPool(const poolKernel &p) { 796 | pool_sz = p.size; 797 | pool_bl = ceilPow2BitLength(pool_sz); 798 | pool_stride_bl = p.stride_bl; 799 | pool_stride = 1 << p.stride_bl; 800 | pool_layer_cnt = p.ty == MAX ? 1 + ceilPow2BitLength(sqr(p.size) + 1) : AVE_POOL_SIZE; 801 | new_nx_in = ((nx_out - pool_sz) >> pool_stride_bl) + 1; 802 | new_ny_in = ((ny_out - pool_sz) >> pool_stride_bl) + 1; 803 | } 804 | 805 | void neuralNetwork::calcInputLayer(layer &circuit) { 806 | val[0].resize(circuit.size); 807 | 808 | assert(val[0].size() == total_in_size); 809 | auto val_0 = val[0].begin(); 810 | 811 | double num, mx = -10000, mn = 10000; 812 | vector input_dat; 813 | for (i64 ci = 0; ci < pic_channel; ++ci) 814 | for (i64 x = 0; x < pic_size_x; ++x) 815 | for (i64 y = 0; y < pic_size_y; ++y) { 816 | in >> num; 817 | input_dat.push_back(num); 818 | mx = max(mx, num); 819 | mn = min(mn, num); 820 | } 821 | 822 | // (mx - mn) * 2^i <= 2^Q - 1 823 | // quant_shr = i 824 | x_next_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2)); 825 | if ((int) ((mx - mn) * exp2(x_next_bit)) > (1 << (Q - 1)) - 1) --x_next_bit; 826 | 827 | for (i64 p = 0; p < pic_parallel; ++p) { 828 | i64 i = 0; 829 | for (i64 ci = 0; ci < pic_channel; ++ci) 830 | for (i64 x = 0; x < pic_size_x; ++x) 831 | for (i64 y = 0; y < pic_size_y; ++y) 832 | *val_0++ = F((i64)(input_dat[i++] * exp2(x_next_bit))); 833 | } 834 | for (; val_0 < val[0].begin() + circuit.size; ++val_0) val_0 -> clear(); 835 | } 836 | 837 | 838 | void neuralNetwork::readConvWeight(i64 first_conv_id) { 839 | auto val_0 = val[0].begin() + first_conv_id; 840 | 841 | double num, mx = -10000, mn = 10000; 842 | vector input_dat; 843 | for (i64 co = 0; co < channel_out; ++co) 844 | for (i64 ci = 0; ci < channel_in; ++ci) 845 | for (i64 x = 0; x < m; ++x) 846 | for (i64 y = 0; y < m; ++y) { 847 | in >> num; 848 | input_dat.push_back(num); 849 | mx = max(mx, num); 850 | mn = min(mn, num); 851 | } 852 | 853 | // (mx - mn) * 2^i <= 2^Q - 1 854 | // quant_shr = i 855 | w_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2)); 856 | if ((int) ((mx - mn) * exp2(w_bit)) > (1 << (Q - 1)) - 1) --w_bit; 857 | 858 | for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit))); 859 | 860 | } 861 | 862 | void neuralNetwork::readBias(i64 first_bias_id) { 863 | auto val_0 = val[0].begin() + first_bias_id; 864 | 865 | double num, mx = -10000, mn = 10000; 866 | vector input_dat; 867 | for (i64 co = 0; co < channel_out; ++co) { 868 | in >> num; 869 | input_dat.push_back(num); 870 | mx = max(mx, num); 871 | mn = min(mn, num); 872 | } 873 | 874 | for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit + x_bit))); 875 | 876 | } 877 | 878 | void neuralNetwork::readFconWeight(i64 first_fc_id) { 879 | double num, mx = -10000, mn = 10000; 880 | auto val_0 = val[0].begin() + first_fc_id; 881 | 882 | vector input_dat; 883 | for (i64 co = 0; co < channel_out; ++co) 884 | for (i64 ci = 0; ci < channel_in; ++ci) { 885 | in >> num; 886 | input_dat.push_back(num); 887 | mx = max(mx, num); 888 | mn = min(mn, num); 889 | } 890 | 891 | // (mx - mn) * 2^i <= 2^Q - 1 892 | // quant_shr = i 893 | w_bit = (int) (log( ((1 << (Q - 1)) - 1) / (mx - mn) ) / log(2)); 894 | if ((int) ((mx - mn) * exp2(w_bit)) > (1 << (Q - 1)) - 1) --w_bit; 895 | 896 | for (double i : input_dat) *val_0++ = F((i64) (i * exp2(w_bit))); 897 | } 898 | 899 | void neuralNetwork::prepareDecmpBit(i64 layer_id, i64 idx, i64 dcmp_id, i64 bit_shift) { 900 | auto data = abs(val[layer_id].at(idx).getInt64()); 901 | val[0].at(dcmp_id) = (data >> bit_shift) & 1; 902 | } 903 | 904 | void neuralNetwork::prepareFieldBit(const F &data, i64 dcmp_id, i64 bit_shift) { 905 | auto tmp = abs(data.getInt64()); 906 | val[0].at(dcmp_id) = (tmp >> bit_shift) & 1; 907 | } 908 | 909 | void neuralNetwork::prepareSignBit(i64 layer_id, i64 idx, i64 dcmp_id) { 910 | val[0].at(dcmp_id) = val[layer_id].at(idx).isNegative() ? F_ONE : F_ZERO; 911 | } 912 | 913 | void neuralNetwork::prepareMax(i64 layer_id, i64 idx, i64 max_id) { 914 | auto data = val[layer_id].at(idx).isNegative() ? F_ZERO : val[layer_id].at(idx); 915 | if (data > val[0].at(max_id)) val[0].at(max_id) = data; 916 | } 917 | 918 | void neuralNetwork::calcNormalLayer(const layer &circuit, i64 layer_id) { 919 | val[layer_id].resize(circuit.size); 920 | for (auto &x: val[layer_id]) x.clear(); 921 | 922 | for (auto &gate: circuit.uni_gates) { 923 | val[layer_id].at(gate.g) = val[layer_id].at(gate.g) + val[gate.lu].at(gate.u) * two_mul[gate.sc]; 924 | } 925 | 926 | 927 | for (auto &gate: circuit.bin_gates) { 928 | u8 bin_lu = gate.getLayerIdU(layer_id), bin_lv = gate.getLayerIdV(layer_id); 929 | val[layer_id].at(gate.g) = val[layer_id].at(gate.g) + val[bin_lu].at(gate.u) * val[bin_lv][gate.v] * two_mul[gate.sc]; 930 | } 931 | 932 | F mx_val = F_ZERO, mn_val = F_ZERO; 933 | for (i64 g = 0; g < circuit.size; ++g) 934 | val[layer_id].at(g) = val[layer_id].at(g) * circuit.scale; 935 | } 936 | 937 | void neuralNetwork::calcDotProdLayer(const layer &circuit, i64 layer_id) { 938 | val[layer_id].resize(circuit.size); 939 | for (int i = 0; i < circuit.size; ++i) val[layer_id][i].clear(); 940 | 941 | char fft_bit = circuit.fft_bit_length; 942 | u32 fft_len = 1 << fft_bit; 943 | u8 l = layer_id - 1; 944 | for (auto &gate: circuit.bin_gates) 945 | for (int s = 0; s < fft_len; ++s) 946 | val[layer_id][gate.g << fft_bit | s] = val[layer_id][gate.g << fft_bit | s] + 947 | val[l][gate.u << fft_bit | s] * val[l][gate.v << fft_bit | s]; 948 | } 949 | 950 | void neuralNetwork::calcFFTLayer(const layer &circuit, i64 layer_id) { 951 | i64 fft_len = 1ULL << circuit.fft_bit_length; 952 | i64 fft_lenh = fft_len >> 1; 953 | val[layer_id].resize(circuit.size); 954 | std::vector arr(fft_len, F_ZERO); 955 | if (circuit.ty == layerType::FFT) for (i64 c = 0, d = 0; d < circuit.size; c += fft_lenh, d += fft_len) { 956 | for (i64 j = c; j < c + fft_lenh; ++j) arr[j - c] = val[layer_id - 1].at(j); 957 | for (i64 j = fft_lenh; j < fft_len; ++j) arr[j].clear(); 958 | fft(arr, circuit.fft_bit_length, circuit.ty == layerType::IFFT); 959 | for (i64 j = d; j < d + fft_len; ++j) val[layer_id].at(j) = arr[j - d]; 960 | } else for (u32 c = 0, d = 0; c < circuit.size; c += fft_lenh, d += fft_len) { 961 | for (i64 j = d; j < d + fft_len; ++j) arr[j - d] = val[layer_id - 1].at(j); 962 | fft(arr, circuit.fft_bit_length, circuit.ty == layerType::IFFT); 963 | for (i64 j = c; j < c + fft_lenh; ++j) val[layer_id].at(j) = arr[j - c]; 964 | } 965 | } 966 | 967 | int neuralNetwork::getNextBit(int layer_id) { 968 | F mx = F_ZERO, mn = F_ZERO; 969 | for (const auto &x: val[layer_id]) { 970 | if (!x.isNegative()) mx = max(mx, x); 971 | else mn = max(mn, -x); 972 | } 973 | i64 x = (mx + mn).getInt64(); 974 | double real_scale = x / exp2(x_bit + w_bit); 975 | int res = (int) log2( ((1 << (Q - 1)) - 1) / real_scale ); 976 | return res; 977 | } 978 | 979 | void neuralNetwork::printLayerValues(prover &pr) { 980 | for (i64 i = 0; i < SIZE; ++i) { 981 | // if (pr.C.circuit[i].ty == layerType::FCONN || pr.C.circuit[i].ty == layerType::ADD_BIAS || i && i < SIZE - 1 && pr.C.circuit[i + 1].ty == layerType::PADDING) { 982 | cerr << i << "(" << pr.C.circuit[i].zero_start_id << ", " << pr.C.circuit[i].size << "):\t"; 983 | for (i64 j = 0; j < std::min(200u, pr.C.circuit[i].size); ++j) 984 | if (!pr.val[i][j].isZero()) cerr << pr.val[i][j] << ' '; 985 | cerr << endl; 986 | for (i64 j = pr.C.circuit[i].zero_start_id; j < pr.C.circuit[i].size; ++j) 987 | if (pr.val[i].at(j) != F_ZERO) { 988 | cerr << "WRONG! " << i << ' ' << j << ' ' << (-pr.val[i][j] * F_ONE) << endl; 989 | exit(EXIT_FAILURE); 990 | } 991 | } 992 | } 993 | 994 | void neuralNetwork::printInfer(prover &pr) { 995 | // output the inference result with the size of (pic_parallel x n_class) 996 | if (out.is_open()) { 997 | int n_class = full_conn.back().channel_out; 998 | for (int p = 0; p < pic_parallel; ++p) { 999 | int k = -1; 1000 | F v; 1001 | for (int c = 0; c < n_class; ++c) { 1002 | auto tmp = val[SIZE - 1].at(matIdx(p, c, n_class)); 1003 | if (!tmp.isNegative() && (k == -1 || v < tmp)) { 1004 | k = c; 1005 | v = tmp; 1006 | } 1007 | } 1008 | out << k << endl; 1009 | 1010 | // output one-hot 1011 | // for (int c = 0; c < n_class; ++c) out << (k == c) << ' '; 1012 | // out << endl; 1013 | } 1014 | } 1015 | out.close(); 1016 | printWitnessInfo(pr.C.circuit[0]); 1017 | } -------------------------------------------------------------------------------- /src/neuralNetwork.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/16/2021. 3 | // 4 | 5 | #ifndef ZKCNN_NEURALNETWORK_HPP 6 | #define ZKCNN_NEURALNETWORK_HPP 7 | 8 | #include 9 | #include 10 | #include "circuit.h" 11 | #include "prover.hpp" 12 | 13 | using std::vector; 14 | using std::tuple; 15 | using std::pair; 16 | 17 | enum convType { 18 | FFT, NAIVE, NAIVE_FAST 19 | }; 20 | 21 | struct convKernel { 22 | convType ty; 23 | i64 channel_out, channel_in, size, stride_bl, padding, weight_start_id, bias_start_id; 24 | convKernel(convType _ty, i64 _channel_out, i64 _channel_in, i64 _size, i64 _log_stride, i64 _padding) : 25 | ty(_ty), channel_out(_channel_out), channel_in(_channel_in), size(_size), stride_bl(_log_stride), padding(_padding) { 26 | } 27 | 28 | convKernel(convType _ty, i64 _channel_out, i64 _channel_in, i64 _size) : 29 | convKernel(_ty, _channel_out, _channel_in, _size, 0, _size >> 1) { 30 | } 31 | }; 32 | 33 | struct fconKernel { 34 | i64 channel_out, channel_in, weight_start_id, bias_start_id; 35 | fconKernel(i64 _channel_out, i64 _channel_in): 36 | channel_out(_channel_out), channel_in(_channel_in) {} 37 | }; 38 | 39 | enum poolType { 40 | AVG, MAX, NONE 41 | }; 42 | 43 | enum actType { 44 | RELU_ACT 45 | }; 46 | 47 | struct poolKernel { 48 | poolType ty; 49 | i64 size, stride_bl, dcmp_start_id, max_start_id, max_dcmp_start_id; 50 | poolKernel(poolType _ty, i64 _size, i64 _log_stride): 51 | ty(_ty), size(_size), stride_bl(_log_stride) {} 52 | }; 53 | 54 | 55 | class neuralNetwork { 56 | public: 57 | explicit neuralNetwork(i64 psize_x, i64 psize_y, i64 pchannel, i64 pparallel, const string &i_filename, 58 | const string &c_filename, const string &o_filename); 59 | 60 | neuralNetwork(i64 psize, i64 pchannel, i64 pparallel, i64 kernel_size, i64 sec_size, i64 fc_size, 61 | i64 start_channel, poolType pool_ty); 62 | 63 | void create(prover &pr, bool only_compute); 64 | 65 | protected: 66 | 67 | void initParam(); 68 | 69 | int getNextBit(int layer_id); 70 | 71 | void refreshConvParam(i64 new_nx, i64 new_ny, const convKernel &conv); 72 | 73 | void calcSizeAfterPool(const poolKernel &p); 74 | 75 | void refreshFCParam(const fconKernel &fc); 76 | 77 | [[nodiscard]] i64 getFFTLen() const; 78 | 79 | [[nodiscard]] i8 getFFTBitLen() const; 80 | 81 | [[nodiscard]] i64 getPoolDecmpSize() const; 82 | 83 | void prepareDecmpBit(i64 layer_id, i64 idx, i64 dcmp_id, i64 bit_shift); 84 | 85 | void prepareFieldBit(const F &data, i64 dcmp_id, i64 bit_shift); 86 | 87 | void prepareSignBit(i64 layer_id, i64 idx, i64 dcmp_id); 88 | 89 | void prepareMax(i64 layer_id, i64 idx, i64 max_id); 90 | 91 | void calcInputLayer(layer &circuit); 92 | 93 | void calcNormalLayer(const layer &circuit, i64 layer_id); 94 | 95 | void calcDotProdLayer(const layer &circuit, i64 layer_id); 96 | 97 | void calcFFTLayer(const layer &circuit, i64 layer_id); 98 | 99 | vector> conv_section; 100 | vector pool; 101 | poolType pool_ty; 102 | i64 pool_bl, pool_sz; 103 | i64 pool_stride_bl, pool_stride; 104 | i64 pool_layer_cnt, act_layer_cnt, conv_layer_cnt; 105 | actType act_ty; 106 | 107 | vector full_conn; 108 | 109 | i64 pic_size_x, pic_size_y, pic_channel, pic_parallel; 110 | i64 SIZE; 111 | const i64 NCONV_FAST_SIZE, NCONV_SIZE, FFT_SIZE, AVE_POOL_SIZE, FC_SIZE, RELU_SIZE; 112 | i64 T; 113 | const i64 Q = 9; 114 | i64 Q_MAX; 115 | const i64 Q_BIT_SIZE = 220; 116 | 117 | i64 nx_in, nx_out, ny_in, ny_out, m, channel_in, channel_out, log_stride, padding; 118 | i64 new_nx_in, new_ny_in; 119 | i64 nx_padded_in, ny_padded_in; 120 | i64 total_in_size, total_para_size, total_relu_in_size, total_ave_in_size, total_max_in_size; 121 | int x_bit, w_bit, x_next_bit; 122 | 123 | vector>::iterator val; 124 | vector::iterator two_mul; 125 | 126 | void inputLayer(layer &circuit); 127 | 128 | void paddingLayer(layer &circuit, i64 &layer_id, i64 first_conv_id); 129 | 130 | void fftLayer(layer &circuit, i64 &layer_id); 131 | 132 | void dotProdLayer(layer &circuit, i64 &layer_id); 133 | 134 | void ifftLayer(layer &circuit, i64 &layer_id); 135 | 136 | void addBiasLayer(layer &circuit, i64 &layer_id, i64 first_bias_id); 137 | 138 | void naiveConvLayerFast(layer &circuit, i64 &layer_id, i64 first_conv_id, i64 first_bias_id); 139 | 140 | void naiveConvLayerMul(layer &circuit, i64 &layer_id, i64 first_conv_id); 141 | 142 | void naiveConvLayerAdd(layer &circuit, i64 &layer_id, i64 first_bias_id); 143 | 144 | void reluActConvLayer(layer &circuit, i64 &layer_id); 145 | 146 | void reluActFconLayer(layer &circuit, i64 &layer_id); 147 | 148 | void avgPoolingLayer(layer &circuit, i64 &layer_id); 149 | 150 | void 151 | maxPoolingLayer(layeredCircuit &C, i64 &layer_id, i64 first_dcmp_id, i64 first_max_id, i64 first_max_dcmp_id); 152 | 153 | void fullyConnLayer(layer &circuit, i64 &layer_id, i64 first_fc_id, i64 first_bias_id); 154 | 155 | static void printLayerInfo(const layer &circuit, i64 layer_id); 156 | 157 | void readBias(i64 first_bias_id); 158 | 159 | void readConvWeight(i64 first_conv_id); 160 | 161 | void readFconWeight(i64 first_fc_id); 162 | 163 | void printWitnessInfo(const layer &circuit) const; 164 | 165 | void printLayerValues(prover &pr); 166 | 167 | void printInfer(prover &pr); 168 | }; 169 | 170 | 171 | #endif //ZKCNN_NEURALNETWORK_HPP 172 | -------------------------------------------------------------------------------- /src/polynomial.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "polynomial.h" 3 | 4 | quintuple_poly::quintuple_poly() { a.clear(); b.clear(); c.clear(); d.clear(); e.clear(); f.clear();} 5 | quintuple_poly::quintuple_poly(const F &aa, const F &bb, const F &cc, const F &dd, const F &ee, const F &ff) { 6 | a = aa; 7 | b = bb; 8 | c = cc; 9 | d = dd; 10 | e = ee; 11 | f = ff; 12 | } 13 | 14 | quintuple_poly quintuple_poly::operator + (const quintuple_poly &x) const { 15 | return quintuple_poly(a + x.a, b + x.b, c + x.c, d + x.d, e + x.e, f + x.f); 16 | } 17 | 18 | F quintuple_poly::eval(const F &x) const { 19 | return (((((a * x) + b) * x + c) * x + d) * x + e) * x + f; 20 | } 21 | 22 | void quintuple_poly::clear() { 23 | a.clear(); b.clear(); c.clear(); d.clear(); e.clear(); f.clear(); 24 | } 25 | 26 | quadruple_poly::quadruple_poly() {a.clear(); b.clear(); c.clear(); d.clear(); e.clear();} 27 | quadruple_poly::quadruple_poly(const F &aa, const F &bb, const F &cc, const F &dd, const F &ee) { 28 | a = aa; 29 | b = bb; 30 | c = cc; 31 | d = dd; 32 | e = ee; 33 | } 34 | 35 | quadruple_poly quadruple_poly::operator + (const quadruple_poly &x) const { 36 | return quadruple_poly(a + x.a, b + x.b, c + x.c, d + x.d, e + x.e); 37 | } 38 | 39 | F quadruple_poly::eval(const F &x) const { 40 | return ((((a * x) + b) * x + c) * x + d) * x + e; 41 | } 42 | 43 | void quadruple_poly::clear() { 44 | a.clear(); b.clear(); c.clear(); d.clear(); e.clear(); 45 | } 46 | 47 | cubic_poly::cubic_poly() {a.clear(); b.clear(); c.clear(); d.clear();} 48 | cubic_poly::cubic_poly(const F &aa, const F &bb, const F &cc, const F &dd) { 49 | a = aa; 50 | b = bb; 51 | c = cc; 52 | d = dd; 53 | } 54 | 55 | cubic_poly cubic_poly::operator + (const cubic_poly &x) const { 56 | return cubic_poly(a + x.a, b + x.b, c + x.c, d + x.d); 57 | } 58 | 59 | F cubic_poly::eval(const F &x) const { 60 | return (((a * x) + b) * x + c) * x + d; 61 | } 62 | 63 | quadratic_poly::quadratic_poly() {a.clear(); b.clear(); c.clear();} 64 | quadratic_poly::quadratic_poly(const F &aa, const F &bb, const F &cc) { 65 | a = aa; 66 | b = bb; 67 | c = cc; 68 | } 69 | 70 | quadratic_poly quadratic_poly::operator + (const quadratic_poly &x) const { 71 | return quadratic_poly(a + x.a, b + x.b, c + x.c); 72 | } 73 | 74 | quadratic_poly quadratic_poly::operator+(const linear_poly &x) const { 75 | return quadratic_poly(a, b + x.a, c + x.b); 76 | } 77 | 78 | cubic_poly quadratic_poly::operator * (const linear_poly &x) const { 79 | return cubic_poly(a * x.a, a * x.b + b * x.a, b * x.b + c * x.a, c * x.b); 80 | } 81 | 82 | cubic_poly cubic_poly::operator * (const F &x) const { 83 | return cubic_poly(a * x, b * x, c * x, d * x); 84 | } 85 | 86 | void cubic_poly::clear() { 87 | a.clear(); b.clear(); c.clear(); d.clear(); 88 | } 89 | 90 | quadratic_poly quadratic_poly::operator*(const F &x) const { 91 | return quadratic_poly(a * x, b * x, c * x); 92 | } 93 | 94 | F quadratic_poly::eval(const F &x) const { 95 | return ((a * x) + b) * x + c; 96 | } 97 | 98 | void quadratic_poly::clear() { 99 | a.clear(); b.clear(); c.clear(); 100 | } 101 | 102 | linear_poly::linear_poly() {a.clear(); b.clear();} 103 | linear_poly::linear_poly(const F &aa, const F &bb) { 104 | a = aa; 105 | b = bb; 106 | } 107 | linear_poly::linear_poly(const F &x) { 108 | a.clear(); 109 | b = x; 110 | } 111 | 112 | linear_poly linear_poly::operator + (const linear_poly &x) const { 113 | return linear_poly(a + x.a, b + x.b); 114 | } 115 | 116 | quadratic_poly linear_poly::operator * (const linear_poly &x) const { 117 | return quadratic_poly(a * x.a, a * x.b + b * x.a, b * x.b); 118 | } 119 | 120 | linear_poly linear_poly::operator*(const F &x) const { 121 | return linear_poly(a * x, b * x); 122 | } 123 | 124 | F linear_poly::eval(const F &x) const { 125 | return a * x + b; 126 | } 127 | 128 | void linear_poly::clear() { 129 | a.clear(); b.clear(); 130 | } 131 | -------------------------------------------------------------------------------- /src/polynomial.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "global_var.hpp" 6 | 7 | class linear_poly; 8 | 9 | //ax^3 + bx^2 + cx + d 10 | class cubic_poly { 11 | public: 12 | F a, b, c, d; 13 | cubic_poly(); 14 | cubic_poly(const F &, const F &, const F &, const F &); 15 | cubic_poly operator + (const cubic_poly &) const; 16 | cubic_poly operator * (const F &) const; 17 | F eval(const F &) const; 18 | void clear(); 19 | }; 20 | 21 | //ax^2 + bx + c 22 | class quadratic_poly { 23 | public: 24 | F a, b, c; 25 | quadratic_poly(); 26 | quadratic_poly(const F &, const F &, const F &); 27 | quadratic_poly operator + (const quadratic_poly &) const; 28 | quadratic_poly operator + (const linear_poly &) const; 29 | cubic_poly operator * (const linear_poly &) const; 30 | quadratic_poly operator * (const F &) const; 31 | F eval(const F &) const; 32 | void clear(); 33 | }; 34 | 35 | 36 | //ax + b 37 | class linear_poly { 38 | public: 39 | F a, b; 40 | linear_poly(); 41 | linear_poly(const F &, const F &); 42 | linear_poly(const F &); 43 | linear_poly operator + (const linear_poly &) const; 44 | quadratic_poly operator * (const linear_poly &) const; 45 | linear_poly operator * (const F &) const; 46 | F eval(const F &) const; 47 | void clear(); 48 | }; 49 | 50 | 51 | 52 | //ax^4 + bx^3 + cx^2 + dx + e 53 | class quadruple_poly { 54 | public: 55 | F a, b, c, d, e; 56 | quadruple_poly(); 57 | quadruple_poly(const F &, const F &, const F &, const F &, const F &); 58 | quadruple_poly operator + (const quadruple_poly &) const; 59 | F eval(const F &) const; 60 | void clear(); 61 | }; 62 | 63 | //ax^5 + bx^4 + cx^3 + dx^2 + ex + f 64 | class quintuple_poly { 65 | public: 66 | F a, b, c, d, e, f; 67 | quintuple_poly(); 68 | quintuple_poly(const F &, const F &, const F &, const F &, const F &, const F &); 69 | quintuple_poly operator + (const quintuple_poly &) const; 70 | F eval(const F &) const; 71 | void clear(); 72 | }; -------------------------------------------------------------------------------- /src/prover.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/9/2021. 3 | // 4 | 5 | #include "prover.hpp" 6 | #include 7 | #include 8 | 9 | static vector beta_gs, beta_u; 10 | 11 | using std::unique_ptr; 12 | 13 | linear_poly interpolate(const F &zero_v, const F &one_v) { 14 | return {one_v - zero_v, zero_v}; 15 | } 16 | 17 | void prover::init() { 18 | proof_size = 0; 19 | r_u.resize(C.size + 1); 20 | r_v.resize(C.size + 1); 21 | } 22 | 23 | /** 24 | * This is to initialize all process. 25 | * 26 | * @param the random point to be evaluated at the output layer 27 | */ 28 | void prover::sumcheckInitAll(const vector::const_iterator &r_0_from_v) { 29 | sumcheck_id = C.size; 30 | i8 last_bl = C.circuit[sumcheck_id - 1].bit_length; 31 | r_u[sumcheck_id].resize(last_bl); 32 | 33 | prove_timer.start(); 34 | for (int i = 0; i < last_bl; ++i) r_u[sumcheck_id][i] = r_0_from_v[i]; 35 | prove_timer.stop(); 36 | } 37 | 38 | /** 39 | * This is to initialize before the process of a single layer. 40 | * 41 | * @param the random combination coefficiants for multiple reduction points 42 | */ 43 | void prover::sumcheckInit(const F &alpha_0, const F &beta_0) { 44 | prove_timer.start(); 45 | auto &cur = C.circuit[sumcheck_id]; 46 | alpha = alpha_0; 47 | beta = beta_0; 48 | r_0 = r_u[sumcheck_id].begin(); 49 | r_1 = r_v[sumcheck_id].begin(); 50 | --sumcheck_id; 51 | prove_timer.stop(); 52 | } 53 | 54 | /** 55 | * This is to initialize before the phase 1 of a single inner production layer. 56 | */ 57 | void prover::sumcheckDotProdInitPhase1() { 58 | fprintf(stderr, "sumcheck level %d, phase1 init start\n", sumcheck_id); 59 | 60 | auto &cur = C.circuit[sumcheck_id]; 61 | i8 fft_bl = cur.fft_bit_length; 62 | i8 cnt_bl = cur.bit_length - fft_bl; 63 | total[0] = 1ULL << fft_bl; 64 | total[1] = 1ULL << cur.bit_length_u[1]; 65 | total_size[1] = cur.size_u[1]; 66 | u32 fft_len = total[0]; 67 | 68 | r_u[sumcheck_id].resize(cur.max_bl_u); 69 | V_mult[0].resize(total[1]); 70 | V_mult[1].resize(total[1]); 71 | mult_array[1].resize(total[0]); 72 | beta_gs.resize(1ULL << fft_bl); 73 | 74 | prove_timer.start(); 75 | 76 | initBetaTable(beta_gs, fft_bl, r_0, F_ONE); 77 | 78 | for (u32 t = 0; t < fft_len; ++t) 79 | mult_array[1][t] = beta_gs[t]; 80 | for (u32 u = 0; u < total[1]; ++u) { 81 | V_mult[0][u].clear(); 82 | if (u >= cur.size_u[1]) V_mult[1][u].clear(); 83 | else V_mult[1][u] = val[sumcheck_id - 1][u]; 84 | } 85 | 86 | for (auto &gate: cur.bin_gates) 87 | for (u32 t = 0; t < fft_len; ++t) { 88 | u32 idx_u = gate.u << fft_bl | t; 89 | u32 idx_v = gate.v << fft_bl | t; 90 | V_mult[0][idx_u] = V_mult[0][idx_u] + beta_g[gate.g] * val[sumcheck_id - 1][idx_v]; 91 | } 92 | 93 | round = 0; 94 | prove_timer.stop(); 95 | } 96 | 97 | /** 98 | * This is the one-step reduction within a sumcheck process of a single inner production layer. 99 | * 100 | * @param the random point of the reduction of the previous step 101 | * @return the reducted cubic degree polynomial of the current variable from prover to verifier 102 | */ 103 | cubic_poly prover::sumcheckDotProdUpdate1(const F &previous_random) { 104 | prove_timer.start(); 105 | 106 | if (round) r_u[sumcheck_id].at(round - 1) = previous_random; 107 | ++round; 108 | 109 | auto &tmp_mult = mult_array[1]; 110 | auto &tmp_v0 = V_mult[0], &tmp_v1 = V_mult[1]; 111 | 112 | if (total[0] == 1) 113 | tmp_mult[0] = tmp_mult[0].eval(previous_random); 114 | else for (u32 i = 0; i < (total[0] >> 1); ++i) { 115 | u32 g0 = i << 1, g1 = i << 1 | 1; 116 | tmp_mult[i] = interpolate(tmp_mult[g0].eval(previous_random), tmp_mult[g1].eval(previous_random)); 117 | } 118 | total[0] >>= 1; 119 | 120 | cubic_poly ret; 121 | for (u32 i = 0; i < (total[1] >> 1); ++i) { 122 | u32 g0 = i << 1, g1 = i << 1 | 1; 123 | if (g0 >= total_size[1]) { 124 | tmp_v0[i].clear(); 125 | tmp_v1[i].clear(); 126 | continue; 127 | } 128 | if (g1 >= total_size[1]) { 129 | tmp_v0[g1].clear(); 130 | tmp_v1[g1].clear(); 131 | } 132 | tmp_v0[i] = interpolate(tmp_v0[g0].eval(previous_random), tmp_v0[g1].eval(previous_random)); 133 | tmp_v1[i] = interpolate(tmp_v1[g0].eval(previous_random), tmp_v1[g1].eval(previous_random)); 134 | if (total[0]) ret = ret + tmp_mult[i & total[0] - 1] * tmp_v1[i] * tmp_v0[i]; 135 | else ret = ret + tmp_mult[0] * tmp_v1[i] * tmp_v0[i]; 136 | } 137 | proof_size += F_BYTE_SIZE * (3 + (!ret.a.isZero())); 138 | 139 | total[1] >>= 1; 140 | total_size[1] = (total_size[1] + 1) >> 1; 141 | 142 | prove_timer.stop(); 143 | return ret; 144 | } 145 | 146 | void prover::sumcheckDotProdFinalize1(const F &previous_random, F &claim_1) { 147 | prove_timer.start(); 148 | r_u[sumcheck_id].at(round - 1) = previous_random; 149 | claim_1 = V_mult[1][0].eval(previous_random); 150 | V_u1 = V_mult[1][0].eval(previous_random) * mult_array[1][0].eval(previous_random); 151 | prove_timer.stop(); 152 | proof_size += F_BYTE_SIZE * 1; 153 | } 154 | 155 | void prover::sumcheckInitPhase1(const F &relu_rou_0) { 156 | fprintf(stderr, "sumcheck level %d, phase1 init start\n", sumcheck_id); 157 | 158 | auto &cur = C.circuit[sumcheck_id]; 159 | total[0] = ~cur.bit_length_u[0] ? 1ULL << cur.bit_length_u[0] : 0; 160 | total_size[0] = cur.size_u[0]; 161 | total[1] = ~cur.bit_length_u[1] ? 1ULL << cur.bit_length_u[1] : 0; 162 | total_size[1] = cur.size_u[1]; 163 | 164 | r_u[sumcheck_id].resize(cur.max_bl_u); 165 | V_mult[0].resize(total[0]); 166 | V_mult[1].resize(total[1]); 167 | mult_array[0].resize(total[0]); 168 | mult_array[1].resize(total[1]); 169 | beta_g.resize(1ULL << cur.bit_length); 170 | if (cur.ty == layerType::PADDING) beta_gs.resize(1ULL << cur.fft_bit_length); 171 | if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT) 172 | beta_gs.resize(total[1]); 173 | 174 | prove_timer.start(); 175 | 176 | relu_rou = relu_rou_0; 177 | add_term.clear(); 178 | for (int b = 0; b < 2; ++b) 179 | for (u32 u = 0; u < total[b]; ++u) 180 | mult_array[b][u].clear(); 181 | 182 | if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT) { 183 | i8 fft_bl = cur.fft_bit_length; 184 | i8 fft_blh = cur.fft_bit_length - 1; 185 | i8 cnt_bl = cur.ty == layerType::FFT ? cur.bit_length - fft_bl : cur.bit_length - fft_blh; 186 | u32 cnt_len = cur.size >> (cur.ty == layerType::FFT ? fft_bl : fft_blh); 187 | if (cur.ty == layerType::FFT) 188 | initBetaTable(beta_g, cnt_bl, r_0 + fft_bl, r_1, alpha, beta); 189 | else initBetaTable(beta_g, cnt_bl, r_0 + fft_blh, alpha); 190 | for (u32 u = 0, l = sumcheck_id - 1; u < total[1]; ++u) { 191 | V_mult[1][u].clear(); 192 | if (u >= cur.size_u[1]) continue; 193 | for (u32 g = 0; g < cnt_len; ++g) { 194 | u32 idx = g << cur.max_bl_u | u; 195 | V_mult[1][u] = V_mult[1][u] + val[l][idx] * beta_g[g]; 196 | } 197 | } 198 | 199 | beta_gs.resize(total[1]); 200 | phiGInit(beta_gs, r_0, cur.scale, fft_bl, cur.ty == layerType::IFFT); 201 | for (u32 u = 0; u < total[1] ; ++u) { 202 | mult_array[1][u] = beta_gs[u]; 203 | } 204 | } else { 205 | for (int b = 0; b < 2; ++b) { 206 | auto dep = !b ? 0 : sumcheck_id - 1; 207 | for (u32 u = 0; u < total[b]; ++u) { 208 | if (u >= cur.size_u[b]) 209 | V_mult[b][u].clear(); 210 | else V_mult[b][u] = getCirValue(dep, cur.ori_id_u, u); 211 | } 212 | } 213 | 214 | if (cur.ty == layerType::PADDING) { 215 | i8 fft_blh = cur.fft_bit_length - 1; 216 | u32 fft_lenh = 1ULL << fft_blh; 217 | initBetaTable(beta_gs, fft_blh, r_0, F_ONE); 218 | for (long g = (1L << cur.bit_length) - 1; g >= 0; --g) 219 | beta_g[g] = beta_g[g >> fft_blh] * beta_gs[g & fft_lenh - 1]; 220 | } else initBetaTable(beta_g, cur.bit_length, r_0, r_1, alpha * cur.scale, beta * cur.scale); 221 | if (cur.zero_start_id < cur.size) 222 | for (u32 g = cur.zero_start_id; g < 1ULL << cur.bit_length; ++g) beta_g[g] = beta_g[g] * relu_rou; 223 | 224 | for (auto &gate: cur.uni_gates) { 225 | bool idx = gate.lu != 0; 226 | mult_array[idx][gate.u] = mult_array[idx][gate.u] + beta_g[gate.g] * C.two_mul[gate.sc]; 227 | } 228 | 229 | for (auto &gate: cur.bin_gates) { 230 | bool idx = gate.getLayerIdU(sumcheck_id) != 0; 231 | auto val_lv = getCirValue(gate.getLayerIdV(sumcheck_id), cur.ori_id_v, gate.v); 232 | mult_array[idx][gate.u] = mult_array[idx][gate.u] + val_lv * beta_g[gate.g] * C.two_mul[gate.sc]; 233 | } 234 | } 235 | 236 | round = 0; 237 | prove_timer.stop(); 238 | fprintf(stderr, "sumcheck level %d, phase1 init finished\n", sumcheck_id); 239 | } 240 | 241 | void prover::sumcheckInitPhase2() { 242 | fprintf(stderr, "sumcheck level %d, phase2 init start\n", sumcheck_id); 243 | 244 | auto &cur = C.circuit[sumcheck_id]; 245 | total[0] = ~cur.bit_length_v[0] ? 1ULL << cur.bit_length_v[0] : 0; 246 | total_size[0] = cur.size_v[0]; 247 | total[1] = ~cur.bit_length_v[1] ? 1ULL << cur.bit_length_v[1] : 0; 248 | total_size[1] = cur.size_v[1]; 249 | i8 fft_bl = cur.fft_bit_length; 250 | i8 cnt_bl = cur.max_bl_v; 251 | 252 | r_v[sumcheck_id].resize(cur.max_bl_v); 253 | 254 | V_mult[0].resize(total[0]); 255 | V_mult[1].resize(total[1]); 256 | mult_array[0].resize(total[0]); 257 | mult_array[1].resize(total[1]); 258 | 259 | if (cur.ty == layerType::DOT_PROD) { 260 | beta_u.resize(1ULL << cnt_bl); 261 | beta_gs.resize(1ULL << fft_bl); 262 | } else beta_u.resize(1ULL << cur.max_bl_u); 263 | 264 | prove_timer.start(); 265 | 266 | add_term.clear(); 267 | for (int b = 0; b < 2; ++b) { 268 | for (u32 v = 0; v < total[b]; ++v) 269 | mult_array[b][v].clear(); 270 | } 271 | 272 | if (cur.ty == layerType::DOT_PROD) { 273 | u32 fft_len = 1ULL << cur.fft_bit_length; 274 | initBetaTable(beta_u, cnt_bl, r_u[sumcheck_id].begin() + fft_bl, F_ONE); 275 | initBetaTable(beta_gs, fft_bl, r_u[sumcheck_id].begin(), F_ONE); 276 | 277 | for (u32 v = 0; v < total[1]; ++v) { 278 | V_mult[1][v].clear(); 279 | if (v >= cur.size_v[1]) continue; 280 | for (u32 t = 0; t < fft_len; ++t) { 281 | u32 idx_v = (v << fft_bl) | t; 282 | V_mult[1][v] = V_mult[1][v] + val[sumcheck_id - 1][idx_v] * beta_gs[t]; 283 | } 284 | } 285 | 286 | for (auto &gate: cur.bin_gates) 287 | mult_array[1][gate.v] = 288 | mult_array[1][gate.v] + beta_g[gate.g] * beta_u[gate.u] * V_u1; 289 | } else { 290 | initBetaTable(beta_u, cur.max_bl_u, r_u[sumcheck_id].begin(), F_ONE); 291 | for (int b = 0; b < 2; ++b) { 292 | auto dep = !b ? 0 : sumcheck_id - 1; 293 | for (u32 v = 0; v < total[b]; ++v) { 294 | V_mult[b][v] = v >= cur.size_v[b] ? F_ZERO : getCirValue(dep, cur.ori_id_v, v); 295 | } 296 | } 297 | for (auto &gate: cur.uni_gates) { 298 | auto V_u = !gate.lu ? V_u0 : V_u1; 299 | add_term = add_term + beta_g[gate.g] * beta_u[gate.u] * V_u * C.two_mul[gate.sc]; 300 | } 301 | for (auto &gate: cur.bin_gates) { 302 | bool idx = gate.getLayerIdV(sumcheck_id); 303 | auto V_u = !gate.getLayerIdU(sumcheck_id) ? V_u0 : V_u1; 304 | mult_array[idx][gate.v] = mult_array[idx][gate.v] + beta_g[gate.g] * beta_u[gate.u] * V_u * C.two_mul[gate.sc]; 305 | } 306 | } 307 | 308 | round = 0; 309 | prove_timer.stop(); 310 | } 311 | 312 | void prover::sumcheckLiuInit(const vector &s_u, const vector &s_v) { 313 | sumcheck_id = 0; 314 | total[1] = (1ULL << C.circuit[sumcheck_id].bit_length); 315 | total_size[1] = C.circuit[sumcheck_id].size; 316 | 317 | r_u[0].resize(C.circuit[0].bit_length); 318 | mult_array[1].resize(total[1]); 319 | V_mult[1].resize(total[1]); 320 | 321 | i8 max_bl = 0; 322 | for (int i = sumcheck_id + 1; i < C.size; ++i) 323 | max_bl = max(max_bl, max(C.circuit[i].bit_length_u[0], C.circuit[i].bit_length_v[0])); 324 | beta_g.resize(1ULL << max_bl); 325 | 326 | prove_timer.start(); 327 | add_term.clear(); 328 | 329 | for (u32 g = 0; g < total[1]; ++g) { 330 | mult_array[1][g].clear(); 331 | V_mult[1][g] = (g < total_size[1]) ? val[0][g] : F_ZERO; 332 | } 333 | 334 | for (u8 i = sumcheck_id + 1; i < C.size; ++i) { 335 | i8 bit_length_i = C.circuit[i].bit_length_u[0]; 336 | u32 size_i = C.circuit[i].size_u[0]; 337 | if (~bit_length_i) { 338 | initBetaTable(beta_g, bit_length_i, r_u[i].begin(), s_u[i - 1]); 339 | for (u32 hu = 0; hu < size_i; ++hu) { 340 | u32 u = C.circuit[i].ori_id_u[hu]; 341 | mult_array[1][u] = mult_array[1][u] + beta_g[hu]; 342 | } 343 | } 344 | 345 | bit_length_i = C.circuit[i].bit_length_v[0]; 346 | size_i = C.circuit[i].size_v[0]; 347 | if (~bit_length_i) { 348 | initBetaTable(beta_g, bit_length_i, r_v[i].begin(), s_v[i - 1]); 349 | for (u32 hv = 0; hv < size_i; ++hv) { 350 | u32 v = C.circuit[i].ori_id_v[hv]; 351 | mult_array[1][v] = mult_array[1][v] + beta_g[hv]; 352 | } 353 | } 354 | } 355 | 356 | round = 0; 357 | prove_timer.stop(); 358 | } 359 | 360 | quadratic_poly prover::sumcheckUpdate1(const F &previous_random) { 361 | return sumcheckUpdate(previous_random, r_u[sumcheck_id]); 362 | } 363 | 364 | quadratic_poly prover::sumcheckUpdate2(const F &previous_random) { 365 | return sumcheckUpdate(previous_random, r_v[sumcheck_id]); 366 | } 367 | 368 | quadratic_poly prover::sumcheckUpdate(const F &previous_random, vector &r_arr) { 369 | prove_timer.start(); 370 | 371 | if (round) r_arr.at(round - 1) = previous_random; 372 | ++round; 373 | quadratic_poly ret; 374 | 375 | add_term = add_term * (F_ONE - previous_random); 376 | for (int b = 0; b < 2; ++b) 377 | ret = ret + sumcheckUpdateEach(previous_random, b); 378 | ret = ret + quadratic_poly(F_ZERO, -add_term, add_term); 379 | 380 | prove_timer.stop(); 381 | proof_size += F_BYTE_SIZE * 3; 382 | return ret; 383 | } 384 | 385 | quadratic_poly prover::sumcheckLiuUpdate(const F &previous_random) { 386 | prove_timer.start(); 387 | ++round; 388 | 389 | auto ret = sumcheckUpdateEach(previous_random, true); 390 | 391 | prove_timer.stop(); 392 | proof_size += F_BYTE_SIZE * 3; 393 | return ret; 394 | } 395 | 396 | quadratic_poly prover::sumcheckUpdateEach(const F &previous_random, bool idx) { 397 | auto &tmp_mult = mult_array[idx]; 398 | auto &tmp_v = V_mult[idx]; 399 | 400 | if (total[idx] == 1) { 401 | tmp_v[0] = tmp_v[0].eval(previous_random); 402 | tmp_mult[0] = tmp_mult[0].eval(previous_random); 403 | add_term = add_term + tmp_v[0].b * tmp_mult[0].b; 404 | } 405 | 406 | quadratic_poly ret; 407 | for (u32 i = 0; i < (total[idx] >> 1); ++i) { 408 | u32 g0 = i << 1, g1 = i << 1 | 1; 409 | if (g0 >= total_size[idx]) { 410 | tmp_v[i].clear(); 411 | tmp_mult[i].clear(); 412 | continue; 413 | } 414 | if (g1 >= total_size[idx]) { 415 | tmp_v[g1].clear(); 416 | tmp_mult[g1].clear(); 417 | } 418 | tmp_v[i] = interpolate(tmp_v[g0].eval(previous_random), tmp_v[g1].eval(previous_random)); 419 | tmp_mult[i] = interpolate(tmp_mult[g0].eval(previous_random), tmp_mult[g1].eval(previous_random)); 420 | ret = ret + tmp_mult[i] * tmp_v[i]; 421 | } 422 | total[idx] >>= 1; 423 | total_size[idx] = (total_size[idx] + 1) >> 1; 424 | 425 | return ret; 426 | } 427 | 428 | /** 429 | * This is to evaluate a multi-linear extension at a random point. 430 | * 431 | * @param the value of the array & random point & the size of the array & the size of the random point 432 | * @return sum of `values`, or 0.0 if `values` is empty. 433 | */ 434 | F prover::Vres(const vector::const_iterator &r, u32 output_size, u8 r_size) { 435 | prove_timer.start(); 436 | 437 | vector output(output_size); 438 | for (u32 i = 0; i < output_size; ++i) 439 | output[i] = val[C.size - 1][i]; 440 | u32 whole = 1ULL << r_size; 441 | for (u8 i = 0; i < r_size; ++i) { 442 | for (u32 j = 0; j < (whole >> 1); ++j) { 443 | if (j > 0) 444 | output[j].clear(); 445 | if ((j << 1) < output_size) 446 | output[j] = output[j << 1] * (F_ONE - r[i]); 447 | if ((j << 1 | 1) < output_size) 448 | output[j] = output[j] + output[j << 1 | 1] * (r[i]); 449 | } 450 | whole >>= 1; 451 | } 452 | F res = output[0]; 453 | 454 | prove_timer.stop(); 455 | proof_size += F_BYTE_SIZE; 456 | return res; 457 | } 458 | 459 | void prover::sumcheckFinalize1(const F &previous_random, F &claim_0, F &claim_1) { 460 | prove_timer.start(); 461 | r_u[sumcheck_id].at(round - 1) = previous_random; 462 | V_u0 = claim_0 = total[0] ? V_mult[0][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_u[0]) ? V_mult[0][0].b : F_ZERO; 463 | V_u1 = claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_u[1]) ? V_mult[1][0].b : F_ZERO; 464 | prove_timer.stop(); 465 | 466 | mult_array[0].clear(); 467 | mult_array[1].clear(); 468 | V_mult[0].clear(); 469 | V_mult[1].clear(); 470 | proof_size += F_BYTE_SIZE * 2; 471 | } 472 | 473 | void prover::sumcheckFinalize2(const F &previous_random, F &claim_0, F &claim_1) { 474 | prove_timer.start(); 475 | r_v[sumcheck_id].at(round - 1) = previous_random; 476 | claim_0 = total[0] ? V_mult[0][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_v[0]) ? V_mult[0][0].b : F_ZERO; 477 | claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : (~C.circuit[sumcheck_id].bit_length_v[1]) ? V_mult[1][0].b : F_ZERO; 478 | prove_timer.stop(); 479 | 480 | mult_array[0].clear(); 481 | mult_array[1].clear(); 482 | V_mult[0].clear(); 483 | V_mult[1].clear(); 484 | proof_size += F_BYTE_SIZE * 2; 485 | } 486 | 487 | void prover::sumcheckLiuFinalize(const F &previous_random, F &claim_1) { 488 | prove_timer.start(); 489 | r_u[sumcheck_id].at(round - 1) = previous_random; 490 | claim_1 = total[1] ? V_mult[1][0].eval(previous_random) : V_mult[1][0].b; 491 | prove_timer.stop(); 492 | proof_size += F_BYTE_SIZE; 493 | 494 | mult_array[1].clear(); 495 | V_mult[1].clear(); 496 | beta_g.clear(); 497 | } 498 | 499 | F prover::getCirValue(u8 layer_id, const vector &ori, u32 u) { 500 | return !layer_id ? val[0][ori[u]] : val[layer_id][u]; 501 | } 502 | 503 | hyrax_bls12_381::polyProver &prover::commitInput(const vector &gens) { 504 | if (C.circuit[0].size != (1ULL << C.circuit[0].bit_length)) { 505 | val[0].resize(1ULL << C.circuit[0].bit_length); 506 | for (int i = C.circuit[0].size; i < val[0].size(); ++i) 507 | val[0][i].clear(); 508 | } 509 | poly_p = std::make_unique(val[0], gens); 510 | return *poly_p; 511 | } -------------------------------------------------------------------------------- /src/prover.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/9/2021. 3 | // 4 | 5 | #ifndef ZKCNN_PROVER_HPP 6 | #define ZKCNN_PROVER_HPP 7 | 8 | #include "global_var.hpp" 9 | #include "circuit.h" 10 | #include "polynomial.h" 11 | 12 | using std::unique_ptr; 13 | 14 | class neuralNetwork; 15 | class singleConv; 16 | class prover { 17 | public: 18 | void init(); 19 | 20 | void sumcheckInitAll(const vector::const_iterator &r_0_from_v); 21 | void sumcheckInit(const F &alpha_0, const F &beta_0); 22 | void sumcheckDotProdInitPhase1(); 23 | void sumcheckInitPhase1(const F &relu_rou_0); 24 | void sumcheckInitPhase2(); 25 | 26 | cubic_poly sumcheckDotProdUpdate1(const F &previous_random); 27 | quadratic_poly sumcheckUpdate1(const F &previous_random); 28 | quadratic_poly sumcheckUpdate2(const F &previous_random); 29 | 30 | F Vres(const vector::const_iterator &r, u32 output_size, u8 r_size); 31 | 32 | void sumcheckDotProdFinalize1(const F &previous_random, F &claim_1); 33 | void sumcheckFinalize1(const F &previous_random, F &claim_0, F &claim_1); 34 | void sumcheckFinalize2(const F &previous_random, F &claim_0, F &claim_1); 35 | void sumcheckLiuFinalize(const F &previous_random, F &claim_1); 36 | 37 | void sumcheckLiuInit(const vector &s_u, const vector &s_v); 38 | quadratic_poly sumcheckLiuUpdate(const F &previous_random); 39 | 40 | hyrax_bls12_381::polyProver &commitInput(const vector &gens); 41 | 42 | timer prove_timer; 43 | double proveTime() const { return prove_timer.elapse_sec(); } 44 | double proofSize() const { return (double) proof_size / 1024.0; } 45 | double polyProverTime() const { return poly_p -> getPT(); } 46 | double polyProofSize() const { return poly_p -> getPS(); } 47 | 48 | layeredCircuit C; 49 | vector> val; // the output of each gate 50 | private: 51 | quadratic_poly sumcheckUpdateEach(const F &previous_random, bool idx); 52 | quadratic_poly sumcheckUpdate(const F &previous_random, vector &r_arr); 53 | F getCirValue(u8 layer_id, const vector &ori, u32 u); 54 | 55 | vector::iterator r_0, r_1; // current positions 56 | vector> r_u, r_v; // next positions 57 | 58 | vector beta_g; 59 | 60 | F add_term; 61 | vector mult_array[2]; 62 | vector V_mult[2]; 63 | 64 | F V_u0, V_u1; 65 | 66 | F alpha, beta, relu_rou; 67 | 68 | u64 proof_size; 69 | 70 | u32 total[2], total_size[2]; 71 | u8 round; // step within a sumcheck 72 | u8 sumcheck_id; // the level 73 | 74 | unique_ptr poly_p; 75 | 76 | friend neuralNetwork; 77 | friend singleConv; 78 | }; 79 | 80 | 81 | #endif //ZKCNN_PROVER_HPP 82 | -------------------------------------------------------------------------------- /src/utils.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/9/2021. 3 | // 4 | 5 | #include 6 | #include 7 | #include 8 | #include "utils.hpp" 9 | 10 | using std::cerr; 11 | using std::endl; 12 | using std::string; 13 | using std::cin; 14 | 15 | int ceilPow2BitLengthSigned(double n) { 16 | return (i8) ceil(log2(n)); 17 | } 18 | 19 | int floorPow2BitLengthSigned(double n) { 20 | return (i8) floor(log2(n)); 21 | } 22 | 23 | i8 ceilPow2BitLength(u32 n) { 24 | return n < 1e-9 ? -1 : (i8) ceil(log(n) / log(2.)); 25 | } 26 | 27 | i8 floorPow2BitLength(u32 n) { 28 | // cerr << n << ' ' << log(n) / log(2.)< &beta_f, vector &beta_s, const vector::const_iterator &r, const F &init, u32 first_half, u32 second_half) { 33 | beta_f.at(0) = init; 34 | beta_s.at(0) = F_ONE; 35 | 36 | for (u32 i = 0; i < first_half; ++i) { 37 | for (u32 j = 0; j < (1ULL << i); ++j) { 38 | auto tmp = beta_f.at(j) * r[i]; 39 | beta_f.at(j | (1ULL << i)) = tmp; 40 | beta_f.at(j) = beta_f[j] - tmp; 41 | } 42 | } 43 | 44 | for (u32 i = 0; i < second_half; ++i) { 45 | for (u32 j = 0; j < (1ULL << i); ++j) { 46 | auto tmp = beta_s[j] * r[(i + first_half)]; 47 | beta_s[j | (1ULL << i)] = tmp; 48 | beta_s[j] = beta_s[j] - tmp; 49 | } 50 | } 51 | } 52 | 53 | void phiPowInit(vector &phi_mul, int n, bool isIFFT) { 54 | u32 N = 1ULL << n; 55 | F phi = getRootOfUnit(n); 56 | if (isIFFT) F::inv(phi, phi); 57 | phi_mul[0] = F_ONE; 58 | for (u32 i = 1; i < N; ++i) phi_mul[i] = phi_mul[i - 1] * phi; 59 | } 60 | 61 | void phiGInit(vector &phi_g, const vector::const_iterator &rx, const F &scale, int n, bool isIFFT) { 62 | vector phi_mul(1 << n); 63 | phiPowInit(phi_mul, n, isIFFT); 64 | 65 | if (isIFFT) { 66 | // cerr << "==" << endl; 67 | // cerr << "gLength: " << n << endl; 68 | // for (int i = 0; i < n - 1; ++i) { 69 | // cerr << rx[i]; 70 | // cerr << endl; 71 | // } 72 | phi_g[0] = phi_g[1] = scale; 73 | for (int i = 2; i <= n; ++i) 74 | for (u32 b = 0; b < (1ULL << (i - 1)); ++b) { 75 | u32 l = b, r = b ^ (1ULL << (i - 1)); 76 | int m = n - i; 77 | F tmp1 = F_ONE - rx[m], tmp2 = rx[m] * phi_mul[b << m]; 78 | phi_g[r] = phi_g[l] * (tmp1 - tmp2); 79 | phi_g[l] = phi_g[l] * (tmp1 + tmp2); 80 | } 81 | } else { 82 | // cerr << "==" << endl; 83 | // cerr << "gLength: " << n << endl; 84 | // for (int i = 0; i < n; ++i) { 85 | // cerr << rx[i]; 86 | // cerr << endl; 87 | // } 88 | phi_g[0] = scale; 89 | for (int i = 1; i < n; ++i) 90 | for (u32 b = 0; b < (1ULL << (i - 1)); ++b) { 91 | u32 l = b, r = b ^ (1ULL << (i - 1)); 92 | int m = n - i; 93 | F tmp1 = F_ONE - rx[m], tmp2 = rx[m] * phi_mul[b << m]; 94 | phi_g[r] = phi_g[l] * (tmp1 - tmp2); 95 | phi_g[l] = phi_g[l] * (tmp1 + tmp2); 96 | } 97 | for (u32 b = 0; b < (1ULL << (n - 1)); ++b) { 98 | u32 l = b; 99 | F tmp1 = F_ONE - rx[0], tmp2 = rx[0] * phi_mul[b]; 100 | phi_g[l] = phi_g[l] * (tmp1 + tmp2); 101 | } 102 | } 103 | } 104 | 105 | void fft(vector &arr, int logn, bool flag) { 106 | // cerr << "fft: " << endl; 107 | // for (auto x: arr) cerr << x << ' '; 108 | // cerr << endl; 109 | static std::vector rev; 110 | static std::vector w; 111 | 112 | u32 len = 1ULL << logn; 113 | assert(arr.size() == len); 114 | 115 | rev.resize(len); 116 | w.resize(len); 117 | 118 | rev[0] = 0; 119 | for (u32 i = 1; i < len; ++i) 120 | rev[i] = rev[i >> 1] >> 1 | (i & 1) << (logn - 1); 121 | 122 | w[0] = F_ONE; 123 | w[1] = getRootOfUnit(logn); 124 | if (flag) F::inv(w[1], w[1]); 125 | for (u32 i = 2; i < len; ++i) w[i] = w[i - 1] * w[1]; 126 | 127 | for (u32 i = 0; i < len; ++i) 128 | if (rev[i] < i) std::swap(arr[i], arr[rev[i]]); 129 | 130 | for (u32 i = 2; i <= len; i <<= 1) 131 | for (u32 j = 0; j < len; j += i) 132 | for (u32 k = 0; k < (i >> 1); ++k) { 133 | auto u = arr[j + k]; 134 | auto v = arr[j + k + (i >> 1)] * w[len / i * k]; 135 | arr[j + k] = u + v; 136 | arr[j + k + (i >> 1)] = u - v; 137 | } 138 | 139 | if (flag) { 140 | F ilen; 141 | F::inv(ilen, len); 142 | for (u32 i = 0; i < len; ++i) 143 | arr[i] = arr[i] * ilen; 144 | } 145 | } 146 | 147 | void 148 | initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r_0, const vector::const_iterator &r_1, 149 | const F &alpha, const F &beta) { 150 | u8 first_half = gLength >> 1, second_half = gLength - first_half; 151 | u32 mask_fhalf = (1ULL << first_half) - 1; 152 | 153 | vector beta_f(1ULL << first_half), beta_s(1ULL << second_half); 154 | if (!beta.isZero()) { 155 | initHalfTable(beta_f, beta_s, r_1, beta, first_half, second_half); 156 | for (u32 i = 0; i < (1ULL << gLength); ++i) 157 | beta_g[i] = beta_f[i & mask_fhalf] * beta_s[i >> first_half]; 158 | } else for (u32 i = 0; i < (1ULL << gLength); ++i) 159 | beta_g[i].clear(); 160 | 161 | if (alpha.isZero()) return; 162 | initHalfTable(beta_f, beta_s, r_0, alpha, first_half, second_half); 163 | for (u32 i = 0; i < (1ULL << gLength); ++i) 164 | beta_g[i] = beta_g[i] + beta_f[i & mask_fhalf] * beta_s[i >> first_half]; 165 | } 166 | 167 | 168 | void initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r, const F &init) { 169 | if (gLength == -1) return; 170 | int first_half = gLength >> 1, second_half = gLength - first_half; 171 | u32 mask_fhalf = (1ULL << first_half) - 1; 172 | vector beta_f(1ULL << first_half), beta_s(1ULL << second_half); 173 | 174 | if (!init.isZero()) { 175 | initHalfTable(beta_f, beta_s, r, init, first_half, second_half); 176 | for (u32 i = 0; i < (1ULL << gLength); ++i) 177 | beta_g[i] = beta_f[i & mask_fhalf] * beta_s[i >> first_half]; 178 | } else for (u32 i = 0; i < (1ULL << gLength); ++i) 179 | beta_g[i].clear(); 180 | } 181 | 182 | bool check(long x, long y, long nx, long ny) { 183 | return 0 <= x && x < nx && 0 <= y && y < ny; 184 | } 185 | // 186 | //F getData(u8 scale_bl) { 187 | // double x; 188 | // in >> x; 189 | // long y = round(x * (1L << scale_bl)); 190 | // return F(y); 191 | //} 192 | 193 | void initLayer(layer &circuit, long size, layerType ty) { 194 | circuit.size = circuit.zero_start_id = size; 195 | circuit.bit_length = ceilPow2BitLength(size); 196 | circuit.ty = ty; 197 | } 198 | 199 | long sqr(long x) { 200 | return x * x; 201 | } 202 | 203 | double byte2KB(size_t x) { return x / 1024.0; } 204 | 205 | double byte2MB(size_t x) { return x / 1024.0 / 1024.0; } 206 | 207 | double byte2GB(size_t x) { return x / 1024.0 / 1024.0 / 1024.0; } 208 | 209 | long matIdx(long x, long y, long n) { 210 | assert(y < n); 211 | return x * n + y; 212 | } 213 | 214 | long cubIdx(long x, long y, long z, long n, long m) { 215 | assert(y < n && z < m); 216 | return matIdx(matIdx(x, y, n), z, m); 217 | } 218 | 219 | long tesIdx(long w, long x, long y, long z, long n, long m, long l) { 220 | assert(x < n && y < m && z < l); 221 | return matIdx(cubIdx(w, x, y, n, m), z, l); 222 | } 223 | 224 | F getRootOfUnit(int n) { 225 | F res = -F_ONE; 226 | if (!n) return F_ONE; 227 | while (--n) { 228 | bool b = F::squareRoot(res, res); 229 | assert(b); 230 | } 231 | return res; 232 | } 233 | 234 | 235 | -------------------------------------------------------------------------------- /src/utils.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/9/2021. 3 | // 4 | 5 | #ifndef ZKCNN_UTILS_HPP 6 | #define ZKCNN_UTILS_HPP 7 | 8 | #include 9 | 10 | int ceilPow2BitLengthSigned(double n); 11 | int floorPow2BitLengthSigned(double n); 12 | 13 | char ceilPow2BitLength(u32 n); 14 | char floorPow2BitLength(u32 n); 15 | 16 | 17 | void fft(vector &arr, int logn, bool flag); 18 | 19 | void 20 | initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r_0, const vector::const_iterator &r_1, 21 | const F &alpha, const F &beta); 22 | 23 | void initPhiTable(F *phi_g, const layer &cur_layer, const F *r_0, const F *r_1, F alpha, F beta); 24 | 25 | void phiGInit(vector &phi_g, const vector::const_iterator &rx, const F &scale, int n, bool isIFFT); 26 | 27 | void initBetaTable(vector &beta_g, u8 gLength, const vector::const_iterator &r, const F &init); 28 | 29 | bool check(long x, long y, long nx, long ny); 30 | 31 | long matIdx(long x, long y, long n); 32 | 33 | long cubIdx(long x, long y, long z, long n, long m); 34 | 35 | long tesIdx(long w, long x, long y, long z, long n, long m, long l); 36 | 37 | void initLayer(layer &circuit, long size, layerType ty); 38 | 39 | long sqr(long x); 40 | 41 | double byte2KB(size_t x); 42 | 43 | double byte2MB(size_t x); 44 | 45 | double byte2GB(size_t x); 46 | 47 | F getRootOfUnit(int n); 48 | 49 | #endif //ZKCNN_UTILS_HPP 50 | -------------------------------------------------------------------------------- /src/verifier.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/9/2021. 3 | // 4 | 5 | #include "verifier.hpp" 6 | #include "global_var.hpp" 7 | #include 8 | #include 9 | #include 10 | 11 | vector beta_v; 12 | static vector beta_u, beta_gs; 13 | 14 | verifier::verifier(prover *pr, const layeredCircuit &cir): 15 | p(pr), C(cir) { 16 | final_claim_u0.resize(C.size + 2); 17 | final_claim_v0.resize(C.size + 2); 18 | 19 | r_u.resize(C.size + 2); 20 | r_v.resize(C.size + 2); 21 | // make the prover ready 22 | p->init(); 23 | } 24 | 25 | F verifier::getFinalValue(const F &claim_u0, const F &claim_u1, const F &claim_v0, const F &claim_v1) { 26 | 27 | auto test_value = bin_value[0] * (claim_u0 * claim_v0) 28 | + bin_value[1] * (claim_u1 * claim_v1) 29 | + bin_value[2] * (claim_u1 * claim_v0) 30 | + uni_value[0] * claim_u0 31 | + uni_value[1] * claim_u1; 32 | 33 | return test_value; 34 | } 35 | 36 | void verifier::betaInitPhase1(u8 depth, const F &alpha, const F &beta, const vector::const_iterator &r_0, const vector::const_iterator &r_1, const F &relu_rou) { 37 | i8 bl = C.circuit[depth].bit_length; 38 | i8 fft_bl = C.circuit[depth].fft_bit_length; 39 | i8 fft_blh = C.circuit[depth].fft_bit_length - 1; 40 | i8 cnt_bl = bl - fft_bl, cnt_bl2 = C.circuit[depth].max_bl_u - fft_bl; 41 | 42 | switch (C.circuit[depth].ty) { 43 | case layerType::FFT: 44 | case layerType::IFFT: 45 | beta_gs.resize(1ULL << fft_bl); 46 | phiGInit(beta_gs, r_0, C.circuit[depth].scale, fft_bl, C.circuit[depth].ty == layerType::IFFT); 47 | beta_u.resize(1ULL << C.circuit[depth].max_bl_u); 48 | initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE); 49 | break; 50 | case layerType::PADDING: 51 | beta_g.resize(1ULL << bl); 52 | beta_gs.resize(1ULL << fft_blh); 53 | initBetaTable(beta_g, bl - fft_blh, r_u[depth + 2].begin() + fft_bl, r_v[depth + 2].begin(), alpha, beta); 54 | initBetaTable(beta_gs, fft_blh, r_0, F_ONE); 55 | for (u32 g = (1ULL << bl) - 1; g < (1ULL << bl); --g) 56 | beta_g[g] = beta_g[g >> fft_blh] * 57 | beta_gs[g & (1ULL << fft_blh) - 1]; 58 | beta_u.resize(1ULL << C.circuit[depth].max_bl_u); 59 | initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE); 60 | break; 61 | case layerType::DOT_PROD: 62 | beta_g.resize(1ULL << cnt_bl); 63 | initBetaTable(beta_g, cnt_bl, r_u[depth + 2].begin() + fft_bl - 1, alpha); 64 | 65 | beta_u.resize(1ULL << cnt_bl2); 66 | initBetaTable(beta_u, cnt_bl2, r_u[depth].begin() + fft_bl, F_ONE); 67 | for (u32 i = 0; i < 1ULL << cnt_bl2; ++i) 68 | for (u32 j = 0; j < fft_bl; ++j) 69 | beta_u[i] = beta_u[i] * ((r_0[j] * r_u[depth][j]) + (F_ONE - r_0[j]) * (F_ONE - r_u[depth][j])); 70 | break; 71 | 72 | default: 73 | beta_g.resize(1ULL << bl); 74 | initBetaTable(beta_g, C.circuit[depth].bit_length, r_0, r_1, alpha * C.circuit[depth].scale, 75 | beta * C.circuit[depth].scale); 76 | if (C.circuit[depth].zero_start_id < C.circuit[depth].size) 77 | for (u32 g = C.circuit[depth].zero_start_id; g < 1ULL << C.circuit[depth].bit_length; ++g) 78 | beta_g[g] = beta_g[g] * relu_rou; 79 | beta_u.resize(1ULL << C.circuit[depth].max_bl_u); 80 | initBetaTable(beta_u, C.circuit[depth].max_bl_u, r_u[depth].begin(), F_ONE); 81 | } 82 | } 83 | 84 | void verifier::betaInitPhase2(u8 depth) { 85 | beta_v.resize(1ULL << C.circuit[depth].max_bl_v); 86 | initBetaTable(beta_v, C.circuit[depth].max_bl_v, r_v[depth].begin(), F_ONE); 87 | } 88 | 89 | void verifier::predicatePhase1(u8 layer_id) { 90 | auto &cur_layer = C.circuit[layer_id]; 91 | 92 | uni_value[0].clear(); 93 | uni_value[1].clear(); 94 | if (cur_layer.ty == layerType::FFT || cur_layer.ty == layerType::IFFT) 95 | for (u32 u = 0; u < 1ULL << cur_layer.max_bl_u; ++u) 96 | uni_value[1] = uni_value[1] + beta_gs[u] * beta_u[u]; 97 | else for (auto &gate: cur_layer.uni_gates) { 98 | bool idx = gate.lu; 99 | uni_value[idx] = uni_value[idx] + beta_g[gate.g] * beta_u[gate.u] * C.two_mul[gate.sc]; 100 | } 101 | bin_value[0] = bin_value[1] = bin_value[2] = F_ZERO; 102 | } 103 | 104 | void verifier::predicatePhase2(u8 layer_id) { 105 | uni_value[0] = uni_value[0] * beta_v[0]; 106 | uni_value[1] = uni_value[1] * beta_v[0]; 107 | 108 | auto &cur_layer = C.circuit[layer_id]; 109 | if (C.circuit[layer_id].ty == layerType::DOT_PROD) { 110 | for (auto &gate: cur_layer.bin_gates) 111 | bin_value[gate.l] = 112 | bin_value[gate.l] + 113 | beta_g[gate.g] * beta_u[gate.u] * beta_v[gate.v]; 114 | } else for (auto &gate: cur_layer.bin_gates) 115 | bin_value[gate.l] = bin_value[gate.l] + beta_g[gate.g] * beta_u[gate.u] * beta_v[gate.v] * C.two_mul[gate.sc]; 116 | } 117 | 118 | bool verifier::verify() { 119 | u8 logn = C.circuit[0].bit_length; 120 | u64 n_sqrt = 1ULL << (logn - (logn >> 1)); 121 | vector gens(n_sqrt); 122 | for (auto &x: gens) { 123 | Fr tmp; 124 | tmp.setByCSPRNG(); 125 | x = mcl::bn::getG1basePoint() * tmp; 126 | } 127 | 128 | poly_v = std::make_unique(p -> commitInput(gens), gens); 129 | return verifyInnerLayers() && verifyFirstLayer() && verifyInput(); 130 | } 131 | 132 | bool verifier::verifyInnerLayers() { 133 | total_timer.start(); 134 | total_slow_timer.start(); 135 | 136 | F alpha = F_ONE, beta = F_ZERO, relu_rou, final_claim_u1, final_claim_v1; 137 | r_u[C.size].resize(C.circuit[C.size - 1].bit_length); 138 | for (i8 i = 0; i < C.circuit[C.size - 1].bit_length; ++i) 139 | r_u[C.size][i].setByCSPRNG(); 140 | vector::const_iterator r_0 = r_u[C.size].begin(); 141 | vector::const_iterator r_1; 142 | 143 | total_timer.stop(); 144 | total_slow_timer.stop(); 145 | 146 | auto previousSum = p->Vres(r_0, C.circuit[C.size - 1].size, C.circuit[C.size - 1].bit_length); 147 | p -> sumcheckInitAll(r_0); 148 | 149 | for (u8 i = C.size - 1; i; --i) { 150 | auto &cur = C.circuit[i]; 151 | p->sumcheckInit(alpha, beta); 152 | total_timer.start(); 153 | total_slow_timer.start(); 154 | 155 | // phase 1 156 | r_u[i].resize(cur.max_bl_u); 157 | for (int j = 0; j < cur.max_bl_u; ++j) r_u[i][j].setByCSPRNG(); 158 | if (cur.zero_start_id < cur.size) 159 | relu_rou.setByCSPRNG(); 160 | else relu_rou = F_ONE; 161 | 162 | total_timer.stop(); 163 | total_slow_timer.stop(); 164 | if (cur.ty == layerType::DOT_PROD) 165 | p->sumcheckDotProdInitPhase1(); 166 | else p->sumcheckInitPhase1(relu_rou); 167 | 168 | F previousRandom = F_ZERO; 169 | for (i8 j = 0; j < cur.max_bl_u; ++j) { 170 | F cur_claim, nxt_claim; 171 | if (cur.ty == layerType::DOT_PROD) { 172 | cubic_poly poly = p->sumcheckDotProdUpdate1(previousRandom); 173 | total_timer.start(); 174 | total_slow_timer.start(); 175 | cur_claim = poly.eval(F_ZERO) + poly.eval(F_ONE); 176 | nxt_claim = poly.eval(r_u[i][j]); 177 | } else { 178 | quadratic_poly poly = p->sumcheckUpdate1(previousRandom); 179 | total_timer.start(); 180 | total_slow_timer.start(); 181 | cur_claim = poly.eval(F_ZERO) + poly.eval(F_ONE); 182 | nxt_claim = poly.eval(r_u[i][j]); 183 | } 184 | 185 | if (cur_claim != previousSum) { 186 | cerr << cur_claim << ' ' << previousSum << endl; 187 | fprintf(stderr, "Verification fail, phase1, circuit %d, current bit %d\n", i, j); 188 | return false; 189 | } 190 | previousRandom = r_u[i][j]; 191 | previousSum = nxt_claim; 192 | total_timer.stop(); 193 | total_slow_timer.stop(); 194 | } 195 | 196 | if (cur.ty == layerType::DOT_PROD) 197 | p->sumcheckDotProdFinalize1(previousRandom, final_claim_u1); 198 | else p->sumcheckFinalize1(previousRandom, final_claim_u0[i], final_claim_u1); 199 | 200 | total_slow_timer.start(); 201 | betaInitPhase1(i, alpha, beta, r_0, r_1, relu_rou); 202 | predicatePhase1(i); 203 | 204 | total_timer.start(); 205 | if (cur.need_phase2) { 206 | r_v[i].resize(cur.max_bl_v); 207 | for (int j = 0; j < cur.max_bl_v; ++j) r_v[i][j].setByCSPRNG(); 208 | 209 | total_timer.stop(); 210 | total_slow_timer.stop(); 211 | 212 | p->sumcheckInitPhase2(); 213 | previousRandom = F_ZERO; 214 | for (u32 j = 0; j < cur.max_bl_v; ++j) { 215 | quadratic_poly poly = p->sumcheckUpdate2(previousRandom); 216 | 217 | total_timer.start(); 218 | total_slow_timer.start(); 219 | if (poly.eval(F_ZERO) + poly.eval(F_ONE) != previousSum) { 220 | fprintf(stderr, "Verification fail, phase2, circuit level %d, current bit %d, total is %d\n", i, j, 221 | cur.max_bl_v); 222 | return false; 223 | } 224 | 225 | previousRandom = r_v[i][j]; 226 | previousSum = poly.eval(previousRandom); 227 | total_timer.stop(); 228 | total_slow_timer.stop(); 229 | } 230 | p->sumcheckFinalize2(previousRandom, final_claim_v0[i], final_claim_v1); 231 | 232 | total_slow_timer.start(); 233 | betaInitPhase2(i); 234 | predicatePhase2(i); 235 | total_timer.start(); 236 | } 237 | F test_value = getFinalValue(final_claim_u0[i], final_claim_u1, final_claim_v0[i], final_claim_v1); 238 | 239 | if (previousSum != test_value) { 240 | std::cerr << test_value << ' ' << previousSum << std::endl; 241 | fprintf(stderr, "Verification fail, semi final, circuit level %d\n", i); 242 | return false; 243 | } else fprintf(stderr, "Verification Pass, semi final, circuit level %d\n", i); 244 | 245 | if (cur.ty == layerType::FFT || cur.ty == layerType::IFFT) 246 | previousSum = final_claim_u1; 247 | else { 248 | if (~cur.bit_length_u[1]) 249 | alpha.setByCSPRNG(); 250 | else alpha.clear(); 251 | if ((~cur.bit_length_v[1]) || cur.ty == layerType::FFT) 252 | beta.setByCSPRNG(); 253 | else beta.clear(); 254 | previousSum = alpha * final_claim_u1 + beta * final_claim_v1; 255 | } 256 | 257 | r_0 = r_u[i].begin(); 258 | r_1 = r_v[i].begin(); 259 | 260 | total_timer.stop(); 261 | total_slow_timer.stop(); 262 | beta_u.clear(); 263 | beta_v.clear(); 264 | } 265 | return true; 266 | } 267 | 268 | bool verifier::verifyFirstLayer() { 269 | total_slow_timer.start(); 270 | total_timer.start(); 271 | 272 | auto &cur = C.circuit[0]; 273 | 274 | vector sig_u(C.size - 1); 275 | for (int i = 0; i < C.size - 1; ++i) sig_u[i].setByCSPRNG(); 276 | vector sig_v(C.size - 1); 277 | for (int i = 0; i < C.size - 1; ++i) sig_v[i].setByCSPRNG(); 278 | r_u[0].resize(cur.bit_length); 279 | for (int i = 0; i < cur.bit_length; ++i) r_u[0][i].setByCSPRNG(); 280 | auto r_0 = r_u[0].begin(); 281 | 282 | F previousSum = F_ZERO; 283 | for (int i = 1; i < C.size; ++i) { 284 | if (~C.circuit[i].bit_length_u[0]) 285 | previousSum = previousSum + sig_u[i - 1] * final_claim_u0[i]; 286 | if (~C.circuit[i].bit_length_v[0]) 287 | previousSum = previousSum + sig_v[i - 1] * final_claim_v0[i]; 288 | } 289 | total_timer.stop(); 290 | total_slow_timer.stop(); 291 | 292 | p->sumcheckLiuInit(sig_u, sig_v); 293 | F previousRandom = F_ZERO; 294 | for (int j = 0; j < cur.bit_length; ++j) { 295 | auto poly = p -> sumcheckLiuUpdate(previousRandom); 296 | if (poly.eval(F_ZERO) + poly.eval(F_ONE) != previousSum) { 297 | fprintf(stderr, "Liu fail, circuit 0, current bit %d\n", j); 298 | return false; 299 | } 300 | previousRandom = r_0[j]; 301 | previousSum = poly.eval(previousRandom); 302 | } 303 | 304 | F gr = F_ZERO; 305 | p->sumcheckLiuFinalize(previousRandom, eval_in); 306 | 307 | beta_g.resize(1ULL << cur.bit_length); 308 | 309 | total_slow_timer.start(); 310 | initBetaTable(beta_g, cur.bit_length, r_0, F_ONE); 311 | for (int i = 1; i < C.size; ++i) { 312 | if (~C.circuit[i].bit_length_u[0]) { 313 | beta_u.resize(1ULL << C.circuit[i].bit_length_u[0]); 314 | initBetaTable(beta_u, C.circuit[i].bit_length_u[0], r_u[i].begin(), sig_u[i - 1]); 315 | for (u32 j = 0; j < C.circuit[i].size_u[0]; ++j) 316 | gr = gr + beta_g[C.circuit[i].ori_id_u[j]] * beta_u[j]; 317 | } 318 | 319 | if (~C.circuit[i].bit_length_v[0]) { 320 | beta_v.resize(1ULL << C.circuit[i].bit_length_v[0]); 321 | initBetaTable(beta_v, C.circuit[i].bit_length_v[0], r_v[i].begin(), sig_v[i - 1]); 322 | for (u32 j = 0; j < C.circuit[i].size_v[0]; ++j) 323 | gr = gr + beta_g[C.circuit[i].ori_id_v[j]] * beta_v[j]; 324 | } 325 | } 326 | 327 | beta_u.clear(); 328 | beta_v.clear(); 329 | 330 | total_timer.start(); 331 | if (eval_in * gr != previousSum) { 332 | fprintf(stderr, "Liu fail, semi final, circuit 0.\n"); 333 | return false; 334 | } 335 | 336 | total_timer.stop(); 337 | total_slow_timer.stop(); 338 | output_tb[PT_OUT_ID] = to_string_wp(p->proveTime()); 339 | output_tb[VT_OUT_ID] = to_string_wp(verifierTime()); 340 | output_tb[PS_OUT_ID] = to_string_wp(p -> proofSize()); 341 | 342 | fprintf(stderr, "Verification pass\n"); 343 | fprintf(stderr, "Prove Time %lf\n", p->proveTime()); 344 | fprintf(stderr, "verify time %lf = %lf + %lf(slow)\n", verifierSlowTime(), verifierTime(), verifierSlowTime() - verifierTime()); 345 | fprintf(stderr, "proof size = %lf kb\n", p -> proofSize()); 346 | 347 | beta_g.clear(); 348 | beta_gs.clear(); 349 | beta_u.clear(); 350 | beta_v.clear(); 351 | r_u.resize(1); 352 | r_v.clear(); 353 | 354 | sig_u.clear(); 355 | sig_v.clear(); 356 | return true; 357 | } 358 | 359 | bool verifier::verifyInput() { 360 | if (!poly_v -> verify(r_u[0], eval_in)) { 361 | fprintf(stderr, "Verification fail, final input check fail.\n"); 362 | return false; 363 | } 364 | 365 | fprintf(stderr, "poly pt = %.5f, vt = %.5f, ps = %.5f\n", p -> polyProverTime(), poly_v -> getVT(), p -> polyProofSize()); 366 | output_tb[POLY_PT_OUT_ID] = to_string_wp(p -> polyProverTime()); 367 | output_tb[POLY_VT_OUT_ID] = to_string_wp(poly_v -> getVT()); 368 | output_tb[POLY_PS_OUT_ID] = to_string_wp(p -> polyProofSize()); 369 | output_tb[TOT_PT_OUT_ID] = to_string_wp(p -> polyProverTime() + p->proveTime()); 370 | output_tb[TOT_VT_OUT_ID] = to_string_wp(poly_v -> getVT() + verifierTime()); 371 | output_tb[TOT_PS_OUT_ID] = to_string_wp(p -> polyProofSize() + p -> proofSize()); 372 | return true; 373 | } 374 | -------------------------------------------------------------------------------- /src/verifier.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 69029 on 3/9/2021. 3 | // 4 | 5 | #ifndef ZKCNN_CONVVERIFIER_HPP 6 | #define ZKCNN_CONVVERIFIER_HPP 7 | 8 | #include "prover.hpp" 9 | 10 | using std::unique_ptr; 11 | class verifier { 12 | public: 13 | prover *p; 14 | const layeredCircuit &C; 15 | 16 | verifier(prover *pr, const layeredCircuit &cir); 17 | 18 | bool verify(); 19 | 20 | timer total_timer, total_slow_timer; 21 | double verifierTime() const { return total_timer.elapse_sec(); } 22 | double verifierSlowTime() const { return total_slow_timer.elapse_sec(); } 23 | 24 | private: 25 | vector> r_u, r_v; 26 | vector final_claim_u0, final_claim_v0; 27 | bool verifyInnerLayers(); 28 | bool verifyFirstLayer(); 29 | bool verifyInput(); 30 | 31 | vector beta_g; 32 | void betaInitPhase1(u8 depth, const F &alpha, const F &beta, const vector::const_iterator &r_0, const vector::const_iterator &r_1, const F &relu_rou); 33 | void betaInitPhase2(u8 depth); 34 | 35 | F uni_value[2]; 36 | F bin_value[3]; 37 | void predicatePhase1(u8 layer_id); 38 | void predicatePhase2(u8 layer_id); 39 | 40 | F getFinalValue(const F &claim_u0, const F &claim_u1, const F &claim_v0, const F &claim_v1); 41 | 42 | F eval_in; 43 | unique_ptr poly_v; 44 | }; 45 | 46 | 47 | #endif //ZKCNN_CONVVERIFIER_HPP 48 | --------------------------------------------------------------------------------