├── GSC_v2_data.py ├── GTSRB_data.py ├── README.md ├── baseline_execute.py ├── cifar10 ├── __init__.py ├── cifar10.data-00000-of-00001 ├── cifar10.index ├── cifar10.meta ├── cifar10_fisher.npy ├── cifar10_network_fisher.npy ├── cifar10_network_weight.npy ├── cifar10_weight.npy └── pintle.py ├── cifar10_data.py ├── download_dataset.sh ├── gsc ├── __init__.py ├── gsc.data-00000-of-00001 ├── gsc.index ├── gsc.meta ├── gsc_fisher.npy ├── gsc_network_fisher.npy ├── gsc_network_weight.npy ├── gsc_weight.npy └── pintle.py ├── gtsrb ├── __init__.py ├── gtsrb.data-00000-of-00001 ├── gtsrb.index ├── gtsrb.meta ├── gtsrb_fisher.npy ├── gtsrb_network_fisher.npy ├── gtsrb_network_weight.npy ├── gtsrb_weight.npy └── pintle.py ├── in-memory_execute.py ├── joint_optimization.sh ├── mnist ├── __init__.py ├── mnist.data-00000-of-00001 ├── mnist.index ├── mnist.meta ├── mnist_fisher.npy ├── mnist_network_fisher.npy ├── mnist_network_weight.npy ├── mnist_weight.npy └── pintle.py ├── mnist_data.py ├── sequential_optimization.sh ├── svhn ├── __init__.py ├── pintle.py ├── svhn.data-00000-of-00001 ├── svhn.index ├── svhn.meta ├── svhn_fisher.npy ├── svhn_network_fisher.npy ├── svhn_network_weight.npy └── svhn_weight.npy ├── svhn_data.py ├── tf_operation.so ├── tf_operation ├── build_tf_operation.sh ├── build_tf_operation_nano.sh ├── build_weight_loader.sh ├── build_weight_loader_nano.sh ├── tf_operation.cc ├── tf_operation.cu ├── tf_operation.cu.o ├── tf_operation.so ├── weight_loader.c ├── weight_loader.cu ├── weight_loader.cu.o └── weight_loader.so ├── weight_loader.so └── weight_virtualization.py /GSC_v2_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import os 4 | import struct 5 | 6 | train_path = '...' 7 | test_path = '...' 8 | validation_path = '...' 9 | 10 | #""" 11 | GSC_v2_train_data = np.load('GSC_v2_train_data.npy') 12 | GSC_v2_train_label = np.load('GSC_v2_train_label.npy') 13 | 14 | GSC_v2_test_data = np.load('GSC_v2_test_data.npy') 15 | GSC_v2_test_label = np.load('GSC_v2_test_label.npy') 16 | 17 | GSC_v2_validation_data = np.load('GSC_v2_validation_data.npy') 18 | GSC_v2_validation_label = np.load('GSC_v2_validation_label.npy') 19 | #""" 20 | 21 | def train_set(): 22 | return GSC_v2_train_data, GSC_v2_train_label 23 | 24 | def test_set(): 25 | return GSC_v2_test_data, GSC_v2_test_label 26 | 27 | def validation_set(): 28 | return GSC_v2_validation_data, GSC_v2_validation_label 29 | 30 | def get_GSC_v2(filename): 31 | file = open(filename, "rb") 32 | file_content = file.read() 33 | file.close() 34 | 35 | (num_frame, num_filter) = struct.unpack("ii", file_content[:8]) 36 | 37 | GSC_v2_tuple = struct.unpack("f" * num_frame * num_filter, file_content[8:8+num_frame*num_filter*4]) 38 | GSC_v2 = np.asarray(GSC_v2_tuple).reshape((num_frame,num_filter)) 39 | 40 | crop_start_filter = 1 41 | crop_end_filter = 13 42 | GSC_v2_cropped = GSC_v2[:,crop_start_filter:crop_start_filter+crop_end_filter] 43 | 44 | GSC_v2_resized = GSC_v2_cropped.reshape((num_frame*(crop_start_filter+crop_end_filter-1))) 45 | 46 | normalized_GSC_v2 = GSC_v2_resized / np.linalg.norm(GSC_v2_resized) 47 | 48 | return normalized_GSC_v2 49 | 50 | def get_GSC_v2_batch(dir_path): 51 | GSC_v2_batch = [] 52 | labels = [] 53 | label = 0 54 | 55 | for (dirpath, dirnames, filenames) in os.walk(dir_path): 56 | if os.path.basename(dirpath) == '': 57 | continue 58 | 59 | for (dirpath2, dirnames2, filenames2) in os.walk(dirpath): 60 | for file in sorted(filenames2): 61 | filepath = os.path.join(dirpath2, file) 62 | GSC_v2_batch.append(get_GSC_v2(filepath)) 63 | labels.append(label) 64 | 65 | label += 1 66 | 67 | GSC_v2_batch = np.vstack((GSC_v2_batch)) 68 | labels_one_hot = np.zeros((GSC_v2_batch.shape[0], label)) 69 | labels_one_hot[np.arange(GSC_v2_batch.shape[0]), labels] = 1 70 | 71 | return GSC_v2_batch, labels_one_hot 72 | 73 | def create_data_files(): 74 | GSC_v2_train_data, GSC_v2_train_label = get_GSC_v2_batch(train_path) 75 | GSC_v2_test_data, GSC_v2_test_label = get_GSC_v2_batch(test_path) 76 | GSC_v2_validation_data, GSC_v2_validation_label = get_GSC_v2_batch(validation_path) 77 | 78 | np.save('GSC_v2_train_data', GSC_v2_train_data) 79 | np.save('GSC_v2_train_label', GSC_v2_train_label) 80 | np.save('GSC_v2_test_data', GSC_v2_test_data) 81 | np.save('GSC_v2_test_label', GSC_v2_test_label) 82 | np.save('GSC_v2_validation_data', GSC_v2_validation_data) 83 | np.save('GSC_v2_validation_label', GSC_v2_validation_label) 84 | 85 | def main(): 86 | #create_data_files() 87 | print(train_set()[0]) 88 | print(train_set()[0].shape) 89 | print(train_set()[1]) 90 | print(train_set()[1].shape) 91 | 92 | if __name__ == '__main__': 93 | main() 94 | 95 | -------------------------------------------------------------------------------- /GTSRB_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | 6 | train_path = '...' 7 | test_path = '...' 8 | 9 | #""" 10 | GTSRB_train_data = np.load('GTSRB_train_data.npy') 11 | GTSRB_train_label = np.load('GTSRB_train_label.npy') 12 | 13 | GTSRB_test_data = np.load('GTSRB_test_data.npy') 14 | GTSRB_test_label = np.load('GTSRB_test_label.npy') 15 | #""" 16 | 17 | def train_set(): 18 | return GTSRB_train_data, GTSRB_train_label 19 | 20 | def test_set(): 21 | return GTSRB_test_data, GTSRB_test_label 22 | 23 | def validation_set(): 24 | return None, None 25 | 26 | def get_GTSRB_test_batch(dir_path): 27 | GTSRB_batch = [] 28 | labels = [] 29 | 30 | csv_filepath = os.path.join(dir_path, 'GT-final_test.csv') 31 | if not os.path.exists(csv_filepath): 32 | return 33 | 34 | print(csv_filepath) 35 | 36 | with open(csv_filepath) as f: 37 | for line in f: 38 | lists = line.split(';') 39 | img_filepath = os.path.join(dir_path, lists[0]) 40 | new_img_filepath = os.path.join(dir_path, os.path.basename(img_filepath)) 41 | new_img_filepath = os.path.splitext(new_img_filepath)[0] + '.png' 42 | if not os.path.exists(new_img_filepath): 43 | continue 44 | 45 | img = Image.open(new_img_filepath) 46 | GTSRB = np.array(img) 47 | GTSRB = GTSRB.reshape((np.size(GTSRB))) 48 | normalized_GTSRB = GTSRB / np.linalg.norm(GTSRB) 49 | label = int(lists[7]) 50 | GTSRB_batch.append(normalized_GTSRB) 51 | labels.append(label) 52 | 53 | GTSRB_batch = np.stack((GTSRB_batch)) 54 | print(GTSRB_batch.shape) 55 | 56 | labels_one_hot = np.zeros((GTSRB_batch.shape[0], np.max(labels)+1)) 57 | labels_one_hot[np.arange(GTSRB_batch.shape[0]), labels] = 1 58 | print(labels_one_hot.shape) 59 | 60 | return GTSRB_batch, labels_one_hot 61 | 62 | def get_GTSRB_train_batch(dir_path): 63 | GTSRB_batch = [] 64 | labels = [] 65 | label = 0 66 | 67 | for (dirpath, dirnames, filenames) in sorted(os.walk(dir_path)): 68 | if os.path.basename(dirpath) == '': 69 | continue 70 | 71 | for (dirpath2, dirnames2, filenames2) in sorted(os.walk(dirpath)): 72 | for file in sorted(filenames2): 73 | filepath = os.path.join(dirpath2, file) 74 | img = Image.open(filepath) 75 | GTSRB = np.array(img) 76 | GTSRB = GTSRB.reshape((np.size(GTSRB))) 77 | normalized_GTSRB = GTSRB / np.linalg.norm(GTSRB) 78 | GTSRB_batch.append(normalized_GTSRB) 79 | labels.append(label) 80 | 81 | label += 1 82 | 83 | GTSRB_batch = np.vstack((GTSRB_batch)) 84 | labels_one_hot = np.zeros((GTSRB_batch.shape[0], label)) 85 | labels_one_hot[np.arange(GTSRB_batch.shape[0]), labels] = 1 86 | 87 | print(labels_one_hot.shape) 88 | 89 | return GTSRB_batch, labels_one_hot 90 | 91 | def create_data_files(): 92 | GTSRB_train_data, GTSRB_train_label = get_GTSRB_train_batch(train_path) 93 | GTSRB_test_data, GTSRB_test_label = get_GTSRB_test_batch(test_path) 94 | 95 | np.save('GTSRB_train_data', GTSRB_train_data) 96 | np.save('GTSRB_train_label', GTSRB_train_label) 97 | np.save('GTSRB_test_data', GTSRB_test_data) 98 | np.save('GTSRB_test_label', GTSRB_test_label) 99 | 100 | def main(): 101 | #create_data_files() 102 | print(train_set()[0]) 103 | print(train_set()[0].shape) 104 | print(train_set()[1]) 105 | print(train_set()[1].shape) 106 | 107 | if __name__ == '__main__': 108 | main() 109 | 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [MobiSys 2020] Fast and Scalable In-memory Deep Multitask Learning via Neural Weight Virtualization 2 | 3 | ## Introduction 4 | This is an open-source repository of the [MobiSys 2020](https://www.sigmobile.org/mobisys/2020/) paper titled "***Fast and Scalable In-memory Deep Multitask Learning via Neural Weight Virtualization***". It enables fast and scalable in-memory multitask deep learning on memory-constrained embedded systems by (1) packing multiple deep neural networks (DNNs) into a fixed-sized main memory whose combined memory requirement is larger than the main memory, and (2) enabling fast in-memory execution of the DNNs. 5 | 6 | This repository implements (1) *virtualization of weight parameters* of multiple heterogeneous DNNs of arbitrary network architectures, and (2) *in-memory execution and context-switching* of deep neural network (DNN) tasks. For the user's convenience, we provide a step-by-step guideline for the weight virtualization and in-memory execution of the five DNN that are used for the multitask learning IoT device, one of the application systems we implemented in the paper. The sizes of those DNNs are small, so the entire process of weight virtualization can be easily demonstrated in a reasonable time without requiring to spend several days. The five DNNs to be virtualized are [MNIST](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf), [GoogleSpeechCommands (GSC)](https://arxiv.org/abs/1804.03209), [German Traffic Sign Recognition Benchmark (GTSRB)](https://www.ini.rub.de/upload/file/1470692848_f03494010c16c36bab9e/StallkampEtAl_GTSRB_IJCNN2011.pdf), [CIFAR-10](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf), and [Street View House Numbers (SVHN)](http://ufldl.stanford.edu/housenumbers/nips2011_housenumbers.pdf). It shows an example of weight virtualization, which can be easily applied to any other DNNs, i.e., they can be virtualized in the same way presented here. 7 | 8 |   9 | ## Software Install and Setup 10 | Neural weight virtualization is implemented by using Python, TensorFlow, and NVIDIA CUDA (custom TensorFlow operation). The TensorFlow version should be lower than or equal to 1.13.2; the latest version (1.14) seems to have a problem of executing custom operations. We used Python 2.7, Tensorflow 1.13.1, and CUDA 10.0. A GPU is required to perform the weight virtualization (i.e., weight-page matching and optimization) as well as in-memory execution of GPU RAM. We used NVIDIA RTX 20280 Ti GPU. 11 | 12 | **Step 1.** Install [Python (>= 2.7)](https://www.python.org/downloads/). 13 | 14 | **Step 2.** Install [Tensorflow (<= 1.13.2)](https://www.tensorflow.org/). 15 | 16 | **Step 3.** Install [NVIDIA CUDA (>= 10.0)](https://developer.nvidia.com/cuda-downloads). 17 | 18 | **Step 4.** Clone this NeuralWeightVirtualization repository. 19 | ```sh 20 | $ git clone https://github.com/multitaskdnn-mobisys2020/NeuralWeightVirtualization.git 21 | Cloning into 'NeuralWeightVirtualization'... 22 | remote: Enumerating objects: 225, done. 23 | remote: Counting objects: 100% (225/225), done. 24 | remote: Compressing objects: 100% (178/178), done. 25 | remote: Total 225 (delta 90), reused 164 (delta 42), pack-reused 0 26 | Receiving objects: 100% (225/225), 11.81 MiB | 15.90 MiB/s, done. 27 | Resolving deltas: 100% (90/90), done. 28 | ``` 29 | 30 | ## 1) Download Datasets (Preliminary Step 1) 31 | Before getting into the weight virtualization, download the five datasets by executing the following script (*download_dataset.sh*). The script uses [curl](https://curl.haxx.se/download.html) for downloading the datasets. 32 | ```sh 33 | $ ./download_dataset.sh 34 | [1/4] Downloading CIFAR10 dataset... 35 | % Total % Received % Xferd Average Speed Time Time Time Current 36 | Dload Upload Total Spent Left Speed 37 | 100 388 0 388 0 0 2604 0 --:--:-- --:--:-- --:--:-- 2604 38 | 100 234M 0 234M 0 0 63.2M 0 --:--:-- 0:00:03 --:--:-- 71.2M 39 | % Total % Received % Xferd Average Speed Time Time Time Current 40 | Dload Upload Total Spent Left Speed 41 | 100 388 0 388 0 0 937 0 --:--:-- --:--:-- --:--:-- 934 42 | 100 781k 100 781k 0 0 1260k 0 --:--:-- --:--:-- --:--:-- 1260k 43 | % Total % Received % Xferd Average Speed Time Time Time Current 44 | Dload Upload Total Spent Left Speed 45 | 100 388 0 388 0 0 2675 0 --:--:-- --:--:-- --:--:-- 2675 46 | 100 1171M 0 1171M 0 0 67.8M 0 --:--:-- 0:00:17 --:--:-- 46.5M 47 | % Total % Received % Xferd Average Speed Time Time Time Current 48 | Dload Upload Total Spent Left Speed 49 | 100 388 0 388 0 0 665 0 --:--:-- --:--:-- --:--:-- 664 50 | 100 3906k 0 3906k 0 0 4706k 0 --:--:-- --:--:-- --:--:-- 4706k 51 | ... 52 | ... 53 | ... 54 | [4/4] Downloading SVHN dataset... 55 | % Total % Received % Xferd Average Speed Time Time Time Current 56 | Dload Upload Total Spent Left Speed 57 | 100 388 0 388 0 0 2791 0 --:--:-- --:--:-- --:--:-- 2791 58 | 100 305M 0 305M 0 0 79.1M 0 --:--:-- 0:00:03 --:--:-- 90.4M 59 | % Total % Received % Xferd Average Speed Time Time Time Current 60 | Dload Upload Total Spent Left Speed 61 | 100 388 0 388 0 0 885 0 --:--:-- --:--:-- --:--:-- 885 62 | 100 1017k 100 1017k 0 0 1473k 0 --:--:-- --:--:-- --:--:-- 1473k 63 | % Total % Received % Xferd Average Speed Time Time Time Current 64 | Dload Upload Total Spent Left Speed 65 | 100 388 0 388 0 0 2639 0 --:--:-- --:--:-- --:--:-- 2621 66 | 100 772M 0 772M 0 0 84.7M 0 --:--:-- 0:00:09 --:--:-- 93.4M 67 | % Total % Received % Xferd Average Speed Time Time Time Current 68 | Dload Upload Total Spent Left Speed 69 | 100 388 0 388 0 0 958 0 --:--:-- --:--:-- --:--:-- 955 70 | 100 2575k 0 2575k 0 0 3974k 0 --:--:-- --:--:-- --:--:-- 3974k 71 | % Total % Received % Xferd Average Speed Time Time Time Current 72 | Dload Upload Total Spent Left Speed 73 | 100 388 0 388 0 0 174 0 --:--:-- 0:00:02 --:--:-- 174 74 | 100 85.8M 0 85.8M 0 0 26.1M 0 --:--:-- 0:00:03 --:--:-- 96.3M 75 | % Total % Received % Xferd Average Speed Time Time Time Current 76 | Dload Upload Total Spent Left Speed 77 | 100 388 0 388 0 0 2380 0 --:--:-- --:--:-- --:--:-- 2380 78 | 100 286k 100 286k 0 0 788k 0 --:--:-- --:--:-- --:--:-- 788k 79 | ``` 80 | 81 | ## 2) Prepare and Train DNN Models (Preliminary Step 2) 82 | The next preliminary step is to obtain and train DNN models for the five datasets. For the user's convenience, we include pre-trained models of the five DNNs in this repository. They are located in separate folders. 83 | ```sh 84 | $ ls -d mnist gsc gtsrb cifar10 svhn 85 | cifar10 gsc gtsrb mnist svhn 86 | ``` 87 | 88 | The number of weight parameters, memory usage, and inference accuracy of each DNN model are shown in the below table, which is the same as in the paper. In short, a total of ***268,692 weight parameters (1,046 KB)*** are virtualized into ***66,475 weights (258 KB)***, achieving ***4x*** of packing efficiency. 89 | 90 | | DNN | Number of weights | Memory (KB) | Accuracy (%) | 91 | | :-------------: | -------------: | -------------: | -------------: | 92 | | MNIST | 45,706 | 178 | 98.33 | 93 | | GSC | 65,531 | 256 | 69.86 | 94 | | GTSRB | 66,475 | 258 | 92.74 | 95 | | CIFAR-10 | 45,490 | 176 | 55.68 | 96 | | SVHN | 45,490 | 176 | 81.55 | 97 | | **Total** | ***268,692*** | ***1,046*** | - | 98 | | **Virtualized** | ***66,475*** | ***258*** | - | 99 | 100 | ## 3) Weight Virtualization Step 1: Weight-Page Matching 101 | The first step of weight virtualization is the weight-page matching, which is performed by a Python script (*weight_virtualization.py*). It first computes Fisher information of the DNN and then performs weight-page matching as described in the paper. Each DNN performs the weight-page matching one by one. 102 | 103 | Perform weight-page matching for the first DNN (MNIST) with the following Python script. 104 | ```sh 105 | $ python weight_virtualization.py -mode=a -network_path=mnist 106 | init new weight pages 107 | add_vnn 108 | mnist/mnist_network_weight.npy 109 | compute_fisher 110 | Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes. 111 | Extracting MNIST_data/train-images-idx3-ubyte.gz 112 | Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes. 113 | Extracting MNIST_data/train-labels-idx1-ubyte.gz 114 | Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes. 115 | Extracting MNIST_data/t10k-images-idx3-ubyte.gz 116 | Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes. 117 | Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 118 | do_compute_fisher 119 | sample num: 0, data_idx: 40422 120 | sample num: 1, data_idx: 43444 121 | sample num: 2, data_idx: 23402 122 | ... 123 | ... 124 | ... 125 | sample num: 97, data_idx: 40313 126 | sample num: 98, data_idx: 21500 127 | sample num: 99, data_idx: 3595 128 | mnist/mnist_network_fisher.npy 129 | [calculate_cost] 130 | toal_cost: 0.0 131 | 458 pages allocated for 45706 weights 132 | total_network_cost: 0 133 | ``` 134 | 135 | Perform weight-page matching for the second (GSC) with the following Python script. 136 | ```sh 137 | $ python weight_virtualization.py -mode=a -network_path=gsc 138 | add_vnn 139 | gsc/gsc_network_weight.npy 140 | compute_fisher 141 | do_compute_fisher 142 | sample num: 0, data_idx: 19860 143 | sample num: 1, data_idx: 51449 144 | sample num: 2, data_idx: 30773 145 | ... 146 | ... 147 | ... 148 | sample num: 97, data_idx: 41594 149 | sample num: 98, data_idx: 30133 150 | sample num: 99, data_idx: 44799 151 | gsc/gsc_network_fisher.npy 152 | [match_page_by_cost] 153 | occupation: 0 154 | len(page_list): 207 155 | len(network_page_list): 656 156 | 0-th page 157 | 206-th page 158 | cost: 0.0 159 | 160 | occupation: 1 161 | len(page_list): 458 162 | len(network_page_list): 449 163 | 0-th page 164 | 448-th page 165 | cost: 0.03946808 166 | 167 | assing_page 146.488 ms 168 | [calculate_cost] 169 | toal_cost: 0.0394762507876294 170 | 656 pages allocated for 65531 weights 171 | total_network_cost: 0.15915566682815552 172 | ``` 173 | 174 | Perform weight-page matching for the third (GTSRB) with the following Python script. 175 | ```sh 176 | $ python weight_virtualization.py -mode=a -network_path=gtsrb 177 | add_vnn 178 | gtsrb/gtsrb_network_weight.npy 179 | compute_fisher 180 | do_compute_fisher 181 | sample num: 0, data_idx: 23099 182 | sample num: 1, data_idx: 22485 183 | sample num: 2, data_idx: 15947 184 | ... 185 | ... 186 | ... 187 | sample num: 97, data_idx: 6798 188 | sample num: 98, data_idx: 9251 189 | sample num: 99, data_idx: 18952 190 | gtsrb/gtsrb_network_fisher.npy 191 | [match_page_by_cost] 192 | occupation: 0 193 | len(page_list): 0 194 | len(network_page_list): 665 195 | cost: 0 196 | 197 | occupation: 1 198 | len(page_list): 216 199 | len(network_page_list): 665 200 | 0-th page 201 | 215-th page 202 | cost: 1.4526434 203 | 204 | occupation: 2 205 | len(page_list): 449 206 | len(network_page_list): 449 207 | 0-th page 208 | 448-th page 209 | cost: 0.047510564 210 | 211 | assing_page 150.184 ms 212 | [calculate_cost] 213 | toal_cost: 1.5001573274303155 214 | 665 pages allocated for 66475 weights 215 | total_network_cost: 6.379258215427399 216 | ``` 217 | 218 | Perform weight-page matching for the fourth (CIFAR-10) with the following Python script. 219 | ```sh 220 | $ python weight_virtualization.py -mode=a -network_path=cifar10 221 | add_vnn 222 | cifar10/cifar10_network_weight.npy 223 | compute_fisher 224 | do_compute_fisher 225 | sample num: 0, data_idx: 30796 226 | sample num: 1, data_idx: 44166 227 | sample num: 2, data_idx: 2649 228 | ... 229 | ... 230 | ... 231 | sample num: 97, data_idx: 6889 232 | sample num: 98, data_idx: 36036 233 | sample num: 99, data_idx: 1621 234 | cifar10/cifar10_network_fisher.npy 235 | [match_page_by_cost] 236 | occupation: 0 237 | len(page_list): 0 238 | len(network_page_list): 455 239 | cost: 0 240 | 241 | occupation: 1 242 | len(page_list): 0 243 | len(network_page_list): 455 244 | cost: 0 245 | 246 | occupation: 2 247 | len(page_list): 216 248 | len(network_page_list): 455 249 | 0-th page 250 | 215-th page 251 | cost: 13.863106 252 | 253 | occupation: 3 254 | len(page_list): 449 255 | len(network_page_list): 239 256 | 0-th page 257 | 238-th page 258 | cost: 0.0028860972 259 | 260 | assing_page 134.211 ms 261 | [calculate_cost] 262 | toal_cost: 13.865990731732381 263 | 455 pages allocated for 45490 weights 264 | total_network_cost: 71.27165424823761 265 | ``` 266 | 267 | Perform weight-page matching for the fifth (SVHN) with the following Python script. 268 | ```sh 269 | $ python weight_virtualization.py -mode=a -network_path=svhn 270 | add_vnn 271 | svhn/svhn_network_weight.npy 272 | compute_fisher 273 | do_compute_fisher 274 | sample num: 0, data_idx: 51356 275 | sample num: 1, data_idx: 47162 276 | sample num: 2, data_idx: 624 277 | ... 278 | ... 279 | ... 280 | sample num: 97, data_idx: 46074 281 | sample num: 98, data_idx: 41740 282 | sample num: 99, data_idx: 42296 283 | svhn/svhn_network_fisher.npy 284 | [match_page_by_cost] 285 | occupation: 0 286 | len(page_list): 0 287 | len(network_page_list): 455 288 | cost: 0 289 | 290 | occupation: 1 291 | len(page_list): 0 292 | len(network_page_list): 455 293 | cost: 0 294 | 295 | occupation: 2 296 | len(page_list): 0 297 | len(network_page_list): 455 298 | cost: 0 299 | 300 | occupation: 3 301 | len(page_list): 426 302 | len(network_page_list): 455 303 | 0-th page 304 | 425-th page 305 | cost: 5.48569 306 | 307 | occupation: 4 308 | len(page_list): 239 309 | len(network_page_list): 29 310 | 0-th page 311 | 28-th page 312 | cost: 0.0003839913 313 | 314 | assing_page 143.431 ms 315 | [calculate_cost] 316 | toal_cost: 5.486162122021597 317 | 455 pages allocated for 45490 weights 318 | total_network_cost: 114.62664997577667 319 | ``` 320 | 321 | ## 4) Weight Virtualization Step 2: Weight-Page Optimization 322 | The next step of weight virtualization is the weight-page optimization, which combines the matched weight-pages into single virtual weight-pages and optimizes them for the DNN tasks. Here, we perform joint optimization in which all the DNN tasks are optimized together by executing a shell script (*joint_optimization.sh*). 323 | ```sh 324 | $ ./joint_optimization.sh 325 | 1-th joint optimization 326 | get_matching_loss 327 | v_train 328 | Extracting MNIST_data/train-images-idx3-ubyte.gz 329 | Extracting MNIST_data/train-labels-idx1-ubyte.gz 330 | Extracting MNIST_data/t10k-images-idx3-ubyte.gz 331 | Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 332 | step 0, training accuracy: 0.100000 original loss: 6.907214 matching loss: 7.037911 333 | step 0, Validation accuracy: 0.113500 334 | step 100, training accuracy: 0.390000 original loss: 6.157534 matching loss: 3.197577 335 | step 100, Validation accuracy: 0.484900 336 | get new weight for 0.4849 337 | step 200, training accuracy: 0.890000 original loss: 5.160514 matching loss: 1.595586 338 | step 200, Validation accuracy: 0.861600 339 | get new weight for 0.8616 340 | ... 341 | ... 342 | ... 343 | step 1700, Validation accuracy: 0.839966 344 | step 1800, training accuracy: 0.890000 original loss:: 4.937591 matching loss: 0.009319 345 | step 1800, Validation accuracy: 0.840005 346 | step 1900, training accuracy: 0.880000 original loss:: 5.124782 matching loss: 0.009399 347 | step 1900, Validation accuracy: 0.846727 348 | get new weight for 0.84672713 349 | step 1999, training accuracy: 0.890000 original loss: 4.990521 matching loss 0.009340 350 | step 1999, Validation accuracy: 0.842732 351 | svhn/svhn_weight.npy 352 | Extracting MNIST_data/train-images-idx3-ubyte.gz 353 | Extracting MNIST_data/train-labels-idx1-ubyte.gz 354 | Extracting MNIST_data/t10k-images-idx3-ubyte.gz 355 | Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 356 | Inference accuracy: 0.983100 357 | Inference accuracy: 0.745388 358 | Inference accuracy: 0.948931 359 | Inference accuracy: 0.585900 360 | Inference accuracy: 0.849647 361 | ``` 362 | 363 | After the optimization is finished, the virtual weight-pages are generated (*virtual_weight_page.npy*). The final inference accuracy of each DNN that is achieved with the virtual weight-pages can be checked with the following Python script (*weight_virtualization.py*). 364 | ```sh 365 | $ python weight_virtualization.py -mode=e -vnn_name=mnist 366 | Extracting MNIST_data/train-images-idx3-ubyte.gz 367 | Extracting MNIST_data/train-labels-idx1-ubyte.gz 368 | Extracting MNIST_data/t10k-images-idx3-ubyte.gz 369 | Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 370 | Inference accuracy: 0.983100 371 | ``` 372 | ```sh 373 | $ python weight_virtualization.py -mode=e -vnn_name=gsc 374 | Inference accuracy: 0.745388 375 | ``` 376 | ```sh 377 | $ python weight_virtualization.py -mode=e -vnn_name=gtsrb 378 | Inference accuracy: 0.948931 379 | ``` 380 | ```sh 381 | $ python weight_virtualization.py -mode=e -vnn_name=cifar10 382 | Inference accuracy: 0.585900 383 | ``` 384 | ```sh 385 | python weight_virtualization.py -mode=e -vnn_name=svhn 386 | Inference accuracy: 0.849647 387 | ``` 388 | 389 | The below table is the comparison of inference accuracy between the non-virtualized (stand-alone) DNNs and the virtualized DNNs. 390 | 391 | | | Accuracy (%) | Accuracy (%) | 392 | | :-------------: | -------------: | -------------: | 393 | | DNN | Non-Virtualized DNN | ***Virtualized DNN*** | 394 | | MNIST | 98.33 | ***98.31*** | 395 | | GSC | 69.86 | ***74.53*** | 396 | | GTSRB | 92.74 | ***94.89*** | 397 | | CIFAR-10 | 55.68 | ***58.59*** | 398 | | SVHN | 81.55 | ***84.96*** | 399 | 400 | ## 5) In-Memory Execution (ours) vs. No In-Memory Execution (previous) 401 | Once the weight virtualization is completed, the virtual weight-pages that will be shared between the DNNs are generated (*virtual_weight_page.npy*) and loaded into the GPU RAM. By using the virtual weight pages, the DNNs are executed entirely in the GPU RAM, which enables fast and responsive execution of the DNNs. 402 | 403 | We compare the DNN switching (weight parameter loading) and execution time of the in-memory multitask execution against the non-virtualized baseline DNNs that perform the DNN switching (restoring the saved weight parameters) and execution based on a secondary storage module (e.g., HDD or SSD) as done in many state-of-the-art DNNs today. 404 | 405 | First, the in-memory execution is performed by the following Python script (i.e., *in-memory_execute.py*) that executes 30 random DNNs and measures the DNN switching and execution time. The result show that the total DNN switching time and execution time of the 30 DNN execution are ***3.919 ms*** and ***4443.045 ms***, respectively. 406 | 407 | ```sh 408 | $ python in-memory_execute.py 409 | virtual_weight address: 140111287156736 410 | init virtual_weight 87.536 ms 411 | [VNN 0][cifar10] init page table 3.399 ms 412 | [VNN 1][gsc] init page table 2.733 ms 413 | [VNN 2][gtsrb] init page table 2.462 ms 414 | [VNN 3][mnist] init page table 2.361 ms 415 | [VNN 4][svhn] init page table 2.465 ms 416 | tf.global_variables_initializer 696.639 ms 417 | [Executing] svhn 418 | weights load time : 0.204 ms 419 | DNN execution time: 1350.998 ms 420 | Inference accuracy: 0.849647 421 | [Executing] gsc 422 | weights load time : 0.165 ms 423 | DNN execution time: 273.202 ms 424 | Inference accuracy: 0.745388 425 | ... 426 | ... 427 | ... 428 | [Executing] cifar10 429 | weights load time : 0.123 ms 430 | DNN execution time: 92.804 ms 431 | Inference accuracy: 0.585900 432 | [Executing] svhn 433 | weights load time : 0.123 ms 434 | DNN execution time: 178.231 ms 435 | Inference accuracy: 0.849647 436 | total weights load time : 3.919 ms 437 | total DNN execution time: 4443.045 ms 438 | ``` 439 | 440 | Next, the no in-memory execution is performed by the following Python script (i.e., *baseline_execute.py*) that executes 30 random DNNs and measures the DNN switching and execution time. The result show that the total DNN switching time and execution time of the 30 DNN execution are ***1801.726 ms*** and ***4559.881 ms***, respectively. 441 | ```sh 442 | python baseline_execute.py 443 | [Executing] cifar10 444 | weights load time : 377.985 ms 445 | DNN execution time: 968.260 ms 446 | Inference accuracy: 0.555200 447 | Extracting MNIST_data/train-images-idx3-ubyte.gz 448 | Extracting MNIST_data/train-labels-idx1-ubyte.gz 449 | Extracting MNIST_data/t10k-images-idx3-ubyte.gz 450 | Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 451 | [Executing] mnist 452 | weights load time : 45.943 ms 453 | DNN execution time: 169.006 ms 454 | Inference accuracy: 0.980800 455 | ... 456 | ... 457 | ... 458 | [Executing] gtsrb 459 | weights load time : 38.172 ms 460 | DNN execution time: 74.077 ms 461 | Inference accuracy: 0.928029 462 | [Executing] gtsrb 463 | weights load time : 37.190 ms 464 | DNN execution time: 74.039 ms 465 | Inference accuracy: 0.928029 466 | [Executing] gsc 467 | weights load time : 41.281 ms 468 | DNN execution time: 54.813 ms 469 | Inference accuracy: 0.694684 470 | total weights load time : 1801.726 ms 471 | total DNN execution time: 4559.881 ms 472 | ``` 473 | 474 | It shows that in-memory execution accelerates the DNN switching time by ***459x (1801.726 ms vs. 3.919 ms)***. 475 | 476 | | | DNN Switching Time (ms) | DNN Execution Time (ms) 477 | | :-------------: | -------------: | -------------: | 478 | | No In-Memory Execution | 1801.72 | 4559.881 | 479 | | ***In-Memory Execution*** | ***3.919*** | ***4443.04*** | 480 | 481 |   482 | ## Citation (BibTeX) 483 | **Fast and Scalable In-memory Deep Multitask Learning via Neural Weight Virtualization** 484 | ``` 485 | @inproceedings{10.1145/3386901.3388947, 486 | author = {Lee, Seulki and Nirjon, Shahriar}, 487 | title = {Fast and Scalable In-Memory Deep Multitask Learning via Neural Weight Virtualization}, 488 | year = {2020}, 489 | isbn = {9781450379540}, 490 | publisher = {Association for Computing Machinery}, 491 | address = {New York, NY, USA}, 492 | url = {https://doi.org/10.1145/3386901.3388947}, 493 | doi = {10.1145/3386901.3388947}, 494 | booktitle = {Proceedings of the 18th International Conference on Mobile Systems, Applications, and Services}, 495 | pages = {175–190}, 496 | numpages = {16}, 497 | keywords = {virtualization, in-memory, multitask learning, deep neural network}, 498 | location = {Toronto, Ontario, Canada}, 499 | series = {MobiSys ’20} 500 | } 501 | ``` 502 | -------------------------------------------------------------------------------- /baseline_execute.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import os 4 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 5 | import tensorflow as tf 6 | import importlib 7 | import time 8 | import ctypes 9 | from weight_virtualization import VNN 10 | from weight_virtualization import WeightVirtualization 11 | 12 | tf.logging.set_verbosity(tf.logging.ERROR) 13 | 14 | wv_op = tf.load_op_library('./tf_operation.so') 15 | _weight_loader = ctypes.CDLL('./weight_loader.so') 16 | _weight_loader.get_weight.argtypes = (ctypes.POINTER(ctypes.c_int64), ctypes.POINTER(ctypes.c_int), 17 | ctypes.c_int, ctypes.c_int64, ctypes.c_int64, ctypes.c_int) 18 | 19 | #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.060) 20 | gpu_options = None 21 | 22 | def baseline_execute(graph, sess, vnn, layers, data_set, label=None): 23 | print("[Executing]", vnn.name) 24 | 25 | saver = tf.train.import_meta_graph(vnn.meta_filepath) 26 | 27 | time1 = time.time() 28 | saver.restore(sess, vnn.model_filepath) 29 | time2 = time.time() 30 | weights_load_time = (time2-time1)*1000.0 31 | print('weights load time : %0.3f ms' % (weights_load_time)) 32 | 33 | keep_prob_input = graph.get_tensor_by_name("keep_prob_input:0") 34 | keep_prob = graph.get_tensor_by_name("keep_prob:0") 35 | x = graph.get_tensor_by_name("neuron_0:0") 36 | y = graph.get_tensor_by_name("neuron_" + str(layers-1) + ":0") 37 | 38 | data_set_reshaped = np.reshape(data_set, ([-1] + x.get_shape().as_list()[1:])) 39 | time1 = time.time() 40 | infer_result = sess.run(y, feed_dict={ 41 | x: data_set_reshaped, keep_prob_input: 1.0, keep_prob: 1.0}) 42 | time2 = time.time() 43 | DNN_execution_time = (time2-time1)*1000.0 44 | print('DNN execution time: %0.3f ms' % (DNN_execution_time)) 45 | 46 | if label is not None: 47 | y_ = graph.get_tensor_by_name("y_:0") 48 | accuracy = graph.get_tensor_by_name("accuracy:0") 49 | test_accuracy = sess.run(accuracy, feed_dict={ 50 | x: data_set_reshaped, y_: label, keep_prob_input: 1.0, keep_prob: 1.0}) 51 | print("Inference accuracy: %f" % test_accuracy) 52 | 53 | return weights_load_time, DNN_execution_time 54 | 55 | def main(): 56 | wv = WeightVirtualization() 57 | 58 | vnn_list = [] 59 | for name, vnn in sorted(wv.vnns.items()): 60 | vnn_list.append(vnn) 61 | 62 | data_list = [ 'cifar10_data', 'GSC_v2_data', 'GTSRB_data', 'mnist_data', 'svhn_data' ] 63 | layer_list = [ 7, 6, 7, 7, 7 ] 64 | 65 | total_weight_load_time = 0 66 | total_execution_time = 0 67 | num_execution = 30 68 | 69 | for i in range(num_execution): 70 | vnn_no = np.random.randint(len(vnn_list)) 71 | #print('vnn_no:', vnn_no) 72 | 73 | data = __import__(data_list[vnn_no]) 74 | data_set = data.test_set()[0]#[0:1000] 75 | label = data.test_set()[1]#[0:1000] 76 | 77 | with tf.Graph().as_default() as graph: 78 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 79 | weight_load_time, execution_time = baseline_execute(graph, 80 | sess, vnn_list[vnn_no], layer_list[vnn_no], 81 | data_set, label) 82 | 83 | total_weight_load_time += weight_load_time 84 | total_execution_time += execution_time 85 | 86 | print('total weights load time : %0.3f ms' % (total_weight_load_time)) 87 | print('total DNN execution time: %0.3f ms' % (total_execution_time)) 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /cifar10/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/cifar10/__init__.py -------------------------------------------------------------------------------- /cifar10/cifar10.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/cifar10/cifar10.data-00000-of-00001 -------------------------------------------------------------------------------- /cifar10/cifar10.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/cifar10/cifar10.index -------------------------------------------------------------------------------- /cifar10/cifar10.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/cifar10/cifar10.meta -------------------------------------------------------------------------------- /cifar10/cifar10_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/cifar10/cifar10_fisher.npy -------------------------------------------------------------------------------- /cifar10/cifar10_network_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/cifar10/cifar10_network_fisher.npy -------------------------------------------------------------------------------- /cifar10/cifar10_network_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/cifar10/cifar10_network_weight.npy -------------------------------------------------------------------------------- /cifar10/cifar10_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/cifar10/cifar10_weight.npy -------------------------------------------------------------------------------- /cifar10/pintle.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import_name = 'cifar10_data' 6 | 7 | def next_batch(data_set, batch_size): 8 | data = data_set[0] 9 | label = data_set[1] # one-hot vectors 10 | 11 | data_num = np.random.choice(data.shape[0], size=batch_size, replace=False) 12 | batch = data[data_num,:] 13 | label = label[data_num,:] # one-hot vectors 14 | 15 | return batch, label 16 | 17 | def v_input_variable_names(): 18 | input_variable_names = [ 'neuron_0', 'keep_prob_input', 'keep_prob' ] 19 | return input_variable_names 20 | 21 | def v_output_variable_names(): 22 | output_variable_names = [ 'neuron_6' ] 23 | return output_variable_names 24 | 25 | def v_train_input_variables(): 26 | data = __import__(import_name) 27 | train_set = data.train_set() 28 | train_image_reshaped = np.reshape(train_set[0], ([-1, 32, 32, 3])) 29 | return [[train_image_reshaped, train_set[1]], 1.0, 1.0] 30 | 31 | def v_test_input_variables(): 32 | data = __import__(import_name) 33 | test_set = data.test_set() 34 | test_image_reshaped = np.reshape(test_set[0], ([-1, 32, 32, 3])) 35 | return [[test_image_reshaped, test_set[1]], 1.0, 1.0] 36 | 37 | def v_execute(graph, sess, input_tensors, input_variables, ground_truth): 38 | tensor_y_name = "neuron_6:0" 39 | y = graph.get_tensor_by_name(tensor_y_name) 40 | 41 | # infer 42 | infer_result = sess.run(y, feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 43 | 44 | # accuracy 45 | test_accuracy = None 46 | if ground_truth is not None: 47 | y_ = graph.get_tensor_by_name("y_:0") 48 | accuracy = graph.get_tensor_by_name("accuracy:0") 49 | input_tensors.append(y_) 50 | input_variables.append(ground_truth) 51 | test_accuracy = sess.run(accuracy, 52 | feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 53 | print("Inference accuracy: %f" % test_accuracy) 54 | 55 | return infer_result, test_accuracy 56 | 57 | def v_train(graph, sess, matching_cost, batch_size, train_iteration, get_weight_func): 58 | print("v_train") 59 | 60 | data = __import__(import_name) 61 | train_set = data.train_set() 62 | validation_set = data.test_set() 63 | 64 | # get tensors 65 | tensor_x_name = "neuron_0:0" 66 | x = graph.get_tensor_by_name("neuron_0:0") 67 | y_ = graph.get_tensor_by_name("y_:0") 68 | keep_prob_input = graph.get_tensor_by_name("keep_prob_input:0") 69 | keep_prob = graph.get_tensor_by_name("keep_prob:0") 70 | accuracy = graph.get_tensor_by_name("accuracy:0") 71 | cross_entropy = graph.get_tensor_by_name('cross_entropy:0') 72 | 73 | learning_rate = 0.001 74 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 75 | name='matching_cost_optimizer') 76 | loss = optimizer.minimize(tf.add(cross_entropy, matching_cost)) 77 | 78 | sess.run(tf.variables_initializer(optimizer.variables())) 79 | 80 | input_images_validation = validation_set[0] 81 | input_images_validation_reshaped = np.reshape(validation_set[0], ([-1] + x.get_shape().as_list()[1:])) 82 | labels_validation = validation_set[1] 83 | highest_accuracy = 0 84 | new_weight_vector = None 85 | 86 | # train 87 | for i in range(train_iteration): 88 | input_data, labels = next_batch(train_set, batch_size) 89 | input_data_reshpaed = np.reshape(input_data, ([-1] + x.get_shape().as_list()[1:])) 90 | 91 | if i % (100) == 0 or i == (train_iteration-1): 92 | original_loss, matching_loss, train_accuracy = sess.run([cross_entropy, matching_cost, accuracy], 93 | feed_dict={x: input_data_reshpaed, y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 94 | print("step %d, training accuracy: %f original loss: %f matching loss: %f" 95 | % (i, train_accuracy, original_loss, matching_loss)) 96 | 97 | # validate 98 | test_accuracy = sess.run(accuracy, feed_dict={ 99 | x: input_images_validation_reshaped, y_: labels_validation, 100 | keep_prob_input: 1.0, keep_prob: 1.0}) 101 | print("step %d, Validation accuracy: %f" % (i, test_accuracy)) 102 | 103 | 104 | if i == 0: 105 | highest_accuracy = test_accuracy 106 | else: 107 | if test_accuracy > highest_accuracy: 108 | new_weight_vector = get_weight_func(sess) 109 | highest_accuracy = test_accuracy 110 | print('get new weight for', highest_accuracy) 111 | 112 | sess.run(loss, feed_dict={x: input_data_reshpaed, 113 | y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 114 | 115 | return new_weight_vector 116 | 117 | def v_fx_tensors(graph): 118 | y = graph.get_tensor_by_name("neuron_6:0") 119 | row_idx = tf.range(tf.shape(y)[0]) 120 | col_idx = tf.argmax(y, axis=1, output_type=tf.dtypes.int32) 121 | full_indices = tf.stack([row_idx, col_idx], axis=1) 122 | fx_tensors = tf.gather_nd(y, full_indices) 123 | return fx_tensors 124 | 125 | def main(): 126 | print(v_test_input_variables()) 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /cifar10_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | import numpy as np 3 | 4 | #""" 5 | cifar10_train_data = np.load('cifar10_train_data.npy') 6 | cifar10_train_label = np.load('cifar10_train_label.npy') 7 | cifar10_test_data = np.load('cifar10_test_data.npy') 8 | cifar10_test_label = np.load('cifar10_test_label.npy') 9 | #""" 10 | 11 | def train_set(): 12 | return cifar10_train_data, cifar10_train_label 13 | 14 | def test_set(): 15 | return cifar10_test_data, cifar10_test_label 16 | 17 | def validation_set(): 18 | return None, None 19 | 20 | def unpickle(file): 21 | import cPickle 22 | with open(file, 'rb') as fo: 23 | dict = cPickle.load(fo) 24 | return dict 25 | 26 | def get_cifar_test_batch(batch_size=10000): 27 | filename = "..." 28 | dict = unpickle(filename) 29 | 30 | idx = range(batch_size) 31 | img_flat = dict['data'][idx] 32 | 33 | labels = [] 34 | for i in range(batch_size): 35 | labels.append(dict['labels'][idx[i]]) 36 | 37 | img_R = img_flat[:,0:1024].reshape((batch_size, 32, 32)) 38 | img_G = img_flat[:,1024:2048].reshape((batch_size, 32, 32)) 39 | img_B = img_flat[:,2048:3072].reshape((batch_size, 32, 32)) 40 | img = np.stack((img_R, img_G, img_B), axis=3) 41 | batch = img / np.max(img) 42 | 43 | labels_one_hot = np.zeros((batch_size, 10)) 44 | labels_one_hot[np.arange(batch_size), labels] = 1 45 | 46 | return batch, labels_one_hot 47 | 48 | def get_cifar_train_batch(batch_size=10000): 49 | batch_list = [] 50 | labels_one_hot_list = [] 51 | 52 | for file_no in range (1, 5+1): 53 | filename = "..." + str(file_no) 54 | dict = unpickle(filename) 55 | 56 | idx = range(batch_size) 57 | img_flat = dict['data'][idx] 58 | 59 | labels = [] 60 | for i in range(batch_size): 61 | labels.append(dict['labels'][idx[i]]) 62 | 63 | img_R = img_flat[:,0:1024].reshape((batch_size, 32, 32)) 64 | img_G = img_flat[:,1024:2048].reshape((batch_size, 32, 32)) 65 | img_B = img_flat[:,2048:3072].reshape((batch_size, 32, 32)) 66 | img = np.stack((img_R, img_G, img_B), axis=3) 67 | batch = img / np.max(img) 68 | 69 | labels_one_hot = np.zeros((batch_size, 10)) 70 | labels_one_hot[np.arange(batch_size), labels] = 1 71 | 72 | batch_list.append(batch) 73 | labels_one_hot_list.append(labels_one_hot) 74 | 75 | return np.vstack((batch_list)), np.vstack((labels_one_hot_list)) 76 | 77 | def create_data_files(): 78 | cifar10_train_data, cifar10_train_label = get_cifar_train_batch() 79 | cifar10_test_data, cifar10_test_label = get_cifar_test_batch() 80 | 81 | np.save('cifar10_train_data', cifar10_train_data) 82 | print (cifar10_train_data.shape) 83 | np.save('cifar10_train_label', cifar10_train_label) 84 | print (cifar10_train_label.shape) 85 | np.save('cifar10_test_data', cifar10_test_data) 86 | print (cifar10_test_data.shape) 87 | np.save('cifar10_test_label', cifar10_test_label) 88 | print (cifar10_test_label.shape) 89 | 90 | def main(): 91 | #create_data_files() 92 | 93 | print(train_set()[0]) 94 | print(train_set()[0].shape) 95 | print(train_set()[1]) 96 | print(train_set()[1].shape) 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "[1/4] Downloading CIFAR10 dataset..." 4 | 5 | fileid="1C1npRinwi5DLoWSD_AWEPNp1khwqCOV4" 6 | filename="cifar10_test_data.npy" 7 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 8 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 9 | 10 | fileid="1SfvD7V1d04JXe9k2JUmygXUi4d83wujL" 11 | filename="cifar10_test_label.npy" 12 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 13 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 14 | 15 | fileid="1W-7dyYS5szUCZT4M_8qKeaRKlv_rjqBI" 16 | filename="cifar10_train_data.npy" 17 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 18 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 19 | 20 | fileid="1rI5KxWMoMF4r8a1w1jTzJdIzxRjmlKXz" 21 | filename="cifar10_train_label.npy" 22 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 23 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 24 | 25 | echo "" 26 | echo "[2/4] Downloading Google Speech Command V2 dataset..." 27 | 28 | fileid="1tqN3TQlAB0XUaNSR8Yvb0H6uhY5sJyML" 29 | filename="GSC_v2_test_data.npy" 30 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 31 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 32 | 33 | fileid="150ho5EfYFzdDChh0hc0xwpmmsO_bxzbO" 34 | filename="GSC_v2_test_label.npy" 35 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 36 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 37 | 38 | fileid="1EIBfVVFN1IjEF5j4eO-g3SMsCnKktBar" 39 | filename="GSC_v2_train_data.npy" 40 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 41 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 42 | 43 | fileid="11GFXzBx35LgosSqARLh2w9qMXd1rzIUF" 44 | filename="GSC_v2_train_label.npy" 45 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 46 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 47 | 48 | fileid="1frOtUfBrh64ML7VG1mTpCWfL9t31fK6A" 49 | filename="GSC_v2_validation_data.npy" 50 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 51 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 52 | 53 | fileid="1jWXz2bQXkZZu9ZN-ebyzwr8vdSZjlU7C" 54 | filename="GSC_v2_validation_label.npy" 55 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 56 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 57 | 58 | echo "" 59 | echo "[3/4] Downloading GTSRB dataset..." 60 | 61 | fileid="1kup2bRDjRcr_Ofch8O95-LMIvlY_-R7x" 62 | filename="GTSRB_test_data.npy" 63 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 64 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 65 | 66 | fileid="1cEMIUZeDZn4SqIVeI6y7eJ3t0EAw10dB" 67 | filename="GTSRB_test_label.npy" 68 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 69 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 70 | 71 | fileid="1dp_PP2ipGwwR6wUt15mBejkPFilSG65t" 72 | filename="GTSRB_train_data.npy" 73 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 74 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 75 | 76 | fileid="1nQBq1z3ncB1c6JxI-z2gre3m7SfoMtSp" 77 | filename="GTSRB_train_label.npy" 78 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 79 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 80 | 81 | echo "" 82 | echo "[4/4] Downloading SVHN dataset..." 83 | 84 | fileid="1QPFJkA-PLp5Ghcb1AG_hzEE9aIIeeY8A" 85 | filename="svhn_test_data.npy" 86 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 87 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 88 | 89 | fileid="1TGeLsOzpZbWnrKWehZg_ZvRViJsHPn7g" 90 | filename="svhn_test_label.npy" 91 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 92 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 93 | 94 | fileid="1Tqk2DqGAfdWLbQ6GpEoWnSe6HMwcctn9" 95 | filename="svhn_train_data.npy" 96 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 97 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 98 | 99 | fileid="1sHdjz8_W5ZDaOxReZ54xVVLiwYEoED2J" 100 | filename="svhn_train_label.npy" 101 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 102 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 103 | 104 | fileid="1XBMCZMEnHzQiU4qyXX4lx4HQxP0wfaj2" 105 | filename="svhn_validation_data.npy" 106 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 107 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 108 | 109 | fileid="1Pp7tVKU3iJe52GISudYqPxi7Zow2c38y" 110 | filename="svhn_validation_label.npy" 111 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 112 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 113 | -------------------------------------------------------------------------------- /gsc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gsc/__init__.py -------------------------------------------------------------------------------- /gsc/gsc.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gsc/gsc.data-00000-of-00001 -------------------------------------------------------------------------------- /gsc/gsc.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gsc/gsc.index -------------------------------------------------------------------------------- /gsc/gsc.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gsc/gsc.meta -------------------------------------------------------------------------------- /gsc/gsc_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gsc/gsc_fisher.npy -------------------------------------------------------------------------------- /gsc/gsc_network_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gsc/gsc_network_fisher.npy -------------------------------------------------------------------------------- /gsc/gsc_network_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gsc/gsc_network_weight.npy -------------------------------------------------------------------------------- /gsc/gsc_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gsc/gsc_weight.npy -------------------------------------------------------------------------------- /gsc/pintle.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import_name = 'GSC_v2_data' 6 | 7 | def next_batch(data_set, batch_size): 8 | data = data_set[0] 9 | label = data_set[1] # one-hot vectors 10 | 11 | data_num = np.random.choice(data.shape[0], size=batch_size, replace=False) 12 | batch = data[data_num,:] 13 | label = label[data_num,:] # one-hot vectors 14 | 15 | return batch, label 16 | 17 | def v_input_variable_names(): 18 | input_variable_names = [ 'neuron_0', 'keep_prob_input', 'keep_prob' ] 19 | return input_variable_names 20 | 21 | def v_output_variable_names(): 22 | output_variable_names = [ 'neuron_5' ] 23 | return output_variable_names 24 | 25 | def v_train_input_variables(): 26 | data = __import__(import_name) 27 | train_set = data.train_set() 28 | train_image_reshaped = np.reshape(train_set[0], ([-1, 61, 13, 1])) 29 | return [[train_image_reshaped, train_set[1]], 1.0, 1.0] 30 | 31 | def v_test_input_variables(): 32 | data = __import__(import_name) 33 | test_set = data.test_set() 34 | test_image_reshaped = np.reshape(test_set[0], ([-1, 61, 13, 1])) 35 | return [[test_image_reshaped, test_set[1]], 1.0, 1.0] 36 | 37 | def v_execute(graph, sess, input_tensors, input_variables, ground_truth): 38 | tensor_y_name = "neuron_5:0" 39 | y = graph.get_tensor_by_name(tensor_y_name) 40 | 41 | # infer 42 | infer_result = sess.run(y, feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 43 | 44 | # accuracy 45 | test_accuracy = None 46 | if ground_truth is not None: 47 | y_ = graph.get_tensor_by_name("y_:0") 48 | accuracy = graph.get_tensor_by_name("accuracy:0") 49 | input_tensors.append(y_) 50 | input_variables.append(ground_truth) 51 | test_accuracy = sess.run(accuracy, 52 | feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 53 | print("Inference accuracy: %f" % test_accuracy) 54 | 55 | return infer_result, test_accuracy 56 | 57 | def v_train(graph, sess, matching_cost, batch_size, train_iteration, get_weight_func): 58 | print("v_train") 59 | 60 | data = __import__(import_name) 61 | train_set = data.train_set() 62 | validation_set = data.test_set() 63 | 64 | # get tensors 65 | tensor_x_name = "neuron_0:0" 66 | x = graph.get_tensor_by_name("neuron_0:0") 67 | y_ = graph.get_tensor_by_name("y_:0") 68 | keep_prob_input = graph.get_tensor_by_name("keep_prob_input:0") 69 | keep_prob = graph.get_tensor_by_name("keep_prob:0") 70 | accuracy = graph.get_tensor_by_name("accuracy:0") 71 | cross_entropy = graph.get_tensor_by_name('cross_entropy:0') 72 | 73 | learning_rate = 0.001 74 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 75 | name='matching_cost_optimizer') 76 | loss = optimizer.minimize(tf.add(cross_entropy, matching_cost)) 77 | 78 | sess.run(tf.variables_initializer(optimizer.variables())) 79 | 80 | input_images_validation = validation_set[0] 81 | input_images_validation_reshaped = np.reshape(validation_set[0], ([-1] + x.get_shape().as_list()[1:])) 82 | labels_validation = validation_set[1] 83 | highest_accuracy = 0 84 | new_weight_vector = None 85 | 86 | # train 87 | for i in range(train_iteration): 88 | input_data, labels = next_batch(train_set, batch_size) 89 | input_data_reshpaed = np.reshape(input_data, ([-1] + x.get_shape().as_list()[1:])) 90 | 91 | if i % (100) == 0 or i == (train_iteration-1): 92 | original_loss, matching_loss, train_accuracy = sess.run([cross_entropy, matching_cost, accuracy], 93 | feed_dict={x: input_data_reshpaed, y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 94 | print("step %d, training accuracy: %f original loss: %f matching loss: %f" 95 | % (i, train_accuracy, original_loss, matching_loss)) 96 | 97 | # validate 98 | test_accuracy = sess.run(accuracy, feed_dict={ 99 | x: input_images_validation_reshaped, y_: labels_validation, 100 | keep_prob_input: 1.0, keep_prob: 1.0}) 101 | print("step %d, Validation accuracy: %f" % (i, test_accuracy)) 102 | 103 | 104 | if i == 0: 105 | highest_accuracy = test_accuracy 106 | else: 107 | if test_accuracy > highest_accuracy: 108 | new_weight_vector = get_weight_func(sess) 109 | highest_accuracy = test_accuracy 110 | print('get new weight for', highest_accuracy) 111 | 112 | sess.run(loss, feed_dict={x: input_data_reshpaed, 113 | y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 114 | 115 | return new_weight_vector 116 | 117 | def v_fx_tensors(graph): 118 | y = graph.get_tensor_by_name("neuron_5:0") 119 | row_idx = tf.range(tf.shape(y)[0]) 120 | col_idx = tf.argmax(y, axis=1, output_type=tf.dtypes.int32) 121 | full_indices = tf.stack([row_idx, col_idx], axis=1) 122 | fx_tensors = tf.gather_nd(y, full_indices) 123 | return fx_tensors 124 | 125 | def main(): 126 | print(v_test_input_variables()) 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /gtsrb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gtsrb/__init__.py -------------------------------------------------------------------------------- /gtsrb/gtsrb.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gtsrb/gtsrb.data-00000-of-00001 -------------------------------------------------------------------------------- /gtsrb/gtsrb.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gtsrb/gtsrb.index -------------------------------------------------------------------------------- /gtsrb/gtsrb.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gtsrb/gtsrb.meta -------------------------------------------------------------------------------- /gtsrb/gtsrb_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gtsrb/gtsrb_fisher.npy -------------------------------------------------------------------------------- /gtsrb/gtsrb_network_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gtsrb/gtsrb_network_fisher.npy -------------------------------------------------------------------------------- /gtsrb/gtsrb_network_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gtsrb/gtsrb_network_weight.npy -------------------------------------------------------------------------------- /gtsrb/gtsrb_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/gtsrb/gtsrb_weight.npy -------------------------------------------------------------------------------- /gtsrb/pintle.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import_name = 'GTSRB_data' 6 | 7 | def next_batch(data_set, batch_size): 8 | data = data_set[0] 9 | label = data_set[1] # one-hot vectors 10 | 11 | data_num = np.random.choice(data.shape[0], size=batch_size, replace=False) 12 | batch = data[data_num,:] 13 | label = label[data_num,:] # one-hot vectors 14 | 15 | return batch, label 16 | 17 | def v_input_variable_names(): 18 | input_variable_names = [ 'neuron_0', 'keep_prob_input', 'keep_prob' ] 19 | return input_variable_names 20 | 21 | def v_output_variable_names(): 22 | output_variable_names = [ 'neuron_6' ] 23 | return output_variable_names 24 | 25 | def v_train_input_variables(): 26 | data = __import__(import_name) 27 | train_set = data.train_set() 28 | train_image_reshaped = np.reshape(train_set[0], ([-1, 32, 32, 1])) 29 | return [[train_image_reshaped, train_set[1]], 1.0, 1.0] 30 | 31 | def v_test_input_variables(): 32 | data = __import__(import_name) 33 | test_set = data.test_set() 34 | test_image_reshaped = np.reshape(test_set[0], ([-1, 32, 32, 1])) 35 | return [[test_image_reshaped, test_set[1]], 1.0, 1.0] 36 | 37 | def v_execute(graph, sess, input_tensors, input_variables, ground_truth): 38 | tensor_y_name = "neuron_6:0" 39 | y = graph.get_tensor_by_name(tensor_y_name) 40 | 41 | # infer 42 | infer_result = sess.run(y, feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 43 | 44 | # accuracy 45 | test_accuracy = None 46 | if ground_truth is not None: 47 | y_ = graph.get_tensor_by_name("y_:0") 48 | accuracy = graph.get_tensor_by_name("accuracy:0") 49 | input_tensors.append(y_) 50 | input_variables.append(ground_truth) 51 | test_accuracy = sess.run(accuracy, 52 | feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 53 | print("Inference accuracy: %f" % test_accuracy) 54 | 55 | return infer_result, test_accuracy 56 | 57 | def v_train(graph, sess, matching_cost, batch_size, train_iteration, get_weight_func): 58 | print("v_train") 59 | 60 | data = __import__(import_name) 61 | train_set = data.train_set() 62 | validation_set = data.test_set() 63 | 64 | # get tensors 65 | tensor_x_name = "neuron_0:0" 66 | x = graph.get_tensor_by_name("neuron_0:0") 67 | y_ = graph.get_tensor_by_name("y_:0") 68 | keep_prob_input = graph.get_tensor_by_name("keep_prob_input:0") 69 | keep_prob = graph.get_tensor_by_name("keep_prob:0") 70 | accuracy = graph.get_tensor_by_name("accuracy:0") 71 | cross_entropy = graph.get_tensor_by_name('cross_entropy:0') 72 | 73 | learning_rate = 0.001 74 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 75 | name='matching_cost_optimizer') 76 | loss = optimizer.minimize(tf.add(cross_entropy, matching_cost)) 77 | 78 | sess.run(tf.variables_initializer(optimizer.variables())) 79 | 80 | input_images_validation = validation_set[0] 81 | input_images_validation_reshaped = np.reshape(validation_set[0], ([-1] + x.get_shape().as_list()[1:])) 82 | labels_validation = validation_set[1] 83 | highest_accuracy = 0 84 | new_weight_vector = None 85 | 86 | # train 87 | for i in range(train_iteration): 88 | input_data, labels = next_batch(train_set, batch_size) 89 | input_data_reshpaed = np.reshape(input_data, ([-1] + x.get_shape().as_list()[1:])) 90 | 91 | if i % (100) == 0 or i == (train_iteration-1): 92 | original_loss, matching_loss, train_accuracy = sess.run([cross_entropy, matching_cost, accuracy], 93 | feed_dict={x: input_data_reshpaed, y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 94 | print("step %d, training accuracy: %f original loss: %f matching loss: %f" 95 | % (i, train_accuracy, original_loss, matching_loss)) 96 | 97 | # validate 98 | test_accuracy = sess.run(accuracy, feed_dict={ 99 | x: input_images_validation_reshaped, y_: labels_validation, 100 | keep_prob_input: 1.0, keep_prob: 1.0}) 101 | print("step %d, Validation accuracy: %f" % (i, test_accuracy)) 102 | 103 | 104 | if i == 0: 105 | highest_accuracy = test_accuracy 106 | else: 107 | if test_accuracy > highest_accuracy: 108 | new_weight_vector = get_weight_func(sess) 109 | highest_accuracy = test_accuracy 110 | print('get new weight for', highest_accuracy) 111 | 112 | sess.run(loss, feed_dict={x: input_data_reshpaed, 113 | y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 114 | 115 | return new_weight_vector 116 | 117 | def v_fx_tensors(graph): 118 | y = graph.get_tensor_by_name("neuron_6:0") 119 | row_idx = tf.range(tf.shape(y)[0]) 120 | col_idx = tf.argmax(y, axis=1, output_type=tf.dtypes.int32) 121 | full_indices = tf.stack([row_idx, col_idx], axis=1) 122 | fx_tensors = tf.gather_nd(y, full_indices) 123 | return fx_tensors 124 | 125 | def main(): 126 | print(v_test_input_variables()) 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /in-memory_execute.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import os 4 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 5 | import tensorflow as tf 6 | import importlib 7 | import time 8 | import ctypes 9 | from weight_virtualization import VNN 10 | from weight_virtualization import WeightVirtualization 11 | 12 | tf.logging.set_verbosity(tf.logging.ERROR) 13 | 14 | wv_op = tf.load_op_library('./tf_operation.so') 15 | _weight_loader = ctypes.CDLL('./weight_loader.so') 16 | _weight_loader.get_weight.argtypes = (ctypes.POINTER(ctypes.c_int64), ctypes.POINTER(ctypes.c_int), 17 | ctypes.c_int, ctypes.c_int64, ctypes.c_int64, ctypes.c_int) 18 | 19 | #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.060) 20 | gpu_options = None 21 | 22 | def init_virtualization(wv, sess): 23 | vnn_list = [] 24 | for name, vnn in sorted(wv.vnns.items()): 25 | vnn_list.append(vnn) 26 | 27 | virtual_weight_address = None 28 | #with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 29 | time1 = time.time() 30 | virtual_weight_address = sess.run(wv_op.init_weight(wv.weight_page)) 31 | time2 = time.time() 32 | print('virtual_weight address:', virtual_weight_address) 33 | print('init virtual_weight %0.3f ms' % ((time2-time1)*1000.0)) 34 | 35 | page_address_list = [] 36 | vnn_no = 0 37 | for vnn in vnn_list: 38 | time1 = time.time() 39 | page_address = sess.run(wv_op.init_page_table(vnn.weight_page_list)) 40 | time2 = time.time() 41 | print('[VNN %d][%s] init page table %0.3f ms' 42 | % (vnn_no, vnn.name, (time2-time1)*1000.0)) 43 | page_address_list.append(page_address) 44 | vnn_no += 1 45 | 46 | page_table_address_list = [] 47 | for i in range(len(page_address_list)): 48 | page_table_address = tf.constant(page_address_list[i], 49 | name='page_table_address/' + str(i)) 50 | page_table_address_list.append(page_table_address) 51 | 52 | for vnn in vnn_list: 53 | with tf.name_scope(vnn.name): 54 | tf.train.import_meta_graph(vnn.meta_filepath) 55 | 56 | weight_address_list = [] 57 | weight_len_list = [] 58 | 59 | for i in range(len(vnn_list)): 60 | train_weights = tf.trainable_variables(scope=vnn_list[i].name) 61 | weight_address, weight_len = sess.run(wv_op.get_weight_address(train_weights)) 62 | weight_address_list.append(weight_address) 63 | weight_len_list.append(weight_len) 64 | 65 | time1 = time.time() 66 | sess.run(tf.global_variables_initializer()) 67 | time2 = time.time() 68 | print('tf.global_variables_initializer %0.3f ms' % ((time2-time1)*1000.0)) 69 | 70 | return vnn_list, weight_address_list, weight_len_list, virtual_weight_address, page_address_list 71 | 72 | def load_weight_page(virtual_weight_address, weight_address_list, 73 | weight_len_list, page_address_list, weight_per_page): 74 | num_of_weight = len(weight_address_list) 75 | weight_address_list_array_type = ctypes.c_int64 * num_of_weight 76 | weight_len_list_array_type = ctypes.c_int * num_of_weight 77 | _weight_loader.get_weight( 78 | weight_address_list_array_type(*weight_address_list), 79 | weight_len_list_array_type(*weight_len_list), 80 | ctypes.c_int(num_of_weight), 81 | ctypes.c_int64(virtual_weight_address), 82 | ctypes.c_int64(page_address_list), 83 | ctypes.c_int(weight_per_page)) 84 | 85 | def in_memory_execute(graph, sess, vnn, layers, data_set, 86 | virtual_weight_address, weight_address_list, weight_len_list, 87 | page_address_list, weight_per_page, label=None): 88 | print("[Executing]", vnn.name) 89 | 90 | time1 = time.time() 91 | load_weight_page(virtual_weight_address, weight_address_list, 92 | weight_len_list, page_address_list, weight_per_page) 93 | time2 = time.time() 94 | weights_load_time = (time2-time1)*1000.0 95 | print('weights load time : %0.3f ms' % (weights_load_time)) 96 | 97 | keep_prob_input = graph.get_tensor_by_name(vnn.name + "/keep_prob_input:0") 98 | keep_prob = graph.get_tensor_by_name(vnn.name + "/keep_prob:0") 99 | x = graph.get_tensor_by_name(vnn.name + "/neuron_0:0") 100 | y = graph.get_tensor_by_name(vnn.name + "/neuron_" + str(layers-1) + ":0") 101 | 102 | data_set_reshaped = np.reshape(data_set, ([-1] + x.get_shape().as_list()[1:])) 103 | time1 = time.time() 104 | infer_result = sess.run(y, feed_dict={ 105 | x: data_set_reshaped, keep_prob_input: 1.0, keep_prob: 1.0}) 106 | time2 = time.time() 107 | DNN_execution_time = (time2-time1)*1000.0 108 | print('DNN execution time: %0.3f ms' % (DNN_execution_time)) 109 | 110 | if label is not None: 111 | y_ = graph.get_tensor_by_name(vnn.name + "/y_:0") 112 | accuracy = graph.get_tensor_by_name(vnn.name + "/accuracy:0") 113 | test_accuracy = sess.run(accuracy, feed_dict={ 114 | x: data_set_reshaped, y_: label, keep_prob_input: 1.0, keep_prob: 1.0}) 115 | print("Inference accuracy: %f" % test_accuracy) 116 | 117 | return weights_load_time, DNN_execution_time 118 | 119 | def main(): 120 | wv = WeightVirtualization() 121 | 122 | data_list = [ 'cifar10_data', 'GSC_v2_data', 'GTSRB_data', 'mnist_data', 'svhn_data' ] 123 | layer_list = [ 7, 6, 7, 7, 7 ] 124 | 125 | total_weight_load_time = 0 126 | total_execution_time = 0 127 | num_execution = 30 128 | 129 | with tf.Graph().as_default() as graph: 130 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 131 | vnn_list, weight_address_list, weight_len_list, \ 132 | virtual_weight_address, \ 133 | page_address_list = init_virtualization(wv, sess) 134 | 135 | for i in range(num_execution): 136 | vnn_no = np.random.randint(len(vnn_list)) 137 | #print('vnn_no:', vnn_no) 138 | 139 | data = __import__(data_list[vnn_no]) 140 | data_set = data.test_set()[0]#[0:1000] 141 | label = data.test_set()[1]#[0:1000] 142 | 143 | weight_load_time, execution_time = in_memory_execute(tf.get_default_graph(), 144 | sess, vnn_list[vnn_no], layer_list[vnn_no], data_set, 145 | virtual_weight_address, 146 | weight_address_list[vnn_no], weight_len_list[vnn_no], 147 | page_address_list[vnn_no], wv.weight_per_page, label) 148 | 149 | total_weight_load_time += weight_load_time 150 | total_execution_time += execution_time 151 | 152 | print('total weights load time : %0.3f ms' % (total_weight_load_time)) 153 | print('total DNN execution time: %0.3f ms' % (total_execution_time)) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /joint_optimization.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..20} 4 | do 5 | echo "$i-th joint optimization" 6 | python weight_virtualization.py -mode=t -vnn_name=mnist -iter=2000 7 | python weight_virtualization.py -mode=e -vnn_name=mnist 8 | python weight_virtualization.py -mode=e -vnn_name=gsc 9 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 10 | python weight_virtualization.py -mode=e -vnn_name=cifar10 11 | python weight_virtualization.py -mode=e -vnn_name=svhn 12 | 13 | python weight_virtualization.py -mode=t -vnn_name=gsc -iter=2000 14 | python weight_virtualization.py -mode=e -vnn_name=mnist 15 | python weight_virtualization.py -mode=e -vnn_name=gsc 16 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 17 | python weight_virtualization.py -mode=e -vnn_name=cifar10 18 | python weight_virtualization.py -mode=e -vnn_name=svhn 19 | 20 | python weight_virtualization.py -mode=t -vnn_name=gtsrb -iter=2000 21 | python weight_virtualization.py -mode=e -vnn_name=mnist 22 | python weight_virtualization.py -mode=e -vnn_name=gsc 23 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 24 | python weight_virtualization.py -mode=e -vnn_name=cifar10 25 | python weight_virtualization.py -mode=e -vnn_name=svhn 26 | 27 | python weight_virtualization.py -mode=t -vnn_name=cifar10 -iter=2000 28 | python weight_virtualization.py -mode=e -vnn_name=mnist 29 | python weight_virtualization.py -mode=e -vnn_name=gsc 30 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 31 | python weight_virtualization.py -mode=e -vnn_name=cifar10 32 | python weight_virtualization.py -mode=e -vnn_name=svhn 33 | 34 | python weight_virtualization.py -mode=t -vnn_name=svhn -iter=2000 35 | python weight_virtualization.py -mode=e -vnn_name=mnist 36 | python weight_virtualization.py -mode=e -vnn_name=gsc 37 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 38 | python weight_virtualization.py -mode=e -vnn_name=cifar10 39 | python weight_virtualization.py -mode=e -vnn_name=svhn 40 | done 41 | -------------------------------------------------------------------------------- /mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/mnist/__init__.py -------------------------------------------------------------------------------- /mnist/mnist.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/mnist/mnist.data-00000-of-00001 -------------------------------------------------------------------------------- /mnist/mnist.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/mnist/mnist.index -------------------------------------------------------------------------------- /mnist/mnist.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/mnist/mnist.meta -------------------------------------------------------------------------------- /mnist/mnist_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/mnist/mnist_fisher.npy -------------------------------------------------------------------------------- /mnist/mnist_network_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/mnist/mnist_network_fisher.npy -------------------------------------------------------------------------------- /mnist/mnist_network_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/mnist/mnist_network_weight.npy -------------------------------------------------------------------------------- /mnist/mnist_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/mnist/mnist_weight.npy -------------------------------------------------------------------------------- /mnist/pintle.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import_name = 'mnist_data' 6 | 7 | def next_batch(data_set, batch_size): 8 | data = data_set[0] 9 | label = data_set[1] # one-hot vectors 10 | 11 | data_num = np.random.choice(data.shape[0], size=batch_size, replace=False) 12 | batch = data[data_num,:] 13 | label = label[data_num,:] # one-hot vectors 14 | 15 | return batch, label 16 | 17 | def v_input_variable_names(): 18 | input_variable_names = [ 'neuron_0', 'keep_prob_input', 'keep_prob' ] 19 | return input_variable_names 20 | 21 | def v_output_variable_names(): 22 | output_variable_names = [ 'neuron_6' ] 23 | return output_variable_names 24 | 25 | def v_train_input_variables(): 26 | data = __import__(import_name) 27 | train_set = data.train_set() 28 | train_image_reshaped = np.reshape(train_set[0], ([-1, 28, 28, 1])) 29 | return [[train_image_reshaped, train_set[1]], 1.0, 1.0] 30 | 31 | def v_test_input_variables(): 32 | data = __import__(import_name) 33 | test_set = data.test_set() 34 | test_image_reshaped = np.reshape(test_set[0], ([-1, 28, 28, 1])) 35 | return [[test_image_reshaped, test_set[1]], 1.0, 1.0] 36 | 37 | def v_execute(graph, sess, input_tensors, input_variables, ground_truth): 38 | tensor_y_name = "neuron_6:0" 39 | y = graph.get_tensor_by_name(tensor_y_name) 40 | 41 | # infer 42 | infer_result = sess.run(y, feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 43 | 44 | # accuracy 45 | test_accuracy = None 46 | if ground_truth is not None: 47 | y_ = graph.get_tensor_by_name("y_:0") 48 | accuracy = graph.get_tensor_by_name("accuracy:0") 49 | input_tensors.append(y_) 50 | input_variables.append(ground_truth) 51 | test_accuracy = sess.run(accuracy, 52 | feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 53 | print("Inference accuracy: %f" % test_accuracy) 54 | 55 | return infer_result, test_accuracy 56 | 57 | def v_train(graph, sess, matching_cost, batch_size, train_iteration, get_weight_func): 58 | print("v_train") 59 | 60 | data = __import__(import_name) 61 | train_set = data.train_set() 62 | validation_set = data.test_set() 63 | 64 | # get tensors 65 | tensor_x_name = "neuron_0:0" 66 | x = graph.get_tensor_by_name("neuron_0:0") 67 | y_ = graph.get_tensor_by_name("y_:0") 68 | keep_prob_input = graph.get_tensor_by_name("keep_prob_input:0") 69 | keep_prob = graph.get_tensor_by_name("keep_prob:0") 70 | accuracy = graph.get_tensor_by_name("accuracy:0") 71 | cross_entropy = graph.get_tensor_by_name('cross_entropy:0') 72 | 73 | learning_rate = 0.001 74 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 75 | name='matching_cost_optimizer') 76 | loss = optimizer.minimize(tf.add(cross_entropy, matching_cost)) 77 | 78 | sess.run(tf.variables_initializer(optimizer.variables())) 79 | 80 | input_images_validation = validation_set[0] 81 | input_images_validation_reshaped = np.reshape(validation_set[0], ([-1] + x.get_shape().as_list()[1:])) 82 | labels_validation = validation_set[1] 83 | highest_accuracy = 0 84 | new_weight_vector = None 85 | 86 | # train 87 | for i in range(train_iteration): 88 | input_data, labels = next_batch(train_set, batch_size) 89 | input_data_reshpaed = np.reshape(input_data, ([-1] + x.get_shape().as_list()[1:])) 90 | 91 | if i % (100) == 0 or i == (train_iteration-1): 92 | original_loss, matching_loss, train_accuracy = sess.run([cross_entropy, matching_cost, accuracy], 93 | feed_dict={x: input_data_reshpaed, y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 94 | print("step %d, training accuracy: %f original loss: %f matching loss: %f" 95 | % (i, train_accuracy, original_loss, matching_loss)) 96 | 97 | # validate 98 | test_accuracy = sess.run(accuracy, feed_dict={ 99 | x: input_images_validation_reshaped, y_: labels_validation, 100 | keep_prob_input: 1.0, keep_prob: 1.0}) 101 | print("step %d, Validation accuracy: %f" % (i, test_accuracy)) 102 | 103 | 104 | if i == 0: 105 | highest_accuracy = test_accuracy 106 | else: 107 | if test_accuracy > highest_accuracy: 108 | new_weight_vector = get_weight_func(sess) 109 | highest_accuracy = test_accuracy 110 | print('get new weight for', highest_accuracy) 111 | 112 | sess.run(loss, feed_dict={x: input_data_reshpaed, 113 | y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 114 | 115 | return new_weight_vector 116 | 117 | def v_fx_tensors(graph): 118 | y = graph.get_tensor_by_name("neuron_6:0") 119 | row_idx = tf.range(tf.shape(y)[0]) 120 | col_idx = tf.argmax(y, axis=1, output_type=tf.dtypes.int32) 121 | full_indices = tf.stack([row_idx, col_idx], axis=1) 122 | fx_tensors = tf.gather_nd(y, full_indices) 123 | return fx_tensors 124 | 125 | def main(): 126 | print(v_test_input_variables()) 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /mnist_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | from tensorflow.examples.tutorials.mnist import input_data 5 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 6 | 7 | def train_set(): 8 | return mnist.train.images, mnist.train.labels 9 | 10 | def validation_set(): 11 | return mnist.validation.images, mnist.validation.labels 12 | 13 | def test_set(): 14 | return mnist.test.images, mnist.test.labels 15 | 16 | def main(): 17 | print(train_set()[0]) 18 | print(train_set()[0].shape) 19 | print(train_set()[1]) 20 | print(train_set()[1].shape) 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /sequential_optimization.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python weight_virtualization.py -mode=t -vnn_name=mnist -iter=20000 4 | python weight_virtualization.py -mode=e -vnn_name=mnist 5 | python weight_virtualization.py -mode=e -vnn_name=gsc 6 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 7 | python weight_virtualization.py -mode=e -vnn_name=cifar10 8 | python weight_virtualization.py -mode=e -vnn_name=svhn 9 | 10 | python weight_virtualization.py -mode=t -vnn_name=gsc -iter=20000 11 | python weight_virtualization.py -mode=e -vnn_name=mnist 12 | python weight_virtualization.py -mode=e -vnn_name=gsc 13 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 14 | python weight_virtualization.py -mode=e -vnn_name=cifar10 15 | python weight_virtualization.py -mode=e -vnn_name=svhn 16 | 17 | python weight_virtualization.py -mode=t -vnn_name=gtsrb -iter=20000 18 | python weight_virtualization.py -mode=e -vnn_name=mnist 19 | python weight_virtualization.py -mode=e -vnn_name=gsc 20 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 21 | python weight_virtualization.py -mode=e -vnn_name=cifar10 22 | python weight_virtualization.py -mode=e -vnn_name=svhn 23 | 24 | python weight_virtualization.py -mode=t -vnn_name=cifar10 -iter=20000 25 | python weight_virtualization.py -mode=e -vnn_name=mnist 26 | python weight_virtualization.py -mode=e -vnn_name=gsc 27 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 28 | python weight_virtualization.py -mode=e -vnn_name=cifar10 29 | python weight_virtualization.py -mode=e -vnn_name=svhn 30 | 31 | python weight_virtualization.py -mode=t -vnn_name=svhn -iter=20000 32 | python weight_virtualization.py -mode=e -vnn_name=mnist 33 | python weight_virtualization.py -mode=e -vnn_name=gsc 34 | python weight_virtualization.py -mode=e -vnn_name=gtsrb 35 | python weight_virtualization.py -mode=e -vnn_name=cifar10 36 | python weight_virtualization.py -mode=e -vnn_name=svhn 37 | -------------------------------------------------------------------------------- /svhn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/svhn/__init__.py -------------------------------------------------------------------------------- /svhn/pintle.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import_name = 'svhn_data' 6 | 7 | def next_batch(data_set, batch_size): 8 | data = data_set[0] 9 | label = data_set[1] # one-hot vectors 10 | 11 | data_num = np.random.choice(data.shape[0], size=batch_size, replace=False) 12 | batch = data[data_num,:] 13 | label = label[data_num,:] # one-hot vectors 14 | 15 | return batch, label 16 | 17 | def v_input_variable_names(): 18 | input_variable_names = [ 'neuron_0', 'keep_prob_input', 'keep_prob' ] 19 | return input_variable_names 20 | 21 | def v_output_variable_names(): 22 | output_variable_names = [ 'neuron_6' ] 23 | return output_variable_names 24 | 25 | def v_train_input_variables(): 26 | data = __import__(import_name) 27 | train_set = data.train_set() 28 | train_image_reshaped = np.reshape(train_set[0], ([-1, 32, 32, 3])) 29 | return [[train_image_reshaped, train_set[1]], 1.0, 1.0] 30 | 31 | def v_test_input_variables(): 32 | data = __import__(import_name) 33 | test_set = data.test_set() 34 | test_image_reshaped = np.reshape(test_set[0], ([-1, 32, 32, 3])) 35 | return [[test_image_reshaped, test_set[1]], 1.0, 1.0] 36 | 37 | def v_execute(graph, sess, input_tensors, input_variables, ground_truth): 38 | tensor_y_name = "neuron_6:0" 39 | y = graph.get_tensor_by_name(tensor_y_name) 40 | 41 | # infer 42 | infer_result = sess.run(y, feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 43 | 44 | # accuracy 45 | test_accuracy = None 46 | if ground_truth is not None: 47 | y_ = graph.get_tensor_by_name("y_:0") 48 | accuracy = graph.get_tensor_by_name("accuracy:0") 49 | input_tensors.append(y_) 50 | input_variables.append(ground_truth) 51 | test_accuracy = sess.run(accuracy, 52 | feed_dict={t: v for t,v in zip(input_tensors, input_variables)}) 53 | print("Inference accuracy: %f" % test_accuracy) 54 | 55 | return infer_result, test_accuracy 56 | 57 | def v_train(graph, sess, matching_cost, batch_size, train_iteration, get_weight_func): 58 | print("v_train") 59 | 60 | data = __import__(import_name) 61 | train_set = data.train_set() 62 | validation_set = data.test_set() 63 | 64 | # get tensors 65 | tensor_x_name = "neuron_0:0" 66 | x = graph.get_tensor_by_name("neuron_0:0") 67 | y_ = graph.get_tensor_by_name("y_:0") 68 | keep_prob_input = graph.get_tensor_by_name("keep_prob_input:0") 69 | keep_prob = graph.get_tensor_by_name("keep_prob:0") 70 | accuracy = graph.get_tensor_by_name("accuracy:0") 71 | cross_entropy = graph.get_tensor_by_name('cross_entropy:0') 72 | 73 | learning_rate = 0.001 74 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 75 | name='matching_cost_optimizer') 76 | loss = optimizer.minimize(tf.add(cross_entropy, matching_cost)) 77 | 78 | sess.run(tf.variables_initializer(optimizer.variables())) 79 | 80 | input_images_validation = validation_set[0] 81 | input_images_validation_reshaped = np.reshape(validation_set[0], ([-1] + x.get_shape().as_list()[1:])) 82 | labels_validation = validation_set[1] 83 | highest_accuracy = 0 84 | new_weight_vector = None 85 | 86 | # train 87 | for i in range(train_iteration): 88 | input_data, labels = next_batch(train_set, batch_size) 89 | input_data_reshpaed = np.reshape(input_data, ([-1] + x.get_shape().as_list()[1:])) 90 | 91 | if i % (100) == 0 or i == (train_iteration-1): 92 | original_loss, matching_loss, train_accuracy = sess.run([cross_entropy, matching_cost, accuracy], 93 | feed_dict={x: input_data_reshpaed, y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 94 | print("step %d, training accuracy: %f original loss: %f matching loss: %f" 95 | % (i, train_accuracy, original_loss, matching_loss)) 96 | 97 | # validate 98 | test_accuracy = sess.run(accuracy, feed_dict={ 99 | x: input_images_validation_reshaped, y_: labels_validation, 100 | keep_prob_input: 1.0, keep_prob: 1.0}) 101 | print("step %d, Validation accuracy: %f" % (i, test_accuracy)) 102 | 103 | 104 | if i == 0: 105 | highest_accuracy = test_accuracy 106 | else: 107 | if test_accuracy > highest_accuracy: 108 | new_weight_vector = get_weight_func(sess) 109 | highest_accuracy = test_accuracy 110 | print('get new weight for', highest_accuracy) 111 | 112 | sess.run(loss, feed_dict={x: input_data_reshpaed, 113 | y_: labels, keep_prob_input: 1.0, keep_prob: 1.0}) 114 | 115 | return new_weight_vector 116 | 117 | def v_fx_tensors(graph): 118 | y = graph.get_tensor_by_name("neuron_6:0") 119 | row_idx = tf.range(tf.shape(y)[0]) 120 | col_idx = tf.argmax(y, axis=1, output_type=tf.dtypes.int32) 121 | full_indices = tf.stack([row_idx, col_idx], axis=1) 122 | fx_tensors = tf.gather_nd(y, full_indices) 123 | return fx_tensors 124 | 125 | def main(): 126 | print(v_test_input_variables()) 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /svhn/svhn.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/svhn/svhn.data-00000-of-00001 -------------------------------------------------------------------------------- /svhn/svhn.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/svhn/svhn.index -------------------------------------------------------------------------------- /svhn/svhn.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/svhn/svhn.meta -------------------------------------------------------------------------------- /svhn/svhn_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/svhn/svhn_fisher.npy -------------------------------------------------------------------------------- /svhn/svhn_network_fisher.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/svhn/svhn_network_fisher.npy -------------------------------------------------------------------------------- /svhn/svhn_network_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/svhn/svhn_network_weight.npy -------------------------------------------------------------------------------- /svhn/svhn_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/svhn/svhn_weight.npy -------------------------------------------------------------------------------- /svhn_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import os 4 | 5 | train_path = '...' 6 | test_path = '...' 7 | 8 | #""" 9 | svhn_train_data = np.load('svhn_train_data.npy') 10 | svhn_train_label = np.load('svhn_train_label.npy') 11 | 12 | svhn_test_data = np.load('svhn_test_data.npy') 13 | svhn_test_label = np.load('svhn_test_label.npy') 14 | 15 | svhn_validation_data = np.load('svhn_validation_data.npy') 16 | svhn_validation_label = np.load('svhn_validation_label.npy') 17 | 18 | #""" 19 | 20 | def train_set(): 21 | return svhn_train_data, svhn_train_label 22 | 23 | def test_set(): 24 | return svhn_test_data, svhn_test_label 25 | 26 | def validation_set(): 27 | return svhn_validation_data, svhn_validation_label 28 | 29 | def create_data_files(): 30 | return 31 | 32 | def main(): 33 | #create_data_files() 34 | print(train_set()[0]) 35 | print(train_set()[0].shape) 36 | print(train_set()[1]) 37 | print(train_set()[1].shape) 38 | 39 | if __name__ == '__main__': 40 | main() 41 | 42 | -------------------------------------------------------------------------------- /tf_operation.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/tf_operation.so -------------------------------------------------------------------------------- /tf_operation/build_tf_operation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) 4 | TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) 5 | 6 | nvcc -std=c++11 -c -o tf_operation.cu.o tf_operation.cu ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 -D_MWAITXINTRIN_H_INCLUDED 7 | g++ -std=c++11 -shared -o tf_operation.so tf_operation.cc tf_operation.cu.o ${TF_CFLAGS[@]} -fPIC -lcuda ${TF_LFLAGS[@]} -O2 -D_GLIBCXX_USE_CXX11_ABI=0 8 | 9 | cp tf_operation.so .. 10 | -------------------------------------------------------------------------------- /tf_operation/build_tf_operation_nano.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TF_CFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) 4 | TF_LFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) 5 | 6 | nvcc -std=c++11 -c -o tf_operation.cu.o tf_operation.cu ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -D_MWAITXINTRIN_H_INCLUDED 7 | g++ -std=c++11 -shared -o tf_operation.so tf_operation.cc tf_operation.cu.o ${TF_CFLAGS[@]} -fPIC -lcuda ${TF_LFLAGS[@]} -O3 -D GOOGLE_CUDA=1 -I/usr/loca/cuda/include 8 | 9 | cp tf_operation.so .. 10 | -------------------------------------------------------------------------------- /tf_operation/build_weight_loader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | nvcc -c -o weight_loader.cu.o weight_loader.cu -x cu -Xcompiler -fPIC 4 | gcc -shared -o weight_loader.so weight_loader.c weight_loader.cu.o -fPIC -L/usr/local/cuda/lib64 -lcuda -lcudart -O2 -lstdc++ 5 | 6 | cp weight_loader.so .. 7 | -------------------------------------------------------------------------------- /tf_operation/build_weight_loader_nano.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | nvcc -c -o weight_loader.cu.o weight_loader.cu -x cu -Xcompiler -fPIC 4 | gcc -shared -o weight_loader.so weight_loader.c weight_loader.cu.o -fPIC -L/usr/local/cuda/lib64 -lcuda -lcudart -O2 -lstdc++ 5 | 6 | cp weight_loader.so .. 7 | -------------------------------------------------------------------------------- /tf_operation/tf_operation.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include "tensorflow/core/util/tensor_format.h" 5 | 6 | using namespace tensorflow; 7 | 8 | REGISTER_OP("InitWeight") 9 | .Input("input: float") 10 | .Output("virtual_weight_address: int64"); 11 | 12 | REGISTER_OP("InitPageTable") 13 | .Input("input: int32") 14 | .Output("page_table_address: int64"); 15 | 16 | REGISTER_OP("ReadWeight") 17 | .Attr("start: int") 18 | .Attr("end: int") 19 | .Input("virtual_weight_address: int64") 20 | .Output("weight: float"); 21 | 22 | REGISTER_OP("ReadPageTable") 23 | .Attr("start: int") 24 | .Attr("end: int") 25 | .Input("page_table_address: int64") 26 | .Output("page: int32"); 27 | 28 | REGISTER_OP("GetWeight") 29 | .Attr("num_inputs: int >= 1") 30 | .Attr("page_size: int") 31 | .Input("ref: Ref(num_inputs * float)") 32 | .Input("virtual_weight_address: int64") 33 | .Input("page_table_address: int64") 34 | //.Output("output_ref: Ref(num_inputs * float)") 35 | .SetAllowsUninitializedInput(); 36 | 37 | REGISTER_OP("GetWeightAddress") 38 | .Attr("num_inputs: int >= 1") 39 | .Input("ref: Ref(num_inputs * float)") 40 | .Output("weight_address: int64") 41 | .Output("weight_len: int32") 42 | .SetAllowsUninitializedInput(); 43 | 44 | REGISTER_OP("FreeWeight") 45 | .Input("virtual_weight_address: int64"); 46 | 47 | REGISTER_OP("FreePageTable") 48 | .Input("page_table_address: int64"); 49 | 50 | REGISTER_OP("SharingCost") 51 | .Input("fisher1: float") 52 | .Input("weight1: float") 53 | .Input("fisher2: float") 54 | .Input("weight2: float") 55 | .Output("cost: float"); 56 | 57 | REGISTER_OP("PageAlloc") 58 | .Attr("page_size: int") 59 | .Input("fisher1: float") 60 | .Input("weight1: float") 61 | .Input("page_list1: int64") 62 | .Input("fisher2: float") 63 | .Input("weight2: float") 64 | .Input("page_list2: int64") 65 | .Output("page_allocation: int64") 66 | .Output("total_cost: float"); 67 | 68 | REGISTER_OP("PageAllocMulti") 69 | .Attr("num_of_list: int >= 1") 70 | .Attr("page_size: int") 71 | .Input("curent_fisher: float") 72 | .Input("base_weight: float") 73 | .Input("base_page_list: int64") 74 | .Input("new_fisher: num_of_list * float") 75 | .Input("new_weight: num_of_list * float") 76 | .Input("new_page_list: num_of_list * int64") 77 | .Output("page_allocation: num_of_list * int64") 78 | .Output("total_cost: float"); 79 | 80 | REGISTER_OP("HarmonicMean") 81 | .Attr("num_inputs: int >= 1") 82 | .Input("inputs: num_inputs * float") 83 | .Output("harmonic_mean: float"); 84 | 85 | void InitWeight(const float *weight, const int weight_len, 86 | long long int *address); 87 | void InitPageTable(const int *page_table, const int page_table_len, 88 | long long int *page_table_address); 89 | void ReadWeight(const long long int *address, 90 | int start, int end, float *output); 91 | void ReadPageTable(const long long int *address, 92 | int start, int end, int *output); 93 | void GetWeightKernelLauncher(float *input, const int input_len, 94 | const long long int *address, const long long int *page_table_address, 95 | const int page_size, const int start, const int end); 96 | void GetWeightAddress(int64 *weight_address_list, int *weight_len_list, int input_len, 97 | int64 *output1, int *output2); 98 | void FreeWeight(const long long int *address); 99 | void FreePageTable(const long long int *address); 100 | void SharingCostKernelLauncher(const float *input1, const float *input2, const float *input3, 101 | const float *input4, int len, float *output); 102 | void PageAlloc(const float *fisher1, const float *weight1, const float *fisher2, 103 | const float *weight2, int page_list1_len, int page_list2_len, 104 | const long long int *page_list1, const long long int *page_list2, 105 | int page_size, long long int *page_allocation, float *total_cost); 106 | void PageAllocMulti(const float *fisher_addr[], const float *weight_addr[], 107 | const long long int *page_list_addr[], 108 | int num_of_list, int *fisher_len, int *weight_len, int *page_list_len, 109 | int page_size, long long int *page_allocation[], float *total_cost); 110 | void HarmonicMeanKernelLauncher(const float *input_data[], int input_size, int input_len, 111 | float *output); 112 | 113 | class InitWeightOp : public OpKernel { 114 | public: 115 | explicit InitWeightOp(OpKernelConstruction* context) : OpKernel(context) {} 116 | 117 | void Compute(OpKernelContext* context) override { 118 | const Tensor& input_tensor = context->input(0); 119 | auto weight = input_tensor.flat(); 120 | const int weight_len = weight.size(); 121 | 122 | Tensor* output_tensor = nullptr; 123 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), 124 | &output_tensor)); 125 | auto address = output_tensor->template flat(); 126 | 127 | InitWeight(weight.data(), weight_len, address.data()); 128 | } 129 | }; 130 | 131 | class InitPageTableOp : public OpKernel { 132 | public: 133 | explicit InitPageTableOp(OpKernelConstruction* context) : OpKernel(context) {} 134 | 135 | void Compute(OpKernelContext* context) override { 136 | const Tensor& input_tensor = context->input(0); 137 | auto page_table = input_tensor.flat(); 138 | const int page_table_len = page_table.size(); 139 | 140 | Tensor* output_tensor = nullptr; 141 | OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), 142 | &output_tensor)); 143 | auto page_table_address = output_tensor->template flat(); 144 | 145 | InitPageTable(page_table.data(), page_table_len, page_table_address.data()); 146 | } 147 | }; 148 | 149 | class ReadWeightOp : public OpKernel { 150 | public: 151 | int start_, end_; 152 | explicit ReadWeightOp(OpKernelConstruction* context) : OpKernel(context) { 153 | OP_REQUIRES_OK(context, context->GetAttr("start", &start_)); 154 | OP_REQUIRES_OK(context, context->GetAttr("end", &end_)); 155 | } 156 | 157 | void Compute(OpKernelContext* context) override { 158 | const Tensor& address_tensor = context->input(0); 159 | auto address = address_tensor.flat(); 160 | 161 | Tensor* output_tensor = nullptr; 162 | OP_REQUIRES_OK(context, context->allocate_output(0, 163 | TensorShape({end_-start_+1}), &output_tensor)); 164 | auto output = output_tensor->template flat(); 165 | 166 | ReadWeight(address.data(), start_, end_, output.data()); 167 | } 168 | }; 169 | 170 | class ReadPageTableOp : public OpKernel { 171 | public: 172 | int start_, end_; 173 | explicit ReadPageTableOp(OpKernelConstruction* context) : OpKernel(context) { 174 | OP_REQUIRES_OK(context, context->GetAttr("start", &start_)); 175 | OP_REQUIRES_OK(context, context->GetAttr("end", &end_)); 176 | } 177 | 178 | void Compute(OpKernelContext* context) override { 179 | const Tensor& address_tensor = context->input(0); 180 | auto address = address_tensor.flat(); 181 | 182 | Tensor* output_tensor = nullptr; 183 | OP_REQUIRES_OK(context, context->allocate_output(0, 184 | TensorShape({end_-start_+1}), &output_tensor)); 185 | auto output = output_tensor->template flat(); 186 | 187 | ReadPageTable(address.data(), start_, end_, output.data()); 188 | } 189 | }; 190 | 191 | class GetWeightOp : public OpKernel { 192 | public: 193 | int page_size_; 194 | explicit GetWeightOp(OpKernelConstruction* context) : OpKernel(context) { 195 | OP_REQUIRES_OK(context, context->GetAttr("page_size", &page_size_)); 196 | } 197 | 198 | void Compute(OpKernelContext* context) override { 199 | 200 | OpMutableInputList ref_inputs; 201 | OP_REQUIRES_OK(context, 202 | context->mutable_input_list("ref", &ref_inputs)); 203 | 204 | int start_id = ref_inputs.size(); 205 | 206 | const Tensor& address_tensor = context->input(start_id); 207 | auto address = address_tensor.flat(); 208 | 209 | const Tensor& page_table_address_tensor = context->input(start_id+1); 210 | auto page_table_address = page_table_address_tensor.flat(); 211 | 212 | int start = 0; 213 | int end = 0; 214 | 215 | for (int i = 0; i < ref_inputs.size(); i++) { 216 | Tensor ref_tensor = ref_inputs.at(i, /*lock_held=*/ true); 217 | auto ref_input = ref_tensor.flat(); 218 | int num_weight = ref_input.size(); 219 | 220 | AllocatorAttributes attr; 221 | attr.set_gpu_compatible(true); 222 | attr.set_nic_compatible(true); 223 | 224 | PersistentTensor persistent_tensor; 225 | Tensor* new_tensor = nullptr; 226 | OP_REQUIRES_OK(context, 227 | context->allocate_persistent(ref_tensor.dtype(), 228 | ref_tensor.shape(), &persistent_tensor, &new_tensor, attr)); 229 | context->clear_recorded_memory(); 230 | context->replace_ref_input(i, *new_tensor, /* lock_held */ true); 231 | Tensor unlocked_input_tensor = context->mutable_input(i, 232 | /* lock_held */ false); 233 | auto input = unlocked_input_tensor.flat(); 234 | const int input_len = input.size(); 235 | 236 | end = start + num_weight - 1; 237 | GetWeightKernelLauncher(input.data(), input_len, address.data(), 238 | page_table_address.data(), page_size_, start, end); 239 | start = end + 1; 240 | } 241 | } 242 | }; 243 | 244 | class GetWeightAddressOp : public OpKernel { 245 | public: 246 | explicit GetWeightAddressOp(OpKernelConstruction* context) : OpKernel(context) {} 247 | 248 | void Compute(OpKernelContext* context) override { 249 | 250 | OpMutableInputList ref_inputs; 251 | OP_REQUIRES_OK(context, 252 | context->mutable_input_list("ref", &ref_inputs)); 253 | 254 | int start_id = ref_inputs.size(); 255 | 256 | Tensor* output_tensor1 = nullptr; 257 | OP_REQUIRES_OK(context, context->allocate_output(0, 258 | TensorShape({ref_inputs.size()}), &output_tensor1)); 259 | auto output1 = output_tensor1->template flat(); 260 | 261 | Tensor* output_tensor2 = nullptr; 262 | OP_REQUIRES_OK(context, context->allocate_output(1, 263 | TensorShape({ref_inputs.size()}), &output_tensor2)); 264 | auto output2 = output_tensor2->template flat(); 265 | 266 | int64 weight_address_list[ref_inputs.size()]; 267 | int weight_len_list[ref_inputs.size()]; 268 | 269 | int start = 0; 270 | int end = 0; 271 | 272 | for (int i = 0; i < ref_inputs.size(); i++) { 273 | Tensor ref_tensor = ref_inputs.at(i, /*lock_held=*/ true); 274 | auto ref_input = ref_tensor.flat(); 275 | 276 | AllocatorAttributes attr; 277 | attr.set_gpu_compatible(true); 278 | attr.set_nic_compatible(true); 279 | 280 | PersistentTensor persistent_tensor; 281 | Tensor* new_tensor = nullptr; 282 | OP_REQUIRES_OK(context, 283 | context->allocate_persistent(ref_tensor.dtype(), 284 | ref_tensor.shape(), &persistent_tensor, &new_tensor, attr)); 285 | context->clear_recorded_memory(); 286 | context->replace_ref_input(i, *new_tensor, /* lock_held */ true); 287 | Tensor unlocked_input_tensor = context->mutable_input(i, 288 | /* lock_held */ false); 289 | auto input = unlocked_input_tensor.flat(); 290 | 291 | weight_address_list[i] = (int64)input.data(); 292 | weight_len_list[i] = input.size(); 293 | } 294 | 295 | GetWeightAddress(weight_address_list, weight_len_list, 296 | ref_inputs.size(), output1.data(), output2.data()); 297 | } 298 | }; 299 | 300 | 301 | class FreeWeightOp : public OpKernel { 302 | public: 303 | explicit FreeWeightOp(OpKernelConstruction* context) : OpKernel(context) {} 304 | 305 | void Compute(OpKernelContext* context) override { 306 | const Tensor& address_tensor = context->input(0); 307 | auto address = address_tensor.flat(); 308 | 309 | FreeWeight(address.data()); 310 | } 311 | }; 312 | 313 | class FreePageTableOp : public OpKernel { 314 | public: 315 | explicit FreePageTableOp(OpKernelConstruction* context) : OpKernel(context) {} 316 | 317 | void Compute(OpKernelContext* context) override { 318 | const Tensor& address_tensor = context->input(0); 319 | auto address = address_tensor.flat(); 320 | 321 | FreePageTable(address.data()); 322 | } 323 | }; 324 | 325 | class SharingCostOp : public OpKernel { 326 | public: 327 | explicit SharingCostOp(OpKernelConstruction* context) : OpKernel(context) {} 328 | 329 | void Compute(OpKernelContext* context) override { 330 | const Tensor& input_tensor1 = context->input(0); 331 | auto input1 = input_tensor1.flat(); 332 | 333 | const Tensor& input_tensor2 = context->input(1); 334 | auto input2 = input_tensor2.flat(); 335 | 336 | const Tensor& input_tensor3 = context->input(2); 337 | auto input3 = input_tensor3.flat(); 338 | 339 | const Tensor& input_tensor4 = context->input(3); 340 | auto input4 = input_tensor4.flat(); 341 | 342 | OP_REQUIRES(context, input_tensor1.NumElements() 343 | == input_tensor2.NumElements(), 344 | errors::InvalidArgument("size of input tensors needs to be same")); 345 | OP_REQUIRES(context, input_tensor2.NumElements() 346 | == input_tensor3.NumElements(), 347 | errors::InvalidArgument("size of input tensors needs to be same")); 348 | OP_REQUIRES(context, input_tensor3.NumElements() 349 | == input_tensor4.NumElements(), 350 | errors::InvalidArgument("size of input tensors needs to be same")); 351 | 352 | Tensor* output_tensor = nullptr; 353 | OP_REQUIRES_OK(context, context->allocate_output(0, 354 | TensorShape({}), &output_tensor)); 355 | auto output = output_tensor->template flat(); 356 | 357 | SharingCostKernelLauncher(input1.data(), input2.data(), input3.data(), 358 | input4.data(), input_tensor1.dim_size(0), output.data()); 359 | } 360 | }; 361 | 362 | class PageAllocOp : public OpKernel { 363 | public: 364 | int page_size_; 365 | explicit PageAllocOp(OpKernelConstruction* context) : OpKernel(context) { 366 | OP_REQUIRES_OK(context, context->GetAttr("page_size", &page_size_)); 367 | } 368 | 369 | void Compute(OpKernelContext* context) override { 370 | const Tensor& input_tensor1 = context->input(0); 371 | auto fisher1 = input_tensor1.flat(); 372 | 373 | const Tensor& input_tensor2 = context->input(1); 374 | auto weight1 = input_tensor2.flat(); 375 | 376 | const Tensor& input_tensor3 = context->input(2); 377 | auto page_list1 = input_tensor3.flat(); 378 | 379 | const Tensor& input_tensor4 = context->input(3); 380 | auto fisher2 = input_tensor4.flat(); 381 | 382 | const Tensor& input_tensor5 = context->input(4); 383 | auto weight2 = input_tensor5.flat(); 384 | 385 | const Tensor& input_tensor6 = context->input(5); 386 | auto page_list2 = input_tensor6.flat(); 387 | 388 | OP_REQUIRES(context, input_tensor1.dims() == 1, 389 | errors::InvalidArgument("dims of input tensor needs to be 1")); 390 | OP_REQUIRES(context, input_tensor2.dims() == 1, 391 | errors::InvalidArgument("dims of input tensor needs to be 1")); 392 | OP_REQUIRES(context, input_tensor4.dims() == 1, 393 | errors::InvalidArgument("dims of input tensor needs to be 1")); 394 | OP_REQUIRES(context, input_tensor5.dims() == 1, 395 | errors::InvalidArgument("dims of input tensor needs to be 1")); 396 | OP_REQUIRES(context, input_tensor1.NumElements() 397 | == input_tensor2.NumElements(), 398 | errors::InvalidArgument("size of input tensor 1, 2 needs to be same")); 399 | OP_REQUIRES(context, input_tensor4.NumElements() 400 | == input_tensor5.NumElements(), 401 | errors::InvalidArgument("size of input tensor 4, 5 needs to be same")); 402 | 403 | Tensor* output_tensor1 = nullptr; 404 | int output_tensor_len = page_list1.size() < page_list2.size() 405 | ? page_list1.size() : page_list2.size(); 406 | OP_REQUIRES_OK(context, context->allocate_output(0, 407 | TensorShape({output_tensor_len,2}), &output_tensor1)); 408 | auto page_allocation = output_tensor1->template flat(); 409 | 410 | Tensor* output_tensor2 = nullptr; 411 | OP_REQUIRES_OK(context, context->allocate_output(1, 412 | TensorShape({}), &output_tensor2)); 413 | auto total_cost = output_tensor2->template flat(); 414 | 415 | PageAlloc(fisher1.data(), weight1.data(), fisher2.data(), weight2.data(), 416 | page_list1.size(), page_list2.size(), page_list1.data(), 417 | page_list2.data(), page_size_, page_allocation.data(), 418 | total_cost.data()); 419 | } 420 | }; 421 | 422 | class PageAllocMultiOp : public OpKernel { 423 | public: 424 | int page_size_; 425 | explicit PageAllocMultiOp(OpKernelConstruction* context) : OpKernel(context) { 426 | OP_REQUIRES_OK(context, context->GetAttr("page_size", &page_size_)); 427 | } 428 | 429 | void Compute(OpKernelContext* context) override { 430 | const Tensor& input_tensor1 = context->input(0); 431 | auto base_fisher = input_tensor1.flat(); 432 | 433 | const Tensor& input_tensor2 = context->input(1); 434 | auto base_weight = input_tensor2.flat(); 435 | 436 | const Tensor& input_tensor3 = context->input(2); 437 | auto base_page_list = input_tensor3.flat(); 438 | int base_page_list_len = input_tensor3.NumElements(); 439 | 440 | OP_REQUIRES(context, input_tensor1.NumElements() == input_tensor2.NumElements(), 441 | errors::InvalidArgument("size of current fisher and weight needs to be same")); 442 | OP_REQUIRES(context, input_tensor1.NumElements()/page_size_ == input_tensor3.NumElements(), 443 | errors::InvalidArgument("size of current fisher and page list needs to be same")); 444 | 445 | OpInputList new_fisher_list; 446 | OP_REQUIRES_OK(context, 447 | context->input_list("new_fisher", &new_fisher_list)); 448 | int new_fisher_size = new_fisher_list.size(); 449 | 450 | OpInputList new_weight_list; 451 | OP_REQUIRES_OK(context, 452 | context->input_list("new_weight", &new_weight_list)); 453 | int new_weight_size = new_weight_list.size(); 454 | 455 | OpInputList new_page_list_list; 456 | OP_REQUIRES_OK(context, 457 | context->input_list("new_page_list", &new_page_list_list)); 458 | int new_page_list_size = new_page_list_list.size(); 459 | 460 | OP_REQUIRES(context, new_fisher_size == new_weight_size, 461 | errors::InvalidArgument("size of new_fisher and new_weight_list needs to be same")); 462 | OP_REQUIRES(context, new_fisher_size == new_page_list_size, 463 | errors::InvalidArgument("size of new_fisher and new_page_list_list needs to be same")); 464 | 465 | const float *fisher_addr[new_fisher_size+1]; 466 | const float *weight_addr[new_weight_size+1]; 467 | const long long int *page_list_addr[new_page_list_size+1]; 468 | int fisher_len[new_fisher_size+1]; 469 | int weight_len[new_weight_size+1]; 470 | int page_list_len[new_page_list_size+1]; 471 | long long int *output_page_list_addr[new_page_list_size]; 472 | 473 | fisher_addr[0] = base_fisher.data(); 474 | weight_addr[0] = base_weight.data(); 475 | page_list_addr[0] = base_page_list.data(); 476 | fisher_len[0] = input_tensor1.NumElements(); 477 | weight_len[0] = input_tensor2.NumElements(); 478 | page_list_len[0] = input_tensor3.NumElements(); 479 | 480 | OpOutputList page_allocation_list; 481 | OP_REQUIRES_OK(context, context->output_list("page_allocation", 482 | &page_allocation_list)); 483 | 484 | for (int i = 0; i < new_fisher_size; i++) { 485 | fisher_len[i+1] = new_fisher_list[i].NumElements(); 486 | weight_len[i+1] = new_weight_list[i].NumElements(); 487 | page_list_len[i+1] = new_page_list_list[i].NumElements(); 488 | 489 | OP_REQUIRES(context, fisher_len[i] == weight_len[i], 490 | errors::InvalidArgument("size of fisher_len and weight_len needs to be same")); 491 | OP_REQUIRES(context, fisher_len[i]/page_size_ == page_list_len[i], 492 | errors::InvalidArgument("size of fisher_len and page_list_len needs to be same")); 493 | 494 | auto new_fisher = new_fisher_list[i].flat(); 495 | fisher_addr[i+1] = new_fisher.data(); 496 | auto new_weight = new_weight_list[i].flat(); 497 | weight_addr[i+1] = new_weight.data(); 498 | auto new_page_list = new_page_list_list[i].flat(); 499 | page_list_addr[i+1] = new_page_list.data(); 500 | 501 | Tensor* page_allocation = nullptr; 502 | OP_REQUIRES_OK(context, page_allocation_list.allocate(i, 503 | TensorShape({page_list_len[i+1]}), &page_allocation)); 504 | auto output_page_list = page_allocation->template flat(); 505 | output_page_list_addr[i] = output_page_list.data(); 506 | } 507 | 508 | Tensor* cost_tensor = nullptr; 509 | OP_REQUIRES_OK(context, context->allocate_output(new_page_list_size, 510 | TensorShape({}), &cost_tensor)); 511 | auto cost = cost_tensor->template flat(); 512 | 513 | PageAllocMulti(fisher_addr, weight_addr, page_list_addr, new_fisher_size+1, 514 | fisher_len, weight_len, page_list_len, page_size_, 515 | output_page_list_addr, cost.data()); 516 | } 517 | }; 518 | 519 | 520 | class HarmonicMeanOp : public OpKernel { 521 | public: 522 | explicit HarmonicMeanOp(OpKernelConstruction* context) : OpKernel(context) {} 523 | 524 | void Compute(OpKernelContext* context) override { 525 | OpInputList inputs; 526 | OP_REQUIRES_OK(context, context->input_list("inputs", &inputs)); 527 | int input_size = inputs.size(); 528 | int input_len = inputs[0].NumElements(); 529 | const float *input_data_addr[input_size]; 530 | 531 | for (int i = 0; i < input_size; i++) { 532 | auto input = inputs[i].flat(); 533 | 534 | OP_REQUIRES(context, inputs[i].NumElements() == input_len, 535 | errors::InvalidArgument("len of input tensors needs to be same")); 536 | input_data_addr[i] = (float *)input.data(); 537 | } 538 | 539 | Tensor* output_tensor = nullptr; 540 | OP_REQUIRES_OK(context, context->allocate_output(0, 541 | TensorShape({}), &output_tensor)); 542 | auto output = output_tensor->template flat(); 543 | 544 | HarmonicMeanKernelLauncher(input_data_addr, input_size, input_len, 545 | output.data()); 546 | } 547 | }; 548 | 549 | REGISTER_KERNEL_BUILDER(Name("InitWeight").Device(DEVICE_GPU), InitWeightOp); 550 | REGISTER_KERNEL_BUILDER(Name("InitPageTable").Device(DEVICE_GPU), InitPageTableOp); 551 | REGISTER_KERNEL_BUILDER(Name("ReadWeight").Device(DEVICE_GPU), ReadWeightOp); 552 | REGISTER_KERNEL_BUILDER(Name("ReadPageTable").Device(DEVICE_GPU), ReadPageTableOp); 553 | REGISTER_KERNEL_BUILDER(Name("GetWeight").Device(DEVICE_GPU), GetWeightOp); 554 | REGISTER_KERNEL_BUILDER(Name("GetWeightAddress").Device(DEVICE_GPU), GetWeightAddressOp); 555 | REGISTER_KERNEL_BUILDER(Name("FreeWeight").Device(DEVICE_GPU), FreeWeightOp); 556 | REGISTER_KERNEL_BUILDER(Name("FreePageTable").Device(DEVICE_GPU), FreePageTableOp); 557 | REGISTER_KERNEL_BUILDER(Name("SharingCost").Device(DEVICE_GPU), SharingCostOp); 558 | REGISTER_KERNEL_BUILDER(Name("PageAlloc").Device(DEVICE_GPU), PageAllocOp); 559 | REGISTER_KERNEL_BUILDER(Name("PageAllocMulti").Device(DEVICE_GPU), PageAllocMultiOp); 560 | REGISTER_KERNEL_BUILDER(Name("HarmonicMean").Device(DEVICE_GPU), HarmonicMeanOp); 561 | -------------------------------------------------------------------------------- /tf_operation/tf_operation.cu: -------------------------------------------------------------------------------- 1 | #define EIGEN_USE_GPU 2 | #include 3 | #include 4 | #include 5 | 6 | void InitWeight(const float* input, const int input_len, 7 | long long int* output) 8 | { 9 | float *weight; 10 | int64_t address; 11 | 12 | cudaMalloc(&weight, sizeof(float)*input_len); 13 | cudaMemcpy(weight, input, sizeof(float)*input_len, cudaMemcpyDeviceToDevice); 14 | 15 | address = (int64_t)weight; 16 | cudaMemcpy(output, &address, sizeof(address), cudaMemcpyHostToDevice); 17 | } 18 | 19 | void InitPageTable(const int *input, const int input_len, 20 | long long int *output) 21 | { 22 | int *page_table; 23 | int64_t page_table_address; 24 | 25 | cudaMalloc(&page_table, sizeof(int)*input_len); 26 | cudaMemcpy(page_table, input, sizeof(int)*input_len, cudaMemcpyDeviceToDevice); 27 | 28 | page_table_address = (int64_t)page_table; 29 | cudaMemcpy(output, &page_table_address, sizeof(page_table_address), 30 | cudaMemcpyHostToDevice); 31 | } 32 | 33 | __global__ void GetWeightKernel(float *input, const int input_len, float *addr, 34 | int *page_table_addr, const int page_size, const int start, const int end) 35 | { 36 | int idx, page_num, page, offset; 37 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_len; 38 | i += blockDim.x * gridDim.x) { 39 | idx = start+i; 40 | page_num = idx / page_size; 41 | page = page_table_addr[page_num]; 42 | offset = idx % page_size; 43 | input[i] = addr[page*page_size + offset]; 44 | } 45 | } 46 | 47 | void GetWeightKernelLauncher(float *input, const int input_len, 48 | const long long int* address, const long long int* page_table_address, 49 | const int page_size, const int start, const int end) 50 | { 51 | int64_t addr; 52 | int64_t page_table_addr; 53 | 54 | cudaMemcpy(&addr, address, sizeof(addr), cudaMemcpyDeviceToHost); 55 | cudaMemcpy(&page_table_addr, page_table_address, sizeof(page_table_addr), 56 | cudaMemcpyDeviceToHost); 57 | 58 | GetWeightKernel<<<32, 256>>>(input, input_len, (float *)addr, 59 | (int *)page_table_addr, page_size, start, end); 60 | cudaDeviceSynchronize(); 61 | } 62 | 63 | void GetWeightAddress(long long int *weight_address_list, int *weight_len_list, 64 | int input_len, long long int *output1, int *output2) 65 | { 66 | for (int i = 0; i < input_len; i++) { 67 | cudaMemcpy(&output1[i], &weight_address_list[i], sizeof(long long int), 68 | cudaMemcpyHostToDevice); 69 | cudaMemcpy(&output2[i], &weight_len_list[i], sizeof(int), 70 | cudaMemcpyHostToDevice); 71 | } 72 | } 73 | 74 | void ReadWeight(const long long int* address, int start, int end, float *output) 75 | { 76 | int64_t addr; 77 | cudaMemcpy(&addr, address, sizeof(addr), cudaMemcpyDeviceToHost); 78 | cudaMemcpy(output, (const void *)addr, 79 | sizeof(float)*(end-start+1), cudaMemcpyDeviceToDevice); 80 | } 81 | 82 | void ReadPageTable(const long long int* address, int start, int end, int *output) 83 | { 84 | int64_t addr; 85 | cudaMemcpy(&addr, address, sizeof(addr), cudaMemcpyDeviceToHost); 86 | cudaMemcpy(output, (const void *)addr, 87 | sizeof(int)*(end-start+1), cudaMemcpyDeviceToDevice); 88 | } 89 | 90 | void FreeWeight(const long long int* address) 91 | { 92 | cudaFree((void *)address); 93 | } 94 | 95 | void FreePageTable(const long long int* address) 96 | { 97 | cudaFree((void *)address); 98 | } 99 | 100 | __device__ __forceinline__ float atomicMinFloat(float *addr, float value) { 101 | float old; 102 | old = (value >= 0) ? __int_as_float(atomicMin((int *)addr, __float_as_int(value))) : 103 | __uint_as_float(atomicMax((unsigned int *)addr, __float_as_uint(value))); 104 | 105 | return old; 106 | } 107 | 108 | __device__ __forceinline__ float atomicMaxFloat(float *addr, float value) { 109 | float old; 110 | old = (value >= 0) ? __int_as_float(atomicMax((int *)addr, __float_as_int(value))) : 111 | __uint_as_float(atomicMin((unsigned int *)addr, __float_as_uint(value))); 112 | 113 | return old; 114 | } 115 | 116 | __device__ float SharingCost(const float *fisher1, const float *weight1, 117 | const float *fisher2, const float *weight2, int len) 118 | { 119 | float cost = 0; 120 | for (int p = 0; p < len; p++) { 121 | if (fisher1[p] <= 0 || fisher2[p] <= 0) { 122 | continue; 123 | } 124 | 125 | float fisher_cost = fisher1[p] + fisher2[p]; 126 | //float fisher_cost = fisher1[p] * fisher2[p]; 127 | float weight_dist = weight1[p] - weight2[p]; 128 | float weight_cost = weight_dist * weight_dist; 129 | cost += (fisher_cost * weight_cost); 130 | } 131 | 132 | return cost; 133 | } 134 | 135 | __global__ void SharingCostKernel(const float *fisher1, const float *weight1, 136 | const float *fisher2, const float *weight2, int len, float *cost) 137 | { 138 | *cost = SharingCost(fisher1, weight1, fisher2, weight2, len); 139 | } 140 | 141 | void SharingCostKernelLauncher(const float* input1, const float* input2, const float* input3, 142 | const float* input4, int len, float *output) 143 | { 144 | SharingCostKernel<<<32, 128>>>(input1, input2, input3, input4, len, output); 145 | cudaDeviceSynchronize(); 146 | } 147 | 148 | __global__ void MinSharingCostKernel(const float *fisher1, const float *weight1, 149 | const float *fisher2, const float *weight2, long long int *page_list, 150 | int page_num, int page_size, float *min_cost, int *min_idx) 151 | { 152 | float cost, old_cost; 153 | *min_idx = -1; 154 | *min_cost = 10000000000.0; 155 | 156 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < page_num; 157 | i += blockDim.x * gridDim.x) { 158 | if (page_list[i] < 0) { 159 | continue; 160 | } 161 | 162 | cost = SharingCost(fisher1, weight1, &fisher2[page_size*page_list[i]], 163 | &weight2[page_size*page_list[i]], page_size); 164 | old_cost = atomicMinFloat(min_cost, cost); 165 | if (cost < old_cost) 166 | atomicExch(min_idx, i); 167 | } 168 | } 169 | 170 | __global__ void MinSharingCostMultiKernel(const float *fisher1, const float *weight1, 171 | const float *fisher_addr[], const float *weight_addr[], 172 | long long int *base_page_list, int base_page_num, float *page_cost_accum, 173 | long long int *page_occupation_unsorted_addr[], 174 | int page_size, int num_of_list, float *min_cost, int *min_idx) 175 | { 176 | float cost, old_cost; 177 | *min_idx = -1; 178 | *min_cost = 10000000000.0; 179 | 180 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < base_page_num; 181 | i += blockDim.x * gridDim.x) { 182 | if (base_page_list[i] < 0) { 183 | continue; 184 | } 185 | 186 | cost = page_cost_accum[base_page_list[i]]; 187 | const float *fisher2 = fisher_addr[0] + page_size*base_page_list[i]; 188 | const float *weight2 = weight_addr[0] + page_size*base_page_list[i]; 189 | 190 | cost += SharingCost(fisher1, weight1, fisher2, weight2, page_size); 191 | 192 | for (int j = 1; j < num_of_list; j++) { 193 | long long int page 194 | = *(page_occupation_unsorted_addr[j] + base_page_list[i]); 195 | if (page >= 0) { 196 | fisher2 = fisher_addr[j] + page_size*page; 197 | weight2 = weight_addr[j] + page_size*page; 198 | cost += SharingCost(fisher1, weight1, fisher2, weight2, 199 | page_size); 200 | } 201 | } 202 | 203 | old_cost = atomicMinFloat(min_cost, cost); 204 | if (cost < old_cost) 205 | atomicExch(min_idx, i); 206 | } 207 | } 208 | 209 | __global__ void BitonicSortKernel(float *dev_values, long long int *dev_idxs, int j, int k) 210 | { 211 | unsigned int i, ixj; /* Sorting partners: i and ixj */ 212 | i = threadIdx.x + blockDim.x * blockIdx.x; 213 | ixj = i^j; 214 | 215 | /* The threads with the lowest ids sort the array. */ 216 | if ((ixj)>i) { 217 | if ((i&k)==0) { 218 | if (dev_values[i]dev_values[ixj]) { 230 | float temp = dev_values[i]; 231 | dev_values[i] = dev_values[ixj]; 232 | dev_values[ixj] = temp; 233 | long long int temp2 = dev_idxs[i]; 234 | dev_idxs[i] = dev_idxs[ixj]; 235 | dev_idxs[ixj] = temp2; 236 | } 237 | } 238 | } 239 | } 240 | 241 | void BitonicSort(float *values, long long int *idxs, int len) 242 | { 243 | int dev = 0; 244 | int block, thread; 245 | int max_thread = 0; 246 | 247 | cudaSetDevice(dev); 248 | cudaDeviceProp deviceProp; 249 | cudaGetDeviceProperties(&deviceProp, dev); 250 | max_thread = deviceProp.maxThreadsPerBlock; 251 | 252 | if (len <= max_thread) { 253 | block = 1; 254 | thread = len; 255 | } else { 256 | block = len/max_thread; 257 | thread = max_thread; 258 | } 259 | 260 | dim3 blocks(block, 1); 261 | dim3 threads(thread, 1); 262 | 263 | int j, k; 264 | /* Major step */ 265 | for (k = 2; k <= len; k <<= 1) { 266 | /* Minor step */ 267 | for (j=k>>1; j>0; j=j>>1) { 268 | BitonicSortKernel<<>>(values, idxs, j, k); 269 | } 270 | } 271 | } 272 | 273 | __global__ void PageFisherSumKernel(const float *fisher, long long int *page_list, 274 | int page_list_len, int page_size, float *sum) 275 | { 276 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < page_list_len; 277 | i += blockDim.x * gridDim.x) { 278 | sum[i] = 0; 279 | for (int j = 0; j < page_size; j++) { 280 | sum[i] += fisher[page_size*page_list[i]+j]; 281 | } 282 | } 283 | } 284 | 285 | __global__ void FloatSetKernel(float *data, float set_value, int len) 286 | { 287 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; 288 | i += blockDim.x * gridDim.x) { 289 | data[i] = set_value; 290 | } 291 | } 292 | 293 | void PageAlloc(const float* fisher1, const float* weight1, const float* fisher2, 294 | const float* weight2, int page_list1_len, int page_list2_len, 295 | const long long int *page_list1, const long long int *page_list2, 296 | int page_size, long long int *page_allocation, float *total_cost) 297 | { 298 | long long int *page_list1_sorted, *page_list2_sorted; 299 | long long int *page_list1_dev, *page_list2_dev; 300 | float *page_sum1_sorted, *page_sum2_sorted; 301 | float *page_sum1_sorted_dev, *page_sum2_sorted_dev; 302 | int valid_page_num = page_list1_len < page_list2_len ? page_list1_len : page_list2_len; 303 | int page_list1_len_extended = 1 << (int)ceil(log2(page_list1_len)); 304 | int page_list2_len_extended = 1 << (int)ceil(log2(page_list2_len)); 305 | float total_min_cost = 0; 306 | float min_cost; 307 | float *min_cost_dev; 308 | int min_idx; 309 | int *min_idx_dev; 310 | int pos1 = 0, pos2 = 0; 311 | int idx = 0; 312 | 313 | /* copying page_list1 & page_list2 */ 314 | cudaMalloc(&page_list1_dev, sizeof(long long int)*page_list1_len_extended); 315 | cudaMalloc(&page_list2_dev, sizeof(long long int)*page_list2_len_extended); 316 | cudaMemset(page_list1_dev, 0, sizeof(long long int)*page_list1_len_extended); 317 | cudaMemset(page_list2_dev, 0, sizeof(long long int)*page_list2_len_extended); 318 | cudaMemcpy(page_list1_dev, page_list1, sizeof(long long int)*page_list1_len, 319 | cudaMemcpyDeviceToDevice); 320 | cudaMemcpy(page_list2_dev, page_list2, sizeof(long long int)*page_list2_len, 321 | cudaMemcpyDeviceToDevice); 322 | 323 | /* fisher sum per page */ 324 | cudaMalloc(&page_sum1_sorted_dev, sizeof(float)*page_list1_len_extended); 325 | cudaMalloc(&page_sum2_sorted_dev, sizeof(float)*page_list2_len_extended); 326 | FloatSetKernel<<<32, 128>>>(page_sum1_sorted_dev, -page_size-1.0, page_list1_len_extended); 327 | cudaDeviceSynchronize(); 328 | FloatSetKernel<<<32, 128>>>(page_sum2_sorted_dev, -page_size-1.0, page_list2_len_extended); 329 | cudaDeviceSynchronize(); 330 | 331 | PageFisherSumKernel<<<32, 128>>>(fisher1, page_list1_dev, page_list1_len, 332 | page_size, page_sum1_sorted_dev); 333 | cudaDeviceSynchronize(); 334 | PageFisherSumKernel<<<32, 128>>>(fisher2, page_list2_dev, page_list2_len, 335 | page_size, page_sum2_sorted_dev); 336 | cudaDeviceSynchronize(); 337 | #if 0 338 | float *page_sum1_sorted2 = (float *)malloc(sizeof(float)*page_list1_len); 339 | float *page_sum2_sorted2 = (float *)malloc(sizeof(float)*page_list2_len); 340 | cudaMemcpy(page_sum1_sorted2, page_sum1_sorted_dev, 341 | sizeof(float)*page_list1_len, cudaMemcpyDeviceToHost); 342 | cudaMemcpy(page_sum2_sorted2, page_sum2_sorted_dev, 343 | sizeof(float)*page_list2_len, cudaMemcpyDeviceToHost); 344 | 345 | for (int i = 0; i < page_list1_len; i++) { 346 | printf("[%d] %f ", i, page_sum1_sorted2[i]); 347 | } 348 | printf("\n"); 349 | #endif 350 | /* sort fisehr sum page (descending order) */ 351 | BitonicSort(page_sum1_sorted_dev, page_list1_dev, page_list1_len_extended); 352 | BitonicSort(page_sum2_sorted_dev, page_list2_dev, page_list2_len_extended); 353 | 354 | page_sum1_sorted = (float *)malloc(sizeof(float)*page_list1_len); 355 | page_sum2_sorted = (float *)malloc(sizeof(float)*page_list2_len); 356 | cudaMemcpy(page_sum1_sorted, page_sum1_sorted_dev, 357 | sizeof(float)*page_list1_len, cudaMemcpyDeviceToHost); 358 | cudaMemcpy(page_sum2_sorted, page_sum2_sorted_dev, 359 | sizeof(float)*page_list2_len, cudaMemcpyDeviceToHost); 360 | cudaFree(page_sum1_sorted_dev); 361 | cudaFree(page_sum2_sorted_dev); 362 | 363 | #if 0 364 | for (int i = 0; i < page_list1_len; i++) { 365 | printf("%f ", page_sum1_sorted[i]); 366 | } 367 | printf("\n"); 368 | for (int i = 0; i < page_list2_len; i++) { 369 | printf("%f ", page_sum2_sorted[i]); 370 | } 371 | printf("\n"); 372 | #endif 373 | page_list1_sorted = (long long int *)malloc(sizeof(long long int)*page_list1_len); 374 | page_list2_sorted = (long long int *)malloc(sizeof(long long int)*page_list2_len); 375 | cudaMemcpy(page_list1_sorted, page_list1_dev, sizeof(long long int)*page_list1_len, 376 | cudaMemcpyDeviceToHost); 377 | cudaMemcpy(page_list2_sorted, page_list2_dev, sizeof(long long int)*page_list2_len, 378 | cudaMemcpyDeviceToHost); 379 | #if 0 380 | for (int i = 0; i < page_list1_len; i++) { 381 | printf("%lld ", page_list1_sorted[i]); 382 | } 383 | printf("\n"); 384 | #endif 385 | /* skip redundant pages */ 386 | if (page_list1_len < page_list2_len) { 387 | //for (int i = 0; i < (page_list2_len - page_list1_len); i++) { 388 | for (int i = page_list2_len-1; i >= page_list1_len; i--) { 389 | page_list2_sorted[i] = -1; 390 | cudaMemcpy(&page_list2_dev[i], &page_list2_sorted[i], 391 | sizeof(long long int), cudaMemcpyHostToDevice); 392 | } 393 | } else if (page_list1_len > page_list2_len) { 394 | for (int i = 0; i < (page_list1_len - page_list2_len); i++) { 395 | page_list1_sorted[i] = -1; 396 | cudaMemcpy(&page_list1_dev[i], &page_list1_sorted[i], 397 | sizeof(long long int), cudaMemcpyHostToDevice); 398 | } 399 | } 400 | 401 | /* calculate total min cost */ 402 | cudaMalloc(&min_cost_dev, sizeof(float)); 403 | cudaMalloc(&min_idx_dev, sizeof(int)); 404 | 405 | for (int i = 0; i < valid_page_num; i++) { 406 | if ((i % 10000 == 0) or (i == valid_page_num-1)) { 407 | printf("%8d-th page\n", i); 408 | } 409 | 410 | while (page_list1_sorted[pos1] < 0) { 411 | pos1++; 412 | } 413 | while (page_list2_sorted[pos2] < 0) { 414 | pos2++; 415 | } 416 | 417 | if (page_sum1_sorted[pos1] > page_sum2_sorted[pos2]) { 418 | MinSharingCostKernel<<<32, 128>>>(fisher1+page_size*page_list1_sorted[pos1], 419 | weight1+page_size*page_list1_sorted[pos1], 420 | fisher2, weight2, page_list2_dev, page_list2_len, 421 | page_size, min_cost_dev, min_idx_dev); 422 | cudaDeviceSynchronize(); 423 | 424 | cudaMemcpy(&min_idx, min_idx_dev, sizeof(int), cudaMemcpyDeviceToHost); 425 | cudaMemcpy(&page_allocation[idx++], &page_list1_sorted[pos1], 426 | sizeof(long long int), cudaMemcpyHostToDevice); 427 | cudaMemcpy(&page_allocation[idx++], &page_list2_sorted[min_idx], 428 | sizeof(long long int), cudaMemcpyHostToDevice); 429 | 430 | page_list1_sorted[pos1] = -1; 431 | page_list2_sorted[min_idx] = -1; 432 | 433 | cudaMemcpy(&page_list1_dev[pos1], &page_list1_sorted[pos1], 434 | sizeof(long long int), cudaMemcpyHostToDevice); 435 | cudaMemcpy(&page_list2_dev[min_idx], &page_list2_sorted[min_idx], 436 | sizeof(long long int), cudaMemcpyHostToDevice); 437 | } else { 438 | MinSharingCostKernel<<<32, 128>>>(fisher2+page_size*page_list2_sorted[pos2], 439 | weight2+page_size*page_list2_sorted[pos2], 440 | fisher1, weight1, page_list1_dev, page_list1_len, 441 | page_size, min_cost_dev, min_idx_dev); 442 | cudaDeviceSynchronize(); 443 | 444 | cudaMemcpy(&min_idx, min_idx_dev, sizeof(int), cudaMemcpyDeviceToHost); 445 | cudaMemcpy(&page_allocation[idx++], &page_list1_sorted[min_idx], 446 | sizeof(long long int), cudaMemcpyHostToDevice); 447 | cudaMemcpy(&page_allocation[idx++], &page_list2_sorted[pos2], 448 | sizeof(long long int), cudaMemcpyHostToDevice); 449 | 450 | page_list1_sorted[min_idx] = -1; 451 | page_list2_sorted[pos2] = -1; 452 | 453 | cudaMemcpy(&page_list1_dev[min_idx], &page_list1_sorted[min_idx], 454 | sizeof(long long int), cudaMemcpyHostToDevice); 455 | cudaMemcpy(&page_list2_dev[pos2], &page_list2_sorted[pos2], 456 | sizeof(long long int), cudaMemcpyHostToDevice); 457 | } 458 | 459 | cudaMemcpy(&min_cost, min_cost_dev, sizeof(float), cudaMemcpyDeviceToHost); 460 | total_min_cost += min_cost; 461 | } 462 | 463 | cudaMemcpy(total_cost, &total_min_cost, sizeof(float), cudaMemcpyHostToDevice); 464 | 465 | // sanity check 466 | for (int i = 0; i < page_list1_len; i++) { 467 | if (page_list1_sorted[i] != -1) { 468 | printf("[PROGRAM EXIT] page alloc fails 1\n"); 469 | exit(EXIT_FAILURE); 470 | } 471 | } 472 | 473 | // sanity check 474 | for (int i = 0; i < page_list2_len; i++) { 475 | if (page_list2_sorted[i] != -1) { 476 | printf("[PROGRAM EXIT] page alloc fails 2\n"); 477 | exit(EXIT_FAILURE); 478 | } 479 | } 480 | 481 | cudaFree(min_idx_dev); 482 | cudaFree(min_cost_dev); 483 | free(page_list2_sorted); 484 | free(page_list1_sorted); 485 | free(page_sum2_sorted); 486 | free(page_sum1_sorted); 487 | cudaFree(page_list2_dev); 488 | cudaFree(page_list1_dev); 489 | } 490 | 491 | void PageAllocMulti(const float *fisher_addr[], const float *weight_addr[], 492 | const long long int *page_list_addr[], 493 | int num_of_list, int *fisher_len, int *weight_len, int *page_list_len, 494 | int page_size, long long int *page_allocation[], float *total_cost) 495 | { 496 | float *page_sum_sorted_addr[num_of_list]; 497 | long long int *page_list_sorted_addr[num_of_list]; 498 | long long int *page_list_sorted_addr_dev[num_of_list]; 499 | long long int *page_occupation_unsorted_addr[num_of_list]; 500 | long long int *page_occupation_unsorted_addr_dev[num_of_list]; 501 | 502 | for (int i = 0; i < num_of_list; i++) { 503 | float *page_sum_sorted_dev; 504 | int page_list_len_extended = 1 << (int)ceil(log2(page_list_len[i])); 505 | 506 | cudaMalloc(&page_list_sorted_addr_dev[i], 507 | sizeof(long long int)*page_list_len_extended); 508 | cudaMemset(page_list_sorted_addr_dev[i], 0, 509 | sizeof(long long int)*page_list_len_extended); 510 | cudaMemcpy(page_list_sorted_addr_dev[i], page_list_addr[i], 511 | sizeof(long long int)*page_list_len[i], 512 | cudaMemcpyDeviceToDevice); 513 | 514 | cudaMalloc(&page_sum_sorted_dev, sizeof(float)*page_list_len_extended); 515 | FloatSetKernel<<<32, 128>>>(page_sum_sorted_dev, -page_size-1.0, 516 | page_list_len_extended); 517 | cudaDeviceSynchronize(); 518 | PageFisherSumKernel<<<32, 128>>>(fisher_addr[i], page_list_sorted_addr_dev[i], 519 | page_list_len[i], page_size, page_sum_sorted_dev); 520 | cudaDeviceSynchronize(); 521 | 522 | BitonicSort(page_sum_sorted_dev, page_list_sorted_addr_dev[i], 523 | page_list_len_extended); 524 | cudaDeviceSynchronize(); 525 | page_sum_sorted_addr[i] = (float *)malloc(sizeof(float)*page_list_len[i]); 526 | cudaMemcpy(page_sum_sorted_addr[i], page_sum_sorted_dev, 527 | sizeof(float)*page_list_len[i], cudaMemcpyDeviceToHost); 528 | page_list_sorted_addr[i] 529 | = (long long int *)malloc(sizeof(long long int)*page_list_len[i]); 530 | cudaMemcpy(page_list_sorted_addr[i], page_list_sorted_addr_dev[i], 531 | sizeof(long long int)*page_list_len[i], cudaMemcpyDeviceToHost); 532 | 533 | page_occupation_unsorted_addr[i] 534 | = (long long int *)malloc(sizeof(long long int)*page_list_len[0]); 535 | for (int j = 0; j < page_list_len[0]; j++) { 536 | *(page_occupation_unsorted_addr[i] + j) = -1; 537 | } 538 | cudaMalloc(&page_occupation_unsorted_addr_dev[i], 539 | sizeof(long long int)*page_list_len[0]); 540 | cudaMemcpy(page_occupation_unsorted_addr_dev[i], 541 | page_occupation_unsorted_addr[i], 542 | sizeof(long long int)*page_list_len[0], cudaMemcpyHostToDevice); 543 | 544 | cudaFree(page_sum_sorted_dev); 545 | } 546 | 547 | int total_new_page_list_len = 0; 548 | 549 | for (int i = 1; i < num_of_list; i++) { 550 | total_new_page_list_len += page_list_len[i]; 551 | } 552 | 553 | if (page_list_len[0] > total_new_page_list_len) { 554 | int diff = page_list_len[0] - total_new_page_list_len; 555 | for (int i = 0; i < diff; i++) { 556 | *(page_list_sorted_addr[0] + i) = -1; 557 | } 558 | cudaMemcpy(page_list_sorted_addr_dev[0], page_list_sorted_addr[0], 559 | sizeof(long long int)*diff, cudaMemcpyHostToDevice); 560 | } 561 | 562 | long long int *page_occupation_addr_dev[num_of_list]; 563 | 564 | for (int i = 0; i < num_of_list; i++) { 565 | cudaMalloc(&page_occupation_addr_dev[i], 566 | sizeof(long long int)*page_list_len[0]); 567 | cudaMemcpy(page_occupation_addr_dev[i], page_list_sorted_addr_dev[0], 568 | sizeof(long long int)*page_list_len[0], cudaMemcpyDeviceToDevice); 569 | } 570 | 571 | float *min_cost_dev; 572 | float total_min_cost = 0; 573 | float total_min_cost2 = 0; 574 | int min_idx, *min_idx_dev; 575 | int pos[num_of_list] = { 0, }; 576 | int is_end[num_of_list] = { 0, }; 577 | long long int occupied = -1; 578 | 579 | cudaMalloc(&min_cost_dev, sizeof(float)); 580 | cudaMalloc(&min_idx_dev, sizeof(int)); 581 | 582 | float *page_cost_accum_dev; 583 | cudaMalloc(&page_cost_accum_dev, sizeof(float)*page_list_len[0]); 584 | cudaMemset(page_cost_accum_dev, 0, sizeof(float)*page_list_len[0]); 585 | 586 | for (int i = 0; i < total_new_page_list_len; i++) { 587 | if ((i % 10000 == 0) or (i == total_new_page_list_len-1)) { 588 | printf("%8d-th page\n", i); 589 | } 590 | 591 | float largest_sum = -1; 592 | int largest_set = -1; 593 | 594 | for (int j = 1; j < num_of_list; j++) { 595 | while (!is_end[j] && *(page_list_sorted_addr[j] + pos[j]) < 0) { 596 | pos[j] += 1; 597 | if (pos[j] >= page_list_len[j]) { 598 | is_end[j] = 1; 599 | break; 600 | } 601 | } 602 | 603 | if (is_end[j]) { 604 | continue; 605 | } 606 | 607 | float fisher_sum = *(page_sum_sorted_addr[j] + pos[j]); 608 | if (fisher_sum > largest_sum) { 609 | largest_sum = fisher_sum; 610 | largest_set = j; 611 | } 612 | } 613 | 614 | long long int largest_page 615 | = *(page_list_sorted_addr[largest_set] + pos[largest_set]); 616 | const float *fisher1 617 | = fisher_addr[largest_set] + page_size*largest_page; 618 | const float *weight1 619 | = weight_addr[largest_set] + page_size*largest_page; 620 | 621 | if (largest_set == 0) { 622 | #if 0 623 | float min_min_cost = -1; 624 | int smallest_set = -1; 625 | 626 | for (int j = 1; j < num_of_list; j++) { 627 | int is_occupied 628 | = *(page_occupation_unsorted_addr[j] + largest_page); 629 | if (is_occupied == occupied) { 630 | continue; 631 | } 632 | 633 | MinSharingCostKernel<<<32, 128>>>(fisher1, weight1, fisher_addr[j], 634 | weight_addr[j], page_list_sorted_addr_dev[j], 635 | page_list_len[j], page_size, 636 | min_cost_dev, min_idx_dev); 637 | cudaDeviceSynchronize(); 638 | cudaMemcpy(&min_cost, min_cost_dev, sizeof(float), 639 | cudaMemcpyDeviceToHost); 640 | cudaMemcpy(&min_idx, min_idx_dev, sizeof(int), 641 | cudaMemcpyDeviceToHost); 642 | 643 | if (smallest_set == -1) { 644 | min_min_cost = min_cost; 645 | smallest_set = j; 646 | } else { 647 | if (min_cost < min_min_cost) { 648 | min_min_cost = min_cost; 649 | smallest_set = j; 650 | } 651 | } 652 | } 653 | 654 | long long int smallest_page 655 | = *(page_list_sorted_addr[smallest_set] + min_idx); 656 | 657 | cudaMemcpy(page_allocation[smallest_set] + smallest_page, 658 | page_list_sorted_addr_dev[largest_set] + pos[largest_set], 659 | sizeof(long long int), cudaMemcpyDeviceToDevice); 660 | 661 | *(page_occupation_unsorted_addr[smallest_set] + largest_page) 662 | = (int)occupied; 663 | #endif 664 | } else { 665 | const float **fisher_addr_dev; 666 | cudaMalloc((void **)&fisher_addr_dev, sizeof(float *)*num_of_list); 667 | cudaMemcpy(fisher_addr_dev, fisher_addr, 668 | sizeof(float *)*num_of_list, cudaMemcpyHostToDevice); 669 | const float **weight_addr_dev; 670 | cudaMalloc((void **)&weight_addr_dev, sizeof(float *)*num_of_list); 671 | cudaMemcpy(weight_addr_dev, weight_addr, 672 | sizeof(float *)*num_of_list, cudaMemcpyHostToDevice); 673 | long long int **page_occupation_unsorted_addr_dev_dev; 674 | cudaMalloc((void **)&page_occupation_unsorted_addr_dev_dev, 675 | sizeof(long long int *)*num_of_list); 676 | cudaMemcpy(page_occupation_unsorted_addr_dev_dev, 677 | page_occupation_unsorted_addr_dev, 678 | sizeof(long long int *)*num_of_list, 679 | cudaMemcpyHostToDevice); 680 | 681 | MinSharingCostMultiKernel<<<32, 128>>>(fisher1, weight1, fisher_addr_dev, 682 | weight_addr_dev, page_occupation_addr_dev[largest_set], 683 | page_list_len[0], page_cost_accum_dev, 684 | page_occupation_unsorted_addr_dev_dev, 685 | page_size, num_of_list, min_cost_dev, min_idx_dev); 686 | cudaDeviceSynchronize(); 687 | 688 | cudaFree(fisher_addr_dev); 689 | cudaFree(weight_addr_dev); 690 | cudaFree(page_occupation_unsorted_addr_dev_dev); 691 | 692 | float min_cost; 693 | cudaMemcpy(&min_cost, min_cost_dev, 694 | sizeof(float), cudaMemcpyDeviceToHost); 695 | if (min_cost < 0) { 696 | printf("[EXIT] min_cost < 0\n"); 697 | exit(EXIT_FAILURE); 698 | } 699 | 700 | total_min_cost2 += min_cost; 701 | 702 | cudaMemcpy(&min_idx, min_idx_dev, sizeof(int), cudaMemcpyDeviceToHost); 703 | if (min_idx < 0 || min_idx >= page_list_len[0]) { 704 | printf("[EXIT] min_idx < 0 (%d)\n", min_idx); 705 | exit(EXIT_FAILURE); 706 | } 707 | 708 | cudaMemcpy(page_allocation[largest_set-1] + largest_page, 709 | page_occupation_addr_dev[largest_set] + min_idx, 710 | sizeof(long long int), cudaMemcpyDeviceToDevice); 711 | 712 | *(page_list_sorted_addr[largest_set] + pos[largest_set]) = occupied; 713 | cudaMemcpy(page_list_sorted_addr_dev[largest_set] + pos[largest_set], 714 | &occupied, sizeof(long long int), cudaMemcpyHostToDevice); 715 | 716 | cudaMemcpy(page_occupation_addr_dev[largest_set] + min_idx, 717 | &occupied, sizeof(long long int), cudaMemcpyHostToDevice); 718 | 719 | long long int base_page = -1; 720 | cudaMemcpy(&base_page, page_list_sorted_addr_dev[0] + min_idx, 721 | sizeof(long long int), cudaMemcpyDeviceToHost); 722 | *(page_occupation_unsorted_addr[largest_set] + base_page) 723 | = largest_page; 724 | cudaMemcpy(page_occupation_unsorted_addr_dev[largest_set] + base_page, 725 | &largest_page, sizeof(long long int), 726 | cudaMemcpyHostToDevice); 727 | 728 | cudaMemcpy(&page_cost_accum_dev[base_page], min_cost_dev, 729 | sizeof(float), cudaMemcpyDeviceToDevice); 730 | } 731 | } 732 | 733 | float *page_cost_accum = (float *)malloc(sizeof(float)*page_list_len[0]); 734 | cudaMemcpy(page_cost_accum, page_cost_accum_dev, 735 | sizeof(float)*page_list_len[0], cudaMemcpyDeviceToHost); 736 | 737 | for (int i = 0; i < page_list_len[0]; i++) { 738 | total_min_cost += page_cost_accum[i]; 739 | } 740 | free(page_cost_accum); 741 | cudaMemcpy(total_cost, &total_min_cost, sizeof(float), cudaMemcpyHostToDevice); 742 | printf("total_min_cost2 = %f\n", total_min_cost2); 743 | 744 | // sanity check 1 745 | for (int i = 1; i < num_of_list; i++) { 746 | for (int j = 0; j < page_list_len[i]; j++) { 747 | if (*(page_list_sorted_addr[i] + j) != occupied) { 748 | printf("[EXIT] page alloc fails 1-1 (%d)\n", i); 749 | exit(EXIT_FAILURE); 750 | } 751 | long long int page; 752 | cudaMemcpy(&page, page_list_sorted_addr_dev[i] + j, 753 | sizeof(long long int), cudaMemcpyDeviceToHost); 754 | if (page != occupied) { 755 | printf("[EXIT] page alloc fails 1-2 (%d)\n", i); 756 | exit(EXIT_FAILURE); 757 | } 758 | } 759 | } 760 | 761 | // sanity check 2 762 | for (int i = 1; i < num_of_list; i++) { 763 | int num_of_occupied = 0; 764 | for (int j = 0; j < page_list_len[0]; j++) { 765 | long long int page; 766 | cudaMemcpy(&page, page_occupation_addr_dev[i] + j, 767 | sizeof(long long int), cudaMemcpyDeviceToHost); 768 | if (page == occupied) { 769 | num_of_occupied += 1; 770 | } 771 | } 772 | 773 | if (num_of_occupied != page_list_len[i]) { 774 | printf("[EXIT] page alloc fails 2 (%d)\n", i); 775 | exit(EXIT_FAILURE); 776 | } 777 | } 778 | 779 | // sanity check 3 780 | for (int i = 1; i < num_of_list; i++) { 781 | int num_of_occupation = 0; 782 | int num_of_occupation_dev = 0; 783 | for (int j = 0; j < page_list_len[0]; j++) { 784 | long long int page; 785 | page = *(page_occupation_unsorted_addr[i] + j); 786 | if (page >= 0) { 787 | num_of_occupation +=1; 788 | } 789 | 790 | cudaMemcpy(&page, page_occupation_unsorted_addr_dev[i] + j, 791 | sizeof(long long int), cudaMemcpyDeviceToHost); 792 | if (page >= 0) { 793 | num_of_occupation_dev +=1; 794 | } 795 | } 796 | 797 | if (num_of_occupation != page_list_len[i]) { 798 | printf("[EXIT] page alloc fails 3-1 (%d)\n", i); 799 | exit(EXIT_FAILURE); 800 | } 801 | 802 | if (num_of_occupation_dev != page_list_len[i]) { 803 | printf("[EXIT] page alloc fails 3-2 (%d)\n", i); 804 | exit(EXIT_FAILURE); 805 | } 806 | } 807 | 808 | cudaFree(min_cost_dev); 809 | cudaFree(min_idx_dev); 810 | 811 | for (int i = 0; i < num_of_list; i++) { 812 | cudaFree(page_cost_accum_dev); 813 | cudaFree(page_occupation_addr_dev[i]); 814 | free(page_sum_sorted_addr[i]); 815 | free(page_list_sorted_addr[i]); 816 | cudaFree(page_list_sorted_addr_dev[i]); 817 | cudaFree(page_occupation_unsorted_addr_dev[i]); 818 | free(page_occupation_unsorted_addr[i]); 819 | } 820 | } 821 | 822 | __global__ void HarmonicMeanKernel(const float **input_data_addr, int input_size, 823 | int input_len, float *output) 824 | { 825 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_len; 826 | i += blockDim.x * gridDim.x) { 827 | int numerator = 0; 828 | float denominator = 0; 829 | 830 | for (int j = 0; j < input_size; j++) { 831 | float num = *(input_data_addr[j] + i); 832 | if (num > 0) { 833 | denominator += (1.0f / num); 834 | numerator += 1; 835 | } 836 | } 837 | 838 | if (denominator > 0) { 839 | float harmonic_mean = (float)numerator / denominator; 840 | atomicAdd(output, harmonic_mean); 841 | } 842 | } 843 | } 844 | 845 | void HarmonicMeanKernelLauncher(const float *input_data_addr[], int input_size, 846 | int input_len, float *output) 847 | { 848 | const float **input_data_addr_dev; 849 | cudaMalloc((void **)&input_data_addr_dev, sizeof(float *)*input_size); 850 | cudaMemcpy(input_data_addr_dev, input_data_addr, 851 | sizeof(float *)*input_size, cudaMemcpyHostToDevice); 852 | HarmonicMeanKernel<<<32, 256>>>(input_data_addr_dev, input_size, input_len, output); 853 | } 854 | -------------------------------------------------------------------------------- /tf_operation/tf_operation.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/tf_operation/tf_operation.cu.o -------------------------------------------------------------------------------- /tf_operation/tf_operation.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/tf_operation/tf_operation.so -------------------------------------------------------------------------------- /tf_operation/weight_loader.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void GetWeightKernelLauncher(float *input, int input_len, float* addr, 4 | int* page_table_addr, int page_size, int start, int end); 5 | 6 | void get_weight(long long int *weight_address_list, int *weight_len_list, int num_of_weight, 7 | long long int virtual_weight_address, long long int page_table_address, 8 | int page_size) 9 | { 10 | int start = 0; 11 | int end = 0; 12 | 13 | for (int i = 0; i < num_of_weight; i++) { 14 | float *input = (float *)weight_address_list[i]; 15 | int input_len = weight_len_list[i]; 16 | float *address = (float *)virtual_weight_address; 17 | int *page_table_addr = (int *)page_table_address; 18 | 19 | end = start + input_len - 1; 20 | GetWeightKernelLauncher(input, input_len, address, page_table_addr, 21 | page_size, start, end); 22 | start = end + 1; 23 | } 24 | } 25 | 26 | -------------------------------------------------------------------------------- /tf_operation/weight_loader.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | __global__ void GetWeightKernel(float *input, int input_len, float *addr, 5 | int *page_table_addr, int page_size, int start, int end) 6 | { 7 | int idx, page_num, page, offset; 8 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_len; 9 | i += blockDim.x * gridDim.x) { 10 | idx = start+i; 11 | page_num = idx / page_size; 12 | page = page_table_addr[page_num]; 13 | offset = idx % page_size; 14 | input[i] = addr[page*page_size + offset]; 15 | } 16 | } 17 | 18 | extern "C" { 19 | void GetWeightKernelLauncher(float *input, int input_len, float* addr, 20 | int* page_table_addr, int page_size, int start, int end) 21 | { 22 | GetWeightKernel<<<32, 256>>>(input, input_len, addr, 23 | page_table_addr, page_size, start, end); 24 | cudaDeviceSynchronize(); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /tf_operation/weight_loader.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/tf_operation/weight_loader.cu.o -------------------------------------------------------------------------------- /tf_operation/weight_loader.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/tf_operation/weight_loader.so -------------------------------------------------------------------------------- /weight_loader.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning1234embed/NeuralWeightVirtualization/b799860c54ac7a9b3cdaf8398bd7b035e6757ae2/weight_loader.so -------------------------------------------------------------------------------- /weight_virtualization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.python.framework import ops 5 | import os 6 | import copy 7 | import pickle 8 | import struct 9 | import sys 10 | import argparse 11 | import importlib 12 | import matplotlib.pyplot as plt 13 | import time 14 | 15 | tf.logging.set_verbosity(tf.logging.ERROR) 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 17 | #np.set_printoptions(threshold=sys.maxsize) 18 | 19 | class VNN: 20 | def __init__(self, network_path, id, state=-1, weight_page_list=None): 21 | self.id = id 22 | self.name = os.path.basename(os.path.normpath(network_path)) 23 | self.state = state 24 | self.weight_page_list = weight_page_list 25 | 26 | self.network_path = network_path 27 | assert os.path.exists(self.network_path), 'No network path %s exists' % self.network_path 28 | 29 | self.meta_filename = self.name + '.meta' 30 | self.meta_filepath = os.path.join(self.network_path, self.meta_filename) 31 | assert os.path.exists(self.meta_filepath), 'No filepath %s exists' % self.meta_filepath 32 | 33 | self.num_of_weight = self.get_weight_num() 34 | 35 | self.model_filepath = os.path.join(self.network_path, self.name) 36 | 37 | self.network_weight_filename = self.name + '_network_weight.npy' 38 | self.network_weight_filepath = os.path.join(self.network_path, self.network_weight_filename) 39 | self.network_fisher_filename = self.name + '_network_fisher.npy' 40 | self.network_fisher_filepath = os.path.join(self.network_path, self.network_fisher_filename) 41 | 42 | self.weight_filename = self.name + '_weight.npy' 43 | self.weight_filepath = os.path.join(self.network_path, self.weight_filename) 44 | self.fisher_filename = self.name + '_fisher.npy' 45 | self.fisher_filepath = os.path.join(self.network_path, self.fisher_filename) 46 | 47 | self.pintle_filename = 'pintle.py' 48 | self.pintle_filepath = os.path.join(self.network_path, self.pintle_filename) 49 | 50 | self.filepath = self.name + '.vnn' 51 | 52 | def get_weight_num(self): 53 | meta_graph_def = tf.MetaGraphDef() 54 | with open(self.meta_filepath, 'rb') as f: 55 | meta_graph_def.MergeFromString(f.read()) 56 | 57 | num_of_weight = 0 58 | with tf.Graph().as_default() as graph: 59 | tf.train.import_meta_graph(meta_graph_def) 60 | trainable_variables = tf.trainable_variables() 61 | for trainable_variable in trainable_variables: 62 | num_of_weight += np.prod(trainable_variable.get_shape().as_list()) 63 | 64 | return num_of_weight 65 | 66 | class WeightVirtualization: 67 | __instance = None 68 | @staticmethod 69 | def getInstance(): 70 | if WeightVirtualization.__instance == None: 71 | WeightVirtualization() 72 | return WeightVirtualization.__instance 73 | 74 | def __init__(self, num_of_weight_page=665, weight_per_page=100, 75 | weight_page_filename='virtual_weight_page.npy', 76 | weight_page_occupation_filename='weight_page_occupation.npy', 77 | weight_virtualization_op_filename='./tf_operation.so'): 78 | 79 | if WeightVirtualization.__instance != None: 80 | raise Exception("this class is a singleton") 81 | else: 82 | WeightVirtualization.__instance = self 83 | 84 | self.num_of_weight_page = num_of_weight_page 85 | self.weight_per_page = weight_per_page 86 | self.weight_page = None 87 | 88 | self.weight_page_filename = weight_page_filename 89 | self.weight_page_occupation_filename = weight_page_occupation_filename 90 | self.weight_virtualization_op_filename = weight_virtualization_op_filename 91 | 92 | if self.load_weight_page() is False: 93 | print('init new weight pages') 94 | self.init_weight_page() 95 | self.save_weight_page() 96 | 97 | self.next_vnn_id = 0 98 | self.vnns = {} 99 | self.load_vnns() 100 | 101 | def create_vnn(self, network_path): 102 | # create a vnn 103 | vnn = VNN(network_path, self.next_vnn_id) 104 | if vnn.name in self.vnns: 105 | raise Exception('vnn named %s is already there' % vnn.name) 106 | 107 | # save the network weights 108 | self.save_network_weight(vnn) 109 | 110 | # get and save fisher of network 111 | fisher_information = self.compute_fisher(vnn, 'network') 112 | self.save_network_fisher(vnn, fisher_information) 113 | 114 | return vnn 115 | 116 | def add_vnn(self, network_path): 117 | print('add_vnn') 118 | # create a vnn 119 | vnn = self.create_vnn(network_path) 120 | 121 | # allocate weight pages 122 | self.match_weight_page(vnn) 123 | 124 | # increment next_vnn_id and add vnn to the dictionary 125 | self.vnns[vnn.name] = vnn 126 | self.next_vnn_id += 1 127 | 128 | # save vnn 129 | self.save_vnn(vnn) 130 | 131 | total_vnn_list = [] 132 | for name, vnn in self.vnns.items(): 133 | total_vnn_list.append(vnn) 134 | 135 | total_network_cost = self.calculate_network_cost(total_vnn_list) 136 | print('total_network_cost:', total_network_cost) 137 | 138 | def add_multi_vnns(self, network_path_list): 139 | print('add_multi_vnns') 140 | new_vnn_list = [] 141 | for network_path in network_path_list: 142 | vnn = self.create_vnn(network_path) 143 | new_vnn_list.append(vnn) 144 | self.next_vnn_id += 1 145 | 146 | self.match_weight_page_multi(new_vnn_list) 147 | 148 | for vnn in new_vnn_list: 149 | self.vnns[vnn.name] = vnn 150 | self.save_vnn(vnn) 151 | 152 | total_vnn_list = [] 153 | for name, vnn in self.vnns.items(): 154 | total_vnn_list.append(vnn) 155 | 156 | total_network_cost = self.calculate_network_cost(total_vnn_list) 157 | print('total_network_cost:', total_network_cost) 158 | 159 | def remove_vnn(self, vnn): 160 | self.dematch_weight_page(vnn) 161 | 162 | if vnn.name in self.vnns: 163 | del self.vnns[vnn.name] 164 | 165 | if os.path.exists(vnn.network_weight_filepath): 166 | os.remove(vnn.network_weight_filepath) 167 | 168 | if os.path.exists(vnn.network_fisher_filepath): 169 | os.remove(vnn.network_fisher_filepath) 170 | 171 | if os.path.exists(vnn.weight_filepath): 172 | os.remove(vnn.weight_filepath) 173 | 174 | if os.path.exists(vnn.fisher_filepath): 175 | os.remove(vnn.fisher_filepath) 176 | 177 | if os.path.exists(vnn.filepath): 178 | os.remove(vnn.filepath) 179 | 180 | def train_vnn(self, vnn, iteration): 181 | with tf.Graph().as_default() as graph: 182 | with tf.Session(graph=graph) as sess: 183 | self.restore_vnn(vnn, graph, sess) 184 | matching_loss = self.get_matching_loss(vnn, sess, lamb=10.0) 185 | pintle = self.import_pintle(vnn) 186 | weight_vector = pintle.pintle.v_train(graph, sess, matching_loss, 187 | 100, iteration, self.get_weight_from_vnn) 188 | #weight_vector = self.get_weight_from_vnn(sess) 189 | if weight_vector is not None: 190 | self.apply_weight_to_page(vnn, weight_vector) 191 | self.save_weight_page() 192 | 193 | def execute_vnn(self, vnn, input_variables, ground_truth=None): 194 | pintle = self.import_pintle(vnn) 195 | input_variable_names = pintle.pintle.v_input_variable_names() 196 | input_tensors = [] 197 | 198 | with tf.Graph().as_default() as graph: 199 | with tf.Session(graph=graph) as sess: 200 | self.restore_vnn(vnn, graph, sess) 201 | 202 | for variable_name in input_variable_names: 203 | input_tensor_name = variable_name + ':0' 204 | input_tensors.append(graph.get_tensor_by_name(input_tensor_name)) 205 | 206 | return pintle.pintle.v_execute(graph, sess, 207 | input_tensors, input_variables, ground_truth) 208 | 209 | def load_vnns(self): 210 | for file in sorted(os.listdir("./")): 211 | if file.endswith(".vnn"): 212 | vnn = self.load_vnn(file) 213 | self.vnns[vnn.name] = vnn 214 | if vnn.id >= self.next_vnn_id: 215 | self.next_vnn_id = vnn.id + 1 216 | 217 | def load_vnn(self, filepath): 218 | with open(filepath, 'rb') as f: 219 | vnn = pickle.load(f) 220 | return vnn 221 | 222 | def save_vnn(self, vnn): 223 | with open(vnn.filepath, 'wb') as f: 224 | pickle.dump(vnn, f) 225 | 226 | def load_weight(self, vnn): 227 | weight = np.load(vnn.weight_filepath, allow_pickle=True) 228 | return weight 229 | 230 | def save_weight(self, vnn): 231 | with tf.Graph().as_default() as graph: 232 | with tf.Session(graph=graph) as sess: 233 | self.restore_vnn(vnn, graph, sess) 234 | tensor_weights = tf.trainable_variables() 235 | weights = sess.run(tensor_weights) 236 | 237 | np.save(vnn.weight_filepath, weights) 238 | print(vnn.weight_filepath) 239 | 240 | def load_network_weight(self, vnn): 241 | network_weight = np.load(vnn.network_weight_filepath, allow_pickle=True) 242 | return network_weight 243 | 244 | def save_network_weight(self, vnn): 245 | with tf.Graph().as_default() as graph: 246 | with tf.Session(graph=graph) as sess: 247 | self.restore_network(vnn, sess) 248 | tensor_weights = tf.trainable_variables() 249 | network_weights = sess.run(tensor_weights) 250 | 251 | np.save(vnn.network_weight_filepath, network_weights) 252 | print(vnn.network_weight_filepath) 253 | 254 | def load_network_fisher(self, vnn): 255 | fisher_information = np.load(vnn.network_fisher_filepath, allow_pickle=True) 256 | return fisher_information 257 | 258 | def save_network_fisher(self, vnn, fisher_information): 259 | np.save(vnn.network_fisher_filepath, fisher_information) 260 | print(vnn.network_fisher_filepath) 261 | 262 | def load_vnn_fisher(self, vnn): 263 | if os.path.exists(vnn.fisher_filepath): 264 | fisher_information = np.load(vnn.fisher_filepath, allow_pickle=True) 265 | return fisher_information 266 | else: 267 | return None 268 | 269 | def save_vnn_fisher(self, vnn, fisher_information): 270 | np.save(vnn.fisher_filepath, fisher_information) 271 | print(vnn.fisher_filepath) 272 | 273 | def load_weight_to_vnn(self, vnn, graph, sess, weight_vector): 274 | assign_tensor_weight = [] 275 | tensor_weights = tf.trainable_variables() 276 | start_idx = 0 277 | end_idx = 0 278 | 279 | for weight in tensor_weights: 280 | end_idx = start_idx + np.prod(weight.get_shape().as_list()) 281 | assign_weight = tf.assign(weight, weight_vector[start_idx:end_idx].reshape(weight.shape)) 282 | assign_tensor_weight.append(assign_weight) 283 | start_idx = end_idx 284 | 285 | sess.run(assign_tensor_weight) 286 | 287 | def get_weight_from_vnn(self, sess): 288 | tensor_weights = tf.trainable_variables() 289 | weights = sess.run(tensor_weights) 290 | 291 | weight_vector_list = [] 292 | for weight in weights: 293 | weight_vector_list.append(weight.reshape((weight.size))) 294 | 295 | weight_vector = np.concatenate(weight_vector_list) 296 | return weight_vector 297 | 298 | def get_weight_from_page(self, vnn): 299 | weight_vector_list = [] 300 | for page in vnn.weight_page_list: 301 | weight_vector_list.append(self.weight_page[page]) 302 | 303 | weight_vector = np.concatenate(weight_vector_list) 304 | return weight_vector[0:vnn.num_of_weight] 305 | 306 | def apply_weight_to_page(self, vnn, weight_vector): 307 | start_idx = 0 308 | end_idx = 0 309 | 310 | for page in vnn.weight_page_list: 311 | end_idx = start_idx + self.weight_page[page].size 312 | if end_idx <= len(weight_vector): 313 | self.weight_page[page] = copy.deepcopy(weight_vector[start_idx:end_idx]) 314 | start_idx = end_idx 315 | else: 316 | end_idx = len(weight_vector) 317 | for i in range(end_idx-start_idx): 318 | self.weight_page[page][i] = copy.deepcopy(weight_vector[start_idx+i]) 319 | 320 | def init_weight_page(self, num_of_weight_page=None, weight_per_page=None): 321 | if num_of_weight_page is not None: 322 | self.num_of_weight_page = num_of_weight_page 323 | 324 | if weight_per_page is not None: 325 | self.weight_per_page = weight_per_page 326 | 327 | self.weight_page = np.random.normal(scale=0.01, 328 | size=(self.num_of_weight_page, self.weight_per_page)).astype(np.float32) 329 | 330 | def load_weight_page(self, weight_page_filename=None): 331 | if weight_page_filename is not None: 332 | self.weight_page_filename = weight_page_filename 333 | if os.path.exists(self.weight_page_filename): 334 | self.weight_page = np.load(self.weight_page_filename, allow_pickle=True) 335 | self.num_of_weight_page = len(self.weight_page) 336 | self.weight_per_page = len(self.weight_page[0]) 337 | return True 338 | else: 339 | return False 340 | 341 | def save_weight_page(self, weight_page_filename=None): 342 | if weight_page_filename is not None: 343 | self.weight_page_filename = weight_page_filename 344 | np.save(self.weight_page_filename, self.weight_page) 345 | 346 | def update_weight_page_occupation(self, vnn): 347 | weight_page_occupation = self.load_weight_page_occupation() 348 | 349 | for i in range(len(vnn.weight_page_list)): 350 | page_no = vnn.weight_page_list[i] 351 | weight_page_occupation[page_no].append([vnn.id, i]) 352 | 353 | self.save_weight_page_occupation(weight_page_occupation) 354 | 355 | def load_weight_page_occupation(self): 356 | if os.path.exists(self.weight_page_occupation_filename): 357 | return np.load(self.weight_page_occupation_filename, allow_pickle=True) 358 | else: 359 | weight_page_occupation = [[] for _ in np.arange(self.num_of_weight_page, dtype=np.int32)] 360 | return weight_page_occupation 361 | 362 | def save_weight_page_occupation(self, weight_page_occupation): 363 | np.save(self.weight_page_occupation_filename, weight_page_occupation) 364 | 365 | def import_pintle(self, vnn): 366 | pintle_name = os.path.splitext(vnn.pintle_filepath)[0] 367 | pintle_import_name = pintle_name.replace('/', '.') 368 | pintle = __import__(pintle_import_name) 369 | return pintle 370 | 371 | def restore_network(self, vnn, sess): 372 | saver = tf.train.import_meta_graph(vnn.meta_filepath) 373 | saver.restore(sess, vnn.model_filepath) 374 | 375 | def restore_vnn(self, vnn, graph, sess): 376 | tf.train.import_meta_graph(vnn.meta_filepath) 377 | weight_vector = self.get_weight_from_page(vnn) 378 | self.load_weight_to_vnn(vnn, graph, sess, weight_vector) 379 | 380 | def do_compute_fisher(self, sess, fx_tensors, x_tensors, input_tensors,\ 381 | input_variables, num_samples=100): 382 | print("do_compute_fisher") 383 | 384 | # input_variable[0] is data 385 | assert input_variables[0].shape[0] >= num_samples 386 | 387 | fisher_information = [] 388 | for v in range(len(x_tensors)): 389 | fisher_information.append(np.zeros(x_tensors[v].get_shape().as_list()).astype(np.float32)) 390 | 391 | for i in range(num_samples): 392 | data_idx = np.random.randint(input_variables[0].shape[0]) 393 | sampled_data = input_variables[0][data_idx:data_idx+1] 394 | sampled_input_variables = [ sampled_data ] + input_variables[1:] 395 | print ('sample num: %4d, data_idx: %5d' % (i, data_idx)) 396 | 397 | derivatives, prob = sess.run([tf.gradients(tf.log(fx_tensors), x_tensors), fx_tensors], 398 | feed_dict={t: v for t,v in zip(input_tensors, sampled_input_variables)}) 399 | 400 | for v in range(len(fisher_information)): 401 | fisher_information[v] += np.square(derivatives[v]) * prob 402 | 403 | for v in range(len(fisher_information)): 404 | fisher_information[v] /= num_samples 405 | 406 | return fisher_information 407 | 408 | def compute_fisher(self, vnn, target): 409 | print('compute_fisher') 410 | 411 | pintle = self.import_pintle(vnn) 412 | input_variable_names = pintle.pintle.v_input_variable_names() 413 | 414 | with tf.Graph().as_default() as graph: 415 | with tf.Session(graph=graph) as sess: 416 | if target == 'network': 417 | self.restore_network(vnn, sess) 418 | elif target == 'vnn': 419 | self.restore_vnn(vnn, graph, sess) 420 | else: 421 | raise Exception('Neither network nor vnn') 422 | 423 | input_tensors = [] 424 | for variable_name in input_variable_names: 425 | input_tensor_name = variable_name + ':0' 426 | input_tensors.append(graph.get_tensor_by_name(input_tensor_name)) 427 | 428 | weight_tensors = tf.trainable_variables() 429 | pintle = self.import_pintle(vnn) 430 | raw_input_variables = pintle.pintle.v_train_input_variables() 431 | input_variables = [ raw_input_variables[0][0] ] 432 | for i in range(1, len(raw_input_variables)): 433 | input_variables.append(raw_input_variables[i]) 434 | 435 | target_tensors = pintle.pintle.v_fx_tensors(graph) 436 | 437 | fisher_information = self.do_compute_fisher(sess, target_tensors, \ 438 | weight_tensors, input_tensors, input_variables, num_samples=100) 439 | 440 | return fisher_information 441 | 442 | def get_fisher_sum_vector(self): 443 | fisher_dic = {} 444 | for name_, vnn_ in self.vnns.items(): 445 | fisher = self.load_network_fisher(vnn_) 446 | #fisher = self.load_vnn_fisher(vnn_) 447 | fisher_vector = self.vectorize_list(fisher) 448 | fisher_dic[vnn_.id] = fisher_vector 449 | 450 | fisher_sum_vector = np.zeros(self.num_of_weight_page*self.weight_per_page, dtype=np.float32) 451 | weight_page_occupation = self.load_weight_page_occupation() 452 | 453 | for i in range(len(weight_page_occupation)): 454 | fisher_list = [] 455 | for occupation in weight_page_occupation[i]: 456 | size = self.weight_per_page 457 | src_start = occupation[1]*self.weight_per_page 458 | if src_start + size > len(fisher_dic[occupation[0]]): 459 | size = len(fisher_dic[occupation[0]]) % self.weight_per_page 460 | src_end = src_start + size 461 | fisher = fisher_dic[occupation[0]][src_start:src_end] 462 | if len(fisher) != self.weight_per_page: 463 | fisher = np.concatenate([fisher, 464 | np.zeros(self.weight_per_page-len(fisher), dtype=np.float32)]) 465 | fisher_list.append(fisher) 466 | 467 | if not fisher_list: 468 | continue 469 | 470 | fisher_page_sum = np.sum(fisher_list, axis=0) 471 | dst_start = i*self.weight_per_page 472 | dst_end = dst_start+self.weight_per_page 473 | fisher_sum_vector[dst_start:dst_end] = fisher_page_sum 474 | 475 | return fisher_sum_vector 476 | 477 | def get_fisher_vector_page_order(self, vnn, target): 478 | fisher = None 479 | if target == 'network': 480 | fisher = self.load_network_fisher(vnn) 481 | elif target == 'vnn': 482 | fisher = self.load_vnn_fisher(vnn) 483 | else: 484 | raise Exception('Neither network nor vnn') 485 | 486 | fisher_vector = self.vectorize_list(fisher) 487 | 488 | fisher_vector_page_order = np.zeros(self.num_of_weight_page*self.weight_per_page, 489 | dtype=np.float32) 490 | weight_page_occupation = self.load_weight_page_occupation() 491 | 492 | for i in range(len(weight_page_occupation)): 493 | for occupation in weight_page_occupation[i]: 494 | if occupation[0] == vnn.id: 495 | size = self.weight_per_page 496 | src_start = occupation[1]*self.weight_per_page 497 | if src_start + size > len(fisher_vector): 498 | size = len(fisher_vector) % self.weight_per_page 499 | src_end = src_start + size 500 | fisher = fisher_vector[src_start:src_end] 501 | if len(fisher) != self.weight_per_page: 502 | fisher = np.concatenate([fisher, 503 | np.zeros(self.weight_per_page-len(fisher), 504 | dtype=np.float32)]) 505 | 506 | dst_start = i*self.weight_per_page 507 | dst_end = dst_start+self.weight_per_page 508 | fisher_vector_page_order[dst_start:dst_end] = fisher 509 | break 510 | 511 | return fisher_vector_page_order 512 | 513 | def get_weight_vector_page_order(self, vnn, target): 514 | weight = None 515 | if target == 'network': 516 | weight = self.load_network_weight(vnn) 517 | elif target == 'vnn': 518 | weight = self.load_vnn_weight(vnn) 519 | else: 520 | raise Exception('Neither network nor vnn') 521 | 522 | weight_vector = self.vectorize_list(weight) 523 | 524 | weight_vector_page_order = np.zeros(self.num_of_weight_page*self.weight_per_page, 525 | dtype=np.float32) 526 | weight_page_occupation = self.load_weight_page_occupation() 527 | 528 | for i in range(len(weight_page_occupation)): 529 | for occupation in weight_page_occupation[i]: 530 | if occupation[0] == vnn.id: 531 | size = self.weight_per_page 532 | src_start = occupation[1]*self.weight_per_page 533 | if src_start + size > len(weight_vector): 534 | size = len(weight_vector) % self.weight_per_page 535 | src_end = src_start + size 536 | weight = weight_vector[src_start:src_end] 537 | if len(weight) != self.weight_per_page: 538 | weight = np.concatenate([weight, 539 | np.zeros(self.weight_per_page-len(weight), 540 | dtype=np.float32)]) 541 | 542 | dst_start = i*self.weight_per_page 543 | dst_end = dst_start+self.weight_per_page 544 | weight_vector_page_order[dst_start:dst_end] = weight 545 | break 546 | 547 | return weight_vector_page_order 548 | 549 | def matching_cost_pair(self, fisher1, weight1, fisher2, weight2): 550 | assert len(fisher1) == len(weight1) 551 | assert len(fisher2) == len(weight2) 552 | assert len(fisher1) == len(fisher2) 553 | 554 | fisher_sum = np.add(fisher1, fisher2) 555 | square_weight_diff = np.square(np.subtract(weight1, weight2)) 556 | cost = np.sum(np.multiply(fisher_sum, square_weight_diff)) 557 | 558 | return cost 559 | 560 | def calculate_cost(self, vnn): 561 | print('[calculate_cost]') 562 | fisher_sum_vector = self.get_fisher_sum_vector() 563 | weight_vector = self.weight_page.flatten() 564 | assert len(fisher_sum_vector) == len(weight_vector) 565 | 566 | 567 | network_fisher_vector = self.get_fisher_vector_pad(vnn, 'network') 568 | network_weight_vector = self.get_weight_vector_pad(vnn, 'network') 569 | assert len(network_fisher_vector) == len(network_weight_vector) 570 | 571 | total_cost = 0.0 572 | idx = 0 573 | 574 | for page_no in vnn.weight_page_list: 575 | size = self.weight_per_page 576 | start_n = idx * self.weight_per_page 577 | end_n = start_n + size 578 | start_s = page_no * self.weight_per_page 579 | end_s = start_s + size 580 | 581 | fisher_network = network_fisher_vector[start_n:end_n] 582 | weight_network = network_weight_vector[start_n:end_n] 583 | fisher_star = fisher_sum_vector[start_s:end_s] 584 | weight_star = weight_vector[start_s:end_s] 585 | 586 | zero_fisher_network = np.where(fisher_network == 0) 587 | fisher_star[zero_fisher_network] = 0 588 | zero_fisher_star = np.where(fisher_star == 0) 589 | fisher_network[zero_fisher_star] = 0 590 | 591 | fisher_cost = np.add(fisher_network, fisher_star) 592 | #fisher_cost = np.multiply(fisher_network, fisher_star) 593 | weight_cost = np.square(weight_network - weight_star) 594 | cost = np.sum(np.multiply(fisher_cost, weight_cost)) 595 | total_cost += cost 596 | idx += 1 597 | 598 | print('toal_cost:', total_cost) 599 | return total_cost 600 | 601 | def calculate_network_cost(self, vnn_list): 602 | total_cost = 0 603 | 604 | for i in range(len(vnn_list)): 605 | for j in range(i+1, len(vnn_list)): 606 | fisher1 = self.get_fisher_vector_page_order(vnn_list[i], 'network') 607 | weight1 = self.get_weight_vector_page_order(vnn_list[i], 'network') 608 | fisher2 = self.get_fisher_vector_page_order(vnn_list[j], 'network') 609 | weight2 = self.get_weight_vector_page_order(vnn_list[j], 'network') 610 | zero_fisher1 = np.where(fisher1 <= 0) 611 | fisher2[zero_fisher1] = 0 612 | zero_fisher2 = np.where(fisher2 <= 0) 613 | fisher1[zero_fisher2] = 0 614 | total_cost += self.matching_cost_pair(fisher1, weight1, fisher2, weight2) 615 | 616 | return total_cost 617 | 618 | def match_page_by_cost(self, vnn): 619 | print('[match_page_by_cost]') 620 | fisher_sum_vector = self.get_fisher_sum_vector() 621 | weight_vector = self.weight_page.flatten() 622 | assert len(fisher_sum_vector) == len(weight_vector) 623 | 624 | network_fisher_vector = self.get_fisher_vector_pad(vnn, 'network') 625 | network_weight_vector = self.get_weight_vector_pad(vnn, 'network') 626 | assert len(network_fisher_vector) == len(network_weight_vector) 627 | 628 | """ 629 | page_to_alloc = len(network_weight_vector)/self.weight_per_page 630 | page_list = np.arange(self.num_of_weight_page, dtype=np.int32) 631 | network_page_list = np.arange(page_to_alloc, dtype=np.int32) 632 | weight_virtualization_op = tf.load_op_library(self.weight_virtualization_op_filename) 633 | 634 | with tf.Graph().as_default() as graph: 635 | with tf.Session() as sess: 636 | page_alloc_op = weight_virtualization_op.page_alloc(fisher_sum_vector, 637 | weight_vector, page_list, network_fisher_vector, 638 | network_weight_vector, network_page_list, 639 | page_size=self.weight_per_page) 640 | page_match, cost = sess.run(page_alloc_op) 641 | 642 | print('cost:', cost) 643 | print('') 644 | weight_page_list = page_match[page_match[:,1].argsort()][:,0].astype(np.int32) 645 | if len(weight_page_list) > len(set(weight_page_list)): 646 | raise Exception('weight_page_list is not unique') 647 | assert len(weight_page_list) == page_to_alloc 648 | """ 649 | 650 | #""" 651 | weight_page_occupation = self.load_weight_page_occupation() 652 | len_list_of_occupation = np.asarray([len(page_occupation) for page_occupation in weight_page_occupation]) 653 | max_occupation = np.max(len_list_of_occupation) 654 | page_to_alloc = len(network_weight_vector)/self.weight_per_page 655 | network_page_list = np.arange(page_to_alloc, dtype=np.int32) 656 | page_match_list = [] 657 | total_cost = 0 658 | 659 | weight_virtualization_op = tf.load_op_library(self.weight_virtualization_op_filename) 660 | with tf.Graph().as_default() as graph: 661 | with tf.Session() as sess: 662 | for occupation in range(max_occupation+1): 663 | page_list = np.where(len_list_of_occupation == occupation)[0] 664 | print('occupation:', occupation) 665 | print('len(page_list):', len(page_list)) 666 | print('len(network_page_list):', len(network_page_list)) 667 | if len(page_list) <= 0: 668 | print('cost: 0\n') 669 | continue 670 | 671 | page_alloc_op = weight_virtualization_op.page_alloc(fisher_sum_vector, 672 | weight_vector, page_list, network_fisher_vector, 673 | network_weight_vector, network_page_list, 674 | page_size=self.weight_per_page) 675 | page_match, cost = sess.run(page_alloc_op) 676 | 677 | total_cost += cost 678 | print('cost:', cost) 679 | print('') 680 | page_match_list.append(page_match) 681 | network_page_list = list(set(network_page_list) - set(page_match[:,1])) 682 | 683 | if not network_page_list: 684 | break 685 | 686 | if network_page_list: 687 | raise Exception('network_page_list is not empty') 688 | 689 | page_match_array = np.concatenate(page_match_list) 690 | weight_page_list = page_match_array[page_match_array[:,1].argsort()][:,0].astype(np.int32) 691 | if len(weight_page_list) > len(set(weight_page_list)): 692 | raise Exception('weight_page_list is not unique') 693 | assert len(weight_page_list) == page_to_alloc 694 | 695 | #print('total_cost:', total_cost) 696 | #""" 697 | 698 | vnn.weight_page_list = weight_page_list 699 | 700 | def get_fisher_vector_pad(self, vnn, target): 701 | if target == 'network': 702 | fisher = self.load_network_fisher(vnn) 703 | elif target == 'vnn': 704 | fisher = self.load_vnn_fisher(vnn) 705 | else: 706 | raise Exception('Neither network nor vnn') 707 | 708 | fisher_vector = self.vectorize_list(fisher) 709 | pad_len = self.weight_per_page - (len(fisher_vector) % self.weight_per_page) 710 | fisher_vector_pad = np.concatenate([fisher_vector, 711 | np.zeros(pad_len, dtype=np.float32)]) 712 | assert len(fisher_vector_pad) % self.weight_per_page == 0 713 | 714 | return fisher_vector_pad 715 | 716 | def get_weight_vector_pad(self, vnn, target): 717 | if target == 'network': 718 | weight = self.load_network_weight(vnn) 719 | elif target == 'vnn': 720 | weight = self.load_vnn_weight(vnn) 721 | else: 722 | raise Exception('Neither network nor vnn') 723 | 724 | weight_vector = self.vectorize_list(weight) 725 | pad_len = self.weight_per_page - (len(weight_vector) % self.weight_per_page) 726 | weight_vector_pad = np.concatenate([weight_vector, 727 | np.zeros(pad_len, dtype=np.float32)]) 728 | assert len(weight_vector_pad) % self.weight_per_page == 0 729 | 730 | return weight_vector_pad 731 | 732 | def match_page_by_random_multi(self, vnn_list): 733 | print('[match_page_by_random_multi]') 734 | for vnn in vnn_list: 735 | network_fisher_vector = self.get_fisher_vector_pad(vnn, 'network') 736 | network_weight_vector = self.get_weight_vector_pad(vnn, 'network') 737 | assert len(network_fisher_vector) == len(network_weight_vector) 738 | page_to_alloc = len(network_weight_vector)/self.weight_per_page 739 | vnn.weight_page_list = np.random.choice(self.num_of_weight_page, page_to_alloc, replace=False) 740 | 741 | def match_page_by_cost_multi(self, vnn_list): 742 | print('[match_page_by_cost_multi]') 743 | fisher_sum_vector = self.get_fisher_sum_vector() 744 | weight_vector = self.weight_page.flatten() 745 | assert len(fisher_sum_vector) == len(weight_vector) 746 | base_page_list = np.arange(self.num_of_weight_page, dtype=np.int32) 747 | 748 | network_fisher_vector_list = [] 749 | network_weight_vector_list = [] 750 | network_page_list_list = [] 751 | page_to_alloc_list = [] 752 | 753 | for vnn in vnn_list: 754 | network_fisher_vector = self.get_fisher_vector_pad(vnn, 'network') 755 | network_weight_vector = self.get_weight_vector_pad(vnn, 'network') 756 | assert len(network_fisher_vector) == len(network_weight_vector) 757 | network_fisher_vector_list.append(network_fisher_vector) 758 | network_weight_vector_list.append(network_weight_vector) 759 | page_to_alloc = len(network_weight_vector)/self.weight_per_page 760 | page_to_alloc_list.append(page_to_alloc) 761 | network_page_list = np.arange(page_to_alloc, dtype=np.int32) 762 | network_page_list_list.append(network_page_list) 763 | 764 | weight_virtualization_op = tf.load_op_library(self.weight_virtualization_op_filename) 765 | with tf.Graph().as_default() as graph: 766 | with tf.Session() as sess: 767 | page_alloc_multi_op = weight_virtualization_op.page_alloc_multi(fisher_sum_vector, 768 | weight_vector, base_page_list, network_fisher_vector_list, 769 | network_weight_vector_list, network_page_list_list, 770 | page_size=self.weight_per_page) 771 | page_match_list, cost = sess.run(page_alloc_multi_op) 772 | 773 | print('total cost[d]:', cost) 774 | 775 | for vnn, page_match, num_of_page in zip(vnn_list, page_match_list, page_to_alloc_list): 776 | weight_page_list = page_match.astype(np.int32) 777 | assert len(weight_page_list) == num_of_page 778 | if len(weight_page_list) > len(set(weight_page_list)): 779 | raise Exception('weight_page_list is not unique') 780 | vnn.weight_page_list = weight_page_list 781 | 782 | def match_page_by_random(self, vnn, num_of_page_to_select): 783 | print('[match_page_by_random]') 784 | weight_page_occupation = self.load_weight_page_occupation() 785 | len_list_of_occupation = np.asarray([len(page_occupation) for page_occupation in weight_page_occupation]) 786 | max_occupation = np.max(len_list_of_occupation) 787 | page_sorted_by_occupation = [] 788 | num_of_weight_page_not_allocated = num_of_page_to_select 789 | weight_page_list = [] 790 | 791 | for occupation in range(max_occupation+1): 792 | page_sorted_by_occupation.append(np.where(len_list_of_occupation == occupation)) 793 | pages = page_sorted_by_occupation[occupation][0] 794 | 795 | if num_of_weight_page_not_allocated > len(pages): 796 | weight_page_list = np.concatenate((weight_page_list, pages)) 797 | num_of_weight_page_not_allocated -= len(pages) 798 | else: 799 | if occupation != 0: 800 | np.random.shuffle(pages) # By RANDOM 801 | weight_page_list = np.concatenate((weight_page_list, 802 | pages[0:num_of_weight_page_not_allocated])) 803 | num_of_weight_page_not_allocated = 0 804 | 805 | np.random.shuffle(weight_page_list) 806 | if len(weight_page_list) > len(set(weight_page_list)): 807 | raise Exception('weight_page_list is not unique') 808 | 809 | vnn.weight_page_list = weight_page_list.astype(np.int32) 810 | #vnn.weight_page_list = np.random.choice(self.num_of_weight_page, num_of_page_to_select, replace=False) 811 | 812 | def match_weight_page(self, vnn): 813 | num_of_weight_page = vnn.num_of_weight // self.weight_per_page 814 | if vnn.num_of_weight % self.weight_per_page != 0: 815 | num_of_weight_page += 1 816 | assert num_of_weight_page <= self.num_of_weight_page,\ 817 | "%d vs. %d" % (num_of_weight_page, self.num_of_weight_page) 818 | 819 | if not self.vnns: 820 | vnn.weight_page_list = np.arange(num_of_weight_page, dtype=np.int32) 821 | else: 822 | time1 = time.time() 823 | #self.match_page_by_random(vnn, num_of_weight_page) 824 | self.match_page_by_cost(vnn) 825 | time2 = time.time() 826 | print('assing_page %0.3f ms' % ((time2-time1)*1000.0)) 827 | 828 | self.calculate_cost(vnn) 829 | self.update_weight_page_occupation(vnn) 830 | print("%d pages allocated for %d weights" % 831 | (len(vnn.weight_page_list), vnn.num_of_weight)) 832 | 833 | def match_weight_page_multi(self, vnn_list): 834 | for vnn in vnn_list: 835 | num_of_weight_page = vnn.num_of_weight // self.weight_per_page 836 | if vnn.num_of_weight % self.weight_per_page != 0: 837 | num_of_weight_page += 1 838 | assert num_of_weight_page <= self.num_of_weight_page,\ 839 | "%d vs. %d" % (num_of_weight_page, self.num_of_weight_page) 840 | 841 | self.match_page_by_cost_multi(vnn_list) 842 | 843 | for vnn in vnn_list: 844 | self.update_weight_page_occupation(vnn) 845 | print("%d pages allocated for %d weights" % 846 | (len(vnn.weight_page_list), vnn.num_of_weight)) 847 | 848 | new_network_cost = self.calculate_network_cost(vnn_list) 849 | print('new_network_cost:', new_network_cost) 850 | 851 | """ 852 | if os.path.exists(self.weight_page_occupation_filename): 853 | os.remove(self.weight_page_occupation_filename) 854 | self.match_page_by_random_multi(vnn_list) 855 | for vnn in vnn_list: 856 | self.update_weight_page_occupation(vnn) 857 | total_cost = self.calculate_network_cost(vnn_list) 858 | print('total cost:', total_cost) 859 | 860 | if os.path.exists(self.weight_page_occupation_filename): 861 | os.remove(self.weight_page_occupation_filename) 862 | 863 | exit(1) 864 | """ 865 | 866 | def dematch_weight_page(self, vnn): 867 | weight_page_occupation = self.load_weight_page_occupation() 868 | for page in weight_page_occupation: 869 | for occupation in page: 870 | if occupation[0] == vnn.id: 871 | page.remove(occupation) 872 | break 873 | self.save_weight_page_occupation(weight_page_occupation) 874 | 875 | def quadratic_mean(self, cost_list): 876 | if not cost_list: 877 | return 0 878 | 879 | stacked = tf.stack(cost_list) 880 | non_zero = tf.cast(tf.count_nonzero(stacked, 0), tf.float32) 881 | non_zero_pad = tf.where(tf.equal(non_zero, 0), tf.ones_like(non_zero), non_zero) 882 | return tf.reduce_sum(tf.sqrt(tf.reduce_sum(tf.square(stacked), 0) / non_zero_pad)) 883 | 884 | def arithmetic_mean(self, cost_list): 885 | if not cost_list: 886 | return 0 887 | 888 | stacked = tf.stack(cost_list) 889 | non_zero = tf.cast(tf.count_nonzero(stacked, 0), tf.float32) 890 | non_zero_pad = tf.where(tf.equal(non_zero, 0), tf.ones_like(non_zero), non_zero) 891 | return tf.reduce_sum(tf.reduce_sum(stacked, 0) / non_zero_pad) 892 | 893 | def harmonic_mean(self, cost_list): 894 | @ops.RegisterGradient("HarmonicMean") 895 | def harmonic_mean_grad(op, grad): 896 | input_list = [] 897 | for i in range(len(op.inputs)): 898 | input_list.append(op.inputs[i]) 899 | 900 | stacked = tf.stack(input_list) 901 | non_zero = tf.count_nonzero(stacked, 0) 902 | grad_list = [] 903 | 904 | for i in range(len(op.inputs)): 905 | gradient = tf.where(op.inputs[i] <= 0, tf.zeros_like(non_zero), non_zero) 906 | grad_list.append(tf.cast(gradient, tf.float32)) 907 | 908 | return grad_list 909 | 910 | if not cost_list: 911 | return 0 912 | 913 | weight_virtualization_op = tf.load_op_library(self.weight_virtualization_op_filename) 914 | return weight_virtualization_op.harmonic_mean(cost_list) 915 | 916 | def get_matching_loss(self, vnn, sess, lamb=10.0): 917 | print ("get_matching_loss") 918 | matching_loss = tf.constant(0.0) 919 | 920 | tensor_weights = tf.trainable_variables() 921 | tensor_weights_concat = [] 922 | for weight in tensor_weights: 923 | tensor_weights_concat.append(tf.reshape(weight, [tf.size(weight)])) 924 | 925 | new_weight_vector = tf.concat(tensor_weights_concat, 0) 926 | new_weight_vector_len = new_weight_vector.get_shape().as_list()[0] 927 | weight_page_occupation = self.load_weight_page_occupation() 928 | cost_list = [] 929 | 930 | for name_, vnn_ in self.vnns.items(): 931 | if vnn.id == vnn_.id: 932 | continue; 933 | 934 | fisher = self.load_network_fisher(vnn_) 935 | #fisher = self.load_vnn_fisher(vnn_) 936 | if fisher is None: 937 | continue 938 | fisher_vector = self.vectorize_list(fisher) 939 | 940 | #weight = self.load_network_weight(vnn_) 941 | weight = self.load_weight(vnn_) 942 | #weight_vector = self.get_weight_from_page(vnn_) 943 | weight_vector = self.vectorize_list(weight) 944 | assert len(fisher_vector) == len(weight_vector) 945 | 946 | fisher_vector_reordered = np.zeros(new_weight_vector_len, dtype=np.float32) 947 | weight_vector_reordered = np.zeros(new_weight_vector_len, dtype=np.float32) 948 | 949 | page_idx = 0 950 | for page in vnn.weight_page_list: 951 | size = self.weight_per_page 952 | page_pos = -1 953 | 954 | for occupation in weight_page_occupation[page]: 955 | if occupation[0] == vnn_.id: 956 | page_pos = occupation[1] 957 | break 958 | 959 | if page_pos == -1: 960 | page_idx += 1 961 | continue 962 | 963 | if page_idx*self.weight_per_page+size > new_weight_vector_len: 964 | size = new_weight_vector_len % self.weight_per_page 965 | if page_pos*self.weight_per_page+size > len(fisher_vector): 966 | size = len(fisher_vector) % self.weight_per_page 967 | 968 | dst_start = page_idx*self.weight_per_page 969 | dst_end = dst_start+size 970 | src_start = page_pos*self.weight_per_page 971 | src_end = src_start+size 972 | fisher_vector_reordered[dst_start:dst_end] = fisher_vector[src_start:src_end] 973 | weight_vector_reordered[dst_start:dst_end] = weight_vector[src_start:src_end] 974 | page_idx += 1 975 | 976 | weight_diff_square = tf.square(new_weight_vector - weight_vector_reordered) 977 | cost_list.append(tf.multiply(weight_diff_square, fisher_vector_reordered)) 978 | 979 | if cost_list: 980 | #matching_loss += self.quadratic_mean(cost_list) 981 | matching_loss += self.arithmetic_mean(cost_list) 982 | #matching_loss += self.harmonic_mean(cost_list) 983 | 984 | return lamb*matching_loss 985 | 986 | def vectorize_list(self, list_to_vectorize): 987 | vector_list = [] 988 | for item in list_to_vectorize: 989 | vector_list.append(item.flatten()) 990 | return np.concatenate(vector_list) 991 | 992 | def plot_vnn_fisher(self, vnn_list=None): 993 | if vnn_list is None: 994 | vnn_list = [] 995 | for name, vnn in self.vnns.items(): 996 | vnn_list.append(vnn) 997 | 998 | for vnn in vnn_list: 999 | fisher = self.get_fisher_vector_page_order(vnn, 'vnn') 1000 | plt.plot(fisher, label=vnn.name) 1001 | plt.legend(loc='upper right') 1002 | plt.show() 1003 | 1004 | def plot_vnn_weight(self, vnn_list=None): 1005 | if vnn_list is None: 1006 | vnn_list = [] 1007 | for name, vnn in self.vnns.items(): 1008 | vnn_list.append(vnn) 1009 | 1010 | for vnn in vnn_list: 1011 | weight = self.get_weight_from_page(vnn) 1012 | plt.plot(weight, label=vnn.name) 1013 | plt.legend(loc='upper right') 1014 | plt.show() 1015 | 1016 | def plot_network_fisher(self, vnn_list=None): 1017 | if vnn_list is None: 1018 | vnn_list = [] 1019 | for name, vnn in self.vnns.items(): 1020 | vnn_list.append(vnn) 1021 | 1022 | for vnn in vnn_list: 1023 | fisher = self.load_network_fisher(vnn) 1024 | plt.plot(self.vectorize_list(fisher), label=vnn.name) 1025 | plt.legend(loc='upper right') 1026 | plt.show() 1027 | 1028 | def plot_network_weight(self, vnn_list=None): 1029 | if vnn_list is None: 1030 | vnn_list = [] 1031 | for name, vnn in self.vnns.items(): 1032 | vnn_list.append(vnn) 1033 | 1034 | for vnn in vnn_list: 1035 | weight = self.load_network_weight(vnn) 1036 | plt.plot(self.vectorize_list(weight), label=vnn.name) 1037 | plt.legend(loc='upper right') 1038 | plt.show() 1039 | 1040 | def plot_network_fisher_histogram(self, vnn): 1041 | network_fisher = self.load_network_fisher(vnn) 1042 | plt.hist(network_fisher, density=False, bins=100) 1043 | plt.show() 1044 | 1045 | def plot_network_weight_histogram(self, vnn): 1046 | network_weight = self.load_network_weight(vnn) 1047 | plt.hist(network_weight, density=False, bins=100) 1048 | plt.show() 1049 | 1050 | def plot_sharing_cost_heatmap(self, vnn_list, page_size=100): 1051 | vnn1 = vnn_list[0] 1052 | print('vnn1.num_of_weight', vnn1.num_of_weight) 1053 | vnn1_num_of_page = vnn1.num_of_weight // page_size 1054 | print('vnn1_num_of_page', vnn1_num_of_page) 1055 | 1056 | vnn2 = vnn_list[1] 1057 | print('vnn2.num_of_weight', vnn2.num_of_weight) 1058 | vnn2_num_of_page = vnn2.num_of_weight // page_size 1059 | print('vnn2_num_of_page', vnn2_num_of_page) 1060 | 1061 | weight1 = self.vectorize_list(self.load_network_weight(vnn1)) 1062 | print(weight1.shape) 1063 | fisher1 = self.vectorize_list(self.load_network_fisher(vnn1)) 1064 | print(fisher1.shape) 1065 | 1066 | weight2 = self.vectorize_list(self.load_network_weight(vnn2)) 1067 | fisher2 = self.vectorize_list(self.load_network_fisher(vnn2)) 1068 | 1069 | sharing_cost_matrix = np.zeros((vnn2_num_of_page, vnn1_num_of_page)) 1070 | 1071 | for i in range(vnn1_num_of_page): 1072 | if i >= vnn2_num_of_page: 1073 | break 1074 | 1075 | print(i) 1076 | sharing_cost_vector = np.zeros((vnn2_num_of_page)) 1077 | w1 = weight1[i*page_size:(i+1)*page_size] 1078 | f1 = fisher1[i*page_size:(i+1)*page_size] 1079 | 1080 | for j in range(vnn2_num_of_page): 1081 | w2 = weight2[j*page_size:(j+1)*page_size] 1082 | f2 = fisher2[j*page_size:(j+1)*page_size] 1083 | sharing_cost = np.sum(np.multiply(np.square(w1 - w2), (f1 + f2))) 1084 | sharing_cost_matrix[j,i] = sharing_cost 1085 | 1086 | small = vnn1_num_of_page 1087 | if small > vnn2_num_of_page: 1088 | small = vnn2_num_of_page 1089 | 1090 | sharing_cost_matrix = sharing_cost_matrix[0:small, 0:small] 1091 | print(sharing_cost_matrix.shape) 1092 | 1093 | import seaborn as sns 1094 | from matplotlib.colors import LogNorm 1095 | from matplotlib.pyplot import figure 1096 | figure(num=None, figsize=(8, 4.5)) 1097 | sns.set(font_scale=1.7) 1098 | ax = sns.heatmap(sharing_cost_matrix, 1099 | norm=LogNorm(vmin=np.min(sharing_cost_matrix), vmax=np.max(sharing_cost_matrix)), 1100 | cmap='Blues', 1101 | xticklabels=100, yticklabels=100) 1102 | ax.invert_yaxis() 1103 | plt.xlabel('xlabel', fontsize=20) 1104 | plt.ylabel('ylabel', fontsize=20) 1105 | plt.savefig("sharing_score_heatmap.pdf", bbox_inches='tight') 1106 | plt.show() 1107 | 1108 | def parse_arguments(argv): 1109 | parser = argparse.ArgumentParser() 1110 | 1111 | parser.add_argument('-mode', type=str, help='mode', default='l') 1112 | # a: add a vnn from a network 1113 | # am: add multiple vnns from networks 1114 | # r: remove a vnn 1115 | # t: train a vnn 1116 | # e: execute inference of a vnn 1117 | # f: compute fisher informaiont of a vnn 1118 | # c: calculate matching cost 1119 | 1120 | parser.add_argument('-network_path', type=str, help='network_path', default=None) 1121 | parser.add_argument('-vnn_name', type=str, help='vnn_name', default=None) 1122 | parser.add_argument('-iter', type=int, help='training iteration', default=5000) 1123 | 1124 | return parser.parse_args(argv) 1125 | 1126 | def main(args): 1127 | wv = WeightVirtualization() 1128 | 1129 | if args.mode == 'l': 1130 | print('[VNN list]') 1131 | for vnn_name in wv.vnns: 1132 | vnn = wv.vnns[vnn_name] 1133 | print('Name: %s, id: %d, path: %s, num_of_weight: %d' 1134 | % (vnn.name, vnn.id, vnn.network_path, vnn.num_of_weight)) 1135 | 1136 | weight_page_occupation = wv.load_weight_page_occupation() 1137 | print('\n[Weight page] total %d weight pages' % len(weight_page_occupation)) 1138 | print(weight_page_occupation) 1139 | 1140 | elif args.mode == 'a': 1141 | if args.network_path is None: 1142 | print('no network_path') 1143 | return 1144 | 1145 | assert os.path.exists(args.network_path) 1146 | wv.add_vnn(args.network_path) 1147 | 1148 | elif args.mode == 'am': 1149 | if args.network_path is None: 1150 | print('no network_path') 1151 | return 1152 | 1153 | network_path_list = args.network_path.split(',') 1154 | for network_path in network_path_list: 1155 | assert os.path.exists(network_path) 1156 | wv.add_multi_vnns(network_path_list) 1157 | 1158 | elif args.mode == 'r': 1159 | if args.vnn_name is None: 1160 | print('no vnn name') 1161 | return 1162 | 1163 | vnn = wv.vnns[args.vnn_name] 1164 | wv.remove_vnn(vnn) 1165 | 1166 | elif args.mode == 't': 1167 | if args.vnn_name is None: 1168 | print('no vnn name') 1169 | return 1170 | 1171 | vnn = wv.vnns[args.vnn_name] 1172 | wv.train_vnn(vnn, args.iter) 1173 | wv.save_weight(vnn) 1174 | 1175 | elif args.mode == 'e': 1176 | if args.vnn_name is None: 1177 | print('no vnn name') 1178 | return 1179 | 1180 | vnn = wv.vnns[args.vnn_name] 1181 | pintle = wv.import_pintle(vnn) 1182 | raw_input_variables = pintle.pintle.v_test_input_variables() 1183 | input_variables = [ raw_input_variables[0][0] ] 1184 | for i in range(1, len(raw_input_variables)): 1185 | input_variables.append(raw_input_variables[i]) 1186 | ground_truth = raw_input_variables[0][1] 1187 | 1188 | result, accuracy = wv.execute_vnn(vnn, input_variables, ground_truth) 1189 | if accuracy: 1190 | with open(args.vnn_name + '.accuracy', "a") as file: 1191 | file.write(str(accuracy) + '\n') 1192 | 1193 | elif args.mode == 'f': 1194 | if args.vnn_name is None: 1195 | print('no vnn name') 1196 | return 1197 | 1198 | vnn = wv.vnns[args.vnn_name] 1199 | fisher_information = wv.compute_fisher(vnn, 'vnn') 1200 | wv.save_vnn_fisher(vnn, fisher_information) 1201 | 1202 | elif args.mode == 'c': 1203 | if args.vnn_name is None: 1204 | print('no vnn_name') 1205 | return 1206 | 1207 | vnn_name_list = args.vnn_name.split(',') 1208 | vnn_list = [] 1209 | 1210 | for name in vnn_name_list: 1211 | vnn = wv.vnns[name] 1212 | assert os.path.exists(vnn.network_path) 1213 | vnn_list.append(vnn) 1214 | 1215 | total_cost = wv.calculate_network_cost(vnn_list) 1216 | print('total cost:', total_cost) 1217 | 1218 | elif args.mode == 'pf': 1219 | vnn_list = [] 1220 | if args.vnn_name is None: 1221 | for name, vnn in wv.vnns.items(): 1222 | vnn_list.append(vnn) 1223 | else: 1224 | vnn_name_list = args.vnn_name.split(',') 1225 | for name in vnn_name_list: 1226 | vnn = wv.vnns[name] 1227 | assert os.path.exists(vnn.network_path) 1228 | vnn_list.append(vnn) 1229 | wv.plot_vnn_fisher(vnn_list) 1230 | 1231 | elif args.mode == 'pw': 1232 | vnn_list = [] 1233 | if args.vnn_name is None: 1234 | for name, vnn in wv.vnns.items(): 1235 | vnn_list.append(vnn) 1236 | else: 1237 | vnn_name_list = args.vnn_name.split(',') 1238 | for name in vnn_name_list: 1239 | vnn = wv.vnns[name] 1240 | assert os.path.exists(vnn.network_path) 1241 | vnn_list.append(vnn) 1242 | wv.plot_vnn_weight(vnn_list) 1243 | 1244 | elif args.mode == 'pnf': 1245 | vnn_list = [] 1246 | if args.vnn_name is None: 1247 | for name, vnn in wv.vnns.items(): 1248 | vnn_list.append(vnn) 1249 | else: 1250 | vnn_name_list = args.vnn_name.split(',') 1251 | for name in vnn_name_list: 1252 | vnn = wv.vnns[name] 1253 | assert os.path.exists(vnn.network_path) 1254 | vnn_list.append(vnn) 1255 | wv.plot_network_fisher(vnn_list) 1256 | 1257 | elif args.mode == 'pnw': 1258 | vnn_list = [] 1259 | if args.vnn_name is None: 1260 | for name, vnn in wv.vnns.items(): 1261 | vnn_list.append(vnn) 1262 | else: 1263 | vnn_name_list = args.vnn_name.split(',') 1264 | for name in vnn_name_list: 1265 | vnn = wv.vnns[name] 1266 | assert os.path.exists(vnn.network_path) 1267 | vnn_list.append(vnn) 1268 | wv.plot_network_weight(vnn_list) 1269 | 1270 | elif args.mode == 'pnfh': 1271 | if args.vnn_name is None: 1272 | print('no vnn name') 1273 | return 1274 | 1275 | vnn = wv.vnns[args.vnn_name] 1276 | wv.plot_network_fisher_histogram(vnn) 1277 | 1278 | elif args.mode == 'pnwh': 1279 | if args.vnn_name is None: 1280 | print('no vnn name') 1281 | return 1282 | 1283 | vnn = wv.vnns[args.vnn_name] 1284 | wv.plot_network_weight_histogram(vnn) 1285 | 1286 | elif args.mode == 'heatmap': 1287 | vnn_list = [] 1288 | if args.vnn_name is None: 1289 | print('no vnn name') 1290 | return 1291 | else: 1292 | vnn_name_list = args.vnn_name.split(',') 1293 | assert len(vnn_name_list) == 2 1294 | for name in vnn_name_list: 1295 | vnn = wv.vnns[name] 1296 | assert os.path.exists(vnn.network_path) 1297 | vnn_list.append(vnn) 1298 | wv.plot_sharing_cost_heatmap(vnn_list) 1299 | 1300 | if __name__ == '__main__': 1301 | main(parse_arguments(sys.argv[1:])) 1302 | --------------------------------------------------------------------------------