├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt └── src ├── 01-dataset-tutorial.ipynb ├── 02-deep-learning-tutorial.ipynb ├── 03-inference-tutorial.ipynb ├── build-dataset.py ├── build_dataset.sh ├── common ├── data_utils.py ├── logging.py ├── losses.py ├── preprocessing.py ├── quaternion.py ├── rotations.py └── skeleton.py ├── matlab ├── get_folder_path.m ├── joint_angle_segments.m ├── joint_angles.m ├── load_partial_mvnx.m ├── mvnx_to_csv.m ├── mvnx_to_hdf.m ├── mvnx_to_hdf_batch.m ├── process_mvnx_files.m ├── segment_orientation.m ├── segment_position.m └── segment_reference.m ├── seq2seq ├── seq2seq.py └── training_utils.py ├── test-seq2seq.py ├── test-transformer.py ├── test_seq2seq.sh ├── test_transformer.sh ├── train-seq2seq.py ├── train-transformer.py ├── train_seq2seq.sh ├── train_transformer.sh └── transformers ├── training_utils.py └── transformers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .idea/ 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # data 128 | *.csv 129 | *.mvnx 130 | *.h5 131 | *.pt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020-present, Assistive Robotics Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Virginia Tech Natural Motion Processing 2 | 3 | This repository was written to help analyze the Virginia Tech Natural Motion Dataset. The dataset contains 40 hours of unscripted human motion collected in the open world using XSens MVN Link. The dataset, metadata and more information is available through the Virginia Tech University Libraries: https://data.lib.vt.edu/articles/dataset/Virginia_Tech_Natural_Motion_Dataset/14114054/2. 4 | 5 | ## Table of Contents 6 | 7 | - [Layout](#project-layout) 8 | - [Workflow](#workflow) 9 | - [Dependencies](#dependencies) 10 | - [Conda Environment](#conda-environment) 11 | 12 | ## Project Layout 13 | 14 | src/ 15 | common/ 16 | seq2seq/ 17 | transformers 18 | matlab/ 19 | 20 | ## Dependencies 21 | 22 | numpy==1.18.1 23 | h5py==2.10.0 24 | matplotlib==3.1.3 25 | torch==1.6.0 26 | 27 | ## Setup 28 | 29 | - Clone the repo locally 30 | - Setup the conda environment 31 | - `$ conda create -n vt-nmp python=3.7` 32 | - Install requirements 33 | - `$ pip install -r requirements.txt` 34 | 35 | ## Conda Environment 36 | 37 | An Anaconda environment is used to help with development. The environment's main dependency is PyTorch, which will be installed when setting up the workflow above. 38 | 39 | ## License 40 | 41 | Please see the LICENSE for more details. If you use our code or models in your research, please cite our paper: 42 | 43 | ``` 44 | @article{geissinger2020motion, 45 | title={Motion inference using sparse inertial sensors, self-supervised learning, and a new dataset of unscripted human motion}, 46 | author={Geissinger, Jack H and Asbeck, Alan T}, 47 | journal={Sensors}, 48 | volume={20}, 49 | number={21}, 50 | pages={6330}, 51 | year={2020}, 52 | publisher={Multidisciplinary Digital Publishing Institute} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | h5py==2.10.0 3 | matplotlib==3.1.3 4 | torch==1.6.0 5 | -------------------------------------------------------------------------------- /src/02-deep-learning-tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Building a dataset, training a Seq2Seq model, and testing it\n", 8 | "\n", 9 | "The Virginia Tech Natural Motion Dataset contains .h5 files with unscripted human motion data collected in real-world environments as participants went about their day-to-day lives. This is a brief tutorial in using the dataset and then training and testing a neural network.\n", 10 | "\n", 11 | "This tutorial illustrates how to use the shell (.sh) scripts to train a seq2seq model (particularly **train_seq2seq.sh** and **test_seq2seq.sh**). Similar shell scripts are also available for the Transformers (see **train_transformer.sh** and **test_transformer.sh**)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "### Building a dataset\n", 19 | "\n", 20 | "We will first cover how to build a dataset with data from a few participants using the build-dataset.py file.\n", 21 | "\n", 22 | "We are running the script from a Jupyter Notebook, but this can just as easily be run as a shell script (see build_dataset.sh).\n", 23 | "\n", 24 | "In this case, we are drawing data from the h5-dataset folder located in the cloud. We are going to output the training.h5, validation.h5, and testing.h5 files to the folder data/set-2.\n", 25 | "\n", 26 | "We will be using participants 1, 5, and 10 (P1, P5, P10, respectively) and extracting normOrientation and normAcceleration data on a few segments (norm* means data normalized relative to the pelvis). As output data we will be extracting normOrientation data for every segment.\n", 27 | "\n", 28 | "In other words, our task is as follows: use orientation and acceleration from a set of sparse segments and try to train a model mapping that input data to orientations for every segment on the human body." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "2020-08-10 13:09:31 INFO Writing X to the training file group...\n", 41 | "2020-08-10 13:09:37 INFO Writing X to the validation file group...\n", 42 | "2020-08-10 13:09:40 INFO Writing X to the testing file group...\n", 43 | "2020-08-10 13:09:50 INFO Writing Y to the training file group...\n", 44 | "2020-08-10 13:09:57 INFO Writing Y to the validation file group...\n", 45 | "2020-08-10 13:09:58 INFO Writing Y to the testing file group...\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "!mkdir -p /home/jackg7/VT-Natural-Motion-Processing/data/set-2\n", 51 | "!python build-dataset.py --data-path \"/groups/MotionPred/h5-dataset\" \\\n", 52 | " --output-path \"/home/jackg7/VT-Natural-Motion-Processing/data/set-2\" \\\n", 53 | " --training \"P1\" \\\n", 54 | " --validation \"P5\" \\\n", 55 | " --testing \"P10\" \\\n", 56 | " --task-input \"normOrientation normAcceleration\" \\\n", 57 | " --input-label-request \"T8 RightForeArm RightLowerLeg LeftForeArm LeftLowerLeg\" \\\n", 58 | " --task-output \"normOrientation\" \\\n", 59 | " --output-label-request \"all\"" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "### Training a seq2seq model\n", 67 | "\n", 68 | "We can now train a seq2seq model to map the normOrientation and normAcceleration data from the sparse segments to the full-body normOrientation data.\n", 69 | "\n", 70 | "We will be using a seq-length of 30 (at 240 Hz) downsample it by a factor of 6 (to 40 Hz). The resulting sequences will be of length 5 for the input and output. The in-out-ratio will then be used to reduce the output sequence length to 1.\n", 71 | "\n", 72 | "The input sequence will be of shape (B, 5, 35) and output shape will be of shape (B, 1, 92). Orientations are stored as quaternions, so orientation value will be 4 in length. The number 35 comes from our use of 5 segment orientations and accelerations or $5*4 + 5*3 = 35$. The full-body has 23 segments and we're predicting orientation values for each one or $23*4 = 92$\n", 73 | "\n", 74 | "We're training a seq2seq model with a hidden size of 512, a bidirectional encoder and dot product attention. The model will be trained for a single epoch.\n", 75 | "\n", 76 | "Our loss function for training will be the L1Loss and our validation losses will be the L1Loss and the QuatDistance (cosine similarity) loss." 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 2, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "2020-08-10 13:10:06 INFO task - conversion\n", 89 | "2020-08-10 13:10:06 INFO data_path - /home/jackg7/VT-Natural-Motion-Processing/data/set-2\n", 90 | "2020-08-10 13:10:06 INFO model_file_path - /home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt\n", 91 | "2020-08-10 13:10:06 INFO representation - quaternions\n", 92 | "2020-08-10 13:10:06 INFO auxiliary_acc - False\n", 93 | "2020-08-10 13:10:06 INFO batch_size - 32\n", 94 | "2020-08-10 13:10:06 INFO learning_rate - 0.001\n", 95 | "2020-08-10 13:10:06 INFO seq_length - 30\n", 96 | "2020-08-10 13:10:06 INFO downsample - 6\n", 97 | "2020-08-10 13:10:06 INFO in_out_ratio - 5\n", 98 | "2020-08-10 13:10:06 INFO stride - 30\n", 99 | "2020-08-10 13:10:06 INFO num_epochs - 1\n", 100 | "2020-08-10 13:10:06 INFO hidden_size - 512\n", 101 | "2020-08-10 13:10:06 INFO dropout - 0.0\n", 102 | "2020-08-10 13:10:06 INFO bidirectional - True\n", 103 | "2020-08-10 13:10:06 INFO attention - dot\n", 104 | "2020-08-10 13:10:06 INFO Starting seq2seq model training...\n", 105 | "2020-08-10 13:10:06 INFO Retrieving training data for sequences 125 ms long and downsampling to 40.0 Hz...\n", 106 | "2020-08-10 13:10:09 INFO Data for training have shapes (X, y): torch.Size([259570, 35]), torch.Size([51914, 92])\n", 107 | "2020-08-10 13:10:09 INFO Reshaped training shapes (X, y): torch.Size([51914, 5, 35]), torch.Size([51914, 1, 92])\n", 108 | "2020-08-10 13:10:09 INFO Number of training samples: 51914\n", 109 | "2020-08-10 13:10:09 INFO Retrieving validation data for sequences 125 ms long and downsampling to 40.0 Hz...\n", 110 | "2020-08-10 13:10:09 INFO Data for validation have shapes (X, y): torch.Size([90880, 35]), torch.Size([18176, 92])\n", 111 | "2020-08-10 13:10:09 INFO Reshaped validation shapes (X, y): torch.Size([18176, 5, 35]), torch.Size([18176, 1, 92])\n", 112 | "2020-08-10 13:10:09 INFO Number of validation samples: 18176\n", 113 | "2020-08-10 13:10:09 INFO Encoder for training: EncoderRNN(\n", 114 | " (gru): GRU(35, 512, bidirectional=True)\n", 115 | " (dropout): Dropout(p=0.0, inplace=False)\n", 116 | " (fc): Linear(in_features=1024, out_features=512, bias=True)\n", 117 | ")\n", 118 | "2020-08-10 13:10:09 INFO Decoder for training: AttnDecoderRNN(\n", 119 | " (attention): Attention()\n", 120 | " (rnn): GRU(1116, 512)\n", 121 | " (out): Linear(in_features=1628, out_features=92, bias=True)\n", 122 | ")\n", 123 | "2020-08-10 13:10:09 INFO Number of parameters: 4864876\n", 124 | "2020-08-10 13:10:09 INFO Optimizers for training: AdamW (\n", 125 | "Parameter Group 0\n", 126 | " amsgrad: False\n", 127 | " betas: (0.9, 0.999)\n", 128 | " eps: 1e-08\n", 129 | " initial_lr: 0.001\n", 130 | " lr: 0.001\n", 131 | " weight_decay: 0.05\n", 132 | ")\n", 133 | "2020-08-10 13:10:09 INFO Criterion for training: L1Loss()\n", 134 | "2020-08-10 13:10:09 INFO Epoch 1 / 1\n", 135 | "2020-08-10 13:10:09 INFO Total time elapsed: 0.2845275402069092 - Batch Number: 0 / 1622 - Training loss: 0.5867175199891831\n", 136 | "2020-08-10 13:10:28 INFO Total time elapsed: 19.075539588928223 - Batch Number: 162 / 1622 - Training loss: 0.06888134323321701\n", 137 | "2020-08-10 13:10:46 INFO Total time elapsed: 36.41547679901123 - Batch Number: 324 / 1622 - Training loss: 0.06516135748519093\n", 138 | "2020-08-10 13:11:06 INFO Total time elapsed: 56.93833088874817 - Batch Number: 486 / 1622 - Training loss: 0.06124250432828066\n", 139 | "2020-08-10 13:11:27 INFO Total time elapsed: 77.28701090812683 - Batch Number: 648 / 1622 - Training loss: 0.06000007542016456\n", 140 | "2020-08-10 13:11:49 INFO Total time elapsed: 99.8214840888977 - Batch Number: 810 / 1622 - Training loss: 0.06179196632301481\n", 141 | "2020-08-10 13:12:13 INFO Total time elapsed: 123.77892279624939 - Batch Number: 972 / 1622 - Training loss: 0.055964607362772964\n", 142 | "2020-08-10 13:12:34 INFO Total time elapsed: 144.65489411354065 - Batch Number: 1134 / 1622 - Training loss: 0.05693873948561733\n", 143 | "2020-08-10 13:12:58 INFO Total time elapsed: 167.9723401069641 - Batch Number: 1296 / 1622 - Training loss: 0.05655428921935604\n", 144 | "2020-08-10 13:13:21 INFO Total time elapsed: 191.32336831092834 - Batch Number: 1458 / 1622 - Training loss: 0.05438031324405435\n", 145 | "2020-08-10 13:13:44 INFO Total time elapsed: 214.2629885673523 - Batch Number: 1620 / 1622 - Training loss: 0.05223539072065468\n", 146 | "2020-08-10 13:14:03 INFO Training Loss: 0.06224823312900671 - Val Loss: 0.10371989231654884, 22.61474124110037\n", 147 | "2020-08-10 13:14:03 INFO Saving model to /home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt\n", 148 | "2020-08-10 13:14:04 INFO Completed Training...\n", 149 | "2020-08-10 13:14:04 INFO \n", 150 | "\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "!mkdir -p /home/jackg7/VT-Natural-Motion-Processing/models/set-2\n", 156 | "!python train-seq2seq.py --task conversion \\\n", 157 | " --data-path \"/home/jackg7/VT-Natural-Motion-Processing/data/set-2\" \\\n", 158 | " --model-file-path \"/home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt\" \\\n", 159 | " --representation quaternions \\\n", 160 | " --batch-size=32 \\\n", 161 | " --seq-length=30 \\\n", 162 | " --downsample=6 \\\n", 163 | " --in-out-ratio=5 \\\n", 164 | " --stride=30 \\\n", 165 | " --learning-rate=0.001 \\\n", 166 | " --num-epochs=1 \\\n", 167 | " --hidden-size=512 \\\n", 168 | " --attention=dot \\\n", 169 | " --bidirectional" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "### Testing our model\n", 177 | "\n", 178 | "We can now test our model and output a histogram of performance over the testing data. The model parameters must be the same to properly read in the model." 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 3, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "2020-08-10 13:14:06 INFO task - conversion\n", 191 | "2020-08-10 13:14:06 INFO data_path_parent - /home/jackg7/VT-Natural-Motion-Processing/data\n", 192 | "2020-08-10 13:14:06 INFO figure_file_path - /home/jackg7/VT-Natural-Motion-Processing/images/seq2seq-test.pdf\n", 193 | "2020-08-10 13:14:06 INFO figure_title - Seq2Seq\n", 194 | "2020-08-10 13:14:06 INFO include_legend - False\n", 195 | "2020-08-10 13:14:06 INFO model_dir - /home/jackg7/VT-Natural-Motion-Processing/models/set-2\n", 196 | "2020-08-10 13:14:06 INFO representation - quaternions\n", 197 | "2020-08-10 13:14:06 INFO batch_size - 512\n", 198 | "2020-08-10 13:14:06 INFO seq_length - 30\n", 199 | "2020-08-10 13:14:06 INFO downsample - 6\n", 200 | "2020-08-10 13:14:06 INFO in_out_ratio - 5\n", 201 | "2020-08-10 13:14:06 INFO stride - 30\n", 202 | "2020-08-10 13:14:06 INFO hidden_size - 512\n", 203 | "2020-08-10 13:14:06 INFO dropout - 0.0\n", 204 | "2020-08-10 13:14:06 INFO bidirectional - True\n", 205 | "2020-08-10 13:14:06 INFO attention - dot\n", 206 | "2020-08-10 13:14:06 INFO Starting seq2seq model testing...\n", 207 | "2020-08-10 13:14:06 INFO Retrieving testing data for sequences 125 ms long and downsampling to 40.0 Hz...\n", 208 | "2020-08-10 13:14:09 INFO Data for testing have shapes (X, y): torch.Size([452760, 35]), torch.Size([90552, 92])\n", 209 | "2020-08-10 13:14:09 INFO Reshaped testing shapes (X, y): torch.Size([90552, 5, 35]), torch.Size([90552, 1, 92])\n", 210 | "2020-08-10 13:14:09 INFO Number of testing samples: 90552\n", 211 | "2020-08-10 13:14:37 INFO Inference Loss for set-2: 14.272361435695645\n", 212 | "2020-08-10 13:14:37 INFO Completed Testing...\n", 213 | "2020-08-10 13:14:37 INFO \n", 214 | "\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "!mkdir -p /home/jackg7/VT-Natural-Motion-Processing/images\n", 220 | "!python test-seq2seq.py --task conversion \\\n", 221 | " --data-path-parent /home/jackg7/VT-Natural-Motion-Processing/data \\\n", 222 | " --figure-file-path /home/jackg7/VT-Natural-Motion-Processing/images/seq2seq-test.pdf \\\n", 223 | " --figure-title \"Seq2Seq\" \\\n", 224 | " --model-dir /home/jackg7/VT-Natural-Motion-Processing/models/set-2 \\\n", 225 | " --representation quaternions \\\n", 226 | " --batch-size=512 \\\n", 227 | " --seq-length=30 \\\n", 228 | " --downsample=6 \\\n", 229 | " --in-out-ratio=5 \\\n", 230 | " --stride=30 \\\n", 231 | " --hidden-size=512 \\\n", 232 | " --attention=dot \\\n", 233 | " --bidirectional" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "We can now visualize the performance of the seq2seq model on the test data." 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 4, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "data": { 250 | "text/html": [ 251 | "\n", 252 | " \n", 259 | " " 260 | ], 261 | "text/plain": [ 262 | "" 263 | ] 264 | }, 265 | "execution_count": 4, 266 | "metadata": {}, 267 | "output_type": "execute_result" 268 | } 269 | ], 270 | "source": [ 271 | "from IPython.display import IFrame\n", 272 | "IFrame(\"../images/seq2seq-test.pdf\", width=600, height=300)" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [] 281 | } 282 | ], 283 | "metadata": { 284 | "kernelspec": { 285 | "display_name": "Python 3", 286 | "language": "python", 287 | "name": "python3" 288 | }, 289 | "language_info": { 290 | "codemirror_mode": { 291 | "name": "ipython", 292 | "version": 3 293 | }, 294 | "file_extension": ".py", 295 | "mimetype": "text/x-python", 296 | "name": "python", 297 | "nbconvert_exporter": "python", 298 | "pygments_lexer": "ipython3", 299 | "version": "3.7.4" 300 | } 301 | }, 302 | "nbformat": 4, 303 | "nbformat_minor": 2 304 | } 305 | -------------------------------------------------------------------------------- /src/03-inference-tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Running Inference with our Seq2Seq model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2\n", 18 | "%matplotlib notebook\n", 19 | "\n", 20 | "import torch\n", 21 | "import torch.nn.functional as F\n", 22 | "import h5py\n", 23 | "import numpy as np\n", 24 | "import glob\n", 25 | "from common.quaternion import quat_mul\n", 26 | "from common.data_utils import read_h5\n", 27 | "from common.skeleton import Skeleton\n", 28 | "from seq2seq.training_utils import get_encoder, get_attn_decoder" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "### Loading in some data to test with\n", 36 | "\n", 37 | "Let's start by loading in data.\n", 38 | "\n", 39 | "We trained our model using \"normOrientation\" and \"normAcceleration\" from the T8 (sternum), both forearms, and both lower legs as input, so we'll read in that data. We'll also read in some data for \"normOrientation\" on the entire body because this is the output of our model. Finally, we'll read in data for the orientation of the pelvis. This is important because we want to rotate the \"normOrientation\" data back into it's original reference frame." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "group = [\"T8\", \"RightForeArm\", \"RightLowerLeg\", \"LeftForeArm\", \"LeftLowerLeg\"]\n", 49 | "\n", 50 | "filepaths = glob.glob(\"../data/*.h5\")\n", 51 | "requests = {\"normOrientation\" : [\"all\"], \"orientation\" : [\"Pelvis\"]}\n", 52 | "dataset = read_h5(filepaths, requests)\n", 53 | "\n", 54 | "filename = filepaths[0].split(\"/\")[-1]\n", 55 | "fullBodyOrientations = torch.Tensor(dataset[filename][\"normOrientation\"]).double()\n", 56 | "root = torch.Tensor(dataset[filename]['orientation']).double()\n", 57 | "\n", 58 | "requests = {\"normOrientation\" : group, \"normAcceleration\" : group}\n", 59 | "dataset = read_h5(filepaths, requests)\n", 60 | "orientationInputs = torch.Tensor(dataset[filename][\"normOrientation\"]).double()\n", 61 | "accelerationInputs = torch.Tensor(dataset[filename][\"normAcceleration\"]).double()\n", 62 | "\n", 63 | "with h5py.File(\"../data/set-2/normalization.h5\", \"r\") as f:\n", 64 | " mean, std_dev = torch.Tensor(f[\"mean\"]), torch.Tensor(f[\"std_dev\"])" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "### Loading in the Seq2Seq Model\n", 72 | "\n", 73 | "We can now load in our Seq2Seq models (encoder and decoder).\n", 74 | "\n", 75 | "The models must have the same arguments used during training so that errors don't pop up." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "encoder_feature_size = len(group)*4 + len(group)*3\n", 85 | "decoder_feature_size = 92\n", 86 | "hidden_size = 512\n", 87 | "attention = \"dot\"\n", 88 | "bidirectional = True\n", 89 | "\n", 90 | "seq_length = 30\n", 91 | "downsample = 6\n", 92 | "in_out_ratio = 5\n", 93 | "\n", 94 | "\n", 95 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 96 | "\n", 97 | "encoder = get_encoder(encoder_feature_size,\n", 98 | " device,\n", 99 | " hidden_size=hidden_size,\n", 100 | " bidirectional=bidirectional)\n", 101 | "\n", 102 | "decoder = get_attn_decoder(decoder_feature_size,\n", 103 | " attention,\n", 104 | " device,\n", 105 | " hidden_size=hidden_size,\n", 106 | " bidirectional_encoder=bidirectional)\n", 107 | "\n", 108 | " \n", 109 | "decoder.batch_size = 1\n", 110 | "decoder.attention.batch_size = 1\n", 111 | "\n", 112 | "PATH = \"/home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt\"\n", 113 | "\n", 114 | "checkpoint = torch.load(PATH, map_location=device)\n", 115 | "\n", 116 | "encoder.load_state_dict(checkpoint['encoder_state_dict'])\n", 117 | "decoder.load_state_dict(checkpoint['decoder_state_dict'])\n", 118 | "\n", 119 | "encoder.eval()\n", 120 | "decoder.eval()\n", 121 | "models = (encoder, decoder)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "### Defining the inference function\n", 129 | "\n", 130 | "Our inference function needs to take in a batch of input data, pass it through both the encoder and decoder, and then return the output.\n", 131 | "\n", 132 | "Note that this function is very similar to the *loss_batch* function defined in *src/seq2seq/training_utils.py*. Here it returns the outputs instead of the loss for training/validation purposes." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "def inference(data, models, device, use_attention=False, norm_quaternions=False):\n", 142 | " encoder, decoder = models\n", 143 | " input_batch, target_batch = data\n", 144 | "\n", 145 | " input_batch = input_batch.to(device)\n", 146 | " target_batch = target_batch.to(device)\n", 147 | "\n", 148 | " seq_length = target_batch.shape[1]\n", 149 | "\n", 150 | " input = input_batch.permute(1, 0, 2)\n", 151 | " encoder_outputs, encoder_hidden = encoder(input)\n", 152 | "\n", 153 | " decoder_hidden = encoder_hidden\n", 154 | " decoder_input = torch.ones_like(target_batch[:, 0, :]).unsqueeze(0)\n", 155 | " \n", 156 | " outputs = torch.zeros_like(target_batch)\n", 157 | "\n", 158 | " for t in range(seq_length):\n", 159 | "\n", 160 | " if use_attention:\n", 161 | " decoder_output, decoder_hidden = decoder(\n", 162 | " decoder_input, decoder_hidden, encoder_outputs)\n", 163 | " else:\n", 164 | " decoder_output, decoder_hidden = decoder(\n", 165 | " decoder_input, decoder_hidden)\n", 166 | " \n", 167 | " target = target_batch[:, t, :].unsqueeze(0).double()\n", 168 | " \n", 169 | " output = decoder_output\n", 170 | "\n", 171 | " if norm_quaternions:\n", 172 | " original_shape = output.shape\n", 173 | "\n", 174 | " output = output.view(-1,4)\n", 175 | " output = F.normalize(output, p=2, dim=1).view(original_shape)\n", 176 | "\n", 177 | " outputs[:, t, :] = output\n", 178 | " \n", 179 | " decoder_input = output.detach()\n", 180 | "\n", 181 | " return outputs" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "### Sampling some data and running it through the inference function\n", 189 | "\n", 190 | "We can now sample some data and use our model to perform inference.\n", 191 | "\n", 192 | "We only trained the model to predict a single posture and did not use much training data as this is just a tutorial.\n", 193 | "\n", 194 | "Note we used normOrientations as output, so we have to put the orientation back into the original frame using the root's (pelvis) orientation." 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 5, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "i = 15330\n", 204 | "\n", 205 | "inp = torch.cat((orientationInputs[i:i+seq_length:downsample, :], accelerationInputs[i:i+seq_length:downsample, :]), dim=1)\n", 206 | "inp = inp.sub(mean).div(std_dev).double()\n", 207 | "\n", 208 | "out = fullBodyOrientations[i:i+seq_length:downsample, :].double()\n", 209 | "out = out[-1,:].unsqueeze(0) # trained our model to predict only a single pose\n", 210 | "\n", 211 | "data = (inp.unsqueeze(0), out.unsqueeze(0))\n", 212 | " \n", 213 | "output = inference(data, models, device, use_attention=True, norm_quaternions=True)\n", 214 | "\n", 215 | "full_body = fullBodyOrientations[i:i+seq_length:downsample,:].clone()\n", 216 | "full_body = full_body[-1, :].unsqueeze(0)\n", 217 | "\n", 218 | "seq2seq_body = output.clone().squeeze(0)\n", 219 | "\n", 220 | "root_motion = root[i:i+seq_length:downsample, :].clone()\n", 221 | "root_motion = root_motion[-1, :].unsqueeze(0)\n", 222 | "\n", 223 | "root_motion = root_motion.unsqueeze(1).repeat(1, full_body.shape[1]//4, 1)\n", 224 | "\n", 225 | "full_body = quat_mul(root_motion, full_body.view(-1, full_body.shape[1]//4, 4)).view(full_body.shape)\n", 226 | "seq2seq_body = quat_mul(root_motion, seq2seq_body.view(-1, seq2seq_body.shape[1]//4, 4)).view(seq2seq_body.shape)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "Now that we have the ground truth posture and our seq2seq output, we can compare the motion using the Skeleton. \n", 234 | "\n", 235 | "The output won't look that great because we only used a single participant and a single epoch to train the model." 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 6, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAD7CAYAAAC7WecDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAcqklEQVR4nO3deXCc9Z3n8c/Tp87WYUm25UM22NhcNoc9mEA4DANsiJ3YGGZgHZINU7XLMpNkEmonm0zYmmy2andnSFKQDJlkLmoIkwwYL5gzmAUvp4MJwYAxtoMPWZasw7rVre5+nmf/aHVLji+1ju5f9/N+Vbkw7qcff4V49Onfbbmu6woAAOSVL98FAAAAAhkAACMQyAAAGIBABgDAAAQyAAAGIJABADAAgQwAgAEIZAAADEAgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkAAAMQyAAAGIBABgDAAAQyAAAGIJABADAAgQwAgAEIZAAADEAgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkAAAMQyAAAGIBABgDAAJ4KZNu21d7eLtu2810KgJMYHh5WR0eHHMfJdylAznkqkHt6ejRz5ky1tbXxwAMG6uzsVENDgzo7O+W6br7LAXLKU4GclkgkNDg4SEsZMFQ0GtXw8DChDE/xZCAHAgFCGTCYZVmEMjzHk4EspUI5mUwSyoCB/H4/oQzP8WwgW5ZFKAMGGxvKsViMUEbR82wgS4QyYLp0KMdiMUIZRS+Q7wLybWwoJ5NJWZaV75KQQ5Zl8T03nN/vl23bSiQSCofDhLKHeO359HwgS6OhLKXWQcI7LMtSOBz21ENfiPx+fyaYk8lkvstBjnjt+SSQR6S/4V75xkNyXTfzi++7+XhGvcWLzyeBPIZXvukYRfdnYfFaF6bXee359PSkLgAATEEgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkADGX39GjozTfzXQZyJJDvAgAAJ0ocblHz2rWy/H41vfKyfOXl+S4J04wWMgAYKDCnUYE5jXLjcQ29+lq+y0EOEMgAYCDLslR+3fWSpMGXtua5GuQCgQwAhiq//jpJ0tCrr8mJxfJcDaYbgQwAhgqff74Cs2fLjUYVfeONfJeDaUYgA4ChLMtS2XWrJUmDW1/KczWYbgQyABisYmQceej/bZObSOS5GkwnAhmnFD9wQP3PPJPvMgBPC1+0XP66Ojn9A4pu357vcjCNCGScVLKjQ213362Ob31b/Vu25LscwLMsn0/lq6+VNI5ua8eW7HgOqsJ0IJBxAmdgQG333KPkkVYF5s9T2RVX5LskwNPKrx9Z/vTyy3KTyeNf7Dsi33s/V2DzXQr9cIlCD1woHfskD1VistipC8dxEwkd/fo3FP94j/y1tZr90EPy19bmuyzA00ouvVS+6mo5PT2KvfO2yhri8n3yf+Xb/4p8HR+dcH3gpfuUvPWRPFSKySCQkeE6jjruu0/R7dtllZVp1o9/pODcuSe7ULLoXAFyxQoE5B8JZP8v7lSooTfzmitLbuMlcs5aLbd+qQJP/kf59/1Kzr6tchZdn8eqkS0CGRlD27Zp4NnnJL9fs+++QWVdz8h64WFpsEPWUKc02Jn6Z6xXzsVfUPKmv8l3yUDRStgJvdLyip7Y94T2HHxHPzuQGhsureqXW94g56zVqV8LrpLKRnux7JYdCvz6Iflf+o6chVdJ/lC+vgRkiUCGbMfWW21vadarL8qSVLOwT1Utfyu1nPo9vt8+Il3xDalyds7qBLzgcP9hbf5ks7Z8skXHho9Jki5ucSRJocqkfPPPV/xLL0i+k//4tq+8V/4PH5fv2O/k3/Ez2Zfdk7PaMTkEsoe5rqs3Wt/Qg+89qH29+/TXb7hqklTWMCxnzgq5DefJLauTW1YvldfJLauTyusUeO4b8h3+tfzv/1L2p76W7y8DKHhJJ6ltLdv0xL4ntP3o6NKm+tJ63dR0kxLb/nnkD5JK/rv7TxnGkqRwpZLXfEfBZ74i/2t/I/v8DVLFzOn9AjAlCGSP2nVslx747QPa0b5DklQy7Gpemy1JCp89S4nbN0nB0pO+116+MRXI7z0q+/KvMJ4MTNKPd/5Y/7L7XyRJlixdPvtyrT97va5svFIBSVub/0mS9O7i+Zo7+6Iz3s+58DY5v/ln+Vp/o8Ar31Pysw9OZ/mYIvwk9ZjDA4f1rTe+pTt/dad2tO9Q0BfUxiUbtabrLPlcaSDiyLr1+6cMY0lylq6RG6qQ1XNA1iH21wUmazAxKElaNmOZnvzsk3rg6gd0zdxrFPAFJF9AH563TLuXBDSw6kvju6HlU/IP/4ckyf/+L2W17JimyjGVCGSPSDpJff/d72vDsxv0q0O/kiVLn1nwGT1x8xP62sVf0yVHA3Ik7Z0bltt05elvFiqXc946SZL/vZ9Pf/FAkVs5c6UkKWpH1VjReMLrO879iv783P+pqnOuG/c93TmXyr7wjyVJgRe/nVodAaMRyB7xZuubevTjR5V0klo1a5UeufERfXfVdzW7PDUp69fX3KdvrPsjtd/8vXHdz16+UZLk+/gZKdZ7hqsBnM6lDZdKkvb27FV3rPuE1xO2K0kK+Kys7pu85ttyQxXytb4r3/u/nHyhmFYEskcsrV0q38hY772X3KslNUuOe72tN67d7kqVLzp3XPdzZ18kp/5cWcmYfB9umvJ6AS+pLanVoqpFkqR32t854XXbGQlkf3aBrIqZsq+8N/Xel/+7FOubXKGYVgSyR9SX1uvyWZdLkp7e//QJr7f0pg4/b6wuGd8NLUvO8n8viW5rYCqku63fbn/7hNfSgezPsoUsSfaKP5FTu0jWUKf8r7N3gMkIZA9Zc9YaSdIzB56R7djHvdY6Eshzqk49mev32edvkOsPyXf0fVlt709doYAHrZi5QpL09tETAzkxEsjBCQSy/CElr08NRfl3/L2szj0TLxLTikD2kKsar1JVqEod0Q691fZW5s+jcVvHBlPnrI67hSxJZbVyzvmMJFrJwGRdUn+JfJZPh/oPqX2o/bjXJtNCliT37NWyF98ky0kqsPUvJdeddL2YegSyh4T8Id3UdJMkacv+0SMV2/uHJUmlQZ8iJdktTbeX3yFJ8u3aJCWiU1Qp4D2VoUotrVkqSZn9AdKG4qkTnkpD/gnfP3ndX0lS6kCKfS9M+D6YPgSyx6w9a60kaVvLNvUM90iSOgdTe+TWVYRlWdl9AncXXCW3ap6sWG9qxjWACcuMI/9et3V/LBXIleFJ7OVU3pD5bfDxOzk32UAEsscsqVmic6rPUcJJ6PmDz0uSOgdSLeS6iglsQm/5Mmsd/TsfnbI6AS9a2TAayO5It7LjuBqMp+Z8REqCE773CWPHA+0nvxB5QyB7ULqVnO627hxIt5AndiqMvex2ubLkO/ia1L1/aooEPGh5/XIFfAG1DbWpZTB1usvAcDIz5FuZ5ZDSWG7jxRr++ieK//Fjim98Sqo6ydGqyCsC2YNuarpJAV9AH3d/rD3dezKBPKN8gse0Vc3N7Gft/y2HogMTVRoo1YUzLpQk7TiaGkfuG+muDgV8CgUm+SM7XCF34dVy562a3H0wLQhkD6oOV+uqxqskpVrJR3pG1iBXZTHDeqxojyw31aUWeOtByUlOSZ2AF61oOH75U8fIkFL9BHuwUDgIZI9Kd1s/d/A5HerulyTNqx3/GuTjlFYrsebHmX/lwAlg4tITu3a075DrujralwrkmZFwPstCDnD8oketmrVKdSV16ox1aij6jqQlmlszwUCW5Fxwq5Jde+Xb+4Lc6gVTVifgNRfMuEBhf1hdsS7t79uvtr7Uj+mZlRPswULBoIXsUQFfQDcvvFmSNBR+U5I0t3rigSxJ9tXfUuJPtknV8yddH+BVIX9IF9Wlzjx+++jbmRbyLFrIRY8WsoetWbhGD3/0sPwVe1ReNqiasokvqQAwdVbOXKntR7frF3t+IXtwh0L1Uqua9PT+PaoKVSkSjqgqVKWqUJUqQ5Wpc5NR8PguetiCyAItrDhP+wd2qaJupyzrcye9znZsxexY6lcypmgymvl9zI6pIlihi+ovynH1QPG6fPbl+tHOH6l5oFlSs8J10us90uvbT359RbAiFdDhKkVCEUVCEd3YdKOunnN1TuvG5BDIHre86jrtH9ilaNlW3bX1wGjQjoRtNBlV3Dn9jj4rGlboJ6t/kqOKgeK3pGaJHrr2Ie3r2aefvrFLXdEerVoUVjg0rN54r3qHe9Ub79VAYkCSNJAY0EBiILN2WZLOrR3fUaowB4HscU3hT8m1fyrHP6j3Ot874/Ul/hKVBkpH/xkoUVOkKQeVAt6ycuZKrZy5Uj995jUN98b0p59bqWVzqo67JukkNZAYUM9wj/rifZmg7ov36eL6i/NUOSaKQPa4WDyooQP/WSvPGdSXVi1SSaBEpf5U0KaDtyRQohJ/icL+7Pe6BjA5PdHUSWzVpSfO8Qj4AqoOV6s6XJ3rsjANCGSP64sm5cRnaknFPK2et0SS5LouwQsYIJ50NDSyj3XVSQIZxYVA9pi+jna17t2tWH+fov198r/5ie7q61PklUv12LuDig0kFI8mdf5VjVq1bmG+ywU8xXFsHXzvXfV3dSg20K+O1i596Uiz/PLp5b+t1/BgUsNDCQVCft38Zxeoqn5ySxVhFgLZQ+xkUv/23/6LYgP9mT8LjPxSzxz1xkY3mz/4/jECGcixD19+Udse/tlxf1YpSQqp89BA5s/iUVt73mrXyjXM3ygmBLKHxAb6U2FsWVr8B59SSWVERzoc9Q8HtXDJeTp/2VIlYkm9+Pe7FY+xHzWQa46d6p4uq67RgosuleMv1YGDcVllFbruhiUqqUgF81ub9+vAzk6t+Ox8hpeKCIHsIcODqU/Y4bJy3XjP1096zWBPalegRMxmLBnIscoZ9ZKkipparf7y3Se9praxTL9+6oB622Pqbh1SbWN5LkvENGLrTA9Jd1WXVFSc8ppgiV+S5Niu7KSbk7oApETqGyRJfZ0dp7wmVBLQnKWpWdUHdnblpC7kBoHsIbGBVAu5pKLylNcEQ/7M7xN0WwM5lW4hx/r7FI9FT3ndwuUzJEkH3iOQiwmB7CGZFnL5qVvIls9SMJwK5XjUzkldAFLC5eUKl6W6oPtP00qef0GtLJ907MiQejtOHdwoLASyh8QGz9xClqRQaSqQE8MEMpBrlXWpbuvTBXJJeVCzF6V27aLbungQyB6SmdR1mjFkaXQcmZnWQO5F6lPd1qcLZIlu62JEIHvI6KSuM7SQS1KT7xMxWshArqXHkfs62097XdOyGZIldRwc0ED3cC5KwzQjkD0kM6nrNGPI0pgWMmPIQM6NdlmfPpDLIiHNXBiRRLd1sSCQPSTdQg6Xn6mFnO6yJpCBXEt3WZ9u6VPagmW1kgjkYkEge8hwZlLX+FrILHsCci/TQu44fQtZkhaMjCMf/V2fov2nP7cc5iOQPWQ865Cl0TFkWshA7kVGAjna36fE8OnHhitrS1Q3r0Kum9p/HoWNQPaQ8ezUJY2dZU0gA7kWLi9XqLRM0plnWkujrWRmWxc+AtkjkvG4kvHUp+2ScY4h02UN5EdlXXrp05m7rdPLn1r29Gp4iGe2kBHIHjE8NChJsnw+hcrKTnstXdZAfqW7rc+09EmSqhpKVTOrTK7j6tAHdFsXMgLZI0ZnWFec8QSn0UldBDKQD5Xj3BwkLdNtzWzrgkYge8R49rFOY9kTkF+RcWyfOVY6kA/v7mHL2wJGIHvEeGdYSyx7AvKtMosuayl1RnKkrkR2wlHzru7pLA3TiED2iOHB0S7rMwmVMoYM5NPopK7xtZAty2K2dREgkD0imxZyaMwYsuu601oXgBNFRgJ5qLcnszriTBYsSwVy865jSiacaasN04dA9ojxrkGWRrusHduVzYMN5Fy4vELBklJJ428l18+vUHl1SIlhRy0f90xneZgmBLJHjPcsZEkKhvzSyERsuq2B3LMsK6s9rSXJ8lmZVjLd1oWJQPaIzD7W4xhDtnyWgmGWPgH5VJnlTGtpdLb1oQ+OybHp3So0BLJHZNYhj6PLWhozjswSCiAvxnsu8lgzz4qopCKo4aGkjuztm67SME0IZI/IZlKXNPZMZJY+AflQUZtq7Q4cG3/3s89nqenC1JGMhz5k165CQyB7RDYbg0hsnwnk21BvamJWWaQqq/fNmFM+8n6OYyw0BLIHuK478RYygQzkRXdriySpenZjVu+jd6twEcgeEO3vk52IS5al8uqacb0nxH7WQF71jARyzew5Wb0vvbEPz27hIZA9oK/jqCSpoqZW/mBwXO/JzLJmUheQc8n4cGa5U03j3KzeO7oXPS3kQkMge0BfR2qWZqS+YdzvYT9rIH962lol11W4vEKllZGs3psJ5CgfpgsNgewB6UCurJs57vfQQgbyJzN+PKvxjMel/j72oi9cBLIH9HemW8j1434PRzAC+dPTekSSVNOY3fixNNq7ZSccNgcpMASyB0ysy5qJIUC+dE9wQpc0umRRotu60BDIHpCe1BWpn0CXNYEM5NxkAtnntxQIpX60M7GrsBDIRc5xbPV3dkqSInXjbyGzdSaQH67rjnZZZ7kGOY2JXYWJQC5yg93dcuykfP6Aymtrx/0+NgYB8mOw+5gSwzH5/H5FGmZN6B5BJnYVJAK5yKW7qytn1Mnn84/7fUFayEBedB85LEmKNMyUPxA4w9UnF2LZYkEikIvcRCZ0SWPHkHmggVzqbhvprp6V/fhxWmbpE13WBYVALnKZNchZBvLoGLIj13GnvC4AJ9d9ZGRC1wSWPKWxW1dhIpCL3GgLefwzrKXRFrIkJeJ8ygZypSdzqMTEA5k5IIWJQC5y/Z3pJU/ZtZD9QZ8sX2qHIJY+AbkzmSVPaZnjU+myLigEcpHLtJCzWPIkSZZlsVsXkGPxWFQDx7okTXzJkySFSpkDUogI5CJmJxIa6D4mKfsWsjT2gAkCGciFnrZWSVJpZWTcZ5efDC3kwkQgF7H+rg7JdRUIhVUaqcr6/RwwAeTWVEzokkZbyEzqKiwEchEbndBVn/WJMRITQ4Bc62lLn/I0yUAuYWOQQkQgF7GJHLs4VogWMpBTU9VCzgw30WVdUAjkIjZ6qET248fS2DFkur2AXJiKGdYSXdaFikAuYhPdpSuNE5+A3HEdJzOpa9KBzKSugkQgF7G+zoltCpLGsicgd/q7OmUn4vIFAqqsr5/UvTLLnoZtdtorIARyEZt0C5kDJoCcSXdXV8+cndVBMCeT3sta4vktJARykYrHoor190nKflOQtOBItxdd1sD0m6oJXZLkD/jkD6RWVtBtXTgI5CLV39khSQqXlStcXj6he4RoIQM5k56EOZkNQcZKf6DuPjo0JffD9COQi9RkZ1hLo5O6GEMGpt+sxUskSR+9+rI6Dx2Y9P3KIkFJ0gs/2aWnfrBTH791lA/XhiOQi1R6/LhiRt2E75EeQ472x+XYzpTUBeDkFl92hRZcvEJOMqkX/+4B2YnEpO531R2L1bSsVpbPUvuBfr36r/v06Hfe1mu/3KeOQ/1yXSZ7mYZALlLl1TWSpJbdHyra1zuhe5RFQpKknraonvhfv1Xzru4pqw/A8SzL0uov362Syoi6mg9q+xO/mNT96uZV6A/vOle3/9UKrVzTpEhdiRLDtna/cVRP3r9T/+ev39OuV1s1PMRaZVMQyEXqrBWXqb5poeJDQ3pr079O6B4z5pbryj86W+HygHqORvXC3+3S8w99qO5WxqSA6VBWVa1r/8N/kiT95tkndeTjjyZ/z0hIy6+fq1u/fYk+86fn6+xL6+QPWOpqGdQbj3+iR+97W9se2aOj+/sm/XdhcgjkIuXz+fXpjXdJkj58Zas6DnyS9T0sy9LST83SbX95qS68tlE+v6XDu3v0xP9+V68/9jvFBibXpQbgRGevuExLr7xGcl1t/dmDikejU3Jfy2epcXG1rr1ziW7/7kqtWr9QNbPLZCcc7X27Q7tfb5uSvwcTRyAXscYl52rxqisk19WrP//HCY8ZhcsCuuzzC3XLf71YTctq5TrSR6+16d++947ef7lFdpLxZWAqfXrjl1U5o04DXV1q3bt7yu9fUh7UBVc3av1fXKS1f75M56xq0NIrZk3534PsWK6HRva7urpUV1en5uZmVVWdeBxhOBwuuokO/V2d+vlf/JmS8bhuvOfrWnzZFZO+55G9vdq+eb+6WgYlSZG6Ev3B5xao6cLaCZ0qlS+u68p1XYXDYfl8fDY1QUtLi+bOnavDhw8rEokc91ogEFAwGJRte2OmcOvejxUIhVTftDDfpeSFF59Pb3yVHlY5o06XfHadJKn9k31Tcs/GxVX63L3L9enbF6k0ElRfZ0xb/2G3tv7D7qL7QAPky+zFSzwbxl4VOPMl3uA4TuYTWbG56Ka1mnf+cs1adM6UfX2WJZ1zWYMWLJ+hnS8d1gcvH1HDgtSGBoXy3/Cxxx7TmjVrFA6H810KzsBxnKJ+RnGigwcPqrm5WatXr853KTlDl7Uk27Zl27YCgUBBdbmaZKB7WKUVQfmDhdHpEo1GdcMNN2j+/Pl67LHHVFJSku+SoJN3WTuOo2QyKcuyFAwGCWSPePzxx/XVr35Vmzdv1vXXX5/vcnLC8y3kdBj7fD75/X4CeYKq6sryXUJWKioqtGXLFq1du1YbNmzQpk2baCkbaGwY+/1+z4wlQrrtttsUi8W0bt06bdq0STfccEO+S5p2ng7ksWEcDAZ52D2mrq5OTz/9tNauXav169dr06ZNtJQN8vthHAwG810ScsiyLH3xi1+Uz+fLPJ833nhjvsuaVp5NIMIYklRbW6unnnpKra2tWr9+vWKxWL5LgghjpFiWpS984Qu6//77dcstt+j555/Pd0nTypMp5DgOYYyM2tpabdmyRUePHtW6desIZQPYtk0YQ1IqlDdu3Kgf/OAH2rBhg5577rl8lzRtPJlEjuMQxjhOTU2NtmzZoo6ODn3+859XdIp2R0J2xk7YIoyRZlmW7rjjDv3whz/UrbfeqmeffTbfJU0LT82y7ujoUENDg3bu3Kmampp8lwMD9fX1aePGjaqpqdHTTz+t0tLSfJfkKYcOHVJTU5M++OCDk27eA2zevFnf/OY39fDDD+uWW24pqom4ngrk5uZmzZ8/P99loEB0dHSorm7ix1ciewcOHNDChWyGgfHp7e09YUe3QuapQHYcR0eOHFFlZWVRfarC9OD/k9zjGUU2iu3/E08FMgAApmJGEwAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkAAAMQyAAAGIBABgDAAAQyAAAGIJABADAAgQwAgAEIZAAADEAgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkAAAMQyAAAGIBABgDAAAQyAAAGIJABADAAgQwAgAEIZAAADEAgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAb4/8ajl/QUiWwtAAAAAElFTkSuQmCC\n", 246 | "text/plain": [ 247 | "
" 248 | ] 249 | }, 250 | "execution_count": 6, 251 | "metadata": {}, 252 | "output_type": "execute_result" 253 | } 254 | ], 255 | "source": [ 256 | "skeleton = Skeleton()\n", 257 | "bodies = torch.cat((full_body, seq2seq_body), dim=0).float()\n", 258 | "skeleton.compare_motion(bodies, azim=0, elev=0)" 259 | ] 260 | } 261 | ], 262 | "metadata": { 263 | "kernelspec": { 264 | "display_name": "Python 3", 265 | "language": "python", 266 | "name": "python3" 267 | }, 268 | "language_info": { 269 | "codemirror_mode": { 270 | "name": "ipython", 271 | "version": 3 272 | }, 273 | "file_extension": ".py", 274 | "mimetype": "text/x-python", 275 | "name": "python", 276 | "nbconvert_exporter": "python", 277 | "pygments_lexer": "ipython3", 278 | "version": "3.7.4" 279 | } 280 | }, 281 | "nbformat": 4, 282 | "nbformat_minor": 2 283 | } 284 | -------------------------------------------------------------------------------- /src/build-dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from common.data_utils import read_h5 9 | from common.logging import logger 10 | import h5py 11 | import argparse 12 | import sys 13 | import glob 14 | import numpy as np 15 | 16 | 17 | def parse_args(): 18 | """Parse arguments for module. 19 | 20 | Returns: 21 | argparse.Namespace: contains accessible arguments passed in to module 22 | """ 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument("--training", 26 | help=("participants for training, space separated; " 27 | "e.g., W1 W2")) 28 | parser.add_argument("--validation", 29 | help=("participants for validation, space separated; " 30 | "e.g., P1 P2")) 31 | parser.add_argument("--testing", 32 | help=("participants for testing, space separated; " 33 | "e.g., P4 P5")) 34 | parser.add_argument("-f", 35 | "--data-path", 36 | help="path to h5 files for reading data") 37 | parser.add_argument("-o", "--output-path", 38 | help=("path to directory to save h5 files for " 39 | "training, validation, and testing")) 40 | parser.add_argument("-x", "--task-input", 41 | help=("input type; " 42 | "e.g., orientation, relativePosition, " 43 | "or jointAngle")) 44 | parser.add_argument("--input-label-request", 45 | help=("input label requests, space separated; " 46 | "e.g., all or Pelvis RightForearm")) 47 | parser.add_argument("-y", "--task-output", 48 | help="output type; e.g., orientation or jointAngle") 49 | parser.add_argument("--output-label-request", 50 | help=("output label requests, space separated; " 51 | "e.g., all or jRightElbow")) 52 | parser.add_argument("--aux-task-output", 53 | help=("auxiliary task output in addition " 54 | "to regular task output")) 55 | parser.add_argument("--aux-output-label-request", 56 | help="aux output label requests, space separated") 57 | 58 | args = parser.parse_args() 59 | 60 | if None in [args.training, args.validation, args.testing]: 61 | logger.info(("Participant numbers for training, validation, " 62 | "or testing dataset were not provided.")) 63 | parser.print_help() 64 | sys.exit() 65 | 66 | if None in [args.data_path, args.output_path]: 67 | logger.error("Data path or output path were not provided.") 68 | parser.print_help() 69 | sys.exit() 70 | 71 | if None in [args.task_input, args.input_label_request, args.task_output]: 72 | logger.error(("Task input and label requests " 73 | "or task output were not given.")) 74 | parser.print_help() 75 | sys.exit() 76 | 77 | if args.output_label_request is None: 78 | if args.task_input == args.task_output: 79 | logger.info("Will create h5 files with input data only.") 80 | else: 81 | logger.error("Label output requests were not given for the task.") 82 | parser.print_help() 83 | sys.exit() 84 | 85 | if args.aux_task_output == args.task_output: 86 | logger.error("Auxiliary task should not be the same as the main task.") 87 | parser.print_help() 88 | sys.exit() 89 | 90 | if (args.aux_task_output is not None and 91 | args.aux_output_label_request is None): 92 | logger.error("Need auxiliary output labels if using aux output task") 93 | parser.print_help() 94 | sys.exit() 95 | 96 | if args.task_input == args.task_output: 97 | if args.output_label_request is None: 98 | logger.info(("Will create h5 files with only input " 99 | "data for self-supervision tasks...")) 100 | else: 101 | logger.info("Will create h5 files with input and output data.") 102 | 103 | return args 104 | 105 | 106 | def setup_filepaths(data_path, participant_numbers): 107 | """Set up filepaths for reading in participant .h5 files. 108 | 109 | Args: 110 | data_path (str): path to directory containing .h5 files 111 | participant_numbers (list): participant numbers for filepaths 112 | 113 | Returns: 114 | list: filepaths to all of the .h5 files 115 | """ 116 | all_filepaths = [] 117 | for participant_number in participant_numbers: 118 | filepaths = glob.glob(data_path + "/" + participant_number + "_*.h5") 119 | all_filepaths += filepaths 120 | return all_filepaths 121 | 122 | 123 | def map_requests(tasks, labels): 124 | """Generate a dict of tasks mapped to labels. 125 | 126 | Args: 127 | tasks (list): list of tasks/groups that will be mapped to labels 128 | labels (list): list of labels that will be the value for each task 129 | 130 | Returns: 131 | dict: dictionary mapping each task to the list of labels 132 | """ 133 | requests = dict(map(lambda e: (e, labels), tasks)) 134 | return requests 135 | 136 | 137 | def write_dataset(filepath_groups, variable, experiment_setup, requests): 138 | """Write data to training, validation, testing .h5 files. 139 | 140 | Args: 141 | filepath_groups (list): list of tuples associate files to data group 142 | variable (str): the machine learning variable X or Y to be written to 143 | experiment_setup (dict): map to reference task for variable 144 | requests (dict): requests to read from files to store with variable 145 | """ 146 | for filepaths, group in filepath_groups: 147 | logger.info(f"Writing {variable} to the {group} set...") 148 | h5_filename = args.output_path + "/" + group + ".h5" 149 | with h5py.File(h5_filename, "a") as h5_file: 150 | dataset = read_h5(filepaths, requests) 151 | for filename in dataset.keys(): 152 | temp_dataset = None 153 | for j, data_group in enumerate(experiment_setup[variable]): 154 | if temp_dataset is None: 155 | temp_dataset = dataset[filename][data_group] 156 | else: 157 | temp_dataset = np.append(temp_dataset, 158 | dataset[filename][data_group], 159 | axis=1) 160 | try: 161 | h5_file.create_dataset(filename + "/" + variable, 162 | data=temp_dataset) 163 | except KeyError: 164 | logger.info(f"{filename} does not contain {data_group}") 165 | 166 | 167 | if __name__ == "__main__": 168 | args = parse_args() 169 | 170 | train_filepaths = setup_filepaths(args.data_path, args.training.split(" ")) 171 | val_filepaths = setup_filepaths(args.data_path, args.validation.split(" ")) 172 | test_filepaths = setup_filepaths(args.data_path, args.testing.split(" ")) 173 | 174 | filepath_groups = [(train_filepaths, "training"), 175 | (val_filepaths, "validation"), 176 | (test_filepaths, "testing")] 177 | 178 | task_input = args.task_input.split(" ") 179 | input_label_request = args.input_label_request.split(" ") 180 | 181 | task_output = args.task_output.split(" ") 182 | if args.output_label_request is not None: 183 | output_label_request = args.output_label_request.split(" ") 184 | 185 | if (args.task_input == args.task_output 186 | and args.output_label_request is None): 187 | experiment_setup = {"X": task_input} 188 | requests = map_requests(task_input, input_label_request) 189 | 190 | write_dataset(filepath_groups, "X", experiment_setup, requests) 191 | else: 192 | experiment_setup = {"X": task_input, "Y": task_output} 193 | 194 | input_requests = map_requests(task_input, input_label_request) 195 | output_requests = map_requests(task_output, output_label_request) 196 | 197 | if args.aux_task_output is not None: 198 | aux_task_output = args.aux_task_output.split(" ") 199 | aux_output_label_request = args.aux_output_label_request.split(" ") 200 | experiment_setup["Y"] += aux_task_output 201 | aux_output_requests = map_requests(aux_task_output, 202 | aux_output_label_request) 203 | output_requests.update(aux_output_requests) 204 | 205 | write_dataset(filepath_groups, "X", experiment_setup, input_requests) 206 | write_dataset(filepath_groups, "Y", experiment_setup, output_requests) 207 | -------------------------------------------------------------------------------- /src/build_dataset.sh: -------------------------------------------------------------------------------- 1 | python build-dataset.py --data-path "/groups/MotionPred/h5-dataset" \ 2 | --output-path "/home/jackg7/VT-Natural-Motion-Processing/data/set-2" \ 3 | --training "P1" \ 4 | --validation "P5" \ 5 | --testing "P10" \ 6 | --task-input "normOrientation normAcceleration" \ 7 | --input-label-request "T8 RightForeArm RightLowerLeg LeftForeArm LeftLowerLeg" \ 8 | --task-output "normOrientation" \ 9 | --output-label-request "all" 10 | -------------------------------------------------------------------------------- /src/common/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | import warnings 11 | import h5py 12 | import math 13 | import os 14 | from .logging import logger 15 | import sys 16 | from torch.utils.data import TensorDataset, DataLoader 17 | 18 | 19 | class XSensDataIndices: 20 | """XSensDataIndices helps with retrieving data. 21 | 22 | XSens has a data layout which requires interfacing data measurements 23 | (position, orientation, etc.) with places (RightLowerLeg, RightWrist) 24 | where those measurements are available. 25 | """ 26 | 27 | def __init__(self): 28 | """Initialize object for use in reading data from .h5 files.""" 29 | sensor_group = ["sensorFreeAcceleration", 30 | "sensorMagneticField", 31 | "sensorOrientation"] 32 | 33 | segment_group = ["position", "normPositions", 34 | "velocity", "acceleration", 35 | "normAcceleration", "sternumNormAcceleration", 36 | "angularVelocity", "angularAcceleration", 37 | "orientation", "normOrientation", 38 | "sternumNormOrientation"] 39 | 40 | joint_group = ["jointAngle", "jointAngleXZY"] 41 | 42 | joint_ergo_group = ["jointAngleErgo", "jointAngleErgoXZY"] 43 | 44 | groups = [sensor_group, segment_group, 45 | joint_group, joint_ergo_group] 46 | self._labels_to_items(groups) 47 | 48 | def __call__(self, requests): 49 | """Retrieve indices using a request. 50 | 51 | Requests are dicts that use groups like "position" as keys and 52 | labels like ["Pelvis"]. Requests can be passed to an instance of 53 | XSensDataIndices to retrieve the indices for the labels of each group. 54 | 55 | Args: 56 | requests (dict): maps groups to desired labels and will retrieve 57 | indices for those group labels. 58 | 59 | Returns: 60 | dict: a map of the groups to the indices for the labels. 61 | """ 62 | label_indices = {} 63 | for label, items in requests.items(): 64 | if label in self.label_items: 65 | label_indices[label] = self._request(label, items) 66 | return label_indices 67 | 68 | def _labels_to_items(self, groups): 69 | self.label_items = {} 70 | sensors = ["Pelvis", "T8", "Head", "RightShoulder", "RightUpperArm", 71 | "RightForeArm", "RightHand", "LeftShoulder", "LeftUpperArm", 72 | "LeftForeArm", "LeftHand", "RightUpperLeg", "RightLowerLeg", 73 | "RightFoot", "LeftUpperLeg", "LeftLowerLeg", "LeftFoot"] 74 | 75 | segments = ["Pelvis", "L5", "L3", "T12", 76 | "T8", "Neck", "Head", 77 | "RightShoulder", "RightUpperArm", 78 | "RightForeArm", "RightHand", 79 | "LeftShoulder", "LeftUpperArm", 80 | "LeftForeArm", "LeftHand", 81 | "RightUpperLeg", "RightLowerLeg", 82 | "RightFoot", "RightToe", 83 | "LeftUpperLeg", "LeftLowerLeg", 84 | "LeftFoot", "LeftToe"] 85 | 86 | joints = ["jL5S1", "jL4L3", "jL1T12", 87 | "jT9T8", "jT1C7", "jC1Head", 88 | "jRightT4Shoulder", "jRightShoulder", 89 | "jRightElbow", "jRightWrist", 90 | "jLeftT4Shoulder", "jLeftShoulder", 91 | "jLeftElbow", "jLeftWrist", 92 | "jRightHip", "jRightKnee", 93 | "jRightAnkle", "jRightBallFoot", 94 | "jLeftHip", "jLeftKnee", 95 | "jLeftAnkle", "jLeftBallFoot"] 96 | 97 | ergo_joints = ["T8_Head", "T8_LeftUpperArm", "T8_RightUpperArm", 98 | "Pelvis_T8", "Vertical_Pelvis", "Vertical_T8"] 99 | 100 | item_groups = [sensors, segments, joints, ergo_joints] 101 | 102 | for index, group in enumerate(groups): 103 | for label in group: 104 | self.label_items[label] = item_groups[index] 105 | 106 | def _request(self, req_label, req_items): 107 | valid_items = self.label_items[req_label] 108 | 109 | if "all" in req_items: 110 | req_items = valid_items 111 | 112 | num_valid_items = len(valid_items) 113 | orientation_groups = ["orientation", "normOrientation", 114 | "sternumNormOrientation"] 115 | dims = 4 if req_label in orientation_groups else 3 116 | 117 | indices = [list(range(i, i+dims)) 118 | for i in range(0, dims*num_valid_items, dims)] 119 | 120 | index_map = dict(zip(valid_items, indices)) 121 | 122 | return self._find_indices(index_map, req_items) 123 | 124 | def _find_indices(self, index_map, items): 125 | mapped_indices = [] 126 | 127 | for item in items: 128 | if item in index_map: 129 | mapped_indices.append(index_map[item]) 130 | else: 131 | warnings.warn("Requested item {} not in file.".format(item)) 132 | 133 | return mapped_indices 134 | 135 | 136 | def discard_remainder(data, seq_length): 137 | """Discard data that does not fit inside sequence length. 138 | 139 | Args: 140 | data (np.ndarray): data to truncate 141 | seq_length (int): sequence length to find data that doesn"t fit into 142 | sequences 143 | 144 | Returns: 145 | np.ndarray: truncated data 146 | """ 147 | new_row_num = data.shape[0] - (data.shape[0] % seq_length) 148 | data = data[:new_row_num] 149 | return data 150 | 151 | 152 | def stride_downsample_sequences(data, seq_length, stride, downsample, 153 | offset=0, in_out_ratio=1): 154 | """Build sequences with an array of data tensor. 155 | 156 | Args: 157 | data (np.ndarray): data to turn into sequences 158 | seq_length (int): sequence length of the original sequence (e.g., 30 159 | frames will be downsampled to 5 frames if downsample is 6.) 160 | stride (int): step size over data when looping over frames 161 | downsample (int): amount to downsample data (e.g., 6 will take 240 Hz 162 | to 40 Hz.) 163 | offset (int, optional): offset for index when looping; useful for the 164 | prediction task when making output data. Defaults to 0. 165 | in_out_ratio (int, optional): ratio of input to output; useful for the 166 | conversion task when making output data. Defaults to 1. 167 | 168 | Returns: 169 | np.ndarray: data broken into sequences 170 | """ 171 | samples = [] 172 | for i in range(0, data.shape[0] - 2*seq_length, stride): 173 | i_shift = i+offset 174 | sample = data[i_shift:i_shift+seq_length:downsample, :] 175 | 176 | ratio_shift = sample.shape[0] - sample.shape[0]//in_out_ratio 177 | sample = sample[ratio_shift:, :] 178 | samples.append(sample) 179 | samples = np.concatenate(samples, axis=0) 180 | return samples 181 | 182 | 183 | def read_h5(filepaths, requests): 184 | """Read data from an h5 file and store in a dataset dict. 185 | 186 | Primarily used for building a dataset (see build-dataset.py) 187 | 188 | Args: 189 | filepaths (list): list of file paths to draw data from. 190 | requests (dict): dictionary of requests to make to files. 191 | 192 | Returns: 193 | dict: dictionary containing files mapped to labels mapped to data 194 | """ 195 | xsensIndices = XSensDataIndices() 196 | indices = xsensIndices(requests) 197 | 198 | def flatten(l): return [item for sublist in l for item in sublist] 199 | 200 | h5_files = [] 201 | for filepath in filepaths: 202 | try: 203 | h5_file = h5py.File(filepath, "r+") 204 | except OSError: 205 | logger.info(f"OSError: Unable to open file {filepath}") 206 | continue 207 | h5_files.append((h5_file, os.path.basename(filepath))) 208 | 209 | dataset = {} 210 | for h5_file, filename in h5_files: 211 | dataset[filename] = {} 212 | for label in indices: 213 | label_indices = flatten(indices[label]) 214 | label_indices.sort() 215 | 216 | file_data = np.array(h5_file[label]) 217 | file_data = file_data.reshape(file_data.shape[1], 218 | file_data.shape[0]) 219 | 220 | data = np.array(file_data[:, label_indices]) 221 | 222 | dataset[filename][label] = data 223 | 224 | h5_file.close() 225 | 226 | return dataset 227 | 228 | 229 | def read_variables(h5_file_path, task, seq_length, stride, downsample, 230 | in_out_ratio=1): 231 | """Read data from dataset and store in X and y variables. 232 | 233 | Args: 234 | h5_file_path (str): h5 file containing dataset built previously 235 | task (str): either prediction or conversion; task that will be modeled 236 | by the machine learning model 237 | seq_length (int): original sequence length before downsampling data 238 | stride (int): step size over data when building sequences 239 | downsample (int): amount to downsample data by (e.g., 6 to reduce 240 | 240 Hz sampling rate to 40 Hz.) 241 | in_out_ratio (int, optional): input length compared to output length. 242 | Defaults to 1. 243 | 244 | Returns: 245 | tuple: returns a tuple of variables X and y for use in a machine 246 | learning task. 247 | """ 248 | X, y = None, None 249 | h5_file = h5py.File(h5_file_path, "r") 250 | for filename in h5_file.keys(): 251 | X_temp = h5_file[filename]["X"] 252 | X_temp = discard_remainder(X_temp, 2*seq_length) 253 | 254 | if task == "prediction": 255 | y_temp = stride_downsample_sequences(X_temp, seq_length, stride, 256 | downsample, offset=seq_length) 257 | elif task == "conversion": 258 | y_temp = h5_file[filename]["Y"] 259 | y_temp = discard_remainder(y_temp, 2*seq_length) 260 | y_temp = stride_downsample_sequences(y_temp, seq_length, stride, 261 | downsample, 262 | in_out_ratio=in_out_ratio) 263 | else: 264 | logger.error(("Task must be either prediction or conversion, " 265 | f"found {task}")) 266 | sys.exit() 267 | 268 | X_temp = stride_downsample_sequences(X_temp, seq_length, 269 | stride, downsample) 270 | 271 | assert not np.any(np.isnan(X_temp)) 272 | assert not np.any(np.isnan(y_temp)) 273 | 274 | if X is None and y is None: 275 | X = torch.tensor(X_temp) 276 | y = torch.tensor(y_temp) 277 | else: 278 | X = torch.cat((X, torch.tensor(X_temp)), dim=0) 279 | y = torch.cat((y, torch.tensor(y_temp)), dim=0) 280 | h5_file.close() 281 | return X, y 282 | 283 | 284 | def load_dataloader(args, set_type, normalize, norm_data=None, shuffle=True): 285 | """Create dataloaders for PyTorch machine learning tasks. 286 | 287 | Args: 288 | args (argparse.Namespace): contains accessible arguments passed in 289 | to module 290 | set_type (str): set to read from when gathering data (either training, 291 | validation or testing sets) 292 | normalize (bool): whether to normalize the data before adding to 293 | dataloader 294 | norm_data (tuple, optional): if passed will contain mean and std_dev 295 | data to normalize input data with. Defaults to None. 296 | shuffle (bool, optional): whether to shuffle the data stored in the 297 | dataloader. Defaults to True. 298 | 299 | Returns: 300 | tuple: returns a tuple containing the DataLoader and the normalization 301 | data 302 | """ 303 | file_path = args.data_path + "/" + set_type + ".h5" 304 | seq_length = int(args.seq_length) 305 | downsample = int(args.downsample) 306 | batch_size = int(args.batch_size) 307 | in_out_ratio = int(args.in_out_ratio) 308 | stride = int(args.stride) if set_type == "training" else seq_length//2 309 | 310 | logger.info((f"Retrieving {set_type} data " 311 | f"for sequences {int(seq_length/240*1000)} ms long and " 312 | f"downsampling to {240/downsample} Hz...")) 313 | 314 | X, y = read_variables(file_path, args.task, seq_length, stride, downsample, 315 | in_out_ratio=in_out_ratio) 316 | 317 | if normalize: 318 | mean, std_dev = None, None 319 | if norm_data is None: 320 | mean, std_dev = X.mean(dim=0), X.std(dim=0) 321 | norm_data = (mean, std_dev) 322 | with h5py.File(args.data_path + "/normalization.h5", "w") as f: 323 | f["mean"], f["std_dev"] = mean, std_dev 324 | else: 325 | mean, std_dev = norm_data 326 | X = X.sub(mean).div(std_dev + 1e-8) 327 | 328 | logger.info(f"Data for {set_type} have shapes " 329 | f"(X, y): {X.shape}, {y.shape}") 330 | 331 | X = X.view(-1, math.ceil(seq_length/downsample), X.shape[1]) 332 | y = y.view(-1, math.ceil(seq_length/(downsample*in_out_ratio)), y.shape[1]) 333 | 334 | logger.info(f"Reshaped {set_type} shapes (X, y): {X.shape}, {y.shape}") 335 | 336 | dataset = TensorDataset(X, y) 337 | 338 | shuffle = True if set_type == "training" else False 339 | dataloader = DataLoader(dataset, batch_size=batch_size, 340 | shuffle=shuffle, drop_last=True) 341 | 342 | logger.info(f"Number of {set_type} samples: {len(dataset)}") 343 | 344 | return dataloader, norm_data 345 | -------------------------------------------------------------------------------- /src/common/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | 10 | import logging 11 | 12 | logging.basicConfig( 13 | format="%(asctime)s %(levelname)-8s %(message)s", 14 | level=logging.INFO, 15 | datefmt="%Y-%m-%d %H:%M:%S") 16 | logger = logging.getLogger() 17 | logger.setLevel(logging.INFO) 18 | -------------------------------------------------------------------------------- /src/common/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | 12 | 13 | class QuatDistance(nn.Module): 14 | """Loss function for calculating cosine similarity of quaternions.""" 15 | 16 | def __init__(self): 17 | """Initialize QuatDistance loss.""" 18 | super(QuatDistance, self).__init__() 19 | 20 | def forward(self, predictions, targets): 21 | """Forward pass through the QuatDistance loss. 22 | 23 | Args: 24 | predictions (torch.Tensor): the predictions from the model in 25 | quaternion form. 26 | targets (torch.Tensor): the targets in quaternion form. 27 | 28 | Returns: 29 | torch.Tensor: average angular difference in degrees between 30 | quaternions in predictions and targets. 31 | """ 32 | predictions = predictions.contiguous().view(-1, 1, 4) 33 | targets = targets.contiguous().view(-1, 4, 1) 34 | 35 | inner_prod = torch.bmm(predictions, targets).view(-1) 36 | 37 | x = torch.clamp(torch.abs(inner_prod), min=0.0, max=1.0-1e-7) 38 | 39 | theta = torch.acos(x) 40 | 41 | return (360.0/math.pi)*torch.mean(theta) 42 | -------------------------------------------------------------------------------- /src/common/preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | import h5py 11 | from .logging import logger 12 | from .rotations import quat_to_rotMat, rotMat_to_quat 13 | from .quaternion import quat_fix 14 | 15 | 16 | def add_normalized_positions(filepaths, new_group_name): 17 | """Add position data normalized relative to the pelvis to h5 files. 18 | 19 | Args: 20 | filepaths (list): paths to files to add data to 21 | new_group_name (str): what the new group will be called in the h5 22 | file 23 | """ 24 | for filepath in filepaths: 25 | try: 26 | h5_file = h5py.File(filepath, "r+") 27 | except OSError: 28 | logger.info(f"OSError: Unable to open file {filepath}") 29 | continue 30 | 31 | quat = np.array(h5_file["orientation"][:, :]) 32 | quat = quat.reshape(quat.shape[1], quat.shape[0]) 33 | quat = quat.reshape(quat.shape[0], -1, 4) 34 | 35 | pos = np.array(h5_file["position"][:, :]) 36 | pos = pos.reshape(pos.shape[1], pos.shape[0]) 37 | pos = pos.reshape(pos.shape[0], -1, 3) 38 | 39 | quat = quat_fix(quat) 40 | 41 | norm_pos = np.zeros(pos.shape) 42 | 43 | pelvis_rot = np.linalg.inv( 44 | quat_to_rotMat(torch.tensor(quat[:, 0, :])) 45 | ) 46 | pelvis_pos = pos[:, 0, :] 47 | for i in range(0, quat.shape[1]): 48 | relative_pos = np.expand_dims(pos[:, i, :] - pelvis_pos, axis=2) 49 | norm_pos[:, i, :] = np.squeeze(np.matmul(pelvis_rot, relative_pos), 50 | axis=2) 51 | 52 | norm_pos = norm_pos.reshape(norm_pos.shape[0], -1) 53 | norm_pos = norm_pos.reshape(norm_pos.shape[1], norm_pos.shape[0]) 54 | 55 | try: 56 | logger.info(f"Writing to file {filepath}") 57 | h5_file.create_dataset(new_group_name, data=norm_pos) 58 | except RuntimeError: 59 | logger.info(("RuntimeError: Unable to create link " 60 | f"(name already exists) in {filepath}")) 61 | h5_file.close() 62 | 63 | 64 | def add_normalized_accelerations(filepaths, group_name, new_group_name, 65 | root=0): 66 | """Add acceleration data normalized relative to a root to the h5 files. 67 | 68 | Args: 69 | filepaths (list): paths to files to add data to 70 | group_name (str): acceleration group to normalize 71 | (typically acceleration, but can also be sensorFreeAcceleration) 72 | new_group_name (str): new group name for normalized acceleration data 73 | root (int, optional): index of root (e.g., 0 is pelvis, 4 is sternum). 74 | Defaults to 0. 75 | """ 76 | for filepath in filepaths: 77 | try: 78 | h5_file = h5py.File(filepath, "r+") 79 | except OSError: 80 | logger.info(f"OSError: Unable to open file {filepath}") 81 | continue 82 | 83 | quat = np.array(h5_file["orientation"][:, :]) 84 | quat = quat.reshape(quat.shape[1], quat.shape[0]) 85 | quat = quat.reshape(quat.shape[0], -1, 4) 86 | 87 | acc = np.array(h5_file[group_name][:, :]) 88 | acc = acc.reshape(acc.shape[1], acc.shape[0]) 89 | acc = acc.reshape(acc.shape[0], -1, 3) 90 | 91 | quat = quat_fix(quat) 92 | 93 | norm_acc = np.zeros(acc.shape) 94 | 95 | root_rot = np.linalg.inv( 96 | quat_to_rotMat(torch.tensor(quat[:, root, :])) 97 | ) 98 | root_acc = acc[:, root, :] 99 | for i in range(0, acc.shape[1]): 100 | relative_acc = np.expand_dims(acc[:, i, :] - root_acc, axis=2) 101 | norm_acc[:, i, :] = np.squeeze(np.matmul(root_rot, relative_acc), 102 | axis=2) 103 | 104 | norm_acc = norm_acc.reshape(norm_acc.shape[0], -1) 105 | norm_acc = norm_acc.reshape(norm_acc.shape[1], norm_acc.shape[0]) 106 | 107 | try: 108 | logger.info(f"Writing to file {filepath}") 109 | h5_file.create_dataset(new_group_name, data=norm_acc) 110 | except RuntimeError: 111 | logger.info(("RuntimeError: Unable to create link " 112 | f"(name already exists) in {filepath}")) 113 | h5_file.close() 114 | 115 | 116 | def add_normalized_quaternions(filepaths, group_name, new_group_name, root=0): 117 | """Add orientation data normalized relative to a root to the h5 files. 118 | 119 | Args: 120 | filepaths (list): paths to files to add data to 121 | group_name (str): orientation group to normalize 122 | (typically orientation, but can also be sensorOrientation) 123 | new_group_name (str): new group name for normalized orientation data 124 | root (int, optional): index of root (e.g., 0 is pelvis, 4 is sternum). 125 | Defaults to 0. 126 | """ 127 | for filepath in filepaths: 128 | try: 129 | h5_file = h5py.File(filepath, "r+") 130 | except OSError: 131 | logger.info(f"OSError: Unable to open file {filepath}") 132 | continue 133 | 134 | quat = np.array(h5_file[group_name][:, :]) 135 | quat = quat.reshape(quat.shape[1], quat.shape[0]) 136 | quat = quat.reshape(quat.shape[0], -1, 4) 137 | 138 | quat = quat_fix(quat) 139 | 140 | norm_quat = np.zeros(quat.shape) 141 | 142 | root_rotMat = np.linalg.inv( 143 | quat_to_rotMat(torch.tensor(quat[:, root, :])) 144 | ) 145 | for i in range(0, quat.shape[1]): 146 | rotMat = quat_to_rotMat(torch.tensor(quat[:, i, :])) 147 | norm_rotMat = np.matmul(root_rotMat, rotMat) 148 | norm_quat[:, i, :] = rotMat_to_quat(norm_rotMat) 149 | 150 | norm_quat = norm_quat.reshape(norm_quat.shape[0], -1) 151 | norm_quat = norm_quat.reshape(norm_quat.shape[1], norm_quat.shape[0]) 152 | 153 | try: 154 | logger.info(f"Writing to file {filepath}") 155 | h5_file.create_dataset(new_group_name, data=norm_quat) 156 | except RuntimeError: 157 | logger.info(("RuntimeError: Unable to create link " 158 | f"(name already exists) in {filepath}")) 159 | h5_file.close() 160 | -------------------------------------------------------------------------------- /src/common/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the Attribution-NonCommercial 5 | # 4.0 International license and is borrowed from the QuaterNet library. 6 | # See https://github.com/facebookresearch/QuaterNet/blob/master/LICENSE for 7 | # more details. 8 | # 9 | 10 | import torch 11 | import numpy as np 12 | 13 | 14 | def quat_fix(q): 15 | """Enforce quaternion continuity across the time dimension. 16 | 17 | Borrowed from QuaterNet: 18 | https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L119 19 | 20 | This function falls under the Attribution-NonCommercial 4.0 International 21 | license. 22 | 23 | Selects the representation (q or -q) with minimal distance 24 | (or, equivalently, maximal dot product) between two consecutive frames. 25 | 26 | Expects a tensor of shape (L, J, 4), where L is the sequence length and 27 | J is the number of joints. 28 | Returns a tensor of the same shape. 29 | 30 | Args: 31 | q (np.ndarray): quaternions of size (L, J, 4) to enforce continuity. 32 | 33 | Returns: 34 | np.ndarray: quaternion of size (L, J, 4) that is continuous 35 | in time dimension. 36 | """ 37 | assert len(q.shape) == 3 38 | assert q.shape[-1] == 4 39 | 40 | result = q.copy() 41 | dot_products = np.sum(q[1:]*q[:-1], axis=2) 42 | mask = dot_products < 0 43 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 44 | result[1:][mask] *= -1 45 | return result 46 | 47 | 48 | def quat_mul(q, r): 49 | """Multiply quaternion(s) q with quaternion(s) r. 50 | 51 | Borrowed from QuaterNet: 52 | https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L13 53 | 54 | This function falls under the Attribution-NonCommercial 4.0 International 55 | license. 56 | 57 | Expects two equally-sized tensors of shape (*, 4), where * denotes any 58 | number of dimensions. 59 | Returns q*r as a tensor of shape (*, 4). 60 | 61 | Args: 62 | q (torch.Tensor): quaternions of size (*, 4) 63 | r (torch.Tensor): quaternions of size (*, 4) 64 | 65 | Returns: 66 | torch.Tensor: quaternions of size (*, 4) 67 | """ 68 | assert q.shape[-1] == 4 69 | assert r.shape[-1] == 4 70 | 71 | original_shape = q.shape 72 | 73 | # Compute outer product 74 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 75 | 76 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 77 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 78 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 79 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 80 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 81 | -------------------------------------------------------------------------------- /src/common/rotations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | 11 | def quat_to_rotMat(q): 12 | """Convert quaternions to rotation matrices. 13 | 14 | Using equation provided in XSens MVN Manual: 15 | https://www.xsens.com/hubfs/Downloads/usermanual/MVN_User_Manual.pdf 16 | 17 | Args: 18 | q (torch.Tensor): quaternion(s) to convert to rotation matrix format 19 | 20 | Returns: 21 | torch.Tensor: rotation matrix converted from quaternion format 22 | """ 23 | if len(q.shape) != 2: 24 | q = q.unsqueeze(0) 25 | 26 | assert q.shape[1] == 4 27 | 28 | r0c0 = q[:, 0]**2 + q[:, 1]**2 - q[:, 2]**2 - q[:, 3]**2 29 | r0c1 = 2*q[:, 1]*q[:, 2] - 2*q[:, 0]*q[:, 3] 30 | r0c2 = 2*q[:, 1]*q[:, 3] + 2*q[:, 0]*q[:, 2] 31 | 32 | r1c0 = 2*q[:, 1]*q[:, 2] + 2*q[:, 0]*q[:, 3] 33 | r1c1 = q[:, 0]**2 - q[:, 1]**2 + q[:, 2]**2 - q[:, 3]**2 34 | r1c2 = 2*q[:, 2]*q[:, 3] - 2*q[:, 0]*q[:, 1] 35 | 36 | r2c0 = 2*q[:, 1]*q[:, 3] - 2*q[:, 0]*q[:, 2] 37 | r2c1 = 2*q[:, 2]*q[:, 3] + 2*q[:, 0]*q[:, 1] 38 | r2c2 = q[:, 0]**2 - q[:, 1]**2 - q[:, 2]**2 + q[:, 3]**2 39 | 40 | r0 = torch.stack([r0c0, r0c1, r0c2], dim=1) 41 | r1 = torch.stack([r1c0, r1c1, r1c2], dim=1) 42 | r2 = torch.stack([r2c0, r2c1, r2c2], dim=1) 43 | 44 | R = torch.stack([r0, r1, r2], dim=2) 45 | 46 | return R.permute(0, 2, 1) 47 | 48 | 49 | def rotMat_to_quat(rotMat): 50 | """Convert rotation matrices back to quaternions. 51 | 52 | Ported from Matlab: 53 | https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/mhmublv/Motion/rotmat2quat.m#L4 54 | 55 | 56 | Args: 57 | rotMat (torch.Tensor): rotation matrix or matrices to convert to 58 | quaternion format. 59 | 60 | Returns: 61 | torch.Tensor: quaternion(s) converted from rotation matrix format 62 | """ 63 | if len(rotMat.shape) != 3: 64 | rotMat = rotMat.unsqueeze(0) 65 | 66 | assert rotMat.shape[1] == 3 and rotMat.shape[2] == 3 67 | 68 | diffMat = rotMat - torch.transpose(rotMat, 1, 2) 69 | 70 | r = torch.zeros((rotMat.shape[0], 3), dtype=torch.float64) 71 | 72 | r[:, 0] = -diffMat[:, 1, 2] 73 | r[:, 1] = diffMat[:, 0, 2] 74 | r[:, 2] = -diffMat[:, 0, 1] 75 | 76 | sin_theta = torch.norm(r, dim=1)/2 77 | sin_theta = sin_theta.unsqueeze(1) 78 | 79 | r0 = r / (torch.norm(r, dim=1).unsqueeze(1) + 1e-9) 80 | 81 | cos_theta = (rotMat.diagonal(dim1=-2, dim2=-1).sum(-1) - 1) / 2 82 | cos_theta = cos_theta.unsqueeze(1) 83 | 84 | theta = torch.atan2(sin_theta, cos_theta) 85 | 86 | theta = theta.squeeze(1) 87 | 88 | q = torch.zeros((rotMat.shape[0], 4), dtype=torch.float64) 89 | 90 | q[:, 0] = torch.cos(theta/2) 91 | q[:, 1:] = r0*torch.sin(theta/2).unsqueeze(1) 92 | 93 | return q 94 | -------------------------------------------------------------------------------- /src/common/skeleton.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | from .rotations import quat_to_rotMat 10 | from .data_utils import XSensDataIndices 11 | import matplotlib.pyplot as plt 12 | import mpl_toolkits.mplot3d.axes3d as p3 13 | import matplotlib.animation as animation 14 | 15 | 16 | class Skeleton: 17 | """Skeleton for modeling and visualizing forward kinematics.""" 18 | 19 | def __init__(self): 20 | """Initialize the skeleton, using segment lengths from P1. 21 | 22 | Have to initialize segments, parents of the segments, and map them 23 | together. 24 | """ 25 | segments = ["Pelvis", "L5", "L3", 26 | "T12", "T8", "Neck", "Head", 27 | "RightShoulder", "RightUpperArm", 28 | "RightForeArm", "RightHand", 29 | "LeftShoulder", "LeftUpperArm", 30 | "LeftForeArm", "LeftHand", 31 | "RightUpperLeg", "RightLowerLeg", 32 | "RightFoot", "RightToe", 33 | "LeftUpperLeg", "LeftLowerLeg", 34 | "LeftFoot", "LeftToe"] 35 | 36 | parents = [None, "Pelvis", "L5", 37 | "L3", "T12", "T8", "Neck", 38 | "T8", "RightShoulder", "RightUpperArm", "RightForeArm", 39 | "T8", "LeftShoulder", "LeftUpperArm", "LeftForeArm", 40 | "Pelvis", "RightUpperLeg", "RightLowerLeg", "RightFoot", 41 | "Pelvis", "LeftUpperLeg", "LeftLowerLeg", "LeftFoot"] 42 | 43 | body_frame_segments = [None, 44 | [0.000000, 0.000000, 0.103107], 45 | [0.000000, 0.000000, 0.095793], 46 | [0.000000, 0.000000, 0.087564], 47 | [0.000000, 0.000000, 0.120823], 48 | [0.000000, 0.000000, 0.103068], 49 | [0.000000, 0.000000, 0.208570], 50 | [0.000000, -0.027320, 0.068158], 51 | [0.000000, -0.141007, 0.000000], 52 | [0.000000, -0.291763, 0.000000], 53 | [0.000071, -0.240367, 0.000000], 54 | [0.000000, 0.027320, 0.068158], 55 | [0.000000, 0.141007, 0.000000], 56 | [0.000000, 0.291763, 0.000000], 57 | [0.000071, 0.240367, 0.000000], 58 | [0.000000, -0.087677, 0.000000], 59 | [-0.000055, 0.000000, -0.439960], 60 | [0.000248, 0.000000, -0.445123], 61 | [0.192542, 0.000000, -0.087304], 62 | [0.000000, 0.087677, 0.000000], 63 | [-0.000055, 0.000000, -0.439960], 64 | [0.000248, 0.000000, -0.445123], 65 | [0.192542, 0.000000, -0.087304]] 66 | 67 | self.skeleton_tree = [[0, 1, 2, 3, 4], 68 | [4, 7, 8, 9, 10], 69 | [4, 11, 12, 13, 14], 70 | [4, 5, 6], 71 | [0, 15, 16, 17, 18], 72 | [0, 19, 20, 21, 22]] 73 | 74 | self.segments = segments 75 | self.index_of = dict(zip(segments, range(len(segments)))) 76 | self.segment_parents = dict(zip(segments, parents)) 77 | self.segment_positions_in_parent_frame = dict(zip(segments, 78 | body_frame_segments)) 79 | 80 | def forward_kinematics(self, orientations): 81 | """Compute positions of segment endpoints using orientation of segments. 82 | 83 | Args: 84 | orientations (torch.Tensor): orientations of segments in the 85 | skeleton 86 | 87 | Returns: 88 | torch.Tensor: position of segment endpoints 89 | """ 90 | xsens_indices = XSensDataIndices() 91 | 92 | positions = torch.zeros([len(self.segments), 93 | orientations.shape[0], 94 | 3], dtype=torch.float32) 95 | 96 | for i, segment in enumerate(self.segments): 97 | parent = self.segment_parents[segment] 98 | 99 | if parent is None: 100 | continue 101 | else: 102 | indices_map = xsens_indices({"orientation": [parent]}) 103 | parent_indices = indices_map["orientation"][0] 104 | 105 | x_B = torch.tensor( 106 | self.segment_positions_in_parent_frame[segment], 107 | dtype=torch.float32) 108 | 109 | x_B = x_B.view(1, -1, 1).repeat(orientations.shape[0], 1, 1) 110 | 111 | R_GB = quat_to_rotMat(orientations[:, parent_indices]) 112 | positions[i] = (positions[self.index_of[parent]] + 113 | R_GB.bmm(x_B).squeeze(2)) 114 | 115 | return positions.permute(1, 0, 2) 116 | 117 | def animate_motion(self, orientations, azim, elev, title=None): 118 | """Animate frames of orientation data using forward kinematics. 119 | 120 | Args: 121 | orientations (torch.Tensor): orientations of the segments in the 122 | kinematic chain 123 | azim (float): azimuth of the plot point of view 124 | elev (float): elevation of the plot point of view 125 | title (str, optional): plot title. Defaults to None. 126 | 127 | Returns: 128 | animation.FuncAnimation: returns an animation that can be saved or 129 | viewed in a Jupyter Notebook 130 | """ 131 | if len(orientations.shape) == 1: 132 | orientations = orientations.unsqueeze(0) 133 | 134 | def update_lines(num, data, lines): 135 | positions = data[num] 136 | 137 | for i, line in enumerate(lines): 138 | xs = list(positions[self.skeleton_tree[i], 0]) 139 | ys = list(positions[self.skeleton_tree[i], 1]) 140 | zs = list(positions[self.skeleton_tree[i], 2]) 141 | 142 | line.set_data(xs, ys) 143 | line.set_3d_properties(zs) 144 | line.set_linestyle("-") 145 | return lines 146 | 147 | fig = plt.figure() 148 | ax = p3.Axes3D(fig) 149 | 150 | if title is not None: 151 | ax.set_title(title) 152 | 153 | data = self.forward_kinematics(orientations) 154 | 155 | lines = [ax.plot([0], [0], [0])[0] for _ in range(6)] 156 | limits = [-1.0, 1.0] 157 | 158 | self._setup_axis(ax, limits, azim, elev) 159 | 160 | line_ani = animation.FuncAnimation(fig, 161 | update_lines, 162 | frames=range(data.shape[0]), 163 | fargs=(data, lines), 164 | interval=25, 165 | blit=True) 166 | plt.show() 167 | return line_ani 168 | 169 | def compare_motion(self, orientations, azim, elev, 170 | fig_filename=None, titles=None): 171 | """Display plots of different orientation frames. 172 | 173 | Primarily useful for plotting skeletons for orientation outputs from 174 | different models. 175 | 176 | Args: 177 | orientations (torch.Tensor): orientations of the segments in the 178 | kinematic chain, typically of orientations from different 179 | models. 180 | azim (float): azimuth of the plot point of view 181 | elev (float): elevation of the plot point of view 182 | fig_filename (str, optional): figure filename. Defaults to None. 183 | titles (str, optional): plot titles. Defaults to None. 184 | 185 | Returns: 186 | matplotlib.pyplot.figure: figure that will be displayed and saved 187 | if fig_filename is provided. 188 | """ 189 | if len(orientations.shape) == 1: 190 | orientations = orientations.unsqueeze(0) 191 | 192 | def update_lines(num, data, lines): 193 | positions = data[num] 194 | 195 | for i, line in enumerate(lines): 196 | xs = list(positions[self.skeleton_tree[i], 0]) 197 | ys = list(positions[self.skeleton_tree[i], 1]) 198 | zs = list(positions[self.skeleton_tree[i], 2]) 199 | 200 | line.set_data(xs, ys) 201 | line.set_3d_properties(zs) 202 | return lines 203 | 204 | fig = plt.figure(figsize=(orientations.shape[0]*3, 3)) 205 | data = self.forward_kinematics(orientations) 206 | 207 | limits = [-1.0, 1.0] 208 | for i in range(orientations.shape[0]): 209 | ax = fig.add_subplot(1, 210 | orientations.shape[0], 211 | i+1, 212 | projection="3d") 213 | 214 | lines = [ax.plot([0], [0], [0])[0] for _ in range(6)] 215 | 216 | self._setup_axis(ax, limits, azim, elev) 217 | 218 | if titles is not None: 219 | ax.set_title(titles[i]) 220 | 221 | update_lines(i, data, lines) 222 | 223 | plt.subplots_adjust(wspace=0) 224 | plt.show() 225 | if fig_filename: 226 | fig.savefig(fig_filename, bbox_inches="tight") 227 | return fig 228 | 229 | def _setup_axis(self, ax, limits, azim, elev): 230 | ax.set_xlim3d(limits) 231 | ax.set_ylim3d(limits) 232 | ax.set_zlim3d(limits) 233 | 234 | ax.grid(False) 235 | 236 | ax.set_xticks([]) 237 | ax.set_yticks([]) 238 | ax.set_zticks([]) 239 | 240 | ax.view_init(azim=azim, elev=elev) 241 | -------------------------------------------------------------------------------- /src/matlab/get_folder_path.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function [parentpath] = get_folder_path() 9 | % GET_FOLDER_PATH Returns the path to the parent folder. 10 | % 11 | % parentPath = get_folder_path() returns the path to the parent folder 12 | % for ease of use in other files. 13 | % 14 | % See also MVNX_TO_CSV 15 | parentpath = cd(cd('..')); 16 | end -------------------------------------------------------------------------------- /src/matlab/joint_angle_segments.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function [jointAngleSegments] = joint_angle_segments() 9 | % JOINT_ANGLE_SEGMENTS Returns a cell array containing the joint angles of 10 | % the XSens data file in proper order. 11 | % 12 | % jointAngleSegments = JOINT_ANGLE_SEGMENTS() is used elsewhere to 13 | % construct joint angle maps and fill in data. 14 | % 15 | % See also INIT_JOINT_ANGLE_MAP, READ_DATA 16 | % 17 | jointAngleSegments = {'jL5S1','jL4L3','jL1T12','jT9T8','jT1C7','jC1Head','jRightT4Shoulder',... 18 | 'jRightShoulder','jRightElbow','jRightWrist','jLeftT4Shoulder','jLeftShoulder','jLeftElbow',... 19 | 'jLeftWrist','jRightHip','jRightKnee','jRightAnkle','jRightBallFoot','jLeftHip','jLeftKnee','jLeftAnkle',... 20 | 'jLeftBallFoot'}; 21 | end -------------------------------------------------------------------------------- /src/matlab/joint_angles.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function[matrix]= joint_angles(values, frames, index) 9 | % JDEG Re-shuffles each joint angle in data to a joint 10 | % 11 | % matrix = JDEG(values, frames, index) will associate flexion, abduction, 12 | % and extension to the proper joint. 13 | % 14 | % See also READ_DATA 15 | matrix= zeros(frames, 3); 16 | for i= 1:frames 17 | matrix(i, :)= values(i).jointAngle(index:index+2); 18 | end 19 | end -------------------------------------------------------------------------------- /src/matlab/load_partial_mvnx.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function [data,time,cal_data,extra] = load_partial_mvnx(filename,section,varargin) 9 | % This function allows matlab to read in larger mvnx files since it does 10 | % not try to load the entire file into memory which can be several GB 11 | % large. 12 | % 13 | % 14 | % data = load_partial_mvnx(filename,section) 15 | % Loads all of the frame data from the fields specified in section, with 16 | % multiple fields passed as a cell array. 17 | % 18 | % data = load_partial_mvnx(filename,section,n) 19 | % Loads n frames of the data. Where n<=0 and n=0 indicates to 20 | % read to the end of the file. If n is greater than the number of frames 21 | % then the data is read until the end of the file. 22 | % 23 | % data = load_partial_mvnx(filename,section,m,n) 24 | % Loads n frames of the data starting at frame m. 25 | % 26 | % Inputs 27 | % filename - Type:string, path to the file to be read 28 | % 29 | % section - Types:string,cell array, labels of data to read 30 | % Valid value for field are 'segments','joints','orientation','position','velocity', 31 | % 'acceleration','angularVelocity','angularAcceleration','jointAngle','jointAngleXYZ' 32 | % 33 | % 'segments' and 'joints' cannot be used in the cell array. 34 | % Order does not matter in the cell array 35 | % Example section = 'position' or {'acceleration','position',...} 36 | % 37 | % m,n - Type:integers, m>=0, n>=0 38 | % 39 | % Outputs 40 | % data - Type:Varies based on arguements 41 | % 42 | % time - Type:struct time values for the frames, output only for sections 43 | % 'orientation','position','velocity','acceleration','angularVelocity', 44 | % 'angularAcceleration','jointAngle','jointAngleXYZ' 45 | % 46 | % cal_data - Type:struct calibration pose data, output only for sections 47 | % 'orientation','position','velocity','acceleration','angularVelocity', 48 | % 'angularAcceleration','jointAngle','jointAngleXYZ' 49 | % 50 | % extra - Type:struct other data from sections 'orientation','position', 51 | % 'velocity','acceleration','angularVelocity','angularAcceleration', 52 | % 'jointAngle','jointAngleXYZ' 53 | 54 | 55 | % preallocate some space probably will not be enough 56 | N=1000000; 57 | cells=cell(N,1); 58 | 59 | % number of calibration poses, aka frames with type other than 'normal', 60 | % this might need to be made variable in the future 61 | n_cal = 3; 62 | 63 | % error handling and setting up initial parameters 64 | frame_data = {'orientation', 'position', 'velocity', 'acceleration', 'angularVelocity', ... 65 | 'angularAcceleration', 'footContacts', 'sensorFreeAcceleration', ... 66 | 'sensorMagneticField', 'sensorOrientation', 'jointAngle', 'jointAngleXZY', ... 67 | 'jointAngleErgo', 'jointAngleErgoXZY', 'centerOfMass'}; 68 | 69 | if iscell(section) 70 | parfor i=1:length(section) 71 | f_test(i)=any(strcmp(frame_data, section{i})); 72 | end 73 | else 74 | f_test=any(strcmp(frame_data, section)); 75 | end 76 | 77 | if iscell(section) 78 | ME=MException('Input:error','Supplied arg %s %s %s\nSyntax: load_partial_mvnx(filename,section,[[limit_start],][limit_stop])\nSupported section values are:\n Singular Arguments - segments, joints\n These can be applied individually or as a cell array - orientation, position, velocity, acceleration, angularVelocity, angularAcceleration, jointAngle, jointAngleXYZ',filename,section{:},varargin{:}); 79 | else 80 | ME=MException('Input:error','Supplied arg %s %s %s\nSyntax: load_partial_mvnx(filename,section,[[limit_start],][limit_stop])\nSupported section values are:\n Singular Arguments - segments, joints\n These can be applied individually or as a cell array - orientation, position, velocity, acceleration, angularVelocity, angularAcceleration, jointAngle, jointAngleXYZ',filename,section,varargin{:}); 81 | end 82 | LE=MException('Limit:error','0<=limit_start<=limit_end'); 83 | 84 | if nargin-length(varargin)<2 || nargin>4 85 | throw(ME) 86 | end 87 | run_cal=1; 88 | % set up to get calibration data 89 | if length(varargin)==1 && ischar(varargin{1}) 90 | limit_start=0; 91 | limit_stop=n_cal; 92 | section={'orientation','position'}; 93 | f_test=1; 94 | run_cal=0; 95 | %set up to get segment/joint info 96 | elseif ~iscell(section) && (strcmp(section,'segments') || strcmp(section,'joints')) 97 | limit_start=0; 98 | limit_stop=0; 99 | % set converting limits for all other formats 100 | elseif length(varargin)==1 101 | if varargin{1}<0 102 | throw(LE) 103 | end 104 | limit_start=n_cal; 105 | limit_stop=varargin{1}+n_cal; 106 | elseif length(varargin)==2 107 | if varargin{1}<0 || varargin{2}',''}; 122 | en = {'','',''}; 123 | stop = ''; 124 | st_depth = 1; 125 | en_depth = 0; 126 | max_depth=length(st); 127 | elseif ~iscell(section) && strcmp(section,'joints') 128 | start = ''; 129 | st = {'',''}; 131 | en = {''}; 132 | stop = ''; 133 | st_depth = 1; 134 | en_depth = 0; 135 | max_depth=length(st); 136 | elseif all(f_test) && ~any(strcmp('segments', section)) && ~any(strcmp('joints', section)) 137 | start = ''}; 145 | stop = ''; 146 | st_depth = 1; 147 | en_depth = 0; 148 | max_depth=length(st); 149 | else 150 | throw(ME); 151 | end 152 | 153 | % Check file name 154 | if isempty(strfind(filename,'mvnx')) 155 | filename = [get_folder_path() '\' filename '.mvnx']; 156 | end 157 | if ~exist(filename,'file') 158 | error([mfilename ':xsens:filename'],['No file with filename: ' filename ', file is not present or file has wrong format (function only reads .mvnx)']) 159 | end 160 | 161 | 162 | % Open file 163 | disp('Reading file'); 164 | fid = fopen(filename, 'r', 'n', 'UTF-8'); 165 | 166 | l = fgetl(fid); 167 | k = 1; 168 | 169 | % find start 170 | while isempty(strfind(l,start)) 171 | l = fgetl(fid); 172 | end 173 | cells{k} = l; 174 | k=k+1; 175 | l = fgetl(fid); 176 | count = 0; 177 | while ~feof(fid) 178 | if ~isempty(strfind(l,stop)) 179 | cells{k} = l; 180 | k = k+1; 181 | break 182 | end 183 | % test for the start key at the current depth 184 | found = 0; 185 | if st_depth<=max_depth && ~isempty(strfind(l,st{st_depth})) 186 | found = 1; 187 | if count>=limit_start 188 | cells{k} = l; 189 | k = k+1; 190 | end 191 | st_depth = st_depth+1; 192 | en_depth = en_depth+1; 193 | elseif st_depth>max_depth 194 | for i=1:length(key) 195 | if ~isempty(strfind(l,key{i})) 196 | if count>=limit_start 197 | cells{k} = l; 198 | k = k+1; 199 | end 200 | break 201 | end 202 | end 203 | end 204 | % if a start key was not fould 205 | if en_depth>0 && ~isempty(strfind(l,en{en_depth})) 206 | if found==0 207 | if count>=limit_start 208 | cells{k} = l; 209 | k = k+1; 210 | end 211 | if en_depth == 1 212 | count=count+1; 213 | end 214 | st_depth = st_depth-1; 215 | en_depth = en_depth-1; 216 | else 217 | if en_depth == 1 218 | count=count+1; 219 | end 220 | st_depth = st_depth-1; 221 | en_depth = en_depth-1; 222 | end 223 | end 224 | if limit_stop~=0 && count>=limit_stop 225 | cells{k}=stop; 226 | break 227 | end 228 | l = fgetl(fid); 229 | end 230 | if k=8 && strcmp('comment',line(2:8)) 273 | % add exception for comment 274 | word{n} = line(2:hooks(1)-1); 275 | iLine = hooks(1)+1; 276 | elseif openandclose(n) 277 | word{n} = line(2:find(line==62,1)-1); 278 | iLine = find(line==62,1)+1; 279 | else 280 | word{n} = line(2:end-1); 281 | oneword = true; 282 | end 283 | if word{n}(1) ~= '/' 284 | if ~oneword && ~openandclose(n) 285 | k = find(line == 34); 286 | k = reshape(k,2,length(k)/2)'; 287 | l = [iLine find(line(iLine:end) == 61)+iLine-2]; 288 | fieldname = cell(1,length(l)-1); value = cell(1,length(l)-1); 289 | if ~isempty(k) 290 | for il=1:size(k,1) 291 | fieldname{il} = line(iLine:find(line(iLine:end) == 61,1)+iLine-2); 292 | if size(k,1) > 1 && il < size(k,1) 293 | a = strfind(line(iLine:end),'" ')+iLine+1; 294 | iLine = a(1); 295 | end 296 | value{il} = line(k(il,1)+1:k(il,2)-1); 297 | end 298 | else 299 | value = []; fieldname =[]; 300 | value = line(find(line == 62,1)+1:end); 301 | end 302 | elseif ~oneword && openandclose(n) 303 | value = []; fieldname =[]; 304 | value = line(find(line == 62,1)+1:find(line==60,1,'last')-1); 305 | else 306 | value = NaN;fieldname = []; 307 | end 308 | wordvalue{n} = value; 309 | wordfields{n} = fieldname; 310 | end 311 | end 312 | %% get values 313 | parfor n=1:length(wordvalue) 314 | if iscell(wordvalue{n}) 315 | if length(wordvalue{n}) == 1 316 | B = []; 317 | try 318 | B = str2double(wordvalue{n}{1}); 319 | end 320 | if ~isempty(B) 321 | wordvalue{n} = B; 322 | else 323 | wordvalue{n} = wordvalue{n}{1}; 324 | end 325 | else 326 | for m=1:length(wordvalue{n}) 327 | try 328 | B = str2double(wordvalue{n}{m}); 329 | if ~isempty(B) 330 | wordvalue{n}{m} = B; 331 | end 332 | end 333 | end 334 | end 335 | else 336 | try 337 | B = str2num(wordvalue{n}); 338 | if ~isempty(B) 339 | wordvalue{n} = B; 340 | end 341 | end 342 | end 343 | end 344 | %% 345 | disp('Logging data'); 346 | % form tree structure for segment and joint data 347 | if ~iscell(section) && (strcmp(section,'segments') || strcmp(section,'joints')) 348 | data=rec_struct(word,wordfields,wordvalue,wordindex,1,length(section)); 349 | extra = []; 350 | cal_data = []; 351 | time=[]; 352 | else 353 | %store extra values and get calibration data 354 | extra = struct(wordfields{1,1}{1},wordvalue{1,1}{1},wordfields{1,1}{2},wordvalue{1,1}{2}); 355 | if run_cal 356 | cal_data = load_partial_mvnx(filename,section,'cal'); 357 | else 358 | cal_data=[]; 359 | end 360 | 361 | f = strcmp(word,'frame'); 362 | fi=find(f); 363 | n_f = length(wordfields{fi(1)}); 364 | fr_lab = wordfields{f}; 365 | 366 | n_k=length(key); 367 | 368 | for i=1:length(key) 369 | k = strcmp(word,key{i}); 370 | kf = find(k); 371 | kc{i}=kf; 372 | end 373 | 374 | %preallocate 375 | val=cell(length(fi),n_f); 376 | val2=cell(length(fi),n_k); 377 | 378 | for i=1:length(fi) 379 | for j=1:n_f 380 | if isnumeric(wordvalue{fi(i)}{j}) 381 | val{i,j}=wordvalue{fi(i)}{j}; 382 | elseif ischar(wordvalue{fi(i)}{j}) 383 | val{i,j}=wordvalue{fi(i)}{j}; 384 | else 385 | disp('error'); 386 | end 387 | end 388 | for j=1:n_k 389 | val2{i,j}=wordvalue{kc{j}(i)}; 390 | end 391 | end 392 | 393 | % put data into structures 394 | time=form_struct(fr_lab,val); 395 | data=form_struct(key,val2); 396 | 397 | end 398 | end 399 | function data = form_struct(lab,val) 400 | switch length(lab) 401 | case 1 402 | data=struct(lab{1},val(:,1)); 403 | case 2 404 | data=struct(lab{1},val(:,1),lab{2},val(:,2)); 405 | case 3 406 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3)); 407 | case 4 408 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4)); 409 | case 5 410 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4),lab{5}, val(:,5)); 411 | case 6 412 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4),lab{5},val(:,5),... 413 | lab{6},val(:,6)); 414 | case 7 415 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4),lab{5},val(:,5),... 416 | lab{6},val(:,6),lab{7},val(:,7)); 417 | case 8 418 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4),lab{5},val(:,5),... 419 | lab{6},val(:,6),lab{7},val(:,7),lab{8},val(:,8)); 420 | otherwise 421 | disp('data format error'); 422 | end 423 | end 424 | % Recursively builds up structs for segments and joints 425 | function [s,j] = rec_struct(word,field,value,index,i,len) 426 | in=index{i}; 427 | numfields = length(field{i}); 428 | j=i+1; 429 | k=1; 430 | si={}; 431 | while index{j}>in 432 | [sn,j]=rec_struct(word,field,value,index,j,len); 433 | if ~isempty(sn) 434 | si{k}=sn; 435 | k=k+1; 436 | end 437 | end 438 | if isempty(si) 439 | si=value{i}; 440 | end 441 | switch numfields 442 | case 0 443 | s=si; 444 | case 1 445 | s=struct(char(field{i}),value{i},word{i+1},si); 446 | case 2 447 | s=struct(field{i}{1},value{i}{1},field{i}{2},value{i}{2},word{i+1},si); 448 | case 3 449 | s=struct(field{i}{1},value{i}{1},field{i}{2},value{i}{2},field{i}{3},value{i}{3},word{i+1},si); 450 | case 4 451 | s=struct(field{i}{1},value{i}{1},field{i}{2},value{i}{2},field{i}{3},value{i}{3},field{i}{4},value{i}{4},word{i+1},si); 452 | case 5 453 | s=struct(field{i}{1},value{i}{1},field{i}{2},value{i}{2},field{i}{3},value{i}{3},field{i}{4},value{i}{4},field{i}{5},value{i}{5},word{i+1},si); 454 | otherwise 455 | disp('error') 456 | end 457 | end 458 | -------------------------------------------------------------------------------- /src/matlab/mvnx_to_csv.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function data = mvnx_to_csv(file) 9 | %MVNX_TO_CSV File for reading in .mvnx files and converting them to .csv files. 10 | % Reading the .mvnx files is expensive, so this converts orientation, position, and 11 | % joint angle data to csv format. 12 | 13 | filename = [get_folder_path(), '/mvnx-files/', file]; 14 | 15 | [data, ~, ~, ~] = load_partial_mvnx(filename, {'orientation', 'position', 'jointAngle'}); 16 | 17 | csvData = zeros(size(data,1), 227); 18 | 19 | for i = 1:size(data,1) 20 | 21 | csvData(i, :) = [data(i).orientation(:)' data(i).position(:)' data(i).jointAngle(:)']; 22 | 23 | end 24 | output = [get_folder_path(), '/', 'csv-files', '/', file(1:end-5), '.csv']; 25 | 26 | csvwrite(output, csvData); 27 | 28 | end -------------------------------------------------------------------------------- /src/matlab/mvnx_to_hdf.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function mvnx_to_hdf(file) 9 | 10 | input_file = [get_folder_path(), '/mvnx-files/', file]; 11 | output_file = [get_folder_path(), '/h5-files/', file(1:end-5), '.h5']; 12 | 13 | dataset_names = {'orientation', 'position', 'velocity', 'acceleration', 'angularVelocity', ... 14 | 'angularAcceleration', 'footContacts', 'sensorFreeAcceleration', ... 15 | 'sensorMagneticField', 'sensorOrientation', 'jointAngle', 'jointAngleXZY', ... 16 | 'jointAngleErgo', 'jointAngleErgoXZY', 'centerOfMass'}; 17 | 18 | disp(input_file); 19 | 20 | tic 21 | for i = 1:length(dataset_names) 22 | [data, ~, ~, ~] = load_partial_mvnx(input_file, dataset_names(i)); 23 | data = struct2cell(data); 24 | 25 | dataset = cell2mat(data); 26 | dataset = reshape(dataset, length(data), length(dataset)/length(data)); 27 | 28 | h5create(output_file, ['/', dataset_names{i}], size(dataset)); 29 | h5write(output_file, ['/', dataset_names{i}], dataset); 30 | end 31 | toc 32 | 33 | h5disp(output_file); 34 | end 35 | -------------------------------------------------------------------------------- /src/matlab/mvnx_to_hdf_batch.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | clear 9 | 10 | fprintf('Starting job'); 11 | % 12 | % BATCH defines the job and sends it for execution. 13 | % 14 | job = batch ('process_mvnx_files', 'Profile', 'dtlogin_R2018a', 'AttachedFiles', {'get_folder_path', 'joint_angle_segments', 'joint_angles', ... 15 | 'load_partial_mvnx', 'mvnx_to_hdf', ... 16 | 'segment_orientation', 'segment_position', 'segment_reference' }, ... 17 | 'CurrentFolder', '.', 'Pool', 9); 18 | % 19 | % WAIT pauses the MATLAB session til the job completes. 20 | % 21 | wait (job); 22 | 23 | % 24 | % DIARY displays any messages printed during execution. 25 | % 26 | diary (job); 27 | 28 | % 29 | % These commands clean up data about the job we no longer need. 30 | delete ( job ); %Use delete() for R2012a or later 31 | 32 | fprintf(1, '\n'); 33 | fprintf(1, 'Done'); 34 | -------------------------------------------------------------------------------- /src/matlab/process_mvnx_files.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | files = dir([get_folder_path(), '/mvnx-files/*.mvnx']); 9 | i = 0; 10 | 11 | for file = files' 12 | i = i + 1; 13 | disp(i); 14 | mvnx_to_hdf(file.name); 15 | end 16 | -------------------------------------------------------------------------------- /src/matlab/segment_orientation.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function[matrix]= segment_orientation(values, frames, index) 9 | % SEGMENT_ORIENTATION Re-shuffles each quaternion in data to a segment 10 | % 11 | % matrix = SEGMENT_ORIENTATION(values, frames, index) will associate 4 12 | % parts of quaternion to the proper segment using XSens' 13 | % segmentation indices. 14 | % 15 | % See also READ_DATA, GET_ROLL, GET_PITCH, GET_YAW 16 | matrix= zeros(frames, 4); 17 | for i= 1:frames 18 | matrix(i, :)= values(i).orientation(index:index+3); 19 | end 20 | end -------------------------------------------------------------------------------- /src/matlab/segment_position.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function[matrix]= segment_position(values, frames, index) 9 | % SEGMENT_POSITION Re-shuffles each x,y,z position in data to a segment 10 | % 11 | % matrix = SEGMENT_POSITION(values, frames, index) will associate 3 12 | % parts of position to the proper segment using XSens' 13 | % segmentation indices. 14 | % 15 | % See also READ_DATA 16 | matrix= zeros(frames, 3); 17 | for i= 1:frames 18 | matrix(i, :)= values(i).position(index:index + 2); 19 | end 20 | end -------------------------------------------------------------------------------- /src/matlab/segment_reference.m: -------------------------------------------------------------------------------- 1 | % Copyright (c) 2020-present, Assistive Robotics Lab 2 | % All rights reserved. 3 | % 4 | % This source code is licensed under the license found in the 5 | % LICENSE file in the root directory of this source tree. 6 | % 7 | 8 | function [segmentOrientationMap, segmentPositionMap, jointMap]= segment_reference 9 | % SEGREF Places all of the segments and joints into maps for 10 | % accessing. Necessary for reading data from MVNX files. 11 | % 12 | % [segor_map, segpos_map, joint_map] = segRef will return a map for 13 | % segment orientation, segment position, and joint angles. These maps can 14 | % then be used to access the proper index in the XSens data for 15 | % processing. 16 | % 17 | % See also READ_DATA, LOAD_PARTIAL_MVNX 18 | 19 | orientationIndices = {1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 65, 69, 73, 77, 81, 85, 89}; 20 | positionIndices = {1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64, 67}; 21 | jointIndices = {1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64}; 22 | 23 | segmentKeys= {'Pelvis';'L5';'L3';'T12';'T8';'Neck';'Head';'RightShoulder';'RightUpperArm';'RightForeArm';'RightHand';'LeftShoulder';'LeftUpperArm';'LeftForeArm';'LeftHand';'RightUpperLeg';'RightLowerLeg';'RightFoot';'RightToe';'LeftUpperLeg';'LeftLowerLeg';'LeftFoot';'LeftToe'}; 24 | jointKeys= {'jL5S1';'jL4L3';'jL1T12';'jT9T8';'jT1C7';'jC1Head';'jRightC7Shoulder';'jRightShoulder';'jRightElbow';'jRightWrist';'jLeftC7Shoulder';'jLeftShoulder';'jLeftElbow';'jLeftWrist';'jRightHip';'jRightKnee';'jRightAnkle';'jRightBallFoot';'jLeftHip';'jLeftKnee';'jLeftAnkle';'jLeftBallFoot'}; 25 | 26 | segmentOrientationMap = containers.Map(segmentKeys, orientationIndices); 27 | segmentPositionMap = containers.Map(segmentKeys, positionIndices); 28 | jointMap = containers.Map(jointKeys, jointIndices); 29 | -------------------------------------------------------------------------------- /src/seq2seq/seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Seq2Seq Encoders and Decoders. 10 | 11 | Reference: 12 | [1] https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html 13 | """ 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class EncoderRNN(nn.Module): 20 | """An encoder for seq2seq architectures.""" 21 | 22 | def __init__(self, input_size, hidden_size, 23 | dropout=0.0, bidirectional=False): 24 | """Initialize encoder for use with decoder in seq2seq architecture. 25 | 26 | Args: 27 | input_size (int): Number of features in the input 28 | hidden_size (int): Number of hidden units in the GRU layer 29 | dropout (float, optional): Dropout applied after GRU layer. 30 | Defaults to 0.0. 31 | bidirectional (bool, optional): Whether encoder is bidirectional. 32 | Defaults to False. 33 | """ 34 | super(EncoderRNN, self).__init__() 35 | 36 | self.hidden_size = hidden_size 37 | self.bidirectional = bidirectional 38 | 39 | self.gru = nn.GRU(input_size, hidden_size, bidirectional=bidirectional) 40 | self.dropout = nn.Dropout(dropout) 41 | 42 | self.directions = 2 if bidirectional else 1 43 | self.fc = nn.Linear(hidden_size * self.directions, hidden_size) 44 | 45 | def forward(self, input): 46 | """Forward pass through encoder. 47 | 48 | Args: 49 | input (torch.tensor): (seq_len, 50 | batch_size, 51 | input_size) 52 | 53 | Returns: 54 | tuple: Returns output and hidden state of decoder. 55 | output (torch.Tensor): (seq_len, 56 | batch_size, 57 | directions*hidden_size) 58 | hidden (torch.Tensor): (1, batch_size, hidden_size) 59 | """ 60 | output, hidden = self.gru(input) 61 | output = self.dropout(output) 62 | 63 | if self.bidirectional: 64 | hidden = torch.tanh( 65 | self.fc(torch.cat((hidden[-2, :, :], 66 | hidden[-1, :, :]), dim=1))).unsqueeze(0) 67 | 68 | return output, hidden 69 | 70 | 71 | class DecoderRNN(nn.Module): 72 | """A decoder for use in seq2seq architectures.""" 73 | 74 | def __init__(self, input_size, hidden_size, output_size, dropout=0.0): 75 | """Initialize DecoderRNN. 76 | 77 | Args: 78 | input_size (int): number of features in input 79 | hidden_size (int): number of hidden units in GRU layer 80 | output_size (int): number of features in output 81 | dropout (float, optional): Dropout applied after GRU layer. 82 | Defaults to 0.0. 83 | """ 84 | super(DecoderRNN, self).__init__() 85 | self.hidden_size = hidden_size 86 | 87 | self.gru = nn.GRU(input_size, hidden_size) 88 | self.dropout = nn.Dropout(dropout) 89 | self.out = nn.Linear(hidden_size, output_size) 90 | 91 | def forward(self, input, hidden): 92 | """Forward pass through decoder. 93 | 94 | Args: 95 | input (torch.Tensor): input batch to pass through RNN 96 | (1, batch_size, input_size) 97 | hidden (torch.Tensor): hidden state of the decoder 98 | (1, batch_size, hidden_size) 99 | 100 | Returns: 101 | tuple: Returns output and hidden state of decoder. 102 | output (torch.Tensor): (1, batch_size, output_size) 103 | hidden (torch.Tensor): (1, batch_size, hidden_size) 104 | """ 105 | output, hidden = self.gru(input, hidden) 106 | output = self.dropout(output) 107 | output = self.out(output) 108 | return output, hidden 109 | 110 | 111 | class AttnDecoderRNN(nn.Module): 112 | """A decoder with an attention layer for use in seq2seq architectures.""" 113 | 114 | def __init__(self, output_size, feature_size, enc_hidden_size, 115 | dec_hidden_size, attention, bidirectional_encoder=False): 116 | """Initialize AttnDecoderRNN. 117 | 118 | Args: 119 | output_size (int): size of output 120 | features_size (int): number of features in input batch 121 | enc_hidden_size (int): hidden size of encoder 122 | dec_hidden_size (int): hidden size of decoder 123 | attention (str): attention method for use in Attention layer 124 | bidirectional_encoder (bool, optional): Whether encoder used with 125 | decoder is bidirectional. Defaults to False. 126 | """ 127 | super().__init__() 128 | 129 | self.enc_hidden_size = enc_hidden_size 130 | self.dec_hidden_size = dec_hidden_size 131 | self.output_size = output_size 132 | self.feature_size = feature_size 133 | self.attention = attention 134 | 135 | self.directions = 2 if bidirectional_encoder else 1 136 | 137 | self.rnn = nn.GRU(self.directions * enc_hidden_size + feature_size, 138 | dec_hidden_size) 139 | 140 | self.out = nn.Linear(self.directions * enc_hidden_size + 141 | dec_hidden_size + feature_size, 142 | output_size) 143 | 144 | def forward(self, input, hidden, annotations): 145 | """Forward pass through decoder. 146 | 147 | Args: 148 | input (torch.Tensor): (1, batch_size, feature_size) 149 | hidden (torch.Tensor): (1, batch_size, dec_hidden_size) 150 | annotations (torch.Tensor): (seq_len, 151 | batch_size, 152 | directions * enc_hidden_size) 153 | 154 | Returns: 155 | tuple: Returns output and hidden state of decoder. 156 | output (torch.Tensor): (1, batch_size, output_size) 157 | hidden (torch.Tensor): (1, batch_size, dec_hidden_size) 158 | """ 159 | attention = self.attention(hidden, annotations) 160 | 161 | attention = attention.unsqueeze(1) 162 | 163 | annotations = annotations.permute(1, 0, 2) 164 | 165 | context_vector = torch.bmm(attention, annotations) 166 | 167 | context_vector = context_vector.permute(1, 0, 2) 168 | 169 | rnn_input = torch.cat((input, context_vector), dim=2) 170 | 171 | output, hidden = self.rnn(rnn_input, hidden) 172 | 173 | # assert torch.all(torch.isclose(output, hidden)) 174 | 175 | input = input.squeeze(0) 176 | output = output.squeeze(0) 177 | context_vector = context_vector.squeeze(0) 178 | 179 | output = self.out(torch.cat((output, context_vector, input), dim=1)) 180 | 181 | # output = [batch_size, output_size] 182 | 183 | return output.unsqueeze(0), hidden 184 | 185 | 186 | class Attention(nn.Module): 187 | """An Attention layer for the AttnDecoder with multiple methods.""" 188 | 189 | def __init__(self, hidden_size, batch_size, 190 | method, bidirectional_encoder=False): 191 | """Initialize Attention layer. 192 | 193 | Args: 194 | hidden_size (int): Size of hidden state in decoder. 195 | batch_size (int): Size of batch, used for shape checks. 196 | method (str): Attention technique/method to use. Supports 197 | general, biased-general, activated-general, dot, add, and 198 | concat. 199 | bidirectional_encoder (bool, optional): Whether encoder used with 200 | decoder is bidirectional. Defaults to False. 201 | """ 202 | super().__init__() 203 | 204 | self.batch_size = batch_size 205 | self.hidden_size = hidden_size 206 | self.method = method 207 | self.directions = 2 if bidirectional_encoder else 1 208 | 209 | if method in ["general", "biased-general", "activated-general"]: 210 | bias = not(method == "general") 211 | self.Wa = nn.Linear(hidden_size, 212 | self.directions * hidden_size, 213 | bias=bias) 214 | elif method == "add": 215 | self.Wa = nn.Linear((self.directions * hidden_size), 216 | hidden_size, 217 | bias=False) 218 | self.Wb = nn.Linear(hidden_size, 219 | hidden_size, 220 | bias=False) 221 | self.va = nn.Parameter(torch.rand(hidden_size)) 222 | elif method == "concat": 223 | self.Wa = nn.Linear((self.directions * hidden_size) + hidden_size, 224 | hidden_size, bias=False) 225 | self.va = nn.Parameter(torch.rand(hidden_size)) 226 | 227 | def forward(self, hidden, annotations): 228 | """Forward pass through attention layer. 229 | 230 | Args: 231 | hidden (torch.Tensor): (1, batch_size, hidden_size) 232 | annotations (torch.Tensor): (seq_len, 233 | batch_size, 234 | directions * hidden_size) 235 | 236 | Returns: 237 | torch.Tensor: (batch_size, seq_len) 238 | """ 239 | assert list(hidden.shape) == [1, self.batch_size, self.hidden_size] 240 | 241 | assert self.batch_size == annotations.shape[1] 242 | assert self.directions * self.hidden_size == annotations.shape[2] 243 | self.seq_len = annotations.shape[0] 244 | 245 | hidden = hidden.squeeze(0) 246 | 247 | assert list(hidden.shape) == [self.batch_size, self.hidden_size] 248 | 249 | annotations = annotations.permute(1, 0, 2) 250 | 251 | assert list(annotations.shape) == [self.batch_size, 252 | self.seq_len, 253 | self.directions * self.hidden_size] 254 | 255 | score = self._score(hidden, annotations) 256 | 257 | assert list(score.shape) == [self.batch_size, self.seq_len] 258 | 259 | return F.softmax(score, dim=1) 260 | 261 | def _score(self, hidden, annotations): 262 | """Compute an attention score with hidden state and annotations. 263 | 264 | Args: 265 | hidden (torch.Tensor): (batch_size, hidden_size) 266 | annotations (torch.Tensor): (batch_size, 267 | seq_len, 268 | directions * hidden_size) 269 | 270 | Returns: 271 | torch.Tensor: (batch_size, seq_len) 272 | """ 273 | if "general" in self.method: 274 | x = self.Wa(hidden) 275 | 276 | x = x.unsqueeze(-1) 277 | 278 | score = annotations.bmm(x) 279 | 280 | if self.method == "activated-general": 281 | score = torch.tanh(score) 282 | 283 | assert list(score.shape) == [self.batch_size, 284 | self.seq_len, 285 | 1] 286 | 287 | score = score.squeeze(-1) 288 | 289 | return score 290 | 291 | elif self.method == "dot": 292 | hidden = hidden.unsqueeze(-1) 293 | 294 | hidden = hidden.repeat(1, self.directions, 1) 295 | 296 | score = annotations.bmm(hidden) 297 | 298 | assert list(score.shape) == [self.batch_size, 299 | self.seq_len, 300 | 1] 301 | 302 | score = score.squeeze(-1) 303 | 304 | return score 305 | 306 | elif self.method == "add": 307 | x1 = self.Wa(annotations) 308 | 309 | x2 = self.Wb(hidden) 310 | 311 | x2 = x2.unsqueeze(1) 312 | 313 | x2 = x2.repeat(1, self.seq_len, 1) 314 | 315 | energy = x1 + x2 316 | 317 | energy = energy.permute(0, 2, 1) 318 | 319 | assert list(energy.shape) == [self.batch_size, 320 | self.hidden_size, 321 | self.seq_len] 322 | 323 | va = self.va.repeat(self.batch_size, 1).unsqueeze(1) 324 | 325 | score = torch.bmm(va, energy) 326 | 327 | assert list(score.shape) == [self.batch_size, 328 | 1, 329 | self.seq_len] 330 | 331 | score = score.squeeze(1) 332 | 333 | return score 334 | 335 | elif self.method == "concat": 336 | hidden = hidden.unsqueeze(1) 337 | 338 | hidden = hidden.repeat(1, self.seq_len, 1) 339 | 340 | energy = torch.tanh(self.Wa(torch.cat((hidden, annotations), 2))) 341 | 342 | energy = energy.permute(0, 2, 1) 343 | 344 | assert list(energy.shape) == [self.batch_size, 345 | self.hidden_size, 346 | self.seq_len] 347 | 348 | va = self.va.repeat(self.batch_size, 1).unsqueeze(1) 349 | 350 | assert list(va.shape) == [self.batch_size, 351 | 1, 352 | self.hidden_size] 353 | 354 | score = torch.bmm(va, energy) 355 | 356 | assert list(score.shape) == [self.batch_size, 357 | 1, 358 | self.seq_len] 359 | 360 | score = score.squeeze(1) 361 | 362 | return score 363 | -------------------------------------------------------------------------------- /src/seq2seq/training_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import random 11 | import time 12 | import math 13 | import sys 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from common.logging import logger 17 | from .seq2seq import ( 18 | EncoderRNN, 19 | DecoderRNN, 20 | AttnDecoderRNN, 21 | Attention 22 | ) 23 | plt.switch_backend("agg") 24 | torch.manual_seed(0) 25 | 26 | 27 | class Timer: 28 | def __enter__(self): 29 | self.start = time.time() 30 | return self 31 | 32 | def __exit__(self, *args): 33 | self.end = time.time() 34 | self.interval = self.end - self.start 35 | 36 | 37 | def get_encoder(feature_size, device, hidden_size=64, 38 | dropout=0.0, bidirectional=False): 39 | """Function to help set up an encoder. 40 | 41 | Args: 42 | feature_size (int): number of features in the input to this model 43 | device (torch.device): what device to put this encoder on 44 | hidden_size (int, optional): number of hidden units in the encoder. 45 | Defaults to 64. 46 | dropout (float, optional): dropout to apply in the encoder. 47 | Defaults to 0.0. 48 | bidirectional (bool, optional): whether to use a bidirectional encoder. 49 | Defaults to False. 50 | 51 | Returns: 52 | EncoderRNN: an encoder for use in seq2seq tasks 53 | """ 54 | encoder = EncoderRNN(feature_size, hidden_size, 55 | dropout=dropout, bidirectional=bidirectional) 56 | encoder = encoder.double().to(device) 57 | return encoder 58 | 59 | 60 | def get_decoder(feature_size, device, hidden_size=64, dropout=0.0): 61 | """Function to help set up a decoder. 62 | 63 | Args: 64 | feature_size (int): number of features in the input to this model 65 | device (torch.device): what device to put this encoder on 66 | hidden_size (int, optional): number of hidden units in the encoder. 67 | Defaults to 64. 68 | dropout (float, optional): dropout to apply in the encoder. 69 | Defaults to 0.0. 70 | 71 | Returns: 72 | DecoderRNN: a decoder for use in seq2seq tasks 73 | """ 74 | decoder = DecoderRNN(feature_size, hidden_size, 75 | feature_size, dropout=dropout) 76 | decoder = decoder.double().to(device) 77 | return decoder 78 | 79 | 80 | def get_attn_decoder(feature_size, method, device, batch_size=32, 81 | hidden_size=64, bidirectional_encoder=False): 82 | """Function to help set up a decoder with attention. 83 | 84 | Args: 85 | feature_size (int): number of features in the input to this model 86 | method ([type]): [description] 87 | device (torch.device): what device to put this encoder on 88 | batch_size (int, optional): [description]. Defaults to 32. 89 | hidden_size (int, optional): number of hidden units in the encoder. 90 | Defaults to 64. 91 | dropout (float, optional): dropout to apply in the encoder. 92 | Defaults to 0.0. 93 | bidirectional_encoder (bool, optional): whether the encoder is 94 | bidirectional. Defaults to False. 95 | 96 | Returns: 97 | AttnDecoderRNN: a decoder with attention for use in seq2seq tasks 98 | """ 99 | attn = Attention(hidden_size, batch_size, method, 100 | bidirectional_encoder=bidirectional_encoder) 101 | decoder = AttnDecoderRNN(feature_size, feature_size, 102 | hidden_size, hidden_size, 103 | attn, 104 | bidirectional_encoder=bidirectional_encoder) 105 | decoder = decoder.double().to(device) 106 | return decoder 107 | 108 | 109 | def loss_batch(data, models, opts, criterion, device, 110 | teacher_forcing_ratio=0.0, use_attention=False, 111 | norm_quaternions=False, average_batch=True): 112 | """Train or validate encoder and decoder models on a single batch of data. 113 | 114 | Args: 115 | data (tuple): tuple containing input and target batches. 116 | models (tuple): tuple containing encoder and decoder for use during 117 | training. 118 | opts (tuple): tuple containing encoder and decoder optimizers 119 | criterion (nn.Module): criterion to use for training or validation 120 | device (torch.device): device to put batches on 121 | teacher_forcing_ratio (float, optional): percent of the time to use 122 | teacher forcing in the decoder. Defaults to 0.0. 123 | use_attention (bool, optional): whether the decoder uses attention. 124 | Defaults to False. 125 | norm_quaternions (bool, optional): whether the quaternion outputs 126 | should be normalized. Defaults to False. 127 | average_batch (bool, optional): For use during training; can be set to 128 | false for plotting histograms during testing. Defaults to True. 129 | 130 | Returns: 131 | float or list: will return a single loss or list of losses depending 132 | on the average_batch flag. 133 | """ 134 | training = (opts is not None) 135 | 136 | encoder, decoder = models 137 | input_batch, target_batch = data 138 | 139 | input_batch = input_batch.to(device).double() 140 | target_batch = target_batch.to(device).double() 141 | 142 | if training: 143 | encoder.train(), decoder.train() 144 | encoder_opt, decoder_opt = opts 145 | else: 146 | encoder.eval(), decoder.eval() 147 | 148 | loss = 0 149 | seq_length = target_batch.shape[1] 150 | 151 | input = input_batch.permute(1, 0, 2) 152 | encoder_outputs, encoder_hidden = encoder(input) 153 | 154 | decoder_hidden = encoder_hidden 155 | decoder_input = torch.ones_like(target_batch[:, 0, :]).unsqueeze(0) 156 | EOS = torch.zeros_like(target_batch[:, 0, :]).unsqueeze(0) 157 | outputs = torch.zeros_like(target_batch) 158 | 159 | use_teacher_forcing = (True if random.random() < teacher_forcing_ratio 160 | else False) 161 | 162 | if not average_batch: 163 | if training: 164 | logger.warning("average_batch must be true when training") 165 | sys.exit() 166 | losses = [0 for i in range(target_batch.shape[0])] 167 | 168 | for t in range(seq_length): 169 | 170 | if use_attention: 171 | decoder_output, decoder_hidden = decoder( 172 | decoder_input, decoder_hidden, encoder_outputs) 173 | else: 174 | decoder_output, decoder_hidden = decoder( 175 | decoder_input, decoder_hidden) 176 | 177 | target = target_batch[:, t, :].unsqueeze(0) 178 | 179 | output = decoder_output 180 | 181 | if norm_quaternions: 182 | original_shape = output.shape 183 | 184 | output = output.contiguous().view(-1, 4) 185 | output = F.normalize(output, p=2, dim=1).view(original_shape) 186 | 187 | outputs[:, t, :] = output 188 | 189 | if use_teacher_forcing: 190 | decoder_input = target 191 | else: 192 | if torch.all(torch.eq(decoder_output, EOS)): 193 | break 194 | decoder_input = output.detach() 195 | 196 | loss = criterion(outputs, target_batch) 197 | 198 | if training: 199 | loss.backward() 200 | 201 | encoder_opt.step() 202 | encoder_opt.zero_grad() 203 | 204 | decoder_opt.step() 205 | decoder_opt.zero_grad() 206 | 207 | if average_batch: 208 | return loss.item() 209 | else: 210 | losses = [] 211 | for b in range(outputs.shape[0]): 212 | sample_loss = criterion(outputs[b, :], target_batch[b, :]) 213 | losses.append(sample_loss.item()) 214 | 215 | return losses 216 | 217 | 218 | def fit(models, optims, epochs, dataloaders, training_criterion, 219 | validation_criteria, schedulers, device, model_file_path, 220 | teacher_forcing_ratio=0.0, use_attention=False, 221 | norm_quaternions=False, schedule_rate=1.0): 222 | """Fit a seq2seq model to data, logging training and validation loss. 223 | 224 | Args: 225 | models (tuple): tuple containing the encoder and decoder 226 | optims (tuple): tuple containing the encoder and decoder optimizers for 227 | training 228 | epochs (int): number of epochs to train for 229 | dataloaders (tuple): tuple containing the training dataloader and val 230 | dataloader. 231 | training_criterion (nn.Module): criterion for backpropagation during 232 | training 233 | validation_criteria (list): list of criteria for validation 234 | schedulers (list): list of schedulers to control learning rate for 235 | optimizers 236 | device (torch.device): device to place data on 237 | model_file_path (str): where to save the model when validation loss 238 | reaches new minimum 239 | teacher_forcing_ratio (float, optional): percent of the time to use 240 | teacher forcing in the decoder. Defaults to 0.0. 241 | use_attention (bool, optional): whether decoder uses attention or not. 242 | Defaults to False. 243 | norm_quaternions (bool, optional): whether quaternions should be 244 | normalized after they are output from the decoder. 245 | Defaults to False. 246 | schedule_rate (float, optional): rate to increase or decrease teacher 247 | forcing ratio. Defaults to 1.0. 248 | """ 249 | 250 | train_dataloader, val_dataloader = dataloaders 251 | 252 | min_val_loss = math.inf 253 | for epoch in range(epochs): 254 | losses = [] 255 | total_time = 0 256 | 257 | logger.info(f"Epoch {epoch+1} / {epochs}") 258 | 259 | for index, data in enumerate(train_dataloader, 0): 260 | with Timer() as timer: 261 | loss = loss_batch(data, models, 262 | optims, training_criterion, device, 263 | use_attention=use_attention, 264 | norm_quaternions=norm_quaternions) 265 | 266 | losses.append(loss) 267 | total_time += timer.interval 268 | if index % (len(train_dataloader) // 10) == 0: 269 | logger.info((f"Total time elapsed: {total_time} - " 270 | f"Batch Number: {index} / {len(train_dataloader)}" 271 | f" - Training loss: {loss}" 272 | )) 273 | val_loss = [] 274 | for validation_criterion in validation_criteria: 275 | with torch.no_grad(): 276 | val_losses = [loss_batch(data, models, 277 | None, validation_criterion, device, 278 | use_attention=use_attention, 279 | norm_quaternions=norm_quaternions) 280 | for _, data in enumerate(val_dataloader, 0)] 281 | 282 | val_loss.append(np.sum(val_losses) / len(val_losses)) 283 | 284 | loss = np.sum(losses) / len(losses) 285 | 286 | for scheduler in schedulers: 287 | scheduler.step() 288 | 289 | val_loss_strs = ", ".join(map(str, val_loss)) 290 | logger.info(f"Training Loss: {loss} - Val Loss: {val_loss_strs}") 291 | 292 | teacher_forcing_ratio *= schedule_rate 293 | if val_loss[0] < min_val_loss: 294 | min_val_loss = val_loss[0] 295 | logger.info(f"Saving model to {model_file_path}") 296 | torch.save({ 297 | "encoder_state_dict": models[0].state_dict(), 298 | "decoder_state_dict": models[1].state_dict(), 299 | "optimizerA_state_dict": optims[0].state_dict(), 300 | "optimizerB_state_dict": optims[1].state_dict(), 301 | }, model_file_path) 302 | -------------------------------------------------------------------------------- /src/test-seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from seq2seq.training_utils import ( 9 | loss_batch, 10 | get_encoder, 11 | get_decoder, 12 | get_attn_decoder 13 | ) 14 | from common.losses import QuatDistance 15 | from common.data_utils import load_dataloader 16 | from common.logging import logger 17 | import torch 18 | import numpy as np 19 | import os 20 | import argparse 21 | import h5py 22 | import matplotlib.pyplot as plt 23 | import matplotlib.font_manager 24 | matplotlib.use("Agg") 25 | 26 | torch.manual_seed(42) 27 | np.random.seed(42) 28 | 29 | torch.backends.cudnn.deterministic = False 30 | torch.backends.cudnn.benchmark = False 31 | 32 | 33 | def parse_args(): 34 | """Parse arguments for module. 35 | 36 | Returns: 37 | argparse.Namespace: contains accessible arguments passed in to module 38 | """ 39 | parser = argparse.ArgumentParser() 40 | 41 | parser.add_argument("--task", 42 | help=("task for neural network to train on; " 43 | "either prediction or conversion")) 44 | parser.add_argument("--data-path-parent", 45 | help=("parent folder of directories with h5 files " 46 | "(folders must contain normalization.h5 " 47 | "and testing.h5)")) 48 | parser.add_argument("--figure-file-path", 49 | help="path to where the figure should be saved") 50 | parser.add_argument("--figure-title", 51 | help="title of the histogram plot") 52 | parser.add_argument("--include-legend", 53 | help="will use bidirectional encoder", 54 | default=False, 55 | action="store_true") 56 | parser.add_argument("--model-dir", 57 | help="path to model file directory") 58 | parser.add_argument("--representation", 59 | help="orientation representation", 60 | default="quaternions") 61 | parser.add_argument("--batch-size", 62 | help="batch size for training", default=32) 63 | parser.add_argument("--seq-length", 64 | help=("sequence length for model, will be " 65 | "downsampled if downsample is provided"), 66 | default=20) 67 | parser.add_argument("--downsample", 68 | help=("reduce sampling frequency of recorded data; " 69 | "default sampling frequency is 240 Hz"), 70 | default=1) 71 | parser.add_argument("--in-out-ratio", 72 | help=("ratio of input/output; " 73 | "seq_length / downsample = input length = 10, " 74 | "output length = input length / in_out_ratio"), 75 | default=1) 76 | parser.add_argument("--stride", 77 | help=("stride used when reading data " 78 | "in for running prediction tasks"), 79 | default=3) 80 | parser.add_argument("--hidden-size", 81 | help="hidden size in both the encoder and decoder") 82 | parser.add_argument("--dropout", 83 | help="dropout percentage in encoder and decoder", 84 | default=0.0) 85 | parser.add_argument("--bidirectional", 86 | help="will use bidirectional encoder", 87 | default=False, 88 | action="store_true") 89 | parser.add_argument("--attention", 90 | help="use decoder with specified attention", 91 | default="general") 92 | 93 | args = parser.parse_args() 94 | 95 | return args 96 | 97 | 98 | if __name__ == "__main__": 99 | args = parse_args() 100 | 101 | for arg in vars(args): 102 | logger.info("{} - {}".format(arg, getattr(args, arg))) 103 | 104 | logger.info("Starting seq2seq model testing...") 105 | 106 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 107 | 108 | seq_length = int(args.seq_length) 109 | stride = int(args.stride) 110 | batch_size = int(args.batch_size) 111 | 112 | data_paths = [args.data_path_parent + "/" + name 113 | for name in os.listdir(args.data_path_parent) 114 | if os.path.isdir(args.data_path_parent + "/" + name)] 115 | 116 | model_paths = [args.model_dir + "/" + name 117 | for name in os.listdir(args.model_dir)] 118 | 119 | data_paths.sort() 120 | model_paths.sort() 121 | 122 | plt.style.use("fivethirtyeight") 123 | plt.rcParams["font.family"] = "Times New Roman" 124 | plt.rcParams["axes.titlesize"] = 24 125 | plt.rcParams["axes.labelsize"] = 23 126 | plt.rcParams["xtick.labelsize"] = 22 127 | plt.rcParams["ytick.labelsize"] = 22 128 | plt.rcParams["figure.facecolor"] = "white" 129 | plt.rcParams["axes.facecolor"] = "white" 130 | plt.rcParams["savefig.edgecolor"] = "white" 131 | plt.rcParams["savefig.facecolor"] = "white" 132 | 133 | fig = plt.figure() 134 | ax = fig.add_subplot(111) 135 | 136 | for i, data_path in enumerate(data_paths): 137 | args.data_path = data_path 138 | 139 | with h5py.File(args.data_path + "/normalization.h5", "r") as f: 140 | mean, std_dev = torch.Tensor(f["mean"]), torch.Tensor(f["std_dev"]) 141 | norm_data = (mean, std_dev) 142 | test_dataloader, _ = load_dataloader(args, 143 | "testing", 144 | True, 145 | norm_data=norm_data) 146 | 147 | encoder_feature_size = test_dataloader.dataset[0][0].shape[1] 148 | decoder_feature_size = test_dataloader.dataset[0][1].shape[1] 149 | 150 | bidirectional = args.bidirectional 151 | encoder = get_encoder(encoder_feature_size, 152 | device, 153 | hidden_size=int(args.hidden_size), 154 | dropout=float(args.dropout), 155 | bidirectional=bidirectional) 156 | 157 | use_attention = False 158 | attention_options = ["add", "dot", "concat", 159 | "general", "activated-general", "biased-general"] 160 | 161 | if args.attention in attention_options: 162 | decoder = get_attn_decoder(decoder_feature_size, 163 | args.attention, 164 | device, 165 | hidden_size=int(args.hidden_size), 166 | bidirectional_encoder=bidirectional) 167 | use_attention = True 168 | else: 169 | decoder = get_decoder(decoder_feature_size, 170 | device, 171 | dropout=float(args.dropout), 172 | hidden_size=int(args.hidden_size)) 173 | 174 | checkpoint = torch.load(model_paths[i], map_location=device) 175 | 176 | encoder.load_state_dict(checkpoint["encoder_state_dict"]) 177 | decoder.load_state_dict(checkpoint["decoder_state_dict"]) 178 | 179 | decoder.batch_size = batch_size 180 | if use_attention: 181 | decoder.attention.batch_size = batch_size 182 | 183 | models = (encoder.double(), decoder.double()) 184 | criterion = QuatDistance() 185 | norm_quaternions = (args.representation == "quaternions") 186 | 187 | with torch.no_grad(): 188 | inference_losses = [loss_batch(data, models, 189 | None, criterion, device, 190 | use_attention=use_attention, 191 | norm_quaternions=norm_quaternions, 192 | average_batch=False) 193 | for _, data in enumerate(test_dataloader, 0)] 194 | 195 | def flatten(l): return [item for sublist in l for item in sublist] 196 | 197 | inference_losses = flatten(inference_losses) 198 | 199 | ax.hist(inference_losses, bins=60, density=True, 200 | histtype=u"step", linewidth=2) 201 | ax.set_xlim(0, 40) 202 | ax.set_xticks(range(0, 45, 5)) 203 | ax.set_xticklabels(range(0, 45, 5)) 204 | 205 | ax.set_ylim(0, 0.20) 206 | ax.set_yticks(np.arange(0, 0.20, 0.05)) 207 | ax.set_yticklabels(np.arange(0, 0.20, 0.05).round(decimals=2)) 208 | 209 | inference_loss = np.sum(inference_losses) / len(inference_losses) 210 | group = data_path.split("/")[-1] 211 | logger.info(f"Inference Loss for {group}: {inference_loss}") 212 | 213 | ax.set_title(args.figure_title) 214 | ax.set_xlabel("Sequence Angular Error in Degrees") 215 | ax.set_ylabel("Percentage") 216 | if args.include_legend: 217 | ax.legend(["Config. 1", "Config. 2", "Config. 3", "Config. 4"]) 218 | figname = args.figure_file_path 219 | fig.savefig(figname, bbox_inches="tight") 220 | 221 | logger.info("Completed Testing...") 222 | logger.info("\n") 223 | -------------------------------------------------------------------------------- /src/test-transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from transformers.training_utils import ( 9 | inference, 10 | loss_batch 11 | ) 12 | from transformers.transformers import ( 13 | InferenceTransformer, 14 | InferenceTransformerEncoder 15 | ) 16 | from common.data_utils import load_dataloader 17 | from common.logging import logger 18 | from common.losses import QuatDistance 19 | import torch 20 | from torch import nn 21 | import numpy as np 22 | import os 23 | import argparse 24 | import h5py 25 | import matplotlib 26 | import matplotlib.pyplot as plt 27 | import matplotlib.font_manager 28 | matplotlib.use("Agg") 29 | 30 | 31 | torch.manual_seed(42) 32 | np.random.seed(42) 33 | 34 | torch.backends.cudnn.deterministic = False 35 | torch.backends.cudnn.benchmark = False 36 | 37 | 38 | def parse_args(): 39 | """Parse arguments for module. 40 | 41 | Returns: 42 | argparse.Namespace: contains accessible arguments passed in to module 43 | """ 44 | parser = argparse.ArgumentParser() 45 | 46 | parser.add_argument("--task", 47 | help=("task for neural network to train on; " 48 | "either prediction or conversion")) 49 | parser.add_argument("--data-path-parent", 50 | help=("parent folder of directories with h5 files " 51 | "(folders must contain normalization.h5 " 52 | "and testing.h5)")) 53 | parser.add_argument("--figure-file-path", 54 | help="path to where the figure should be saved") 55 | parser.add_argument("--figure-title", 56 | help="title of the histogram plot") 57 | parser.add_argument("--include-legend", 58 | help="will use bidirectional encoder", 59 | default=False, 60 | action="store_true") 61 | parser.add_argument("--model-dir", 62 | help="path to model file directory") 63 | parser.add_argument("--representation", 64 | help="orientation representation", 65 | default="quaternions") 66 | parser.add_argument("--full-transformer", 67 | help=("will use full Transformer if true, " 68 | "will only use encoder if false"), 69 | default=False, 70 | action="store_true") 71 | parser.add_argument("--batch-size", 72 | help="batch size for training", default=32) 73 | parser.add_argument("--seq-length", 74 | help=("sequence length for model, will be " 75 | "downsampled if downsample is provided"), 76 | default=20) 77 | parser.add_argument("--downsample", 78 | help=("reduce sampling frequency of recorded data; " 79 | "default sampling frequency is 240 Hz"), 80 | default=1) 81 | parser.add_argument("--in-out-ratio", 82 | help=("ratio of input/output; " 83 | "seq_length / downsample = input length = 10, " 84 | "output length = input length / in_out_ratio"), 85 | default=1) 86 | parser.add_argument("--stride", 87 | help=("stride used when reading data " 88 | "in for running prediction tasks"), 89 | default=3) 90 | parser.add_argument("--num-heads", 91 | help="number of heads in Transformer Encoder") 92 | parser.add_argument("--dim-feedforward", 93 | help="number of dimensions in feedforward " 94 | "layer in Transformer Encoder") 95 | parser.add_argument("--dropout", 96 | help="dropout percentage in Transformer Encoder") 97 | parser.add_argument("--num-layers", 98 | help="number of layers in Transformer Encoder") 99 | 100 | args = parser.parse_args() 101 | 102 | if args.data_path_parent is None: 103 | parser.print_help() 104 | 105 | return args 106 | 107 | 108 | if __name__ == "__main__": 109 | args = parse_args() 110 | 111 | for arg in vars(args): 112 | logger.info(f"{arg} - {getattr(args, arg)}") 113 | 114 | logger.info("Starting Transformer testing...") 115 | 116 | logger.info(f"Device count: {str(torch.cuda.device_count())}") 117 | 118 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 119 | logger.info("Testing on {}...".format(device)) 120 | seq_length = int(args.seq_length)//int(args.downsample) 121 | 122 | data_paths = [args.data_path_parent + "/" + name 123 | for name in os.listdir(args.data_path_parent) 124 | if os.path.isdir(args.data_path_parent + "/" + name)] 125 | 126 | model_paths = [args.model_dir + "/" + name 127 | for name in os.listdir(args.model_dir)] 128 | 129 | data_paths.sort() 130 | model_paths.sort() 131 | 132 | plt.style.use("fivethirtyeight") 133 | plt.rcParams["font.family"] = "Times New Roman" 134 | plt.rcParams["axes.titlesize"] = 24 135 | plt.rcParams["axes.labelsize"] = 22 136 | plt.rcParams["xtick.labelsize"] = 18 137 | plt.rcParams["ytick.labelsize"] = 18 138 | plt.rcParams["figure.facecolor"] = "white" 139 | plt.rcParams["axes.facecolor"] = "white" 140 | plt.rcParams["savefig.edgecolor"] = "white" 141 | plt.rcParams["savefig.facecolor"] = "white" 142 | 143 | fig = plt.figure() 144 | ax = fig.add_subplot(111) 145 | 146 | for i, data_path in enumerate(data_paths): 147 | args.data_path = data_path 148 | with h5py.File(args.data_path + "/normalization.h5", "r") as f: 149 | mean, std_dev = torch.Tensor(f["mean"]), torch.Tensor(f["std_dev"]) 150 | norm_data = (mean, std_dev) 151 | 152 | test_dataloader, _ = load_dataloader(args, 153 | "testing", 154 | True, 155 | norm_data=norm_data) 156 | 157 | encoder_feature_size = test_dataloader.dataset[0][0].shape[1] 158 | decoder_feature_size = test_dataloader.dataset[0][1].shape[1] 159 | 160 | num_heads = (int(args.num_heads) if args.full_transformer 161 | else encoder_feature_size) 162 | dim_feedforward = int(args.dim_feedforward) 163 | dropout = float(args.dropout) 164 | num_layers = int(args.num_layers) 165 | quaternions = (args.representation == "quaternions") 166 | 167 | if args.full_transformer: 168 | model = InferenceTransformer(decoder_feature_size, num_heads, 169 | dim_feedforward, dropout, 170 | num_layers, quaternions=quaternions) 171 | else: 172 | num_heads = encoder_feature_size 173 | model = InferenceTransformerEncoder(encoder_feature_size, 174 | num_heads, dim_feedforward, 175 | dropout, num_layers, 176 | decoder_feature_size, 177 | quaternions=quaternions) 178 | 179 | checkpoint = torch.load(model_paths[i], map_location=device) 180 | 181 | model.load_state_dict(checkpoint["model_state_dict"]) 182 | 183 | if torch.cuda.device_count() > 1: 184 | model = nn.DataParallel(model) 185 | 186 | model = model.to(device).double() 187 | 188 | criterion = QuatDistance() 189 | 190 | with torch.no_grad(): 191 | if args.full_transformer: 192 | inference_losses = [inference(model, data, 193 | criterion, device, 194 | average_batch=False) 195 | for _, data in enumerate(test_dataloader, 196 | 0)] 197 | else: 198 | inference_losses = [loss_batch(model, None, 199 | data, criterion, device, 200 | full_transformer=False, 201 | average_batch=False) 202 | for _, data in enumerate(test_dataloader, 203 | 0)] 204 | 205 | def flatten(l): return [item for sublist in l for item in sublist] 206 | 207 | inference_losses = flatten(inference_losses) 208 | 209 | inference_loss = np.sum(inference_losses) / len(inference_losses) 210 | logger.info("Inference Loss: {}".format(inference_loss)) 211 | 212 | ax.hist(inference_losses, bins=60, 213 | density=True, histtype=u"step", 214 | linewidth=2) 215 | ax.set_xlim(0, 40) 216 | ax.set_xticks(range(0, 45, 5)) 217 | ax.set_xticklabels(range(0, 45, 5)) 218 | 219 | ax.set_ylim(0, 0.20) 220 | ax.set_yticks(np.arange(0, 0.20, 0.05)) 221 | ax.set_yticklabels(np.arange(0, 0.20, 0.05).round(decimals=2)) 222 | 223 | ax.set_title(args.figure_title) 224 | ax.set_xlabel("Sequence Angular Error in Degrees") 225 | ax.set_ylabel("Percentage") 226 | if args.include_legend: 227 | ax.legend(["Config. 1", "Config. 2", "Config. 3", "Config. 4"]) 228 | figname = args.figure_file_path 229 | fig.savefig(figname, bbox_inches="tight") 230 | 231 | logger.info("Completed testing...") 232 | logger.info("\n") 233 | -------------------------------------------------------------------------------- /src/test_seq2seq.sh: -------------------------------------------------------------------------------- 1 | python test-seq2seq.py --task conversion \ 2 | --data-path-parent "/home/jackg7/VT-Natural-Motion-Processing/data" \ 3 | --figure-file-path "/home/jackg7/VT-Natural-Motion-Processing/images/seq2seq-test.pdf" \ 4 | --figure-title "Seq2Seq" \ 5 | --model-dir "/home/jackg7/VT-Natural-Motion-Processing/models/set-2" \ 6 | --representation quaternions \ 7 | --batch-size=512 \ 8 | --seq-length=30 \ 9 | --downsample=6 \ 10 | --in-out-ratio=5 \ 11 | --stride=30 \ 12 | --hidden-size=512 \ 13 | --attention=dot \ 14 | --bidirectional 15 | -------------------------------------------------------------------------------- /src/test_transformer.sh: -------------------------------------------------------------------------------- 1 | python test-transformer.py --task conversion \ 2 | --data-path-parent "/home/jackg7/VT-Natural-Motion-Processing/data" \ 3 | --figure-file-path "/home/jackg7/VT-Natural-Motion-Processing/images/transformer-test.pdf" \ 4 | --figure-title "Seq2Seq" \ 5 | --model-dir "/home/jackg7/VT-Natural-Motion-Processing/models/set-2" \ 6 | --full-transformer \ 7 | --representation quaternions \ 8 | --batch-size=512 \ 9 | --seq-length=30 \ 10 | --downsample=6 \ 11 | --in-out-ratio=1 \ 12 | --stride=30 \ 13 | --num-heads=4 \ 14 | --dim-feedforward=512 \ 15 | --dropout=0.0 \ 16 | --num-layers=2 17 | -------------------------------------------------------------------------------- /src/train-seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from seq2seq.training_utils import ( 9 | fit, 10 | get_encoder, 11 | get_decoder, 12 | get_attn_decoder 13 | ) 14 | from common.losses import QuatDistance 15 | from common.data_utils import load_dataloader 16 | from common.logging import logger 17 | import torch 18 | from torch import nn, optim 19 | import numpy as np 20 | import argparse 21 | 22 | torch.manual_seed(42) 23 | np.random.seed(42) 24 | 25 | torch.backends.cudnn.deterministic = False 26 | torch.backends.cudnn.benchmark = False 27 | 28 | 29 | def parse_args(): 30 | """Parse arguments for module. 31 | 32 | Returns: 33 | argparse.Namespace: contains accessible arguments passed in to module 34 | """ 35 | parser = argparse.ArgumentParser() 36 | 37 | parser.add_argument("--task", 38 | help=("task for neural network to train on; " 39 | "either prediction or conversion")) 40 | parser.add_argument("--data-path", 41 | help=("path to h5 files containing data " 42 | "(must contain training.h5 and validation.h5)")) 43 | parser.add_argument("--model-file-path", 44 | help="path to model file for saving it after training") 45 | parser.add_argument("--representation", 46 | help="data representation", default="quaternion") 47 | parser.add_argument("--auxiliary-acc", 48 | help="will train on auxiliary acceleration if true", 49 | default=False, 50 | action="store_true") 51 | parser.add_argument("--batch-size", 52 | help="batch size for training", default=32) 53 | parser.add_argument("--learning-rate", 54 | help="learning rate for encoder and decoder", 55 | default=0.001) 56 | parser.add_argument("--seq-length", 57 | help="sequence length for encoder/decoder", 58 | default=20) 59 | parser.add_argument("--downsample", 60 | help="reduce sampling frequency of recorded data; " 61 | "default sampling frequency is 240 Hz", 62 | default=1) 63 | parser.add_argument("--in-out-ratio", 64 | help=("ratio of input/output; " 65 | "seq_length / downsample = input length = 10, " 66 | "output length = input length / in_out_ratio"), 67 | default=1) 68 | parser.add_argument("--stride", 69 | help="stride used when running prediction tasks", 70 | default=3) 71 | parser.add_argument("--num-epochs", 72 | help="number of epochs for training", default=1) 73 | parser.add_argument("--hidden-size", 74 | help="hidden size in both the encoder and decoder") 75 | parser.add_argument("--dropout", 76 | help="dropout percentage in encoder and decoder", 77 | default=0.0) 78 | parser.add_argument("--bidirectional", 79 | help="will use bidirectional encoder", 80 | default=False, 81 | action="store_true") 82 | parser.add_argument("--attention", 83 | help="will use decoder with given attention method", 84 | default="general") 85 | 86 | args = parser.parse_args() 87 | 88 | if args.data_path is None: 89 | parser.print_help() 90 | 91 | return args 92 | 93 | 94 | if __name__ == "__main__": 95 | args = parse_args() 96 | 97 | for arg in vars(args): 98 | logger.info(f"{arg} - {getattr(args, arg)}") 99 | 100 | logger.info("Starting seq2seq model training...") 101 | 102 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 103 | 104 | seq_length = int(args.seq_length) 105 | stride = int(args.stride) 106 | lr = float(args.learning_rate) 107 | 108 | assert seq_length % int(args.in_out_ratio) == 0 109 | 110 | normalize = True 111 | train_dataloader, norm_data = load_dataloader(args, "training", 112 | normalize, norm_data=None) 113 | val_dataloader, _ = load_dataloader(args, "validation", 114 | normalize, norm_data=norm_data) 115 | 116 | encoder_feature_size = train_dataloader.dataset[0][0].shape[1] 117 | decoder_feature_size = train_dataloader.dataset[0][1].shape[1] 118 | 119 | bidirectional = args.bidirectional 120 | encoder = get_encoder(encoder_feature_size, 121 | device, 122 | hidden_size=int(args.hidden_size), 123 | dropout=float(args.dropout), 124 | bidirectional=bidirectional) 125 | 126 | encoder_optim = optim.AdamW(encoder.parameters(), lr=lr, weight_decay=0.05) 127 | encoder_sched = optim.lr_scheduler.MultiStepLR(encoder_optim, 128 | milestones=[5], 129 | gamma=0.1) 130 | 131 | use_attention = False 132 | attention_options = ["add", "dot", "concat", 133 | "general", "activated-general", "biased-general"] 134 | if args.attention in attention_options: 135 | decoder = get_attn_decoder(decoder_feature_size, 136 | args.attention, 137 | device, 138 | hidden_size=int(args.hidden_size), 139 | bidirectional_encoder=bidirectional) 140 | use_attention = True 141 | else: 142 | decoder = get_decoder(decoder_feature_size, 143 | device, 144 | dropout=float(args.dropout), 145 | hidden_size=int(args.hidden_size)) 146 | 147 | decoder_optim = optim.AdamW(decoder.parameters(), lr=lr, weight_decay=0.05) 148 | decoder_sched = optim.lr_scheduler.MultiStepLR(decoder_optim, 149 | milestones=[5], 150 | gamma=0.1) 151 | 152 | encoder_params = sum(p.numel() 153 | for p in encoder.parameters() if p.requires_grad) 154 | decoder_params = sum(p.numel() 155 | for p in decoder.parameters() if p.requires_grad) 156 | 157 | models = (encoder, decoder) 158 | optims = (encoder_optim, decoder_optim) 159 | dataloaders = (train_dataloader, val_dataloader) 160 | epochs = int(args.num_epochs) 161 | training_criterion = nn.L1Loss() 162 | validation_criteria = [nn.L1Loss(), QuatDistance()] 163 | norm_quaternions = (args.representation == "quaternions") 164 | 165 | schedulers = (encoder_sched, decoder_sched) 166 | 167 | logger.info(f"Encoder for training: {encoder}") 168 | logger.info(f"Decoder for training: {decoder}") 169 | logger.info(f"Number of parameters: {encoder_params + decoder_params}") 170 | logger.info(f"Optimizers for training: {encoder_optim}") 171 | logger.info(f"Criterion for training: {training_criterion}") 172 | 173 | fit(models, optims, epochs, dataloaders, training_criterion, 174 | validation_criteria, schedulers, device, args.model_file_path, 175 | use_attention=use_attention, norm_quaternions=norm_quaternions) 176 | 177 | logger.info("Completed Training...") 178 | logger.info("\n") 179 | -------------------------------------------------------------------------------- /src/train-transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from transformers.training_utils import fit 9 | from transformers.transformers import ( 10 | InferenceTransformerEncoder, 11 | InferenceTransformer 12 | ) 13 | from common.data_utils import load_dataloader 14 | from common.logging import logger 15 | from common.losses import QuatDistance 16 | import torch 17 | from torch import nn, optim 18 | import numpy as np 19 | import argparse 20 | 21 | torch.manual_seed(42) 22 | np.random.seed(42) 23 | 24 | torch.backends.cudnn.deterministic = False 25 | torch.backends.cudnn.benchmark = False 26 | 27 | 28 | def parse_args(): 29 | """Parse arguments for module. 30 | 31 | Returns: 32 | argparse.Namespace: contains accessible arguments passed in to module 33 | """ 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument("--task", 37 | help=("task for neural network to train on; " 38 | "either prediction or conversion")) 39 | parser.add_argument("--data-path", 40 | help=("path to h5 files containing data " 41 | "(must contain training.h5 and validation.h5)")) 42 | parser.add_argument("--representation", 43 | help=("will normalize if quaternions, will use expmap " 44 | "to quat validation loss if expmap"), 45 | default="quaternion") 46 | parser.add_argument("--full-transformer", 47 | help=("will use Transformer with both encoder and " 48 | "decoder if true, will only use encoder " 49 | "if false"), 50 | default=False, 51 | action="store_true") 52 | parser.add_argument("--model-file-path", 53 | help="path to model file for saving it after training") 54 | parser.add_argument("--batch-size", 55 | help="batch size for training", default=32) 56 | parser.add_argument("--learning-rate", 57 | help="initial learning rate for training", 58 | default=0.001) 59 | parser.add_argument("--beta-one", 60 | help="beta1 for adam optimizer (momentum)", 61 | default=0.9) 62 | parser.add_argument("--beta-two", 63 | help="beta2 for adam optimizer", default=0.999) 64 | parser.add_argument("--seq-length", 65 | help=("sequence length for model, will be divided " 66 | "by downsample if downsample is provided"), 67 | default=20) 68 | parser.add_argument("--downsample", 69 | help=("reduce sampling frequency of recorded data; " 70 | "default sampling frequency is 240 Hz"), 71 | default=1) 72 | parser.add_argument("--in-out-ratio", 73 | help=("ratio of input/output; " 74 | "seq_length / downsample = input length = 10, " 75 | "output length = input length / in_out_ratio"), 76 | default=1) 77 | parser.add_argument("--stride", 78 | help=("stride used when reading data in " 79 | "for running prediction tasks"), 80 | default=3) 81 | parser.add_argument("--num-epochs", 82 | help="number of epochs for training", default=1) 83 | parser.add_argument("--num-heads", 84 | help="number of heads in Transformer") 85 | parser.add_argument("--dim-feedforward", 86 | help=("number of dimensions in feedforward layer " 87 | "in Transformer")) 88 | parser.add_argument("--dropout", 89 | help="dropout percentage in Transformer") 90 | parser.add_argument("--num-layers", 91 | help="number of layers in Transformer") 92 | 93 | args = parser.parse_args() 94 | 95 | if args.data_path is None: 96 | parser.print_help() 97 | 98 | return args 99 | 100 | 101 | if __name__ == "__main__": 102 | args = parse_args() 103 | 104 | for arg in vars(args): 105 | logger.info(f"{arg} - {getattr(args, arg)}") 106 | 107 | logger.info("Starting Transformer training...") 108 | 109 | logger.info(f"Device count: {torch.cuda.device_count()}") 110 | 111 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 112 | logger.info(f"Training on {device}...") 113 | seq_length = int(args.seq_length)//int(args.downsample) 114 | 115 | assert seq_length % int(args.in_out_ratio) == 0 116 | 117 | lr = float(args.learning_rate) 118 | 119 | normalize = True 120 | train_dataloader, norm_data = load_dataloader(args, "training", normalize) 121 | val_dataloader, _ = load_dataloader(args, "validation", normalize, 122 | norm_data=norm_data) 123 | 124 | encoder_feature_size = train_dataloader.dataset[0][0].shape[1] 125 | decoder_feature_size = train_dataloader.dataset[0][1].shape[1] 126 | 127 | num_heads = int(args.num_heads) 128 | dim_feedforward = int(args.dim_feedforward) 129 | dropout = float(args.dropout) 130 | num_layers = int(args.num_layers) 131 | quaternions = (args.representation == "quaternions") 132 | 133 | if args.full_transformer: 134 | model = InferenceTransformer(decoder_feature_size, num_heads, 135 | dim_feedforward, dropout, 136 | num_layers, quaternions=quaternions) 137 | else: 138 | model = InferenceTransformerEncoder(encoder_feature_size, num_heads, 139 | dim_feedforward, dropout, 140 | num_layers, decoder_feature_size, 141 | quaternions=quaternions) 142 | 143 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 144 | 145 | if torch.cuda.device_count() > 1: 146 | model = nn.DataParallel(model) 147 | 148 | model = model.to(device).double() 149 | 150 | epochs = int(args.num_epochs) 151 | beta1 = float(args.beta_one) 152 | beta2 = float(args.beta_two) 153 | 154 | optimizer = optim.AdamW(model.parameters(), 155 | lr=lr, 156 | betas=(beta1, beta2), 157 | weight_decay=0.03) 158 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 159 | milestones=[1, 3], 160 | gamma=0.1) 161 | 162 | dataloaders = (train_dataloader, val_dataloader) 163 | training_criterion = nn.L1Loss() 164 | validation_criteria = [nn.L1Loss(), QuatDistance()] 165 | 166 | logger.info(f"Model for training: {model}") 167 | logger.info(f"Number of parameters: {num_params}") 168 | logger.info(f"Optimizer for training: {optimizer}") 169 | logger.info(f"Criterion for training: {training_criterion}") 170 | 171 | fit(model, optimizer, scheduler, epochs, dataloaders, training_criterion, 172 | validation_criteria, device, args.model_file_path, 173 | full_transformer=args.full_transformer) 174 | 175 | logger.info("Completed Training...") 176 | logger.info("\n") 177 | -------------------------------------------------------------------------------- /src/train_seq2seq.sh: -------------------------------------------------------------------------------- 1 | python train-seq2seq.py --task conversion \ 2 | --data-path "/home/jackg7/VT-Natural-Motion-Processing/data/set-2" \ 3 | --model-file-path "/home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt" \ 4 | --representation quaternions \ 5 | --batch-size=32 \ 6 | --seq-length=30 \ 7 | --downsample=6 \ 8 | --in-out-ratio=5 \ 9 | --stride=30 \ 10 | --learning-rate=0.001 \ 11 | --num-epochs=1 \ 12 | --hidden-size=512 \ 13 | --attention=dot \ 14 | --bidirectional 15 | -------------------------------------------------------------------------------- /src/train_transformer.sh: -------------------------------------------------------------------------------- 1 | python train-transformer.py --task conversion \ 2 | --data-path "/home/jackg7/VT-Natural-Motion-Processing/data/set-2" \ 3 | --model-file-path "/home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt" \ 4 | --full-transformer \ 5 | --representation quaternions \ 6 | --batch-size=32 \ 7 | --seq-length=30 \ 8 | --downsample=6 \ 9 | --in-out-ratio=1 \ 10 | --stride=30 \ 11 | --learning-rate=0.001 \ 12 | --beta-one=0.95 \ 13 | --beta-two=0.999 \ 14 | --num-epochs=5 \ 15 | --num-heads=4 \ 16 | --dim-feedforward=512 \ 17 | --dropout=0.0 \ 18 | --num-layers=2 19 | -------------------------------------------------------------------------------- /src/transformers/training_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import time 10 | import math 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from common.logging import logger 14 | plt.switch_backend("agg") 15 | torch.manual_seed(0) 16 | 17 | 18 | class Timer: 19 | def __enter__(self): 20 | self.start = time.time() 21 | return self 22 | 23 | def __exit__(self, *args): 24 | self.end = time.time() 25 | self.interval = self.end - self.start 26 | 27 | 28 | def loss_batch(model, optimizer, data, criterion, device, 29 | full_transformer=False, average_batch=True): 30 | """Run a batch through the Transformer Encoder or Transformer for training. 31 | 32 | Used in the fit function to train the models. 33 | 34 | Args: 35 | model (nn.Module): model to pass batch through 36 | optimizer (optim.Optimizer): optimizer for training 37 | data (tuple): contains input and targets to use for inference 38 | criterion (nn.Module): criterion to evaluate model on 39 | device (torch.device): device to put data on 40 | full_transformer (bool, optional): whether the model is a full 41 | transformer; the forward pass will operate differently if so. 42 | Defaults to False. 43 | average_batch (bool, optional): whether to average the batch; useful 44 | for plotting histograms if average_batch is false. 45 | Defaults to True. 46 | 47 | Returns: 48 | float or list: returns a single loss or a list of losses depending on 49 | average_batch argument. 50 | """ 51 | if optimizer is None: 52 | model.eval() 53 | else: 54 | model.train() 55 | 56 | inputs, targets = data 57 | inputs = inputs.permute(1, 0, 2).to(device).double() 58 | targets = targets.permute(1, 0, 2).to(device).double() 59 | 60 | outputs = None 61 | if full_transformer: 62 | SOS = torch.zeros_like(targets[0, :]).unsqueeze(0) 63 | tgt = torch.cat((SOS, targets), dim=0) 64 | 65 | if tgt.shape[2] > inputs.shape[2]: 66 | padding = torch.zeros((inputs.shape[0], inputs.shape[1], 67 | tgt.shape[2] - inputs.shape[2])) 68 | padding = padding.to(device).double() 69 | inputs = torch.cat((inputs, padding), dim=2) 70 | 71 | outputs = model(inputs, tgt[:-1, :]) 72 | else: 73 | outputs = model(inputs) 74 | 75 | loss = criterion(outputs, targets) 76 | 77 | if optimizer is not None: 78 | loss.backward() 79 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 80 | optimizer.step() 81 | optimizer.zero_grad() 82 | 83 | if average_batch: 84 | return loss.item() 85 | else: 86 | losses = [] 87 | for b in range(outputs.shape[1]): 88 | sample_loss = criterion(outputs[:, b, :], targets[:, b, :]) 89 | losses.append(sample_loss.item()) 90 | 91 | return losses 92 | 93 | 94 | def inference(model, data, criterion, device, average_batch=True): 95 | """Run inference with full Transformer. 96 | 97 | Used to determine actual validation loss that will be seen in practice 98 | since full Transformers use teacher forcing. 99 | 100 | 101 | Args: 102 | model (nn.Module): model to run inference with 103 | data (tuple): contains input and targets to use for inference 104 | criterion (nn.Module): criterion to evaluate model on 105 | device (torch.device): device to put data on 106 | average_batch (bool, optional): whether to average the batch; useful 107 | for plotting histograms if average_batch is false. 108 | Defaults to True. 109 | 110 | Returns: 111 | float or list: returns a single loss or a list of losses depending on 112 | average_batch argument. 113 | """ 114 | model.eval() 115 | 116 | inputs, targets = data 117 | inputs = inputs.permute(1, 0, 2).to(device).double() 118 | targets = targets.permute(1, 0, 2).to(device).double() 119 | 120 | pred = torch.zeros((targets.shape[0]+1, 121 | targets.shape[1], 122 | targets.shape[2])).to(device).double() 123 | 124 | if pred.shape[2] > inputs.shape[2]: 125 | padding = torch.zeros((inputs.shape[0], 126 | inputs.shape[1], 127 | pred.shape[2] - inputs.shape[2])) 128 | padding = padding.to(device).double() 129 | inputs = torch.cat((inputs.clone(), padding), dim=2) 130 | 131 | memory = model.encoder(model.pos_decoder(inputs)) 132 | 133 | for i in range(pred.shape[0]-1): 134 | next_pred = model.inference(memory, pred[:i+1, :].clone()) 135 | pred[i+1, :] = pred[i+1, :].clone() + next_pred[-1, :].clone() 136 | 137 | if average_batch: 138 | loss = criterion(pred[1:, :], targets) 139 | return loss.item() 140 | else: 141 | losses = [] 142 | for b in range(pred.shape[1]): 143 | loss = criterion(pred[1:, b, :], targets[:, b, :]) 144 | losses.append(loss.item()) 145 | return losses 146 | 147 | 148 | def fit(model, optimizer, scheduler, epochs, dataloaders, training_criterion, 149 | validation_criteria, device, model_file_path, 150 | full_transformer=False, min_val_loss=math.inf): 151 | """Fit a Transformer model to data, logging training and validation loss. 152 | 153 | Args: 154 | model (nn.Module): the model to train 155 | optimizer (tuple): the optimizer for training 156 | scheduler (list): scheduler to control learning rate for 157 | optimizers 158 | epochs (int): number of epochs to train for 159 | dataloaders (tuple): tuple containing the training dataloader and val 160 | dataloader. 161 | training_criterion (nn.Module): criterion for backpropagation during 162 | training 163 | validation_criteria (list): list of criteria for validation 164 | device (torch.device): device to place data on 165 | model_file_path (str): where to save the model when validation loss 166 | reaches new minimum 167 | full_transformer (bool): whether the model is a full transformer and 168 | needs to run inference for evaluation 169 | min_val_loss (float, optional): minimum validation loss 170 | 171 | Returns: 172 | float: minimum validation loss reached during training 173 | """ 174 | train_dataloader, val_dataloader = dataloaders 175 | total_time = 0 176 | 177 | for epoch in range(epochs): 178 | losses = 0 179 | logger.info("Epoch {}".format(epoch + 1)) 180 | avg_loss = 0 181 | for index, data in enumerate(train_dataloader, 0): 182 | with Timer() as timer: 183 | loss = loss_batch(model, optimizer, data, training_criterion, 184 | device, full_transformer=full_transformer) 185 | losses += loss 186 | avg_loss += loss 187 | total_time += timer.interval 188 | if index % (len(train_dataloader) // 10) == 0 and index != 0: 189 | avg_training_loss = avg_loss / (len(train_dataloader) // 10) 190 | logger.info((f"Total time elapsed: {total_time}" 191 | " - " 192 | f"Batch number: {index} / {len(train_dataloader)}" 193 | " - " 194 | f"Training loss: {avg_training_loss}" 195 | " - " 196 | f"LR: {optimizer.param_groups[0]['lr']}" 197 | )) 198 | avg_loss = 0 199 | 200 | val_loss = [] 201 | for validation_criterion in validation_criteria: 202 | with torch.no_grad(): 203 | val_losses = [loss_batch(model, None, data, 204 | validation_criterion, device, 205 | full_transformer=full_transformer) 206 | for _, data in enumerate(val_dataloader, 0)] 207 | 208 | val_loss.append(np.sum(val_losses) / len(val_losses)) 209 | 210 | loss = losses / len(train_dataloader) 211 | 212 | scheduler.step() 213 | val_loss_str = ", ".join(map(str, val_loss)) 214 | logger.info(f"Epoch {epoch+1} - " 215 | f"Training Loss: {loss} - " 216 | f"Val Loss: {val_loss_str}") 217 | 218 | if full_transformer: 219 | inference_loss = [] 220 | for validation_criterion in validation_criteria: 221 | with torch.no_grad(): 222 | inference_losses = [inference(model, data, 223 | validation_criterion, device) 224 | for _, data in 225 | enumerate(val_dataloader, 0)] 226 | inference_loss.append( 227 | np.sum(inference_losses) / len(inference_losses) 228 | ) 229 | inference_loss_str = ", ".join(map(str, inference_loss)) 230 | logger.info(f"Inference Loss: {inference_loss_str}") 231 | 232 | if val_loss[0] < min_val_loss: 233 | min_val_loss = val_loss[0] 234 | logger.info(f"Saving model to {model_file_path}") 235 | torch.save({ 236 | "model_state_dict": model.state_dict(), 237 | "optimizer_state_dict": optimizer.state_dict(), 238 | }, model_file_path) 239 | 240 | return min_val_loss 241 | -------------------------------------------------------------------------------- /src/transformers/transformers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-present, Assistive Robotics Lab 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Transformer classes with quaternion normalization. 10 | 11 | Reference: 12 | [1] https://pytorch.org/tutorials/beginner/transformer_tutorial.html 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import math 18 | 19 | 20 | class PositionalEncoding(nn.Module): 21 | """PositionalEncoding injects position-dependent signals in input/targets. 22 | 23 | Transformers have no concept of sequence like RNNs, so this acts to inject 24 | information about the order of a sequence. 25 | 26 | Useful reference for more info: 27 | https://datascience.stackexchange.com/questions/51065/what-is-the-positional-encoding-in-the-transformer-model 28 | """ 29 | 30 | def __init__(self, d_model, dropout=0.1, max_len=5000): 31 | super(PositionalEncoding, self).__init__() 32 | self.dropout = nn.Dropout(p=dropout) 33 | 34 | pe = torch.zeros(max_len, d_model) 35 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 36 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 37 | (-math.log(10000.0) / d_model)) 38 | pe[:, 0::2] = torch.sin(position * div_term) 39 | pe[:, 1::2] = torch.cos(position * div_term)[:, :d_model//2] 40 | pe = pe.unsqueeze(0).transpose(0, 1) 41 | self.register_buffer("pe", pe) 42 | 43 | def forward(self, x): 44 | x = x + self.pe[:x.size(0), :] 45 | return self.dropout(x) 46 | 47 | 48 | class InferenceTransformerEncoder(nn.Module): 49 | """Transformer Encoder for use in inferring unit quaternions.""" 50 | 51 | def __init__(self, num_input_features, num_heads, dim_feedforward, dropout, 52 | num_layers, num_output_features, quaternions=False): 53 | """Initialize the Transformer Encoder. 54 | 55 | Args: 56 | num_input_features (int): number of features in the input 57 | data 58 | num_heads (int): number of heads in each layer for multi-head 59 | attention 60 | dim_feedforward (int): dimensionality of the feedforward layers in 61 | each layer 62 | dropout (float): dropout amount in the layers 63 | num_layers (int): number of layers in Encoder 64 | num_output_features (int): number of features in the output 65 | data 66 | quaternions (bool, optional): whether quaternions are used in the 67 | output; will normalize if True. Defaults to False. 68 | """ 69 | super(InferenceTransformerEncoder, self).__init__() 70 | self.pos_encoder = PositionalEncoding(num_input_features, dropout) 71 | self.encoder_layer = nn.TransformerEncoderLayer(num_input_features, 72 | num_heads, 73 | dim_feedforward, 74 | dropout) 75 | self.encoder = nn.TransformerEncoder(self.encoder_layer, 76 | num_layers, 77 | norm=None) 78 | self.linear1 = nn.Linear(num_input_features, dim_feedforward) 79 | self.linear2 = nn.Linear(dim_feedforward, num_output_features) 80 | self.quaternions = quaternions 81 | 82 | def forward(self, src): 83 | """Forward pass through Transformer Encoder model for training/testing. 84 | 85 | Args: 86 | src (torch.Tensor): used by encoder to generate output 87 | 88 | Returns: 89 | torch.Tensor: output from the Transformer 90 | """ 91 | pos_enc = self.pos_encoder(src) 92 | enc_output = self.encoder(pos_enc) 93 | output = self.linear2(self.linear1(enc_output)) 94 | 95 | if self.quaternions: 96 | original_shape = output.shape 97 | 98 | output = output.view(-1, 4) 99 | output = F.normalize(output, p=2, dim=1).view(original_shape) 100 | 101 | return output 102 | 103 | 104 | class InferenceTransformer(nn.Module): 105 | """Transformer for use in inferring unit quaternions.""" 106 | 107 | def __init__(self, num_features, num_heads, dim_feedforward, dropout, 108 | num_layers, quaternions=False): 109 | """Initialize the Transformer model. 110 | 111 | Args: 112 | num_features (int): number of features in the input and target 113 | data 114 | num_heads (int): number of heads in each layer for multi-head 115 | attention 116 | dim_feedforward (int): dimensionality of the feedforward layers in 117 | each layer 118 | dropout (float): dropout amount in the layers 119 | num_layers (int): number of layers in Encoder and Decoder 120 | quaternions (bool, optional): whether quaternions are used in the 121 | output; will normalize if True. Defaults to False. 122 | """ 123 | super(InferenceTransformer, self).__init__() 124 | 125 | self.pos_encoder = PositionalEncoding(num_features, dropout) 126 | self.pos_decoder = PositionalEncoding(num_features, dropout) 127 | 128 | self.encoder_layer = nn.TransformerEncoderLayer(num_features, 129 | num_heads, 130 | dim_feedforward, 131 | dropout) 132 | self.encoder = nn.TransformerEncoder(self.encoder_layer, 133 | num_layers, 134 | norm=None) 135 | self.decoder_layer = nn.TransformerDecoderLayer(num_features, 136 | num_heads, 137 | dim_feedforward, 138 | dropout) 139 | self.decoder = nn.TransformerDecoder(self.decoder_layer, 140 | num_layers, 141 | norm=None) 142 | self.tgt_mask = None 143 | self.quaternions = quaternions 144 | 145 | def generate_square_subsequent_mask(self, sz): 146 | """Mask the upcoming values in the tensor to avoid cheating. 147 | 148 | Args: 149 | sz (int): sequence length of tensor 150 | 151 | Returns: 152 | torch.Tensor: mask of subsequent entries during forward pass 153 | """ 154 | mask = torch.triu(torch.ones(sz, sz), 1) 155 | mask = mask.masked_fill(mask == 1, float("-inf")) 156 | return mask 157 | 158 | def inference(self, memory, target): 159 | """Forward pass through Transformer model at validation/test time. 160 | 161 | Args: 162 | memory (torch.Tensor): memory passed from the Encoder 163 | target (torch.Tensor): predictions built up over time during 164 | inference 165 | 166 | Returns: 167 | torch.Tensor: output from the model (passed into model in next 168 | iteration) 169 | """ 170 | if self.tgt_mask is None or self.tgt_mask.size(0) != len(target): 171 | self.tgt_mask = self.generate_square_subsequent_mask(len(target)) 172 | self.tgt_mask = self.tgt_mask.to(target.device) 173 | 174 | pos_dec = self.pos_decoder(target) 175 | output = self.decoder(pos_dec, memory, tgt_mask=self.tgt_mask) 176 | 177 | if self.quaternions: 178 | original_shape = output.shape 179 | output = F.normalize(output.view(-1, 4), 180 | p=2, 181 | dim=1).view(original_shape) 182 | 183 | return output 184 | 185 | def forward(self, src, target): 186 | """Forward pass through Transformer model for training. 187 | 188 | Use inference function at validation/test time to get accurate 189 | measure of performance. 190 | 191 | Args: 192 | src (torch.Tensor): used by encoder to generate memory 193 | target (torch.Tensor): target for decoder to try to match; 194 | Transformers use teacher forcing so targets are used as input 195 | 196 | Returns: 197 | torch.Tensor: output from the Transformer 198 | """ 199 | if self.tgt_mask is None or self.tgt_mask.size(0) != len(target): 200 | self.tgt_mask = self.generate_square_subsequent_mask(len(target)) 201 | self.tgt_mask = self.tgt_mask.to(target.device) 202 | 203 | pos_enc = self.pos_decoder(src) 204 | memory = self.encoder(pos_enc) 205 | 206 | pos_dec = self.pos_decoder(target) 207 | output = self.decoder(pos_dec, memory, tgt_mask=self.tgt_mask) 208 | 209 | if self.quaternions: 210 | original_shape = output.shape 211 | output = F.normalize(output.view(-1, 4), 212 | p=2, 213 | dim=1).view(original_shape) 214 | 215 | return output 216 | --------------------------------------------------------------------------------