├── .gitignore ├── LICENSE ├── README.md ├── configs ├── amass.yaml ├── application.yaml ├── basic.yaml ├── basic_dog.yaml ├── dog_mocap.yaml ├── dog_skel.yaml ├── evaluation.yaml ├── generative.yaml ├── gmp.yaml └── smpl.yaml ├── data └── .gitkeep ├── outputs └── .gitkeep ├── overview.png ├── requirements.txt └── src ├── application.py ├── arguments.py ├── datasets ├── __init__.py ├── amass.py └── animal.py ├── evaluation.py ├── holden ├── Animation.py ├── AnimationStructure.py ├── BVH.py ├── InverseKinematics.py ├── Pivots.py ├── Quaternions.py └── __init__.py ├── nemf ├── __init__.py ├── base_model.py ├── basic.py ├── fk.py ├── generative.py ├── global_motion.py ├── loss_record.py ├── losses.py ├── neural_motion.py ├── prior.py ├── residual_blocks.py └── skeleton.py ├── rotations.py ├── soft_dtw_cuda.py ├── train.py ├── train_basic.py ├── train_gmp.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | data/* 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | *.ipynb 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | outputs/* 134 | !outputs/.gitkeep 135 | outputs_*/ 136 | config_v*.yml 137 | analysis/ 138 | analysis_*/ 139 | optimize/ 140 | optimize_*/ 141 | *.bak* 142 | *.out 143 | *.csv 144 | inbetween_benchmark/ 145 | comparison/ 146 | nemf_test/ 147 | motion_recon/ 148 | logs/ 149 | optim*/ 150 | evaluate*/ 151 | max_sneaks/ 152 | 153 | !data/.gitkeep 154 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Chengan He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeMF: Neural Motion Fields for Kinematic Animation (NeurIPS 2022) 2 | 3 | ### [Paper](https://arxiv.org/abs/2206.03287) | [Project](https://cs.yale.edu/homes/che/projects/nemf/) 4 | 5 | [Chengan He](https://cs.yale.edu/homes/che/)1, [Jun Saito](https://research.adobe.com/person/jun-saito/)2, [James Zachary](https://jameszachary.com/)2, [Holly Rushmeier](https://graphics.cs.yale.edu/people/holly-rushmeier)1, [Yi Zhou](https://zhouyisjtu.github.io/)2 6 | 7 | 1Yale University, 2Adobe Research 8 | 9 | ![NeMF Overview](overview.png) 10 | 11 | ## Prerequisites 12 | 13 | This code was developed on Ubuntu 20.04 with Python 3.9, CUDA 11.3 and PyTorch 1.9.0. 14 | 15 | ### Environment Setup 16 | 17 | To begin with, please set up the virtual environment with Anaconda: 18 | ```bash 19 | conda create -n nemf python=3.9 20 | conda activate nemf 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ### Body Model 25 | 26 | Our code relies on [SMPL](https://smpl.is.tue.mpg.de/) as the body model. Please download it from Max Planck and accept the terms of license there. 27 | 28 | ### Datasets 29 | 30 | #### AMASS 31 | 32 | [AMASS](https://amass.is.tue.mpg.de/) mocap data is used to train and evaluate our model in most of the experiments. Please download it from Max Planck and accept the terms of license there. After downloading its raw data, we run the [processing scripts](https://github.com/davrempe/humor/tree/main/data#amass) provided by HuMoR to unify frame rates, detect contacts, and remove problematic sequences. We then split the data into training, validation and testing sets by running: 33 | ```bash 34 | python src/datasets/amass.py amass.yaml 35 | ``` 36 | 37 | #### MANN 38 | 39 | The dog mocap data in [MANN](https://github.com/sebastianstarke/AI4Animation/tree/master/AI4Animation/SIGGRAPH_2018) is used to evaluate the reconstruction capability of our model in long sequences, whose raw data can be downloaded from [here](http://www.starke-consult.de/AI4Animation/SIGGRAPH_2018/MotionCapture.zip). To process the data, we first manually remove some sequences on uneven terrain and then run: 40 | ```bash 41 | python src/datasets/animal.py dog_mocap.yaml 42 | ``` 43 | 44 | > **Note:** All the data mentioned before should be downloaded **from their original source with corresponding licenses** and processed to the `data` directory. Otherwise you need to update the configuration files to point them to the path you extracted. 45 | 46 | ## Quickstart 47 | 48 | We provide a pre-trained [generative NeMF model](https://yaleedu-my.sharepoint.com/:u:/g/personal/chengan_he_yale_edu/ERadqRp5XedAn0pOyGeFWOMB3LH7g7w9c4OyzB7nmweHqA?e=8P1NWU) and [global motion predictor](https://yaleedu-my.sharepoint.com/:u:/g/personal/chengan_he_yale_edu/ERmQUqW5z_tAncGS1XzTJ6sBGnIxEawVgK9J-vID-uT6iw?e=TC6wfQ). Download and extract them to the `outputs` directory. 49 | 50 | ### Application 51 | 52 | We deploy our trained model as a generic motion prior to sovle different motion tasks. To run the applications we showed in the paper, use: 53 | ```bash 54 | python src/application.py --config application.yaml --task {application task} --save_path {save path} 55 | ``` 56 | Here we implement several applications including `motion_reconstruction`, `latent_interpolation`, `motion_inbetweening`, `motion_renavigating`, `aist_inbetweening`, and `time_translation`. For `aist_inbetweening`, you need to download the motion data and dance videos from [AIST++ Dataset](https://google.github.io/aistplusplus_dataset/download.html) and place them under `data/aist`. You also need to have [FFmpeg](https://ffmpeg.org/) installed to process these videos. 57 | 58 | ### Evaluation 59 | 60 | To evaluate our trained model on different tasks, use: 61 | ```bash 62 | python src/evaluation.py --config evaluation.yaml --task {evaluation task} --load_path {load path} 63 | ``` 64 | The tasks we implement here include `ablation_study`, `comparison`, `inbetween_benchmark`, `super_sampling`, and `smoothness_test`, which cover the tables and figures we showed in the paper. The quantitative results will be saved in `.csv` files. 65 | 66 | To evaludate FID and Diversity, we provide a pre-trained feature extractor at [here](https://yaleedu-my.sharepoint.com/:u:/g/personal/chengan_he_yale_edu/EVXKF8Tc1t5CnrkswLS0qUoBi_1YsDa_BpMDAalEuTblSQ?e=bcxHAP), which is essentially an auto-encoder. You can train a new one on your data by running: 67 | ``` 68 | python src/train.py evaluation.yaml 69 | ``` 70 | 71 | ## Train NeMF from Scratch 72 | 73 | In our paper we proposed three different models: a single-motion NeMF that overfits specific motion sequences, a generative NeMF that learns a motion prior, and a global motion predictor that generates root translations separately. Below we describe how to train these three models from scratch. 74 | 75 | ### Training Single-motion NeMF 76 | 77 | To train the single-motion NeMF on AMASS sequences, use: 78 | ``` 79 | python src/train_basic.py basic.yaml 80 | ``` 81 | The code will obtain sequences of 32, 64, 128, 256, and 512 frames and reconstruct them at 30, 60, 120, and 240 fps. 82 | 83 | To train the model on dog mocap sequences, use: 84 | ``` 85 | python src/train_basic.py basic_dog.yaml 86 | ``` 87 | 88 | ### Training Generative NeMF 89 | 90 | To train the generative NeMF on AMASS dataset, use: 91 | ``` 92 | python src/train.py generative.yaml 93 | ``` 94 | 95 | ### Training Global Motion Predictor 96 | 97 | To train the global motion predictor on AMASS dataset, use: 98 | ``` 99 | python src/train_gmp.py gmp.yaml 100 | ``` 101 | 102 | ## Visualization and Rendering 103 | 104 | Our codebase outputs `.npz` data following the AMASS data format, thus they can be visualized directly with the [SMPL-X Blender add-on](https://smpl-x.is.tue.mpg.de/). To render the skeleton animation of `.bvh` files, we use the rendering scripts provided in [deep-motion-editing](https://github.com/DeepMotionEditing/deep-motion-editing). 105 | 106 | ## Acknowledgements 107 | 108 | - Our bvh I/O code is adapted from the [work](https://theorangeduck.com/media/uploads/other_stuff/motionsynth_code.zip) of [Daniel Holden](https://theorangeduck.com/page/publications). 109 | - The code in `src/rotations.py` is adapted from [PyTorch3D](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py). 110 | - The code in `src/datasets/amass.py` is adapted from [AMASS](https://github.com/nghorbani/amass/blob/master/src/amass/data/prepare_data.py). 111 | - The code in `src/nemf/skeleton.py` is taken from [deep-motion-editing](https://github.com/DeepMotionEditing/deep-motion-editing) and our code structure is also based on it. 112 | - Part of the code in `sec/evaluation.py` is adapted from [Action2Motion](https://github.com/EricGuo5513/action-to-motion/tree/master/eval_scripts). 113 | - Part of the code in `src/utils.py` is taken from [HuMoR](https://github.com/davrempe/humor/blob/b86c2d9faf7abd497749621821a5d46211304d62/humor/scripts/process_amass_data.py). 114 | - The code in `src/soft_dtw_cuda.py` is taken from [pytorch-softdtw-cuda](https://github.com/Maghoumi/pytorch-softdtw-cuda). 115 | 116 | **Huge thanks to these great open-source projects!** 117 | 118 | ## Citation 119 | 120 | If you found this code or paper useful, please consider citing: 121 | ``` 122 | @article{he2022nemf, 123 | title={NeMF: Neural Motion Fields for Kinematic Animation}, 124 | author={He, Chengan and Saito, Jun and Zachary, James and Rushmeier, Holly and Zhou, Yi}, 125 | journal={Advances in Neural Information Processing Systems}, 126 | volume={35}, 127 | pages={4244--4256}, 128 | year={2022} 129 | } 130 | ``` 131 | 132 | ## Contact 133 | If you run into any problems or have questions, please create an issue or contact `chengan.he@yale.edu`. 134 | -------------------------------------------------------------------------------- /configs/amass.yaml: -------------------------------------------------------------------------------- 1 | dataset_dir: ./data/amass/test 2 | # dataset_dir: ./data/amass/gmp 3 | # dataset_dir: ./data/amass/single 4 | 5 | data: 6 | fps: 30 7 | clip_length: 128 8 | gender: male 9 | up: z 10 | 11 | canonical: False 12 | unified_orientation: True 13 | single_motion: False 14 | normalize: False -------------------------------------------------------------------------------- /configs/application.yaml: -------------------------------------------------------------------------------- 1 | multi_gpu: False 2 | verbose: False 3 | 4 | is_train: False 5 | log: False 6 | 7 | iterations: 600 8 | 9 | batch_size: 1 10 | num_workers: 0 11 | 12 | dataset_dir: ./data/amass/generative 13 | save_dir: ./outputs/generative 14 | bvh_viz: False 15 | 16 | output_trans: False 17 | pretrained_gmp: gmp.yaml 18 | 19 | initialization: True 20 | learning_rate: 0.1 21 | geodesic_loss: True 22 | l1_loss: True 23 | lambda_rot: 1 24 | lambda_pos: 10 25 | lambda_orient: 1 26 | lambda_trans: 1 27 | dtw_loss: True 28 | lambda_dtw: 0.5 29 | lambda_angle: 1 30 | 31 | lambda_kl: 0.0001 32 | 33 | data: 34 | fps: 30 35 | clip_length: 128 36 | gender: male 37 | up: z 38 | root_transform: True 39 | normalize: ['pos', 'velocity', 'global_xform', 'angular', 'height', 'root_orient', 'root_vel'] 40 | 41 | local_prior: 42 | activation: tanh 43 | channel_base: 15 # 3(pos) + 3(vel) + 6(rot) + 3(ang) 44 | use_residual_blocks: True 45 | z_dim: 1024 46 | temporal_scale: 8 47 | kernel_size: 4 48 | num_layers: 4 49 | skeleton_dist: 2 50 | extra_conv: 0 51 | padding_mode: reflect 52 | skeleton_pool: mean 53 | upsampling: linear 54 | 55 | global_prior: 56 | activation: tanh 57 | in_channels: 6 58 | kernel_size: 4 59 | temporal_scale: 8 60 | z_dim: 256 61 | 62 | nemf: 63 | siren: False 64 | skip_connection: True 65 | norm_layer: True 66 | bandwidth: 7 67 | hidden_neuron: 1024 68 | local_z: 1024 69 | global_z: 256 70 | local_output: 144 71 | global_output: 6 72 | 73 | scheduler: 74 | step_size: 200 75 | gamma: 0.7 -------------------------------------------------------------------------------- /configs/basic.yaml: -------------------------------------------------------------------------------- 1 | multi_gpu: False 2 | verbose: True 3 | 4 | is_train: True 5 | log: True 6 | 7 | epoch_begin: 0 8 | iterations: 500 9 | 10 | amass_data: True 11 | dataset_dir: ./data/amass/single 12 | save_dir: ./outputs/basic 13 | bvh_viz: False 14 | 15 | learning_rate: 0.0001 16 | geodesic_loss: True 17 | l1_loss: True 18 | lambda_rotmat: 1 19 | lambda_pos: 10 20 | lambda_orient: 1 21 | lambda_v: 1 22 | lambda_up: 1 23 | lambda_trans: 1 24 | 25 | data: 26 | fps: 30 27 | up: z 28 | gender: male 29 | root_transform: True 30 | 31 | nemf: 32 | siren: False 33 | skip_connection: True 34 | norm_layer: True 35 | bandwidth: 7 36 | hidden_neuron: 1024 37 | local_z: 0 38 | global_z: 0 39 | local_output: 154 # 24 x 6 + 6 + 4 40 | global_output: 1 41 | 42 | scheduler: 43 | name: -------------------------------------------------------------------------------- /configs/basic_dog.yaml: -------------------------------------------------------------------------------- 1 | multi_gpu: False 2 | verbose: True 3 | 4 | is_train: True 5 | log: True 6 | 7 | epoch_begin: 0 8 | iterations: 8000 9 | 10 | amass_data: False 11 | dataset_dir: ./data/dog 12 | save_dir: ./outputs/dog 13 | bvh_viz: True 14 | 15 | learning_rate: 0.0001 16 | geodesic_loss: True 17 | l1_loss: True 18 | lambda_rotmat: 1 19 | lambda_pos: 10 20 | lambda_orient: 1 21 | lambda_v: 1 22 | lambda_up: 1 23 | lambda_trans: 1 24 | 25 | data: 26 | fps: 60 27 | up: y 28 | root_transform: True 29 | 30 | nemf: 31 | siren: False 32 | skip_connection: True 33 | norm_layer: True 34 | bandwidth: 7 35 | hidden_neuron: 1024 36 | local_z: 0 37 | global_z: 0 38 | local_output: 172 # 27 x 6 + 6 + 4 39 | global_output: 1 40 | 41 | scheduler: 42 | name: -------------------------------------------------------------------------------- /configs/dog_mocap.yaml: -------------------------------------------------------------------------------- 1 | dataset_dir: ./data/dog 2 | 3 | data: 4 | fps: 60 5 | up: y 6 | 7 | canonical: False 8 | unified_orientation: True -------------------------------------------------------------------------------- /configs/dog_skel.yaml: -------------------------------------------------------------------------------- 1 | offsets: [ 2 | [-0.100563, 0.077338, -4.725500], 3 | [ 0.000000, 0.000000, 0.000000], 4 | [ 0.190000, 0.000000, 0.000000], 5 | [ 0.225000, 0.006000, 0.000000], 6 | [ 0.140000, 0.000309, 0.000000], 7 | [ 0.170000, 0.000000, 0.000000], 8 | [ 0.198000, 0.037000, 0.043000], 9 | [ 0.080000, 0.000000, 0.000000], 10 | [ 0.152000, 0.000000, 0.000000], 11 | [ 0.178000, 0.000000, 0.000000], 12 | [ 0.072000, 0.000000, 0.000000], 13 | [ 0.198000, 0.037000, -0.043000], 14 | [ 0.080000, 0.000000, 0.001517], 15 | [ 0.152000, 0.000000, 0.000000], 16 | [ 0.178000, 0.000000, 0.000000], 17 | [ 0.072000, 0.000000, 0.000000], 18 | [ 0.059843, -0.076660, 0.047888], 19 | [ 0.160000, 0.000000, 0.000000], 20 | [ 0.180000, 0.000000, 0.000000], 21 | [ 0.000000, -0.108000, 0.000000], 22 | [ 0.059843, -0.076660, -0.047888], 23 | [ 0.160000, 0.000000, 0.000000], 24 | [ 0.180000, 0.000000, 0.000000], 25 | [ 0.000000, -0.108000, 0.000000], 26 | [ 0.068370, -0.007226, 0.000000], 27 | [ 0.120000, 0.000000, 0.000000], 28 | [ 0.120000, 0.000000, 0.000000] 29 | ] 30 | 31 | parents: [-1, 0, 1, 2, 3, 4, 2, 6, 7, 8, 9, 2, 11, 12, 13, 14, 0, 16, 17, 18, 0, 20, 21, 22, 0, 24, 25] 32 | 33 | joint_names: ['Hips', 'Spine', 'Spine1', 'Neck', 'Head', 'Head_End', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'LHand_End', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand', 'RHand_End', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LFoot_End', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RFoot_End', 'Tail', 'Tail1', 'Tail_End'] 34 | -------------------------------------------------------------------------------- /configs/evaluation.yaml: -------------------------------------------------------------------------------- 1 | multi_gpu: False 2 | verbose: True 3 | 4 | is_train: True 5 | log: True 6 | 7 | epoch_begin: 0 8 | epoch_num: 1000 9 | checkpoint: 100 10 | 11 | batch_size: 16 12 | num_workers: 0 13 | 14 | dataset_dir: ./data/amass/generative 15 | save_dir: ./outputs/prior 16 | bvh_viz: False 17 | 18 | output_trans: True 19 | pretrained_gmp: 20 | 21 | adam_optimizer: True 22 | learning_rate: 0.0001 23 | weight_decay: 0.0001 24 | geodesic_loss: True 25 | l1_loss: True 26 | lambda_rotmat: 1 27 | lambda_pos: 10 28 | lambda_orient: 1 29 | 30 | lambda_v: 1 31 | lambda_up: 1 32 | lambda_trans: 1 33 | lambda_contacts: 0.5 34 | lambda_kl: 0 35 | 36 | annealing_cycles: 50 37 | annealing_warmup: 25 38 | 39 | data: 40 | fps: 30 41 | clip_length: 128 42 | gender: male 43 | up: z 44 | root_transform: True 45 | normalize: ['pos', 'velocity', 'global_xform', 'angular', 'root_orient', 'root_vel'] 46 | 47 | local_prior: 48 | activation: tanh 49 | channel_base: 15 # 3(pos) + 3(vel) + 6(rot) + 3(ang) 50 | use_residual_blocks: True 51 | z_dim: 1024 52 | temporal_scale: 8 53 | kernel_size: 4 54 | num_layers: 4 55 | skeleton_dist: 2 56 | extra_conv: 0 57 | padding_mode: reflect 58 | skeleton_pool: mean 59 | upsampling: linear 60 | 61 | global_prior: 62 | activation: tanh 63 | in_channels: 9 # 6 (root orient) + 3 (root vel) 64 | kernel_size: 4 65 | temporal_scale: 8 66 | z_dim: 256 67 | 68 | nemf: 69 | siren: False 70 | skip_connection: True 71 | norm_layer: True 72 | bandwidth: 7 73 | hidden_neuron: 1024 74 | local_z: 1024 75 | global_z: 256 76 | local_output: 144 77 | global_output: 18 # 6 + 3 + 1 + 8 78 | 79 | 80 | scheduler: 81 | name: 82 | -------------------------------------------------------------------------------- /configs/generative.yaml: -------------------------------------------------------------------------------- 1 | multi_gpu: True 2 | verbose: True 3 | 4 | is_train: True 5 | log: True 6 | 7 | epoch_begin: 0 8 | epoch_num: 1000 9 | checkpoint: 100 10 | 11 | batch_size: 16 12 | num_workers: 0 13 | 14 | dataset_dir: ./data/amass/generative 15 | save_dir: ./outputs/generative 16 | bvh_viz: False 17 | 18 | output_trans: False 19 | # pretrained_gmp: gmp.yaml 20 | pretrained_gmp: 21 | 22 | adam_optimizer: False 23 | learning_rate: 0.0001 24 | weight_decay: 0.0001 25 | geodesic_loss: True 26 | l1_loss: True 27 | lambda_rotmat: 1 28 | lambda_pos: 10 29 | lambda_orient: 1 30 | 31 | lambda_v: 1 32 | lambda_up: 1 33 | lambda_trans: 1 34 | lambda_contacts: 0.5 35 | lambda_kl: 0.00001 36 | 37 | annealing_cycles: 50 38 | annealing_warmup: 25 39 | 40 | data: 41 | fps: 30 42 | clip_length: 128 43 | gender: male 44 | up: z 45 | root_transform: True 46 | normalize: ['pos', 'velocity', 'global_xform', 'angular', 'root_orient', 'root_vel'] 47 | 48 | local_prior: 49 | activation: tanh 50 | channel_base: 15 # 3(pos) + 3(vel) + 6(rot) + 3(ang) 51 | use_residual_blocks: True 52 | z_dim: 1024 53 | temporal_scale: 8 54 | kernel_size: 4 55 | num_layers: 4 56 | skeleton_dist: 2 57 | extra_conv: 0 58 | padding_mode: reflect 59 | skeleton_pool: mean 60 | upsampling: linear 61 | 62 | global_prior: 63 | activation: tanh 64 | in_channels: 6 65 | kernel_size: 4 66 | temporal_scale: 8 67 | z_dim: 256 68 | 69 | nemf: 70 | siren: False 71 | skip_connection: True 72 | norm_layer: True 73 | bandwidth: 7 74 | hidden_neuron: 1024 75 | local_z: 1024 76 | global_z: 256 77 | local_output: 144 78 | global_output: 6 79 | 80 | 81 | scheduler: 82 | name: 83 | -------------------------------------------------------------------------------- /configs/gmp.yaml: -------------------------------------------------------------------------------- 1 | multi_gpu: True 2 | verbose: True 3 | 4 | is_train: True 5 | log: True 6 | 7 | epoch_begin: 0 8 | epoch_num: 1000 9 | checkpoint: 100 10 | 11 | batch_size: 16 12 | num_workers: 0 13 | 14 | dataset_dir: ./data/amass/gmp 15 | save_dir: ./outputs/gmp 16 | 17 | learning_rate: 0.0001 18 | weight_decay: 0.0001 19 | l1_loss: True 20 | lambda_v: 1 21 | lambda_up: 1 22 | lambda_trans: 1 23 | lambda_contacts: 0.5 24 | 25 | data: 26 | fps: 30 27 | clip_length: 128 28 | gender: male 29 | up: z 30 | normalize: ['pos', 'velocity', 'rot6d', 'angular'] 31 | 32 | global_motion: 33 | activation: relu 34 | channel_base: 15 # 3(pos) + 3(vel) + 6(rot) + 3(ang) 35 | out_channels: 12 # 3 + 1 + 8 36 | use_residual_blocks: True 37 | kernel_size: 15 38 | num_layers: 3 39 | skeleton_dist: 1 40 | extra_conv: 0 41 | padding_mode: reflect 42 | skeleton_pool: mean 43 | 44 | scheduler: 45 | name: -------------------------------------------------------------------------------- /configs/smpl.yaml: -------------------------------------------------------------------------------- 1 | smpl_body_model: ./data/body_models/smpl 2 | 3 | offsets: 4 | male: [ 5 | [-0.002174, 0.972724, 0.028584], 6 | [ 0.058581, -0.082280, -0.017664], 7 | [ 0.043451, -0.386469, 0.008037], 8 | [-0.014790, -0.426874, -0.037428], 9 | [ 0.041054, -0.060286, 0.122042], 10 | [-0.060310, -0.090513, -0.013543], 11 | [-0.043257, -0.383688, -0.004843], 12 | [ 0.019056, -0.420046, -0.034562], 13 | [-0.034840, -0.062106, 0.130323], 14 | [ 0.004439, 0.124404, -0.038385], 15 | [ 0.004488, 0.137956, 0.026820], 16 | [-0.002265, 0.056032, 0.002855], 17 | [-0.013390, 0.211635, -0.033468], 18 | [ 0.010113, 0.088937, 0.050410], 19 | [ 0.071702, 0.114000, -0.018898], 20 | [ 0.122921, 0.045205, -0.019046], 21 | [ 0.255332, -0.015649, -0.022946], 22 | [ 0.265709, 0.012698, -0.007375], 23 | [ 0.086691, -0.010636, -0.015594], 24 | [-0.082954, 0.112472, -0.023707], 25 | [-0.113228, 0.046853, -0.008472], 26 | [-0.260127, -0.014369, -0.031269], 27 | [-0.269108, 0.006794, -0.006027], 28 | [-0.088754, -0.008652, -0.010107] 29 | ] 30 | female: [ 31 | [-0.000876, 0.909315, 0.027821], 32 | [ 0.071361, -0.089584, -0.008046], 33 | [ 0.030669, -0.364209, -0.006689], 34 | [-0.011554, -0.383348, -0.043502], 35 | [ 0.023338, -0.054645, 0.114370], 36 | [-0.069012, -0.088960, -0.004796], 37 | [-0.036152, -0.370650, -0.009185], 38 | [ 0.014029, -0.383638, -0.041892], 39 | [-0.022043, -0.046410, 0.117900], 40 | [-0.002508, 0.103257, -0.022185], 41 | [ 0.003581, 0.127658, -0.001713], 42 | [ 0.002027, 0.049072, 0.027867], 43 | [-0.001963, 0.208243, -0.049765], 44 | [ 0.003517, 0.062322, 0.050205], 45 | [ 0.075298, 0.117780, -0.036875], 46 | [ 0.085317, 0.031739, -0.007293], 47 | [ 0.251247, -0.011967, -0.027518], 48 | [ 0.238019, 0.009007, 0.000044], 49 | [ 0.079668, -0.009683, -0.013206], 50 | [-0.077033, 0.115606, -0.041811], 51 | [-0.089203, 0.032785, -0.009802], 52 | [-0.245990, -0.013152, -0.020162], 53 | [-0.245177, 0.008622, -0.003532], 54 | [-0.080400, -0.007248, -0.010419] 55 | ] 56 | 57 | parents: [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 17, 11, 19, 20, 21, 22] 58 | 59 | joint_names: ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Foot', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Foot', 'Spine1', 'Spine2', 'Spine3', 'Neck', 'Head', 'L_Collar', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Collar', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] 60 | 61 | joints_to_use: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 37] 62 | 63 | leaf_joints: [10, 11, 15, 22, 23] 64 | 65 | lfoot_index: [7, 10] 66 | 67 | rfoot_index: [8, 11] -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/data/.gitkeep -------------------------------------------------------------------------------- /outputs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/outputs/.gitkeep -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/nghorbani/human_body_prior 2 | git+https://github.com/nghorbani/configer 3 | 4 | autopep8 5 | pylint 6 | torch==1.9.0 7 | matplotlib 8 | plyfile 9 | scikit-learn 10 | tensorboard 11 | pandas 12 | tqdm 13 | pyyaml 14 | pybullet 15 | numba 16 | umap-learn[plot] 17 | transformations==2019.4.22 18 | torchgeometry==0.1.2 19 | setuptools==59.5.0 -------------------------------------------------------------------------------- /src/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import yaml 5 | 6 | 7 | class Struct: 8 | def __init__(self, **entries): 9 | self.__dict__.update(entries) 10 | 11 | 12 | class Arguments: 13 | def __init__(self, config_path, filename='default.yaml'): 14 | with open(os.path.join(config_path, 'smpl.yaml'), 'r') as f: 15 | smpl = yaml.safe_load(f) 16 | self.smpl = Struct(**smpl) 17 | self.smpl.offsets['male'] = np.array(self.smpl.offsets['male']) 18 | self.smpl.offsets['female'] = np.array(self.smpl.offsets['female']) 19 | self.smpl.parents = np.array(self.smpl.parents).astype(np.int32) 20 | self.smpl.joint_num = len(self.smpl.joints_to_use) 21 | self.smpl.joints_to_use = np.array(self.smpl.joints_to_use) 22 | self.smpl.joints_to_use = np.arange(0, 156).reshape((-1, 3))[self.smpl.joints_to_use].reshape(-1) 23 | 24 | with open(os.path.join(config_path, 'dog_skel.yaml'), 'r') as f: 25 | animal = yaml.safe_load(f) 26 | self.animal = Struct(**animal) 27 | self.animal.offsets = np.array(self.animal.offsets) 28 | self.animal.parents = np.array(self.animal.parents).astype(np.int32) 29 | self.animal.joint_num = len(self.animal.parents) 30 | 31 | self.filename = os.path.splitext(filename)[0] 32 | with open(os.path.join(config_path, filename), 'r') as f: 33 | config = yaml.safe_load(f) 34 | 35 | for key, value in config.items(): 36 | if isinstance(value, dict): 37 | setattr(self, key, Struct(**value)) 38 | else: 39 | setattr(self, key, value) 40 | 41 | self.json = config 42 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/animal.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import sys 4 | 5 | file_path = os.path.dirname(os.path.realpath(__file__)) 6 | sys.path.append(os.path.join(file_path, '..')) 7 | 8 | import holden.BVH as BVH 9 | import numpy as np 10 | import torch 11 | from arguments import Arguments 12 | from human_body_prior.tools.omni_tools import log2file, makepath 13 | from nemf.fk import ForwardKinematicsLayer 14 | from rotations import matrix_to_rotation_6d, quaternion_to_matrix, rotation_6d_to_matrix 15 | from torch.utils.data import Dataset 16 | from tqdm import tqdm 17 | from utils import build_canonical_frame, estimate_angular_velocity, estimate_linear_velocity 18 | 19 | 20 | def dump_animal2single(animal_data_dir, logger): 21 | fk = ForwardKinematicsLayer(parents=args.animal.parents, positions=args.animal.offsets) 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | bvh_fnames = glob.glob(os.path.join(animal_data_dir, '*.bvh')) 25 | index = 0 26 | for bvh_fname in tqdm(bvh_fnames): 27 | try: 28 | anim, _, ftime = BVH.load(bvh_fname) 29 | except: 30 | logger('Could not read %s! skipping..' % bvh_fname) 31 | continue 32 | 33 | N = len(anim.rotations) 34 | 35 | fps = np.around(1.0 / ftime).astype(np.int32) 36 | logger(f'{bvh_fname} frame: {N}\tFPS: {fps}') 37 | 38 | data_rotation = anim.rotations.qs 39 | data_translation = anim.positions[:, 0] / 100.0 40 | 41 | # insert identity quaternion 42 | data_rotation = np.insert(data_rotation, [5, 9, 13, 16, 19, 21], [1, 0, 0, 0], axis=1) 43 | poses = torch.from_numpy(np.asarray(data_rotation, np.float32)).to(device) # quaternion (T, J, 4) 44 | trans = torch.from_numpy(np.asarray(data_translation, np.float32)).to(device) # global translation (T, 3) 45 | 46 | # Compute necessary data for model training. 47 | rotmat = quaternion_to_matrix(poses) # rotation matrix (T, J, 3, 3) 48 | root_orient = rotmat[:, 0].clone() 49 | root_orient = matrix_to_rotation_6d(root_orient) # root orientation (T, 6) 50 | if args.unified_orientation: 51 | identity = torch.eye(3).cuda() 52 | identity = identity.view(1, 3, 3).repeat(rotmat.shape[0], 1, 1) 53 | rotmat[:, 0] = identity 54 | rot6d = matrix_to_rotation_6d(rotmat) # 6D rotation representation (T, J, 6) 55 | 56 | rot_seq = rotmat.clone() 57 | angular = estimate_angular_velocity(rot_seq.unsqueeze(0), dt=1.0 / args.data.fps).squeeze(0) # angular velocity of all the joints (T, J, 3) 58 | 59 | pos, global_xform = fk(rot6d) # local joint positions (T, J, 3), global transformation matrix for each joint (T, J, 4, 4) 60 | pos = pos.contiguous() 61 | global_xform = global_xform.contiguous() 62 | velocity = estimate_linear_velocity(pos.unsqueeze(0), dt=1.0 / args.data.fps).squeeze(0) # linear velocity of all the joints (T, J, 3) 63 | 64 | if args.unified_orientation: 65 | root_rotation = rotation_6d_to_matrix(root_orient) # (T, 3, 3) 66 | root_rotation = root_rotation.unsqueeze(1).repeat(1, args.animal.joint_num, 1, 1) # (T, J, 3, 3) 67 | global_pos = torch.matmul(root_rotation, pos.unsqueeze(-1)).squeeze(-1) 68 | height = global_pos + trans.unsqueeze(1) 69 | else: 70 | height = pos + trans.unsqueeze(1) 71 | height = height[..., 'xyz'.index(args.data.up)] # (T, J) 72 | root_vel = estimate_linear_velocity(trans.unsqueeze(0), dt=1.0 / args.data.fps).squeeze(0) # linear velocity of the root joint (T, 3) 73 | 74 | out_posepath = makepath(os.path.join(work_dir, 'train', f'trans_{index}.pt'), isfile=True) 75 | torch.save(trans.detach().cpu(), out_posepath) # (T, 3) 76 | torch.save(root_vel.detach().cpu(), out_posepath.replace('trans', 'root_vel')) # (T, 3) 77 | torch.save(pos.detach().cpu(), out_posepath.replace('trans', 'pos')) # (T, J, 3) 78 | torch.save(rotmat.detach().cpu(), out_posepath.replace('trans', 'rotmat')) # (T, J, 3, 3) 79 | torch.save(height.detach().cpu(), out_posepath.replace('trans', 'height')) # (T, J) 80 | 81 | if args.canonical: 82 | forward = rotmat[:, 0, :, 2].clone() 83 | canonical_frame = build_canonical_frame(forward, up_axis=args.data.up) 84 | root_rotation = canonical_frame.transpose(-2, -1) # (T, 3, 3) 85 | root_rotation = root_rotation.unsqueeze(1).repeat(1, args.animal.joint_num, 1, 1) # (T, J, 3, 3) 86 | 87 | if args.data.up == 'x': 88 | theta = torch.atan2(forward[..., 2], forward[..., 1]) 89 | elif args.data.up == 'y': 90 | theta = torch.atan2(forward[..., 0], forward[..., 2]) 91 | else: 92 | theta = torch.atan2(forward[..., 1], forward[..., 0]) 93 | dt = 1.0 / args.data.fps 94 | forward_ang = (theta[1:] - theta[:-1]) / dt 95 | forward_ang = torch.cat((forward_ang, forward_ang[-1:]), dim=-1) 96 | 97 | local_pos = torch.matmul(root_rotation, pos.unsqueeze(-1)).squeeze(-1) 98 | local_vel = torch.matmul(root_rotation, velocity.unsqueeze(-1)).squeeze(-1) 99 | local_rot = torch.matmul(root_rotation, global_xform[:, :, :3, :3]) 100 | local_rot = matrix_to_rotation_6d(local_rot) 101 | local_ang = torch.matmul(root_rotation, angular.unsqueeze(-1)).squeeze(-1) 102 | 103 | torch.save(forward.detach().cpu(), out_posepath.replace('trans', 'forward')) 104 | torch.save(forward_ang.detach().cpu(), out_posepath.replace('trans', 'forward_ang')) 105 | torch.save(local_pos.detach().cpu(), out_posepath.replace('trans', 'local_pos')) 106 | torch.save(local_vel.detach().cpu(), out_posepath.replace('trans', 'local_vel')) 107 | torch.save(local_rot.detach().cpu(), out_posepath.replace('trans', 'local_rot')) 108 | torch.save(local_ang.detach().cpu(), out_posepath.replace('trans', 'local_ang')) 109 | else: 110 | global_xform = global_xform[:, :, :3, :3] # (T, J, 3, 3) 111 | global_xform = matrix_to_rotation_6d(global_xform) # (T, J, 6) 112 | 113 | torch.save(rot6d.detach().cpu(), out_posepath.replace('trans', 'rot6d')) # (N, T, J, 6) 114 | torch.save(angular.detach().cpu(), out_posepath.replace('trans', 'angular')) # (N, T, J, 3) 115 | torch.save(global_xform.detach().cpu(), out_posepath.replace('trans', 'global_xform')) # (N, T, J, 6) 116 | torch.save(velocity.detach().cpu(), out_posepath.replace('trans', 'velocity')) # (N, T, J, 3) 117 | torch.save(root_orient.detach().cpu(), out_posepath.replace('trans', 'root_orient')) # (N, T, 3, 3) 118 | 119 | index += 1 120 | 121 | 122 | class Animal(Dataset): 123 | def __init__(self, dataset_dir): 124 | self.ds = {} 125 | for data_fname in glob.glob(os.path.join(dataset_dir, '*.pt')): 126 | k = os.path.basename(data_fname).split('-')[0] 127 | self.ds[k] = torch.load(data_fname) 128 | 129 | def __len__(self): 130 | return len(self.ds['trans']) 131 | 132 | def __getitem__(self, idx): 133 | data = {k: self.ds[k][idx] for k in self.ds.keys()} 134 | 135 | return data 136 | 137 | 138 | if __name__ == '__main__': 139 | animal_data_dir = '../AnimalData/' 140 | 141 | args = Arguments('./configs', filename='dog_mocap.yaml') 142 | work_dir = makepath(args.dataset_dir) 143 | 144 | log_name = os.path.join(work_dir, 'animal.log') 145 | if os.path.exists(log_name): 146 | os.remove(log_name) 147 | logger = log2file(log_name) 148 | 149 | logger('Start processing the animal mocap data ...') 150 | dump_animal2single(animal_data_dir, logger) 151 | -------------------------------------------------------------------------------- /src/holden/AnimationStructure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sparse 3 | import Animation as Animation 4 | 5 | 6 | """ Maya Functions """ 7 | 8 | def load_from_maya(root): 9 | """ 10 | Load joint parents and names from maya 11 | 12 | Parameters 13 | ---------- 14 | 15 | root : PyNode 16 | Root Maya Node 17 | 18 | Returns 19 | ------- 20 | 21 | (names, parents) : ([str], (J) ndarray) 22 | List of joint names and array 23 | of indices representing the parent 24 | joint for each joint J. 25 | 26 | Joint index -1 is used to represent 27 | that there is no parent joint 28 | """ 29 | 30 | import pymel.core as pm 31 | 32 | names = [] 33 | parents = [] 34 | 35 | def unload_joint(j, parents, par): 36 | 37 | id = len(names) 38 | names.append(j) 39 | parents.append(par) 40 | 41 | children = [c for c in j.getChildren() if 42 | isinstance(c, pm.nt.Transform) and 43 | not isinstance(c, pm.nt.Constraint) and 44 | not any(pm.listRelatives(c, s=True)) and 45 | (any(pm.listRelatives(c, ad=True, ap=False, type='joint')) or isinstance(c, pm.nt.Joint))] 46 | 47 | map(lambda c: unload_joint(c, parents, id), children) 48 | 49 | unload_joint(root, parents, -1) 50 | 51 | return (names, parents) 52 | 53 | 54 | """ Family Functions """ 55 | 56 | def joints(parents): 57 | """ 58 | Parameters 59 | ---------- 60 | 61 | parents : (J) ndarray 62 | parents array 63 | 64 | Returns 65 | ------- 66 | 67 | joints : (J) ndarray 68 | Array of joint indices 69 | """ 70 | return np.arange(len(parents), dtype=int) 71 | 72 | def joints_list(parents): 73 | """ 74 | Parameters 75 | ---------- 76 | 77 | parents : (J) ndarray 78 | parents array 79 | 80 | Returns 81 | ------- 82 | 83 | joints : [ndarray] 84 | List of arrays of joint idices for 85 | each joint 86 | """ 87 | return list(joints(parents)[:,np.newaxis]) 88 | 89 | def parents_list(parents): 90 | """ 91 | Parameters 92 | ---------- 93 | 94 | parents : (J) ndarray 95 | parents array 96 | 97 | Returns 98 | ------- 99 | 100 | parents : [ndarray] 101 | List of arrays of joint idices for 102 | the parents of each joint 103 | """ 104 | return list(parents[:,np.newaxis]) 105 | 106 | 107 | def children_list(parents): 108 | """ 109 | Parameters 110 | ---------- 111 | 112 | parents : (J) ndarray 113 | parents array 114 | 115 | Returns 116 | ------- 117 | 118 | children : [ndarray] 119 | List of arrays of joint indices for 120 | the children of each joint 121 | """ 122 | 123 | def joint_children(i): 124 | return [j for j, p in enumerate(parents) if p == i] 125 | 126 | return list(map(lambda j: np.array(joint_children(j)), joints(parents))) 127 | 128 | 129 | def descendants_list(parents): 130 | """ 131 | Parameters 132 | ---------- 133 | 134 | parents : (J) ndarray 135 | parents array 136 | 137 | Returns 138 | ------- 139 | 140 | descendants : [ndarray] 141 | List of arrays of joint idices for 142 | the descendants of each joint 143 | """ 144 | 145 | children = children_list(parents) 146 | 147 | def joint_descendants(i): 148 | return sum([joint_descendants(j) for j in children[i]], list(children[i])) 149 | 150 | return list(map(lambda j: np.array(joint_descendants(j)), joints(parents))) 151 | 152 | 153 | def ancestors_list(parents): 154 | """ 155 | Parameters 156 | ---------- 157 | 158 | parents : (J) ndarray 159 | parents array 160 | 161 | Returns 162 | ------- 163 | 164 | ancestors : [ndarray] 165 | List of arrays of joint idices for 166 | the ancestors of each joint 167 | """ 168 | 169 | decendants = descendants_list(parents) 170 | 171 | def joint_ancestors(i): 172 | return [j for j in joints(parents) if i in decendants[j]] 173 | 174 | return list(map(lambda j: np.array(joint_ancestors(j)), joints(parents))) 175 | 176 | 177 | """ Mask Functions """ 178 | 179 | def mask(parents, filter): 180 | """ 181 | Constructs a Mask for a give filter 182 | 183 | A mask is a (J, J) ndarray truth table for a given 184 | condition over J joints. For example there 185 | may be a mask specifying if a joint N is a 186 | child of another joint M. 187 | 188 | This could be constructed into a mask using 189 | `m = mask(parents, children_list)` and the condition 190 | of childhood tested using `m[N, M]`. 191 | 192 | Parameters 193 | ---------- 194 | 195 | parents : (J) ndarray 196 | parents array 197 | 198 | filter : (J) ndarray -> [ndarray] 199 | function that outputs a list of arrays 200 | of joint indices for some condition 201 | 202 | Returns 203 | ------- 204 | 205 | mask : (N, N) ndarray 206 | boolean truth table of given condition 207 | """ 208 | m = np.zeros((len(parents), len(parents))).astype(bool) 209 | jnts = joints(parents) 210 | fltr = filter(parents) 211 | for i,f in enumerate(fltr): m[i,:] = np.any(jnts[:,np.newaxis] == f[np.newaxis,:], axis=1) 212 | return m 213 | 214 | def joints_mask(parents): return np.eye(len(parents)).astype(bool) 215 | def children_mask(parents): return mask(parents, children_list) 216 | def parents_mask(parents): return mask(parents, parents_list) 217 | def descendants_mask(parents): return mask(parents, descendants_list) 218 | def ancestors_mask(parents): return mask(parents, ancestors_list) 219 | 220 | """ Search Functions """ 221 | 222 | def joint_chain_ascend(parents, start, end): 223 | chain = [] 224 | while start != end: 225 | chain.append(start) 226 | start = parents[start] 227 | chain.append(end) 228 | return np.array(chain, dtype=int) 229 | 230 | 231 | """ Constraints """ 232 | 233 | def constraints(anim, **kwargs): 234 | """ 235 | Constraint list for Animation 236 | 237 | This constraint list can be used in the 238 | VerletParticle solver to constrain 239 | a animation global joint positions. 240 | 241 | Parameters 242 | ---------- 243 | 244 | anim : Animation 245 | Input animation 246 | 247 | masses : (F, J) ndarray 248 | Optional list of masses 249 | for joints J across frames F 250 | defaults to weighting by 251 | vertical height 252 | 253 | Returns 254 | ------- 255 | 256 | constraints : [(int, int, (F, J) ndarray, (F, J) ndarray, (F, J) ndarray)] 257 | A list of constraints in the format: 258 | (Joint1, Joint2, Masses1, Masses2, Lengths) 259 | 260 | """ 261 | 262 | masses = kwargs.pop('masses', None) 263 | 264 | children = children_list(anim.parents) 265 | constraints = [] 266 | 267 | points_offsets = Animation.offsets_global(anim) 268 | points = Animation.positions_global(anim) 269 | 270 | if masses is None: 271 | masses = 1.0 / (0.1 + np.absolute(points_offsets[:,1])) 272 | masses = masses[np.newaxis].repeat(len(anim), axis=0) 273 | 274 | for j in xrange(anim.shape[1]): 275 | 276 | """ Add constraints between all joints and their children """ 277 | for c0 in children[j]: 278 | 279 | dists = np.sum((points[:, c0] - points[:, j])**2.0, axis=1)**0.5 280 | constraints.append((c0, j, masses[:,c0], masses[:,j], dists)) 281 | 282 | """ Add constraints between all children of joint """ 283 | for c1 in children[j]: 284 | if c0 == c1: continue 285 | 286 | dists = np.sum((points[:, c0] - points[:, c1])**2.0, axis=1)**0.5 287 | constraints.append((c0, c1, masses[:,c0], masses[:,c1], dists)) 288 | 289 | return constraints 290 | 291 | """ Graph Functions """ 292 | 293 | def graph(anim): 294 | """ 295 | Generates a weighted adjacency matrix 296 | using local joint distances along 297 | the skeletal structure. 298 | 299 | Joints which are not connected 300 | are assigned the weight `0`. 301 | 302 | Joints which actually have zero distance 303 | between them, but are still connected, are 304 | perturbed by some minimal amount. 305 | 306 | The output of this routine can be used 307 | with the `scipy.sparse.csgraph` 308 | routines for graph analysis. 309 | 310 | Parameters 311 | ---------- 312 | 313 | anim : Animation 314 | input animation 315 | 316 | Returns 317 | ------- 318 | 319 | graph : (N, N) ndarray 320 | weight adjacency matrix using 321 | local distances along the 322 | skeletal structure from joint 323 | N to joint M. If joints are not 324 | directly connected are assigned 325 | the weight `0`. 326 | """ 327 | 328 | graph = np.zeros(anim.shape[1], anim.shape[1]) 329 | lengths = np.sum(anim.offsets**2.0, axis=1)**0.5 + 0.001 330 | 331 | for i,p in enumerate(anim.parents): 332 | if p == -1: continue 333 | graph[i,p] = lengths[p] 334 | graph[p,i] = lengths[p] 335 | 336 | return graph 337 | 338 | 339 | def distances(anim): 340 | """ 341 | Generates a distance matrix for 342 | pairwise joint distances along 343 | the skeletal structure 344 | 345 | Parameters 346 | ---------- 347 | 348 | anim : Animation 349 | input animation 350 | 351 | Returns 352 | ------- 353 | 354 | distances : (N, N) ndarray 355 | array of pairwise distances 356 | along skeletal structure 357 | from some joint N to some 358 | joint M 359 | """ 360 | 361 | distances = np.zeros((anim.shape[1], anim.shape[1])) 362 | generated = distances.copy().astype(bool) 363 | 364 | joint_lengths = np.sum(anim.offsets**2.0, axis=1)**0.5 365 | joint_children = children_list(anim) 366 | joint_parents = parents_list(anim) 367 | 368 | def find_distance(distances, generated, prev, i, j): 369 | 370 | """ If root, identity, or already generated, return """ 371 | if j == -1: return (0.0, True) 372 | if j == i: return (0.0, True) 373 | if generated[i,j]: return (distances[i,j], True) 374 | 375 | """ Find best distances along parents and children """ 376 | par_dists = [(joint_lengths[j], find_distance(distances, generated, j, i, p)) for p in joint_parents[j] if p != prev] 377 | out_dists = [(joint_lengths[c], find_distance(distances, generated, j, i, c)) for c in joint_children[j] if c != prev] 378 | 379 | """ Check valid distance and not dead end """ 380 | par_dists = [a + d for (a, (d, f)) in par_dists if f] 381 | out_dists = [a + d for (a, (d, f)) in out_dists if f] 382 | 383 | """ All dead ends """ 384 | if (out_dists + par_dists) == []: return (0.0, False) 385 | 386 | """ Get minimum path """ 387 | dist = min(out_dists + par_dists) 388 | distances[i,j] = dist; distances[j,i] = dist 389 | generated[i,j] = True; generated[j,i] = True 390 | 391 | for i in xrange(anim.shape[1]): 392 | for j in xrange(anim.shape[1]): 393 | find_distance(distances, generated, -1, i, j) 394 | 395 | return distances 396 | 397 | def edges(parents): 398 | """ 399 | Animation structure edges 400 | 401 | Parameters 402 | ---------- 403 | 404 | parents : (J) ndarray 405 | parents array 406 | 407 | Returns 408 | ------- 409 | 410 | edges : (M, 2) ndarray 411 | array of pairs where each 412 | pair contains two indices of a joints 413 | which corrisponds to an edge in the 414 | joint structure going from parent to child. 415 | """ 416 | 417 | return np.array(list(zip(parents, joints(parents)))[1:]) 418 | 419 | 420 | def incidence(parents): 421 | """ 422 | Incidence Matrix 423 | 424 | Parameters 425 | ---------- 426 | 427 | parents : (J) ndarray 428 | parents array 429 | 430 | Returns 431 | ------- 432 | 433 | incidence : (N, M) ndarray 434 | 435 | Matrix of N joint positions by 436 | M edges which each entry is either 437 | 1 or -1 and multiplication by the 438 | joint positions returns the an 439 | array of vectors along each edge 440 | of the structure 441 | """ 442 | 443 | es = edges(parents) 444 | 445 | inc = np.zeros((len(parents)-1, len(parents))).astype(np.int) 446 | for i, e in enumerate(es): 447 | inc[i,e[0]] = 1 448 | inc[i,e[1]] = -1 449 | 450 | return inc.T 451 | -------------------------------------------------------------------------------- /src/holden/BVH.py: -------------------------------------------------------------------------------- 1 | from Quaternions import Quaternions 2 | from Animation import Animation 3 | import re 4 | import numpy as np 5 | import sys 6 | 7 | 8 | channelmap = { 9 | 'Xrotation': 'x', 10 | 'Yrotation': 'y', 11 | 'Zrotation': 'z' 12 | } 13 | 14 | channelmap_inv = { 15 | 'x': 'Xrotation', 16 | 'y': 'Yrotation', 17 | 'z': 'Zrotation', 18 | } 19 | 20 | ordermap = { 21 | 'x': 0, 22 | 'y': 1, 23 | 'z': 2, 24 | } 25 | 26 | 27 | def load(filename, start=None, end=None, order=None, world=False): 28 | """ 29 | Reads a BVH file and constructs an animation 30 | 31 | Parameters 32 | ---------- 33 | filename: str 34 | File to be opened 35 | 36 | start : int 37 | Optional Starting Frame 38 | 39 | end : int 40 | Optional Ending Frame 41 | 42 | order : str 43 | Optional Specifier for joint order. 44 | Given as string E.G 'xyz', 'zxy' 45 | 46 | world : bool 47 | If set to true euler angles are applied 48 | together in world space rather than local 49 | space 50 | 51 | Returns 52 | ------- 53 | 54 | (animation, joint_names, frametime) 55 | Tuple of loaded animation and joint names 56 | """ 57 | 58 | f = open(filename, "r") 59 | 60 | i = 0 61 | active = -1 62 | end_site = False 63 | 64 | names = [] 65 | orients = Quaternions.id(0) 66 | offsets = np.array([]).reshape((0, 3)) 67 | parents = np.array([], dtype=int) 68 | 69 | for line in f: 70 | 71 | if "HIERARCHY" in line: 72 | continue 73 | if "MOTION" in line: 74 | continue 75 | 76 | rmatch = re.match(r"ROOT (\w+)", line) 77 | if rmatch: 78 | names.append(rmatch.group(1)) 79 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 80 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 81 | parents = np.append(parents, active) 82 | active = (len(parents)-1) 83 | continue 84 | 85 | if "{" in line: 86 | continue 87 | 88 | if "}" in line: 89 | if end_site: 90 | end_site = False 91 | else: 92 | active = parents[active] 93 | continue 94 | 95 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 96 | if offmatch: 97 | if not end_site: 98 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 99 | continue 100 | 101 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 102 | if chanmatch: 103 | channels = int(chanmatch.group(1)) 104 | if order is None: 105 | channelis = 0 if channels == 3 else 3 106 | channelie = 3 if channels == 3 else 6 107 | parts = line.split()[2+channelis:2+channelie] 108 | if any([p not in channelmap for p in parts]): 109 | continue 110 | order = "".join([channelmap[p] for p in parts]) 111 | continue 112 | 113 | jmatch = re.match("\s*JOINT\s+(\w+)", line) 114 | if jmatch: 115 | names.append(jmatch.group(1)) 116 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 117 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 118 | parents = np.append(parents, active) 119 | active = (len(parents)-1) 120 | continue 121 | 122 | if "End Site" in line: 123 | end_site = True 124 | continue 125 | 126 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 127 | if fmatch: 128 | if start and end: 129 | fnum = (end - start)-1 130 | else: 131 | fnum = int(fmatch.group(1)) 132 | jnum = len(parents) 133 | # result: [fnum, J, 3] 134 | positions = offsets[np.newaxis].repeat(fnum, axis=0) 135 | # result: [fnum, len(orients), 3] 136 | rotations = np.zeros((fnum, len(orients), 3)) 137 | continue 138 | 139 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 140 | if fmatch: 141 | frametime = float(fmatch.group(1)) 142 | continue 143 | 144 | if (start and end) and (i < start or i >= end-1): 145 | i += 1 146 | continue 147 | 148 | dmatch = line.strip().split() 149 | if dmatch: 150 | data_block = np.array(list(map(float, dmatch))) 151 | N = len(parents) 152 | fi = i - start if start else i 153 | if channels == 3: 154 | # This should be root positions[0:1] & all rotations 155 | positions[fi, 0:1] = data_block[0:3] 156 | rotations[fi, :] = data_block[3:].reshape(N, 3) 157 | elif channels == 6: 158 | data_block = data_block.reshape(N, 6) 159 | # fill in all positions 160 | positions[fi, :] = data_block[:, 0:3] 161 | rotations[fi, :] = data_block[:, 3:6] 162 | elif channels == 9: 163 | positions[fi, 0] = data_block[0:3] 164 | data_block = data_block[3:].reshape(N-1, 9) 165 | rotations[fi, 1:] = data_block[:, 3:6] 166 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9] 167 | else: 168 | raise Exception("Too many channels! %i" % channels) 169 | 170 | i += 1 171 | 172 | f.close() 173 | 174 | rotations = Quaternions.from_euler(np.radians(rotations), order=order, world=world) 175 | 176 | return (Animation(rotations, positions, orients, offsets, parents), names, frametime) 177 | 178 | 179 | def load_bfa(filename, start=None, end=None, order=None, world=False): 180 | """ 181 | Reads a BVH file and constructs an animation 182 | 183 | !!! Read from bfa, will replace the end sites of arms by two joints (w/ unit rotation) 184 | 185 | Parameters 186 | ---------- 187 | filename: str 188 | File to be opened 189 | 190 | start : int 191 | Optional Starting Frame 192 | 193 | end : int 194 | Optional Ending Frame 195 | 196 | order : str 197 | Optional Specifier for joint order. 198 | Given as string E.G 'xyz', 'zxy' 199 | 200 | world : bool 201 | If set to true euler angles are applied 202 | together in world space rather than local 203 | space 204 | 205 | Returns 206 | ------- 207 | 208 | (animation, joint_names, frametime) 209 | Tuple of loaded animation and joint names 210 | """ 211 | 212 | f = open(filename, "r") 213 | 214 | i = 0 215 | active = -1 216 | end_site = False 217 | 218 | hand_idx = [9, 14] 219 | 220 | names = [] 221 | orients = Quaternions.id(0) 222 | offsets = np.array([]).reshape((0, 3)) 223 | parents = np.array([], dtype=int) 224 | 225 | for line in f: 226 | 227 | if "HIERARCHY" in line: 228 | continue 229 | if "MOTION" in line: 230 | continue 231 | 232 | rmatch = re.match(r"ROOT (\w+)", line) 233 | if rmatch: 234 | names.append(rmatch.group(1)) 235 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 236 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 237 | parents = np.append(parents, active) 238 | active = (len(parents)-1) 239 | continue 240 | 241 | if "{" in line: 242 | continue 243 | 244 | if "}" in line: 245 | if end_site: 246 | end_site = False 247 | else: 248 | active = parents[active] 249 | continue 250 | 251 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 252 | if offmatch: 253 | if not end_site: 254 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 255 | """ 256 | else: 257 | print("active = ", active) 258 | if active in hand_idx: 259 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 260 | """ 261 | continue 262 | 263 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 264 | if chanmatch: 265 | channels = int(chanmatch.group(1)) 266 | if order is None: 267 | channelis = 0 if channels == 3 else 3 268 | channelie = 3 if channels == 3 else 6 269 | parts = line.split()[2+channelis:2+channelie] 270 | if any([p not in channelmap for p in parts]): 271 | continue 272 | order = "".join([channelmap[p] for p in parts]) 273 | continue 274 | 275 | jmatch = re.match("\s*JOINT\s+(\w+)", line) 276 | if jmatch: 277 | names.append(jmatch.group(1)) 278 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 279 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 280 | parents = np.append(parents, active) 281 | active = (len(parents)-1) 282 | continue 283 | 284 | if "End Site" in line: 285 | if active + 1 in hand_idx: 286 | print("parent:", names[-1]) 287 | name = "LeftHandIndex" if active + 1 == hand_idx[0] else "RightHandIndex" 288 | names.append(name) 289 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 290 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 291 | parents = np.append(parents, active) 292 | active = (len(parents)-1) 293 | else: 294 | end_site = True 295 | continue 296 | 297 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 298 | if fmatch: 299 | if start and end: 300 | fnum = (end - start)-1 301 | else: 302 | fnum = int(fmatch.group(1)) 303 | jnum = len(parents) 304 | # result: [fnum, J, 3] 305 | positions = offsets[np.newaxis].repeat(fnum, axis=0) 306 | # result: [fnum, len(orients), 3] 307 | rotations = np.zeros((fnum, len(orients), 3)) 308 | continue 309 | 310 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 311 | if fmatch: 312 | frametime = float(fmatch.group(1)) 313 | continue 314 | 315 | if (start and end) and (i < start or i >= end-1): 316 | i += 1 317 | continue 318 | 319 | dmatch = line.strip().split() 320 | if dmatch: 321 | data_block = np.array(list(map(float, dmatch))) 322 | N = len(parents) 323 | fi = i - start if start else i 324 | if channels == 3: 325 | # This should be root positions[0:1] & all rotations 326 | positions[fi, 0:1] = data_block[0:3] 327 | tmp = data_block[3:].reshape(N - 2, 3) 328 | tmp = np.concatenate([tmp[:hand_idx[0]], 329 | np.array([[0, 0, 0]]), 330 | tmp[hand_idx[0]: hand_idx[1] - 1], 331 | np.array([[0, 0, 0]]), 332 | tmp[hand_idx[1] - 1:]], axis=0) 333 | rotations[fi, :] = tmp.reshape(N, 3) 334 | elif channels == 6: 335 | data_block = data_block.reshape(N, 6) 336 | # fill in all positions 337 | positions[fi, :] = data_block[:, 0:3] 338 | rotations[fi, :] = data_block[:, 3:6] 339 | elif channels == 9: 340 | positions[fi, 0] = data_block[0:3] 341 | data_block = data_block[3:].reshape(N-1, 9) 342 | rotations[fi, 1:] = data_block[:, 3:6] 343 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9] 344 | else: 345 | raise Exception("Too many channels! %i" % channels) 346 | 347 | i += 1 348 | 349 | f.close() 350 | 351 | rotations = Quaternions.from_euler(np.radians(rotations), order=order, world=world) 352 | 353 | return (Animation(rotations, positions, orients, offsets, parents), names, frametime) 354 | 355 | 356 | def save(filename, anim, names=None, frametime=1.0/24.0, order='zyx', positions=False, orients=True): 357 | """ 358 | Saves an Animation to file as BVH 359 | 360 | Parameters 361 | ---------- 362 | filename: str 363 | File to be saved to 364 | 365 | anim : Animation 366 | Animation to save 367 | 368 | names : [str] 369 | List of joint names 370 | 371 | order : str 372 | Optional Specifier for joint order. 373 | Given as string E.G 'xyz', 'zxy' 374 | 375 | frametime : float 376 | Optional Animation Frame time 377 | 378 | positions : bool 379 | Optional specfier to save bone 380 | positions for each frame 381 | 382 | orients : bool 383 | Multiply joint orients to the rotations 384 | before saving. 385 | 386 | """ 387 | 388 | if names is None: 389 | names = ["joint_" + str(i) for i in range(len(anim.parents))] 390 | 391 | with open(filename, 'w') as f: 392 | 393 | t = "" 394 | f.write("%sHIERARCHY\n" % t) 395 | f.write("%sROOT %s\n" % (t, names[0])) 396 | f.write("%s{\n" % t) 397 | t += '\t' 398 | 399 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[0, 0], anim.offsets[0, 1], anim.offsets[0, 2])) 400 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % 401 | (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 402 | 403 | for i in range(anim.shape[1]): 404 | if anim.parents[i] == 0: 405 | t = save_joint(f, anim, names, t, i, order=order, positions=positions) 406 | 407 | t = t[:-1] 408 | f.write("%s}\n" % t) 409 | 410 | f.write("MOTION\n") 411 | f.write("Frames: %i\n" % anim.shape[0]) 412 | f.write("Frame Time: %f\n" % frametime) 413 | 414 | # if orients: 415 | # rots = np.degrees((-anim.orients[np.newaxis] * anim.rotations).euler(order=order[::-1])) 416 | # else: 417 | # rots = np.degrees(anim.rotations.euler(order=order[::-1])) 418 | rots = np.degrees(anim.rotations.euler(order=order[::-1])) 419 | poss = anim.positions 420 | 421 | for i in range(anim.shape[0]): 422 | for j in range(anim.shape[1]): 423 | 424 | if positions or j == 0: 425 | 426 | f.write("%f %f %f %f %f %f " % ( 427 | poss[i, j, 0], poss[i, j, 1], poss[i, j, 2], 428 | rots[i, j, ordermap[order[0]]], rots[i, j, ordermap[order[1]]], rots[i, j, ordermap[order[2]]])) 429 | 430 | else: 431 | 432 | f.write("%f %f %f " % ( 433 | rots[i, j, ordermap[order[0]]], rots[i, j, ordermap[order[1]]], rots[i, j, ordermap[order[2]]])) 434 | 435 | f.write("\n") 436 | 437 | 438 | def save_joint(f, anim, names, t, i, order='zyx', positions=False): 439 | 440 | f.write("%sJOINT %s\n" % (t, names[i])) 441 | f.write("%s{\n" % t) 442 | t += '\t' 443 | 444 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[i, 0], anim.offsets[i, 1], anim.offsets[i, 2])) 445 | 446 | if positions: 447 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t, 448 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 449 | else: 450 | f.write("%sCHANNELS 3 %s %s %s\n" % (t, 451 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 452 | 453 | end_site = True 454 | 455 | for j in range(anim.shape[1]): 456 | if anim.parents[j] == i: 457 | t = save_joint(f, anim, names, t, j, order=order, positions=positions) 458 | end_site = False 459 | 460 | if end_site: 461 | f.write("%sEnd Site\n" % t) 462 | f.write("%s{\n" % t) 463 | t += '\t' 464 | f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0)) 465 | t = t[:-1] 466 | f.write("%s}\n" % t) 467 | 468 | t = t[:-1] 469 | f.write("%s}\n" % t) 470 | 471 | return t 472 | -------------------------------------------------------------------------------- /src/holden/Pivots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from Quaternions import Quaternions 4 | 5 | 6 | class Pivots: 7 | """ 8 | Pivots is an ndarray of angular rotations 9 | 10 | This wrapper provides some functions for 11 | working with pivots. 12 | 13 | These are particularly useful as a number 14 | of atomic operations (such as adding or 15 | subtracting) cannot be achieved using 16 | the standard arithmatic and need to be 17 | defined differently to work correctly 18 | """ 19 | 20 | def __init__(self, ps): self.ps = np.array(ps) 21 | def __str__(self): return "Pivots(" + str(self.ps) + ")" 22 | def __repr__(self): return "Pivots(" + repr(self.ps) + ")" 23 | 24 | def __add__(self, other): return Pivots(np.arctan2(np.sin(self.ps + other.ps), np.cos(self.ps + other.ps))) 25 | def __sub__(self, other): return Pivots(np.arctan2(np.sin(self.ps - other.ps), np.cos(self.ps - other.ps))) 26 | def __mul__(self, other): return Pivots(self.ps * other.ps) 27 | def __div__(self, other): return Pivots(self.ps / other.ps) 28 | def __mod__(self, other): return Pivots(self.ps % other.ps) 29 | def __pow__(self, other): return Pivots(self.ps ** other.ps) 30 | 31 | def __lt__(self, other): return self.ps < other.ps 32 | def __le__(self, other): return self.ps <= other.ps 33 | def __eq__(self, other): return self.ps == other.ps 34 | def __ne__(self, other): return self.ps != other.ps 35 | def __ge__(self, other): return self.ps >= other.ps 36 | def __gt__(self, other): return self.ps > other.ps 37 | 38 | def __abs__(self): return Pivots(abs(self.ps)) 39 | def __neg__(self): return Pivots(-self.ps) 40 | 41 | def __iter__(self): return iter(self.ps) 42 | def __len__(self): return len(self.ps) 43 | 44 | def __getitem__(self, k): return Pivots(self.ps[k]) 45 | def __setitem__(self, k, v): self.ps[k] = v.ps 46 | 47 | def _ellipsis(self): return tuple(map(lambda x: slice(None), self.shape)) 48 | 49 | def quaternions(self, plane='xz'): 50 | fa = self._ellipsis() 51 | axises = np.ones(self.ps.shape + (3,)) 52 | axises[fa + ("xyz".index(plane[0]),)] = 0.0 53 | axises[fa + ("xyz".index(plane[1]),)] = 0.0 54 | return Quaternions.from_angle_axis(self.ps, axises) 55 | 56 | def directions(self, plane='xz'): 57 | dirs = np.zeros((len(self.ps), 3)) 58 | dirs[..., "xyz".index(plane[0])] = np.sin(self.ps) 59 | dirs[..., "xyz".index(plane[1])] = np.cos(self.ps) 60 | return dirs 61 | 62 | def normalized(self): 63 | xs = np.copy(self.ps) 64 | while np.any(xs > np.pi): 65 | xs[xs > np.pi] = xs[xs > np.pi] - 2 * np.pi 66 | while np.any(xs < -np.pi): 67 | xs[xs < -np.pi] = xs[xs < -np.pi] + 2 * np.pi 68 | return Pivots(xs) 69 | 70 | def interpolate(self, ws): 71 | dir = np.average(self.directions, weights=ws, axis=0) 72 | return np.arctan2(dir[2], dir[0]) 73 | 74 | def copy(self): 75 | return Pivots(np.copy(self.ps)) 76 | 77 | @property 78 | def shape(self): 79 | return self.ps.shape 80 | 81 | @classmethod 82 | def from_quaternions(cls, qs, forward='z', plane='xz'): 83 | ds = np.zeros(qs.shape + (3,)) 84 | ds[..., 'xyz'.index(forward)] = 1.0 85 | return Pivots.from_directions(qs * ds, plane=plane) 86 | 87 | @classmethod 88 | def from_directions(cls, ds, plane='xz'): 89 | ys = ds[..., 'xyz'.index(plane[0])] 90 | xs = ds[..., 'xyz'.index(plane[1])] 91 | return Pivots(np.arctan2(ys, xs)) 92 | -------------------------------------------------------------------------------- /src/holden/Quaternions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Quaternions: 5 | """ 6 | Quaternions is a wrapper around a numpy ndarray 7 | that allows it to act as if it were an narray of 8 | a quaternion data type. 9 | 10 | Therefore addition, subtraction, multiplication, 11 | division, negation, absolute, are all defined 12 | in terms of quaternion operations such as quaternion 13 | multiplication. 14 | 15 | This allows for much neater code and many routines 16 | which conceptually do the same thing to be written 17 | in the same way for point data and for rotation data. 18 | 19 | The Quaternions class has been desgined such that it 20 | should support broadcasting and slicing in all of the 21 | usual ways. 22 | """ 23 | 24 | def __init__(self, qs): 25 | if isinstance(qs, np.ndarray): 26 | 27 | if len(qs.shape) == 1: 28 | qs = np.array([qs]) 29 | self.qs = qs 30 | return 31 | 32 | if isinstance(qs, Quaternions): 33 | self.qs = qs.qs 34 | return 35 | 36 | raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs)) 37 | 38 | def __str__(self): return "Quaternions(" + str(self.qs) + ")" 39 | def __repr__(self): return "Quaternions(" + repr(self.qs) + ")" 40 | """ Helper Methods for Broadcasting and Data extraction """ 41 | 42 | @classmethod 43 | def _broadcast(cls, sqs, oqs, scalar=False): 44 | if isinstance(oqs, float): 45 | return sqs, oqs * np.ones(sqs.shape[:-1]) 46 | 47 | ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1]) 48 | os = np.array(oqs.shape) 49 | 50 | if len(ss) != len(os): 51 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) 52 | 53 | if np.all(ss == os): 54 | return sqs, oqs 55 | 56 | if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))): 57 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) 58 | 59 | sqsn, oqsn = sqs.copy(), oqs.copy() 60 | 61 | for a in np.where(ss == 1)[0]: 62 | sqsn = sqsn.repeat(os[a], axis=a) 63 | for a in np.where(os == 1)[0]: 64 | oqsn = oqsn.repeat(ss[a], axis=a) 65 | 66 | return sqsn, oqsn 67 | 68 | """ Adding Quaterions is just Defined as Multiplication """ 69 | 70 | def __add__(self, other): return self * other 71 | def __sub__(self, other): return self / other 72 | """ Quaterion Multiplication """ 73 | 74 | def __mul__(self, other): 75 | """ 76 | Quaternion multiplication has three main methods. 77 | 78 | When multiplying a Quaternions array by Quaternions 79 | normal quaternion multiplication is performed. 80 | 81 | When multiplying a Quaternions array by a vector 82 | array of the same shape, where the last axis is 3, 83 | it is assumed to be a Quaternion by 3D-Vector 84 | multiplication and the 3D-Vectors are rotated 85 | in space by the Quaternions. 86 | 87 | When multipplying a Quaternions array by a scalar 88 | or vector of different shape it is assumed to be 89 | a Quaternions by Scalars multiplication and the 90 | Quaternions are scaled using Slerp and the identity 91 | quaternions. 92 | """ 93 | 94 | """ If Quaternions type do Quaternions * Quaternions """ 95 | if isinstance(other, Quaternions): 96 | sqs, oqs = Quaternions._broadcast(self.qs, other.qs) 97 | 98 | q0 = sqs[..., 0] 99 | q1 = sqs[..., 1] 100 | q2 = sqs[..., 2] 101 | q3 = sqs[..., 3] 102 | r0 = oqs[..., 0] 103 | r1 = oqs[..., 1] 104 | r2 = oqs[..., 2] 105 | r3 = oqs[..., 3] 106 | 107 | qs = np.empty(sqs.shape) 108 | qs[..., 0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3 109 | qs[..., 1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2 110 | qs[..., 2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1 111 | qs[..., 3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0 112 | 113 | return Quaternions(qs) 114 | 115 | """ If array type do Quaternions * Vectors """ 116 | if isinstance(other, np.ndarray) and other.shape[-1] == 3: 117 | vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1)) 118 | 119 | return (self * (vs * -self)).imaginaries 120 | 121 | """ If float do Quaternions * Scalars """ 122 | if isinstance(other, np.ndarray) or isinstance(other, float): 123 | return Quaternions.slerp(Quaternions.id_like(self), self, other) 124 | 125 | raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other))) 126 | 127 | def __div__(self, other): 128 | """ 129 | When a Quaternion type is supplied, division is defined 130 | as multiplication by the inverse of that Quaternion. 131 | 132 | When a scalar or vector is supplied it is defined 133 | as multiplicaion of one over the supplied value. 134 | Essentially a scaling. 135 | """ 136 | 137 | if isinstance(other, Quaternions): 138 | return self * (-other) 139 | if isinstance(other, np.ndarray): 140 | return self * (1.0 / other) 141 | if isinstance(other, float): 142 | return self * (1.0 / other) 143 | raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other))) 144 | 145 | def __eq__(self, other): return self.qs == other.qs 146 | def __ne__(self, other): return self.qs != other.qs 147 | 148 | def __neg__(self): 149 | """ Invert Quaternions """ 150 | return Quaternions(self.qs * np.array([[1, -1, -1, -1]])) 151 | 152 | def __abs__(self): 153 | """ Unify Quaternions To Single Pole """ 154 | qabs = self.normalized().copy() 155 | top = np.sum((qabs.qs) * np.array([1, 0, 0, 0]), axis=-1) 156 | bot = np.sum((-qabs.qs) * np.array([1, 0, 0, 0]), axis=-1) 157 | qabs.qs[top < bot] = -qabs.qs[top < bot] 158 | return qabs 159 | 160 | def __iter__(self): return iter(self.qs) 161 | def __len__(self): return len(self.qs) 162 | 163 | def __getitem__(self, k): return Quaternions(self.qs[k]) 164 | def __setitem__(self, k, v): self.qs[k] = v.qs 165 | 166 | @property 167 | def lengths(self): 168 | return np.sum(self.qs**2.0, axis=-1)**0.5 169 | 170 | @property 171 | def reals(self): 172 | return self.qs[..., 0] 173 | 174 | @property 175 | def imaginaries(self): 176 | return self.qs[..., 1:4] 177 | 178 | @property 179 | def shape(self): return self.qs.shape[:-1] 180 | 181 | def repeat(self, n, **kwargs): 182 | return Quaternions(self.qs.repeat(n, **kwargs)) 183 | 184 | def normalized(self): 185 | return Quaternions(self.qs / self.lengths[..., np.newaxis]) 186 | 187 | def log(self): 188 | norm = abs(self.normalized()) 189 | imgs = norm.imaginaries 190 | lens = np.sqrt(np.sum(imgs**2, axis=-1)) 191 | lens = np.arctan2(lens, norm.reals) / (lens + 1e-10) 192 | return imgs * lens[..., np.newaxis] 193 | 194 | def constrained(self, axis): 195 | 196 | rl = self.reals 197 | im = np.sum(axis * self.imaginaries, axis=-1) 198 | 199 | t1 = -2 * np.arctan2(rl, im) + np.pi 200 | t2 = -2 * np.arctan2(rl, im) - np.pi 201 | 202 | top = Quaternions.exp(axis[np.newaxis] * (t1[:, np.newaxis] / 2.0)) 203 | bot = Quaternions.exp(axis[np.newaxis] * (t2[:, np.newaxis] / 2.0)) 204 | img = self.dot(top) > self.dot(bot) 205 | 206 | ret = top.copy() 207 | ret[img] = top[img] 208 | ret[~img] = bot[~img] 209 | return ret 210 | 211 | def constrained_x(self): return self.constrained(np.array([1, 0, 0])) 212 | def constrained_y(self): return self.constrained(np.array([0, 1, 0])) 213 | def constrained_z(self): return self.constrained(np.array([0, 0, 1])) 214 | 215 | def dot(self, q): return np.sum(self.qs * q.qs, axis=-1) 216 | 217 | def copy(self): return Quaternions(np.copy(self.qs)) 218 | 219 | def reshape(self, s): 220 | self.qs.reshape(s) 221 | return self 222 | 223 | def interpolate(self, ws): 224 | return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws)) 225 | 226 | def euler(self, order='xyz'): 227 | 228 | q = self.normalized().qs 229 | q0 = q[..., 0] 230 | q1 = q[..., 1] 231 | q2 = q[..., 2] 232 | q3 = q[..., 3] 233 | es = np.zeros(self.shape + (3,)) 234 | 235 | if order == 'xyz': 236 | es[..., 0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 237 | es[..., 1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1, 1)) 238 | es[..., 2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 239 | elif order == 'yzx': 240 | es[..., 0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0) 241 | es[..., 1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0) 242 | es[..., 2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1, 1)) 243 | else: 244 | raise NotImplementedError('Cannot convert from ordering %s' % order) 245 | 246 | """ 247 | 248 | # These conversion don't appear to work correctly for Maya. 249 | # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/ 250 | 251 | if order == 'xyz': 252 | es[fa + (0,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 253 | es[fa + (1,)] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1)) 254 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 255 | elif order == 'yzx': 256 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 257 | es[fa + (1,)] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1)) 258 | es[fa + (2,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 259 | elif order == 'zxy': 260 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 261 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1)) 262 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 263 | elif order == 'xzy': 264 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 265 | es[fa + (1,)] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1)) 266 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 267 | elif order == 'yxz': 268 | es[fa + (0,)] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 269 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1)) 270 | es[fa + (2,)] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 271 | elif order == 'zyx': 272 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 273 | es[fa + (1,)] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1)) 274 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 275 | else: 276 | raise KeyError('Unknown ordering %s' % order) 277 | 278 | """ 279 | 280 | # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp 281 | # Use this class and convert from matrix 282 | 283 | return es 284 | 285 | def average(self): 286 | 287 | if len(self.shape) == 1: 288 | 289 | import numpy.core.umath_tests as ut 290 | system = ut.matrix_multiply(self.qs[:, :, np.newaxis], self.qs[:, np.newaxis, :]).sum(axis=0) 291 | w, v = np.linalg.eigh(system) 292 | qiT_dot_qref = (self.qs[:, :, np.newaxis] * v[np.newaxis, :, :]).sum(axis=1) 293 | return Quaternions(v[:, np.argmin((1.-qiT_dot_qref**2).sum(axis=0))]) 294 | 295 | else: 296 | 297 | raise NotImplementedError('Cannot average multi-dimensionsal Quaternions') 298 | 299 | def angle_axis(self): 300 | 301 | norm = self.normalized() 302 | s = np.sqrt(1 - (norm.reals**2.0)) 303 | s[s == 0] = 0.001 304 | 305 | angles = 2.0 * np.arccos(norm.reals) 306 | axis = norm.imaginaries / s[..., np.newaxis] 307 | 308 | return angles, axis 309 | 310 | def transforms(self): 311 | 312 | qw = self.qs[..., 0] 313 | qx = self.qs[..., 1] 314 | qy = self.qs[..., 2] 315 | qz = self.qs[..., 3] 316 | 317 | x2 = qx + qx 318 | y2 = qy + qy 319 | z2 = qz + qz 320 | xx = qx * x2 321 | yy = qy * y2 322 | wx = qw * x2 323 | xy = qx * y2 324 | yz = qy * z2 325 | wy = qw * y2 326 | xz = qx * z2 327 | zz = qz * z2 328 | wz = qw * z2 329 | 330 | m = np.empty(self.shape + (3, 3)) 331 | m[..., 0, 0] = 1.0 - (yy + zz) 332 | m[..., 0, 1] = xy - wz 333 | m[..., 0, 2] = xz + wy 334 | m[..., 1, 0] = xy + wz 335 | m[..., 1, 1] = 1.0 - (xx + zz) 336 | m[..., 1, 2] = yz - wx 337 | m[..., 2, 0] = xz - wy 338 | m[..., 2, 1] = yz + wx 339 | m[..., 2, 2] = 1.0 - (xx + yy) 340 | 341 | return m 342 | 343 | def ravel(self): 344 | return self.qs.ravel() 345 | 346 | @classmethod 347 | def id(cls, n): 348 | 349 | if isinstance(n, tuple): 350 | qs = np.zeros(n + (4,)) 351 | qs[..., 0] = 1.0 352 | return Quaternions(qs) 353 | 354 | if isinstance(n, int): 355 | qs = np.zeros((n, 4)) 356 | qs[:, 0] = 1.0 357 | return Quaternions(qs) 358 | 359 | raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n))) 360 | 361 | @classmethod 362 | def id_like(cls, a): 363 | qs = np.zeros(a.shape + (4,)) 364 | qs[..., 0] = 1.0 365 | return Quaternions(qs) 366 | 367 | @classmethod 368 | def exp(cls, ws): 369 | 370 | ts = np.sum(ws**2.0, axis=-1)**0.5 371 | ts[ts == 0] = 0.001 372 | ls = np.sin(ts) / ts 373 | 374 | qs = np.empty(ws.shape[:-1] + (4,)) 375 | qs[..., 0] = np.cos(ts) 376 | qs[..., 1] = ws[..., 0] * ls 377 | qs[..., 2] = ws[..., 1] * ls 378 | qs[..., 3] = ws[..., 2] * ls 379 | 380 | return Quaternions(qs).normalized() 381 | 382 | @classmethod 383 | def slerp(cls, q0s, q1s, a): 384 | 385 | fst, snd = cls._broadcast(q0s.qs, q1s.qs) 386 | fst, a = cls._broadcast(fst, a, scalar=True) 387 | snd, a = cls._broadcast(snd, a, scalar=True) 388 | 389 | len = np.sum(fst * snd, axis=-1) 390 | 391 | neg = len < 0.0 392 | len[neg] = -len[neg] 393 | snd[neg] = -snd[neg] 394 | 395 | amount0 = np.zeros(a.shape) 396 | amount1 = np.zeros(a.shape) 397 | 398 | linear = (1.0 - len) < 0.01 399 | omegas = np.arccos(len[~linear]) 400 | sinoms = np.sin(omegas) 401 | 402 | amount0[linear] = 1.0 - a[linear] 403 | amount1[linear] = a[linear] 404 | amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms 405 | amount1[~linear] = np.sin(a[~linear] * omegas) / sinoms 406 | 407 | return Quaternions( 408 | amount0[..., np.newaxis] * fst + 409 | amount1[..., np.newaxis] * snd) 410 | 411 | @classmethod 412 | def between(cls, v0s, v1s): 413 | a = np.cross(v0s, v1s) 414 | w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1) 415 | return Quaternions(np.concatenate([w[..., np.newaxis], a], axis=-1)).normalized() 416 | 417 | @classmethod 418 | def from_angle_axis(cls, angles, axis): 419 | axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[..., np.newaxis] 420 | sines = np.sin(angles / 2.0)[..., np.newaxis] 421 | cosines = np.cos(angles / 2.0)[..., np.newaxis] 422 | return Quaternions(np.concatenate([cosines, axis * sines], axis=-1)) 423 | 424 | @classmethod 425 | def from_euler(cls, es, order='xyz', world=False): 426 | 427 | axis = { 428 | 'x': np.array([1, 0, 0]), 429 | 'y': np.array([0, 1, 0]), 430 | 'z': np.array([0, 0, 1]), 431 | } 432 | 433 | q0s = Quaternions.from_angle_axis(es[..., 0], axis[order[0]]) 434 | q1s = Quaternions.from_angle_axis(es[..., 1], axis[order[1]]) 435 | q2s = Quaternions.from_angle_axis(es[..., 2], axis[order[2]]) 436 | 437 | return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s)) 438 | 439 | @classmethod 440 | def from_transforms(cls, ts): 441 | 442 | d0, d1, d2 = ts[..., 0, 0], ts[..., 1, 1], ts[..., 2, 2] 443 | 444 | q0 = (d0 + d1 + d2 + 1.0) / 4.0 445 | q1 = (d0 - d1 - d2 + 1.0) / 4.0 446 | q2 = (-d0 + d1 - d2 + 1.0) / 4.0 447 | q3 = (-d0 - d1 + d2 + 1.0) / 4.0 448 | 449 | q0 = np.sqrt(q0.clip(0, None)) 450 | q1 = np.sqrt(q1.clip(0, None)) 451 | q2 = np.sqrt(q2.clip(0, None)) 452 | q3 = np.sqrt(q3.clip(0, None)) 453 | 454 | c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3) 455 | c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3) 456 | c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3) 457 | c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2) 458 | 459 | q1[c0] *= np.sign(ts[c0, 2, 1] - ts[c0, 1, 2]) 460 | q2[c0] *= np.sign(ts[c0, 0, 2] - ts[c0, 2, 0]) 461 | q3[c0] *= np.sign(ts[c0, 1, 0] - ts[c0, 0, 1]) 462 | 463 | q0[c1] *= np.sign(ts[c1, 2, 1] - ts[c1, 1, 2]) 464 | q2[c1] *= np.sign(ts[c1, 1, 0] + ts[c1, 0, 1]) 465 | q3[c1] *= np.sign(ts[c1, 0, 2] + ts[c1, 2, 0]) 466 | 467 | q0[c2] *= np.sign(ts[c2, 0, 2] - ts[c2, 2, 0]) 468 | q1[c2] *= np.sign(ts[c2, 1, 0] + ts[c2, 0, 1]) 469 | q3[c2] *= np.sign(ts[c2, 2, 1] + ts[c2, 1, 2]) 470 | 471 | q0[c3] *= np.sign(ts[c3, 1, 0] - ts[c3, 0, 1]) 472 | q1[c3] *= np.sign(ts[c3, 2, 0] + ts[c3, 0, 2]) 473 | q2[c3] *= np.sign(ts[c3, 2, 1] + ts[c3, 1, 2]) 474 | 475 | qs = np.empty(ts.shape[:-2] + (4,)) 476 | qs[..., 0] = q0 477 | qs[..., 1] = q1 478 | qs[..., 2] = q2 479 | qs[..., 3] = q3 480 | 481 | return cls(qs) 482 | -------------------------------------------------------------------------------- /src/holden/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | BASEPATH = os.path.dirname(__file__) 4 | sys.path.insert(0, BASEPATH) 5 | -------------------------------------------------------------------------------- /src/nemf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-he/NeMF/9ca4599f7c8f72b39e2dcb3e36114f840cab3d5b/src/nemf/__init__.py -------------------------------------------------------------------------------- /src/nemf/base_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from abc import ABC, abstractmethod 4 | 5 | import torch 6 | import torch.nn.init 7 | import torch.optim 8 | 9 | 10 | def weights_init(init_type='default'): 11 | def init_fun(m): 12 | classname = m.__class__.__name__ 13 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 14 | if init_type == 'gaussian': 15 | torch.nn.init.normal_(m.weight.data) 16 | elif init_type == 'xavier': 17 | torch.nn.init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 18 | elif init_type == 'kaiming': 19 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 20 | elif init_type == 'orthogonal': 21 | torch.nn.init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 22 | elif init_type == 'default': 23 | torch.nn.init.normal_(m.weight.data) 24 | weight_norm = m.weight.pow(2).sum(2, keepdim=True).sum(1, keepdim=True).add(1e-8).sqrt() 25 | m.weight.data.div_(weight_norm) 26 | else: 27 | raise NotImplementedError(f"Unsupported initialization: {init_type}") 28 | if hasattr(m, 'bias') and m.bias is not None: 29 | torch.nn.init.constant_(m.bias.data, 0.0) 30 | return init_fun 31 | 32 | 33 | class BaseModel(ABC): 34 | """This class is an abstract base class (ABC) for models. 35 | To create a subclass, you need to implement the following five functions: 36 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 37 | -- : unpack data from dataset and apply preprocessing. 38 | -- : produce intermediate results. 39 | -- : calculate losses, gradients, and update network weights. 40 | """ 41 | 42 | def __init__(self, args): 43 | self.args = args 44 | self.is_train = args.is_train 45 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 46 | self.model_save_dir = os.path.join(args.save_dir, 'checkpoints') # save all the checkpoints to save_dir 47 | 48 | if self.is_train and args.log: 49 | import time 50 | from torch.utils.tensorboard import SummaryWriter 51 | from .loss_record import LossRecorder 52 | 53 | log_dir = os.path.join(args.save_dir, 'logs') 54 | if args.epoch_begin != 0: 55 | all = [d for d in os.listdir(log_dir)] 56 | if len(all) == 0: 57 | raise RuntimeError(f'Empty logging path {log_dir}') 58 | timestamp = sorted(all)[-1] 59 | else: 60 | timestamp = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime(time.time())) 61 | self.log_path = os.path.join(log_dir, timestamp) 62 | os.makedirs(self.log_path, exist_ok=True) 63 | self.writer = SummaryWriter(self.log_path) 64 | self.loss_recorder = LossRecorder(self.writer) 65 | 66 | self.epoch_cnt = 0 67 | self.schedulers = [] 68 | self.optimizers = [] 69 | self.models = [] 70 | 71 | @abstractmethod 72 | def set_input(self, input): 73 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 74 | Parameters: 75 | input (dict): includes the data itself and its metadata information. 76 | """ 77 | pass 78 | 79 | @abstractmethod 80 | def compute_test_result(self): 81 | """ 82 | After forward, do something like output bvh, get error value 83 | """ 84 | pass 85 | 86 | @abstractmethod 87 | def forward(self): 88 | """Run forward pass; called by both functions and .""" 89 | pass 90 | 91 | @abstractmethod 92 | def optimize_parameters(self): 93 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 94 | pass 95 | 96 | def get_scheduler(self, optimizer): 97 | if self.args.scheduler.name == 'linear': 98 | def lambda_rule(epoch): 99 | lr_l = 1.0 - max(0, epoch - self.args.n_epochs_origin) / float(self.args.n_epochs_decay + 1) 100 | return lr_l 101 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 102 | if self.args.scheduler.name == 'StepLR': 103 | print('StepLR scheduler set') 104 | return torch.optim.lr_scheduler.StepLR(optimizer, self.args.scheduler.step_size, self.args.scheduler.gamma, verbose=self.args.verbose) 105 | if self.args.scheduler.name == 'Plateau': 106 | print('ReduceLROnPlateau shceduler set') 107 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=self.args.scheduler.factor, threshold=self.args.scheduler.threshold, patience=5, verbose=True) 108 | if self.args.scheduler.name == 'MultiStep': 109 | print('MultiStep shceduler set') 110 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.args.scheduler.milestones, gamma=self.args.scheduler.gamma, verbose=self.args.verbose) 111 | else: 112 | print('No scheduler set') 113 | return None 114 | 115 | def setup(self): 116 | """Load and print networks; create schedulers 117 | Parameters: 118 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 119 | """ 120 | if self.is_train: 121 | self.schedulers = [self.get_scheduler(optimizer) for optimizer in self.optimizers] 122 | 123 | def epoch(self): 124 | if self.args.verbose: 125 | self.loss_recorder.epoch() 126 | for scheduler in self.schedulers: 127 | if scheduler is not None: 128 | if self.args.scheduler.name == 'Plateau': 129 | loss = self.loss_recorder.losses['total_loss'].current() 130 | scheduler.step(loss[1]) 131 | else: 132 | scheduler.step() 133 | self.epoch_cnt += 1 134 | 135 | def test(self): 136 | """Forward function used in test time. 137 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 138 | It also calls to produce additional visualization results 139 | """ 140 | with torch.no_grad(): 141 | self.forward() 142 | self.compute_test_result() 143 | 144 | def train(self): 145 | for model in self.models: 146 | model.train() 147 | for param in model.parameters(): 148 | param.requires_grad = True 149 | 150 | def eval(self): 151 | for model in self.models: 152 | model.eval() 153 | for param in model.parameters(): 154 | param.requires_grad = False 155 | 156 | def print(self): 157 | for model in self.models: 158 | print(model) 159 | 160 | def count_params(self): 161 | params = 0 162 | for model in self.models: 163 | params += sum(p.numel() for p in model.parameters() if p.requires_grad) 164 | 165 | return params 166 | -------------------------------------------------------------------------------- /src/nemf/basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import holden.BVH as BVH 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from holden.Animation import Animation 8 | from holden.Quaternions import Quaternions 9 | from human_body_prior.tools.omni_tools import copy2cpu as c2c 10 | from rotations import matrix_to_quaternion, quaternion_to_axis_angle, rotation_6d_to_matrix 11 | from utils import align_joints, compute_trajectory 12 | 13 | from .base_model import BaseModel 14 | from .fk import ForwardKinematicsLayer 15 | from .losses import GeodesicLoss 16 | from .neural_motion import NeuralMotionField 17 | 18 | 19 | class Architecture(BaseModel): 20 | def __init__(self, args, ngpu): 21 | super(Architecture, self).__init__(args) 22 | 23 | self.args = args 24 | 25 | if args.amass_data: 26 | self.fk = ForwardKinematicsLayer(args) 27 | else: 28 | self.fk = ForwardKinematicsLayer(parents=args.animal.parents, positions=args.animal.offsets) 29 | 30 | self.field = NeuralMotionField(args.nemf).to(self.device) 31 | if args.multi_gpu is True: 32 | self.field = nn.DataParallel(self.field, device_ids=range(ngpu)) 33 | 34 | self.models = [self.field] 35 | 36 | ordermap = { 37 | 'x': 0, 38 | 'y': 1, 39 | 'z': 2, 40 | } 41 | self.up_axis = ordermap[self.args.data.up] 42 | self.v_axis = [x for x in ordermap.values() if x != self.up_axis] 43 | 44 | print(self.up_axis) 45 | print(self.v_axis) 46 | 47 | self.input_data = dict() 48 | self.recon_data = dict() 49 | 50 | if self.is_train: 51 | self.optimizer = torch.optim.Adam(self.field.parameters(), args.learning_rate) 52 | self.optimizers = [self.optimizer] 53 | self.criterion_rec = nn.L1Loss() if args.l1_loss else nn.MSELoss() 54 | 55 | self.criterion_geo = GeodesicLoss().to(self.device) 56 | self.fps = args.data.fps 57 | if args.bvh_viz: 58 | self.viz_dir = os.path.join(args.save_dir, 'results', 'bvh') 59 | else: 60 | self.viz_dir = os.path.join(args.save_dir, 'results', 'smpl') 61 | 62 | def set_input(self, input): 63 | self.input_data = {k: v.float().to(self.device) for k, v in input.items() if k in ['pos', 'global_xform', 'root_orient', 'root_vel', 'trans']} 64 | self.input_data['rotmat'] = rotation_6d_to_matrix(self.input_data['global_xform']) 65 | 66 | def forward(self, step=1): 67 | b_size, t_length, n_joints = self.input_data['pos'].shape[:3] 68 | 69 | t = torch.arange(start=0, end=t_length, step=step).unsqueeze(0) # (1, T) 70 | t = t / t_length * 2 - 1 # map t to [-1, 1] 71 | t = t.expand(b_size, -1).unsqueeze(-1).to(self.device) # (B, T, 1) 72 | f, _ = self.field(t) # (B, T, D), rot6d + root orient + root vel + root height 73 | 74 | self.recon_data = dict() 75 | rot6d_recon = f[:, :, :n_joints * 6] # (B, T, J x 6) 76 | rot6d_recon = rot6d_recon.view(b_size, -1, n_joints, 6) # (B, T, J, 6) 77 | rotmat_recon = rotation_6d_to_matrix(rot6d_recon) # (B, T, J, 3, 3) 78 | if self.args.data.root_transform: 79 | identity = torch.zeros(b_size, t.shape[1], 3, 3).to(self.device) 80 | identity[:, :, 0, 0] = 1 81 | identity[:, :, 1, 1] = 1 82 | identity[:, :, 2, 2] = 1 83 | rotmat_recon[:, :, 0] = identity 84 | rotmat_recon[:, :, -2] = rotmat_recon[:, :, -4] 85 | rotmat_recon[:, :, -1] = rotmat_recon[:, :, -3] 86 | local_rotmat = self.fk.global_to_local(rotmat_recon.view(-1, n_joints, 3, 3)) # (B x T, J, 3. 3) 87 | pos_recon, _ = self.fk(local_rotmat) # (B x T, J, 3) 88 | pos_recon = pos_recon.contiguous().view(b_size, -1, n_joints, 3) # (B, T, J, 3) 89 | self.recon_data['rotmat'] = rotmat_recon 90 | self.recon_data['pos'] = pos_recon 91 | 92 | if self.args.data.root_transform: 93 | root_orient = f[:, :, n_joints * 6:n_joints * 6 + 6] # (B, T, 6) 94 | self.recon_data['root_orient'] = root_orient 95 | else: 96 | self.recon_data['root_orient'] = rot6d_recon[:, :, 0] # (B, T, 6) 97 | 98 | root_vel = f[:, :, -4:-1] # (B, T, 3) 99 | root_height = f[:, :, -1] # (B, T) 100 | 101 | self.recon_data['root_vel'] = root_vel 102 | self.recon_data['root_height'] = root_height 103 | 104 | dt = 1.0 / self.args.data.fps * step 105 | origin = torch.zeros(b_size, 3).to(self.device) 106 | trans = compute_trajectory(root_vel, root_height, origin, dt, up_axis=self.args.data.up) 107 | self.recon_data['trans'] = trans 108 | 109 | def backward(self, validation=False): 110 | loss = 0 111 | 112 | root_orient = rotation_6d_to_matrix(self.recon_data['root_orient']) 113 | root_orient_gt = rotation_6d_to_matrix(self.input_data['root_orient']) 114 | if self.args.geodesic_loss: 115 | orient_recon_loss = self.criterion_geo(root_orient.view(-1, 3, 3), root_orient_gt.view(-1, 3, 3)) 116 | else: 117 | orient_recon_loss = self.criterion_rec(root_orient, root_orient_gt) 118 | self.loss_recorder.add_scalar('orient_recon_loss', orient_recon_loss, validation=validation) 119 | loss += self.args.lambda_orient * orient_recon_loss 120 | 121 | if self.args.geodesic_loss: 122 | rotmat_recon_loss = self.criterion_geo(self.recon_data['rotmat'].view(-1, 3, 3), self.input_data['rotmat'].view(-1, 3, 3)) 123 | else: 124 | rotmat_recon_loss = self.criterion_rec(self.recon_data['rotmat'], self.input_data['rotmat']) 125 | self.loss_recorder.add_scalar('rotmat_recon_loss', rotmat_recon_loss, validation=validation) 126 | loss += self.args.lambda_rotmat * rotmat_recon_loss 127 | 128 | pos_recon_loss = self.criterion_rec(self.recon_data['pos'], self.input_data['pos']) 129 | self.loss_recorder.add_scalar('pos_recon_loss', pos_recon_loss, validation=validation) 130 | loss += self.args.lambda_pos * pos_recon_loss 131 | 132 | v_loss = self.criterion_rec(self.recon_data['root_vel'], self.input_data['root_vel']) 133 | up_loss = self.criterion_rec(self.recon_data['root_height'], self.input_data['trans'][..., 'xyz'.index(self.args.data.up)]) 134 | self.loss_recorder.add_scalar('v_loss', v_loss, validation=validation) 135 | self.loss_recorder.add_scalar('up_loss', up_loss, validation=validation) 136 | loss += self.args.lambda_v * v_loss + self.args.lambda_up * up_loss 137 | 138 | origin = self.input_data['trans'][:, 0] 139 | trans = self.recon_data['trans'] 140 | trans[..., self.v_axis] = trans[..., self.v_axis] + origin[..., self.v_axis].unsqueeze(1) 141 | trans_loss = self.criterion_rec(trans, self.input_data['trans']) 142 | self.loss_recorder.add_scalar('trans_loss', trans_loss, validation=validation) 143 | loss += self.args.lambda_trans * trans_loss 144 | 145 | self.loss_recorder.add_scalar('total_loss', loss, validation=validation) 146 | 147 | if not validation: 148 | loss.backward() 149 | 150 | def optimize_parameters(self): 151 | self.optimizer.zero_grad() 152 | self.forward() 153 | self.backward(validation=False) 154 | self.optimizer.step() 155 | 156 | def report_errors(self): 157 | root_orient = rotation_6d_to_matrix(self.recon_data['root_orient']) # (B, T, 3, 3) 158 | root_orient_gt = rotation_6d_to_matrix(self.input_data['root_orient']) # (B, T, 3, 3) 159 | orientation_error = self.criterion_geo(root_orient.view(-1, 3, 3), root_orient_gt.view(-1, 3, 3)) 160 | 161 | local_rotmat = self.fk.global_to_local(self.recon_data['rotmat'].view(-1, self.args.smpl.joint_num, 3, 3)) # (B x T, J, 3, 3) 162 | local_rotmat_gt = self.fk.global_to_local(self.input_data['rotmat'].view(-1, self.args.smpl.joint_num, 3, 3)) # (B x T, J, 3, 3) 163 | rotation_error = self.criterion_geo(local_rotmat[:, 1:].reshape(-1, 3, 3), local_rotmat_gt[:, 1:].reshape(-1, 3, 3)) 164 | 165 | pos = self.recon_data['pos'] # (B, T, J, 3) 166 | pos_gt = self.input_data['pos'] # (B, T, J, 3) 167 | position_error = torch.linalg.norm((pos - pos_gt), dim=-1).mean() 168 | 169 | origin = self.input_data['trans'][:, 0] 170 | trans = self.recon_data['trans'] 171 | trans[..., self.v_axis] = trans[..., self.v_axis] + origin[..., self.v_axis].unsqueeze(1) 172 | trans_gt = self.input_data['trans'] # (B, T, 3) 173 | translation_error = torch.linalg.norm((trans - trans_gt), dim=-1).mean() 174 | 175 | return { 176 | 'rotation': c2c(rotation_error), 177 | 'position': c2c(position_error), 178 | 'orientation': c2c(orientation_error), 179 | 'translation': c2c(translation_error) 180 | } 181 | 182 | def verbose(self): 183 | res = {} 184 | for loss in self.loss_recorder.losses.values(): 185 | res[loss.name] = {'train': loss.current()[0]} 186 | 187 | return res 188 | 189 | def save(self, optimal=False): 190 | if optimal: 191 | path = os.path.join(self.args.save_dir, 'results', 'model') 192 | else: 193 | path = os.path.join(self.model_save_dir, f'{self.epoch_cnt:04d}') 194 | 195 | os.makedirs(path, exist_ok=True) 196 | nemf = self.field.module if isinstance(self.field, nn.DataParallel) else self.field 197 | torch.save(nemf.state_dict(), os.path.join(path, 'nemf.pth')) 198 | torch.save(self.optimizer.state_dict(), os.path.join(path, 'optimizer.pth')) 199 | if self.args.scheduler.name: 200 | torch.save(self.schedulers[0].state_dict(), os.path.join(path, 'scheduler.pth')) 201 | self.loss_recorder.save(path) 202 | 203 | print(f'Save at {path} succeeded') 204 | 205 | def load(self, epoch=None, optimal=False): 206 | if optimal: 207 | path = os.path.join(self.args.save_dir, 'results', 'model') 208 | else: 209 | if epoch is None: 210 | all = [int(q) for q in os.listdir(self.model_save_dir)] 211 | if len(all) == 0: 212 | raise RuntimeError(f'Empty loading path {self.model_save_dir}') 213 | epoch = sorted(all)[-1] 214 | path = os.path.join(self.model_save_dir, f'{epoch:04d}') 215 | 216 | print(f'Loading from {path}') 217 | nemf = self.field.module if isinstance(self.field, nn.DataParallel) else self.field 218 | nemf.load_state_dict(torch.load(os.path.join(path, 'nemf.pth'), map_location=self.device)) 219 | if self.is_train: 220 | self.optimizer.load_state_dict(torch.load(os.path.join(path, 'optimizer.pth'))) 221 | self.loss_recorder.load(path) 222 | if self.args.scheduler.name: 223 | self.schedulers[0].load_state_dict(torch.load(os.path.join(path, 'scheduler.pth'))) 224 | self.epoch_cnt = epoch if not optimal else 0 225 | 226 | print('Load succeeded') 227 | 228 | def super_sampling(self, step=1): 229 | with torch.no_grad(): 230 | self.forward(step) 231 | self.fps = self.args.data.fps / step 232 | self.compute_test_result() 233 | 234 | def compute_test_result(self): 235 | os.makedirs(self.viz_dir, exist_ok=True) 236 | 237 | rotmat = self.recon_data['rotmat'][0] # (T, J, 3, 3) 238 | rotmat_gt = self.input_data['rotmat'][0] # (T, J, 3, 3) 239 | local_rotmat = self.fk.global_to_local(rotmat) # (T, J, 3, 3) 240 | local_rotmat_gt = self.fk.global_to_local(rotmat_gt) # (T, J, 3, 3) 241 | 242 | if self.args.data.root_transform: 243 | root_orient = rotation_6d_to_matrix(self.recon_data['root_orient'][0]) # (T, 3, 3) 244 | root_orient_gt = rotation_6d_to_matrix(self.input_data['root_orient'][0]) # (T, 3, 3) 245 | local_rotmat[:, 0] = root_orient 246 | local_rotmat_gt[:, 0] = root_orient_gt 247 | rotation = matrix_to_quaternion(local_rotmat) # (T, J, 4) 248 | rotation_gt = matrix_to_quaternion(local_rotmat_gt) # (T, J, 4) 249 | 250 | origin = self.input_data['trans'][0, 0].clone() # [3] 251 | position = self.recon_data['trans'][0].clone() # (T, 3) 252 | position[:, self.v_axis] = position[:, self.v_axis] + origin[self.v_axis].unsqueeze(0) 253 | position_gt = self.input_data['trans'][0] # (T, 3) 254 | 255 | if self.args.bvh_viz: 256 | if self.args.amass_data: 257 | rotation = align_joints(rotation) 258 | rotation_gt = align_joints(rotation_gt) 259 | 260 | position = position.unsqueeze(1) # (T, 1, 3) 261 | position_gt = position_gt.unsqueeze(1) # (T, 1, 3) 262 | 263 | # export data to BVH files 264 | if self.args.amass_data: 265 | offsets = self.args.smpl.offsets[self.args.data.gender] 266 | parents = self.args.smpl.parents 267 | names = self.args.smpl.joint_names 268 | else: 269 | offsets = self.args.animal.offsets 270 | parents = self.args.animal.parents 271 | names = self.args.animal.joint_names 272 | 273 | anim = Animation(Quaternions(c2c(rotation)), c2c(position), None, offsets=offsets, parents=parents) 274 | BVH.save(os.path.join(self.viz_dir, f'test_{int(self.fps)}fps.bvh'), anim, names=names, frametime=1 / self.fps) 275 | 276 | anim.rotations = Quaternions(c2c(rotation_gt)) 277 | anim.positions = c2c(position_gt) 278 | BVH.save(os.path.join(self.viz_dir, 'test_gt.bvh'), anim, names=names, frametime=1 / self.args.data.fps) 279 | else: 280 | poses = c2c(quaternion_to_axis_angle(rotation)) # (T, J, 3) 281 | poses_gt = c2c(quaternion_to_axis_angle(rotation_gt)) # (T, J, 3) 282 | 283 | poses = poses.reshape((poses.shape[0], -1)) # (T, J x 3) 284 | poses = np.pad(poses, [(0, 0), (0, 93)], mode='constant') 285 | poses_gt = poses_gt.reshape((poses_gt.shape[0], -1)) # (T, J x 3) 286 | poses_gt = np.pad(poses_gt, [(0, 0), (0, 93)], mode='constant') 287 | 288 | np.savez(os.path.join(self.viz_dir, f'test_{int(self.fps)}fps.npz'), 289 | poses=poses, trans=c2c(position), betas=np.zeros(10), gender=self.args.data.gender, mocap_framerate=self.fps) 290 | np.savez(os.path.join(self.viz_dir, 'test_gt.npz'), 291 | poses=poses_gt, trans=c2c(position_gt), betas=np.zeros(10), gender=self.args.data.gender, mocap_framerate=self.args.data.fps) 292 | -------------------------------------------------------------------------------- /src/nemf/fk.py: -------------------------------------------------------------------------------- 1 | """Based on Daniel Holden code from: 2 | A Deep Learning Framework for Character Motion Synthesis and Editing 3 | (http://www.ipab.inf.ed.ac.uk/cgvu/motionsynthesis.pdf) 4 | """ 5 | 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from rotations import euler_angles_to_matrix, quaternion_to_matrix, rotation_6d_to_matrix 12 | 13 | 14 | class ForwardKinematicsLayer(nn.Module): 15 | """ Forward Kinematics Layer Class """ 16 | 17 | def __init__(self, args=None, parents=None, positions=None, device=None): 18 | super().__init__() 19 | self.b_idxs = None 20 | if device is None: 21 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | else: 23 | self.device = device 24 | if parents is None and positions is None: 25 | # Load SMPL skeleton (their joint order is different from the one we use for bvh export) 26 | smpl_fname = os.path.join(args.smpl.smpl_body_model, args.data.gender, 'model.npz') 27 | smpl_data = np.load(smpl_fname, encoding='latin1') 28 | self.parents = torch.from_numpy(smpl_data['kintree_table'][0].astype(np.int32)).to(self.device) 29 | self.parents = self.parents.long() 30 | self.positions = torch.from_numpy(smpl_data['J'].astype(np.float32)).to(self.device) 31 | self.positions[1:] -= self.positions[self.parents[1:]] 32 | else: 33 | self.parents = torch.from_numpy(parents).to(self.device) 34 | self.parents = self.parents.long() 35 | self.positions = torch.from_numpy(positions).to(self.device) 36 | self.positions = self.positions.float() 37 | self.positions[0] = 0 38 | 39 | def rotate(self, t0s, t1s): 40 | return torch.matmul(t0s, t1s) 41 | 42 | def identity_rotation(self, rotations): 43 | diagonal = torch.diag(torch.tensor([1.0, 1.0, 1.0, 1.0])).to(self.device) 44 | diagonal = torch.reshape( 45 | diagonal, torch.Size([1] * len(rotations.shape[:2]) + [4, 4])) 46 | ts = diagonal.repeat(rotations.shape[:2] + torch.Size([1, 1])) 47 | return ts 48 | 49 | def make_fast_rotation_matrices(self, positions, rotations): 50 | if len(rotations.shape) == 4 and rotations.shape[-2:] == torch.Size([3, 3]): 51 | rot_matrices = rotations 52 | elif rotations.shape[-1] == 3: 53 | rot_matrices = euler_angles_to_matrix(rotations, convention='XYZ') 54 | elif rotations.shape[-1] == 4: 55 | rot_matrices = quaternion_to_matrix(rotations) 56 | elif rotations.shape[-1] == 6: 57 | rot_matrices = rotation_6d_to_matrix(rotations) 58 | else: 59 | raise NotImplementedError(f'Unimplemented rotation representation in FK layer, shape of {rotations.shape}') 60 | 61 | rot_matrices = torch.cat([rot_matrices, positions[..., None]], dim=-1) 62 | zeros = torch.zeros(rot_matrices.shape[:-2] + torch.Size([1, 3])).to(self.device) 63 | ones = torch.ones(rot_matrices.shape[:-2] + torch.Size([1, 1])).to(self.device) 64 | zerosones = torch.cat([zeros, ones], dim=-1) 65 | rot_matrices = torch.cat([rot_matrices, zerosones], dim=-2) 66 | return rot_matrices 67 | 68 | def rotate_global(self, parents, positions, rotations): 69 | locals = self.make_fast_rotation_matrices(positions, rotations) 70 | globals = self.identity_rotation(rotations) 71 | 72 | globals = torch.cat([locals[:, 0:1], globals[:, 1:]], dim=1) 73 | b_size = positions.shape[0] 74 | if self.b_idxs is None: 75 | self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device) 76 | elif self.b_idxs.shape[-1] != b_size: 77 | self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device) 78 | 79 | for i in range(1, positions.shape[1]): 80 | globals[:, i] = self.rotate( 81 | globals[self.b_idxs, parents[i]], locals[:, i]) 82 | 83 | return globals 84 | 85 | def get_tpose_joints(self, offsets, parents): 86 | num_joints = len(parents) 87 | joints = [offsets[:, 0]] 88 | for j in range(1, len(parents)): 89 | joints.append(joints[parents[j]] + offsets[:, j]) 90 | 91 | return torch.stack(joints, dim=1) 92 | 93 | def canonical_to_local(self, canonical_xform, global_orient=None): 94 | """ 95 | Args: 96 | canonical_xform: (B, J, 3, 3) 97 | global_orient: (B, 3, 3) 98 | 99 | Returns: 100 | local_xform: (B, J, 3, 3) 101 | """ 102 | local_xform = torch.zeros_like(canonical_xform) 103 | 104 | if global_orient is None: 105 | global_xform = canonical_xform 106 | else: 107 | global_xform = torch.matmul(global_orient.unsqueeze(1), canonical_xform) 108 | for i in range(global_xform.shape[1]): 109 | if i == 0: 110 | local_xform[:, i] = global_xform[:, i] 111 | else: 112 | local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i]) 113 | 114 | return local_xform 115 | 116 | def global_to_local(self, global_xform): 117 | """ 118 | Args: 119 | global_xform: (B, J, 3, 3) 120 | 121 | Returns: 122 | local_xform: (B, J, 3, 3) 123 | """ 124 | local_xform = torch.zeros_like(global_xform) 125 | 126 | for i in range(global_xform.shape[1]): 127 | if i == 0: 128 | local_xform[:, i] = global_xform[:, i] 129 | else: 130 | local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i]) 131 | 132 | return local_xform 133 | 134 | def forward(self, rotations, positions=None): 135 | """ 136 | Args: 137 | rotations (B, J, D) 138 | 139 | Returns: 140 | The global position of each joint after FK (B, J, 3) 141 | """ 142 | # Get the full transform with rotations for skinning 143 | b_size = rotations.shape[0] 144 | if positions is None: 145 | positions = self.positions.repeat(b_size, 1, 1) 146 | transforms = self.rotate_global(self.parents, positions, rotations) 147 | coordinates = transforms[:, :, :3, 3] / transforms[:, :, 3:, 3] 148 | 149 | return coordinates, transforms 150 | -------------------------------------------------------------------------------- /src/nemf/global_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from human_body_prior.tools.omni_tools import copy2cpu as c2c 7 | from rotations import matrix_to_axis_angle, rotation_6d_to_matrix 8 | from utils import compute_trajectory, export_ply_trajectory, normalize 9 | 10 | from .base_model import BaseModel 11 | from .fk import ForwardKinematicsLayer 12 | from .residual_blocks import ResidualBlock, SkeletonResidual, residual_ratio 13 | from .skeleton import SkeletonConv, SkeletonPool, build_edge_topology, find_neighbor 14 | 15 | 16 | class Predictor(nn.Module): 17 | def __init__(self, args, topology): 18 | super(Predictor, self).__init__() 19 | self.topologies = [topology] 20 | self.channel_base = [args.channel_base] 21 | self.channel_list = [] 22 | self.edge_num = [len(topology)] 23 | self.pooling_list = [] 24 | self.layers = nn.ModuleList() 25 | self.args = args 26 | # self.convs = [] 27 | 28 | kernel_size = args.kernel_size 29 | padding = (kernel_size - 1) // 2 30 | bias = True 31 | 32 | for _ in range(args.num_layers): 33 | self.channel_base.append(self.channel_base[-1] * 2) 34 | 35 | for i in range(args.num_layers): 36 | seq = [] 37 | neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist) 38 | in_channels = self.channel_base[i] * self.edge_num[i] 39 | out_channels = self.channel_base[i + 1] * self.edge_num[i] 40 | if i == 0: 41 | self.channel_list.append(in_channels) 42 | self.channel_list.append(out_channels) 43 | last_pool = True if i == args.num_layers - 1 else False 44 | 45 | # (T, J, D) => (T, J', D) 46 | pool = SkeletonPool(edges=self.topologies[i], pooling_mode=args.skeleton_pool, 47 | channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool) 48 | 49 | if args.use_residual_blocks: 50 | # (T, J, D) => (T, J', 2D) 51 | seq.append(SkeletonResidual(self.topologies[i], neighbour_list, joint_num=self.edge_num[i], in_channels=in_channels, out_channels=out_channels, 52 | kernel_size=kernel_size, stride=1, padding=padding, padding_mode=args.padding_mode, bias=bias, 53 | extra_conv=args.extra_conv, pooling_mode=args.skeleton_pool, activation=args.activation, last_pool=last_pool)) 54 | else: 55 | for _ in range(args.extra_conv): 56 | # (T, J, D) => (T, J, D) 57 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, 58 | joint_num=self.edge_num[i], kernel_size=kernel_size, stride=1, 59 | padding=padding, padding_mode=args.padding_mode, bias=bias)) 60 | seq.append(nn.PReLU()) 61 | # (T, J, D) => (T, J, 2D) 62 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, 63 | joint_num=self.edge_num[i], kernel_size=kernel_size, stride=1, 64 | padding=padding, padding_mode=args.padding_mode, bias=bias, add_offset=False, 65 | in_offset_channel=3 * self.channel_base[i] // self.channel_base[0])) 66 | # self.convs.append(seq[-1]) 67 | 68 | seq.append(pool) 69 | seq.append(nn.PReLU()) 70 | self.layers.append(nn.Sequential(*seq)) 71 | 72 | self.topologies.append(pool.new_edges) 73 | self.pooling_list.append(pool.pooling_list) 74 | self.edge_num.append(len(self.topologies[-1])) 75 | 76 | in_channels = self.channel_base[-1] * len(self.pooling_list[-1]) 77 | out_channels = args.out_channels # root orient (6) + root vel (3) + root height (1) + contacts (8) 78 | 79 | self.global_motion = nn.Sequential( 80 | ResidualBlock(in_channels=in_channels, out_channels=512, kernel_size=kernel_size, stride=1, padding=padding, residual_ratio=residual_ratio(1), activation=args.activation), 81 | ResidualBlock(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=1, padding=padding, residual_ratio=residual_ratio(2), activation=args.activation), 82 | ResidualBlock(in_channels=256, out_channels=128, kernel_size=kernel_size, stride=1, padding=padding, residual_ratio=residual_ratio(3), activation=args.activation), 83 | ResidualBlock(in_channels=128, out_channels=out_channels, kernel_size=kernel_size, stride=1, padding=padding, 84 | residual_ratio=residual_ratio(4), activation=args.activation, last_layer=True) 85 | ) 86 | 87 | def forward(self, input): 88 | feature = input 89 | for layer in self.layers: 90 | feature = layer(feature) 91 | 92 | return self.global_motion(feature) 93 | 94 | 95 | class GlobalMotionPredictor(BaseModel): 96 | def __init__(self, args, ngpu): 97 | super(GlobalMotionPredictor, self).__init__(args) 98 | 99 | self.args = args 100 | 101 | smpl_fname = os.path.join(args.smpl.smpl_body_model, args.data.gender, 'model.npz') 102 | smpl_data = np.load(smpl_fname, encoding='latin1') 103 | parents = smpl_data['kintree_table'][0].astype(np.int32) 104 | edges = build_edge_topology(parents) 105 | 106 | self.fk = ForwardKinematicsLayer(args) 107 | 108 | self.predictor = Predictor(args.global_motion, edges).to(self.device) 109 | if args.multi_gpu is True: 110 | self.predictor = nn.DataParallel(self.predictor, device_ids=range(ngpu)) 111 | self.models = [self.predictor] 112 | 113 | ordermap = { 114 | 'x': 0, 115 | 'y': 1, 116 | 'z': 2, 117 | } 118 | self.up_axis = ordermap[self.args.data.up] 119 | self.v_axis = [x for x in ordermap.values() if x != self.up_axis] 120 | 121 | print(self.up_axis) 122 | print(self.v_axis) 123 | 124 | self.input_data = {} 125 | self.pred_data = {} 126 | 127 | self.mean = torch.load(os.path.join(args.dataset_dir, f'mean-{args.data.gender}-{args.data.clip_length}-{args.data.fps}fps.pt'), map_location=self.device) 128 | print(f'mean: {list(self.mean.keys())}') 129 | self.std = torch.load(os.path.join(args.dataset_dir, f'std-{args.data.gender}-{args.data.clip_length}-{args.data.fps}fps.pt'), map_location=self.device) 130 | print(f'std: {list(self.std.keys())}') 131 | 132 | if self.is_train: 133 | self.optimizer = torch.optim.Adam(self.predictor.parameters(), args.learning_rate, weight_decay=args.weight_decay) 134 | self.optimizers = [self.optimizer] 135 | self.criterion_pred = nn.L1Loss() if args.l1_loss else nn.MSELoss() 136 | self.criterion_bce = nn.BCEWithLogitsLoss() 137 | else: 138 | self.test_index = 0 139 | self.smpl_dir = os.path.join(args.save_dir, 'results', 'smpl') 140 | 141 | def set_input(self, input): 142 | self.input_data = {k: v.to(self.device) for k, v in input.items() if k in ['pos', 'velocity', 'rot6d', 'angular', 'global_xform', 143 | 'root_vel', 'trans', 'contacts']} 144 | 145 | def forward(self): 146 | self.pred_data.clear() 147 | self.input_data['origin'] = self.input_data['trans'][:, 0] 148 | self.pred_data = self.predict(data=self.input_data, dt=1.0 / self.args.data.fps) 149 | 150 | def predict(self, data, dt): 151 | b_size, t_length = data['pos'].shape[:2] 152 | 153 | if 'pos' in self.args.data.normalize: 154 | pos = normalize(data['pos'], mean=self.mean['pos'], std=self.std['pos']) 155 | else: 156 | pos = data['pos'] 157 | if 'velocity' in self.args.data.normalize: 158 | velocity = normalize(data['velocity'], mean=self.mean['velocity'], std=self.std['velocity']) 159 | else: 160 | velocity = data['velocity'] 161 | if 'rot6d' in self.args.data.normalize: 162 | rot6d = normalize(data['rot6d'], mean=self.mean['rot6d'], std=self.std['rot6d']) 163 | else: 164 | rot6d = data['rot6d'] 165 | if 'angular' in self.args.data.normalize: 166 | angular = normalize(data['angular'], mean=self.mean['angular'], std=self.std['angular']) 167 | else: 168 | angular = data['angular'] 169 | 170 | x = torch.cat((pos, velocity, rot6d, angular), dim=-1) 171 | x = x.view(b_size, t_length, -1) # (B, T, J x D) 172 | x = x.permute(0, 2, 1) # (B, J x D, T) 173 | 174 | global_motion = self.predictor(x) 175 | global_motion = global_motion.permute(0, 2, 1) # (B, T, D) 176 | 177 | output = dict() 178 | output['root_vel'] = global_motion[:, :, :3] 179 | output['root_height'] = global_motion[:, :, 3] 180 | output['contacts'] = global_motion[:, :, 4:] 181 | if 'origin' in data.keys(): 182 | origin = data['origin'] 183 | else: 184 | origin = torch.zeros(b_size, 3).to(self.device) 185 | output['trans'] = compute_trajectory(output['root_vel'], output['root_height'], origin, dt, up_axis=self.args.data.up) 186 | 187 | return output 188 | 189 | def backward(self, validation=False): 190 | loss = 0 191 | 192 | v_loss = self.criterion_pred(self.pred_data['root_vel'], self.input_data['root_vel']) 193 | up_loss = self.criterion_pred(self.pred_data['root_height'], self.input_data['trans'][..., self.up_axis]) 194 | self.loss_recorder.add_scalar('v_loss', v_loss, validation=validation) 195 | self.loss_recorder.add_scalar('up_loss', up_loss, validation=validation) 196 | loss += self.args.lambda_v * v_loss + self.args.lambda_up * up_loss 197 | 198 | trans_loss = self.criterion_pred(self.pred_data['trans'], self.input_data['trans']) 199 | self.loss_recorder.add_scalar('trans_loss', trans_loss, validation=validation) 200 | loss += self.args.lambda_trans * trans_loss 201 | 202 | contacts_loss = self.criterion_bce(self.pred_data['contacts'], self.input_data['contacts']) 203 | self.loss_recorder.add_scalar('contacts_loss', contacts_loss, validation=validation) 204 | loss += self.args.lambda_contacts * contacts_loss 205 | 206 | self.loss_recorder.add_scalar('total_loss', loss, validation=validation) 207 | 208 | if not validation: 209 | loss.backward() 210 | 211 | def optimize_parameters(self): 212 | self.optimizer.zero_grad() 213 | self.forward() 214 | self.backward(validation=False) 215 | self.optimizer.step() 216 | 217 | def report_errors(self): 218 | trans = self.pred_data['trans'] # (B, T, 3) 219 | trans_gt = self.input_data['trans'] # (B, T, 3) 220 | translation_error = torch.linalg.norm((trans - trans_gt), dim=-1).mean() 221 | 222 | return { 223 | 'translation': c2c(translation_error) 224 | } 225 | 226 | def verbose(self): 227 | res = {} 228 | for loss in self.loss_recorder.losses.values(): 229 | res[loss.name] = {'train': loss.current()[0], 'val': loss.current()[1]} 230 | 231 | return res 232 | 233 | def validate(self): 234 | with torch.no_grad(): 235 | self.forward() 236 | self.backward(validation=True) 237 | 238 | def save(self, optimal=False): 239 | if optimal: 240 | path = os.path.join(self.args.save_dir, 'results', 'model') 241 | else: 242 | path = os.path.join(self.model_save_dir, f'{self.epoch_cnt:04d}') 243 | 244 | os.makedirs(path, exist_ok=True) 245 | predictor = self.predictor.module if isinstance(self.predictor, nn.DataParallel) else self.predictor 246 | torch.save(predictor.state_dict(), os.path.join(path, 'predictor.pth')) 247 | torch.save(self.optimizer.state_dict(), os.path.join(path, 'optimizer.pth')) 248 | if self.args.scheduler.name: 249 | torch.save(self.schedulers[0].state_dict(), os.path.join(path, 'scheduler.pth')) 250 | self.loss_recorder.save(path) 251 | 252 | print(f'Save at {path} succeeded') 253 | 254 | def load(self, epoch=None, optimal=False): 255 | if optimal: 256 | path = os.path.join(self.args.save_dir, 'results', 'model') 257 | else: 258 | if epoch is None: 259 | all = [int(q) for q in os.listdir(self.model_save_dir)] 260 | if len(all) == 0: 261 | raise RuntimeError(f'Empty loading path {self.model_save_dir}') 262 | epoch = sorted(all)[-1] 263 | path = os.path.join(self.model_save_dir, f'{epoch:04d}') 264 | 265 | print(f'Loading from {path}') 266 | predictor = self.predictor.module if isinstance(self.predictor, nn.DataParallel) else self.predictor 267 | predictor.load_state_dict(torch.load(os.path.join(path, 'predictor.pth'), map_location=self.device)) 268 | if self.is_train: 269 | self.optimizer.load_state_dict(torch.load(os.path.join(path, 'optimizer.pth'))) 270 | self.loss_recorder.load(path) 271 | if self.args.scheduler.name: 272 | self.schedulers[0].load_state_dict(torch.load(os.path.join(path, 'scheduler.pth'))) 273 | self.epoch_cnt = epoch if not optimal else 0 274 | 275 | print('Load succeeded') 276 | 277 | def compute_test_result(self): 278 | os.makedirs(self.smpl_dir, exist_ok=True) 279 | 280 | b_size = self.input_data['trans'].shape[0] 281 | n_joints = self.input_data['pos'].shape[2] 282 | 283 | trans = self.pred_data['trans'] 284 | trans_gt = self.input_data['trans'] 285 | 286 | rotmat_gt = rotation_6d_to_matrix(self.input_data['global_xform']) # (B, T, J, 3, 3) 287 | local_rotmat_gt = self.fk.global_to_local(rotmat_gt.view(-1, n_joints, 3, 3)) # (B x T, J, 3, 3) 288 | local_rotmat_gt = local_rotmat_gt.view(b_size, -1, n_joints, 3, 3) # (B, T, J, 3, 3) 289 | local_rotmat = local_rotmat_gt.clone() 290 | 291 | for i in range(b_size): 292 | export_ply_trajectory(points=trans[i], color=(255, 0, 0), ply_fname=os.path.join(self.smpl_dir, f'test_{self.test_index + i:03d}.ply')) 293 | export_ply_trajectory(points=trans_gt[i], color=(0, 255, 0), ply_fname=os.path.join(self.smpl_dir, f'test_{self.test_index + i:03d}_gt.ply')) 294 | 295 | poses = c2c(matrix_to_axis_angle(local_rotmat[i])) # (T, J, 3) 296 | poses_gt = c2c(matrix_to_axis_angle(local_rotmat_gt[i])) # (T, J, 3) 297 | 298 | poses = poses.reshape((poses.shape[0], -1)) # (T, J x 3) 299 | poses = np.pad(poses, [(0, 0), (0, 93)], mode='constant') 300 | poses_gt = poses_gt.reshape((poses_gt.shape[0], -1)) # (T, J x 3) 301 | poses_gt = np.pad(poses_gt, [(0, 0), (0, 93)], mode='constant') 302 | 303 | np.savez(os.path.join(self.smpl_dir, f'test_{self.test_index + i:03d}.npz'), 304 | poses=poses, trans=c2c(trans[i]), betas=np.zeros(10), gender=self.args.data.gender, mocap_framerate=self.args.data.fps) 305 | np.savez(os.path.join(self.smpl_dir, f'test_{self.test_index + i:03d}_gt.npz'), 306 | poses=poses_gt, trans=c2c(trans_gt[i]), betas=np.zeros(10), gender=self.args.data.gender, mocap_framerate=self.args.data.fps) 307 | 308 | self.test_index += b_size 309 | -------------------------------------------------------------------------------- /src/nemf/loss_record.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | 8 | class SingleLoss: 9 | def __init__(self, name: str, writer: SummaryWriter): 10 | self.name = name 11 | self.loss_step = [] 12 | self.loss_epoch = [] 13 | self.loss_epoch_tmp = [] 14 | self.loss_epoch_val = [] 15 | self.writer = writer 16 | 17 | def add_scalar(self, val, step=None, validation=False): 18 | if validation: 19 | self.loss_epoch_val.append(val) 20 | else: 21 | if step is None: 22 | step = len(self.loss_step) 23 | self.loss_step.append(val) 24 | self.loss_epoch_tmp.append(val) 25 | self.writer.add_scalar('Loss/step_' + self.name, val, step) 26 | 27 | def current(self): 28 | return self.loss_epoch[-1] 29 | 30 | def epoch(self, step=None): 31 | if step is None: 32 | step = len(self.loss_epoch) 33 | if self.loss_epoch_tmp: 34 | loss_avg = sum(self.loss_epoch_tmp) / len(self.loss_epoch_tmp) 35 | else: 36 | loss_avg = 0 37 | self.loss_epoch_tmp = [] 38 | if self.loss_epoch_val: 39 | loss_avg_val = sum(self.loss_epoch_val) / len(self.loss_epoch_val) 40 | else: 41 | loss_avg_val = 0 42 | self.loss_epoch_val = [] 43 | self.loss_epoch.append([loss_avg, loss_avg_val]) 44 | self.writer.add_scalars('Loss/epoch_' + self.name, {'train': loss_avg, 'val': loss_avg_val}, step) 45 | 46 | def save(self): 47 | return { 48 | 'loss_step': self.loss_step, 49 | 'loss_epoch': self.loss_epoch 50 | } 51 | 52 | 53 | class LossRecorder: 54 | def __init__(self, writer: SummaryWriter): 55 | self.losses = {} 56 | self.writer = writer 57 | 58 | def add_scalar(self, name, val, step=None, validation=False): 59 | if isinstance(val, torch.Tensor): 60 | val = val.item() 61 | if name not in self.losses: 62 | self.losses[name] = SingleLoss(name, self.writer) 63 | self.losses[name].add_scalar(val, step, validation) 64 | 65 | def epoch(self, step=None): 66 | for loss in self.losses.values(): 67 | loss.epoch(step) 68 | 69 | def save(self, path): 70 | data = {} 71 | for key, value in self.losses.items(): 72 | data[key] = value.save() 73 | with open(os.path.join(path, 'loss.pkl'), 'wb') as f: 74 | pickle.dump(data, f) 75 | 76 | def load(self, path): 77 | with open(os.path.join(path, 'loss.pkl'), 'rb') as f: 78 | data = pickle.load(f) 79 | for key, value in data.items(): 80 | self.losses[key] = SingleLoss(key, self.writer) 81 | self.losses[key].loss_step = value['loss_step'] 82 | self.losses[key].loss_epoch = value['loss_epoch'] 83 | -------------------------------------------------------------------------------- /src/nemf/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GeodesicLoss(nn.Module): 6 | def __init__(self): 7 | super(GeodesicLoss, self).__init__() 8 | 9 | def compute_geodesic_distance(self, m1, m2, epsilon=1e-7): 10 | """ Compute the geodesic distance between two rotation matrices. 11 | Args: 12 | m1, m2: Two rotation matrices with the shape (batch x 3 x 3). 13 | Returns: 14 | The minimal angular difference between two rotation matrices in radian form [0, pi]. 15 | """ 16 | batch = m1.shape[0] 17 | m = torch.bmm(m1, m2.permute(0, 2, 1)) # batch*3*3 18 | 19 | cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 20 | # cos = (m.diagonal(dim1=-2, dim2=-1).sum(-1) -1) /2 21 | # cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda())) 22 | # cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda()) * -1) 23 | cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon) 24 | theta = torch.acos(cos) 25 | 26 | return theta 27 | 28 | def __call__(self, m1, m2, reduction='mean'): 29 | loss = self.compute_geodesic_distance(m1, m2) 30 | 31 | if reduction == 'mean': 32 | return loss.mean() 33 | elif reduction == 'none': 34 | return loss 35 | else: 36 | raise RuntimeError(f'unsupported reduction: {reduction}') 37 | raise RuntimeError(f'unsupported reduction: {reduction}') 38 | -------------------------------------------------------------------------------- /src/nemf/neural_motion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | def __init__(self, n_functions): 9 | super(PositionalEncoding, self).__init__() 10 | 11 | self.register_buffer('frequencies', 2.0 ** torch.arange(n_functions)) 12 | 13 | def forward(self, x): 14 | """ 15 | Args: 16 | x: tensor of shape [..., dim] 17 | 18 | Returns: 19 | embedding: a temporal embedding of `x` of shape [..., n_functions * dim * 2] 20 | """ 21 | freq = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1) 22 | 23 | embedding = torch.zeros(*freq.shape[:-1], freq.shape[-1] * 2).cuda() 24 | embedding[..., 0::2] = freq.sin() 25 | embedding[..., 1::2] = freq.cos() 26 | 27 | return embedding 28 | 29 | 30 | class Sine(nn.Module): 31 | def __init__(self, w0=1.): 32 | super(Sine, self).__init__() 33 | 34 | self.w0 = w0 35 | 36 | def forward(self, x): 37 | return torch.sin(self.w0 * x) 38 | 39 | 40 | class SirenBlock(nn.Module): 41 | def __init__(self, in_features, out_features, w0=30, c=6, is_first=False, use_bias=True, activation=None): 42 | super(SirenBlock, self).__init__() 43 | 44 | self.in_features = in_features 45 | self.is_first = is_first 46 | 47 | weight = torch.zeros(out_features, in_features) 48 | bias = torch.zeros(out_features) if use_bias else None 49 | self.init(weight, bias, c=c, w0=w0) 50 | 51 | self.weight = nn.Parameter(weight) 52 | self.bias = nn.Parameter(bias) if use_bias else None 53 | self.activation = Sine(w0) if activation is None else activation 54 | 55 | def init(self, weight, bias, c, w0): 56 | n = self.in_features 57 | 58 | w_std = (1 / n) if self.is_first else (np.sqrt(c / n) / w0) 59 | weight.uniform_(-w_std, w_std) 60 | 61 | if bias is not None: 62 | bias.uniform_(-w_std, w_std) 63 | 64 | def forward(self, x): 65 | out = F.linear(x, self.weight, self.bias) 66 | 67 | return self.activation(out) 68 | 69 | 70 | class FCBlock(nn.Module): 71 | def __init__(self, in_features, out_features, norm_layer=False, activation=None): 72 | super(FCBlock, self).__init__() 73 | 74 | self.fc = nn.Linear(in_features, out_features) 75 | self.residual = (in_features == out_features) # when the input and output have the same dimensions, build a residual block 76 | self.norm_layer = nn.LayerNorm(out_features) if norm_layer else None 77 | self.activation = nn.ReLU(inplace=True) if activation is None else activation 78 | 79 | def forward(self, x): 80 | """ 81 | Args: 82 | x: (B, T, D), features are in the last dimension. 83 | """ 84 | out = self.fc(x) 85 | 86 | if self.norm_layer is not None: 87 | out = self.norm_layer(out) 88 | 89 | if self.residual: 90 | return self.activation(out) + x 91 | 92 | return self.activation(out) 93 | 94 | 95 | class NeuralMotionField(nn.Module): 96 | def __init__(self, args): 97 | super(NeuralMotionField, self).__init__() 98 | 99 | self.args = args 100 | 101 | if args.bandwidth != 0: 102 | self.positional_encoding = PositionalEncoding(args.bandwidth) 103 | embedding_dim = args.bandwidth * 2 104 | else: 105 | embedding_dim = 1 106 | 107 | hidden_neuron = args.hidden_neuron 108 | in_features = 1 + args.local_z if args.siren else embedding_dim + args.local_z 109 | local_in_features = hidden_neuron + in_features if args.skip_connection else hidden_neuron 110 | global_in_features = local_in_features + args.global_z if args.skip_connection else hidden_neuron 111 | 112 | layers = [ 113 | SirenBlock(in_features, hidden_neuron, is_first=True) if args.siren else FCBlock(in_features, hidden_neuron, norm_layer=args.norm_layer), 114 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer), 115 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer), 116 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer), 117 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer), 118 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer), 119 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer), 120 | SirenBlock(local_in_features, hidden_neuron) if args.siren else FCBlock(local_in_features, hidden_neuron, norm_layer=args.norm_layer), 121 | SirenBlock(global_in_features, hidden_neuron) if args.siren else FCBlock(global_in_features, hidden_neuron, norm_layer=args.norm_layer), 122 | SirenBlock(global_in_features, hidden_neuron) if args.siren else FCBlock(global_in_features, hidden_neuron, norm_layer=args.norm_layer), 123 | SirenBlock(global_in_features, hidden_neuron) if args.siren else FCBlock(global_in_features, hidden_neuron, norm_layer=args.norm_layer) 124 | ] 125 | self.mlp = nn.ModuleList(layers) 126 | self.skip_layers = [] if not args.skip_connection else list(range(1, len(self.mlp))) 127 | self.local_layers = list(range(8)) 128 | self.local_linear = nn.Sequential(nn.Linear(hidden_neuron, args.local_output)) 129 | self.global_layers = list(range(8, len(self.mlp))) 130 | self.global_linear = nn.Sequential(nn.Linear(hidden_neuron, args.global_output)) 131 | 132 | def forward(self, t, z_l=None, z_g=None): 133 | if self.args.bandwidth != 0 and self.args.siren is False: 134 | t = self.positional_encoding(t) 135 | 136 | if z_l is not None: 137 | z_l = z_l.unsqueeze(1).expand(-1, t.shape[1], -1) 138 | x = torch.cat([t, z_l], dim=-1) 139 | else: 140 | x = t 141 | 142 | skip_in = x 143 | for i in self.local_layers: 144 | if i in self.skip_layers: 145 | x = torch.cat((x, skip_in), dim=-1) 146 | x = self.mlp[i](x) 147 | local_output = self.local_linear(x) 148 | 149 | if z_g is not None: 150 | z_g = z_g.unsqueeze(1).expand(-1, t.shape[1], -1) 151 | for i in self.global_layers: 152 | if i in self.skip_layers: 153 | x = torch.cat((x, skip_in, z_g), dim=-1) 154 | x = self.mlp[i](x) 155 | global_output = self.global_linear(x) 156 | else: 157 | global_output = None 158 | 159 | return local_output, global_output 160 | -------------------------------------------------------------------------------- /src/nemf/prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .residual_blocks import ResidualBlock, SkeletonResidual, residual_ratio 5 | from .skeleton import SkeletonConv, SkeletonPool, find_neighbor 6 | 7 | 8 | class LocalEncoder(nn.Module): 9 | def __init__(self, args, topology): 10 | super(LocalEncoder, self).__init__() 11 | self.topologies = [topology] 12 | self.channel_base = [args.channel_base] 13 | 14 | self.channel_list = [] 15 | self.edge_num = [len(topology)] 16 | self.pooling_list = [] 17 | self.layers = nn.ModuleList() 18 | self.args = args 19 | # self.convs = [] 20 | 21 | kernel_size = args.kernel_size 22 | kernel_even = False if kernel_size % 2 else True 23 | padding = (kernel_size - 1) // 2 24 | bias = True 25 | 26 | for _ in range(args.num_layers): 27 | self.channel_base.append(self.channel_base[-1] * 2) 28 | 29 | for i in range(args.num_layers): 30 | seq = [] 31 | neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist) 32 | in_channels = self.channel_base[i] * self.edge_num[i] 33 | out_channels = self.channel_base[i + 1] * self.edge_num[i] 34 | if i == 0: 35 | self.channel_list.append(in_channels) 36 | self.channel_list.append(out_channels) 37 | last_pool = True if i == args.num_layers - 1 else False 38 | 39 | # (T, J, D) => (T, J', D) 40 | pool = SkeletonPool(edges=self.topologies[i], pooling_mode=args.skeleton_pool, 41 | channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool) 42 | 43 | if args.use_residual_blocks: 44 | # (T, J, D) => (T/2, J', 2D) 45 | seq.append(SkeletonResidual(self.topologies[i], neighbour_list, joint_num=self.edge_num[i], in_channels=in_channels, out_channels=out_channels, 46 | kernel_size=kernel_size, stride=2, padding=padding, padding_mode=args.padding_mode, bias=bias, 47 | extra_conv=args.extra_conv, pooling_mode=args.skeleton_pool, activation=args.activation, last_pool=last_pool)) 48 | else: 49 | for _ in range(args.extra_conv): 50 | # (T, J, D) => (T, J, D) 51 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, 52 | joint_num=self.edge_num[i], kernel_size=kernel_size - 1 if kernel_even else kernel_size, 53 | stride=1, 54 | padding=padding, padding_mode=args.padding_mode, bias=bias)) 55 | seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh()) 56 | # (T, J, D) => (T/2, J, 2D) 57 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, 58 | joint_num=self.edge_num[i], kernel_size=kernel_size, stride=2, 59 | padding=padding, padding_mode=args.padding_mode, bias=bias, add_offset=False, 60 | in_offset_channel=3 * self.channel_base[i] // self.channel_base[0])) 61 | # self.convs.append(seq[-1]) 62 | 63 | seq.append(pool) 64 | seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh()) 65 | self.layers.append(nn.Sequential(*seq)) 66 | 67 | self.topologies.append(pool.new_edges) 68 | self.pooling_list.append(pool.pooling_list) 69 | self.edge_num.append(len(self.topologies[-1])) 70 | 71 | in_features = self.channel_base[-1] * len(self.pooling_list[-1]) 72 | in_features *= args.temporal_scale 73 | self.mu = nn.Linear(in_features, args.z_dim) 74 | self.logvar = nn.Linear(in_features, args.z_dim) 75 | 76 | def forward(self, input): 77 | output = input 78 | for layer in self.layers: 79 | output = layer(output) 80 | output = output.view(output.shape[0], -1) 81 | 82 | return self.mu(output), self.logvar(output) 83 | 84 | 85 | class GlobalEncoder(nn.Module): 86 | def __init__(self, args): 87 | super(GlobalEncoder, self).__init__() 88 | 89 | in_channels = args.in_channels # root orientation + root velocity 90 | out_channels = 512 91 | kernel_size = args.kernel_size 92 | padding = (kernel_size - 1) // 2 93 | 94 | self.encoder = nn.Sequential( 95 | ResidualBlock(in_channels=in_channels, out_channels=128, kernel_size=kernel_size, stride=2, padding=padding, residual_ratio=residual_ratio(1), activation=args.activation), 96 | ResidualBlock(in_channels=128, out_channels=256, kernel_size=kernel_size, stride=2, padding=padding, residual_ratio=residual_ratio(2), activation=args.activation), 97 | ResidualBlock(in_channels=256, out_channels=512, kernel_size=kernel_size, stride=2, padding=padding, residual_ratio=residual_ratio(3), activation=args.activation), 98 | ResidualBlock(in_channels=512, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=padding, residual_ratio=residual_ratio(4), activation=args.activation) 99 | ) 100 | 101 | in_features = out_channels * args.temporal_scale 102 | self.mu = nn.Linear(in_features, args.z_dim) 103 | self.logvar = nn.Linear(in_features, args.z_dim) 104 | 105 | def forward(self, input): 106 | feature = self.encoder(input) 107 | feature = feature.view(feature.shape[0], -1) 108 | 109 | return self.mu(feature), self.logvar(feature) 110 | -------------------------------------------------------------------------------- /src/nemf/residual_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .skeleton import SkeletonConv, SkeletonPool, SkeletonUnpool 5 | 6 | 7 | def residual_ratio(k): 8 | return 1 / (k + 1) 9 | 10 | 11 | class Affine(nn.Module): 12 | def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0): 13 | super(Affine, self).__init__() 14 | if scale: 15 | self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init) 16 | else: 17 | self.register_parameter('scale', None) 18 | 19 | if bias: 20 | self.bias = nn.Parameter(torch.zeros(num_parameters)) 21 | else: 22 | self.register_parameter('bias', None) 23 | 24 | def forward(self, input): 25 | output = input 26 | if self.scale is not None: 27 | scale = self.scale.unsqueeze(0) 28 | while scale.dim() < input.dim(): 29 | scale = scale.unsqueeze(2) 30 | output = output.mul(scale) 31 | 32 | if self.bias is not None: 33 | bias = self.bias.unsqueeze(0) 34 | while bias.dim() < input.dim(): 35 | bias = bias.unsqueeze(2) 36 | output += bias 37 | 38 | return output 39 | 40 | 41 | class BatchStatistics(nn.Module): 42 | def __init__(self, affine=-1): 43 | super(BatchStatistics, self).__init__() 44 | self.affine = nn.Sequential() if affine == -1 else Affine(affine) 45 | self.loss = 0 46 | 47 | def clear_loss(self): 48 | self.loss = 0 49 | 50 | def compute_loss(self, input): 51 | input_flat = input.view(input.size(1), input.numel() // input.size(1)) 52 | mu = input_flat.mean(1) 53 | logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log() 54 | 55 | self.loss = mu.pow(2).mean() + logvar.pow(2).mean() 56 | 57 | def forward(self, input): 58 | self.compute_loss(input) 59 | return self.affine(input) 60 | 61 | 62 | class ResidualBlock(nn.Module): 63 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False): 64 | super(ResidualBlock, self).__init__() 65 | 66 | self.residual_ratio = residual_ratio 67 | self.shortcut_ratio = 1 - residual_ratio 68 | 69 | residual = [] 70 | residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)) 71 | if batch_statistics: 72 | residual.append(BatchStatistics(out_channels)) 73 | if not last_layer: 74 | residual.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) 75 | self.residual = nn.Sequential(*residual) 76 | 77 | self.shortcut = nn.Sequential( 78 | nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(), 79 | nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), 80 | BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential() 81 | ) 82 | 83 | def forward(self, input): 84 | return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) 85 | 86 | 87 | class ResidualBlockTranspose(nn.Module): 88 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation): 89 | super(ResidualBlockTranspose, self).__init__() 90 | 91 | self.residual_ratio = residual_ratio 92 | self.shortcut_ratio = 1 - residual_ratio 93 | 94 | self.residual = nn.Sequential( 95 | nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), 96 | nn.PReLU() if activation == 'relu' else nn.Tanh() 97 | ) 98 | 99 | self.shortcut = nn.Sequential( 100 | nn.Upsample(scale_factor=2, mode='linear', align_corners=False) if stride == 2 else nn.Sequential(), 101 | nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 102 | ) 103 | 104 | def forward(self, input): 105 | return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) 106 | 107 | 108 | class SkeletonResidual(nn.Module): 109 | def __init__(self, topology, neighbour_list, joint_num, in_channels, out_channels, kernel_size, stride, padding, padding_mode, bias, extra_conv, pooling_mode, activation, last_pool): 110 | super(SkeletonResidual, self).__init__() 111 | 112 | kernel_even = False if kernel_size % 2 else True 113 | 114 | seq = [] 115 | for _ in range(extra_conv): 116 | # (T, J, D) => (T, J, D) 117 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, 118 | joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, 119 | stride=1, 120 | padding=padding, padding_mode=padding_mode, bias=bias)) 121 | seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) 122 | # (T, J, D) => (T/2, J, 2D) 123 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, 124 | joint_num=joint_num, kernel_size=kernel_size, stride=stride, 125 | padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False)) 126 | seq.append(nn.GroupNorm(8, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!! 127 | self.residual = nn.Sequential(*seq) 128 | 129 | # (T, J, D) => (T/2, J, 2D) 130 | self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, 131 | joint_num=joint_num, kernel_size=1, stride=stride, padding=0, 132 | bias=True, add_offset=False) 133 | 134 | seq = [] 135 | # (T/2, J, 2D) => (T/2, J', 2D) 136 | pool = SkeletonPool(edges=topology, pooling_mode=pooling_mode, 137 | channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool) 138 | if len(pool.pooling_list) != pool.edge_num: 139 | seq.append(pool) 140 | seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) 141 | self.common = nn.Sequential(*seq) 142 | 143 | def forward(self, input): 144 | output = self.residual(input) + self.shortcut(input) 145 | 146 | return self.common(output) 147 | 148 | 149 | class SkeletonResidualTranspose(nn.Module): 150 | def __init__(self, neighbour_list, joint_num, in_channels, out_channels, kernel_size, padding, padding_mode, bias, extra_conv, pooling_list, upsampling, activation, last_layer): 151 | super(SkeletonResidualTranspose, self).__init__() 152 | 153 | kernel_even = False if kernel_size % 2 else True 154 | 155 | seq = [] 156 | # (T, J, D) => (2T, J, D) 157 | if upsampling is not None: 158 | seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False)) 159 | # (2T, J, D) => (2T, J', D) 160 | unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list)) 161 | if unpool.input_edge_num != unpool.output_edge_num: 162 | seq.append(unpool) 163 | self.common = nn.Sequential(*seq) 164 | 165 | seq = [] 166 | for _ in range(extra_conv): 167 | # (2T, J', D) => (2T, J', D) 168 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, 169 | joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, 170 | stride=1, 171 | padding=padding, padding_mode=padding_mode, bias=bias)) 172 | seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) 173 | # (2T, J', D) => (2T, J', D/2) 174 | seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, 175 | joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, 176 | stride=1, 177 | padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False)) 178 | self.residual = nn.Sequential(*seq) 179 | 180 | # (2T, J', D) => (2T, J', D/2) 181 | self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, 182 | joint_num=joint_num, kernel_size=1, stride=1, padding=0, 183 | bias=True, add_offset=False) 184 | 185 | if activation == 'relu': 186 | self.activation = nn.PReLU() if not last_layer else None 187 | else: 188 | self.activation = nn.Tanh() if not last_layer else None 189 | 190 | def forward(self, input): 191 | output = self.common(input) 192 | output = self.residual(output) + self.shortcut(output) 193 | 194 | if self.activation is not None: 195 | return self.activation(output) 196 | else: 197 | return output 198 | -------------------------------------------------------------------------------- /src/nemf/skeleton.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class SkeletonConv(nn.Module): 10 | def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0, 11 | bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0): 12 | self.in_channels_per_joint = in_channels // joint_num 13 | self.out_channels_per_joint = out_channels // joint_num 14 | if in_channels % joint_num != 0 or out_channels % joint_num != 0: 15 | raise Exception('BAD') 16 | super(SkeletonConv, self).__init__() 17 | 18 | if padding_mode == 'zeros': 19 | padding_mode = 'constant' 20 | if padding_mode == 'reflection': 21 | padding_mode = 'reflect' 22 | 23 | self.expanded_neighbour_list = [] 24 | self.expanded_neighbour_list_offset = [] 25 | self.neighbour_list = neighbour_list 26 | self.add_offset = add_offset 27 | self.joint_num = joint_num 28 | 29 | self.stride = stride 30 | self.dilation = 1 31 | self.groups = 1 32 | self.padding = padding 33 | self.padding_mode = padding_mode 34 | self._padding_repeated_twice = (padding, padding) 35 | 36 | for neighbour in neighbour_list: 37 | expanded = [] 38 | for k in neighbour: 39 | for i in range(self.in_channels_per_joint): 40 | expanded.append(k * self.in_channels_per_joint + i) 41 | self.expanded_neighbour_list.append(expanded) 42 | 43 | if self.add_offset: 44 | self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels) 45 | 46 | for neighbour in neighbour_list: 47 | expanded = [] 48 | for k in neighbour: 49 | for i in range(add_offset): 50 | expanded.append(k * in_offset_channel + i) 51 | self.expanded_neighbour_list_offset.append(expanded) 52 | 53 | self.weight = torch.zeros(out_channels, in_channels, kernel_size) 54 | if bias: 55 | self.bias = torch.zeros(out_channels) 56 | else: 57 | self.register_parameter('bias', None) 58 | 59 | self.mask = torch.zeros_like(self.weight) 60 | for i, neighbour in enumerate(self.expanded_neighbour_list): 61 | self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1 62 | self.mask = nn.Parameter(self.mask, requires_grad=False) 63 | 64 | self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \ 65 | 'joint_num={}, stride={}, padding={}, bias={})'.format( 66 | in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias 67 | ) 68 | 69 | self.reset_parameters() 70 | 71 | def reset_parameters(self): 72 | for i, neighbour in enumerate(self.expanded_neighbour_list): 73 | """ Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """ 74 | tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), 75 | neighbour, ...]) 76 | nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) 77 | self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), 78 | neighbour, ...] = tmp 79 | if self.bias is not None: 80 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out( 81 | self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...]) 82 | bound = 1 / math.sqrt(fan_in) 83 | tmp = torch.zeros_like( 84 | self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)]) 85 | nn.init.uniform_(tmp, -bound, bound) 86 | self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp 87 | 88 | self.weight = nn.Parameter(self.weight) 89 | if self.bias is not None: 90 | self.bias = nn.Parameter(self.bias) 91 | 92 | def set_offset(self, offset): 93 | if not self.add_offset: 94 | raise Exception('Wrong Combination of Parameters') 95 | self.offset = offset.reshape(offset.shape[0], -1) 96 | 97 | def forward(self, input): 98 | # print('SkeletonConv') 99 | weight_masked = self.weight * self.mask 100 | # print(f'input: {input.size()}') 101 | res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 102 | weight_masked, self.bias, self.stride, 103 | 0, self.dilation, self.groups) 104 | 105 | if self.add_offset: 106 | offset_res = self.offset_enc(self.offset) 107 | offset_res = offset_res.reshape(offset_res.shape + (1, )) 108 | res += offset_res / 100 109 | # print(f'res: {res.size()}') 110 | return res 111 | 112 | 113 | class SkeletonLinear(nn.Module): 114 | def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False): 115 | super(SkeletonLinear, self).__init__() 116 | self.neighbour_list = neighbour_list 117 | self.in_channels = in_channels 118 | self.out_channels = out_channels 119 | self.in_channels_per_joint = in_channels // len(neighbour_list) 120 | self.out_channels_per_joint = out_channels // len(neighbour_list) 121 | self.extra_dim1 = extra_dim1 122 | self.expanded_neighbour_list = [] 123 | 124 | for neighbour in neighbour_list: 125 | expanded = [] 126 | for k in neighbour: 127 | for i in range(self.in_channels_per_joint): 128 | expanded.append(k * self.in_channels_per_joint + i) 129 | self.expanded_neighbour_list.append(expanded) 130 | 131 | self.weight = torch.zeros(out_channels, in_channels) 132 | self.mask = torch.zeros(out_channels, in_channels) 133 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 134 | 135 | self.reset_parameters() 136 | 137 | def reset_parameters(self): 138 | for i, neighbour in enumerate(self.expanded_neighbour_list): 139 | tmp = torch.zeros_like( 140 | self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] 141 | ) 142 | self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1 143 | nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) 144 | self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp 145 | 146 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 147 | bound = 1 / math.sqrt(fan_in) 148 | nn.init.uniform_(self.bias, -bound, bound) 149 | 150 | self.weight = nn.Parameter(self.weight) 151 | self.mask = nn.Parameter(self.mask, requires_grad=False) 152 | 153 | def forward(self, input): 154 | input = input.reshape(input.shape[0], -1) 155 | weight_masked = self.weight * self.mask 156 | res = F.linear(input, weight_masked, self.bias) 157 | if self.extra_dim1: 158 | res = res.reshape(res.shape + (1,)) 159 | return res 160 | 161 | 162 | class SkeletonPool(nn.Module): 163 | def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False): 164 | super(SkeletonPool, self).__init__() 165 | 166 | if pooling_mode != 'mean': 167 | raise Exception('Unimplemented pooling mode in matrix_implementation') 168 | 169 | self.channels_per_edge = channels_per_edge 170 | self.pooling_mode = pooling_mode 171 | self.edge_num = len(edges) 172 | # self.edge_num = len(edges) + 1 173 | self.seq_list = [] 174 | self.pooling_list = [] 175 | self.new_edges = [] 176 | degree = [0] * 100 # each element represents the degree of the corresponding joint 177 | 178 | for edge in edges: 179 | degree[edge[0]] += 1 180 | degree[edge[1]] += 1 181 | 182 | # seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2. 183 | def find_seq(j, seq): 184 | nonlocal self, degree, edges 185 | 186 | if degree[j] > 2 and j != 0: 187 | self.seq_list.append(seq) 188 | seq = [] 189 | 190 | if degree[j] == 1: 191 | self.seq_list.append(seq) 192 | return 193 | 194 | for idx, edge in enumerate(edges): 195 | if edge[0] == j: 196 | find_seq(edge[1], seq + [idx]) 197 | 198 | find_seq(0, []) 199 | # print(f'self.seq_list: {self.seq_list}') 200 | 201 | for seq in self.seq_list: 202 | if last_pool: 203 | self.pooling_list.append(seq) 204 | continue 205 | if len(seq) % 2 == 1: 206 | self.pooling_list.append([seq[0]]) 207 | self.new_edges.append(edges[seq[0]]) 208 | seq = seq[1:] 209 | for i in range(0, len(seq), 2): 210 | self.pooling_list.append([seq[i], seq[i + 1]]) 211 | self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]]) 212 | # print(f'self.pooling_list: {self.pooling_list}') 213 | # print(f'self.new_egdes: {self.new_edges}') 214 | 215 | # add global position 216 | # self.pooling_list.append([self.edge_num - 1]) 217 | 218 | self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format( 219 | len(edges), len(self.pooling_list) 220 | ) 221 | 222 | self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge) 223 | 224 | for i, pair in enumerate(self.pooling_list): 225 | for j in pair: 226 | for c in range(channels_per_edge): 227 | self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair) 228 | 229 | self.weight = nn.Parameter(self.weight, requires_grad=False) 230 | 231 | def forward(self, input: torch.Tensor): 232 | # print('SkeletonPool') 233 | # print(f'input: {input.size()}') 234 | # print(f'self.weight: {self.weight.size()}') 235 | return torch.matmul(self.weight, input) 236 | 237 | 238 | class SkeletonUnpool(nn.Module): 239 | def __init__(self, pooling_list, channels_per_edge): 240 | super(SkeletonUnpool, self).__init__() 241 | self.pooling_list = pooling_list 242 | self.input_edge_num = len(pooling_list) 243 | self.output_edge_num = 0 244 | self.channels_per_edge = channels_per_edge 245 | for t in self.pooling_list: 246 | self.output_edge_num += len(t) 247 | 248 | self.description = 'SkeletonUnpool(in_edge_num={}, out_edge_num={})'.format( 249 | self.input_edge_num, self.output_edge_num, 250 | ) 251 | 252 | self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge) 253 | 254 | for i, pair in enumerate(self.pooling_list): 255 | for j in pair: 256 | for c in range(channels_per_edge): 257 | self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1 258 | 259 | self.weight = nn.Parameter(self.weight) 260 | self.weight.requires_grad_(False) 261 | 262 | def forward(self, input: torch.Tensor): 263 | # print('SkeletonUnpool') 264 | # print(f'input: {input.size()}') 265 | # print(f'self.weight: {self.weight.size()}') 266 | return torch.matmul(self.weight, input) 267 | 268 | 269 | """ 270 | Helper functions for skeleton operation 271 | """ 272 | 273 | 274 | def dfs(x, fa, vis, dist): 275 | vis[x] = 1 276 | for y in range(len(fa)): 277 | if (fa[y] == x or fa[x] == y) and vis[y] == 0: 278 | dist[y] = dist[x] + 1 279 | dfs(y, fa, vis, dist) 280 | 281 | 282 | """ 283 | def find_neighbor_joint(fa, threshold): 284 | neighbor_list = [[]] 285 | for x in range(1, len(fa)): 286 | vis = [0 for _ in range(len(fa))] 287 | dist = [0 for _ in range(len(fa))] 288 | dist[0] = 10000 289 | dfs(x, fa, vis, dist) 290 | neighbor = [] 291 | for j in range(1, len(fa)): 292 | if dist[j] <= threshold: 293 | neighbor.append(j) 294 | neighbor_list.append(neighbor) 295 | 296 | neighbor = [0] 297 | for i, x in enumerate(neighbor_list): 298 | if i == 0: continue 299 | if 1 in x: 300 | neighbor.append(i) 301 | neighbor_list[i] = [0] + neighbor_list[i] 302 | neighbor_list[0] = neighbor 303 | return neighbor_list 304 | 305 | 306 | def build_edge_topology(topology, offset): 307 | # get all edges (pa, child, offset) 308 | edges = [] 309 | joint_num = len(topology) 310 | for i in range(1, joint_num): 311 | edges.append((topology[i], i, offset[i])) 312 | return edges 313 | """ 314 | 315 | 316 | def build_edge_topology(topology): 317 | # get all edges (pa, child) 318 | edges = [] 319 | joint_num = len(topology) 320 | edges.append((0, joint_num)) # add an edge between the root joint and a virtual joint 321 | for i in range(1, joint_num): 322 | edges.append((topology[i], i)) 323 | return edges 324 | 325 | 326 | def build_joint_topology(edges, origin_names): 327 | parent = [] 328 | offset = [] 329 | names = [] 330 | edge2joint = [] 331 | joint_from_edge = [] # -1 means virtual joint 332 | joint_cnt = 0 333 | out_degree = [0] * (len(edges) + 10) 334 | for edge in edges: 335 | out_degree[edge[0]] += 1 336 | 337 | # add root joint 338 | joint_from_edge.append(-1) 339 | parent.append(0) 340 | offset.append(np.array([0, 0, 0])) 341 | names.append(origin_names[0]) 342 | joint_cnt += 1 343 | 344 | def make_topology(edge_idx, pa): 345 | nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt 346 | edge = edges[edge_idx] 347 | if out_degree[edge[0]] > 1: 348 | parent.append(pa) 349 | offset.append(np.array([0, 0, 0])) 350 | names.append(origin_names[edge[1]] + '_virtual') 351 | edge2joint.append(-1) 352 | pa = joint_cnt 353 | joint_cnt += 1 354 | 355 | parent.append(pa) 356 | offset.append(edge[2]) 357 | names.append(origin_names[edge[1]]) 358 | edge2joint.append(edge_idx) 359 | pa = joint_cnt 360 | joint_cnt += 1 361 | 362 | for idx, e in enumerate(edges): 363 | if e[0] == edge[1]: 364 | make_topology(idx, pa) 365 | 366 | for idx, e in enumerate(edges): 367 | if e[0] == 0: 368 | make_topology(idx, 0) 369 | 370 | return parent, offset, names, edge2joint 371 | 372 | 373 | def calc_edge_mat(edges): 374 | edge_num = len(edges) 375 | # edge_mat[i][j] = distance between edge(i) and edge(j) 376 | edge_mat = [[100000] * edge_num for _ in range(edge_num)] 377 | for i in range(edge_num): 378 | edge_mat[i][i] = 0 379 | 380 | # initialize edge_mat with direct neighbor 381 | for i, a in enumerate(edges): 382 | for j, b in enumerate(edges): 383 | link = 0 384 | for x in range(2): 385 | for y in range(2): 386 | if a[x] == b[y]: 387 | link = 1 388 | if link: 389 | edge_mat[i][j] = 1 390 | 391 | # calculate all the pairs distance 392 | for k in range(edge_num): 393 | for i in range(edge_num): 394 | for j in range(edge_num): 395 | edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j]) 396 | return edge_mat 397 | 398 | 399 | def find_neighbor(edges, d): 400 | """ 401 | Args: 402 | edges: The list contains N elements, each element represents (parent, child). 403 | d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1). 404 | 405 | Returns: 406 | The list contains N elements, each element is a list of edge indices whose distance <= d. 407 | """ 408 | edge_mat = calc_edge_mat(edges) 409 | neighbor_list = [] 410 | edge_num = len(edge_mat) 411 | for i in range(edge_num): 412 | neighbor = [] 413 | for j in range(edge_num): 414 | if edge_mat[i][j] <= d: 415 | neighbor.append(j) 416 | neighbor_list.append(neighbor) 417 | 418 | # # add neighbor for global part 419 | # global_part_neighbor = neighbor_list[0].copy() 420 | # """ 421 | # Line #373 is buggy. Thanks @crissallan!! 422 | # See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30) 423 | # However, fixing this bug will make it unable to load the pretrained model and 424 | # affect the reproducibility of quantitative error reported in the paper. 425 | # It is not a fatal bug so we didn't touch it and we are looking for possible solutions. 426 | # """ 427 | # for i in global_part_neighbor: 428 | # neighbor_list[i].append(edge_num) 429 | # neighbor_list.append(global_part_neighbor) 430 | 431 | return neighbor_list 432 | 433 | 434 | def calc_node_depth(topology): 435 | def dfs(node, topology): 436 | if topology[node] < 0: 437 | return 0 438 | return 1 + dfs(topology[node], topology) 439 | depth = [] 440 | for i in range(len(topology)): 441 | depth.append(dfs(i, topology)) 442 | 443 | return depth 444 | -------------------------------------------------------------------------------- /src/soft_dtw_cuda.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Mehran Maghoumi 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # ---------------------------------------------------------------------------------------------------------------------- 23 | 24 | import numpy as np 25 | import torch 26 | import torch.cuda 27 | from numba import jit 28 | from torch.autograd import Function 29 | from numba import cuda 30 | import math 31 | 32 | # ---------------------------------------------------------------------------------------------------------------------- 33 | @cuda.jit 34 | def compute_softdtw_cuda(D, gamma, bandwidth, max_i, max_j, n_passes, R): 35 | """ 36 | :param seq_len: The length of the sequence (both inputs are assumed to be of the same size) 37 | :param n_passes: 2 * seq_len - 1 (The number of anti-diagonals) 38 | """ 39 | # Each block processes one pair of examples 40 | b = cuda.blockIdx.x 41 | # We have as many threads as seq_len, because the most number of threads we need 42 | # is equal to the number of elements on the largest anti-diagonal 43 | tid = cuda.threadIdx.x 44 | 45 | # Compute I, J, the indices from [0, seq_len) 46 | 47 | # The row index is always the same as tid 48 | I = tid 49 | 50 | inv_gamma = 1.0 / gamma 51 | 52 | # Go over each anti-diagonal. Only process threads that fall on the current on the anti-diagonal 53 | for p in range(n_passes): 54 | 55 | # The index is actually 'p - tid' but need to force it in-bounds 56 | J = max(0, min(p - tid, max_j - 1)) 57 | 58 | # For simplicity, we define i, j which start from 1 (offset from I, J) 59 | i = I + 1 60 | j = J + 1 61 | 62 | # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds 63 | if I + J == p and (I < max_i and J < max_j): 64 | # Don't compute if outside bandwidth 65 | if not (abs(i - j) > bandwidth > 0): 66 | r0 = -R[b, i - 1, j - 1] * inv_gamma 67 | r1 = -R[b, i - 1, j] * inv_gamma 68 | r2 = -R[b, i, j - 1] * inv_gamma 69 | rmax = max(max(r0, r1), r2) 70 | rsum = math.exp(r0 - rmax) + math.exp(r1 - rmax) + math.exp(r2 - rmax) 71 | softmin = -gamma * (math.log(rsum) + rmax) 72 | R[b, i, j] = D[b, i - 1, j - 1] + softmin 73 | 74 | # Wait for other threads in this block 75 | cuda.syncthreads() 76 | 77 | # ---------------------------------------------------------------------------------------------------------------------- 78 | @cuda.jit 79 | def compute_softdtw_backward_cuda(D, R, inv_gamma, bandwidth, max_i, max_j, n_passes, E): 80 | k = cuda.blockIdx.x 81 | tid = cuda.threadIdx.x 82 | 83 | # Indexing logic is the same as above, however, the anti-diagonal needs to 84 | # progress backwards 85 | I = tid 86 | 87 | for p in range(n_passes): 88 | # Reverse the order to make the loop go backward 89 | rev_p = n_passes - p - 1 90 | 91 | # convert tid to I, J, then i, j 92 | J = max(0, min(rev_p - tid, max_j - 1)) 93 | 94 | i = I + 1 95 | j = J + 1 96 | 97 | # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds 98 | if I + J == rev_p and (I < max_i and J < max_j): 99 | 100 | if math.isinf(R[k, i, j]): 101 | R[k, i, j] = -math.inf 102 | 103 | # Don't compute if outside bandwidth 104 | if not (abs(i - j) > bandwidth > 0): 105 | a = math.exp((R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) * inv_gamma) 106 | b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) * inv_gamma) 107 | c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) * inv_gamma) 108 | E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c 109 | 110 | # Wait for other threads in this block 111 | cuda.syncthreads() 112 | 113 | # ---------------------------------------------------------------------------------------------------------------------- 114 | class _SoftDTWCUDA(Function): 115 | """ 116 | CUDA implementation is inspired by the diagonal one proposed in https://ieeexplore.ieee.org/document/8400444: 117 | "Developing a pattern discovery method in time series data and its GPU acceleration" 118 | """ 119 | 120 | @staticmethod 121 | def forward(ctx, D, gamma, bandwidth): 122 | dev = D.device 123 | dtype = D.dtype 124 | gamma = torch.cuda.FloatTensor([gamma]) 125 | bandwidth = torch.cuda.FloatTensor([bandwidth]) 126 | 127 | B = D.shape[0] 128 | N = D.shape[1] 129 | M = D.shape[2] 130 | threads_per_block = max(N, M) 131 | n_passes = 2 * threads_per_block - 1 132 | 133 | # Prepare the output array 134 | R = torch.ones((B, N + 2, M + 2), device=dev, dtype=dtype) * math.inf 135 | R[:, 0, 0] = 0 136 | 137 | # Run the CUDA kernel. 138 | # Set CUDA's grid size to be equal to the batch size (every CUDA block processes one sample pair) 139 | # Set the CUDA block size to be equal to the length of the longer sequence (equal to the size of the largest diagonal) 140 | compute_softdtw_cuda[B, threads_per_block](cuda.as_cuda_array(D.detach()), 141 | gamma.item(), bandwidth.item(), N, M, n_passes, 142 | cuda.as_cuda_array(R)) 143 | ctx.save_for_backward(D, R.clone(), gamma, bandwidth) 144 | return R[:, -2, -2] 145 | 146 | @staticmethod 147 | def backward(ctx, grad_output): 148 | dev = grad_output.device 149 | dtype = grad_output.dtype 150 | D, R, gamma, bandwidth = ctx.saved_tensors 151 | 152 | B = D.shape[0] 153 | N = D.shape[1] 154 | M = D.shape[2] 155 | threads_per_block = max(N, M) 156 | n_passes = 2 * threads_per_block - 1 157 | 158 | D_ = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev) 159 | D_[:, 1:N + 1, 1:M + 1] = D 160 | 161 | R[:, :, -1] = -math.inf 162 | R[:, -1, :] = -math.inf 163 | R[:, -1, -1] = R[:, -2, -2] 164 | 165 | E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev) 166 | E[:, -1, -1] = 1 167 | 168 | # Grid and block sizes are set same as done above for the forward() call 169 | compute_softdtw_backward_cuda[B, threads_per_block](cuda.as_cuda_array(D_), 170 | cuda.as_cuda_array(R), 171 | 1.0 / gamma.item(), bandwidth.item(), N, M, n_passes, 172 | cuda.as_cuda_array(E)) 173 | E = E[:, 1:N + 1, 1:M + 1] 174 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None 175 | 176 | 177 | # ---------------------------------------------------------------------------------------------------------------------- 178 | # 179 | # The following is the CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw 180 | # Credit goes to Kanru Hua. 181 | # I've added support for batching and pruning. 182 | # 183 | # ---------------------------------------------------------------------------------------------------------------------- 184 | @jit(nopython=True) 185 | def compute_softdtw(D, gamma, bandwidth): 186 | B = D.shape[0] 187 | N = D.shape[1] 188 | M = D.shape[2] 189 | R = np.ones((B, N + 2, M + 2)) * np.inf 190 | R[:, 0, 0] = 0 191 | for b in range(B): 192 | for j in range(1, M + 1): 193 | for i in range(1, N + 1): 194 | 195 | # Check the pruning condition 196 | if 0 < bandwidth < np.abs(i - j): 197 | continue 198 | 199 | r0 = -R[b, i - 1, j - 1] / gamma 200 | r1 = -R[b, i - 1, j] / gamma 201 | r2 = -R[b, i, j - 1] / gamma 202 | rmax = max(max(r0, r1), r2) 203 | rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) 204 | softmin = - gamma * (np.log(rsum) + rmax) 205 | R[b, i, j] = D[b, i - 1, j - 1] + softmin 206 | return R 207 | 208 | # ---------------------------------------------------------------------------------------------------------------------- 209 | @jit(nopython=True) 210 | def compute_softdtw_backward(D_, R, gamma, bandwidth): 211 | B = D_.shape[0] 212 | N = D_.shape[1] 213 | M = D_.shape[2] 214 | D = np.zeros((B, N + 2, M + 2)) 215 | E = np.zeros((B, N + 2, M + 2)) 216 | D[:, 1:N + 1, 1:M + 1] = D_ 217 | E[:, -1, -1] = 1 218 | R[:, :, -1] = -np.inf 219 | R[:, -1, :] = -np.inf 220 | R[:, -1, -1] = R[:, -2, -2] 221 | for k in range(B): 222 | for j in range(M, 0, -1): 223 | for i in range(N, 0, -1): 224 | 225 | if np.isinf(R[k, i, j]): 226 | R[k, i, j] = -np.inf 227 | 228 | # Check the pruning condition 229 | if 0 < bandwidth < np.abs(i - j): 230 | continue 231 | 232 | a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma 233 | b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma 234 | c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma 235 | a = np.exp(a0) 236 | b = np.exp(b0) 237 | c = np.exp(c0) 238 | E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c 239 | return E[:, 1:N + 1, 1:M + 1] 240 | 241 | # ---------------------------------------------------------------------------------------------------------------------- 242 | class _SoftDTW(Function): 243 | """ 244 | CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw 245 | """ 246 | 247 | @staticmethod 248 | def forward(ctx, D, gamma, bandwidth): 249 | dev = D.device 250 | dtype = D.dtype 251 | gamma = torch.Tensor([gamma]).to(dev).type(dtype) # dtype fixed 252 | bandwidth = torch.Tensor([bandwidth]).to(dev).type(dtype) 253 | D_ = D.detach().cpu().numpy() 254 | g_ = gamma.item() 255 | b_ = bandwidth.item() 256 | R = torch.Tensor(compute_softdtw(D_, g_, b_)).to(dev).type(dtype) 257 | ctx.save_for_backward(D, R, gamma, bandwidth) 258 | return R[:, -2, -2] 259 | 260 | @staticmethod 261 | def backward(ctx, grad_output): 262 | dev = grad_output.device 263 | dtype = grad_output.dtype 264 | D, R, gamma, bandwidth = ctx.saved_tensors 265 | D_ = D.detach().cpu().numpy() 266 | R_ = R.detach().cpu().numpy() 267 | g_ = gamma.item() 268 | b_ = bandwidth.item() 269 | E = torch.Tensor(compute_softdtw_backward(D_, R_, g_, b_)).to(dev).type(dtype) 270 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None 271 | 272 | # ---------------------------------------------------------------------------------------------------------------------- 273 | class SoftDTW(torch.nn.Module): 274 | """ 275 | The soft DTW implementation that optionally supports CUDA 276 | """ 277 | 278 | def __init__(self, use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_func=None): 279 | """ 280 | Initializes a new instance using the supplied parameters 281 | :param use_cuda: Flag indicating whether the CUDA implementation should be used 282 | :param gamma: sDTW's gamma parameter 283 | :param normalize: Flag indicating whether to perform normalization 284 | (as discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790) 285 | :param bandwidth: Sakoe-Chiba bandwidth for pruning. Passing 'None' will disable pruning. 286 | :param dist_func: Optional point-wise distance function to use. If 'None', then a default Euclidean distance function will be used. 287 | """ 288 | super(SoftDTW, self).__init__() 289 | self.normalize = normalize 290 | self.gamma = gamma 291 | self.bandwidth = 0 if bandwidth is None else float(bandwidth) 292 | self.use_cuda = use_cuda 293 | 294 | # Set the distance function 295 | if dist_func is not None: 296 | self.dist_func = dist_func 297 | else: 298 | self.dist_func = SoftDTW._euclidean_dist_func 299 | 300 | def _get_func_dtw(self, x, y): 301 | """ 302 | Checks the inputs and selects the proper implementation to use. 303 | """ 304 | bx, lx, dx = x.shape 305 | by, ly, dy = y.shape 306 | # Make sure the dimensions match 307 | assert bx == by # Equal batch sizes 308 | assert dx == dy # Equal feature dimensions 309 | 310 | use_cuda = self.use_cuda 311 | 312 | if use_cuda and (lx > 1024 or ly > 1024): # We should be able to spawn enough threads in CUDA 313 | print("SoftDTW: Cannot use CUDA because the sequence length > 1024 (the maximum block size supported by CUDA)") 314 | use_cuda = False 315 | 316 | # Finally, return the correct function 317 | return _SoftDTWCUDA.apply if use_cuda else _SoftDTW.apply 318 | 319 | @staticmethod 320 | def _euclidean_dist_func(x, y): 321 | """ 322 | Calculates the Euclidean distance between each element in x and y per timestep 323 | """ 324 | n = x.size(1) 325 | m = y.size(1) 326 | d = x.size(2) 327 | x = x.unsqueeze(2).expand(-1, n, m, d) 328 | y = y.unsqueeze(1).expand(-1, n, m, d) 329 | return torch.pow(x - y, 2).sum(3) 330 | 331 | def forward(self, X, Y): 332 | """ 333 | Compute the soft-DTW value between X and Y 334 | :param X: One batch of examples, batch_size x seq_len x dims 335 | :param Y: The other batch of examples, batch_size x seq_len x dims 336 | :return: The computed results 337 | """ 338 | 339 | # Check the inputs and get the correct implementation 340 | func_dtw = self._get_func_dtw(X, Y) 341 | 342 | if self.normalize: 343 | # Stack everything up and run 344 | x = torch.cat([X, X, Y]) 345 | y = torch.cat([Y, X, Y]) 346 | D = self.dist_func(x, y) 347 | out = func_dtw(D, self.gamma, self.bandwidth) 348 | out_xy, out_xx, out_yy = torch.split(out, X.shape[0]) 349 | return out_xy - 1 / 2 * (out_xx + out_yy) 350 | else: 351 | D_xy = self.dist_func(X, Y) 352 | return func_dtw(D_xy, self.gamma, self.bandwidth) 353 | 354 | # ---------------------------------------------------------------------------------------------------------------------- 355 | def timed_run(a, b, sdtw): 356 | """ 357 | Runs a and b through sdtw, and times the forward and backward passes. 358 | Assumes that a requires gradients. 359 | :return: timing, forward result, backward result 360 | """ 361 | from timeit import default_timer as timer 362 | 363 | # Forward pass 364 | start = timer() 365 | forward = sdtw(a, b) 366 | end = timer() 367 | t = end - start 368 | 369 | grad_outputs = torch.ones_like(forward) 370 | 371 | # Backward 372 | start = timer() 373 | grads = torch.autograd.grad(forward, a, grad_outputs=grad_outputs)[0] 374 | end = timer() 375 | 376 | # Total time 377 | t += end - start 378 | 379 | return t, forward, grads 380 | 381 | # ---------------------------------------------------------------------------------------------------------------------- 382 | def profile(batch_size, seq_len_a, seq_len_b, dims, tol_backward): 383 | sdtw = SoftDTW(False, gamma=1.0, normalize=False) 384 | sdtw_cuda = SoftDTW(True, gamma=1.0, normalize=False) 385 | n_iters = 6 386 | 387 | print("Profiling forward() + backward() times for batch_size={}, seq_len_a={}, seq_len_b={}, dims={}...".format(batch_size, seq_len_a, seq_len_b, dims)) 388 | 389 | times_cpu = [] 390 | times_gpu = [] 391 | 392 | for i in range(n_iters): 393 | a_cpu = torch.rand((batch_size, seq_len_a, dims), requires_grad=True) 394 | b_cpu = torch.rand((batch_size, seq_len_b, dims)) 395 | a_gpu = a_cpu.cuda() 396 | b_gpu = b_cpu.cuda() 397 | 398 | # GPU 399 | t_gpu, forward_gpu, backward_gpu = timed_run(a_gpu, b_gpu, sdtw_cuda) 400 | 401 | # CPU 402 | t_cpu, forward_cpu, backward_cpu = timed_run(a_cpu, b_cpu, sdtw) 403 | 404 | # Verify the results 405 | assert torch.allclose(forward_cpu, forward_gpu.cpu()) 406 | assert torch.allclose(backward_cpu, backward_gpu.cpu(), atol=tol_backward) 407 | 408 | if i > 0: # Ignore the first time we run, in case this is a cold start (because timings are off at a cold start of the script) 409 | times_cpu += [t_cpu] 410 | times_gpu += [t_gpu] 411 | 412 | # Average and log 413 | avg_cpu = np.mean(times_cpu) 414 | avg_gpu = np.mean(times_gpu) 415 | print("\tCPU: ", avg_cpu) 416 | print("\tGPU: ", avg_gpu) 417 | print("\tSpeedup: ", avg_cpu / avg_gpu) 418 | print() 419 | 420 | # ---------------------------------------------------------------------------------------------------------------------- 421 | if __name__ == "__main__": 422 | from timeit import default_timer as timer 423 | 424 | torch.manual_seed(1234) 425 | 426 | profile(128, 17, 15, 2, tol_backward=1e-6) 427 | profile(512, 64, 64, 2, tol_backward=1e-4) 428 | profile(512, 256, 256, 2, tol_backward=1e-3) 429 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import sys 5 | import time 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data.dataloader import DataLoader 11 | from tqdm import tqdm 12 | 13 | from datasets.amass import AMASS 14 | from arguments import Arguments 15 | from nemf.generative import Architecture 16 | 17 | 18 | def train(): 19 | model = Architecture(args, ngpu, len(train_data_loader)) 20 | model.print() 21 | model.setup() 22 | 23 | loss_min = None 24 | if args.epoch_begin != 0: 25 | model.load(epoch=args.epoch_begin) 26 | model.eval() 27 | for data in valid_data_loader: 28 | model.set_input(data) 29 | model.validate() 30 | loss_min = model.verbose() 31 | 32 | epoch_begin = args.epoch_begin + 1 33 | epoch_end = epoch_begin + args.epoch_num 34 | start_time = time.time() 35 | 36 | for epoch in range(epoch_begin, epoch_end): 37 | model.train() 38 | with tqdm(train_data_loader, unit="batch") as tepoch: 39 | for data in tepoch: 40 | tepoch.set_description(f"Epoch {epoch}") 41 | model.set_input(data) 42 | model.optimize_parameters() 43 | 44 | model.eval() 45 | for data in valid_data_loader: 46 | model.set_input(data) 47 | model.validate() 48 | 49 | model.epoch() 50 | res = model.verbose() 51 | 52 | if args.verbose: 53 | print(f'Epoch {epoch}/{epoch_end - 1}:') 54 | print(json.dumps(res, sort_keys=True, indent=4)) 55 | 56 | if loss_min is None or res['total_loss']['val'] < loss_min['total_loss']['val']: 57 | loss_min = res 58 | model.save(optimal=True) 59 | 60 | if epoch % args.checkpoint == 0 or epoch == epoch_end - 1: 61 | model.save() 62 | 63 | end_time = time.time() 64 | print(f'Training finished in {time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))}') 65 | print('Final Loss:') 66 | print(json.dumps(loss_min, sort_keys=True, indent=4)) 67 | df = pd.DataFrame.from_dict(loss_min) 68 | df.to_csv(os.path.join(args.save_dir, f'{args.filename}.csv'), index=False) 69 | 70 | 71 | def test(steps): 72 | model = Architecture(args, ngpu, len(test_data_loader)) 73 | model.load(optimal=True) 74 | model.eval() 75 | 76 | for step in steps: 77 | statistics = dict() 78 | model.test_index = 0 79 | for data in test_data_loader: 80 | model.set_input(data) 81 | model.super_sampling(step=step) 82 | if step == 1.0: 83 | errors = model.report_errors() 84 | if not statistics: 85 | statistics = { 86 | 'rotation_error': [errors['rotation'] * 180.0 / np.pi], 87 | 'position_error': [errors['position'] * 100.0], 88 | 'orientation_error': [errors['orientation'] * 180.0 / np.pi], 89 | 'translation_error': [errors['translation'] * 100.0] 90 | } 91 | else: 92 | statistics['rotation_error'].append(errors['rotation'] * 180.0 / np.pi) 93 | statistics['position_error'].append(errors['position'] * 100.0) 94 | statistics['orientation_error'].append(errors['orientation'] * 180.0 / np.pi) 95 | statistics['translation_error'].append(errors['translation'] * 100.0) 96 | 97 | if step == 1.0: 98 | df = pd.DataFrame.from_dict(statistics) 99 | df.to_csv(os.path.join(args.save_dir, f'{args.filename}_test.csv'), index=False) 100 | 101 | 102 | if __name__ == '__main__': 103 | if len(sys.argv) == 1: 104 | args = Arguments('./configs', filename='generative.yaml') 105 | else: 106 | args = Arguments('./configs', filename=sys.argv[1]) 107 | print(json.dumps(args.json, sort_keys=True, indent=4)) 108 | 109 | torch.set_default_dtype(torch.float32) 110 | 111 | torch.manual_seed(0) 112 | np.random.seed(0) 113 | random.seed(0) 114 | 115 | ngpu = 1 116 | if args.multi_gpu is True: 117 | ngpu = torch.cuda.device_count() 118 | if ngpu == 1: 119 | args.multi_gpu = False 120 | print(f'Number of GPUs: {ngpu}') 121 | 122 | # dataset definition 123 | train_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'train')) 124 | train_data_loader = DataLoader(train_dataset, batch_size=ngpu * args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True, drop_last=True) 125 | valid_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'valid')) 126 | valid_data_loader = DataLoader(valid_dataset, batch_size=ngpu * args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) 127 | test_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'test')) 128 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) 129 | 130 | if args.is_train: 131 | train() 132 | 133 | args.is_train = False 134 | test(steps=[0.5, 1.0]) 135 | -------------------------------------------------------------------------------- /src/train_basic.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import sys 5 | import time 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | 11 | from arguments import Arguments 12 | from nemf.basic import Architecture 13 | 14 | 15 | def load_data(data_path, index, frames): 16 | data = {} 17 | for key in ['pos', 'global_xform', 'root_orient', 'root_vel', 'trans']: 18 | data[key] = torch.load(os.path.join(data_path, f'{key}_{index}.pt')) 19 | if frames != -1: 20 | data[key] = data[key][350:frames + 350].unsqueeze(0) 21 | else: 22 | data[key] = data[key].unsqueeze(0) 23 | 24 | return data 25 | 26 | 27 | def train_basic(frames, save_dir, steps): 28 | data_path = os.path.join(args.dataset_dir, 'train') 29 | file_indices = list(range(16)) if args.amass_data else [22] 30 | 31 | statistics = dict() 32 | for index in file_indices: 33 | print(f'Fitting sequence {index} with {frames} frames:') 34 | args.save_dir = os.path.join(save_dir, f'frame_{frames}', f'sequence_{index}') 35 | model = Architecture(args, ngpu) 36 | model.setup() 37 | model.print() 38 | print(f'# of parameters: {model.count_params()}') 39 | model.set_input(load_data(data_path, index, frames)) 40 | 41 | start_time = time.time() 42 | 43 | iterations = args.iterations * (frames // 32) if args.amass_data else args.iterations 44 | if args.is_train: 45 | model.train() 46 | loss_min = None 47 | for iter in range(iterations): 48 | model.optimize_parameters() 49 | 50 | model.epoch() 51 | res = model.verbose() 52 | 53 | if args.verbose: 54 | print(f'Iteration {iter}/{iterations}:') 55 | print(json.dumps(res, sort_keys=True, indent=4)) 56 | 57 | if loss_min is None or res['total_loss']['train'] < loss_min['total_loss']['train']: 58 | loss_min = res 59 | model.save(optimal=True) 60 | 61 | end_time = time.time() 62 | print(f'Training finished in {time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))}') 63 | print('Final Loss:') 64 | print(json.dumps(loss_min, sort_keys=True, indent=4)) 65 | 66 | model.eval() 67 | model.load(optimal=True) 68 | 69 | for step in steps: 70 | model.super_sampling(step) 71 | 72 | if step == 1.0: 73 | errors = model.report_errors() 74 | if not statistics: 75 | statistics = { 76 | 'iterations': [iterations], 77 | 'rotation_error': [errors['rotation'] * 180.0 / np.pi], 78 | 'position_error': [errors['position'] * 100.0], 79 | 'orientation_error': [errors['orientation'] * 180.0 / np.pi], 80 | 'translation_error': [errors['translation'] * 100.0] 81 | } 82 | else: 83 | statistics['iterations'].append(iterations) 84 | statistics['rotation_error'].append(errors['rotation'] * 180.0 / np.pi) 85 | statistics['position_error'].append(errors['position'] * 100.0) 86 | statistics['orientation_error'].append(errors['orientation'] * 180.0 / np.pi) 87 | statistics['translation_error'].append(errors['translation'] * 100.0) 88 | 89 | if statistics: 90 | df = pd.DataFrame.from_dict(statistics) 91 | df.to_csv(os.path.join(save_dir, f'recon_{frames}.csv'), index=False) 92 | 93 | 94 | if __name__ == '__main__': 95 | if len(sys.argv) == 1: 96 | args = Arguments('./configs', filename='basic.yaml') 97 | else: 98 | args = Arguments('./configs', filename=sys.argv[1]) 99 | print(json.dumps(args.json, sort_keys=True, indent=4)) 100 | 101 | torch.set_default_dtype(torch.float32) 102 | 103 | torch.manual_seed(0) 104 | np.random.seed(0) 105 | random.seed(0) 106 | 107 | ngpu = 1 108 | if args.multi_gpu is True: 109 | ngpu = torch.cuda.device_count() 110 | if ngpu == 1: 111 | args.multi_gpu = False 112 | print(f'Number of GPUs: {ngpu}') 113 | 114 | save_dir = args.save_dir 115 | frames = [32, 64, 128, 256, 512] if args.amass_data else [-1] 116 | steps = [1.0, 0.5, 0.25, 0.125] if args.amass_data else [1.0] 117 | for f in frames: 118 | train_basic(frames=f, save_dir=save_dir, steps=steps) 119 | -------------------------------------------------------------------------------- /src/train_gmp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import sys 5 | import time 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data.dataloader import DataLoader 11 | from tqdm import tqdm 12 | 13 | from arguments import Arguments 14 | from datasets.amass import AMASS 15 | from nemf.global_motion import GlobalMotionPredictor 16 | 17 | 18 | def train(): 19 | model = GlobalMotionPredictor(args, ngpu) 20 | model.print() 21 | model.setup() 22 | 23 | loss_min = None 24 | if args.epoch_begin != 0: 25 | model.load(epoch=args.epoch_begin) 26 | model.eval() 27 | for data in valid_data_loader: 28 | model.set_input(data) 29 | model.validate() 30 | loss_min = model.verbose() 31 | 32 | epoch_begin = args.epoch_begin + 1 33 | epoch_end = epoch_begin + args.epoch_num 34 | start_time = time.time() 35 | 36 | for epoch in range(epoch_begin, epoch_end): 37 | model.train() 38 | with tqdm(train_data_loader, unit="batch") as tepoch: 39 | for data in tepoch: 40 | tepoch.set_description(f"Epoch {epoch}") 41 | model.set_input(data) 42 | model.optimize_parameters() 43 | 44 | model.eval() 45 | for data in valid_data_loader: 46 | model.set_input(data) 47 | model.validate() 48 | 49 | model.epoch() 50 | res = model.verbose() 51 | 52 | if args.verbose: 53 | print(f'Epoch {epoch}/{epoch_end - 1}:') 54 | print(json.dumps(res, sort_keys=True, indent=4)) 55 | 56 | if loss_min is None or res['total_loss']['val'] < loss_min['total_loss']['val']: 57 | loss_min = res 58 | model.save(optimal=True) 59 | 60 | if epoch % args.checkpoint == 0 or epoch == epoch_end - 1: 61 | model.save() 62 | 63 | end_time = time.time() 64 | print(f'Training finished in {time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))}') 65 | print('Final Loss:') 66 | print(json.dumps(loss_min, sort_keys=True, indent=4)) 67 | df = pd.DataFrame.from_dict(loss_min) 68 | df.to_csv(os.path.join(args.save_dir, f'{args.filename}.csv'), index=False) 69 | 70 | 71 | def test(): 72 | model = GlobalMotionPredictor(args, ngpu) 73 | model.load(optimal=True) 74 | model.eval() 75 | 76 | statistics = dict() 77 | for data in test_data_loader: 78 | model.set_input(data) 79 | model.test() 80 | 81 | errors = model.report_errors() 82 | if not statistics: 83 | statistics = { 84 | 'translation_error': [errors['translation'] * 100.0] 85 | } 86 | else: 87 | statistics['translation_error'].append(errors['translation'] * 100.0) 88 | 89 | df = pd.DataFrame.from_dict(statistics) 90 | df.to_csv(os.path.join(args.save_dir, f'{args.filename}_test.csv'), index=False) 91 | 92 | 93 | if __name__ == '__main__': 94 | if len(sys.argv) == 1: 95 | args = Arguments('./configs', filename='gmp.yaml') 96 | else: 97 | args = Arguments('./configs', filename=sys.argv[1]) 98 | print(json.dumps(args.json, sort_keys=True, indent=4)) 99 | 100 | torch.set_default_dtype(torch.float32) 101 | 102 | torch.manual_seed(0) 103 | np.random.seed(0) 104 | random.seed(0) 105 | 106 | ngpu = 1 107 | if args.multi_gpu is True: 108 | ngpu = torch.cuda.device_count() 109 | if ngpu == 1: 110 | args.multi_gpu = False 111 | print(f'Number of GPUs: {ngpu}') 112 | 113 | # dataset definition 114 | train_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'train')) 115 | train_data_loader = DataLoader(train_dataset, batch_size=ngpu * args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True) 116 | valid_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'valid')) 117 | valid_data_loader = DataLoader(valid_dataset, batch_size=ngpu * args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) 118 | test_dataset = AMASS(dataset_dir=os.path.join(args.dataset_dir, 'test')) 119 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) 120 | 121 | if args.is_train: 122 | train() 123 | 124 | args.is_train = False 125 | test() 126 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import plyfile 5 | import torch 6 | import torch.nn.functional as F 7 | from human_body_prior.tools.omni_tools import copy2cpu as c2c 8 | from scipy.spatial.transform import Rotation as R 9 | from scipy.spatial.transform import Slerp 10 | 11 | import holden.BVH as BVH 12 | from holden.Animation import Animation 13 | from holden.Quaternions import Quaternions 14 | 15 | SMPL_JOINTS = { 16 | 'Pelvis': 0, 17 | 'L_Hip': 1, 'L_Knee': 4, 'L_Ankle': 7, 'L_Foot': 10, 18 | 'R_Hip': 2, 'R_Knee': 5, 'R_Ankle': 8, 'R_Foot': 11, 19 | 'Spine1': 3, 'Spine2': 6, 'Spine3': 9, 'Neck': 12, 'Head': 15, 20 | 'L_Collar': 13, 'L_Shoulder': 16, 'L_Elbow': 18, 'L_Wrist': 20, 'L_Hand': 22, 21 | 'R_Collar': 14, 'R_Shoulder': 17, 'R_Elbow': 19, 'R_Wrist': 21, 'R_Hand': 23 22 | } 23 | 24 | FOOT_IDX = [SMPL_JOINTS['L_Ankle'], SMPL_JOINTS['R_Ankle'], SMPL_JOINTS['L_Foot'], SMPL_JOINTS['R_Foot']] 25 | 26 | CONTACTS_IDX = [SMPL_JOINTS['L_Ankle'], SMPL_JOINTS['R_Ankle'], 27 | SMPL_JOINTS['L_Foot'], SMPL_JOINTS['R_Foot'], 28 | SMPL_JOINTS['L_Wrist'], SMPL_JOINTS['R_Wrist'], 29 | SMPL_JOINTS['L_Knee'], SMPL_JOINTS['R_Knee']] 30 | 31 | 32 | def align_joints(tensor, smpl_to_bvh=True): 33 | """ 34 | Args: 35 | tensor (T x J x D): The 3D torch tensor we need to process. 36 | 37 | Returns: 38 | The 3D tensor whose joint order is compatible for bvh export. 39 | """ 40 | order = list(SMPL_JOINTS.values()) 41 | result = tensor.clone() if isinstance(tensor, torch.Tensor) else tensor.copy() 42 | for i in range(len(order)): 43 | if smpl_to_bvh: 44 | result[:, i] = tensor[:, order[i]] 45 | else: 46 | result[:, order[i]] = tensor[:, i] 47 | 48 | return result 49 | 50 | 51 | def slerp(quat, trans, key_times, times, mask=True): 52 | """ 53 | Args: 54 | quat: (T x J x 4) 55 | trans: (T x 3) 56 | """ 57 | if mask: 58 | quat = c2c(quat[key_times]) 59 | trans = c2c(trans[key_times]) 60 | else: 61 | quat = c2c(quat) 62 | trans = c2c(trans) 63 | 64 | quats = [] 65 | for j in range(quat.shape[1]): 66 | key_rots = R.from_quat(quat[:, j]) 67 | s = Slerp(key_times, key_rots) 68 | interp_rots = s(times) 69 | quats.append(interp_rots.as_quat()) 70 | slerp_quat = np.stack(quats, axis=1) 71 | 72 | lerp_trans = np.zeros((len(times), 3)) 73 | for i in range(3): 74 | lerp_trans[:, i] = np.interp(times, key_times, trans[:, i]) 75 | 76 | return slerp_quat, lerp_trans 77 | 78 | 79 | def compute_orient_angle(matrix, traj): 80 | """ 81 | Args: 82 | matrix (N x T x 3 x 3): The 3D rotation matrix at the root joint 83 | traj (N x T x 3): The trajectory to align 84 | """ 85 | forward = matrix[:, :, :, 2].clone() 86 | forward[:, :, 2] = 0 87 | forward = F.normalize(forward, dim=-1) # normalized forward vector (N, T, 3) 88 | 89 | traj[:, :, 2] = 0 # make sure the trajectory is projected to the plane 90 | 91 | # first steps is forward diff 92 | init_tan = traj[:, 1:2] - traj[:, :1] 93 | # middle steps are second order 94 | middle_tan = (traj[:, 2:] - traj[:, 0:-2]) / 2 95 | # last step is backward diff 96 | final_tan = traj[:, -1:] - traj[:, -2:-1] 97 | 98 | tangent = torch.cat([init_tan, middle_tan, final_tan], dim=1) 99 | tangent = F.normalize(tangent, dim=-1) # normalized tangent vector (N, T, 3) 100 | 101 | cos = torch.sum(forward * tangent, dim=-1) 102 | 103 | return cos 104 | 105 | 106 | def compute_trajectory(velocity, up, origin, dt, up_axis='z'): 107 | """ 108 | Args: 109 | velocity: (B, T, 3) 110 | up: (B, T) 111 | origin: (B, 3) 112 | up_axis: x, y, or z 113 | 114 | Returns: 115 | trajectory: (B, T, 3) 116 | """ 117 | ordermap = { 118 | 'x': 0, 119 | 'y': 1, 120 | 'z': 2, 121 | } 122 | v_axis = [x for x in ordermap.values() if x != ordermap[up_axis]] 123 | 124 | origin = origin.unsqueeze(1) # (B, 3) => (B, 1, 3) 125 | trajectory = origin.repeat(1, up.shape[1], 1) # (B, 1, 3) => (B, T, 3) 126 | 127 | for t in range(1, up.shape[1]): 128 | trajectory[:, t, v_axis[0]] = trajectory[:, t - 1, v_axis[0]] + velocity[:, t - 1, v_axis[0]] * dt 129 | trajectory[:, t, v_axis[1]] = trajectory[:, t - 1, v_axis[1]] + velocity[:, t - 1, v_axis[1]] * dt 130 | 131 | trajectory[:, :, ordermap[up_axis]] = up 132 | 133 | return trajectory 134 | 135 | 136 | def build_canonical_frame(forward, up_axis='z'): 137 | """ 138 | Args: 139 | forward: (..., 3) 140 | 141 | Returns: 142 | frame: (..., 3, 3) 143 | """ 144 | forward[..., 'xyz'.index(up_axis)] = 0 145 | forward = F.normalize(forward, dim=-1) # normalized forward vector 146 | 147 | up = torch.zeros_like(forward) 148 | up[..., 'xyz'.index(up_axis)] = 1 # normalized up vector 149 | right = torch.cross(up, forward) 150 | frame = torch.stack((right, up, forward), dim=-1) # canonical frame 151 | 152 | return frame 153 | 154 | 155 | def estimate_linear_velocity(data_seq, dt): 156 | ''' 157 | Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates 158 | the velocity for the middle T-2 steps using a second order central difference scheme. 159 | The first and last frames are with forward and backward first-order 160 | differences, respectively 161 | - h : step size 162 | ''' 163 | # first steps is forward diff (t+1 - t) / dt 164 | init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt 165 | # middle steps are second order (t+1 - t-1) / 2dt 166 | middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) 167 | # last step is backward diff (t - t-1) / dt 168 | final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt 169 | 170 | vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) 171 | return vel_seq 172 | 173 | 174 | def estimate_angular_velocity(rot_seq, dt): 175 | ''' 176 | Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. 177 | Input sequence should be of shape (B, T, ..., 3, 3) 178 | ''' 179 | # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix 180 | dRdt = estimate_linear_velocity(rot_seq, dt) 181 | R = rot_seq 182 | RT = R.transpose(-1, -2) 183 | # compute skew-symmetric angular velocity tensor 184 | w_mat = torch.matmul(dRdt, RT) 185 | # pull out angular velocity vector by averaging symmetric entries 186 | w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 187 | w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 188 | w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 189 | w = torch.stack([w_x, w_y, w_z], axis=-1) 190 | return w 191 | 192 | 193 | def normalize(tensor, mean=None, std=None): 194 | """ 195 | Args: 196 | tensor: (B, T, ...) 197 | 198 | Returns: 199 | normalized tensor with 0 mean and 1 standard deviation, std, mean 200 | """ 201 | if mean is None or std is None: 202 | # std, mean = torch.std_mean(tensor, dim=0, unbiased=False, keepdim=True) 203 | std, mean = torch.std_mean(tensor, dim=(0, 1), unbiased=False, keepdim=True) 204 | std[std == 0.0] = 1.0 205 | 206 | return (tensor - mean) / std, mean, std 207 | 208 | return (tensor - mean) / std 209 | 210 | 211 | def denormalize(tensor, mean, std): 212 | """ 213 | Args: 214 | tensor: B x T x D 215 | mean: 216 | std: 217 | """ 218 | return tensor * std + mean 219 | 220 | 221 | def export_bvh_animation(rotations, positions, offsets, parents, output_dir, prefix, joint_names, fps): 222 | """ 223 | Args: 224 | rotations: quaternions of the shape (B, T, J, 4) 225 | positions: global translations of the shape (B, T, 3) 226 | """ 227 | os.makedirs(output_dir, exist_ok=True) 228 | for i in range(rotations.shape[0]): 229 | rotation = align_joints(rotations[i]) 230 | position = positions[i] 231 | position = position.unsqueeze(1) 232 | anim = Animation(Quaternions(c2c(rotation)), c2c(position), None, offsets=offsets, parents=parents) 233 | BVH.save(os.path.join(output_dir, f'{prefix}_{i}.bvh'), anim, names=joint_names, frametime=1 / fps) 234 | 235 | 236 | def export_ply_trajectory(points, color, ply_fname): 237 | v = [] 238 | for p in points: 239 | v += [(p[0], p[1], p[2], color[0], color[1], color[2])] 240 | v = np.array(v, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 241 | el_v = plyfile.PlyElement.describe(v, 'vertex') 242 | 243 | e = np.empty(len(points) - 1, dtype=[('vertex1', 'i4'), ('vertex2', 'i4')]) 244 | edge_data = np.array([[i, i + 1] for i in range(len(points) - 1)], dtype='i4') 245 | e['vertex1'] = edge_data[:, 0] 246 | e['vertex2'] = edge_data[:, 1] 247 | el_e = plyfile.PlyElement.describe(e, 'edge') 248 | 249 | plyfile.PlyData([el_v, el_e], text=True).write(ply_fname) 250 | --------------------------------------------------------------------------------