├── .gitignore
├── LICENSE
├── README.md
├── configs
├── amass.yaml
├── application.yaml
├── basic.yaml
├── basic_dog.yaml
├── dog_mocap.yaml
├── dog_skel.yaml
├── evaluation.yaml
├── generative.yaml
├── gmp.yaml
└── smpl.yaml
├── data
└── .gitkeep
├── outputs
└── .gitkeep
├── overview.png
├── requirements.txt
└── src
├── application.py
├── arguments.py
├── datasets
├── __init__.py
├── amass.py
└── animal.py
├── evaluation.py
├── holden
├── Animation.py
├── AnimationStructure.py
├── BVH.py
├── InverseKinematics.py
├── Pivots.py
├── Quaternions.py
└── __init__.py
├── nemf
├── __init__.py
├── base_model.py
├── basic.py
├── fk.py
├── generative.py
├── global_motion.py
├── loss_record.py
├── losses.py
├── neural_motion.py
├── prior.py
├── residual_blocks.py
└── skeleton.py
├── rotations.py
├── soft_dtw_cuda.py
├── train.py
├── train_basic.py
├── train_gmp.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | data/*
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 | *.ipynb
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | outputs/*
134 | !outputs/.gitkeep
135 | outputs_*/
136 | config_v*.yml
137 | analysis/
138 | analysis_*/
139 | optimize/
140 | optimize_*/
141 | *.bak*
142 | *.out
143 | *.csv
144 | inbetween_benchmark/
145 | comparison/
146 | nemf_test/
147 | motion_recon/
148 | logs/
149 | optim*/
150 | evaluate*/
151 | max_sneaks/
152 |
153 | !data/.gitkeep
154 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Chengan He
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeMF: Neural Motion Fields for Kinematic Animation (NeurIPS 2022)
2 |
3 | ### [Paper](https://arxiv.org/abs/2206.03287) | [Project](https://cs.yale.edu/homes/che/projects/nemf/)
4 |
5 | [Chengan He](https://cs.yale.edu/homes/che/)1, [Jun Saito](https://research.adobe.com/person/jun-saito/)2, [James Zachary](https://jameszachary.com/)2, [Holly Rushmeier](https://graphics.cs.yale.edu/people/holly-rushmeier)1, [Yi Zhou](https://zhouyisjtu.github.io/)2
6 |
7 | 1Yale University, 2Adobe Research
8 |
9 | 
10 |
11 | ## Prerequisites
12 |
13 | This code was developed on Ubuntu 20.04 with Python 3.9, CUDA 11.3 and PyTorch 1.9.0.
14 |
15 | ### Environment Setup
16 |
17 | To begin with, please set up the virtual environment with Anaconda:
18 | ```bash
19 | conda create -n nemf python=3.9
20 | conda activate nemf
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ### Body Model
25 |
26 | Our code relies on [SMPL](https://smpl.is.tue.mpg.de/) as the body model. Please download it from Max Planck and accept the terms of license there.
27 |
28 | ### Datasets
29 |
30 | #### AMASS
31 |
32 | [AMASS](https://amass.is.tue.mpg.de/) mocap data is used to train and evaluate our model in most of the experiments. Please download it from Max Planck and accept the terms of license there. After downloading its raw data, we run the [processing scripts](https://github.com/davrempe/humor/tree/main/data#amass) provided by HuMoR to unify frame rates, detect contacts, and remove problematic sequences. We then split the data into training, validation and testing sets by running:
33 | ```bash
34 | python src/datasets/amass.py amass.yaml
35 | ```
36 |
37 | #### MANN
38 |
39 | The dog mocap data in [MANN](https://github.com/sebastianstarke/AI4Animation/tree/master/AI4Animation/SIGGRAPH_2018) is used to evaluate the reconstruction capability of our model in long sequences, whose raw data can be downloaded from [here](http://www.starke-consult.de/AI4Animation/SIGGRAPH_2018/MotionCapture.zip). To process the data, we first manually remove some sequences on uneven terrain and then run:
40 | ```bash
41 | python src/datasets/animal.py dog_mocap.yaml
42 | ```
43 |
44 | > **Note:** All the data mentioned before should be downloaded **from their original source with corresponding licenses** and processed to the `data` directory. Otherwise you need to update the configuration files to point them to the path you extracted.
45 |
46 | ## Quickstart
47 |
48 | We provide a pre-trained [generative NeMF model](https://yaleedu-my.sharepoint.com/:u:/g/personal/chengan_he_yale_edu/ERadqRp5XedAn0pOyGeFWOMB3LH7g7w9c4OyzB7nmweHqA?e=8P1NWU) and [global motion predictor](https://yaleedu-my.sharepoint.com/:u:/g/personal/chengan_he_yale_edu/ERmQUqW5z_tAncGS1XzTJ6sBGnIxEawVgK9J-vID-uT6iw?e=TC6wfQ). Download and extract them to the `outputs` directory.
49 |
50 | ### Application
51 |
52 | We deploy our trained model as a generic motion prior to sovle different motion tasks. To run the applications we showed in the paper, use:
53 | ```bash
54 | python src/application.py --config application.yaml --task {application task} --save_path {save path}
55 | ```
56 | Here we implement several applications including `motion_reconstruction`, `latent_interpolation`, `motion_inbetweening`, `motion_renavigating`, `aist_inbetweening`, and `time_translation`. For `aist_inbetweening`, you need to download the motion data and dance videos from [AIST++ Dataset](https://google.github.io/aistplusplus_dataset/download.html) and place them under `data/aist`. You also need to have [FFmpeg](https://ffmpeg.org/) installed to process these videos.
57 |
58 | ### Evaluation
59 |
60 | To evaluate our trained model on different tasks, use:
61 | ```bash
62 | python src/evaluation.py --config evaluation.yaml --task {evaluation task} --load_path {load path}
63 | ```
64 | The tasks we implement here include `ablation_study`, `comparison`, `inbetween_benchmark`, `super_sampling`, and `smoothness_test`, which cover the tables and figures we showed in the paper. The quantitative results will be saved in `.csv` files.
65 |
66 | To evaludate FID and Diversity, we provide a pre-trained feature extractor at [here](https://yaleedu-my.sharepoint.com/:u:/g/personal/chengan_he_yale_edu/EVXKF8Tc1t5CnrkswLS0qUoBi_1YsDa_BpMDAalEuTblSQ?e=bcxHAP), which is essentially an auto-encoder. You can train a new one on your data by running:
67 | ```
68 | python src/train.py evaluation.yaml
69 | ```
70 |
71 | ## Train NeMF from Scratch
72 |
73 | In our paper we proposed three different models: a single-motion NeMF that overfits specific motion sequences, a generative NeMF that learns a motion prior, and a global motion predictor that generates root translations separately. Below we describe how to train these three models from scratch.
74 |
75 | ### Training Single-motion NeMF
76 |
77 | To train the single-motion NeMF on AMASS sequences, use:
78 | ```
79 | python src/train_basic.py basic.yaml
80 | ```
81 | The code will obtain sequences of 32, 64, 128, 256, and 512 frames and reconstruct them at 30, 60, 120, and 240 fps.
82 |
83 | To train the model on dog mocap sequences, use:
84 | ```
85 | python src/train_basic.py basic_dog.yaml
86 | ```
87 |
88 | ### Training Generative NeMF
89 |
90 | To train the generative NeMF on AMASS dataset, use:
91 | ```
92 | python src/train.py generative.yaml
93 | ```
94 |
95 | ### Training Global Motion Predictor
96 |
97 | To train the global motion predictor on AMASS dataset, use:
98 | ```
99 | python src/train_gmp.py gmp.yaml
100 | ```
101 |
102 | ## Visualization and Rendering
103 |
104 | Our codebase outputs `.npz` data following the AMASS data format, thus they can be visualized directly with the [SMPL-X Blender add-on](https://smpl-x.is.tue.mpg.de/). To render the skeleton animation of `.bvh` files, we use the rendering scripts provided in [deep-motion-editing](https://github.com/DeepMotionEditing/deep-motion-editing).
105 |
106 | ## Acknowledgements
107 |
108 | - Our bvh I/O code is adapted from the [work](https://theorangeduck.com/media/uploads/other_stuff/motionsynth_code.zip) of [Daniel Holden](https://theorangeduck.com/page/publications).
109 | - The code in `src/rotations.py` is adapted from [PyTorch3D](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py).
110 | - The code in `src/datasets/amass.py` is adapted from [AMASS](https://github.com/nghorbani/amass/blob/master/src/amass/data/prepare_data.py).
111 | - The code in `src/nemf/skeleton.py` is taken from [deep-motion-editing](https://github.com/DeepMotionEditing/deep-motion-editing) and our code structure is also based on it.
112 | - Part of the code in `sec/evaluation.py` is adapted from [Action2Motion](https://github.com/EricGuo5513/action-to-motion/tree/master/eval_scripts).
113 | - Part of the code in `src/utils.py` is taken from [HuMoR](https://github.com/davrempe/humor/blob/b86c2d9faf7abd497749621821a5d46211304d62/humor/scripts/process_amass_data.py).
114 | - The code in `src/soft_dtw_cuda.py` is taken from [pytorch-softdtw-cuda](https://github.com/Maghoumi/pytorch-softdtw-cuda).
115 |
116 | **Huge thanks to these great open-source projects!**
117 |
118 | ## Citation
119 |
120 | If you found this code or paper useful, please consider citing:
121 | ```
122 | @article{he2022nemf,
123 | title={NeMF: Neural Motion Fields for Kinematic Animation},
124 | author={He, Chengan and Saito, Jun and Zachary, James and Rushmeier, Holly and Zhou, Yi},
125 | journal={Advances in Neural Information Processing Systems},
126 | volume={35},
127 | pages={4244--4256},
128 | year={2022}
129 | }
130 | ```
131 |
132 | ## Contact
133 | If you run into any problems or have questions, please create an issue or contact `chengan.he@yale.edu`.
134 |
--------------------------------------------------------------------------------
/configs/amass.yaml:
--------------------------------------------------------------------------------
1 | dataset_dir: ./data/amass/test
2 | # dataset_dir: ./data/amass/gmp
3 | # dataset_dir: ./data/amass/single
4 |
5 | data:
6 | fps: 30
7 | clip_length: 128
8 | gender: male
9 | up: z
10 |
11 | canonical: False
12 | unified_orientation: True
13 | single_motion: False
14 | normalize: False
--------------------------------------------------------------------------------
/configs/application.yaml:
--------------------------------------------------------------------------------
1 | multi_gpu: False
2 | verbose: False
3 |
4 | is_train: False
5 | log: False
6 |
7 | iterations: 600
8 |
9 | batch_size: 1
10 | num_workers: 0
11 |
12 | dataset_dir: ./data/amass/generative
13 | save_dir: ./outputs/generative
14 | bvh_viz: False
15 |
16 | output_trans: False
17 | pretrained_gmp: gmp.yaml
18 |
19 | initialization: True
20 | learning_rate: 0.1
21 | geodesic_loss: True
22 | l1_loss: True
23 | lambda_rot: 1
24 | lambda_pos: 10
25 | lambda_orient: 1
26 | lambda_trans: 1
27 | dtw_loss: True
28 | lambda_dtw: 0.5
29 | lambda_angle: 1
30 |
31 | lambda_kl: 0.0001
32 |
33 | data:
34 | fps: 30
35 | clip_length: 128
36 | gender: male
37 | up: z
38 | root_transform: True
39 | normalize: ['pos', 'velocity', 'global_xform', 'angular', 'height', 'root_orient', 'root_vel']
40 |
41 | local_prior:
42 | activation: tanh
43 | channel_base: 15 # 3(pos) + 3(vel) + 6(rot) + 3(ang)
44 | use_residual_blocks: True
45 | z_dim: 1024
46 | temporal_scale: 8
47 | kernel_size: 4
48 | num_layers: 4
49 | skeleton_dist: 2
50 | extra_conv: 0
51 | padding_mode: reflect
52 | skeleton_pool: mean
53 | upsampling: linear
54 |
55 | global_prior:
56 | activation: tanh
57 | in_channels: 6
58 | kernel_size: 4
59 | temporal_scale: 8
60 | z_dim: 256
61 |
62 | nemf:
63 | siren: False
64 | skip_connection: True
65 | norm_layer: True
66 | bandwidth: 7
67 | hidden_neuron: 1024
68 | local_z: 1024
69 | global_z: 256
70 | local_output: 144
71 | global_output: 6
72 |
73 | scheduler:
74 | step_size: 200
75 | gamma: 0.7
--------------------------------------------------------------------------------
/configs/basic.yaml:
--------------------------------------------------------------------------------
1 | multi_gpu: False
2 | verbose: True
3 |
4 | is_train: True
5 | log: True
6 |
7 | epoch_begin: 0
8 | iterations: 500
9 |
10 | amass_data: True
11 | dataset_dir: ./data/amass/single
12 | save_dir: ./outputs/basic
13 | bvh_viz: False
14 |
15 | learning_rate: 0.0001
16 | geodesic_loss: True
17 | l1_loss: True
18 | lambda_rotmat: 1
19 | lambda_pos: 10
20 | lambda_orient: 1
21 | lambda_v: 1
22 | lambda_up: 1
23 | lambda_trans: 1
24 |
25 | data:
26 | fps: 30
27 | up: z
28 | gender: male
29 | root_transform: True
30 |
31 | nemf:
32 | siren: False
33 | skip_connection: True
34 | norm_layer: True
35 | bandwidth: 7
36 | hidden_neuron: 1024
37 | local_z: 0
38 | global_z: 0
39 | local_output: 154 # 24 x 6 + 6 + 4
40 | global_output: 1
41 |
42 | scheduler:
43 | name:
--------------------------------------------------------------------------------
/configs/basic_dog.yaml:
--------------------------------------------------------------------------------
1 | multi_gpu: False
2 | verbose: True
3 |
4 | is_train: True
5 | log: True
6 |
7 | epoch_begin: 0
8 | iterations: 8000
9 |
10 | amass_data: False
11 | dataset_dir: ./data/dog
12 | save_dir: ./outputs/dog
13 | bvh_viz: True
14 |
15 | learning_rate: 0.0001
16 | geodesic_loss: True
17 | l1_loss: True
18 | lambda_rotmat: 1
19 | lambda_pos: 10
20 | lambda_orient: 1
21 | lambda_v: 1
22 | lambda_up: 1
23 | lambda_trans: 1
24 |
25 | data:
26 | fps: 60
27 | up: y
28 | root_transform: True
29 |
30 | nemf:
31 | siren: False
32 | skip_connection: True
33 | norm_layer: True
34 | bandwidth: 7
35 | hidden_neuron: 1024
36 | local_z: 0
37 | global_z: 0
38 | local_output: 172 # 27 x 6 + 6 + 4
39 | global_output: 1
40 |
41 | scheduler:
42 | name:
--------------------------------------------------------------------------------
/configs/dog_mocap.yaml:
--------------------------------------------------------------------------------
1 | dataset_dir: ./data/dog
2 |
3 | data:
4 | fps: 60
5 | up: y
6 |
7 | canonical: False
8 | unified_orientation: True
--------------------------------------------------------------------------------
/configs/dog_skel.yaml:
--------------------------------------------------------------------------------
1 | offsets: [
2 | [-0.100563, 0.077338, -4.725500],
3 | [ 0.000000, 0.000000, 0.000000],
4 | [ 0.190000, 0.000000, 0.000000],
5 | [ 0.225000, 0.006000, 0.000000],
6 | [ 0.140000, 0.000309, 0.000000],
7 | [ 0.170000, 0.000000, 0.000000],
8 | [ 0.198000, 0.037000, 0.043000],
9 | [ 0.080000, 0.000000, 0.000000],
10 | [ 0.152000, 0.000000, 0.000000],
11 | [ 0.178000, 0.000000, 0.000000],
12 | [ 0.072000, 0.000000, 0.000000],
13 | [ 0.198000, 0.037000, -0.043000],
14 | [ 0.080000, 0.000000, 0.001517],
15 | [ 0.152000, 0.000000, 0.000000],
16 | [ 0.178000, 0.000000, 0.000000],
17 | [ 0.072000, 0.000000, 0.000000],
18 | [ 0.059843, -0.076660, 0.047888],
19 | [ 0.160000, 0.000000, 0.000000],
20 | [ 0.180000, 0.000000, 0.000000],
21 | [ 0.000000, -0.108000, 0.000000],
22 | [ 0.059843, -0.076660, -0.047888],
23 | [ 0.160000, 0.000000, 0.000000],
24 | [ 0.180000, 0.000000, 0.000000],
25 | [ 0.000000, -0.108000, 0.000000],
26 | [ 0.068370, -0.007226, 0.000000],
27 | [ 0.120000, 0.000000, 0.000000],
28 | [ 0.120000, 0.000000, 0.000000]
29 | ]
30 |
31 | parents: [-1, 0, 1, 2, 3, 4, 2, 6, 7, 8, 9, 2, 11, 12, 13, 14, 0, 16, 17, 18, 0, 20, 21, 22, 0, 24, 25]
32 |
33 | joint_names: ['Hips', 'Spine', 'Spine1', 'Neck', 'Head', 'Head_End', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'LHand_End', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand', 'RHand_End', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LFoot_End', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RFoot_End', 'Tail', 'Tail1', 'Tail_End']
34 |
--------------------------------------------------------------------------------
/configs/evaluation.yaml:
--------------------------------------------------------------------------------
1 | multi_gpu: False
2 | verbose: True
3 |
4 | is_train: True
5 | log: True
6 |
7 | epoch_begin: 0
8 | epoch_num: 1000
9 | checkpoint: 100
10 |
11 | batch_size: 16
12 | num_workers: 0
13 |
14 | dataset_dir: ./data/amass/generative
15 | save_dir: ./outputs/prior
16 | bvh_viz: False
17 |
18 | output_trans: True
19 | pretrained_gmp:
20 |
21 | adam_optimizer: True
22 | learning_rate: 0.0001
23 | weight_decay: 0.0001
24 | geodesic_loss: True
25 | l1_loss: True
26 | lambda_rotmat: 1
27 | lambda_pos: 10
28 | lambda_orient: 1
29 |
30 | lambda_v: 1
31 | lambda_up: 1
32 | lambda_trans: 1
33 | lambda_contacts: 0.5
34 | lambda_kl: 0
35 |
36 | annealing_cycles: 50
37 | annealing_warmup: 25
38 |
39 | data:
40 | fps: 30
41 | clip_length: 128
42 | gender: male
43 | up: z
44 | root_transform: True
45 | normalize: ['pos', 'velocity', 'global_xform', 'angular', 'root_orient', 'root_vel']
46 |
47 | local_prior:
48 | activation: tanh
49 | channel_base: 15 # 3(pos) + 3(vel) + 6(rot) + 3(ang)
50 | use_residual_blocks: True
51 | z_dim: 1024
52 | temporal_scale: 8
53 | kernel_size: 4
54 | num_layers: 4
55 | skeleton_dist: 2
56 | extra_conv: 0
57 | padding_mode: reflect
58 | skeleton_pool: mean
59 | upsampling: linear
60 |
61 | global_prior:
62 | activation: tanh
63 | in_channels: 9 # 6 (root orient) + 3 (root vel)
64 | kernel_size: 4
65 | temporal_scale: 8
66 | z_dim: 256
67 |
68 | nemf:
69 | siren: False
70 | skip_connection: True
71 | norm_layer: True
72 | bandwidth: 7
73 | hidden_neuron: 1024
74 | local_z: 1024
75 | global_z: 256
76 | local_output: 144
77 | global_output: 18 # 6 + 3 + 1 + 8
78 |
79 |
80 | scheduler:
81 | name:
82 |
--------------------------------------------------------------------------------
/configs/generative.yaml:
--------------------------------------------------------------------------------
1 | multi_gpu: True
2 | verbose: True
3 |
4 | is_train: True
5 | log: True
6 |
7 | epoch_begin: 0
8 | epoch_num: 1000
9 | checkpoint: 100
10 |
11 | batch_size: 16
12 | num_workers: 0
13 |
14 | dataset_dir: ./data/amass/generative
15 | save_dir: ./outputs/generative
16 | bvh_viz: False
17 |
18 | output_trans: False
19 | # pretrained_gmp: gmp.yaml
20 | pretrained_gmp:
21 |
22 | adam_optimizer: False
23 | learning_rate: 0.0001
24 | weight_decay: 0.0001
25 | geodesic_loss: True
26 | l1_loss: True
27 | lambda_rotmat: 1
28 | lambda_pos: 10
29 | lambda_orient: 1
30 |
31 | lambda_v: 1
32 | lambda_up: 1
33 | lambda_trans: 1
34 | lambda_contacts: 0.5
35 | lambda_kl: 0.00001
36 |
37 | annealing_cycles: 50
38 | annealing_warmup: 25
39 |
40 | data:
41 | fps: 30
42 | clip_length: 128
43 | gender: male
44 | up: z
45 | root_transform: True
46 | normalize: ['pos', 'velocity', 'global_xform', 'angular', 'root_orient', 'root_vel']
47 |
48 | local_prior:
49 | activation: tanh
50 | channel_base: 15 # 3(pos) + 3(vel) + 6(rot) + 3(ang)
51 | use_residual_blocks: True
52 | z_dim: 1024
53 | temporal_scale: 8
54 | kernel_size: 4
55 | num_layers: 4
56 | skeleton_dist: 2
57 | extra_conv: 0
58 | padding_mode: reflect
59 | skeleton_pool: mean
60 | upsampling: linear
61 |
62 | global_prior:
63 | activation: tanh
64 | in_channels: 6
65 | kernel_size: 4
66 | temporal_scale: 8
67 | z_dim: 256
68 |
69 | nemf:
70 | siren: False
71 | skip_connection: True
72 | norm_layer: True
73 | bandwidth: 7
74 | hidden_neuron: 1024
75 | local_z: 1024
76 | global_z: 256
77 | local_output: 144
78 | global_output: 6
79 |
80 |
81 | scheduler:
82 | name:
83 |
--------------------------------------------------------------------------------
/configs/gmp.yaml:
--------------------------------------------------------------------------------
1 | multi_gpu: True
2 | verbose: True
3 |
4 | is_train: True
5 | log: True
6 |
7 | epoch_begin: 0
8 | epoch_num: 1000
9 | checkpoint: 100
10 |
11 | batch_size: 16
12 | num_workers: 0
13 |
14 | dataset_dir: ./data/amass/gmp
15 | save_dir: ./outputs/gmp
16 |
17 | learning_rate: 0.0001
18 | weight_decay: 0.0001
19 | l1_loss: True
20 | lambda_v: 1
21 | lambda_up: 1
22 | lambda_trans: 1
23 | lambda_contacts: 0.5
24 |
25 | data:
26 | fps: 30
27 | clip_length: 128
28 | gender: male
29 | up: z
30 | normalize: ['pos', 'velocity', 'rot6d', 'angular']
31 |
32 | global_motion:
33 | activation: relu
34 | channel_base: 15 # 3(pos) + 3(vel) + 6(rot) + 3(ang)
35 | out_channels: 12 # 3 + 1 + 8
36 | use_residual_blocks: True
37 | kernel_size: 15
38 | num_layers: 3
39 | skeleton_dist: 1
40 | extra_conv: 0
41 | padding_mode: reflect
42 | skeleton_pool: mean
43 |
44 | scheduler:
45 | name:
--------------------------------------------------------------------------------
/configs/smpl.yaml:
--------------------------------------------------------------------------------
1 | smpl_body_model: ./data/body_models/smpl
2 |
3 | offsets:
4 | male: [
5 | [-0.002174, 0.972724, 0.028584],
6 | [ 0.058581, -0.082280, -0.017664],
7 | [ 0.043451, -0.386469, 0.008037],
8 | [-0.014790, -0.426874, -0.037428],
9 | [ 0.041054, -0.060286, 0.122042],
10 | [-0.060310, -0.090513, -0.013543],
11 | [-0.043257, -0.383688, -0.004843],
12 | [ 0.019056, -0.420046, -0.034562],
13 | [-0.034840, -0.062106, 0.130323],
14 | [ 0.004439, 0.124404, -0.038385],
15 | [ 0.004488, 0.137956, 0.026820],
16 | [-0.002265, 0.056032, 0.002855],
17 | [-0.013390, 0.211635, -0.033468],
18 | [ 0.010113, 0.088937, 0.050410],
19 | [ 0.071702, 0.114000, -0.018898],
20 | [ 0.122921, 0.045205, -0.019046],
21 | [ 0.255332, -0.015649, -0.022946],
22 | [ 0.265709, 0.012698, -0.007375],
23 | [ 0.086691, -0.010636, -0.015594],
24 | [-0.082954, 0.112472, -0.023707],
25 | [-0.113228, 0.046853, -0.008472],
26 | [-0.260127, -0.014369, -0.031269],
27 | [-0.269108, 0.006794, -0.006027],
28 | [-0.088754, -0.008652, -0.010107]
29 | ]
30 | female: [
31 | [-0.000876, 0.909315, 0.027821],
32 | [ 0.071361, -0.089584, -0.008046],
33 | [ 0.030669, -0.364209, -0.006689],
34 | [-0.011554, -0.383348, -0.043502],
35 | [ 0.023338, -0.054645, 0.114370],
36 | [-0.069012, -0.088960, -0.004796],
37 | [-0.036152, -0.370650, -0.009185],
38 | [ 0.014029, -0.383638, -0.041892],
39 | [-0.022043, -0.046410, 0.117900],
40 | [-0.002508, 0.103257, -0.022185],
41 | [ 0.003581, 0.127658, -0.001713],
42 | [ 0.002027, 0.049072, 0.027867],
43 | [-0.001963, 0.208243, -0.049765],
44 | [ 0.003517, 0.062322, 0.050205],
45 | [ 0.075298, 0.117780, -0.036875],
46 | [ 0.085317, 0.031739, -0.007293],
47 | [ 0.251247, -0.011967, -0.027518],
48 | [ 0.238019, 0.009007, 0.000044],
49 | [ 0.079668, -0.009683, -0.013206],
50 | [-0.077033, 0.115606, -0.041811],
51 | [-0.089203, 0.032785, -0.009802],
52 | [-0.245990, -0.013152, -0.020162],
53 | [-0.245177, 0.008622, -0.003532],
54 | [-0.080400, -0.007248, -0.010419]
55 | ]
56 |
57 | parents: [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 17, 11, 19, 20, 21, 22]
58 |
59 | joint_names: ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Foot', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Foot', 'Spine1', 'Spine2', 'Spine3', 'Neck', 'Head', 'L_Collar', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Collar', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand']
60 |
61 | joints_to_use: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 37]
62 |
63 | leaf_joints: [10, 11, 15, 22, 23]
64 |
65 | lfoot_index: [7, 10]
66 |
67 | rfoot_index: [8, 11]
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/data/.gitkeep
--------------------------------------------------------------------------------
/outputs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/outputs/.gitkeep
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/nghorbani/human_body_prior
2 | git+https://github.com/nghorbani/configer
3 |
4 | autopep8
5 | pylint
6 | torch==1.9.0
7 | matplotlib
8 | plyfile
9 | scikit-learn
10 | tensorboard
11 | pandas
12 | tqdm
13 | pyyaml
14 | pybullet
15 | numba
16 | umap-learn[plot]
17 | transformations==2019.4.22
18 | torchgeometry==0.1.2
19 | setuptools==59.5.0
--------------------------------------------------------------------------------
/src/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import yaml
5 |
6 |
7 | class Struct:
8 | def __init__(self, **entries):
9 | self.__dict__.update(entries)
10 |
11 |
12 | class Arguments:
13 | def __init__(self, config_path, filename='default.yaml'):
14 | with open(os.path.join(config_path, 'smpl.yaml'), 'r') as f:
15 | smpl = yaml.safe_load(f)
16 | self.smpl = Struct(**smpl)
17 | self.smpl.offsets['male'] = np.array(self.smpl.offsets['male'])
18 | self.smpl.offsets['female'] = np.array(self.smpl.offsets['female'])
19 | self.smpl.parents = np.array(self.smpl.parents).astype(np.int32)
20 | self.smpl.joint_num = len(self.smpl.joints_to_use)
21 | self.smpl.joints_to_use = np.array(self.smpl.joints_to_use)
22 | self.smpl.joints_to_use = np.arange(0, 156).reshape((-1, 3))[self.smpl.joints_to_use].reshape(-1)
23 |
24 | with open(os.path.join(config_path, 'dog_skel.yaml'), 'r') as f:
25 | animal = yaml.safe_load(f)
26 | self.animal = Struct(**animal)
27 | self.animal.offsets = np.array(self.animal.offsets)
28 | self.animal.parents = np.array(self.animal.parents).astype(np.int32)
29 | self.animal.joint_num = len(self.animal.parents)
30 |
31 | self.filename = os.path.splitext(filename)[0]
32 | with open(os.path.join(config_path, filename), 'r') as f:
33 | config = yaml.safe_load(f)
34 |
35 | for key, value in config.items():
36 | if isinstance(value, dict):
37 | setattr(self, key, Struct(**value))
38 | else:
39 | setattr(self, key, value)
40 |
41 | self.json = config
42 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/src/datasets/__init__.py
--------------------------------------------------------------------------------
/src/datasets/animal.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import sys
4 |
5 | file_path = os.path.dirname(os.path.realpath(__file__))
6 | sys.path.append(os.path.join(file_path, '..'))
7 |
8 | import holden.BVH as BVH
9 | import numpy as np
10 | import torch
11 | from arguments import Arguments
12 | from human_body_prior.tools.omni_tools import log2file, makepath
13 | from nemf.fk import ForwardKinematicsLayer
14 | from rotations import matrix_to_rotation_6d, quaternion_to_matrix, rotation_6d_to_matrix
15 | from torch.utils.data import Dataset
16 | from tqdm import tqdm
17 | from utils import build_canonical_frame, estimate_angular_velocity, estimate_linear_velocity
18 |
19 |
20 | def dump_animal2single(animal_data_dir, logger):
21 | fk = ForwardKinematicsLayer(parents=args.animal.parents, positions=args.animal.offsets)
22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23 |
24 | bvh_fnames = glob.glob(os.path.join(animal_data_dir, '*.bvh'))
25 | index = 0
26 | for bvh_fname in tqdm(bvh_fnames):
27 | try:
28 | anim, _, ftime = BVH.load(bvh_fname)
29 | except:
30 | logger('Could not read %s! skipping..' % bvh_fname)
31 | continue
32 |
33 | N = len(anim.rotations)
34 |
35 | fps = np.around(1.0 / ftime).astype(np.int32)
36 | logger(f'{bvh_fname} frame: {N}\tFPS: {fps}')
37 |
38 | data_rotation = anim.rotations.qs
39 | data_translation = anim.positions[:, 0] / 100.0
40 |
41 | # insert identity quaternion
42 | data_rotation = np.insert(data_rotation, [5, 9, 13, 16, 19, 21], [1, 0, 0, 0], axis=1)
43 | poses = torch.from_numpy(np.asarray(data_rotation, np.float32)).to(device) # quaternion (T, J, 4)
44 | trans = torch.from_numpy(np.asarray(data_translation, np.float32)).to(device) # global translation (T, 3)
45 |
46 | # Compute necessary data for model training.
47 | rotmat = quaternion_to_matrix(poses) # rotation matrix (T, J, 3, 3)
48 | root_orient = rotmat[:, 0].clone()
49 | root_orient = matrix_to_rotation_6d(root_orient) # root orientation (T, 6)
50 | if args.unified_orientation:
51 | identity = torch.eye(3).cuda()
52 | identity = identity.view(1, 3, 3).repeat(rotmat.shape[0], 1, 1)
53 | rotmat[:, 0] = identity
54 | rot6d = matrix_to_rotation_6d(rotmat) # 6D rotation representation (T, J, 6)
55 |
56 | rot_seq = rotmat.clone()
57 | angular = estimate_angular_velocity(rot_seq.unsqueeze(0), dt=1.0 / args.data.fps).squeeze(0) # angular velocity of all the joints (T, J, 3)
58 |
59 | pos, global_xform = fk(rot6d) # local joint positions (T, J, 3), global transformation matrix for each joint (T, J, 4, 4)
60 | pos = pos.contiguous()
61 | global_xform = global_xform.contiguous()
62 | velocity = estimate_linear_velocity(pos.unsqueeze(0), dt=1.0 / args.data.fps).squeeze(0) # linear velocity of all the joints (T, J, 3)
63 |
64 | if args.unified_orientation:
65 | root_rotation = rotation_6d_to_matrix(root_orient) # (T, 3, 3)
66 | root_rotation = root_rotation.unsqueeze(1).repeat(1, args.animal.joint_num, 1, 1) # (T, J, 3, 3)
67 | global_pos = torch.matmul(root_rotation, pos.unsqueeze(-1)).squeeze(-1)
68 | height = global_pos + trans.unsqueeze(1)
69 | else:
70 | height = pos + trans.unsqueeze(1)
71 | height = height[..., 'xyz'.index(args.data.up)] # (T, J)
72 | root_vel = estimate_linear_velocity(trans.unsqueeze(0), dt=1.0 / args.data.fps).squeeze(0) # linear velocity of the root joint (T, 3)
73 |
74 | out_posepath = makepath(os.path.join(work_dir, 'train', f'trans_{index}.pt'), isfile=True)
75 | torch.save(trans.detach().cpu(), out_posepath) # (T, 3)
76 | torch.save(root_vel.detach().cpu(), out_posepath.replace('trans', 'root_vel')) # (T, 3)
77 | torch.save(pos.detach().cpu(), out_posepath.replace('trans', 'pos')) # (T, J, 3)
78 | torch.save(rotmat.detach().cpu(), out_posepath.replace('trans', 'rotmat')) # (T, J, 3, 3)
79 | torch.save(height.detach().cpu(), out_posepath.replace('trans', 'height')) # (T, J)
80 |
81 | if args.canonical:
82 | forward = rotmat[:, 0, :, 2].clone()
83 | canonical_frame = build_canonical_frame(forward, up_axis=args.data.up)
84 | root_rotation = canonical_frame.transpose(-2, -1) # (T, 3, 3)
85 | root_rotation = root_rotation.unsqueeze(1).repeat(1, args.animal.joint_num, 1, 1) # (T, J, 3, 3)
86 |
87 | if args.data.up == 'x':
88 | theta = torch.atan2(forward[..., 2], forward[..., 1])
89 | elif args.data.up == 'y':
90 | theta = torch.atan2(forward[..., 0], forward[..., 2])
91 | else:
92 | theta = torch.atan2(forward[..., 1], forward[..., 0])
93 | dt = 1.0 / args.data.fps
94 | forward_ang = (theta[1:] - theta[:-1]) / dt
95 | forward_ang = torch.cat((forward_ang, forward_ang[-1:]), dim=-1)
96 |
97 | local_pos = torch.matmul(root_rotation, pos.unsqueeze(-1)).squeeze(-1)
98 | local_vel = torch.matmul(root_rotation, velocity.unsqueeze(-1)).squeeze(-1)
99 | local_rot = torch.matmul(root_rotation, global_xform[:, :, :3, :3])
100 | local_rot = matrix_to_rotation_6d(local_rot)
101 | local_ang = torch.matmul(root_rotation, angular.unsqueeze(-1)).squeeze(-1)
102 |
103 | torch.save(forward.detach().cpu(), out_posepath.replace('trans', 'forward'))
104 | torch.save(forward_ang.detach().cpu(), out_posepath.replace('trans', 'forward_ang'))
105 | torch.save(local_pos.detach().cpu(), out_posepath.replace('trans', 'local_pos'))
106 | torch.save(local_vel.detach().cpu(), out_posepath.replace('trans', 'local_vel'))
107 | torch.save(local_rot.detach().cpu(), out_posepath.replace('trans', 'local_rot'))
108 | torch.save(local_ang.detach().cpu(), out_posepath.replace('trans', 'local_ang'))
109 | else:
110 | global_xform = global_xform[:, :, :3, :3] # (T, J, 3, 3)
111 | global_xform = matrix_to_rotation_6d(global_xform) # (T, J, 6)
112 |
113 | torch.save(rot6d.detach().cpu(), out_posepath.replace('trans', 'rot6d')) # (N, T, J, 6)
114 | torch.save(angular.detach().cpu(), out_posepath.replace('trans', 'angular')) # (N, T, J, 3)
115 | torch.save(global_xform.detach().cpu(), out_posepath.replace('trans', 'global_xform')) # (N, T, J, 6)
116 | torch.save(velocity.detach().cpu(), out_posepath.replace('trans', 'velocity')) # (N, T, J, 3)
117 | torch.save(root_orient.detach().cpu(), out_posepath.replace('trans', 'root_orient')) # (N, T, 3, 3)
118 |
119 | index += 1
120 |
121 |
122 | class Animal(Dataset):
123 | def __init__(self, dataset_dir):
124 | self.ds = {}
125 | for data_fname in glob.glob(os.path.join(dataset_dir, '*.pt')):
126 | k = os.path.basename(data_fname).split('-')[0]
127 | self.ds[k] = torch.load(data_fname)
128 |
129 | def __len__(self):
130 | return len(self.ds['trans'])
131 |
132 | def __getitem__(self, idx):
133 | data = {k: self.ds[k][idx] for k in self.ds.keys()}
134 |
135 | return data
136 |
137 |
138 | if __name__ == '__main__':
139 | animal_data_dir = '../AnimalData/'
140 |
141 | args = Arguments('./configs', filename='dog_mocap.yaml')
142 | work_dir = makepath(args.dataset_dir)
143 |
144 | log_name = os.path.join(work_dir, 'animal.log')
145 | if os.path.exists(log_name):
146 | os.remove(log_name)
147 | logger = log2file(log_name)
148 |
149 | logger('Start processing the animal mocap data ...')
150 | dump_animal2single(animal_data_dir, logger)
151 |
--------------------------------------------------------------------------------
/src/holden/AnimationStructure.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sparse
3 | import Animation as Animation
4 |
5 |
6 | """ Maya Functions """
7 |
8 | def load_from_maya(root):
9 | """
10 | Load joint parents and names from maya
11 |
12 | Parameters
13 | ----------
14 |
15 | root : PyNode
16 | Root Maya Node
17 |
18 | Returns
19 | -------
20 |
21 | (names, parents) : ([str], (J) ndarray)
22 | List of joint names and array
23 | of indices representing the parent
24 | joint for each joint J.
25 |
26 | Joint index -1 is used to represent
27 | that there is no parent joint
28 | """
29 |
30 | import pymel.core as pm
31 |
32 | names = []
33 | parents = []
34 |
35 | def unload_joint(j, parents, par):
36 |
37 | id = len(names)
38 | names.append(j)
39 | parents.append(par)
40 |
41 | children = [c for c in j.getChildren() if
42 | isinstance(c, pm.nt.Transform) and
43 | not isinstance(c, pm.nt.Constraint) and
44 | not any(pm.listRelatives(c, s=True)) and
45 | (any(pm.listRelatives(c, ad=True, ap=False, type='joint')) or isinstance(c, pm.nt.Joint))]
46 |
47 | map(lambda c: unload_joint(c, parents, id), children)
48 |
49 | unload_joint(root, parents, -1)
50 |
51 | return (names, parents)
52 |
53 |
54 | """ Family Functions """
55 |
56 | def joints(parents):
57 | """
58 | Parameters
59 | ----------
60 |
61 | parents : (J) ndarray
62 | parents array
63 |
64 | Returns
65 | -------
66 |
67 | joints : (J) ndarray
68 | Array of joint indices
69 | """
70 | return np.arange(len(parents), dtype=int)
71 |
72 | def joints_list(parents):
73 | """
74 | Parameters
75 | ----------
76 |
77 | parents : (J) ndarray
78 | parents array
79 |
80 | Returns
81 | -------
82 |
83 | joints : [ndarray]
84 | List of arrays of joint idices for
85 | each joint
86 | """
87 | return list(joints(parents)[:,np.newaxis])
88 |
89 | def parents_list(parents):
90 | """
91 | Parameters
92 | ----------
93 |
94 | parents : (J) ndarray
95 | parents array
96 |
97 | Returns
98 | -------
99 |
100 | parents : [ndarray]
101 | List of arrays of joint idices for
102 | the parents of each joint
103 | """
104 | return list(parents[:,np.newaxis])
105 |
106 |
107 | def children_list(parents):
108 | """
109 | Parameters
110 | ----------
111 |
112 | parents : (J) ndarray
113 | parents array
114 |
115 | Returns
116 | -------
117 |
118 | children : [ndarray]
119 | List of arrays of joint indices for
120 | the children of each joint
121 | """
122 |
123 | def joint_children(i):
124 | return [j for j, p in enumerate(parents) if p == i]
125 |
126 | return list(map(lambda j: np.array(joint_children(j)), joints(parents)))
127 |
128 |
129 | def descendants_list(parents):
130 | """
131 | Parameters
132 | ----------
133 |
134 | parents : (J) ndarray
135 | parents array
136 |
137 | Returns
138 | -------
139 |
140 | descendants : [ndarray]
141 | List of arrays of joint idices for
142 | the descendants of each joint
143 | """
144 |
145 | children = children_list(parents)
146 |
147 | def joint_descendants(i):
148 | return sum([joint_descendants(j) for j in children[i]], list(children[i]))
149 |
150 | return list(map(lambda j: np.array(joint_descendants(j)), joints(parents)))
151 |
152 |
153 | def ancestors_list(parents):
154 | """
155 | Parameters
156 | ----------
157 |
158 | parents : (J) ndarray
159 | parents array
160 |
161 | Returns
162 | -------
163 |
164 | ancestors : [ndarray]
165 | List of arrays of joint idices for
166 | the ancestors of each joint
167 | """
168 |
169 | decendants = descendants_list(parents)
170 |
171 | def joint_ancestors(i):
172 | return [j for j in joints(parents) if i in decendants[j]]
173 |
174 | return list(map(lambda j: np.array(joint_ancestors(j)), joints(parents)))
175 |
176 |
177 | """ Mask Functions """
178 |
179 | def mask(parents, filter):
180 | """
181 | Constructs a Mask for a give filter
182 |
183 | A mask is a (J, J) ndarray truth table for a given
184 | condition over J joints. For example there
185 | may be a mask specifying if a joint N is a
186 | child of another joint M.
187 |
188 | This could be constructed into a mask using
189 | `m = mask(parents, children_list)` and the condition
190 | of childhood tested using `m[N, M]`.
191 |
192 | Parameters
193 | ----------
194 |
195 | parents : (J) ndarray
196 | parents array
197 |
198 | filter : (J) ndarray -> [ndarray]
199 | function that outputs a list of arrays
200 | of joint indices for some condition
201 |
202 | Returns
203 | -------
204 |
205 | mask : (N, N) ndarray
206 | boolean truth table of given condition
207 | """
208 | m = np.zeros((len(parents), len(parents))).astype(bool)
209 | jnts = joints(parents)
210 | fltr = filter(parents)
211 | for i,f in enumerate(fltr): m[i,:] = np.any(jnts[:,np.newaxis] == f[np.newaxis,:], axis=1)
212 | return m
213 |
214 | def joints_mask(parents): return np.eye(len(parents)).astype(bool)
215 | def children_mask(parents): return mask(parents, children_list)
216 | def parents_mask(parents): return mask(parents, parents_list)
217 | def descendants_mask(parents): return mask(parents, descendants_list)
218 | def ancestors_mask(parents): return mask(parents, ancestors_list)
219 |
220 | """ Search Functions """
221 |
222 | def joint_chain_ascend(parents, start, end):
223 | chain = []
224 | while start != end:
225 | chain.append(start)
226 | start = parents[start]
227 | chain.append(end)
228 | return np.array(chain, dtype=int)
229 |
230 |
231 | """ Constraints """
232 |
233 | def constraints(anim, **kwargs):
234 | """
235 | Constraint list for Animation
236 |
237 | This constraint list can be used in the
238 | VerletParticle solver to constrain
239 | a animation global joint positions.
240 |
241 | Parameters
242 | ----------
243 |
244 | anim : Animation
245 | Input animation
246 |
247 | masses : (F, J) ndarray
248 | Optional list of masses
249 | for joints J across frames F
250 | defaults to weighting by
251 | vertical height
252 |
253 | Returns
254 | -------
255 |
256 | constraints : [(int, int, (F, J) ndarray, (F, J) ndarray, (F, J) ndarray)]
257 | A list of constraints in the format:
258 | (Joint1, Joint2, Masses1, Masses2, Lengths)
259 |
260 | """
261 |
262 | masses = kwargs.pop('masses', None)
263 |
264 | children = children_list(anim.parents)
265 | constraints = []
266 |
267 | points_offsets = Animation.offsets_global(anim)
268 | points = Animation.positions_global(anim)
269 |
270 | if masses is None:
271 | masses = 1.0 / (0.1 + np.absolute(points_offsets[:,1]))
272 | masses = masses[np.newaxis].repeat(len(anim), axis=0)
273 |
274 | for j in xrange(anim.shape[1]):
275 |
276 | """ Add constraints between all joints and their children """
277 | for c0 in children[j]:
278 |
279 | dists = np.sum((points[:, c0] - points[:, j])**2.0, axis=1)**0.5
280 | constraints.append((c0, j, masses[:,c0], masses[:,j], dists))
281 |
282 | """ Add constraints between all children of joint """
283 | for c1 in children[j]:
284 | if c0 == c1: continue
285 |
286 | dists = np.sum((points[:, c0] - points[:, c1])**2.0, axis=1)**0.5
287 | constraints.append((c0, c1, masses[:,c0], masses[:,c1], dists))
288 |
289 | return constraints
290 |
291 | """ Graph Functions """
292 |
293 | def graph(anim):
294 | """
295 | Generates a weighted adjacency matrix
296 | using local joint distances along
297 | the skeletal structure.
298 |
299 | Joints which are not connected
300 | are assigned the weight `0`.
301 |
302 | Joints which actually have zero distance
303 | between them, but are still connected, are
304 | perturbed by some minimal amount.
305 |
306 | The output of this routine can be used
307 | with the `scipy.sparse.csgraph`
308 | routines for graph analysis.
309 |
310 | Parameters
311 | ----------
312 |
313 | anim : Animation
314 | input animation
315 |
316 | Returns
317 | -------
318 |
319 | graph : (N, N) ndarray
320 | weight adjacency matrix using
321 | local distances along the
322 | skeletal structure from joint
323 | N to joint M. If joints are not
324 | directly connected are assigned
325 | the weight `0`.
326 | """
327 |
328 | graph = np.zeros(anim.shape[1], anim.shape[1])
329 | lengths = np.sum(anim.offsets**2.0, axis=1)**0.5 + 0.001
330 |
331 | for i,p in enumerate(anim.parents):
332 | if p == -1: continue
333 | graph[i,p] = lengths[p]
334 | graph[p,i] = lengths[p]
335 |
336 | return graph
337 |
338 |
339 | def distances(anim):
340 | """
341 | Generates a distance matrix for
342 | pairwise joint distances along
343 | the skeletal structure
344 |
345 | Parameters
346 | ----------
347 |
348 | anim : Animation
349 | input animation
350 |
351 | Returns
352 | -------
353 |
354 | distances : (N, N) ndarray
355 | array of pairwise distances
356 | along skeletal structure
357 | from some joint N to some
358 | joint M
359 | """
360 |
361 | distances = np.zeros((anim.shape[1], anim.shape[1]))
362 | generated = distances.copy().astype(bool)
363 |
364 | joint_lengths = np.sum(anim.offsets**2.0, axis=1)**0.5
365 | joint_children = children_list(anim)
366 | joint_parents = parents_list(anim)
367 |
368 | def find_distance(distances, generated, prev, i, j):
369 |
370 | """ If root, identity, or already generated, return """
371 | if j == -1: return (0.0, True)
372 | if j == i: return (0.0, True)
373 | if generated[i,j]: return (distances[i,j], True)
374 |
375 | """ Find best distances along parents and children """
376 | par_dists = [(joint_lengths[j], find_distance(distances, generated, j, i, p)) for p in joint_parents[j] if p != prev]
377 | out_dists = [(joint_lengths[c], find_distance(distances, generated, j, i, c)) for c in joint_children[j] if c != prev]
378 |
379 | """ Check valid distance and not dead end """
380 | par_dists = [a + d for (a, (d, f)) in par_dists if f]
381 | out_dists = [a + d for (a, (d, f)) in out_dists if f]
382 |
383 | """ All dead ends """
384 | if (out_dists + par_dists) == []: return (0.0, False)
385 |
386 | """ Get minimum path """
387 | dist = min(out_dists + par_dists)
388 | distances[i,j] = dist; distances[j,i] = dist
389 | generated[i,j] = True; generated[j,i] = True
390 |
391 | for i in xrange(anim.shape[1]):
392 | for j in xrange(anim.shape[1]):
393 | find_distance(distances, generated, -1, i, j)
394 |
395 | return distances
396 |
397 | def edges(parents):
398 | """
399 | Animation structure edges
400 |
401 | Parameters
402 | ----------
403 |
404 | parents : (J) ndarray
405 | parents array
406 |
407 | Returns
408 | -------
409 |
410 | edges : (M, 2) ndarray
411 | array of pairs where each
412 | pair contains two indices of a joints
413 | which corrisponds to an edge in the
414 | joint structure going from parent to child.
415 | """
416 |
417 | return np.array(list(zip(parents, joints(parents)))[1:])
418 |
419 |
420 | def incidence(parents):
421 | """
422 | Incidence Matrix
423 |
424 | Parameters
425 | ----------
426 |
427 | parents : (J) ndarray
428 | parents array
429 |
430 | Returns
431 | -------
432 |
433 | incidence : (N, M) ndarray
434 |
435 | Matrix of N joint positions by
436 | M edges which each entry is either
437 | 1 or -1 and multiplication by the
438 | joint positions returns the an
439 | array of vectors along each edge
440 | of the structure
441 | """
442 |
443 | es = edges(parents)
444 |
445 | inc = np.zeros((len(parents)-1, len(parents))).astype(np.int)
446 | for i, e in enumerate(es):
447 | inc[i,e[0]] = 1
448 | inc[i,e[1]] = -1
449 |
450 | return inc.T
451 |
--------------------------------------------------------------------------------
/src/holden/BVH.py:
--------------------------------------------------------------------------------
1 | from Quaternions import Quaternions
2 | from Animation import Animation
3 | import re
4 | import numpy as np
5 | import sys
6 |
7 |
8 | channelmap = {
9 | 'Xrotation': 'x',
10 | 'Yrotation': 'y',
11 | 'Zrotation': 'z'
12 | }
13 |
14 | channelmap_inv = {
15 | 'x': 'Xrotation',
16 | 'y': 'Yrotation',
17 | 'z': 'Zrotation',
18 | }
19 |
20 | ordermap = {
21 | 'x': 0,
22 | 'y': 1,
23 | 'z': 2,
24 | }
25 |
26 |
27 | def load(filename, start=None, end=None, order=None, world=False):
28 | """
29 | Reads a BVH file and constructs an animation
30 |
31 | Parameters
32 | ----------
33 | filename: str
34 | File to be opened
35 |
36 | start : int
37 | Optional Starting Frame
38 |
39 | end : int
40 | Optional Ending Frame
41 |
42 | order : str
43 | Optional Specifier for joint order.
44 | Given as string E.G 'xyz', 'zxy'
45 |
46 | world : bool
47 | If set to true euler angles are applied
48 | together in world space rather than local
49 | space
50 |
51 | Returns
52 | -------
53 |
54 | (animation, joint_names, frametime)
55 | Tuple of loaded animation and joint names
56 | """
57 |
58 | f = open(filename, "r")
59 |
60 | i = 0
61 | active = -1
62 | end_site = False
63 |
64 | names = []
65 | orients = Quaternions.id(0)
66 | offsets = np.array([]).reshape((0, 3))
67 | parents = np.array([], dtype=int)
68 |
69 | for line in f:
70 |
71 | if "HIERARCHY" in line:
72 | continue
73 | if "MOTION" in line:
74 | continue
75 |
76 | rmatch = re.match(r"ROOT (\w+)", line)
77 | if rmatch:
78 | names.append(rmatch.group(1))
79 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
80 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
81 | parents = np.append(parents, active)
82 | active = (len(parents)-1)
83 | continue
84 |
85 | if "{" in line:
86 | continue
87 |
88 | if "}" in line:
89 | if end_site:
90 | end_site = False
91 | else:
92 | active = parents[active]
93 | continue
94 |
95 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line)
96 | if offmatch:
97 | if not end_site:
98 | offsets[active] = np.array([list(map(float, offmatch.groups()))])
99 | continue
100 |
101 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line)
102 | if chanmatch:
103 | channels = int(chanmatch.group(1))
104 | if order is None:
105 | channelis = 0 if channels == 3 else 3
106 | channelie = 3 if channels == 3 else 6
107 | parts = line.split()[2+channelis:2+channelie]
108 | if any([p not in channelmap for p in parts]):
109 | continue
110 | order = "".join([channelmap[p] for p in parts])
111 | continue
112 |
113 | jmatch = re.match("\s*JOINT\s+(\w+)", line)
114 | if jmatch:
115 | names.append(jmatch.group(1))
116 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
117 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
118 | parents = np.append(parents, active)
119 | active = (len(parents)-1)
120 | continue
121 |
122 | if "End Site" in line:
123 | end_site = True
124 | continue
125 |
126 | fmatch = re.match("\s*Frames:\s+(\d+)", line)
127 | if fmatch:
128 | if start and end:
129 | fnum = (end - start)-1
130 | else:
131 | fnum = int(fmatch.group(1))
132 | jnum = len(parents)
133 | # result: [fnum, J, 3]
134 | positions = offsets[np.newaxis].repeat(fnum, axis=0)
135 | # result: [fnum, len(orients), 3]
136 | rotations = np.zeros((fnum, len(orients), 3))
137 | continue
138 |
139 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line)
140 | if fmatch:
141 | frametime = float(fmatch.group(1))
142 | continue
143 |
144 | if (start and end) and (i < start or i >= end-1):
145 | i += 1
146 | continue
147 |
148 | dmatch = line.strip().split()
149 | if dmatch:
150 | data_block = np.array(list(map(float, dmatch)))
151 | N = len(parents)
152 | fi = i - start if start else i
153 | if channels == 3:
154 | # This should be root positions[0:1] & all rotations
155 | positions[fi, 0:1] = data_block[0:3]
156 | rotations[fi, :] = data_block[3:].reshape(N, 3)
157 | elif channels == 6:
158 | data_block = data_block.reshape(N, 6)
159 | # fill in all positions
160 | positions[fi, :] = data_block[:, 0:3]
161 | rotations[fi, :] = data_block[:, 3:6]
162 | elif channels == 9:
163 | positions[fi, 0] = data_block[0:3]
164 | data_block = data_block[3:].reshape(N-1, 9)
165 | rotations[fi, 1:] = data_block[:, 3:6]
166 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9]
167 | else:
168 | raise Exception("Too many channels! %i" % channels)
169 |
170 | i += 1
171 |
172 | f.close()
173 |
174 | rotations = Quaternions.from_euler(np.radians(rotations), order=order, world=world)
175 |
176 | return (Animation(rotations, positions, orients, offsets, parents), names, frametime)
177 |
178 |
179 | def load_bfa(filename, start=None, end=None, order=None, world=False):
180 | """
181 | Reads a BVH file and constructs an animation
182 |
183 | !!! Read from bfa, will replace the end sites of arms by two joints (w/ unit rotation)
184 |
185 | Parameters
186 | ----------
187 | filename: str
188 | File to be opened
189 |
190 | start : int
191 | Optional Starting Frame
192 |
193 | end : int
194 | Optional Ending Frame
195 |
196 | order : str
197 | Optional Specifier for joint order.
198 | Given as string E.G 'xyz', 'zxy'
199 |
200 | world : bool
201 | If set to true euler angles are applied
202 | together in world space rather than local
203 | space
204 |
205 | Returns
206 | -------
207 |
208 | (animation, joint_names, frametime)
209 | Tuple of loaded animation and joint names
210 | """
211 |
212 | f = open(filename, "r")
213 |
214 | i = 0
215 | active = -1
216 | end_site = False
217 |
218 | hand_idx = [9, 14]
219 |
220 | names = []
221 | orients = Quaternions.id(0)
222 | offsets = np.array([]).reshape((0, 3))
223 | parents = np.array([], dtype=int)
224 |
225 | for line in f:
226 |
227 | if "HIERARCHY" in line:
228 | continue
229 | if "MOTION" in line:
230 | continue
231 |
232 | rmatch = re.match(r"ROOT (\w+)", line)
233 | if rmatch:
234 | names.append(rmatch.group(1))
235 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
236 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
237 | parents = np.append(parents, active)
238 | active = (len(parents)-1)
239 | continue
240 |
241 | if "{" in line:
242 | continue
243 |
244 | if "}" in line:
245 | if end_site:
246 | end_site = False
247 | else:
248 | active = parents[active]
249 | continue
250 |
251 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line)
252 | if offmatch:
253 | if not end_site:
254 | offsets[active] = np.array([list(map(float, offmatch.groups()))])
255 | """
256 | else:
257 | print("active = ", active)
258 | if active in hand_idx:
259 | offsets[active] = np.array([list(map(float, offmatch.groups()))])
260 | """
261 | continue
262 |
263 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line)
264 | if chanmatch:
265 | channels = int(chanmatch.group(1))
266 | if order is None:
267 | channelis = 0 if channels == 3 else 3
268 | channelie = 3 if channels == 3 else 6
269 | parts = line.split()[2+channelis:2+channelie]
270 | if any([p not in channelmap for p in parts]):
271 | continue
272 | order = "".join([channelmap[p] for p in parts])
273 | continue
274 |
275 | jmatch = re.match("\s*JOINT\s+(\w+)", line)
276 | if jmatch:
277 | names.append(jmatch.group(1))
278 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
279 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
280 | parents = np.append(parents, active)
281 | active = (len(parents)-1)
282 | continue
283 |
284 | if "End Site" in line:
285 | if active + 1 in hand_idx:
286 | print("parent:", names[-1])
287 | name = "LeftHandIndex" if active + 1 == hand_idx[0] else "RightHandIndex"
288 | names.append(name)
289 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
290 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
291 | parents = np.append(parents, active)
292 | active = (len(parents)-1)
293 | else:
294 | end_site = True
295 | continue
296 |
297 | fmatch = re.match("\s*Frames:\s+(\d+)", line)
298 | if fmatch:
299 | if start and end:
300 | fnum = (end - start)-1
301 | else:
302 | fnum = int(fmatch.group(1))
303 | jnum = len(parents)
304 | # result: [fnum, J, 3]
305 | positions = offsets[np.newaxis].repeat(fnum, axis=0)
306 | # result: [fnum, len(orients), 3]
307 | rotations = np.zeros((fnum, len(orients), 3))
308 | continue
309 |
310 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line)
311 | if fmatch:
312 | frametime = float(fmatch.group(1))
313 | continue
314 |
315 | if (start and end) and (i < start or i >= end-1):
316 | i += 1
317 | continue
318 |
319 | dmatch = line.strip().split()
320 | if dmatch:
321 | data_block = np.array(list(map(float, dmatch)))
322 | N = len(parents)
323 | fi = i - start if start else i
324 | if channels == 3:
325 | # This should be root positions[0:1] & all rotations
326 | positions[fi, 0:1] = data_block[0:3]
327 | tmp = data_block[3:].reshape(N - 2, 3)
328 | tmp = np.concatenate([tmp[:hand_idx[0]],
329 | np.array([[0, 0, 0]]),
330 | tmp[hand_idx[0]: hand_idx[1] - 1],
331 | np.array([[0, 0, 0]]),
332 | tmp[hand_idx[1] - 1:]], axis=0)
333 | rotations[fi, :] = tmp.reshape(N, 3)
334 | elif channels == 6:
335 | data_block = data_block.reshape(N, 6)
336 | # fill in all positions
337 | positions[fi, :] = data_block[:, 0:3]
338 | rotations[fi, :] = data_block[:, 3:6]
339 | elif channels == 9:
340 | positions[fi, 0] = data_block[0:3]
341 | data_block = data_block[3:].reshape(N-1, 9)
342 | rotations[fi, 1:] = data_block[:, 3:6]
343 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9]
344 | else:
345 | raise Exception("Too many channels! %i" % channels)
346 |
347 | i += 1
348 |
349 | f.close()
350 |
351 | rotations = Quaternions.from_euler(np.radians(rotations), order=order, world=world)
352 |
353 | return (Animation(rotations, positions, orients, offsets, parents), names, frametime)
354 |
355 |
356 | def save(filename, anim, names=None, frametime=1.0/24.0, order='zyx', positions=False, orients=True):
357 | """
358 | Saves an Animation to file as BVH
359 |
360 | Parameters
361 | ----------
362 | filename: str
363 | File to be saved to
364 |
365 | anim : Animation
366 | Animation to save
367 |
368 | names : [str]
369 | List of joint names
370 |
371 | order : str
372 | Optional Specifier for joint order.
373 | Given as string E.G 'xyz', 'zxy'
374 |
375 | frametime : float
376 | Optional Animation Frame time
377 |
378 | positions : bool
379 | Optional specfier to save bone
380 | positions for each frame
381 |
382 | orients : bool
383 | Multiply joint orients to the rotations
384 | before saving.
385 |
386 | """
387 |
388 | if names is None:
389 | names = ["joint_" + str(i) for i in range(len(anim.parents))]
390 |
391 | with open(filename, 'w') as f:
392 |
393 | t = ""
394 | f.write("%sHIERARCHY\n" % t)
395 | f.write("%sROOT %s\n" % (t, names[0]))
396 | f.write("%s{\n" % t)
397 | t += '\t'
398 |
399 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[0, 0], anim.offsets[0, 1], anim.offsets[0, 2]))
400 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" %
401 | (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))
402 |
403 | for i in range(anim.shape[1]):
404 | if anim.parents[i] == 0:
405 | t = save_joint(f, anim, names, t, i, order=order, positions=positions)
406 |
407 | t = t[:-1]
408 | f.write("%s}\n" % t)
409 |
410 | f.write("MOTION\n")
411 | f.write("Frames: %i\n" % anim.shape[0])
412 | f.write("Frame Time: %f\n" % frametime)
413 |
414 | # if orients:
415 | # rots = np.degrees((-anim.orients[np.newaxis] * anim.rotations).euler(order=order[::-1]))
416 | # else:
417 | # rots = np.degrees(anim.rotations.euler(order=order[::-1]))
418 | rots = np.degrees(anim.rotations.euler(order=order[::-1]))
419 | poss = anim.positions
420 |
421 | for i in range(anim.shape[0]):
422 | for j in range(anim.shape[1]):
423 |
424 | if positions or j == 0:
425 |
426 | f.write("%f %f %f %f %f %f " % (
427 | poss[i, j, 0], poss[i, j, 1], poss[i, j, 2],
428 | rots[i, j, ordermap[order[0]]], rots[i, j, ordermap[order[1]]], rots[i, j, ordermap[order[2]]]))
429 |
430 | else:
431 |
432 | f.write("%f %f %f " % (
433 | rots[i, j, ordermap[order[0]]], rots[i, j, ordermap[order[1]]], rots[i, j, ordermap[order[2]]]))
434 |
435 | f.write("\n")
436 |
437 |
438 | def save_joint(f, anim, names, t, i, order='zyx', positions=False):
439 |
440 | f.write("%sJOINT %s\n" % (t, names[i]))
441 | f.write("%s{\n" % t)
442 | t += '\t'
443 |
444 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[i, 0], anim.offsets[i, 1], anim.offsets[i, 2]))
445 |
446 | if positions:
447 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t,
448 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))
449 | else:
450 | f.write("%sCHANNELS 3 %s %s %s\n" % (t,
451 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))
452 |
453 | end_site = True
454 |
455 | for j in range(anim.shape[1]):
456 | if anim.parents[j] == i:
457 | t = save_joint(f, anim, names, t, j, order=order, positions=positions)
458 | end_site = False
459 |
460 | if end_site:
461 | f.write("%sEnd Site\n" % t)
462 | f.write("%s{\n" % t)
463 | t += '\t'
464 | f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0))
465 | t = t[:-1]
466 | f.write("%s}\n" % t)
467 |
468 | t = t[:-1]
469 | f.write("%s}\n" % t)
470 |
471 | return t
472 |
--------------------------------------------------------------------------------
/src/holden/Pivots.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from Quaternions import Quaternions
4 |
5 |
6 | class Pivots:
7 | """
8 | Pivots is an ndarray of angular rotations
9 |
10 | This wrapper provides some functions for
11 | working with pivots.
12 |
13 | These are particularly useful as a number
14 | of atomic operations (such as adding or
15 | subtracting) cannot be achieved using
16 | the standard arithmatic and need to be
17 | defined differently to work correctly
18 | """
19 |
20 | def __init__(self, ps): self.ps = np.array(ps)
21 | def __str__(self): return "Pivots(" + str(self.ps) + ")"
22 | def __repr__(self): return "Pivots(" + repr(self.ps) + ")"
23 |
24 | def __add__(self, other): return Pivots(np.arctan2(np.sin(self.ps + other.ps), np.cos(self.ps + other.ps)))
25 | def __sub__(self, other): return Pivots(np.arctan2(np.sin(self.ps - other.ps), np.cos(self.ps - other.ps)))
26 | def __mul__(self, other): return Pivots(self.ps * other.ps)
27 | def __div__(self, other): return Pivots(self.ps / other.ps)
28 | def __mod__(self, other): return Pivots(self.ps % other.ps)
29 | def __pow__(self, other): return Pivots(self.ps ** other.ps)
30 |
31 | def __lt__(self, other): return self.ps < other.ps
32 | def __le__(self, other): return self.ps <= other.ps
33 | def __eq__(self, other): return self.ps == other.ps
34 | def __ne__(self, other): return self.ps != other.ps
35 | def __ge__(self, other): return self.ps >= other.ps
36 | def __gt__(self, other): return self.ps > other.ps
37 |
38 | def __abs__(self): return Pivots(abs(self.ps))
39 | def __neg__(self): return Pivots(-self.ps)
40 |
41 | def __iter__(self): return iter(self.ps)
42 | def __len__(self): return len(self.ps)
43 |
44 | def __getitem__(self, k): return Pivots(self.ps[k])
45 | def __setitem__(self, k, v): self.ps[k] = v.ps
46 |
47 | def _ellipsis(self): return tuple(map(lambda x: slice(None), self.shape))
48 |
49 | def quaternions(self, plane='xz'):
50 | fa = self._ellipsis()
51 | axises = np.ones(self.ps.shape + (3,))
52 | axises[fa + ("xyz".index(plane[0]),)] = 0.0
53 | axises[fa + ("xyz".index(plane[1]),)] = 0.0
54 | return Quaternions.from_angle_axis(self.ps, axises)
55 |
56 | def directions(self, plane='xz'):
57 | dirs = np.zeros((len(self.ps), 3))
58 | dirs[..., "xyz".index(plane[0])] = np.sin(self.ps)
59 | dirs[..., "xyz".index(plane[1])] = np.cos(self.ps)
60 | return dirs
61 |
62 | def normalized(self):
63 | xs = np.copy(self.ps)
64 | while np.any(xs > np.pi):
65 | xs[xs > np.pi] = xs[xs > np.pi] - 2 * np.pi
66 | while np.any(xs < -np.pi):
67 | xs[xs < -np.pi] = xs[xs < -np.pi] + 2 * np.pi
68 | return Pivots(xs)
69 |
70 | def interpolate(self, ws):
71 | dir = np.average(self.directions, weights=ws, axis=0)
72 | return np.arctan2(dir[2], dir[0])
73 |
74 | def copy(self):
75 | return Pivots(np.copy(self.ps))
76 |
77 | @property
78 | def shape(self):
79 | return self.ps.shape
80 |
81 | @classmethod
82 | def from_quaternions(cls, qs, forward='z', plane='xz'):
83 | ds = np.zeros(qs.shape + (3,))
84 | ds[..., 'xyz'.index(forward)] = 1.0
85 | return Pivots.from_directions(qs * ds, plane=plane)
86 |
87 | @classmethod
88 | def from_directions(cls, ds, plane='xz'):
89 | ys = ds[..., 'xyz'.index(plane[0])]
90 | xs = ds[..., 'xyz'.index(plane[1])]
91 | return Pivots(np.arctan2(ys, xs))
92 |
--------------------------------------------------------------------------------
/src/holden/Quaternions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Quaternions:
5 | """
6 | Quaternions is a wrapper around a numpy ndarray
7 | that allows it to act as if it were an narray of
8 | a quaternion data type.
9 |
10 | Therefore addition, subtraction, multiplication,
11 | division, negation, absolute, are all defined
12 | in terms of quaternion operations such as quaternion
13 | multiplication.
14 |
15 | This allows for much neater code and many routines
16 | which conceptually do the same thing to be written
17 | in the same way for point data and for rotation data.
18 |
19 | The Quaternions class has been desgined such that it
20 | should support broadcasting and slicing in all of the
21 | usual ways.
22 | """
23 |
24 | def __init__(self, qs):
25 | if isinstance(qs, np.ndarray):
26 |
27 | if len(qs.shape) == 1:
28 | qs = np.array([qs])
29 | self.qs = qs
30 | return
31 |
32 | if isinstance(qs, Quaternions):
33 | self.qs = qs.qs
34 | return
35 |
36 | raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs))
37 |
38 | def __str__(self): return "Quaternions(" + str(self.qs) + ")"
39 | def __repr__(self): return "Quaternions(" + repr(self.qs) + ")"
40 | """ Helper Methods for Broadcasting and Data extraction """
41 |
42 | @classmethod
43 | def _broadcast(cls, sqs, oqs, scalar=False):
44 | if isinstance(oqs, float):
45 | return sqs, oqs * np.ones(sqs.shape[:-1])
46 |
47 | ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1])
48 | os = np.array(oqs.shape)
49 |
50 | if len(ss) != len(os):
51 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
52 |
53 | if np.all(ss == os):
54 | return sqs, oqs
55 |
56 | if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))):
57 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
58 |
59 | sqsn, oqsn = sqs.copy(), oqs.copy()
60 |
61 | for a in np.where(ss == 1)[0]:
62 | sqsn = sqsn.repeat(os[a], axis=a)
63 | for a in np.where(os == 1)[0]:
64 | oqsn = oqsn.repeat(ss[a], axis=a)
65 |
66 | return sqsn, oqsn
67 |
68 | """ Adding Quaterions is just Defined as Multiplication """
69 |
70 | def __add__(self, other): return self * other
71 | def __sub__(self, other): return self / other
72 | """ Quaterion Multiplication """
73 |
74 | def __mul__(self, other):
75 | """
76 | Quaternion multiplication has three main methods.
77 |
78 | When multiplying a Quaternions array by Quaternions
79 | normal quaternion multiplication is performed.
80 |
81 | When multiplying a Quaternions array by a vector
82 | array of the same shape, where the last axis is 3,
83 | it is assumed to be a Quaternion by 3D-Vector
84 | multiplication and the 3D-Vectors are rotated
85 | in space by the Quaternions.
86 |
87 | When multipplying a Quaternions array by a scalar
88 | or vector of different shape it is assumed to be
89 | a Quaternions by Scalars multiplication and the
90 | Quaternions are scaled using Slerp and the identity
91 | quaternions.
92 | """
93 |
94 | """ If Quaternions type do Quaternions * Quaternions """
95 | if isinstance(other, Quaternions):
96 | sqs, oqs = Quaternions._broadcast(self.qs, other.qs)
97 |
98 | q0 = sqs[..., 0]
99 | q1 = sqs[..., 1]
100 | q2 = sqs[..., 2]
101 | q3 = sqs[..., 3]
102 | r0 = oqs[..., 0]
103 | r1 = oqs[..., 1]
104 | r2 = oqs[..., 2]
105 | r3 = oqs[..., 3]
106 |
107 | qs = np.empty(sqs.shape)
108 | qs[..., 0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3
109 | qs[..., 1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2
110 | qs[..., 2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1
111 | qs[..., 3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0
112 |
113 | return Quaternions(qs)
114 |
115 | """ If array type do Quaternions * Vectors """
116 | if isinstance(other, np.ndarray) and other.shape[-1] == 3:
117 | vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1))
118 |
119 | return (self * (vs * -self)).imaginaries
120 |
121 | """ If float do Quaternions * Scalars """
122 | if isinstance(other, np.ndarray) or isinstance(other, float):
123 | return Quaternions.slerp(Quaternions.id_like(self), self, other)
124 |
125 | raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other)))
126 |
127 | def __div__(self, other):
128 | """
129 | When a Quaternion type is supplied, division is defined
130 | as multiplication by the inverse of that Quaternion.
131 |
132 | When a scalar or vector is supplied it is defined
133 | as multiplicaion of one over the supplied value.
134 | Essentially a scaling.
135 | """
136 |
137 | if isinstance(other, Quaternions):
138 | return self * (-other)
139 | if isinstance(other, np.ndarray):
140 | return self * (1.0 / other)
141 | if isinstance(other, float):
142 | return self * (1.0 / other)
143 | raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other)))
144 |
145 | def __eq__(self, other): return self.qs == other.qs
146 | def __ne__(self, other): return self.qs != other.qs
147 |
148 | def __neg__(self):
149 | """ Invert Quaternions """
150 | return Quaternions(self.qs * np.array([[1, -1, -1, -1]]))
151 |
152 | def __abs__(self):
153 | """ Unify Quaternions To Single Pole """
154 | qabs = self.normalized().copy()
155 | top = np.sum((qabs.qs) * np.array([1, 0, 0, 0]), axis=-1)
156 | bot = np.sum((-qabs.qs) * np.array([1, 0, 0, 0]), axis=-1)
157 | qabs.qs[top < bot] = -qabs.qs[top < bot]
158 | return qabs
159 |
160 | def __iter__(self): return iter(self.qs)
161 | def __len__(self): return len(self.qs)
162 |
163 | def __getitem__(self, k): return Quaternions(self.qs[k])
164 | def __setitem__(self, k, v): self.qs[k] = v.qs
165 |
166 | @property
167 | def lengths(self):
168 | return np.sum(self.qs**2.0, axis=-1)**0.5
169 |
170 | @property
171 | def reals(self):
172 | return self.qs[..., 0]
173 |
174 | @property
175 | def imaginaries(self):
176 | return self.qs[..., 1:4]
177 |
178 | @property
179 | def shape(self): return self.qs.shape[:-1]
180 |
181 | def repeat(self, n, **kwargs):
182 | return Quaternions(self.qs.repeat(n, **kwargs))
183 |
184 | def normalized(self):
185 | return Quaternions(self.qs / self.lengths[..., np.newaxis])
186 |
187 | def log(self):
188 | norm = abs(self.normalized())
189 | imgs = norm.imaginaries
190 | lens = np.sqrt(np.sum(imgs**2, axis=-1))
191 | lens = np.arctan2(lens, norm.reals) / (lens + 1e-10)
192 | return imgs * lens[..., np.newaxis]
193 |
194 | def constrained(self, axis):
195 |
196 | rl = self.reals
197 | im = np.sum(axis * self.imaginaries, axis=-1)
198 |
199 | t1 = -2 * np.arctan2(rl, im) + np.pi
200 | t2 = -2 * np.arctan2(rl, im) - np.pi
201 |
202 | top = Quaternions.exp(axis[np.newaxis] * (t1[:, np.newaxis] / 2.0))
203 | bot = Quaternions.exp(axis[np.newaxis] * (t2[:, np.newaxis] / 2.0))
204 | img = self.dot(top) > self.dot(bot)
205 |
206 | ret = top.copy()
207 | ret[img] = top[img]
208 | ret[~img] = bot[~img]
209 | return ret
210 |
211 | def constrained_x(self): return self.constrained(np.array([1, 0, 0]))
212 | def constrained_y(self): return self.constrained(np.array([0, 1, 0]))
213 | def constrained_z(self): return self.constrained(np.array([0, 0, 1]))
214 |
215 | def dot(self, q): return np.sum(self.qs * q.qs, axis=-1)
216 |
217 | def copy(self): return Quaternions(np.copy(self.qs))
218 |
219 | def reshape(self, s):
220 | self.qs.reshape(s)
221 | return self
222 |
223 | def interpolate(self, ws):
224 | return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws))
225 |
226 | def euler(self, order='xyz'):
227 |
228 | q = self.normalized().qs
229 | q0 = q[..., 0]
230 | q1 = q[..., 1]
231 | q2 = q[..., 2]
232 | q3 = q[..., 3]
233 | es = np.zeros(self.shape + (3,))
234 |
235 | if order == 'xyz':
236 | es[..., 0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
237 | es[..., 1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1, 1))
238 | es[..., 2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
239 | elif order == 'yzx':
240 | es[..., 0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0)
241 | es[..., 1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0)
242 | es[..., 2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1, 1))
243 | else:
244 | raise NotImplementedError('Cannot convert from ordering %s' % order)
245 |
246 | """
247 |
248 | # These conversion don't appear to work correctly for Maya.
249 | # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/
250 |
251 | if order == 'xyz':
252 | es[fa + (0,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
253 | es[fa + (1,)] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))
254 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
255 | elif order == 'yzx':
256 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
257 | es[fa + (1,)] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1))
258 | es[fa + (2,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
259 | elif order == 'zxy':
260 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
261 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1))
262 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
263 | elif order == 'xzy':
264 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
265 | es[fa + (1,)] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1))
266 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
267 | elif order == 'yxz':
268 | es[fa + (0,)] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
269 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1))
270 | es[fa + (2,)] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
271 | elif order == 'zyx':
272 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
273 | es[fa + (1,)] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1))
274 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
275 | else:
276 | raise KeyError('Unknown ordering %s' % order)
277 |
278 | """
279 |
280 | # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp
281 | # Use this class and convert from matrix
282 |
283 | return es
284 |
285 | def average(self):
286 |
287 | if len(self.shape) == 1:
288 |
289 | import numpy.core.umath_tests as ut
290 | system = ut.matrix_multiply(self.qs[:, :, np.newaxis], self.qs[:, np.newaxis, :]).sum(axis=0)
291 | w, v = np.linalg.eigh(system)
292 | qiT_dot_qref = (self.qs[:, :, np.newaxis] * v[np.newaxis, :, :]).sum(axis=1)
293 | return Quaternions(v[:, np.argmin((1.-qiT_dot_qref**2).sum(axis=0))])
294 |
295 | else:
296 |
297 | raise NotImplementedError('Cannot average multi-dimensionsal Quaternions')
298 |
299 | def angle_axis(self):
300 |
301 | norm = self.normalized()
302 | s = np.sqrt(1 - (norm.reals**2.0))
303 | s[s == 0] = 0.001
304 |
305 | angles = 2.0 * np.arccos(norm.reals)
306 | axis = norm.imaginaries / s[..., np.newaxis]
307 |
308 | return angles, axis
309 |
310 | def transforms(self):
311 |
312 | qw = self.qs[..., 0]
313 | qx = self.qs[..., 1]
314 | qy = self.qs[..., 2]
315 | qz = self.qs[..., 3]
316 |
317 | x2 = qx + qx
318 | y2 = qy + qy
319 | z2 = qz + qz
320 | xx = qx * x2
321 | yy = qy * y2
322 | wx = qw * x2
323 | xy = qx * y2
324 | yz = qy * z2
325 | wy = qw * y2
326 | xz = qx * z2
327 | zz = qz * z2
328 | wz = qw * z2
329 |
330 | m = np.empty(self.shape + (3, 3))
331 | m[..., 0, 0] = 1.0 - (yy + zz)
332 | m[..., 0, 1] = xy - wz
333 | m[..., 0, 2] = xz + wy
334 | m[..., 1, 0] = xy + wz
335 | m[..., 1, 1] = 1.0 - (xx + zz)
336 | m[..., 1, 2] = yz - wx
337 | m[..., 2, 0] = xz - wy
338 | m[..., 2, 1] = yz + wx
339 | m[..., 2, 2] = 1.0 - (xx + yy)
340 |
341 | return m
342 |
343 | def ravel(self):
344 | return self.qs.ravel()
345 |
346 | @classmethod
347 | def id(cls, n):
348 |
349 | if isinstance(n, tuple):
350 | qs = np.zeros(n + (4,))
351 | qs[..., 0] = 1.0
352 | return Quaternions(qs)
353 |
354 | if isinstance(n, int):
355 | qs = np.zeros((n, 4))
356 | qs[:, 0] = 1.0
357 | return Quaternions(qs)
358 |
359 | raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n)))
360 |
361 | @classmethod
362 | def id_like(cls, a):
363 | qs = np.zeros(a.shape + (4,))
364 | qs[..., 0] = 1.0
365 | return Quaternions(qs)
366 |
367 | @classmethod
368 | def exp(cls, ws):
369 |
370 | ts = np.sum(ws**2.0, axis=-1)**0.5
371 | ts[ts == 0] = 0.001
372 | ls = np.sin(ts) / ts
373 |
374 | qs = np.empty(ws.shape[:-1] + (4,))
375 | qs[..., 0] = np.cos(ts)
376 | qs[..., 1] = ws[..., 0] * ls
377 | qs[..., 2] = ws[..., 1] * ls
378 | qs[..., 3] = ws[..., 2] * ls
379 |
380 | return Quaternions(qs).normalized()
381 |
382 | @classmethod
383 | def slerp(cls, q0s, q1s, a):
384 |
385 | fst, snd = cls._broadcast(q0s.qs, q1s.qs)
386 | fst, a = cls._broadcast(fst, a, scalar=True)
387 | snd, a = cls._broadcast(snd, a, scalar=True)
388 |
389 | len = np.sum(fst * snd, axis=-1)
390 |
391 | neg = len < 0.0
392 | len[neg] = -len[neg]
393 | snd[neg] = -snd[neg]
394 |
395 | amount0 = np.zeros(a.shape)
396 | amount1 = np.zeros(a.shape)
397 |
398 | linear = (1.0 - len) < 0.01
399 | omegas = np.arccos(len[~linear])
400 | sinoms = np.sin(omegas)
401 |
402 | amount0[linear] = 1.0 - a[linear]
403 | amount1[linear] = a[linear]
404 | amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms
405 | amount1[~linear] = np.sin(a[~linear] * omegas) / sinoms
406 |
407 | return Quaternions(
408 | amount0[..., np.newaxis] * fst +
409 | amount1[..., np.newaxis] * snd)
410 |
411 | @classmethod
412 | def between(cls, v0s, v1s):
413 | a = np.cross(v0s, v1s)
414 | w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1)
415 | return Quaternions(np.concatenate([w[..., np.newaxis], a], axis=-1)).normalized()
416 |
417 | @classmethod
418 | def from_angle_axis(cls, angles, axis):
419 | axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[..., np.newaxis]
420 | sines = np.sin(angles / 2.0)[..., np.newaxis]
421 | cosines = np.cos(angles / 2.0)[..., np.newaxis]
422 | return Quaternions(np.concatenate([cosines, axis * sines], axis=-1))
423 |
424 | @classmethod
425 | def from_euler(cls, es, order='xyz', world=False):
426 |
427 | axis = {
428 | 'x': np.array([1, 0, 0]),
429 | 'y': np.array([0, 1, 0]),
430 | 'z': np.array([0, 0, 1]),
431 | }
432 |
433 | q0s = Quaternions.from_angle_axis(es[..., 0], axis[order[0]])
434 | q1s = Quaternions.from_angle_axis(es[..., 1], axis[order[1]])
435 | q2s = Quaternions.from_angle_axis(es[..., 2], axis[order[2]])
436 |
437 | return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s))
438 |
439 | @classmethod
440 | def from_transforms(cls, ts):
441 |
442 | d0, d1, d2 = ts[..., 0, 0], ts[..., 1, 1], ts[..., 2, 2]
443 |
444 | q0 = (d0 + d1 + d2 + 1.0) / 4.0
445 | q1 = (d0 - d1 - d2 + 1.0) / 4.0
446 | q2 = (-d0 + d1 - d2 + 1.0) / 4.0
447 | q3 = (-d0 - d1 + d2 + 1.0) / 4.0
448 |
449 | q0 = np.sqrt(q0.clip(0, None))
450 | q1 = np.sqrt(q1.clip(0, None))
451 | q2 = np.sqrt(q2.clip(0, None))
452 | q3 = np.sqrt(q3.clip(0, None))
453 |
454 | c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3)
455 | c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3)
456 | c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3)
457 | c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2)
458 |
459 | q1[c0] *= np.sign(ts[c0, 2, 1] - ts[c0, 1, 2])
460 | q2[c0] *= np.sign(ts[c0, 0, 2] - ts[c0, 2, 0])
461 | q3[c0] *= np.sign(ts[c0, 1, 0] - ts[c0, 0, 1])
462 |
463 | q0[c1] *= np.sign(ts[c1, 2, 1] - ts[c1, 1, 2])
464 | q2[c1] *= np.sign(ts[c1, 1, 0] + ts[c1, 0, 1])
465 | q3[c1] *= np.sign(ts[c1, 0, 2] + ts[c1, 2, 0])
466 |
467 | q0[c2] *= np.sign(ts[c2, 0, 2] - ts[c2, 2, 0])
468 | q1[c2] *= np.sign(ts[c2, 1, 0] + ts[c2, 0, 1])
469 | q3[c2] *= np.sign(ts[c2, 2, 1] + ts[c2, 1, 2])
470 |
471 | q0[c3] *= np.sign(ts[c3, 1, 0] - ts[c3, 0, 1])
472 | q1[c3] *= np.sign(ts[c3, 2, 0] + ts[c3, 0, 2])
473 | q2[c3] *= np.sign(ts[c3, 2, 1] + ts[c3, 1, 2])
474 |
475 | qs = np.empty(ts.shape[:-2] + (4,))
476 | qs[..., 0] = q0
477 | qs[..., 1] = q1
478 | qs[..., 2] = q2
479 | qs[..., 3] = q3
480 |
481 | return cls(qs)
482 |
--------------------------------------------------------------------------------
/src/holden/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | BASEPATH = os.path.dirname(__file__)
4 | sys.path.insert(0, BASEPATH)
5 |
--------------------------------------------------------------------------------
/src/nemf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/src/nemf/__init__.py
--------------------------------------------------------------------------------
/src/nemf/base_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from abc import ABC, abstractmethod
4 |
5 | import torch
6 | import torch.nn.init
7 | import torch.optim
8 |
9 |
10 | def weights_init(init_type='default'):
11 | def init_fun(m):
12 | classname = m.__class__.__name__
13 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
14 | if init_type == 'gaussian':
15 | torch.nn.init.normal_(m.weight.data)
16 | elif init_type == 'xavier':
17 | torch.nn.init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
18 | elif init_type == 'kaiming':
19 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
20 | elif init_type == 'orthogonal':
21 | torch.nn.init.orthogonal_(m.weight.data, gain=math.sqrt(2))
22 | elif init_type == 'default':
23 | torch.nn.init.normal_(m.weight.data)
24 | weight_norm = m.weight.pow(2).sum(2, keepdim=True).sum(1, keepdim=True).add(1e-8).sqrt()
25 | m.weight.data.div_(weight_norm)
26 | else:
27 | raise NotImplementedError(f"Unsupported initialization: {init_type}")
28 | if hasattr(m, 'bias') and m.bias is not None:
29 | torch.nn.init.constant_(m.bias.data, 0.0)
30 | return init_fun
31 |
32 |
33 | class BaseModel(ABC):
34 | """This class is an abstract base class (ABC) for models.
35 | To create a subclass, you need to implement the following five functions:
36 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
37 | -- : unpack data from dataset and apply preprocessing.
38 | -- : produce intermediate results.
39 | -- : calculate losses, gradients, and update network weights.
40 | """
41 |
42 | def __init__(self, args):
43 | self.args = args
44 | self.is_train = args.is_train
45 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46 | self.model_save_dir = os.path.join(args.save_dir, 'checkpoints') # save all the checkpoints to save_dir
47 |
48 | if self.is_train and args.log:
49 | import time
50 | from torch.utils.tensorboard import SummaryWriter
51 | from .loss_record import LossRecorder
52 |
53 | log_dir = os.path.join(args.save_dir, 'logs')
54 | if args.epoch_begin != 0:
55 | all = [d for d in os.listdir(log_dir)]
56 | if len(all) == 0:
57 | raise RuntimeError(f'Empty logging path {log_dir}')
58 | timestamp = sorted(all)[-1]
59 | else:
60 | timestamp = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime(time.time()))
61 | self.log_path = os.path.join(log_dir, timestamp)
62 | os.makedirs(self.log_path, exist_ok=True)
63 | self.writer = SummaryWriter(self.log_path)
64 | self.loss_recorder = LossRecorder(self.writer)
65 |
66 | self.epoch_cnt = 0
67 | self.schedulers = []
68 | self.optimizers = []
69 | self.models = []
70 |
71 | @abstractmethod
72 | def set_input(self, input):
73 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
74 | Parameters:
75 | input (dict): includes the data itself and its metadata information.
76 | """
77 | pass
78 |
79 | @abstractmethod
80 | def compute_test_result(self):
81 | """
82 | After forward, do something like output bvh, get error value
83 | """
84 | pass
85 |
86 | @abstractmethod
87 | def forward(self):
88 | """Run forward pass; called by both functions and ."""
89 | pass
90 |
91 | @abstractmethod
92 | def optimize_parameters(self):
93 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
94 | pass
95 |
96 | def get_scheduler(self, optimizer):
97 | if self.args.scheduler.name == 'linear':
98 | def lambda_rule(epoch):
99 | lr_l = 1.0 - max(0, epoch - self.args.n_epochs_origin) / float(self.args.n_epochs_decay + 1)
100 | return lr_l
101 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
102 | if self.args.scheduler.name == 'StepLR':
103 | print('StepLR scheduler set')
104 | return torch.optim.lr_scheduler.StepLR(optimizer, self.args.scheduler.step_size, self.args.scheduler.gamma, verbose=self.args.verbose)
105 | if self.args.scheduler.name == 'Plateau':
106 | print('ReduceLROnPlateau shceduler set')
107 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=self.args.scheduler.factor, threshold=self.args.scheduler.threshold, patience=5, verbose=True)
108 | if self.args.scheduler.name == 'MultiStep':
109 | print('MultiStep shceduler set')
110 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.args.scheduler.milestones, gamma=self.args.scheduler.gamma, verbose=self.args.verbose)
111 | else:
112 | print('No scheduler set')
113 | return None
114 |
115 | def setup(self):
116 | """Load and print networks; create schedulers
117 | Parameters:
118 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
119 | """
120 | if self.is_train:
121 | self.schedulers = [self.get_scheduler(optimizer) for optimizer in self.optimizers]
122 |
123 | def epoch(self):
124 | if self.args.verbose:
125 | self.loss_recorder.epoch()
126 | for scheduler in self.schedulers:
127 | if scheduler is not None:
128 | if self.args.scheduler.name == 'Plateau':
129 | loss = self.loss_recorder.losses['total_loss'].current()
130 | scheduler.step(loss[1])
131 | else:
132 | scheduler.step()
133 | self.epoch_cnt += 1
134 |
135 | def test(self):
136 | """Forward function used in test time.
137 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
138 | It also calls to produce additional visualization results
139 | """
140 | with torch.no_grad():
141 | self.forward()
142 | self.compute_test_result()
143 |
144 | def train(self):
145 | for model in self.models:
146 | model.train()
147 | for param in model.parameters():
148 | param.requires_grad = True
149 |
150 | def eval(self):
151 | for model in self.models:
152 | model.eval()
153 | for param in model.parameters():
154 | param.requires_grad = False
155 |
156 | def print(self):
157 | for model in self.models:
158 | print(model)
159 |
160 | def count_params(self):
161 | params = 0
162 | for model in self.models:
163 | params += sum(p.numel() for p in model.parameters() if p.requires_grad)
164 |
165 | return params
166 |
--------------------------------------------------------------------------------
/src/nemf/basic.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import holden.BVH as BVH
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from holden.Animation import Animation
8 | from holden.Quaternions import Quaternions
9 | from human_body_prior.tools.omni_tools import copy2cpu as c2c
10 | from rotations import matrix_to_quaternion, quaternion_to_axis_angle, rotation_6d_to_matrix
11 | from utils import align_joints, compute_trajectory
12 |
13 | from .base_model import BaseModel
14 | from .fk import ForwardKinematicsLayer
15 | from .losses import GeodesicLoss
16 | from .neural_motion import NeuralMotionField
17 |
18 |
19 | class Architecture(BaseModel):
20 | def __init__(self, args, ngpu):
21 | super(Architecture, self).__init__(args)
22 |
23 | self.args = args
24 |
25 | if args.amass_data:
26 | self.fk = ForwardKinematicsLayer(args)
27 | else:
28 | self.fk = ForwardKinematicsLayer(parents=args.animal.parents, positions=args.animal.offsets)
29 |
30 | self.field = NeuralMotionField(args.nemf).to(self.device)
31 | if args.multi_gpu is True:
32 | self.field = nn.DataParallel(self.field, device_ids=range(ngpu))
33 |
34 | self.models = [self.field]
35 |
36 | ordermap = {
37 | 'x': 0,
38 | 'y': 1,
39 | 'z': 2,
40 | }
41 | self.up_axis = ordermap[self.args.data.up]
42 | self.v_axis = [x for x in ordermap.values() if x != self.up_axis]
43 |
44 | print(self.up_axis)
45 | print(self.v_axis)
46 |
47 | self.input_data = dict()
48 | self.recon_data = dict()
49 |
50 | if self.is_train:
51 | self.optimizer = torch.optim.Adam(self.field.parameters(), args.learning_rate)
52 | self.optimizers = [self.optimizer]
53 | self.criterion_rec = nn.L1Loss() if args.l1_loss else nn.MSELoss()
54 |
55 | self.criterion_geo = GeodesicLoss().to(self.device)
56 | self.fps = args.data.fps
57 | if args.bvh_viz:
58 | self.viz_dir = os.path.join(args.save_dir, 'results', 'bvh')
59 | else:
60 | self.viz_dir = os.path.join(args.save_dir, 'results', 'smpl')
61 |
62 | def set_input(self, input):
63 | self.input_data = {k: v.float().to(self.device) for k, v in input.items() if k in ['pos', 'global_xform', 'root_orient', 'root_vel', 'trans']}
64 | self.input_data['rotmat'] = rotation_6d_to_matrix(self.input_data['global_xform'])
65 |
66 | def forward(self, step=1):
67 | b_size, t_length, n_joints = self.input_data['pos'].shape[:3]
68 |
69 | t = torch.arange(start=0, end=t_length, step=step).unsqueeze(0) # (1, T)
70 | t = t / t_length * 2 - 1 # map t to [-1, 1]
71 | t = t.expand(b_size, -1).unsqueeze(-1).to(self.device) # (B, T, 1)
72 | f, _ = self.field(t) # (B, T, D), rot6d + root orient + root vel + root height
73 |
74 | self.recon_data = dict()
75 | rot6d_recon = f[:, :, :n_joints * 6] # (B, T, J x 6)
76 | rot6d_recon = rot6d_recon.view(b_size, -1, n_joints, 6) # (B, T, J, 6)
77 | rotmat_recon = rotation_6d_to_matrix(rot6d_recon) # (B, T, J, 3, 3)
78 | if self.args.data.root_transform:
79 | identity = torch.zeros(b_size, t.shape[1], 3, 3).to(self.device)
80 | identity[:, :, 0, 0] = 1
81 | identity[:, :, 1, 1] = 1
82 | identity[:, :, 2, 2] = 1
83 | rotmat_recon[:, :, 0] = identity
84 | rotmat_recon[:, :, -2] = rotmat_recon[:, :, -4]
85 | rotmat_recon[:, :, -1] = rotmat_recon[:, :, -3]
86 | local_rotmat = self.fk.global_to_local(rotmat_recon.view(-1, n_joints, 3, 3)) # (B x T, J, 3. 3)
87 | pos_recon, _ = self.fk(local_rotmat) # (B x T, J, 3)
88 | pos_recon = pos_recon.contiguous().view(b_size, -1, n_joints, 3) # (B, T, J, 3)
89 | self.recon_data['rotmat'] = rotmat_recon
90 | self.recon_data['pos'] = pos_recon
91 |
92 | if self.args.data.root_transform:
93 | root_orient = f[:, :, n_joints * 6:n_joints * 6 + 6] # (B, T, 6)
94 | self.recon_data['root_orient'] = root_orient
95 | else:
96 | self.recon_data['root_orient'] = rot6d_recon[:, :, 0] # (B, T, 6)
97 |
98 | root_vel = f[:, :, -4:-1] # (B, T, 3)
99 | root_height = f[:, :, -1] # (B, T)
100 |
101 | self.recon_data['root_vel'] = root_vel
102 | self.recon_data['root_height'] = root_height
103 |
104 | dt = 1.0 / self.args.data.fps * step
105 | origin = torch.zeros(b_size, 3).to(self.device)
106 | trans = compute_trajectory(root_vel, root_height, origin, dt, up_axis=self.args.data.up)
107 | self.recon_data['trans'] = trans
108 |
109 | def backward(self, validation=False):
110 | loss = 0
111 |
112 | root_orient = rotation_6d_to_matrix(self.recon_data['root_orient'])
113 | root_orient_gt = rotation_6d_to_matrix(self.input_data['root_orient'])
114 | if self.args.geodesic_loss:
115 | orient_recon_loss = self.criterion_geo(root_orient.view(-1, 3, 3), root_orient_gt.view(-1, 3, 3))
116 | else:
117 | orient_recon_loss = self.criterion_rec(root_orient, root_orient_gt)
118 | self.loss_recorder.add_scalar('orient_recon_loss', orient_recon_loss, validation=validation)
119 | loss += self.args.lambda_orient * orient_recon_loss
120 |
121 | if self.args.geodesic_loss:
122 | rotmat_recon_loss = self.criterion_geo(self.recon_data['rotmat'].view(-1, 3, 3), self.input_data['rotmat'].view(-1, 3, 3))
123 | else:
124 | rotmat_recon_loss = self.criterion_rec(self.recon_data['rotmat'], self.input_data['rotmat'])
125 | self.loss_recorder.add_scalar('rotmat_recon_loss', rotmat_recon_loss, validation=validation)
126 | loss += self.args.lambda_rotmat * rotmat_recon_loss
127 |
128 | pos_recon_loss = self.criterion_rec(self.recon_data['pos'], self.input_data['pos'])
129 | self.loss_recorder.add_scalar('pos_recon_loss', pos_recon_loss, validation=validation)
130 | loss += self.args.lambda_pos * pos_recon_loss
131 |
132 | v_loss = self.criterion_rec(self.recon_data['root_vel'], self.input_data['root_vel'])
133 | up_loss = self.criterion_rec(self.recon_data['root_height'], self.input_data['trans'][..., 'xyz'.index(self.args.data.up)])
134 | self.loss_recorder.add_scalar('v_loss', v_loss, validation=validation)
135 | self.loss_recorder.add_scalar('up_loss', up_loss, validation=validation)
136 | loss += self.args.lambda_v * v_loss + self.args.lambda_up * up_loss
137 |
138 | origin = self.input_data['trans'][:, 0]
139 | trans = self.recon_data['trans']
140 | trans[..., self.v_axis] = trans[..., self.v_axis] + origin[..., self.v_axis].unsqueeze(1)
141 | trans_loss = self.criterion_rec(trans, self.input_data['trans'])
142 | self.loss_recorder.add_scalar('trans_loss', trans_loss, validation=validation)
143 | loss += self.args.lambda_trans * trans_loss
144 |
145 | self.loss_recorder.add_scalar('total_loss', loss, validation=validation)
146 |
147 | if not validation:
148 | loss.backward()
149 |
150 | def optimize_parameters(self):
151 | self.optimizer.zero_grad()
152 | self.forward()
153 | self.backward(validation=False)
154 | self.optimizer.step()
155 |
156 | def report_errors(self):
157 | root_orient = rotation_6d_to_matrix(self.recon_data['root_orient']) # (B, T, 3, 3)
158 | root_orient_gt = rotation_6d_to_matrix(self.input_data['root_orient']) # (B, T, 3, 3)
159 | orientation_error = self.criterion_geo(root_orient.view(-1, 3, 3), root_orient_gt.view(-1, 3, 3))
160 |
161 | local_rotmat = self.fk.global_to_local(self.recon_data['rotmat'].view(-1, self.args.smpl.joint_num, 3, 3)) # (B x T, J, 3, 3)
162 | local_rotmat_gt = self.fk.global_to_local(self.input_data['rotmat'].view(-1, self.args.smpl.joint_num, 3, 3)) # (B x T, J, 3, 3)
163 | rotation_error = self.criterion_geo(local_rotmat[:, 1:].reshape(-1, 3, 3), local_rotmat_gt[:, 1:].reshape(-1, 3, 3))
164 |
165 | pos = self.recon_data['pos'] # (B, T, J, 3)
166 | pos_gt = self.input_data['pos'] # (B, T, J, 3)
167 | position_error = torch.linalg.norm((pos - pos_gt), dim=-1).mean()
168 |
169 | origin = self.input_data['trans'][:, 0]
170 | trans = self.recon_data['trans']
171 | trans[..., self.v_axis] = trans[..., self.v_axis] + origin[..., self.v_axis].unsqueeze(1)
172 | trans_gt = self.input_data['trans'] # (B, T, 3)
173 | translation_error = torch.linalg.norm((trans - trans_gt), dim=-1).mean()
174 |
175 | return {
176 | 'rotation': c2c(rotation_error),
177 | 'position': c2c(position_error),
178 | 'orientation': c2c(orientation_error),
179 | 'translation': c2c(translation_error)
180 | }
181 |
182 | def verbose(self):
183 | res = {}
184 | for loss in self.loss_recorder.losses.values():
185 | res[loss.name] = {'train': loss.current()[0]}
186 |
187 | return res
188 |
189 | def save(self, optimal=False):
190 | if optimal:
191 | path = os.path.join(self.args.save_dir, 'results', 'model')
192 | else:
193 | path = os.path.join(self.model_save_dir, f'{self.epoch_cnt:04d}')
194 |
195 | os.makedirs(path, exist_ok=True)
196 | nemf = self.field.module if isinstance(self.field, nn.DataParallel) else self.field
197 | torch.save(nemf.state_dict(), os.path.join(path, 'nemf.pth'))
198 | torch.save(self.optimizer.state_dict(), os.path.join(path, 'optimizer.pth'))
199 | if self.args.scheduler.name:
200 | torch.save(self.schedulers[0].state_dict(), os.path.join(path, 'scheduler.pth'))
201 | self.loss_recorder.save(path)
202 |
203 | print(f'Save at {path} succeeded')
204 |
205 | def load(self, epoch=None, optimal=False):
206 | if optimal:
207 | path = os.path.join(self.args.save_dir, 'results', 'model')
208 | else:
209 | if epoch is None:
210 | all = [int(q) for q in os.listdir(self.model_save_dir)]
211 | if len(all) == 0:
212 | raise RuntimeError(f'Empty loading path {self.model_save_dir}')
213 | epoch = sorted(all)[-1]
214 | path = os.path.join(self.model_save_dir, f'{epoch:04d}')
215 |
216 | print(f'Loading from {path}')
217 | nemf = self.field.module if isinstance(self.field, nn.DataParallel) else self.field
218 | nemf.load_state_dict(torch.load(os.path.join(path, 'nemf.pth'), map_location=self.device))
219 | if self.is_train:
220 | self.optimizer.load_state_dict(torch.load(os.path.join(path, 'optimizer.pth')))
221 | self.loss_recorder.load(path)
222 | if self.args.scheduler.name:
223 | self.schedulers[0].load_state_dict(torch.load(os.path.join(path, 'scheduler.pth')))
224 | self.epoch_cnt = epoch if not optimal else 0
225 |
226 | print('Load succeeded')
227 |
228 | def super_sampling(self, step=1):
229 | with torch.no_grad():
230 | self.forward(step)
231 | self.fps = self.args.data.fps / step
232 | self.compute_test_result()
233 |
234 | def compute_test_result(self):
235 | os.makedirs(self.viz_dir, exist_ok=True)
236 |
237 | rotmat = self.recon_data['rotmat'][0] # (T, J, 3, 3)
238 | rotmat_gt = self.input_data['rotmat'][0] # (T, J, 3, 3)
239 | local_rotmat = self.fk.global_to_local(rotmat) # (T, J, 3, 3)
240 | local_rotmat_gt = self.fk.global_to_local(rotmat_gt) # (T, J, 3, 3)
241 |
242 | if self.args.data.root_transform:
243 | root_orient = rotation_6d_to_matrix(self.recon_data['root_orient'][0]) # (T, 3, 3)
244 | root_orient_gt = rotation_6d_to_matrix(self.input_data['root_orient'][0]) # (T, 3, 3)
245 | local_rotmat[:, 0] = root_orient
246 | local_rotmat_gt[:, 0] = root_orient_gt
247 | rotation = matrix_to_quaternion(local_rotmat) # (T, J, 4)
248 | rotation_gt = matrix_to_quaternion(local_rotmat_gt) # (T, J, 4)
249 |
250 | origin = self.input_data['trans'][0, 0].clone() # [3]
251 | position = self.recon_data['trans'][0].clone() # (T, 3)
252 | position[:, self.v_axis] = position[:, self.v_axis] + origin[self.v_axis].unsqueeze(0)
253 | position_gt = self.input_data['trans'][0] # (T, 3)
254 |
255 | if self.args.bvh_viz:
256 | if self.args.amass_data:
257 | rotation = align_joints(rotation)
258 | rotation_gt = align_joints(rotation_gt)
259 |
260 | position = position.unsqueeze(1) # (T, 1, 3)
261 | position_gt = position_gt.unsqueeze(1) # (T, 1, 3)
262 |
263 | # export data to BVH files
264 | if self.args.amass_data:
265 | offsets = self.args.smpl.offsets[self.args.data.gender]
266 | parents = self.args.smpl.parents
267 | names = self.args.smpl.joint_names
268 | else:
269 | offsets = self.args.animal.offsets
270 | parents = self.args.animal.parents
271 | names = self.args.animal.joint_names
272 |
273 | anim = Animation(Quaternions(c2c(rotation)), c2c(position), None, offsets=offsets, parents=parents)
274 | BVH.save(os.path.join(self.viz_dir, f'test_{int(self.fps)}fps.bvh'), anim, names=names, frametime=1 / self.fps)
275 |
276 | anim.rotations = Quaternions(c2c(rotation_gt))
277 | anim.positions = c2c(position_gt)
278 | BVH.save(os.path.join(self.viz_dir, 'test_gt.bvh'), anim, names=names, frametime=1 / self.args.data.fps)
279 | else:
280 | poses = c2c(quaternion_to_axis_angle(rotation)) # (T, J, 3)
281 | poses_gt = c2c(quaternion_to_axis_angle(rotation_gt)) # (T, J, 3)
282 |
283 | poses = poses.reshape((poses.shape[0], -1)) # (T, J x 3)
284 | poses = np.pad(poses, [(0, 0), (0, 93)], mode='constant')
285 | poses_gt = poses_gt.reshape((poses_gt.shape[0], -1)) # (T, J x 3)
286 | poses_gt = np.pad(poses_gt, [(0, 0), (0, 93)], mode='constant')
287 |
288 | np.savez(os.path.join(self.viz_dir, f'test_{int(self.fps)}fps.npz'),
289 | poses=poses, trans=c2c(position), betas=np.zeros(10), gender=self.args.data.gender, mocap_framerate=self.fps)
290 | np.savez(os.path.join(self.viz_dir, 'test_gt.npz'),
291 | poses=poses_gt, trans=c2c(position_gt), betas=np.zeros(10), gender=self.args.data.gender, mocap_framerate=self.args.data.fps)
292 |
--------------------------------------------------------------------------------
/src/nemf/fk.py:
--------------------------------------------------------------------------------
1 | """Based on Daniel Holden code from:
2 | A Deep Learning Framework for Character Motion Synthesis and Editing
3 | (http://www.ipab.inf.ed.ac.uk/cgvu/motionsynthesis.pdf)
4 | """
5 |
6 | import os
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | from rotations import euler_angles_to_matrix, quaternion_to_matrix, rotation_6d_to_matrix
12 |
13 |
14 | class ForwardKinematicsLayer(nn.Module):
15 | """ Forward Kinematics Layer Class """
16 |
17 | def __init__(self, args=None, parents=None, positions=None, device=None):
18 | super().__init__()
19 | self.b_idxs = None
20 | if device is None:
21 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22 | else:
23 | self.device = device
24 | if parents is None and positions is None:
25 | # Load SMPL skeleton (their joint order is different from the one we use for bvh export)
26 | smpl_fname = os.path.join(args.smpl.smpl_body_model, args.data.gender, 'model.npz')
27 | smpl_data = np.load(smpl_fname, encoding='latin1')
28 | self.parents = torch.from_numpy(smpl_data['kintree_table'][0].astype(np.int32)).to(self.device)
29 | self.parents = self.parents.long()
30 | self.positions = torch.from_numpy(smpl_data['J'].astype(np.float32)).to(self.device)
31 | self.positions[1:] -= self.positions[self.parents[1:]]
32 | else:
33 | self.parents = torch.from_numpy(parents).to(self.device)
34 | self.parents = self.parents.long()
35 | self.positions = torch.from_numpy(positions).to(self.device)
36 | self.positions = self.positions.float()
37 | self.positions[0] = 0
38 |
39 | def rotate(self, t0s, t1s):
40 | return torch.matmul(t0s, t1s)
41 |
42 | def identity_rotation(self, rotations):
43 | diagonal = torch.diag(torch.tensor([1.0, 1.0, 1.0, 1.0])).to(self.device)
44 | diagonal = torch.reshape(
45 | diagonal, torch.Size([1] * len(rotations.shape[:2]) + [4, 4]))
46 | ts = diagonal.repeat(rotations.shape[:2] + torch.Size([1, 1]))
47 | return ts
48 |
49 | def make_fast_rotation_matrices(self, positions, rotations):
50 | if len(rotations.shape) == 4 and rotations.shape[-2:] == torch.Size([3, 3]):
51 | rot_matrices = rotations
52 | elif rotations.shape[-1] == 3:
53 | rot_matrices = euler_angles_to_matrix(rotations, convention='XYZ')
54 | elif rotations.shape[-1] == 4:
55 | rot_matrices = quaternion_to_matrix(rotations)
56 | elif rotations.shape[-1] == 6:
57 | rot_matrices = rotation_6d_to_matrix(rotations)
58 | else:
59 | raise NotImplementedError(f'Unimplemented rotation representation in FK layer, shape of {rotations.shape}')
60 |
61 | rot_matrices = torch.cat([rot_matrices, positions[..., None]], dim=-1)
62 | zeros = torch.zeros(rot_matrices.shape[:-2] + torch.Size([1, 3])).to(self.device)
63 | ones = torch.ones(rot_matrices.shape[:-2] + torch.Size([1, 1])).to(self.device)
64 | zerosones = torch.cat([zeros, ones], dim=-1)
65 | rot_matrices = torch.cat([rot_matrices, zerosones], dim=-2)
66 | return rot_matrices
67 |
68 | def rotate_global(self, parents, positions, rotations):
69 | locals = self.make_fast_rotation_matrices(positions, rotations)
70 | globals = self.identity_rotation(rotations)
71 |
72 | globals = torch.cat([locals[:, 0:1], globals[:, 1:]], dim=1)
73 | b_size = positions.shape[0]
74 | if self.b_idxs is None:
75 | self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device)
76 | elif self.b_idxs.shape[-1] != b_size:
77 | self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device)
78 |
79 | for i in range(1, positions.shape[1]):
80 | globals[:, i] = self.rotate(
81 | globals[self.b_idxs, parents[i]], locals[:, i])
82 |
83 | return globals
84 |
85 | def get_tpose_joints(self, offsets, parents):
86 | num_joints = len(parents)
87 | joints = [offsets[:, 0]]
88 | for j in range(1, len(parents)):
89 | joints.append(joints[parents[j]] + offsets[:, j])
90 |
91 | return torch.stack(joints, dim=1)
92 |
93 | def canonical_to_local(self, canonical_xform, global_orient=None):
94 | """
95 | Args:
96 | canonical_xform: (B, J, 3, 3)
97 | global_orient: (B, 3, 3)
98 |
99 | Returns:
100 | local_xform: (B, J, 3, 3)
101 | """
102 | local_xform = torch.zeros_like(canonical_xform)
103 |
104 | if global_orient is None:
105 | global_xform = canonical_xform
106 | else:
107 | global_xform = torch.matmul(global_orient.unsqueeze(1), canonical_xform)
108 | for i in range(global_xform.shape[1]):
109 | if i == 0:
110 | local_xform[:, i] = global_xform[:, i]
111 | else:
112 | local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i])
113 |
114 | return local_xform
115 |
116 | def global_to_local(self, global_xform):
117 | """
118 | Args:
119 | global_xform: (B, J, 3, 3)
120 |
121 | Returns:
122 | local_xform: (B, J, 3, 3)
123 | """
124 | local_xform = torch.zeros_like(global_xform)
125 |
126 | for i in range(global_xform.shape[1]):
127 | if i == 0:
128 | local_xform[:, i] = global_xform[:, i]
129 | else:
130 | local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i])
131 |
132 | return local_xform
133 |
134 | def forward(self, rotations, positions=None):
135 | """
136 | Args:
137 | rotations (B, J, D)
138 |
139 | Returns:
140 | The global position of each joint after FK (B, J, 3)
141 | """
142 | # Get the full transform with rotations for skinning
143 | b_size = rotations.shape[0]
144 | if positions is None:
145 | positions = self.positions.repeat(b_size, 1, 1)
146 | transforms = self.rotate_global(self.parents, positions, rotations)
147 | coordinates = transforms[:, :, :3, 3] / transforms[:, :, 3:, 3]
148 |
149 | return coordinates, transforms
150 |
--------------------------------------------------------------------------------
/src/nemf/global_motion.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from human_body_prior.tools.omni_tools import copy2cpu as c2c
7 | from rotations import matrix_to_axis_angle, rotation_6d_to_matrix
8 | from utils import compute_trajectory, export_ply_trajectory, normalize
9 |
10 | from .base_model import BaseModel
11 | from .fk import ForwardKinematicsLayer
12 | from .residual_blocks import ResidualBlock, SkeletonResidual, residual_ratio
13 | from .skeleton import SkeletonConv, SkeletonPool, build_edge_topology, find_neighbor
14 |
15 |
16 | class Predictor(nn.Module):
17 | def __init__(self, args, topology):
18 | super(Predictor, self).__init__()
19 | self.topologies = [topology]
20 | self.channel_base = [args.channel_base]
21 | self.channel_list = []
22 | self.edge_num = [len(topology)]
23 | self.pooling_list = []
24 | self.layers = nn.ModuleList()
25 | self.args = args
26 | # self.convs = []
27 |
28 | kernel_size = args.kernel_size
29 | padding = (kernel_size - 1) // 2
30 | bias = True
31 |
32 | for _ in range(args.num_layers):
33 | self.channel_base.append(self.channel_base[-1] * 2)
34 |
35 | for i in range(args.num_layers):
36 | seq = []
37 | neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist)
38 | in_channels = self.channel_base[i] * self.edge_num[i]
39 | out_channels = self.channel_base[i + 1] * self.edge_num[i]
40 | if i == 0:
41 | self.channel_list.append(in_channels)
42 | self.channel_list.append(out_channels)
43 | last_pool = True if i == args.num_layers - 1 else False
44 |
45 | # (T, J, D) => (T, J', D)
46 | pool = SkeletonPool(edges=self.topologies[i], pooling_mode=args.skeleton_pool,
47 | channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool)
48 |
49 | if args.use_residual_blocks:
50 | # (T, J, D) => (T, J', 2D)
51 | seq.append(SkeletonResidual(self.topologies[i], neighbour_list, joint_num=self.edge_num[i], in_channels=in_channels, out_channels=out_channels,
52 | kernel_size=kernel_size, stride=1, padding=padding, padding_mode=args.padding_mode, bias=bias,
53 | extra_conv=args.extra_conv, pooling_mode=args.skeleton_pool, activation=args.activation, last_pool=last_pool))
54 | else:
55 | for _ in range(args.extra_conv):
56 | # (T, J, D) => (T, J, D)
57 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
58 | joint_num=self.edge_num[i], kernel_size=kernel_size, stride=1,
59 | padding=padding, padding_mode=args.padding_mode, bias=bias))
60 | seq.append(nn.PReLU())
61 | # (T, J, D) => (T, J, 2D)
62 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
63 | joint_num=self.edge_num[i], kernel_size=kernel_size, stride=1,
64 | padding=padding, padding_mode=args.padding_mode, bias=bias, add_offset=False,
65 | in_offset_channel=3 * self.channel_base[i] // self.channel_base[0]))
66 | # self.convs.append(seq[-1])
67 |
68 | seq.append(pool)
69 | seq.append(nn.PReLU())
70 | self.layers.append(nn.Sequential(*seq))
71 |
72 | self.topologies.append(pool.new_edges)
73 | self.pooling_list.append(pool.pooling_list)
74 | self.edge_num.append(len(self.topologies[-1]))
75 |
76 | in_channels = self.channel_base[-1] * len(self.pooling_list[-1])
77 | out_channels = args.out_channels # root orient (6) + root vel (3) + root height (1) + contacts (8)
78 |
79 | self.global_motion = nn.Sequential(
80 | ResidualBlock(in_channels=in_channels, out_channels=512, kernel_size=kernel_size, stride=1, padding=padding, residual_ratio=residual_ratio(1), activation=args.activation),
81 | ResidualBlock(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=1, padding=padding, residual_ratio=residual_ratio(2), activation=args.activation),
82 | ResidualBlock(in_channels=256, out_channels=128, kernel_size=kernel_size, stride=1, padding=padding, residual_ratio=residual_ratio(3), activation=args.activation),
83 | ResidualBlock(in_channels=128, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=padding,
84 | residual_ratio=residual_ratio(4), activation=args.activation, last_layer=True)
85 | )
86 |
87 | def forward(self, input):
88 | feature = input
89 | for layer in self.layers:
90 | feature = layer(feature)
91 |
92 | return self.global_motion(feature)
93 |
94 |
95 | class GlobalMotionPredictor(BaseModel):
96 | def __init__(self, args, ngpu):
97 | super(GlobalMotionPredictor, self).__init__(args)
98 |
99 | self.args = args
100 |
101 | smpl_fname = os.path.join(args.smpl.smpl_body_model, args.data.gender, 'model.npz')
102 | smpl_data = np.load(smpl_fname, encoding='latin1')
103 | parents = smpl_data['kintree_table'][0].astype(np.int32)
104 | edges = build_edge_topology(parents)
105 |
106 | self.fk = ForwardKinematicsLayer(args)
107 |
108 | self.predictor = Predictor(args.global_motion, edges).to(self.device)
109 | if args.multi_gpu is True:
110 | self.predictor = nn.DataParallel(self.predictor, device_ids=range(ngpu))
111 | self.models = [self.predictor]
112 |
113 | ordermap = {
114 | 'x': 0,
115 | 'y': 1,
116 | 'z': 2,
117 | }
118 | self.up_axis = ordermap[self.args.data.up]
119 | self.v_axis = [x for x in ordermap.values() if x != self.up_axis]
120 |
121 | print(self.up_axis)
122 | print(self.v_axis)
123 |
124 | self.input_data = {}
125 | self.pred_data = {}
126 |
127 | self.mean = torch.load(os.path.join(args.dataset_dir, f'mean-{args.data.gender}-{args.data.clip_length}-{args.data.fps}fps.pt'), map_location=self.device)
128 | print(f'mean: {list(self.mean.keys())}')
129 | self.std = torch.load(os.path.join(args.dataset_dir, f'std-{args.data.gender}-{args.data.clip_length}-{args.data.fps}fps.pt'), map_location=self.device)
130 | print(f'std: {list(self.std.keys())}')
131 |
132 | if self.is_train:
133 | self.optimizer = torch.optim.Adam(self.predictor.parameters(), args.learning_rate, weight_decay=args.weight_decay)
134 | self.optimizers = [self.optimizer]
135 | self.criterion_pred = nn.L1Loss() if args.l1_loss else nn.MSELoss()
136 | self.criterion_bce = nn.BCEWithLogitsLoss()
137 | else:
138 | self.test_index = 0
139 | self.smpl_dir = os.path.join(args.save_dir, 'results', 'smpl')
140 |
141 | def set_input(self, input):
142 | self.input_data = {k: v.to(self.device) for k, v in input.items() if k in ['pos', 'velocity', 'rot6d', 'angular', 'global_xform',
143 | 'root_vel', 'trans', 'contacts']}
144 |
145 | def forward(self):
146 | self.pred_data.clear()
147 | self.input_data['origin'] = self.input_data['trans'][:, 0]
148 | self.pred_data = self.predict(data=self.input_data, dt=1.0 / self.args.data.fps)
149 |
150 | def predict(self, data, dt):
151 | b_size, t_length = data['pos'].shape[:2]
152 |
153 | if 'pos' in self.args.data.normalize:
154 | pos = normalize(data['pos'], mean=self.mean['pos'], std=self.std['pos'])
155 | else:
156 | pos = data['pos']
157 | if 'velocity' in self.args.data.normalize:
158 | velocity = normalize(data['velocity'], mean=self.mean['velocity'], std=self.std['velocity'])
159 | else:
160 | velocity = data['velocity']
161 | if 'rot6d' in self.args.data.normalize:
162 | rot6d = normalize(data['rot6d'], mean=self.mean['rot6d'], std=self.std['rot6d'])
163 | else:
164 | rot6d = data['rot6d']
165 | if 'angular' in self.args.data.normalize:
166 | angular = normalize(data['angular'], mean=self.mean['angular'], std=self.std['angular'])
167 | else:
168 | angular = data['angular']
169 |
170 | x = torch.cat((pos, velocity, rot6d, angular), dim=-1)
171 | x = x.view(b_size, t_length, -1) # (B, T, J x D)
172 | x = x.permute(0, 2, 1) # (B, J x D, T)
173 |
174 | global_motion = self.predictor(x)
175 | global_motion = global_motion.permute(0, 2, 1) # (B, T, D)
176 |
177 | output = dict()
178 | output['root_vel'] = global_motion[:, :, :3]
179 | output['root_height'] = global_motion[:, :, 3]
180 | output['contacts'] = global_motion[:, :, 4:]
181 | if 'origin' in data.keys():
182 | origin = data['origin']
183 | else:
184 | origin = torch.zeros(b_size, 3).to(self.device)
185 | output['trans'] = compute_trajectory(output['root_vel'], output['root_height'], origin, dt, up_axis=self.args.data.up)
186 |
187 | return output
188 |
189 | def backward(self, validation=False):
190 | loss = 0
191 |
192 | v_loss = self.criterion_pred(self.pred_data['root_vel'], self.input_data['root_vel'])
193 | up_loss = self.criterion_pred(self.pred_data['root_height'], self.input_data['trans'][..., self.up_axis])
194 | self.loss_recorder.add_scalar('v_loss', v_loss, validation=validation)
195 | self.loss_recorder.add_scalar('up_loss', up_loss, validation=validation)
196 | loss += self.args.lambda_v * v_loss + self.args.lambda_up * up_loss
197 |
198 | trans_loss = self.criterion_pred(self.pred_data['trans'], self.input_data['trans'])
199 | self.loss_recorder.add_scalar('trans_loss', trans_loss, validation=validation)
200 | loss += self.args.lambda_trans * trans_loss
201 |
202 | contacts_loss = self.criterion_bce(self.pred_data['contacts'], self.input_data['contacts'])
203 | self.loss_recorder.add_scalar('contacts_loss', contacts_loss, validation=validation)
204 | loss += self.args.lambda_contacts * contacts_loss
205 |
206 | self.loss_recorder.add_scalar('total_loss', loss, validation=validation)
207 |
208 | if not validation:
209 | loss.backward()
210 |
211 | def optimize_parameters(self):
212 | self.optimizer.zero_grad()
213 | self.forward()
214 | self.backward(validation=False)
215 | self.optimizer.step()
216 |
217 | def report_errors(self):
218 | trans = self.pred_data['trans'] # (B, T, 3)
219 | trans_gt = self.input_data['trans'] # (B, T, 3)
220 | translation_error = torch.linalg.norm((trans - trans_gt), dim=-1).mean()
221 |
222 | return {
223 | 'translation': c2c(translation_error)
224 | }
225 |
226 | def verbose(self):
227 | res = {}
228 | for loss in self.loss_recorder.losses.values():
229 | res[loss.name] = {'train': loss.current()[0], 'val': loss.current()[1]}
230 |
231 | return res
232 |
233 | def validate(self):
234 | with torch.no_grad():
235 | self.forward()
236 | self.backward(validation=True)
237 |
238 | def save(self, optimal=False):
239 | if optimal:
240 | path = os.path.join(self.args.save_dir, 'results', 'model')
241 | else:
242 | path = os.path.join(self.model_save_dir, f'{self.epoch_cnt:04d}')
243 |
244 | os.makedirs(path, exist_ok=True)
245 | predictor = self.predictor.module if isinstance(self.predictor, nn.DataParallel) else self.predictor
246 | torch.save(predictor.state_dict(), os.path.join(path, 'predictor.pth'))
247 | torch.save(self.optimizer.state_dict(), os.path.join(path, 'optimizer.pth'))
248 | if self.args.scheduler.name:
249 | torch.save(self.schedulers[0].state_dict(), os.path.join(path, 'scheduler.pth'))
250 | self.loss_recorder.save(path)
251 |
252 | print(f'Save at {path} succeeded')
253 |
254 | def load(self, epoch=None, optimal=False):
255 | if optimal:
256 | path = os.path.join(self.args.save_dir, 'results', 'model')
257 | else:
258 | if epoch is None:
259 | all = [int(q) for q in os.listdir(self.model_save_dir)]
260 | if len(all) == 0:
261 | raise RuntimeError(f'Empty loading path {self.model_save_dir}')
262 | epoch = sorted(all)[-1]
263 | path = os.path.join(self.model_save_dir, f'{epoch:04d}')
264 |
265 | print(f'Loading from {path}')
266 | predictor = self.predictor.module if isinstance(self.predictor, nn.DataParallel) else self.predictor
267 | predictor.load_state_dict(torch.load(os.path.join(path, 'predictor.pth'), map_location=self.device))
268 | if self.is_train:
269 | self.optimizer.load_state_dict(torch.load(os.path.join(path, 'optimizer.pth')))
270 | self.loss_recorder.load(path)
271 | if self.args.scheduler.name:
272 | self.schedulers[0].load_state_dict(torch.load(os.path.join(path, 'scheduler.pth')))
273 | self.epoch_cnt = epoch if not optimal else 0
274 |
275 | print('Load succeeded')
276 |
277 | def compute_test_result(self):
278 | os.makedirs(self.smpl_dir, exist_ok=True)
279 |
280 | b_size = self.input_data['trans'].shape[0]
281 | n_joints = self.input_data['pos'].shape[2]
282 |
283 | trans = self.pred_data['trans']
284 | trans_gt = self.input_data['trans']
285 |
286 | rotmat_gt = rotation_6d_to_matrix(self.input_data['global_xform']) # (B, T, J, 3, 3)
287 | local_rotmat_gt = self.fk.global_to_local(rotmat_gt.view(-1, n_joints, 3, 3)) # (B x T, J, 3, 3)
288 | local_rotmat_gt = local_rotmat_gt.view(b_size, -1, n_joints, 3, 3) # (B, T, J, 3, 3)
289 | local_rotmat = local_rotmat_gt.clone()
290 |
291 | for i in range(b_size):
292 | export_ply_trajectory(points=trans[i], color=(255, 0, 0), ply_fname=os.path.join(self.smpl_dir, f'test_{self.test_index + i:03d}.ply'))
293 | export_ply_trajectory(points=trans_gt[i], color=(0, 255, 0), ply_fname=os.path.join(self.smpl_dir, f'test_{self.test_index + i:03d}_gt.ply'))
294 |
295 | poses = c2c(matrix_to_axis_angle(local_rotmat[i])) # (T, J, 3)
296 | poses_gt = c2c(matrix_to_axis_angle(local_rotmat_gt[i])) # (T, J, 3)
297 |
298 | poses = poses.reshape((poses.shape[0], -1)) # (T, J x 3)
299 | poses = np.pad(poses, [(0, 0), (0, 93)], mode='constant')
300 | poses_gt = poses_gt.reshape((poses_gt.shape[0], -1)) # (T, J x 3)
301 | poses_gt = np.pad(poses_gt, [(0, 0), (0, 93)], mode='constant')
302 |
303 | np.savez(os.path.join(self.smpl_dir, f'test_{self.test_index + i:03d}.npz'),
304 | poses=poses, trans=c2c(trans[i]), betas=np.zeros(10), gender=self.args.data.gender, mocap_framerate=self.args.data.fps)
305 | np.savez(os.path.join(self.smpl_dir, f'test_{self.test_index + i:03d}_gt.npz'),
306 | poses=poses_gt, trans=c2c(trans_gt[i]), betas=np.zeros(10), gender=self.args.data.gender, mocap_framerate=self.args.data.fps)
307 |
308 | self.test_index += b_size
309 |
--------------------------------------------------------------------------------
/src/nemf/loss_record.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import torch
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 |
8 | class SingleLoss:
9 | def __init__(self, name: str, writer: SummaryWriter):
10 | self.name = name
11 | self.loss_step = []
12 | self.loss_epoch = []
13 | self.loss_epoch_tmp = []
14 | self.loss_epoch_val = []
15 | self.writer = writer
16 |
17 | def add_scalar(self, val, step=None, validation=False):
18 | if validation:
19 | self.loss_epoch_val.append(val)
20 | else:
21 | if step is None:
22 | step = len(self.loss_step)
23 | self.loss_step.append(val)
24 | self.loss_epoch_tmp.append(val)
25 | self.writer.add_scalar('Loss/step_' + self.name, val, step)
26 |
27 | def current(self):
28 | return self.loss_epoch[-1]
29 |
30 | def epoch(self, step=None):
31 | if step is None:
32 | step = len(self.loss_epoch)
33 | if self.loss_epoch_tmp:
34 | loss_avg = sum(self.loss_epoch_tmp) / len(self.loss_epoch_tmp)
35 | else:
36 | loss_avg = 0
37 | self.loss_epoch_tmp = []
38 | if self.loss_epoch_val:
39 | loss_avg_val = sum(self.loss_epoch_val) / len(self.loss_epoch_val)
40 | else:
41 | loss_avg_val = 0
42 | self.loss_epoch_val = []
43 | self.loss_epoch.append([loss_avg, loss_avg_val])
44 | self.writer.add_scalars('Loss/epoch_' + self.name, {'train': loss_avg, 'val': loss_avg_val}, step)
45 |
46 | def save(self):
47 | return {
48 | 'loss_step': self.loss_step,
49 | 'loss_epoch': self.loss_epoch
50 | }
51 |
52 |
53 | class LossRecorder:
54 | def __init__(self, writer: SummaryWriter):
55 | self.losses = {}
56 | self.writer = writer
57 |
58 | def add_scalar(self, name, val, step=None, validation=False):
59 | if isinstance(val, torch.Tensor):
60 | val = val.item()
61 | if name not in self.losses:
62 | self.losses[name] = SingleLoss(name, self.writer)
63 | self.losses[name].add_scalar(val, step, validation)
64 |
65 | def epoch(self, step=None):
66 | for loss in self.losses.values():
67 | loss.epoch(step)
68 |
69 | def save(self, path):
70 | data = {}
71 | for key, value in self.losses.items():
72 | data[key] = value.save()
73 | with open(os.path.join(path, 'loss.pkl'), 'wb') as f:
74 | pickle.dump(data, f)
75 |
76 | def load(self, path):
77 | with open(os.path.join(path, 'loss.pkl'), 'rb') as f:
78 | data = pickle.load(f)
79 | for key, value in data.items():
80 | self.losses[key] = SingleLoss(key, self.writer)
81 | self.losses[key].loss_step = value['loss_step']
82 | self.losses[key].loss_epoch = value['loss_epoch']
83 |
--------------------------------------------------------------------------------
/src/nemf/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class GeodesicLoss(nn.Module):
6 | def __init__(self):
7 | super(GeodesicLoss, self).__init__()
8 |
9 | def compute_geodesic_distance(self, m1, m2, epsilon=1e-7):
10 | """ Compute the geodesic distance between two rotation matrices.
11 | Args:
12 | m1, m2: Two rotation matrices with the shape (batch x 3 x 3).
13 | Returns:
14 | The minimal angular difference between two rotation matrices in radian form [0, pi].
15 | """
16 | batch = m1.shape[0]
17 | m = torch.bmm(m1, m2.permute(0, 2, 1)) # batch*3*3
18 |
19 | cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2
20 | # cos = (m.diagonal(dim1=-2, dim2=-1).sum(-1) -1) /2
21 | # cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()))
22 | # cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda()) * -1)
23 | cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon)
24 | theta = torch.acos(cos)
25 |
26 | return theta
27 |
28 | def __call__(self, m1, m2, reduction='mean'):
29 | loss = self.compute_geodesic_distance(m1, m2)
30 |
31 | if reduction == 'mean':
32 | return loss.mean()
33 | elif reduction == 'none':
34 | return loss
35 | else:
36 | raise RuntimeError(f'unsupported reduction: {reduction}')
37 | raise RuntimeError(f'unsupported reduction: {reduction}')
38 |
--------------------------------------------------------------------------------
/src/nemf/neural_motion.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class PositionalEncoding(nn.Module):
8 | def __init__(self, n_functions):
9 | super(PositionalEncoding, self).__init__()
10 |
11 | self.register_buffer('frequencies', 2.0 ** torch.arange(n_functions))
12 |
13 | def forward(self, x):
14 | """
15 | Args:
16 | x: tensor of shape [..., dim]
17 |
18 | Returns:
19 | embedding: a temporal embedding of `x` of shape [..., n_functions * dim * 2]
20 | """
21 | freq = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1)
22 |
23 | embedding = torch.zeros(*freq.shape[:-1], freq.shape[-1] * 2).cuda()
24 | embedding[..., 0::2] = freq.sin()
25 | embedding[..., 1::2] = freq.cos()
26 |
27 | return embedding
28 |
29 |
30 | class Sine(nn.Module):
31 | def __init__(self, w0=1.):
32 | super(Sine, self).__init__()
33 |
34 | self.w0 = w0
35 |
36 | def forward(self, x):
37 | return torch.sin(self.w0 * x)
38 |
39 |
40 | class SirenBlock(nn.Module):
41 | def __init__(self, in_features, out_features, w0=30, c=6, is_first=False, use_bias=True, activation=None):
42 | super(SirenBlock, self).__init__()
43 |
44 | self.in_features = in_features
45 | self.is_first = is_first
46 |
47 | weight = torch.zeros(out_features, in_features)
48 | bias = torch.zeros(out_features) if use_bias else None
49 | self.init(weight, bias, c=c, w0=w0)
50 |
51 | self.weight = nn.Parameter(weight)
52 | self.bias = nn.Parameter(bias) if use_bias else None
53 | self.activation = Sine(w0) if activation is None else activation
54 |
55 | def init(self, weight, bias, c, w0):
56 | n = self.in_features
57 |
58 | w_std = (1 / n) if self.is_first else (np.sqrt(c / n) / w0)
59 | weight.uniform_(-w_std, w_std)
60 |
61 | if bias is not None:
62 | bias.uniform_(-w_std, w_std)
63 |
64 | def forward(self, x):
65 | out = F.linear(x, self.weight, self.bias)
66 |
67 | return self.activation(out)
68 |
69 |
70 | class FCBlock(nn.Module):
71 | def __init__(self, in_features, out_features, norm_layer=False, activation=None):
72 | super(FCBlock, self).__init__()
73 |
74 | self.fc = nn.Linear(in_features, out_features)
75 | self.residual = (in_features == out_features) # when the input and output have the same dimensions, build a residual block
76 | self.norm_layer = nn.LayerNorm(out_features) if norm_layer else None
77 | self.activation = nn.ReLU(inplace=True) if activation is None else activation
78 |
79 | def forward(self, x):
80 | """
81 | Args:
82 | x: (B, T, D), features are in the last dimension.
83 | """
84 | out = self.fc(x)
85 |
86 | if self.norm_layer is not None:
87 | out = self.norm_layer(out)
88 |
89 | if self.residual:
90 | return self.activation(out) + x
91 |
92 | return self.activation(out)
93 |
94 |
95 | class NeuralMotionField(nn.Module):
96 | def __init__(self, args):
97 | super(NeuralMotionField, self).__init__()
98 |
99 | self.args = args
100 |
101 | if args.bandwidth != 0:
102 | self.positional_encoding = PositionalEncoding(args.bandwidth)
103 | embedding_dim = args.bandwidth * 2
104 | else:
105 | embedding_dim = 1
106 |
107 | hidden_neuron = args.hidden_neuron
108 | in_features = 1 + args.local_z if args.siren else embedding_dim + args.local_z
109 | local_in_features = hidden_neuron + in_features if args.skip_connection else hidden_neuron
110 | global_in_features = local_in_features + args.global_z if args.skip_connection else hidden_neuron
111 |
112 | layers = [
113 | SirenBlock(in_features, hidden_neuron, is_first=True) if args.siren else FCBlock(in_features, hidden_neuron, norm_layer=args.norm_layer),
114 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer),
115 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer),
116 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer),
117 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer),
118 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer),
119 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer),
120 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer),
121 | SirenBlock(global_in_features, hidden_neuron) if args.siren else FCBlock(global_in_features, hidden_neuron, norm_layer=args.norm_layer),
122 | SirenBlock(global_in_features, hidden_neuron) if args.siren else FCBlock(global_in_features, hidden_neuron, norm_layer=args.norm_layer),
123 | SirenBlock(global_in_features, hidden_neuron) if args.siren else FCBlock(global_in_features, hidden_neuron, norm_layer=args.norm_layer)
124 | ]
125 | self.mlp = nn.ModuleList(layers)
126 | self.skip_layers = [] if not args.skip_connection else list(range(1, len(self.mlp)))
127 | self.local_layers = list(range(8))
128 | self.local_linear = nn.Sequential(nn.Linear(hidden_neuron, args.local_output))
129 | self.global_layers = list(range(8, len(self.mlp)))
130 | self.global_linear = nn.Sequential(nn.Linear(hidden_neuron, args.global_output))
131 |
132 | def forward(self, t, z_l=None, z_g=None):
133 | if self.args.bandwidth != 0 and self.args.siren is False:
134 | t = self.positional_encoding(t)
135 |
136 | if z_l is not None:
137 | z_l = z_l.unsqueeze(1).expand(-1, t.shape[1], -1)
138 | x = torch.cat([t, z_l], dim=-1)
139 | else:
140 | x = t
141 |
142 | skip_in = x
143 | for i in self.local_layers:
144 | if i in self.skip_layers:
145 | x = torch.cat((x, skip_in), dim=-1)
146 | x = self.mlp[i](x)
147 | local_output = self.local_linear(x)
148 |
149 | if z_g is not None:
150 | z_g = z_g.unsqueeze(1).expand(-1, t.shape[1], -1)
151 | for i in self.global_layers:
152 | if i in self.skip_layers:
153 | x = torch.cat((x, skip_in, z_g), dim=-1)
154 | x = self.mlp[i](x)
155 | global_output = self.global_linear(x)
156 | else:
157 | global_output = None
158 |
159 | return local_output, global_output
160 |
--------------------------------------------------------------------------------
/src/nemf/prior.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .residual_blocks import ResidualBlock, SkeletonResidual, residual_ratio
5 | from .skeleton import SkeletonConv, SkeletonPool, find_neighbor
6 |
7 |
8 | class LocalEncoder(nn.Module):
9 | def __init__(self, args, topology):
10 | super(LocalEncoder, self).__init__()
11 | self.topologies = [topology]
12 | self.channel_base = [args.channel_base]
13 |
14 | self.channel_list = []
15 | self.edge_num = [len(topology)]
16 | self.pooling_list = []
17 | self.layers = nn.ModuleList()
18 | self.args = args
19 | # self.convs = []
20 |
21 | kernel_size = args.kernel_size
22 | kernel_even = False if kernel_size % 2 else True
23 | padding = (kernel_size - 1) // 2
24 | bias = True
25 |
26 | for _ in range(args.num_layers):
27 | self.channel_base.append(self.channel_base[-1] * 2)
28 |
29 | for i in range(args.num_layers):
30 | seq = []
31 | neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist)
32 | in_channels = self.channel_base[i] * self.edge_num[i]
33 | out_channels = self.channel_base[i + 1] * self.edge_num[i]
34 | if i == 0:
35 | self.channel_list.append(in_channels)
36 | self.channel_list.append(out_channels)
37 | last_pool = True if i == args.num_layers - 1 else False
38 |
39 | # (T, J, D) => (T, J', D)
40 | pool = SkeletonPool(edges=self.topologies[i], pooling_mode=args.skeleton_pool,
41 | channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool)
42 |
43 | if args.use_residual_blocks:
44 | # (T, J, D) => (T/2, J', 2D)
45 | seq.append(SkeletonResidual(self.topologies[i], neighbour_list, joint_num=self.edge_num[i], in_channels=in_channels, out_channels=out_channels,
46 | kernel_size=kernel_size, stride=2, padding=padding, padding_mode=args.padding_mode, bias=bias,
47 | extra_conv=args.extra_conv, pooling_mode=args.skeleton_pool, activation=args.activation, last_pool=last_pool))
48 | else:
49 | for _ in range(args.extra_conv):
50 | # (T, J, D) => (T, J, D)
51 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
52 | joint_num=self.edge_num[i], kernel_size=kernel_size - 1 if kernel_even else kernel_size,
53 | stride=1,
54 | padding=padding, padding_mode=args.padding_mode, bias=bias))
55 | seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh())
56 | # (T, J, D) => (T/2, J, 2D)
57 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
58 | joint_num=self.edge_num[i], kernel_size=kernel_size, stride=2,
59 | padding=padding, padding_mode=args.padding_mode, bias=bias, add_offset=False,
60 | in_offset_channel=3 * self.channel_base[i] // self.channel_base[0]))
61 | # self.convs.append(seq[-1])
62 |
63 | seq.append(pool)
64 | seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh())
65 | self.layers.append(nn.Sequential(*seq))
66 |
67 | self.topologies.append(pool.new_edges)
68 | self.pooling_list.append(pool.pooling_list)
69 | self.edge_num.append(len(self.topologies[-1]))
70 |
71 | in_features = self.channel_base[-1] * len(self.pooling_list[-1])
72 | in_features *= args.temporal_scale
73 | self.mu = nn.Linear(in_features, args.z_dim)
74 | self.logvar = nn.Linear(in_features, args.z_dim)
75 |
76 | def forward(self, input):
77 | output = input
78 | for layer in self.layers:
79 | output = layer(output)
80 | output = output.view(output.shape[0], -1)
81 |
82 | return self.mu(output), self.logvar(output)
83 |
84 |
85 | class GlobalEncoder(nn.Module):
86 | def __init__(self, args):
87 | super(GlobalEncoder, self).__init__()
88 |
89 | in_channels = args.in_channels # root orientation + root velocity
90 | out_channels = 512
91 | kernel_size = args.kernel_size
92 | padding = (kernel_size - 1) // 2
93 |
94 | self.encoder = nn.Sequential(
95 | ResidualBlock(in_channels=in_channels, out_channels=128, kernel_size=kernel_size, stride=2, padding=padding, residual_ratio=residual_ratio(1), activation=args.activation),
96 | ResidualBlock(in_channels=128, out_channels=256, kernel_size=kernel_size, stride=2, padding=padding, residual_ratio=residual_ratio(2), activation=args.activation),
97 | ResidualBlock(in_channels=256, out_channels=512, kernel_size=kernel_size, stride=2, padding=padding, residual_ratio=residual_ratio(3), activation=args.activation),
98 | ResidualBlock(in_channels=512, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=padding, residual_ratio=residual_ratio(4), activation=args.activation)
99 | )
100 |
101 | in_features = out_channels * args.temporal_scale
102 | self.mu = nn.Linear(in_features, args.z_dim)
103 | self.logvar = nn.Linear(in_features, args.z_dim)
104 |
105 | def forward(self, input):
106 | feature = self.encoder(input)
107 | feature = feature.view(feature.shape[0], -1)
108 |
109 | return self.mu(feature), self.logvar(feature)
110 |
--------------------------------------------------------------------------------
/src/nemf/residual_blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .skeleton import SkeletonConv, SkeletonPool, SkeletonUnpool
5 |
6 |
7 | def residual_ratio(k):
8 | return 1 / (k + 1)
9 |
10 |
11 | class Affine(nn.Module):
12 | def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0):
13 | super(Affine, self).__init__()
14 | if scale:
15 | self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init)
16 | else:
17 | self.register_parameter('scale', None)
18 |
19 | if bias:
20 | self.bias = nn.Parameter(torch.zeros(num_parameters))
21 | else:
22 | self.register_parameter('bias', None)
23 |
24 | def forward(self, input):
25 | output = input
26 | if self.scale is not None:
27 | scale = self.scale.unsqueeze(0)
28 | while scale.dim() < input.dim():
29 | scale = scale.unsqueeze(2)
30 | output = output.mul(scale)
31 |
32 | if self.bias is not None:
33 | bias = self.bias.unsqueeze(0)
34 | while bias.dim() < input.dim():
35 | bias = bias.unsqueeze(2)
36 | output += bias
37 |
38 | return output
39 |
40 |
41 | class BatchStatistics(nn.Module):
42 | def __init__(self, affine=-1):
43 | super(BatchStatistics, self).__init__()
44 | self.affine = nn.Sequential() if affine == -1 else Affine(affine)
45 | self.loss = 0
46 |
47 | def clear_loss(self):
48 | self.loss = 0
49 |
50 | def compute_loss(self, input):
51 | input_flat = input.view(input.size(1), input.numel() // input.size(1))
52 | mu = input_flat.mean(1)
53 | logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log()
54 |
55 | self.loss = mu.pow(2).mean() + logvar.pow(2).mean()
56 |
57 | def forward(self, input):
58 | self.compute_loss(input)
59 | return self.affine(input)
60 |
61 |
62 | class ResidualBlock(nn.Module):
63 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False):
64 | super(ResidualBlock, self).__init__()
65 |
66 | self.residual_ratio = residual_ratio
67 | self.shortcut_ratio = 1 - residual_ratio
68 |
69 | residual = []
70 | residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))
71 | if batch_statistics:
72 | residual.append(BatchStatistics(out_channels))
73 | if not last_layer:
74 | residual.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
75 | self.residual = nn.Sequential(*residual)
76 |
77 | self.shortcut = nn.Sequential(
78 | nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(),
79 | nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
80 | BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential()
81 | )
82 |
83 | def forward(self, input):
84 | return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
85 |
86 |
87 | class ResidualBlockTranspose(nn.Module):
88 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation):
89 | super(ResidualBlockTranspose, self).__init__()
90 |
91 | self.residual_ratio = residual_ratio
92 | self.shortcut_ratio = 1 - residual_ratio
93 |
94 | self.residual = nn.Sequential(
95 | nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding),
96 | nn.PReLU() if activation == 'relu' else nn.Tanh()
97 | )
98 |
99 | self.shortcut = nn.Sequential(
100 | nn.Upsample(scale_factor=2, mode='linear', align_corners=False) if stride == 2 else nn.Sequential(),
101 | nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
102 | )
103 |
104 | def forward(self, input):
105 | return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
106 |
107 |
108 | class SkeletonResidual(nn.Module):
109 | def __init__(self, topology, neighbour_list, joint_num, in_channels, out_channels, kernel_size, stride, padding, padding_mode, bias, extra_conv, pooling_mode, activation, last_pool):
110 | super(SkeletonResidual, self).__init__()
111 |
112 | kernel_even = False if kernel_size % 2 else True
113 |
114 | seq = []
115 | for _ in range(extra_conv):
116 | # (T, J, D) => (T, J, D)
117 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
118 | joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
119 | stride=1,
120 | padding=padding, padding_mode=padding_mode, bias=bias))
121 | seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
122 | # (T, J, D) => (T/2, J, 2D)
123 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
124 | joint_num=joint_num, kernel_size=kernel_size, stride=stride,
125 | padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False))
126 | seq.append(nn.GroupNorm(8, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!!
127 | self.residual = nn.Sequential(*seq)
128 |
129 | # (T, J, D) => (T/2, J, 2D)
130 | self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
131 | joint_num=joint_num, kernel_size=1, stride=stride, padding=0,
132 | bias=True, add_offset=False)
133 |
134 | seq = []
135 | # (T/2, J, 2D) => (T/2, J', 2D)
136 | pool = SkeletonPool(edges=topology, pooling_mode=pooling_mode,
137 | channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool)
138 | if len(pool.pooling_list) != pool.edge_num:
139 | seq.append(pool)
140 | seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
141 | self.common = nn.Sequential(*seq)
142 |
143 | def forward(self, input):
144 | output = self.residual(input) + self.shortcut(input)
145 |
146 | return self.common(output)
147 |
148 |
149 | class SkeletonResidualTranspose(nn.Module):
150 | def __init__(self, neighbour_list, joint_num, in_channels, out_channels, kernel_size, padding, padding_mode, bias, extra_conv, pooling_list, upsampling, activation, last_layer):
151 | super(SkeletonResidualTranspose, self).__init__()
152 |
153 | kernel_even = False if kernel_size % 2 else True
154 |
155 | seq = []
156 | # (T, J, D) => (2T, J, D)
157 | if upsampling is not None:
158 | seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False))
159 | # (2T, J, D) => (2T, J', D)
160 | unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list))
161 | if unpool.input_edge_num != unpool.output_edge_num:
162 | seq.append(unpool)
163 | self.common = nn.Sequential(*seq)
164 |
165 | seq = []
166 | for _ in range(extra_conv):
167 | # (2T, J', D) => (2T, J', D)
168 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
169 | joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
170 | stride=1,
171 | padding=padding, padding_mode=padding_mode, bias=bias))
172 | seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
173 | # (2T, J', D) => (2T, J', D/2)
174 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
175 | joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
176 | stride=1,
177 | padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False))
178 | self.residual = nn.Sequential(*seq)
179 |
180 | # (2T, J', D) => (2T, J', D/2)
181 | self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
182 | joint_num=joint_num, kernel_size=1, stride=1, padding=0,
183 | bias=True, add_offset=False)
184 |
185 | if activation == 'relu':
186 | self.activation = nn.PReLU() if not last_layer else None
187 | else:
188 | self.activation = nn.Tanh() if not last_layer else None
189 |
190 | def forward(self, input):
191 | output = self.common(input)
192 | output = self.residual(output) + self.shortcut(output)
193 |
194 | if self.activation is not None:
195 | return self.activation(output)
196 | else:
197 | return output
198 |
--------------------------------------------------------------------------------
/src/nemf/skeleton.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class SkeletonConv(nn.Module):
10 | def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,
11 | bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):
12 | self.in_channels_per_joint = in_channels // joint_num
13 | self.out_channels_per_joint = out_channels // joint_num
14 | if in_channels % joint_num != 0 or out_channels % joint_num != 0:
15 | raise Exception('BAD')
16 | super(SkeletonConv, self).__init__()
17 |
18 | if padding_mode == 'zeros':
19 | padding_mode = 'constant'
20 | if padding_mode == 'reflection':
21 | padding_mode = 'reflect'
22 |
23 | self.expanded_neighbour_list = []
24 | self.expanded_neighbour_list_offset = []
25 | self.neighbour_list = neighbour_list
26 | self.add_offset = add_offset
27 | self.joint_num = joint_num
28 |
29 | self.stride = stride
30 | self.dilation = 1
31 | self.groups = 1
32 | self.padding = padding
33 | self.padding_mode = padding_mode
34 | self._padding_repeated_twice = (padding, padding)
35 |
36 | for neighbour in neighbour_list:
37 | expanded = []
38 | for k in neighbour:
39 | for i in range(self.in_channels_per_joint):
40 | expanded.append(k * self.in_channels_per_joint + i)
41 | self.expanded_neighbour_list.append(expanded)
42 |
43 | if self.add_offset:
44 | self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
45 |
46 | for neighbour in neighbour_list:
47 | expanded = []
48 | for k in neighbour:
49 | for i in range(add_offset):
50 | expanded.append(k * in_offset_channel + i)
51 | self.expanded_neighbour_list_offset.append(expanded)
52 |
53 | self.weight = torch.zeros(out_channels, in_channels, kernel_size)
54 | if bias:
55 | self.bias = torch.zeros(out_channels)
56 | else:
57 | self.register_parameter('bias', None)
58 |
59 | self.mask = torch.zeros_like(self.weight)
60 | for i, neighbour in enumerate(self.expanded_neighbour_list):
61 | self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
62 | self.mask = nn.Parameter(self.mask, requires_grad=False)
63 |
64 | self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
65 | 'joint_num={}, stride={}, padding={}, bias={})'.format(
66 | in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
67 | )
68 |
69 | self.reset_parameters()
70 |
71 | def reset_parameters(self):
72 | for i, neighbour in enumerate(self.expanded_neighbour_list):
73 | """ Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """
74 | tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
75 | neighbour, ...])
76 | nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
77 | self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
78 | neighbour, ...] = tmp
79 | if self.bias is not None:
80 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
81 | self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
82 | bound = 1 / math.sqrt(fan_in)
83 | tmp = torch.zeros_like(
84 | self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
85 | nn.init.uniform_(tmp, -bound, bound)
86 | self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp
87 |
88 | self.weight = nn.Parameter(self.weight)
89 | if self.bias is not None:
90 | self.bias = nn.Parameter(self.bias)
91 |
92 | def set_offset(self, offset):
93 | if not self.add_offset:
94 | raise Exception('Wrong Combination of Parameters')
95 | self.offset = offset.reshape(offset.shape[0], -1)
96 |
97 | def forward(self, input):
98 | # print('SkeletonConv')
99 | weight_masked = self.weight * self.mask
100 | # print(f'input: {input.size()}')
101 | res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
102 | weight_masked, self.bias, self.stride,
103 | 0, self.dilation, self.groups)
104 |
105 | if self.add_offset:
106 | offset_res = self.offset_enc(self.offset)
107 | offset_res = offset_res.reshape(offset_res.shape + (1, ))
108 | res += offset_res / 100
109 | # print(f'res: {res.size()}')
110 | return res
111 |
112 |
113 | class SkeletonLinear(nn.Module):
114 | def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
115 | super(SkeletonLinear, self).__init__()
116 | self.neighbour_list = neighbour_list
117 | self.in_channels = in_channels
118 | self.out_channels = out_channels
119 | self.in_channels_per_joint = in_channels // len(neighbour_list)
120 | self.out_channels_per_joint = out_channels // len(neighbour_list)
121 | self.extra_dim1 = extra_dim1
122 | self.expanded_neighbour_list = []
123 |
124 | for neighbour in neighbour_list:
125 | expanded = []
126 | for k in neighbour:
127 | for i in range(self.in_channels_per_joint):
128 | expanded.append(k * self.in_channels_per_joint + i)
129 | self.expanded_neighbour_list.append(expanded)
130 |
131 | self.weight = torch.zeros(out_channels, in_channels)
132 | self.mask = torch.zeros(out_channels, in_channels)
133 | self.bias = nn.Parameter(torch.Tensor(out_channels))
134 |
135 | self.reset_parameters()
136 |
137 | def reset_parameters(self):
138 | for i, neighbour in enumerate(self.expanded_neighbour_list):
139 | tmp = torch.zeros_like(
140 | self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
141 | )
142 | self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
143 | nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
144 | self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp
145 |
146 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
147 | bound = 1 / math.sqrt(fan_in)
148 | nn.init.uniform_(self.bias, -bound, bound)
149 |
150 | self.weight = nn.Parameter(self.weight)
151 | self.mask = nn.Parameter(self.mask, requires_grad=False)
152 |
153 | def forward(self, input):
154 | input = input.reshape(input.shape[0], -1)
155 | weight_masked = self.weight * self.mask
156 | res = F.linear(input, weight_masked, self.bias)
157 | if self.extra_dim1:
158 | res = res.reshape(res.shape + (1,))
159 | return res
160 |
161 |
162 | class SkeletonPool(nn.Module):
163 | def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
164 | super(SkeletonPool, self).__init__()
165 |
166 | if pooling_mode != 'mean':
167 | raise Exception('Unimplemented pooling mode in matrix_implementation')
168 |
169 | self.channels_per_edge = channels_per_edge
170 | self.pooling_mode = pooling_mode
171 | self.edge_num = len(edges)
172 | # self.edge_num = len(edges) + 1
173 | self.seq_list = []
174 | self.pooling_list = []
175 | self.new_edges = []
176 | degree = [0] * 100 # each element represents the degree of the corresponding joint
177 |
178 | for edge in edges:
179 | degree[edge[0]] += 1
180 | degree[edge[1]] += 1
181 |
182 | # seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2.
183 | def find_seq(j, seq):
184 | nonlocal self, degree, edges
185 |
186 | if degree[j] > 2 and j != 0:
187 | self.seq_list.append(seq)
188 | seq = []
189 |
190 | if degree[j] == 1:
191 | self.seq_list.append(seq)
192 | return
193 |
194 | for idx, edge in enumerate(edges):
195 | if edge[0] == j:
196 | find_seq(edge[1], seq + [idx])
197 |
198 | find_seq(0, [])
199 | # print(f'self.seq_list: {self.seq_list}')
200 |
201 | for seq in self.seq_list:
202 | if last_pool:
203 | self.pooling_list.append(seq)
204 | continue
205 | if len(seq) % 2 == 1:
206 | self.pooling_list.append([seq[0]])
207 | self.new_edges.append(edges[seq[0]])
208 | seq = seq[1:]
209 | for i in range(0, len(seq), 2):
210 | self.pooling_list.append([seq[i], seq[i + 1]])
211 | self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
212 | # print(f'self.pooling_list: {self.pooling_list}')
213 | # print(f'self.new_egdes: {self.new_edges}')
214 |
215 | # add global position
216 | # self.pooling_list.append([self.edge_num - 1])
217 |
218 | self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
219 | len(edges), len(self.pooling_list)
220 | )
221 |
222 | self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
223 |
224 | for i, pair in enumerate(self.pooling_list):
225 | for j in pair:
226 | for c in range(channels_per_edge):
227 | self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
228 |
229 | self.weight = nn.Parameter(self.weight, requires_grad=False)
230 |
231 | def forward(self, input: torch.Tensor):
232 | # print('SkeletonPool')
233 | # print(f'input: {input.size()}')
234 | # print(f'self.weight: {self.weight.size()}')
235 | return torch.matmul(self.weight, input)
236 |
237 |
238 | class SkeletonUnpool(nn.Module):
239 | def __init__(self, pooling_list, channels_per_edge):
240 | super(SkeletonUnpool, self).__init__()
241 | self.pooling_list = pooling_list
242 | self.input_edge_num = len(pooling_list)
243 | self.output_edge_num = 0
244 | self.channels_per_edge = channels_per_edge
245 | for t in self.pooling_list:
246 | self.output_edge_num += len(t)
247 |
248 | self.description = 'SkeletonUnpool(in_edge_num={}, out_edge_num={})'.format(
249 | self.input_edge_num, self.output_edge_num,
250 | )
251 |
252 | self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge)
253 |
254 | for i, pair in enumerate(self.pooling_list):
255 | for j in pair:
256 | for c in range(channels_per_edge):
257 | self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
258 |
259 | self.weight = nn.Parameter(self.weight)
260 | self.weight.requires_grad_(False)
261 |
262 | def forward(self, input: torch.Tensor):
263 | # print('SkeletonUnpool')
264 | # print(f'input: {input.size()}')
265 | # print(f'self.weight: {self.weight.size()}')
266 | return torch.matmul(self.weight, input)
267 |
268 |
269 | """
270 | Helper functions for skeleton operation
271 | """
272 |
273 |
274 | def dfs(x, fa, vis, dist):
275 | vis[x] = 1
276 | for y in range(len(fa)):
277 | if (fa[y] == x or fa[x] == y) and vis[y] == 0:
278 | dist[y] = dist[x] + 1
279 | dfs(y, fa, vis, dist)
280 |
281 |
282 | """
283 | def find_neighbor_joint(fa, threshold):
284 | neighbor_list = [[]]
285 | for x in range(1, len(fa)):
286 | vis = [0 for _ in range(len(fa))]
287 | dist = [0 for _ in range(len(fa))]
288 | dist[0] = 10000
289 | dfs(x, fa, vis, dist)
290 | neighbor = []
291 | for j in range(1, len(fa)):
292 | if dist[j] <= threshold:
293 | neighbor.append(j)
294 | neighbor_list.append(neighbor)
295 |
296 | neighbor = [0]
297 | for i, x in enumerate(neighbor_list):
298 | if i == 0: continue
299 | if 1 in x:
300 | neighbor.append(i)
301 | neighbor_list[i] = [0] + neighbor_list[i]
302 | neighbor_list[0] = neighbor
303 | return neighbor_list
304 |
305 |
306 | def build_edge_topology(topology, offset):
307 | # get all edges (pa, child, offset)
308 | edges = []
309 | joint_num = len(topology)
310 | for i in range(1, joint_num):
311 | edges.append((topology[i], i, offset[i]))
312 | return edges
313 | """
314 |
315 |
316 | def build_edge_topology(topology):
317 | # get all edges (pa, child)
318 | edges = []
319 | joint_num = len(topology)
320 | edges.append((0, joint_num)) # add an edge between the root joint and a virtual joint
321 | for i in range(1, joint_num):
322 | edges.append((topology[i], i))
323 | return edges
324 |
325 |
326 | def build_joint_topology(edges, origin_names):
327 | parent = []
328 | offset = []
329 | names = []
330 | edge2joint = []
331 | joint_from_edge = [] # -1 means virtual joint
332 | joint_cnt = 0
333 | out_degree = [0] * (len(edges) + 10)
334 | for edge in edges:
335 | out_degree[edge[0]] += 1
336 |
337 | # add root joint
338 | joint_from_edge.append(-1)
339 | parent.append(0)
340 | offset.append(np.array([0, 0, 0]))
341 | names.append(origin_names[0])
342 | joint_cnt += 1
343 |
344 | def make_topology(edge_idx, pa):
345 | nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt
346 | edge = edges[edge_idx]
347 | if out_degree[edge[0]] > 1:
348 | parent.append(pa)
349 | offset.append(np.array([0, 0, 0]))
350 | names.append(origin_names[edge[1]] + '_virtual')
351 | edge2joint.append(-1)
352 | pa = joint_cnt
353 | joint_cnt += 1
354 |
355 | parent.append(pa)
356 | offset.append(edge[2])
357 | names.append(origin_names[edge[1]])
358 | edge2joint.append(edge_idx)
359 | pa = joint_cnt
360 | joint_cnt += 1
361 |
362 | for idx, e in enumerate(edges):
363 | if e[0] == edge[1]:
364 | make_topology(idx, pa)
365 |
366 | for idx, e in enumerate(edges):
367 | if e[0] == 0:
368 | make_topology(idx, 0)
369 |
370 | return parent, offset, names, edge2joint
371 |
372 |
373 | def calc_edge_mat(edges):
374 | edge_num = len(edges)
375 | # edge_mat[i][j] = distance between edge(i) and edge(j)
376 | edge_mat = [[100000] * edge_num for _ in range(edge_num)]
377 | for i in range(edge_num):
378 | edge_mat[i][i] = 0
379 |
380 | # initialize edge_mat with direct neighbor
381 | for i, a in enumerate(edges):
382 | for j, b in enumerate(edges):
383 | link = 0
384 | for x in range(2):
385 | for y in range(2):
386 | if a[x] == b[y]:
387 | link = 1
388 | if link:
389 | edge_mat[i][j] = 1
390 |
391 | # calculate all the pairs distance
392 | for k in range(edge_num):
393 | for i in range(edge_num):
394 | for j in range(edge_num):
395 | edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j])
396 | return edge_mat
397 |
398 |
399 | def find_neighbor(edges, d):
400 | """
401 | Args:
402 | edges: The list contains N elements, each element represents (parent, child).
403 | d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1).
404 |
405 | Returns:
406 | The list contains N elements, each element is a list of edge indices whose distance <= d.
407 | """
408 | edge_mat = calc_edge_mat(edges)
409 | neighbor_list = []
410 | edge_num = len(edge_mat)
411 | for i in range(edge_num):
412 | neighbor = []
413 | for j in range(edge_num):
414 | if edge_mat[i][j] <= d:
415 | neighbor.append(j)
416 | neighbor_list.append(neighbor)
417 |
418 | # # add neighbor for global part
419 | # global_part_neighbor = neighbor_list[0].copy()
420 | # """
421 | # Line #373 is buggy. Thanks @crissallan!!
422 | # See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30)
423 | # However, fixing this bug will make it unable to load the pretrained model and
424 | # affect the reproducibility of quantitative error reported in the paper.
425 | # It is not a fatal bug so we didn't touch it and we are looking for possible solutions.
426 | # """
427 | # for i in global_part_neighbor:
428 | # neighbor_list[i].append(edge_num)
429 | # neighbor_list.append(global_part_neighbor)
430 |
431 | return neighbor_list
432 |
433 |
434 | def calc_node_depth(topology):
435 | def dfs(node, topology):
436 | if topology[node] < 0:
437 | return 0
438 | return 1 + dfs(topology[node], topology)
439 | depth = []
440 | for i in range(len(topology)):
441 | depth.append(dfs(i, topology))
442 |
443 | return depth
444 |
--------------------------------------------------------------------------------
/src/soft_dtw_cuda.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 Mehran Maghoumi
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | # ----------------------------------------------------------------------------------------------------------------------
23 |
24 | import numpy as np
25 | import torch
26 | import torch.cuda
27 | from numba import jit
28 | from torch.autograd import Function
29 | from numba import cuda
30 | import math
31 |
32 | # ----------------------------------------------------------------------------------------------------------------------
33 | @cuda.jit
34 | def compute_softdtw_cuda(D, gamma, bandwidth, max_i, max_j, n_passes, R):
35 | """
36 | :param seq_len: The length of the sequence (both inputs are assumed to be of the same size)
37 | :param n_passes: 2 * seq_len - 1 (The number of anti-diagonals)
38 | """
39 | # Each block processes one pair of examples
40 | b = cuda.blockIdx.x
41 | # We have as many threads as seq_len, because the most number of threads we need
42 | # is equal to the number of elements on the largest anti-diagonal
43 | tid = cuda.threadIdx.x
44 |
45 | # Compute I, J, the indices from [0, seq_len)
46 |
47 | # The row index is always the same as tid
48 | I = tid
49 |
50 | inv_gamma = 1.0 / gamma
51 |
52 | # Go over each anti-diagonal. Only process threads that fall on the current on the anti-diagonal
53 | for p in range(n_passes):
54 |
55 | # The index is actually 'p - tid' but need to force it in-bounds
56 | J = max(0, min(p - tid, max_j - 1))
57 |
58 | # For simplicity, we define i, j which start from 1 (offset from I, J)
59 | i = I + 1
60 | j = J + 1
61 |
62 | # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds
63 | if I + J == p and (I < max_i and J < max_j):
64 | # Don't compute if outside bandwidth
65 | if not (abs(i - j) > bandwidth > 0):
66 | r0 = -R[b, i - 1, j - 1] * inv_gamma
67 | r1 = -R[b, i - 1, j] * inv_gamma
68 | r2 = -R[b, i, j - 1] * inv_gamma
69 | rmax = max(max(r0, r1), r2)
70 | rsum = math.exp(r0 - rmax) + math.exp(r1 - rmax) + math.exp(r2 - rmax)
71 | softmin = -gamma * (math.log(rsum) + rmax)
72 | R[b, i, j] = D[b, i - 1, j - 1] + softmin
73 |
74 | # Wait for other threads in this block
75 | cuda.syncthreads()
76 |
77 | # ----------------------------------------------------------------------------------------------------------------------
78 | @cuda.jit
79 | def compute_softdtw_backward_cuda(D, R, inv_gamma, bandwidth, max_i, max_j, n_passes, E):
80 | k = cuda.blockIdx.x
81 | tid = cuda.threadIdx.x
82 |
83 | # Indexing logic is the same as above, however, the anti-diagonal needs to
84 | # progress backwards
85 | I = tid
86 |
87 | for p in range(n_passes):
88 | # Reverse the order to make the loop go backward
89 | rev_p = n_passes - p - 1
90 |
91 | # convert tid to I, J, then i, j
92 | J = max(0, min(rev_p - tid, max_j - 1))
93 |
94 | i = I + 1
95 | j = J + 1
96 |
97 | # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds
98 | if I + J == rev_p and (I < max_i and J < max_j):
99 |
100 | if math.isinf(R[k, i, j]):
101 | R[k, i, j] = -math.inf
102 |
103 | # Don't compute if outside bandwidth
104 | if not (abs(i - j) > bandwidth > 0):
105 | a = math.exp((R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) * inv_gamma)
106 | b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) * inv_gamma)
107 | c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) * inv_gamma)
108 | E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c
109 |
110 | # Wait for other threads in this block
111 | cuda.syncthreads()
112 |
113 | # ----------------------------------------------------------------------------------------------------------------------
114 | class _SoftDTWCUDA(Function):
115 | """
116 | CUDA implementation is inspired by the diagonal one proposed in https://ieeexplore.ieee.org/document/8400444:
117 | "Developing a pattern discovery method in time series data and its GPU acceleration"
118 | """
119 |
120 | @staticmethod
121 | def forward(ctx, D, gamma, bandwidth):
122 | dev = D.device
123 | dtype = D.dtype
124 | gamma = torch.cuda.FloatTensor([gamma])
125 | bandwidth = torch.cuda.FloatTensor([bandwidth])
126 |
127 | B = D.shape[0]
128 | N = D.shape[1]
129 | M = D.shape[2]
130 | threads_per_block = max(N, M)
131 | n_passes = 2 * threads_per_block - 1
132 |
133 | # Prepare the output array
134 | R = torch.ones((B, N + 2, M + 2), device=dev, dtype=dtype) * math.inf
135 | R[:, 0, 0] = 0
136 |
137 | # Run the CUDA kernel.
138 | # Set CUDA's grid size to be equal to the batch size (every CUDA block processes one sample pair)
139 | # Set the CUDA block size to be equal to the length of the longer sequence (equal to the size of the largest diagonal)
140 | compute_softdtw_cuda[B, threads_per_block](cuda.as_cuda_array(D.detach()),
141 | gamma.item(), bandwidth.item(), N, M, n_passes,
142 | cuda.as_cuda_array(R))
143 | ctx.save_for_backward(D, R.clone(), gamma, bandwidth)
144 | return R[:, -2, -2]
145 |
146 | @staticmethod
147 | def backward(ctx, grad_output):
148 | dev = grad_output.device
149 | dtype = grad_output.dtype
150 | D, R, gamma, bandwidth = ctx.saved_tensors
151 |
152 | B = D.shape[0]
153 | N = D.shape[1]
154 | M = D.shape[2]
155 | threads_per_block = max(N, M)
156 | n_passes = 2 * threads_per_block - 1
157 |
158 | D_ = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
159 | D_[:, 1:N + 1, 1:M + 1] = D
160 |
161 | R[:, :, -1] = -math.inf
162 | R[:, -1, :] = -math.inf
163 | R[:, -1, -1] = R[:, -2, -2]
164 |
165 | E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
166 | E[:, -1, -1] = 1
167 |
168 | # Grid and block sizes are set same as done above for the forward() call
169 | compute_softdtw_backward_cuda[B, threads_per_block](cuda.as_cuda_array(D_),
170 | cuda.as_cuda_array(R),
171 | 1.0 / gamma.item(), bandwidth.item(), N, M, n_passes,
172 | cuda.as_cuda_array(E))
173 | E = E[:, 1:N + 1, 1:M + 1]
174 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None
175 |
176 |
177 | # ----------------------------------------------------------------------------------------------------------------------
178 | #
179 | # The following is the CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw
180 | # Credit goes to Kanru Hua.
181 | # I've added support for batching and pruning.
182 | #
183 | # ----------------------------------------------------------------------------------------------------------------------
184 | @jit(nopython=True)
185 | def compute_softdtw(D, gamma, bandwidth):
186 | B = D.shape[0]
187 | N = D.shape[1]
188 | M = D.shape[2]
189 | R = np.ones((B, N + 2, M + 2)) * np.inf
190 | R[:, 0, 0] = 0
191 | for b in range(B):
192 | for j in range(1, M + 1):
193 | for i in range(1, N + 1):
194 |
195 | # Check the pruning condition
196 | if 0 < bandwidth < np.abs(i - j):
197 | continue
198 |
199 | r0 = -R[b, i - 1, j - 1] / gamma
200 | r1 = -R[b, i - 1, j] / gamma
201 | r2 = -R[b, i, j - 1] / gamma
202 | rmax = max(max(r0, r1), r2)
203 | rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax)
204 | softmin = - gamma * (np.log(rsum) + rmax)
205 | R[b, i, j] = D[b, i - 1, j - 1] + softmin
206 | return R
207 |
208 | # ----------------------------------------------------------------------------------------------------------------------
209 | @jit(nopython=True)
210 | def compute_softdtw_backward(D_, R, gamma, bandwidth):
211 | B = D_.shape[0]
212 | N = D_.shape[1]
213 | M = D_.shape[2]
214 | D = np.zeros((B, N + 2, M + 2))
215 | E = np.zeros((B, N + 2, M + 2))
216 | D[:, 1:N + 1, 1:M + 1] = D_
217 | E[:, -1, -1] = 1
218 | R[:, :, -1] = -np.inf
219 | R[:, -1, :] = -np.inf
220 | R[:, -1, -1] = R[:, -2, -2]
221 | for k in range(B):
222 | for j in range(M, 0, -1):
223 | for i in range(N, 0, -1):
224 |
225 | if np.isinf(R[k, i, j]):
226 | R[k, i, j] = -np.inf
227 |
228 | # Check the pruning condition
229 | if 0 < bandwidth < np.abs(i - j):
230 | continue
231 |
232 | a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma
233 | b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma
234 | c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma
235 | a = np.exp(a0)
236 | b = np.exp(b0)
237 | c = np.exp(c0)
238 | E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c
239 | return E[:, 1:N + 1, 1:M + 1]
240 |
241 | # ----------------------------------------------------------------------------------------------------------------------
242 | class _SoftDTW(Function):
243 | """
244 | CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw
245 | """
246 |
247 | @staticmethod
248 | def forward(ctx, D, gamma, bandwidth):
249 | dev = D.device
250 | dtype = D.dtype
251 | gamma = torch.Tensor([gamma]).to(dev).type(dtype) # dtype fixed
252 | bandwidth = torch.Tensor([bandwidth]).to(dev).type(dtype)
253 | D_ = D.detach().cpu().numpy()
254 | g_ = gamma.item()
255 | b_ = bandwidth.item()
256 | R = torch.Tensor(compute_softdtw(D_, g_, b_)).to(dev).type(dtype)
257 | ctx.save_for_backward(D, R, gamma, bandwidth)
258 | return R[:, -2, -2]
259 |
260 | @staticmethod
261 | def backward(ctx, grad_output):
262 | dev = grad_output.device
263 | dtype = grad_output.dtype
264 | D, R, gamma, bandwidth = ctx.saved_tensors
265 | D_ = D.detach().cpu().numpy()
266 | R_ = R.detach().cpu().numpy()
267 | g_ = gamma.item()
268 | b_ = bandwidth.item()
269 | E = torch.Tensor(compute_softdtw_backward(D_, R_, g_, b_)).to(dev).type(dtype)
270 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None
271 |
272 | # ----------------------------------------------------------------------------------------------------------------------
273 | class SoftDTW(torch.nn.Module):
274 | """
275 | The soft DTW implementation that optionally supports CUDA
276 | """
277 |
278 | def __init__(self, use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_func=None):
279 | """
280 | Initializes a new instance using the supplied parameters
281 | :param use_cuda: Flag indicating whether the CUDA implementation should be used
282 | :param gamma: sDTW's gamma parameter
283 | :param normalize: Flag indicating whether to perform normalization
284 | (as discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790)
285 | :param bandwidth: Sakoe-Chiba bandwidth for pruning. Passing 'None' will disable pruning.
286 | :param dist_func: Optional point-wise distance function to use. If 'None', then a default Euclidean distance function will be used.
287 | """
288 | super(SoftDTW, self).__init__()
289 | self.normalize = normalize
290 | self.gamma = gamma
291 | self.bandwidth = 0 if bandwidth is None else float(bandwidth)
292 | self.use_cuda = use_cuda
293 |
294 | # Set the distance function
295 | if dist_func is not None:
296 | self.dist_func = dist_func
297 | else:
298 | self.dist_func = SoftDTW._euclidean_dist_func
299 |
300 | def _get_func_dtw(self, x, y):
301 | """
302 | Checks the inputs and selects the proper implementation to use.
303 | """
304 | bx, lx, dx = x.shape
305 | by, ly, dy = y.shape
306 | # Make sure the dimensions match
307 | assert bx == by # Equal batch sizes
308 | assert dx == dy # Equal feature dimensions
309 |
310 | use_cuda = self.use_cuda
311 |
312 | if use_cuda and (lx > 1024 or ly > 1024): # We should be able to spawn enough threads in CUDA
313 | print("SoftDTW: Cannot use CUDA because the sequence length > 1024 (the maximum block size supported by CUDA)")
314 | use_cuda = False
315 |
316 | # Finally, return the correct function
317 | return _SoftDTWCUDA.apply if use_cuda else _SoftDTW.apply
318 |
319 | @staticmethod
320 | def _euclidean_dist_func(x, y):
321 | """
322 | Calculates the Euclidean distance between each element in x and y per timestep
323 | """
324 | n = x.size(1)
325 | m = y.size(1)
326 | d = x.size(2)
327 | x = x.unsqueeze(2).expand(-1, n, m, d)
328 | y = y.unsqueeze(1).expand(-1, n, m, d)
329 | return torch.pow(x - y, 2).sum(3)
330 |
331 | def forward(self, X, Y):
332 | """
333 | Compute the soft-DTW value between X and Y
334 | :param X: One batch of examples, batch_size x seq_len x dims
335 | :param Y: The other batch of examples, batch_size x seq_len x dims
336 | :return: The computed results
337 | """
338 |
339 | # Check the inputs and get the correct implementation
340 | func_dtw = self._get_func_dtw(X, Y)
341 |
342 | if self.normalize:
343 | # Stack everything up and run
344 | x = torch.cat([X, X, Y])
345 | y = torch.cat([Y, X, Y])
346 | D = self.dist_func(x, y)
347 | out = func_dtw(D, self.gamma, self.bandwidth)
348 | out_xy, out_xx, out_yy = torch.split(out, X.shape[0])
349 | return out_xy - 1 / 2 * (out_xx + out_yy)
350 | else:
351 | D_xy = self.dist_func(X, Y)
352 | return func_dtw(D_xy, self.gamma, self.bandwidth)
353 |
354 | # ----------------------------------------------------------------------------------------------------------------------
355 | def timed_run(a, b, sdtw):
356 | """
357 | Runs a and b through sdtw, and times the forward and backward passes.
358 | Assumes that a requires gradients.
359 | :return: timing, forward result, backward result
360 | """
361 | from timeit import default_timer as timer
362 |
363 | # Forward pass
364 | start = timer()
365 | forward = sdtw(a, b)
366 | end = timer()
367 | t = end - start
368 |
369 | grad_outputs = torch.ones_like(forward)
370 |
371 | # Backward
372 | start = timer()
373 | grads = torch.autograd.grad(forward, a, grad_outputs=grad_outputs)[0]
374 | end = timer()
375 |
376 | # Total time
377 | t += end - start
378 |
379 | return t, forward, grads
380 |
381 | # ----------------------------------------------------------------------------------------------------------------------
382 | def profile(batch_size, seq_len_a, seq_len_b, dims, tol_backward):
383 | sdtw = SoftDTW(False, gamma=1.0, normalize=False)
384 | sdtw_cuda = SoftDTW(True, gamma=1.0, normalize=False)
385 | n_iters = 6
386 |
387 | print("Profiling forward() + backward() times for batch_size={}, seq_len_a={}, seq_len_b={}, dims={}...".format(batch_size, seq_len_a, seq_len_b, dims))
388 |
389 | times_cpu = []
390 | times_gpu = []
391 |
392 | for i in range(n_iters):
393 | a_cpu = torch.rand((batch_size, seq_len_a, dims), requires_grad=True)
394 | b_cpu = torch.rand((batch_size, seq_len_b, dims))
395 | a_gpu = a_cpu.cuda()
396 | b_gpu = b_cpu.cuda()
397 |
398 | # GPU
399 | t_gpu, forward_gpu, backward_gpu = timed_run(a_gpu, b_gpu, sdtw_cuda)
400 |
401 | # CPU
402 | t_cpu, forward_cpu, backward_cpu = timed_run(a_cpu, b_cpu, sdtw)
403 |
404 | # Verify the results
405 | assert torch.allclose(forward_cpu, forward_gpu.cpu())
406 | assert torch.allclose(backward_cpu, backward_gpu.cpu(), atol=tol_backward)
407 |
408 | if i > 0: # Ignore the first time we run, in case this is a cold start (because timings are off at a cold start of the script)
409 | times_cpu += [t_cpu]
410 | times_gpu += [t_gpu]
411 |
412 | # Average and log
413 | avg_cpu = np.mean(times_cpu)
414 | avg_gpu = np.mean(times_gpu)
415 | print("\tCPU: ", avg_cpu)
416 | print("\tGPU: ", avg_gpu)
417 | print("\tSpeedup: ", avg_cpu / avg_gpu)
418 | print()
419 |
420 | # ----------------------------------------------------------------------------------------------------------------------
421 | if __name__ == "__main__":
422 | from timeit import default_timer as timer
423 |
424 | torch.manual_seed(1234)
425 |
426 | profile(128, 17, 15, 2, tol_backward=1e-6)
427 | profile(512, 64, 64, 2, tol_backward=1e-4)
428 | profile(512, 256, 256, 2, tol_backward=1e-3)
429 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import sys
5 | import time
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | from torch.utils.data.dataloader import DataLoader
11 | from tqdm import tqdm
12 |
13 | from datasets.amass import AMASS
14 | from arguments import Arguments
15 | from nemf.generative import Architecture
16 |
17 |
18 | def train():
19 | model = Architecture(args, ngpu, len(train_data_loader))
20 | model.print()
21 | model.setup()
22 |
23 | loss_min = None
24 | if args.epoch_begin != 0:
25 | model.load(epoch=args.epoch_begin)
26 | model.eval()
27 | for data in valid_data_loader:
28 | model.set_input(data)
29 | model.validate()
30 | loss_min = model.verbose()
31 |
32 | epoch_begin = args.epoch_begin + 1
33 | epoch_end = epoch_begin + args.epoch_num
34 | start_time = time.time()
35 |
36 | for epoch in range(epoch_begin, epoch_end):
37 | model.train()
38 | with tqdm(train_data_loader, unit="batch") as tepoch:
39 | for data in tepoch:
40 | tepoch.set_description(f"Epoch {epoch}")
41 | model.set_input(data)
42 | model.optimize_parameters()
43 |
44 | model.eval()
45 | for data in valid_data_loader:
46 | model.set_input(data)
47 | model.validate()
48 |
49 | model.epoch()
50 | res = model.verbose()
51 |
52 | if args.verbose:
53 | print(f'Epoch {epoch}/{epoch_end - 1}:')
54 | print(json.dumps(res, sort_keys=True, indent=4))
55 |
56 | if loss_min is None or res['total_loss']['val'] < loss_min['total_loss']['val']:
57 | loss_min = res
58 | model.save(optimal=True)
59 |
60 | if epoch % args.checkpoint == 0 or epoch == epoch_end - 1:
61 | model.save()
62 |
63 | end_time = time.time()
64 | print(f'Training finished in {time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))}')
65 | print('Final Loss:')
66 | print(json.dumps(loss_min, sort_keys=True, indent=4))
67 | df = pd.DataFrame.from_dict(loss_min)
68 | df.to_csv(os.path.join(args.save_dir, f'{args.filename}.csv'), index=False)
69 |
70 |
71 | def test(steps):
72 | model = Architecture(args, ngpu, len(test_data_loader))
73 | model.load(optimal=True)
74 | model.eval()
75 |
76 | for step in steps:
77 | statistics = dict()
78 | model.test_index = 0
79 | for data in test_data_loader:
80 | model.set_input(data)
81 | model.super_sampling(step=step)
82 | if step == 1.0:
83 | errors = model.report_errors()
84 | if not statistics:
85 | statistics = {
86 | 'rotation_error': [errors['rotation'] * 180.0 / np.pi],
87 | 'position_error': [errors['position'] * 100.0],
88 | 'orientation_error': [errors['orientation'] * 180.0 / np.pi],
89 | 'translation_error': [errors['translation'] * 100.0]
90 | }
91 | else:
92 | statistics['rotation_error'].append(errors['rotation'] * 180.0 / np.pi)
93 | statistics['position_error'].append(errors['position'] * 100.0)
94 | statistics['orientation_error'].append(errors['orientation'] * 180.0 / np.pi)
95 | statistics['translation_error'].append(errors['translation'] * 100.0)
96 |
97 | if step == 1.0:
98 | df = pd.DataFrame.from_dict(statistics)
99 | df.to_csv(os.path.join(args.save_dir, f'{args.filename}_test.csv'), index=False)
100 |
101 |
102 | if __name__ == '__main__':
103 | if len(sys.argv) == 1:
104 | args = Arguments('./configs', filename='generative.yaml')
105 | else:
106 | args = Arguments('./configs', filename=sys.argv[1])
107 | print(json.dumps(args.json, sort_keys=True, indent=4))
108 |
109 | torch.set_default_dtype(torch.float32)
110 |
111 | torch.manual_seed(0)
112 | np.random.seed(0)
113 | random.seed(0)
114 |
115 | ngpu = 1
116 | if args.multi_gpu is True:
117 | ngpu = torch.cuda.device_count()
118 | if ngpu == 1:
119 | args.multi_gpu = False
120 | print(f'Number of GPUs: {ngpu}')
121 |
122 | # dataset definition
123 | train_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'train'))
124 | train_data_loader = DataLoader(train_dataset, batch_size=ngpu * args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True, drop_last=True)
125 | valid_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'valid'))
126 | valid_data_loader = DataLoader(valid_dataset, batch_size=ngpu * args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True)
127 | test_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'test'))
128 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True)
129 |
130 | if args.is_train:
131 | train()
132 |
133 | args.is_train = False
134 | test(steps=[0.5, 1.0])
135 |
--------------------------------------------------------------------------------
/src/train_basic.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import sys
5 | import time
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 |
11 | from arguments import Arguments
12 | from nemf.basic import Architecture
13 |
14 |
15 | def load_data(data_path, index, frames):
16 | data = {}
17 | for key in ['pos', 'global_xform', 'root_orient', 'root_vel', 'trans']:
18 | data[key] = torch.load(os.path.join(data_path, f'{key}_{index}.pt'))
19 | if frames != -1:
20 | data[key] = data[key][350:frames + 350].unsqueeze(0)
21 | else:
22 | data[key] = data[key].unsqueeze(0)
23 |
24 | return data
25 |
26 |
27 | def train_basic(frames, save_dir, steps):
28 | data_path = os.path.join(args.dataset_dir, 'train')
29 | file_indices = list(range(16)) if args.amass_data else [22]
30 |
31 | statistics = dict()
32 | for index in file_indices:
33 | print(f'Fitting sequence {index} with {frames} frames:')
34 | args.save_dir = os.path.join(save_dir, f'frame_{frames}', f'sequence_{index}')
35 | model = Architecture(args, ngpu)
36 | model.setup()
37 | model.print()
38 | print(f'# of parameters: {model.count_params()}')
39 | model.set_input(load_data(data_path, index, frames))
40 |
41 | start_time = time.time()
42 |
43 | iterations = args.iterations * (frames // 32) if args.amass_data else args.iterations
44 | if args.is_train:
45 | model.train()
46 | loss_min = None
47 | for iter in range(iterations):
48 | model.optimize_parameters()
49 |
50 | model.epoch()
51 | res = model.verbose()
52 |
53 | if args.verbose:
54 | print(f'Iteration {iter}/{iterations}:')
55 | print(json.dumps(res, sort_keys=True, indent=4))
56 |
57 | if loss_min is None or res['total_loss']['train'] < loss_min['total_loss']['train']:
58 | loss_min = res
59 | model.save(optimal=True)
60 |
61 | end_time = time.time()
62 | print(f'Training finished in {time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))}')
63 | print('Final Loss:')
64 | print(json.dumps(loss_min, sort_keys=True, indent=4))
65 |
66 | model.eval()
67 | model.load(optimal=True)
68 |
69 | for step in steps:
70 | model.super_sampling(step)
71 |
72 | if step == 1.0:
73 | errors = model.report_errors()
74 | if not statistics:
75 | statistics = {
76 | 'iterations': [iterations],
77 | 'rotation_error': [errors['rotation'] * 180.0 / np.pi],
78 | 'position_error': [errors['position'] * 100.0],
79 | 'orientation_error': [errors['orientation'] * 180.0 / np.pi],
80 | 'translation_error': [errors['translation'] * 100.0]
81 | }
82 | else:
83 | statistics['iterations'].append(iterations)
84 | statistics['rotation_error'].append(errors['rotation'] * 180.0 / np.pi)
85 | statistics['position_error'].append(errors['position'] * 100.0)
86 | statistics['orientation_error'].append(errors['orientation'] * 180.0 / np.pi)
87 | statistics['translation_error'].append(errors['translation'] * 100.0)
88 |
89 | if statistics:
90 | df = pd.DataFrame.from_dict(statistics)
91 | df.to_csv(os.path.join(save_dir, f'recon_{frames}.csv'), index=False)
92 |
93 |
94 | if __name__ == '__main__':
95 | if len(sys.argv) == 1:
96 | args = Arguments('./configs', filename='basic.yaml')
97 | else:
98 | args = Arguments('./configs', filename=sys.argv[1])
99 | print(json.dumps(args.json, sort_keys=True, indent=4))
100 |
101 | torch.set_default_dtype(torch.float32)
102 |
103 | torch.manual_seed(0)
104 | np.random.seed(0)
105 | random.seed(0)
106 |
107 | ngpu = 1
108 | if args.multi_gpu is True:
109 | ngpu = torch.cuda.device_count()
110 | if ngpu == 1:
111 | args.multi_gpu = False
112 | print(f'Number of GPUs: {ngpu}')
113 |
114 | save_dir = args.save_dir
115 | frames = [32, 64, 128, 256, 512] if args.amass_data else [-1]
116 | steps = [1.0, 0.5, 0.25, 0.125] if args.amass_data else [1.0]
117 | for f in frames:
118 | train_basic(frames=f, save_dir=save_dir, steps=steps)
119 |
--------------------------------------------------------------------------------
/src/train_gmp.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import sys
5 | import time
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | from torch.utils.data.dataloader import DataLoader
11 | from tqdm import tqdm
12 |
13 | from arguments import Arguments
14 | from datasets.amass import AMASS
15 | from nemf.global_motion import GlobalMotionPredictor
16 |
17 |
18 | def train():
19 | model = GlobalMotionPredictor(args, ngpu)
20 | model.print()
21 | model.setup()
22 |
23 | loss_min = None
24 | if args.epoch_begin != 0:
25 | model.load(epoch=args.epoch_begin)
26 | model.eval()
27 | for data in valid_data_loader:
28 | model.set_input(data)
29 | model.validate()
30 | loss_min = model.verbose()
31 |
32 | epoch_begin = args.epoch_begin + 1
33 | epoch_end = epoch_begin + args.epoch_num
34 | start_time = time.time()
35 |
36 | for epoch in range(epoch_begin, epoch_end):
37 | model.train()
38 | with tqdm(train_data_loader, unit="batch") as tepoch:
39 | for data in tepoch:
40 | tepoch.set_description(f"Epoch {epoch}")
41 | model.set_input(data)
42 | model.optimize_parameters()
43 |
44 | model.eval()
45 | for data in valid_data_loader:
46 | model.set_input(data)
47 | model.validate()
48 |
49 | model.epoch()
50 | res = model.verbose()
51 |
52 | if args.verbose:
53 | print(f'Epoch {epoch}/{epoch_end - 1}:')
54 | print(json.dumps(res, sort_keys=True, indent=4))
55 |
56 | if loss_min is None or res['total_loss']['val'] < loss_min['total_loss']['val']:
57 | loss_min = res
58 | model.save(optimal=True)
59 |
60 | if epoch % args.checkpoint == 0 or epoch == epoch_end - 1:
61 | model.save()
62 |
63 | end_time = time.time()
64 | print(f'Training finished in {time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))}')
65 | print('Final Loss:')
66 | print(json.dumps(loss_min, sort_keys=True, indent=4))
67 | df = pd.DataFrame.from_dict(loss_min)
68 | df.to_csv(os.path.join(args.save_dir, f'{args.filename}.csv'), index=False)
69 |
70 |
71 | def test():
72 | model = GlobalMotionPredictor(args, ngpu)
73 | model.load(optimal=True)
74 | model.eval()
75 |
76 | statistics = dict()
77 | for data in test_data_loader:
78 | model.set_input(data)
79 | model.test()
80 |
81 | errors = model.report_errors()
82 | if not statistics:
83 | statistics = {
84 | 'translation_error': [errors['translation'] * 100.0]
85 | }
86 | else:
87 | statistics['translation_error'].append(errors['translation'] * 100.0)
88 |
89 | df = pd.DataFrame.from_dict(statistics)
90 | df.to_csv(os.path.join(args.save_dir, f'{args.filename}_test.csv'), index=False)
91 |
92 |
93 | if __name__ == '__main__':
94 | if len(sys.argv) == 1:
95 | args = Arguments('./configs', filename='gmp.yaml')
96 | else:
97 | args = Arguments('./configs', filename=sys.argv[1])
98 | print(json.dumps(args.json, sort_keys=True, indent=4))
99 |
100 | torch.set_default_dtype(torch.float32)
101 |
102 | torch.manual_seed(0)
103 | np.random.seed(0)
104 | random.seed(0)
105 |
106 | ngpu = 1
107 | if args.multi_gpu is True:
108 | ngpu = torch.cuda.device_count()
109 | if ngpu == 1:
110 | args.multi_gpu = False
111 | print(f'Number of GPUs: {ngpu}')
112 |
113 | # dataset definition
114 | train_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'train'))
115 | train_data_loader = DataLoader(train_dataset, batch_size=ngpu * args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True)
116 | valid_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'valid'))
117 | valid_data_loader = DataLoader(valid_dataset, batch_size=ngpu * args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True)
118 | test_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'test'))
119 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True)
120 |
121 | if args.is_train:
122 | train()
123 |
124 | args.is_train = False
125 | test()
126 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import plyfile
5 | import torch
6 | import torch.nn.functional as F
7 | from human_body_prior.tools.omni_tools import copy2cpu as c2c
8 | from scipy.spatial.transform import Rotation as R
9 | from scipy.spatial.transform import Slerp
10 |
11 | import holden.BVH as BVH
12 | from holden.Animation import Animation
13 | from holden.Quaternions import Quaternions
14 |
15 | SMPL_JOINTS = {
16 | 'Pelvis': 0,
17 | 'L_Hip': 1, 'L_Knee': 4, 'L_Ankle': 7, 'L_Foot': 10,
18 | 'R_Hip': 2, 'R_Knee': 5, 'R_Ankle': 8, 'R_Foot': 11,
19 | 'Spine1': 3, 'Spine2': 6, 'Spine3': 9, 'Neck': 12, 'Head': 15,
20 | 'L_Collar': 13, 'L_Shoulder': 16, 'L_Elbow': 18, 'L_Wrist': 20, 'L_Hand': 22,
21 | 'R_Collar': 14, 'R_Shoulder': 17, 'R_Elbow': 19, 'R_Wrist': 21, 'R_Hand': 23
22 | }
23 |
24 | FOOT_IDX = [SMPL_JOINTS['L_Ankle'], SMPL_JOINTS['R_Ankle'], SMPL_JOINTS['L_Foot'], SMPL_JOINTS['R_Foot']]
25 |
26 | CONTACTS_IDX = [SMPL_JOINTS['L_Ankle'], SMPL_JOINTS['R_Ankle'],
27 | SMPL_JOINTS['L_Foot'], SMPL_JOINTS['R_Foot'],
28 | SMPL_JOINTS['L_Wrist'], SMPL_JOINTS['R_Wrist'],
29 | SMPL_JOINTS['L_Knee'], SMPL_JOINTS['R_Knee']]
30 |
31 |
32 | def align_joints(tensor, smpl_to_bvh=True):
33 | """
34 | Args:
35 | tensor (T x J x D): The 3D torch tensor we need to process.
36 |
37 | Returns:
38 | The 3D tensor whose joint order is compatible for bvh export.
39 | """
40 | order = list(SMPL_JOINTS.values())
41 | result = tensor.clone() if isinstance(tensor, torch.Tensor) else tensor.copy()
42 | for i in range(len(order)):
43 | if smpl_to_bvh:
44 | result[:, i] = tensor[:, order[i]]
45 | else:
46 | result[:, order[i]] = tensor[:, i]
47 |
48 | return result
49 |
50 |
51 | def slerp(quat, trans, key_times, times, mask=True):
52 | """
53 | Args:
54 | quat: (T x J x 4)
55 | trans: (T x 3)
56 | """
57 | if mask:
58 | quat = c2c(quat[key_times])
59 | trans = c2c(trans[key_times])
60 | else:
61 | quat = c2c(quat)
62 | trans = c2c(trans)
63 |
64 | quats = []
65 | for j in range(quat.shape[1]):
66 | key_rots = R.from_quat(quat[:, j])
67 | s = Slerp(key_times, key_rots)
68 | interp_rots = s(times)
69 | quats.append(interp_rots.as_quat())
70 | slerp_quat = np.stack(quats, axis=1)
71 |
72 | lerp_trans = np.zeros((len(times), 3))
73 | for i in range(3):
74 | lerp_trans[:, i] = np.interp(times, key_times, trans[:, i])
75 |
76 | return slerp_quat, lerp_trans
77 |
78 |
79 | def compute_orient_angle(matrix, traj):
80 | """
81 | Args:
82 | matrix (N x T x 3 x 3): The 3D rotation matrix at the root joint
83 | traj (N x T x 3): The trajectory to align
84 | """
85 | forward = matrix[:, :, :, 2].clone()
86 | forward[:, :, 2] = 0
87 | forward = F.normalize(forward, dim=-1) # normalized forward vector (N, T, 3)
88 |
89 | traj[:, :, 2] = 0 # make sure the trajectory is projected to the plane
90 |
91 | # first steps is forward diff
92 | init_tan = traj[:, 1:2] - traj[:, :1]
93 | # middle steps are second order
94 | middle_tan = (traj[:, 2:] - traj[:, 0:-2]) / 2
95 | # last step is backward diff
96 | final_tan = traj[:, -1:] - traj[:, -2:-1]
97 |
98 | tangent = torch.cat([init_tan, middle_tan, final_tan], dim=1)
99 | tangent = F.normalize(tangent, dim=-1) # normalized tangent vector (N, T, 3)
100 |
101 | cos = torch.sum(forward * tangent, dim=-1)
102 |
103 | return cos
104 |
105 |
106 | def compute_trajectory(velocity, up, origin, dt, up_axis='z'):
107 | """
108 | Args:
109 | velocity: (B, T, 3)
110 | up: (B, T)
111 | origin: (B, 3)
112 | up_axis: x, y, or z
113 |
114 | Returns:
115 | trajectory: (B, T, 3)
116 | """
117 | ordermap = {
118 | 'x': 0,
119 | 'y': 1,
120 | 'z': 2,
121 | }
122 | v_axis = [x for x in ordermap.values() if x != ordermap[up_axis]]
123 |
124 | origin = origin.unsqueeze(1) # (B, 3) => (B, 1, 3)
125 | trajectory = origin.repeat(1, up.shape[1], 1) # (B, 1, 3) => (B, T, 3)
126 |
127 | for t in range(1, up.shape[1]):
128 | trajectory[:, t, v_axis[0]] = trajectory[:, t - 1, v_axis[0]] + velocity[:, t - 1, v_axis[0]] * dt
129 | trajectory[:, t, v_axis[1]] = trajectory[:, t - 1, v_axis[1]] + velocity[:, t - 1, v_axis[1]] * dt
130 |
131 | trajectory[:, :, ordermap[up_axis]] = up
132 |
133 | return trajectory
134 |
135 |
136 | def build_canonical_frame(forward, up_axis='z'):
137 | """
138 | Args:
139 | forward: (..., 3)
140 |
141 | Returns:
142 | frame: (..., 3, 3)
143 | """
144 | forward[..., 'xyz'.index(up_axis)] = 0
145 | forward = F.normalize(forward, dim=-1) # normalized forward vector
146 |
147 | up = torch.zeros_like(forward)
148 | up[..., 'xyz'.index(up_axis)] = 1 # normalized up vector
149 | right = torch.cross(up, forward)
150 | frame = torch.stack((right, up, forward), dim=-1) # canonical frame
151 |
152 | return frame
153 |
154 |
155 | def estimate_linear_velocity(data_seq, dt):
156 | '''
157 | Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates
158 | the velocity for the middle T-2 steps using a second order central difference scheme.
159 | The first and last frames are with forward and backward first-order
160 | differences, respectively
161 | - h : step size
162 | '''
163 | # first steps is forward diff (t+1 - t) / dt
164 | init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt
165 | # middle steps are second order (t+1 - t-1) / 2dt
166 | middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt)
167 | # last step is backward diff (t - t-1) / dt
168 | final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt
169 |
170 | vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1)
171 | return vel_seq
172 |
173 |
174 | def estimate_angular_velocity(rot_seq, dt):
175 | '''
176 | Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps.
177 | Input sequence should be of shape (B, T, ..., 3, 3)
178 | '''
179 | # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix
180 | dRdt = estimate_linear_velocity(rot_seq, dt)
181 | R = rot_seq
182 | RT = R.transpose(-1, -2)
183 | # compute skew-symmetric angular velocity tensor
184 | w_mat = torch.matmul(dRdt, RT)
185 | # pull out angular velocity vector by averaging symmetric entries
186 | w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0
187 | w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0
188 | w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0
189 | w = torch.stack([w_x, w_y, w_z], axis=-1)
190 | return w
191 |
192 |
193 | def normalize(tensor, mean=None, std=None):
194 | """
195 | Args:
196 | tensor: (B, T, ...)
197 |
198 | Returns:
199 | normalized tensor with 0 mean and 1 standard deviation, std, mean
200 | """
201 | if mean is None or std is None:
202 | # std, mean = torch.std_mean(tensor, dim=0, unbiased=False, keepdim=True)
203 | std, mean = torch.std_mean(tensor, dim=(0, 1), unbiased=False, keepdim=True)
204 | std[std == 0.0] = 1.0
205 |
206 | return (tensor - mean) / std, mean, std
207 |
208 | return (tensor - mean) / std
209 |
210 |
211 | def denormalize(tensor, mean, std):
212 | """
213 | Args:
214 | tensor: B x T x D
215 | mean:
216 | std:
217 | """
218 | return tensor * std + mean
219 |
220 |
221 | def export_bvh_animation(rotations, positions, offsets, parents, output_dir, prefix, joint_names, fps):
222 | """
223 | Args:
224 | rotations: quaternions of the shape (B, T, J, 4)
225 | positions: global translations of the shape (B, T, 3)
226 | """
227 | os.makedirs(output_dir, exist_ok=True)
228 | for i in range(rotations.shape[0]):
229 | rotation = align_joints(rotations[i])
230 | position = positions[i]
231 | position = position.unsqueeze(1)
232 | anim = Animation(Quaternions(c2c(rotation)), c2c(position), None, offsets=offsets, parents=parents)
233 | BVH.save(os.path.join(output_dir, f'{prefix}_{i}.bvh'), anim, names=joint_names, frametime=1 / fps)
234 |
235 |
236 | def export_ply_trajectory(points, color, ply_fname):
237 | v = []
238 | for p in points:
239 | v += [(p[0], p[1], p[2], color[0], color[1], color[2])]
240 | v = np.array(v, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
241 | el_v = plyfile.PlyElement.describe(v, 'vertex')
242 |
243 | e = np.empty(len(points) - 1, dtype=[('vertex1', 'i4'), ('vertex2', 'i4')])
244 | edge_data = np.array([[i, i + 1] for i in range(len(points) - 1)], dtype='i4')
245 | e['vertex1'] = edge_data[:, 0]
246 | e['vertex2'] = edge_data[:, 1]
247 | el_e = plyfile.PlyElement.describe(e, 'edge')
248 |
249 | plyfile.PlyData([el_v, el_e], text=True).write(ply_fname)
250 |
--------------------------------------------------------------------------------