├── AP_accel_max.pkl ├── LSTM_scaler.pkl ├── V_accel_max.pkl ├── Alcantara_ISB_World_Athletics_Manuscript.pdf ├── data └── Sub_Info_one_sub.csv ├── .gitignore ├── README.md ├── Train_LSTM.ipynb ├── pre_processing.py ├── LICENSE ├── Test_LSTM.ipynb └── LSTM_Example.ipynb /AP_accel_max.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alcantarar/Recurrent_GRF_Prediction/HEAD/AP_accel_max.pkl -------------------------------------------------------------------------------- /LSTM_scaler.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alcantarar/Recurrent_GRF_Prediction/HEAD/LSTM_scaler.pkl -------------------------------------------------------------------------------- /V_accel_max.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alcantarar/Recurrent_GRF_Prediction/HEAD/V_accel_max.pkl -------------------------------------------------------------------------------- /Alcantara_ISB_World_Athletics_Manuscript.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alcantarar/Recurrent_GRF_Prediction/HEAD/Alcantara_ISB_World_Athletics_Manuscript.pdf -------------------------------------------------------------------------------- /data/Sub_Info_one_sub.csv: -------------------------------------------------------------------------------- 1 | Sub,Condition,Sex,Age,Height,Mass,RFS,MFS,FFS,Speed,Slope,StepFreq 2 | 2,NA,m,NA,173,76.8,0,100,0,2.5,10,NA 3 | 2,NA,m,NA,173,76.8,100,0,0,4.17,5,NA 4 | 2,NA,m,NA,173,76.8,100,0,0,2.5,0,NA 5 | 2,NA,m,NA,173,76.8,100,0,0,3.33,0,NA 6 | 2,NA,m,NA,173,76.8,100,0,0,3.33,0,NA 7 | 2,NA,m,NA,173,76.8,100,0,0,3.33,0,NA 8 | 2,NA,m,NA,173,76.8,100,0,0,4.17,0,NA 9 | 2,NA,m,NA,173,76.8,100,0,0,2.5,-5,NA 10 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-5,NA 11 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-5,NA 12 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-5,NA 13 | 2,NA,m,NA,173,76.8,0,100,0,3.33,10,NA 14 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-5,NA 15 | 2,NA,m,NA,173,76.8,100,0,0,4.17,-5,NA 16 | 2,NA,m,NA,173,76.8,100,0,0,2.5,-10,NA 17 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-10,NA 18 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-10,NA 19 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-10,NA 20 | 2,NA,m,NA,173,76.8,100,0,0,4.17,-10,NA 21 | 2,NA,m,NA,173,76.8,0,86.66666667,13.33333333,3.33,10,NA 22 | 2,NA,m,NA,173,76.8,40,60,0,3.33,10,NA 23 | 2,NA,m,NA,173,76.8,0,100,0,4.17,10,NA 24 | 2,NA,m,NA,173,76.8,100,0,0,2.5,5,NA 25 | 2,NA,m,NA,173,76.8,100,0,0,3.33,5,NA 26 | 2,NA,m,NA,173,76.8,100,0,0,3.33,5,NA 27 | 2,NA,m,NA,173,76.8,100,0,0,3.33,5,NA 28 | 2,NA,m,NA,173,76.8,0,86.66666667,13.33333333,3.33,10,NA 29 | 2,NA,m,NA,173,76.8,100,0,0,3.33,5,NA 30 | 2,NA,m,NA,173,76.8,100,0,0,3.33,0,NA 31 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-5,NA 32 | 2,NA,m,NA,173,76.8,100,0,0,3.33,-10,NA 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Matlab ### 2 | # Windows default autosave extension 3 | *.asv 4 | 5 | # OSX / *nix default autosave extension 6 | *.m~ 7 | 8 | # Compiled MEX binaries (all platforms) 9 | *.mex* 10 | 11 | # Packaged app and toolbox files 12 | *.mlappinstall 13 | *.mltbx 14 | 15 | # Generated helpsearch folders 16 | helpsearch*/ 17 | 18 | # Simulink code generation folders 19 | slprj/ 20 | sccprj/ 21 | 22 | # Matlab code generation folders 23 | codegen/ 24 | 25 | # Simulink autosave extension 26 | *.autosave 27 | 28 | # Octave session info 29 | octave-workspace 30 | 31 | ### R ### 32 | # History files 33 | .Rhistory 34 | .Rapp.history 35 | 36 | # Session Data files 37 | .RData 38 | 39 | # Example code in package build process 40 | *-Ex.R 41 | 42 | # Output files from R CMD build 43 | /*.tar.gz 44 | 45 | # Output files from R CMD check 46 | /*.Rcheck/ 47 | 48 | # RStudio files 49 | .Rproj.user/ 50 | 51 | # produced vignettes 52 | vignettes/*.html 53 | vignettes/*.pdf 54 | 55 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 56 | .httr-oauth 57 | 58 | # knitr and R markdown default cache directories 59 | /*_cache/ 60 | /cache/ 61 | 62 | # Temporary files created by R markdown 63 | *.utf8.md 64 | *.knit.md 65 | 66 | # Incomplete Rplots 67 | Rplots.pdf 68 | 69 | ### R.Bookdown Stack ### 70 | # R package: bookdown caching files 71 | /*_files/ 72 | 73 | ### MS Word ### 74 | 75 | # Word temporary 76 | ~$*.doc* 77 | 78 | ### Other things to ignore ### 79 | .DS_Store 80 | 81 | # End of https://www.gitignore.io/api/r,matlab 82 | 83 | # PyCharm env files 84 | *.idea 85 | 86 | # Jupyter notebook checkpoints 87 | *.ipynb_checkpoints 88 | 89 | # Python cache 90 | __pycache__ 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Predicting continuous ground reaction forces from accelerometers during uphill and downhill running: A recurrent neural network solution 2 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=alcantarar.Recurrent_GRF_Prediction) 3 | 4 | This repository supports [Predicting continuous ground reaction 5 | forces from accelerometers during uphill and downhill running: A recurrent neural network 6 | solution](https://peerj.com/articles/12752/). 7 | 8 | The final models and data supporting the published manuscript are archived [here](https://zenodo.org/record/5224624). 9 | 10 | ## Contents 11 | 12 | `Train_LSTM.ipynb` is a notebook that generates the model from the archived data. 13 | 14 | `Test_LSTM.ipynb` is a notebook that shows you how to use the trained LSTM to predict GRFs from your own accelerometer data. 15 | 16 | `LSTM_Example.ipynb` is a notebook that provides a tutorial of how a Long Short-Term Memory Network (LSTM) can be used to 17 | predict ground reaction force (GRF) data from accelerometer data during running. 18 | 19 | `pre_processing.py` contains helper functions used in `LSTM_Example.ipynb` and `Test_LSTM.ipynb`. 20 | 21 | `data/` Contains example accelerometer data, GRF data, condition/demographic data, and LSTM model file. Supports `Test_LSTM.ipynb` and `LSTM_Example.ipynb`. 22 | 23 | If you're going to train an LSTM model using [Google Colab](https://colab.research.google.com/) (recommended), make sure 24 | you utilize their GPU Runtime Type. You will need to adjust the path to `data/` depending on how files are uploaded in 25 | Google Colab. 26 | 27 | ## Questions? 28 | [Open an issue](https://github.com/alcantarar/Recurrent_GRF_Prediction/issues/new) if you have a question or if 29 | something is broken. You can also email me at the address listed in the associated publication. 30 | -------------------------------------------------------------------------------- /Train_LSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Train_LSTM.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "outputs": [], 21 | "source": [ 22 | "# Import packages\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "import numpy as np\n", 25 | "from tensorflow import keras\n", 26 | "import numpy as np" 27 | ], 28 | "metadata": { 29 | "collapsed": false, 30 | "pycharm": { 31 | "name": "#%%\n" 32 | } 33 | } 34 | }, 35 | { 36 | "cell_type": "code", 37 | "metadata": { 38 | "id": "z3TG85RZGMYv" 39 | }, 40 | "source": [ 41 | "# Load Data\n", 42 | "data_filename = '/PATH/TO/all_subs_data_w_footstrike.pkl'\n", 43 | "\n", 44 | "df = np.load(data_filename, allow_pickle=True)\n", 45 | "feats = df['train_feats'] # (trials, frames per trial, features)\n", 46 | "y = df['train_y'] # (trials, frames per trial)\n", 47 | "sub_info = df['train_Sub_Info'] # (trials, features)\n", 48 | "\n", 49 | "# Test LSTM on Representative Subject (#14 when sorted from lowest -> highest RMSE in paper). \n", 50 | "# Expect Validation MSE ~ 0.03 BW (RMSE = 0.17 BW).\n", 51 | "sub_num = 2\n", 52 | "\n", 53 | "test_X = feats[sub_info['Sub'] == sub_num,:,:]\n", 54 | "test_y = y[sub_info['Sub'] == sub_num,:]\n", 55 | "\n", 56 | "train_X = feats[sub_info['Sub'] != sub_num,:,:] \n", 57 | "train_y = y[sub_info['Sub'] != sub_num,:]" 58 | ], 59 | "execution_count": 3, 60 | "outputs": [] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "metadata": { 65 | "id": "TsAW5LbLGTBk" 66 | }, 67 | "source": [ 68 | "# Make sure GPU runtime is activated before training LSTM!\n", 69 | "\n", 70 | "# Build Model\n", 71 | "def build_model(lstm_size, lstm_act, dropout_rate, dense_act, lr=0.001, loss='mean_squared_error'):\n", 72 | "\n", 73 | " #accelerometer data lstm model\n", 74 | " model_inputs = keras.Input(shape=(None,train_X.shape[2]))\n", 75 | " model_features = keras.layers.Dropout(0.2, seed=541)(model_inputs)\n", 76 | " model_features = keras.layers.Bidirectional(keras.layers.LSTM(lstm_size, activation=lstm_act, return_sequences=True), merge_mode='ave')(model_features)\n", 77 | " model_features = keras.layers.Dropout(dropout_rate, seed=541)(model_features)\n", 78 | " model_features = keras.layers.Dense(128, activation=dense_act)(model_features)\n", 79 | " model_features = keras.layers.Dense(384, activation=dense_act)(model_features)\n", 80 | " model_features = keras.layers.Dense(320, activation=dense_act)(model_features)\n", 81 | " model_outputs = keras.layers.Dense(1, activation='linear')(model_features)\n", 82 | "\n", 83 | " model_out = keras.Model(inputs=model_inputs, outputs=model_outputs, name='LSTM')\n", 84 | " # define optimizer algorithm and learning rate\n", 85 | " opt = keras.optimizers.Adam(learning_rate =lr)\n", 86 | " # compile model and define loss function\n", 87 | " model_out.compile(optimizer=opt, loss=loss)\n", 88 | "\n", 89 | " return model_out\n", 90 | "\n", 91 | "model = build_model(\n", 92 | " lstm_size=512,\n", 93 | " lstm_act='tanh',\n", 94 | " dropout_rate=0.4,\n", 95 | " dense_act='relu',\n", 96 | " lr=0.001\n", 97 | " )\n", 98 | "\n", 99 | "# Plot Model Architecture\n", 100 | "# keras.utils.plot_model(model, show_shapes=True, show_layer_names=True)" 101 | ], 102 | "execution_count": 4, 103 | "outputs": [] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "metadata": { 108 | "id": "eqsEfcKhdde2" 109 | }, 110 | "source": [ 111 | "# Train Model\n", 112 | "\n", 113 | "# Define Early Stopping and Checkpoint Callbacks\n", 114 | "model_filename = '/PATH/TO/MODEL.h5'\n", 115 | "\n", 116 | "# Early Stopping\n", 117 | "es = keras.callbacks.EarlyStopping(monitor='val_loss', \n", 118 | " mode='min', \n", 119 | " verbose=0, \n", 120 | " patience=30, \n", 121 | " min_delta=0.001, \n", 122 | " restore_best_weights=True\n", 123 | " )\n", 124 | "# Model Checkpoint\n", 125 | "mc = keras.callbacks.ModelCheckpoint(\n", 126 | " model_filename,\n", 127 | " monitor='val_loss', \n", 128 | " mode='min', \n", 129 | " verbose=1, \n", 130 | " save_best_only=True, \n", 131 | " save_weights_only=False\n", 132 | " )\n", 133 | "\n", 134 | "# Fit Model\n", 135 | "history_accel = model.fit(\n", 136 | " train_X, \n", 137 | " train_y, \n", 138 | " epochs=1000,\n", 139 | " validation_data=(test_X, test_y), \n", 140 | " verbose=1,\n", 141 | " batch_size=32, \n", 142 | " callbacks=[es, mc]\n", 143 | " )\n", 144 | "\n", 145 | "# Plot Train/Validation Loss across epochs\n", 146 | "plt.plot(history_accel.history['loss'], label = 'mse_train')\n", 147 | "plt.plot(history_accel.history['val_loss'], label = 'mse_validation')\n", 148 | "plt.legend()\n", 149 | "plt.show()\n", 150 | "\n", 151 | "keras.backend.clear_session()" 152 | ], 153 | "execution_count": null, 154 | "outputs": [] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "colab": { 160 | "base_uri": "https://localhost:8080/" 161 | }, 162 | "id": "fghi2QqPXXw7", 163 | "outputId": "de9b8181-c943-4e70-8264-2e3fa4b02ab6" 164 | }, 165 | "source": [ 166 | "# Load Trained Model and Calculate RMSE for GRF Waveform\n", 167 | "saved_model = keras.models.load_model(model_filename)\n", 168 | "\n", 169 | "pred_y = saved_model.predict(test_X)\n", 170 | "test_sub_info = sub_info.loc[sub_info['Sub'] == sub_num,:]\n", 171 | "\n", 172 | "test_y = np.squeeze(test_y)\n", 173 | "pred_y = np.squeeze(pred_y)\n", 174 | "\n", 175 | "rmse = []\n", 176 | "trim = 100 # Number of frames to ignore at edge of trial due to lack of prior data for LSTM.\n", 177 | "\n", 178 | "for trial in range(test_sub_info.shape[0]):\n", 179 | " # Calculate RMSE\n", 180 | " trial_rmse = np.sqrt(np.mean((pred_y[trial, trim:-trim] - test_y[trial, trim:-trim])**2))\n", 181 | " trial_rmse = np.round(trial_rmse,3)\n", 182 | " # Calculate rRMSE\n", 183 | " trial_rrmse = trial_rmse / np.mean((\n", 184 | " np.max(pred_y[trial, trim:-trim]) - np.min(pred_y[trial, trim:-trim]),\n", 185 | " np.max(test_y[trial, trim:-trim]) - np.min(test_y[trial, trim:-trim])\n", 186 | " ))*100\n", 187 | " trial_rrmse = np.round(trial_rrmse, 2)\n", 188 | "\n", 189 | " rmse.append(trial_rmse)\n", 190 | "\n", 191 | "# Expect RMSE of approximately 0.17 ± 0.07 BW for Representative Subject (#14 in paper)\n", 192 | "print('MEAN:', np.round(np.mean(rmse),2))\n", 193 | "print('SD:', np.round(np.std(rmse),2))" 194 | ], 195 | "execution_count": 9, 196 | "outputs": [ 197 | { 198 | "output_type": "stream", 199 | "text": [ 200 | "MEAN: 0.16\n", 201 | "SD: 0.05\n" 202 | ], 203 | "name": "stdout" 204 | } 205 | ] 206 | } 207 | ] 208 | } -------------------------------------------------------------------------------- /pre_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.lib.stride_tricks import as_strided as ast 3 | from scipy.signal import butter, filtfilt 4 | 5 | def buttfilt(x, fs, fc, order, axis=1): 6 | """ 7 | BUTTFILT applies a lowpass butterworth filter with a correction factor described in 8 | Research Methods in Biomechanics (2ed) and explained further here: 9 | https://github.com/alcantarar/dryft/issues/22#issuecomment-557771825 10 | :param x: signal to filter 11 | :param fs: sampling frequency (Hz) 12 | :param fc: lowpass cutoff frequency (Hz). If tuple, will do bandpass (low, high). 13 | :param order: final desired filter order. Must be even number. 14 | :param axis: axis to filter along. Default is `axis=1`. 15 | :return: 16 | """ 17 | n_pass = 2 # two passes (one forward, one backward to be zero-lag) 18 | if (order % 2) != 0: 19 | raise ValueError('order must be even integer') 20 | else: 21 | order = order / 2 22 | fn = fs / 2 23 | # Correction factor per Research Methods in Biomechanics (2e) pg 288 24 | c = (2 ** (1 / n_pass) - 1) ** (1 / (2 * order)) 25 | if type(fc) is tuple: 26 | # bandpass filter: 27 | wn_low = (np.tan(np.pi*fc[0]/fs))/c # Apply correction to adjusted cutoff freq (lower boundary) 28 | wn_up = (np.tan(np.pi*fc[1]/fs))/c # Apply correction to adjusted cutoff freq (upper boundary) 29 | fc_corrected_low = np.arctan(wn_low)*fs/np.pi # Hz 30 | fc_corrected_up = np.arctan(wn_up)*fs/np.pi # Hz 31 | b, a = butter(order, [fc_corrected_low/fn, fc_corrected_up/fn], btype='band') 32 | 33 | x_filt = filtfilt(b, a, x, axis=axis) 34 | 35 | else: 36 | # lowpass filter: 37 | wn = (np.tan(np.pi*fc/fs))/c # Apply correction to adjusted cutoff freq 38 | fc_corrected = np.arctan(wn)*fs/np.pi # Hz 39 | b, a = butter(order, fc_corrected/fn) 40 | 41 | x_filt = filtfilt(b, a, x, axis=axis) 42 | 43 | return x_filt 44 | 45 | 46 | def chunk_data(data, window_size, overlap_size=0, flatten_inside_window=True): 47 | """ 48 | CHUNK_DATA was made by Matthew Johnson (github username mattjj). Gist URL (accessed 8/2020): 49 | https://gist.github.com/mattjj/5213172. It is clever. I use this function within window_data_centered(). 50 | :param data: data to be windowed 51 | :param window_size: size of window 52 | :param overlap_size: size of overlap between windows 53 | :param flatten_inside_window: flatten ndim data inside window 54 | :return: ndarray with shape: (trials, number of windows, window size). 55 | """ 56 | assert data.ndim == 1 or data.ndim == 2 57 | if data.ndim == 1: 58 | data = data.reshape((-1, 1)) 59 | 60 | # get the number of overlapping windows that fit into the data 61 | num_windows = (data.shape[0] - window_size) // (window_size - overlap_size) + 1 62 | overhang = data.shape[0] - (num_windows * window_size - (num_windows - 1) * overlap_size) 63 | 64 | # if there's overhang, need an extra window and a zero pad on the data 65 | # (numpy 1.7 has a nice pad function I'm not using here) 66 | if overhang != 0: 67 | num_windows += 1 68 | newdata = np.zeros((num_windows * window_size - (num_windows - 1) * overlap_size, data.shape[1])) 69 | newdata[:data.shape[0]] = data 70 | data = newdata 71 | 72 | sz = data.dtype.itemsize 73 | ret = ast( 74 | data, 75 | shape=(num_windows, window_size * data.shape[1]), 76 | strides=((window_size - overlap_size) * data.shape[1] * sz, sz) 77 | ) 78 | 79 | if flatten_inside_window: 80 | return ret 81 | else: 82 | return ret.reshape((num_windows, -1, data.shape[1])) 83 | 84 | 85 | def window_data_centered(data, window_size, verbose=True): 86 | """ 87 | WINDOW_DATA_CENTERED uses chunk_data() to create windows centered about a given frame. This is accomplished 88 | by padding the trial with data at start/end. Input of 5 trials, each 1000 frames long (5, 1000) and window_size of 6 89 | would result in output shape of (5, 1000, 6) because overlap between windows is window_size-1. X at time points t-3, 90 | t-2, t-1, t, t+1, and t+2 create the 6-frame window that corresponds to Y at time t. 91 | :param data: data (shape (trials, frames)) to be windowed using chunk_data(). 92 | :param window_size: size of windows. Must be even number. 93 | :param verbose: If True, prints input/output shapes. Default True. 94 | :return: ndarray of shape (trials, number of windows, window size) 95 | """ 96 | if (window_size % 2) != 0: 97 | raise ValueError('window_size must be divisible by 2') 98 | else: 99 | pad = int(window_size/2) 100 | overlap = window_size - 1 101 | ds = data.shape 102 | # pad with nearest value at edges 103 | data = np.pad(data, [(0, 0), (pad, pad-1)], 'edge') 104 | 105 | # apply chunk_data() 106 | data = np.apply_along_axis(chunk_data, 1, data, window_size, overlap, True) 107 | if verbose: 108 | print('input shape:', ds) 109 | print('output shape:', data.shape) 110 | 111 | return data 112 | 113 | 114 | def signal_features(data, fs): 115 | """ 116 | SIGNAL_FEATURES used in generate_features() and calculates the following features for each window: 117 | - mean (np.mean) 118 | - standard deviation (np.std) 119 | - range (np.ptp) 120 | - Average 1st Derivative (np.gradient) #commented out 121 | - Average 2nd Derivative (np.gradient) #commented out 122 | - Average 1st Integral (np.cumtrapz) #commented out 123 | - Average 2nd Integral (np.cumtrapz) #commented out 124 | 125 | :param data: ndarray with shape (trials, number of windows, window size) 126 | :return: ndarray with shape (trials, number of windows, number of features) 127 | """ 128 | mean = np.mean(data, axis=2) 129 | std = np.std(data, axis=2) 130 | rg = np.ptp(data, axis=2) 131 | # diff = np.mean(np.gradient(data, axis=2), axis=2) 132 | # diffdiff = np.mean(np.gradient(np.gradient(data, axis=2), axis=2), axis=2) 133 | # integral = np.mean(cumtrapz(data, np.linspace(0, data.shape[2]/fs, data.shape[2])), axis=2) 134 | # temp_integral = cumtrapz(data, np.linspace(0, data.shape[2]/fs, data.shape[2])) 135 | # integralintegral = np.mean(cumtrapz(temp_integral, 136 | # np.linspace(0,(data.shape[2]-1)/fs, (data.shape[2]-1)) 137 | # ), 138 | # axis=2) 139 | features = np.concatenate((np.expand_dims(mean, 2), 140 | np.expand_dims(std, 2), 141 | np.expand_dims(rg, 2), 142 | # np.expand_dims(diff, 2), 143 | # np.expand_dims(diffdiff, 2), 144 | # np.expand_dims(integral, 2), 145 | # np.expand_dims(integralintegral, 2) 146 | ), axis=2) 147 | 148 | return features 149 | 150 | 151 | def subject_info_features(subinfo, test_sub_num, data_shape): 152 | """ 153 | SUBJECT_INFO_FEATURES used in generate_features() and extracts the following features from subinfo df: 154 | - Subject height 155 | - Subject mass 156 | - Running speed 157 | - Treadmill slope 158 | - % of steps Rearfoot strike (RF) 159 | - % of steps Midfoot Strike (MF) 160 | - % of steps Forefoot Strike (FF) 161 | :param subinfo: Pandas dataframe containing information about subject and trial conditions. Has the following column 162 | headers: 'Sub', 'Shoe', 'Condition', 'Sex', 'Age', 'Height', 'Mass'. 163 | :param test_sub_num: Integer of subject to be used to test model when doing Leave-One-Subject-Out Cross Validation. 164 | To ONLY calculate features for test_sub_num, make it negative. 165 | :param data_shape: shape of windowed data to match feature shapes 166 | :return: concatenated features with shape of data_shape. 167 | """ 168 | if test_sub_num >= 0: 169 | ht = np.tile(subinfo.loc[subinfo['Sub'] != test_sub_num, ['Height']].to_numpy(), data_shape) 170 | ms = np.tile(subinfo.loc[subinfo['Sub'] != test_sub_num, ['Mass']].to_numpy(), data_shape) 171 | sp = np.tile(subinfo.loc[subinfo['Sub'] != test_sub_num, ['Speed']].to_numpy(), data_shape) 172 | sl = np.tile(subinfo.loc[subinfo['Sub'] != test_sub_num, ['Slope']].to_numpy(), data_shape) 173 | rf = np.tile(subinfo.loc[subinfo['Sub'] != test_sub_num, ['RFS']].to_numpy(), data_shape) 174 | mf = np.tile(subinfo.loc[subinfo['Sub'] != test_sub_num, ['MFS']].to_numpy(), data_shape) 175 | ff = np.tile(subinfo.loc[subinfo['Sub'] != test_sub_num, ['FFS']].to_numpy(), data_shape) 176 | else: 177 | test_sub_num = np.abs(test_sub_num) 178 | ht = np.tile(subinfo.loc[subinfo['Sub'] == test_sub_num, ['Height']].to_numpy(), data_shape) 179 | ms = np.tile(subinfo.loc[subinfo['Sub'] == test_sub_num, ['Mass']].to_numpy(), data_shape) 180 | sp = np.tile(subinfo.loc[subinfo['Sub'] == test_sub_num, ['Speed']].to_numpy(), data_shape) 181 | sl = np.tile(subinfo.loc[subinfo['Sub'] == test_sub_num, ['Slope']].to_numpy(), data_shape) 182 | rf = np.tile(subinfo.loc[subinfo['Sub'] == test_sub_num, ['RFS']].to_numpy(), data_shape) 183 | mf = np.tile(subinfo.loc[subinfo['Sub'] == test_sub_num, ['MFS']].to_numpy(), data_shape) 184 | ff = np.tile(subinfo.loc[subinfo['Sub'] == test_sub_num, ['FFS']].to_numpy(), data_shape) 185 | features = np.concatenate((np.expand_dims(ht, 2), 186 | np.expand_dims(ms, 2), 187 | np.expand_dims(sp, 2), 188 | np.expand_dims(sl, 2), 189 | np.expand_dims(rf, 2), 190 | np.expand_dims(mf, 2), 191 | np.expand_dims(ff, 2)), axis=2) 192 | 193 | return features 194 | 195 | 196 | def generate_features(data, fs, subinfo, test_sub_num, include_sub_info_feats=True): 197 | """ 198 | 199 | :param data: Pandas dataframe containing windowed data (see window_data_centered()). 200 | :param fs: Sampling frequency of [data] in Hz. 201 | :param subinfo: Pandas dataframe containing information about subject and trial conditions. Has the following column 202 | headers: 'Sub', 'Shoe', 'Condition', 'Sex', 'Age', 'Height', 'Mass'. 203 | :param test_sub_num: Integer of subject to be used to test model when doing Leave-One-Subject-Out Cross Validation. 204 | :param include_sub_info_feats: Boolean to concatenate signal and subject info/condition features. Default True. If 205 | False, return only features calculated from `data`. 206 | :return: concatenated features with shape of data_shape. 207 | """ 208 | sig_feats = signal_features(data=data, fs=fs) 209 | if include_sub_info_feats: 210 | sub_cond_feats = subject_info_features(subinfo, test_sub_num, data.shape[1]) 211 | features = np.concatenate((sig_feats, sub_cond_feats), axis=2) 212 | else: 213 | features = sig_feats 214 | 215 | return features 216 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Test_LSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Test_LSTM.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": { 23 | "colab": { 24 | "base_uri": "https://localhost:8080/" 25 | }, 26 | "id": "GfTtJnxYyFvd", 27 | "outputId": "6bc828f2-84c8-474c-a0d5-fd408ab40f20" 28 | }, 29 | "outputs": [ 30 | { 31 | "output_type": "stream", 32 | "name": "stdout", 33 | "text": [ 34 | "Mounted at /content/gdrive\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "# Step 0: Download repository and unzipp folder in google drive. \n", 40 | "\n", 41 | "# Mount Google Drive to access data\n", 42 | "from google.colab import drive\n", 43 | "drive.mount('/content/gdrive')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "source": [ 49 | "import matplotlib.pyplot as plt\n", 50 | "import numpy as np\n", 51 | "import pandas as pd\n", 52 | "from tensorflow import keras\n", 53 | "from pickle import load\n", 54 | "# Import helper functions from pre_processing.py:\n", 55 | "import sys\n", 56 | "sys.path.append('/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/')\n", 57 | "from pre_processing import *" 58 | ], 59 | "metadata": { 60 | "id": "rmPC56BayHiQ" 61 | }, 62 | "execution_count": 2, 63 | "outputs": [] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "source": [ 68 | "# import data\n", 69 | "df = pd.read_csv('/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Sample_test_data.csv')\n", 70 | "\n", 71 | "# Notes about test data:\n", 72 | "# NO FOOTSTRIKE DATA\n", 73 | "# X is vertical acceleration (including gravitation)\n", 74 | "# Z is anteroposterior (sign is flipped though)\n", 75 | "# Velocity is around 3 m/s\n", 76 | "# Height is 175 cm\n", 77 | "# Mass is 60 kg\n", 78 | "# Sampling Frequency is 512 Hz" 79 | ], 80 | "metadata": { 81 | "id": "-M_Gf98MyRxc" 82 | }, 83 | "execution_count": 3, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "source": [ 89 | "# reorient to match training data coordinate system\n", 90 | "Sacrum_V = df['Accel_LN_X']*-1\n", 91 | "Sacrum_AP = df['Accel_LN_Z']*-1\n", 92 | "\n", 93 | "plt.plot(Sacrum_AP)\n", 94 | "plt.plot(Sacrum_V)" 95 | ], 96 | "metadata": { 97 | "colab": { 98 | "base_uri": "https://localhost:8080/", 99 | "height": 282 100 | }, 101 | "id": "8rb4BtP4zC1w", 102 | "outputId": "129f1044-2451-46f2-9a5b-0eb1d5b01d5a" 103 | }, 104 | "execution_count": 4, 105 | "outputs": [ 106 | { 107 | "output_type": "execute_result", 108 | "data": { 109 | "text/plain": [ 110 | "[]" 111 | ] 112 | }, 113 | "metadata": {}, 114 | "execution_count": 4 115 | }, 116 | { 117 | "output_type": "display_data", 118 | "data": { 119 | "image/png": "\n", 120 | "text/plain": [ 121 | "
" 122 | ] 123 | }, 124 | "metadata": { 125 | "needs_background": "light" 126 | } 127 | } 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "source": [ 133 | "# pre processing for acceleration signal\n", 134 | "# filter\n", 135 | "Sacrum_V_f = buttfilt(Sacrum_V, 512, 20, 4, axis=0)\n", 136 | "Sacrum_AP_f = buttfilt(Sacrum_AP, 512, 20, 4, axis=0)\n", 137 | "\n", 138 | "# set negative sacrum vertical acceleration to 0\n", 139 | "Sacrum_V_f[Sacrum_V_f < 0] = 0\n", 140 | "\n", 141 | "# load scaler for subject info and signals\n", 142 | "sub_scaler = load(open('/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/LSTM_scaler.pkl', 'rb'))\n", 143 | "V_accel_max = load(open('/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/V_accel_max.pkl', 'rb'))\n", 144 | "AP_accel_max = load(open('/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/AP_accel_max.pkl', 'rb'))\n", 145 | "\n", 146 | "\n", 147 | "# Split trial into 2 parts. Window_data_centered() expects 2D array\n", 148 | "X_V = np.vstack((Sacrum_V_f[1000:1500], Sacrum_V_f[1500:2000])) # each trial is a row\n", 149 | "X_AP = np.vstack((Sacrum_AP_f[1000:1500], Sacrum_AP_f[1500:2000])) # each trial is a row\n", 150 | "\n", 151 | "# Insert your subject and condition info here (2 trials in this example):\n", 152 | "sub_info = {'Height': [175, 175], 'Mass': [60,60], 'Speed': [3.0, 3.0], 'Slope': [0, 0]}\n", 153 | "sub_info = pd.DataFrame(data=sub_info)\n", 154 | "sub_info_scaled = sub_scaler.transform(sub_info)\n", 155 | "\n", 156 | "# break signal into 6-frame windows for LSTM\n", 157 | "window_size = 6\n", 158 | "X_V = window_data_centered(X_V, window_size, verbose = False)\n", 159 | "X_AP = window_data_centered(X_AP, window_size, verbose = False)\n", 160 | "\n", 161 | "# Generate features\n", 162 | "feats_V = signal_features(X_V, 512)\n", 163 | "feats_AP = signal_features(X_AP, 512)\n", 164 | "\n", 165 | "# scale input features based on model training data\n", 166 | "def max_scale_3d(feats, train_max):\n", 167 | " feats_scaled = feats / train_max\n", 168 | " print('max used to normalize: ',train_max)\n", 169 | " return feats_scaled\n", 170 | "\n", 171 | "feats_V_scaled = max_scale_3d(feats_V, V_accel_max)\n", 172 | "feats_AP_scaled = max_scale_3d(feats_AP, AP_accel_max) \n", 173 | "\n", 174 | "# join AP & V accel signal features together\n", 175 | "signal_feats = np.concatenate((feats_V_scaled, feats_AP_scaled), axis = 2)\n", 176 | "print('Shape of signal features: ', signal_feats.shape)\n", 177 | "\n", 178 | "# create 3d array of scaled subject/condition info features (height, mass, speed, slope) that matches shape of signal features\n", 179 | "sub_cond_feats = np.ones((2, 500, 4))*sub_info_scaled[0,:]\n", 180 | "print('Shape of subject/condition features: ', sub_cond_feats.shape)\n", 181 | "\n", 182 | "#concatenate subject and signal features \n", 183 | "input_features = np.concatenate((signal_feats, sub_cond_feats), axis = 2)\n", 184 | "print('Input feature shape: ', input_features.shape)" 185 | ], 186 | "metadata": { 187 | "colab": { 188 | "base_uri": "https://localhost:8080/" 189 | }, 190 | "id": "ph9MCFv44L9Y", 191 | "outputId": "f53b3390-f675-481e-9c9b-d14f95149c7f" 192 | }, 193 | "execution_count": 5, 194 | "outputs": [ 195 | { 196 | "output_type": "stream", 197 | "name": "stdout", 198 | "text": [ 199 | "max used to normalize: [22.61166667 11.90211559 27.11 ]\n", 200 | "max used to normalize: [ 5.42166667 12.39942931 31.06 ]\n", 201 | "Shape of signal features: (2, 500, 6)\n", 202 | "Shape of subject/condition features: (2, 500, 4)\n", 203 | "Input feature shape: (2, 500, 10)\n" 204 | ] 205 | } 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "source": [ 211 | "# NO FOOTSTRIKE DATA AVAILABLE, so use all_subs_model_wo_footstrike.h5\n", 212 | "#Download model file from https://zenodo.org/record/5213939 and save in /data/ folder.\n", 213 | "saved_model = keras.models.load_model('/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/all_subs_model_wo_footstrike.h5')" 214 | ], 215 | "metadata": { 216 | "id": "zxg6uibsCCfL" 217 | }, 218 | "execution_count": 6, 219 | "outputs": [] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "source": [ 224 | "# Make prediction \n", 225 | "GRF = saved_model.predict(input_features)" 226 | ], 227 | "metadata": { 228 | "id": "bTiXd_L7EztL" 229 | }, 230 | "execution_count": 7, 231 | "outputs": [] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "source": [ 236 | "# Plot prediction for 1st trial\n", 237 | "plt.plot(np.squeeze(GRF[0,:]), label='vGRF [Bodyweight]')\n", 238 | "plt.plot(Sacrum_V_f[1000:1500]/9.81, label='Vertical Sacrum Accel [g]')\n", 239 | "plt.legend(bbox_to_anchor = (0, 1.02, 1, 0.2), loc = 'lower left', mode='expand', ncol=2)\n", 240 | "plt.xlabel('Time [seconds]')\n" 241 | ], 242 | "metadata": { 243 | "colab": { 244 | "base_uri": "https://localhost:8080/", 245 | "height": 324 246 | }, 247 | "id": "_iCt4B2sfzyq", 248 | "outputId": "5ba03059-c9db-4ab8-f7d7-c3db42e63b81" 249 | }, 250 | "execution_count": 8, 251 | "outputs": [ 252 | { 253 | "output_type": "execute_result", 254 | "data": { 255 | "text/plain": [ 256 | "Text(0.5, 0, 'Time [seconds]')" 257 | ] 258 | }, 259 | "metadata": {}, 260 | "execution_count": 8 261 | }, 262 | { 263 | "output_type": "display_data", 264 | "data": { 265 | "image/png": "\n", 266 | "text/plain": [ 267 | "
" 268 | ] 269 | }, 270 | "metadata": { 271 | "needs_background": "light" 272 | } 273 | } 274 | ] 275 | } 276 | ] 277 | } -------------------------------------------------------------------------------- /LSTM_Example.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"LSTM_Example.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"BZrg_p7vXcZa"},"source":["## Load packages and functions"]},{"cell_type":"code","metadata":{"id":"GBF-IYr4h7VG","executionInfo":{"status":"ok","timestamp":1641319957206,"user_tz":480,"elapsed":6493,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}}},"source":["\n","import matplotlib.pyplot as plt\n","import numpy as np\n","np.random.seed(541)\n","import pandas as pd\n","from sklearn.preprocessing import MinMaxScaler\n","from tensorflow import keras, random\n","random.set_seed(541)\n","import numpy as np\n","import sys\n","import os\n"],"execution_count":1,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"12fTjeMWX80S"},"source":["## User Inputs\n","Mount Google Drive, define path to repository, and import functions\n"]},{"cell_type":"code","metadata":{"id":"371yqPsVX-uQ","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1641320013106,"user_tz":480,"elapsed":12806,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}},"outputId":"db564bf1-c2cf-4643-87a7-750b9836cbe7"},"source":["# Mounting Google Drive is faster than uploading files. Move repository into Google Drive and update paths\n","from google.colab import drive\n","drive.mount('/content/gdrive')\n","sys.path.append('/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/') # or '/content/gdrive/My Drive/PATH/TO/REPOSITORY/'\n","directory = '/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/'\n","# Import functions from pre_processing.py\n","from pre_processing import *\n","# Data directory\n","directory = '/content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/' # or '/content/gdrive/My Drive/PATH/TO/REPOSITORY/'"],"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive\n"]}]},{"cell_type":"markdown","metadata":{"id":"PZh5qk0CX3Ls"},"source":["## Load Data"]},{"cell_type":"code","metadata":{"id":"AKy4ut4011AK","executionInfo":{"status":"ok","timestamp":1641320018224,"user_tz":480,"elapsed":1818,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}}},"source":["fs = 500 # Sampling Frequency\n","\n","# Accelerometer Data\n","X = pd.read_csv(os.path.join(directory, 'Accelerometer_one_sub.csv'), header=None).values\n","# Subject and Condition Data\n","Sub_Info = pd.read_csv(os.path.join(directory, 'Sub_Info_one_sub.csv'))\n","# Normal GRF data\n","y = pd.read_csv(os.path.join(directory, 'GRF_one_sub.csv'), header=None).values\n","\n","## Train/Test Split ----\n","train_X = X[(Sub_Info['Slope'] != -5) | (Sub_Info['Slope'] != 5),:]\n","train_y = y[(Sub_Info['Slope'] != -5) | (Sub_Info['Slope'] != 5),:]\n","train_Sub_Info = Sub_Info.loc[(Sub_Info['Slope'] != -5) | (Sub_Info['Slope'] != 5),:]\n","train_Sub_Info.reset_index(drop=True, inplace=True)\n","\n","test_X = X[(Sub_Info['Slope'] == -5) | (Sub_Info['Slope'] == 5),:]\n","test_y = y[(Sub_Info['Slope'] == -5) | (Sub_Info['Slope'] == 5),:]\n","test_Sub_Info = Sub_Info.loc[(Sub_Info['Slope'] == -5) | (Sub_Info['Slope'] == 5),:]\n","test_Sub_Info.reset_index(drop=True, inplace=True)\n"],"execution_count":4,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PJp05JJoYjLO"},"source":["## Scale Features "]},{"cell_type":"code","metadata":{"id":"nUT285l7YnQH","executionInfo":{"status":"ok","timestamp":1641320020513,"user_tz":480,"elapsed":118,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}}},"source":["l = 2480 # Accelerometer data is concatenated\n","train_X_V = train_X[:,0:l]\n","train_X_AP = train_X[:,l:l*2]\n","test_X_V = test_X[:,0:l]\n","test_X_AP = test_X[:,l:l*2]\n","\n","# Sub_Info\n","train_Sub_Feats = train_Sub_Info[['Height', 'Mass', 'Speed', 'Slope']]\n","test_Sub_Feats = test_Sub_Info[['Height', 'Mass', 'Speed', 'Slope']]\n","scaler_sub = MinMaxScaler()\n","train_Sub_Feats = pd.DataFrame(scaler_sub.fit_transform(train_Sub_Feats))\n","test_Sub_Feats = pd.DataFrame(scaler_sub.transform(test_Sub_Feats))\n","train_Sub_Feats.columns = ['Height', 'Mass', 'Speed', 'Slope']\n","test_Sub_Feats.columns = ['Height', 'Mass', 'Speed', 'Slope']\n","\n","#Don't include foot strike percentages in min/max scaler because they're \n","#already scaled to each other. min/max scaler is per feature (column).\n","train_Sub_Feats[['RFS', 'MFS', 'FFS']] = train_Sub_Info[['RFS', 'MFS', 'FFS']].copy()/100\n","test_Sub_Feats[['RFS', 'MFS', 'FFS']] = test_Sub_Info[['RFS', 'MFS', 'FFS']].copy()/100\n","\n","# Copy over Subject ID\n","train_Sub_Feats['Sub'] = train_Sub_Info['Sub']\n","test_Sub_Feats['Sub'] = train_Sub_Info['Sub'][0]\n","sub_num = 2"],"execution_count":5,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7tgjoPOiYsHG"},"source":["## Window accelerometer data"]},{"cell_type":"code","metadata":{"id":"UfaLfETuYrVw","executionInfo":{"status":"ok","timestamp":1641320022735,"user_tz":480,"elapsed":112,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}}},"source":["\n","# N-frame windows centered on prediction frame. Pad w/ nearest value\n","window_size = 6 # 6 frames @ 500 Hz == 12 ms\n","\n","train_X_V = window_data_centered(train_X_V, window_size, verbose=False)\n","train_X_AP = window_data_centered(train_X_AP, window_size, verbose=False)\n","train_y = np.reshape(train_y, (train_y.shape[0],train_y.shape[1],1))\n","\n","test_X_V = window_data_centered(test_X_V, window_size, verbose=False)\n","test_X_AP = window_data_centered(test_X_AP, window_size, verbose=False)\n","test_y = np.reshape(test_y, (test_y.shape[0],test_y.shape[1],1))"],"execution_count":6,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"TNw8taZdY0p8"},"source":["## Generate Features"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"7gC7Gw8o78hW","executionInfo":{"status":"ok","timestamp":1641320024458,"user_tz":480,"elapsed":282,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}},"outputId":"68ee8039-0d02-4cef-9824-21465c92d930"},"source":["# Calculates mean, sd, range in each window of accelerometer data. generate_features() can also reshape sub_info to match dimensions.\n","train_feats_V = generate_features(train_X_V, fs, train_Sub_Feats, -sub_num, include_sub_info_feats=False) # Just signal feats\n","train_feats_AP = generate_features(train_X_AP, fs, train_Sub_Feats, -sub_num, include_sub_info_feats=True) # Signal and Sub_Info feats\n","sub_info_feats = train_feats_AP[-7:-1].copy()\n","test_feats_V = generate_features(test_X_V, fs, test_Sub_Feats, -sub_num, include_sub_info_feats=False) # Just signal feats\n","test_feats_AP = generate_features(test_X_AP, fs, test_Sub_Feats, -sub_num, include_sub_info_feats=True) # Signal and Sub_Info feats\n","\n","train_sub_info_feats = train_feats_AP[:,:,-7:].copy()\n","train_feats_AP = train_feats_AP[:,:,0:3].copy()\n","\n","test_sub_info_feats = test_feats_AP[:,:,-7:].copy()\n","test_feats_AP = test_feats_AP[:,:,0:3].copy()\n","\n","# #remove footstrike features if that's not your thing\n","# train_sub_info_feats = train_sub_info_feats[:,:,0:4]\n","# test_sub_info_feats = test_sub_info_feats[:,:,0:4]\n","\n","def max_scale_3d(feats, train_max):\n"," feats_scaled = feats / train_max\n"," return feats_scaled\n","\n","train_feats_V_scaled = max_scale_3d(train_feats_V, np.max(train_feats_V, axis=(0,1)))\n","train_feats_AP_scaled = max_scale_3d(train_feats_AP, np.max(train_feats_AP, axis=(0,1)))\n","\n","test_feats_V_scaled = max_scale_3d(test_feats_V, np.max(train_feats_V, axis=(0,1)))\n","test_feats_AP_scaled = max_scale_3d(test_feats_AP, np.max(train_feats_AP, axis=(0,1)))\n","\n","train_feats = np.concatenate((train_feats_V_scaled, train_feats_AP_scaled, train_sub_info_feats), axis = 2)\n","test_feats = np.concatenate((test_feats_V_scaled, test_feats_AP_scaled, test_sub_info_feats), axis = 2)\n","\n","print('TRAINING DATA: \\nNumber of trials: %d \\nNumber of overlapping windows: %d \\nNumber of features per window: %d' % train_feats.shape)\n","print('\\nTESTING DATA: \\nNumber of trials: %d \\nNumber of overlapping windows: %d \\nNumber of features per window: %d' % test_feats.shape)"],"execution_count":7,"outputs":[{"output_type":"stream","name":"stdout","text":["TRAINING DATA: \n","Number of trials: 31 \n","Number of overlapping windows: 2480 \n","Number of features per window: 13\n","\n","TESTING DATA: \n","Number of trials: 13 \n","Number of overlapping windows: 2480 \n","Number of features per window: 13\n"]}]},{"cell_type":"markdown","metadata":{"id":"oq__Sm7fZJz3"},"source":["## Construct Model"]},{"cell_type":"code","metadata":{"id":"_nLIjkaUr9O8","executionInfo":{"status":"ok","timestamp":1641320040825,"user_tz":480,"elapsed":3543,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}}},"source":["def build_model(lr=0.001, loss='mean_squared_error'):\n","\n"," #accelerometer data lstm model\n"," accel_inputs = keras.Input(shape=(None,train_feats.shape[2])) # Shape is number of features. Undefined number of windows\n"," accel_features = keras.layers.Dropout(0.2, seed=541,)(accel_inputs)\n"," accel_features = keras.layers.Bidirectional(keras.layers.LSTM(512, activation='tanh', return_sequences=True), merge_mode='ave')(accel_features)\n"," accel_features = keras.layers.Dropout(0.4, seed=541)(accel_features)\n"," accel_features = keras.layers.Dense(128, activation='relu')(accel_features)\n"," accel_features = keras.layers.Dense(384, activation='relu')(accel_features)\n"," accel_features = keras.layers.Dense(320, activation='relu')(accel_features)\n"," accel_outputs = keras.layers.Dense(1, activation='linear')(accel_features)\n","\n"," model_out = keras.Model(inputs=accel_inputs, outputs=accel_outputs, name='Accel_subcond_LSTM')\n"," # define optimizer algorithm and learning rate\n"," opt = keras.optimizers.Adam(learning_rate =lr)\n"," # compile model and define loss function\n"," model_out.compile(optimizer=opt, loss=loss)\n","\n"," return model_out\n","\n","# Build Model\n","model = build_model()\n","\n","# Plot Model\n","# keras.utils.plot_model(model, show_shapes=True, show_layer_names=False)\n"],"execution_count":8,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ThGBJo0ZZN_D"},"source":["## Train Model\n","Make sure you're using the GPU runtime type in Colab (Runtime > Change runtime type > GPU)"]},{"cell_type":"code","metadata":{"id":"0cn2-olBEk08","colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"status":"ok","timestamp":1641320579157,"user_tz":480,"elapsed":322539,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}},"outputId":"7618a1b0-2254-4011-abe0-eb18d65d539f"},"source":["# Define Early Stopping and Checkpoint Callbacks\n","model_filename = os.path.join(directory, 'Model.h5')\n","\n","# early stopping\n","es = keras.callbacks.EarlyStopping(monitor='val_loss', \n"," mode='min', \n"," verbose=0, \n"," patience=25, # low for example \n"," min_delta=0.001, \n"," restore_best_weights=True\n"," )\n","\n","# model checkpoint\n","mc = keras.callbacks.ModelCheckpoint(\n"," model_filename,\n"," monitor='val_loss', \n"," mode='min', \n"," verbose=1, \n"," save_best_only=True, \n"," save_weights_only=False\n"," )\n","\n","# Fit Model\n","history_accel = model.fit(\n"," train_feats, \n"," train_y, \n"," epochs=100, # low for example\n"," validation_data=(test_feats, test_y), \n"," verbose=1,\n"," batch_size=32, \n"," callbacks=[es, mc]\n"," )\n","\n","# Plot Train/Validation Loss across epochs\n","plt.plot(history_accel.history['loss'], label = 'mse_train')\n","plt.plot(history_accel.history['val_loss'], label = 'mse_validation')\n","plt.legend()\n","plt.show()"],"execution_count":11,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/100\n","1/1 [==============================] - ETA: 0s - loss: 0.8701\n","Epoch 00001: val_loss improved from inf to 0.87679, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 4s 4s/step - loss: 0.8701 - val_loss: 0.8768\n","Epoch 2/100\n","1/1 [==============================] - ETA: 0s - loss: 0.8569\n","Epoch 00002: val_loss improved from 0.87679 to 0.82845, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.8569 - val_loss: 0.8284\n","Epoch 3/100\n","1/1 [==============================] - ETA: 0s - loss: 0.8108\n","Epoch 00003: val_loss improved from 0.82845 to 0.75060, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.8108 - val_loss: 0.7506\n","Epoch 4/100\n","1/1 [==============================] - ETA: 0s - loss: 0.7371\n","Epoch 00004: val_loss improved from 0.75060 to 0.67830, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.7371 - val_loss: 0.6783\n","Epoch 5/100\n","1/1 [==============================] - ETA: 0s - loss: 0.6751\n","Epoch 00005: val_loss did not improve from 0.67830\n","1/1 [==============================] - 2s 2s/step - loss: 0.6751 - val_loss: 0.7121\n","Epoch 6/100\n","1/1 [==============================] - ETA: 0s - loss: 0.7043\n","Epoch 00006: val_loss did not improve from 0.67830\n","1/1 [==============================] - 3s 3s/step - loss: 0.7043 - val_loss: 0.7137\n","Epoch 7/100\n","1/1 [==============================] - ETA: 0s - loss: 0.7096\n","Epoch 00007: val_loss improved from 0.67830 to 0.63056, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.7096 - val_loss: 0.6306\n","Epoch 8/100\n","1/1 [==============================] - ETA: 0s - loss: 0.6294\n","Epoch 00008: val_loss improved from 0.63056 to 0.59665, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.6294 - val_loss: 0.5966\n","Epoch 9/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5958\n","Epoch 00009: val_loss improved from 0.59665 to 0.59193, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.5958 - val_loss: 0.5919\n","Epoch 10/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5839\n","Epoch 00010: val_loss improved from 0.59193 to 0.58144, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.5839 - val_loss: 0.5814\n","Epoch 11/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5730\n","Epoch 00011: val_loss improved from 0.58144 to 0.54089, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.5730 - val_loss: 0.5409\n","Epoch 12/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5340\n","Epoch 00012: val_loss improved from 0.54089 to 0.43864, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.5340 - val_loss: 0.4386\n","Epoch 13/100\n","1/1 [==============================] - ETA: 0s - loss: 0.4403\n","Epoch 00013: val_loss did not improve from 0.43864\n","1/1 [==============================] - 3s 3s/step - loss: 0.4403 - val_loss: 0.9307\n","Epoch 14/100\n","1/1 [==============================] - ETA: 0s - loss: 0.9899\n","Epoch 00014: val_loss did not improve from 0.43864\n","1/1 [==============================] - 3s 3s/step - loss: 0.9899 - val_loss: 0.4753\n","Epoch 15/100\n","1/1 [==============================] - ETA: 0s - loss: 0.4626\n","Epoch 00015: val_loss did not improve from 0.43864\n","1/1 [==============================] - 3s 3s/step - loss: 0.4626 - val_loss: 0.5768\n","Epoch 16/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5532\n","Epoch 00016: val_loss did not improve from 0.43864\n","1/1 [==============================] - 2s 2s/step - loss: 0.5532 - val_loss: 0.5917\n","Epoch 17/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5675\n","Epoch 00017: val_loss did not improve from 0.43864\n","1/1 [==============================] - 3s 3s/step - loss: 0.5675 - val_loss: 0.5442\n","Epoch 18/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5287\n","Epoch 00018: val_loss did not improve from 0.43864\n","1/1 [==============================] - 3s 3s/step - loss: 0.5287 - val_loss: 0.4455\n","Epoch 19/100\n","1/1 [==============================] - ETA: 0s - loss: 0.4510\n","Epoch 00019: val_loss improved from 0.43864 to 0.40471, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.4510 - val_loss: 0.4047\n","Epoch 20/100\n","1/1 [==============================] - ETA: 0s - loss: 0.4468\n","Epoch 00020: val_loss improved from 0.40471 to 0.37465, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.4468 - val_loss: 0.3747\n","Epoch 21/100\n","1/1 [==============================] - ETA: 0s - loss: 0.4075\n","Epoch 00021: val_loss improved from 0.37465 to 0.32763, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.4075 - val_loss: 0.3276\n","Epoch 22/100\n","1/1 [==============================] - ETA: 0s - loss: 0.3391\n","Epoch 00022: val_loss did not improve from 0.32763\n","1/1 [==============================] - 2s 2s/step - loss: 0.3391 - val_loss: 0.3322\n","Epoch 23/100\n","1/1 [==============================] - ETA: 0s - loss: 0.3358\n","Epoch 00023: val_loss improved from 0.32763 to 0.29082, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.3358 - val_loss: 0.2908\n","Epoch 24/100\n","1/1 [==============================] - ETA: 0s - loss: 0.2999\n","Epoch 00024: val_loss improved from 0.29082 to 0.19393, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.2999 - val_loss: 0.1939\n","Epoch 25/100\n","1/1 [==============================] - ETA: 0s - loss: 0.2183\n","Epoch 00025: val_loss did not improve from 0.19393\n","1/1 [==============================] - 3s 3s/step - loss: 0.2183 - val_loss: 0.4725\n","Epoch 26/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5013\n","Epoch 00026: val_loss did not improve from 0.19393\n","1/1 [==============================] - 2s 2s/step - loss: 0.5013 - val_loss: 0.3533\n","Epoch 27/100\n","1/1 [==============================] - ETA: 0s - loss: 0.3524\n","Epoch 00027: val_loss did not improve from 0.19393\n","1/1 [==============================] - 3s 3s/step - loss: 0.3524 - val_loss: 0.5258\n","Epoch 28/100\n","1/1 [==============================] - ETA: 0s - loss: 0.4990\n","Epoch 00028: val_loss did not improve from 0.19393\n","1/1 [==============================] - 3s 3s/step - loss: 0.4990 - val_loss: 0.5414\n","Epoch 29/100\n","1/1 [==============================] - ETA: 0s - loss: 0.5089\n","Epoch 00029: val_loss did not improve from 0.19393\n","1/1 [==============================] - 2s 2s/step - loss: 0.5089 - val_loss: 0.4434\n","Epoch 30/100\n","1/1 [==============================] - ETA: 0s - loss: 0.4210\n","Epoch 00030: val_loss did not improve from 0.19393\n","1/1 [==============================] - 3s 3s/step - loss: 0.4210 - val_loss: 0.2944\n","Epoch 31/100\n","1/1 [==============================] - ETA: 0s - loss: 0.3223\n","Epoch 00031: val_loss did not improve from 0.19393\n","1/1 [==============================] - 3s 3s/step - loss: 0.3223 - val_loss: 0.3272\n","Epoch 32/100\n","1/1 [==============================] - ETA: 0s - loss: 0.4339\n","Epoch 00032: val_loss did not improve from 0.19393\n","1/1 [==============================] - 2s 2s/step - loss: 0.4339 - val_loss: 0.2391\n","Epoch 33/100\n","1/1 [==============================] - ETA: 0s - loss: 0.2720\n","Epoch 00033: val_loss did not improve from 0.19393\n","1/1 [==============================] - 3s 3s/step - loss: 0.2720 - val_loss: 0.2240\n","Epoch 34/100\n","1/1 [==============================] - ETA: 0s - loss: 0.2364\n","Epoch 00034: val_loss did not improve from 0.19393\n","1/1 [==============================] - 3s 3s/step - loss: 0.2364 - val_loss: 0.2212\n","Epoch 35/100\n","1/1 [==============================] - ETA: 0s - loss: 0.2383\n","Epoch 00035: val_loss did not improve from 0.19393\n","1/1 [==============================] - 3s 3s/step - loss: 0.2383 - val_loss: 0.2093\n","Epoch 36/100\n","1/1 [==============================] - ETA: 0s - loss: 0.2323\n","Epoch 00036: val_loss improved from 0.19393 to 0.18982, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.2323 - val_loss: 0.1898\n","Epoch 37/100\n","1/1 [==============================] - ETA: 0s - loss: 0.2184\n","Epoch 00037: val_loss improved from 0.18982 to 0.16987, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.2184 - val_loss: 0.1699\n","Epoch 38/100\n","1/1 [==============================] - ETA: 0s - loss: 0.2014\n","Epoch 00038: val_loss improved from 0.16987 to 0.15399, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.2014 - val_loss: 0.1540\n","Epoch 39/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1921\n","Epoch 00039: val_loss improved from 0.15399 to 0.13887, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1921 - val_loss: 0.1389\n","Epoch 40/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1746\n","Epoch 00040: val_loss improved from 0.13887 to 0.11846, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1746 - val_loss: 0.1185\n","Epoch 41/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1606\n","Epoch 00041: val_loss improved from 0.11846 to 0.10000, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1606 - val_loss: 0.1000\n","Epoch 42/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1459\n","Epoch 00042: val_loss improved from 0.10000 to 0.09653, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1459 - val_loss: 0.0965\n","Epoch 43/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1326\n","Epoch 00043: val_loss did not improve from 0.09653\n","1/1 [==============================] - 2s 2s/step - loss: 0.1326 - val_loss: 0.1054\n","Epoch 44/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1384\n","Epoch 00044: val_loss did not improve from 0.09653\n","1/1 [==============================] - 3s 3s/step - loss: 0.1384 - val_loss: 0.1110\n","Epoch 45/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1397\n","Epoch 00045: val_loss did not improve from 0.09653\n","1/1 [==============================] - 3s 3s/step - loss: 0.1397 - val_loss: 0.1051\n","Epoch 46/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1383\n","Epoch 00046: val_loss did not improve from 0.09653\n","1/1 [==============================] - 3s 3s/step - loss: 0.1383 - val_loss: 0.0966\n","Epoch 47/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1388\n","Epoch 00047: val_loss improved from 0.09653 to 0.08971, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1388 - val_loss: 0.0897\n","Epoch 48/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1364\n","Epoch 00048: val_loss improved from 0.08971 to 0.08350, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1364 - val_loss: 0.0835\n","Epoch 49/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1202\n","Epoch 00049: val_loss improved from 0.08350 to 0.07772, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1202 - val_loss: 0.0777\n","Epoch 50/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1110\n","Epoch 00050: val_loss improved from 0.07772 to 0.06849, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1110 - val_loss: 0.0685\n","Epoch 51/100\n","1/1 [==============================] - ETA: 0s - loss: 0.1008\n","Epoch 00051: val_loss improved from 0.06849 to 0.05752, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.1008 - val_loss: 0.0575\n","Epoch 52/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0923\n","Epoch 00052: val_loss improved from 0.05752 to 0.05125, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0923 - val_loss: 0.0513\n","Epoch 53/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0822\n","Epoch 00053: val_loss did not improve from 0.05125\n","1/1 [==============================] - 2s 2s/step - loss: 0.0822 - val_loss: 0.0514\n","Epoch 54/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0798\n","Epoch 00054: val_loss did not improve from 0.05125\n","1/1 [==============================] - 3s 3s/step - loss: 0.0798 - val_loss: 0.0525\n","Epoch 55/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0795\n","Epoch 00055: val_loss did not improve from 0.05125\n","1/1 [==============================] - 3s 3s/step - loss: 0.0795 - val_loss: 0.0514\n","Epoch 56/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0748\n","Epoch 00056: val_loss improved from 0.05125 to 0.05019, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0748 - val_loss: 0.0502\n","Epoch 57/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0738\n","Epoch 00057: val_loss improved from 0.05019 to 0.04883, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0738 - val_loss: 0.0488\n","Epoch 58/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0724\n","Epoch 00058: val_loss improved from 0.04883 to 0.04612, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0724 - val_loss: 0.0461\n","Epoch 59/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0674\n","Epoch 00059: val_loss improved from 0.04612 to 0.04296, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0674 - val_loss: 0.0430\n","Epoch 60/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0617\n","Epoch 00060: val_loss improved from 0.04296 to 0.03937, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0617 - val_loss: 0.0394\n","Epoch 61/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0593\n","Epoch 00061: val_loss improved from 0.03937 to 0.03407, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0593 - val_loss: 0.0341\n","Epoch 62/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0554\n","Epoch 00062: val_loss improved from 0.03407 to 0.03168, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0554 - val_loss: 0.0317\n","Epoch 63/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0564\n","Epoch 00063: val_loss improved from 0.03168 to 0.03013, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0564 - val_loss: 0.0301\n","Epoch 64/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0554\n","Epoch 00064: val_loss improved from 0.03013 to 0.02954, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0554 - val_loss: 0.0295\n","Epoch 65/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0554\n","Epoch 00065: val_loss improved from 0.02954 to 0.02862, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0554 - val_loss: 0.0286\n","Epoch 66/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0550\n","Epoch 00066: val_loss did not improve from 0.02862\n","1/1 [==============================] - 2s 2s/step - loss: 0.0550 - val_loss: 0.0289\n","Epoch 67/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0534\n","Epoch 00067: val_loss improved from 0.02862 to 0.02757, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0534 - val_loss: 0.0276\n","Epoch 68/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0501\n","Epoch 00068: val_loss did not improve from 0.02757\n","1/1 [==============================] - 3s 3s/step - loss: 0.0501 - val_loss: 0.0284\n","Epoch 69/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0493\n","Epoch 00069: val_loss improved from 0.02757 to 0.02530, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0493 - val_loss: 0.0253\n","Epoch 70/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0479\n","Epoch 00070: val_loss did not improve from 0.02530\n","1/1 [==============================] - 3s 3s/step - loss: 0.0479 - val_loss: 0.0253\n","Epoch 71/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0479\n","Epoch 00071: val_loss did not improve from 0.02530\n","1/1 [==============================] - 3s 3s/step - loss: 0.0479 - val_loss: 0.0259\n","Epoch 72/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0477\n","Epoch 00072: val_loss did not improve from 0.02530\n","1/1 [==============================] - 2s 2s/step - loss: 0.0477 - val_loss: 0.0262\n","Epoch 73/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0470\n","Epoch 00073: val_loss improved from 0.02530 to 0.02512, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0470 - val_loss: 0.0251\n","Epoch 74/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0449\n","Epoch 00074: val_loss did not improve from 0.02512\n","1/1 [==============================] - 2s 2s/step - loss: 0.0449 - val_loss: 0.0264\n","Epoch 75/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0461\n","Epoch 00075: val_loss improved from 0.02512 to 0.02510, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0461 - val_loss: 0.0251\n","Epoch 76/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0447\n","Epoch 00076: val_loss improved from 0.02510 to 0.02374, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0447 - val_loss: 0.0237\n","Epoch 77/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0459\n","Epoch 00077: val_loss improved from 0.02374 to 0.02309, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0459 - val_loss: 0.0231\n","Epoch 78/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0429\n","Epoch 00078: val_loss did not improve from 0.02309\n","1/1 [==============================] - 3s 3s/step - loss: 0.0429 - val_loss: 0.0234\n","Epoch 79/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0435\n","Epoch 00079: val_loss improved from 0.02309 to 0.02282, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0435 - val_loss: 0.0228\n","Epoch 80/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0440\n","Epoch 00080: val_loss did not improve from 0.02282\n","1/1 [==============================] - 3s 3s/step - loss: 0.0440 - val_loss: 0.0232\n","Epoch 81/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0432\n","Epoch 00081: val_loss improved from 0.02282 to 0.02265, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0432 - val_loss: 0.0227\n","Epoch 82/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0428\n","Epoch 00082: val_loss improved from 0.02265 to 0.02173, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0428 - val_loss: 0.0217\n","Epoch 83/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0407\n","Epoch 00083: val_loss improved from 0.02173 to 0.02128, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0407 - val_loss: 0.0213\n","Epoch 84/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0415\n","Epoch 00084: val_loss improved from 0.02128 to 0.02053, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0415 - val_loss: 0.0205\n","Epoch 85/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0407\n","Epoch 00085: val_loss did not improve from 0.02053\n","1/1 [==============================] - 3s 3s/step - loss: 0.0407 - val_loss: 0.0209\n","Epoch 86/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0402\n","Epoch 00086: val_loss did not improve from 0.02053\n","1/1 [==============================] - 3s 3s/step - loss: 0.0402 - val_loss: 0.0206\n","Epoch 87/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0400\n","Epoch 00087: val_loss did not improve from 0.02053\n","1/1 [==============================] - 2s 2s/step - loss: 0.0400 - val_loss: 0.0207\n","Epoch 88/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0394\n","Epoch 00088: val_loss improved from 0.02053 to 0.02030, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0394 - val_loss: 0.0203\n","Epoch 89/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0379\n","Epoch 00089: val_loss did not improve from 0.02030\n","1/1 [==============================] - 2s 2s/step - loss: 0.0379 - val_loss: 0.0211\n","Epoch 90/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0388\n","Epoch 00090: val_loss did not improve from 0.02030\n","1/1 [==============================] - 2s 2s/step - loss: 0.0388 - val_loss: 0.0208\n","Epoch 91/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0382\n","Epoch 00091: val_loss did not improve from 0.02030\n","1/1 [==============================] - 3s 3s/step - loss: 0.0382 - val_loss: 0.0205\n","Epoch 92/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0391\n","Epoch 00092: val_loss improved from 0.02030 to 0.01967, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0391 - val_loss: 0.0197\n","Epoch 93/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0372\n","Epoch 00093: val_loss did not improve from 0.01967\n","1/1 [==============================] - 2s 2s/step - loss: 0.0372 - val_loss: 0.0197\n","Epoch 94/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0378\n","Epoch 00094: val_loss improved from 0.01967 to 0.01884, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0378 - val_loss: 0.0188\n","Epoch 95/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0373\n","Epoch 00095: val_loss improved from 0.01884 to 0.01880, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0373 - val_loss: 0.0188\n","Epoch 96/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0372\n","Epoch 00096: val_loss improved from 0.01880 to 0.01866, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0372 - val_loss: 0.0187\n","Epoch 97/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0366\n","Epoch 00097: val_loss improved from 0.01866 to 0.01860, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0366 - val_loss: 0.0186\n","Epoch 98/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0360\n","Epoch 00098: val_loss improved from 0.01860 to 0.01855, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0360 - val_loss: 0.0185\n","Epoch 99/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0361\n","Epoch 00099: val_loss improved from 0.01855 to 0.01842, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0361 - val_loss: 0.0184\n","Epoch 100/100\n","1/1 [==============================] - ETA: 0s - loss: 0.0357\n","Epoch 00100: val_loss improved from 0.01842 to 0.01822, saving model to /content/gdrive/My Drive/Recurrent_GRF_Prediction-main/data/Model.h5\n","1/1 [==============================] - 3s 3s/step - loss: 0.0357 - val_loss: 0.0182\n"]},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"F3pPYRbNZcqz"},"source":["## Load best model and evaluate\n","\n"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":371},"id":"zmNX1lxwVvvE","executionInfo":{"status":"ok","timestamp":1641320616894,"user_tz":480,"elapsed":6466,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}},"outputId":"1ffa8e65-59d2-4f04-a70c-0eb263cd0bcb"},"source":["saved_model = keras.models.load_model(model_filename)\n","\n","test_scores = saved_model.evaluate(test_feats, test_y, verbose = 2)\n","pred_final = saved_model.predict(test_feats)\n","pred_final = np.squeeze(pred_final)\n","\n","# Plot Prediction vs Measured GRF ----\n","i = 0 # trial to plot\n","s = 500 # frame to start plot\n","e = 1000 # frame to end plot\n","fig, ax = plt.subplots(1,1, figsize=(10,5))\n","ax.plot(test_y[i,s:e], '-', label='true')\n","ax.plot(pred_final[i,s:e], 'r--', label='prediction')\n","ax.grid()\n","ax.set_title(' Speed: ' + str(test_Sub_Info['Speed'].iloc[i]) + \n"," ' Slope: ' + str(test_Sub_Info['Slope'].iloc[i])\n"," )\n","print()\n","plt.show()"],"execution_count":13,"outputs":[{"output_type":"stream","name":"stdout","text":["1/1 - 2s - loss: 0.0182 - 2s/epoch - 2s/step\n","\n"]},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"TdaZsexZZpUe"},"source":["## Additional Information about RNN"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"7-BqsqkRtrhf","executionInfo":{"status":"ok","timestamp":1641320619811,"user_tz":480,"elapsed":131,"user":{"displayName":"Ryan Alcantara","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiGvdgh0URM4PNeT3MUDpX4MovyDEXk5cA0b2z-d_Y=s64","userId":"01174355928364926683"}},"outputId":"64f2a374-3b93-4166-ed6c-47bc179c6b03"},"source":["## Get Info about Saved Model ----\n","saved_model.summary()\n","# saved_model.get_config()"],"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["Model: \"Accel_subcond_LSTM\"\n","_________________________________________________________________\n"," Layer (type) Output Shape Param # \n","=================================================================\n"," input_1 (InputLayer) [(None, None, 13)] 0 \n"," \n"," dropout (Dropout) (None, None, 13) 0 \n"," \n"," bidirectional (Bidirectiona (None, None, 512) 2154496 \n"," l) \n"," \n"," dropout_1 (Dropout) (None, None, 512) 0 \n"," \n"," dense (Dense) (None, None, 128) 65664 \n"," \n"," dense_1 (Dense) (None, None, 384) 49536 \n"," \n"," dense_2 (Dense) (None, None, 320) 123200 \n"," \n"," dense_3 (Dense) (None, None, 1) 321 \n"," \n","=================================================================\n","Total params: 2,393,217\n","Trainable params: 2,393,217\n","Non-trainable params: 0\n","_________________________________________________________________\n"]}]}]} --------------------------------------------------------------------------------