├── .gitignore
├── DREAMER
├── DREAMER_Arousal_FastGRNN_48_16.ipynb
├── DREAMER_Arousal_GRU_48_16.ipynb
├── DREAMER_Arousal_LSTM_48_16.ipynb
├── DREAMER_Arousal_LSTM_64_16.ipynb
├── DREAMER_Data_Preprocessing.ipynb
├── DREAMER_Dominance_FastGRNN_48_16.ipynb
├── DREAMER_Dominance_GRU_48_16.ipynb
├── DREAMER_Dominance_LSTM_48_16.ipynb
├── DREAMER_SVM.ipynb
├── DREAMER_Valence_FastGRNN_48_16.ipynb
├── DREAMER_Valence_GRU_48_16.ipynb
└── DREAMER_Valence_LSTM_48_16.ipynb
├── ICMLA_Presentation.pdf
├── LICENSE
├── NOTICE
├── README.md
├── SWELL-KW
├── SWELL-KW_Analysis.ipynb
├── SWELL-KW_FastGRNN.ipynb
├── SWELL-KW_GRU.ipynb
├── SWELL-KW_LSTM.ipynb
└── SWELL-KW_Scores.ipynb
└── WESAD
├── WESAD_Analysis.ipynb
├── WESAD_Data.ipynb
├── WESAD_Data_Numpy.py
├── WESAD_Extract.ipynb
├── WESAD_FastGRNN.ipynb
├── WESAD_GRU.ipynb
├── WESAD_Get_Single.py
├── WESAD_Inference.ipynb
└── WESAD_LSTM.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data
2 | Data/
3 |
4 | # Paper
5 | SAD
6 |
7 | # Other
8 | .DS_Store
9 | .ipynb_checkpoints
10 | __pycache__
11 |
--------------------------------------------------------------------------------
/DREAMER/DREAMER_Arousal_GRU_48_16.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# DREAMER Arousal EMI-GRU 48_16"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import pandas as pd\n",
24 | "import numpy as np\n",
25 | "from tabulate import tabulate\n",
26 | "import os\n",
27 | "import datetime as datetime\n",
28 | "import pickle as pkl\n",
29 | "import pathlib"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {
36 | "ExecuteTime": {
37 | "end_time": "2018-12-14T14:17:51.796585Z",
38 | "start_time": "2018-12-14T14:17:49.648375Z"
39 | }
40 | },
41 | "outputs": [
42 | {
43 | "name": "stderr",
44 | "output_type": "stream",
45 | "text": [
46 | "Using TensorFlow backend.\n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "from __future__ import print_function\n",
52 | "import os\n",
53 | "import sys\n",
54 | "import tensorflow as tf\n",
55 | "import numpy as np\n",
56 | "# Making sure edgeml is part of python path\n",
57 | "sys.path.insert(0, '../../')\n",
58 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
59 | "\n",
60 | "np.random.seed(42)\n",
61 | "tf.set_random_seed(42)\n",
62 | "\n",
63 | "# MI-RNN and EMI-RNN imports\n",
64 | "from edgeml.graph.rnn import EMI_DataPipeline\n",
65 | "from edgeml.graph.rnn import EMI_GRU\n",
66 | "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
67 | "import edgeml.utils\n",
68 | "\n",
69 | "import keras.backend as K\n",
70 | "cfg = K.tf.ConfigProto()\n",
71 | "cfg.gpu_options.allow_growth = True\n",
72 | "K.set_session(K.tf.Session(config=cfg))"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 3,
78 | "metadata": {
79 | "ExecuteTime": {
80 | "end_time": "2018-12-14T14:17:51.803381Z",
81 | "start_time": "2018-12-14T14:17:51.798799Z"
82 | }
83 | },
84 | "outputs": [],
85 | "source": [
86 | "# Network parameters for our LSTM + FC Layer\n",
87 | "NUM_HIDDEN = 128\n",
88 | "NUM_TIMESTEPS = 48\n",
89 | "ORIGINAL_NUM_TIMESTEPS = 128\n",
90 | "NUM_FEATS = 16\n",
91 | "FORGET_BIAS = 1.0\n",
92 | "NUM_OUTPUT = 5\n",
93 | "USE_DROPOUT = True\n",
94 | "KEEP_PROB = 0.75\n",
95 | "\n",
96 | "# For dataset API\n",
97 | "PREFETCH_NUM = 5\n",
98 | "BATCH_SIZE = 32\n",
99 | "\n",
100 | "# Number of epochs in *one iteration*\n",
101 | "NUM_EPOCHS = 2\n",
102 | "\n",
103 | "# Number of iterations in *one round*. After each iteration,\n",
104 | "# the model is dumped to disk. At the end of the current\n",
105 | "# round, the best model among all the dumped models in the\n",
106 | "# current round is picked up..\n",
107 | "NUM_ITER = 4\n",
108 | "\n",
109 | "# A round consists of multiple training iterations and a belief\n",
110 | "# update step using the best model from all of these iterations\n",
111 | "NUM_ROUNDS = 10\n",
112 | "LEARNING_RATE=0.001\n",
113 | "\n",
114 | "# A staging direcory to store models\n",
115 | "MODEL_PREFIX = '/home/sf/data/DREAMER/Arousal/models/model-gru'"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {
121 | "heading_collapsed": true
122 | },
123 | "source": [
124 | "# Loading Data"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 4,
130 | "metadata": {
131 | "ExecuteTime": {
132 | "end_time": "2018-12-14T14:17:52.040352Z",
133 | "start_time": "2018-12-14T14:17:51.805319Z"
134 | },
135 | "hidden": true
136 | },
137 | "outputs": [
138 | {
139 | "name": "stdout",
140 | "output_type": "stream",
141 | "text": [
142 | "x_train shape is: (61735, 6, 48, 16)\n",
143 | "y_train shape is: (61735, 6, 5)\n",
144 | "x_test shape is: (6860, 6, 48, 16)\n",
145 | "y_test shape is: (6860, 6, 5)\n"
146 | ]
147 | }
148 | ],
149 | "source": [
150 | "# Loading the data\n",
151 | "x_train, y_train = np.load('/home/sf/data/DREAMER/Arousal/48_16/x_train.npy'), np.load('/home/sf/data/DREAMER/Arousal/48_16/y_train.npy')\n",
152 | "x_test, y_test = np.load('/home/sf/data/DREAMER/Arousal/48_16/x_test.npy'), np.load('/home/sf/data/DREAMER/Arousal/48_16/y_test.npy')\n",
153 | "x_val, y_val = np.load('/home/sf/data/DREAMER/Arousal/48_16/x_val.npy'), np.load('/home/sf/data/DREAMER/Arousal/48_16/y_val.npy')\n",
154 | "\n",
155 | "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
156 | "# step of EMI/MI RNN\n",
157 | "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
158 | "BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
159 | "BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
160 | "NUM_SUBINSTANCE = x_train.shape[1]\n",
161 | "print(\"x_train shape is:\", x_train.shape)\n",
162 | "print(\"y_train shape is:\", y_train.shape)\n",
163 | "print(\"x_test shape is:\", x_val.shape)\n",
164 | "print(\"y_test shape is:\", y_val.shape)"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {},
170 | "source": [
171 | "# Computation Graph"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 5,
177 | "metadata": {
178 | "ExecuteTime": {
179 | "end_time": "2018-12-14T14:17:52.053161Z",
180 | "start_time": "2018-12-14T14:17:52.042928Z"
181 | }
182 | },
183 | "outputs": [],
184 | "source": [
185 | "# Define the linear secondary classifier\n",
186 | "def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
187 | " W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
188 | " B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
189 | " y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
190 | " self.output = y_cap\n",
191 | " self.graphCreated = True\n",
192 | "\n",
193 | "def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
194 | " y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
195 | " self.output = y_cap\n",
196 | " self.graphCreated = True\n",
197 | " \n",
198 | "def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
199 | " if inference is False:\n",
200 | " feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
201 | " else:\n",
202 | " feedDict = {self._emiGraph.keep_prob: 1.0}\n",
203 | " return feedDict\n",
204 | " \n",
205 | "EMI_GRU._createExtendedGraph = createExtendedGraph\n",
206 | "EMI_GRU._restoreExtendedGraph = restoreExtendedGraph\n",
207 | "\n",
208 | "if USE_DROPOUT is True:\n",
209 | " EMI_Driver.feedDictFunc = feedDictFunc"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 6,
215 | "metadata": {
216 | "ExecuteTime": {
217 | "end_time": "2018-12-14T14:17:52.335299Z",
218 | "start_time": "2018-12-14T14:17:52.055483Z"
219 | }
220 | },
221 | "outputs": [],
222 | "source": [
223 | "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
224 | "emiGRU = EMI_GRU(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
225 | " useDropout=USE_DROPOUT)\n",
226 | "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n",
227 | " stepSize=LEARNING_RATE)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 7,
233 | "metadata": {
234 | "ExecuteTime": {
235 | "end_time": "2018-12-14T14:18:05.031382Z",
236 | "start_time": "2018-12-14T14:17:52.338750Z"
237 | }
238 | },
239 | "outputs": [],
240 | "source": [
241 | "tf.reset_default_graph()\n",
242 | "g1 = tf.Graph() \n",
243 | "with g1.as_default():\n",
244 | " # Obtain the iterators to each batch of the data\n",
245 | " x_batch, y_batch = inputPipeline()\n",
246 | " # Create the forward computation graph based on the iterators\n",
247 | " y_cap = emiGRU(x_batch)\n",
248 | " # Create loss graphs and training routines\n",
249 | " emiTrainer(y_cap, y_batch)"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {},
255 | "source": [
256 | "# EMI Driver"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 8,
262 | "metadata": {
263 | "ExecuteTime": {
264 | "end_time": "2018-12-14T14:35:15.209910Z",
265 | "start_time": "2018-12-14T14:18:05.034359Z"
266 | }
267 | },
268 | "outputs": [
269 | {
270 | "name": "stdout",
271 | "output_type": "stream",
272 | "text": [
273 | "Update policy: top-k\n",
274 | "Training with MI-RNN loss for 5 rounds\n",
275 | "Round: 0\n",
276 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02828 Acc 0.39062 | Val acc 0.36778 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1000\n",
277 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02747 Acc 0.44271 | Val acc 0.40481 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1001\n",
278 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02699 Acc 0.46875 | Val acc 0.44227 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1002\n",
279 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02365 Acc 0.59375 | Val acc 0.47216 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1003\n",
280 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1003\n",
281 | "Round: 1\n",
282 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02228 Acc 0.60938 | Val acc 0.51793 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1004\n",
283 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02062 Acc 0.63021 | Val acc 0.53601 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1005\n",
284 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02085 Acc 0.61458 | Val acc 0.55641 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1006\n",
285 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01955 Acc 0.62500 | Val acc 0.58499 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1007\n",
286 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1007\n",
287 | "Round: 2\n",
288 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01866 Acc 0.67188 | Val acc 0.59650 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1008\n",
289 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01890 Acc 0.69792 | Val acc 0.61589 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1009\n",
290 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01878 Acc 0.69271 | Val acc 0.62085 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1010\n",
291 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01786 Acc 0.68229 | Val acc 0.62857 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1011\n",
292 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1011\n",
293 | "Round: 3\n",
294 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01762 Acc 0.75521 | Val acc 0.63980 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1012\n",
295 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01638 Acc 0.76562 | Val acc 0.64490 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1013\n",
296 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01718 Acc 0.71875 | Val acc 0.63936 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1014\n",
297 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01692 Acc 0.69792 | Val acc 0.65000 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1015\n",
298 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1015\n",
299 | "Round: 4\n",
300 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01506 Acc 0.77083 | Val acc 0.65729 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1016\n",
301 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01559 Acc 0.75000 | Val acc 0.66166 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1017\n",
302 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01509 Acc 0.75521 | Val acc 0.66166 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1018\n",
303 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01540 Acc 0.74479 | Val acc 0.66283 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1019\n",
304 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1019\n",
305 | "Round: 5\n",
306 | "Switching to EMI-Loss function\n",
307 | "Epoch 1 Batch 1925 ( 3855) Loss 0.92278 Acc 0.77083 | Val acc 0.64446 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1020\n",
308 | "Epoch 1 Batch 1925 ( 3855) Loss 0.93813 Acc 0.73438 | Val acc 0.65058 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1021\n",
309 | "Epoch 1 Batch 1925 ( 3855) Loss 0.94226 Acc 0.70833 | Val acc 0.64781 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1022\n",
310 | "Epoch 1 Batch 1925 ( 3855) Loss 0.90150 Acc 0.75000 | Val acc 0.65437 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1023\n",
311 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1023\n",
312 | "Round: 6\n",
313 | "Epoch 1 Batch 1925 ( 3855) Loss 0.88298 Acc 0.77083 | Val acc 0.65277 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1024\n",
314 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87756 Acc 0.72917 | Val acc 0.65306 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1025\n",
315 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87819 Acc 0.71875 | Val acc 0.64927 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1026\n",
316 | "Epoch 1 Batch 1925 ( 3855) Loss 0.82530 Acc 0.71354 | Val acc 0.65350 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1027\n",
317 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1027\n",
318 | "Round: 7\n",
319 | "Epoch 1 Batch 1925 ( 3855) Loss 0.84559 Acc 0.76562 | Val acc 0.65306 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1028\n",
320 | "Epoch 1 Batch 1925 ( 3855) Loss 0.81977 Acc 0.79167 | Val acc 0.64636 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1029\n",
321 | "Epoch 1 Batch 1925 ( 3855) Loss 0.83083 Acc 0.76042 | Val acc 0.64665 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1030\n",
322 | "Epoch 1 Batch 1925 ( 3855) Loss 0.80918 Acc 0.75000 | Val acc 0.64796 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1031\n",
323 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1028\n",
324 | "Round: 8\n",
325 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87074 Acc 0.71875 | Val acc 0.65015 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1032\n",
326 | "Epoch 1 Batch 1925 ( 3855) Loss 0.83685 Acc 0.76562 | Val acc 0.65131 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1033\n",
327 | "Epoch 1 Batch 1925 ( 3855) Loss 0.81599 Acc 0.75521 | Val acc 0.65117 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1034\n",
328 | "Epoch 1 Batch 1925 ( 3855) Loss 0.83342 Acc 0.72917 | Val acc 0.64869 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1035\n",
329 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1033\n",
330 | "Round: 9\n",
331 | "Epoch 1 Batch 1925 ( 3855) Loss 0.82510 Acc 0.76042 | Val acc 0.64606 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1036\n",
332 | "Epoch 1 Batch 1925 ( 3855) Loss 0.82719 Acc 0.76562 | Val acc 0.65335 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1037\n",
333 | "Epoch 1 Batch 1925 ( 3855) Loss 0.81562 Acc 0.76042 | Val acc 0.64592 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1038\n",
334 | "Epoch 1 Batch 1925 ( 3855) Loss 0.82774 Acc 0.79688 | Val acc 0.65087 | Model saved to /home/sf/data/DREAMER/Arousal/models/model-gru, global_step 1039\n",
335 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1037\n"
336 | ]
337 | }
338 | ],
339 | "source": [
340 | "with g1.as_default():\n",
341 | " emiDriver = EMI_Driver(inputPipeline, emiGRU, emiTrainer)\n",
342 | "\n",
343 | "emiDriver.initializeSession(g1)\n",
344 | "y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
345 | " y_train=y_train, bag_train=BAG_TRAIN,\n",
346 | " x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
347 | " numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
348 | " numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
349 | " numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
350 | " fracEMI=0.5, updatePolicy='top-k', k=1)"
351 | ]
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {},
356 | "source": [
357 | "# Evaluating the trained model"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": 9,
363 | "metadata": {
364 | "ExecuteTime": {
365 | "end_time": "2018-12-14T14:35:15.218040Z",
366 | "start_time": "2018-12-14T14:35:15.211771Z"
367 | }
368 | },
369 | "outputs": [],
370 | "source": [
371 | "# Early Prediction Policy: We make an early prediction based on the predicted classes\n",
372 | "# probability. If the predicted class probability > minProb at some step, we make\n",
373 | "# a prediction at that step.\n",
374 | "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
375 | " assert instanceOut.ndim == 2\n",
376 | " classes = np.argmax(instanceOut, axis=1)\n",
377 | " prob = np.max(instanceOut, axis=1)\n",
378 | " index = np.where(prob >= minProb)[0]\n",
379 | " if len(index) == 0:\n",
380 | " assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
381 | " return classes[-1], len(instanceOut) - 1\n",
382 | " index = index[0]\n",
383 | " return classes[index], index\n",
384 | "\n",
385 | "def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
386 | " predictionStep = predictionStep + 1\n",
387 | " predictionStep = np.reshape(predictionStep, -1)\n",
388 | " totalSteps = np.sum(predictionStep)\n",
389 | " maxSteps = len(predictionStep) * numTimeSteps\n",
390 | " savings = 1.0 - (totalSteps / maxSteps)\n",
391 | " if returnTotal:\n",
392 | " return savings, totalSteps\n",
393 | " return savings"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": 10,
399 | "metadata": {
400 | "ExecuteTime": {
401 | "end_time": "2018-12-14T14:35:16.257489Z",
402 | "start_time": "2018-12-14T14:35:15.221029Z"
403 | },
404 | "scrolled": true
405 | },
406 | "outputs": [
407 | {
408 | "name": "stdout",
409 | "output_type": "stream",
410 | "text": [
411 | "Accuracy at k = 2: 0.646918\n",
412 | "Savings due to MI-RNN : 0.625000\n",
413 | "Savings due to Early prediction: 0.133430\n",
414 | "Total Savings: 0.675036\n"
415 | ]
416 | }
417 | ],
418 | "source": [
419 | "k = 2\n",
420 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
421 | " minProb=0.99, keep_prob=1.0)\n",
422 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
423 | "print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
424 | "mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
425 | "emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
426 | "total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
427 | "print('Savings due to MI-RNN : %f' % mi_savings)\n",
428 | "print('Savings due to Early prediction: %f' % emi_savings)\n",
429 | "print('Total Savings: %f' % (total_savings))"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 11,
435 | "metadata": {
436 | "ExecuteTime": {
437 | "end_time": "2018-12-14T14:35:17.044115Z",
438 | "start_time": "2018-12-14T14:35:16.259280Z"
439 | },
440 | "scrolled": false
441 | },
442 | "outputs": [
443 | {
444 | "name": "stdout",
445 | "output_type": "stream",
446 | "text": [
447 | " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
448 | "0 1 0.646452 0.608454 0.689071 0.581906 0.646452 0.646452 \n",
449 | "1 2 0.646918 0.611034 0.620723 0.606872 0.646918 0.646918 \n",
450 | "2 3 0.607674 0.556231 0.581544 0.603607 0.607674 0.607674 \n",
451 | "3 4 0.528252 0.509008 0.619781 0.556418 0.528252 0.528252 \n",
452 | "4 5 0.466266 0.472236 0.663204 0.517449 0.466266 0.466266 \n",
453 | "5 6 0.422182 0.442379 0.697692 0.487339 0.422182 0.422182 \n",
454 | "\n",
455 | " micro-rec \n",
456 | "0 0.646452 \n",
457 | "1 0.646918 \n",
458 | "2 0.607674 \n",
459 | "3 0.528252 \n",
460 | "4 0.466266 \n",
461 | "5 0.422182 \n",
462 | "Max accuracy 0.646918 at subsequencelength 2\n",
463 | "Max micro-f 0.646918 at subsequencelength 2\n",
464 | "Micro-precision 0.646918 at subsequencelength 2\n",
465 | "Micro-recall 0.646918 at subsequencelength 2\n",
466 | "Max macro-f 0.611034 at subsequencelength 2\n",
467 | "macro-precision 0.620723 at subsequencelength 2\n",
468 | "macro-recall 0.606872 at subsequencelength 2\n"
469 | ]
470 | }
471 | ],
472 | "source": [
473 | "# A slightly more detailed analysis method is provided. \n",
474 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)"
475 | ]
476 | },
477 | {
478 | "cell_type": "markdown",
479 | "metadata": {},
480 | "source": [
481 | "## Picking the best model"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": 12,
487 | "metadata": {
488 | "ExecuteTime": {
489 | "end_time": "2018-12-14T14:35:54.899340Z",
490 | "start_time": "2018-12-14T14:35:17.047464Z"
491 | }
492 | },
493 | "outputs": [
494 | {
495 | "name": "stdout",
496 | "output_type": "stream",
497 | "text": [
498 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1003\n",
499 | "Round: 0, Validation accuracy: 0.4722, Test Accuracy (k = 2): 0.474197, Total Savings: 0.629132\n",
500 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1007\n",
501 | "Round: 1, Validation accuracy: 0.5850, Test Accuracy (k = 2): 0.583999, Total Savings: 0.633690\n",
502 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1011\n",
503 | "Round: 2, Validation accuracy: 0.6286, Test Accuracy (k = 2): 0.632340, Total Savings: 0.637225\n",
504 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1015\n",
505 | "Round: 3, Validation accuracy: 0.6500, Test Accuracy (k = 2): 0.652983, Total Savings: 0.640312\n",
506 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1019\n",
507 | "Round: 4, Validation accuracy: 0.6628, Test Accuracy (k = 2): 0.666919, Total Savings: 0.643258\n",
508 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1023\n",
509 | "Round: 5, Validation accuracy: 0.6544, Test Accuracy (k = 2): 0.654440, Total Savings: 0.661873\n",
510 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1027\n",
511 | "Round: 6, Validation accuracy: 0.6535, Test Accuracy (k = 2): 0.654440, Total Savings: 0.668051\n",
512 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1028\n",
513 | "Round: 7, Validation accuracy: 0.6531, Test Accuracy (k = 2): 0.652983, Total Savings: 0.670961\n",
514 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1033\n",
515 | "Round: 8, Validation accuracy: 0.6513, Test Accuracy (k = 2): 0.653741, Total Savings: 0.673200\n",
516 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Arousal/models/model-gru-1037\n",
517 | "Round: 9, Validation accuracy: 0.6534, Test Accuracy (k = 2): 0.646918, Total Savings: 0.675036\n"
518 | ]
519 | }
520 | ],
521 | "source": [
522 | "devnull = open(os.devnull, 'r')\n",
523 | "for val in modelStats:\n",
524 | " round_, acc, modelPrefix, globalStep = val\n",
525 | " emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
526 | " predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
527 | " minProb=0.99, keep_prob=1.0)\n",
528 | "\n",
529 | " bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
530 | " print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
531 | " print(', Test Accuracy (k = %d): %f, ' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
532 | " mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
533 | " emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
534 | " total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
535 | " print(\"Total Savings: %f\" % total_savings)"
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": 13,
541 | "metadata": {},
542 | "outputs": [],
543 | "source": [
544 | "params = {\n",
545 | " \"NUM_HIDDEN\" : 128,\n",
546 | " \"NUM_TIMESTEPS\" : 48, #subinstance length.\n",
547 | " \"ORIGINAL_NUM_TIMESTEPS\" : 128,\n",
548 | " \"NUM_FEATS\" : 16,\n",
549 | " \"FORGET_BIAS\" : 1.0,\n",
550 | " \"NUM_OUTPUT\" : 5,\n",
551 | " \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n",
552 | " \"KEEP_PROB\" : 0.75,\n",
553 | " \"PREFETCH_NUM\" : 5,\n",
554 | " \"BATCH_SIZE\" : 32,\n",
555 | " \"NUM_EPOCHS\" : 2,\n",
556 | " \"NUM_ITER\" : 4,\n",
557 | " \"NUM_ROUNDS\" : 10,\n",
558 | " \"LEARNING_RATE\" : 0.001,\n",
559 | " \"MODEL_PREFIX\" : '/home/sf/data/DREAMER/Arousal/model-gru'\n",
560 | "}"
561 | ]
562 | },
563 | {
564 | "cell_type": "code",
565 | "execution_count": 14,
566 | "metadata": {},
567 | "outputs": [
568 | {
569 | "name": "stdout",
570 | "output_type": "stream",
571 | "text": [
572 | " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
573 | "0 1 0.646452 0.608454 0.689071 0.581906 0.646452 0.646452 \n",
574 | "1 2 0.646918 0.611034 0.620723 0.606872 0.646918 0.646918 \n",
575 | "2 3 0.607674 0.556231 0.581544 0.603607 0.607674 0.607674 \n",
576 | "3 4 0.528252 0.509008 0.619781 0.556418 0.528252 0.528252 \n",
577 | "4 5 0.466266 0.472236 0.663204 0.517449 0.466266 0.466266 \n",
578 | "5 6 0.422182 0.442379 0.697692 0.487339 0.422182 0.422182 \n",
579 | "\n",
580 | " micro-rec \n",
581 | "0 0.646452 \n",
582 | "1 0.646918 \n",
583 | "2 0.607674 \n",
584 | "3 0.528252 \n",
585 | "4 0.466266 \n",
586 | "5 0.422182 \n",
587 | "Max accuracy 0.646918 at subsequencelength 2\n",
588 | "Max micro-f 0.646918 at subsequencelength 2\n",
589 | "Micro-precision 0.646918 at subsequencelength 2\n",
590 | "Micro-recall 0.646918 at subsequencelength 2\n",
591 | "Max macro-f 0.611034 at subsequencelength 2\n",
592 | "macro-precision 0.620723 at subsequencelength 2\n",
593 | "macro-recall 0.606872 at subsequencelength 2\n",
594 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
595 | "| | len | acc | macro-fsc | macro-pre | macro-rec | micro-fsc | micro-pre | micro-rec |\n",
596 | "+====+=======+==========+=============+=============+=============+=============+=============+=============+\n",
597 | "| 0 | 1 | 0.646452 | 0.608454 | 0.689071 | 0.581906 | 0.646452 | 0.646452 | 0.646452 |\n",
598 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
599 | "| 1 | 2 | 0.646918 | 0.611034 | 0.620723 | 0.606872 | 0.646918 | 0.646918 | 0.646918 |\n",
600 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
601 | "| 2 | 3 | 0.607674 | 0.556231 | 0.581544 | 0.603607 | 0.607674 | 0.607674 | 0.607674 |\n",
602 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
603 | "| 3 | 4 | 0.528252 | 0.509008 | 0.619781 | 0.556418 | 0.528252 | 0.528252 | 0.528252 |\n",
604 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
605 | "| 4 | 5 | 0.466266 | 0.472236 | 0.663204 | 0.517449 | 0.466266 | 0.466266 | 0.466266 |\n",
606 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
607 | "| 5 | 6 | 0.422182 | 0.442379 | 0.697692 | 0.487339 | 0.422182 | 0.422182 | 0.422182 |\n",
608 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n"
609 | ]
610 | }
611 | ],
612 | "source": [
613 | "gru_dict = {**params}\n",
614 | "gru_dict[\"k\"] = k\n",
615 | "gru_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
616 | "gru_dict[\"total_savings\"] = total_savings\n",
617 | "gru_dict[\"y_test\"] = BAG_TEST\n",
618 | "gru_dict[\"y_pred\"] = bagPredictions\n",
619 | "\n",
620 | "# A slightly more detailed analysis method is provided. \n",
621 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
622 | "print (tabulate(df, headers=list(df.columns), tablefmt='grid'))"
623 | ]
624 | },
625 | {
626 | "cell_type": "code",
627 | "execution_count": 15,
628 | "metadata": {},
629 | "outputs": [
630 | {
631 | "name": "stdout",
632 | "output_type": "stream",
633 | "text": [
634 | "Results for this run have been saved at /home/sf/data/DREAMER/Arousal/GRU/ .\n"
635 | ]
636 | }
637 | ],
638 | "source": [
639 | "dirname = \"/home/sf/data/DREAMER/Arousal/GRU/\"\n",
640 | "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
641 | "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
642 | "\n",
643 | "now = datetime.datetime.now()\n",
644 | "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
645 | "filename = ''.join(filename)\n",
646 | "\n",
647 | "# Save the dictionary containing the params and the results.\n",
648 | "pkl.dump(gru_dict,open(dirname + filename + \".pkl\",mode='wb'))"
649 | ]
650 | },
651 | {
652 | "cell_type": "code",
653 | "execution_count": 16,
654 | "metadata": {},
655 | "outputs": [
656 | {
657 | "data": {
658 | "text/plain": [
659 | "'/home/sf/data/DREAMER/Arousal/GRU/2019-8-14|1-54.pkl'"
660 | ]
661 | },
662 | "execution_count": 16,
663 | "metadata": {},
664 | "output_type": "execute_result"
665 | }
666 | ],
667 | "source": [
668 | "dirname+filename+'.pkl'"
669 | ]
670 | }
671 | ],
672 | "metadata": {
673 | "kernelspec": {
674 | "display_name": "Python 3",
675 | "language": "python",
676 | "name": "python3"
677 | },
678 | "language_info": {
679 | "codemirror_mode": {
680 | "name": "ipython",
681 | "version": 3
682 | },
683 | "file_extension": ".py",
684 | "mimetype": "text/x-python",
685 | "name": "python",
686 | "nbconvert_exporter": "python",
687 | "pygments_lexer": "ipython3",
688 | "version": "3.7.3"
689 | },
690 | "latex_envs": {
691 | "LaTeX_envs_menu_present": true,
692 | "autoclose": false,
693 | "autocomplete": true,
694 | "bibliofile": "biblio.bib",
695 | "cite_by": "apalike",
696 | "current_citInitial": 1,
697 | "eqLabelWithNumbers": true,
698 | "eqNumInitial": 1,
699 | "hotkeys": {
700 | "equation": "Ctrl-E",
701 | "itemize": "Ctrl-I"
702 | },
703 | "labels_anchors": false,
704 | "latex_user_defs": false,
705 | "report_style_numbering": false,
706 | "user_envs_cfg": false
707 | }
708 | },
709 | "nbformat": 4,
710 | "nbformat_minor": 2
711 | }
712 |
--------------------------------------------------------------------------------
/DREAMER/DREAMER_Data_Preprocessing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import numpy as np"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 4,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "ecg = pd.read_csv(\"ECG.csv\")\n",
20 | "eeg = pd.read_csv(\"EEG.csv\")\n",
21 | "labels = pd.read_csv(\"Labels.csv\")"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 23,
27 | "metadata": {
28 | "scrolled": true
29 | },
30 | "outputs": [],
31 | "source": [
32 | "Arousal = labels[labels['Label'] == 'Arousal']\n",
33 | "Dominance = labels[labels['Label'] == 'Dominance']\n",
34 | "Valence = labels[labels['Label'] == 'Valence']"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 28,
40 | "metadata": {},
41 | "outputs": [
42 | {
43 | "data": {
44 | "text/plain": [
45 | "True"
46 | ]
47 | },
48 | "execution_count": 28,
49 | "metadata": {},
50 | "output_type": "execute_result"
51 | }
52 | ],
53 | "source": [
54 | "labels.shape[0] == Arousal.shape[0] + Dominance.shape[0] + Valence.shape[0]"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 25,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "Arousal = Arousal.rename({'Score':'Arousal'}, axis=1)\n",
64 | "Dominance = Dominance.rename({'Score':'Dominance'}, axis=1)\n",
65 | "Valence = Valence.rename({'Score':'Valence'}, axis=1)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 32,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "Index(['ECG1', 'ECG2', 'Person', 'Movie'], dtype='object')\n",
78 | "Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',\n",
79 | " '13', 'Person', 'Movie'],\n",
80 | " dtype='object')\n",
81 | "Index(['Score', 'Person', 'Movie', 'Label'], dtype='object')\n"
82 | ]
83 | }
84 | ],
85 | "source": [
86 | "print(ecg.columns)\n",
87 | "print(eeg.columns)\n",
88 | "print(labels.columns)"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 33,
94 | "metadata": {
95 | "scrolled": true
96 | },
97 | "outputs": [
98 | {
99 | "name": "stdout",
100 | "output_type": "stream",
101 | "text": [
102 | "(10975232, 16)\n",
103 | "(21950464, 4)\n"
104 | ]
105 | }
106 | ],
107 | "source": [
108 | "print(eeg.shape)\n",
109 | "print(ecg.shape)"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 34,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "ecg = pd.merge(ecg, Arousal.drop('Label', axis=1), how='left', on=['Person', 'Movie'])\n",
119 | "ecg = pd.merge(ecg, Dominance.drop('Label', axis=1), how='left', on=['Person', 'Movie'])\n",
120 | "ecg = pd.merge(ecg, Valence.drop('Label', axis=1), how='left', on=['Person', 'Movie'])"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 38,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "eeg = pd.merge(eeg, Arousal.drop('Label', axis=1), how='left', on=['Person', 'Movie'])\n",
130 | "eeg = pd.merge(eeg, Dominance.drop('Label', axis=1), how='left', on=['Person', 'Movie'])\n",
131 | "eeg = pd.merge(eeg, Valence.drop('Label', axis=1), how='left', on=['Person', 'Movie'])"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 39,
137 | "metadata": {},
138 | "outputs": [
139 | {
140 | "name": "stdout",
141 | "output_type": "stream",
142 | "text": [
143 | "(10975232, 19)\n"
144 | ]
145 | }
146 | ],
147 | "source": [
148 | "print(eeg.shape)"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 40,
154 | "metadata": {},
155 | "outputs": [
156 | {
157 | "name": "stdout",
158 | "output_type": "stream",
159 | "text": [
160 | "(10975232, 19)\n",
161 | "(21950464, 7)\n"
162 | ]
163 | }
164 | ],
165 | "source": [
166 | "print(eeg.shape)\n",
167 | "print(ecg.shape)"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 41,
173 | "metadata": {},
174 | "outputs": [
175 | {
176 | "data": {
177 | "text/html": [
178 | "
\n",
179 | "\n",
192 | "
\n",
193 | " \n",
194 | " \n",
195 | " | \n",
196 | " 0 | \n",
197 | " 1 | \n",
198 | " 2 | \n",
199 | " 3 | \n",
200 | " 4 | \n",
201 | " 5 | \n",
202 | " 6 | \n",
203 | " 7 | \n",
204 | " 8 | \n",
205 | " 9 | \n",
206 | " 10 | \n",
207 | " 11 | \n",
208 | " 12 | \n",
209 | " 13 | \n",
210 | " Person | \n",
211 | " Movie | \n",
212 | " Arousal | \n",
213 | " Dominance | \n",
214 | " Valence | \n",
215 | "
\n",
216 | " \n",
217 | " \n",
218 | " \n",
219 | " 0 | \n",
220 | " 4388.2 | \n",
221 | " 4102.6 | \n",
222 | " 4219.5 | \n",
223 | " 4465.1 | \n",
224 | " 4370.8 | \n",
225 | " 4399.5 | \n",
226 | " 4443.1 | \n",
227 | " 4023.1 | \n",
228 | " 4365.1 | \n",
229 | " 4310.3 | \n",
230 | " 3953.8 | \n",
231 | " 4454.4 | \n",
232 | " 4326.2 | \n",
233 | " 4165.1 | \n",
234 | " 1 | \n",
235 | " 1 | \n",
236 | " 3 | \n",
237 | " 2 | \n",
238 | " 4 | \n",
239 | "
\n",
240 | " \n",
241 | " 1 | \n",
242 | " 4375.9 | \n",
243 | " 4093.8 | \n",
244 | " 4252.8 | \n",
245 | " 4522.6 | \n",
246 | " 4435.9 | \n",
247 | " 4411.8 | \n",
248 | " 4488.7 | \n",
249 | " 4108.7 | \n",
250 | " 4399.5 | \n",
251 | " 4384.6 | \n",
252 | " 4007.7 | \n",
253 | " 4466.7 | \n",
254 | " 4372.8 | \n",
255 | " 4247.2 | \n",
256 | " 1 | \n",
257 | " 1 | \n",
258 | " 3 | \n",
259 | " 2 | \n",
260 | " 4 | \n",
261 | "
\n",
262 | " \n",
263 | " 2 | \n",
264 | " 4378.5 | \n",
265 | " 4091.3 | \n",
266 | " 4230.3 | \n",
267 | " 4488.2 | \n",
268 | " 4370.3 | \n",
269 | " 4402.6 | \n",
270 | " 4461.0 | \n",
271 | " 4077.4 | \n",
272 | " 4378.5 | \n",
273 | " 4328.7 | \n",
274 | " 3986.2 | \n",
275 | " 4461.0 | \n",
276 | " 4328.2 | \n",
277 | " 4203.6 | \n",
278 | " 1 | \n",
279 | " 1 | \n",
280 | " 3 | \n",
281 | " 2 | \n",
282 | " 4 | \n",
283 | "
\n",
284 | " \n",
285 | " 3 | \n",
286 | " 4393.8 | \n",
287 | " 4101.0 | \n",
288 | " 4193.3 | \n",
289 | " 4419.0 | \n",
290 | " 4270.3 | \n",
291 | " 4392.3 | \n",
292 | " 4411.3 | \n",
293 | " 3982.6 | \n",
294 | " 4336.4 | \n",
295 | " 4213.3 | \n",
296 | " 3930.3 | \n",
297 | " 4442.6 | \n",
298 | " 4261.0 | \n",
299 | " 4100.0 | \n",
300 | " 1 | \n",
301 | " 1 | \n",
302 | " 3 | \n",
303 | " 2 | \n",
304 | " 4 | \n",
305 | "
\n",
306 | " \n",
307 | " 4 | \n",
308 | " 4396.4 | \n",
309 | " 4108.7 | \n",
310 | " 4210.8 | \n",
311 | " 4436.4 | \n",
312 | " 4310.8 | \n",
313 | " 4401.0 | \n",
314 | " 4426.7 | \n",
315 | " 3980.5 | \n",
316 | " 4349.7 | \n",
317 | " 4238.5 | \n",
318 | " 3945.1 | \n",
319 | " 4446.7 | \n",
320 | " 4289.7 | \n",
321 | " 4115.4 | \n",
322 | " 1 | \n",
323 | " 1 | \n",
324 | " 3 | \n",
325 | " 2 | \n",
326 | " 4 | \n",
327 | "
\n",
328 | " \n",
329 | "
\n",
330 | "
"
331 | ],
332 | "text/plain": [
333 | " 0 1 2 3 4 5 6 7 8 \\\n",
334 | "0 4388.2 4102.6 4219.5 4465.1 4370.8 4399.5 4443.1 4023.1 4365.1 \n",
335 | "1 4375.9 4093.8 4252.8 4522.6 4435.9 4411.8 4488.7 4108.7 4399.5 \n",
336 | "2 4378.5 4091.3 4230.3 4488.2 4370.3 4402.6 4461.0 4077.4 4378.5 \n",
337 | "3 4393.8 4101.0 4193.3 4419.0 4270.3 4392.3 4411.3 3982.6 4336.4 \n",
338 | "4 4396.4 4108.7 4210.8 4436.4 4310.8 4401.0 4426.7 3980.5 4349.7 \n",
339 | "\n",
340 | " 9 10 11 12 13 Person Movie Arousal Dominance \\\n",
341 | "0 4310.3 3953.8 4454.4 4326.2 4165.1 1 1 3 2 \n",
342 | "1 4384.6 4007.7 4466.7 4372.8 4247.2 1 1 3 2 \n",
343 | "2 4328.7 3986.2 4461.0 4328.2 4203.6 1 1 3 2 \n",
344 | "3 4213.3 3930.3 4442.6 4261.0 4100.0 1 1 3 2 \n",
345 | "4 4238.5 3945.1 4446.7 4289.7 4115.4 1 1 3 2 \n",
346 | "\n",
347 | " Valence \n",
348 | "0 4 \n",
349 | "1 4 \n",
350 | "2 4 \n",
351 | "3 4 \n",
352 | "4 4 "
353 | ]
354 | },
355 | "execution_count": 41,
356 | "metadata": {},
357 | "output_type": "execute_result"
358 | }
359 | ],
360 | "source": [
361 | "eeg.head()"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": 42,
367 | "metadata": {},
368 | "outputs": [],
369 | "source": [
370 | "ecg.to_csv(\"ECG_Labelled.csv\")\n",
371 | "eeg.to_csv(\"EEG_Labelled.csv\")"
372 | ]
373 | },
374 | {
375 | "cell_type": "markdown",
376 | "metadata": {},
377 | "source": [
378 | "## Combine ECG and EEG"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 20,
384 | "metadata": {},
385 | "outputs": [
386 | {
387 | "name": "stderr",
388 | "output_type": "stream",
389 | "text": [
390 | "/home/sf/.local/lib/python3.6/site-packages/numpy/lib/arraysetops.py:472: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
391 | " mask |= (ar1 == a)\n"
392 | ]
393 | }
394 | ],
395 | "source": [
396 | "ecg = pd.read_csv(\"ECG_Labelled.csv\", index_col=0)\n",
397 | "eeg = pd.read_csv(\"EEG_Labelled.csv\", index_col=0)"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": 21,
403 | "metadata": {},
404 | "outputs": [
405 | {
406 | "data": {
407 | "text/html": [
408 | "\n",
409 | "\n",
422 | "
\n",
423 | " \n",
424 | " \n",
425 | " | \n",
426 | " 0 | \n",
427 | " 1 | \n",
428 | " 2 | \n",
429 | " 3 | \n",
430 | " 4 | \n",
431 | " 5 | \n",
432 | " 6 | \n",
433 | " 7 | \n",
434 | " 8 | \n",
435 | " 9 | \n",
436 | " 10 | \n",
437 | " 11 | \n",
438 | " 12 | \n",
439 | " 13 | \n",
440 | " Person | \n",
441 | " Movie | \n",
442 | " Arousal | \n",
443 | " Dominance | \n",
444 | " Valence | \n",
445 | "
\n",
446 | " \n",
447 | " \n",
448 | " \n",
449 | " 0 | \n",
450 | " 4388.2 | \n",
451 | " 4102.6 | \n",
452 | " 4219.5 | \n",
453 | " 4465.1 | \n",
454 | " 4370.8 | \n",
455 | " 4399.5 | \n",
456 | " 4443.1 | \n",
457 | " 4023.1 | \n",
458 | " 4365.1 | \n",
459 | " 4310.3 | \n",
460 | " 3953.8 | \n",
461 | " 4454.4 | \n",
462 | " 4326.2 | \n",
463 | " 4165.1 | \n",
464 | " 1 | \n",
465 | " 1 | \n",
466 | " 3 | \n",
467 | " 2 | \n",
468 | " 4 | \n",
469 | "
\n",
470 | " \n",
471 | " 1 | \n",
472 | " 4375.9 | \n",
473 | " 4093.8 | \n",
474 | " 4252.8 | \n",
475 | " 4522.6 | \n",
476 | " 4435.9 | \n",
477 | " 4411.8 | \n",
478 | " 4488.7 | \n",
479 | " 4108.7 | \n",
480 | " 4399.5 | \n",
481 | " 4384.6 | \n",
482 | " 4007.7 | \n",
483 | " 4466.7 | \n",
484 | " 4372.8 | \n",
485 | " 4247.2 | \n",
486 | " 1 | \n",
487 | " 1 | \n",
488 | " 3 | \n",
489 | " 2 | \n",
490 | " 4 | \n",
491 | "
\n",
492 | " \n",
493 | " 2 | \n",
494 | " 4378.5 | \n",
495 | " 4091.3 | \n",
496 | " 4230.3 | \n",
497 | " 4488.2 | \n",
498 | " 4370.3 | \n",
499 | " 4402.6 | \n",
500 | " 4461.0 | \n",
501 | " 4077.4 | \n",
502 | " 4378.5 | \n",
503 | " 4328.7 | \n",
504 | " 3986.2 | \n",
505 | " 4461.0 | \n",
506 | " 4328.2 | \n",
507 | " 4203.6 | \n",
508 | " 1 | \n",
509 | " 1 | \n",
510 | " 3 | \n",
511 | " 2 | \n",
512 | " 4 | \n",
513 | "
\n",
514 | " \n",
515 | " 3 | \n",
516 | " 4393.8 | \n",
517 | " 4101.0 | \n",
518 | " 4193.3 | \n",
519 | " 4419.0 | \n",
520 | " 4270.3 | \n",
521 | " 4392.3 | \n",
522 | " 4411.3 | \n",
523 | " 3982.6 | \n",
524 | " 4336.4 | \n",
525 | " 4213.3 | \n",
526 | " 3930.3 | \n",
527 | " 4442.6 | \n",
528 | " 4261.0 | \n",
529 | " 4100.0 | \n",
530 | " 1 | \n",
531 | " 1 | \n",
532 | " 3 | \n",
533 | " 2 | \n",
534 | " 4 | \n",
535 | "
\n",
536 | " \n",
537 | " 4 | \n",
538 | " 4396.4 | \n",
539 | " 4108.7 | \n",
540 | " 4210.8 | \n",
541 | " 4436.4 | \n",
542 | " 4310.8 | \n",
543 | " 4401.0 | \n",
544 | " 4426.7 | \n",
545 | " 3980.5 | \n",
546 | " 4349.7 | \n",
547 | " 4238.5 | \n",
548 | " 3945.1 | \n",
549 | " 4446.7 | \n",
550 | " 4289.7 | \n",
551 | " 4115.4 | \n",
552 | " 1 | \n",
553 | " 1 | \n",
554 | " 3 | \n",
555 | " 2 | \n",
556 | " 4 | \n",
557 | "
\n",
558 | " \n",
559 | "
\n",
560 | "
"
561 | ],
562 | "text/plain": [
563 | " 0 1 2 3 4 5 6 7 8 \\\n",
564 | "0 4388.2 4102.6 4219.5 4465.1 4370.8 4399.5 4443.1 4023.1 4365.1 \n",
565 | "1 4375.9 4093.8 4252.8 4522.6 4435.9 4411.8 4488.7 4108.7 4399.5 \n",
566 | "2 4378.5 4091.3 4230.3 4488.2 4370.3 4402.6 4461.0 4077.4 4378.5 \n",
567 | "3 4393.8 4101.0 4193.3 4419.0 4270.3 4392.3 4411.3 3982.6 4336.4 \n",
568 | "4 4396.4 4108.7 4210.8 4436.4 4310.8 4401.0 4426.7 3980.5 4349.7 \n",
569 | "\n",
570 | " 9 10 11 12 13 Person Movie Arousal Dominance \\\n",
571 | "0 4310.3 3953.8 4454.4 4326.2 4165.1 1 1 3 2 \n",
572 | "1 4384.6 4007.7 4466.7 4372.8 4247.2 1 1 3 2 \n",
573 | "2 4328.7 3986.2 4461.0 4328.2 4203.6 1 1 3 2 \n",
574 | "3 4213.3 3930.3 4442.6 4261.0 4100.0 1 1 3 2 \n",
575 | "4 4238.5 3945.1 4446.7 4289.7 4115.4 1 1 3 2 \n",
576 | "\n",
577 | " Valence \n",
578 | "0 4 \n",
579 | "1 4 \n",
580 | "2 4 \n",
581 | "3 4 \n",
582 | "4 4 "
583 | ]
584 | },
585 | "execution_count": 21,
586 | "metadata": {},
587 | "output_type": "execute_result"
588 | }
589 | ],
590 | "source": [
591 | "eeg.head()"
592 | ]
593 | },
594 | {
595 | "cell_type": "code",
596 | "execution_count": 22,
597 | "metadata": {},
598 | "outputs": [],
599 | "source": [
600 | "ecg_downsampled = ecg[np.arange(len(ecg)) % 2 == 0] # deleting all odd indices"
601 | ]
602 | },
603 | {
604 | "cell_type": "code",
605 | "execution_count": 23,
606 | "metadata": {},
607 | "outputs": [
608 | {
609 | "data": {
610 | "text/plain": [
611 | "44544"
612 | ]
613 | },
614 | "execution_count": 23,
615 | "metadata": {},
616 | "output_type": "execute_result"
617 | }
618 | ],
619 | "source": [
620 | "ecg_downsampled[(ecg_downsampled['Person'] == 1) & (ecg_downsampled['Movie'] == 3)].shape[0]"
621 | ]
622 | },
623 | {
624 | "cell_type": "code",
625 | "execution_count": 24,
626 | "metadata": {},
627 | "outputs": [],
628 | "source": [
629 | "ecg_downsampled.reset_index(inplace=True, drop=True)"
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "execution_count": 25,
635 | "metadata": {},
636 | "outputs": [
637 | {
638 | "data": {
639 | "text/html": [
640 | "\n",
641 | "\n",
654 | "
\n",
655 | " \n",
656 | " \n",
657 | " | \n",
658 | " ECG1 | \n",
659 | " ECG2 | \n",
660 | " Person | \n",
661 | " Movie | \n",
662 | " Arousal | \n",
663 | " Dominance | \n",
664 | " Valence | \n",
665 | "
\n",
666 | " \n",
667 | " \n",
668 | " \n",
669 | " 0 | \n",
670 | " 2046 | \n",
671 | " 2056 | \n",
672 | " 1 | \n",
673 | " 1 | \n",
674 | " 3 | \n",
675 | " 2 | \n",
676 | " 4 | \n",
677 | "
\n",
678 | " \n",
679 | " 1 | \n",
680 | " 2039 | \n",
681 | " 2059 | \n",
682 | " 1 | \n",
683 | " 1 | \n",
684 | " 3 | \n",
685 | " 2 | \n",
686 | " 4 | \n",
687 | "
\n",
688 | " \n",
689 | " 2 | \n",
690 | " 2041 | \n",
691 | " 2060 | \n",
692 | " 1 | \n",
693 | " 1 | \n",
694 | " 3 | \n",
695 | " 2 | \n",
696 | " 4 | \n",
697 | "
\n",
698 | " \n",
699 | " 3 | \n",
700 | " 2039 | \n",
701 | " 2059 | \n",
702 | " 1 | \n",
703 | " 1 | \n",
704 | " 3 | \n",
705 | " 2 | \n",
706 | " 4 | \n",
707 | "
\n",
708 | " \n",
709 | " 4 | \n",
710 | " 2040 | \n",
711 | " 2056 | \n",
712 | " 1 | \n",
713 | " 1 | \n",
714 | " 3 | \n",
715 | " 2 | \n",
716 | " 4 | \n",
717 | "
\n",
718 | " \n",
719 | "
\n",
720 | "
"
721 | ],
722 | "text/plain": [
723 | " ECG1 ECG2 Person Movie Arousal Dominance Valence\n",
724 | "0 2046 2056 1 1 3 2 4\n",
725 | "1 2039 2059 1 1 3 2 4\n",
726 | "2 2041 2060 1 1 3 2 4\n",
727 | "3 2039 2059 1 1 3 2 4\n",
728 | "4 2040 2056 1 1 3 2 4"
729 | ]
730 | },
731 | "execution_count": 25,
732 | "metadata": {},
733 | "output_type": "execute_result"
734 | }
735 | ],
736 | "source": [
737 | "ecg_downsampled.head()"
738 | ]
739 | },
740 | {
741 | "cell_type": "code",
742 | "execution_count": 26,
743 | "metadata": {},
744 | "outputs": [
745 | {
746 | "data": {
747 | "text/plain": [
748 | "44544"
749 | ]
750 | },
751 | "execution_count": 26,
752 | "metadata": {},
753 | "output_type": "execute_result"
754 | }
755 | ],
756 | "source": [
757 | "eeg[(eeg['Person'] == 1) & (eeg['Movie'] == 3)].shape[0] # check"
758 | ]
759 | },
760 | {
761 | "cell_type": "code",
762 | "execution_count": 27,
763 | "metadata": {},
764 | "outputs": [
765 | {
766 | "data": {
767 | "text/plain": [
768 | "True"
769 | ]
770 | },
771 | "execution_count": 27,
772 | "metadata": {},
773 | "output_type": "execute_result"
774 | }
775 | ],
776 | "source": [
777 | "eeg.shape[0] == ecg_downsampled.shape[0] # check"
778 | ]
779 | },
780 | {
781 | "cell_type": "code",
782 | "execution_count": 28,
783 | "metadata": {},
784 | "outputs": [
785 | {
786 | "data": {
787 | "text/plain": [
788 | "Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',\n",
789 | " '13', 'Person', 'Movie', 'Arousal', 'Dominance', 'Valence'],\n",
790 | " dtype='object')"
791 | ]
792 | },
793 | "execution_count": 28,
794 | "metadata": {},
795 | "output_type": "execute_result"
796 | }
797 | ],
798 | "source": [
799 | "eeg.columns"
800 | ]
801 | },
802 | {
803 | "cell_type": "code",
804 | "execution_count": 29,
805 | "metadata": {},
806 | "outputs": [],
807 | "source": [
808 | "del ecg"
809 | ]
810 | },
811 | {
812 | "cell_type": "code",
813 | "execution_count": 30,
814 | "metadata": {},
815 | "outputs": [],
816 | "source": [
817 | "combined = pd.concat([eeg, ecg_downsampled.drop(['Person', 'Movie', 'Arousal', 'Dominance', 'Valence'], axis=1)], axis=1)"
818 | ]
819 | },
820 | {
821 | "cell_type": "code",
822 | "execution_count": 31,
823 | "metadata": {},
824 | "outputs": [
825 | {
826 | "data": {
827 | "text/plain": [
828 | "(10975232, 7)"
829 | ]
830 | },
831 | "execution_count": 31,
832 | "metadata": {},
833 | "output_type": "execute_result"
834 | }
835 | ],
836 | "source": [
837 | "ecg_downsampled.shape"
838 | ]
839 | },
840 | {
841 | "cell_type": "code",
842 | "execution_count": 32,
843 | "metadata": {},
844 | "outputs": [
845 | {
846 | "data": {
847 | "text/plain": [
848 | "(10975232, 21)"
849 | ]
850 | },
851 | "execution_count": 32,
852 | "metadata": {},
853 | "output_type": "execute_result"
854 | }
855 | ],
856 | "source": [
857 | "combined.shape"
858 | ]
859 | },
860 | {
861 | "cell_type": "code",
862 | "execution_count": 34,
863 | "metadata": {},
864 | "outputs": [
865 | {
866 | "data": {
867 | "text/plain": [
868 | "Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',\n",
869 | " '13', 'Person', 'Movie', 'Arousal', 'Dominance', 'Valence', 'ECG1',\n",
870 | " 'ECG2'],\n",
871 | " dtype='object')"
872 | ]
873 | },
874 | "execution_count": 34,
875 | "metadata": {},
876 | "output_type": "execute_result"
877 | }
878 | ],
879 | "source": [
880 | "combined.columns"
881 | ]
882 | },
883 | {
884 | "cell_type": "code",
885 | "execution_count": 35,
886 | "metadata": {},
887 | "outputs": [],
888 | "source": [
889 | "combined.to_csv(\"DREAMER_combined.csv\")"
890 | ]
891 | }
892 | ],
893 | "metadata": {
894 | "kernelspec": {
895 | "display_name": "Python 3",
896 | "language": "python",
897 | "name": "python3"
898 | },
899 | "language_info": {
900 | "codemirror_mode": {
901 | "name": "ipython",
902 | "version": 3
903 | },
904 | "file_extension": ".py",
905 | "mimetype": "text/x-python",
906 | "name": "python",
907 | "nbconvert_exporter": "python",
908 | "pygments_lexer": "ipython3",
909 | "version": "3.7.3"
910 | }
911 | },
912 | "nbformat": 4,
913 | "nbformat_minor": 2
914 | }
915 |
--------------------------------------------------------------------------------
/DREAMER/DREAMER_Dominance_GRU_48_16.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# DREAMER Dominance EMI-GRU 48_16"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import pandas as pd\n",
24 | "import numpy as np\n",
25 | "from tabulate import tabulate\n",
26 | "import os\n",
27 | "import datetime as datetime\n",
28 | "import pickle as pkl\n",
29 | "import pathlib"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {
36 | "ExecuteTime": {
37 | "end_time": "2018-12-14T14:17:51.796585Z",
38 | "start_time": "2018-12-14T14:17:49.648375Z"
39 | }
40 | },
41 | "outputs": [
42 | {
43 | "name": "stderr",
44 | "output_type": "stream",
45 | "text": [
46 | "Using TensorFlow backend.\n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "from __future__ import print_function\n",
52 | "import os\n",
53 | "import sys\n",
54 | "import tensorflow as tf\n",
55 | "import numpy as np\n",
56 | "# Making sure edgeml is part of python path\n",
57 | "sys.path.insert(0, '../../')\n",
58 | "#For processing on CPU.\n",
59 | "os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
60 | "\n",
61 | "np.random.seed(42)\n",
62 | "tf.set_random_seed(42)\n",
63 | "\n",
64 | "# MI-RNN and EMI-RNN imports\n",
65 | "from edgeml.graph.rnn import EMI_DataPipeline\n",
66 | "from edgeml.graph.rnn import EMI_GRU\n",
67 | "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
68 | "import edgeml.utils\n",
69 | "\n",
70 | "import keras.backend as K\n",
71 | "cfg = K.tf.ConfigProto()\n",
72 | "cfg.gpu_options.allow_growth = True\n",
73 | "K.set_session(K.tf.Session(config=cfg))"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 10,
79 | "metadata": {
80 | "ExecuteTime": {
81 | "end_time": "2018-12-14T14:17:51.803381Z",
82 | "start_time": "2018-12-14T14:17:51.798799Z"
83 | }
84 | },
85 | "outputs": [],
86 | "source": [
87 | "# Network parameters for our LSTM + FC Layer\n",
88 | "NUM_HIDDEN = 128\n",
89 | "NUM_TIMESTEPS = 48\n",
90 | "ORIGINAL_NUM_TIMESTEPS = 128\n",
91 | "NUM_FEATS = 16\n",
92 | "FORGET_BIAS = 1.0\n",
93 | "NUM_OUTPUT = 5\n",
94 | "USE_DROPOUT = True\n",
95 | "KEEP_PROB = 0.75\n",
96 | "\n",
97 | "# For dataset API\n",
98 | "PREFETCH_NUM = 5\n",
99 | "BATCH_SIZE = 32\n",
100 | "\n",
101 | "# Number of epochs in *one iteration*\n",
102 | "NUM_EPOCHS = 2\n",
103 | "# Number of iterations in *one round*. After each iteration,\n",
104 | "# the model is dumped to disk. At the end of the current\n",
105 | "# round, the best model among all the dumped models in the\n",
106 | "# current round is picked up..\n",
107 | "NUM_ITER = 4\n",
108 | "# A round consists of multiple training iterations and a belief\n",
109 | "# update step using the best model from all of these iterations\n",
110 | "NUM_ROUNDS = 10\n",
111 | "LEARNING_RATE=0.001\n",
112 | "\n",
113 | "# A staging direcory to store models\n",
114 | "MODEL_PREFIX = '/home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru'"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {
120 | "heading_collapsed": true
121 | },
122 | "source": [
123 | "# Loading Data"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 11,
129 | "metadata": {
130 | "ExecuteTime": {
131 | "end_time": "2018-12-14T14:17:52.040352Z",
132 | "start_time": "2018-12-14T14:17:51.805319Z"
133 | },
134 | "hidden": true
135 | },
136 | "outputs": [
137 | {
138 | "name": "stdout",
139 | "output_type": "stream",
140 | "text": [
141 | "x_train shape is: (61735, 6, 48, 16)\n",
142 | "y_train shape is: (61735, 6, 5)\n",
143 | "x_test shape is: (6860, 6, 48, 16)\n",
144 | "y_test shape is: (6860, 6, 5)\n"
145 | ]
146 | }
147 | ],
148 | "source": [
149 | "# Loading the data\n",
150 | "path='/home/sf/data/DREAMER/Dominance/Fast_GRNN/48_16/'\n",
151 | "x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
152 | "x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
153 | "x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
154 | "\n",
155 | "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
156 | "# step of EMI/MI RNN\n",
157 | "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
158 | "BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
159 | "BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
160 | "NUM_SUBINSTANCE = x_train.shape[1]\n",
161 | "print(\"x_train shape is:\", x_train.shape)\n",
162 | "print(\"y_train shape is:\", y_train.shape)\n",
163 | "print(\"x_test shape is:\", x_val.shape)\n",
164 | "print(\"y_test shape is:\", y_val.shape)"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {},
170 | "source": [
171 | "# Computation Graph"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 12,
177 | "metadata": {
178 | "ExecuteTime": {
179 | "end_time": "2018-12-14T14:17:52.053161Z",
180 | "start_time": "2018-12-14T14:17:52.042928Z"
181 | }
182 | },
183 | "outputs": [],
184 | "source": [
185 | "# Define the linear secondary classifier\n",
186 | "def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
187 | " W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
188 | " B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
189 | " y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
190 | " self.output = y_cap\n",
191 | " self.graphCreated = True\n",
192 | "\n",
193 | "def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
194 | " y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
195 | " self.output = y_cap\n",
196 | " self.graphCreated = True\n",
197 | " \n",
198 | "def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
199 | " if inference is False:\n",
200 | " feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
201 | " else:\n",
202 | " feedDict = {self._emiGraph.keep_prob: 1.0}\n",
203 | " return feedDict\n",
204 | " \n",
205 | "EMI_GRU._createExtendedGraph = createExtendedGraph\n",
206 | "EMI_GRU._restoreExtendedGraph = restoreExtendedGraph\n",
207 | "\n",
208 | "if USE_DROPOUT is True:\n",
209 | " EMI_Driver.feedDictFunc = feedDictFunc"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 13,
215 | "metadata": {
216 | "ExecuteTime": {
217 | "end_time": "2018-12-14T14:17:52.335299Z",
218 | "start_time": "2018-12-14T14:17:52.055483Z"
219 | }
220 | },
221 | "outputs": [],
222 | "source": [
223 | "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
224 | "emiGRU = EMI_GRU(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
225 | " useDropout=USE_DROPOUT)\n",
226 | "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n",
227 | " stepSize=LEARNING_RATE)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 14,
233 | "metadata": {
234 | "ExecuteTime": {
235 | "end_time": "2018-12-14T14:18:05.031382Z",
236 | "start_time": "2018-12-14T14:17:52.338750Z"
237 | }
238 | },
239 | "outputs": [],
240 | "source": [
241 | "tf.reset_default_graph()\n",
242 | "g1 = tf.Graph() \n",
243 | "with g1.as_default():\n",
244 | " # Obtain the iterators to each batch of the data\n",
245 | " x_batch, y_batch = inputPipeline()\n",
246 | " # Create the forward computation graph based on the iterators\n",
247 | " y_cap = emiGRU(x_batch)\n",
248 | " # Create loss graphs and training routines\n",
249 | " emiTrainer(y_cap, y_batch)"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {},
255 | "source": [
256 | "# EMI Driver"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 15,
262 | "metadata": {
263 | "ExecuteTime": {
264 | "end_time": "2018-12-14T14:35:15.209910Z",
265 | "start_time": "2018-12-14T14:18:05.034359Z"
266 | },
267 | "scrolled": true
268 | },
269 | "outputs": [
270 | {
271 | "name": "stdout",
272 | "output_type": "stream",
273 | "text": [
274 | "Update policy: top-k\n",
275 | "Training with MI-RNN loss for 5 rounds\n",
276 | "Round: 0\n",
277 | "Epoch 1 Batch 1925 ( 3855) Loss 0.03100 Acc 0.36979 | Val acc 0.38717 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1000\n",
278 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02785 Acc 0.40104 | Val acc 0.41647 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1001\n",
279 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02730 Acc 0.39062 | Val acc 0.45000 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1002\n",
280 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02400 Acc 0.46354 | Val acc 0.48513 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1003\n",
281 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1003\n",
282 | "Round: 1\n",
283 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02273 Acc 0.52604 | Val acc 0.51706 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1004\n",
284 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02227 Acc 0.55729 | Val acc 0.55423 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1005\n",
285 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02006 Acc 0.60417 | Val acc 0.58017 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1006\n",
286 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01900 Acc 0.66667 | Val acc 0.59985 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1007\n",
287 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1007\n",
288 | "Round: 2\n",
289 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01803 Acc 0.67708 | Val acc 0.61268 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1008\n",
290 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01809 Acc 0.63021 | Val acc 0.62828 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1009\n",
291 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01681 Acc 0.65625 | Val acc 0.63499 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1010\n",
292 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01673 Acc 0.70312 | Val acc 0.64227 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1011\n",
293 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1011\n",
294 | "Round: 3\n",
295 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01625 Acc 0.66667 | Val acc 0.65262 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1012\n",
296 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01641 Acc 0.68750 | Val acc 0.66122 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1013\n",
297 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01519 Acc 0.70833 | Val acc 0.65583 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1014\n",
298 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01486 Acc 0.70312 | Val acc 0.66268 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1015\n",
299 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1015\n",
300 | "Round: 4\n",
301 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01496 Acc 0.69792 | Val acc 0.67128 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1016\n",
302 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01475 Acc 0.72917 | Val acc 0.67303 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1017\n",
303 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01540 Acc 0.70833 | Val acc 0.66764 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1018\n",
304 | "Epoch 1 Batch 1925 ( 3855) Loss 0.01428 Acc 0.75000 | Val acc 0.67609 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1019\n",
305 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1019\n",
306 | "Round: 5\n",
307 | "Switching to EMI-Loss function\n",
308 | "Epoch 1 Batch 1925 ( 3855) Loss 0.90647 Acc 0.74479 | Val acc 0.65758 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1020\n",
309 | "Epoch 1 Batch 1925 ( 3855) Loss 0.91761 Acc 0.64583 | Val acc 0.65933 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1021\n",
310 | "Epoch 1 Batch 1925 ( 3855) Loss 0.88657 Acc 0.70833 | Val acc 0.65991 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1022\n",
311 | "Epoch 1 Batch 1925 ( 3855) Loss 0.89605 Acc 0.68229 | Val acc 0.66749 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1023\n",
312 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1023\n",
313 | "Round: 6\n",
314 | "Epoch 1 Batch 1925 ( 3855) Loss 0.88819 Acc 0.69792 | Val acc 0.66910 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1024\n",
315 | "Epoch 1 Batch 1925 ( 3855) Loss 0.85815 Acc 0.71875 | Val acc 0.66910 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1025\n",
316 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87857 Acc 0.71354 | Val acc 0.67128 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1026\n",
317 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87222 Acc 0.72396 | Val acc 0.66574 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1027\n",
318 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1026\n",
319 | "Round: 7\n",
320 | "Epoch 1 Batch 1925 ( 3855) Loss 0.86202 Acc 0.70312 | Val acc 0.66720 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1028\n",
321 | "Epoch 1 Batch 1925 ( 3855) Loss 0.88255 Acc 0.68229 | Val acc 0.66545 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1029\n",
322 | "Epoch 1 Batch 1925 ( 3855) Loss 0.89641 Acc 0.64583 | Val acc 0.66647 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1030\n",
323 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87491 Acc 0.67708 | Val acc 0.66297 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1031\n",
324 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1028\n",
325 | "Round: 8\n",
326 | "Epoch 1 Batch 1925 ( 3855) Loss 0.88003 Acc 0.67708 | Val acc 0.66429 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1032\n",
327 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87404 Acc 0.69271 | Val acc 0.65904 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1033\n",
328 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87308 Acc 0.72396 | Val acc 0.66603 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1034\n",
329 | "Epoch 1 Batch 1925 ( 3855) Loss 0.86821 Acc 0.67708 | Val acc 0.66822 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1035\n",
330 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1035\n",
331 | "Round: 9\n",
332 | "Epoch 1 Batch 1925 ( 3855) Loss 0.90025 Acc 0.69792 | Val acc 0.66020 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1036\n",
333 | "Epoch 1 Batch 1925 ( 3855) Loss 0.89532 Acc 0.66146 | Val acc 0.66676 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1037\n",
334 | "Epoch 1 Batch 1925 ( 3855) Loss 0.87133 Acc 0.69271 | Val acc 0.66706 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1038\n",
335 | "Epoch 1 Batch 1925 ( 3855) Loss 0.86914 Acc 0.68750 | Val acc 0.66837 | Model saved to /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru, global_step 1039\n",
336 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1039\n"
337 | ]
338 | }
339 | ],
340 | "source": [
341 | "with g1.as_default():\n",
342 | " emiDriver = EMI_Driver(inputPipeline, emiGRU, emiTrainer)\n",
343 | "\n",
344 | "emiDriver.initializeSession(g1)\n",
345 | "y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
346 | " y_train=y_train, bag_train=BAG_TRAIN,\n",
347 | " x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
348 | " numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
349 | " numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
350 | " numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
351 | " fracEMI=0.5, updatePolicy='top-k', k=1)"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {},
357 | "source": [
358 | "# Evaluating the trained model"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 16,
364 | "metadata": {
365 | "ExecuteTime": {
366 | "end_time": "2018-12-14T14:35:15.218040Z",
367 | "start_time": "2018-12-14T14:35:15.211771Z"
368 | }
369 | },
370 | "outputs": [],
371 | "source": [
372 | "# Early Prediction Policy: We make an early prediction based on the predicted classes\n",
373 | "# probability. If the predicted class probability > minProb at some step, we make\n",
374 | "# a prediction at that step.\n",
375 | "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
376 | " assert instanceOut.ndim == 2\n",
377 | " classes = np.argmax(instanceOut, axis=1)\n",
378 | " prob = np.max(instanceOut, axis=1)\n",
379 | " index = np.where(prob >= minProb)[0]\n",
380 | " if len(index) == 0:\n",
381 | " assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
382 | " return classes[-1], len(instanceOut) - 1\n",
383 | " index = index[0]\n",
384 | " return classes[index], index\n",
385 | "\n",
386 | "def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
387 | " predictionStep = predictionStep + 1\n",
388 | " predictionStep = np.reshape(predictionStep, -1)\n",
389 | " totalSteps = np.sum(predictionStep)\n",
390 | " maxSteps = len(predictionStep) * numTimeSteps\n",
391 | " savings = 1.0 - (totalSteps / maxSteps)\n",
392 | " if returnTotal:\n",
393 | " return savings, totalSteps\n",
394 | " return savings"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": 17,
400 | "metadata": {
401 | "ExecuteTime": {
402 | "end_time": "2018-12-14T14:35:16.257489Z",
403 | "start_time": "2018-12-14T14:35:15.221029Z"
404 | }
405 | },
406 | "outputs": [
407 | {
408 | "name": "stdout",
409 | "output_type": "stream",
410 | "text": [
411 | "Accuracy at k = 2: 0.669893\n",
412 | "Savings due to MI-RNN : 0.625000\n",
413 | "Savings due to Early prediction: 0.133547\n",
414 | "Total Savings: 0.675080\n"
415 | ]
416 | }
417 | ],
418 | "source": [
419 | "k = 2\n",
420 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
421 | " minProb=0.99, keep_prob=1.0)\n",
422 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
423 | "print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
424 | "mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
425 | "emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
426 | "total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
427 | "print('Savings due to MI-RNN : %f' % mi_savings)\n",
428 | "print('Savings due to Early prediction: %f' % emi_savings)\n",
429 | "print('Total Savings: %f' % (total_savings))"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 18,
435 | "metadata": {
436 | "ExecuteTime": {
437 | "end_time": "2018-12-14T14:35:17.044115Z",
438 | "start_time": "2018-12-14T14:35:16.259280Z"
439 | },
440 | "scrolled": false
441 | },
442 | "outputs": [
443 | {
444 | "name": "stdout",
445 | "output_type": "stream",
446 | "text": [
447 | " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
448 | "0 1 0.670185 0.635187 0.713956 0.607175 0.670185 0.670185 \n",
449 | "1 2 0.669893 0.639764 0.639024 0.641768 0.669893 0.669893 \n",
450 | "2 3 0.631232 0.574154 0.594423 0.641324 0.631232 0.631232 \n",
451 | "3 4 0.541956 0.520735 0.628534 0.588688 0.541956 0.541956 \n",
452 | "4 5 0.470873 0.480797 0.664928 0.542186 0.470873 0.470873 \n",
453 | "5 6 0.420783 0.450270 0.700213 0.506986 0.420783 0.420783 \n",
454 | "\n",
455 | " micro-rec \n",
456 | "0 0.670185 \n",
457 | "1 0.669893 \n",
458 | "2 0.631232 \n",
459 | "3 0.541956 \n",
460 | "4 0.470873 \n",
461 | "5 0.420783 \n",
462 | "Max accuracy 0.670185 at subsequencelength 1\n",
463 | "Max micro-f 0.670185 at subsequencelength 1\n",
464 | "Micro-precision 0.670185 at subsequencelength 1\n",
465 | "Micro-recall 0.670185 at subsequencelength 1\n",
466 | "Max macro-f 0.639764 at subsequencelength 2\n",
467 | "macro-precision 0.639024 at subsequencelength 2\n",
468 | "macro-recall 0.641768 at subsequencelength 2\n"
469 | ]
470 | }
471 | ],
472 | "source": [
473 | "# A slightly more detailed analysis method is provided. \n",
474 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)"
475 | ]
476 | },
477 | {
478 | "cell_type": "markdown",
479 | "metadata": {},
480 | "source": [
481 | "## Picking the best model"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": null,
487 | "metadata": {
488 | "ExecuteTime": {
489 | "end_time": "2018-12-14T14:35:54.899340Z",
490 | "start_time": "2018-12-14T14:35:17.047464Z"
491 | }
492 | },
493 | "outputs": [
494 | {
495 | "name": "stdout",
496 | "output_type": "stream",
497 | "text": [
498 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1003\n",
499 | "Round: 0, Validation accuracy: 0.4851, Test Accuracy (k = 2): 0.486151, Total Savings: 0.628213\n",
500 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1007\n",
501 | "Round: 1, Validation accuracy: 0.5999, Test Accuracy (k = 2): 0.601143, Total Savings: 0.633056\n",
502 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1011\n",
503 | "Round: 2, Validation accuracy: 0.6423, Test Accuracy (k = 2): 0.638405, Total Savings: 0.635002\n",
504 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1015\n",
505 | "Round: 3, Validation accuracy: 0.6627, Test Accuracy (k = 2): 0.657881, Total Savings: 0.638982\n",
506 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1019\n",
507 | "Round: 4, Validation accuracy: 0.6761, Test Accuracy (k = 2): 0.674033, Total Savings: 0.641478\n",
508 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1023\n",
509 | "Round: 5, Validation accuracy: 0.6675, Test Accuracy (k = 2): 0.658114, Total Savings: 0.659983\n",
510 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1026\n",
511 | "Round: 6, Validation accuracy: 0.6713, Test Accuracy (k = 2): 0.665053, Total Savings: 0.664241\n",
512 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1028\n",
513 | "Round: 7, Validation accuracy: 0.6672, Test Accuracy (k = 2): 0.662837, Total Savings: 0.664690\n",
514 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1035\n",
515 | "Round: 8, Validation accuracy: 0.6682, Test Accuracy (k = 2): 0.665578, Total Savings: 0.670677\n",
516 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Dominance/48_16/models/GRU/model-gru-1039\n"
517 | ]
518 | }
519 | ],
520 | "source": [
521 | "devnull = open(os.devnull, 'r')\n",
522 | "for val in modelStats:\n",
523 | " round_, acc, modelPrefix, globalStep = val\n",
524 | " emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
525 | " predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
526 | " minProb=0.99, keep_prob=1.0)\n",
527 | "\n",
528 | " bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
529 | " print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
530 | " print(', Test Accuracy (k = %d): %f, ' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
531 | " mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
532 | " emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
533 | " total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
534 | " print(\"Total Savings: %f\" % total_savings)"
535 | ]
536 | },
537 | {
538 | "cell_type": "code",
539 | "execution_count": null,
540 | "metadata": {},
541 | "outputs": [],
542 | "source": [
543 | "params = {\n",
544 | " \"NUM_HIDDEN\" : 128,\n",
545 | " \"NUM_TIMESTEPS\" : 48, #subinstance length.\n",
546 | " \"ORIGINAL_NUM_TIMESTEPS\" : 128,\n",
547 | " \"NUM_FEATS\" : 16,\n",
548 | " \"FORGET_BIAS\" : 1.0,\n",
549 | " \"NUM_OUTPUT\" : 5,\n",
550 | " \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n",
551 | " \"KEEP_PROB\" : 0.75,\n",
552 | " \"PREFETCH_NUM\" : 5,\n",
553 | " \"BATCH_SIZE\" : 32,\n",
554 | " \"NUM_EPOCHS\" : 2,\n",
555 | " \"NUM_ITER\" : 4,\n",
556 | " \"NUM_ROUNDS\" : 10,\n",
557 | " \"LEARNING_RATE\" : 0.001,\n",
558 | " \"MODEL_PREFIX\" : '/home/sf/data/DREAMER/Dominance/model-gru'\n",
559 | "}"
560 | ]
561 | },
562 | {
563 | "cell_type": "code",
564 | "execution_count": null,
565 | "metadata": {},
566 | "outputs": [],
567 | "source": [
568 | "gru_dict = {**params}\n",
569 | "gru_dict[\"k\"] = k\n",
570 | "gru_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
571 | "gru_dict[\"total_savings\"] = total_savings\n",
572 | "gru_dict[\"y_test\"] = BAG_TEST\n",
573 | "gru_dict[\"y_pred\"] = bagPredictions\n",
574 | "\n",
575 | "# A slightly more detailed analysis method is provided. \n",
576 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
577 | "print (tabulate(df, headers=list(df.columns), tablefmt='grid'))"
578 | ]
579 | },
580 | {
581 | "cell_type": "code",
582 | "execution_count": null,
583 | "metadata": {},
584 | "outputs": [],
585 | "source": [
586 | "dirname = \"home/sf/data/DREAMER/Dominance/GRU/\"\n",
587 | "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
588 | "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
589 | "\n",
590 | "now = datetime.datetime.now()\n",
591 | "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
592 | "filename = ''.join(filename)\n",
593 | "\n",
594 | "#Save the dictionary containing the params and the results.\n",
595 | "pkl.dump(gru_dict,open(dirname + filename + \".pkl\",mode='wb'))"
596 | ]
597 | },
598 | {
599 | "cell_type": "code",
600 | "execution_count": 23,
601 | "metadata": {},
602 | "outputs": [
603 | {
604 | "data": {
605 | "text/plain": [
606 | "'home/sf/data/DREAMER/Dominance/GRU/2019-8-11|2-30.pkl'"
607 | ]
608 | },
609 | "execution_count": 23,
610 | "metadata": {},
611 | "output_type": "execute_result"
612 | }
613 | ],
614 | "source": [
615 | "dirname+filename+'.pkl'"
616 | ]
617 | }
618 | ],
619 | "metadata": {
620 | "kernelspec": {
621 | "display_name": "Python 3",
622 | "language": "python",
623 | "name": "python3"
624 | },
625 | "language_info": {
626 | "codemirror_mode": {
627 | "name": "ipython",
628 | "version": 3
629 | },
630 | "file_extension": ".py",
631 | "mimetype": "text/x-python",
632 | "name": "python",
633 | "nbconvert_exporter": "python",
634 | "pygments_lexer": "ipython3",
635 | "version": "3.7.3"
636 | },
637 | "latex_envs": {
638 | "LaTeX_envs_menu_present": true,
639 | "autoclose": false,
640 | "autocomplete": true,
641 | "bibliofile": "biblio.bib",
642 | "cite_by": "apalike",
643 | "current_citInitial": 1,
644 | "eqLabelWithNumbers": true,
645 | "eqNumInitial": 1,
646 | "hotkeys": {
647 | "equation": "Ctrl-E",
648 | "itemize": "Ctrl-I"
649 | },
650 | "labels_anchors": false,
651 | "latex_user_defs": false,
652 | "report_style_numbering": false,
653 | "user_envs_cfg": false
654 | }
655 | },
656 | "nbformat": 4,
657 | "nbformat_minor": 2
658 | }
659 |
--------------------------------------------------------------------------------
/DREAMER/DREAMER_SVM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from sklearn.svm import SVC\n",
10 | "import pandas as pd\n",
11 | "from sklearn.utils import shuffle\n",
12 | "import pickle"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "df = pd.read_csv(\"DREAMER_combined.csv\", index_col=0)"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "df_train = df.drop(['Movie', 'Person', 'Arousal','Dominance', 'Valence'], axis=1)\n",
31 | "df_target = df['Dominance']\n",
32 | "df_target = df_target.replace({1:0,2:1,3:2,4:3,5:4})"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "clf = SVC()"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {
48 | "scrolled": true
49 | },
50 | "outputs": [],
51 | "source": [
52 | "%%time\n",
53 | "clf.fit(df_train, df_target)"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "filename = 'DREAMER_SVM.sav'\n",
63 | "pickle.dump(model, open(filename, 'wb'))"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "%%time\n",
73 | "clf.predict(df_train[:64])"
74 | ]
75 | }
76 | ],
77 | "metadata": {
78 | "kernelspec": {
79 | "display_name": "Python 3",
80 | "language": "python",
81 | "name": "python3"
82 | },
83 | "language_info": {
84 | "codemirror_mode": {
85 | "name": "ipython",
86 | "version": 3
87 | },
88 | "file_extension": ".py",
89 | "mimetype": "text/x-python",
90 | "name": "python",
91 | "nbconvert_exporter": "python",
92 | "pygments_lexer": "ipython3",
93 | "version": "3.7.3"
94 | }
95 | },
96 | "nbformat": 4,
97 | "nbformat_minor": 2
98 | }
99 |
--------------------------------------------------------------------------------
/DREAMER/DREAMER_Valence_GRU_48_16.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# DREAMER Valence EMI-GRU 48_16"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import pandas as pd\n",
24 | "import numpy as np\n",
25 | "from tabulate import tabulate\n",
26 | "import os\n",
27 | "import datetime as datetime\n",
28 | "import pickle as pkl\n",
29 | "import pathlib"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {
36 | "ExecuteTime": {
37 | "end_time": "2018-12-14T14:17:51.796585Z",
38 | "start_time": "2018-12-14T14:17:49.648375Z"
39 | }
40 | },
41 | "outputs": [
42 | {
43 | "name": "stderr",
44 | "output_type": "stream",
45 | "text": [
46 | "Using TensorFlow backend.\n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "from __future__ import print_function\n",
52 | "import os\n",
53 | "import sys\n",
54 | "import tensorflow as tf\n",
55 | "import numpy as np\n",
56 | "# Making sure edgeml is part of python path\n",
57 | "sys.path.insert(0, '../../')\n",
58 | "#For processing on CPU.\n",
59 | "os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
60 | "\n",
61 | "np.random.seed(42)\n",
62 | "tf.set_random_seed(42)\n",
63 | "\n",
64 | "# MI-RNN and EMI-RNN imports\n",
65 | "from edgeml.graph.rnn import EMI_DataPipeline\n",
66 | "from edgeml.graph.rnn import EMI_GRU\n",
67 | "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
68 | "import edgeml.utils\n",
69 | "\n",
70 | "import keras.backend as K\n",
71 | "cfg = K.tf.ConfigProto()\n",
72 | "cfg.gpu_options.allow_growth = True\n",
73 | "K.set_session(K.tf.Session(config=cfg))"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 3,
79 | "metadata": {
80 | "ExecuteTime": {
81 | "end_time": "2018-12-14T14:17:51.803381Z",
82 | "start_time": "2018-12-14T14:17:51.798799Z"
83 | }
84 | },
85 | "outputs": [],
86 | "source": [
87 | "# Network parameters for our LSTM + FC Layer\n",
88 | "NUM_HIDDEN = 128\n",
89 | "NUM_TIMESTEPS = 48\n",
90 | "ORIGINAL_NUM_TIMESTEPS = 128\n",
91 | "NUM_FEATS = 16\n",
92 | "FORGET_BIAS = 1.0\n",
93 | "NUM_OUTPUT = 5\n",
94 | "USE_DROPOUT = True\n",
95 | "KEEP_PROB = 0.75\n",
96 | "\n",
97 | "# For dataset API\n",
98 | "PREFETCH_NUM = 5\n",
99 | "BATCH_SIZE = 32\n",
100 | "\n",
101 | "# Number of epochs in *one iteration*\n",
102 | "NUM_EPOCHS = 2\n",
103 | "# Number of iterations in *one round*. After each iteration,\n",
104 | "# the model is dumped to disk. At the end of the current\n",
105 | "# round, the best model among all the dumped models in the\n",
106 | "# current round is picked up..\n",
107 | "NUM_ITER = 4\n",
108 | "# A round consists of multiple training iterations and a belief\n",
109 | "# update step using the best model from all of these iterations\n",
110 | "NUM_ROUNDS = 10\n",
111 | "LEARNING_RATE=0.001\n",
112 | "\n",
113 | "# A staging direcory to store models\n",
114 | "MODEL_PREFIX = '/home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru'"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {
120 | "heading_collapsed": true
121 | },
122 | "source": [
123 | "# Loading Data"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 4,
129 | "metadata": {
130 | "ExecuteTime": {
131 | "end_time": "2018-12-14T14:17:52.040352Z",
132 | "start_time": "2018-12-14T14:17:51.805319Z"
133 | },
134 | "hidden": true
135 | },
136 | "outputs": [
137 | {
138 | "name": "stdout",
139 | "output_type": "stream",
140 | "text": [
141 | "x_train shape is: (61735, 6, 48, 16)\n",
142 | "y_train shape is: (61735, 6, 5)\n",
143 | "x_test shape is: (6860, 6, 48, 16)\n",
144 | "y_test shape is: (6860, 6, 5)\n"
145 | ]
146 | }
147 | ],
148 | "source": [
149 | "# Loading the data\n",
150 | "path='/home/sf/data/DREAMER/Valence/Fast_GRNN/48_16/'\n",
151 | "x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
152 | "x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
153 | "x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
154 | "\n",
155 | "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
156 | "# step of EMI/MI RNN\n",
157 | "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
158 | "BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
159 | "BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
160 | "NUM_SUBINSTANCE = x_train.shape[1]\n",
161 | "print(\"x_train shape is:\", x_train.shape)\n",
162 | "print(\"y_train shape is:\", y_train.shape)\n",
163 | "print(\"x_test shape is:\", x_val.shape)\n",
164 | "print(\"y_test shape is:\", y_val.shape)"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {},
170 | "source": [
171 | "# Computation Graph"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 5,
177 | "metadata": {
178 | "ExecuteTime": {
179 | "end_time": "2018-12-14T14:17:52.053161Z",
180 | "start_time": "2018-12-14T14:17:52.042928Z"
181 | }
182 | },
183 | "outputs": [],
184 | "source": [
185 | "# Define the linear secondary classifier\n",
186 | "def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
187 | " W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
188 | " B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
189 | " y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
190 | " self.output = y_cap\n",
191 | " self.graphCreated = True\n",
192 | "\n",
193 | "def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
194 | " y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
195 | " self.output = y_cap\n",
196 | " self.graphCreated = True\n",
197 | " \n",
198 | "def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
199 | " if inference is False:\n",
200 | " feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
201 | " else:\n",
202 | " feedDict = {self._emiGraph.keep_prob: 1.0}\n",
203 | " return feedDict\n",
204 | " \n",
205 | "EMI_GRU._createExtendedGraph = createExtendedGraph\n",
206 | "EMI_GRU._restoreExtendedGraph = restoreExtendedGraph\n",
207 | "\n",
208 | "if USE_DROPOUT is True:\n",
209 | " EMI_Driver.feedDictFunc = feedDictFunc"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 6,
215 | "metadata": {
216 | "ExecuteTime": {
217 | "end_time": "2018-12-14T14:17:52.335299Z",
218 | "start_time": "2018-12-14T14:17:52.055483Z"
219 | }
220 | },
221 | "outputs": [],
222 | "source": [
223 | "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
224 | "emiGRU = EMI_GRU(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
225 | " useDropout=USE_DROPOUT)\n",
226 | "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n",
227 | " stepSize=LEARNING_RATE)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 7,
233 | "metadata": {
234 | "ExecuteTime": {
235 | "end_time": "2018-12-14T14:18:05.031382Z",
236 | "start_time": "2018-12-14T14:17:52.338750Z"
237 | }
238 | },
239 | "outputs": [],
240 | "source": [
241 | "tf.reset_default_graph()\n",
242 | "g1 = tf.Graph() \n",
243 | "with g1.as_default():\n",
244 | " # Obtain the iterators to each batch of the data\n",
245 | " x_batch, y_batch = inputPipeline()\n",
246 | " # Create the forward computation graph based on the iterators\n",
247 | " y_cap = emiGRU(x_batch)\n",
248 | " # Create loss graphs and training routines\n",
249 | " emiTrainer(y_cap, y_batch)"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {},
255 | "source": [
256 | "# EMI Driver"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 8,
262 | "metadata": {
263 | "ExecuteTime": {
264 | "end_time": "2018-12-14T14:35:15.209910Z",
265 | "start_time": "2018-12-14T14:18:05.034359Z"
266 | },
267 | "scrolled": true
268 | },
269 | "outputs": [
270 | {
271 | "name": "stdout",
272 | "output_type": "stream",
273 | "text": [
274 | "Update policy: top-k\n",
275 | "Training with MI-RNN loss for 5 rounds\n",
276 | "Round: 0\n",
277 | "Epoch 1 Batch 1925 ( 3855) Loss 0.03087 Acc 0.32292 | Val acc 0.30146 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1000\n",
278 | "Epoch 1 Batch 1925 ( 3855) Loss 0.03040 Acc 0.39062 | Val acc 0.34198 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1001\n",
279 | "Epoch 1 Batch 1925 ( 3855) Loss 0.03005 Acc 0.36979 | Val acc 0.37318 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1002\n",
280 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02878 Acc 0.41146 | Val acc 0.40292 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1003\n",
281 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1003\n",
282 | "Round: 1\n",
283 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02790 Acc 0.40104 | Val acc 0.43105 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1004\n",
284 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02739 Acc 0.40625 | Val acc 0.46633 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1005\n",
285 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02682 Acc 0.40625 | Val acc 0.49417 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1006\n",
286 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02684 Acc 0.45833 | Val acc 0.52172 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1007\n",
287 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1007\n",
288 | "Round: 2\n",
289 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02423 Acc 0.47396 | Val acc 0.54111 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1008\n",
290 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02415 Acc 0.46354 | Val acc 0.55758 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1009\n",
291 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02348 Acc 0.50521 | Val acc 0.56545 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1010\n",
292 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02517 Acc 0.44271 | Val acc 0.57638 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1011\n",
293 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1011\n",
294 | "Round: 3\n",
295 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02395 Acc 0.47396 | Val acc 0.57857 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1012\n",
296 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02223 Acc 0.52083 | Val acc 0.58280 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1013\n",
297 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02149 Acc 0.51562 | Val acc 0.59125 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1014\n",
298 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02089 Acc 0.52083 | Val acc 0.59475 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1015\n",
299 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1015\n",
300 | "Round: 4\n",
301 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02177 Acc 0.53125 | Val acc 0.59650 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1016\n",
302 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02202 Acc 0.50000 | Val acc 0.59927 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1017\n",
303 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02128 Acc 0.51562 | Val acc 0.60087 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1018\n",
304 | "Epoch 1 Batch 1925 ( 3855) Loss 0.02187 Acc 0.53646 | Val acc 0.60394 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1019\n",
305 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1019\n",
306 | "Round: 5\n",
307 | "Switching to EMI-Loss function\n",
308 | "Epoch 1 Batch 1925 ( 3855) Loss 1.22457 Acc 0.50000 | Val acc 0.58630 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1020\n",
309 | "Epoch 1 Batch 1925 ( 3855) Loss 1.18352 Acc 0.48958 | Val acc 0.60481 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1021\n",
310 | "Epoch 1 Batch 1925 ( 3855) Loss 1.21332 Acc 0.48438 | Val acc 0.61093 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1022\n",
311 | "Epoch 1 Batch 1925 ( 3855) Loss 1.19744 Acc 0.53646 | Val acc 0.61414 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1023\n",
312 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1023\n",
313 | "Round: 6\n",
314 | "Epoch 1 Batch 1925 ( 3855) Loss 1.17704 Acc 0.52604 | Val acc 0.61414 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1024\n",
315 | "Epoch 1 Batch 1925 ( 3855) Loss 1.13723 Acc 0.56250 | Val acc 0.60918 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1025\n",
316 | "Epoch 1 Batch 1925 ( 3855) Loss 1.16212 Acc 0.51042 | Val acc 0.62143 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1026\n",
317 | "Epoch 1 Batch 1925 ( 3855) Loss 1.13103 Acc 0.55208 | Val acc 0.61545 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1027\n",
318 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1026\n",
319 | "Round: 7\n",
320 | "Epoch 1 Batch 1925 ( 3855) Loss 1.18400 Acc 0.47396 | Val acc 0.61837 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1028\n",
321 | "Epoch 1 Batch 1925 ( 3855) Loss 1.10993 Acc 0.51042 | Val acc 0.62041 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1029\n",
322 | "Epoch 1 Batch 1925 ( 3855) Loss 1.13089 Acc 0.54167 | Val acc 0.61749 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1030\n",
323 | "Epoch 1 Batch 1925 ( 3855) Loss 1.13852 Acc 0.49479 | Val acc 0.62128 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1031\n",
324 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1031\n",
325 | "Round: 8\n",
326 | "Epoch 1 Batch 1925 ( 3855) Loss 1.12852 Acc 0.53125 | Val acc 0.62493 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1032\n",
327 | "Epoch 1 Batch 1925 ( 3855) Loss 1.12762 Acc 0.53125 | Val acc 0.61997 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1033\n",
328 | "Epoch 1 Batch 1925 ( 3855) Loss 1.12535 Acc 0.50521 | Val acc 0.62245 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1034\n",
329 | "Epoch 1 Batch 1925 ( 3855) Loss 1.11068 Acc 0.55208 | Val acc 0.63134 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1035\n",
330 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1035\n",
331 | "Round: 9\n",
332 | "Epoch 1 Batch 1925 ( 3855) Loss 1.12000 Acc 0.48958 | Val acc 0.63017 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1036\n",
333 | "Epoch 1 Batch 1925 ( 3855) Loss 1.11178 Acc 0.48438 | Val acc 0.63265 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1037\n",
334 | "Epoch 1 Batch 1925 ( 3855) Loss 1.13696 Acc 0.53125 | Val acc 0.62682 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1038\n",
335 | "Epoch 1 Batch 1925 ( 3855) Loss 1.13168 Acc 0.54688 | Val acc 0.62464 | Model saved to /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru, global_step 1039\n",
336 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1037\n"
337 | ]
338 | }
339 | ],
340 | "source": [
341 | "with g1.as_default():\n",
342 | " emiDriver = EMI_Driver(inputPipeline, emiGRU, emiTrainer)\n",
343 | "\n",
344 | "emiDriver.initializeSession(g1)\n",
345 | "y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
346 | " y_train=y_train, bag_train=BAG_TRAIN,\n",
347 | " x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
348 | " numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
349 | " numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
350 | " numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
351 | " fracEMI=0.5, updatePolicy='top-k', k=1)"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {},
357 | "source": [
358 | "# Evaluating the trained model"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 9,
364 | "metadata": {
365 | "ExecuteTime": {
366 | "end_time": "2018-12-14T14:35:15.218040Z",
367 | "start_time": "2018-12-14T14:35:15.211771Z"
368 | }
369 | },
370 | "outputs": [],
371 | "source": [
372 | "# Early Prediction Policy: We make an early prediction based on the predicted classes\n",
373 | "# probability. If the predicted class probability > minProb at some step, we make\n",
374 | "# a prediction at that step.\n",
375 | "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
376 | " assert instanceOut.ndim == 2\n",
377 | " classes = np.argmax(instanceOut, axis=1)\n",
378 | " prob = np.max(instanceOut, axis=1)\n",
379 | " index = np.where(prob >= minProb)[0]\n",
380 | " if len(index) == 0:\n",
381 | " assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
382 | " return classes[-1], len(instanceOut) - 1\n",
383 | " index = index[0]\n",
384 | " return classes[index], index\n",
385 | "\n",
386 | "def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
387 | " predictionStep = predictionStep + 1\n",
388 | " predictionStep = np.reshape(predictionStep, -1)\n",
389 | " totalSteps = np.sum(predictionStep)\n",
390 | " maxSteps = len(predictionStep) * numTimeSteps\n",
391 | " savings = 1.0 - (totalSteps / maxSteps)\n",
392 | " if returnTotal:\n",
393 | " return savings, totalSteps\n",
394 | " return savings"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": 10,
400 | "metadata": {
401 | "ExecuteTime": {
402 | "end_time": "2018-12-14T14:35:16.257489Z",
403 | "start_time": "2018-12-14T14:35:15.221029Z"
404 | }
405 | },
406 | "outputs": [
407 | {
408 | "name": "stdout",
409 | "output_type": "stream",
410 | "text": [
411 | "Accuracy at k = 2: 0.630941\n",
412 | "Savings due to MI-RNN : 0.625000\n",
413 | "Savings due to Early prediction: 0.118245\n",
414 | "Total Savings: 0.669342\n"
415 | ]
416 | }
417 | ],
418 | "source": [
419 | "k = 2\n",
420 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
421 | " minProb=0.99, keep_prob=1.0)\n",
422 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
423 | "print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
424 | "mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
425 | "emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
426 | "total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
427 | "print('Savings due to MI-RNN : %f' % mi_savings)\n",
428 | "print('Savings due to Early prediction: %f' % emi_savings)\n",
429 | "print('Total Savings: %f' % (total_savings))"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 11,
435 | "metadata": {
436 | "ExecuteTime": {
437 | "end_time": "2018-12-14T14:35:17.044115Z",
438 | "start_time": "2018-12-14T14:35:16.259280Z"
439 | },
440 | "scrolled": false
441 | },
442 | "outputs": [
443 | {
444 | "name": "stdout",
445 | "output_type": "stream",
446 | "text": [
447 | " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
448 | "0 1 0.612980 0.611018 0.657280 0.610441 0.612980 0.612980 \n",
449 | "1 2 0.630941 0.631356 0.640975 0.628401 0.630941 0.630941 \n",
450 | "2 3 0.606158 0.608647 0.637136 0.603571 0.606158 0.606158 \n",
451 | "3 4 0.550061 0.560872 0.669536 0.547729 0.550061 0.550061 \n",
452 | "4 5 0.506035 0.520660 0.712413 0.504197 0.506035 0.506035 \n",
453 | "5 6 0.469532 0.482675 0.743530 0.468139 0.469532 0.469532 \n",
454 | "\n",
455 | " micro-rec \n",
456 | "0 0.612980 \n",
457 | "1 0.630941 \n",
458 | "2 0.606158 \n",
459 | "3 0.550061 \n",
460 | "4 0.506035 \n",
461 | "5 0.469532 \n",
462 | "Max accuracy 0.630941 at subsequencelength 2\n",
463 | "Max micro-f 0.630941 at subsequencelength 2\n",
464 | "Micro-precision 0.630941 at subsequencelength 2\n",
465 | "Micro-recall 0.630941 at subsequencelength 2\n",
466 | "Max macro-f 0.631356 at subsequencelength 2\n",
467 | "macro-precision 0.640975 at subsequencelength 2\n",
468 | "macro-recall 0.628401 at subsequencelength 2\n"
469 | ]
470 | }
471 | ],
472 | "source": [
473 | "# A slightly more detailed analysis method is provided. \n",
474 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)"
475 | ]
476 | },
477 | {
478 | "cell_type": "markdown",
479 | "metadata": {},
480 | "source": [
481 | "## Picking the best model"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": 12,
487 | "metadata": {
488 | "ExecuteTime": {
489 | "end_time": "2018-12-14T14:35:54.899340Z",
490 | "start_time": "2018-12-14T14:35:17.047464Z"
491 | }
492 | },
493 | "outputs": [
494 | {
495 | "name": "stdout",
496 | "output_type": "stream",
497 | "text": [
498 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1003\n",
499 | "Round: 0, Validation accuracy: 0.4029, Test Accuracy (k = 2): 0.403231, Total Savings: 0.630370\n",
500 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1007\n",
501 | "Round: 1, Validation accuracy: 0.5217, Test Accuracy (k = 2): 0.520555, Total Savings: 0.631844\n",
502 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1011\n",
503 | "Round: 2, Validation accuracy: 0.5764, Test Accuracy (k = 2): 0.577818, Total Savings: 0.635449\n",
504 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1015\n",
505 | "Round: 3, Validation accuracy: 0.5948, Test Accuracy (k = 2): 0.596186, Total Savings: 0.637888\n",
506 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1019\n",
507 | "Round: 4, Validation accuracy: 0.6039, Test Accuracy (k = 2): 0.609831, Total Savings: 0.640020\n",
508 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1023\n",
509 | "Round: 5, Validation accuracy: 0.6141, Test Accuracy (k = 2): 0.612572, Total Savings: 0.655162\n",
510 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1026\n",
511 | "Round: 6, Validation accuracy: 0.6214, Test Accuracy (k = 2): 0.618054, Total Savings: 0.658907\n",
512 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1031\n",
513 | "Round: 7, Validation accuracy: 0.6213, Test Accuracy (k = 2): 0.621494, Total Savings: 0.662773\n",
514 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1035\n",
515 | "Round: 8, Validation accuracy: 0.6313, Test Accuracy (k = 2): 0.623652, Total Savings: 0.666211\n",
516 | "INFO:tensorflow:Restoring parameters from /home/sf/data/DREAMER/Valence/48_16/models/GRU/model-gru-1037\n",
517 | "Round: 9, Validation accuracy: 0.6327, Test Accuracy (k = 2): 0.630941, Total Savings: 0.669342\n"
518 | ]
519 | }
520 | ],
521 | "source": [
522 | "devnull = open(os.devnull, 'r')\n",
523 | "for val in modelStats:\n",
524 | " round_, acc, modelPrefix, globalStep = val\n",
525 | " emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
526 | " predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
527 | " minProb=0.99, keep_prob=1.0)\n",
528 | "\n",
529 | " bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
530 | " print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
531 | " print(', Test Accuracy (k = %d): %f, ' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
532 | " mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
533 | " emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
534 | " total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
535 | " print(\"Total Savings: %f\" % total_savings)"
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": 13,
541 | "metadata": {},
542 | "outputs": [],
543 | "source": [
544 | "params = {\n",
545 | " \"NUM_HIDDEN\" : 128,\n",
546 | " \"NUM_TIMESTEPS\" : 48, #subinstance length.\n",
547 | " \"ORIGINAL_NUM_TIMESTEPS\" : 128,\n",
548 | " \"NUM_FEATS\" : 16,\n",
549 | " \"FORGET_BIAS\" : 1.0,\n",
550 | " \"NUM_OUTPUT\" : 5,\n",
551 | " \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n",
552 | " \"KEEP_PROB\" : 0.75,\n",
553 | " \"PREFETCH_NUM\" : 5,\n",
554 | " \"BATCH_SIZE\" : 32,\n",
555 | " \"NUM_EPOCHS\" : 2,\n",
556 | " \"NUM_ITER\" : 4,\n",
557 | " \"NUM_ROUNDS\" : 10,\n",
558 | " \"LEARNING_RATE\" : 0.001,\n",
559 | " \"MODEL_PREFIX\" : '/home/sf/data/DREAMER/Valence/model-gru'\n",
560 | "}"
561 | ]
562 | },
563 | {
564 | "cell_type": "code",
565 | "execution_count": 14,
566 | "metadata": {},
567 | "outputs": [
568 | {
569 | "name": "stdout",
570 | "output_type": "stream",
571 | "text": [
572 | " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
573 | "0 1 0.612980 0.611018 0.657280 0.610441 0.612980 0.612980 \n",
574 | "1 2 0.630941 0.631356 0.640975 0.628401 0.630941 0.630941 \n",
575 | "2 3 0.606158 0.608647 0.637136 0.603571 0.606158 0.606158 \n",
576 | "3 4 0.550061 0.560872 0.669536 0.547729 0.550061 0.550061 \n",
577 | "4 5 0.506035 0.520660 0.712413 0.504197 0.506035 0.506035 \n",
578 | "5 6 0.469532 0.482675 0.743530 0.468139 0.469532 0.469532 \n",
579 | "\n",
580 | " micro-rec \n",
581 | "0 0.612980 \n",
582 | "1 0.630941 \n",
583 | "2 0.606158 \n",
584 | "3 0.550061 \n",
585 | "4 0.506035 \n",
586 | "5 0.469532 \n",
587 | "Max accuracy 0.630941 at subsequencelength 2\n",
588 | "Max micro-f 0.630941 at subsequencelength 2\n",
589 | "Micro-precision 0.630941 at subsequencelength 2\n",
590 | "Micro-recall 0.630941 at subsequencelength 2\n",
591 | "Max macro-f 0.631356 at subsequencelength 2\n",
592 | "macro-precision 0.640975 at subsequencelength 2\n",
593 | "macro-recall 0.628401 at subsequencelength 2\n",
594 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
595 | "| | len | acc | macro-fsc | macro-pre | macro-rec | micro-fsc | micro-pre | micro-rec |\n",
596 | "+====+=======+==========+=============+=============+=============+=============+=============+=============+\n",
597 | "| 0 | 1 | 0.61298 | 0.611018 | 0.65728 | 0.610441 | 0.61298 | 0.61298 | 0.61298 |\n",
598 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
599 | "| 1 | 2 | 0.630941 | 0.631356 | 0.640975 | 0.628401 | 0.630941 | 0.630941 | 0.630941 |\n",
600 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
601 | "| 2 | 3 | 0.606158 | 0.608647 | 0.637136 | 0.603571 | 0.606158 | 0.606158 | 0.606158 |\n",
602 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
603 | "| 3 | 4 | 0.550061 | 0.560872 | 0.669536 | 0.547729 | 0.550061 | 0.550061 | 0.550061 |\n",
604 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
605 | "| 4 | 5 | 0.506035 | 0.52066 | 0.712413 | 0.504197 | 0.506035 | 0.506035 | 0.506035 |\n",
606 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
607 | "| 5 | 6 | 0.469532 | 0.482675 | 0.74353 | 0.468139 | 0.469532 | 0.469532 | 0.469532 |\n",
608 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n"
609 | ]
610 | }
611 | ],
612 | "source": [
613 | "gru_dict = {**params}\n",
614 | "gru_dict[\"k\"] = k\n",
615 | "gru_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
616 | "gru_dict[\"total_savings\"] = total_savings\n",
617 | "gru_dict[\"y_test\"] = BAG_TEST\n",
618 | "gru_dict[\"y_pred\"] = bagPredictions\n",
619 | "\n",
620 | "# A slightly more detailed analysis method is provided. \n",
621 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
622 | "print (tabulate(df, headers=list(df.columns), tablefmt='grid'))"
623 | ]
624 | },
625 | {
626 | "cell_type": "code",
627 | "execution_count": 15,
628 | "metadata": {},
629 | "outputs": [
630 | {
631 | "name": "stdout",
632 | "output_type": "stream",
633 | "text": [
634 | "Results for this run have been saved at home/sf/data/DREAMER/Valence/GRU/ .\n"
635 | ]
636 | }
637 | ],
638 | "source": [
639 | "dirname = \"home/sf/data/DREAMER/Valence/GRU/\"\n",
640 | "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
641 | "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
642 | "\n",
643 | "now = datetime.datetime.now()\n",
644 | "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
645 | "filename = ''.join(filename)\n",
646 | "\n",
647 | "#Save the dictionary containing the params and the results.\n",
648 | "pkl.dump(gru_dict,open(dirname + filename + \".pkl\",mode='wb'))"
649 | ]
650 | },
651 | {
652 | "cell_type": "code",
653 | "execution_count": 16,
654 | "metadata": {},
655 | "outputs": [
656 | {
657 | "data": {
658 | "text/plain": [
659 | "'home/sf/data/DREAMER/Valence/GRU/2019-8-16|3-1.pkl'"
660 | ]
661 | },
662 | "execution_count": 16,
663 | "metadata": {},
664 | "output_type": "execute_result"
665 | }
666 | ],
667 | "source": [
668 | "dirname+filename+'.pkl'"
669 | ]
670 | }
671 | ],
672 | "metadata": {
673 | "kernelspec": {
674 | "display_name": "Python 3",
675 | "language": "python",
676 | "name": "python3"
677 | },
678 | "language_info": {
679 | "codemirror_mode": {
680 | "name": "ipython",
681 | "version": 3
682 | },
683 | "file_extension": ".py",
684 | "mimetype": "text/x-python",
685 | "name": "python",
686 | "nbconvert_exporter": "python",
687 | "pygments_lexer": "ipython3",
688 | "version": "3.7.3"
689 | },
690 | "latex_envs": {
691 | "LaTeX_envs_menu_present": true,
692 | "autoclose": false,
693 | "autocomplete": true,
694 | "bibliofile": "biblio.bib",
695 | "cite_by": "apalike",
696 | "current_citInitial": 1,
697 | "eqLabelWithNumbers": true,
698 | "eqNumInitial": 1,
699 | "hotkeys": {
700 | "equation": "Ctrl-E",
701 | "itemize": "Ctrl-I"
702 | },
703 | "labels_anchors": false,
704 | "latex_user_defs": false,
705 | "report_style_numbering": false,
706 | "user_envs_cfg": false
707 | }
708 | },
709 | "nbformat": 4,
710 | "nbformat_minor": 2
711 | }
712 |
--------------------------------------------------------------------------------
/ICMLA_Presentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nandahkrishna/StressAffectDetection/1482e2dea6060d3bccdcbc5a70a286b2cb11bb58/ICMLA_Presentation.pdf
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Abhijith Ragav, Nanda H Krishna, Naveen Narayanan, Kevin Thelly, Vineeth Vijayaraghavan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | The code in this repository incorporates materials from https://github.com/microsoft/EdgeML which has been licensed under the MIT License (c) Microsoft Corporation.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stress and Affect Detection on Resource-Constrained Devices
2 |
3 | This repository contains code for the paper "Stress and Affect Detection on Resource-Constrained Devices", presented at the 18th International Conference on Machine Learning and Applications (2019), Boca Raton, FL, USA.
4 |
5 | We discuss an efficient and accurate solution to stress detection in our paper. The datasets used are:
6 | * WESAD
7 | * SWELL-KW
8 | * DREAMER
9 |
10 | Paper published in Proceedings of the 18th International Conference on Machine Learning and Applications (2019), IEEE Xplore.
11 |
12 | Link: https://ieeexplore.ieee.org/document/8999216
13 |
14 | Citation:
15 |
16 | ```
17 | @inproceedings{8999216,
18 | author={A. {Ragav} and N. H. {Krishna} and N. {Narayanan} and K. {Thelly} and V. {Vijayaraghavan}},
19 | booktitle={2019 18th IEEE International Conference On Machine Learning And Applications (ICMLA)},
20 | title={Scalable Deep Learning for Stress and Affect Detection on Resource-Constrained Devices},
21 | year={2019},
22 | volume={},
23 | number={},
24 | pages={1585-1592},
25 | }
26 | ```
27 |
28 | Authors: [Abhijith Ragav](https://github.com/abhijithragav), [Nanda H Krishna](https://github.com/nandahkrishna), [Naveen Narayanan](https://github.com/naveenggmu), [Kevin Thelly](https://github.com/KevinThelly), Vineeth Vijayaraghavan
29 |
--------------------------------------------------------------------------------
/SWELL-KW/SWELL-KW_Scores.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# SWELL-KW Scores"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {
21 | "ExecuteTime": {
22 | "end_time": "2019-04-30T09:51:58.751435Z",
23 | "start_time": "2019-04-30T09:51:57.442626Z"
24 | }
25 | },
26 | "outputs": [],
27 | "source": [
28 | "from __future__ import print_function\n",
29 | "import os\n",
30 | "import sys\n",
31 | "import tensorflow as tf\n",
32 | "import numpy as np\n",
33 | "# To include edgeml in python path\n",
34 | "sys.path.insert(0, '../../')\n",
35 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
36 | "\n",
37 | "# MI-RNN and EMI-RNN imports\n",
38 | "from edgeml.graph.rnn import EMI_DataPipeline\n",
39 | "from edgeml.graph.rnn import EMI_FastGRNN,EMI_FastRNN\n",
40 | "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
41 | "import edgeml.utils"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {
48 | "ExecuteTime": {
49 | "end_time": "2019-07-28T09:37:09.981820Z",
50 | "start_time": "2019-07-28T09:37:09.975515Z"
51 | }
52 | },
53 | "outputs": [],
54 | "source": [
55 | "NUM_HIDDEN = 128\n",
56 | "NUM_TIMESTEPS = 30\n",
57 | "NUM_FEATS = 22\n",
58 | "FORGET_BIAS = 1.0\n",
59 | "NUM_OUTPUT = 3\n",
60 | "USE_DROPOUT = 0\n",
61 | "\n",
62 | "KEEP_PROB = 0.9\n",
63 | "UPDATE_NL = \"quantTanh\"\n",
64 | "GATE_NL = \"quantSigm\"\n",
65 | "WRANK = 5\n",
66 | "URANK = 6\n",
67 | "PREFETCH_NUM = 5\n",
68 | "BATCH_SIZE = 32\n",
69 | "NUM_EPOCHS = 50\n",
70 | "NUM_ITER = 4\n",
71 | "NUM_ROUNDS = 2\n",
72 | "\n",
73 | "# A staging direcory to store models\n",
74 | "MODEL_PREFIX = '/home/sf/data/SWELL-KW/FGModels_30_10/'"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {},
80 | "source": [
81 | "# Loading Data"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {
88 | "ExecuteTime": {
89 | "end_time": "2019-04-30T09:52:00.022110Z",
90 | "start_time": "2019-04-30T09:51:59.925101Z"
91 | }
92 | },
93 | "outputs": [],
94 | "source": [
95 | "path=\"/home/sf/data/SWELL-KW/30_10/\"\n",
96 | "x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
97 | "x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
98 | "x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
99 | "\n",
100 | "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
101 | "# step of EMI/MI RNN\n",
102 | "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
103 | "BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
104 | "BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
105 | "NUM_SUBINSTANCE = x_train.shape[1]\n",
106 | "print(\"x_train shape is:\", x_train.shape)\n",
107 | "print(\"y_train shape is:\", y_train.shape)\n",
108 | "print(\"x_test shape is:\", x_val.shape)\n",
109 | "print(\"y_test shape is:\", y_val.shape)"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "ExecuteTime": {
117 | "end_time": "2019-04-30T09:52:00.604049Z",
118 | "start_time": "2019-04-30T09:52:00.589634Z"
119 | }
120 | },
121 | "outputs": [],
122 | "source": [
123 | "# Define the linear secondary classifier\n",
124 | "def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
125 | " W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
126 | " B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
127 | " y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
128 | " self.output = y_cap\n",
129 | " self.graphCreated = True\n",
130 | " \n",
131 | "def addExtendedAssignOps(self, graph, W_val=None, B_val=None):\n",
132 | " W1 = graph.get_tensor_by_name('W1:0')\n",
133 | " B1 = graph.get_tensor_by_name('B1:0')\n",
134 | " W1_op = tf.assign(W1, W_val)\n",
135 | " B1_op = tf.assign(B1, B_val)\n",
136 | " self.assignOps.extend([W1_op, B1_op])\n",
137 | "\n",
138 | "def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
139 | " y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
140 | " self.output = y_cap\n",
141 | " self.graphCreated = True\n",
142 | " \n",
143 | "def feedDictFunc(self, keep_prob, **kwargs):\n",
144 | " feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
145 | " return feedDict\n",
146 | " \n",
147 | "EMI_FastGRNN._createExtendedGraph = createExtendedGraph\n",
148 | "EMI_FastGRNN._restoreExtendedGraph = restoreExtendedGraph\n",
149 | "EMI_FastGRNN.addExtendedAssignOps = addExtendedAssignOps\n",
150 | "\n",
151 | "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
152 | " assert instanceOut.ndim == 2\n",
153 | " classes = np.argmax(instanceOut, axis=1)\n",
154 | " prob = np.max(instanceOut, axis=1)\n",
155 | " index = np.where(prob >= minProb)[0]\n",
156 | " if len(index) == 0:\n",
157 | " assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
158 | " return classes[-1], len(instanceOut) - 1\n",
159 | " index = index[0]\n",
160 | " return classes[index], index\n",
161 | "\n",
162 | "\n",
163 | "if USE_DROPOUT is True:\n",
164 | " EMI_Driver.feedDictFunc = feedDictFunc"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {
170 | "ExecuteTime": {
171 | "end_time": "2018-08-19T09:34:06.288012Z",
172 | "start_time": "2018-08-19T09:34:06.285286Z"
173 | }
174 | },
175 | "source": [
176 | "## 1. Initializing a New Computation Graph"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "metadata": {
183 | "ExecuteTime": {
184 | "end_time": "2019-04-30T09:52:10.701762Z",
185 | "start_time": "2019-04-30T09:52:02.074816Z"
186 | }
187 | },
188 | "outputs": [],
189 | "source": [
190 | "tf.reset_default_graph()\n",
191 | "\n",
192 | "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
193 | "emiLSTM = EMI_FastGRNN(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS, wRank=WRANK, uRank=URANK, \n",
194 | " gate_non_linearity=GATE_NL, update_non_linearity=UPDATE_NL, useDropout=USE_DROPOUT)\n",
195 | "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy')\n",
196 | "\n",
197 | "# Construct the graph\n",
198 | "g1 = tf.Graph() \n",
199 | "with g1.as_default():\n",
200 | " x_batch, y_batch = inputPipeline()\n",
201 | " y_cap = emiLSTM(x_batch)\n",
202 | " emiTrainer(y_cap, y_batch)\n",
203 | " \n",
204 | "with g1.as_default():\n",
205 | " emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)\n"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {
212 | "ExecuteTime": {
213 | "end_time": "2019-04-25T06:25:20.220063Z",
214 | "start_time": "2019-04-25T06:25:19.987538Z"
215 | }
216 | },
217 | "outputs": [],
218 | "source": [
219 | "emiDriver.initializeSession(g1)\n",
220 | "#y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
221 | "# y_train=y_train, bag_train=BAG_TRAIN,\n",
222 | "# x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
223 | "# numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
224 | "# numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
225 | "# numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
226 | "# fracEMI=0.5, updatePolicy='top-k', k=1)"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "metadata": {
233 | "ExecuteTime": {
234 | "end_time": "2018-08-19T11:48:33.294431Z",
235 | "start_time": "2018-08-19T11:48:32.897376Z"
236 | }
237 | },
238 | "outputs": [],
239 | "source": [
240 | "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
241 | " assert instanceOut.ndim == 2\n",
242 | " classes = np.argmax(instanceOut, axis=1)\n",
243 | " prob = np.max(instanceOut, axis=1)\n",
244 | " index = np.where(prob >= minProb)[0]\n",
245 | " if len(index) == 0:\n",
246 | " assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
247 | " return classes[-1], len(instanceOut) - 1\n",
248 | " index = index[0]\n",
249 | " return classes[index], index\n",
250 | "\n",
251 | "emiDriver.initializeSession(g1)\n",
252 | "\n",
253 | "k = 2\n",
254 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
255 | " minProb=0.99, keep_prob=1.0)\n",
256 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
257 | "print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {},
263 | "source": [
264 | "## 2. Loading a Saved Graph into EMI-Driver"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "metadata": {
271 | "ExecuteTime": {
272 | "end_time": "2019-04-25T06:24:09.713351Z",
273 | "start_time": "2019-04-25T06:24:09.638610Z"
274 | }
275 | },
276 | "outputs": [],
277 | "source": [
278 | "tf.reset_default_graph()\n",
279 | "emiDriver.initializeSession(g1)\n",
280 | "emiDriver.loadSavedGraphToNewSession(MODEL_PREFIX , 1039)\n",
281 | "k = 1"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "metadata": {
288 | "ExecuteTime": {
289 | "end_time": "2019-04-25T06:24:09.713351Z",
290 | "start_time": "2019-04-25T06:24:09.638610Z"
291 | }
292 | },
293 | "outputs": [],
294 | "source": [
295 | "%%time\n",
296 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test[:64], y_test[:64], earlyPolicy_minProb,\n",
297 | " minProb=0.99, keep_prob=1.0)\n",
298 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": null,
304 | "metadata": {
305 | "ExecuteTime": {
306 | "end_time": "2019-04-25T06:24:09.713351Z",
307 | "start_time": "2019-04-25T06:24:09.638610Z"
308 | }
309 | },
310 | "outputs": [],
311 | "source": [
312 | "print('Accuracy at k = %d: %f' % (k, np.mean(int(bagPredictions == BAG_TEST))))"
313 | ]
314 | },
315 | {
316 | "cell_type": "markdown",
317 | "metadata": {},
318 | "source": [
319 | "## 3. Initializing using a Saved Graph"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": null,
325 | "metadata": {
326 | "ExecuteTime": {
327 | "end_time": "2019-04-24T12:07:09.616748Z",
328 | "start_time": "2019-04-24T12:07:09.596906Z"
329 | }
330 | },
331 | "outputs": [],
332 | "source": [
333 | "# Making sure the old graph and sessions are closed\n",
334 | "sess = emiDriver.getCurrentSession()\n",
335 | "sess.close()\n",
336 | "tf.reset_default_graph()"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": null,
342 | "metadata": {
343 | "ExecuteTime": {
344 | "end_time": "2019-04-30T09:52:19.568739Z",
345 | "start_time": "2019-04-30T09:52:10.703663Z"
346 | }
347 | },
348 | "outputs": [],
349 | "source": [
350 | "tf.reset_default_graph()\n",
351 | "\n",
352 | "sess = tf.Session()\n",
353 | "graphManager = edgeml.utils.GraphManager()\n",
354 | "graph = graphManager.loadCheckpoint(sess, MODEL_PREFIX, globalStep=1004)"
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": null,
360 | "metadata": {
361 | "ExecuteTime": {
362 | "end_time": "2019-04-30T09:52:20.570380Z",
363 | "start_time": "2019-04-30T09:52:19.571022Z"
364 | }
365 | },
366 | "outputs": [],
367 | "source": [
368 | "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT, graph=graph)\n",
369 | "emiLSTM = EMI_FastGRNN(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS, wRank=WRANK, uRank=URANK, \n",
370 | " gate_non_linearity=GATE_NL, update_non_linearity=UPDATE_NL, useDropout=USE_DROPOUT)\n",
371 | "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy', graph=graph)\n",
372 | "\n",
373 | "g1 = graph\n",
374 | "with g1.as_default():\n",
375 | " x_batch, y_batch = inputPipeline()\n",
376 | " y_cap = emiLSTM(x_batch)\n",
377 | " emiTrainer(y_cap, y_batch)\n",
378 | " \n",
379 | "with g1.as_default():\n",
380 | " emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": null,
386 | "metadata": {
387 | "ExecuteTime": {
388 | "end_time": "2019-04-30T09:52:20.574716Z",
389 | "start_time": "2019-04-30T09:52:20.572193Z"
390 | }
391 | },
392 | "outputs": [],
393 | "source": [
394 | "emiDriver.setSession(sess)"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": null,
400 | "metadata": {
401 | "ExecuteTime": {
402 | "end_time": "2019-04-30T09:52:34.913965Z",
403 | "start_time": "2019-04-30T09:52:32.795936Z"
404 | }
405 | },
406 | "outputs": [],
407 | "source": [
408 | "%%time\n",
409 | "\n",
410 | "# tf.reset_default_graph()\n",
411 | "# emiDriver.initializeSession(g1)\n",
412 | "# emiDriver.loadSavedGraphToNewSession(MODEL_PREFIX, 1007)\n",
413 | "k = 1\n",
414 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
415 | " minProb=0.99, keep_prob=1.0)\n",
416 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
417 | "print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))\n"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "metadata": {
424 | "ExecuteTime": {
425 | "end_time": "2019-04-30T09:52:42.334950Z",
426 | "start_time": "2019-04-30T09:52:42.318103Z"
427 | }
428 | },
429 | "outputs": [],
430 | "source": [
431 | "x_test.shape, bagPredictions.shape,y_test.shape"
432 | ]
433 | },
434 | {
435 | "cell_type": "markdown",
436 | "metadata": {},
437 | "source": [
438 | "## 4. Restoring from Numpy Matrices"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": null,
444 | "metadata": {
445 | "ExecuteTime": {
446 | "end_time": "2018-08-19T11:48:44.379901Z",
447 | "start_time": "2018-08-19T11:48:44.326706Z"
448 | }
449 | },
450 | "outputs": [],
451 | "source": [
452 | "graph = tf.get_default_graph()\n",
453 | "W1 = graph.get_tensor_by_name('W1:0')\n",
454 | "B1 = graph.get_tensor_by_name('B1:0')\n",
455 | "allVars = emiLSTM.varList + [W1, B1]\n",
456 | "sess = emiDriver.getCurrentSession()\n",
457 | "allVars = sess.run(allVars)\n",
458 | "\n",
459 | "base = '/tmp/models/'\n",
460 | "np.save(base + 'kernel.npy', allVars[0])\n",
461 | "np.save(base + 'bias.npy', allVars[1])\n",
462 | "np.save(base + 'W1.npy', allVars[2])\n",
463 | "np.save(base + 'B1.npy', allVars[3])"
464 | ]
465 | },
466 | {
467 | "cell_type": "code",
468 | "execution_count": null,
469 | "metadata": {
470 | "ExecuteTime": {
471 | "end_time": "2018-08-19T11:48:44.389724Z",
472 | "start_time": "2018-08-19T11:48:44.381802Z"
473 | }
474 | },
475 | "outputs": [],
476 | "source": [
477 | "sess = emiDriver.getCurrentSession()\n",
478 | "sess.close()\n",
479 | "tf.reset_default_graph()"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": null,
485 | "metadata": {
486 | "ExecuteTime": {
487 | "end_time": "2018-08-19T11:48:44.442241Z",
488 | "start_time": "2018-08-19T11:48:44.391384Z"
489 | }
490 | },
491 | "outputs": [],
492 | "source": [
493 | "base = '/home/iot/Documents/EdgeML-master/tf/examples/EMI-RNN/GRNN model'\n",
494 | "kernel = np.load(base + 'kernel.npy')\n",
495 | "bias = np.load(base + 'bias.npy')\n",
496 | "W = np.load(base + 'W1.npy')\n",
497 | "B = np.load(base + 'B1.npy')"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": null,
503 | "metadata": {
504 | "ExecuteTime": {
505 | "end_time": "2018-08-19T11:48:51.378377Z",
506 | "start_time": "2018-08-19T11:48:44.444182Z"
507 | }
508 | },
509 | "outputs": [],
510 | "source": [
511 | "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS,\n",
512 | " NUM_OUTPUT)\n",
513 | "emiLSTM = EMI_Fast(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
514 | " forgetBias=FORGET_BIAS, useDropout=USE_DROPOUT)\n",
515 | "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy')\n",
516 | "\n",
517 | "tf.reset_default_graph()\n",
518 | "graph = tf.Graph()\n",
519 | "\n",
520 | "with graph.as_default():\n",
521 | " x_batch, y_batch = inputPipeline()\n",
522 | " y_cap = emiLSTM(x_batch)\n",
523 | " emiTrainer(y_cap, y_batch)\n",
524 | " # Add the assignment operations\n",
525 | " emiLSTM.addBaseAssignOps(graph, [kernel, bias])\n",
526 | " emiLSTM.addExtendedAssignOps(graph, W, B)\n",
527 | " # Setup the driver. You can run the initializations manually as well\n",
528 | " emiDriver = EMI_Driver(inputPipeline, emiLSTM, emiTrainer)\n",
529 | "\n",
530 | "emiDriver.initializeSession(graph)\n",
531 | "# Run the assignment operations\n",
532 | "sess = emiDriver.getCurrentSession()\n",
533 | "sess.run(emiLSTM.assignOps)\n",
534 | "\n",
535 | "k = 2\n",
536 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test,\n",
537 | " earlyPolicy_minProb,\n",
538 | " minProb=0.99,\n",
539 | " keep_prob=1.0)\n",
540 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k,\n",
541 | " numClass=NUM_OUTPUT)\n",
542 | "print('PART IV: Accuracy at k = %d: %f' % (k, np.mean((bagPredictions ==\n",
543 | " BAG_TEST).astype(int))))"
544 | ]
545 | }
546 | ],
547 | "metadata": {
548 | "kernelspec": {
549 | "display_name": "Python 3",
550 | "language": "python",
551 | "name": "python3"
552 | },
553 | "language_info": {
554 | "codemirror_mode": {
555 | "name": "ipython",
556 | "version": 3
557 | },
558 | "file_extension": ".py",
559 | "mimetype": "text/x-python",
560 | "name": "python",
561 | "nbconvert_exporter": "python",
562 | "pygments_lexer": "ipython3",
563 | "version": "3.7.3"
564 | },
565 | "latex_envs": {
566 | "LaTeX_envs_menu_present": true,
567 | "autoclose": false,
568 | "autocomplete": true,
569 | "bibliofile": "biblio.bib",
570 | "cite_by": "apalike",
571 | "current_citInitial": 1,
572 | "eqLabelWithNumbers": true,
573 | "eqNumInitial": 1,
574 | "hotkeys": {
575 | "equation": "Ctrl-E",
576 | "itemize": "Ctrl-I"
577 | },
578 | "labels_anchors": false,
579 | "latex_user_defs": false,
580 | "report_style_numbering": false,
581 | "user_envs_cfg": false
582 | }
583 | },
584 | "nbformat": 4,
585 | "nbformat_minor": 2
586 | }
587 |
--------------------------------------------------------------------------------
/WESAD/WESAD_Analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# WESAD Dataset Analysis"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "import matplotlib as mpl\n",
18 | "import matplotlib.pyplot as plt\n",
19 | "import seaborn as sns\n",
20 | "import pandas as pd"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "df = pd.read_csv(\"../../WESAD/allchest.csv\")"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "df = df[df['ID'] == 2]"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "df.reset_index(inplace=True, drop=True)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "df['label'].unique()"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "sns.set_context(\"paper\", rc={\"lines.linewidth\": 2.5})\n",
66 | "sns.set_palette(\"binary_d\")"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# Neutral\n",
76 | "sns.lineplot(data=df[df['label'] == 1].reset_index(drop=True)['chestResp'])\n",
77 | "plt.ylabel('RESP')\n",
78 | "plt.xlabel('Sequential Data-Points')\n",
79 | "plt.show()"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {
86 | "scrolled": false
87 | },
88 | "outputs": [],
89 | "source": [
90 | "# Stress\n",
91 | "sns.lineplot(data=df[df['label'] == 2].reset_index(drop=True)['chestResp'])\n",
92 | "plt.ylabel('RESP')\n",
93 | "plt.xlabel('Sequential Data-Points')\n",
94 | "plt.show()"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "# Neutral\n",
104 | "sns.lineplot(data=df[df['label'] == 1].reset_index(drop=True).iloc[400000:420000]['chestResp'])\n",
105 | "plt.ylabel('RESP')\n",
106 | "plt.xlabel('Sequential Data-Points')\n",
107 | "plt.show()"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "# Stress\n",
117 | "sns.lineplot(data=df[df['label'] == 2].reset_index(drop=True).iloc[50000:70000]['chestResp'])\n",
118 | "plt.ylabel('RESP')\n",
119 | "plt.xlabel('Sequential Data-Points')\n",
120 | "plt.show()"
121 | ]
122 | }
123 | ],
124 | "metadata": {
125 | "kernelspec": {
126 | "display_name": "Python 3",
127 | "language": "python",
128 | "name": "python3"
129 | },
130 | "language_info": {
131 | "codemirror_mode": {
132 | "name": "ipython",
133 | "version": 3
134 | },
135 | "file_extension": ".py",
136 | "mimetype": "text/x-python",
137 | "name": "python",
138 | "nbconvert_exporter": "python",
139 | "pygments_lexer": "ipython3",
140 | "version": "3.7.3"
141 | }
142 | },
143 | "nbformat": 4,
144 | "nbformat_minor": 2
145 | }
146 |
--------------------------------------------------------------------------------
/WESAD/WESAD_Data_Numpy.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from sklearn.utils import shuffle
4 |
5 | df = pd.read_csv("../allchest.csv")
6 | df_list = list(map(lambda x: df[df['ID'] == x], [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17]))
7 | df_list = list(map(lambda x: x[x['label'] <= 4], df_list))
8 | del df
9 |
10 | train = []
11 | test = []
12 |
13 | for df in df_list:
14 | df.reset_index(inplace=True, drop=True)
15 | df = shuffle(df, random_state=42)
16 | df.reset_index(inplace=True, drop=True)
17 | train.append(df.iloc[:int(0.6 * df.shape[0]), :])
18 | test.append(df.iloc[int(0.6 * df.shape[0]):, :])
19 | del df_list
20 |
21 | train = pd.concat(train, axis=0)
22 | test = pd.concat(test, axis=0)
23 | train.reset_index(inplace=True, drop=True)
24 | test.reset_index(inplace=True, drop=True)
25 | train = train[['label', 'ID', 'chestACCx', 'chestACCy', 'chestACCz', 'chestECG', 'chestEMG', 'chestEDA', 'chestTemp', 'chestResp']]
26 | test = test[['label', 'ID', 'chestACCx', 'chestACCy', 'chestACCz', 'chestECG', 'chestEMG', 'chestEDA', 'chestTemp', 'chestResp']]
27 |
28 | train = np.array(train, dtype=np.float32)
29 | test = np.array(test, dtype=np.float32)
30 |
31 | np.save("train.npy", train)
32 | np.save("test.npy", test)
33 |
--------------------------------------------------------------------------------
/WESAD/WESAD_Extract.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "ExecuteTime": {
8 | "end_time": "2019-06-17T19:14:54.249772Z",
9 | "start_time": "2019-06-17T19:14:54.237352Z"
10 | },
11 | "code_folding": []
12 | },
13 | "outputs": [],
14 | "source": [
15 | "def pickle_to_csv_chest(ID):\n",
16 | " import pickle as pkl\n",
17 | " import pandas as pd\n",
18 | " with open('S' + str(ID) + '/S'+str(ID) + '.pkl', 'rb') as f:\n",
19 | " u = pkl._Unpickler(f)\n",
20 | " u.encoding = 'latin1'\n",
21 | " p = u.load()\n",
22 | " df = pd.DataFrame()\n",
23 | " df['chestACCx'] = [item[0] for item in p['signal']['chest']['ACC']]\n",
24 | " df['chestACCy'] = [item[1] for item in p['signal']['chest']['ACC']]\n",
25 | " df['chestACCz'] = [item[2] for item in p['signal']['chest']['ACC']]\n",
26 | " df['chestECG'] = [item for sublist in p['signal']['chest']['ECG'] for item in sublist]\n",
27 | " df['chestEMG'] = [item for sublist in p['signal']['chest']['EMG'] for item in sublist]\n",
28 | " df['chestEDA'] = [item for sublist in p['signal']['chest']['EDA'] for item in sublist]\n",
29 | " df['chestTemp'] = [item for sublist in p['signal']['chest']['Temp'] for item in sublist]\n",
30 | " df['chestResp'] = [item for sublist in p['signal']['chest']['Resp'] for item in sublist]\n",
31 | " df['ID'] = ID\n",
32 | " df['label'] = p['label']\n",
33 | " df = df[['ID','chestACCx','chestACCy','chestACCz','chestECG','chestEMG','chestEDA','chestTemp','chestResp','label']]\n",
34 | " df.to_csv('S' + str(ID) + 'chest.csv', index=False)"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 2,
40 | "metadata": {
41 | "ExecuteTime": {
42 | "end_time": "2019-06-17T19:40:19.352229Z",
43 | "start_time": "2019-06-17T19:14:57.985458Z"
44 | }
45 | },
46 | "outputs": [],
47 | "source": [
48 | "for i in range(1,18):\n",
49 | " if i != 1 and i != 12:\n",
50 | " pickle_to_csv_chest(i)"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 3,
56 | "metadata": {
57 | "ExecuteTime": {
58 | "end_time": "2019-06-17T19:43:12.702035Z",
59 | "start_time": "2019-06-17T19:40:19.354185Z"
60 | }
61 | },
62 | "outputs": [],
63 | "source": [
64 | "!cat *chest.csv > allchest.csv"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 4,
70 | "metadata": {
71 | "ExecuteTime": {
72 | "end_time": "2019-06-17T20:08:30.037790Z",
73 | "start_time": "2019-06-17T19:43:52.786621Z"
74 | }
75 | },
76 | "outputs": [
77 | {
78 | "name": "stdout",
79 | "output_type": "stream",
80 | "text": [
81 | "The following options are in effect for this COMPRESSION.\n",
82 | "Threading is ENABLED. Number of CPUs detected: 16\n",
83 | "Detected 33628700672 bytes ram\n",
84 | "Compression level 7\n",
85 | "Nice Value: 19\n",
86 | "Show Progress\n",
87 | "Verbose\n",
88 | "Remove input files on completion\n",
89 | "Temporary Directory set as: ./\n",
90 | "Compression mode is: LZMA. LZO Compressibility testing enabled\n",
91 | "Heuristically Computed Compression Window: 213 = 21300MB\n",
92 | "Output filename is: allchest.csv.lrz\n",
93 | "File size: 9488548771\n",
94 | "Will take 1 pass\n",
95 | "Beginning rzip pre-processing phase\n",
96 | "allchest.csv - Compression Ratio: 7.274. Average Compression Speed: 6.130MB/s.\n",
97 | "Total time: 00:24:37.12\n"
98 | ]
99 | }
100 | ],
101 | "source": [
102 | "!lrzip -v -D allchest.csv"
103 | ]
104 | }
105 | ],
106 | "metadata": {
107 | "kernelspec": {
108 | "display_name": "Python 3",
109 | "language": "python",
110 | "name": "python3"
111 | },
112 | "language_info": {
113 | "codemirror_mode": {
114 | "name": "ipython",
115 | "version": 3
116 | },
117 | "file_extension": ".py",
118 | "mimetype": "text/x-python",
119 | "name": "python",
120 | "nbconvert_exporter": "python",
121 | "pygments_lexer": "ipython3",
122 | "version": "3.7.3"
123 | }
124 | },
125 | "nbformat": 4,
126 | "nbformat_minor": 2
127 | }
128 |
--------------------------------------------------------------------------------
/WESAD/WESAD_FastGRNN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# WESAD FastGRNN"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Adapted from Microsoft's notebooks, available at https://github.com/microsoft/EdgeML authored by Dennis et al."
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import pandas as pd\n",
24 | "import numpy as np\n",
25 | "from tabulate import tabulate\n",
26 | "import os\n",
27 | "import datetime as datetime\n",
28 | "import pickle as pkl\n",
29 | "from sklearn.model_selection import train_test_split\n",
30 | "import pathlib\n",
31 | "from os import mkdir"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 2,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "def loadData(dirname):\n",
41 | " x_train = np.load(dirname + '/' + 'x_train.npy')\n",
42 | " y_train = np.load(dirname + '/' + 'y_train.npy')\n",
43 | " x_test = np.load(dirname + '/' + 'x_test.npy')\n",
44 | " y_test = np.load(dirname + '/' + 'y_test.npy')\n",
45 | " x_val = np.load(dirname + '/' + 'x_val.npy')\n",
46 | " y_val = np.load(dirname + '/' + 'y_val.npy')\n",
47 | " return x_train, y_train, x_test, y_test, x_val, y_val\n",
48 | "def makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir):\n",
49 | " x_train, y_train, x_test, y_test, x_val, y_val = loadData(sourceDir)\n",
50 | " x, y = bagData(x_train, y_train, subinstanceLen, subinstanceStride)\n",
51 | " np.save(outDir + '/x_train.npy', x)\n",
52 | " np.save(outDir + '/y_train.npy', y)\n",
53 | " print('Num train %d' % len(x))\n",
54 | " x, y = bagData(x_test, y_test, subinstanceLen, subinstanceStride)\n",
55 | " np.save(outDir + '/x_test.npy', x)\n",
56 | " np.save(outDir + '/y_test.npy', y)\n",
57 | " print('Num test %d' % len(x))\n",
58 | " x, y = bagData(x_val, y_val, subinstanceLen, subinstanceStride)\n",
59 | " np.save(outDir + '/x_val.npy', x)\n",
60 | " np.save(outDir + '/y_val.npy', y)\n",
61 | " print('Num val %d' % len(x))\n",
62 | "def bagData(X, Y, subinstanceLen, subinstanceStride):\n",
63 | " numClass = 3\n",
64 | " numSteps = 175\n",
65 | " numFeats = 8\n",
66 | " assert X.ndim == 3\n",
67 | " assert X.shape[1] == numSteps\n",
68 | " assert X.shape[2] == numFeats\n",
69 | " assert subinstanceLen <= numSteps\n",
70 | " assert subinstanceLen > 0\n",
71 | " assert subinstanceStride <= numSteps\n",
72 | " assert subinstanceStride >= 0\n",
73 | " assert len(X) == len(Y)\n",
74 | " assert Y.ndim == 2\n",
75 | " assert Y.shape[1] == numClass\n",
76 | " x_bagged = []\n",
77 | " y_bagged = []\n",
78 | " for i, point in enumerate(X[:, :, :]):\n",
79 | " instanceList = []\n",
80 | " start = 0\n",
81 | " end = subinstanceLen\n",
82 | " while True:\n",
83 | " x = point[start:end, :]\n",
84 | " if len(x) < subinstanceLen:\n",
85 | " x_ = np.zeros([subinstanceLen, x.shape[1]])\n",
86 | " x_[:len(x), :] = x[:, :]\n",
87 | " x = x_\n",
88 | " instanceList.append(x)\n",
89 | " if end >= numSteps:\n",
90 | " break\n",
91 | " start += subinstanceStride\n",
92 | " end += subinstanceStride\n",
93 | " bag = np.array(instanceList)\n",
94 | " numSubinstance = bag.shape[0]\n",
95 | " label = Y[i]\n",
96 | " label = np.argmax(label)\n",
97 | " labelBag = np.zeros([numSubinstance, numClass])\n",
98 | " labelBag[:, label] = 1\n",
99 | " x_bagged.append(bag)\n",
100 | " label = np.array(labelBag)\n",
101 | " y_bagged.append(label)\n",
102 | " return np.array(x_bagged), np.array(y_bagged)"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 4,
108 | "metadata": {},
109 | "outputs": [
110 | {
111 | "name": "stdout",
112 | "output_type": "stream",
113 | "text": [
114 | "Num train 95450\n",
115 | "Num test 26514\n",
116 | "Num val 10606\n"
117 | ]
118 | }
119 | ],
120 | "source": [
121 | "subinstanceLen=88\n",
122 | "subinstanceStride=30\n",
123 | "extractedDir = '/home/sf/data/WESAD/'\n",
124 | "# mkdir('/home/sf/data/WESAD/Fast_GRNN/88_30')\n",
125 | "rawDir = extractedDir + '/RAW'\n",
126 | "sourceDir = rawDir\n",
127 | "outDir = extractedDir + 'Fast_GRNN' '/%d_%d/' % (subinstanceLen, subinstanceStride)\n",
128 | "makeEMIData(subinstanceLen, subinstanceStride, sourceDir, outDir)"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 7,
134 | "metadata": {},
135 | "outputs": [
136 | {
137 | "data": {
138 | "text/plain": [
139 | "'/home/sf/data/WESAD/Fast_GRNN/88_30/'"
140 | ]
141 | },
142 | "execution_count": 7,
143 | "metadata": {},
144 | "output_type": "execute_result"
145 | }
146 | ],
147 | "source": [
148 | "outDir"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 8,
154 | "metadata": {
155 | "ExecuteTime": {
156 | "end_time": "2018-08-19T12:39:06.272261Z",
157 | "start_time": "2018-08-19T12:39:05.330668Z"
158 | }
159 | },
160 | "outputs": [],
161 | "source": [
162 | "from __future__ import print_function\n",
163 | "import os\n",
164 | "import sys\n",
165 | "import tensorflow as tf\n",
166 | "import numpy as np\n",
167 | "os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
168 | "\n",
169 | "# FastGRNN and FastRNN imports\n",
170 | "from edgeml.graph.rnn import EMI_DataPipeline\n",
171 | "from edgeml.graph.rnn import EMI_FastGRNN\n",
172 | "from edgeml.graph.rnn import EMI_FastRNN\n",
173 | "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
174 | "import edgeml.utils"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 9,
180 | "metadata": {
181 | "ExecuteTime": {
182 | "end_time": "2018-08-19T12:39:06.292205Z",
183 | "start_time": "2018-08-19T12:39:06.274254Z"
184 | }
185 | },
186 | "outputs": [],
187 | "source": [
188 | "# Network parameters for our FastGRNN + FC Layer\n",
189 | "NUM_HIDDEN = 128\n",
190 | "NUM_TIMESTEPS = 88\n",
191 | "NUM_FEATS = 8\n",
192 | "FORGET_BIAS = 1.0\n",
193 | "NUM_OUTPUT = 3\n",
194 | "USE_DROPOUT = False\n",
195 | "KEEP_PROB = 0.9\n",
196 | "\n",
197 | "# Non-linearities can be chosen among \"tanh, sigmoid, relu, quantTanh, quantSigm\"\n",
198 | "UPDATE_NL = \"quantTanh\"\n",
199 | "GATE_NL = \"quantSigm\"\n",
200 | "\n",
201 | "# Ranks of Parameter matrices for low-rank parameterisation to compress models.\n",
202 | "WRANK = 5\n",
203 | "URANK = 6\n",
204 | "\n",
205 | "# For dataset API\n",
206 | "PREFETCH_NUM = 5\n",
207 | "BATCH_SIZE = 175\n",
208 | "\n",
209 | "# Number of epochs in *one iteration*\n",
210 | "NUM_EPOCHS = 3\n",
211 | "\n",
212 | "# Number of iterations in *one round*. After each iteration,\n",
213 | "# the model is dumped to disk. At the end of the current\n",
214 | "# round, the best model among all the dumped models in the\n",
215 | "# current round is picked up..\n",
216 | "NUM_ITER = 4\n",
217 | "\n",
218 | "# A round consists of multiple training iterations and a belief\n",
219 | "# update step using the best model from all of these iterations\n",
220 | "NUM_ROUNDS = 6\n",
221 | "\n",
222 | "# A staging direcory to store models\n",
223 | "MODEL_PREFIX = '/home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn'"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {},
229 | "source": [
230 | "# Loading Data"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 10,
236 | "metadata": {
237 | "ExecuteTime": {
238 | "end_time": "2018-08-19T12:39:06.410372Z",
239 | "start_time": "2018-08-19T12:39:06.294014Z"
240 | }
241 | },
242 | "outputs": [
243 | {
244 | "name": "stdout",
245 | "output_type": "stream",
246 | "text": [
247 | "x_train shape is: (95450, 4, 88, 8)\n",
248 | "y_train shape is: (95450, 4, 3)\n",
249 | "x_test shape is: (26514, 4, 88, 8)\n",
250 | "y_test shape is: (26514, 4, 3)\n"
251 | ]
252 | }
253 | ],
254 | "source": [
255 | "# Loading the data\n",
256 | "path='/home/sf/data/WESAD/Fast_GRNN/88_30/'\n",
257 | "x_train, y_train = np.load(path + 'x_train.npy'), np.load(path + 'y_train.npy')\n",
258 | "x_test, y_test = np.load(path + 'x_test.npy'), np.load(path + 'y_test.npy')\n",
259 | "x_val, y_val = np.load(path + 'x_val.npy'), np.load(path + 'y_val.npy')\n",
260 | "\n",
261 | "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
262 | "# step of EMI/MI RNN\n",
263 | "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
264 | "BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
265 | "BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
266 | "NUM_SUBINSTANCE = x_train.shape[1]\n",
267 | "print(\"x_train shape is:\", x_train.shape)\n",
268 | "print(\"y_train shape is:\", y_train.shape)\n",
269 | "print(\"x_test shape is:\", x_test.shape)\n",
270 | "print(\"y_test shape is:\", y_test.shape)"
271 | ]
272 | },
273 | {
274 | "cell_type": "markdown",
275 | "metadata": {},
276 | "source": [
277 | "# Computation Graph"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": 11,
283 | "metadata": {
284 | "ExecuteTime": {
285 | "end_time": "2018-08-19T12:39:06.653612Z",
286 | "start_time": "2018-08-19T12:39:06.412290Z"
287 | }
288 | },
289 | "outputs": [],
290 | "source": [
291 | "# Define the linear secondary classifier\n",
292 | "def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
293 | " W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
294 | " B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
295 | " y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
296 | " self.output = y_cap\n",
297 | " self.graphCreated = True\n",
298 | "\n",
299 | "def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
300 | " y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
301 | " self.output = y_cap\n",
302 | " self.graphCreated = True\n",
303 | " \n",
304 | "def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
305 | " if inference is False:\n",
306 | " feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
307 | " else:\n",
308 | " feedDict = {self._emiGraph.keep_prob: 1.0}\n",
309 | " return feedDict\n",
310 | "\n",
311 | " \n",
312 | "EMI_FastGRNN._createExtendedGraph = createExtendedGraph\n",
313 | "EMI_FastGRNN._restoreExtendedGraph = restoreExtendedGraph\n",
314 | "if USE_DROPOUT is True:\n",
315 | " EMI_FastGRNN.feedDictFunc = feedDictFunc"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 12,
321 | "metadata": {
322 | "ExecuteTime": {
323 | "end_time": "2018-08-19T12:39:06.701740Z",
324 | "start_time": "2018-08-19T12:39:06.655328Z"
325 | }
326 | },
327 | "outputs": [],
328 | "source": [
329 | "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
330 | "emiFastGRNN = EMI_FastGRNN(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS, wRank=WRANK, uRank=URANK, \n",
331 | " gate_non_linearity=GATE_NL, update_non_linearity=UPDATE_NL, useDropout=USE_DROPOUT)\n",
332 | "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy')"
333 | ]
334 | },
335 | {
336 | "cell_type": "code",
337 | "execution_count": 13,
338 | "metadata": {},
339 | "outputs": [
340 | {
341 | "name": "stdout",
342 | "output_type": "stream",
343 | "text": [
344 | "x_train shape is: (95450, 4, 88, 8)\n",
345 | "y_train shape is: (95450, 4, 3)\n",
346 | "x_test shape is: (10606, 4, 88, 8)\n",
347 | "y_test shape is: (10606, 4, 3)\n"
348 | ]
349 | }
350 | ],
351 | "source": [
352 | "print(\"x_train shape is:\", x_train.shape)\n",
353 | "print(\"y_train shape is:\", y_train.shape)\n",
354 | "print(\"x_test shape is:\", x_val.shape)\n",
355 | "print(\"y_test shape is:\", y_val.shape)"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "execution_count": 14,
361 | "metadata": {
362 | "ExecuteTime": {
363 | "end_time": "2018-08-19T12:39:14.187456Z",
364 | "start_time": "2018-08-19T12:39:06.703481Z"
365 | },
366 | "scrolled": true
367 | },
368 | "outputs": [],
369 | "source": [
370 | "tf.reset_default_graph()\n",
371 | "g1 = tf.Graph() \n",
372 | "with g1.as_default():\n",
373 | " # Obtain the iterators to each batch of the data\n",
374 | " x_batch, y_batch = inputPipeline()\n",
375 | " # Create the forward computation graph based on the iterators\n",
376 | " y_cap = emiFastGRNN(x_batch)\n",
377 | " # Create loss graphs and training routines\n",
378 | " emiTrainer(y_cap, y_batch)"
379 | ]
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "metadata": {},
384 | "source": [
385 | "# EMI Driver"
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": 15,
391 | "metadata": {
392 | "ExecuteTime": {
393 | "end_time": "2018-08-19T12:51:45.803360Z",
394 | "start_time": "2018-08-19T12:39:14.189648Z"
395 | },
396 | "scrolled": true
397 | },
398 | "outputs": [
399 | {
400 | "name": "stdout",
401 | "output_type": "stream",
402 | "text": [
403 | "Update policy: top-k\n",
404 | "Training with MI-RNN loss for 3 rounds\n",
405 | "Round: 0\n",
406 | "Epoch 2 Batch 543 ( 1635) Loss 0.00083 Acc 0.98429 | Val acc 0.97662 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1000\n",
407 | "Epoch 2 Batch 543 ( 1635) Loss 0.00042 Acc 0.99143 | Val acc 0.98982 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1001\n",
408 | "Epoch 2 Batch 543 ( 1635) Loss 0.00036 Acc 0.99286 | Val acc 0.99133 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1002\n",
409 | "Epoch 2 Batch 543 ( 1635) Loss 0.00033 Acc 0.99429 | Val acc 0.99274 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1003\n",
410 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1003\n",
411 | "Round: 1\n",
412 | "Epoch 2 Batch 543 ( 1635) Loss 0.00027 Acc 0.99429 | Val acc 0.99481 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1004\n",
413 | "Epoch 2 Batch 543 ( 1635) Loss 0.00029 Acc 0.99286 | Val acc 0.99538 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1005\n",
414 | "Epoch 2 Batch 543 ( 1635) Loss 0.00022 Acc 0.99429 | Val acc 0.99632 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1006\n",
415 | "Epoch 2 Batch 543 ( 1635) Loss 0.00014 Acc 0.99429 | Val acc 0.99727 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1007\n",
416 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1007\n",
417 | "Round: 2\n",
418 | "Epoch 2 Batch 543 ( 1635) Loss 0.00022 Acc 0.98857 | Val acc 0.99529 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1008\n",
419 | "Epoch 2 Batch 543 ( 1635) Loss 0.00018 Acc 0.99000 | Val acc 0.99595 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1009\n",
420 | "Epoch 2 Batch 543 ( 1635) Loss 0.00024 Acc 0.98857 | Val acc 0.99595 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1010\n",
421 | "Epoch 2 Batch 543 ( 1635) Loss 0.00008 Acc 0.99714 | Val acc 0.99642 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1011\n",
422 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1011\n",
423 | "Round: 3\n",
424 | "Switching to EMI-Loss function\n",
425 | "Epoch 2 Batch 543 ( 1635) Loss 0.03035 Acc 0.99429 | Val acc 0.99538 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1012\n",
426 | "Epoch 2 Batch 543 ( 1635) Loss 0.01915 Acc 0.99429 | Val acc 0.99613 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1013\n",
427 | "Epoch 2 Batch 543 ( 1635) Loss 0.01444 Acc 0.99571 | Val acc 0.99613 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1014\n",
428 | "Epoch 2 Batch 543 ( 1635) Loss 0.01283 Acc 0.99571 | Val acc 0.99717 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1015\n",
429 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1015\n",
430 | "Round: 4\n",
431 | "Epoch 2 Batch 543 ( 1635) Loss 0.00784 Acc 1.00000 | Val acc 0.99745 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1016\n",
432 | "Epoch 2 Batch 543 ( 1635) Loss 0.00683 Acc 0.99857 | Val acc 0.99821 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1017\n",
433 | "Epoch 2 Batch 543 ( 1635) Loss 0.00627 Acc 1.00000 | Val acc 0.99849 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1018\n",
434 | "Epoch 2 Batch 543 ( 1635) Loss 0.00631 Acc 1.00000 | Val acc 0.99840 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1019\n",
435 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1018\n",
436 | "Round: 5\n",
437 | "Epoch 2 Batch 543 ( 1635) Loss 0.00631 Acc 1.00000 | Val acc 0.99840 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1020\n",
438 | "Epoch 2 Batch 543 ( 1635) Loss 0.00793 Acc 0.99857 | Val acc 0.99849 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1021\n",
439 | "Epoch 2 Batch 543 ( 1635) Loss 0.00700 Acc 0.99857 | Val acc 0.99830 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1022\n",
440 | "Epoch 2 Batch 543 ( 1635) Loss 0.00465 Acc 1.00000 | Val acc 0.99840 | Model saved to /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn, global_step 1023\n",
441 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1021\n"
442 | ]
443 | }
444 | ],
445 | "source": [
446 | "with g1.as_default():\n",
447 | " emiDriver = EMI_Driver(inputPipeline, emiFastGRNN, emiTrainer)\n",
448 | "\n",
449 | "emiDriver.initializeSession(g1)\n",
450 | "y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
451 | " y_train=y_train, bag_train=BAG_TRAIN,\n",
452 | " x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
453 | " numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
454 | " numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
455 | " numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
456 | " fracEMI=0.5, updatePolicy='top-k', k=1)"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 16,
462 | "metadata": {
463 | "ExecuteTime": {
464 | "end_time": "2018-08-19T12:51:45.832728Z",
465 | "start_time": "2018-08-19T12:51:45.805984Z"
466 | }
467 | },
468 | "outputs": [],
469 | "source": [
470 | "# Early Prediction Policy: We make an early prediction based on the predicted classes\n",
471 | "# probability. If the predicted class probability > minProb at some step, we make\n",
472 | "# a prediction at that step.\n",
473 | "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
474 | " assert instanceOut.ndim == 2\n",
475 | " classes = np.argmax(instanceOut, axis=1)\n",
476 | " prob = np.max(instanceOut, axis=1)\n",
477 | " index = np.where(prob >= minProb)[0]\n",
478 | " if len(index) == 0:\n",
479 | " assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
480 | " return classes[-1], len(instanceOut) - 1\n",
481 | " index = index[0]\n",
482 | " return classes[index], index\n",
483 | "\n",
484 | "def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
485 | " predictionStep = predictionStep + 1\n",
486 | " predictionStep = np.reshape(predictionStep, -1)\n",
487 | " totalSteps = np.sum(predictionStep)\n",
488 | " maxSteps = len(predictionStep) * numTimeSteps\n",
489 | " savings = 1.0 - (totalSteps / maxSteps)\n",
490 | " if returnTotal:\n",
491 | " return savings, totalSteps\n",
492 | " return savings"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": 17,
498 | "metadata": {
499 | "ExecuteTime": {
500 | "end_time": "2018-08-19T12:51:46.210240Z",
501 | "start_time": "2018-08-19T12:51:45.834534Z"
502 | }
503 | },
504 | "outputs": [
505 | {
506 | "name": "stdout",
507 | "output_type": "stream",
508 | "text": [
509 | "Accuracy at k = 2: 0.998567\n",
510 | "Additional savings: 0.960761\n"
511 | ]
512 | }
513 | ],
514 | "source": [
515 | "k = 2\n",
516 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb, minProb=0.99)\n",
517 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
518 | "print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
519 | "print('Additional savings: %f' % getEarlySaving(predictionStep, NUM_TIMESTEPS))"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": 18,
525 | "metadata": {
526 | "ExecuteTime": {
527 | "end_time": "2018-08-19T12:51:46.677691Z",
528 | "start_time": "2018-08-19T12:51:46.212285Z"
529 | },
530 | "scrolled": false
531 | },
532 | "outputs": [
533 | {
534 | "name": "stdout",
535 | "output_type": "stream",
536 | "text": [
537 | " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
538 | "0 1 0.998831 0.998532 0.998504 0.998561 0.998831 0.998831 \n",
539 | "1 2 0.998567 0.998295 0.998517 0.998074 0.998567 0.998567 \n",
540 | "2 3 0.997850 0.997675 0.998287 0.997069 0.997850 0.997850 \n",
541 | "3 4 0.996040 0.995706 0.997310 0.994133 0.996040 0.996040 \n",
542 | "\n",
543 | " micro-rec \n",
544 | "0 0.998831 \n",
545 | "1 0.998567 \n",
546 | "2 0.997850 \n",
547 | "3 0.996040 \n",
548 | "Max accuracy 0.998831 at subsequencelength 1\n",
549 | "Max micro-f 0.998831 at subsequencelength 1\n",
550 | "Micro-precision 0.998831 at subsequencelength 1\n",
551 | "Micro-recall 0.998831 at subsequencelength 1\n",
552 | "Max macro-f 0.998532 at subsequencelength 1\n",
553 | "macro-precision 0.998504 at subsequencelength 1\n",
554 | "macro-recall 0.998561 at subsequencelength 1\n"
555 | ]
556 | }
557 | ],
558 | "source": [
559 | "# A slightly more detailed analysis method is provided. \n",
560 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)"
561 | ]
562 | },
563 | {
564 | "cell_type": "markdown",
565 | "metadata": {},
566 | "source": [
567 | "## Picking the best model"
568 | ]
569 | },
570 | {
571 | "cell_type": "code",
572 | "execution_count": 19,
573 | "metadata": {
574 | "ExecuteTime": {
575 | "end_time": "2018-08-19T13:06:04.024660Z",
576 | "start_time": "2018-08-19T13:04:47.045787Z"
577 | }
578 | },
579 | "outputs": [
580 | {
581 | "name": "stdout",
582 | "output_type": "stream",
583 | "text": [
584 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1003\n",
585 | "Round: 0, Validation accuracy: 0.9927, Test Accuracy (k = 2): 0.960361, Additional savings: 0.372858\n",
586 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1007\n",
587 | "Round: 1, Validation accuracy: 0.9973, Test Accuracy (k = 2): 0.926303, Additional savings: 0.508829\n",
588 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1011\n",
589 | "Round: 2, Validation accuracy: 0.9964, Test Accuracy (k = 2): 0.950743, Additional savings: 0.585428\n",
590 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1015\n",
591 | "Round: 3, Validation accuracy: 0.9972, Test Accuracy (k = 2): 0.997435, Additional savings: 0.945828\n",
592 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1018\n",
593 | "Round: 4, Validation accuracy: 0.9985, Test Accuracy (k = 2): 0.998491, Additional savings: 0.956765\n",
594 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/Fast_GRNN/88_30/models/model-fgrnn-1021\n",
595 | "Round: 5, Validation accuracy: 0.9985, Test Accuracy (k = 2): 0.998567, Additional savings: 0.960761\n"
596 | ]
597 | }
598 | ],
599 | "source": [
600 | "devnull = open(os.devnull, 'r')\n",
601 | "for val in modelStats:\n",
602 | " round_, acc, modelPrefix, globalStep = val\n",
603 | " emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
604 | " predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
605 | " minProb=0.99, keep_prob=1.0)\n",
606 | " \n",
607 | " bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
608 | " print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
609 | " print(', Test Accuracy (k = %d): %f, ' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
610 | " print('Additional savings: %f' % getEarlySaving(predictionStep, NUM_TIMESTEPS)) "
611 | ]
612 | },
613 | {
614 | "cell_type": "code",
615 | "execution_count": 20,
616 | "metadata": {},
617 | "outputs": [],
618 | "source": [
619 | "dataset=\"WESAD\"\n",
620 | "model=\"fast-grnn\"\n",
621 | "params = {\n",
622 | " \"NUM_HIDDEN\" : 128,\n",
623 | " \"NUM_TIMESTEPS\" : 700, #subinstance length.\n",
624 | " \"NUM_FEATS\" : 8,\n",
625 | " \"FORGET_BIAS\" : 1.0,\n",
626 | " \"NUM_OUTPUT\" : 3,\n",
627 | " \"USE_DROPOUT\" : 0, # '1' -> True. '0' -> False\n",
628 | " \"KEEP_PROB\" : 0.9,\n",
629 | " \"UPDATE_NL\" : \"quantTanh\",\n",
630 | " \"GATE_NL\" : \"quantSigm\",\n",
631 | " \"WRANK\" : 5,\n",
632 | " \"URANK\" : 6,\n",
633 | " \"PREFETCH_NUM\" : 5,\n",
634 | " \"BATCH_SIZE\" : 175,\n",
635 | " \"NUM_EPOCHS\" : 3,\n",
636 | " \"NUM_ITER\" : 4,\n",
637 | " \"NUM_ROUNDS\" : 4,\n",
638 | " \"MODEL_PREFIX\" : dataset + '/model-' + str(model)\n",
639 | "}\n",
640 | "\n",
641 | "fast_dict = {**params}\n",
642 | "fast_dict[\"k\"] = k\n",
643 | "fast_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
644 | "fast_dict[\"additional_savings\"] = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
645 | "fast_dict[\"y_test\"] = BAG_TEST\n",
646 | "fast_dict[\"y_pred\"] = bagPredictions"
647 | ]
648 | },
649 | {
650 | "cell_type": "code",
651 | "execution_count": 21,
652 | "metadata": {},
653 | "outputs": [
654 | {
655 | "name": "stdout",
656 | "output_type": "stream",
657 | "text": [
658 | " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
659 | "0 1 0.998831 0.998532 0.998504 0.998561 0.998831 0.998831 \n",
660 | "1 2 0.998567 0.998295 0.998517 0.998074 0.998567 0.998567 \n",
661 | "2 3 0.997850 0.997675 0.998287 0.997069 0.997850 0.997850 \n",
662 | "3 4 0.996040 0.995706 0.997310 0.994133 0.996040 0.996040 \n",
663 | "\n",
664 | " micro-rec \n",
665 | "0 0.998831 \n",
666 | "1 0.998567 \n",
667 | "2 0.997850 \n",
668 | "3 0.996040 \n",
669 | "Max accuracy 0.998831 at subsequencelength 1\n",
670 | "Max micro-f 0.998831 at subsequencelength 1\n",
671 | "Micro-precision 0.998831 at subsequencelength 1\n",
672 | "Micro-recall 0.998831 at subsequencelength 1\n",
673 | "Max macro-f 0.998532 at subsequencelength 1\n",
674 | "macro-precision 0.998504 at subsequencelength 1\n",
675 | "macro-recall 0.998561 at subsequencelength 1\n",
676 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
677 | "| | len | acc | macro-fsc | macro-pre | macro-rec | micro-fsc | micro-pre | micro-rec |\n",
678 | "+====+=======+==========+=============+=============+=============+=============+=============+=============+\n",
679 | "| 0 | 1 | 0.998831 | 0.998532 | 0.998504 | 0.998561 | 0.998831 | 0.998831 | 0.998831 |\n",
680 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
681 | "| 1 | 2 | 0.998567 | 0.998295 | 0.998517 | 0.998074 | 0.998567 | 0.998567 | 0.998567 |\n",
682 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
683 | "| 2 | 3 | 0.99785 | 0.997675 | 0.998287 | 0.997069 | 0.99785 | 0.99785 | 0.99785 |\n",
684 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
685 | "| 3 | 4 | 0.99604 | 0.995706 | 0.99731 | 0.994133 | 0.99604 | 0.99604 | 0.99604 |\n",
686 | "+----+-------+----------+-------------+-------------+-------------+-------------+-------------+-------------+\n",
687 | "Results for this run have been saved at /home/sf/data/WESAD/Fast_GRNN/ .\n"
688 | ]
689 | }
690 | ],
691 | "source": [
692 | "# A slightly more detailed analysis method is provided. \n",
693 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
694 | "print (tabulate(df, headers=list(df.columns), tablefmt='grid'))\n",
695 | "\n",
696 | "dirname = \"/home/sf/data/WESAD/Fast_GRNN/\"\n",
697 | "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
698 | "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
699 | "\n",
700 | "now = datetime.datetime.now()\n",
701 | "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
702 | "filename = ''.join(filename)\n",
703 | "\n",
704 | "#Save the dictionary containing the params and the results.\n",
705 | "pkl.dump(fast_dict,open(dirname + \"/fast_dict_\" + filename + \".pkl\",mode='wb'))"
706 | ]
707 | }
708 | ],
709 | "metadata": {
710 | "kernelspec": {
711 | "display_name": "Python 3",
712 | "language": "python",
713 | "name": "python3"
714 | },
715 | "language_info": {
716 | "codemirror_mode": {
717 | "name": "ipython",
718 | "version": 3
719 | },
720 | "file_extension": ".py",
721 | "mimetype": "text/x-python",
722 | "name": "python",
723 | "nbconvert_exporter": "python",
724 | "pygments_lexer": "ipython3",
725 | "version": "3.7.3"
726 | }
727 | },
728 | "nbformat": 4,
729 | "nbformat_minor": 2
730 | }
731 |
--------------------------------------------------------------------------------
/WESAD/WESAD_GRU.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# WESAD GRU"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pandas as pd\n",
17 | "import numpy as np\n",
18 | "from tabulate import tabulate\n",
19 | "import os\n",
20 | "import datetime as datetime\n",
21 | "import pickle as pkl\n",
22 | "import pathlib"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {
29 | "ExecuteTime": {
30 | "end_time": "2018-12-14T14:17:51.796585Z",
31 | "start_time": "2018-12-14T14:17:49.648375Z"
32 | }
33 | },
34 | "outputs": [
35 | {
36 | "name": "stderr",
37 | "output_type": "stream",
38 | "text": [
39 | "Using TensorFlow backend.\n"
40 | ]
41 | }
42 | ],
43 | "source": [
44 | "from __future__ import print_function\n",
45 | "import os\n",
46 | "import sys\n",
47 | "import tensorflow as tf\n",
48 | "import numpy as np\n",
49 | "# Making sure edgeml is part of python path\n",
50 | "sys.path.insert(0, '../../')\n",
51 | "#For processing on CPU.\n",
52 | "os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
53 | "\n",
54 | "np.random.seed(42)\n",
55 | "tf.set_random_seed(42)\n",
56 | "\n",
57 | "# MI-RNN and EMI-RNN imports\n",
58 | "from edgeml.graph.rnn import EMI_DataPipeline\n",
59 | "from edgeml.graph.rnn import EMI_GRU\n",
60 | "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
61 | "import edgeml.utils\n",
62 | "\n",
63 | "import keras.backend as K\n",
64 | "cfg = K.tf.ConfigProto()\n",
65 | "cfg.gpu_options.allow_growth = True\n",
66 | "K.set_session(K.tf.Session(config=cfg))"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {
73 | "ExecuteTime": {
74 | "end_time": "2018-12-14T14:17:51.803381Z",
75 | "start_time": "2018-12-14T14:17:51.798799Z"
76 | }
77 | },
78 | "outputs": [],
79 | "source": [
80 | "# Network parameters for our LSTM + FC Layer\n",
81 | "NUM_HIDDEN = 128\n",
82 | "NUM_TIMESTEPS = 88\n",
83 | "ORIGINAL_NUM_TIMESTEPS = 175\n",
84 | "NUM_FEATS = 8\n",
85 | "FORGET_BIAS = 1.0\n",
86 | "NUM_OUTPUT = 3\n",
87 | "USE_DROPOUT = True\n",
88 | "KEEP_PROB = 0.75\n",
89 | "\n",
90 | "# For dataset API\n",
91 | "PREFETCH_NUM = 5\n",
92 | "BATCH_SIZE = 175\n",
93 | "\n",
94 | "# Number of epochs in *one iteration*\n",
95 | "NUM_EPOCHS = 2\n",
96 | "# Number of iterations in *one round*. After each iteration,\n",
97 | "# the model is dumped to disk. At the end of the current\n",
98 | "# round, the best model among all the dumped models in the\n",
99 | "# current round is picked up..\n",
100 | "NUM_ITER = 4\n",
101 | "# A round consists of multiple training iterations and a belief\n",
102 | "# update step using the best model from all of these iterations\n",
103 | "NUM_ROUNDS = 6\n",
104 | "LEARNING_RATE=0.001\n",
105 | "\n",
106 | "# A staging direcory to store models\n",
107 | "MODEL_PREFIX = '/home/sf/data/WESAD/GRU/88_30/models/model-gru'"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {
113 | "heading_collapsed": true
114 | },
115 | "source": [
116 | "# Loading Data"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 4,
122 | "metadata": {
123 | "ExecuteTime": {
124 | "end_time": "2018-12-14T14:17:52.040352Z",
125 | "start_time": "2018-12-14T14:17:51.805319Z"
126 | },
127 | "hidden": true
128 | },
129 | "outputs": [
130 | {
131 | "name": "stdout",
132 | "output_type": "stream",
133 | "text": [
134 | "x_train shape is: (95450, 4, 88, 8)\n",
135 | "y_train shape is: (95450, 4, 3)\n",
136 | "x_test shape is: (10606, 4, 88, 8)\n",
137 | "y_test shape is: (10606, 4, 3)\n"
138 | ]
139 | }
140 | ],
141 | "source": [
142 | "# Loading the data\n",
143 | "x_train, y_train = np.load('/home/sf/data/WESAD/88_30/x_train.npy'), np.load('/home/sf/data/WESAD/88_30/y_train.npy')\n",
144 | "x_test, y_test = np.load('/home/sf/data/WESAD/88_30/x_test.npy'), np.load('/home/sf/data/WESAD/88_30/y_test.npy')\n",
145 | "x_val, y_val = np.load('/home/sf/data/WESAD/88_30/x_val.npy'), np.load('/home/sf/data/WESAD/88_30/y_val.npy')\n",
146 | "\n",
147 | "# BAG_TEST, BAG_TRAIN, BAG_VAL represent bag_level labels. These are used for the label update\n",
148 | "# step of EMI/MI RNN\n",
149 | "BAG_TEST = np.argmax(y_test[:, 0, :], axis=1)\n",
150 | "BAG_TRAIN = np.argmax(y_train[:, 0, :], axis=1)\n",
151 | "BAG_VAL = np.argmax(y_val[:, 0, :], axis=1)\n",
152 | "NUM_SUBINSTANCE = x_train.shape[1]\n",
153 | "print(\"x_train shape is:\", x_train.shape)\n",
154 | "print(\"y_train shape is:\", y_train.shape)\n",
155 | "print(\"x_test shape is:\", x_val.shape)\n",
156 | "print(\"y_test shape is:\", y_val.shape)"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "# Computation Graph"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 5,
169 | "metadata": {
170 | "ExecuteTime": {
171 | "end_time": "2018-12-14T14:17:52.053161Z",
172 | "start_time": "2018-12-14T14:17:52.042928Z"
173 | }
174 | },
175 | "outputs": [],
176 | "source": [
177 | "# Define the linear secondary classifier\n",
178 | "def createExtendedGraph(self, baseOutput, *args, **kwargs):\n",
179 | " W1 = tf.Variable(np.random.normal(size=[NUM_HIDDEN, NUM_OUTPUT]).astype('float32'), name='W1')\n",
180 | " B1 = tf.Variable(np.random.normal(size=[NUM_OUTPUT]).astype('float32'), name='B1')\n",
181 | " y_cap = tf.add(tf.tensordot(baseOutput, W1, axes=1), B1, name='y_cap_tata')\n",
182 | " self.output = y_cap\n",
183 | " self.graphCreated = True\n",
184 | "\n",
185 | "def restoreExtendedGraph(self, graph, *args, **kwargs):\n",
186 | " y_cap = graph.get_tensor_by_name('y_cap_tata:0')\n",
187 | " self.output = y_cap\n",
188 | " self.graphCreated = True\n",
189 | " \n",
190 | "def feedDictFunc(self, keep_prob=None, inference=False, **kwargs):\n",
191 | " if inference is False:\n",
192 | " feedDict = {self._emiGraph.keep_prob: keep_prob}\n",
193 | " else:\n",
194 | " feedDict = {self._emiGraph.keep_prob: 1.0}\n",
195 | " return feedDict\n",
196 | " \n",
197 | "EMI_GRU._createExtendedGraph = createExtendedGraph\n",
198 | "EMI_GRU._restoreExtendedGraph = restoreExtendedGraph\n",
199 | "\n",
200 | "if USE_DROPOUT is True:\n",
201 | " EMI_Driver.feedDictFunc = feedDictFunc"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 6,
207 | "metadata": {
208 | "ExecuteTime": {
209 | "end_time": "2018-12-14T14:17:52.335299Z",
210 | "start_time": "2018-12-14T14:17:52.055483Z"
211 | }
212 | },
213 | "outputs": [],
214 | "source": [
215 | "inputPipeline = EMI_DataPipeline(NUM_SUBINSTANCE, NUM_TIMESTEPS, NUM_FEATS, NUM_OUTPUT)\n",
216 | "emiGRU = EMI_GRU(NUM_SUBINSTANCE, NUM_HIDDEN, NUM_TIMESTEPS, NUM_FEATS,\n",
217 | " useDropout=USE_DROPOUT)\n",
218 | "emiTrainer = EMI_Trainer(NUM_TIMESTEPS, NUM_OUTPUT, lossType='xentropy',\n",
219 | " stepSize=LEARNING_RATE)"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": 7,
225 | "metadata": {
226 | "ExecuteTime": {
227 | "end_time": "2018-12-14T14:18:05.031382Z",
228 | "start_time": "2018-12-14T14:17:52.338750Z"
229 | }
230 | },
231 | "outputs": [],
232 | "source": [
233 | "tf.reset_default_graph()\n",
234 | "g1 = tf.Graph() \n",
235 | "with g1.as_default():\n",
236 | " # Obtain the iterators to each batch of the data\n",
237 | " x_batch, y_batch = inputPipeline()\n",
238 | " # Create the forward computation graph based on the iterators\n",
239 | " y_cap = emiGRU(x_batch)\n",
240 | " # Create loss graphs and training routines\n",
241 | " emiTrainer(y_cap, y_batch)"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "# EMI Driver"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 9,
254 | "metadata": {
255 | "ExecuteTime": {
256 | "end_time": "2018-12-14T14:35:15.209910Z",
257 | "start_time": "2018-12-14T14:18:05.034359Z"
258 | },
259 | "scrolled": true
260 | },
261 | "outputs": [
262 | {
263 | "name": "stdout",
264 | "output_type": "stream",
265 | "text": [
266 | "Update policy: top-k\n",
267 | "Training with MI-RNN loss for 3 rounds\n",
268 | "Round: 0\n",
269 | "Epoch 1 Batch 534 ( 1080) Loss 0.00133 Acc 0.96000 | Val acc 0.97869 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1000\n",
270 | "Epoch 1 Batch 534 ( 1080) Loss 0.00082 Acc 0.97143 | Val acc 0.98190 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1001\n",
271 | "Epoch 1 Batch 534 ( 1080) Loss 0.00024 Acc 0.99429 | Val acc 0.97134 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1002\n",
272 | "Epoch 1 Batch 534 ( 1080) Loss 0.00020 Acc 0.99143 | Val acc 0.97596 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1003\n",
273 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/GRU/88_30/models/model-gru-1001\n",
274 | "Round: 1\n",
275 | "Epoch 1 Batch 534 ( 1080) Loss 0.00008 Acc 1.00000 | Val acc 0.98435 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1004\n",
276 | "Epoch 1 Batch 534 ( 1080) Loss 0.00033 Acc 0.99143 | Val acc 0.96361 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1005\n",
277 | "Epoch 1 Batch 534 ( 1080) Loss 0.00009 Acc 0.99571 | Val acc 0.96134 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1006\n",
278 | "Epoch 1 Batch 534 ( 1080) Loss 0.00026 Acc 0.99286 | Val acc 0.95418 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1007\n",
279 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/GRU/88_30/models/model-gru-1004\n",
280 | "Round: 2\n",
281 | "Epoch 1 Batch 534 ( 1080) Loss 0.00008 Acc 0.99714 | Val acc 0.97237 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1008\n",
282 | "Epoch 1 Batch 534 ( 1080) Loss 0.00007 Acc 0.99714 | Val acc 0.95616 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1009\n",
283 | "Epoch 1 Batch 534 ( 1080) Loss 0.00004 Acc 0.99857 | Val acc 0.95418 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1010\n",
284 | "Epoch 1 Batch 534 ( 1080) Loss 0.00007 Acc 0.99857 | Val acc 0.95587 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1011\n",
285 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/GRU/88_30/models/model-gru-1008\n",
286 | "Round: 3\n",
287 | "Switching to EMI-Loss function\n",
288 | "Epoch 1 Batch 534 ( 1080) Loss 0.07520 Acc 0.99286 | Val acc 0.93485 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1012\n",
289 | "Epoch 1 Batch 534 ( 1080) Loss 0.03711 Acc 1.00000 | Val acc 0.89421 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1013\n",
290 | "Epoch 1 Batch 534 ( 1080) Loss 0.03318 Acc 1.00000 | Val acc 0.85998 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1014\n",
291 | "Epoch 1 Batch 534 ( 1080) Loss 0.03119 Acc 1.00000 | Val acc 0.81897 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1015\n",
292 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/GRU/88_30/models/model-gru-1012\n",
293 | "Round: 4\n",
294 | "Epoch 1 Batch 534 ( 1080) Loss 0.04101 Acc 1.00000 | Val acc 0.90958 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1016\n",
295 | "Epoch 1 Batch 534 ( 1080) Loss 0.03496 Acc 1.00000 | Val acc 0.84546 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1017\n",
296 | "Epoch 1 Batch 534 ( 1080) Loss 0.03953 Acc 0.99571 | Val acc 0.82500 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1018\n",
297 | "Epoch 1 Batch 534 ( 1080) Loss 0.03629 Acc 0.99857 | Val acc 0.79370 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1019\n",
298 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/GRU/88_30/models/model-gru-1016\n",
299 | "Round: 5\n",
300 | "Epoch 1 Batch 534 ( 1080) Loss 0.03494 Acc 1.00000 | Val acc 0.82302 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1020\n",
301 | "Epoch 1 Batch 534 ( 1080) Loss 0.03172 Acc 0.99857 | Val acc 0.85122 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1021\n",
302 | "Epoch 1 Batch 534 ( 1080) Loss 0.02683 Acc 1.00000 | Val acc 0.80200 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1022\n",
303 | "Epoch 1 Batch 534 ( 1080) Loss 0.02836 Acc 1.00000 | Val acc 0.78682 | Model saved to /home/sf/data/WESAD/GRU/88_30/models/model-gru, global_step 1023\n",
304 | "INFO:tensorflow:Restoring parameters from /home/sf/data/WESAD/GRU/88_30/models/model-gru-1021\n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "with g1.as_default():\n",
310 | " emiDriver = EMI_Driver(inputPipeline, emiGRU, emiTrainer)\n",
311 | "\n",
312 | "emiDriver.initializeSession(g1)\n",
313 | "y_updated, modelStats = emiDriver.run(numClasses=NUM_OUTPUT, x_train=x_train,\n",
314 | " y_train=y_train, bag_train=BAG_TRAIN,\n",
315 | " x_val=x_val, y_val=y_val, bag_val=BAG_VAL,\n",
316 | " numIter=NUM_ITER, keep_prob=KEEP_PROB,\n",
317 | " numRounds=NUM_ROUNDS, batchSize=BATCH_SIZE,\n",
318 | " numEpochs=NUM_EPOCHS, modelPrefix=MODEL_PREFIX,\n",
319 | " fracEMI=0.5, updatePolicy='top-k', k=1)"
320 | ]
321 | },
322 | {
323 | "cell_type": "markdown",
324 | "metadata": {},
325 | "source": [
326 | "# Evaluating the trained model"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": 10,
332 | "metadata": {
333 | "ExecuteTime": {
334 | "end_time": "2018-12-14T14:35:15.218040Z",
335 | "start_time": "2018-12-14T14:35:15.211771Z"
336 | }
337 | },
338 | "outputs": [],
339 | "source": [
340 | "# Early Prediction Policy: We make an early prediction based on the predicted classes\n",
341 | "# probability. If the predicted class probability > minProb at some step, we make\n",
342 | "# a prediction at that step.\n",
343 | "def earlyPolicy_minProb(instanceOut, minProb, **kwargs):\n",
344 | " assert instanceOut.ndim == 2\n",
345 | " classes = np.argmax(instanceOut, axis=1)\n",
346 | " prob = np.max(instanceOut, axis=1)\n",
347 | " index = np.where(prob >= minProb)[0]\n",
348 | " if len(index) == 0:\n",
349 | " assert (len(instanceOut) - 1) == (len(classes) - 1)\n",
350 | " return classes[-1], len(instanceOut) - 1\n",
351 | " index = index[0]\n",
352 | " return classes[index], index\n",
353 | "\n",
354 | "def getEarlySaving(predictionStep, numTimeSteps, returnTotal=False):\n",
355 | " predictionStep = predictionStep + 1\n",
356 | " predictionStep = np.reshape(predictionStep, -1)\n",
357 | " totalSteps = np.sum(predictionStep)\n",
358 | " maxSteps = len(predictionStep) * numTimeSteps\n",
359 | " savings = 1.0 - (totalSteps / maxSteps)\n",
360 | " if returnTotal:\n",
361 | " return savings, totalSteps\n",
362 | " return savings"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": 11,
368 | "metadata": {
369 | "ExecuteTime": {
370 | "end_time": "2018-12-14T14:35:16.257489Z",
371 | "start_time": "2018-12-14T14:35:15.221029Z"
372 | }
373 | },
374 | "outputs": [
375 | {
376 | "name": "stdout",
377 | "output_type": "stream",
378 | "text": [
379 | "Accuracy at k = 2: 0.852908\n",
380 | "Savings due to MI-RNN : 0.497143\n",
381 | "Savings due to Early prediction: 0.826181\n",
382 | "Total Savings: 0.912594\n"
383 | ]
384 | }
385 | ],
386 | "source": [
387 | "k = 2\n",
388 | "predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
389 | " minProb=0.99, keep_prob=1.0)\n",
390 | "bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
391 | "print('Accuracy at k = %d: %f' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))))\n",
392 | "mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
393 | "emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
394 | "total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
395 | "print('Savings due to MI-RNN : %f' % mi_savings)\n",
396 | "print('Savings due to Early prediction: %f' % emi_savings)\n",
397 | "print('Total Savings: %f' % (total_savings))"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": 12,
403 | "metadata": {
404 | "ExecuteTime": {
405 | "end_time": "2018-12-14T14:35:17.044115Z",
406 | "start_time": "2018-12-14T14:35:16.259280Z"
407 | },
408 | "scrolled": false
409 | },
410 | "outputs": [
411 | {
412 | "name": "stdout",
413 | "output_type": "stream",
414 | "text": [
415 | " len acc macro-fsc macro-pre macro-rec micro-fsc micro-pre \\\n",
416 | "0 1 0.844648 0.835791 0.838901 0.850929 0.844648 0.844648 \n",
417 | "1 2 0.852908 0.842272 0.845445 0.853562 0.852908 0.852908 \n",
418 | "2 3 0.855397 0.843846 0.848683 0.851069 0.855397 0.855397 \n",
419 | "3 4 0.852757 0.837862 0.847351 0.840189 0.852757 0.852757 \n",
420 | "\n",
421 | " micro-rec \n",
422 | "0 0.844648 \n",
423 | "1 0.852908 \n",
424 | "2 0.855397 \n",
425 | "3 0.852757 \n",
426 | "Max accuracy 0.855397 at subsequencelength 3\n",
427 | "Max micro-f 0.855397 at subsequencelength 3\n",
428 | "Micro-precision 0.855397 at subsequencelength 3\n",
429 | "Micro-recall 0.855397 at subsequencelength 3\n",
430 | "Max macro-f 0.843846 at subsequencelength 3\n",
431 | "macro-precision 0.848683 at subsequencelength 3\n",
432 | "macro-recall 0.851069 at subsequencelength 3\n"
433 | ]
434 | }
435 | ],
436 | "source": [
437 | "# A slightly more detailed analysis method is provided. \n",
438 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)"
439 | ]
440 | },
441 | {
442 | "cell_type": "markdown",
443 | "metadata": {},
444 | "source": [
445 | "## Picking the best model"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": null,
451 | "metadata": {
452 | "ExecuteTime": {
453 | "end_time": "2018-12-14T14:35:54.899340Z",
454 | "start_time": "2018-12-14T14:35:17.047464Z"
455 | }
456 | },
457 | "outputs": [],
458 | "source": [
459 | "devnulldevnull = open(os.devnull, 'r')\n",
460 | "for val in modelStats:\n",
461 | " round_, acc, modelPrefix, globalStep = val\n",
462 | " emiDriver.loadSavedGraphToNewSession(modelPrefix, globalStep, redirFile=devnull)\n",
463 | " predictions, predictionStep = emiDriver.getInstancePredictions(x_test, y_test, earlyPolicy_minProb,\n",
464 | " minProb=0.99, keep_prob=1.0)\n",
465 | "\n",
466 | " bagPredictions = emiDriver.getBagPredictions(predictions, minSubsequenceLen=k, numClass=NUM_OUTPUT)\n",
467 | " print(\"Round: %2d, Validation accuracy: %.4f\" % (round_, acc), end='')\n",
468 | " print(', Test Accuracy (k = %d): %f, ' % (k, np.mean((bagPredictions == BAG_TEST).astype(int))), end='')\n",
469 | " mi_savings = (1 - NUM_TIMESTEPS / ORIGINAL_NUM_TIMESTEPS)\n",
470 | " emi_savings = getEarlySaving(predictionStep, NUM_TIMESTEPS)\n",
471 | " total_savings = mi_savings + (1 - mi_savings) * emi_savings\n",
472 | " print(\"Total Savings: %f\" % total_savings)"
473 | ]
474 | },
475 | {
476 | "cell_type": "code",
477 | "execution_count": null,
478 | "metadata": {},
479 | "outputs": [],
480 | "source": [
481 | "params = {\n",
482 | " \"NUM_HIDDEN\" : 128,\n",
483 | " \"NUM_TIMESTEPS\" : 64, #subinstance length.\n",
484 | " \"ORIGINAL_NUM_TIMESTEPS\" : 128,\n",
485 | " \"NUM_FEATS\" : 16,\n",
486 | " \"FORGET_BIAS\" : 1.0,\n",
487 | " \"NUM_OUTPUT\" : 5,\n",
488 | " \"USE_DROPOUT\" : 1, # '1' -> True. '0' -> False\n",
489 | " \"KEEP_PROB\" : 0.75,\n",
490 | " \"PREFETCH_NUM\" : 5,\n",
491 | " \"BATCH_SIZE\" : 32,\n",
492 | " \"NUM_EPOCHS\" : 2,\n",
493 | " \"NUM_ITER\" : 4,\n",
494 | " \"NUM_ROUNDS\" : 10,\n",
495 | " \"LEARNING_RATE\" : 0.001,\n",
496 | " \"MODEL_PREFIX\" : '/home/sf/data/DREAMER/Dominance/model-gru'\n",
497 | "}"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": null,
503 | "metadata": {},
504 | "outputs": [],
505 | "source": [
506 | "gru_dict = {**params}\n",
507 | "gru_dict[\"k\"] = k\n",
508 | "gru_dict[\"accuracy\"] = np.mean((bagPredictions == BAG_TEST).astype(int))\n",
509 | "gru_dict[\"total_savings\"] = total_savings\n",
510 | "gru_dict[\"y_test\"] = BAG_TEST\n",
511 | "gru_dict[\"y_pred\"] = bagPredictions\n",
512 | "\n",
513 | "# A slightly more detailed analysis method is provided. \n",
514 | "df = emiDriver.analyseModel(predictions, BAG_TEST, NUM_SUBINSTANCE, NUM_OUTPUT)\n",
515 | "print (tabulate(df, headers=list(df.columns), tablefmt='grid'))"
516 | ]
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": null,
521 | "metadata": {},
522 | "outputs": [],
523 | "source": [
524 | "dirname = \"/home/sf/data/WESAD/GRU/\"\n",
525 | "pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n",
526 | "print (\"Results for this run have been saved at\" , dirname, \".\")\n",
527 | "\n",
528 | "now = datetime.datetime.now()\n",
529 | "filename = list((str(now.year),\"-\",str(now.month),\"-\",str(now.day),\"|\",str(now.hour),\"-\",str(now.minute)))\n",
530 | "filename = ''.join(filename)\n",
531 | "\n",
532 | "#Save the dictionary containing the params and the results.\n",
533 | "pkl.dump(gru_dict,open(dirname + filename + \".pkl\",mode='wb'))"
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": null,
539 | "metadata": {},
540 | "outputs": [],
541 | "source": [
542 | "dirname+filename+'.pkl'"
543 | ]
544 | }
545 | ],
546 | "metadata": {
547 | "kernelspec": {
548 | "display_name": "Python 3",
549 | "language": "python",
550 | "name": "python3"
551 | },
552 | "language_info": {
553 | "codemirror_mode": {
554 | "name": "ipython",
555 | "version": 3
556 | },
557 | "file_extension": ".py",
558 | "mimetype": "text/x-python",
559 | "name": "python",
560 | "nbconvert_exporter": "python",
561 | "pygments_lexer": "ipython3",
562 | "version": "3.7.3"
563 | },
564 | "latex_envs": {
565 | "LaTeX_envs_menu_present": true,
566 | "autoclose": false,
567 | "autocomplete": true,
568 | "bibliofile": "biblio.bib",
569 | "cite_by": "apalike",
570 | "current_citInitial": 1,
571 | "eqLabelWithNumbers": true,
572 | "eqNumInitial": 1,
573 | "hotkeys": {
574 | "equation": "Ctrl-E",
575 | "itemize": "Ctrl-I"
576 | },
577 | "labels_anchors": false,
578 | "latex_user_defs": false,
579 | "report_style_numbering": false,
580 | "user_envs_cfg": false
581 | }
582 | },
583 | "nbformat": 4,
584 | "nbformat_minor": 2
585 | }
586 |
--------------------------------------------------------------------------------
/WESAD/WESAD_Get_Single.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from sklearn.utils import shuffle
4 |
5 | get_id = int(input())
6 | df = pd.read_csv("../allchest.csv")
7 | df = df[df['ID'] == get_id]
8 | df.reset_index(inplace=True, drop=True)
9 | df = shuffle(df)
10 | df.reset_index(inplace=True, drop=True)
11 |
12 | train = df.iloc[:int(df.shape[0] * 0.2), :]
13 | test = df.iloc[int(df.shape[0] * 0.2):, :]
14 | del df
15 |
16 | train.reset_index(inplace=True, drop=True)
17 | test.reset_index(inplace=True, drop=True)
18 | train = train[['label', 'ID', 'chestACCx', 'chestACCy', 'chestACCz', 'chestECG', 'chestEMG', 'chestEDA', 'chestTemp', 'chestResp']]
19 | test = test[['label', 'ID', 'chestACCx', 'chestACCy', 'chestACCz', 'chestECG', 'chestEMG', 'chestEDA', 'chestTemp', 'chestResp']]
20 |
21 | train = np.array(train, dtype=np.float32)
22 | test = np.array(test, dtype=np.float32)
23 |
24 | np.save("train.npy", train)
25 | np.save("test.npy", test)
26 |
--------------------------------------------------------------------------------