├── .gitignore
├── LICENSE
├── README.md
├── requirements.txt
└── src
├── 01-dataset-tutorial.ipynb
├── 02-deep-learning-tutorial.ipynb
├── 03-inference-tutorial.ipynb
├── build-dataset.py
├── build_dataset.sh
├── common
├── data_utils.py
├── logging.py
├── losses.py
├── preprocessing.py
├── quaternion.py
├── rotations.py
└── skeleton.py
├── matlab
├── get_folder_path.m
├── joint_angle_segments.m
├── joint_angles.m
├── load_partial_mvnx.m
├── mvnx_to_csv.m
├── mvnx_to_hdf.m
├── mvnx_to_hdf_batch.m
├── process_mvnx_files.m
├── segment_orientation.m
├── segment_position.m
└── segment_reference.m
├── seq2seq
├── seq2seq.py
└── training_utils.py
├── test-seq2seq.py
├── test-transformer.py
├── test_seq2seq.sh
├── test_transformer.sh
├── train-seq2seq.py
├── train-transformer.py
├── train_seq2seq.sh
├── train_transformer.sh
└── transformers
├── training_utils.py
└── transformers.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .idea/
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
127 | # data
128 | *.csv
129 | *.mvnx
130 | *.h5
131 | *.pt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020-present, Assistive Robotics Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Virginia Tech Natural Motion Processing
2 |
3 | This repository was written to help analyze the Virginia Tech Natural Motion Dataset. The dataset contains 40 hours of unscripted human motion collected in the open world using XSens MVN Link. The dataset, metadata and more information is available through the Virginia Tech University Libraries: https://data.lib.vt.edu/articles/dataset/Virginia_Tech_Natural_Motion_Dataset/14114054/2.
4 |
5 | ## Table of Contents
6 |
7 | - [Layout](#project-layout)
8 | - [Workflow](#workflow)
9 | - [Dependencies](#dependencies)
10 | - [Conda Environment](#conda-environment)
11 |
12 | ## Project Layout
13 |
14 | src/
15 | common/
16 | seq2seq/
17 | transformers
18 | matlab/
19 |
20 | ## Dependencies
21 |
22 | numpy==1.18.1
23 | h5py==2.10.0
24 | matplotlib==3.1.3
25 | torch==1.6.0
26 |
27 | ## Setup
28 |
29 | - Clone the repo locally
30 | - Setup the conda environment
31 | - `$ conda create -n vt-nmp python=3.7`
32 | - Install requirements
33 | - `$ pip install -r requirements.txt`
34 |
35 | ## Conda Environment
36 |
37 | An Anaconda environment is used to help with development. The environment's main dependency is PyTorch, which will be installed when setting up the workflow above.
38 |
39 | ## License
40 |
41 | Please see the LICENSE for more details. If you use our code or models in your research, please cite our paper:
42 |
43 | ```
44 | @article{geissinger2020motion,
45 | title={Motion inference using sparse inertial sensors, self-supervised learning, and a new dataset of unscripted human motion},
46 | author={Geissinger, Jack H and Asbeck, Alan T},
47 | journal={Sensors},
48 | volume={20},
49 | number={21},
50 | pages={6330},
51 | year={2020},
52 | publisher={Multidisciplinary Digital Publishing Institute}
53 | }
54 | ```
55 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.18.1
2 | h5py==2.10.0
3 | matplotlib==3.1.3
4 | torch==1.6.0
5 |
--------------------------------------------------------------------------------
/src/02-deep-learning-tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Building a dataset, training a Seq2Seq model, and testing it\n",
8 | "\n",
9 | "The Virginia Tech Natural Motion Dataset contains .h5 files with unscripted human motion data collected in real-world environments as participants went about their day-to-day lives. This is a brief tutorial in using the dataset and then training and testing a neural network.\n",
10 | "\n",
11 | "This tutorial illustrates how to use the shell (.sh) scripts to train a seq2seq model (particularly **train_seq2seq.sh** and **test_seq2seq.sh**). Similar shell scripts are also available for the Transformers (see **train_transformer.sh** and **test_transformer.sh**)"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "### Building a dataset\n",
19 | "\n",
20 | "We will first cover how to build a dataset with data from a few participants using the build-dataset.py file.\n",
21 | "\n",
22 | "We are running the script from a Jupyter Notebook, but this can just as easily be run as a shell script (see build_dataset.sh).\n",
23 | "\n",
24 | "In this case, we are drawing data from the h5-dataset folder located in the cloud. We are going to output the training.h5, validation.h5, and testing.h5 files to the folder data/set-2.\n",
25 | "\n",
26 | "We will be using participants 1, 5, and 10 (P1, P5, P10, respectively) and extracting normOrientation and normAcceleration data on a few segments (norm* means data normalized relative to the pelvis). As output data we will be extracting normOrientation data for every segment.\n",
27 | "\n",
28 | "In other words, our task is as follows: use orientation and acceleration from a set of sparse segments and try to train a model mapping that input data to orientations for every segment on the human body."
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 1,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "name": "stdout",
38 | "output_type": "stream",
39 | "text": [
40 | "2020-08-10 13:09:31 INFO Writing X to the training file group...\n",
41 | "2020-08-10 13:09:37 INFO Writing X to the validation file group...\n",
42 | "2020-08-10 13:09:40 INFO Writing X to the testing file group...\n",
43 | "2020-08-10 13:09:50 INFO Writing Y to the training file group...\n",
44 | "2020-08-10 13:09:57 INFO Writing Y to the validation file group...\n",
45 | "2020-08-10 13:09:58 INFO Writing Y to the testing file group...\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "!mkdir -p /home/jackg7/VT-Natural-Motion-Processing/data/set-2\n",
51 | "!python build-dataset.py --data-path \"/groups/MotionPred/h5-dataset\" \\\n",
52 | " --output-path \"/home/jackg7/VT-Natural-Motion-Processing/data/set-2\" \\\n",
53 | " --training \"P1\" \\\n",
54 | " --validation \"P5\" \\\n",
55 | " --testing \"P10\" \\\n",
56 | " --task-input \"normOrientation normAcceleration\" \\\n",
57 | " --input-label-request \"T8 RightForeArm RightLowerLeg LeftForeArm LeftLowerLeg\" \\\n",
58 | " --task-output \"normOrientation\" \\\n",
59 | " --output-label-request \"all\""
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "### Training a seq2seq model\n",
67 | "\n",
68 | "We can now train a seq2seq model to map the normOrientation and normAcceleration data from the sparse segments to the full-body normOrientation data.\n",
69 | "\n",
70 | "We will be using a seq-length of 30 (at 240 Hz) downsample it by a factor of 6 (to 40 Hz). The resulting sequences will be of length 5 for the input and output. The in-out-ratio will then be used to reduce the output sequence length to 1.\n",
71 | "\n",
72 | "The input sequence will be of shape (B, 5, 35) and output shape will be of shape (B, 1, 92). Orientations are stored as quaternions, so orientation value will be 4 in length. The number 35 comes from our use of 5 segment orientations and accelerations or $5*4 + 5*3 = 35$. The full-body has 23 segments and we're predicting orientation values for each one or $23*4 = 92$\n",
73 | "\n",
74 | "We're training a seq2seq model with a hidden size of 512, a bidirectional encoder and dot product attention. The model will be trained for a single epoch.\n",
75 | "\n",
76 | "Our loss function for training will be the L1Loss and our validation losses will be the L1Loss and the QuatDistance (cosine similarity) loss."
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 2,
82 | "metadata": {},
83 | "outputs": [
84 | {
85 | "name": "stdout",
86 | "output_type": "stream",
87 | "text": [
88 | "2020-08-10 13:10:06 INFO task - conversion\n",
89 | "2020-08-10 13:10:06 INFO data_path - /home/jackg7/VT-Natural-Motion-Processing/data/set-2\n",
90 | "2020-08-10 13:10:06 INFO model_file_path - /home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt\n",
91 | "2020-08-10 13:10:06 INFO representation - quaternions\n",
92 | "2020-08-10 13:10:06 INFO auxiliary_acc - False\n",
93 | "2020-08-10 13:10:06 INFO batch_size - 32\n",
94 | "2020-08-10 13:10:06 INFO learning_rate - 0.001\n",
95 | "2020-08-10 13:10:06 INFO seq_length - 30\n",
96 | "2020-08-10 13:10:06 INFO downsample - 6\n",
97 | "2020-08-10 13:10:06 INFO in_out_ratio - 5\n",
98 | "2020-08-10 13:10:06 INFO stride - 30\n",
99 | "2020-08-10 13:10:06 INFO num_epochs - 1\n",
100 | "2020-08-10 13:10:06 INFO hidden_size - 512\n",
101 | "2020-08-10 13:10:06 INFO dropout - 0.0\n",
102 | "2020-08-10 13:10:06 INFO bidirectional - True\n",
103 | "2020-08-10 13:10:06 INFO attention - dot\n",
104 | "2020-08-10 13:10:06 INFO Starting seq2seq model training...\n",
105 | "2020-08-10 13:10:06 INFO Retrieving training data for sequences 125 ms long and downsampling to 40.0 Hz...\n",
106 | "2020-08-10 13:10:09 INFO Data for training have shapes (X, y): torch.Size([259570, 35]), torch.Size([51914, 92])\n",
107 | "2020-08-10 13:10:09 INFO Reshaped training shapes (X, y): torch.Size([51914, 5, 35]), torch.Size([51914, 1, 92])\n",
108 | "2020-08-10 13:10:09 INFO Number of training samples: 51914\n",
109 | "2020-08-10 13:10:09 INFO Retrieving validation data for sequences 125 ms long and downsampling to 40.0 Hz...\n",
110 | "2020-08-10 13:10:09 INFO Data for validation have shapes (X, y): torch.Size([90880, 35]), torch.Size([18176, 92])\n",
111 | "2020-08-10 13:10:09 INFO Reshaped validation shapes (X, y): torch.Size([18176, 5, 35]), torch.Size([18176, 1, 92])\n",
112 | "2020-08-10 13:10:09 INFO Number of validation samples: 18176\n",
113 | "2020-08-10 13:10:09 INFO Encoder for training: EncoderRNN(\n",
114 | " (gru): GRU(35, 512, bidirectional=True)\n",
115 | " (dropout): Dropout(p=0.0, inplace=False)\n",
116 | " (fc): Linear(in_features=1024, out_features=512, bias=True)\n",
117 | ")\n",
118 | "2020-08-10 13:10:09 INFO Decoder for training: AttnDecoderRNN(\n",
119 | " (attention): Attention()\n",
120 | " (rnn): GRU(1116, 512)\n",
121 | " (out): Linear(in_features=1628, out_features=92, bias=True)\n",
122 | ")\n",
123 | "2020-08-10 13:10:09 INFO Number of parameters: 4864876\n",
124 | "2020-08-10 13:10:09 INFO Optimizers for training: AdamW (\n",
125 | "Parameter Group 0\n",
126 | " amsgrad: False\n",
127 | " betas: (0.9, 0.999)\n",
128 | " eps: 1e-08\n",
129 | " initial_lr: 0.001\n",
130 | " lr: 0.001\n",
131 | " weight_decay: 0.05\n",
132 | ")\n",
133 | "2020-08-10 13:10:09 INFO Criterion for training: L1Loss()\n",
134 | "2020-08-10 13:10:09 INFO Epoch 1 / 1\n",
135 | "2020-08-10 13:10:09 INFO Total time elapsed: 0.2845275402069092 - Batch Number: 0 / 1622 - Training loss: 0.5867175199891831\n",
136 | "2020-08-10 13:10:28 INFO Total time elapsed: 19.075539588928223 - Batch Number: 162 / 1622 - Training loss: 0.06888134323321701\n",
137 | "2020-08-10 13:10:46 INFO Total time elapsed: 36.41547679901123 - Batch Number: 324 / 1622 - Training loss: 0.06516135748519093\n",
138 | "2020-08-10 13:11:06 INFO Total time elapsed: 56.93833088874817 - Batch Number: 486 / 1622 - Training loss: 0.06124250432828066\n",
139 | "2020-08-10 13:11:27 INFO Total time elapsed: 77.28701090812683 - Batch Number: 648 / 1622 - Training loss: 0.06000007542016456\n",
140 | "2020-08-10 13:11:49 INFO Total time elapsed: 99.8214840888977 - Batch Number: 810 / 1622 - Training loss: 0.06179196632301481\n",
141 | "2020-08-10 13:12:13 INFO Total time elapsed: 123.77892279624939 - Batch Number: 972 / 1622 - Training loss: 0.055964607362772964\n",
142 | "2020-08-10 13:12:34 INFO Total time elapsed: 144.65489411354065 - Batch Number: 1134 / 1622 - Training loss: 0.05693873948561733\n",
143 | "2020-08-10 13:12:58 INFO Total time elapsed: 167.9723401069641 - Batch Number: 1296 / 1622 - Training loss: 0.05655428921935604\n",
144 | "2020-08-10 13:13:21 INFO Total time elapsed: 191.32336831092834 - Batch Number: 1458 / 1622 - Training loss: 0.05438031324405435\n",
145 | "2020-08-10 13:13:44 INFO Total time elapsed: 214.2629885673523 - Batch Number: 1620 / 1622 - Training loss: 0.05223539072065468\n",
146 | "2020-08-10 13:14:03 INFO Training Loss: 0.06224823312900671 - Val Loss: 0.10371989231654884, 22.61474124110037\n",
147 | "2020-08-10 13:14:03 INFO Saving model to /home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt\n",
148 | "2020-08-10 13:14:04 INFO Completed Training...\n",
149 | "2020-08-10 13:14:04 INFO \n",
150 | "\n"
151 | ]
152 | }
153 | ],
154 | "source": [
155 | "!mkdir -p /home/jackg7/VT-Natural-Motion-Processing/models/set-2\n",
156 | "!python train-seq2seq.py --task conversion \\\n",
157 | " --data-path \"/home/jackg7/VT-Natural-Motion-Processing/data/set-2\" \\\n",
158 | " --model-file-path \"/home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt\" \\\n",
159 | " --representation quaternions \\\n",
160 | " --batch-size=32 \\\n",
161 | " --seq-length=30 \\\n",
162 | " --downsample=6 \\\n",
163 | " --in-out-ratio=5 \\\n",
164 | " --stride=30 \\\n",
165 | " --learning-rate=0.001 \\\n",
166 | " --num-epochs=1 \\\n",
167 | " --hidden-size=512 \\\n",
168 | " --attention=dot \\\n",
169 | " --bidirectional"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {},
175 | "source": [
176 | "### Testing our model\n",
177 | "\n",
178 | "We can now test our model and output a histogram of performance over the testing data. The model parameters must be the same to properly read in the model."
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 3,
184 | "metadata": {},
185 | "outputs": [
186 | {
187 | "name": "stdout",
188 | "output_type": "stream",
189 | "text": [
190 | "2020-08-10 13:14:06 INFO task - conversion\n",
191 | "2020-08-10 13:14:06 INFO data_path_parent - /home/jackg7/VT-Natural-Motion-Processing/data\n",
192 | "2020-08-10 13:14:06 INFO figure_file_path - /home/jackg7/VT-Natural-Motion-Processing/images/seq2seq-test.pdf\n",
193 | "2020-08-10 13:14:06 INFO figure_title - Seq2Seq\n",
194 | "2020-08-10 13:14:06 INFO include_legend - False\n",
195 | "2020-08-10 13:14:06 INFO model_dir - /home/jackg7/VT-Natural-Motion-Processing/models/set-2\n",
196 | "2020-08-10 13:14:06 INFO representation - quaternions\n",
197 | "2020-08-10 13:14:06 INFO batch_size - 512\n",
198 | "2020-08-10 13:14:06 INFO seq_length - 30\n",
199 | "2020-08-10 13:14:06 INFO downsample - 6\n",
200 | "2020-08-10 13:14:06 INFO in_out_ratio - 5\n",
201 | "2020-08-10 13:14:06 INFO stride - 30\n",
202 | "2020-08-10 13:14:06 INFO hidden_size - 512\n",
203 | "2020-08-10 13:14:06 INFO dropout - 0.0\n",
204 | "2020-08-10 13:14:06 INFO bidirectional - True\n",
205 | "2020-08-10 13:14:06 INFO attention - dot\n",
206 | "2020-08-10 13:14:06 INFO Starting seq2seq model testing...\n",
207 | "2020-08-10 13:14:06 INFO Retrieving testing data for sequences 125 ms long and downsampling to 40.0 Hz...\n",
208 | "2020-08-10 13:14:09 INFO Data for testing have shapes (X, y): torch.Size([452760, 35]), torch.Size([90552, 92])\n",
209 | "2020-08-10 13:14:09 INFO Reshaped testing shapes (X, y): torch.Size([90552, 5, 35]), torch.Size([90552, 1, 92])\n",
210 | "2020-08-10 13:14:09 INFO Number of testing samples: 90552\n",
211 | "2020-08-10 13:14:37 INFO Inference Loss for set-2: 14.272361435695645\n",
212 | "2020-08-10 13:14:37 INFO Completed Testing...\n",
213 | "2020-08-10 13:14:37 INFO \n",
214 | "\n"
215 | ]
216 | }
217 | ],
218 | "source": [
219 | "!mkdir -p /home/jackg7/VT-Natural-Motion-Processing/images\n",
220 | "!python test-seq2seq.py --task conversion \\\n",
221 | " --data-path-parent /home/jackg7/VT-Natural-Motion-Processing/data \\\n",
222 | " --figure-file-path /home/jackg7/VT-Natural-Motion-Processing/images/seq2seq-test.pdf \\\n",
223 | " --figure-title \"Seq2Seq\" \\\n",
224 | " --model-dir /home/jackg7/VT-Natural-Motion-Processing/models/set-2 \\\n",
225 | " --representation quaternions \\\n",
226 | " --batch-size=512 \\\n",
227 | " --seq-length=30 \\\n",
228 | " --downsample=6 \\\n",
229 | " --in-out-ratio=5 \\\n",
230 | " --stride=30 \\\n",
231 | " --hidden-size=512 \\\n",
232 | " --attention=dot \\\n",
233 | " --bidirectional"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {},
239 | "source": [
240 | "We can now visualize the performance of the seq2seq model on the test data."
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 4,
246 | "metadata": {},
247 | "outputs": [
248 | {
249 | "data": {
250 | "text/html": [
251 | "\n",
252 | " \n",
259 | " "
260 | ],
261 | "text/plain": [
262 | ""
263 | ]
264 | },
265 | "execution_count": 4,
266 | "metadata": {},
267 | "output_type": "execute_result"
268 | }
269 | ],
270 | "source": [
271 | "from IPython.display import IFrame\n",
272 | "IFrame(\"../images/seq2seq-test.pdf\", width=600, height=300)"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": null,
278 | "metadata": {},
279 | "outputs": [],
280 | "source": []
281 | }
282 | ],
283 | "metadata": {
284 | "kernelspec": {
285 | "display_name": "Python 3",
286 | "language": "python",
287 | "name": "python3"
288 | },
289 | "language_info": {
290 | "codemirror_mode": {
291 | "name": "ipython",
292 | "version": 3
293 | },
294 | "file_extension": ".py",
295 | "mimetype": "text/x-python",
296 | "name": "python",
297 | "nbconvert_exporter": "python",
298 | "pygments_lexer": "ipython3",
299 | "version": "3.7.4"
300 | }
301 | },
302 | "nbformat": 4,
303 | "nbformat_minor": 2
304 | }
305 |
--------------------------------------------------------------------------------
/src/03-inference-tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Running Inference with our Seq2Seq model"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%load_ext autoreload\n",
17 | "%autoreload 2\n",
18 | "%matplotlib notebook\n",
19 | "\n",
20 | "import torch\n",
21 | "import torch.nn.functional as F\n",
22 | "import h5py\n",
23 | "import numpy as np\n",
24 | "import glob\n",
25 | "from common.quaternion import quat_mul\n",
26 | "from common.data_utils import read_h5\n",
27 | "from common.skeleton import Skeleton\n",
28 | "from seq2seq.training_utils import get_encoder, get_attn_decoder"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "### Loading in some data to test with\n",
36 | "\n",
37 | "Let's start by loading in data.\n",
38 | "\n",
39 | "We trained our model using \"normOrientation\" and \"normAcceleration\" from the T8 (sternum), both forearms, and both lower legs as input, so we'll read in that data. We'll also read in some data for \"normOrientation\" on the entire body because this is the output of our model. Finally, we'll read in data for the orientation of the pelvis. This is important because we want to rotate the \"normOrientation\" data back into it's original reference frame."
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 2,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "group = [\"T8\", \"RightForeArm\", \"RightLowerLeg\", \"LeftForeArm\", \"LeftLowerLeg\"]\n",
49 | "\n",
50 | "filepaths = glob.glob(\"../data/*.h5\")\n",
51 | "requests = {\"normOrientation\" : [\"all\"], \"orientation\" : [\"Pelvis\"]}\n",
52 | "dataset = read_h5(filepaths, requests)\n",
53 | "\n",
54 | "filename = filepaths[0].split(\"/\")[-1]\n",
55 | "fullBodyOrientations = torch.Tensor(dataset[filename][\"normOrientation\"]).double()\n",
56 | "root = torch.Tensor(dataset[filename]['orientation']).double()\n",
57 | "\n",
58 | "requests = {\"normOrientation\" : group, \"normAcceleration\" : group}\n",
59 | "dataset = read_h5(filepaths, requests)\n",
60 | "orientationInputs = torch.Tensor(dataset[filename][\"normOrientation\"]).double()\n",
61 | "accelerationInputs = torch.Tensor(dataset[filename][\"normAcceleration\"]).double()\n",
62 | "\n",
63 | "with h5py.File(\"../data/set-2/normalization.h5\", \"r\") as f:\n",
64 | " mean, std_dev = torch.Tensor(f[\"mean\"]), torch.Tensor(f[\"std_dev\"])"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "metadata": {},
70 | "source": [
71 | "### Loading in the Seq2Seq Model\n",
72 | "\n",
73 | "We can now load in our Seq2Seq models (encoder and decoder).\n",
74 | "\n",
75 | "The models must have the same arguments used during training so that errors don't pop up."
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 3,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "encoder_feature_size = len(group)*4 + len(group)*3\n",
85 | "decoder_feature_size = 92\n",
86 | "hidden_size = 512\n",
87 | "attention = \"dot\"\n",
88 | "bidirectional = True\n",
89 | "\n",
90 | "seq_length = 30\n",
91 | "downsample = 6\n",
92 | "in_out_ratio = 5\n",
93 | "\n",
94 | "\n",
95 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
96 | "\n",
97 | "encoder = get_encoder(encoder_feature_size,\n",
98 | " device,\n",
99 | " hidden_size=hidden_size,\n",
100 | " bidirectional=bidirectional)\n",
101 | "\n",
102 | "decoder = get_attn_decoder(decoder_feature_size,\n",
103 | " attention,\n",
104 | " device,\n",
105 | " hidden_size=hidden_size,\n",
106 | " bidirectional_encoder=bidirectional)\n",
107 | "\n",
108 | " \n",
109 | "decoder.batch_size = 1\n",
110 | "decoder.attention.batch_size = 1\n",
111 | "\n",
112 | "PATH = \"/home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt\"\n",
113 | "\n",
114 | "checkpoint = torch.load(PATH, map_location=device)\n",
115 | "\n",
116 | "encoder.load_state_dict(checkpoint['encoder_state_dict'])\n",
117 | "decoder.load_state_dict(checkpoint['decoder_state_dict'])\n",
118 | "\n",
119 | "encoder.eval()\n",
120 | "decoder.eval()\n",
121 | "models = (encoder, decoder)"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "### Defining the inference function\n",
129 | "\n",
130 | "Our inference function needs to take in a batch of input data, pass it through both the encoder and decoder, and then return the output.\n",
131 | "\n",
132 | "Note that this function is very similar to the *loss_batch* function defined in *src/seq2seq/training_utils.py*. Here it returns the outputs instead of the loss for training/validation purposes."
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 4,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "def inference(data, models, device, use_attention=False, norm_quaternions=False):\n",
142 | " encoder, decoder = models\n",
143 | " input_batch, target_batch = data\n",
144 | "\n",
145 | " input_batch = input_batch.to(device)\n",
146 | " target_batch = target_batch.to(device)\n",
147 | "\n",
148 | " seq_length = target_batch.shape[1]\n",
149 | "\n",
150 | " input = input_batch.permute(1, 0, 2)\n",
151 | " encoder_outputs, encoder_hidden = encoder(input)\n",
152 | "\n",
153 | " decoder_hidden = encoder_hidden\n",
154 | " decoder_input = torch.ones_like(target_batch[:, 0, :]).unsqueeze(0)\n",
155 | " \n",
156 | " outputs = torch.zeros_like(target_batch)\n",
157 | "\n",
158 | " for t in range(seq_length):\n",
159 | "\n",
160 | " if use_attention:\n",
161 | " decoder_output, decoder_hidden = decoder(\n",
162 | " decoder_input, decoder_hidden, encoder_outputs)\n",
163 | " else:\n",
164 | " decoder_output, decoder_hidden = decoder(\n",
165 | " decoder_input, decoder_hidden)\n",
166 | " \n",
167 | " target = target_batch[:, t, :].unsqueeze(0).double()\n",
168 | " \n",
169 | " output = decoder_output\n",
170 | "\n",
171 | " if norm_quaternions:\n",
172 | " original_shape = output.shape\n",
173 | "\n",
174 | " output = output.view(-1,4)\n",
175 | " output = F.normalize(output, p=2, dim=1).view(original_shape)\n",
176 | "\n",
177 | " outputs[:, t, :] = output\n",
178 | " \n",
179 | " decoder_input = output.detach()\n",
180 | "\n",
181 | " return outputs"
182 | ]
183 | },
184 | {
185 | "cell_type": "markdown",
186 | "metadata": {},
187 | "source": [
188 | "### Sampling some data and running it through the inference function\n",
189 | "\n",
190 | "We can now sample some data and use our model to perform inference.\n",
191 | "\n",
192 | "We only trained the model to predict a single posture and did not use much training data as this is just a tutorial.\n",
193 | "\n",
194 | "Note we used normOrientations as output, so we have to put the orientation back into the original frame using the root's (pelvis) orientation."
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": 5,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "i = 15330\n",
204 | "\n",
205 | "inp = torch.cat((orientationInputs[i:i+seq_length:downsample, :], accelerationInputs[i:i+seq_length:downsample, :]), dim=1)\n",
206 | "inp = inp.sub(mean).div(std_dev).double()\n",
207 | "\n",
208 | "out = fullBodyOrientations[i:i+seq_length:downsample, :].double()\n",
209 | "out = out[-1,:].unsqueeze(0) # trained our model to predict only a single pose\n",
210 | "\n",
211 | "data = (inp.unsqueeze(0), out.unsqueeze(0))\n",
212 | " \n",
213 | "output = inference(data, models, device, use_attention=True, norm_quaternions=True)\n",
214 | "\n",
215 | "full_body = fullBodyOrientations[i:i+seq_length:downsample,:].clone()\n",
216 | "full_body = full_body[-1, :].unsqueeze(0)\n",
217 | "\n",
218 | "seq2seq_body = output.clone().squeeze(0)\n",
219 | "\n",
220 | "root_motion = root[i:i+seq_length:downsample, :].clone()\n",
221 | "root_motion = root_motion[-1, :].unsqueeze(0)\n",
222 | "\n",
223 | "root_motion = root_motion.unsqueeze(1).repeat(1, full_body.shape[1]//4, 1)\n",
224 | "\n",
225 | "full_body = quat_mul(root_motion, full_body.view(-1, full_body.shape[1]//4, 4)).view(full_body.shape)\n",
226 | "seq2seq_body = quat_mul(root_motion, seq2seq_body.view(-1, seq2seq_body.shape[1]//4, 4)).view(seq2seq_body.shape)"
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {},
232 | "source": [
233 | "Now that we have the ground truth posture and our seq2seq output, we can compare the motion using the Skeleton. \n",
234 | "\n",
235 | "The output won't look that great because we only used a single participant and a single epoch to train the model."
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": 6,
241 | "metadata": {},
242 | "outputs": [
243 | {
244 | "data": {
245 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAD7CAYAAAC7WecDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAcqklEQVR4nO3deXCc9Z3n8c/Tp87WYUm25UM22NhcNoc9mEA4DANsiJ3YGGZgHZINU7XLMpNkEmonm0zYmmy2andnSFKQDJlkLmoIkwwYL5gzmAUvp4MJwYAxtoMPWZasw7rVre5+nmf/aHVLji+1ju5f9/N+Vbkw7qcff4V49Onfbbmu6woAAOSVL98FAAAAAhkAACMQyAAAGIBABgDAAAQyAAAGIJABADAAgQwAgAEIZAAADEAgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkAAAMQyAAAGIBABgDAAAQyAAAGIJABADAAgQwAgAEIZAAADEAgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkAAAMQyAAAGIBABgDAAJ4KZNu21d7eLtu2810KgJMYHh5WR0eHHMfJdylAznkqkHt6ejRz5ky1tbXxwAMG6uzsVENDgzo7O+W6br7LAXLKU4GclkgkNDg4SEsZMFQ0GtXw8DChDE/xZCAHAgFCGTCYZVmEMjzHk4EspUI5mUwSyoCB/H4/oQzP8WwgW5ZFKAMGGxvKsViMUEbR82wgS4QyYLp0KMdiMUIZRS+Q7wLybWwoJ5NJWZaV75KQQ5Zl8T03nN/vl23bSiQSCofDhLKHeO359HwgS6OhLKXWQcI7LMtSOBz21ENfiPx+fyaYk8lkvstBjnjt+SSQR6S/4V75xkNyXTfzi++7+XhGvcWLzyeBPIZXvukYRfdnYfFaF6bXee359PSkLgAATEEgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkADGX39GjozTfzXQZyJJDvAgAAJ0ocblHz2rWy/H41vfKyfOXl+S4J04wWMgAYKDCnUYE5jXLjcQ29+lq+y0EOEMgAYCDLslR+3fWSpMGXtua5GuQCgQwAhiq//jpJ0tCrr8mJxfJcDaYbgQwAhgqff74Cs2fLjUYVfeONfJeDaUYgA4ChLMtS2XWrJUmDW1/KczWYbgQyABisYmQceej/bZObSOS5GkwnAhmnFD9wQP3PPJPvMgBPC1+0XP66Ojn9A4pu357vcjCNCGScVLKjQ213362Ob31b/Vu25LscwLMsn0/lq6+VNI5ua8eW7HgOqsJ0IJBxAmdgQG333KPkkVYF5s9T2RVX5LskwNPKrx9Z/vTyy3KTyeNf7Dsi33s/V2DzXQr9cIlCD1woHfskD1VistipC8dxEwkd/fo3FP94j/y1tZr90EPy19bmuyzA00ouvVS+6mo5PT2KvfO2yhri8n3yf+Xb/4p8HR+dcH3gpfuUvPWRPFSKySCQkeE6jjruu0/R7dtllZVp1o9/pODcuSe7ULLoXAFyxQoE5B8JZP8v7lSooTfzmitLbuMlcs5aLbd+qQJP/kf59/1Kzr6tchZdn8eqkS0CGRlD27Zp4NnnJL9fs+++QWVdz8h64WFpsEPWUKc02Jn6Z6xXzsVfUPKmv8l3yUDRStgJvdLyip7Y94T2HHxHPzuQGhsureqXW94g56zVqV8LrpLKRnux7JYdCvz6Iflf+o6chVdJ/lC+vgRkiUCGbMfWW21vadarL8qSVLOwT1Utfyu1nPo9vt8+Il3xDalyds7qBLzgcP9hbf5ks7Z8skXHho9Jki5ucSRJocqkfPPPV/xLL0i+k//4tq+8V/4PH5fv2O/k3/Ez2Zfdk7PaMTkEsoe5rqs3Wt/Qg+89qH29+/TXb7hqklTWMCxnzgq5DefJLauTW1YvldfJLauTyusUeO4b8h3+tfzv/1L2p76W7y8DKHhJJ6ltLdv0xL4ntP3o6NKm+tJ63dR0kxLb/nnkD5JK/rv7TxnGkqRwpZLXfEfBZ74i/2t/I/v8DVLFzOn9AjAlCGSP2nVslx747QPa0b5DklQy7Gpemy1JCp89S4nbN0nB0pO+116+MRXI7z0q+/KvMJ4MTNKPd/5Y/7L7XyRJlixdPvtyrT97va5svFIBSVub/0mS9O7i+Zo7+6Iz3s+58DY5v/ln+Vp/o8Ar31Pysw9OZ/mYIvwk9ZjDA4f1rTe+pTt/dad2tO9Q0BfUxiUbtabrLPlcaSDiyLr1+6cMY0lylq6RG6qQ1XNA1iH21wUmazAxKElaNmOZnvzsk3rg6gd0zdxrFPAFJF9AH563TLuXBDSw6kvju6HlU/IP/4ckyf/+L2W17JimyjGVCGSPSDpJff/d72vDsxv0q0O/kiVLn1nwGT1x8xP62sVf0yVHA3Ik7Z0bltt05elvFiqXc946SZL/vZ9Pf/FAkVs5c6UkKWpH1VjReMLrO879iv783P+pqnOuG/c93TmXyr7wjyVJgRe/nVodAaMRyB7xZuubevTjR5V0klo1a5UeufERfXfVdzW7PDUp69fX3KdvrPsjtd/8vXHdz16+UZLk+/gZKdZ7hqsBnM6lDZdKkvb27FV3rPuE1xO2K0kK+Kys7pu85ttyQxXytb4r3/u/nHyhmFYEskcsrV0q38hY772X3KslNUuOe72tN67d7kqVLzp3XPdzZ18kp/5cWcmYfB9umvJ6AS+pLanVoqpFkqR32t854XXbGQlkf3aBrIqZsq+8N/Xel/+7FOubXKGYVgSyR9SX1uvyWZdLkp7e//QJr7f0pg4/b6wuGd8NLUvO8n8viW5rYCqku63fbn/7hNfSgezPsoUsSfaKP5FTu0jWUKf8r7N3gMkIZA9Zc9YaSdIzB56R7djHvdY6Eshzqk49mev32edvkOsPyXf0fVlt709doYAHrZi5QpL09tETAzkxEsjBCQSy/CElr08NRfl3/L2szj0TLxLTikD2kKsar1JVqEod0Q691fZW5s+jcVvHBlPnrI67hSxJZbVyzvmMJFrJwGRdUn+JfJZPh/oPqX2o/bjXJtNCliT37NWyF98ky0kqsPUvJdeddL2YegSyh4T8Id3UdJMkacv+0SMV2/uHJUmlQZ8iJdktTbeX3yFJ8u3aJCWiU1Qp4D2VoUotrVkqSZn9AdKG4qkTnkpD/gnfP3ndX0lS6kCKfS9M+D6YPgSyx6w9a60kaVvLNvUM90iSOgdTe+TWVYRlWdl9AncXXCW3ap6sWG9qxjWACcuMI/9et3V/LBXIleFJ7OVU3pD5bfDxOzk32UAEsscsqVmic6rPUcJJ6PmDz0uSOgdSLeS6iglsQm/5Mmsd/TsfnbI6AS9a2TAayO5It7LjuBqMp+Z8REqCE773CWPHA+0nvxB5QyB7ULqVnO627hxIt5AndiqMvex2ubLkO/ia1L1/aooEPGh5/XIFfAG1DbWpZTB1usvAcDIz5FuZ5ZDSWG7jxRr++ieK//Fjim98Sqo6ydGqyCsC2YNuarpJAV9AH3d/rD3dezKBPKN8gse0Vc3N7Gft/y2HogMTVRoo1YUzLpQk7TiaGkfuG+muDgV8CgUm+SM7XCF34dVy562a3H0wLQhkD6oOV+uqxqskpVrJR3pG1iBXZTHDeqxojyw31aUWeOtByUlOSZ2AF61oOH75U8fIkFL9BHuwUDgIZI9Kd1s/d/A5HerulyTNqx3/GuTjlFYrsebHmX/lwAlg4tITu3a075DrujralwrkmZFwPstCDnD8oketmrVKdSV16ox1aij6jqQlmlszwUCW5Fxwq5Jde+Xb+4Lc6gVTVifgNRfMuEBhf1hdsS7t79uvtr7Uj+mZlRPswULBoIXsUQFfQDcvvFmSNBR+U5I0t3rigSxJ9tXfUuJPtknV8yddH+BVIX9IF9Wlzjx+++jbmRbyLFrIRY8WsoetWbhGD3/0sPwVe1ReNqiasokvqQAwdVbOXKntR7frF3t+IXtwh0L1Uqua9PT+PaoKVSkSjqgqVKWqUJUqQ5Wpc5NR8PguetiCyAItrDhP+wd2qaJupyzrcye9znZsxexY6lcypmgymvl9zI6pIlihi+ovynH1QPG6fPbl+tHOH6l5oFlSs8J10us90uvbT359RbAiFdDhKkVCEUVCEd3YdKOunnN1TuvG5BDIHre86jrtH9ilaNlW3bX1wGjQjoRtNBlV3Dn9jj4rGlboJ6t/kqOKgeK3pGaJHrr2Ie3r2aefvrFLXdEerVoUVjg0rN54r3qHe9Ub79VAYkCSNJAY0EBiILN2WZLOrR3fUaowB4HscU3hT8m1fyrHP6j3Ot874/Ul/hKVBkpH/xkoUVOkKQeVAt6ycuZKrZy5Uj995jUN98b0p59bqWVzqo67JukkNZAYUM9wj/rifZmg7ov36eL6i/NUOSaKQPa4WDyooQP/WSvPGdSXVi1SSaBEpf5U0KaDtyRQohJ/icL+7Pe6BjA5PdHUSWzVpSfO8Qj4AqoOV6s6XJ3rsjANCGSP64sm5cRnaknFPK2et0SS5LouwQsYIJ50NDSyj3XVSQIZxYVA9pi+jna17t2tWH+fov198r/5ie7q61PklUv12LuDig0kFI8mdf5VjVq1bmG+ywU8xXFsHXzvXfV3dSg20K+O1i596Uiz/PLp5b+t1/BgUsNDCQVCft38Zxeoqn5ySxVhFgLZQ+xkUv/23/6LYgP9mT8LjPxSzxz1xkY3mz/4/jECGcixD19+Udse/tlxf1YpSQqp89BA5s/iUVt73mrXyjXM3ygmBLKHxAb6U2FsWVr8B59SSWVERzoc9Q8HtXDJeTp/2VIlYkm9+Pe7FY+xHzWQa46d6p4uq67RgosuleMv1YGDcVllFbruhiUqqUgF81ub9+vAzk6t+Ox8hpeKCIHsIcODqU/Y4bJy3XjP1096zWBPalegRMxmLBnIscoZ9ZKkipparf7y3Se9praxTL9+6oB622Pqbh1SbWN5LkvENGLrTA9Jd1WXVFSc8ppgiV+S5Niu7KSbk7oApETqGyRJfZ0dp7wmVBLQnKWpWdUHdnblpC7kBoHsIbGBVAu5pKLylNcEQ/7M7xN0WwM5lW4hx/r7FI9FT3ndwuUzJEkH3iOQiwmB7CGZFnL5qVvIls9SMJwK5XjUzkldAFLC5eUKl6W6oPtP00qef0GtLJ907MiQejtOHdwoLASyh8QGz9xClqRQaSqQE8MEMpBrlXWpbuvTBXJJeVCzF6V27aLbungQyB6SmdR1mjFkaXQcmZnWQO5F6lPd1qcLZIlu62JEIHvI6KSuM7SQS1KT7xMxWshArqXHkfs62097XdOyGZIldRwc0ED3cC5KwzQjkD0kM6nrNGPI0pgWMmPIQM6NdlmfPpDLIiHNXBiRRLd1sSCQPSTdQg6Xn6mFnO6yJpCBXEt3WZ9u6VPagmW1kgjkYkEge8hwZlLX+FrILHsCci/TQu44fQtZkhaMjCMf/V2fov2nP7cc5iOQPWQ865Cl0TFkWshA7kVGAjna36fE8OnHhitrS1Q3r0Kum9p/HoWNQPaQ8ezUJY2dZU0gA7kWLi9XqLRM0plnWkujrWRmWxc+AtkjkvG4kvHUp+2ScY4h02UN5EdlXXrp05m7rdPLn1r29Gp4iGe2kBHIHjE8NChJsnw+hcrKTnstXdZAfqW7rc+09EmSqhpKVTOrTK7j6tAHdFsXMgLZI0ZnWFec8QSn0UldBDKQD5Xj3BwkLdNtzWzrgkYge8R49rFOY9kTkF+RcWyfOVY6kA/v7mHL2wJGIHvEeGdYSyx7AvKtMosuayl1RnKkrkR2wlHzru7pLA3TiED2iOHB0S7rMwmVMoYM5NPopK7xtZAty2K2dREgkD0imxZyaMwYsuu601oXgBNFRgJ5qLcnszriTBYsSwVy865jSiacaasN04dA9ojxrkGWRrusHduVzYMN5Fy4vELBklJJ428l18+vUHl1SIlhRy0f90xneZgmBLJHjPcsZEkKhvzSyERsuq2B3LMsK6s9rSXJ8lmZVjLd1oWJQPaIzD7W4xhDtnyWgmGWPgH5VJnlTGtpdLb1oQ+OybHp3So0BLJHZNYhj6PLWhozjswSCiAvxnsu8lgzz4qopCKo4aGkjuztm67SME0IZI/IZlKXNPZMZJY+AflQUZtq7Q4cG3/3s89nqenC1JGMhz5k165CQyB7RDYbg0hsnwnk21BvamJWWaQqq/fNmFM+8n6OYyw0BLIHuK478RYygQzkRXdriySpenZjVu+jd6twEcgeEO3vk52IS5al8uqacb0nxH7WQF71jARyzew5Wb0vvbEPz27hIZA9oK/jqCSpoqZW/mBwXO/JzLJmUheQc8n4cGa5U03j3KzeO7oXPS3kQkMge0BfR2qWZqS+YdzvYT9rIH962lol11W4vEKllZGs3psJ5CgfpgsNgewB6UCurJs57vfQQgbyJzN+PKvxjMel/j72oi9cBLIH9HemW8j1434PRzAC+dPTekSSVNOY3fixNNq7ZSccNgcpMASyB0ysy5qJIUC+dE9wQpc0umRRotu60BDIHpCe1BWpn0CXNYEM5NxkAtnntxQIpX60M7GrsBDIRc5xbPV3dkqSInXjbyGzdSaQH67rjnZZZ7kGOY2JXYWJQC5yg93dcuykfP6Aymtrx/0+NgYB8mOw+5gSwzH5/H5FGmZN6B5BJnYVJAK5yKW7qytn1Mnn84/7fUFayEBedB85LEmKNMyUPxA4w9UnF2LZYkEikIvcRCZ0SWPHkHmggVzqbhvprp6V/fhxWmbpE13WBYVALnKZNchZBvLoGLIj13GnvC4AJ9d9ZGRC1wSWPKWxW1dhIpCL3GgLefwzrKXRFrIkJeJ8ygZypSdzqMTEA5k5IIWJQC5y/Z3pJU/ZtZD9QZ8sX2qHIJY+AbkzmSVPaZnjU+myLigEcpHLtJCzWPIkSZZlsVsXkGPxWFQDx7okTXzJkySFSpkDUogI5CJmJxIa6D4mKfsWsjT2gAkCGciFnrZWSVJpZWTcZ5efDC3kwkQgF7H+rg7JdRUIhVUaqcr6/RwwAeTWVEzokkZbyEzqKiwEchEbndBVn/WJMRITQ4Bc62lLn/I0yUAuYWOQQkQgF7GJHLs4VogWMpBTU9VCzgw30WVdUAjkIjZ6qET248fS2DFkur2AXJiKGdYSXdaFikAuYhPdpSuNE5+A3HEdJzOpa9KBzKSugkQgF7G+zoltCpLGsicgd/q7OmUn4vIFAqqsr5/UvTLLnoZtdtorIARyEZt0C5kDJoCcSXdXV8+cndVBMCeT3sta4vktJARykYrHoor190nKflOQtOBItxdd1sD0m6oJXZLkD/jkD6RWVtBtXTgI5CLV39khSQqXlStcXj6he4RoIQM5k56EOZkNQcZKf6DuPjo0JffD9COQi9RkZ1hLo5O6GEMGpt+sxUskSR+9+rI6Dx2Y9P3KIkFJ0gs/2aWnfrBTH791lA/XhiOQi1R6/LhiRt2E75EeQ472x+XYzpTUBeDkFl92hRZcvEJOMqkX/+4B2YnEpO531R2L1bSsVpbPUvuBfr36r/v06Hfe1mu/3KeOQ/1yXSZ7mYZALlLl1TWSpJbdHyra1zuhe5RFQpKknraonvhfv1Xzru4pqw/A8SzL0uov362Syoi6mg9q+xO/mNT96uZV6A/vOle3/9UKrVzTpEhdiRLDtna/cVRP3r9T/+ev39OuV1s1PMRaZVMQyEXqrBWXqb5poeJDQ3pr079O6B4z5pbryj86W+HygHqORvXC3+3S8w99qO5WxqSA6VBWVa1r/8N/kiT95tkndeTjjyZ/z0hIy6+fq1u/fYk+86fn6+xL6+QPWOpqGdQbj3+iR+97W9se2aOj+/sm/XdhcgjkIuXz+fXpjXdJkj58Zas6DnyS9T0sy9LST83SbX95qS68tlE+v6XDu3v0xP9+V68/9jvFBibXpQbgRGevuExLr7xGcl1t/dmDikejU3Jfy2epcXG1rr1ziW7/7kqtWr9QNbPLZCcc7X27Q7tfb5uSvwcTRyAXscYl52rxqisk19WrP//HCY8ZhcsCuuzzC3XLf71YTctq5TrSR6+16d++947ef7lFdpLxZWAqfXrjl1U5o04DXV1q3bt7yu9fUh7UBVc3av1fXKS1f75M56xq0NIrZk3534PsWK6HRva7urpUV1en5uZmVVWdeBxhOBwuuokO/V2d+vlf/JmS8bhuvOfrWnzZFZO+55G9vdq+eb+6WgYlSZG6Ev3B5xao6cLaCZ0qlS+u68p1XYXDYfl8fDY1QUtLi+bOnavDhw8rEokc91ogEFAwGJRte2OmcOvejxUIhVTftDDfpeSFF59Pb3yVHlY5o06XfHadJKn9k31Tcs/GxVX63L3L9enbF6k0ElRfZ0xb/2G3tv7D7qL7QAPky+zFSzwbxl4VOPMl3uA4TuYTWbG56Ka1mnf+cs1adM6UfX2WJZ1zWYMWLJ+hnS8d1gcvH1HDgtSGBoXy3/Cxxx7TmjVrFA6H810KzsBxnKJ+RnGigwcPqrm5WatXr853KTlDl7Uk27Zl27YCgUBBdbmaZKB7WKUVQfmDhdHpEo1GdcMNN2j+/Pl67LHHVFJSku+SoJN3WTuOo2QyKcuyFAwGCWSPePzxx/XVr35Vmzdv1vXXX5/vcnLC8y3kdBj7fD75/X4CeYKq6sryXUJWKioqtGXLFq1du1YbNmzQpk2baCkbaGwY+/1+z4wlQrrtttsUi8W0bt06bdq0STfccEO+S5p2ng7ksWEcDAZ52D2mrq5OTz/9tNauXav169dr06ZNtJQN8vthHAwG810ScsiyLH3xi1+Uz+fLPJ833nhjvsuaVp5NIMIYklRbW6unnnpKra2tWr9+vWKxWL5LgghjpFiWpS984Qu6//77dcstt+j555/Pd0nTypMp5DgOYYyM2tpabdmyRUePHtW6desIZQPYtk0YQ1IqlDdu3Kgf/OAH2rBhg5577rl8lzRtPJlEjuMQxjhOTU2NtmzZoo6ODn3+859XdIp2R0J2xk7YIoyRZlmW7rjjDv3whz/UrbfeqmeffTbfJU0LT82y7ujoUENDg3bu3Kmampp8lwMD9fX1aePGjaqpqdHTTz+t0tLSfJfkKYcOHVJTU5M++OCDk27eA2zevFnf/OY39fDDD+uWW24pqom4ngrk5uZmzZ8/P99loEB0dHSorm7ix1ciewcOHNDChWyGgfHp7e09YUe3QuapQHYcR0eOHFFlZWVRfarC9OD/k9zjGUU2iu3/E08FMgAApmJGEwAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkAAAMQyAAAGIBABgDAAAQyAAAGIJABADAAgQwAgAEIZAAADEAgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAYgkAEAMACBDACAAQhkAAAMQCADAGAAAhkAAAMQyAAAGIBABgDAAAQyAAAGIJABADAAgQwAgAEIZAAADEAgAwBgAAIZAAADEMgAABiAQAYAwAAEMgAABiCQAQAwAIEMAIABCGQAAAxAIAMAYAACGQAAAxDIAAAYgEAGAMAABDIAAAb4/8ajl/QUiWwtAAAAAElFTkSuQmCC\n",
246 | "text/plain": [
247 | ""
248 | ]
249 | },
250 | "execution_count": 6,
251 | "metadata": {},
252 | "output_type": "execute_result"
253 | }
254 | ],
255 | "source": [
256 | "skeleton = Skeleton()\n",
257 | "bodies = torch.cat((full_body, seq2seq_body), dim=0).float()\n",
258 | "skeleton.compare_motion(bodies, azim=0, elev=0)"
259 | ]
260 | }
261 | ],
262 | "metadata": {
263 | "kernelspec": {
264 | "display_name": "Python 3",
265 | "language": "python",
266 | "name": "python3"
267 | },
268 | "language_info": {
269 | "codemirror_mode": {
270 | "name": "ipython",
271 | "version": 3
272 | },
273 | "file_extension": ".py",
274 | "mimetype": "text/x-python",
275 | "name": "python",
276 | "nbconvert_exporter": "python",
277 | "pygments_lexer": "ipython3",
278 | "version": "3.7.4"
279 | }
280 | },
281 | "nbformat": 4,
282 | "nbformat_minor": 2
283 | }
284 |
--------------------------------------------------------------------------------
/src/build-dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from common.data_utils import read_h5
9 | from common.logging import logger
10 | import h5py
11 | import argparse
12 | import sys
13 | import glob
14 | import numpy as np
15 |
16 |
17 | def parse_args():
18 | """Parse arguments for module.
19 |
20 | Returns:
21 | argparse.Namespace: contains accessible arguments passed in to module
22 | """
23 | parser = argparse.ArgumentParser()
24 |
25 | parser.add_argument("--training",
26 | help=("participants for training, space separated; "
27 | "e.g., W1 W2"))
28 | parser.add_argument("--validation",
29 | help=("participants for validation, space separated; "
30 | "e.g., P1 P2"))
31 | parser.add_argument("--testing",
32 | help=("participants for testing, space separated; "
33 | "e.g., P4 P5"))
34 | parser.add_argument("-f",
35 | "--data-path",
36 | help="path to h5 files for reading data")
37 | parser.add_argument("-o", "--output-path",
38 | help=("path to directory to save h5 files for "
39 | "training, validation, and testing"))
40 | parser.add_argument("-x", "--task-input",
41 | help=("input type; "
42 | "e.g., orientation, relativePosition, "
43 | "or jointAngle"))
44 | parser.add_argument("--input-label-request",
45 | help=("input label requests, space separated; "
46 | "e.g., all or Pelvis RightForearm"))
47 | parser.add_argument("-y", "--task-output",
48 | help="output type; e.g., orientation or jointAngle")
49 | parser.add_argument("--output-label-request",
50 | help=("output label requests, space separated; "
51 | "e.g., all or jRightElbow"))
52 | parser.add_argument("--aux-task-output",
53 | help=("auxiliary task output in addition "
54 | "to regular task output"))
55 | parser.add_argument("--aux-output-label-request",
56 | help="aux output label requests, space separated")
57 |
58 | args = parser.parse_args()
59 |
60 | if None in [args.training, args.validation, args.testing]:
61 | logger.info(("Participant numbers for training, validation, "
62 | "or testing dataset were not provided."))
63 | parser.print_help()
64 | sys.exit()
65 |
66 | if None in [args.data_path, args.output_path]:
67 | logger.error("Data path or output path were not provided.")
68 | parser.print_help()
69 | sys.exit()
70 |
71 | if None in [args.task_input, args.input_label_request, args.task_output]:
72 | logger.error(("Task input and label requests "
73 | "or task output were not given."))
74 | parser.print_help()
75 | sys.exit()
76 |
77 | if args.output_label_request is None:
78 | if args.task_input == args.task_output:
79 | logger.info("Will create h5 files with input data only.")
80 | else:
81 | logger.error("Label output requests were not given for the task.")
82 | parser.print_help()
83 | sys.exit()
84 |
85 | if args.aux_task_output == args.task_output:
86 | logger.error("Auxiliary task should not be the same as the main task.")
87 | parser.print_help()
88 | sys.exit()
89 |
90 | if (args.aux_task_output is not None and
91 | args.aux_output_label_request is None):
92 | logger.error("Need auxiliary output labels if using aux output task")
93 | parser.print_help()
94 | sys.exit()
95 |
96 | if args.task_input == args.task_output:
97 | if args.output_label_request is None:
98 | logger.info(("Will create h5 files with only input "
99 | "data for self-supervision tasks..."))
100 | else:
101 | logger.info("Will create h5 files with input and output data.")
102 |
103 | return args
104 |
105 |
106 | def setup_filepaths(data_path, participant_numbers):
107 | """Set up filepaths for reading in participant .h5 files.
108 |
109 | Args:
110 | data_path (str): path to directory containing .h5 files
111 | participant_numbers (list): participant numbers for filepaths
112 |
113 | Returns:
114 | list: filepaths to all of the .h5 files
115 | """
116 | all_filepaths = []
117 | for participant_number in participant_numbers:
118 | filepaths = glob.glob(data_path + "/" + participant_number + "_*.h5")
119 | all_filepaths += filepaths
120 | return all_filepaths
121 |
122 |
123 | def map_requests(tasks, labels):
124 | """Generate a dict of tasks mapped to labels.
125 |
126 | Args:
127 | tasks (list): list of tasks/groups that will be mapped to labels
128 | labels (list): list of labels that will be the value for each task
129 |
130 | Returns:
131 | dict: dictionary mapping each task to the list of labels
132 | """
133 | requests = dict(map(lambda e: (e, labels), tasks))
134 | return requests
135 |
136 |
137 | def write_dataset(filepath_groups, variable, experiment_setup, requests):
138 | """Write data to training, validation, testing .h5 files.
139 |
140 | Args:
141 | filepath_groups (list): list of tuples associate files to data group
142 | variable (str): the machine learning variable X or Y to be written to
143 | experiment_setup (dict): map to reference task for variable
144 | requests (dict): requests to read from files to store with variable
145 | """
146 | for filepaths, group in filepath_groups:
147 | logger.info(f"Writing {variable} to the {group} set...")
148 | h5_filename = args.output_path + "/" + group + ".h5"
149 | with h5py.File(h5_filename, "a") as h5_file:
150 | dataset = read_h5(filepaths, requests)
151 | for filename in dataset.keys():
152 | temp_dataset = None
153 | for j, data_group in enumerate(experiment_setup[variable]):
154 | if temp_dataset is None:
155 | temp_dataset = dataset[filename][data_group]
156 | else:
157 | temp_dataset = np.append(temp_dataset,
158 | dataset[filename][data_group],
159 | axis=1)
160 | try:
161 | h5_file.create_dataset(filename + "/" + variable,
162 | data=temp_dataset)
163 | except KeyError:
164 | logger.info(f"{filename} does not contain {data_group}")
165 |
166 |
167 | if __name__ == "__main__":
168 | args = parse_args()
169 |
170 | train_filepaths = setup_filepaths(args.data_path, args.training.split(" "))
171 | val_filepaths = setup_filepaths(args.data_path, args.validation.split(" "))
172 | test_filepaths = setup_filepaths(args.data_path, args.testing.split(" "))
173 |
174 | filepath_groups = [(train_filepaths, "training"),
175 | (val_filepaths, "validation"),
176 | (test_filepaths, "testing")]
177 |
178 | task_input = args.task_input.split(" ")
179 | input_label_request = args.input_label_request.split(" ")
180 |
181 | task_output = args.task_output.split(" ")
182 | if args.output_label_request is not None:
183 | output_label_request = args.output_label_request.split(" ")
184 |
185 | if (args.task_input == args.task_output
186 | and args.output_label_request is None):
187 | experiment_setup = {"X": task_input}
188 | requests = map_requests(task_input, input_label_request)
189 |
190 | write_dataset(filepath_groups, "X", experiment_setup, requests)
191 | else:
192 | experiment_setup = {"X": task_input, "Y": task_output}
193 |
194 | input_requests = map_requests(task_input, input_label_request)
195 | output_requests = map_requests(task_output, output_label_request)
196 |
197 | if args.aux_task_output is not None:
198 | aux_task_output = args.aux_task_output.split(" ")
199 | aux_output_label_request = args.aux_output_label_request.split(" ")
200 | experiment_setup["Y"] += aux_task_output
201 | aux_output_requests = map_requests(aux_task_output,
202 | aux_output_label_request)
203 | output_requests.update(aux_output_requests)
204 |
205 | write_dataset(filepath_groups, "X", experiment_setup, input_requests)
206 | write_dataset(filepath_groups, "Y", experiment_setup, output_requests)
207 |
--------------------------------------------------------------------------------
/src/build_dataset.sh:
--------------------------------------------------------------------------------
1 | python build-dataset.py --data-path "/groups/MotionPred/h5-dataset" \
2 | --output-path "/home/jackg7/VT-Natural-Motion-Processing/data/set-2" \
3 | --training "P1" \
4 | --validation "P5" \
5 | --testing "P10" \
6 | --task-input "normOrientation normAcceleration" \
7 | --input-label-request "T8 RightForeArm RightLowerLeg LeftForeArm LeftLowerLeg" \
8 | --task-output "normOrientation" \
9 | --output-label-request "all"
10 |
--------------------------------------------------------------------------------
/src/common/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 | import warnings
11 | import h5py
12 | import math
13 | import os
14 | from .logging import logger
15 | import sys
16 | from torch.utils.data import TensorDataset, DataLoader
17 |
18 |
19 | class XSensDataIndices:
20 | """XSensDataIndices helps with retrieving data.
21 |
22 | XSens has a data layout which requires interfacing data measurements
23 | (position, orientation, etc.) with places (RightLowerLeg, RightWrist)
24 | where those measurements are available.
25 | """
26 |
27 | def __init__(self):
28 | """Initialize object for use in reading data from .h5 files."""
29 | sensor_group = ["sensorFreeAcceleration",
30 | "sensorMagneticField",
31 | "sensorOrientation"]
32 |
33 | segment_group = ["position", "normPositions",
34 | "velocity", "acceleration",
35 | "normAcceleration", "sternumNormAcceleration",
36 | "angularVelocity", "angularAcceleration",
37 | "orientation", "normOrientation",
38 | "sternumNormOrientation"]
39 |
40 | joint_group = ["jointAngle", "jointAngleXZY"]
41 |
42 | joint_ergo_group = ["jointAngleErgo", "jointAngleErgoXZY"]
43 |
44 | groups = [sensor_group, segment_group,
45 | joint_group, joint_ergo_group]
46 | self._labels_to_items(groups)
47 |
48 | def __call__(self, requests):
49 | """Retrieve indices using a request.
50 |
51 | Requests are dicts that use groups like "position" as keys and
52 | labels like ["Pelvis"]. Requests can be passed to an instance of
53 | XSensDataIndices to retrieve the indices for the labels of each group.
54 |
55 | Args:
56 | requests (dict): maps groups to desired labels and will retrieve
57 | indices for those group labels.
58 |
59 | Returns:
60 | dict: a map of the groups to the indices for the labels.
61 | """
62 | label_indices = {}
63 | for label, items in requests.items():
64 | if label in self.label_items:
65 | label_indices[label] = self._request(label, items)
66 | return label_indices
67 |
68 | def _labels_to_items(self, groups):
69 | self.label_items = {}
70 | sensors = ["Pelvis", "T8", "Head", "RightShoulder", "RightUpperArm",
71 | "RightForeArm", "RightHand", "LeftShoulder", "LeftUpperArm",
72 | "LeftForeArm", "LeftHand", "RightUpperLeg", "RightLowerLeg",
73 | "RightFoot", "LeftUpperLeg", "LeftLowerLeg", "LeftFoot"]
74 |
75 | segments = ["Pelvis", "L5", "L3", "T12",
76 | "T8", "Neck", "Head",
77 | "RightShoulder", "RightUpperArm",
78 | "RightForeArm", "RightHand",
79 | "LeftShoulder", "LeftUpperArm",
80 | "LeftForeArm", "LeftHand",
81 | "RightUpperLeg", "RightLowerLeg",
82 | "RightFoot", "RightToe",
83 | "LeftUpperLeg", "LeftLowerLeg",
84 | "LeftFoot", "LeftToe"]
85 |
86 | joints = ["jL5S1", "jL4L3", "jL1T12",
87 | "jT9T8", "jT1C7", "jC1Head",
88 | "jRightT4Shoulder", "jRightShoulder",
89 | "jRightElbow", "jRightWrist",
90 | "jLeftT4Shoulder", "jLeftShoulder",
91 | "jLeftElbow", "jLeftWrist",
92 | "jRightHip", "jRightKnee",
93 | "jRightAnkle", "jRightBallFoot",
94 | "jLeftHip", "jLeftKnee",
95 | "jLeftAnkle", "jLeftBallFoot"]
96 |
97 | ergo_joints = ["T8_Head", "T8_LeftUpperArm", "T8_RightUpperArm",
98 | "Pelvis_T8", "Vertical_Pelvis", "Vertical_T8"]
99 |
100 | item_groups = [sensors, segments, joints, ergo_joints]
101 |
102 | for index, group in enumerate(groups):
103 | for label in group:
104 | self.label_items[label] = item_groups[index]
105 |
106 | def _request(self, req_label, req_items):
107 | valid_items = self.label_items[req_label]
108 |
109 | if "all" in req_items:
110 | req_items = valid_items
111 |
112 | num_valid_items = len(valid_items)
113 | orientation_groups = ["orientation", "normOrientation",
114 | "sternumNormOrientation"]
115 | dims = 4 if req_label in orientation_groups else 3
116 |
117 | indices = [list(range(i, i+dims))
118 | for i in range(0, dims*num_valid_items, dims)]
119 |
120 | index_map = dict(zip(valid_items, indices))
121 |
122 | return self._find_indices(index_map, req_items)
123 |
124 | def _find_indices(self, index_map, items):
125 | mapped_indices = []
126 |
127 | for item in items:
128 | if item in index_map:
129 | mapped_indices.append(index_map[item])
130 | else:
131 | warnings.warn("Requested item {} not in file.".format(item))
132 |
133 | return mapped_indices
134 |
135 |
136 | def discard_remainder(data, seq_length):
137 | """Discard data that does not fit inside sequence length.
138 |
139 | Args:
140 | data (np.ndarray): data to truncate
141 | seq_length (int): sequence length to find data that doesn"t fit into
142 | sequences
143 |
144 | Returns:
145 | np.ndarray: truncated data
146 | """
147 | new_row_num = data.shape[0] - (data.shape[0] % seq_length)
148 | data = data[:new_row_num]
149 | return data
150 |
151 |
152 | def stride_downsample_sequences(data, seq_length, stride, downsample,
153 | offset=0, in_out_ratio=1):
154 | """Build sequences with an array of data tensor.
155 |
156 | Args:
157 | data (np.ndarray): data to turn into sequences
158 | seq_length (int): sequence length of the original sequence (e.g., 30
159 | frames will be downsampled to 5 frames if downsample is 6.)
160 | stride (int): step size over data when looping over frames
161 | downsample (int): amount to downsample data (e.g., 6 will take 240 Hz
162 | to 40 Hz.)
163 | offset (int, optional): offset for index when looping; useful for the
164 | prediction task when making output data. Defaults to 0.
165 | in_out_ratio (int, optional): ratio of input to output; useful for the
166 | conversion task when making output data. Defaults to 1.
167 |
168 | Returns:
169 | np.ndarray: data broken into sequences
170 | """
171 | samples = []
172 | for i in range(0, data.shape[0] - 2*seq_length, stride):
173 | i_shift = i+offset
174 | sample = data[i_shift:i_shift+seq_length:downsample, :]
175 |
176 | ratio_shift = sample.shape[0] - sample.shape[0]//in_out_ratio
177 | sample = sample[ratio_shift:, :]
178 | samples.append(sample)
179 | samples = np.concatenate(samples, axis=0)
180 | return samples
181 |
182 |
183 | def read_h5(filepaths, requests):
184 | """Read data from an h5 file and store in a dataset dict.
185 |
186 | Primarily used for building a dataset (see build-dataset.py)
187 |
188 | Args:
189 | filepaths (list): list of file paths to draw data from.
190 | requests (dict): dictionary of requests to make to files.
191 |
192 | Returns:
193 | dict: dictionary containing files mapped to labels mapped to data
194 | """
195 | xsensIndices = XSensDataIndices()
196 | indices = xsensIndices(requests)
197 |
198 | def flatten(l): return [item for sublist in l for item in sublist]
199 |
200 | h5_files = []
201 | for filepath in filepaths:
202 | try:
203 | h5_file = h5py.File(filepath, "r+")
204 | except OSError:
205 | logger.info(f"OSError: Unable to open file {filepath}")
206 | continue
207 | h5_files.append((h5_file, os.path.basename(filepath)))
208 |
209 | dataset = {}
210 | for h5_file, filename in h5_files:
211 | dataset[filename] = {}
212 | for label in indices:
213 | label_indices = flatten(indices[label])
214 | label_indices.sort()
215 |
216 | file_data = np.array(h5_file[label])
217 | file_data = file_data.reshape(file_data.shape[1],
218 | file_data.shape[0])
219 |
220 | data = np.array(file_data[:, label_indices])
221 |
222 | dataset[filename][label] = data
223 |
224 | h5_file.close()
225 |
226 | return dataset
227 |
228 |
229 | def read_variables(h5_file_path, task, seq_length, stride, downsample,
230 | in_out_ratio=1):
231 | """Read data from dataset and store in X and y variables.
232 |
233 | Args:
234 | h5_file_path (str): h5 file containing dataset built previously
235 | task (str): either prediction or conversion; task that will be modeled
236 | by the machine learning model
237 | seq_length (int): original sequence length before downsampling data
238 | stride (int): step size over data when building sequences
239 | downsample (int): amount to downsample data by (e.g., 6 to reduce
240 | 240 Hz sampling rate to 40 Hz.)
241 | in_out_ratio (int, optional): input length compared to output length.
242 | Defaults to 1.
243 |
244 | Returns:
245 | tuple: returns a tuple of variables X and y for use in a machine
246 | learning task.
247 | """
248 | X, y = None, None
249 | h5_file = h5py.File(h5_file_path, "r")
250 | for filename in h5_file.keys():
251 | X_temp = h5_file[filename]["X"]
252 | X_temp = discard_remainder(X_temp, 2*seq_length)
253 |
254 | if task == "prediction":
255 | y_temp = stride_downsample_sequences(X_temp, seq_length, stride,
256 | downsample, offset=seq_length)
257 | elif task == "conversion":
258 | y_temp = h5_file[filename]["Y"]
259 | y_temp = discard_remainder(y_temp, 2*seq_length)
260 | y_temp = stride_downsample_sequences(y_temp, seq_length, stride,
261 | downsample,
262 | in_out_ratio=in_out_ratio)
263 | else:
264 | logger.error(("Task must be either prediction or conversion, "
265 | f"found {task}"))
266 | sys.exit()
267 |
268 | X_temp = stride_downsample_sequences(X_temp, seq_length,
269 | stride, downsample)
270 |
271 | assert not np.any(np.isnan(X_temp))
272 | assert not np.any(np.isnan(y_temp))
273 |
274 | if X is None and y is None:
275 | X = torch.tensor(X_temp)
276 | y = torch.tensor(y_temp)
277 | else:
278 | X = torch.cat((X, torch.tensor(X_temp)), dim=0)
279 | y = torch.cat((y, torch.tensor(y_temp)), dim=0)
280 | h5_file.close()
281 | return X, y
282 |
283 |
284 | def load_dataloader(args, set_type, normalize, norm_data=None, shuffle=True):
285 | """Create dataloaders for PyTorch machine learning tasks.
286 |
287 | Args:
288 | args (argparse.Namespace): contains accessible arguments passed in
289 | to module
290 | set_type (str): set to read from when gathering data (either training,
291 | validation or testing sets)
292 | normalize (bool): whether to normalize the data before adding to
293 | dataloader
294 | norm_data (tuple, optional): if passed will contain mean and std_dev
295 | data to normalize input data with. Defaults to None.
296 | shuffle (bool, optional): whether to shuffle the data stored in the
297 | dataloader. Defaults to True.
298 |
299 | Returns:
300 | tuple: returns a tuple containing the DataLoader and the normalization
301 | data
302 | """
303 | file_path = args.data_path + "/" + set_type + ".h5"
304 | seq_length = int(args.seq_length)
305 | downsample = int(args.downsample)
306 | batch_size = int(args.batch_size)
307 | in_out_ratio = int(args.in_out_ratio)
308 | stride = int(args.stride) if set_type == "training" else seq_length//2
309 |
310 | logger.info((f"Retrieving {set_type} data "
311 | f"for sequences {int(seq_length/240*1000)} ms long and "
312 | f"downsampling to {240/downsample} Hz..."))
313 |
314 | X, y = read_variables(file_path, args.task, seq_length, stride, downsample,
315 | in_out_ratio=in_out_ratio)
316 |
317 | if normalize:
318 | mean, std_dev = None, None
319 | if norm_data is None:
320 | mean, std_dev = X.mean(dim=0), X.std(dim=0)
321 | norm_data = (mean, std_dev)
322 | with h5py.File(args.data_path + "/normalization.h5", "w") as f:
323 | f["mean"], f["std_dev"] = mean, std_dev
324 | else:
325 | mean, std_dev = norm_data
326 | X = X.sub(mean).div(std_dev + 1e-8)
327 |
328 | logger.info(f"Data for {set_type} have shapes "
329 | f"(X, y): {X.shape}, {y.shape}")
330 |
331 | X = X.view(-1, math.ceil(seq_length/downsample), X.shape[1])
332 | y = y.view(-1, math.ceil(seq_length/(downsample*in_out_ratio)), y.shape[1])
333 |
334 | logger.info(f"Reshaped {set_type} shapes (X, y): {X.shape}, {y.shape}")
335 |
336 | dataset = TensorDataset(X, y)
337 |
338 | shuffle = True if set_type == "training" else False
339 | dataloader = DataLoader(dataset, batch_size=batch_size,
340 | shuffle=shuffle, drop_last=True)
341 |
342 | logger.info(f"Number of {set_type} samples: {len(dataset)}")
343 |
344 | return dataloader, norm_data
345 |
--------------------------------------------------------------------------------
/src/common/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from __future__ import absolute_import
9 |
10 | import logging
11 |
12 | logging.basicConfig(
13 | format="%(asctime)s %(levelname)-8s %(message)s",
14 | level=logging.INFO,
15 | datefmt="%Y-%m-%d %H:%M:%S")
16 | logger = logging.getLogger()
17 | logger.setLevel(logging.INFO)
18 |
--------------------------------------------------------------------------------
/src/common/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import torch.nn as nn
10 | import math
11 |
12 |
13 | class QuatDistance(nn.Module):
14 | """Loss function for calculating cosine similarity of quaternions."""
15 |
16 | def __init__(self):
17 | """Initialize QuatDistance loss."""
18 | super(QuatDistance, self).__init__()
19 |
20 | def forward(self, predictions, targets):
21 | """Forward pass through the QuatDistance loss.
22 |
23 | Args:
24 | predictions (torch.Tensor): the predictions from the model in
25 | quaternion form.
26 | targets (torch.Tensor): the targets in quaternion form.
27 |
28 | Returns:
29 | torch.Tensor: average angular difference in degrees between
30 | quaternions in predictions and targets.
31 | """
32 | predictions = predictions.contiguous().view(-1, 1, 4)
33 | targets = targets.contiguous().view(-1, 4, 1)
34 |
35 | inner_prod = torch.bmm(predictions, targets).view(-1)
36 |
37 | x = torch.clamp(torch.abs(inner_prod), min=0.0, max=1.0-1e-7)
38 |
39 | theta = torch.acos(x)
40 |
41 | return (360.0/math.pi)*torch.mean(theta)
42 |
--------------------------------------------------------------------------------
/src/common/preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 | import h5py
11 | from .logging import logger
12 | from .rotations import quat_to_rotMat, rotMat_to_quat
13 | from .quaternion import quat_fix
14 |
15 |
16 | def add_normalized_positions(filepaths, new_group_name):
17 | """Add position data normalized relative to the pelvis to h5 files.
18 |
19 | Args:
20 | filepaths (list): paths to files to add data to
21 | new_group_name (str): what the new group will be called in the h5
22 | file
23 | """
24 | for filepath in filepaths:
25 | try:
26 | h5_file = h5py.File(filepath, "r+")
27 | except OSError:
28 | logger.info(f"OSError: Unable to open file {filepath}")
29 | continue
30 |
31 | quat = np.array(h5_file["orientation"][:, :])
32 | quat = quat.reshape(quat.shape[1], quat.shape[0])
33 | quat = quat.reshape(quat.shape[0], -1, 4)
34 |
35 | pos = np.array(h5_file["position"][:, :])
36 | pos = pos.reshape(pos.shape[1], pos.shape[0])
37 | pos = pos.reshape(pos.shape[0], -1, 3)
38 |
39 | quat = quat_fix(quat)
40 |
41 | norm_pos = np.zeros(pos.shape)
42 |
43 | pelvis_rot = np.linalg.inv(
44 | quat_to_rotMat(torch.tensor(quat[:, 0, :]))
45 | )
46 | pelvis_pos = pos[:, 0, :]
47 | for i in range(0, quat.shape[1]):
48 | relative_pos = np.expand_dims(pos[:, i, :] - pelvis_pos, axis=2)
49 | norm_pos[:, i, :] = np.squeeze(np.matmul(pelvis_rot, relative_pos),
50 | axis=2)
51 |
52 | norm_pos = norm_pos.reshape(norm_pos.shape[0], -1)
53 | norm_pos = norm_pos.reshape(norm_pos.shape[1], norm_pos.shape[0])
54 |
55 | try:
56 | logger.info(f"Writing to file {filepath}")
57 | h5_file.create_dataset(new_group_name, data=norm_pos)
58 | except RuntimeError:
59 | logger.info(("RuntimeError: Unable to create link "
60 | f"(name already exists) in {filepath}"))
61 | h5_file.close()
62 |
63 |
64 | def add_normalized_accelerations(filepaths, group_name, new_group_name,
65 | root=0):
66 | """Add acceleration data normalized relative to a root to the h5 files.
67 |
68 | Args:
69 | filepaths (list): paths to files to add data to
70 | group_name (str): acceleration group to normalize
71 | (typically acceleration, but can also be sensorFreeAcceleration)
72 | new_group_name (str): new group name for normalized acceleration data
73 | root (int, optional): index of root (e.g., 0 is pelvis, 4 is sternum).
74 | Defaults to 0.
75 | """
76 | for filepath in filepaths:
77 | try:
78 | h5_file = h5py.File(filepath, "r+")
79 | except OSError:
80 | logger.info(f"OSError: Unable to open file {filepath}")
81 | continue
82 |
83 | quat = np.array(h5_file["orientation"][:, :])
84 | quat = quat.reshape(quat.shape[1], quat.shape[0])
85 | quat = quat.reshape(quat.shape[0], -1, 4)
86 |
87 | acc = np.array(h5_file[group_name][:, :])
88 | acc = acc.reshape(acc.shape[1], acc.shape[0])
89 | acc = acc.reshape(acc.shape[0], -1, 3)
90 |
91 | quat = quat_fix(quat)
92 |
93 | norm_acc = np.zeros(acc.shape)
94 |
95 | root_rot = np.linalg.inv(
96 | quat_to_rotMat(torch.tensor(quat[:, root, :]))
97 | )
98 | root_acc = acc[:, root, :]
99 | for i in range(0, acc.shape[1]):
100 | relative_acc = np.expand_dims(acc[:, i, :] - root_acc, axis=2)
101 | norm_acc[:, i, :] = np.squeeze(np.matmul(root_rot, relative_acc),
102 | axis=2)
103 |
104 | norm_acc = norm_acc.reshape(norm_acc.shape[0], -1)
105 | norm_acc = norm_acc.reshape(norm_acc.shape[1], norm_acc.shape[0])
106 |
107 | try:
108 | logger.info(f"Writing to file {filepath}")
109 | h5_file.create_dataset(new_group_name, data=norm_acc)
110 | except RuntimeError:
111 | logger.info(("RuntimeError: Unable to create link "
112 | f"(name already exists) in {filepath}"))
113 | h5_file.close()
114 |
115 |
116 | def add_normalized_quaternions(filepaths, group_name, new_group_name, root=0):
117 | """Add orientation data normalized relative to a root to the h5 files.
118 |
119 | Args:
120 | filepaths (list): paths to files to add data to
121 | group_name (str): orientation group to normalize
122 | (typically orientation, but can also be sensorOrientation)
123 | new_group_name (str): new group name for normalized orientation data
124 | root (int, optional): index of root (e.g., 0 is pelvis, 4 is sternum).
125 | Defaults to 0.
126 | """
127 | for filepath in filepaths:
128 | try:
129 | h5_file = h5py.File(filepath, "r+")
130 | except OSError:
131 | logger.info(f"OSError: Unable to open file {filepath}")
132 | continue
133 |
134 | quat = np.array(h5_file[group_name][:, :])
135 | quat = quat.reshape(quat.shape[1], quat.shape[0])
136 | quat = quat.reshape(quat.shape[0], -1, 4)
137 |
138 | quat = quat_fix(quat)
139 |
140 | norm_quat = np.zeros(quat.shape)
141 |
142 | root_rotMat = np.linalg.inv(
143 | quat_to_rotMat(torch.tensor(quat[:, root, :]))
144 | )
145 | for i in range(0, quat.shape[1]):
146 | rotMat = quat_to_rotMat(torch.tensor(quat[:, i, :]))
147 | norm_rotMat = np.matmul(root_rotMat, rotMat)
148 | norm_quat[:, i, :] = rotMat_to_quat(norm_rotMat)
149 |
150 | norm_quat = norm_quat.reshape(norm_quat.shape[0], -1)
151 | norm_quat = norm_quat.reshape(norm_quat.shape[1], norm_quat.shape[0])
152 |
153 | try:
154 | logger.info(f"Writing to file {filepath}")
155 | h5_file.create_dataset(new_group_name, data=norm_quat)
156 | except RuntimeError:
157 | logger.info(("RuntimeError: Unable to create link "
158 | f"(name already exists) in {filepath}"))
159 | h5_file.close()
160 |
--------------------------------------------------------------------------------
/src/common/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the Attribution-NonCommercial
5 | # 4.0 International license and is borrowed from the QuaterNet library.
6 | # See https://github.com/facebookresearch/QuaterNet/blob/master/LICENSE for
7 | # more details.
8 | #
9 |
10 | import torch
11 | import numpy as np
12 |
13 |
14 | def quat_fix(q):
15 | """Enforce quaternion continuity across the time dimension.
16 |
17 | Borrowed from QuaterNet:
18 | https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L119
19 |
20 | This function falls under the Attribution-NonCommercial 4.0 International
21 | license.
22 |
23 | Selects the representation (q or -q) with minimal distance
24 | (or, equivalently, maximal dot product) between two consecutive frames.
25 |
26 | Expects a tensor of shape (L, J, 4), where L is the sequence length and
27 | J is the number of joints.
28 | Returns a tensor of the same shape.
29 |
30 | Args:
31 | q (np.ndarray): quaternions of size (L, J, 4) to enforce continuity.
32 |
33 | Returns:
34 | np.ndarray: quaternion of size (L, J, 4) that is continuous
35 | in time dimension.
36 | """
37 | assert len(q.shape) == 3
38 | assert q.shape[-1] == 4
39 |
40 | result = q.copy()
41 | dot_products = np.sum(q[1:]*q[:-1], axis=2)
42 | mask = dot_products < 0
43 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
44 | result[1:][mask] *= -1
45 | return result
46 |
47 |
48 | def quat_mul(q, r):
49 | """Multiply quaternion(s) q with quaternion(s) r.
50 |
51 | Borrowed from QuaterNet:
52 | https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L13
53 |
54 | This function falls under the Attribution-NonCommercial 4.0 International
55 | license.
56 |
57 | Expects two equally-sized tensors of shape (*, 4), where * denotes any
58 | number of dimensions.
59 | Returns q*r as a tensor of shape (*, 4).
60 |
61 | Args:
62 | q (torch.Tensor): quaternions of size (*, 4)
63 | r (torch.Tensor): quaternions of size (*, 4)
64 |
65 | Returns:
66 | torch.Tensor: quaternions of size (*, 4)
67 | """
68 | assert q.shape[-1] == 4
69 | assert r.shape[-1] == 4
70 |
71 | original_shape = q.shape
72 |
73 | # Compute outer product
74 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
75 |
76 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
77 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
78 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
79 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
80 | return torch.stack((w, x, y, z), dim=1).view(original_shape)
81 |
--------------------------------------------------------------------------------
/src/common/rotations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 |
10 |
11 | def quat_to_rotMat(q):
12 | """Convert quaternions to rotation matrices.
13 |
14 | Using equation provided in XSens MVN Manual:
15 | https://www.xsens.com/hubfs/Downloads/usermanual/MVN_User_Manual.pdf
16 |
17 | Args:
18 | q (torch.Tensor): quaternion(s) to convert to rotation matrix format
19 |
20 | Returns:
21 | torch.Tensor: rotation matrix converted from quaternion format
22 | """
23 | if len(q.shape) != 2:
24 | q = q.unsqueeze(0)
25 |
26 | assert q.shape[1] == 4
27 |
28 | r0c0 = q[:, 0]**2 + q[:, 1]**2 - q[:, 2]**2 - q[:, 3]**2
29 | r0c1 = 2*q[:, 1]*q[:, 2] - 2*q[:, 0]*q[:, 3]
30 | r0c2 = 2*q[:, 1]*q[:, 3] + 2*q[:, 0]*q[:, 2]
31 |
32 | r1c0 = 2*q[:, 1]*q[:, 2] + 2*q[:, 0]*q[:, 3]
33 | r1c1 = q[:, 0]**2 - q[:, 1]**2 + q[:, 2]**2 - q[:, 3]**2
34 | r1c2 = 2*q[:, 2]*q[:, 3] - 2*q[:, 0]*q[:, 1]
35 |
36 | r2c0 = 2*q[:, 1]*q[:, 3] - 2*q[:, 0]*q[:, 2]
37 | r2c1 = 2*q[:, 2]*q[:, 3] + 2*q[:, 0]*q[:, 1]
38 | r2c2 = q[:, 0]**2 - q[:, 1]**2 - q[:, 2]**2 + q[:, 3]**2
39 |
40 | r0 = torch.stack([r0c0, r0c1, r0c2], dim=1)
41 | r1 = torch.stack([r1c0, r1c1, r1c2], dim=1)
42 | r2 = torch.stack([r2c0, r2c1, r2c2], dim=1)
43 |
44 | R = torch.stack([r0, r1, r2], dim=2)
45 |
46 | return R.permute(0, 2, 1)
47 |
48 |
49 | def rotMat_to_quat(rotMat):
50 | """Convert rotation matrices back to quaternions.
51 |
52 | Ported from Matlab:
53 | https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/mhmublv/Motion/rotmat2quat.m#L4
54 |
55 |
56 | Args:
57 | rotMat (torch.Tensor): rotation matrix or matrices to convert to
58 | quaternion format.
59 |
60 | Returns:
61 | torch.Tensor: quaternion(s) converted from rotation matrix format
62 | """
63 | if len(rotMat.shape) != 3:
64 | rotMat = rotMat.unsqueeze(0)
65 |
66 | assert rotMat.shape[1] == 3 and rotMat.shape[2] == 3
67 |
68 | diffMat = rotMat - torch.transpose(rotMat, 1, 2)
69 |
70 | r = torch.zeros((rotMat.shape[0], 3), dtype=torch.float64)
71 |
72 | r[:, 0] = -diffMat[:, 1, 2]
73 | r[:, 1] = diffMat[:, 0, 2]
74 | r[:, 2] = -diffMat[:, 0, 1]
75 |
76 | sin_theta = torch.norm(r, dim=1)/2
77 | sin_theta = sin_theta.unsqueeze(1)
78 |
79 | r0 = r / (torch.norm(r, dim=1).unsqueeze(1) + 1e-9)
80 |
81 | cos_theta = (rotMat.diagonal(dim1=-2, dim2=-1).sum(-1) - 1) / 2
82 | cos_theta = cos_theta.unsqueeze(1)
83 |
84 | theta = torch.atan2(sin_theta, cos_theta)
85 |
86 | theta = theta.squeeze(1)
87 |
88 | q = torch.zeros((rotMat.shape[0], 4), dtype=torch.float64)
89 |
90 | q[:, 0] = torch.cos(theta/2)
91 | q[:, 1:] = r0*torch.sin(theta/2).unsqueeze(1)
92 |
93 | return q
94 |
--------------------------------------------------------------------------------
/src/common/skeleton.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | from .rotations import quat_to_rotMat
10 | from .data_utils import XSensDataIndices
11 | import matplotlib.pyplot as plt
12 | import mpl_toolkits.mplot3d.axes3d as p3
13 | import matplotlib.animation as animation
14 |
15 |
16 | class Skeleton:
17 | """Skeleton for modeling and visualizing forward kinematics."""
18 |
19 | def __init__(self):
20 | """Initialize the skeleton, using segment lengths from P1.
21 |
22 | Have to initialize segments, parents of the segments, and map them
23 | together.
24 | """
25 | segments = ["Pelvis", "L5", "L3",
26 | "T12", "T8", "Neck", "Head",
27 | "RightShoulder", "RightUpperArm",
28 | "RightForeArm", "RightHand",
29 | "LeftShoulder", "LeftUpperArm",
30 | "LeftForeArm", "LeftHand",
31 | "RightUpperLeg", "RightLowerLeg",
32 | "RightFoot", "RightToe",
33 | "LeftUpperLeg", "LeftLowerLeg",
34 | "LeftFoot", "LeftToe"]
35 |
36 | parents = [None, "Pelvis", "L5",
37 | "L3", "T12", "T8", "Neck",
38 | "T8", "RightShoulder", "RightUpperArm", "RightForeArm",
39 | "T8", "LeftShoulder", "LeftUpperArm", "LeftForeArm",
40 | "Pelvis", "RightUpperLeg", "RightLowerLeg", "RightFoot",
41 | "Pelvis", "LeftUpperLeg", "LeftLowerLeg", "LeftFoot"]
42 |
43 | body_frame_segments = [None,
44 | [0.000000, 0.000000, 0.103107],
45 | [0.000000, 0.000000, 0.095793],
46 | [0.000000, 0.000000, 0.087564],
47 | [0.000000, 0.000000, 0.120823],
48 | [0.000000, 0.000000, 0.103068],
49 | [0.000000, 0.000000, 0.208570],
50 | [0.000000, -0.027320, 0.068158],
51 | [0.000000, -0.141007, 0.000000],
52 | [0.000000, -0.291763, 0.000000],
53 | [0.000071, -0.240367, 0.000000],
54 | [0.000000, 0.027320, 0.068158],
55 | [0.000000, 0.141007, 0.000000],
56 | [0.000000, 0.291763, 0.000000],
57 | [0.000071, 0.240367, 0.000000],
58 | [0.000000, -0.087677, 0.000000],
59 | [-0.000055, 0.000000, -0.439960],
60 | [0.000248, 0.000000, -0.445123],
61 | [0.192542, 0.000000, -0.087304],
62 | [0.000000, 0.087677, 0.000000],
63 | [-0.000055, 0.000000, -0.439960],
64 | [0.000248, 0.000000, -0.445123],
65 | [0.192542, 0.000000, -0.087304]]
66 |
67 | self.skeleton_tree = [[0, 1, 2, 3, 4],
68 | [4, 7, 8, 9, 10],
69 | [4, 11, 12, 13, 14],
70 | [4, 5, 6],
71 | [0, 15, 16, 17, 18],
72 | [0, 19, 20, 21, 22]]
73 |
74 | self.segments = segments
75 | self.index_of = dict(zip(segments, range(len(segments))))
76 | self.segment_parents = dict(zip(segments, parents))
77 | self.segment_positions_in_parent_frame = dict(zip(segments,
78 | body_frame_segments))
79 |
80 | def forward_kinematics(self, orientations):
81 | """Compute positions of segment endpoints using orientation of segments.
82 |
83 | Args:
84 | orientations (torch.Tensor): orientations of segments in the
85 | skeleton
86 |
87 | Returns:
88 | torch.Tensor: position of segment endpoints
89 | """
90 | xsens_indices = XSensDataIndices()
91 |
92 | positions = torch.zeros([len(self.segments),
93 | orientations.shape[0],
94 | 3], dtype=torch.float32)
95 |
96 | for i, segment in enumerate(self.segments):
97 | parent = self.segment_parents[segment]
98 |
99 | if parent is None:
100 | continue
101 | else:
102 | indices_map = xsens_indices({"orientation": [parent]})
103 | parent_indices = indices_map["orientation"][0]
104 |
105 | x_B = torch.tensor(
106 | self.segment_positions_in_parent_frame[segment],
107 | dtype=torch.float32)
108 |
109 | x_B = x_B.view(1, -1, 1).repeat(orientations.shape[0], 1, 1)
110 |
111 | R_GB = quat_to_rotMat(orientations[:, parent_indices])
112 | positions[i] = (positions[self.index_of[parent]] +
113 | R_GB.bmm(x_B).squeeze(2))
114 |
115 | return positions.permute(1, 0, 2)
116 |
117 | def animate_motion(self, orientations, azim, elev, title=None):
118 | """Animate frames of orientation data using forward kinematics.
119 |
120 | Args:
121 | orientations (torch.Tensor): orientations of the segments in the
122 | kinematic chain
123 | azim (float): azimuth of the plot point of view
124 | elev (float): elevation of the plot point of view
125 | title (str, optional): plot title. Defaults to None.
126 |
127 | Returns:
128 | animation.FuncAnimation: returns an animation that can be saved or
129 | viewed in a Jupyter Notebook
130 | """
131 | if len(orientations.shape) == 1:
132 | orientations = orientations.unsqueeze(0)
133 |
134 | def update_lines(num, data, lines):
135 | positions = data[num]
136 |
137 | for i, line in enumerate(lines):
138 | xs = list(positions[self.skeleton_tree[i], 0])
139 | ys = list(positions[self.skeleton_tree[i], 1])
140 | zs = list(positions[self.skeleton_tree[i], 2])
141 |
142 | line.set_data(xs, ys)
143 | line.set_3d_properties(zs)
144 | line.set_linestyle("-")
145 | return lines
146 |
147 | fig = plt.figure()
148 | ax = p3.Axes3D(fig)
149 |
150 | if title is not None:
151 | ax.set_title(title)
152 |
153 | data = self.forward_kinematics(orientations)
154 |
155 | lines = [ax.plot([0], [0], [0])[0] for _ in range(6)]
156 | limits = [-1.0, 1.0]
157 |
158 | self._setup_axis(ax, limits, azim, elev)
159 |
160 | line_ani = animation.FuncAnimation(fig,
161 | update_lines,
162 | frames=range(data.shape[0]),
163 | fargs=(data, lines),
164 | interval=25,
165 | blit=True)
166 | plt.show()
167 | return line_ani
168 |
169 | def compare_motion(self, orientations, azim, elev,
170 | fig_filename=None, titles=None):
171 | """Display plots of different orientation frames.
172 |
173 | Primarily useful for plotting skeletons for orientation outputs from
174 | different models.
175 |
176 | Args:
177 | orientations (torch.Tensor): orientations of the segments in the
178 | kinematic chain, typically of orientations from different
179 | models.
180 | azim (float): azimuth of the plot point of view
181 | elev (float): elevation of the plot point of view
182 | fig_filename (str, optional): figure filename. Defaults to None.
183 | titles (str, optional): plot titles. Defaults to None.
184 |
185 | Returns:
186 | matplotlib.pyplot.figure: figure that will be displayed and saved
187 | if fig_filename is provided.
188 | """
189 | if len(orientations.shape) == 1:
190 | orientations = orientations.unsqueeze(0)
191 |
192 | def update_lines(num, data, lines):
193 | positions = data[num]
194 |
195 | for i, line in enumerate(lines):
196 | xs = list(positions[self.skeleton_tree[i], 0])
197 | ys = list(positions[self.skeleton_tree[i], 1])
198 | zs = list(positions[self.skeleton_tree[i], 2])
199 |
200 | line.set_data(xs, ys)
201 | line.set_3d_properties(zs)
202 | return lines
203 |
204 | fig = plt.figure(figsize=(orientations.shape[0]*3, 3))
205 | data = self.forward_kinematics(orientations)
206 |
207 | limits = [-1.0, 1.0]
208 | for i in range(orientations.shape[0]):
209 | ax = fig.add_subplot(1,
210 | orientations.shape[0],
211 | i+1,
212 | projection="3d")
213 |
214 | lines = [ax.plot([0], [0], [0])[0] for _ in range(6)]
215 |
216 | self._setup_axis(ax, limits, azim, elev)
217 |
218 | if titles is not None:
219 | ax.set_title(titles[i])
220 |
221 | update_lines(i, data, lines)
222 |
223 | plt.subplots_adjust(wspace=0)
224 | plt.show()
225 | if fig_filename:
226 | fig.savefig(fig_filename, bbox_inches="tight")
227 | return fig
228 |
229 | def _setup_axis(self, ax, limits, azim, elev):
230 | ax.set_xlim3d(limits)
231 | ax.set_ylim3d(limits)
232 | ax.set_zlim3d(limits)
233 |
234 | ax.grid(False)
235 |
236 | ax.set_xticks([])
237 | ax.set_yticks([])
238 | ax.set_zticks([])
239 |
240 | ax.view_init(azim=azim, elev=elev)
241 |
--------------------------------------------------------------------------------
/src/matlab/get_folder_path.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function [parentpath] = get_folder_path()
9 | % GET_FOLDER_PATH Returns the path to the parent folder.
10 | %
11 | % parentPath = get_folder_path() returns the path to the parent folder
12 | % for ease of use in other files.
13 | %
14 | % See also MVNX_TO_CSV
15 | parentpath = cd(cd('..'));
16 | end
--------------------------------------------------------------------------------
/src/matlab/joint_angle_segments.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function [jointAngleSegments] = joint_angle_segments()
9 | % JOINT_ANGLE_SEGMENTS Returns a cell array containing the joint angles of
10 | % the XSens data file in proper order.
11 | %
12 | % jointAngleSegments = JOINT_ANGLE_SEGMENTS() is used elsewhere to
13 | % construct joint angle maps and fill in data.
14 | %
15 | % See also INIT_JOINT_ANGLE_MAP, READ_DATA
16 | %
17 | jointAngleSegments = {'jL5S1','jL4L3','jL1T12','jT9T8','jT1C7','jC1Head','jRightT4Shoulder',...
18 | 'jRightShoulder','jRightElbow','jRightWrist','jLeftT4Shoulder','jLeftShoulder','jLeftElbow',...
19 | 'jLeftWrist','jRightHip','jRightKnee','jRightAnkle','jRightBallFoot','jLeftHip','jLeftKnee','jLeftAnkle',...
20 | 'jLeftBallFoot'};
21 | end
--------------------------------------------------------------------------------
/src/matlab/joint_angles.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function[matrix]= joint_angles(values, frames, index)
9 | % JDEG Re-shuffles each joint angle in data to a joint
10 | %
11 | % matrix = JDEG(values, frames, index) will associate flexion, abduction,
12 | % and extension to the proper joint.
13 | %
14 | % See also READ_DATA
15 | matrix= zeros(frames, 3);
16 | for i= 1:frames
17 | matrix(i, :)= values(i).jointAngle(index:index+2);
18 | end
19 | end
--------------------------------------------------------------------------------
/src/matlab/load_partial_mvnx.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function [data,time,cal_data,extra] = load_partial_mvnx(filename,section,varargin)
9 | % This function allows matlab to read in larger mvnx files since it does
10 | % not try to load the entire file into memory which can be several GB
11 | % large.
12 | %
13 | %
14 | % data = load_partial_mvnx(filename,section)
15 | % Loads all of the frame data from the fields specified in section, with
16 | % multiple fields passed as a cell array.
17 | %
18 | % data = load_partial_mvnx(filename,section,n)
19 | % Loads n frames of the data. Where n<=0 and n=0 indicates to
20 | % read to the end of the file. If n is greater than the number of frames
21 | % then the data is read until the end of the file.
22 | %
23 | % data = load_partial_mvnx(filename,section,m,n)
24 | % Loads n frames of the data starting at frame m.
25 | %
26 | % Inputs
27 | % filename - Type:string, path to the file to be read
28 | %
29 | % section - Types:string,cell array, labels of data to read
30 | % Valid value for field are 'segments','joints','orientation','position','velocity',
31 | % 'acceleration','angularVelocity','angularAcceleration','jointAngle','jointAngleXYZ'
32 | %
33 | % 'segments' and 'joints' cannot be used in the cell array.
34 | % Order does not matter in the cell array
35 | % Example section = 'position' or {'acceleration','position',...}
36 | %
37 | % m,n - Type:integers, m>=0, n>=0
38 | %
39 | % Outputs
40 | % data - Type:Varies based on arguements
41 | %
42 | % time - Type:struct time values for the frames, output only for sections
43 | % 'orientation','position','velocity','acceleration','angularVelocity',
44 | % 'angularAcceleration','jointAngle','jointAngleXYZ'
45 | %
46 | % cal_data - Type:struct calibration pose data, output only for sections
47 | % 'orientation','position','velocity','acceleration','angularVelocity',
48 | % 'angularAcceleration','jointAngle','jointAngleXYZ'
49 | %
50 | % extra - Type:struct other data from sections 'orientation','position',
51 | % 'velocity','acceleration','angularVelocity','angularAcceleration',
52 | % 'jointAngle','jointAngleXYZ'
53 |
54 |
55 | % preallocate some space probably will not be enough
56 | N=1000000;
57 | cells=cell(N,1);
58 |
59 | % number of calibration poses, aka frames with type other than 'normal',
60 | % this might need to be made variable in the future
61 | n_cal = 3;
62 |
63 | % error handling and setting up initial parameters
64 | frame_data = {'orientation', 'position', 'velocity', 'acceleration', 'angularVelocity', ...
65 | 'angularAcceleration', 'footContacts', 'sensorFreeAcceleration', ...
66 | 'sensorMagneticField', 'sensorOrientation', 'jointAngle', 'jointAngleXZY', ...
67 | 'jointAngleErgo', 'jointAngleErgoXZY', 'centerOfMass'};
68 |
69 | if iscell(section)
70 | parfor i=1:length(section)
71 | f_test(i)=any(strcmp(frame_data, section{i}));
72 | end
73 | else
74 | f_test=any(strcmp(frame_data, section));
75 | end
76 |
77 | if iscell(section)
78 | ME=MException('Input:error','Supplied arg %s %s %s\nSyntax: load_partial_mvnx(filename,section,[[limit_start],][limit_stop])\nSupported section values are:\n Singular Arguments - segments, joints\n These can be applied individually or as a cell array - orientation, position, velocity, acceleration, angularVelocity, angularAcceleration, jointAngle, jointAngleXYZ',filename,section{:},varargin{:});
79 | else
80 | ME=MException('Input:error','Supplied arg %s %s %s\nSyntax: load_partial_mvnx(filename,section,[[limit_start],][limit_stop])\nSupported section values are:\n Singular Arguments - segments, joints\n These can be applied individually or as a cell array - orientation, position, velocity, acceleration, angularVelocity, angularAcceleration, jointAngle, jointAngleXYZ',filename,section,varargin{:});
81 | end
82 | LE=MException('Limit:error','0<=limit_start<=limit_end');
83 |
84 | if nargin-length(varargin)<2 || nargin>4
85 | throw(ME)
86 | end
87 | run_cal=1;
88 | % set up to get calibration data
89 | if length(varargin)==1 && ischar(varargin{1})
90 | limit_start=0;
91 | limit_stop=n_cal;
92 | section={'orientation','position'};
93 | f_test=1;
94 | run_cal=0;
95 | %set up to get segment/joint info
96 | elseif ~iscell(section) && (strcmp(section,'segments') || strcmp(section,'joints'))
97 | limit_start=0;
98 | limit_stop=0;
99 | % set converting limits for all other formats
100 | elseif length(varargin)==1
101 | if varargin{1}<0
102 | throw(LE)
103 | end
104 | limit_start=n_cal;
105 | limit_stop=varargin{1}+n_cal;
106 | elseif length(varargin)==2
107 | if varargin{1}<0 || varargin{2}',''};
122 | en = {'','',''};
123 | stop = '';
124 | st_depth = 1;
125 | en_depth = 0;
126 | max_depth=length(st);
127 | elseif ~iscell(section) && strcmp(section,'joints')
128 | start = '';
129 | st = {'',''};
131 | en = {''};
132 | stop = '';
133 | st_depth = 1;
134 | en_depth = 0;
135 | max_depth=length(st);
136 | elseif all(f_test) && ~any(strcmp('segments', section)) && ~any(strcmp('joints', section))
137 | start = ''};
145 | stop = '';
146 | st_depth = 1;
147 | en_depth = 0;
148 | max_depth=length(st);
149 | else
150 | throw(ME);
151 | end
152 |
153 | % Check file name
154 | if isempty(strfind(filename,'mvnx'))
155 | filename = [get_folder_path() '\' filename '.mvnx'];
156 | end
157 | if ~exist(filename,'file')
158 | error([mfilename ':xsens:filename'],['No file with filename: ' filename ', file is not present or file has wrong format (function only reads .mvnx)'])
159 | end
160 |
161 |
162 | % Open file
163 | disp('Reading file');
164 | fid = fopen(filename, 'r', 'n', 'UTF-8');
165 |
166 | l = fgetl(fid);
167 | k = 1;
168 |
169 | % find start
170 | while isempty(strfind(l,start))
171 | l = fgetl(fid);
172 | end
173 | cells{k} = l;
174 | k=k+1;
175 | l = fgetl(fid);
176 | count = 0;
177 | while ~feof(fid)
178 | if ~isempty(strfind(l,stop))
179 | cells{k} = l;
180 | k = k+1;
181 | break
182 | end
183 | % test for the start key at the current depth
184 | found = 0;
185 | if st_depth<=max_depth && ~isempty(strfind(l,st{st_depth}))
186 | found = 1;
187 | if count>=limit_start
188 | cells{k} = l;
189 | k = k+1;
190 | end
191 | st_depth = st_depth+1;
192 | en_depth = en_depth+1;
193 | elseif st_depth>max_depth
194 | for i=1:length(key)
195 | if ~isempty(strfind(l,key{i}))
196 | if count>=limit_start
197 | cells{k} = l;
198 | k = k+1;
199 | end
200 | break
201 | end
202 | end
203 | end
204 | % if a start key was not fould
205 | if en_depth>0 && ~isempty(strfind(l,en{en_depth}))
206 | if found==0
207 | if count>=limit_start
208 | cells{k} = l;
209 | k = k+1;
210 | end
211 | if en_depth == 1
212 | count=count+1;
213 | end
214 | st_depth = st_depth-1;
215 | en_depth = en_depth-1;
216 | else
217 | if en_depth == 1
218 | count=count+1;
219 | end
220 | st_depth = st_depth-1;
221 | en_depth = en_depth-1;
222 | end
223 | end
224 | if limit_stop~=0 && count>=limit_stop
225 | cells{k}=stop;
226 | break
227 | end
228 | l = fgetl(fid);
229 | end
230 | if k=8 && strcmp('comment',line(2:8))
273 | % add exception for comment
274 | word{n} = line(2:hooks(1)-1);
275 | iLine = hooks(1)+1;
276 | elseif openandclose(n)
277 | word{n} = line(2:find(line==62,1)-1);
278 | iLine = find(line==62,1)+1;
279 | else
280 | word{n} = line(2:end-1);
281 | oneword = true;
282 | end
283 | if word{n}(1) ~= '/'
284 | if ~oneword && ~openandclose(n)
285 | k = find(line == 34);
286 | k = reshape(k,2,length(k)/2)';
287 | l = [iLine find(line(iLine:end) == 61)+iLine-2];
288 | fieldname = cell(1,length(l)-1); value = cell(1,length(l)-1);
289 | if ~isempty(k)
290 | for il=1:size(k,1)
291 | fieldname{il} = line(iLine:find(line(iLine:end) == 61,1)+iLine-2);
292 | if size(k,1) > 1 && il < size(k,1)
293 | a = strfind(line(iLine:end),'" ')+iLine+1;
294 | iLine = a(1);
295 | end
296 | value{il} = line(k(il,1)+1:k(il,2)-1);
297 | end
298 | else
299 | value = []; fieldname =[];
300 | value = line(find(line == 62,1)+1:end);
301 | end
302 | elseif ~oneword && openandclose(n)
303 | value = []; fieldname =[];
304 | value = line(find(line == 62,1)+1:find(line==60,1,'last')-1);
305 | else
306 | value = NaN;fieldname = [];
307 | end
308 | wordvalue{n} = value;
309 | wordfields{n} = fieldname;
310 | end
311 | end
312 | %% get values
313 | parfor n=1:length(wordvalue)
314 | if iscell(wordvalue{n})
315 | if length(wordvalue{n}) == 1
316 | B = [];
317 | try
318 | B = str2double(wordvalue{n}{1});
319 | end
320 | if ~isempty(B)
321 | wordvalue{n} = B;
322 | else
323 | wordvalue{n} = wordvalue{n}{1};
324 | end
325 | else
326 | for m=1:length(wordvalue{n})
327 | try
328 | B = str2double(wordvalue{n}{m});
329 | if ~isempty(B)
330 | wordvalue{n}{m} = B;
331 | end
332 | end
333 | end
334 | end
335 | else
336 | try
337 | B = str2num(wordvalue{n});
338 | if ~isempty(B)
339 | wordvalue{n} = B;
340 | end
341 | end
342 | end
343 | end
344 | %%
345 | disp('Logging data');
346 | % form tree structure for segment and joint data
347 | if ~iscell(section) && (strcmp(section,'segments') || strcmp(section,'joints'))
348 | data=rec_struct(word,wordfields,wordvalue,wordindex,1,length(section));
349 | extra = [];
350 | cal_data = [];
351 | time=[];
352 | else
353 | %store extra values and get calibration data
354 | extra = struct(wordfields{1,1}{1},wordvalue{1,1}{1},wordfields{1,1}{2},wordvalue{1,1}{2});
355 | if run_cal
356 | cal_data = load_partial_mvnx(filename,section,'cal');
357 | else
358 | cal_data=[];
359 | end
360 |
361 | f = strcmp(word,'frame');
362 | fi=find(f);
363 | n_f = length(wordfields{fi(1)});
364 | fr_lab = wordfields{f};
365 |
366 | n_k=length(key);
367 |
368 | for i=1:length(key)
369 | k = strcmp(word,key{i});
370 | kf = find(k);
371 | kc{i}=kf;
372 | end
373 |
374 | %preallocate
375 | val=cell(length(fi),n_f);
376 | val2=cell(length(fi),n_k);
377 |
378 | for i=1:length(fi)
379 | for j=1:n_f
380 | if isnumeric(wordvalue{fi(i)}{j})
381 | val{i,j}=wordvalue{fi(i)}{j};
382 | elseif ischar(wordvalue{fi(i)}{j})
383 | val{i,j}=wordvalue{fi(i)}{j};
384 | else
385 | disp('error');
386 | end
387 | end
388 | for j=1:n_k
389 | val2{i,j}=wordvalue{kc{j}(i)};
390 | end
391 | end
392 |
393 | % put data into structures
394 | time=form_struct(fr_lab,val);
395 | data=form_struct(key,val2);
396 |
397 | end
398 | end
399 | function data = form_struct(lab,val)
400 | switch length(lab)
401 | case 1
402 | data=struct(lab{1},val(:,1));
403 | case 2
404 | data=struct(lab{1},val(:,1),lab{2},val(:,2));
405 | case 3
406 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3));
407 | case 4
408 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4));
409 | case 5
410 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4),lab{5}, val(:,5));
411 | case 6
412 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4),lab{5},val(:,5),...
413 | lab{6},val(:,6));
414 | case 7
415 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4),lab{5},val(:,5),...
416 | lab{6},val(:,6),lab{7},val(:,7));
417 | case 8
418 | data=struct(lab{1},val(:,1),lab{2},val(:,2),lab{3},val(:,3),lab{4},val(:,4),lab{5},val(:,5),...
419 | lab{6},val(:,6),lab{7},val(:,7),lab{8},val(:,8));
420 | otherwise
421 | disp('data format error');
422 | end
423 | end
424 | % Recursively builds up structs for segments and joints
425 | function [s,j] = rec_struct(word,field,value,index,i,len)
426 | in=index{i};
427 | numfields = length(field{i});
428 | j=i+1;
429 | k=1;
430 | si={};
431 | while index{j}>in
432 | [sn,j]=rec_struct(word,field,value,index,j,len);
433 | if ~isempty(sn)
434 | si{k}=sn;
435 | k=k+1;
436 | end
437 | end
438 | if isempty(si)
439 | si=value{i};
440 | end
441 | switch numfields
442 | case 0
443 | s=si;
444 | case 1
445 | s=struct(char(field{i}),value{i},word{i+1},si);
446 | case 2
447 | s=struct(field{i}{1},value{i}{1},field{i}{2},value{i}{2},word{i+1},si);
448 | case 3
449 | s=struct(field{i}{1},value{i}{1},field{i}{2},value{i}{2},field{i}{3},value{i}{3},word{i+1},si);
450 | case 4
451 | s=struct(field{i}{1},value{i}{1},field{i}{2},value{i}{2},field{i}{3},value{i}{3},field{i}{4},value{i}{4},word{i+1},si);
452 | case 5
453 | s=struct(field{i}{1},value{i}{1},field{i}{2},value{i}{2},field{i}{3},value{i}{3},field{i}{4},value{i}{4},field{i}{5},value{i}{5},word{i+1},si);
454 | otherwise
455 | disp('error')
456 | end
457 | end
458 |
--------------------------------------------------------------------------------
/src/matlab/mvnx_to_csv.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function data = mvnx_to_csv(file)
9 | %MVNX_TO_CSV File for reading in .mvnx files and converting them to .csv files.
10 | % Reading the .mvnx files is expensive, so this converts orientation, position, and
11 | % joint angle data to csv format.
12 |
13 | filename = [get_folder_path(), '/mvnx-files/', file];
14 |
15 | [data, ~, ~, ~] = load_partial_mvnx(filename, {'orientation', 'position', 'jointAngle'});
16 |
17 | csvData = zeros(size(data,1), 227);
18 |
19 | for i = 1:size(data,1)
20 |
21 | csvData(i, :) = [data(i).orientation(:)' data(i).position(:)' data(i).jointAngle(:)'];
22 |
23 | end
24 | output = [get_folder_path(), '/', 'csv-files', '/', file(1:end-5), '.csv'];
25 |
26 | csvwrite(output, csvData);
27 |
28 | end
--------------------------------------------------------------------------------
/src/matlab/mvnx_to_hdf.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function mvnx_to_hdf(file)
9 |
10 | input_file = [get_folder_path(), '/mvnx-files/', file];
11 | output_file = [get_folder_path(), '/h5-files/', file(1:end-5), '.h5'];
12 |
13 | dataset_names = {'orientation', 'position', 'velocity', 'acceleration', 'angularVelocity', ...
14 | 'angularAcceleration', 'footContacts', 'sensorFreeAcceleration', ...
15 | 'sensorMagneticField', 'sensorOrientation', 'jointAngle', 'jointAngleXZY', ...
16 | 'jointAngleErgo', 'jointAngleErgoXZY', 'centerOfMass'};
17 |
18 | disp(input_file);
19 |
20 | tic
21 | for i = 1:length(dataset_names)
22 | [data, ~, ~, ~] = load_partial_mvnx(input_file, dataset_names(i));
23 | data = struct2cell(data);
24 |
25 | dataset = cell2mat(data);
26 | dataset = reshape(dataset, length(data), length(dataset)/length(data));
27 |
28 | h5create(output_file, ['/', dataset_names{i}], size(dataset));
29 | h5write(output_file, ['/', dataset_names{i}], dataset);
30 | end
31 | toc
32 |
33 | h5disp(output_file);
34 | end
35 |
--------------------------------------------------------------------------------
/src/matlab/mvnx_to_hdf_batch.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | clear
9 |
10 | fprintf('Starting job');
11 | %
12 | % BATCH defines the job and sends it for execution.
13 | %
14 | job = batch ('process_mvnx_files', 'Profile', 'dtlogin_R2018a', 'AttachedFiles', {'get_folder_path', 'joint_angle_segments', 'joint_angles', ...
15 | 'load_partial_mvnx', 'mvnx_to_hdf', ...
16 | 'segment_orientation', 'segment_position', 'segment_reference' }, ...
17 | 'CurrentFolder', '.', 'Pool', 9);
18 | %
19 | % WAIT pauses the MATLAB session til the job completes.
20 | %
21 | wait (job);
22 |
23 | %
24 | % DIARY displays any messages printed during execution.
25 | %
26 | diary (job);
27 |
28 | %
29 | % These commands clean up data about the job we no longer need.
30 | delete ( job ); %Use delete() for R2012a or later
31 |
32 | fprintf(1, '\n');
33 | fprintf(1, 'Done');
34 |
--------------------------------------------------------------------------------
/src/matlab/process_mvnx_files.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | files = dir([get_folder_path(), '/mvnx-files/*.mvnx']);
9 | i = 0;
10 |
11 | for file = files'
12 | i = i + 1;
13 | disp(i);
14 | mvnx_to_hdf(file.name);
15 | end
16 |
--------------------------------------------------------------------------------
/src/matlab/segment_orientation.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function[matrix]= segment_orientation(values, frames, index)
9 | % SEGMENT_ORIENTATION Re-shuffles each quaternion in data to a segment
10 | %
11 | % matrix = SEGMENT_ORIENTATION(values, frames, index) will associate 4
12 | % parts of quaternion to the proper segment using XSens'
13 | % segmentation indices.
14 | %
15 | % See also READ_DATA, GET_ROLL, GET_PITCH, GET_YAW
16 | matrix= zeros(frames, 4);
17 | for i= 1:frames
18 | matrix(i, :)= values(i).orientation(index:index+3);
19 | end
20 | end
--------------------------------------------------------------------------------
/src/matlab/segment_position.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function[matrix]= segment_position(values, frames, index)
9 | % SEGMENT_POSITION Re-shuffles each x,y,z position in data to a segment
10 | %
11 | % matrix = SEGMENT_POSITION(values, frames, index) will associate 3
12 | % parts of position to the proper segment using XSens'
13 | % segmentation indices.
14 | %
15 | % See also READ_DATA
16 | matrix= zeros(frames, 3);
17 | for i= 1:frames
18 | matrix(i, :)= values(i).position(index:index + 2);
19 | end
20 | end
--------------------------------------------------------------------------------
/src/matlab/segment_reference.m:
--------------------------------------------------------------------------------
1 | % Copyright (c) 2020-present, Assistive Robotics Lab
2 | % All rights reserved.
3 | %
4 | % This source code is licensed under the license found in the
5 | % LICENSE file in the root directory of this source tree.
6 | %
7 |
8 | function [segmentOrientationMap, segmentPositionMap, jointMap]= segment_reference
9 | % SEGREF Places all of the segments and joints into maps for
10 | % accessing. Necessary for reading data from MVNX files.
11 | %
12 | % [segor_map, segpos_map, joint_map] = segRef will return a map for
13 | % segment orientation, segment position, and joint angles. These maps can
14 | % then be used to access the proper index in the XSens data for
15 | % processing.
16 | %
17 | % See also READ_DATA, LOAD_PARTIAL_MVNX
18 |
19 | orientationIndices = {1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 65, 69, 73, 77, 81, 85, 89};
20 | positionIndices = {1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64, 67};
21 | jointIndices = {1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64};
22 |
23 | segmentKeys= {'Pelvis';'L5';'L3';'T12';'T8';'Neck';'Head';'RightShoulder';'RightUpperArm';'RightForeArm';'RightHand';'LeftShoulder';'LeftUpperArm';'LeftForeArm';'LeftHand';'RightUpperLeg';'RightLowerLeg';'RightFoot';'RightToe';'LeftUpperLeg';'LeftLowerLeg';'LeftFoot';'LeftToe'};
24 | jointKeys= {'jL5S1';'jL4L3';'jL1T12';'jT9T8';'jT1C7';'jC1Head';'jRightC7Shoulder';'jRightShoulder';'jRightElbow';'jRightWrist';'jLeftC7Shoulder';'jLeftShoulder';'jLeftElbow';'jLeftWrist';'jRightHip';'jRightKnee';'jRightAnkle';'jRightBallFoot';'jLeftHip';'jLeftKnee';'jLeftAnkle';'jLeftBallFoot'};
25 |
26 | segmentOrientationMap = containers.Map(segmentKeys, orientationIndices);
27 | segmentPositionMap = containers.Map(segmentKeys, positionIndices);
28 | jointMap = containers.Map(jointKeys, jointIndices);
29 |
--------------------------------------------------------------------------------
/src/seq2seq/seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | Seq2Seq Encoders and Decoders.
10 |
11 | Reference:
12 | [1] https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
13 | """
14 | import torch
15 | from torch import nn
16 | import torch.nn.functional as F
17 |
18 |
19 | class EncoderRNN(nn.Module):
20 | """An encoder for seq2seq architectures."""
21 |
22 | def __init__(self, input_size, hidden_size,
23 | dropout=0.0, bidirectional=False):
24 | """Initialize encoder for use with decoder in seq2seq architecture.
25 |
26 | Args:
27 | input_size (int): Number of features in the input
28 | hidden_size (int): Number of hidden units in the GRU layer
29 | dropout (float, optional): Dropout applied after GRU layer.
30 | Defaults to 0.0.
31 | bidirectional (bool, optional): Whether encoder is bidirectional.
32 | Defaults to False.
33 | """
34 | super(EncoderRNN, self).__init__()
35 |
36 | self.hidden_size = hidden_size
37 | self.bidirectional = bidirectional
38 |
39 | self.gru = nn.GRU(input_size, hidden_size, bidirectional=bidirectional)
40 | self.dropout = nn.Dropout(dropout)
41 |
42 | self.directions = 2 if bidirectional else 1
43 | self.fc = nn.Linear(hidden_size * self.directions, hidden_size)
44 |
45 | def forward(self, input):
46 | """Forward pass through encoder.
47 |
48 | Args:
49 | input (torch.tensor): (seq_len,
50 | batch_size,
51 | input_size)
52 |
53 | Returns:
54 | tuple: Returns output and hidden state of decoder.
55 | output (torch.Tensor): (seq_len,
56 | batch_size,
57 | directions*hidden_size)
58 | hidden (torch.Tensor): (1, batch_size, hidden_size)
59 | """
60 | output, hidden = self.gru(input)
61 | output = self.dropout(output)
62 |
63 | if self.bidirectional:
64 | hidden = torch.tanh(
65 | self.fc(torch.cat((hidden[-2, :, :],
66 | hidden[-1, :, :]), dim=1))).unsqueeze(0)
67 |
68 | return output, hidden
69 |
70 |
71 | class DecoderRNN(nn.Module):
72 | """A decoder for use in seq2seq architectures."""
73 |
74 | def __init__(self, input_size, hidden_size, output_size, dropout=0.0):
75 | """Initialize DecoderRNN.
76 |
77 | Args:
78 | input_size (int): number of features in input
79 | hidden_size (int): number of hidden units in GRU layer
80 | output_size (int): number of features in output
81 | dropout (float, optional): Dropout applied after GRU layer.
82 | Defaults to 0.0.
83 | """
84 | super(DecoderRNN, self).__init__()
85 | self.hidden_size = hidden_size
86 |
87 | self.gru = nn.GRU(input_size, hidden_size)
88 | self.dropout = nn.Dropout(dropout)
89 | self.out = nn.Linear(hidden_size, output_size)
90 |
91 | def forward(self, input, hidden):
92 | """Forward pass through decoder.
93 |
94 | Args:
95 | input (torch.Tensor): input batch to pass through RNN
96 | (1, batch_size, input_size)
97 | hidden (torch.Tensor): hidden state of the decoder
98 | (1, batch_size, hidden_size)
99 |
100 | Returns:
101 | tuple: Returns output and hidden state of decoder.
102 | output (torch.Tensor): (1, batch_size, output_size)
103 | hidden (torch.Tensor): (1, batch_size, hidden_size)
104 | """
105 | output, hidden = self.gru(input, hidden)
106 | output = self.dropout(output)
107 | output = self.out(output)
108 | return output, hidden
109 |
110 |
111 | class AttnDecoderRNN(nn.Module):
112 | """A decoder with an attention layer for use in seq2seq architectures."""
113 |
114 | def __init__(self, output_size, feature_size, enc_hidden_size,
115 | dec_hidden_size, attention, bidirectional_encoder=False):
116 | """Initialize AttnDecoderRNN.
117 |
118 | Args:
119 | output_size (int): size of output
120 | features_size (int): number of features in input batch
121 | enc_hidden_size (int): hidden size of encoder
122 | dec_hidden_size (int): hidden size of decoder
123 | attention (str): attention method for use in Attention layer
124 | bidirectional_encoder (bool, optional): Whether encoder used with
125 | decoder is bidirectional. Defaults to False.
126 | """
127 | super().__init__()
128 |
129 | self.enc_hidden_size = enc_hidden_size
130 | self.dec_hidden_size = dec_hidden_size
131 | self.output_size = output_size
132 | self.feature_size = feature_size
133 | self.attention = attention
134 |
135 | self.directions = 2 if bidirectional_encoder else 1
136 |
137 | self.rnn = nn.GRU(self.directions * enc_hidden_size + feature_size,
138 | dec_hidden_size)
139 |
140 | self.out = nn.Linear(self.directions * enc_hidden_size +
141 | dec_hidden_size + feature_size,
142 | output_size)
143 |
144 | def forward(self, input, hidden, annotations):
145 | """Forward pass through decoder.
146 |
147 | Args:
148 | input (torch.Tensor): (1, batch_size, feature_size)
149 | hidden (torch.Tensor): (1, batch_size, dec_hidden_size)
150 | annotations (torch.Tensor): (seq_len,
151 | batch_size,
152 | directions * enc_hidden_size)
153 |
154 | Returns:
155 | tuple: Returns output and hidden state of decoder.
156 | output (torch.Tensor): (1, batch_size, output_size)
157 | hidden (torch.Tensor): (1, batch_size, dec_hidden_size)
158 | """
159 | attention = self.attention(hidden, annotations)
160 |
161 | attention = attention.unsqueeze(1)
162 |
163 | annotations = annotations.permute(1, 0, 2)
164 |
165 | context_vector = torch.bmm(attention, annotations)
166 |
167 | context_vector = context_vector.permute(1, 0, 2)
168 |
169 | rnn_input = torch.cat((input, context_vector), dim=2)
170 |
171 | output, hidden = self.rnn(rnn_input, hidden)
172 |
173 | # assert torch.all(torch.isclose(output, hidden))
174 |
175 | input = input.squeeze(0)
176 | output = output.squeeze(0)
177 | context_vector = context_vector.squeeze(0)
178 |
179 | output = self.out(torch.cat((output, context_vector, input), dim=1))
180 |
181 | # output = [batch_size, output_size]
182 |
183 | return output.unsqueeze(0), hidden
184 |
185 |
186 | class Attention(nn.Module):
187 | """An Attention layer for the AttnDecoder with multiple methods."""
188 |
189 | def __init__(self, hidden_size, batch_size,
190 | method, bidirectional_encoder=False):
191 | """Initialize Attention layer.
192 |
193 | Args:
194 | hidden_size (int): Size of hidden state in decoder.
195 | batch_size (int): Size of batch, used for shape checks.
196 | method (str): Attention technique/method to use. Supports
197 | general, biased-general, activated-general, dot, add, and
198 | concat.
199 | bidirectional_encoder (bool, optional): Whether encoder used with
200 | decoder is bidirectional. Defaults to False.
201 | """
202 | super().__init__()
203 |
204 | self.batch_size = batch_size
205 | self.hidden_size = hidden_size
206 | self.method = method
207 | self.directions = 2 if bidirectional_encoder else 1
208 |
209 | if method in ["general", "biased-general", "activated-general"]:
210 | bias = not(method == "general")
211 | self.Wa = nn.Linear(hidden_size,
212 | self.directions * hidden_size,
213 | bias=bias)
214 | elif method == "add":
215 | self.Wa = nn.Linear((self.directions * hidden_size),
216 | hidden_size,
217 | bias=False)
218 | self.Wb = nn.Linear(hidden_size,
219 | hidden_size,
220 | bias=False)
221 | self.va = nn.Parameter(torch.rand(hidden_size))
222 | elif method == "concat":
223 | self.Wa = nn.Linear((self.directions * hidden_size) + hidden_size,
224 | hidden_size, bias=False)
225 | self.va = nn.Parameter(torch.rand(hidden_size))
226 |
227 | def forward(self, hidden, annotations):
228 | """Forward pass through attention layer.
229 |
230 | Args:
231 | hidden (torch.Tensor): (1, batch_size, hidden_size)
232 | annotations (torch.Tensor): (seq_len,
233 | batch_size,
234 | directions * hidden_size)
235 |
236 | Returns:
237 | torch.Tensor: (batch_size, seq_len)
238 | """
239 | assert list(hidden.shape) == [1, self.batch_size, self.hidden_size]
240 |
241 | assert self.batch_size == annotations.shape[1]
242 | assert self.directions * self.hidden_size == annotations.shape[2]
243 | self.seq_len = annotations.shape[0]
244 |
245 | hidden = hidden.squeeze(0)
246 |
247 | assert list(hidden.shape) == [self.batch_size, self.hidden_size]
248 |
249 | annotations = annotations.permute(1, 0, 2)
250 |
251 | assert list(annotations.shape) == [self.batch_size,
252 | self.seq_len,
253 | self.directions * self.hidden_size]
254 |
255 | score = self._score(hidden, annotations)
256 |
257 | assert list(score.shape) == [self.batch_size, self.seq_len]
258 |
259 | return F.softmax(score, dim=1)
260 |
261 | def _score(self, hidden, annotations):
262 | """Compute an attention score with hidden state and annotations.
263 |
264 | Args:
265 | hidden (torch.Tensor): (batch_size, hidden_size)
266 | annotations (torch.Tensor): (batch_size,
267 | seq_len,
268 | directions * hidden_size)
269 |
270 | Returns:
271 | torch.Tensor: (batch_size, seq_len)
272 | """
273 | if "general" in self.method:
274 | x = self.Wa(hidden)
275 |
276 | x = x.unsqueeze(-1)
277 |
278 | score = annotations.bmm(x)
279 |
280 | if self.method == "activated-general":
281 | score = torch.tanh(score)
282 |
283 | assert list(score.shape) == [self.batch_size,
284 | self.seq_len,
285 | 1]
286 |
287 | score = score.squeeze(-1)
288 |
289 | return score
290 |
291 | elif self.method == "dot":
292 | hidden = hidden.unsqueeze(-1)
293 |
294 | hidden = hidden.repeat(1, self.directions, 1)
295 |
296 | score = annotations.bmm(hidden)
297 |
298 | assert list(score.shape) == [self.batch_size,
299 | self.seq_len,
300 | 1]
301 |
302 | score = score.squeeze(-1)
303 |
304 | return score
305 |
306 | elif self.method == "add":
307 | x1 = self.Wa(annotations)
308 |
309 | x2 = self.Wb(hidden)
310 |
311 | x2 = x2.unsqueeze(1)
312 |
313 | x2 = x2.repeat(1, self.seq_len, 1)
314 |
315 | energy = x1 + x2
316 |
317 | energy = energy.permute(0, 2, 1)
318 |
319 | assert list(energy.shape) == [self.batch_size,
320 | self.hidden_size,
321 | self.seq_len]
322 |
323 | va = self.va.repeat(self.batch_size, 1).unsqueeze(1)
324 |
325 | score = torch.bmm(va, energy)
326 |
327 | assert list(score.shape) == [self.batch_size,
328 | 1,
329 | self.seq_len]
330 |
331 | score = score.squeeze(1)
332 |
333 | return score
334 |
335 | elif self.method == "concat":
336 | hidden = hidden.unsqueeze(1)
337 |
338 | hidden = hidden.repeat(1, self.seq_len, 1)
339 |
340 | energy = torch.tanh(self.Wa(torch.cat((hidden, annotations), 2)))
341 |
342 | energy = energy.permute(0, 2, 1)
343 |
344 | assert list(energy.shape) == [self.batch_size,
345 | self.hidden_size,
346 | self.seq_len]
347 |
348 | va = self.va.repeat(self.batch_size, 1).unsqueeze(1)
349 |
350 | assert list(va.shape) == [self.batch_size,
351 | 1,
352 | self.hidden_size]
353 |
354 | score = torch.bmm(va, energy)
355 |
356 | assert list(score.shape) == [self.batch_size,
357 | 1,
358 | self.seq_len]
359 |
360 | score = score.squeeze(1)
361 |
362 | return score
363 |
--------------------------------------------------------------------------------
/src/seq2seq/training_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | import random
11 | import time
12 | import math
13 | import sys
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 | from common.logging import logger
17 | from .seq2seq import (
18 | EncoderRNN,
19 | DecoderRNN,
20 | AttnDecoderRNN,
21 | Attention
22 | )
23 | plt.switch_backend("agg")
24 | torch.manual_seed(0)
25 |
26 |
27 | class Timer:
28 | def __enter__(self):
29 | self.start = time.time()
30 | return self
31 |
32 | def __exit__(self, *args):
33 | self.end = time.time()
34 | self.interval = self.end - self.start
35 |
36 |
37 | def get_encoder(feature_size, device, hidden_size=64,
38 | dropout=0.0, bidirectional=False):
39 | """Function to help set up an encoder.
40 |
41 | Args:
42 | feature_size (int): number of features in the input to this model
43 | device (torch.device): what device to put this encoder on
44 | hidden_size (int, optional): number of hidden units in the encoder.
45 | Defaults to 64.
46 | dropout (float, optional): dropout to apply in the encoder.
47 | Defaults to 0.0.
48 | bidirectional (bool, optional): whether to use a bidirectional encoder.
49 | Defaults to False.
50 |
51 | Returns:
52 | EncoderRNN: an encoder for use in seq2seq tasks
53 | """
54 | encoder = EncoderRNN(feature_size, hidden_size,
55 | dropout=dropout, bidirectional=bidirectional)
56 | encoder = encoder.double().to(device)
57 | return encoder
58 |
59 |
60 | def get_decoder(feature_size, device, hidden_size=64, dropout=0.0):
61 | """Function to help set up a decoder.
62 |
63 | Args:
64 | feature_size (int): number of features in the input to this model
65 | device (torch.device): what device to put this encoder on
66 | hidden_size (int, optional): number of hidden units in the encoder.
67 | Defaults to 64.
68 | dropout (float, optional): dropout to apply in the encoder.
69 | Defaults to 0.0.
70 |
71 | Returns:
72 | DecoderRNN: a decoder for use in seq2seq tasks
73 | """
74 | decoder = DecoderRNN(feature_size, hidden_size,
75 | feature_size, dropout=dropout)
76 | decoder = decoder.double().to(device)
77 | return decoder
78 |
79 |
80 | def get_attn_decoder(feature_size, method, device, batch_size=32,
81 | hidden_size=64, bidirectional_encoder=False):
82 | """Function to help set up a decoder with attention.
83 |
84 | Args:
85 | feature_size (int): number of features in the input to this model
86 | method ([type]): [description]
87 | device (torch.device): what device to put this encoder on
88 | batch_size (int, optional): [description]. Defaults to 32.
89 | hidden_size (int, optional): number of hidden units in the encoder.
90 | Defaults to 64.
91 | dropout (float, optional): dropout to apply in the encoder.
92 | Defaults to 0.0.
93 | bidirectional_encoder (bool, optional): whether the encoder is
94 | bidirectional. Defaults to False.
95 |
96 | Returns:
97 | AttnDecoderRNN: a decoder with attention for use in seq2seq tasks
98 | """
99 | attn = Attention(hidden_size, batch_size, method,
100 | bidirectional_encoder=bidirectional_encoder)
101 | decoder = AttnDecoderRNN(feature_size, feature_size,
102 | hidden_size, hidden_size,
103 | attn,
104 | bidirectional_encoder=bidirectional_encoder)
105 | decoder = decoder.double().to(device)
106 | return decoder
107 |
108 |
109 | def loss_batch(data, models, opts, criterion, device,
110 | teacher_forcing_ratio=0.0, use_attention=False,
111 | norm_quaternions=False, average_batch=True):
112 | """Train or validate encoder and decoder models on a single batch of data.
113 |
114 | Args:
115 | data (tuple): tuple containing input and target batches.
116 | models (tuple): tuple containing encoder and decoder for use during
117 | training.
118 | opts (tuple): tuple containing encoder and decoder optimizers
119 | criterion (nn.Module): criterion to use for training or validation
120 | device (torch.device): device to put batches on
121 | teacher_forcing_ratio (float, optional): percent of the time to use
122 | teacher forcing in the decoder. Defaults to 0.0.
123 | use_attention (bool, optional): whether the decoder uses attention.
124 | Defaults to False.
125 | norm_quaternions (bool, optional): whether the quaternion outputs
126 | should be normalized. Defaults to False.
127 | average_batch (bool, optional): For use during training; can be set to
128 | false for plotting histograms during testing. Defaults to True.
129 |
130 | Returns:
131 | float or list: will return a single loss or list of losses depending
132 | on the average_batch flag.
133 | """
134 | training = (opts is not None)
135 |
136 | encoder, decoder = models
137 | input_batch, target_batch = data
138 |
139 | input_batch = input_batch.to(device).double()
140 | target_batch = target_batch.to(device).double()
141 |
142 | if training:
143 | encoder.train(), decoder.train()
144 | encoder_opt, decoder_opt = opts
145 | else:
146 | encoder.eval(), decoder.eval()
147 |
148 | loss = 0
149 | seq_length = target_batch.shape[1]
150 |
151 | input = input_batch.permute(1, 0, 2)
152 | encoder_outputs, encoder_hidden = encoder(input)
153 |
154 | decoder_hidden = encoder_hidden
155 | decoder_input = torch.ones_like(target_batch[:, 0, :]).unsqueeze(0)
156 | EOS = torch.zeros_like(target_batch[:, 0, :]).unsqueeze(0)
157 | outputs = torch.zeros_like(target_batch)
158 |
159 | use_teacher_forcing = (True if random.random() < teacher_forcing_ratio
160 | else False)
161 |
162 | if not average_batch:
163 | if training:
164 | logger.warning("average_batch must be true when training")
165 | sys.exit()
166 | losses = [0 for i in range(target_batch.shape[0])]
167 |
168 | for t in range(seq_length):
169 |
170 | if use_attention:
171 | decoder_output, decoder_hidden = decoder(
172 | decoder_input, decoder_hidden, encoder_outputs)
173 | else:
174 | decoder_output, decoder_hidden = decoder(
175 | decoder_input, decoder_hidden)
176 |
177 | target = target_batch[:, t, :].unsqueeze(0)
178 |
179 | output = decoder_output
180 |
181 | if norm_quaternions:
182 | original_shape = output.shape
183 |
184 | output = output.contiguous().view(-1, 4)
185 | output = F.normalize(output, p=2, dim=1).view(original_shape)
186 |
187 | outputs[:, t, :] = output
188 |
189 | if use_teacher_forcing:
190 | decoder_input = target
191 | else:
192 | if torch.all(torch.eq(decoder_output, EOS)):
193 | break
194 | decoder_input = output.detach()
195 |
196 | loss = criterion(outputs, target_batch)
197 |
198 | if training:
199 | loss.backward()
200 |
201 | encoder_opt.step()
202 | encoder_opt.zero_grad()
203 |
204 | decoder_opt.step()
205 | decoder_opt.zero_grad()
206 |
207 | if average_batch:
208 | return loss.item()
209 | else:
210 | losses = []
211 | for b in range(outputs.shape[0]):
212 | sample_loss = criterion(outputs[b, :], target_batch[b, :])
213 | losses.append(sample_loss.item())
214 |
215 | return losses
216 |
217 |
218 | def fit(models, optims, epochs, dataloaders, training_criterion,
219 | validation_criteria, schedulers, device, model_file_path,
220 | teacher_forcing_ratio=0.0, use_attention=False,
221 | norm_quaternions=False, schedule_rate=1.0):
222 | """Fit a seq2seq model to data, logging training and validation loss.
223 |
224 | Args:
225 | models (tuple): tuple containing the encoder and decoder
226 | optims (tuple): tuple containing the encoder and decoder optimizers for
227 | training
228 | epochs (int): number of epochs to train for
229 | dataloaders (tuple): tuple containing the training dataloader and val
230 | dataloader.
231 | training_criterion (nn.Module): criterion for backpropagation during
232 | training
233 | validation_criteria (list): list of criteria for validation
234 | schedulers (list): list of schedulers to control learning rate for
235 | optimizers
236 | device (torch.device): device to place data on
237 | model_file_path (str): where to save the model when validation loss
238 | reaches new minimum
239 | teacher_forcing_ratio (float, optional): percent of the time to use
240 | teacher forcing in the decoder. Defaults to 0.0.
241 | use_attention (bool, optional): whether decoder uses attention or not.
242 | Defaults to False.
243 | norm_quaternions (bool, optional): whether quaternions should be
244 | normalized after they are output from the decoder.
245 | Defaults to False.
246 | schedule_rate (float, optional): rate to increase or decrease teacher
247 | forcing ratio. Defaults to 1.0.
248 | """
249 |
250 | train_dataloader, val_dataloader = dataloaders
251 |
252 | min_val_loss = math.inf
253 | for epoch in range(epochs):
254 | losses = []
255 | total_time = 0
256 |
257 | logger.info(f"Epoch {epoch+1} / {epochs}")
258 |
259 | for index, data in enumerate(train_dataloader, 0):
260 | with Timer() as timer:
261 | loss = loss_batch(data, models,
262 | optims, training_criterion, device,
263 | use_attention=use_attention,
264 | norm_quaternions=norm_quaternions)
265 |
266 | losses.append(loss)
267 | total_time += timer.interval
268 | if index % (len(train_dataloader) // 10) == 0:
269 | logger.info((f"Total time elapsed: {total_time} - "
270 | f"Batch Number: {index} / {len(train_dataloader)}"
271 | f" - Training loss: {loss}"
272 | ))
273 | val_loss = []
274 | for validation_criterion in validation_criteria:
275 | with torch.no_grad():
276 | val_losses = [loss_batch(data, models,
277 | None, validation_criterion, device,
278 | use_attention=use_attention,
279 | norm_quaternions=norm_quaternions)
280 | for _, data in enumerate(val_dataloader, 0)]
281 |
282 | val_loss.append(np.sum(val_losses) / len(val_losses))
283 |
284 | loss = np.sum(losses) / len(losses)
285 |
286 | for scheduler in schedulers:
287 | scheduler.step()
288 |
289 | val_loss_strs = ", ".join(map(str, val_loss))
290 | logger.info(f"Training Loss: {loss} - Val Loss: {val_loss_strs}")
291 |
292 | teacher_forcing_ratio *= schedule_rate
293 | if val_loss[0] < min_val_loss:
294 | min_val_loss = val_loss[0]
295 | logger.info(f"Saving model to {model_file_path}")
296 | torch.save({
297 | "encoder_state_dict": models[0].state_dict(),
298 | "decoder_state_dict": models[1].state_dict(),
299 | "optimizerA_state_dict": optims[0].state_dict(),
300 | "optimizerB_state_dict": optims[1].state_dict(),
301 | }, model_file_path)
302 |
--------------------------------------------------------------------------------
/src/test-seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from seq2seq.training_utils import (
9 | loss_batch,
10 | get_encoder,
11 | get_decoder,
12 | get_attn_decoder
13 | )
14 | from common.losses import QuatDistance
15 | from common.data_utils import load_dataloader
16 | from common.logging import logger
17 | import torch
18 | import numpy as np
19 | import os
20 | import argparse
21 | import h5py
22 | import matplotlib.pyplot as plt
23 | import matplotlib.font_manager
24 | matplotlib.use("Agg")
25 |
26 | torch.manual_seed(42)
27 | np.random.seed(42)
28 |
29 | torch.backends.cudnn.deterministic = False
30 | torch.backends.cudnn.benchmark = False
31 |
32 |
33 | def parse_args():
34 | """Parse arguments for module.
35 |
36 | Returns:
37 | argparse.Namespace: contains accessible arguments passed in to module
38 | """
39 | parser = argparse.ArgumentParser()
40 |
41 | parser.add_argument("--task",
42 | help=("task for neural network to train on; "
43 | "either prediction or conversion"))
44 | parser.add_argument("--data-path-parent",
45 | help=("parent folder of directories with h5 files "
46 | "(folders must contain normalization.h5 "
47 | "and testing.h5)"))
48 | parser.add_argument("--figure-file-path",
49 | help="path to where the figure should be saved")
50 | parser.add_argument("--figure-title",
51 | help="title of the histogram plot")
52 | parser.add_argument("--include-legend",
53 | help="will use bidirectional encoder",
54 | default=False,
55 | action="store_true")
56 | parser.add_argument("--model-dir",
57 | help="path to model file directory")
58 | parser.add_argument("--representation",
59 | help="orientation representation",
60 | default="quaternions")
61 | parser.add_argument("--batch-size",
62 | help="batch size for training", default=32)
63 | parser.add_argument("--seq-length",
64 | help=("sequence length for model, will be "
65 | "downsampled if downsample is provided"),
66 | default=20)
67 | parser.add_argument("--downsample",
68 | help=("reduce sampling frequency of recorded data; "
69 | "default sampling frequency is 240 Hz"),
70 | default=1)
71 | parser.add_argument("--in-out-ratio",
72 | help=("ratio of input/output; "
73 | "seq_length / downsample = input length = 10, "
74 | "output length = input length / in_out_ratio"),
75 | default=1)
76 | parser.add_argument("--stride",
77 | help=("stride used when reading data "
78 | "in for running prediction tasks"),
79 | default=3)
80 | parser.add_argument("--hidden-size",
81 | help="hidden size in both the encoder and decoder")
82 | parser.add_argument("--dropout",
83 | help="dropout percentage in encoder and decoder",
84 | default=0.0)
85 | parser.add_argument("--bidirectional",
86 | help="will use bidirectional encoder",
87 | default=False,
88 | action="store_true")
89 | parser.add_argument("--attention",
90 | help="use decoder with specified attention",
91 | default="general")
92 |
93 | args = parser.parse_args()
94 |
95 | return args
96 |
97 |
98 | if __name__ == "__main__":
99 | args = parse_args()
100 |
101 | for arg in vars(args):
102 | logger.info("{} - {}".format(arg, getattr(args, arg)))
103 |
104 | logger.info("Starting seq2seq model testing...")
105 |
106 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107 |
108 | seq_length = int(args.seq_length)
109 | stride = int(args.stride)
110 | batch_size = int(args.batch_size)
111 |
112 | data_paths = [args.data_path_parent + "/" + name
113 | for name in os.listdir(args.data_path_parent)
114 | if os.path.isdir(args.data_path_parent + "/" + name)]
115 |
116 | model_paths = [args.model_dir + "/" + name
117 | for name in os.listdir(args.model_dir)]
118 |
119 | data_paths.sort()
120 | model_paths.sort()
121 |
122 | plt.style.use("fivethirtyeight")
123 | plt.rcParams["font.family"] = "Times New Roman"
124 | plt.rcParams["axes.titlesize"] = 24
125 | plt.rcParams["axes.labelsize"] = 23
126 | plt.rcParams["xtick.labelsize"] = 22
127 | plt.rcParams["ytick.labelsize"] = 22
128 | plt.rcParams["figure.facecolor"] = "white"
129 | plt.rcParams["axes.facecolor"] = "white"
130 | plt.rcParams["savefig.edgecolor"] = "white"
131 | plt.rcParams["savefig.facecolor"] = "white"
132 |
133 | fig = plt.figure()
134 | ax = fig.add_subplot(111)
135 |
136 | for i, data_path in enumerate(data_paths):
137 | args.data_path = data_path
138 |
139 | with h5py.File(args.data_path + "/normalization.h5", "r") as f:
140 | mean, std_dev = torch.Tensor(f["mean"]), torch.Tensor(f["std_dev"])
141 | norm_data = (mean, std_dev)
142 | test_dataloader, _ = load_dataloader(args,
143 | "testing",
144 | True,
145 | norm_data=norm_data)
146 |
147 | encoder_feature_size = test_dataloader.dataset[0][0].shape[1]
148 | decoder_feature_size = test_dataloader.dataset[0][1].shape[1]
149 |
150 | bidirectional = args.bidirectional
151 | encoder = get_encoder(encoder_feature_size,
152 | device,
153 | hidden_size=int(args.hidden_size),
154 | dropout=float(args.dropout),
155 | bidirectional=bidirectional)
156 |
157 | use_attention = False
158 | attention_options = ["add", "dot", "concat",
159 | "general", "activated-general", "biased-general"]
160 |
161 | if args.attention in attention_options:
162 | decoder = get_attn_decoder(decoder_feature_size,
163 | args.attention,
164 | device,
165 | hidden_size=int(args.hidden_size),
166 | bidirectional_encoder=bidirectional)
167 | use_attention = True
168 | else:
169 | decoder = get_decoder(decoder_feature_size,
170 | device,
171 | dropout=float(args.dropout),
172 | hidden_size=int(args.hidden_size))
173 |
174 | checkpoint = torch.load(model_paths[i], map_location=device)
175 |
176 | encoder.load_state_dict(checkpoint["encoder_state_dict"])
177 | decoder.load_state_dict(checkpoint["decoder_state_dict"])
178 |
179 | decoder.batch_size = batch_size
180 | if use_attention:
181 | decoder.attention.batch_size = batch_size
182 |
183 | models = (encoder.double(), decoder.double())
184 | criterion = QuatDistance()
185 | norm_quaternions = (args.representation == "quaternions")
186 |
187 | with torch.no_grad():
188 | inference_losses = [loss_batch(data, models,
189 | None, criterion, device,
190 | use_attention=use_attention,
191 | norm_quaternions=norm_quaternions,
192 | average_batch=False)
193 | for _, data in enumerate(test_dataloader, 0)]
194 |
195 | def flatten(l): return [item for sublist in l for item in sublist]
196 |
197 | inference_losses = flatten(inference_losses)
198 |
199 | ax.hist(inference_losses, bins=60, density=True,
200 | histtype=u"step", linewidth=2)
201 | ax.set_xlim(0, 40)
202 | ax.set_xticks(range(0, 45, 5))
203 | ax.set_xticklabels(range(0, 45, 5))
204 |
205 | ax.set_ylim(0, 0.20)
206 | ax.set_yticks(np.arange(0, 0.20, 0.05))
207 | ax.set_yticklabels(np.arange(0, 0.20, 0.05).round(decimals=2))
208 |
209 | inference_loss = np.sum(inference_losses) / len(inference_losses)
210 | group = data_path.split("/")[-1]
211 | logger.info(f"Inference Loss for {group}: {inference_loss}")
212 |
213 | ax.set_title(args.figure_title)
214 | ax.set_xlabel("Sequence Angular Error in Degrees")
215 | ax.set_ylabel("Percentage")
216 | if args.include_legend:
217 | ax.legend(["Config. 1", "Config. 2", "Config. 3", "Config. 4"])
218 | figname = args.figure_file_path
219 | fig.savefig(figname, bbox_inches="tight")
220 |
221 | logger.info("Completed Testing...")
222 | logger.info("\n")
223 |
--------------------------------------------------------------------------------
/src/test-transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from transformers.training_utils import (
9 | inference,
10 | loss_batch
11 | )
12 | from transformers.transformers import (
13 | InferenceTransformer,
14 | InferenceTransformerEncoder
15 | )
16 | from common.data_utils import load_dataloader
17 | from common.logging import logger
18 | from common.losses import QuatDistance
19 | import torch
20 | from torch import nn
21 | import numpy as np
22 | import os
23 | import argparse
24 | import h5py
25 | import matplotlib
26 | import matplotlib.pyplot as plt
27 | import matplotlib.font_manager
28 | matplotlib.use("Agg")
29 |
30 |
31 | torch.manual_seed(42)
32 | np.random.seed(42)
33 |
34 | torch.backends.cudnn.deterministic = False
35 | torch.backends.cudnn.benchmark = False
36 |
37 |
38 | def parse_args():
39 | """Parse arguments for module.
40 |
41 | Returns:
42 | argparse.Namespace: contains accessible arguments passed in to module
43 | """
44 | parser = argparse.ArgumentParser()
45 |
46 | parser.add_argument("--task",
47 | help=("task for neural network to train on; "
48 | "either prediction or conversion"))
49 | parser.add_argument("--data-path-parent",
50 | help=("parent folder of directories with h5 files "
51 | "(folders must contain normalization.h5 "
52 | "and testing.h5)"))
53 | parser.add_argument("--figure-file-path",
54 | help="path to where the figure should be saved")
55 | parser.add_argument("--figure-title",
56 | help="title of the histogram plot")
57 | parser.add_argument("--include-legend",
58 | help="will use bidirectional encoder",
59 | default=False,
60 | action="store_true")
61 | parser.add_argument("--model-dir",
62 | help="path to model file directory")
63 | parser.add_argument("--representation",
64 | help="orientation representation",
65 | default="quaternions")
66 | parser.add_argument("--full-transformer",
67 | help=("will use full Transformer if true, "
68 | "will only use encoder if false"),
69 | default=False,
70 | action="store_true")
71 | parser.add_argument("--batch-size",
72 | help="batch size for training", default=32)
73 | parser.add_argument("--seq-length",
74 | help=("sequence length for model, will be "
75 | "downsampled if downsample is provided"),
76 | default=20)
77 | parser.add_argument("--downsample",
78 | help=("reduce sampling frequency of recorded data; "
79 | "default sampling frequency is 240 Hz"),
80 | default=1)
81 | parser.add_argument("--in-out-ratio",
82 | help=("ratio of input/output; "
83 | "seq_length / downsample = input length = 10, "
84 | "output length = input length / in_out_ratio"),
85 | default=1)
86 | parser.add_argument("--stride",
87 | help=("stride used when reading data "
88 | "in for running prediction tasks"),
89 | default=3)
90 | parser.add_argument("--num-heads",
91 | help="number of heads in Transformer Encoder")
92 | parser.add_argument("--dim-feedforward",
93 | help="number of dimensions in feedforward "
94 | "layer in Transformer Encoder")
95 | parser.add_argument("--dropout",
96 | help="dropout percentage in Transformer Encoder")
97 | parser.add_argument("--num-layers",
98 | help="number of layers in Transformer Encoder")
99 |
100 | args = parser.parse_args()
101 |
102 | if args.data_path_parent is None:
103 | parser.print_help()
104 |
105 | return args
106 |
107 |
108 | if __name__ == "__main__":
109 | args = parse_args()
110 |
111 | for arg in vars(args):
112 | logger.info(f"{arg} - {getattr(args, arg)}")
113 |
114 | logger.info("Starting Transformer testing...")
115 |
116 | logger.info(f"Device count: {str(torch.cuda.device_count())}")
117 |
118 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119 | logger.info("Testing on {}...".format(device))
120 | seq_length = int(args.seq_length)//int(args.downsample)
121 |
122 | data_paths = [args.data_path_parent + "/" + name
123 | for name in os.listdir(args.data_path_parent)
124 | if os.path.isdir(args.data_path_parent + "/" + name)]
125 |
126 | model_paths = [args.model_dir + "/" + name
127 | for name in os.listdir(args.model_dir)]
128 |
129 | data_paths.sort()
130 | model_paths.sort()
131 |
132 | plt.style.use("fivethirtyeight")
133 | plt.rcParams["font.family"] = "Times New Roman"
134 | plt.rcParams["axes.titlesize"] = 24
135 | plt.rcParams["axes.labelsize"] = 22
136 | plt.rcParams["xtick.labelsize"] = 18
137 | plt.rcParams["ytick.labelsize"] = 18
138 | plt.rcParams["figure.facecolor"] = "white"
139 | plt.rcParams["axes.facecolor"] = "white"
140 | plt.rcParams["savefig.edgecolor"] = "white"
141 | plt.rcParams["savefig.facecolor"] = "white"
142 |
143 | fig = plt.figure()
144 | ax = fig.add_subplot(111)
145 |
146 | for i, data_path in enumerate(data_paths):
147 | args.data_path = data_path
148 | with h5py.File(args.data_path + "/normalization.h5", "r") as f:
149 | mean, std_dev = torch.Tensor(f["mean"]), torch.Tensor(f["std_dev"])
150 | norm_data = (mean, std_dev)
151 |
152 | test_dataloader, _ = load_dataloader(args,
153 | "testing",
154 | True,
155 | norm_data=norm_data)
156 |
157 | encoder_feature_size = test_dataloader.dataset[0][0].shape[1]
158 | decoder_feature_size = test_dataloader.dataset[0][1].shape[1]
159 |
160 | num_heads = (int(args.num_heads) if args.full_transformer
161 | else encoder_feature_size)
162 | dim_feedforward = int(args.dim_feedforward)
163 | dropout = float(args.dropout)
164 | num_layers = int(args.num_layers)
165 | quaternions = (args.representation == "quaternions")
166 |
167 | if args.full_transformer:
168 | model = InferenceTransformer(decoder_feature_size, num_heads,
169 | dim_feedforward, dropout,
170 | num_layers, quaternions=quaternions)
171 | else:
172 | num_heads = encoder_feature_size
173 | model = InferenceTransformerEncoder(encoder_feature_size,
174 | num_heads, dim_feedforward,
175 | dropout, num_layers,
176 | decoder_feature_size,
177 | quaternions=quaternions)
178 |
179 | checkpoint = torch.load(model_paths[i], map_location=device)
180 |
181 | model.load_state_dict(checkpoint["model_state_dict"])
182 |
183 | if torch.cuda.device_count() > 1:
184 | model = nn.DataParallel(model)
185 |
186 | model = model.to(device).double()
187 |
188 | criterion = QuatDistance()
189 |
190 | with torch.no_grad():
191 | if args.full_transformer:
192 | inference_losses = [inference(model, data,
193 | criterion, device,
194 | average_batch=False)
195 | for _, data in enumerate(test_dataloader,
196 | 0)]
197 | else:
198 | inference_losses = [loss_batch(model, None,
199 | data, criterion, device,
200 | full_transformer=False,
201 | average_batch=False)
202 | for _, data in enumerate(test_dataloader,
203 | 0)]
204 |
205 | def flatten(l): return [item for sublist in l for item in sublist]
206 |
207 | inference_losses = flatten(inference_losses)
208 |
209 | inference_loss = np.sum(inference_losses) / len(inference_losses)
210 | logger.info("Inference Loss: {}".format(inference_loss))
211 |
212 | ax.hist(inference_losses, bins=60,
213 | density=True, histtype=u"step",
214 | linewidth=2)
215 | ax.set_xlim(0, 40)
216 | ax.set_xticks(range(0, 45, 5))
217 | ax.set_xticklabels(range(0, 45, 5))
218 |
219 | ax.set_ylim(0, 0.20)
220 | ax.set_yticks(np.arange(0, 0.20, 0.05))
221 | ax.set_yticklabels(np.arange(0, 0.20, 0.05).round(decimals=2))
222 |
223 | ax.set_title(args.figure_title)
224 | ax.set_xlabel("Sequence Angular Error in Degrees")
225 | ax.set_ylabel("Percentage")
226 | if args.include_legend:
227 | ax.legend(["Config. 1", "Config. 2", "Config. 3", "Config. 4"])
228 | figname = args.figure_file_path
229 | fig.savefig(figname, bbox_inches="tight")
230 |
231 | logger.info("Completed testing...")
232 | logger.info("\n")
233 |
--------------------------------------------------------------------------------
/src/test_seq2seq.sh:
--------------------------------------------------------------------------------
1 | python test-seq2seq.py --task conversion \
2 | --data-path-parent "/home/jackg7/VT-Natural-Motion-Processing/data" \
3 | --figure-file-path "/home/jackg7/VT-Natural-Motion-Processing/images/seq2seq-test.pdf" \
4 | --figure-title "Seq2Seq" \
5 | --model-dir "/home/jackg7/VT-Natural-Motion-Processing/models/set-2" \
6 | --representation quaternions \
7 | --batch-size=512 \
8 | --seq-length=30 \
9 | --downsample=6 \
10 | --in-out-ratio=5 \
11 | --stride=30 \
12 | --hidden-size=512 \
13 | --attention=dot \
14 | --bidirectional
15 |
--------------------------------------------------------------------------------
/src/test_transformer.sh:
--------------------------------------------------------------------------------
1 | python test-transformer.py --task conversion \
2 | --data-path-parent "/home/jackg7/VT-Natural-Motion-Processing/data" \
3 | --figure-file-path "/home/jackg7/VT-Natural-Motion-Processing/images/transformer-test.pdf" \
4 | --figure-title "Seq2Seq" \
5 | --model-dir "/home/jackg7/VT-Natural-Motion-Processing/models/set-2" \
6 | --full-transformer \
7 | --representation quaternions \
8 | --batch-size=512 \
9 | --seq-length=30 \
10 | --downsample=6 \
11 | --in-out-ratio=1 \
12 | --stride=30 \
13 | --num-heads=4 \
14 | --dim-feedforward=512 \
15 | --dropout=0.0 \
16 | --num-layers=2
17 |
--------------------------------------------------------------------------------
/src/train-seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from seq2seq.training_utils import (
9 | fit,
10 | get_encoder,
11 | get_decoder,
12 | get_attn_decoder
13 | )
14 | from common.losses import QuatDistance
15 | from common.data_utils import load_dataloader
16 | from common.logging import logger
17 | import torch
18 | from torch import nn, optim
19 | import numpy as np
20 | import argparse
21 |
22 | torch.manual_seed(42)
23 | np.random.seed(42)
24 |
25 | torch.backends.cudnn.deterministic = False
26 | torch.backends.cudnn.benchmark = False
27 |
28 |
29 | def parse_args():
30 | """Parse arguments for module.
31 |
32 | Returns:
33 | argparse.Namespace: contains accessible arguments passed in to module
34 | """
35 | parser = argparse.ArgumentParser()
36 |
37 | parser.add_argument("--task",
38 | help=("task for neural network to train on; "
39 | "either prediction or conversion"))
40 | parser.add_argument("--data-path",
41 | help=("path to h5 files containing data "
42 | "(must contain training.h5 and validation.h5)"))
43 | parser.add_argument("--model-file-path",
44 | help="path to model file for saving it after training")
45 | parser.add_argument("--representation",
46 | help="data representation", default="quaternion")
47 | parser.add_argument("--auxiliary-acc",
48 | help="will train on auxiliary acceleration if true",
49 | default=False,
50 | action="store_true")
51 | parser.add_argument("--batch-size",
52 | help="batch size for training", default=32)
53 | parser.add_argument("--learning-rate",
54 | help="learning rate for encoder and decoder",
55 | default=0.001)
56 | parser.add_argument("--seq-length",
57 | help="sequence length for encoder/decoder",
58 | default=20)
59 | parser.add_argument("--downsample",
60 | help="reduce sampling frequency of recorded data; "
61 | "default sampling frequency is 240 Hz",
62 | default=1)
63 | parser.add_argument("--in-out-ratio",
64 | help=("ratio of input/output; "
65 | "seq_length / downsample = input length = 10, "
66 | "output length = input length / in_out_ratio"),
67 | default=1)
68 | parser.add_argument("--stride",
69 | help="stride used when running prediction tasks",
70 | default=3)
71 | parser.add_argument("--num-epochs",
72 | help="number of epochs for training", default=1)
73 | parser.add_argument("--hidden-size",
74 | help="hidden size in both the encoder and decoder")
75 | parser.add_argument("--dropout",
76 | help="dropout percentage in encoder and decoder",
77 | default=0.0)
78 | parser.add_argument("--bidirectional",
79 | help="will use bidirectional encoder",
80 | default=False,
81 | action="store_true")
82 | parser.add_argument("--attention",
83 | help="will use decoder with given attention method",
84 | default="general")
85 |
86 | args = parser.parse_args()
87 |
88 | if args.data_path is None:
89 | parser.print_help()
90 |
91 | return args
92 |
93 |
94 | if __name__ == "__main__":
95 | args = parse_args()
96 |
97 | for arg in vars(args):
98 | logger.info(f"{arg} - {getattr(args, arg)}")
99 |
100 | logger.info("Starting seq2seq model training...")
101 |
102 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103 |
104 | seq_length = int(args.seq_length)
105 | stride = int(args.stride)
106 | lr = float(args.learning_rate)
107 |
108 | assert seq_length % int(args.in_out_ratio) == 0
109 |
110 | normalize = True
111 | train_dataloader, norm_data = load_dataloader(args, "training",
112 | normalize, norm_data=None)
113 | val_dataloader, _ = load_dataloader(args, "validation",
114 | normalize, norm_data=norm_data)
115 |
116 | encoder_feature_size = train_dataloader.dataset[0][0].shape[1]
117 | decoder_feature_size = train_dataloader.dataset[0][1].shape[1]
118 |
119 | bidirectional = args.bidirectional
120 | encoder = get_encoder(encoder_feature_size,
121 | device,
122 | hidden_size=int(args.hidden_size),
123 | dropout=float(args.dropout),
124 | bidirectional=bidirectional)
125 |
126 | encoder_optim = optim.AdamW(encoder.parameters(), lr=lr, weight_decay=0.05)
127 | encoder_sched = optim.lr_scheduler.MultiStepLR(encoder_optim,
128 | milestones=[5],
129 | gamma=0.1)
130 |
131 | use_attention = False
132 | attention_options = ["add", "dot", "concat",
133 | "general", "activated-general", "biased-general"]
134 | if args.attention in attention_options:
135 | decoder = get_attn_decoder(decoder_feature_size,
136 | args.attention,
137 | device,
138 | hidden_size=int(args.hidden_size),
139 | bidirectional_encoder=bidirectional)
140 | use_attention = True
141 | else:
142 | decoder = get_decoder(decoder_feature_size,
143 | device,
144 | dropout=float(args.dropout),
145 | hidden_size=int(args.hidden_size))
146 |
147 | decoder_optim = optim.AdamW(decoder.parameters(), lr=lr, weight_decay=0.05)
148 | decoder_sched = optim.lr_scheduler.MultiStepLR(decoder_optim,
149 | milestones=[5],
150 | gamma=0.1)
151 |
152 | encoder_params = sum(p.numel()
153 | for p in encoder.parameters() if p.requires_grad)
154 | decoder_params = sum(p.numel()
155 | for p in decoder.parameters() if p.requires_grad)
156 |
157 | models = (encoder, decoder)
158 | optims = (encoder_optim, decoder_optim)
159 | dataloaders = (train_dataloader, val_dataloader)
160 | epochs = int(args.num_epochs)
161 | training_criterion = nn.L1Loss()
162 | validation_criteria = [nn.L1Loss(), QuatDistance()]
163 | norm_quaternions = (args.representation == "quaternions")
164 |
165 | schedulers = (encoder_sched, decoder_sched)
166 |
167 | logger.info(f"Encoder for training: {encoder}")
168 | logger.info(f"Decoder for training: {decoder}")
169 | logger.info(f"Number of parameters: {encoder_params + decoder_params}")
170 | logger.info(f"Optimizers for training: {encoder_optim}")
171 | logger.info(f"Criterion for training: {training_criterion}")
172 |
173 | fit(models, optims, epochs, dataloaders, training_criterion,
174 | validation_criteria, schedulers, device, args.model_file_path,
175 | use_attention=use_attention, norm_quaternions=norm_quaternions)
176 |
177 | logger.info("Completed Training...")
178 | logger.info("\n")
179 |
--------------------------------------------------------------------------------
/src/train-transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from transformers.training_utils import fit
9 | from transformers.transformers import (
10 | InferenceTransformerEncoder,
11 | InferenceTransformer
12 | )
13 | from common.data_utils import load_dataloader
14 | from common.logging import logger
15 | from common.losses import QuatDistance
16 | import torch
17 | from torch import nn, optim
18 | import numpy as np
19 | import argparse
20 |
21 | torch.manual_seed(42)
22 | np.random.seed(42)
23 |
24 | torch.backends.cudnn.deterministic = False
25 | torch.backends.cudnn.benchmark = False
26 |
27 |
28 | def parse_args():
29 | """Parse arguments for module.
30 |
31 | Returns:
32 | argparse.Namespace: contains accessible arguments passed in to module
33 | """
34 | parser = argparse.ArgumentParser()
35 |
36 | parser.add_argument("--task",
37 | help=("task for neural network to train on; "
38 | "either prediction or conversion"))
39 | parser.add_argument("--data-path",
40 | help=("path to h5 files containing data "
41 | "(must contain training.h5 and validation.h5)"))
42 | parser.add_argument("--representation",
43 | help=("will normalize if quaternions, will use expmap "
44 | "to quat validation loss if expmap"),
45 | default="quaternion")
46 | parser.add_argument("--full-transformer",
47 | help=("will use Transformer with both encoder and "
48 | "decoder if true, will only use encoder "
49 | "if false"),
50 | default=False,
51 | action="store_true")
52 | parser.add_argument("--model-file-path",
53 | help="path to model file for saving it after training")
54 | parser.add_argument("--batch-size",
55 | help="batch size for training", default=32)
56 | parser.add_argument("--learning-rate",
57 | help="initial learning rate for training",
58 | default=0.001)
59 | parser.add_argument("--beta-one",
60 | help="beta1 for adam optimizer (momentum)",
61 | default=0.9)
62 | parser.add_argument("--beta-two",
63 | help="beta2 for adam optimizer", default=0.999)
64 | parser.add_argument("--seq-length",
65 | help=("sequence length for model, will be divided "
66 | "by downsample if downsample is provided"),
67 | default=20)
68 | parser.add_argument("--downsample",
69 | help=("reduce sampling frequency of recorded data; "
70 | "default sampling frequency is 240 Hz"),
71 | default=1)
72 | parser.add_argument("--in-out-ratio",
73 | help=("ratio of input/output; "
74 | "seq_length / downsample = input length = 10, "
75 | "output length = input length / in_out_ratio"),
76 | default=1)
77 | parser.add_argument("--stride",
78 | help=("stride used when reading data in "
79 | "for running prediction tasks"),
80 | default=3)
81 | parser.add_argument("--num-epochs",
82 | help="number of epochs for training", default=1)
83 | parser.add_argument("--num-heads",
84 | help="number of heads in Transformer")
85 | parser.add_argument("--dim-feedforward",
86 | help=("number of dimensions in feedforward layer "
87 | "in Transformer"))
88 | parser.add_argument("--dropout",
89 | help="dropout percentage in Transformer")
90 | parser.add_argument("--num-layers",
91 | help="number of layers in Transformer")
92 |
93 | args = parser.parse_args()
94 |
95 | if args.data_path is None:
96 | parser.print_help()
97 |
98 | return args
99 |
100 |
101 | if __name__ == "__main__":
102 | args = parse_args()
103 |
104 | for arg in vars(args):
105 | logger.info(f"{arg} - {getattr(args, arg)}")
106 |
107 | logger.info("Starting Transformer training...")
108 |
109 | logger.info(f"Device count: {torch.cuda.device_count()}")
110 |
111 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112 | logger.info(f"Training on {device}...")
113 | seq_length = int(args.seq_length)//int(args.downsample)
114 |
115 | assert seq_length % int(args.in_out_ratio) == 0
116 |
117 | lr = float(args.learning_rate)
118 |
119 | normalize = True
120 | train_dataloader, norm_data = load_dataloader(args, "training", normalize)
121 | val_dataloader, _ = load_dataloader(args, "validation", normalize,
122 | norm_data=norm_data)
123 |
124 | encoder_feature_size = train_dataloader.dataset[0][0].shape[1]
125 | decoder_feature_size = train_dataloader.dataset[0][1].shape[1]
126 |
127 | num_heads = int(args.num_heads)
128 | dim_feedforward = int(args.dim_feedforward)
129 | dropout = float(args.dropout)
130 | num_layers = int(args.num_layers)
131 | quaternions = (args.representation == "quaternions")
132 |
133 | if args.full_transformer:
134 | model = InferenceTransformer(decoder_feature_size, num_heads,
135 | dim_feedforward, dropout,
136 | num_layers, quaternions=quaternions)
137 | else:
138 | model = InferenceTransformerEncoder(encoder_feature_size, num_heads,
139 | dim_feedforward, dropout,
140 | num_layers, decoder_feature_size,
141 | quaternions=quaternions)
142 |
143 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
144 |
145 | if torch.cuda.device_count() > 1:
146 | model = nn.DataParallel(model)
147 |
148 | model = model.to(device).double()
149 |
150 | epochs = int(args.num_epochs)
151 | beta1 = float(args.beta_one)
152 | beta2 = float(args.beta_two)
153 |
154 | optimizer = optim.AdamW(model.parameters(),
155 | lr=lr,
156 | betas=(beta1, beta2),
157 | weight_decay=0.03)
158 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
159 | milestones=[1, 3],
160 | gamma=0.1)
161 |
162 | dataloaders = (train_dataloader, val_dataloader)
163 | training_criterion = nn.L1Loss()
164 | validation_criteria = [nn.L1Loss(), QuatDistance()]
165 |
166 | logger.info(f"Model for training: {model}")
167 | logger.info(f"Number of parameters: {num_params}")
168 | logger.info(f"Optimizer for training: {optimizer}")
169 | logger.info(f"Criterion for training: {training_criterion}")
170 |
171 | fit(model, optimizer, scheduler, epochs, dataloaders, training_criterion,
172 | validation_criteria, device, args.model_file_path,
173 | full_transformer=args.full_transformer)
174 |
175 | logger.info("Completed Training...")
176 | logger.info("\n")
177 |
--------------------------------------------------------------------------------
/src/train_seq2seq.sh:
--------------------------------------------------------------------------------
1 | python train-seq2seq.py --task conversion \
2 | --data-path "/home/jackg7/VT-Natural-Motion-Processing/data/set-2" \
3 | --model-file-path "/home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt" \
4 | --representation quaternions \
5 | --batch-size=32 \
6 | --seq-length=30 \
7 | --downsample=6 \
8 | --in-out-ratio=5 \
9 | --stride=30 \
10 | --learning-rate=0.001 \
11 | --num-epochs=1 \
12 | --hidden-size=512 \
13 | --attention=dot \
14 | --bidirectional
15 |
--------------------------------------------------------------------------------
/src/train_transformer.sh:
--------------------------------------------------------------------------------
1 | python train-transformer.py --task conversion \
2 | --data-path "/home/jackg7/VT-Natural-Motion-Processing/data/set-2" \
3 | --model-file-path "/home/jackg7/VT-Natural-Motion-Processing/models/set-2/model.pt" \
4 | --full-transformer \
5 | --representation quaternions \
6 | --batch-size=32 \
7 | --seq-length=30 \
8 | --downsample=6 \
9 | --in-out-ratio=1 \
10 | --stride=30 \
11 | --learning-rate=0.001 \
12 | --beta-one=0.95 \
13 | --beta-two=0.999 \
14 | --num-epochs=5 \
15 | --num-heads=4 \
16 | --dim-feedforward=512 \
17 | --dropout=0.0 \
18 | --num-layers=2
19 |
--------------------------------------------------------------------------------
/src/transformers/training_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import time
10 | import math
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | from common.logging import logger
14 | plt.switch_backend("agg")
15 | torch.manual_seed(0)
16 |
17 |
18 | class Timer:
19 | def __enter__(self):
20 | self.start = time.time()
21 | return self
22 |
23 | def __exit__(self, *args):
24 | self.end = time.time()
25 | self.interval = self.end - self.start
26 |
27 |
28 | def loss_batch(model, optimizer, data, criterion, device,
29 | full_transformer=False, average_batch=True):
30 | """Run a batch through the Transformer Encoder or Transformer for training.
31 |
32 | Used in the fit function to train the models.
33 |
34 | Args:
35 | model (nn.Module): model to pass batch through
36 | optimizer (optim.Optimizer): optimizer for training
37 | data (tuple): contains input and targets to use for inference
38 | criterion (nn.Module): criterion to evaluate model on
39 | device (torch.device): device to put data on
40 | full_transformer (bool, optional): whether the model is a full
41 | transformer; the forward pass will operate differently if so.
42 | Defaults to False.
43 | average_batch (bool, optional): whether to average the batch; useful
44 | for plotting histograms if average_batch is false.
45 | Defaults to True.
46 |
47 | Returns:
48 | float or list: returns a single loss or a list of losses depending on
49 | average_batch argument.
50 | """
51 | if optimizer is None:
52 | model.eval()
53 | else:
54 | model.train()
55 |
56 | inputs, targets = data
57 | inputs = inputs.permute(1, 0, 2).to(device).double()
58 | targets = targets.permute(1, 0, 2).to(device).double()
59 |
60 | outputs = None
61 | if full_transformer:
62 | SOS = torch.zeros_like(targets[0, :]).unsqueeze(0)
63 | tgt = torch.cat((SOS, targets), dim=0)
64 |
65 | if tgt.shape[2] > inputs.shape[2]:
66 | padding = torch.zeros((inputs.shape[0], inputs.shape[1],
67 | tgt.shape[2] - inputs.shape[2]))
68 | padding = padding.to(device).double()
69 | inputs = torch.cat((inputs, padding), dim=2)
70 |
71 | outputs = model(inputs, tgt[:-1, :])
72 | else:
73 | outputs = model(inputs)
74 |
75 | loss = criterion(outputs, targets)
76 |
77 | if optimizer is not None:
78 | loss.backward()
79 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
80 | optimizer.step()
81 | optimizer.zero_grad()
82 |
83 | if average_batch:
84 | return loss.item()
85 | else:
86 | losses = []
87 | for b in range(outputs.shape[1]):
88 | sample_loss = criterion(outputs[:, b, :], targets[:, b, :])
89 | losses.append(sample_loss.item())
90 |
91 | return losses
92 |
93 |
94 | def inference(model, data, criterion, device, average_batch=True):
95 | """Run inference with full Transformer.
96 |
97 | Used to determine actual validation loss that will be seen in practice
98 | since full Transformers use teacher forcing.
99 |
100 |
101 | Args:
102 | model (nn.Module): model to run inference with
103 | data (tuple): contains input and targets to use for inference
104 | criterion (nn.Module): criterion to evaluate model on
105 | device (torch.device): device to put data on
106 | average_batch (bool, optional): whether to average the batch; useful
107 | for plotting histograms if average_batch is false.
108 | Defaults to True.
109 |
110 | Returns:
111 | float or list: returns a single loss or a list of losses depending on
112 | average_batch argument.
113 | """
114 | model.eval()
115 |
116 | inputs, targets = data
117 | inputs = inputs.permute(1, 0, 2).to(device).double()
118 | targets = targets.permute(1, 0, 2).to(device).double()
119 |
120 | pred = torch.zeros((targets.shape[0]+1,
121 | targets.shape[1],
122 | targets.shape[2])).to(device).double()
123 |
124 | if pred.shape[2] > inputs.shape[2]:
125 | padding = torch.zeros((inputs.shape[0],
126 | inputs.shape[1],
127 | pred.shape[2] - inputs.shape[2]))
128 | padding = padding.to(device).double()
129 | inputs = torch.cat((inputs.clone(), padding), dim=2)
130 |
131 | memory = model.encoder(model.pos_decoder(inputs))
132 |
133 | for i in range(pred.shape[0]-1):
134 | next_pred = model.inference(memory, pred[:i+1, :].clone())
135 | pred[i+1, :] = pred[i+1, :].clone() + next_pred[-1, :].clone()
136 |
137 | if average_batch:
138 | loss = criterion(pred[1:, :], targets)
139 | return loss.item()
140 | else:
141 | losses = []
142 | for b in range(pred.shape[1]):
143 | loss = criterion(pred[1:, b, :], targets[:, b, :])
144 | losses.append(loss.item())
145 | return losses
146 |
147 |
148 | def fit(model, optimizer, scheduler, epochs, dataloaders, training_criterion,
149 | validation_criteria, device, model_file_path,
150 | full_transformer=False, min_val_loss=math.inf):
151 | """Fit a Transformer model to data, logging training and validation loss.
152 |
153 | Args:
154 | model (nn.Module): the model to train
155 | optimizer (tuple): the optimizer for training
156 | scheduler (list): scheduler to control learning rate for
157 | optimizers
158 | epochs (int): number of epochs to train for
159 | dataloaders (tuple): tuple containing the training dataloader and val
160 | dataloader.
161 | training_criterion (nn.Module): criterion for backpropagation during
162 | training
163 | validation_criteria (list): list of criteria for validation
164 | device (torch.device): device to place data on
165 | model_file_path (str): where to save the model when validation loss
166 | reaches new minimum
167 | full_transformer (bool): whether the model is a full transformer and
168 | needs to run inference for evaluation
169 | min_val_loss (float, optional): minimum validation loss
170 |
171 | Returns:
172 | float: minimum validation loss reached during training
173 | """
174 | train_dataloader, val_dataloader = dataloaders
175 | total_time = 0
176 |
177 | for epoch in range(epochs):
178 | losses = 0
179 | logger.info("Epoch {}".format(epoch + 1))
180 | avg_loss = 0
181 | for index, data in enumerate(train_dataloader, 0):
182 | with Timer() as timer:
183 | loss = loss_batch(model, optimizer, data, training_criterion,
184 | device, full_transformer=full_transformer)
185 | losses += loss
186 | avg_loss += loss
187 | total_time += timer.interval
188 | if index % (len(train_dataloader) // 10) == 0 and index != 0:
189 | avg_training_loss = avg_loss / (len(train_dataloader) // 10)
190 | logger.info((f"Total time elapsed: {total_time}"
191 | " - "
192 | f"Batch number: {index} / {len(train_dataloader)}"
193 | " - "
194 | f"Training loss: {avg_training_loss}"
195 | " - "
196 | f"LR: {optimizer.param_groups[0]['lr']}"
197 | ))
198 | avg_loss = 0
199 |
200 | val_loss = []
201 | for validation_criterion in validation_criteria:
202 | with torch.no_grad():
203 | val_losses = [loss_batch(model, None, data,
204 | validation_criterion, device,
205 | full_transformer=full_transformer)
206 | for _, data in enumerate(val_dataloader, 0)]
207 |
208 | val_loss.append(np.sum(val_losses) / len(val_losses))
209 |
210 | loss = losses / len(train_dataloader)
211 |
212 | scheduler.step()
213 | val_loss_str = ", ".join(map(str, val_loss))
214 | logger.info(f"Epoch {epoch+1} - "
215 | f"Training Loss: {loss} - "
216 | f"Val Loss: {val_loss_str}")
217 |
218 | if full_transformer:
219 | inference_loss = []
220 | for validation_criterion in validation_criteria:
221 | with torch.no_grad():
222 | inference_losses = [inference(model, data,
223 | validation_criterion, device)
224 | for _, data in
225 | enumerate(val_dataloader, 0)]
226 | inference_loss.append(
227 | np.sum(inference_losses) / len(inference_losses)
228 | )
229 | inference_loss_str = ", ".join(map(str, inference_loss))
230 | logger.info(f"Inference Loss: {inference_loss_str}")
231 |
232 | if val_loss[0] < min_val_loss:
233 | min_val_loss = val_loss[0]
234 | logger.info(f"Saving model to {model_file_path}")
235 | torch.save({
236 | "model_state_dict": model.state_dict(),
237 | "optimizer_state_dict": optimizer.state_dict(),
238 | }, model_file_path)
239 |
240 | return min_val_loss
241 |
--------------------------------------------------------------------------------
/src/transformers/transformers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-present, Assistive Robotics Lab
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | Transformer classes with quaternion normalization.
10 |
11 | Reference:
12 | [1] https://pytorch.org/tutorials/beginner/transformer_tutorial.html
13 | """
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import math
18 |
19 |
20 | class PositionalEncoding(nn.Module):
21 | """PositionalEncoding injects position-dependent signals in input/targets.
22 |
23 | Transformers have no concept of sequence like RNNs, so this acts to inject
24 | information about the order of a sequence.
25 |
26 | Useful reference for more info:
27 | https://datascience.stackexchange.com/questions/51065/what-is-the-positional-encoding-in-the-transformer-model
28 | """
29 |
30 | def __init__(self, d_model, dropout=0.1, max_len=5000):
31 | super(PositionalEncoding, self).__init__()
32 | self.dropout = nn.Dropout(p=dropout)
33 |
34 | pe = torch.zeros(max_len, d_model)
35 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
36 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
37 | (-math.log(10000.0) / d_model))
38 | pe[:, 0::2] = torch.sin(position * div_term)
39 | pe[:, 1::2] = torch.cos(position * div_term)[:, :d_model//2]
40 | pe = pe.unsqueeze(0).transpose(0, 1)
41 | self.register_buffer("pe", pe)
42 |
43 | def forward(self, x):
44 | x = x + self.pe[:x.size(0), :]
45 | return self.dropout(x)
46 |
47 |
48 | class InferenceTransformerEncoder(nn.Module):
49 | """Transformer Encoder for use in inferring unit quaternions."""
50 |
51 | def __init__(self, num_input_features, num_heads, dim_feedforward, dropout,
52 | num_layers, num_output_features, quaternions=False):
53 | """Initialize the Transformer Encoder.
54 |
55 | Args:
56 | num_input_features (int): number of features in the input
57 | data
58 | num_heads (int): number of heads in each layer for multi-head
59 | attention
60 | dim_feedforward (int): dimensionality of the feedforward layers in
61 | each layer
62 | dropout (float): dropout amount in the layers
63 | num_layers (int): number of layers in Encoder
64 | num_output_features (int): number of features in the output
65 | data
66 | quaternions (bool, optional): whether quaternions are used in the
67 | output; will normalize if True. Defaults to False.
68 | """
69 | super(InferenceTransformerEncoder, self).__init__()
70 | self.pos_encoder = PositionalEncoding(num_input_features, dropout)
71 | self.encoder_layer = nn.TransformerEncoderLayer(num_input_features,
72 | num_heads,
73 | dim_feedforward,
74 | dropout)
75 | self.encoder = nn.TransformerEncoder(self.encoder_layer,
76 | num_layers,
77 | norm=None)
78 | self.linear1 = nn.Linear(num_input_features, dim_feedforward)
79 | self.linear2 = nn.Linear(dim_feedforward, num_output_features)
80 | self.quaternions = quaternions
81 |
82 | def forward(self, src):
83 | """Forward pass through Transformer Encoder model for training/testing.
84 |
85 | Args:
86 | src (torch.Tensor): used by encoder to generate output
87 |
88 | Returns:
89 | torch.Tensor: output from the Transformer
90 | """
91 | pos_enc = self.pos_encoder(src)
92 | enc_output = self.encoder(pos_enc)
93 | output = self.linear2(self.linear1(enc_output))
94 |
95 | if self.quaternions:
96 | original_shape = output.shape
97 |
98 | output = output.view(-1, 4)
99 | output = F.normalize(output, p=2, dim=1).view(original_shape)
100 |
101 | return output
102 |
103 |
104 | class InferenceTransformer(nn.Module):
105 | """Transformer for use in inferring unit quaternions."""
106 |
107 | def __init__(self, num_features, num_heads, dim_feedforward, dropout,
108 | num_layers, quaternions=False):
109 | """Initialize the Transformer model.
110 |
111 | Args:
112 | num_features (int): number of features in the input and target
113 | data
114 | num_heads (int): number of heads in each layer for multi-head
115 | attention
116 | dim_feedforward (int): dimensionality of the feedforward layers in
117 | each layer
118 | dropout (float): dropout amount in the layers
119 | num_layers (int): number of layers in Encoder and Decoder
120 | quaternions (bool, optional): whether quaternions are used in the
121 | output; will normalize if True. Defaults to False.
122 | """
123 | super(InferenceTransformer, self).__init__()
124 |
125 | self.pos_encoder = PositionalEncoding(num_features, dropout)
126 | self.pos_decoder = PositionalEncoding(num_features, dropout)
127 |
128 | self.encoder_layer = nn.TransformerEncoderLayer(num_features,
129 | num_heads,
130 | dim_feedforward,
131 | dropout)
132 | self.encoder = nn.TransformerEncoder(self.encoder_layer,
133 | num_layers,
134 | norm=None)
135 | self.decoder_layer = nn.TransformerDecoderLayer(num_features,
136 | num_heads,
137 | dim_feedforward,
138 | dropout)
139 | self.decoder = nn.TransformerDecoder(self.decoder_layer,
140 | num_layers,
141 | norm=None)
142 | self.tgt_mask = None
143 | self.quaternions = quaternions
144 |
145 | def generate_square_subsequent_mask(self, sz):
146 | """Mask the upcoming values in the tensor to avoid cheating.
147 |
148 | Args:
149 | sz (int): sequence length of tensor
150 |
151 | Returns:
152 | torch.Tensor: mask of subsequent entries during forward pass
153 | """
154 | mask = torch.triu(torch.ones(sz, sz), 1)
155 | mask = mask.masked_fill(mask == 1, float("-inf"))
156 | return mask
157 |
158 | def inference(self, memory, target):
159 | """Forward pass through Transformer model at validation/test time.
160 |
161 | Args:
162 | memory (torch.Tensor): memory passed from the Encoder
163 | target (torch.Tensor): predictions built up over time during
164 | inference
165 |
166 | Returns:
167 | torch.Tensor: output from the model (passed into model in next
168 | iteration)
169 | """
170 | if self.tgt_mask is None or self.tgt_mask.size(0) != len(target):
171 | self.tgt_mask = self.generate_square_subsequent_mask(len(target))
172 | self.tgt_mask = self.tgt_mask.to(target.device)
173 |
174 | pos_dec = self.pos_decoder(target)
175 | output = self.decoder(pos_dec, memory, tgt_mask=self.tgt_mask)
176 |
177 | if self.quaternions:
178 | original_shape = output.shape
179 | output = F.normalize(output.view(-1, 4),
180 | p=2,
181 | dim=1).view(original_shape)
182 |
183 | return output
184 |
185 | def forward(self, src, target):
186 | """Forward pass through Transformer model for training.
187 |
188 | Use inference function at validation/test time to get accurate
189 | measure of performance.
190 |
191 | Args:
192 | src (torch.Tensor): used by encoder to generate memory
193 | target (torch.Tensor): target for decoder to try to match;
194 | Transformers use teacher forcing so targets are used as input
195 |
196 | Returns:
197 | torch.Tensor: output from the Transformer
198 | """
199 | if self.tgt_mask is None or self.tgt_mask.size(0) != len(target):
200 | self.tgt_mask = self.generate_square_subsequent_mask(len(target))
201 | self.tgt_mask = self.tgt_mask.to(target.device)
202 |
203 | pos_enc = self.pos_decoder(src)
204 | memory = self.encoder(pos_enc)
205 |
206 | pos_dec = self.pos_decoder(target)
207 | output = self.decoder(pos_dec, memory, tgt_mask=self.tgt_mask)
208 |
209 | if self.quaternions:
210 | original_shape = output.shape
211 | output = F.normalize(output.view(-1, 4),
212 | p=2,
213 | dim=1).view(original_shape)
214 |
215 | return output
216 |
--------------------------------------------------------------------------------