├── requirements.txt ├── src ├── mlp.h ├── mlp.c ├── convnet.h ├── mlp_params.h ├── convnet_params.h ├── convnet.c ├── nn.c ├── run_nn.py ├── nn.h ├── nn_math.h ├── nn_math.c └── convnet_params.c ├── scripts ├── test_convnet_c.py ├── test_mlp_c.py ├── create_convnet_c_params.py ├── train_mlp.py ├── train_convnet.py ├── create_mlp_c_params.py └── quantize_with_package.py ├── neural_nets.py ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.1 2 | torch==1.9.0 3 | -------------------------------------------------------------------------------- /src/mlp.h: -------------------------------------------------------------------------------- 1 | /******************************************************************* 2 | @file mlp.h 3 | * @brief Function prototypes to create and run an MLP for inference 4 | * with only integers (8-bit integers and 32-bit integers 5 | * in fixed-point) 6 | * 7 | * @author Benjamin Fuhrer 8 | * 9 | *******************************************************************/ 10 | #ifndef MLP_H 11 | #define MLP_H 12 | 13 | void run_mlp(const int *x, const unsigned int N, unsigned int *class_indices); 14 | /** 15 | * @brief Function to run an mlp for classification 16 | * 17 | * @param x - NxK input matrix 18 | * @param N 19 | * @param class_indices - Nx1 vector for storing class index prediction 20 | */ 21 | 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /src/mlp.c: -------------------------------------------------------------------------------- 1 | 2 | #include "mlp_params.h" 3 | #include "nn.h" 4 | #include "nn_math.h" 5 | 6 | void run_mlp(const int *x, const unsigned int N, unsigned int *class_indices) 7 | { 8 | int out_input[N*H1]; 9 | linear_layer(x, layer_1_weight, out_input, layer_1_s_x, 10 | layer_1_s_w_inv, layer_1_s_x_inv, 11 | N, INPUT_DIM, H1, 1); 12 | int out_h1[N*H2]; 13 | linear_layer(out_input, layer_2_weight, out_h1, layer_2_s_x, 14 | layer_2_s_w_inv, layer_2_s_x_inv, 15 | N, H1, H2, 1); 16 | int output[N*OUTPUT_DIM]; 17 | linear_layer(out_h1, layer_3_weight, output, layer_3_s_x, 18 | layer_3_s_w_inv, layer_3_s_x_inv, 19 | N, H2, OUTPUT_DIM, 0); 20 | // get argmax 21 | argmax_over_cols(output, class_indices, N, OUTPUT_DIM); 22 | } 23 | 24 | -------------------------------------------------------------------------------- /src/convnet.h: -------------------------------------------------------------------------------- 1 | /******************************************************************* 2 | @file convnet.h 3 | * @brief Function prototypes to create and run a convolutional neural network for inference 4 | * with only integers (8-bit integers and 32-bit integers 5 | * in fixed-point) 6 | * 7 | * @author Benjamin Fuhrer 8 | * 9 | *******************************************************************/ 10 | #ifndef CONVNET_H 11 | #define CONVNET_H 12 | 13 | #define BATCH_SIZE 1 // don't use larger batches to avoid stack overflow 14 | 15 | #include "convnet_params.h" 16 | 17 | void run_convnet(const int *x, unsigned int *class_indices); 18 | /** 19 | * @brief A function to run a pre-specified convolutional neural network with relu activation function and max pooling 20 | * 21 | * @param x - input tensor 22 | * @param class_indices - Nx1 vector for storing class index prediction, where N = batch size 23 | */ 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /src/mlp_params.h: -------------------------------------------------------------------------------- 1 | /******************************************************************* 2 | @file mlp_params.h 3 | * @brief variable prototypes for model parameters and amax values 4 | * 5 | * 6 | * @author Benjamin Fuhrer 7 | * 8 | *******************************************************************/ 9 | #ifndef MLP_PARAMS 10 | #define MLP_PARAMS 11 | 12 | #define INPUT_DIM 784 13 | #define H1 128 14 | #define H2 64 15 | #define OUTPUT_DIM 10 16 | #define FXP_VALUE 16 17 | #define BATCH_SIZE 10 18 | 19 | #include 20 | 21 | 22 | // quantization/dequantization constants 23 | extern const int layer_1_s_x; 24 | extern const int layer_1_s_x_inv; 25 | extern const int layer_1_s_w_inv[128]; 26 | extern const int layer_2_s_x; 27 | extern const int layer_2_s_x_inv; 28 | extern const int layer_2_s_w_inv[64]; 29 | extern const int layer_3_s_x; 30 | extern const int layer_3_s_x_inv; 31 | extern const int layer_3_s_w_inv[10]; 32 | // Layer quantized parameters 33 | extern const int8_t layer_1_weight[100352]; 34 | extern const int8_t layer_2_weight[8192]; 35 | extern const int8_t layer_3_weight[640]; 36 | 37 | #endif // end of MLP_PARAMS 38 | -------------------------------------------------------------------------------- /src/convnet_params.h: -------------------------------------------------------------------------------- 1 | /******************************************************************* 2 | @file convnet_params.h 3 | * @brief variable prototypes for model parameters and amax values 4 | * 5 | * 6 | * @author Benjamin Fuhrer 7 | * 8 | *******************************************************************/ 9 | #ifndef CONVNET_PARAMS 10 | #define CONVNET_PARAMS 11 | 12 | #define INPUT_DIM 784 13 | #define H1 28 14 | #define W1 28 15 | #define H1_conv 26 16 | #define W1_conv 26 17 | #define H1_pool 13 18 | #define W1_pool 13 19 | #define H2_conv 11 20 | #define W2_conv 11 21 | #define H2_pool 5 22 | #define W2_pool 5 23 | #define C0 1 24 | #define C1 16 25 | #define C2 16 26 | #define OUTPUT_DIM 10 27 | 28 | #include 29 | 30 | 31 | // quantization/dequantization constants 32 | extern const int layer_1_s_x; 33 | extern const int layer_1_s_x_inv; 34 | extern const int layer_1_s_w_inv[16]; 35 | extern const int layer_2_s_x; 36 | extern const int layer_2_s_x_inv; 37 | extern const int layer_2_s_w_inv[16]; 38 | extern const int layer_3_s_x; 39 | extern const int layer_3_s_x_inv; 40 | extern const int layer_3_s_w_inv[10]; 41 | // Layer quantized parameters 42 | extern const int8_t layer_1_weight[144]; 43 | extern const int8_t layer_2_weight[2304]; 44 | extern const int8_t layer_3_weight[4000]; 45 | 46 | #endif // end of CONVNET_PARAMS 47 | -------------------------------------------------------------------------------- /src/convnet.c: -------------------------------------------------------------------------------- 1 | 2 | #include "convnet.h" 3 | #include "nn.h" 4 | #include "nn_math.h" 5 | 6 | void run_convnet(const int *x, unsigned int *class_indices) 7 | { 8 | 9 | int out_conv1[BATCH_SIZE*C1*H1_conv*W1_conv]; 10 | 11 | conv2d_layer(x, layer_1_weight, out_conv1, layer_1_s_x, layer_1_s_w_inv, layer_1_s_x_inv, 12 | BATCH_SIZE, C0, C1, H1, W1, H1_conv, W1_conv, 13 | 3, 3, 1, 1); 14 | 15 | int out_pool1[BATCH_SIZE*C1*H1_pool*W1_pool]; 16 | pooling2d(out_conv1, out_pool1, BATCH_SIZE, C1, H1_conv, W1_conv, H1_pool, W1_pool, 2, 2, 2, 2); 17 | 18 | int out_conv2[BATCH_SIZE*C2*H2_conv*W2_conv]; 19 | conv2d_layer(out_pool1, layer_2_weight, out_conv2, layer_2_s_x, layer_2_s_w_inv, layer_2_s_x_inv, 20 | BATCH_SIZE, C1, C2, H1_pool, W1_pool, H2_conv, W2_conv, 21 | 3, 3, 1, 1); 22 | 23 | int out_pool2[BATCH_SIZE*C1*H1_pool*W1_pool]; 24 | pooling2d(out_conv2, out_pool2, BATCH_SIZE, C2, H2_conv, W2_conv, H2_pool, W2_pool, 2, 2, 2, 2); 25 | 26 | int output[BATCH_SIZE*OUTPUT_DIM]; 27 | linear_layer(out_pool2, layer_3_weight, output, layer_3_s_x, 28 | layer_3_s_w_inv, layer_3_s_x_inv, 29 | BATCH_SIZE, C2*H2_pool*W2_pool, OUTPUT_DIM, 0); 30 | 31 | argmax_over_cols(output, class_indices, BATCH_SIZE, OUTPUT_DIM); 32 | } 33 | -------------------------------------------------------------------------------- /src/nn.c: -------------------------------------------------------------------------------- 1 | #include "nn.h" 2 | #include "nn_math.h" 3 | 4 | void linear_layer(const int *x, const int8_t *w, int *output, const int x_scale_factor, 5 | const int *w_scale_factor_inv, const int x_scale_factor_inv, 6 | const unsigned int N, const unsigned int K, const unsigned int M, 7 | const unsigned int hidden_layer) 8 | { 9 | int8_t x_q[N * K]; 10 | quantize(x, x_q, x_scale_factor, x_scale_factor_inv, N*K); 11 | 12 | mat_mult(x_q, w, output, N, K, M); 13 | 14 | dequantize_per_row(output, w_scale_factor_inv, x_scale_factor_inv, N, M); 15 | 16 | if (hidden_layer) 17 | relu(output, N*M); 18 | 19 | } 20 | 21 | void conv2d_layer(const int *x, const int8_t *w,int *output, const int x_scale_factor, const int *w_scale_factor_inv, const int x_scale_factor_inv, 22 | const unsigned int N, const unsigned int C_in, const unsigned int C_out, const int H, const int W, 23 | const int H_conv, const int W_conv, const int k_size_h, const int k_size_w, const int stride_h, const int stride_w) 24 | { 25 | int8_t x_q[N*C_in*H*W]; 26 | 27 | quantize(x, x_q, x_scale_factor, x_scale_factor_inv, N*C_in*H*W); 28 | 29 | conv2d(x_q, w, output, N, C_in, C_out, H, W, H_conv, W_conv, 30 | k_size_h, k_size_w, stride_h, stride_w); 31 | 32 | dequantize_per_channel(output, w_scale_factor_inv, x_scale_factor_inv, N, C_out, H_conv*W_conv); 33 | 34 | relu(output, N*C_out*H_conv*W_conv); 35 | 36 | 37 | } 38 | 39 | -------------------------------------------------------------------------------- /scripts/test_convnet_c.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for running inference of model in C using ctypes 3 | """ 4 | import argparse 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.datasets as datasets 9 | import torchvision.transforms as transforms 10 | from src.run_nn import load_c_lib, run_convnet 11 | from torch.utils.data import DataLoader 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser(description="Script for testing post-training quantization of a pre-trained model in C", 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument('--batch_size', help='batch size', type=int, default=1) 17 | 18 | args = parser.parse_args() 19 | 20 | mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.1307,), (0.3081,)) 23 | ])) 24 | 25 | print(f'Evaluate model on test data') 26 | 27 | test_loader = DataLoader(mnist_testset, batch_size=args.batch_size, num_workers=1, shuffle=False) 28 | 29 | # load c library 30 | c_lib = load_c_lib(library='convnet.so') 31 | 32 | acc = 0 33 | for samples, labels in test_loader: 34 | samples = (samples * (2 ** 16)).round() # convert to fixed-point 16 35 | preds = run_convnet(samples, c_lib).astype(int) 36 | acc += (torch.from_numpy(preds) == labels).sum() 37 | 38 | print(f"Accuracy: {(acc / len(mnist_testset.data)) * 100.0:.2f}%") 39 | -------------------------------------------------------------------------------- /scripts/test_mlp_c.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for running inference of model in C using ctypes 3 | """ 4 | import argparse 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.datasets as datasets 9 | import torchvision.transforms as transforms 10 | from src.run_nn import load_c_lib, run_mlp 11 | from torch.utils.data import DataLoader 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser(description="Script for testing post-training quantization of a pre-trained model in C", 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument('--batch_size', help='batch size', type=int, default=1) 17 | parser.add_argument('--data_dir', help='directory of folder containing the MNIST dataset', default='../data') 18 | 19 | args = parser.parse_args() 20 | 21 | mnist_testset = datasets.MNIST(root=args.data_dir, train=False, download=True, transform=transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.1307,), (0.3081,)) 24 | ])) 25 | 26 | print(f'Evaluating integer-only C model on test data') 27 | 28 | test_loader = DataLoader(mnist_testset, batch_size=args.batch_size, num_workers=1, shuffle=False) 29 | # load c library 30 | c_lib = load_c_lib(library='mlp.so') 31 | 32 | acc = 0 33 | for samples, labels in test_loader: 34 | samples = (samples * (2 ** 16)).round() # convert to fixed-point 16 35 | 36 | preds = run_mlp(samples, c_lib).astype(int) 37 | 38 | acc += (torch.from_numpy(preds) == labels).sum() 39 | 40 | print(f"Accuracy: {(acc / len(mnist_testset.data)) * 100.0:.2f}%") 41 | -------------------------------------------------------------------------------- /neural_nets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module containing neural network architectures (MLP and ConvNet) 3 | """ 4 | import torch.nn as nn 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, in_dim, out_dim, hidden_sizes, activation=nn.ReLU): 8 | super(MLP, self).__init__() 9 | assert isinstance(hidden_sizes, list) and len(hidden_sizes) > 0 10 | layer_list = [nn.Linear(in_dim, hidden_sizes[0], bias=False)] 11 | for i in range(1, len(hidden_sizes)): 12 | layer_list.extend([activation(), 13 | nn.Linear(hidden_sizes[i-1], hidden_sizes[i], bias=False)] 14 | ) 15 | layer_list.extend([activation(), nn.Linear(hidden_sizes[-1], out_dim, bias=False)]) 16 | self.net = nn.Sequential(*layer_list) 17 | 18 | def forward(self, x): 19 | return self.net(x.flatten(start_dim=1)) 20 | 21 | class ConvNet(nn.Module): 22 | def __init__(self, out_dim, channel_sizes, activation=nn.ReLU): 23 | super(ConvNet, self).__init__() 24 | 25 | def get_output_dim(input_dim, kernel_size, stride): 26 | output_dim = (input_dim -(kernel_size-1) - 1) / stride 27 | return int(output_dim + 1) 28 | 29 | output_dim = get_output_dim(get_output_dim(28, kernel_size=3, stride=1), kernel_size=2, stride=2) 30 | output_dim = get_output_dim(get_output_dim(output_dim, kernel_size=3, stride=1), kernel_size=2, stride=2) 31 | 32 | layer_list = [nn.Conv2d(in_channels=1, out_channels=channel_sizes[0], kernel_size=(3, 3), bias=False), 33 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2,2)), 34 | activation(), 35 | nn.Conv2d(in_channels=channel_sizes[0], out_channels=channel_sizes[1], kernel_size=(3, 3), bias=False), 36 | nn.MaxPool2d(kernel_size=(2, 2), stride=(2,2)), 37 | activation(), 38 | nn.Flatten(), 39 | nn.Linear(output_dim*output_dim*channel_sizes[-1], out_dim, bias=False) 40 | ] 41 | 42 | self.net = nn.Sequential(*layer_list) 43 | 44 | 45 | def forward(self, x): 46 | return self.net(x) -------------------------------------------------------------------------------- /src/run_nn.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Module interfacing between the C files and python 4 | """ 5 | import ctypes 6 | import ctypes.util 7 | import os 8 | import sys 9 | 10 | import numpy as np 11 | 12 | 13 | def load_c_lib(library): 14 | """ 15 | Load C shared library 16 | :param library: 17 | :return: 18 | """ 19 | try: 20 | c_lib = ctypes.CDLL(f"{os.path.dirname(os.path.abspath(__file__))}/{library}") 21 | except OSError: 22 | print("Unable to load the requested C library") 23 | sys.exit() 24 | return c_lib 25 | 26 | def ensure_contiguous(array): 27 | """ 28 | Ensure that array is contiguous 29 | :param array: 30 | :return: 31 | """ 32 | return np.ascontiguousarray(array) if not array.flags['C_CONTIGUOUS'] else array 33 | 34 | 35 | def run_mlp(x, c_lib): 36 | """ 37 | Call 'run_mlp' function from C in Python 38 | :param x: 39 | :param c_lib: 40 | :return: 41 | """ 42 | N = len(x) 43 | x = x.flatten() 44 | x = ensure_contiguous(x.numpy()) 45 | x = x.astype(np.intc) 46 | # print(x) 47 | class_indices = ensure_contiguous(np.zeros(N, dtype=np.uintc)) 48 | 49 | c_int_p = ctypes.POINTER(ctypes.c_int) 50 | c_uint_p = ctypes.POINTER(ctypes.c_uint) 51 | 52 | c_run_mlp = c_lib.run_mlp 53 | c_run_mlp.argtypes = (c_int_p, ctypes.c_uint, c_uint_p) 54 | c_run_mlp.restype = None 55 | c_run_mlp(x.ctypes.data_as(c_int_p), ctypes.c_uint(N), 56 | class_indices.ctypes.data_as(c_uint_p) 57 | ) 58 | 59 | return np.ctypeslib.as_array(class_indices, N) 60 | 61 | 62 | def run_convnet(x, c_lib): 63 | """ 64 | Call 'run_mlp' function from C in Python 65 | :param x: 66 | :param c_lib: 67 | :return: 68 | """ 69 | N = len(x) 70 | x = x.flatten() 71 | x = ensure_contiguous(x.numpy()) 72 | x = x.astype(np.intc) 73 | # print(x) 74 | class_indices = ensure_contiguous(np.zeros(N, dtype=np.uintc)) 75 | 76 | c_int_p = ctypes.POINTER(ctypes.c_int) 77 | c_uint_p = ctypes.POINTER(ctypes.c_uint) 78 | 79 | c_run_convnet = c_lib.run_convnet 80 | c_run_convnet.argtypes = (c_int_p, c_uint_p) 81 | c_run_convnet.restype = None 82 | c_run_convnet(x.ctypes.data_as(c_int_p), class_indices.ctypes.data_as(c_uint_p)) 83 | 84 | return np.ctypeslib.as_array(class_indices, N) 85 | 86 | -------------------------------------------------------------------------------- /src/nn.h: -------------------------------------------------------------------------------- 1 | /******************************************************************* 2 | @file nn.h 3 | * @brief Function prototypes for neural network layers 4 | * 5 | * 6 | * @author Benjamin Fuhrer 7 | * 8 | *******************************************************************/ 9 | #ifndef NN_H 10 | #define NN_H 11 | 12 | #include 13 | 14 | void linear_layer(const int *x, const int8_t *w, int *output, const int x_scale_factor, 15 | const int *w_scale_factor_inv, const int x_scale_factor_inv, 16 | const unsigned int N, const unsigned int K, 17 | const unsigned int M, const unsigned int not_output_layer); 18 | /** 19 | * @brief A neural network linear layer withthout bias Y = ReLU(XW) 20 | * x is quantized before multiplication with w and then dequantized per-row granulity prior to the activation function 21 | * 22 | * @param x - NxK input matrix 23 | * @param w - KxM layer weight matrix 24 | * @param output - NxM output matrix 25 | * @param x_amax_quant - amax value for quantization of input matrix 26 | * @param x_w_amax_dequant - 1XM amax values for dequantization of Z=XW 27 | * @param N 28 | * @param K 29 | * @param M 30 | * @param hidden_layer - boolean value if layer is a hidden layer (activation) 31 | * 32 | * @return Void 33 | */ 34 | 35 | void conv2d_layer(const int *x, const int8_t *w,int *output, const int x_scale_factor, 36 | const int *w_scale_factor_inv, const int x_scale_factor_inv, 37 | const unsigned int N, const unsigned int C_in, const unsigned int C_out, 38 | const int H, const int W, 39 | const int H_conv, const int W_conv, const int k_size_h, const int k_size_w, 40 | const int stride_h, const int stride_w); 41 | /** 42 | * @brief A neural network 2D convolutional layer with ReLU activation function 43 | * x is quantized before the convolution operation and then dequantized with per-column granulity prior to the activation function 44 | * 45 | * @param x - (N, C_in, H, W) input tensor 46 | * @param w - (C_out, C_in, H, W) weight tensor 47 | * @param output - (N, C_out, H_conv, W_conv) output tensor_in 48 | * @param x_amax_quant - amax value for input tensor quantization 49 | * @param w_amax_dequant - amax per channel values for weight tensor dequantization 50 | * @param x_amax_dequant - amax value for input tensor dequantization (= 1 / x_amax_quant) 51 | * 52 | * @return Void 53 | */ 54 | 55 | 56 | #endif 57 | 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | *.DS_Store 131 | data/ 132 | *.iml 133 | *code-workspace 134 | *.ipynb 135 | *.json 136 | *.o 137 | *.dylib 138 | *.th 139 | -------------------------------------------------------------------------------- /src/nn_math.h: -------------------------------------------------------------------------------- 1 | /******************************************************************* 2 | @file nn_math.h 3 | * @brief Function prototypes for mathematical functions 4 | * 5 | * 6 | * @author Benjamin Fuhrer 7 | * 8 | *******************************************************************/ 9 | #ifndef NN_MATH_H 10 | #define NN_MATH_H 11 | 12 | #define MAX(a,b) (((a)>(b))?(a):(b)) 13 | #define MIN(a,b) (((a)<(b))?(a):(b)) 14 | 15 | #define NUM_BITS 8 16 | #define INT8_MAX_VALUE 127 17 | #define FXP_VALUE 16 18 | #define ROUND_CONST (1 << (FXP_VALUE - 1)) // = 0.5 to before right shifting to improve rounding 19 | 20 | #include 21 | 22 | void mat_mult(const int8_t *mat_l, const int8_t *mat_r, int *result, const unsigned int N, const unsigned int K, const unsigned int M); 23 | /** 24 | * @brief Calculates matrix multiplication as: Y = XW 25 | * 26 | * 27 | * @param mat_l - left matrix (X), size NxK 28 | * @param mat_r - right matrix (W), size (K+1)xM, the last row of W contains the bias vector 29 | * @param result - output matrix (Y), size NxM 30 | * @param N - number of rows in X 31 | * @param K - number of columns/rows in X/W 32 | * @param M - number of columns in W 33 | * @return Void 34 | */ 35 | 36 | int get_output_dim(int input_dim, int kernel_size, int stride); 37 | 38 | void conv2d(const int8_t *x, const int8_t *w, int *y, int N, int C_in, int C_out, int H, int W, int H_new, int W_new, 39 | int k_size_h, int k_size_w, int stride_h, int stride_w); 40 | 41 | void pooling2d(int *x, int *y, int N, int C_out, int H, int W, int H_new, int W_new, 42 | int k_size_h, int k_size_w, int stride_h, int stride_w); 43 | 44 | void relu(int *tensor_in, const unsigned int size); 45 | /** 46 | * @brief ReLU activation function 47 | * 48 | * @param tensor_in - input tensor 49 | * @param size - size of flattened tensor 50 | * @return Void 51 | */ 52 | 53 | 54 | void quantize(const int *tensor_in, int8_t *tensor_q, const int scale_factor, 55 | const int scale_factor_inv, const unsigned int size); 56 | /** 57 | * @brief Scale quantization of a tensor by a single amax value 58 | * 59 | * @param tensor_in - input tensor 60 | * @param tensor_q - output quantized tensor 61 | * @param scale_factor - 127 / amax 62 | * @param scale_factor_inv - 1 / scale_factor 63 | * @param size - size of flattened tensor 64 | * @return Void 65 | */ 66 | 67 | void dequantize_per_row(int *mat_in, const int *scale_factor_w_inv, const int scale_factor_x_inv, const unsigned int N, const unsigned int M); 68 | /** 69 | * @brief Scale dequantization with per-row granulity 70 | * Each row is multiplied by the corresponding column amax value 71 | * offline calculate reciprocal(amax) so we can replace division by multiplication 72 | * 73 | * @param mat_in - NxM input matrix to dequantize 74 | * @param scale_factor_w_inv -1XM row vector of layer's weight matrix scale factor values 75 | * @param scale_factor_x_inv - input inverse scale factor 76 | * @param N 77 | * @param M 78 | * @return Void 79 | */ 80 | 81 | void dequantize_per_channel(int *tensor_in, const int *amax_w, const int amax_x, const unsigned int N, const unsigned int C, const unsigned int K); 82 | /** 83 | * @brief Scale dequantization with per-channel granulity 84 | * Each channel is multiplied by the corresponding channel amax value 85 | * offline calculate reciprocal(amax) so we can replace division by multiplication 86 | * 87 | * @param tensor_in - input tensor to dequantize (N, C, ...) 88 | * @param amax -1XC row vector of amax values 89 | * @param N - number of samples 90 | * @param C - number of channels 91 | * @param K - number of remaining flattened dimensions 92 | * @return Void 93 | */ 94 | 95 | void argmax_over_cols(const int *mat_in, unsigned int *indices, const unsigned int N, const unsigned int M); 96 | /** 97 | * @brief Calculate argmax per columns of an NxM matrix 98 | * 99 | * @param mat_in - NxM input matrix 100 | * @param indices - 1xM indices to store argmax of each column 101 | * @param N 102 | * @param M 103 | * @return Void 104 | */ 105 | 106 | 107 | #endif // 108 | 109 | -------------------------------------------------------------------------------- /scripts/create_convnet_c_params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for writing param header and source files in C with weights and amax values calculate in python 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import torch 8 | 9 | 10 | def get_output_dim(input_dim, kernel_size, stride): 11 | output_dim = (input_dim -(kernel_size-1) - 1) / stride 12 | return int(output_dim + 1) 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description="Script for post-training quantization of a pre-trained model", 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--filename', help='filename of quantized model', type=str, default='convnet_mnist_quant.th') 18 | parser.add_argument('--save_dir', help='save directory', default='../saved_models', type=Path) 19 | 20 | args = parser.parse_args() 21 | 22 | saved_stats = torch.load(args.save_dir / + args.filename) 23 | state_dict = saved_stats['state_dict'] 24 | channel_sizes = saved_stats['channel_sizes'] 25 | 26 | H1_conv = get_output_dim(28, kernel_size=3, stride=1) 27 | W1_conv = H1_conv 28 | H1_pool = get_output_dim(H1_conv, kernel_size=2, stride=2) 29 | W1_pool = H1_pool 30 | 31 | H2_conv = get_output_dim(H1_pool, kernel_size=3, stride=1) 32 | W2_conv = H2_conv 33 | H2_pool = get_output_dim(H2_conv, kernel_size=2, stride=2) 34 | W2_pool = H2_pool 35 | 36 | 37 | 38 | # create header file 39 | with open('../src/convnet_params.h', 'w') as f: 40 | f.write('/*******************************************************************\n') 41 | f.write('@file convnet_params.h\n* @brief variable prototypes for model parameters and amax values\n*\n*\n') 42 | f.write('* @author Benjamin Fuhrer\n*\n') 43 | f.write('*******************************************************************/\n') 44 | f.write('#ifndef CONVNET_PARAMS\n#define CONVNET_PARAMS\n\n') 45 | 46 | f.write(f'#define INPUT_DIM {28*28}\n') 47 | f.write(f'#define H1 28\n#define W1 28\n#define H1_conv {H1_conv}\n#define W1_conv {W1_conv}\n#define H1_pool {H1_pool}\n#define W1_pool {W1_pool}\n') 48 | f.write(f'#define H2_conv {H2_conv}\n#define W2_conv {W2_conv}\n#define H2_pool {H2_pool}\n#define W2_pool {W2_pool}\n') 49 | f.write(f'#define C0 1\n#define C1 {channel_sizes[0]}\n#define C2 {channel_sizes[1]}\n') 50 | f.write(f'#define OUTPUT_DIM {10}\n\n') 51 | f.write('#include \n\n\n') 52 | 53 | 54 | f.write('// quantization/dequantization constants\n') 55 | 56 | for layer_idx in range(1, 4): 57 | 58 | 59 | name = f'layer_{layer_idx}_s_x' 60 | f.write(f"extern const int {name};\n") 61 | 62 | name = f'layer_{layer_idx}_s_x_inv' 63 | f.write(f"extern const int {name};\n") 64 | 65 | name = f'layer_{layer_idx}_s_w_inv' 66 | value = state_dict[name] 67 | f.write(f"extern const int {name}[{len(value)}];\n") 68 | 69 | f.write('// Layer quantized parameters\n') 70 | for layer_idx in range(1, 4): 71 | name = f'layer_{layer_idx}_weight' 72 | param = state_dict[f'layer_{layer_idx}_weight'] 73 | f.write(f"extern const int8_t {name}[{len(param.flatten())}];\n") 74 | 75 | f.write('\n#endif // end of CONVNET_PARAMS\n') 76 | 77 | # create source file 78 | with open('../src/convnet_params.c', 'w') as f: 79 | f.write('#include "convnet_params.h"\n\n\n') 80 | 81 | for layer_idx in range(1, 4): 82 | name = f'layer_{layer_idx}_s_x' 83 | fxp_value = (state_dict[name] * (2**16)).round() 84 | f.write(f"const int {name} = {int(fxp_value)};\n\n") 85 | 86 | name = f'layer_{layer_idx}_s_x_inv' 87 | fxp_value = (state_dict[name] * (2**16)).round() 88 | f.write(f"const int {name} = {int(fxp_value)};\n\n") 89 | 90 | name = f'layer_{layer_idx}_s_w_inv' 91 | fxp_value = (state_dict[name] * (2**16)).round() 92 | f.write(f"const int {name}[{len(fxp_value)}] = {{") 93 | 94 | for idx in range(len(fxp_value)): 95 | f.write(f"{int(fxp_value[idx])}") 96 | if idx < len(fxp_value) - 1: 97 | f.write(", ") 98 | f.write("};\n\n") 99 | 100 | for layer_idx in range(1, 4): 101 | name = f'layer_{layer_idx}_weight' 102 | param = state_dict[f'layer_{layer_idx}_weight'] 103 | param = param.flatten() 104 | f.write(f"const int8_t {name}[{len(param)}] = {{") 105 | for idx in range(len(param)): 106 | f.write(f"{param[idx]}") 107 | if idx < len(param) - 1: 108 | f.write(", ") 109 | f.write("};\n") 110 | -------------------------------------------------------------------------------- /scripts/train_mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for training a simple MLP for classification on the MNIST dataset 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.datasets as datasets 10 | import torchvision.transforms as transforms 11 | from neural_nets import MLP, ConvNet 12 | from torch.optim import Adam 13 | from torch.utils.data import DataLoader, random_split 14 | 15 | 16 | def train_epoch(model:nn.Module, data_loader:DataLoader, optimizer:Adam, loss_fn:nn.CrossEntropyLoss): 17 | """ 18 | Train model for 1 epoch and return dictionary with the average training metric values 19 | Args: 20 | model (nn.Module) 21 | data_loader (DataLoader) 22 | optimizer (Adam) 23 | loss_fn (nn.CrossEntropyLoss) 24 | 25 | Returns: 26 | [Float]: average training loss on epoch 27 | """ 28 | model.train(mode=True) 29 | num_batches = len(data_loader) 30 | 31 | loss = 0 32 | for x, y in data_loader: 33 | optimizer.zero_grad() 34 | logits = model(x) 35 | 36 | batch_loss = loss_fn(logits, y) 37 | 38 | batch_loss.backward() 39 | optimizer.step() 40 | 41 | loss += batch_loss.item() 42 | return loss / num_batches 43 | 44 | 45 | def eval_epoch(model: nn.Module, data_loader:DataLoader, loss_fn:nn.CrossEntropyLoss): 46 | """ 47 | Evaluate epoch on validation data 48 | Args: 49 | model (nn.Module) 50 | data_loader (DataLoader) 51 | loss_fn (nn.CrossEntropyLoss) 52 | 53 | Returns: 54 | [Float]: average validation loss 55 | """ 56 | model.eval() 57 | num_batches = len(data_loader) 58 | 59 | loss = 0 60 | with torch.no_grad(): 61 | for x, y in data_loader: 62 | pred_y = model(x) 63 | batch_loss = loss_fn(pred_y, y) 64 | loss += batch_loss.item() 65 | return loss / num_batches 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser(description="Script for training a model", 70 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 71 | parser.add_argument('--hidden_sizes', help='hidden layer dimensions', nargs='+', type=int, default=[128, 64]) 72 | parser.add_argument('--num_epochs', help='number of training epochs', type=int, default=10) 73 | parser.add_argument('--batch_size', help='batch size', type=int, default=128) 74 | parser.add_argument('--train_val_split', help='Train validation split ratio', type=float, default=0.8) 75 | parser.add_argument('--data_dir', help='directory of folder containing the MNIST dataset', default='../data') 76 | parser.add_argument('--save_dir', help='save directory', default='../saved_models', type=Path) 77 | 78 | args = parser.parse_args() 79 | 80 | mnist_trainset = datasets.MNIST(root=args.data_dir, train=True, download=True, transform=transforms.Compose([ 81 | transforms.ToTensor(), 82 | transforms.Normalize( 83 | (0.1307,), (0.3081,))])) 84 | 85 | # split training data to train/validation 86 | split_r = args.train_val_split 87 | mnist_trainset, mnist_valset = random_split(mnist_trainset, [round(len(mnist_trainset)*split_r), round(len(mnist_trainset)*(1 - split_r))]) 88 | 89 | mnist_testset = datasets.MNIST(root=args.data_dir, train=False, download=False, transform=transforms.Compose([ 90 | transforms.ToTensor(), 91 | transforms.Normalize( 92 | (0.1307,), (0.3081,))])) 93 | 94 | model = MLP(in_dim=28*28, out_dim=10, hidden_sizes=args.hidden_sizes) 95 | 96 | optimizer = Adam(model.parameters()) 97 | 98 | loss_fnc = nn.CrossEntropyLoss() 99 | 100 | train_loader = DataLoader(mnist_trainset, batch_size=args.batch_size, num_workers=1, shuffle=True) 101 | val_loader = DataLoader(mnist_valset, batch_size=args.batch_size, num_workers=1, shuffle=True) 102 | test_loader = DataLoader(mnist_testset, batch_size=args.batch_size, num_workers=1, shuffle=True) 103 | print('Training') 104 | for epoch in range(args.num_epochs): 105 | train_loss = train_epoch(model, train_loader, optimizer, loss_fnc) 106 | val_loss = eval_epoch(model, val_loader, loss_fnc) 107 | print(f"Epoch: {epoch + 1} - train loss: {train_loss:.5f} validation loss: {val_loss:.5f}") 108 | 109 | print('Evaluate model on test data') 110 | model.eval() 111 | with torch.no_grad(): 112 | acc = 0 113 | for samples, labels in test_loader: 114 | logits = model(samples.float()) 115 | probs = torch.nn.functional.softmax(logits, dim=1) 116 | preds = torch.argmax(probs, dim=1) 117 | acc += (preds == labels).sum() 118 | 119 | print(f"Accuracy: {(acc / len(mnist_testset.data))*100.0:.3f}%") 120 | torch.save({'state_dict': model.state_dict(), 121 | 'hidden_sizes': args.hidden_sizes, 122 | 'train_loss': train_loss, 123 | 'val_loss': val_loss, 124 | 'test_acc': acc}, 125 | args.save_dir / 'mlp_mnist.th') 126 | -------------------------------------------------------------------------------- /scripts/train_convnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for training a simple MLP for classification on the MNIST dataset 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.datasets as datasets 10 | import torchvision.transforms as transforms 11 | from neural_nets import MLP, ConvNet 12 | from torch.optim import Adam 13 | from torch.utils.data import DataLoader, random_split 14 | 15 | 16 | def train_epoch(model:nn.Module, data_loader:DataLoader, optimizer:Adam, loss_fn:nn.CrossEntropyLoss): 17 | """ 18 | Train model for 1 epoch and return dictionary with the average training metric values 19 | Args: 20 | model (nn.Module) 21 | data_loader (DataLoader) 22 | optimizer (Adam) 23 | loss_fn (nn.CrossEntropyLoss) 24 | 25 | Returns: 26 | [Float]: average training loss on epoch 27 | """ 28 | model.train(mode=True) 29 | num_batches = len(data_loader) 30 | 31 | loss = 0 32 | for x, y in data_loader: 33 | optimizer.zero_grad() 34 | logits = model(x) 35 | 36 | batch_loss = loss_fn(logits, y) 37 | 38 | batch_loss.backward() 39 | optimizer.step() 40 | 41 | loss += batch_loss.item() 42 | return loss / num_batches 43 | 44 | 45 | def eval_epoch(model: nn.Module, data_loader:DataLoader, loss_fn:nn.CrossEntropyLoss): 46 | """ 47 | Evaluate epoch on validation data 48 | Args: 49 | model (nn.Module) 50 | data_loader (DataLoader) 51 | loss_fn (nn.CrossEntropyLoss) 52 | 53 | Returns: 54 | [Float]: average validation loss 55 | """ 56 | model.eval() 57 | num_batches = len(data_loader) 58 | 59 | loss = 0 60 | with torch.no_grad(): 61 | for x, y in data_loader: 62 | pred_y = model(x) 63 | batch_loss = loss_fn(pred_y, y) 64 | loss += batch_loss.item() 65 | return loss / num_batches 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser(description="Script for training a model", 70 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 71 | parser.add_argument('--channel_sizes', help='channel layer dimensions', nargs='+', type=int, default=[16, 16]) 72 | parser.add_argument('--num_epochs', help='number of training epochs', type=int, default=10) 73 | parser.add_argument('--batch_size', help='batch size', type=int, default=128) 74 | parser.add_argument('--train_val_split', help='Train validation split ratio', type=float, default=0.8) 75 | parser.add_argument('--data_dir', help='directory of folder containing the MNIST dataset', default='../data') 76 | parser.add_argument('--save_dir', help='save directory', default='../saved_models', type=Path) 77 | 78 | 79 | args = parser.parse_args() 80 | 81 | mnist_trainset = datasets.MNIST(root=args.data_dir, train=True, download=False, transform=transforms.Compose([ 82 | transforms.ToTensor(), 83 | transforms.Normalize( 84 | (0.1307,), (0.3081,))])) 85 | 86 | # split training data to train/validation 87 | split_r = args.train_val_split 88 | mnist_trainset, mnist_valset = random_split(mnist_trainset, [round(len(mnist_trainset)*split_r), round(len(mnist_trainset)*(1 - split_r))]) 89 | 90 | mnist_testset = datasets.MNIST(root=args.data_dir, train=False, download=False, transform=transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize( 93 | (0.1307,), (0.3081,))])) 94 | 95 | model = ConvNet(channel_sizes= args.channel_sizes, out_dim=10) 96 | 97 | optimizer = Adam(model.parameters()) 98 | 99 | loss_fnc = nn.CrossEntropyLoss() 100 | 101 | train_loader = DataLoader(mnist_trainset, batch_size=args.batch_size, num_workers=1, shuffle=True) 102 | val_loader = DataLoader(mnist_valset, batch_size=args.batch_size, num_workers=1, shuffle=True) 103 | test_loader = DataLoader(mnist_testset, batch_size=args.batch_size, num_workers=1, shuffle=True) 104 | print('Training') 105 | for epoch in range(args.num_epochs): 106 | train_loss = train_epoch(model, train_loader, optimizer, loss_fnc) 107 | val_loss = eval_epoch(model, val_loader, loss_fnc) 108 | print(f"Epoch: {epoch + 1} - train loss: {train_loss:.5f} validation loss: {val_loss:.5f}") 109 | 110 | print('Evaluate model on test data') 111 | model.eval() 112 | with torch.no_grad(): 113 | acc = 0 114 | for samples, labels in test_loader: 115 | logits = model(samples.float()) 116 | probs = torch.nn.functional.softmax(logits, dim=1) 117 | preds = torch.argmax(probs, dim=1) 118 | acc += (preds == labels).sum() 119 | 120 | print(f"Accuracy: {(acc / len(mnist_testset.data))*100.0:.3f}%") 121 | torch.save({'state_dict': model.state_dict(), 122 | 'channel_sizes': args.channel_sizes, 123 | 'train_loss': train_loss, 124 | 'val_loss': val_loss, 125 | 'test_acc': acc}, 126 | args.save_dor / 'convnet_mnist.th') 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Integer-Only Inference for Deep Learning in Native C 2 | A repository containing Native C-code implementation of a convolutional neural network and multi-layer perceptron (MLP) models for integer-only inference. Model parameters are quantized to 8-bit integers, and floats are replaced with the fixed-point representation. 3 | The repository contains: 4 | - scripts for training model with PyTorch 5 | - post-training quantization of model parameters to 8-bit integers, 6 | - writing the relevant parameters in C 7 | - interfacing the C code for integer-only inference via C-types. 8 | 9 | The ideas presented in this tutorial were used to quantize and write an inference-only C code to deploy a deep reinforcement learning algorithm on a network interface card (NIC) in Tessler et al. 2021[1]. 10 | 11 | 12 | # Requirements 13 | Quantization is based on Nvidia's pytorch-quantization, which is part of TensorRT. 14 | https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization 15 | pytorch-quantization allows for more sophisticated quantization methods than what is presented here. For more details, see Hao et al. 2020[2]. 16 | 17 | **NOTE** pytorch-quantization requires a GPU and will not work without it 18 | 19 | ## C-code 20 | The c-code is structured to have separate files for the MLP and ConvNet models. 21 | C-code is located within the `src` directory in which: 22 | - `nn_math` - source and header files contain relevant mathematical functions 23 | - `nn` - source and header files contain relevant layers to create the neural network models 24 | - `mlp` - source and header files contain the MLP architecture to run for inference 25 | - `convnet` - source and header files contain the ConvNet architecture to run for inference 26 | - `mlp_params` - source and header files are generated via `scripts/create_mlp_c_params.py` and contains network weights, scale factors, and other relevant constants for the MLP model 27 | - `convnet_params` - source and header files are generated via `scripts/create_convnet_c_params.py` and contains network weights, scale factors, and other relevant constants for the ConvNet model 28 | ### Compilation 29 | The repository was tested using gcc. 30 | To compile and generate a shared library that can be called from Python using c-types run the following commands: 31 | #### MLP 32 | ``` 33 | gcc -Wall -fPIC -c mlp_params.c mlp.c nn_math.c nn.c 34 | gcc -shared mlp_params.o mlp.o nn_math.o nn.o -o mlp.so 35 | ``` 36 | #### ConvNet 37 | ``` 38 | gcc -Wall -fPIC -c convnet_params.c convnet.c nn_math.c nn.c 39 | gcc -shared convnet_params.o convnet.o nn_math.o nn.o -o convnet.so 40 | ``` 41 | ## Scripts 42 | - `src/train_mlp.py` and `src/train_convnet.py` are used to train an MLP/ConvNet model using PyTorch 43 | - `src/quantize_with_package.py` is used to quantize the models using the pytorch-quantization package 44 | - `src/create_mlp_c_params.py` and `src/create_convnet_c_params.py` create the header and source C files with relevant constants (network parameters, scale factors, and more) required to run the C-code. 45 | - `src/test_mlp_c.py` and `src/test_convnet_c.py` run inference on the models using C-types to interface the C-code files from Python 46 | 47 | 48 | ## Results - on the MNIST dataset 49 | ### MLP 50 | ``` 51 | Training 52 | Epoch: 1 - train loss: 0.35650 validation loss: 0.20097 53 | Epoch: 2 - train loss: 0.14854 validation loss: 0.13693 54 | Epoch: 3 - train loss: 0.10302 validation loss: 0.11963 55 | Epoch: 4 - train loss: 0.07892 validation loss: 0.11841 56 | Epoch: 5 - train loss: 0.06072 validation loss: 0.09850 57 | Epoch: 6 - train loss: 0.04874 validation loss: 0.09466 58 | Epoch: 7 - train loss: 0.04126 validation loss: 0.09458 59 | Epoch: 8 - train loss: 0.03457 validation loss: 0.10938 60 | Epoch: 9 - train loss: 0.02713 validation loss: 0.09077 61 | Epoch: 10 - train loss: 0.02135 validation loss: 0.09448 62 | Evaluating model on test data 63 | Accuracy: 97.450% 64 | ``` 65 | ``` 66 | Evaluating integer-only C model on test data 67 | Accuracy: 97.27% 68 | ``` 69 | ### ConvNet 70 | ``` 71 | Training 72 | Epoch: 1 - train loss: 0.37127 validation loss: 0.12948 73 | Epoch: 2 - train loss: 0.09653 validation loss: 0.08608 74 | Epoch: 3 - train loss: 0.07089 validation loss: 0.07480 75 | Epoch: 4 - train loss: 0.05846 validation loss: 0.06347 76 | Epoch: 5 - train loss: 0.05044 validation loss: 0.05909 77 | Epoch: 6 - train loss: 0.04567 validation loss: 0.05466 78 | Epoch: 7 - train loss: 0.04071 validation loss: 0.05099 79 | Epoch: 8 - train loss: 0.03668 validation loss: 0.05336 80 | Epoch: 9 - train loss: 0.03543 validation loss: 0.04965 81 | Epoch: 10 - train loss: 0.03164 validation loss: 0.04883 82 | Evaluate model on test data 83 | Accuracy: 98.620% 84 | ``` 85 | ``` 86 | Evaluating integer-only C model on test data 87 | Accuracy: 98.58% 88 | ``` 89 | # References 90 | [1] Tessler, C., Shpigelman, Y., Dalal, G., Mandelbaum, A., Kazakov, D. H., Fuhrer, B., Chechik, G., & Mannor, S. (2021). Reinforcement Learning for Datacenter Congestion Control. http://arxiv.org/abs/2102.09337 91 | [2] Wu, H., Judd, P., Zhang, X., Isaev, M., & Micikevicius, P. (2020). Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation. http://arxiv.org/abs/2004.09602 92 | -------------------------------------------------------------------------------- /scripts/create_mlp_c_params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for writing param header and source files in C with weights and amax values calculate in python 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import torch 8 | 9 | def get_output_dim(input_dim, kernel_size, stride): 10 | output_dim = (input_dim -(kernel_size-1) - 1) / stride 11 | return int(output_dim + 1) 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser(description="Script for post-training quantization of a pre-trained model", 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument('--filename', help='filename', type=str, default='mlp_mnist_quant.th') 17 | parser.add_argument('--save_dir', help='save directory', default='../saved_models', type=Path) 18 | 19 | args = parser.parse_args() 20 | 21 | saved_stats = torch.load(args.save_dir / + args.filename) 22 | state_dict = saved_stats['state_dict'] 23 | hidden_sizes = saved_stats['hidden_sizes'] 24 | 25 | 26 | 27 | # create header file 28 | with open('../src/mlp_params.h', 'w') as f: 29 | f.write('/*******************************************************************\n') 30 | f.write('@file mlp_params.h\n* @brief variable prototypes for model parameters and amax values\n*\n*\n') 31 | f.write('* @author Benjamin Fuhrer\n*\n') 32 | f.write('*******************************************************************/\n') 33 | f.write('#ifndef MLP_PARAMS\n#define MLP_PARAMS\n\n') 34 | 35 | f.write(f'#define INPUT_DIM {28*28}\n') 36 | for idx, hidden_size in enumerate(hidden_sizes, start=1): 37 | f.write(f'#define H{idx} {hidden_size}\n') 38 | 39 | f.write(f'#define OUTPUT_DIM {10}\n') 40 | f.write('#include \n\n\n') 41 | 42 | 43 | f.write('// quantization/dequantization constants\n') 44 | 45 | for layer_idx in range(1, 4): 46 | 47 | 48 | name = f'layer_{layer_idx}_s_x' 49 | f.write(f"extern const int {name};\n") 50 | 51 | name = f'layer_{layer_idx}_s_x_inv' 52 | f.write(f"extern const int {name};\n") 53 | 54 | name = f'layer_{layer_idx}_s_w_inv' 55 | value = state_dict[name] 56 | f.write(f"extern const int {name}[{len(value)}];\n") 57 | 58 | name = f'layer_{layer_idx}_s_x_f' 59 | f.write(f"extern const float {name};\n") 60 | 61 | name = f'layer_{layer_idx}_s_x_inv_f' 62 | f.write(f"extern const float {name};\n") 63 | 64 | name = f'layer_{layer_idx}_s_w_inv_f' 65 | value = state_dict[name.replace('_f', '')] 66 | f.write(f"extern const float {name}[{len(value)}];\n") 67 | 68 | 69 | f.write('// Layer quantized parameters\n') 70 | for layer_idx in range(1, 4): 71 | name = f'layer_{layer_idx}_weight' 72 | param = state_dict[f'layer_{layer_idx}_weight'] 73 | f.write(f"extern const int8_t {name}[{len(param.flatten())}];\n") 74 | 75 | f.write('\n#endif // end of MLP_PARAMS\n') 76 | 77 | # create source file 78 | with open('../src/mlp_params.c', 'w') as f: 79 | f.write('#include "mlp_params.h"\n\n\n') 80 | 81 | for layer_idx in range(1, 4): 82 | name = f'layer_{layer_idx}_s_x' 83 | fxp_value = (state_dict[name] * (2**16)).round() 84 | f.write(f"const int {name} = {int(fxp_value)};\n\n") 85 | 86 | name = f'layer_{layer_idx}_s_x_inv' 87 | fxp_value = (state_dict[name] * (2**16)).round() 88 | f.write(f"const int {name} = {int(fxp_value)};\n\n") 89 | 90 | name = f'layer_{layer_idx}_s_w_inv' 91 | fxp_value = (state_dict[name] * (2**16)).round() 92 | f.write(f"const int {name}[{len(fxp_value)}] = {{") 93 | 94 | for idx in range(len(fxp_value)): 95 | f.write(f"{int(fxp_value[idx])}") 96 | if idx < len(fxp_value) - 1: 97 | f.write(", ") 98 | f.write("};\n\n") 99 | 100 | name = f'layer_{layer_idx}_s_x' 101 | value = state_dict[name] 102 | f.write(f"const float {name}_f = {float(value)};\n\n") 103 | 104 | name = f'layer_{layer_idx}_s_x_inv' 105 | value = state_dict[name] 106 | f.write(f"const float {name}_f = {float(value)};\n\n") 107 | 108 | name = f'layer_{layer_idx}_s_w_inv' 109 | value = state_dict[name] 110 | f.write(f"const float {name}_f[{len(value)}] = {{") 111 | 112 | for idx in range(len(value)): 113 | f.write(f"{float(value[idx])}") 114 | if idx < len(value) - 1: 115 | f.write(", ") 116 | f.write("};\n\n") 117 | 118 | 119 | for layer_idx in range(1, 4): 120 | name = f'layer_{layer_idx}_weight' 121 | param = state_dict[f'layer_{layer_idx}_weight'] 122 | param = param.flatten() 123 | f.write(f"const int8_t {name}[{len(param)}] = {{") 124 | for idx in range(len(param)): 125 | f.write(f"{param[idx]}") 126 | if idx < len(param) - 1: 127 | f.write(", ") 128 | f.write("};\n") 129 | -------------------------------------------------------------------------------- /scripts/quantize_with_package.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for PTQ using pytorch-quantization package 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torchvision.datasets as datasets 11 | import torchvision.transforms as transforms 12 | from neural_nets import MLP, ConvNet 13 | from pytorch_quantization import calib 14 | from pytorch_quantization import nn as quant_nn 15 | from pytorch_quantization import quant_modules 16 | from pytorch_quantization.tensor_quant import QuantDescriptor 17 | from torch.utils.data import DataLoader 18 | 19 | 20 | def collect_stats(model, data_loader, num_bins): 21 | """Feed data to the network and collect statistic""" 22 | model.eval() 23 | # Enable calibrators 24 | for name, module in model.named_modules(): 25 | if isinstance(module, quant_nn.TensorQuantizer): 26 | if module._calibrator is not None: 27 | module.disable_quant() 28 | module.enable_calib() 29 | if isinstance(module._calibrator, calib.HistogramCalibrator): 30 | module._calibrator._num_bins = num_bins 31 | else: 32 | module.disable() 33 | 34 | for batch, _ in data_loader: 35 | x = batch.float() 36 | model(x) 37 | 38 | # Disable calibrators 39 | for _, module in model.named_modules(): 40 | if isinstance(module, quant_nn.TensorQuantizer): 41 | if module._calibrator is not None: 42 | module.enable_quant() 43 | module.disable_calib() 44 | else: 45 | module.enable() 46 | 47 | def compute_amax(model, **kwargs): 48 | # Load calib result 49 | for name, module in model.named_modules(): 50 | if isinstance(module, quant_nn.TensorQuantizer): 51 | if module._calibrator is not None: 52 | if isinstance(module._calibrator, calib.MaxCalibrator): 53 | module.load_calib_amax() 54 | else: 55 | module.load_calib_amax(**kwargs) 56 | print(F"{name:40}: {module}") 57 | 58 | 59 | def quantize_model_params(model): 60 | """Quantize layer weights using calculated amax 61 | and process scale constant for C-code 62 | 63 | Args: 64 | state_dict (Dict): pytorch model state_dict 65 | amax (Dict): dictionary containing amax values 66 | """ 67 | 68 | is_mlp = isinstance(model, MLP) 69 | 70 | indices = [0, 2, 4] if is_mlp else [0, 3, 7] 71 | scale_factor = 127 # 127 for 8 bits 72 | 73 | 74 | state_dict = dict() 75 | 76 | for layer_idx, idx in enumerate(indices, start=1): 77 | # quantize all parameters 78 | weight = model.state_dict()[f'net.{idx}.weight'] 79 | s_w = model.state_dict()[f'net.{idx}._weight_quantizer._amax'].numpy() 80 | s_x = model.state_dict()[f'net.{idx}._input_quantizer._amax'].numpy() 81 | 82 | scale = weight * (scale_factor / s_w) 83 | state_dict[f'layer_{layer_idx}_weight'] = torch.clamp(scale.round(), min=-127, max=127).to(int) 84 | if is_mlp or layer_idx == 3: 85 | state_dict[f'layer_{layer_idx}_weight'] = state_dict[f'layer_{layer_idx}_weight'].T 86 | state_dict[f'layer_{layer_idx}_weight'] = state_dict[f'layer_{layer_idx}_weight'].numpy() 87 | 88 | state_dict[f'layer_{layer_idx}_s_x'] = scale_factor / s_x 89 | state_dict[f'layer_{layer_idx}_s_x_inv'] = s_x / scale_factor 90 | state_dict[f'layer_{layer_idx}_s_w_inv'] = (s_w / scale_factor).squeeze() 91 | 92 | return state_dict 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser(description="Script for post-training quantization of a pre-trained model", 98 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 99 | parser.add_argument('--filename', help='filename', type=str, default='mlp_mnist.th') 100 | parser.add_argument('--num_bins', help='number of bins', type=int, default=128) 101 | parser.add_argument('--data_dir', help='directory of folder containing the MNIST dataset', default='../data') 102 | parser.add_argument('--save_dir', help='save directory', default='../saved_models', type=Path) 103 | 104 | args = parser.parse_args() 105 | # load model 106 | saved_stats = torch.load(args.save_dir + args.filename) 107 | 108 | 109 | 110 | state_dict = saved_stats['state_dict'] 111 | 112 | hidden_sizes = None if 'convnet' in args.filename else saved_stats['hidden_sizes'] 113 | channel_sizes = None if 'mlp' in args.filename else saved_stats['channel_sizes'] 114 | 115 | 116 | quant_nn.QuantLinear.set_default_quant_desc_input(QuantDescriptor(calib_method='histogram')) 117 | quant_nn.QuantConv2d.set_default_quant_desc_input(QuantDescriptor(calib_method='histogram')) 118 | quant_modules.initialize() 119 | 120 | 121 | model = MLP(in_dim=28*28, hidden_sizes=hidden_sizes, out_dim=10) if 'mlp' in args.filename else ConvNet(channel_sizes=channel_sizes, out_dim=10) 122 | model.load_state_dict(state_dict) 123 | 124 | mnist_trainset = datasets.MNIST(root=args.data_dir train=True, download=False, transform=transforms.Compose([ 125 | transforms.ToTensor(), 126 | transforms.Normalize( 127 | (0.1307,), (0.3081,))])) 128 | 129 | train_loader = DataLoader(mnist_trainset, batch_size=len(mnist_trainset.data), num_workers=1, shuffle=False) 130 | 131 | # It is a bit slow since we collect histograms on CPU 132 | with torch.no_grad(): 133 | collect_stats(model, train_loader, args.num_bins) 134 | compute_amax(model, method="entropy") 135 | 136 | state_dict = quantize_model_params(model) 137 | saved_stats['state_dict'] = state_dict 138 | 139 | name = args.filename.replace('.th', '_quant.th') 140 | 141 | torch.save(saved_stats, 142 | args.save_dir / name) 143 | -------------------------------------------------------------------------------- /src/nn_math.c: -------------------------------------------------------------------------------- 1 | 2 | #include "nn_math.h" 3 | 4 | void mat_mult(const int8_t *mat_l, const int8_t *mat_r, int *result, const unsigned int N, const unsigned int K, const unsigned int M) 5 | { 6 | unsigned int n, k, m; 7 | unsigned int row, col; 8 | int accumulator; 9 | 10 | for (m = 0; m < M; m++) 11 | { 12 | for (n = 0; n < N; n++) 13 | { 14 | row = n*K; 15 | accumulator = 0; 16 | for (k = 0; k < K; k++) 17 | { 18 | col = k*M; 19 | accumulator += mat_l[row + k] * mat_r[col + m]; 20 | } 21 | result[n*M + m] = accumulator; 22 | } 23 | } 24 | } 25 | 26 | 27 | int get_output_dim(int input_dim, int kernel_size, int stride) 28 | { 29 | int output_dim = (input_dim -(kernel_size-1) - 1) / stride; 30 | return output_dim + 1; 31 | } 32 | 33 | 34 | void conv2d(const int8_t *x, const int8_t *w, int *y, int N, int C_in, int C_out, int H, int W, int H_new, int W_new, 35 | int k_size_h, int k_size_w, int stride_h, int stride_w) 36 | { 37 | int n_i, c_out_j, c_in_i; /* sample and channels*/ 38 | int n, m; /* kernel iterations */ 39 | int i, j; /* output image iteration*/ 40 | 41 | for (n_i = 0; n_i < N; n_i++) 42 | { 43 | int N_idx_y = n_i*C_out*H_new*W_new; 44 | int N_idx_x = n_i*C_in*H*W; 45 | 46 | for (c_out_j = 0; c_out_j < C_out; c_out_j++) 47 | { 48 | int C_out_idx_y = c_out_j*H_new*W_new; 49 | int C_out_idx_kernel = c_out_j*C_in*k_size_h*k_size_w; 50 | 51 | for (i = 0; i < H_new; i++) 52 | { 53 | for (j = 0; j < W_new; j++) 54 | { 55 | int output_idx_y = i*W_new + j; 56 | int output_idx_x = i*stride_h*W + j*stride_w; 57 | int sum = 0; 58 | for (c_in_i = 0; c_in_i < C_in; c_in_i++) 59 | { 60 | int C_in_idx_x = c_in_i*H*W; 61 | int C_in_idx_kernel = c_in_i*k_size_h*k_size_w; 62 | for (n = 0; n < k_size_h; n++) 63 | { 64 | for (m = 0; m < k_size_w; m++) 65 | { 66 | int kernel_idx = n*k_size_w + m; 67 | int kernel_idx_x = n*W + m; 68 | int x_value = (int)x[N_idx_x + C_in_idx_x + kernel_idx_x + output_idx_x]; 69 | int w_value = (int)w[C_out_idx_kernel + C_in_idx_kernel + kernel_idx]; 70 | sum += x_value*w_value; 71 | } 72 | } 73 | } 74 | y[N_idx_y + C_out_idx_y + output_idx_y] = sum; 75 | } 76 | 77 | } 78 | } 79 | } 80 | } 81 | 82 | 83 | void pooling2d(int *x, int *y, int N, int C_out, int H, int W, int H_new, int W_new, 84 | int k_size_h, int k_size_w, int stride_h, int stride_w) 85 | { 86 | int n_i, c_out_j; /* sample and channels*/ 87 | int n, m; /* kernel iterations */ 88 | int i, j; /* output image iteration*/ 89 | 90 | for (n_i = 0; n_i < N; n_i++) 91 | { 92 | int N_idx_y = n_i*C_out*H_new*W_new; 93 | int N_idx_x = n_i*C_out*H*W; 94 | 95 | for (c_out_j = 0; c_out_j < C_out; c_out_j++) 96 | { 97 | int C_out_idx_y = c_out_j*H_new*W_new; 98 | int C_out_idx_x = c_out_j*H*W; 99 | 100 | for (i = 0; i < H_new; i++) 101 | { 102 | for (j = 0; j < W_new; j++) 103 | { 104 | int output_idx_y = i*W_new + j; 105 | int output_idx_x = i*stride_h*W + j*stride_w; 106 | 107 | int max = x[N_idx_x+ C_out_idx_x + output_idx_x]; 108 | for (n = 0; n < k_size_w; n++) 109 | { 110 | for (m = 0; m < k_size_h; m++) 111 | { 112 | int kernel_idx = n*W + m; 113 | 114 | int value = x[N_idx_x+ C_out_idx_x + kernel_idx + output_idx_x]; 115 | if (value > max) 116 | max = value; 117 | } 118 | } 119 | y[N_idx_y + C_out_idx_y + output_idx_y] = max; 120 | } 121 | 122 | } 123 | } 124 | } 125 | } 126 | 127 | 128 | void relu(int *tensor, const unsigned int size) 129 | { 130 | unsigned int i; 131 | for (i = 0; i < size; i++) 132 | tensor[i] = MAX(tensor[i], 0); 133 | } 134 | 135 | void quantize(const int *tensor_in, int8_t *tensor_q, const int scale_factor, 136 | const int scale_factor_inv, const unsigned int size) 137 | { 138 | unsigned int i; 139 | int rounded_value, tensor_int, tensor_frac; 140 | // separation to integer and fraction parts 141 | int scale_factor_int = (scale_factor + ROUND_CONST) >> FXP_VALUE; 142 | int scale_factor_frac = scale_factor - (scale_factor_int << FXP_VALUE); 143 | // element wise operation - we iterate throughout the entire length of the flattened tensor 144 | for (i = 0; i < size; i++) 145 | { 146 | tensor_int = (tensor_in[i] + ROUND_CONST) >> FXP_VALUE; 147 | if (tensor_int > INT8_MAX_VALUE*scale_factor_inv) 148 | tensor_q[i] = (int8_t)INT8_MAX_VALUE; 149 | else if (tensor_int < -INT8_MAX_VALUE*scale_factor_inv) 150 | tensor_q[i] = -(int8_t)INT8_MAX_VALUE; 151 | else 152 | { 153 | tensor_frac = tensor_in[i] - (tensor_int << FXP_VALUE); 154 | // int * fxp = result is in fxp */ 155 | rounded_value = tensor_int*scale_factor_frac + scale_factor_int*tensor_frac; 156 | // fxp * fxp = fix-point multiplication with result is in fxp */ 157 | rounded_value += (tensor_frac*scale_factor_frac + ROUND_CONST) >> FXP_VALUE; 158 | // convert fxp to int and add to integer parts as final value should be a rounded integer 159 | rounded_value = ((rounded_value + ROUND_CONST) >> FXP_VALUE) + tensor_int*scale_factor_int; 160 | 161 | tensor_q[i] = (int8_t)rounded_value; /* store quantized value in output tensor */ 162 | } 163 | } 164 | } 165 | 166 | 167 | void dequantize_per_row(int *mat_in, const int *scale_factor_w_inv, const int scale_factor_x_inv, 168 | const unsigned int N, const unsigned int M) 169 | { 170 | unsigned int k, n; 171 | 172 | int out_value; 173 | 174 | 175 | for (n = 0; n < N; n++) 176 | { 177 | for (k = 0; k < M; k++) 178 | { 179 | 180 | out_value = scale_factor_w_inv[k] * scale_factor_x_inv; 181 | if (out_value > (1 << FXP_VALUE)) 182 | mat_in[n*M + k] *= ((out_value + ROUND_CONST) >> FXP_VALUE); 183 | else 184 | mat_in[n*M + k] = (out_value*mat_in[n*M + k] + ROUND_CONST) >> FXP_VALUE; 185 | } 186 | } 187 | } 188 | 189 | 190 | 191 | void dequantize_per_channel(int *tensor_in, const int *amax_w, const int amax_x, const unsigned int N, const unsigned int C, const unsigned int K) 192 | { 193 | unsigned int k, n, c; 194 | 195 | int out_value; 196 | 197 | for (n = 0; n < N; n++) 198 | { 199 | for (c = 0; c < C; c++) 200 | { 201 | for (k =0; k < K; k++) 202 | { 203 | out_value = amax_w[c] *amax_x; 204 | if (out_value > (1 << FXP_VALUE)) 205 | tensor_in[n*C + c*K + k] *= ((out_value + ROUND_CONST) >> FXP_VALUE); 206 | else 207 | tensor_in[n*C + c*K + k] = (out_value*tensor_in[n*C + c*K + k] + ROUND_CONST) >> FXP_VALUE; 208 | } 209 | } 210 | 211 | } 212 | } 213 | 214 | 215 | void argmax_over_cols(const int *mat_in, unsigned int *indices, const unsigned int N, const unsigned int M) 216 | { 217 | 218 | // calculate max of each row 219 | unsigned int n, m, max_idx; 220 | int row_max, value; 221 | for (n = 0; n < N; n++) 222 | { 223 | row_max = mat_in[n*M]; 224 | max_idx = 0; 225 | for (m = 0; m < M; m++) 226 | { 227 | value = mat_in[n*M + m]; 228 | if (value > row_max) 229 | { 230 | row_max = value; 231 | max_idx = m; // return column 232 | } 233 | } 234 | indices[n] = max_idx; 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /src/convnet_params.c: -------------------------------------------------------------------------------- 1 | #include "convnet_params.h" 2 | 3 | 4 | const int layer_1_s_x = 2949889; 5 | 6 | const int layer_1_s_x_inv = 1456; 7 | 8 | const int layer_1_s_w_inv[16] = {413, 272, 298, 260, 249, 331, 206, 332, 290, 235, 229, 302, 241, 307, 240, 229}; 9 | 10 | const int layer_2_s_x = 1565693; 11 | 12 | const int layer_2_s_x_inv = 2743; 13 | 14 | const int layer_2_s_w_inv[16] = {216, 167, 181, 204, 212, 175, 228, 155, 135, 133, 285, 168, 185, 193, 175, 201}; 15 | 16 | const int layer_3_s_x = 615328; 17 | 18 | const int layer_3_s_x_inv = 6980; 19 | 20 | const int layer_3_s_w_inv[10] = {170, 175, 168, 137, 113, 158, 154, 150, 140, 153}; 21 | 22 | const int8_t layer_1_weight[144] = {-51, -70, 74, 37, -127, -105, 90, 55, -15, -27, 73, -54, -38, -8, -127, 61, 107, 19, -68, 66, 40, -120, 51, 10, -127, 33, 66, 115, 113, -50, 127, -62, -104, -1, -34, -111, 85, -90, -127, -72, -75, 86, 8, 124, 97, -44, 38, 19, -40, -94, -127, -33, 33, 73, -127, 62, 40, 38, 105, 33, 77, -75, 56, -52, 79, 30, -47, 50, 84, -126, -127, -94, -108, -82, -25, -59, 69, 56, 127, -9, 77, -127, -18, 122, 79, 102, 7, 50, -99, -61, -26, -68, -127, -90, -13, -38, -101, -100, 88, 127, -33, -105, 46, -99, -75, -97, -90, 17, 80, 15, -29, 67, 24, -45, 67, 127, 121, 47, 2, 22, -50, -30, 75, -127, -3, 67, 93, 94, 116, -10, -15, 29, -98, -96, -127, 55, 89, 78, 14, -49, 11, 51, 127, 0}; 23 | const int8_t layer_2_weight[2304] = {-29, -25, 22, -32, -11, 19, 67, 63, 69, -28, -15, 18, 3, -28, 15, 22, -17, 38, -44, -71, -54, 10, -54, -41, -21, -41, -32, -105, -101, -127, -101, -65, -40, -55, -29, -9, 30, 8, 1, 4, -11, -13, -28, 31, 39, 23, 50, 47, -13, 3, 8, -44, 38, 14, -25, -20, 30, -2, 5, 3, 33, -15, 12, 2, -4, 27, -25, 5, 30, -14, 30, 13, -18, 53, 24, 35, 3, -3, -7, -30, -19, -35, -12, 32, 29, 30, 14, -5, -18, 20, 39, 53, 11, -17, 20, 2, -44, 7, -26, 16, 6, -39, -9, -34, -42, -14, -42, -81, 23, 6, 5, 5, 27, 35, 55, 30, 9, -58, -24, -60, -15, -20, -7, -3, 22, -12, -24, -17, -16, -58, 14, 63, -5, 12, 23, 5, 7, 28, 32, -7, 9, 52, 29, 39, -54, -97, -118, -92, -123, 39, -87, 46, 82, -22, 32, -11, 24, -2, 69, 37, -2, 18, -50, -7, -55, -98, -99, -27, -92, 2, 22, -67, -42, 103, 43, 83, 70, 85, 55, -18, 32, 57, 6, 56, 59, 50, -7, 52, -12, 70, 24, -38, -9, -124, -67, -120, -111, -93, -14, 31, 25, -2, 30, 13, 38, 0, -34, 3, -92, -20, -127, -110, -57, -33, 22, 70, 40, 58, 26, 16, -1, -52, -58, 2, -10, 15, 69, 120, 51, 82, 43, 58, 39, 41, 20, -28, -92, 12, -59, -31, -25, -90, -90, -61, -39, -3, 32, 29, 70, -27, -32, -11, -18, -5, 16, 49, 17, 6, 17, -37, -65, 6, -18, -16, 2, -1, 14, -69, 50, 13, 26, -11, 41, -22, 43, 21, 45, 55, -10, -25, 6, 30, 5, 42, -10, 53, -26, -30, 67, 24, -25, -38, -89, -58, 9, -2, -75, 75, 12, -22, 8, -71, 6, 51, 26, -35, 82, -9, -29, 33, -29, 38, -32, 0, -61, -37, 52, 29, 42, 86, -1, 11, -27, 73, 98, 7, -16, -22, 127, 66, 109, 98, 3, -21, -29, -44, -39, 42, 56, 57, -46, 13, -5, 22, -17, 1, -50, 13, -5, 33, 55, -31, 59, -40, 41, 10, -6, -60, -39, 14, -11, -10, -80, -7, 24, 65, 7, 74, 28, 23, 89, 0, 58, -31, 7, -33, 74, 66, -20, -18, -37, -36, 44, -13, 33, -45, -19, -26, 34, 53, 23, 71, 23, -40, -41, 69, 3, 4, -52, -70, -76, -36, -34, -23, -34, 85, -11, -27, -49, 17, 27, 11, -11, -33, -17, 42, 26, 73, 45, -52, -59, -33, 93, 26, 5, -13, -10, -56, -39, -20, -17, 10, 27, 7, 7, -8, -39, -27, -78, -50, 2, 28, 29, -62, 2, -33, -12, 14, -33, -46, 9, -1, -20, 18, -20, 55, -18, -12, 38, 19, 30, 10, 55, 37, -37, 75, 65, -82, -55, -84, 18, -120, -53, 68, -57, -8, -9, -6, -14, 11, 10, 40, 63, 21, -11, -16, 50, 19, -51, 29, -5, -33, -3, -49, -3, 2, 38, -27, 27, 9, -1, 9, -3, 81, 10, -38, -19, -84, -127, 40, -85, -51, -22, 36, -25, -47, -3, -49, -46, -11, -27, -3, -20, 25, 50, -21, 23, 28, 12, 48, 29, 38, 18, 62, 56, 66, 21, 58, 73, 38, 51, 3, -44, 75, -40, -2, 18, -45, -15, -27, -51, -2, -12, -55, 45, -41, -22, 36, -27, -20, 11, -34, -22, 0, -43, -79, 57, 79, 43, -34, 62, -22, -16, 47, -64, -19, -5, -93, -65, 12, -21, -43, 50, -9, -2, -41, -64, 35, -4, 23, 28, -23, 6, 81, -39, 25, 88, -65, 61, 117, -24, 56, 113, -48, -69, 82, 34, -42, 95, -3, -24, 41, -22, -49, -42, -10, 22, -35, 20, -52, -5, -4, -3, -40, 16, -44, -110, 7, -66, -89, 2, 36, 61, -61, 19, 38, -15, 38, 40, 32, -1, -89, 26, -73, -88, 48, -44, -47, 14, 3, 6, -22, 24, -46, 50, 38, -21, 25, 18, 53, -18, 18, 16, -6, 22, 13, 39, 21, 1, 42, 47, -33, 43, -30, -70, 13, 2, -42, 48, -23, -18, 39, 8, 25, -49, -12, 3, -69, 2, 46, -53, -4, 44, -46, 54, 69, -22, 46, 74, -28, 43, 44, 22, -29, -58, 33, -88, -124, 6, -92, -127, -40, 39, 43, -54, 28, 15, -35, -1, 16, 55, 52, 26, -24, 1, -41, -72, -8, -95, 57, -5, -26, 34, 2, -16, 52, 28, 19, 9, -2, 44, 4, 0, 74, 34, -34, 24, 53, 17, 7, 31, 13, 20, -27, 63, 113, 62, -11, 40, 4, 43, 74, 23, 25, 39, -17, -74, 5, -91, -49, -17, 10, -22, -17, 3, -24, -27, 37, 13, -6, -18, -15, 33, -72, 29, -23, -61, -38, -99, -1, -19, -64, 0, -34, -19, -54, -27, 33, 12, 27, 27, 37, 13, -41, -5, 3, 59, -9, -6, 56, -60, -64, 34, -36, -20, 28, -27, 1, -22, -13, -58, 25, 15, -33, -23, -54, 1, 39, 18, -2, -11, 41, 16, -48, 29, 33, -11, 37, 38, 93, 3, 42, 35, 13, -2, -30, -87, -18, -35, -127, -84, -96, -95, -70, -50, -4, -15, -75, 23, 15, -35, 11, 28, -13, 18, -76, -96, -21, -102, -124, 2, -65, -59, 6, -26, -26, 13, 12, 11, 6, 35, 31, -11, 40, 35, 25, 15, 12, -42, -20, 0, 20, -45, 19, -33, 13, 62, 17, 49, 32, 4, 12, 38, 55, 21, -31, 46, 12, -56, 10, 7, 25, 24, 22, -32, -4, -53, -50, -34, -24, 15, 8, 18, 15, 20, 22, -16, -18, -34, -43, 10, -97, -52, -43, -127, -36, -12, 19, 26, 9, 25, -9, 23, 9, -24, -39, -8, 54, -10, 12, 46, 54, 36, 4, 4, 24, 21, -6, 50, 2, 24, -5, -18, 20, 8, 21, 19, 44, 40, 30, 56, 35, -47, -19, 21, 4, -10, -6, 1, 12, 22, -33, 36, 42, -3, 1, -7, -4, 14, -15, -63, -76, -33, -86, -54, -44, -35, -38, -59, -42, 3, 14, 0, 7, -6, 32, 18, 8, -107, -23, 15, -64, -67, -52, -34, -86, 47, -16, -7, 55, 21, 39, -7, 13, 17, -9, 37, 2, -42, -26, -11, -83, -45, -118, -76, 5, 36, -47, -44, -80, -69, 3, 45, 6, -22, 14, 3, 53, 90, 1, 15, 48, 35, -5, 40, 47, 7, 86, 50, -43, 0, -45, 3, 17, -37, 9, 57, 67, 41, 64, 36, -46, 55, 31, -17, -34, 10, 24, 55, -30, -13, -36, 3, 52, 47, 39, 52, 42, -10, 71, -15, -79, -15, 21, 36, 76, 107, 65, 21, -41, -13, -44, 31, -92, -70, -119, -100, -3, -34, -22, -11, -25, -19, -77, -99, -38, -13, 1, -8, 23, 37, -2, 82, -11, -20, -64, 34, -35, -87, -58, -24, -75, -24, 33, -3, 86, -7, -127, -27, 34, 27, 88, 91, -20, -40, -41, 39, 28, 34, 32, -4, 30, 58, 91, -42, -2, 37, -6, 54, -3, -6, 63, -2, -29, 31, -19, -72, -84, -43, -21, -9, -16, -97, -34, -62, -96, -46, -87, -10, -64, -45, -44, 13, -2, 3, -34, -40, -16, 35, 78, 31, -34, 16, 28, 52, 48, 57, -16, 50, -16, -45, 1, 21, 53, 36, 44, 11, 1, 60, 39, -1, 40, -37, -35, -17, 24, 5, -6, 36, 34, 50, 127, 105, -34, 17, 13, 41, 28, -47, -70, -112, 0, -20, 15, -19, 79, 2, 22, 65, 12, -22, -24, -50, -27, -6, -82, -12, -36, -15, 30, 8, -76, -67, -43, -15, -39, 25, -6, -43, -1, 59, 9, 41, 19, -45, 9, -86, -88, -28, 1, 17, 25, -18, 24, -9, -1, -39, -28, 82, 69, 91, 32, 84, 97, 76, 48, 24, -3, 7, 15, 49, 3, -11, -45, -26, 19, 30, -65, -31, 46, 84, -3, 11, -7, 29, 39, 1, 59, 67, 38, 5, -23, 5, 5, 42, 1, -5, -28, -61, -65, -59, -2, -7, 48, 80, -119, -4, 24, -3, 63, 38, 31, 35, 23, 97, 61, 104, 51, -2, -11, 49, -41, -7, 71, -20, 0, 32, -17, -11, 27, -8, -37, -31, 31, 19, 92, 6, -37, -60, -45, 9, -23, 9, 30, 4, 59, 102, 46, 25, 1, -2, 35, 29, 113, -57, -89, -95, 47, -24, -23, 8, 29, 55, 5, 0, 40, -19, 1, -3, -31, -57, -71, -23, -41, -59, -54, 4, -10, -24, -64, -87, 51, -14, -41, 1, -13, 18, 33, 50, 7, -127, -70, -85, -4, -4, -11, -25, -61, -17, 4, 53, 57, -32, 31, -31, 36, 50, 74, 127, 101, 92, -68, -53, -63, -33, -21, 20, -66, -41, 30, -1, 8, 24, 63, 127, 89, 27, 112, 113, 18, 7, -2, 42, 45, 46, -38, 10, 29, 4, 2, -13, 17, 54, -24, 11, 38, 22, 29, 36, 40, 15, -9, 10, -34, -43, -73, -11, 10, 16, -7, -2, -1, -28, -39, 3, -45, -33, -44, -17, -23, -26, -15, -73, -6, -7, -11, -16, 10, -17, -30, -12, -10, 15, 1, -18, 11, 23, 27, 36, 3, 44, 27, -7, -15, 4, -7, -20, -5, -41, -24, 11, 4, 7, 7, 0, -30, -36, -6, -8, -4, 5, -13, 10, -41, -32, 12, -49, -68, -29, -4, 2, 8, -52, -97, -39, -2, -73, -50, -4, 9, -32, 4, 26, 21, -18, 23, 28, 16, -8, -29, 29, 52, -35, 35, 38, 23, -7, -13, 29, 10, -23, -14, 31, 8, -31, 9, -9, -23, -19, -13, 5, -15, 7, 14, -21, 22, 7, -86, -79, -86, -127, -73, -76, -19, -11, 12, -52, 12, 35, 21, 24, 16, 39, 9, -12, 26, 4, 37, 35, -24, 74, -41, 13, -1, -93, 43, 1, -35, 18, -4, -14, 21, 34, -24, 5, 14, 51, 27, -10, -12, -51, -53, 35, -26, -5, 10, -52, -29, -15, 7, 2, -10, 15, 32, 30, 59, 3, 37, -3, -69, -42, -74, 11, -124, -95, -14, 28, 36, 3, 9, -13, -35, 56, 13, -59, 9, 16, -18, -11, 25, 16, 36, 63, -1, 19, -27, -15, 55, 2, -18, -14, -24, -42, -66, -73, -63, -17, 70, 17, 31, 89, -14, -15, 51, 43, -27, 43, 29, 26, 16, 11, 36, 38, 28, 57, 15, 63, 21, -4, 13, 51, 19, -78, -33, -83, -49, -99, -78, -105, -16, 59, 47, 10, 42, 5, -12, 0, 18, 26, 25, -21, 22, 32, -19, 116, 22, 11, 4, 68, 45, -1, -34, -71, -46, -49, -87, -8, -48, -21, 25, -30, -51, 21, -75, -54, 13, 20, -16, -44, -18, 61, -112, -52, 14, -2, 53, 113, -92, -56, -30, -50, -98, -23, -52, -30, 29, -15, -5, -4, 63, 14, 34, 43, 3, 22, 5, 40, 21, -66, -48, -50, 19, 10, -13, 32, 21, 71, 127, 83, 35, -14, 8, 78, -106, -12, 18, -97, -106, -74, 10, -13, 10, 2, 34, 60, -60, 0, -6, -40, -56, 11, 16, -62, -13, 1, 8, 45, -54, -74, -19, -30, -49, -15, -60, -21, 58, 61, 56, 19, -17, -3, 5, -48, -25, -53, -6, -22, 16, 37, 35, -13, -19, -36, -25, 14, -26, -35, 40, 9, 70, 42, 63, 99, -2, 33, 6, 34, 11, 58, -66, -45, -36, -75, -14, 56, -17, 85, 109, -13, 34, 99, 29, 53, 57, 4, 25, 72, -42, 0, 23, -26, -49, -8, -21, 51, 51, -16, 51, 58, 9, 4, 112, 27, -40, 18, -30, -104, -54, 42, 19, 37, -33, -59, -4, -39, -71, -4, 8, -5, -70, -54, -91, -74, 10, -66, -91, 48, 16, -6, 23, -3, -26, -35, 23, 11, -41, -19, -46, -33, 19, 62, 37, 43, 75, 21, 39, -10, 0, -39, -21, -84, -16, -2, 40, 7, 75, 34, -18, -14, 47, -2, 17, -19, 17, -26, -64, -108, -120, -33, -98, -127, -3, -43, -26, -26, -111, -115, 20, -60, -107, 2, 9, -27, -9, 28, 43, -79, 0, 57, 37, 27, 59, 42, 74, -22, -3, 87, 70, 1, -45, -22, 58, -29, -72, 65, 59, -15, 22, -8, -14, -19, 25, 13, -51, 3, -9, 18, 32, 116, 27, 51, 17, 91, 84, 35, 47, 35, 73, 31, 45, 11, 28, 27, 14, -38, -40, 15, -38, -3, 25, 17, 39, 18, 60, 21, -58, -18, 16, -20, -33, 3, -37, 43, 54, -12, 3, 17, 32, -36, -45, 9, -53, 37, 31, -72, -69, 3, 25, -72, 10, -5, -54, -49, 7, -5, -1, -21, 4, 17, -113, -49, 29, 15, 1, 30, 22, 11, 78, 8, 7, 11, -27, 25, 76, -94, -54, 13, -2, -40, -74, -9, 18, 35, -2, -6, 40, -13, 31, 0, -23, -42, -59, -24, -78, -26, -8, -28, 11, -117, -83, 18, -45, -19, -23, 37, 9, -42, 54, 18, 23, -21, -14, 3, -39, -109, -74, 36, 44, 1, -3, 1, 9, -127, -58, 9, 4, -19, -6, 34, 49, 55, 56, -19, -81, 22, -24, -36, 33, -6, 21, 34, -3, -62, 13, 17, 80, -48, 26, 61, -33, -60, -61, -55, -35, -8, 35, 14, 66, -16, -31, 40, -44, -40, -54, -42, 38, -27, 41, -13, 40, 0, -23, 18, -39, -26, -127, -64, -66, -69, -20, 27, -34, 38, 75, 62, -1, -19, -6, 77, 44, 18, 75, 47, 54, 25, 0, 39, -75, -61, -62, -17, -6, 21, 99, 25, 58, 56, 17, 40, -62, -35, 7, -61, -25, 17, -13, 4, -27, 72, 91, 56, 24, 19, 42, -76, -6, -58, -42, -20, -62, 12, 33, -1, 57, 30, 7, 16, 2, -5, -14, -9, 42, 47, 25, 0, 26, -7, -29, 5, 9, -18, -28, -45, -66, 16, 38, 65, 7, -36, -6, -37, -73, -94, -80, -41, -17, 88, 21, 72, 51, 98, 83, -113, -8, -33, 30, -15, 42, -30, -37, -35, 11, 12, -9}; 24 | const int8_t layer_3_weight[4000] = {45, 25, -25, -41, 51, -28, -17, 33, 10, -43, -3, -32, 0, -18, -10, -1, -39, -2, 4, 18, 30, 13, 15, 33, -47, -2, -36, 11, 26, -2, 25, -14, -4, 33, -2, -12, -34, -37, 23, 11, 17, -13, -14, -11, 3, 13, 3, -9, 20, -12, 33, 14, -1, -21, -20, -1, -108, 19, -35, 25, 27, -9, 23, 15, -59, -3, -55, 5, 18, 8, 10, -15, 15, 13, -96, 1, -12, 7, 2, 39, 17, 26, -29, -23, -58, 9, -21, -1, -36, 22, 30, 13, -39, -15, -24, 26, 38, 18, -42, -34, -6, -35, 74, 1, 47, -20, -74, -11, -15, -13, 2, -53, 49, -3, 45, -13, -12, 18, -47, -21, 0, 50, 6, 60, -37, -17, 13, 24, -47, -77, -91, 80, -57, 62, -60, 24, 56, -5, -12, 21, -11, -23, -45, 12, -56, 43, 57, 1, -1, -28, -10, 25, 14, 39, -37, 11, -23, -3, 1, -70, -28, -2, 9, 15, 7, 20, -32, 59, -9, -56, -54, 4, 29, 18, -31, -34, 55, -7, 37, -33, 3, 14, 2, -5, 15, -31, 32, -27, 19, -5, 28, -10, 6, 2, -38, -4, 19, 34, -43, -24, -36, 58, -21, 7, -23, -3, -5, 15, 4, -15, -40, 19, -10, 26, -35, 47, -57, -11, -25, -15, 21, 19, 28, 36, -9, 55, -2, -44, -51, -34, 22, -21, 53, 2, 42, 0, 31, -48, -59, -10, 22, 18, 38, -8, -38, -11, 0, -50, -1, -52, -21, -6, -12, 20, -89, -10, 64, -48, 48, -5, 20, -50, 13, -32, -44, -4, 35, -31, 24, 40, 41, 21, 4, -39, -120, -2, -8, -8, 18, 59, 4, -7, -9, -36, -110, 43, 24, -41, 12, -21, 10, 27, -43, 0, -44, 33, 60, -63, -29, -48, -43, -2, 24, 2, -64, -19, 0, 10, 17, 26, 11, -6, 28, 9, -58, -13, -19, -37, 19, 36, 10, -34, -5, 6, -33, -40, -20, 9, 18, 27, -16, -13, -9, 22, -3, -24, 12, 20, -5, 11, -35, 22, -42, 24, 52, 25, -20, 14, -31, -38, 27, 9, -12, -83, -9, -1, 52, 3, 8, 25, 21, 5, -5, -30, 57, -8, 7, -26, 31, -14, -4, 11, -5, 26, 20, -20, 14, -33, -9, 25, -23, -25, -30, 127, -50, -72, -26, 14, 43, 12, 11, 0, 12, 5, 6, -112, -60, 32, 57, 20, 23, -30, 29, -91, 17, -24, 34, -3, 54, 8, 2, -19, 34, -103, 44, -40, 17, -46, 63, -26, -33, -29, 42, -26, -50, -14, 23, -52, 80, 6, -24, -22, 14, 53, 30, -21, -17, 9, -2, 4, 5, -22, 8, -6, 5, -55, 34, 42, -32, 10, 6, 5, 54, -33, 23, -35, -16, -80, 40, -29, 2, -3, 22, -44, 46, -13, 9, -9, 24, -5, 10, -11, 10, -11, 8, 7, 12, -11, -46, 11, 9, -34, 23, 64, -23, 27, 18, -26, -7, -41, -5, 17, -32, 34, -19, 8, 2, 32, -7, 18, 9, 6, -1, -3, -13, -40, -12, 7, 0, 4, -29, -7, 5, 30, 40, -3, -25, -11, 39, 11, -14, -42, 6, 17, -7, -11, 0, -12, 26, 41, 14, 6, -16, -30, 12, 22, -16, -17, 23, 50, -36, 17, -47, -110, 76, 50, 19, 25, 13, -7, 30, -5, 26, 4, 9, -89, 32, -22, -13, 17, 0, -42, -32, 86, -6, -9, 9, -66, 10, 32, 16, -11, -27, 28, 20, -14, 21, -61, 15, 41, -37, -37, 5, 39, -6, -17, 64, -16, 16, 25, -45, -6, -30, -33, 57, 39, 13, -14, 48, -23, 2, 41, 47, 7, -17, -31, 20, -7, 4, -45, 17, -5, 63, 56, -24, -31, -11, 25, -19, -51, 5, -23, 16, 44, 8, -13, -9, -12, -16, 15, 5, -70, 8, 35, -51, -30, 44, 9, 82, -36, 10, -4, 44, -22, 36, -91, -27, 3, 81, -60, -11, 65, 3, -43, -36, 24, 8, 83, 40, -72, 17, -6, -6, -15, -22, 1, -21, 102, -10, -45, 50, -71, 4, -15, 22, 18, -2, -34, 2, 74, -12, 20, 1, 38, -42, -29, 1, -43, 47, -5, 44, -14, 14, -79, 43, -47, 21, 82, 30, -50, -39, 36, -16, -24, 63, 40, -97, 84, -68, 10, -23, 4, -7, 23, -5, 25, -109, 59, -43, 52, 21, -74, 31, 60, -25, 22, -72, -8, -9, 26, 2, -75, 9, 19, -93, 8, 50, -54, 38, 9, 29, -14, 11, -7, -6, 35, 21, -32, -25, -52, 21, 43, -21, -45, -13, 9, -15, 8, -9, -21, 4, 58, -15, 18, 21, 1, 6, 2, -2, -61, -37, 74, 8, -107, 69, -2, 127, -78, -47, -66, -35, 93, -92, 5, 56, -15, 59, -44, -32, -65, -76, 9, -55, 25, 20, 50, -10, 13, -15, -36, -3, 6, 25, 2, -31, 5, -9, 28, 34, -21, 34, -2, 63, 9, -10, -37, 40, -89, -58, -7, 12, 33, 59, -66, 29, -6, 27, 9, -86, -66, -68, 46, 13, -33, 78, -41, 33, 41, -49, -51, -42, 14, -3, 40, -19, -34, -18, 47, -12, -8, -28, 29, 29, -2, -37, 31, -11, 32, -23, -3, -2, 26, -11, 100, -90, -31, 41, 51, -98, -34, 51, 36, -44, 2, -42, -38, 2, 127, -79, -112, 14, 50, 6, -46, -53, -22, -54, -51, 1, 22, -8, 22, -6, 12, 27, -42, -61, -11, -9, 42, -42, 61, 2, -8, -52, 65, -32, 40, -60, 22, -37, 37, -16, -1, -74, 10, -52, 96, 3, -65, 31, 108, 12, -94, -87, -9, 4, 95, -47, -82, 2, 33, 35, -34, -17, 11, -54, -26, -64, 12, 27, -15, -22, -5, 21, 3, -12, 24, -20, 44, -10, 39, 7, -45, -46, 36, -10, -12, -19, -22, -36, 84, 27, -48, -1, 26, -65, 42, -94, -15, -13, 32, -6, -55, 69, 24, -59, 44, -94, 13, 4, 1, -55, -6, 32, 39, -16, 32, -52, 9, -9, 13, -30, 26, -33, -1, -24, 28, 11, -17, -14, -31, -25, 12, -23, 20, 88, -31, 14, -23, -16, 20, -17, -34, 24, 14, 57, 9, -36, -54, -35, 37, -4, -35, 102, 38, -8, -46, -12, -60, -33, 27, -42, -39, 78, 39, 1, -97, 30, -95, -1, 28, -30, -121, 105, 14, 35, -48, -13, -102, 23, -40, -50, -63, 34, 17, 52, -15, 34, 13, -20, -18, -42, -56, 56, 32, 9, 1, -17, 12, -15, -5, -26, -35, 74, -1, -5, 1, -18, -11, -19, 41, 20, 22, 56, -74, -53, 57, 0, -38, 43, 29, 54, 46, 62, -127, -97, -7, 27, -67, 1, 1, -53, -30, 20, -15, 63, -43, 10, 19, 15, -46, -35, -55, 29, 22, 25, -15, 21, 30, -14, 30, -24, 7, -71, 4, -8, 20, 11, 8, 15, 75, 29, -89, 73, -14, -67, 31, -49, -12, 47, -38, 4, -98, 12, -123, -25, 37, 7, 10, 24, -39, 7, -31, -27, -20, 26, 5, 15, 1, 31, -25, -3, -36, 26, -21, 29, -17, 29, -36, -54, 40, -4, -55, 25, -64, 10, -6, 19, -3, -10, 3, 3, -18, 37, -8, -7, 16, -24, 48, 43, -70, -65, -20, 16, -7, -38, 75, -34, -1, 35, -12, 8, 1, -18, -34, 2, -10, 3, -26, 29, 15, 44, -61, 3, -15, 20, 1, 33, -50, -83, 25, -41, -62, 54, -27, -29, 50, 26, 18, -57, 27, -52, -13, 23, 3, -57, 63, -30, 30, -57, -41, -77, 53, 64, 47, -56, -36, 16, 28, 4, -60, 4, 4, -19, 25, 13, 3, -10, 8, -5, -1, -2, 16, -21, -30, 8, 22, 0, -19, -2, 12, 8, -18, 41, -5, 16, -25, 7, -2, -10, 62, 14, 14, -7, -35, -15, 12, -15, -14, 10, 16, 39, -7, 46, 10, -30, -25, 18, -11, -22, -28, -5, -6, 8, -55, 37, -15, 50, 39, -16, -40, -48, -78, 99, 17, 5, -3, -13, 5, 18, 13, 4, -5, 61, -2, 8, 12, -21, 1, 14, -14, -3, 34, 32, -59, -65, 13, 1, 15, 13, 13, 12, 30, 64, -60, -88, 55, 8, 17, 2, -14, -14, -73, 17, -55, 49, -68, 25, 36, 31, -34, -43, -81, 34, 21, 5, -71, 3, 8, 49, -18, -20, -32, -21, 21, 26, 1, 31, -41, -6, -10, -2, -36, 3, -28, -23, 14, 14, 30, 3, 2, 15, -13, 8, -13, -16, 12, -32, 36, 20, -8, 30, 2, 64, -40, -39, -56, 64, 11, 11, -9, -2, -14, 49, -33, 14, -38, -10, 11, 29, -47, 16, -13, -48, 13, 49, -32, 33, -37, -54, 42, -9, -8, -48, 4, 60, -16, 8, 1, -16, 16, -18, -12, 42, 1, 25, -7, 22, 8, -7, 2, 25, 27, -44, -96, -47, 51, -15, 60, 36, -5, -31, 13, -43, -15, -3, -9, -27, 3, 75, -14, 56, 16, -104, 37, -9, -36, 18, -70, -6, 8, 18, -4, -51, 18, -7, -15, 13, -11, -18, 45, 20, 42, 1, -18, -8, -18, 36, -23, 16, -20, -35, 22, 4, -24, 41, -37, 20, -38, 4, 9, -17, -10, -28, -42, 63, -2, -13, 29, -5, 32, -8, -22, -21, 11, 17, -29, 23, -4, -4, 26, -17, -29, 31, 10, -7, 2, 8, -8, -6, 26, -71, -65, 57, -3, 26, 0, 13, -18, 34, -29, -25, -58, -19, 26, 47, -4, 0, -3, -14, -17, -3, -33, 50, -7, 20, -7, 25, 26, 21, 8, -6, -29, 34, -1, 25, -22, -46, 4, -26, 3, -3, -18, 46, -26, -13, 27, -1, 21, 15, 21, 22, 12, 38, -37, 11, 32, -26, -54, 18, 0, -33, -40, 14, -70, 34, 7, -13, 30, 30, -6, -54, -76, 44, -2, 30, -27, -13, 28, 20, 15, -40, 30, -21, -19, 4, -2, -4, -19, -37, 10, 10, 8, 21, -76, 2, 21, 36, -8, 9, -8, 25, -36, 27, -105, -47, 31, 6, -21, 32, -58, 19, -61, 9, -35, 35, 9, 16, 29, 37, -5, 51, -108, 29, -46, 36, -17, 87, -41, -44, 13, 34, -30, -46, -4, 32, 2, 34, 6, 22, -41, 21, -4, -7, -20, 9, 9, -22, 9, 8, -14, 38, -39, -21, -5, -9, 44, -81, 36, 29, -1, 11, -12, 5, -45, 10, 35, 15, -51, 2, 15, 14, -82, 32, -34, -36, 12, 12, -7, 39, -15, -4, 19, 39, 19, -13, -10, -36, -14, 15, -64, 18, 27, -24, 23, -2, -65, 10, -23, 28, -70, -57, 43, -23, 26, -26, 7, 27, -12, -13, -55, 2, -2, 24, -25, -60, 10, -10, 0, 6, 5, -14, 21, -53, -1, -58, 0, 13, 36, -30, -12, -5, 54, -71, 17, -46, -1, 31, 48, -9, -7, -18, 20, -76, 41, -44, -6, 11, 38, 19, 14, -27, -28, -25, 18, -17, 6, -11, 12, 31, -5, 16, 23, -42, -19, -40, -3, 18, 43, -1, 35, -6, 45, -43, -18, -11, 15, 13, -2, -13, -5, -5, 48, -59, -8, -10, 33, 10, 28, -21, 4, -45, 35, -76, 3, 0, 14, 2, 29, -25, 16, -35, -36, 5, 56, 3, 51, -7, -16, 16, -15, 88, -16, 19, -54, -50, 0, 6, -28, 0, 4, 57, -61, 12, 25, -36, 48, 34, -77, -44, -25, 24, 26, -7, 12, 13, -62, -9, 29, -74, -40, -42, 80, -9, -39, 49, -35, 41, 41, -9, -30, 17, -17, 88, -56, -15, -20, 34, 39, -5, 34, 29, -18, -17, 33, -22, -42, 16, -72, -10, 5, 18, -29, 40, 18, -25, 34, 2, 3, -16, -69, 47, 27, 27, 2, 5, -17, 0, 24, -12, -6, -41, 1, 27, -18, 7, 20, 6, -45, -4, 58, 20, -64, 97, -36, -43, 36, -99, -22, -2, 17, 7, 5, 1, 15, -49, 69, -39, -3, -15, -16, -26, 39, -8, 28, -25, -47, -14, 3, 17, -51, 33, 62, -66, 49, 15, -90, 19, -44, 18, 19, 25, 1, -117, 39, 2, -89, 28, -39, -5, 15, 16, -5, -24, -2, 22, -119, -18, 0, -35, -13, 31, 19, -51, -95, 22, 50, -14, -16, -54, -23, 11, 29, -36, -49, -14, 40, 34, -16, -24, -60, -8, 4, -107, -3, -38, 62, -22, 48, -3, -23, 13, -25, -47, 38, 30, 26, 11, 32, -17, -22, 6, -19, -13, 39, 36, -14, 39, -43, -25, 24, 31, 51, -19, -64, -45, -12, 13, -10, 9, -9, 15, 65, -71, -9, -78, 5, -21, -9, -14, -34, 29, 9, -76, -17, -36, 43, -15, 8, 15, -43, 16, -2, -84, 62, -6, 5, 8, 8, -31, -45, -78, -22, -1, 56, 6, -2, 19, 3, -22, 32, -15, -16, 16, 28, -74, 2, 20, -12, -9, -16, 23, -10, 20, 23, -79, 14, -61, 2, -9, -7, 5, 11, 16, -8, -11, 15, 11, 6, -29, -25, -33, 54, -62, 4, 19, 12, 18, 24, -7, -32, -26, 0, 69, 6, -34, -31, 64, -21, -17, -6, 30, -50, 35, 28, -41, -24, 17, -7, -32, -30, 3, -35, 63, 38, -36, -4, 13, 5, -37, -43, 2, -1, 7, 3, -13, -40, 37, 55, -5, 21, -19, -11, 17, -4, 10, 5, -26, 15, 54, -30, 39, -59, 18, -17, 27, 31, -8, -87, -10, 25, 1, 2, -28, 2, 26, -37, -29, 3, 21, -8, 3, 42, -31, 18, 8, -57, -33, 38, 18, 0, 23, 23, -72, 6, 5, -26, -32, -12, 30, -9, 5, 29, -56, -4, 13, -10, 18, -19, 22, 30, 68, -20, 40, -17, 8, -6, -33, -28, 24, -16, -10, 19, -22, -35, 56, 18, 21, -54, -64, -14, -2, 9, -115, -5, -18, 1, 12, 82, -30, -38, 0, 32, -64, 29, -32, 22, 29, 33, -28, -59, 9, -7, -104, 29, -4, -6, 10, 57, -13, -20, -30, -29, -37, 56, 19, -28, -11, 22, -4, -22, 37, 25, -28, -33, -35, -6, -32, 15, 21, 3, 24, 46, -82, -13, -70, 3, 29, 7, -6, -73, 16, 13, -61, -25, -37, 32, 20, 31, 3, -55, -21, -20, -58, 34, 9, 13, 12, -8, -30, -15, -51, -37, 5, 103, 2, -32, 75, -38, -28, -13, 27, 30, 12, 20, 2, 10, 2, -22, -15, -18, 42, 0, 33, 2, -45, 24, -14, -12, -18, -54, -19, 32, 46, 18, -11, -35, -44, 76, -81, -127, -32, 43, 4, 9, 63, -37, 13, 0, 19, -58, 20, -30, 12, -25, -11, 5, 88, -73, -17, 26, -14, -2, 47, 33, -8, -11, -7, 18, 8, -12, -7, 13, 24, 1, -29, 39, 10, -4, -23, -29, 0, -3, 12, -12, -23, 11, 20, 56, 34, -65, 14, -26, 71, -30, 11, 13, -11, -10, 22, -12, 57, -81, 15, 9, -16, 49, -34, -120, 18, 58, -9, 13, 0, 20, -8, -48, 8, -18, 0, 27, -21, -27, -6, 10, 18, -22, 33, 1, 24, -11, -11, -25, 30, -15, 24, -22, 38, -8, -9, 31, 18, -18, -22, -33, 12, -32, 18, -3, 12, 55, 30, -19, 42, 2, -2, -47, 5, -75, -13, 22, 4, -7, -20, 3, 16, 8, -18, -58, -45, 37, 7, 53, 43, 2, -45, 41, -30, -121, -13, 18, 20, 58, 0, -42, -47, 26, -23, 1, 12, -35, 33, 9, -30, -51, -61, 29, 28, 45, 19, -21, -3, 52, -76, -64, -36, -46, 31, 27, -4, 24, -57, -35, -5, -4, -38, 66, 10, -30, -2, -10, -48, -96, 66, 14, 3, 23, 21, 15, -3, 8, 3, -52, -7, 0, -27, -37, 62, 33, 36, -4, 8, 7, -3, 2, -35, -10, -7, 0, 35, 8, -28, -52, -53, -16, 33, -15, 22, 3, 20, -27, -57, -84, 50, 21, 7, -56, 20, 35, 12, -68, -58, -72, 32, 15, -4, -41, 80, 29, 5, -12, 3, 6, 34, 34, 1, -56, 75, -22, 21, -27, -20, 28, 5, 5, 13, 1, 13, -18, 28, -17, -19, 5, -29, 25, 35, -8, -11, -49, 9, -33, -12, -2, 23, 15, 25, -75, -15, 17, 15, -23, -3, -30, -55, 3, 43, 24, -16, 13, 1, 12, 4, 25, -70, 18, 18, -50, 25, -54, -9, -20, -9, 36, -38, 10, 49, 3, 35, -45, -35, 11, -45, -2, -21, 15, 8, -33, 36, -36, 27, -49, -2, 38, -39, 13, -30, -41, 34, -24, 18, -27, -1, 3, -38, 41, 10, -72, 25, 12, 57, -4, 13, -36, -76, -22, 37, -40, 16, -36, -56, 6, 53, -12, 36, -31, -3, -5, -35, -9, -127, 48, 71, -54, 51, -47, -37, -22, -14, 18, 0, -8, 5, 17, -36, -18, 60, 4, -21, -58, 5, -16, 10, -12, -22, -1, 15, -8, -4, -16, -15, 28, -10, 14, 2, -41, -5, -30, -2, 24, 7, 32, -8, 11, 28, 0, -2, -43, -9, 2, -5, 65, -24, 15, 77, 12, 41, -39, -5, -33, 15, -4, -34, -1, 37, -1, -14, 28, -16, -26, -18, -19, -6, -30, 11, -2, 23, 29, 9, 16, 3, 13, -40, 18, 23, -25, -13, 20, -22, 36, -4, -16, -29, 22, 14, -85, -15, 50, -23, 15, 13, 26, -4, 18, 29, -74, -33, 25, 34, -7, 32, -21, -1, -23, 36, -37, 17, 24, -25, 10, 40, -22, -14, -35, 48, 17, 12, -35, -10, -1, 0, 21, -28, 0, -64, -35, 12, -35, 21, 15, -13, 25, 0, -8, 1, -47, -30, 18, -15, 0, 44, -7, 29, -38, -34, -37, 21, 8, -8, -3, 10, -53, 6, -44, 33, 0, -4, -3, 14, 7, 32, -35, 12, -45, 2, -28, 35, -5, 41, -40, -21, -6, 42, -20, 10, 10, 13, -37, 37, -46, -7, -21, -17, -9, 49, -1, 23, 23, -30, 18, -11, -3, -35, -5, -7, -6, 7, 29, -9, -4, 12, -11, 15, 6, -43, 21, -23, 3, -9, -23, 31, 36, 4, -23, -30, -16, -6, 23, 5, -46, -34, 17, -19, -1, 26, 15, 3, 34, -10, -5, 4, -5, -18, 38, 29, 13, -28, -5, -11, 3, -2, -45, -68, 19, -24, 36, 18, 13, 16, -22, -55, 32, 21, -27, -54, -68, 18, 50, 22, -35, -113, -14, 74, 16, 11, -63, -47, 77, -21, -127, -19, -28, 16, 35, -51, 2, -62, 41, -29, -33, 5, -44, 10, -9, -101, 42, -27, -19, -70, 23, -17, -75, -31, -105, -36, 31, 30, -61, 19, -53, -46, 23, 2, 22, -5, -34, -35, 50, -31, -7, -36, 3, 13, 10, -50, -39, -59, 48, -55, -33, 2, -2, 48, 41, -19, 0, -22, 71, -26, -31, 4, -38, -4, -40, -48, 62, 20, 36, -31, -14, -20, -72, -94, -102, -59, 68, 84, -43, 1, 11, -86, 0, -1, 27, -27, 4, -21, 28, 12, -12, -95, -7, 44, -10, 46, 17, -112, 1, -21, 32, -5, -80, -10, 9, 16, 22, -21, -26, -31, 48, 9, -87, -50, 45, 20, 5, -34, 29, -23, -12, 4, 14, 21, -44, 8, 60, -49, 6, -13, -27, -12, 17, -53, -19, 43, -2, -1, -46, 9, 7, -20, -33, -49, 12, 46, 22, -25, 11, -46, 53, -61, -48, 0, 58, 52, 9, -29, -45, -44, 26, -16, 12, 14, 15, -5, -13, 18, 16, -15, -14, 28, 2, 127, -122, -9, -29, -29, 11, -32, -71, 5, 6, -13, 43, -13, 18, -18, -68, 1, 3, 11, -36, -3, -8, -64, 9, 24, -66, -14, 5, 36, 8, 34, -2, -89, 11, 49, -52, -17, -46, 33, 17, 38, -46, -21, -2, -5, 12, -54, -12, -7, 25, 63, -65, 60, -47, 11, 18, -51, -7, 1, 49, -23, 4, -16, 4, 60, 1, 25, -98, -30, 70, 11, -3, -22, 24, 41, -44, -1, -126, -9, 21, 39, 43, 2, -35, -65, -23, 0, 10, 12, -25, 29, 16, -82, -34, -27, 2, 0, 10, 63, -62, 29, 69, -44, -94, -3, -23, -28, 39, -7, 21, -47, -43, 19, 38, 8, 24, 38, -11, 1, 5, -16, -58, 55, 21, -16, -5, 30, 41, 22, -8, -17, -22, -22, -21, -110, 9, 81, 29, 37, -33, 18, -14, -24, -25, -23, 25, 1, 14, 11, -43, 6, 9, -94, -58, 7, -4, 31, 51, 25, -86, -44, -97, 52, 4, -14, -54, 35, 31, 5, -39, -61, -44, 63, 35, 6, -60, 56, 45, 26, -22, -21, 64, 63, -32, -48, -37, 70, 18, 3, -18, -35, 65, 27, 10, 15, -32, 3, 16, -12, -20, -18, -18, -22, 6, 51, -2, -26, -33, 31, -46, -18, 18, -9, -22, 19, -80, -22, -8, 11, -35, -20, -25, 16, -14, 17, -24, -9, 7, 19, -1, -8, 48, -2, -22, 29, -29, -12, -42, -54, -5, 3, 17, -26, 14, 4, -47, 27, -5, -26, 2, 4, 33, -3, 17, 14, -52, 39, -19, 23, -71, 10, 39, -72, -22, 12, -68, 55, -34, 2, -15, 29, -45, -59, 16, 47, -46, 27, -49, 36, -30, 8, -35, 9, -10, 24, 5, 0, -20, -53, 7, 70, -13, 51, -26, -4, -6, 6, -7, -105, 51, 19, -56, 88, -6, -26, -30, -31, 64, 9, -10, -3, 19, 5, -24, 14, 9, -37, 38, -12, -1, -8, 10, 38, 7, -8, 23, 2, -13, -10, -24, -2, 29, 58, 26, -26, 3, 35, 6, 0, -25, -3, 24, 25, 10, -21, -4, 47, -9, -8, -10, -8, 16, -21, 10, -5, -32, 31, 3, -2, -3, -16, 27, 1, -48, -13, 42, -17, -59, -18, -35, -4, 17, -22, -16, -51, 65, -13, -28, 20, -60, 9, -14, -39, -1, -44, 31, 7, 15, 13, 25, 6, -10, -28, 8, -35, 18, 28, 25, 29, -27, -9, -27, 39, 18, -31, 5, 2, 17, -51, 19, 79, 43, -4, -21, -12, 56, 11, -85, -29, -52, 10, -31, 12, 3, -29, 6, 27, 40, -2, -29, -4, -73, 59, 36, 46, -2, 37, 25, -33, 37, -3, -4, -78, 31, 29, 18, -13, -45, 6, -12, -5, -8, 68, 14, 10, 7, -13, -12, -66, 34, 11, 53, -5, 67, -18, -9, -40, -27, -10, -14, -63, 12, 54, 37, -49, 9, -31, 49, 17, -72, -27, -35, 55, 27, 37, -54, 23, 10, -40, -19, -7, 13, 2, 19, 20, -7, 35, -89, -44, -7, -9, 13, 69, 10, -21, 46, 34, -44, -42, 21, -46, 57, 39, 26, -48, -28, -59, -14, -15, 42, -28, 9, -7, 71, -39, -40, -1, 48, 21, -23, 12, -8, -9, -7, 8, -17, -38, 34, 13, 23, 35, 2, 1, -35, 46, -19, -20, -60, -51, 37, 77, -26, 30, -5, -2, -21, -19, -20, 22, 18, 0, 6, -4, 4, -59, -14, -21, -16, 4, 6, 4, 6, 33, 24, -50, 2, -11, -6, -8, -25, 15, 18, -34, -21, -26, 6, -42, 21, 0, 13, -26, -1, -29, 16, -2, -1, -27, 11, -31, 20, -28, -62, -39, 46, -11, 15, 24, -31, 40, 25, 44, -2, -20, -58, -10, 13, -20, 6, 32, -1, 11, 20, -127, -27, -21, 38, 5, -45, 23, 21, 16, -15, -71, 66, 79, 72, -69, -69, -19, 53, -30, -18, -24, 28, 78, 37, -53, -20, -15, -7, -35, -31, 8, 49, 10, 48, -2, -58, -29, 19, 77, 22, -100, 34, -31, 21, 4, -39, -31, 11, 46, -45, -16, -1, -26, 69, -34, 17, 6, 12, 17, -5, 13, 15, 74, 26, -56, 3, -53, 8, 4, 24, -25, 42, 32, -44, -37, 15, -14, 20, 19, -19, -25, 2, -59, 17, -9, 33, -73, 71, -13, 12, 8, 35, -62, -7, 14, -46, -20, 91, -67, 26, 0, 38, -88, 44, -69, 46, -51, 66, -21, 15, 65, 13, -47, 21, -19, -22, -49, 37, 8, 0, 26, -20, -74, 37, -23, 36, 18, -14, 18, -39, 16, -35, -72, 51, 10, 31, -30, 56, -63, 88, 19, 41, -30, 35, -120, 4, 18, 1, -7, 63, -49, 64, -65, -25, -127, 13, 34, 87, 76, 32, -92, 44, 31, -73, -85, -14, 65, 25, 25, -6, -77, -24, 88, -35, -59, -17, 48, -5, 47, -56, -10, 11, 11, -46, -25, -28}; 25 | --------------------------------------------------------------------------------