├── README.md ├── generator.py ├── helper_functions.h └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep Compression for PyTorch Model Deployment on Microcontrollers 2 | 3 | In this repository, you can find the source code of the paper *"Deep Compression for PyTorch Model Deployment on Microcontrollers"*. 4 | 5 | This work follows the paper [Efficient Neural Network Deployment for Microcontroller](https://arxiv.org/abs/2007.01348) by Hasan Unlu. You can find the repository of the source code of that paper [here](https://github.com/hasanunlu/neural_network_deployment_for_uC). 6 | 7 | ## Dependencies 8 | These are the only versions tested; therefore, other versions may be incompatible. 9 | * Python 3.8.5 10 | * PyTorch 1.8 11 | * Tensorboard 2.4.1 12 | * Neural Network Intelligence (NNI) 2.1 13 | 14 | ## Usage 15 | Running the `generator.py` will generate `main.c` and `main.h` files in the outputs folder. `helper_functions.h` file is required by `main.c`. 16 | 17 | Only two network architectures are included in this generator. To switch between the two, change this line in `generator.py`: 18 | ```python 19 | dataset_name = 'mnist' # change this for different networks. can be 'mnist' or 'cifar10' 20 | ``` 21 | 22 | Some networks might be sensitive to input activation quantization. To disable input quantization, change this line in `generator.py`: 23 | ```python 24 | quantize_input = dataset_name != 'cifar10' # change this for input quantization. can be True or False 25 | ``` 26 | 27 | If your network is pre-trained you can disable initial training. Your pre-trained network should be in the saves folder with the correct name (`original.pt`). To disable initial training, change this line in `generator.py`: 28 | ```python 29 | pre_trained = True # change this if your model is pre-trained. can be True or False 30 | ``` 31 | 32 | If your GPU supports CUDA, you can enable CUDA usage to speedup the process. CPU will be used if CUDA is not enabled. To use CUDA, change this line in `generator.py`: 33 | ```python 34 | use_cuda = False # change this if yor GPU supports CUDA. can be True or False 35 | ``` 36 | 37 | You can also use this generator with networks other than **LeNet-5** or the **CIFAR-10 test network** implemented. You need to create the network using PyTorch building blocks and adjust the optimizer. Supported PyTorch building blocks are: 38 | * Conv2d 39 | * MaxPool2d 40 | * Linear 41 | * Flatten 42 | * ReLU 43 | 44 | You can look at our LeNet-5 implementation as a reference to supported model implementations: 45 | ```python 46 | import torch.nn as nn 47 | 48 | nn.Sequential ( 49 | nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2), 50 | nn.ReLU(), 51 | 52 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 53 | 54 | nn.Conv2d(32, 16, kernel_size=5, stride=1, padding=2), 55 | nn.ReLU(), 56 | 57 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 58 | 59 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 60 | nn.ReLU(), 61 | 62 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 63 | 64 | nn.Flatten(), 65 | 66 | nn.Linear(4*4*32, 10), 67 | ) 68 | ``` -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import random_split 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import sampler 7 | from torchvision.datasets.mnist import MNIST 8 | 9 | import utils 10 | import math 11 | import copy 12 | from pathlib import Path 13 | 14 | from nni.compression.pytorch.compressor import PrunerModuleWrapper 15 | from nni.compression.pytorch.compressor import QuantizerModuleWrapper 16 | from nni.algorithms.compression.pytorch.quantization.quantizers import NaiveQuantizer 17 | 18 | import torchvision.transforms as transforms 19 | 20 | import torchvision.datasets as dset 21 | import torchvision.transforms as T 22 | import torch.nn.functional as F 23 | 24 | import numpy as np 25 | 26 | dataset_name = 'mnist' # change this for different networks. can be 'mnist' or 'cifar10' 27 | 28 | quantize_input = dataset_name != 'cifar10' # change this for input quantization. can be True or False 29 | 30 | pre_trained = True # change this if your model is pre-trained. can be True or False 31 | 32 | use_cuda = False # change this if yor GPU supports CUDA. can be True or False 33 | 34 | dtype = torch.float32 35 | 36 | if use_cuda: 37 | device = torch.device('cuda') 38 | else: 39 | device = torch.device('cpu') 40 | 41 | if dataset_name == 'cifar10': 42 | # Example network on CIFAR10 dataset 43 | transform = transforms.Compose( 44 | [transforms.ToTensor(), 45 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 46 | 47 | trainset = dset.CIFAR10(root='./data', train=True, 48 | download=True, transform=transform) 49 | data_train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, 50 | shuffle=True) 51 | 52 | testset = dset.CIFAR10(root='./data', train=False, 53 | download=True, transform=transform) 54 | data_test_loader = torch.utils.data.DataLoader(testset, batch_size=512, 55 | shuffle=False) 56 | 57 | classes = ('plane', 'car', 'bird', 'cat', 58 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 59 | 60 | print_every = 50 61 | print('using device:', device) 62 | 63 | # Neural network architecture 64 | # Current code only supports conv2d-ReLU-maxPool2d pairs together and Linear 65 | model = nn.Sequential( 66 | nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2), 67 | nn.ReLU(), 68 | 69 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 70 | 71 | nn.Conv2d(32, 16, kernel_size=5, stride=1, padding=2), 72 | nn.ReLU(), 73 | 74 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 75 | 76 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 77 | nn.ReLU(), 78 | 79 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 80 | 81 | nn.Flatten(), 82 | 83 | nn.Linear(4*4*32, 10), 84 | ) 85 | 86 | elif dataset_name == 'mnist': 87 | # Example network(Lenet-5) on MNIST dataset 88 | data_train = MNIST('./data/mnist', 89 | download=True, 90 | transform=transforms.Compose([ 91 | transforms.Resize((32, 32)), 92 | transforms.ToTensor()])) 93 | 94 | data_test = MNIST('./data/mnist', 95 | train=False, 96 | download=True, 97 | transform=transforms.Compose([ 98 | transforms.Resize((32, 32)), 99 | transforms.ToTensor()])) 100 | 101 | data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True) 102 | data_test_loader = DataLoader(data_test, batch_size=1024) 103 | 104 | 105 | print_every = 50 106 | print('using device:', device) 107 | 108 | 109 | # Lenet-5 architecture 110 | model = nn.Sequential( 111 | nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0), 112 | nn.ReLU(), 113 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 114 | nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0), 115 | nn.ReLU(), 116 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 117 | nn.Flatten(), 118 | nn.Linear(5*5*16, 120), 119 | nn.ReLU(), 120 | nn.Linear(120, 84), 121 | nn.ReLU(), 122 | nn.Linear(84, 10), 123 | ) 124 | 125 | Path("./saves/"+dataset_name).mkdir(parents=True, exist_ok=True) 126 | Path("./outputs/"+dataset_name).mkdir(parents=True, exist_ok=True) 127 | 128 | best_model = None 129 | initial_acc = None 130 | 131 | if not pre_trained: 132 | best_model, initial_acc = utils.train_model(model, data_train_loader, data_test_loader, 4, optim.Adam(model.parameters(), lr=2e-3), True) 133 | print(best_model) 134 | torch.save(best_model, './saves/'+dataset_name+'/original.pt') 135 | 136 | # +-+-+-+ BINARY SEARCH FOR OPTIMAL SPARSITY VALUE +-+-+-+ 137 | current_model = torch.load("./saves/"+dataset_name+"/original.pt").to(device) # load network. make sure to name correctly for pre-trained networks 138 | best_model = current_model 139 | 140 | if initial_acc is None: 141 | initial_acc = utils.evaluate_model(best_model, data_test_loader) 142 | 143 | tolerated_acc_loss = 0.01 # manual parameter to tolerate accuracy loss 144 | min_search_step = 0.001 # manual parameter to stop the binary search 145 | step = 0.5 146 | sparsity = 0.5 147 | best_sparsity = 0 148 | while step > min_search_step: # continue until min search step is crossed 149 | model = copy.deepcopy(current_model) 150 | utils.level_prune_model(model, [{ 'sparsity': sparsity, 'op_types': ['default'] }]) 151 | _, acc = utils.train_model(model, data_train_loader, data_test_loader, 4, optimizer=optim.Adam(model.parameters(), lr=2e-3)) 152 | step /= 2 153 | if acc >= initial_acc - tolerated_acc_loss: 154 | best_model = model 155 | best_sparsity = sparsity 156 | sparsity += step 157 | print("Current sparsity: " + str(best_sparsity)) 158 | else: 159 | sparsity -= step 160 | final_acc = utils.evaluate_model(best_model, data_test_loader) 161 | result_str = "Initial Accuracy: " + str(initial_acc) + " Sparsity: " + str(best_sparsity) + " Accuracy: " + str(final_acc) 162 | print("Best sparsity found! " + result_str) 163 | 164 | torch.save(best_model, './saves/'+dataset_name+'/pruned.pt') 165 | 166 | middle_acc = utils.evaluate_model(best_model, data_test_loader) 167 | 168 | middle_model = torch.load("./saves/"+dataset_name+"/original.pt").to(device) # load network from original.py 169 | model_params = list(best_model.parameters()) 170 | q_count = 0 171 | with torch.no_grad(): 172 | for param in middle_model.parameters(): 173 | flat_weights = param.flatten().numpy() 174 | for idx in range(len(flat_weights)): 175 | if len(param.shape) == 1: 176 | param[idx].data += torch.from_numpy(np.array(model_params[q_count][idx])).data - param[idx].data 177 | elif len(param.shape) == 2: 178 | i_0, i_1 = divmod(idx, param.shape[1]) 179 | param[i_0][i_1].data += torch.from_numpy(np.array(model_params[q_count][i_0][i_1])).data - param[i_0][i_1].data 180 | elif len(param.shape) == 4: 181 | i_2, i_3 = divmod(idx, param.shape[3]) 182 | i_1, i_2 = divmod(i_2, param.shape[2]) 183 | i_0, i_1 = divmod(i_1, param.shape[1]) 184 | param[i_0][i_1][i_2][i_3].data += torch.from_numpy(np.array(model_params[q_count][i_0][i_1][i_2][i_3])).data - param[i_0][i_1][i_2][i_3].data 185 | q_count += 1 186 | 187 | # +-+-+-+ 8-BIT WEIGHT QUANTIZATION +-+-+-+ 188 | quantized_model = middle_model 189 | print("Accuracy after pruning: "+str(middle_acc)) 190 | config_list = [{ 191 | 'quant_types': ['weight'], 192 | 'quant_bits': 8, 193 | 'op_types': ['Conv2d', 'Linear'] 194 | }] 195 | quantizer = NaiveQuantizer(quantized_model, config_list) 196 | quantizer.compress() 197 | final_acc = utils.evaluate_model(quantized_model, data_test_loader) 198 | print("Accuracy after weight quantization (8-bits): "+str(final_acc)) 199 | 200 | # +-+-+-+ 8-BIT ACTIVATION QUANTIZATION +-+-+-+ 201 | activation_bounds = utils.get_activation_bounds(quantized_model, data_train_loader) 202 | print(activation_bounds) 203 | final_acc = utils.evaluate_quantized_model(quantized_model, data_test_loader, activation_bounds, quantize_input = quantize_input) 204 | print("Accuracy after activation quantization (8-bits): "+str(final_acc)) 205 | 206 | best_model = quantized_model 207 | layer_parameters = {} 208 | # { 209 | # 'layer_name': { 210 | # 'scale': ..., 211 | # 'bias': [ ... ], 212 | # 'quantized: [ ... ], 213 | # 'indices': [ ... ] 214 | # }, 215 | # } 216 | for name, param in best_model.state_dict().items(): # extract weights and biases from the model 217 | name = name.replace('.module', '') 218 | if 'mask' in name or 'old' in name: 219 | continue 220 | arr = param.cpu().numpy() 221 | shape_of_params = arr.shape 222 | print(shape_of_params) 223 | param_size = len(arr.flatten()) 224 | print(name) 225 | layer_name = 'w_'+name.split('.', 1)[0] 226 | if layer_name not in layer_parameters.keys(): 227 | layer_parameters[layer_name] = {} 228 | layer_parameters[layer_name]['scale'] = quantizer.layer_scale[layer_name.split('_')[1]].numpy() 229 | if "bias" in name: 230 | layer_parameters[layer_name]['bias'] = arr.flatten() 231 | elif "weight" in name: 232 | quantized, indices = utils.sparse_matrix_1d(arr.flatten()) 233 | layer_parameters[layer_name]['quantized'] = torch.div(torch.Tensor(np.array(quantized)), torch.Tensor(layer_parameters[layer_name]['scale'])).type(torch.int8).numpy() 234 | layer_parameters[layer_name]['indices'] = indices 235 | 236 | torch.save(best_model, './saves/'+dataset_name+'/quantized.pt') 237 | 238 | # weights header and network generator. This generates main.h and main.c 239 | weights_file = open('./outputs/'+dataset_name+'/main.h', 'w') 240 | weights_file.write('typedef float data_t;\n\ 241 | typedef int8_t quan_t;\n\ 242 | typedef uint8_t index_t;\n\n') 243 | 244 | c_file = open('./outputs/'+dataset_name+'/main.c', 'w') 245 | c_file.write('// Initial Accuracy: '+str(initial_acc)+' Final Accuracy: '+str(final_acc)+'\n\ 246 | #include \n\ 247 | #include \n\ 248 | #include \n\ 249 | #include \n\ 250 | #include "main.h"\n\ 251 | #include "helper_functions.h"\n') 252 | 253 | c_file.write('\n\ 254 | int main()\n\ 255 | {\n') 256 | 257 | test_vector_batch = 5 258 | test_vector_index_in_batch = 6 259 | 260 | test_vector = None 261 | 262 | for i, (images, labels) in enumerate(data_test_loader): # sample input data 263 | if i == test_vector_batch: 264 | test_vector_index_in_batch = 3 265 | img = images[test_vector_index_in_batch].numpy() 266 | test_vector = images 267 | 268 | weights_file.write('const '+('quan_t' if quantize_input else 'data_t')+' test['+str(img.size)+']={') 269 | data = images[test_vector_index_in_batch].flatten() 270 | for x in range(len(data)): 271 | if x != 0: 272 | weights_file.write(',') 273 | if quantize_input: 274 | weights_file.write(str(int(data[x].item() * 255 - 128))) 275 | else: 276 | weights_file.write(str(data[x].item() if data[x] != 0 else 0)) 277 | weights_file.write('};\n') 278 | break 279 | 280 | result = test_vector 281 | input_size = result.shape[1]*result.shape[2]*result.shape[3] 282 | L = np.empty(0) 283 | L = np.append(L, np.uint32(input_size)) 284 | 285 | previous_padding = None 286 | 287 | index = 0 288 | 289 | meta_list = list() 290 | 291 | c_file.write('\t twoD_t meta_data'+str(index)+' = {\n\ 292 | .r = '+ str(result.shape[2]) +',\n\ 293 | .c = '+str(result.shape[3])+',\n\ 294 | .channel = '+str(result.shape[1])+',\n\ 295 | .scale = '+str(1/255)+',\n\ 296 | .zero_quan = -128,\n\ 297 | .data = buffer'+str(index%2)+',\n\ 298 | .indices = NULL,\n\ 299 | .bias = NULL\n\ 300 | };\n\n') 301 | 302 | meta_list.append(('meta_data'+str(index),(0,0,0))) 303 | prev_channel_size = 1 304 | max_kernel_size = 0 305 | 306 | for i in best_model: 307 | result = i(result) 308 | if isinstance(i, nn.Conv2d): 309 | previous_padding = i.padding[0] 310 | if 'ReLU' in str(i): 311 | continue 312 | if isinstance(i, nn.MaxPool2d): 313 | index += 1 314 | size = prev_channel_size * result.shape[2] * result.shape[3] 315 | if max_kernel_size < size: 316 | max_kernel_size = size 317 | L = np.append(L, result.shape[1]*result.shape[2]*result.shape[3]) 318 | c_file.write('\t twoD_t meta_data'+str(index)+' = {\n\ 319 | .r = '+str(result.shape[2])+',\n\ 320 | .c = '+str(result.shape[3])+',\n\ 321 | .channel = '+str(result.shape[1])+',\n\ 322 | .scale = '+str(activation_bounds[index - 1]['scale'])+',\n\ 323 | .zero_quan = '+str(activation_bounds[index - 1]['zero'])+',\n\ 324 | .data = buffer'+str(index%2)+',\n\ 325 | .indices = NULL,\n\ 326 | .bias = NULL\n\ 327 | };\n') 328 | meta_list.append(('meta_data'+str(index),(i.stride, i.kernel_size, previous_padding if previous_padding is not None else 0))) 329 | prev_channel_size = result.shape[1] 330 | 331 | if 'Linear' in str(i): 332 | index += 1 333 | L = np.append(L, result.shape[1]) 334 | c_file.write('\t twoD_t meta_data'+str(index)+' = {\n\ 335 | .r = '+str(result.shape[1]) +',\n\ 336 | .c = 1,\n\ 337 | .channel = 1,\n\ 338 | .scale = '+str(activation_bounds[index - 1]['scale'])+',\n\ 339 | .zero_quan = '+str(activation_bounds[index - 1]['zero'])+',\n\ 340 | .data = buffer'+str(index%2)+',\n\ 341 | .indices = NULL,\n\ 342 | .bias = NULL\n\ 343 | };\n\n') 344 | meta_list.append(('meta_data'+str(index),(0,0,0))) 345 | 346 | c_file.write('\n\t memcpy(buffer0, test, sizeof(test));\n') 347 | 348 | c_file.write('\n\t printf("---Network starts---\\n");\n') 349 | 350 | 351 | inx = np.argsort(L) 352 | 353 | weights_file.write('\n\ 354 | quan_t buffer'+str(inx[-1]%2)+'['+str(int(L[inx[-1]]))+'];\n\ 355 | quan_t buffer'+str((inx[-1]+1)%2)+'['+str(int(L[inx[-2]]))+'];\n\ 356 | quan_t w_kernel['+str(max_kernel_size)+'];\n\ 357 | \n\ 358 | typedef struct twoD\n\ 359 | {\n\ 360 | uint32_t r;\n\ 361 | uint32_t c;\n\ 362 | uint32_t in_channel;\n\ 363 | uint32_t channel;\n\ 364 | data_t scale;\n\ 365 | quan_t zero_quan;\n\ 366 | quan_t *data;\n\ 367 | index_t *indices;\n\ 368 | data_t *bias;\n\ 369 | } twoD_t;\n\n') 370 | 371 | prev_shapes = None 372 | prev_arr_name = None 373 | is_bias_first = None 374 | 375 | index = 0 376 | 377 | for name, param in best_model.state_dict().items(): 378 | name = name.replace('.module', '') 379 | if 'mask' in name or 'old' in name: 380 | continue 381 | arr = param.cpu().numpy() 382 | shape_of_params = arr.shape 383 | print(shape_of_params) 384 | param_size = len(arr.flatten()) 385 | print(name) 386 | layer_name = 'w_'+name.split('.', 1)[0] 387 | if "bias" in name: 388 | if is_bias_first is None: 389 | is_bias_first = True 390 | underscore_arr_name = 'w_'+name.replace('.', '_') 391 | array_name = underscore_arr_name + '[' + str(param_size) + ']=' 392 | weights_file.write('const data_t '+array_name+'{') 393 | for x in range(len(layer_parameters[layer_name]['bias'])): 394 | if x != 0: 395 | weights_file.write(',') 396 | weights_file.write(str(layer_parameters[layer_name]['bias'][x] if layer_parameters[layer_name]['bias'][x] != 0 else 0)) 397 | weights_file.write('};\n') 398 | elif "weight" in name: 399 | if is_bias_first is None: 400 | is_bias_first = False 401 | # Saving 8-bit quantized weights 402 | underscore_arr_name = 'w_'+name.replace('.', '_')+'_quantized' 403 | array_name = underscore_arr_name + '[' + str(len(layer_parameters[layer_name]['quantized'])) + ']=' 404 | weights_file.write('const quan_t '+array_name+'{') 405 | for x in range(len(layer_parameters[layer_name]['quantized'])): 406 | if x != 0: 407 | weights_file.write(',') 408 | weights_file.write(str(layer_parameters[layer_name]['quantized'][x])) 409 | weights_file.write('};\n') 410 | # Saving indices of weight values 411 | underscore_arr_name = 'w_'+name.replace('.', '_')+'_indices' 412 | array_name = underscore_arr_name + '[' + str(len(layer_parameters[layer_name]['indices'])) + ']=' 413 | weights_file.write('const index_t '+array_name+'{') 414 | for x in range(len(layer_parameters[layer_name]['indices'])): 415 | if x != 0: 416 | weights_file.write(',') 417 | weights_file.write(str(layer_parameters[layer_name]['indices'][x])) 418 | weights_file.write('};\n') 419 | underscore_arr_name = 'w_'+name.replace('.', '_') 420 | else: 421 | continue 422 | 423 | if not is_bias_first and len(shape_of_params) == 1: 424 | index += 1 425 | if len(prev_shapes) > 2: 426 | print('Conv layer') 427 | out_channel = prev_shapes[0] 428 | in_channel = prev_shapes[1] 429 | c_file.write('\t conv2D'+\ 430 | '(&'+meta_list[index-1][0]+', &'+prev_arr_name+'_2d, &'+meta_list[index][0]+', &reLU, '+str(meta_list[index][1][0])+', '+\ 431 | str(meta_list[index][1][1])+', '+str(meta_list[index][1][1])+', '+str(meta_list[index][1][2])+');\n') 432 | 433 | else: 434 | print('linear layer') 435 | out_channel = 1 436 | in_channel = 1 437 | if len(meta_list) == (index+1): 438 | c_file.write('\t dot'+'(&'+meta_list[index-1][0]+', &'+prev_arr_name+'_2d, &'+meta_list[index][0]+', NULL);\n') 439 | else: 440 | c_file.write('\t dot'+'(&'+meta_list[index-1][0]+', &'+prev_arr_name+'_2d, &'+meta_list[index][0]+', &reLU);\n') 441 | 442 | weights_file.write(('const twoD_t '+prev_arr_name+'_2d = {\n'+\ 443 | '\t.r = '+ str(prev_shapes[-1]) +',\n'+\ 444 | '\t.c = '+str(prev_shapes[-2])+',\n'+\ 445 | '\t.in_channel = '+str(in_channel)+',\n'+\ 446 | '\t.channel = '+str(out_channel)+',\n'+\ 447 | '\t.scale = '+str(layer_parameters[layer_name]['scale'])+',\n\t.data = '+prev_arr_name+'_quantized,\n\t.indices = '+prev_arr_name+'_indices,\n'+\ 448 | '\t.bias = '+prev_arr_name+'\n'+\ 449 | '};\n\n')) 450 | 451 | elif is_bias_first and len(shape_of_params) != 1: 452 | index += 1 453 | if len(shape_of_params) > 2: 454 | print('Conv layer') 455 | out_channel = shape_of_params[0] 456 | in_channel = shape_of_params[1] 457 | c_file.write('\t conv2D'\ 458 | +'(&'+meta_list[index-1][0]+', &'+underscore_arr_name+'_2d, &'+meta_list[index][0]+', &reLU, '+str(meta_list[index][1][0])+', '\ 459 | +str(meta_list[index][1][1])+', '+str(meta_list[index][1][1])+', '+str(meta_list[index][1][2])+');\n') 460 | else: 461 | print('linear layer') 462 | out_channel = 1 463 | in_channel = 1 464 | if len(meta_list) == (index+1): 465 | c_file.write('\t dot(&'+meta_list[index-1][0]+', &'+underscore_arr_name+'_2d, &'+meta_list[index][0]+', NULL);\n') 466 | else: 467 | c_file.write('\t dot(&'+meta_list[index-1][0]+', &'+underscore_arr_name+'_2d, &'+meta_list[index][0]+', &reLU);\n') 468 | 469 | weights_file.write(('const twoD_t '+underscore_arr_name+'_2d = {\n'+\ 470 | '\t.r = '+ str(shape_of_params[-1]) +',\n'+\ 471 | '\t.c = '+str(shape_of_params[-2])+',\n'+\ 472 | '\t.in_channel = '+str(in_channel)+',\n'+\ 473 | '\t.channel = '+str(out_channel)+',\n'+\ 474 | '\t.scale = '+str(layer_parameters[layer_name]['scale'])+',\n\t.data = '+underscore_arr_name+'_quantized,\n\t.indices = '+underscore_arr_name+'_indices,\n'+\ 475 | '\t.bias = '+prev_arr_name+'\n'+\ 476 | '};\n\n')) 477 | 478 | prev_shapes = shape_of_params 479 | prev_arr_name = underscore_arr_name 480 | 481 | c_file.write('\n\t print_twoD(&'+meta_list[-1][0]+', 0);\n') 482 | c_file.write('\t printf("PREDICTION: %d\\n", get_class(&'+meta_list[-1][0]+'));\n') 483 | 484 | class network_partial(nn.Module): 485 | def __init__(self, original_model): 486 | super(network_partial, self).__init__() 487 | self.features = nn.Sequential(*list(original_model.children())[:-6]) 488 | 489 | def forward(self, x): 490 | x = self.features(x) 491 | return x 492 | 493 | intermediate_network = network_partial(best_model) 494 | 495 | weights_file.close() 496 | 497 | c_file.write('\t return 0;\n}') 498 | 499 | c_file.close() -------------------------------------------------------------------------------- /helper_functions.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define DEBUG 1U 6 | 7 | void print_twoD(twoD_t *tmp, uint32_t channel) 8 | { 9 | #if (DEBUG) 10 | printf("Size: %uX%uX%u\n", tmp->channel, tmp->r, tmp->c); 11 | for (uint32_t i = 0; i < tmp->r; i++) 12 | { 13 | for (uint32_t j = 0; j < tmp->c; j++) 14 | { 15 | printf("%i ", (tmp->data[channel * tmp->c * tmp->r + i * tmp->c + j])); 16 | } 17 | printf("\n"); 18 | } 19 | #endif 20 | } 21 | 22 | // Choose stride>=c=r to reuse existing buffer again. 23 | void maxPooling(twoD_t *inout, uint32_t r, uint32_t c, uint32_t stride) 24 | { 25 | data_t max = 0; 26 | for (uint32_t i = 0; i < inout->r; i += stride) 27 | { 28 | for (uint32_t j = 0; j < inout->c; j += stride) 29 | { 30 | max = 0; 31 | for (uint32_t ii = 0; ii < r; ++ii) 32 | { 33 | for (uint32_t jj = 0; jj < c; ++jj) 34 | { 35 | if (inout->data[(i + ii) * inout->c + (j + jj)] > max) 36 | { 37 | max = inout->data[(i + ii) * inout->c + (j + jj)]; 38 | } 39 | } 40 | } 41 | inout->data[i / stride * inout->r / stride + j / stride] = max; 42 | } 43 | } 44 | } 45 | 46 | void dot(twoD_t *input, const twoD_t *weights, twoD_t *output, data_t (*activation)(data_t)) 47 | { 48 | if (weights->r != input->r * input->c * input->channel) 49 | { 50 | printf("size mismatch\n"); 51 | } 52 | output->r = weights->c; 53 | output->c = 1; 54 | 55 | data_t sum; 56 | uint32_t s = 0; // current sparse array index 57 | uint32_t sum_indices = weights->indices[0]; // sum of sparse index differences 58 | for (uint32_t i = 0; i < weights->c; ++i) 59 | { 60 | sum = 0; 61 | for (uint32_t j = 0; j < weights->r; ++j) 62 | { 63 | const uint32_t index = i * weights->r + j; 64 | // iterates over sparse indices until it passes the index 65 | while (sum_indices < index) 66 | { 67 | sum_indices += weights->indices[++s]; 68 | } 69 | // if the weight at the index is nonzero, adds it to the sum 70 | if (sum_indices == index) 71 | { 72 | sum += (input->data[j] - input->zero_quan) * input->scale * weights->data[s]; 73 | } 74 | } 75 | // multiplies the sum with the scale of quantized weights 76 | sum *= weights->scale; 77 | if (weights->bias) 78 | { 79 | sum += weights->bias[i]; 80 | } 81 | if (activation) 82 | { 83 | sum = activation(sum); 84 | } 85 | int32_t out = sum / output->scale + output->zero_quan; 86 | // clips quantized activations into 8-bit bounds 87 | if (out > 127) 88 | { 89 | out = 127; 90 | } 91 | else if (out < -128) 92 | { 93 | out = -128; 94 | } 95 | output->data[i] = out; 96 | } 97 | } 98 | 99 | data_t reLU(data_t a) 100 | { 101 | return a > 0 ? a : 0; 102 | } 103 | 104 | // only works when pooling stride >= r = c 105 | void conv2D(twoD_t *input, const twoD_t *kernel, twoD_t *output, data_t (*activation)(data_t), uint32_t stride, uint32_t r, uint32_t c, uint32_t padding) 106 | { 107 | uint32_t aux_r = input->r - kernel->r + 1 + 2 * padding; 108 | uint32_t aux_c = input->c - kernel->c + 1 + 2 * padding; 109 | data_t aux_sum; 110 | data_t max_pool; 111 | 112 | output->r = aux_r / stride; 113 | output->c = aux_c / stride; 114 | 115 | uint32_t s = 0; // current sparse array index 116 | uint32_t sum_indices = kernel->indices[0]; // sum of sparse index differences 117 | const uint32_t weight_size = kernel->r * kernel->c * input->channel; // size of each weight kernel 118 | for (uint32_t out_ch = 0; out_ch < output->channel; ++out_ch) 119 | { 120 | const uint32_t weight_begin_idx = out_ch * weight_size; // begin index of current weight kernel 121 | for (uint32_t i = 0; i < weight_size; ++i) 122 | { 123 | const uint32_t index = weight_begin_idx + i; // index of weights used in following loops 124 | // iterates over sparse indices until it passes the index 125 | while (sum_indices < index) 126 | { 127 | sum_indices += kernel->indices[++s]; 128 | } 129 | // adds the weight value to the array 130 | if (sum_indices == index) 131 | { 132 | w_kernel[i] = kernel->data[s]; 133 | } 134 | else 135 | { 136 | w_kernel[i] = 0; 137 | } 138 | } 139 | for (uint32_t i = 0; i < aux_r; i += stride) 140 | { 141 | for (uint32_t j = 0; j < aux_c; j += stride) 142 | { 143 | max_pool = 0; /* it is safe because post ReLU values are being compared */ 144 | for (uint32_t i_pool = 0; i_pool < r; ++i_pool) 145 | { 146 | for (uint32_t j_pool = 0; j_pool < c; ++j_pool) 147 | { 148 | aux_sum = 0; 149 | for (uint32_t ii = 0; ii < kernel->r; ++ii) 150 | { 151 | for (uint32_t jj = 0; jj < kernel->c; ++jj) 152 | { 153 | for (uint32_t in_ch = 0; in_ch < input->channel; ++in_ch) 154 | { 155 | const uint32_t index = in_ch * kernel->r * kernel->c + ii * kernel->c + jj; 156 | if (!w_kernel[index]) 157 | { 158 | continue; 159 | } 160 | uint32_t sub_i = i + i_pool + ii; 161 | uint32_t sub_j = j + j_pool + jj; 162 | if ((sub_i < padding) || (sub_j < padding) || (sub_i >= (input->r + padding)) || (sub_j >= (input->c + padding))) 163 | { 164 | continue; 165 | } 166 | // multiplies quantized input activations with their scale value since integer only arithmetic overflows 167 | aux_sum += (input->data[in_ch * input->r * input->c + (sub_i - padding) * input->c + (sub_j - padding)] - input->zero_quan) * input->scale * w_kernel[index]; 168 | } 169 | } 170 | } 171 | // multiplies the sum with the scale of quantized weights 172 | aux_sum *= kernel->scale; 173 | if (kernel->bias) 174 | { 175 | aux_sum += kernel->bias[out_ch]; 176 | } 177 | if (activation) 178 | { 179 | aux_sum = activation(aux_sum); 180 | } 181 | if (aux_sum > max_pool) 182 | { 183 | max_pool = aux_sum; 184 | } 185 | } 186 | } 187 | int32_t out = max_pool / output->scale + output->zero_quan; 188 | // clips quantized activations into 8-bit bounds 189 | if (out > 127) 190 | { 191 | out = 127; 192 | } 193 | else if (out < -128) 194 | { 195 | out = -128; 196 | } 197 | output->data[out_ch * output->r * output->c + i / stride * output->c + j / stride] = out; 198 | } 199 | } 200 | } 201 | } 202 | 203 | uint32_t get_class(const twoD_t *output) 204 | { 205 | uint32_t max_inx = 0; 206 | for (uint32_t i = 0; i < output->r; ++i) 207 | { 208 | if (output->data[i] > output->data[max_inx]) 209 | { 210 | max_inx = i; 211 | } 212 | } 213 | return max_inx; 214 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data import sampler 6 | import torch.nn.functional as F 7 | import torchvision.datasets as dset 8 | 9 | from nni.algorithms.compression.pytorch.pruning import LevelPruner 10 | from nni.compression.pytorch.compressor import QuantizerModuleWrapper 11 | from nni.compression.pytorch.compressor import PrunerModuleWrapper 12 | 13 | import numpy as np 14 | import time 15 | import sys 16 | 17 | if 0: 18 | device = torch.device('cuda') 19 | else: 20 | device = torch.device('cpu') 21 | 22 | def evaluate_model(model, test_loader, process_time=False, print_out=False): 23 | """ 24 | Find the accuracy of a model on the a dataset. 25 | 26 | Inputs: 27 | - model: A PyTorch Module giving the model to find the accuracy of 28 | - test_loader: A data loader object to receive test data 29 | - process_time: (Optional) Boolean to print the evaluation time 30 | - print_out: (Optional) Boolean to print the accuracy in each iteration 31 | 32 | Returns: The accuracy of the model 33 | """ 34 | num_correct = 0 35 | num_samples = 0 36 | model.eval() # set model to evaluation mode 37 | with torch.no_grad(): 38 | if process_time and print_out: 39 | start = time.process_time() 40 | for x, y in test_loader: 41 | x = x.to(device=device) # move to device, e.g. GPU 42 | y = y.to(device=device) # move to device, e.g. GPU 43 | scores = model(x) 44 | _, preds = scores.max(1) 45 | num_correct += (preds == y).sum() 46 | num_samples += preds.size(0) 47 | if process_time: 48 | end = time.process_time() 49 | acc = float(num_correct) / num_samples 50 | if print_out: 51 | print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc)) 52 | if process_time and print_out: 53 | print("Time elapsed: " + str(end - start)) 54 | return acc 55 | 56 | def train_model(model, train_loader, validate_loader, epochs, optimizer, print_out=False): 57 | """ 58 | Train a model on the given dataset using the PyTorch Module API. 59 | 60 | Inputs: 61 | - model: A PyTorch Module giving the model to train 62 | - train_loader: A data loader object to receive train data 63 | - validate_loader: A data loader object to receive validate data 64 | - epochs: (Optional) A Python integer giving the number of epochs to train for 65 | - optimizer: An Optimizer object we will use to train the model 66 | - print_out: (Optional) Boolean to print the accuracy in each iteration 67 | 68 | Returns: Best model and its accuracy 69 | """ 70 | best_acc = -1 71 | best_model = None 72 | model = model.to(device=device) 73 | for e in range(epochs): 74 | for t, (x, y) in enumerate(train_loader): 75 | model.train() # put model to training mode 76 | x = x.to(device=device) # move to device, e.g. GPU 77 | y = y.to(device=device) 78 | 79 | scores = model(x) 80 | loss = F.cross_entropy(scores, y) 81 | 82 | # Zero out all of the gradients for the variables which the optimizer 83 | # will update. 84 | optimizer.zero_grad() 85 | 86 | # This is the backwards pass: compute the gradient of the loss with 87 | # respect to each parameter of the model. 88 | loss.backward() 89 | 90 | # Actually update the parameters of the model using the gradients 91 | # computed by the backwards pass. 92 | optimizer.step() 93 | 94 | if t % 50 == 0: 95 | if print_out: 96 | print('Iteration %d, loss = %.4f' % (t, loss.item())) 97 | acc = evaluate_model(model, validate_loader) 98 | if acc > best_acc: 99 | best_acc = acc 100 | best_model = model 101 | if print_out: 102 | print() 103 | if print_out: 104 | print('Best accuracy found', best_acc) 105 | return best_model, best_acc 106 | 107 | def get_activation_bounds(model, train_loader): 108 | """ 109 | Find the activation bounds of a network on a training data. 110 | Activation bound is the minimum and maximum values a node 111 | in the network can have. 112 | 113 | Inputs: 114 | - model: A PyTorch Module giving the model to train 115 | - train_loader: A data loader object to receive train data 116 | 117 | Returns: Activation bounds, scale values, and zero indices for each layer 118 | """ 119 | bounds = {} 120 | # { 121 | # 'output_order': { 122 | # 'min': ..., 123 | # 'max': ..., 124 | # 'scale: ..., 125 | # 'zero': ... 126 | # }, 127 | # } 128 | model.eval() 129 | with torch.no_grad(): 130 | children = list(model.children()) # get all layers 131 | for x, y in train_loader: 132 | x = x.to(device=device) 133 | y = y.to(device=device) 134 | output = x 135 | order = 0 136 | for i in range(len(children)): 137 | output = children[i](output) 138 | # find the max and min activations for each layer 139 | if not isinstance(children[i], nn.Flatten) and (i + 1 == len(children) or isinstance(children[i + 1], QuantizerModuleWrapper) or isinstance(children[i + 1], nn.Conv2d) or isinstance(children[i + 1], nn.Linear) or isinstance(children[i + 1], nn.Flatten)): 140 | if order not in bounds.keys(): 141 | bounds[order] = { 'min': sys.float_info.max, 'max': sys.float_info.min, 'scale': 0, 'zero': 0 } 142 | if torch.max(output).item() > bounds[order]['max']: 143 | bounds[order]['max'] = torch.max(output).item() 144 | if torch.min(output).item() < bounds[order]['min']: 145 | bounds[order]['min'] = torch.min(output).item() 146 | order += 1 147 | for idx in bounds: # find the scale and the index of zero for each layer 148 | bounds[idx]['scale'] = float((bounds[idx]['max'] - bounds[idx]['min']) / 255) 149 | bounds[idx]['zero'] = int(-(bounds[idx]['min'] / bounds[idx]['scale']) - 128) 150 | return bounds 151 | 152 | def evaluate_quantized_model(model, test_loader, activation_bounds, quantize_input=True): 153 | """ 154 | Find the accuracy of a quantized model on a test dataset. Contrary to the 155 | evaluate_model function, weights and activations are quantized to evaluate 156 | the model accurately. 157 | 158 | Inputs: 159 | - model: A PyTorch Module giving the model to find the accuracy of 160 | - test_loader: A data loader object to receive test data 161 | - activation_bounds: Activation bounds of the given model 162 | - quantize_input: (Optional) Boolean to quantize input activations 163 | 164 | Returns: The accuracy of the model 165 | """ 166 | num_correct = 0 167 | num_samples = 0 168 | model.eval() # set model to evaluation mode 169 | with torch.no_grad(): 170 | children = list(model.children()) # get all layers 171 | for x, y in test_loader: 172 | x = x.to(device=device) # move to device, e.g. GPU 173 | y = y.to(device=device) # move to device, e.g. GPU 174 | scores = x 175 | if quantize_input: # quantize input activations 176 | scores = torch.clamp(torch.add(torch.mul(x, 255), -128), -128, 127) 177 | scores = torch.floor(scores) 178 | scores = torch.div(torch.sub(scores, -128), 255) 179 | order = 0 180 | for i in range(len(children)): 181 | scores = children[i](scores) 182 | # quantize output activations of the layer 183 | if not isinstance(children[i], nn.Flatten) and (i + 1 == len(children) or isinstance(children[i + 1], QuantizerModuleWrapper) or isinstance(children[i + 1], nn.Conv2d) or isinstance(children[i + 1], nn.Linear) or isinstance(children[i + 1], nn.Flatten)): 184 | scores = torch.clamp(torch.add(torch.div(scores, activation_bounds[order]['scale']), activation_bounds[order]['zero']), -128, 127) 185 | scores = torch.floor(scores) 186 | scores = torch.mul(torch.sub(scores, activation_bounds[order]['zero']), activation_bounds[order]['scale']) 187 | order += 1 188 | _, preds = scores.max(1) 189 | num_correct += (preds == y).sum() 190 | num_samples += preds.size(0) 191 | acc = float(num_correct) / num_samples 192 | return acc 193 | 194 | def level_prune_model(model, config_list): 195 | """ 196 | Prune a model with the given configuration using the Neural Network Intelligence's (NNI) 197 | LevelPruner tool. 198 | 199 | Inputs: 200 | - model: A PyTorch Module giving the model to prune 201 | - config_list: A configuration object for the LevelPruner 202 | 203 | Returns: Nothing 204 | """ 205 | pruner = LevelPruner(model, config_list) 206 | pruner.compress() 207 | 208 | def sparse_matrix_1d(array, max_bit_size=8): 209 | """ 210 | Converts an array to compressed sparse column (CSC) format. 211 | 212 | Inputs: 213 | - model: An array to be converted to CSC 214 | - max_bit_size: (Optional) An integer as the maximum bit width 215 | 216 | Returns: Nothing 217 | """ 218 | values = [] 219 | indices = [] 220 | index_diff = 0 221 | for x in array: 222 | if x != 0: 223 | values.append(x) 224 | indices.append(index_diff) 225 | index_diff = 0 226 | if index_diff == 2 ** max_bit_size - 1: 227 | values.append(0) 228 | indices.append(index_diff) 229 | index_diff = 0 230 | index_diff += 1 231 | return values, indices --------------------------------------------------------------------------------