├── utils
├── __init__.py
├── paramUtil.py
├── get_opt.py
├── other_tools.py
├── utils.py
├── metrics.py
└── quaternion.py
├── datasets
├── pymo
│ ├── __init__.py
│ ├── mocapplayer
│ │ ├── data-template.js
│ │ ├── libs
│ │ │ ├── threejs
│ │ │ │ └── Detector.js
│ │ │ ├── pace.min.js
│ │ │ └── papaparse.min.js
│ │ ├── styles
│ │ │ └── pace.css
│ │ ├── js
│ │ │ └── skeletonFactory.js
│ │ └── playURL.html
│ ├── features.py
│ ├── data.py
│ ├── writers.py
│ ├── rotation_tools.py
│ ├── viz_tools.py
│ └── parsers.py
├── __init__.py
├── extract_hubert.py
├── dataloader.py
└── show.py
├── assets
├── data.tar.gz
├── beat_visualize.blend
├── teaser_for_demo_cvpr.png
├── requirements.txt
└── environment.yml
├── audios
├── Forrest_tts.wav
└── 2_scott_0_3_3.wav
├── trainers
├── __init__.py
└── loss_factory.py
├── models
├── __init__.py
├── respace.py
├── motion_autoencoder.py
├── scheduler.py
└── ddpm_utils.py
├── .gitignore
├── inference_custom_audio_beat.sh
├── inference_custom_audio_show.sh
├── LICENSE
├── options
├── evaluate_options.py
├── train_options.py
└── base_options.py
├── train_test_scripts.sh
├── README.md
└── runner.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/pymo/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/data.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/assets/data.tar.gz
--------------------------------------------------------------------------------
/audios/Forrest_tts.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/audios/Forrest_tts.wav
--------------------------------------------------------------------------------
/datasets/pymo/mocapplayer/data-template.js:
--------------------------------------------------------------------------------
1 | var dataBuffer = `$$DATA$$`;
2 |
3 | start(dataBuffer);
--------------------------------------------------------------------------------
/audios/2_scott_0_3_3.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/audios/2_scott_0_3_3.wav
--------------------------------------------------------------------------------
/assets/beat_visualize.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/assets/beat_visualize.blend
--------------------------------------------------------------------------------
/assets/teaser_for_demo_cvpr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JeremyCJM/DiffSHEG/HEAD/assets/teaser_for_demo_cvpr.png
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .ddpm_beat_trainer import DDPMTrainer_beat
2 | from .ddpm_show_trainer import DDPMTrainer_show
3 |
4 |
5 | __all__ = ['DDPMTrainer_beat', 'DDPMTrainer_show']
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .show import ShowDataset
2 | from .beat import BeatDataset
3 | from .dataloader import build_dataloader
4 |
5 |
6 |
7 | __all__ = [
8 | 'ShowDataset', 'BeatDataset', 'build_dataloader']
--------------------------------------------------------------------------------
/assets/requirements.txt:
--------------------------------------------------------------------------------
1 | pyarrow==3.0.0
2 | mmcv-full==1.6.0
3 | matplotlib==3.7.1
4 | scipy==1.10.1
5 | lmdb==0.96
6 | termcolor==2.1.1
7 | librosa==0.9.2
8 | loguru==0.6.0
9 | umap==0.1.1
10 | wandb==0.13.10
11 | transformers==4.31.0
12 | pandas==1.4.2
13 | IPython
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import MotionTransformer, UniDiffuser
2 | from .gaussian_diffusion import GaussianDiffusion
3 | # from .respace import GaussianDiffusion
4 | from .scheduler import get_schedule_jump, get_schedule_jump_paper
5 |
6 | __all__ = ['MotionTransformer', 'UniDiffuser', 'GaussianDiffusion', 'get_schedule_jump', 'get_schedule_jump_paper']
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints
2 | results
3 | wandb
4 | SMPLX_NEUTRAL_2020.npz
5 | rendered
6 | data
7 | rendering
8 | temp
9 |
10 | __pycache__
11 | *.pyc
12 |
13 | motion_ae_fid/logs
14 | test.sh
15 | train.sh
16 |
17 | videos
18 | rebuttal
19 |
20 | test*
21 |
22 | ## BEAT_ori
23 | 0_BEAT_ori/outputs/*
24 | 0_BEAT_ori/datasets
25 | 0_BEAT_ori/codes/audio2pose/docs/*
26 | 0_BEAT_ori/codes/audio2pose/*.sh
27 |
28 |
29 | ## TalkSHOW
30 | *.pkl
31 | *.zip
32 | *.mp4
33 | A_TalkSHOW_ori/.idea
34 | A_TalkSHOW_ori/data
35 | A_TalkSHOW_ori/experiments
36 | A_TalkSHOW_ori/visualise/smplx
37 | A_TalkSHOW_ori/visualise/video
38 | A_TalkSHOW_ori/visualise/mesh_vertices
39 | A_TalkSHOW_ori/visualise/mouth_keypoints
40 | A_TalkSHOW_ori/demo_audio
41 | A_TalkSHOW_ori/demo
42 | A_TalkSHOW_ori/results
43 |
44 |
45 |
46 | # SpeechSplit
47 | B_SpeechSplit/assets
48 | B_SpeechSplit/results*
49 |
50 |
--------------------------------------------------------------------------------
/inference_custom_audio_beat.sh:
--------------------------------------------------------------------------------
1 | ## For 25+ FPS on A100
2 | # PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
3 | # OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0 python -u runner.py \
4 | # --dataset_name beat \
5 | # --name beat_GesExpr_unify_addHubert_encodeHubert_mlpIncludeX_condRes_LN \
6 | # --n_poses 34 \
7 | # --ddim \
8 | # --ckpt fgd_best.tar \
9 | # --ddim \
10 | # --timestep_respacing ddim25 \
11 | # --overlap_len 4 \
12 | # --mode test_custom_audio \
13 | # --test_audio_path audios/2_scott_0_3_3.wav
14 |
15 | ## For 30+ FPS on 3090; For 55+ FPS on A100
16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
17 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0 python -u runner.py \
18 | --dataset_name beat \
19 | --name beat_GesExpr_unify_addHubert_encodeHubert_mlpIncludeX_condRes_LN \
20 | --n_poses 34 \
21 | --ddim \
22 | --ckpt fgd_best.tar \
23 | --ddim \
24 | --timestep_respacing ddim25 \
25 | --overlap_len 4 \
26 | --mode test_custom_audio \
27 | --jump_n_sample 2 \
28 | --test_audio_path audios/2_scott_0_3_3.wav
29 |
--------------------------------------------------------------------------------
/inference_custom_audio_show.sh:
--------------------------------------------------------------------------------
1 | # For 50+ FPS on A100
2 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
3 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0 python -u runner.py \
4 | --dataset_name talkshow \
5 | --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \
6 | --n_poses 88 \
7 | --model_base transformer_encoder \
8 | --classifier_free \
9 | --cond_scale 1.15 \
10 | --ckpt ckpt_e2599.tar \
11 | --ddim \
12 | --timestep_respacing ddim25 \
13 | --overlap_len 10 \
14 | --mode test_custom_audio \
15 | --test_audio_path audios/Forrest_tts.wav
16 |
17 |
18 | ## For 120+ FPS on A100
19 | # PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
20 | # OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0 python -u runner.py \
21 | # --dataset_name talkshow \
22 | # --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \
23 | # --n_poses 88 \
24 | # --model_base transformer_encoder \
25 | # --classifier_free \
26 | # --cond_scale 1.15 \
27 | # --ckpt ckpt_e2599.tar \
28 | # --ddim \
29 | # --timestep_respacing ddim25 \
30 | # --overlap_len 10 \
31 | # --mode test_custom_audio \
32 | # --jump_n_sample 2 \
33 | # --test_audio_path audios/Forrest_tts.wav
--------------------------------------------------------------------------------
/datasets/pymo/features.py:
--------------------------------------------------------------------------------
1 | '''
2 | A set of mocap feature extraction functions
3 |
4 | Created by Omid Alemi | Nov 17 2017
5 |
6 | '''
7 | import numpy as np
8 | import pandas as pd
9 | import peakutils
10 | import matplotlib.pyplot as plt
11 |
12 | def get_foot_contact_idxs(signal, t=0.02, min_dist=120):
13 | up_idxs = peakutils.indexes(signal, thres=t/max(signal), min_dist=min_dist)
14 | down_idxs = peakutils.indexes(-signal, thres=t/min(signal), min_dist=min_dist)
15 |
16 | return [up_idxs, down_idxs]
17 |
18 |
19 | def create_foot_contact_signal(mocap_track, col_name, start=1, t=0.02, min_dist=120):
20 | signal = mocap_track.values[col_name].values
21 | idxs = get_foot_contact_idxs(signal, t, min_dist)
22 |
23 | step_signal = []
24 |
25 | c = start
26 | for f in range(len(signal)):
27 | if f in idxs[1]:
28 | c = 0
29 | elif f in idxs[0]:
30 | c = 1
31 |
32 | step_signal.append(c)
33 |
34 | return step_signal
35 |
36 | def plot_foot_up_down(mocap_track, col_name, t=0.02, min_dist=120):
37 |
38 | signal = mocap_track.values[col_name].values
39 | idxs = get_foot_contact_idxs(signal, t, min_dist)
40 |
41 | plt.plot(mocap_track.values.index, signal)
42 | plt.plot(mocap_track.values.index[idxs[0]], signal[idxs[0]], 'ro')
43 | plt.plot(mocap_track.values.index[idxs[1]], signal[idxs[1]], 'go')
44 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, JeremyCJM
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/datasets/pymo/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Joint():
4 | def __init__(self, name, parent=None, children=None):
5 | self.name = name
6 | self.parent = parent
7 | self.children = children
8 |
9 | class MocapData():
10 | def __init__(self):
11 | self.skeleton = {}
12 | self.values = None
13 | self.channel_names = []
14 | self.framerate = 0.0
15 | self.root_name = ''
16 |
17 | def traverse(self, j=None):
18 | stack = [self.root_name]
19 | while stack:
20 | joint = stack.pop()
21 | yield joint
22 | for c in self.skeleton[joint]['children']:
23 | stack.append(c)
24 |
25 | def clone(self):
26 | import copy
27 | new_data = MocapData()
28 | new_data.skeleton = copy.copy(self.skeleton)
29 | new_data.values = copy.copy(self.values)
30 | new_data.channel_names = copy.copy(self.channel_names)
31 | new_data.root_name = copy.copy(self.root_name)
32 | new_data.framerate = copy.copy(self.framerate)
33 | return new_data
34 |
35 | def get_all_channels(self):
36 | '''Returns all of the channels parsed from the file as a 2D numpy array'''
37 |
38 | frames = [f[1] for f in self.values]
39 | return np.asarray([[channel[2] for channel in frame] for frame in frames])
40 |
41 | def get_skeleton_tree(self):
42 | tree = []
43 | root_key = [j for j in self.skeleton if self.skeleton[j]['parent']==None][0]
44 |
45 | root_joint = Joint(root_key)
46 |
47 | def get_empty_channels(self):
48 | #TODO
49 | pass
50 |
51 | def get_constant_channels(self):
52 | #TODO
53 | pass
54 |
--------------------------------------------------------------------------------
/options/evaluate_options.py:
--------------------------------------------------------------------------------
1 | from options.base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
8 | self.parser.add_argument('--start_mov_len', type=int, default=10)
9 | self.parser.add_argument('--est_length', action="store_true", help="Whether to use sampled motion length")
10 | self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer')
11 | self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer')
12 | self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer')
13 | self.parser.add_argument('--no_clip', action='store_true', help='whether use clip pretrain')
14 | self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient attention')
15 |
16 |
17 | self.parser.add_argument('--repeat_times', type=int, default=3, help="Number of generation rounds for each text description")
18 | self.parser.add_argument('--split_file', type=str, default='test.txt')
19 | self.parser.add_argument('--text', type=str, default="", help='Text description for motion generation')
20 | self.parser.add_argument('--motion_length', type=int, default=0, help='Number of framese for motion generation')
21 | self.parser.add_argument('--text_file', type=str, default="", help='Path of text description for motion generation')
22 | self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint that will be used')
23 | self.parser.add_argument('--result_path', type=str, default="./eval_results/", help='Path to save generation results')
24 | self.parser.add_argument('--num_results', type=int, default=40, help='Number of descriptions that will be used')
25 | self.parser.add_argument('--ext', type=str, default='default', help='Save file path extension')
26 |
27 | self.is_train = False
28 |
--------------------------------------------------------------------------------
/utils/paramUtil.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # Define a kinematic tree for the skeletal struture
4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5 |
6 | kit_raw_offsets = np.array(
7 | [
8 | [0, 0, 0],
9 | [0, 1, 0],
10 | [0, 1, 0],
11 | [0, 1, 0],
12 | [0, 1, 0],
13 | [1, 0, 0],
14 | [0, -1, 0],
15 | [0, -1, 0],
16 | [-1, 0, 0],
17 | [0, -1, 0],
18 | [0, -1, 0],
19 | [1, 0, 0],
20 | [0, -1, 0],
21 | [0, -1, 0],
22 | [0, 0, 1],
23 | [0, 0, 1],
24 | [-1, 0, 0],
25 | [0, -1, 0],
26 | [0, -1, 0],
27 | [0, 0, 1],
28 | [0, 0, 1]
29 | ]
30 | )
31 |
32 | t2m_raw_offsets = np.array([[0,0,0],
33 | [1,0,0],
34 | [-1,0,0],
35 | [0,1,0],
36 | [0,-1,0],
37 | [0,-1,0],
38 | [0,1,0],
39 | [0,-1,0],
40 | [0,-1,0],
41 | [0,1,0],
42 | [0,0,1],
43 | [0,0,1],
44 | [0,1,0],
45 | [1,0,0],
46 | [-1,0,0],
47 | [0,0,1],
48 | [0,-1,0],
49 | [0,-1,0],
50 | [0,-1,0],
51 | [0,-1,0],
52 | [0,-1,0],
53 | [0,-1,0]])
54 |
55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
58 |
59 |
60 | kit_tgt_skel_id = '03950'
61 |
62 | t2m_tgt_skel_id = '000021'
63 |
64 |
--------------------------------------------------------------------------------
/datasets/pymo/writers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | class BVHWriter():
5 | def __init__(self):
6 | pass
7 |
8 | def write(self, X, ofile):
9 |
10 | # Writing the skeleton info
11 | ofile.write('HIERARCHY\n')
12 |
13 | self.motions_ = []
14 | self._printJoint(X, X.root_name, 0, ofile)
15 |
16 | # Writing the motion header
17 | ofile.write('MOTION\n')
18 | ofile.write('Frames: %d\n'%X.values.shape[0])
19 | ofile.write('Frame Time: %f\n'%X.framerate)
20 |
21 | # Writing the data
22 | self.motions_ = np.asarray(self.motions_).T
23 | lines = [" ".join(item) for item in self.motions_.astype(str)]
24 | ofile.write("".join("%s\n"%l for l in lines))
25 |
26 | def _printJoint(self, X, joint, tab, ofile):
27 |
28 | if X.skeleton[joint]['parent'] == None:
29 | ofile.write('ROOT %s\n'%joint)
30 | elif len(X.skeleton[joint]['children']) > 0:
31 | ofile.write('%sJOINT %s\n'%('\t'*(tab), joint))
32 | else:
33 | ofile.write('%sEnd site\n'%('\t'*(tab)))
34 |
35 | ofile.write('%s{\n'%('\t'*(tab)))
36 |
37 | ofile.write('%sOFFSET %3.5f %3.5f %3.5f\n'%('\t'*(tab+1),
38 | X.skeleton[joint]['offsets'][0],
39 | X.skeleton[joint]['offsets'][1],
40 | X.skeleton[joint]['offsets'][2]))
41 | channels = X.skeleton[joint]['channels']
42 | n_channels = len(channels)
43 |
44 | if n_channels > 0:
45 | for ch in channels:
46 | self.motions_.append(np.asarray(X.values['%s_%s'%(joint, ch)].values))
47 |
48 | if len(X.skeleton[joint]['children']) > 0:
49 | ch_str = ''.join(' %s'*n_channels%tuple(channels))
50 | ofile.write('%sCHANNELS %d%s\n' %('\t'*(tab+1), n_channels, ch_str))
51 |
52 | for c in X.skeleton[joint]['children']:
53 | self._printJoint(X, c, tab+1, ofile)
54 |
55 | ofile.write('%s}\n'%('\t'*(tab)))
56 |
--------------------------------------------------------------------------------
/trainers/loss_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) HuaWei, Inc. and its affiliates.
2 | # liu.haiyang@huawei.com
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch
7 | import numpy as np
8 |
9 |
10 | class BCE_Loss(nn.Module):
11 | def __init__(self, args=None):
12 | super(BCE_Loss, self).__init__()
13 |
14 | def forward(self, fake_outputs, real_target):
15 | final_loss = F.cross_entropy(fake_outputs, real_target, reduce="mean")
16 | return final_loss
17 |
18 |
19 | class HuberLoss(nn.Module):
20 | def __init__(self, beta=0.1, reduction="mean"):
21 | super(HuberLoss, self).__init__()
22 | self.beta = beta
23 | self.reduction = reduction
24 |
25 | def forward(self, outputs, targets):
26 | final_loss = F.smooth_l1_loss(outputs / self.beta, targets / self.beta, reduction=self.reduction) * self.beta
27 | return final_loss
28 |
29 |
30 | class KLDLoss(nn.Module):
31 | def __init__(self, beta=0.1):
32 | super(KLDLoss, self).__init__()
33 | self.beta = beta
34 |
35 | def forward(self, outputs, targets):
36 | final_loss = F.smooth_l1_loss((outputs / self.beta, targets / self.beta) * self.beta)
37 | return final_loss
38 |
39 |
40 | class REGLoss(nn.Module):
41 | def __init__(self, beta=0.1):
42 | super(REGLoss, self).__init__()
43 | self.beta = beta
44 |
45 | def forward(self, outputs, targets):
46 | final_loss = F.smooth_l1_loss((outputs / self.beta, targets / self.beta) * self.beta)
47 | return final_loss
48 |
49 |
50 | class L2Loss(nn.Module):
51 | def __init__(self):
52 | super(L2Loss, self).__init__()
53 |
54 | def forward(self, outputs, targets):
55 | final_loss = F.l2_loss(outputs, targets)
56 | return final_loss
57 |
58 | LOSS_FUNC_LUT = {
59 | "bce_loss": BCE_Loss,
60 | "l2_loss": L2Loss,
61 | "huber_loss": HuberLoss,
62 | "kl_loss": KLDLoss,
63 | "id_loss": REGLoss,
64 | }
65 |
66 |
67 | def get_loss_func(loss_name, **kwargs):
68 | loss_func_class = LOSS_FUNC_LUT.get(loss_name)
69 | loss_func = loss_func_class(**kwargs)
70 | return loss_func
71 |
72 |
73 |
--------------------------------------------------------------------------------
/train_test_scripts.sh:
--------------------------------------------------------------------------------
1 | ################################# BEAT #################################
2 |
3 | ### Train on BEAT.
4 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
5 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -u runner.py \
6 | --dataset_name beat \
7 | --name beat_diffsheg \
8 | --batch_size 2500 \
9 | --num_epochs 1000 \
10 | --save_every_e 20 \
11 | --eval_every_e 40 \
12 | --n_poses 34 \
13 | --ddim \
14 | --multiprocessing-distributed \
15 | --dist-url 'tcp://127.0.0.1:6666'
16 |
17 |
18 |
19 | ### Test on BEAT.
20 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
21 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -u runner.py \
22 | --dataset_name beat \
23 | --name beat_GesExpr_unify_addHubert_encodeHubert_mlpIncludeX_condRes_LN \
24 | --n_poses 34 \
25 | --multiprocessing-distributed \
26 | --dist-url 'tcp://127.0.0.1:8888' \
27 | --ckpt fgd_best.tar \
28 | --mode test_arbitrary_len \
29 | --ddim \
30 | --timestep_respacing ddim25 \
31 | --overlap_len 4
32 |
33 |
34 | ################################# SHOW #################################
35 |
36 | ### Train on SHOW.
37 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
38 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=1,2,3,4,5 python -u runner.py \
39 | --dataset_name talkshow \
40 | --name talkshow_diffsheg \
41 | --batch_size 950 \
42 | --num_epochs 4000 \
43 | --save_every_e 20 \
44 | --eval_every_e 40 \
45 | --n_poses 88 \
46 | --classifier_free \
47 | --multiprocessing-distributed \
48 | --dist-url 'tcp://127.0.0.1:6667' \
49 | --ddim \
50 | --max_eval_samples 200
51 |
52 |
53 |
54 | ### Test on SHOW.
55 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
56 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -u runner.py \
57 | --dataset_name talkshow \
58 | --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \
59 | --PE pe_sinu \
60 | --n_poses 88 \
61 | --multiprocessing-distributed \
62 | --dist-url 'tcp://127.0.0.1:8889' \
63 | --classifier_free \
64 | --cond_scale 1.25 \
65 | --ckpt ckpt_e2599.tar \
66 | --mode test_arbitrary_len \
67 | --ddim \
68 | --timestep_respacing ddim25 \
69 | --overlap_len 10
70 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from options.base_options import BaseOptions
2 | import argparse
3 |
4 | class TrainCompOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer')
8 | self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer')
9 | self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer')
10 | self.parser.add_argument('--no_clip', action='store_true', help='whether use clip pretrain')
11 | self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient attention')
12 |
13 | self.parser.add_argument('--num_epochs', type=int, default=5000, help='Number of epochs')
14 | self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate')
15 | self.parser.add_argument('--reset_lr', action='store_true', help='Reset the optimizer lr to args.lr after resume from a ckpt')
16 | self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU')
17 | self.parser.add_argument('--times', type=int, default=1, help='times of dataset')
18 |
19 | self.parser.add_argument('--feat_bias', type=float, default=5, help='Scales for global motion features and foot contact')
20 |
21 | self.parser.add_argument('--resume', action="store_true", help='Is this trail continued from previous trail?')
22 |
23 | self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress (by iteration)')
24 | self.parser.add_argument('--save_every_e', type=int, default=5, help='Frequency of saving models (by epoch)')
25 | self.parser.add_argument('--eval_every_e', type=int, default=5, help='Frequency of animation results (by epoch)')
26 | self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of saving models (by iteration)')
27 |
28 |
29 | self.parser.add_argument('--ckpt', type=str, default='latest.tar', help='choose which checkpoint to use')
30 | # self.parser.add_argument('--max_motion_length', type=int, default=200, help='max_motion_length')
31 |
32 | self.is_train = True
33 |
--------------------------------------------------------------------------------
/datasets/pymo/mocapplayer/libs/threejs/Detector.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @author alteredq / http://alteredqualia.com/
3 | * @author mr.doob / http://mrdoob.com/
4 | */
5 |
6 | var Detector = {
7 |
8 | canvas: !! window.CanvasRenderingContext2D,
9 | webgl: ( function () {
10 |
11 | try {
12 |
13 | var canvas = document.createElement( 'canvas' ); return !! ( window.WebGLRenderingContext && ( canvas.getContext( 'webgl' ) || canvas.getContext( 'experimental-webgl' ) ) );
14 |
15 | } catch ( e ) {
16 |
17 | return false;
18 |
19 | }
20 |
21 | } )(),
22 | workers: !! window.Worker,
23 | fileapi: window.File && window.FileReader && window.FileList && window.Blob,
24 |
25 | getWebGLErrorMessage: function () {
26 |
27 | var element = document.createElement( 'div' );
28 | element.id = 'webgl-error-message';
29 | element.style.fontFamily = 'monospace';
30 | element.style.fontSize = '13px';
31 | element.style.fontWeight = 'normal';
32 | element.style.textAlign = 'center';
33 | element.style.background = '#fff';
34 | element.style.color = '#000';
35 | element.style.padding = '1.5em';
36 | element.style.width = '400px';
37 | element.style.margin = '5em auto 0';
38 |
39 | if ( ! this.webgl ) {
40 |
41 | element.innerHTML = window.WebGLRenderingContext ? [
42 | 'Your graphics card does not seem to support WebGL.
',
43 | 'Find out how to get it here.'
44 | ].join( '\n' ) : [
45 | 'Your browser does not seem to support WebGL.
',
46 | 'Find out how to get it here.'
47 | ].join( '\n' );
48 |
49 | }
50 |
51 | return element;
52 |
53 | },
54 |
55 | addGetWebGLMessage: function ( parameters ) {
56 |
57 | var parent, id, element;
58 |
59 | parameters = parameters || {};
60 |
61 | parent = parameters.parent !== undefined ? parameters.parent : document.body;
62 | id = parameters.id !== undefined ? parameters.id : 'oldie';
63 |
64 | element = Detector.getWebGLErrorMessage();
65 | element.id = id;
66 |
67 | parent.appendChild( element );
68 |
69 | }
70 |
71 | };
72 |
73 | // browserify support
74 | if ( typeof module === 'object' ) {
75 |
76 | module.exports = Detector;
77 |
78 | }
--------------------------------------------------------------------------------
/datasets/pymo/mocapplayer/styles/pace.css:
--------------------------------------------------------------------------------
1 | .pace {
2 | -webkit-pointer-events: none;
3 | pointer-events: none;
4 | -webkit-user-select: none;
5 | -moz-user-select: none;
6 | user-select: none;
7 | }
8 |
9 | .pace-inactive {
10 | display: none;
11 | }
12 |
13 | .pace .pace-progress {
14 | background: #29d;
15 | position: fixed;
16 | z-index: 2000;
17 | top: 0;
18 | right: 100%;
19 | width: 100%;
20 | height: 2px;
21 | }
22 |
23 | .pace .pace-progress-inner {
24 | display: block;
25 | position: absolute;
26 | right: 0px;
27 | width: 100px;
28 | height: 100%;
29 | box-shadow: 0 0 10px #29d, 0 0 5px #29d;
30 | opacity: 1.0;
31 | -webkit-transform: rotate(3deg) translate(0px, -4px);
32 | -moz-transform: rotate(3deg) translate(0px, -4px);
33 | -ms-transform: rotate(3deg) translate(0px, -4px);
34 | -o-transform: rotate(3deg) translate(0px, -4px);
35 | transform: rotate(3deg) translate(0px, -4px);
36 | }
37 |
38 | .pace .pace-activity {
39 | display: block;
40 | position: fixed;
41 | z-index: 2000;
42 | top: 15px;
43 | right: 20px;
44 | width: 34px;
45 | height: 34px;
46 | border: solid 2px transparent;
47 | border-top-color: #9ea7ac;
48 | border-left-color: #9ea7ac;
49 | border-radius: 30px;
50 | -webkit-animation: pace-spinner 700ms linear infinite;
51 | -moz-animation: pace-spinner 700ms linear infinite;
52 | -ms-animation: pace-spinner 700ms linear infinite;
53 | -o-animation: pace-spinner 700ms linear infinite;
54 | animation: pace-spinner 700ms linear infinite;
55 | }
56 |
57 | @-webkit-keyframes pace-spinner {
58 | 0% { -webkit-transform: rotate(0deg); transform: rotate(0deg); }
59 | 100% { -webkit-transform: rotate(360deg); transform: rotate(360deg); }
60 | }
61 | @-moz-keyframes pace-spinner {
62 | 0% { -moz-transform: rotate(0deg); transform: rotate(0deg); }
63 | 100% { -moz-transform: rotate(360deg); transform: rotate(360deg); }
64 | }
65 | @-o-keyframes pace-spinner {
66 | 0% { -o-transform: rotate(0deg); transform: rotate(0deg); }
67 | 100% { -o-transform: rotate(360deg); transform: rotate(360deg); }
68 | }
69 | @-ms-keyframes pace-spinner {
70 | 0% { -ms-transform: rotate(0deg); transform: rotate(0deg); }
71 | 100% { -ms-transform: rotate(360deg); transform: rotate(360deg); }
72 | }
73 | @keyframes pace-spinner {
74 | 0% { transform: rotate(0deg); transform: rotate(0deg); }
75 | 100% { transform: rotate(360deg); transform: rotate(360deg); }
76 | }
--------------------------------------------------------------------------------
/utils/get_opt.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import Namespace
3 | import re
4 | from os.path import join as pjoin
5 | from utils.word_vectorizer import POS_enumerator
6 |
7 |
8 | def is_float(numStr):
9 | flag = False
10 | numStr = str(numStr).strip().lstrip('-').lstrip('+')
11 | try:
12 | reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$')
13 | res = reg.match(str(numStr))
14 | if res:
15 | flag = True
16 | except Exception as ex:
17 | print("is_float() - error: " + str(ex))
18 | return flag
19 |
20 |
21 | def is_number(numStr):
22 | flag = False
23 | numStr = str(numStr).strip().lstrip('-').lstrip('+')
24 | if str(numStr).isdigit():
25 | flag = True
26 | return flag
27 |
28 |
29 | def get_opt(opt_path, device):
30 | opt = Namespace()
31 | opt_dict = vars(opt)
32 |
33 | skip = ('-------------- End ----------------',
34 | '------------ Options -------------',
35 | '\n')
36 | print('Reading', opt_path)
37 | with open(opt_path) as f:
38 | for line in f:
39 | if line.strip() not in skip:
40 | # print(line.strip())
41 | key, value = line.strip().split(': ')
42 | if value in ('True', 'False'):
43 | opt_dict[key] = True if value == 'True' else False
44 | elif is_float(value):
45 | opt_dict[key] = float(value)
46 | elif is_number(value):
47 | opt_dict[key] = int(value)
48 | else:
49 | opt_dict[key] = str(value)
50 |
51 | opt_dict['which_epoch'] = 'latest'
52 | if 'num_layers' not in opt_dict:
53 | opt_dict['num_layers'] = 8
54 | if 'latent_dim' not in opt_dict:
55 | opt_dict['latent_dim'] = 512
56 | if 'diffusion_steps' not in opt_dict:
57 | opt_dict['diffusion_steps'] = 1000
58 | if 'no_clip' not in opt_dict:
59 | opt_dict['no_clip'] = False
60 | if 'no_eff' not in opt_dict:
61 | opt_dict['no_eff'] = False
62 |
63 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
64 | opt.model_dir = pjoin(opt.save_root, 'model')
65 | opt.meta_dir = pjoin(opt.save_root, 'meta')
66 |
67 | if opt.dataset_name == 't2m':
68 | opt.data_root = './data/HumanML3D'
69 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
70 | opt.text_dir = pjoin(opt.data_root, 'texts')
71 | opt.joints_num = 22
72 | opt.dim_pose = 263
73 | opt.max_motion_length = 196
74 | elif opt.dataset_name == 'kit':
75 | opt.data_root = './data/KIT-ML'
76 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
77 | opt.text_dir = pjoin(opt.data_root, 'texts')
78 | opt.joints_num = 21
79 | opt.dim_pose = 251
80 | opt.max_motion_length = 196
81 | else:
82 | raise KeyError('Dataset not recognized')
83 |
84 | opt.dim_word = 300
85 | opt.num_classes = 200 // opt.unit_length
86 | opt.dim_pos_ohot = len(POS_enumerator)
87 | opt.is_train = False
88 | opt.is_continue = False
89 | opt.device = device
90 |
91 | return opt
--------------------------------------------------------------------------------
/assets/environment.yml:
--------------------------------------------------------------------------------
1 | name: diffsheg
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - ca-certificates=2024.3.11=h06a4308_0
9 | - ld_impl_linux-64=2.38=h1181459_1
10 | - libffi=3.4.4=h6a678d5_0
11 | - libgcc-ng=11.2.0=h1234567_1
12 | - libgomp=11.2.0=h1234567_1
13 | - libstdcxx-ng=11.2.0=h1234567_1
14 | - ncurses=6.4=h6a678d5_0
15 | - openssl=3.0.13=h7f8727e_0
16 | - pip=23.3.1=py39h06a4308_0
17 | - python=3.9.19=h955ad1f_0
18 | - readline=8.2=h5eee18b_0
19 | - setuptools=68.2.2=py39h06a4308_0
20 | - sqlite=3.41.2=h5eee18b_0
21 | - tk=8.6.12=h1ccaba5_0
22 | - tzdata=2024a=h04d1e81_0
23 | - wheel=0.41.2=py39h06a4308_0
24 | - xz=5.4.6=h5eee18b_0
25 | - zlib=1.2.13=h5eee18b_0
26 | - pip:
27 | - addict==2.4.0
28 | - appdirs==1.4.4
29 | - asttokens==2.4.1
30 | - audioread==3.0.1
31 | - certifi==2024.2.2
32 | - cffi==1.16.0
33 | - charset-normalizer==3.3.2
34 | - click==8.1.7
35 | - contourpy==1.2.1
36 | - cycler==0.12.1
37 | - decorator==5.1.1
38 | - docker-pycreds==0.4.0
39 | - exceptiongroup==1.2.1
40 | - executing==2.0.1
41 | - filelock==3.14.0
42 | - fonttools==4.51.0
43 | - fsspec==2024.3.1
44 | - gitdb==4.0.11
45 | - gitpython==3.1.43
46 | - huggingface-hub==0.22.2
47 | - idna==3.7
48 | - importlib-metadata==7.1.0
49 | - importlib-resources==6.4.0
50 | - ipython==8.18.1
51 | - jedi==0.19.1
52 | - joblib==1.4.0
53 | - kiwisolver==1.4.5
54 | - librosa==0.9.2
55 | - llvmlite==0.42.0
56 | - lmdb==0.96
57 | - loguru==0.6.0
58 | - matplotlib==3.7.1
59 | - matplotlib-inline==0.1.7
60 | - mmcv-full==1.6.0
61 | - numba==0.59.1
62 | - numpy==1.26.4
63 | - opencv-python==4.9.0.80
64 | - packaging==24.0
65 | - pandas==1.4.2
66 | - parso==0.8.4
67 | - pathtools==0.1.2
68 | - pexpect==4.9.0
69 | - pillow==10.3.0
70 | - platformdirs==4.2.1
71 | - pooch==1.8.1
72 | - prompt-toolkit==3.0.43
73 | - protobuf==4.25.3
74 | - psutil==5.9.8
75 | - ptyprocess==0.7.0
76 | - pure-eval==0.2.2
77 | - pyarrow==3.0.0
78 | - pycparser==2.22
79 | - pygments==2.17.2
80 | - pyparsing==3.1.2
81 | - python-dateutil==2.9.0.post0
82 | - pytz==2024.1
83 | - pyyaml==6.0.1
84 | - regex==2024.4.28
85 | - requests==2.31.0
86 | - resampy==0.4.3
87 | - safetensors==0.4.3
88 | - scikit-learn==1.4.2
89 | - scipy==1.10.1
90 | - sentry-sdk==2.0.1
91 | - setproctitle==1.3.3
92 | - six==1.16.0
93 | - smmap==5.0.1
94 | - soundfile==0.12.1
95 | - stack-data==0.6.3
96 | - termcolor==2.1.1
97 | - threadpoolctl==3.5.0
98 | - tokenizers==0.13.3
99 | - tomli==2.0.1
100 | - torch==1.13.1+cu117
101 | - torchaudio==0.13.1+cu117
102 | - torchvision==0.14.1+cu117
103 | - tqdm==4.66.2
104 | - traitlets==5.14.3
105 | - transformers==4.31.0
106 | - typing-extensions==4.11.0
107 | - umap==0.1.1
108 | - urllib3==2.2.1
109 | - wandb==0.13.10
110 | - wcwidth==0.2.13
111 | - yapf==0.40.2
112 | - zipp==3.18.1
113 |
--------------------------------------------------------------------------------
/datasets/extract_hubert.py:
--------------------------------------------------------------------------------
1 | from transformers import Wav2Vec2Processor, HubertModel
2 | import soundfile as sf
3 | import numpy as np
4 | import torch
5 |
6 | print("Loading the Wav2Vec2 Processor...")
7 | wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
8 | print("Loading the HuBERT Model...")
9 | hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
10 |
11 |
12 | def get_hubert_from_16k_wav(wav_16k_name):
13 | speech_16k, _ = sf.read(wav_16k_name)
14 | hubert = get_hubert_from_16k_speech(speech_16k)
15 | return hubert
16 |
17 | @torch.no_grad()
18 | def get_hubert_from_16k_speech(speech, device="cuda:0"):
19 | global hubert_model
20 | hubert_model = hubert_model.to(device)
21 | if speech.ndim ==2:
22 | speech = speech[:, 0] # [T, 2] ==> [T,]
23 | input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
24 | input_values_all = input_values_all.to(device)
25 | # For long audio sequence, due to the memory limitation, we cannot process them in one run
26 | # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
27 | # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
28 | # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
29 | # We have the equation to calculate out time step: T = floor((t-k)/s)
30 | # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
31 | # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
32 | kernel = 400
33 | stride = 320
34 | clip_length = stride * 1000
35 | num_iter = input_values_all.shape[1] // clip_length
36 | expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
37 | res_lst = []
38 | for i in range(num_iter):
39 | if i == 0:
40 | start_idx = 0
41 | end_idx = clip_length - stride + kernel
42 | else:
43 | start_idx = clip_length * i
44 | end_idx = start_idx + (clip_length - stride + kernel)
45 | input_values = input_values_all[:, start_idx: end_idx]
46 | hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
47 | res_lst.append(hidden_states[0])
48 | if num_iter > 0:
49 | input_values = input_values_all[:, clip_length * num_iter:]
50 | else:
51 | input_values = input_values_all
52 | # if input_values.shape[1] != 0:
53 | if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
54 | hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
55 | res_lst.append(hidden_states[0])
56 | ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
57 | # assert ret.shape[0] == expected_T
58 | assert abs(ret.shape[0] - expected_T) <= 1
59 | if ret.shape[0] < expected_T:
60 | ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))
61 | else:
62 | ret = ret[:expected_T]
63 | return ret
64 |
65 |
66 | if __name__ == '__main__':
67 | ### Process Single Long Audio for NeRF dataset
68 | # person_id = 'May'
69 | # wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
70 | # hubert_npy_name = f"data/processed/videos/{person_id}/hubert.npy"
71 | # speech_16k, _ = sf.read(wav_16k_name)
72 | # hubert_hidden = get_hubert_from_16k_speech(speech_16k)
73 | # np.save(hubert_npy_name, hubert_hidden.detach().numpy())
74 |
75 | ### Process short audio clips for LRS3 dataset
76 | import glob, os, tqdm
77 | lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/'
78 | wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav'))
79 | for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)):
80 | spk_id = wav_16k_name.split("/")[-2]
81 | clip_id = wav_16k_name.split("/")[-1][:-4]
82 | out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy')
83 | if os.path.exists(out_name):
84 | continue
85 | speech_16k, _ = sf.read(wav_16k_name)
86 | hubert_hidden = get_hubert_from_16k_speech(speech_16k)
87 | np.save(out_name, hubert_hidden.detach().numpy())
--------------------------------------------------------------------------------
/utils/other_tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import random
4 | import torch
5 | import csv
6 | import pprint
7 | from loguru import logger
8 | from collections import OrderedDict
9 |
10 | def print_exp_info(args):
11 | logger.info(pprint.pformat(vars(args)))
12 | logger.info(f"# ------------ {args.name} ----------- #")
13 | logger.info("PyTorch version: {}".format(torch.__version__))
14 | logger.info("CUDA version: {}".format(torch.version.cuda))
15 | logger.info("{} GPUs".format(torch.cuda.device_count()))
16 | logger.info(f"Random Seed: {args.random_seed}")
17 |
18 | def args2csv(args, get_head=False, list4print=[]):
19 | for k, v in args.items():
20 | if isinstance(args[k], dict):
21 | args2csv(args[k], get_head, list4print)
22 | else: list4print.append(k) if get_head else list4print.append(v)
23 | return list4print
24 |
25 | def record_trial(args, csv_path, best_metric, best_epoch):
26 | metric_name = []
27 | metric_value = []
28 | metric_epoch = []
29 | list4print = []
30 | name4print = []
31 | for k, v in vars(args).items():
32 | list4print.append(v)
33 | name4print.append(k)
34 |
35 | for k, v in best_metric.items():
36 | metric_name.append(k)
37 | metric_value.append(v)
38 | metric_epoch.append(best_epoch[k])
39 |
40 | if not os.path.exists(csv_path):
41 | with open(csv_path, "a+") as f:
42 | csv_writer = csv.writer(f)
43 | csv_writer.writerow([*metric_name, *metric_name, *name4print])
44 |
45 | with open(csv_path, "a+") as f:
46 | csv_writer = csv.writer(f)
47 | csv_writer.writerow([*metric_value,*metric_epoch, *list4print])
48 |
49 |
50 | def set_random_seed(args):
51 | os.environ['PYTHONHASHSEED'] = str(args.random_seed)
52 | random.seed(args.random_seed)
53 | np.random.seed(args.random_seed)
54 | torch.manual_seed(args.random_seed)
55 | torch.cuda.manual_seed_all(args.random_seed)
56 | torch.cuda.manual_seed(args.random_seed)
57 | torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC
58 | torch.backends.cudnn.benchmark = args.benchmark
59 | torch.backends.cudnn.enabled = args.cudnn_enabled
60 |
61 |
62 | def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None):
63 | if lrs is not None:
64 | states = { 'model_state': model.state_dict(),
65 | 'epoch': epoch + 1,
66 | 'opt_state': opt.state_dict(),
67 | 'lrs':lrs.state_dict(),}
68 | elif opt is not None:
69 | states = { 'model_state': model.state_dict(),
70 | 'epoch': epoch + 1,
71 | 'opt_state': opt.state_dict(),}
72 | else:
73 | states = { 'model_state': model.state_dict(),}
74 | torch.save(states, save_path)
75 |
76 |
77 | def load_checkpoints(model, save_path, load_name='model'):
78 | states = torch.load(save_path)
79 | new_weights = OrderedDict()
80 | flag=False
81 | for k, v in states['model_state'].items():
82 | if "module" not in k:
83 | break
84 | else:
85 | new_weights[k[7:]]=v
86 | flag=True
87 | if flag:
88 | model.load_state_dict(new_weights)
89 | else:
90 | model.load_state_dict(states['model_state'])
91 | logger.info(f"load self-pretrained checkpoints for {load_name}")
92 |
93 |
94 | def model_complexity(model, args):
95 | from ptflops import get_model_complexity_info
96 | flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN),
97 | as_strings=False, print_per_layer_stat=False)
98 | logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9))
99 | logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6))
100 |
101 |
102 | class AverageMeter(object):
103 | """Computes and stores the average and current value"""
104 | def __init__(self, name, fmt=':f'):
105 | self.name = name
106 | self.fmt = fmt
107 | self.reset()
108 |
109 | def reset(self):
110 | self.val = 0
111 | self.avg = 0
112 | self.sum = 0
113 | self.count = 0
114 |
115 | def update(self, val, n=1):
116 | self.val = val
117 | self.sum += val * n
118 | self.count += n
119 | self.avg = self.sum / self.count
120 |
121 | def __str__(self):
122 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
123 | return fmtstr.format(**self.__dict__)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | # import cv2
4 | from PIL import Image
5 | from utils import paramUtil
6 | import math
7 | import time
8 | import matplotlib.pyplot as plt
9 | from scipy.ndimage import gaussian_filter
10 |
11 |
12 | def mkdir(path):
13 | if not os.path.exists(path):
14 | os.makedirs(path)
15 |
16 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
17 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
18 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
19 |
20 | MISSING_VALUE = -1
21 |
22 | def save_image(image_numpy, image_path):
23 | img_pil = Image.fromarray(image_numpy)
24 | img_pil.save(image_path)
25 |
26 |
27 | def save_logfile(log_loss, save_path):
28 | with open(save_path, 'wt') as f:
29 | for k, v in log_loss.items():
30 | w_line = k
31 | for digit in v:
32 | w_line += ' %.3f' % digit
33 | f.write(w_line + '\n')
34 |
35 |
36 | def print_current_loss(start_time, niter_state, losses, epoch=None, inner_iter=None, mode='train'):
37 |
38 | def as_minutes(s):
39 | m = math.floor(s / 60)
40 | s -= m * 60
41 | return '%dm %ds' % (m, s)
42 |
43 | def time_since(since, percent):
44 | now = time.time()
45 | s = now - since
46 | es = s / percent
47 | rs = es - s
48 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
49 |
50 | if epoch is not None:
51 | if mode == 'train':
52 | print('epoch: %3d niter: %6d inner_iter: %4d' % (epoch, niter_state, inner_iter), end=" ")
53 | elif mode == 'val':
54 | print('[Validation] epoch: %3d niter: %6d inner_iter: %4d' % (epoch, niter_state, inner_iter), end=" ")
55 |
56 | now = time.time()
57 | message = '%s'%(as_minutes(now - start_time))
58 |
59 | for k, v in losses.items():
60 | message += ' %s: %.4f ' % (k, v)
61 | print(message)
62 |
63 |
64 | def compose_gif_img_list(img_list, fp_out, duration):
65 | img, *imgs = [Image.fromarray(np.array(image)) for image in img_list]
66 | img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False,
67 | save_all=True, loop=0, duration=duration)
68 |
69 |
70 | def save_images(visuals, image_path):
71 | if not os.path.exists(image_path):
72 | os.makedirs(image_path)
73 |
74 | for i, (label, img_numpy) in enumerate(visuals.items()):
75 | img_name = '%d_%s.jpg' % (i, label)
76 | save_path = os.path.join(image_path, img_name)
77 | save_image(img_numpy, save_path)
78 |
79 |
80 | def save_images_test(visuals, image_path, from_name, to_name):
81 | if not os.path.exists(image_path):
82 | os.makedirs(image_path)
83 |
84 | for i, (label, img_numpy) in enumerate(visuals.items()):
85 | img_name = "%s_%s_%s" % (from_name, to_name, label)
86 | save_path = os.path.join(image_path, img_name)
87 | save_image(img_numpy, save_path)
88 |
89 |
90 | def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)):
91 | # print(col, row)
92 | compose_img = compose_image(img_list, col, row, img_size)
93 | if not os.path.exists(save_dir):
94 | os.makedirs(save_dir)
95 | img_path = os.path.join(save_dir, img_name)
96 | # print(img_path)
97 | compose_img.save(img_path)
98 |
99 |
100 | def compose_image(img_list, col, row, img_size):
101 | to_image = Image.new('RGB', (col * img_size[0], row * img_size[1]))
102 | for y in range(0, row):
103 | for x in range(0, col):
104 | from_img = Image.fromarray(img_list[y * col + x])
105 | # print((x * img_size[0], y*img_size[1],
106 | # (x + 1) * img_size[0], (y + 1) * img_size[1]))
107 | paste_area = (x * img_size[0], y*img_size[1],
108 | (x + 1) * img_size[0], (y + 1) * img_size[1])
109 | to_image.paste(from_img, paste_area)
110 | # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img
111 | return to_image
112 |
113 |
114 | def list_cut_average(ll, intervals):
115 | if intervals == 1:
116 | return ll
117 |
118 | bins = math.ceil(len(ll) * 1.0 / intervals)
119 | ll_new = []
120 | for i in range(bins):
121 | l_low = intervals * i
122 | l_high = l_low + intervals
123 | l_high = l_high if l_high < len(ll) else len(ll)
124 | ll_new.append(np.mean(ll[l_low:l_high]))
125 | return ll_new
126 |
127 |
128 | def motion_temporal_filter(motion, sigma=1):
129 | motion = motion.reshape(motion.shape[0], -1)
130 | # print(motion.shape)
131 | for i in range(motion.shape[1]):
132 | motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
133 | return motion.reshape(motion.shape[0], -1, 3)
134 |
135 |
--------------------------------------------------------------------------------
/datasets/pymo/rotation_tools.py:
--------------------------------------------------------------------------------
1 | '''
2 | Tools for Manipulating and Converting 3D Rotations
3 |
4 | By Omid Alemi
5 | Created: June 12, 2017
6 |
7 | Adapted from that matlab file...
8 | '''
9 |
10 | import math
11 | import numpy as np
12 |
13 | def deg2rad(x):
14 | return x/180*math.pi
15 |
16 |
17 | def rad2deg(x):
18 | return x/math.pi*180
19 |
20 | class Rotation():
21 | def __init__(self,rot, param_type, rotation_order, **params):
22 | self.rotmat = []
23 | self.rotation_order = rotation_order
24 | if param_type == 'euler':
25 | self._from_euler(rot[0],rot[1],rot[2], params)
26 | elif param_type == 'expmap':
27 | self._from_expmap(rot[0], rot[1], rot[2], params)
28 |
29 | def _from_euler(self, alpha, beta, gamma, params):
30 | '''Expecting degress'''
31 |
32 | if params['from_deg']==True:
33 | alpha = deg2rad(alpha)
34 | beta = deg2rad(beta)
35 | gamma = deg2rad(gamma)
36 |
37 | ca = math.cos(alpha)
38 | cb = math.cos(beta)
39 | cg = math.cos(gamma)
40 | sa = math.sin(alpha)
41 | sb = math.sin(beta)
42 | sg = math.sin(gamma)
43 |
44 | Rx = np.asarray([[1, 0, 0],
45 | [0, ca, sa],
46 | [0, -sa, ca]
47 | ])
48 |
49 | Ry = np.asarray([[cb, 0, -sb],
50 | [0, 1, 0],
51 | [sb, 0, cb]])
52 |
53 | Rz = np.asarray([[cg, sg, 0],
54 | [-sg, cg, 0],
55 | [0, 0, 1]])
56 |
57 | self.rotmat = np.eye(3)
58 |
59 | ############################ inner product rotation matrix in order defined at BVH file #########################
60 | for axis in self.rotation_order :
61 | if axis == 'X' :
62 | self.rotmat = np.matmul(Rx, self.rotmat)
63 | elif axis == 'Y':
64 | self.rotmat = np.matmul(Ry, self.rotmat)
65 | else :
66 | self.rotmat = np.matmul(Rz, self.rotmat)
67 | ################################################################################################################
68 |
69 | def _from_expmap(self, alpha, beta, gamma, params):
70 | if (alpha == 0 and beta == 0 and gamma == 0):
71 | self.rotmat = np.eye(3)
72 | return
73 |
74 | #TODO: Check exp map params
75 |
76 | theta = np.linalg.norm([alpha, beta, gamma])
77 |
78 | expmap = [alpha, beta, gamma] / theta
79 |
80 | x = expmap[0]
81 | y = expmap[1]
82 | z = expmap[2]
83 |
84 | s = math.sin(theta/2)
85 | c = math.cos(theta/2)
86 |
87 | self.rotmat = np.asarray([
88 | [2*(x**2-1)*s**2+1, 2*x*y*s**2-2*z*c*s, 2*x*z*s**2+2*y*c*s],
89 | [2*x*y*s**2+2*z*c*s, 2*(y**2-1)*s**2+1, 2*y*z*s**2-2*x*c*s],
90 | [2*x*z*s**2-2*y*c*s, 2*y*z*s**2+2*x*c*s , 2*(z**2-1)*s**2+1]
91 | ])
92 |
93 |
94 |
95 | def get_euler_axis(self):
96 | R = self.rotmat
97 | theta = math.acos((self.rotmat.trace() - 1) / 2)
98 | axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]])
99 | axis = axis/(2*math.sin(theta))
100 | return theta, axis
101 |
102 | def to_expmap(self):
103 | theta, axis = self.get_euler_axis()
104 | rot_arr = theta * axis
105 | if np.isnan(rot_arr).any():
106 | rot_arr = [0, 0, 0]
107 | return rot_arr
108 |
109 | def to_euler(self, use_deg=False):
110 | eulers = np.zeros((2, 3))
111 |
112 | if np.absolute(np.absolute(self.rotmat[2, 0]) - 1) < 1e-12:
113 | #GIMBAL LOCK!
114 | print('Gimbal')
115 | if np.absolute(self.rotmat[2, 0]) - 1 < 1e-12:
116 | eulers[:,0] = math.atan2(-self.rotmat[0,1], -self.rotmat[0,2])
117 | eulers[:,1] = -math.pi/2
118 | else:
119 | eulers[:,0] = math.atan2(self.rotmat[0,1], -elf.rotmat[0,2])
120 | eulers[:,1] = math.pi/2
121 |
122 | return eulers
123 |
124 | theta = - math.asin(self.rotmat[2,0])
125 | theta2 = math.pi - theta
126 |
127 | # psi1, psi2
128 | eulers[0,0] = math.atan2(self.rotmat[2,1]/math.cos(theta), self.rotmat[2,2]/math.cos(theta))
129 | eulers[1,0] = math.atan2(self.rotmat[2,1]/math.cos(theta2), self.rotmat[2,2]/math.cos(theta2))
130 |
131 | # theta1, theta2
132 | eulers[0,1] = theta
133 | eulers[1,1] = theta2
134 |
135 | # phi1, phi2
136 | eulers[0,2] = math.atan2(self.rotmat[1,0]/math.cos(theta), self.rotmat[0,0]/math.cos(theta))
137 | eulers[1,2] = math.atan2(self.rotmat[1,0]/math.cos(theta2), self.rotmat[0,0]/math.cos(theta2))
138 |
139 | if use_deg:
140 | eulers = rad2deg(eulers)
141 |
142 | return eulers
143 |
144 | def to_quat(self):
145 | #TODO
146 | pass
147 |
148 | def __str__(self):
149 | return "Rotation Matrix: \n " + self.rotmat.__str__()
150 |
151 |
152 |
153 |
154 |
--------------------------------------------------------------------------------
/datasets/dataloader.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import random
3 | from functools import partial
4 | from typing import Optional, Union
5 |
6 | import numpy as np
7 | from mmcv.runner import get_dist_info
8 | from mmcv.utils import Registry, build_from_cfg
9 | from torch.utils.data import DataLoader
10 | from torch.utils.data.dataset import Dataset
11 |
12 | import torch
13 | from torch.utils.data import DistributedSampler as _DistributedSampler
14 |
15 |
16 | class DistributedSampler(_DistributedSampler):
17 |
18 | def __init__(self,
19 | dataset,
20 | num_replicas=None,
21 | rank=None,
22 | shuffle=True,
23 | round_up=True):
24 | super().__init__(dataset, num_replicas=num_replicas, rank=rank)
25 | self.shuffle = shuffle
26 | self.round_up = round_up
27 | if self.round_up:
28 | self.total_size = self.num_samples * self.num_replicas
29 | else:
30 | self.total_size = len(self.dataset)
31 |
32 | def __iter__(self):
33 | # deterministically shuffle based on epoch
34 | if self.shuffle:
35 | g = torch.Generator()
36 | g.manual_seed(self.epoch)
37 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
38 | else:
39 | indices = torch.arange(len(self.dataset)).tolist()
40 |
41 | # add extra samples to make it evenly divisible
42 | if self.round_up:
43 | indices = (
44 | indices *
45 | int(self.total_size / len(indices) + 1))[:self.total_size]
46 | assert len(indices) == self.total_size
47 |
48 | # subsample
49 | indices = indices[self.rank:self.total_size:self.num_replicas]
50 | if self.round_up:
51 | assert len(indices) == self.num_samples
52 |
53 | return iter(indices)
54 |
55 |
56 | def build_dataloader(dataset: Dataset,
57 | samples_per_gpu: int,
58 | workers_per_gpu: int,
59 | num_gpus: Optional[int] = 1,
60 | dist: Optional[bool] = True,
61 | shuffle: Optional[bool] = True,
62 | round_up: Optional[bool] = True,
63 | seed: Optional[Union[int, None]] = None,
64 | persistent_workers: Optional[bool] = True,
65 | **kwargs):
66 | """Build PyTorch DataLoader.
67 | In distributed training, each GPU/process has a dataloader.
68 | In non-distributed training, there is only one dataloader for all GPUs.
69 | Args:
70 | dataset (:obj:`Dataset`): A PyTorch dataset.
71 | samples_per_gpu (int): Number of training samples on each GPU, i.e.,
72 | batch size of each GPU.
73 | workers_per_gpu (int): How many subprocesses to use for data loading
74 | for each GPU.
75 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed
76 | training.
77 | dist (bool, optional): Distributed training/test or not. Default: True.
78 | shuffle (bool, optional): Whether to shuffle the data at every epoch.
79 | Default: True.
80 | round_up (bool, optional): Whether to round up the length of dataset by
81 | adding extra samples to make it evenly divisible. Default: True.
82 | persistent_workers (bool): If True, the data loader will not shutdown
83 | the worker processes after a dataset has been consumed once.
84 | This allows to maintain the workers Dataset instances alive.
85 | The argument also has effect in PyTorch>=1.7.0.
86 | Default: True
87 | kwargs: any keyword argument to be used to initialize DataLoader
88 | Returns:
89 | DataLoader: A PyTorch dataloader.
90 | """
91 | rank, world_size = get_dist_info()
92 | if dist:
93 | sampler = DistributedSampler(
94 | dataset, world_size, rank, shuffle=shuffle, round_up=round_up)
95 | shuffle = False
96 | batch_size = samples_per_gpu
97 | num_workers = workers_per_gpu
98 | else:
99 | sampler = None
100 | batch_size = num_gpus * samples_per_gpu
101 | num_workers = num_gpus * workers_per_gpu
102 |
103 | init_fn = partial(
104 | worker_init_fn, num_workers=num_workers, rank=rank,
105 | seed=seed) if seed is not None else None
106 |
107 | data_loader = DataLoader(
108 | dataset,
109 | batch_size=batch_size,
110 | sampler=sampler,
111 | num_workers=num_workers,
112 | pin_memory=False,
113 | shuffle=shuffle,
114 | worker_init_fn=init_fn,
115 | persistent_workers=persistent_workers,
116 | **kwargs)
117 |
118 | return data_loader
119 |
120 |
121 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
122 | """Init random seed for each worker."""
123 | # The seed of each worker equals to
124 | # num_worker * rank + worker_id + user_seed
125 | worker_seed = num_workers * rank + worker_id + seed
126 | np.random.seed(worker_seed)
127 | random.seed(worker_seed)
--------------------------------------------------------------------------------
/models/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 | For example, if there's 300 timesteps and the section counts are [10,15,20]
13 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
14 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
15 | If the stride is a string starting with "ddim", then the fixed striding
16 | from the DDIM paper is used, and only one section is allowed.
17 | :param num_timesteps: the number of diffusion steps in the original
18 | process to divide up.
19 | :param section_counts: either a list of numbers, or a string containing
20 | comma-separated numbers, indicating the step count
21 | per section. As a special case, use "ddimN" where N
22 | is a number of steps to use the striding from the
23 | DDIM paper.
24 | :return: a set of diffusion steps from the original process to use.
25 | """
26 | if isinstance(section_counts, str):
27 | if section_counts.startswith("ddim"):
28 | desired_count = int(section_counts[len("ddim") :])
29 | for i in range(1, num_timesteps):
30 | if len(range(0, num_timesteps, i)) == desired_count:
31 | return set(range(0, num_timesteps, i))
32 | raise ValueError(
33 | f"cannot create exactly {num_timesteps} steps with an integer stride"
34 | )
35 | section_counts = [int(x) for x in section_counts.split(",")]
36 | size_per = num_timesteps // len(section_counts)
37 | extra = num_timesteps % len(section_counts)
38 | start_idx = 0
39 | all_steps = []
40 | for i, section_count in enumerate(section_counts):
41 | size = size_per + (1 if i < extra else 0)
42 | if size < section_count:
43 | raise ValueError(
44 | f"cannot divide section of {size} steps into {section_count}"
45 | )
46 | if section_count <= 1:
47 | frac_stride = 1
48 | else:
49 | frac_stride = (size - 1) / (section_count - 1)
50 | cur_idx = 0.0
51 | taken_steps = []
52 | for _ in range(section_count):
53 | taken_steps.append(start_idx + round(cur_idx))
54 | cur_idx += frac_stride
55 | all_steps += taken_steps
56 | start_idx += size
57 | return set(all_steps)
58 |
59 |
60 | class SpacedDiffusion(GaussianDiffusion):
61 | """
62 | A diffusion process which can skip steps in a base diffusion process.
63 | :param use_timesteps: a collection (sequence or set) of timesteps from the
64 | original diffusion process to retain.
65 | :param kwargs: the kwargs to create the base diffusion process.
66 | """
67 |
68 | def __init__(self, use_timesteps, **kwargs):
69 | self.use_timesteps = set(use_timesteps)
70 | self.timestep_map = []
71 | self.original_num_steps = len(kwargs["betas"])
72 |
73 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
74 | last_alpha_cumprod = 1.0
75 | new_betas = []
76 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
77 | if i in self.use_timesteps:
78 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
79 | last_alpha_cumprod = alpha_cumprod
80 | self.timestep_map.append(i)
81 | kwargs["betas"] = np.array(new_betas)
82 | super().__init__(**kwargs)
83 |
84 | def p_mean_variance(
85 | self, model, *args, **kwargs
86 | ): # pylint: disable=signature-differs
87 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
88 |
89 | def training_losses(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
93 |
94 | def condition_mean(self, cond_fn, *args, **kwargs):
95 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
96 |
97 | def condition_score(self, cond_fn, *args, **kwargs):
98 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
99 |
100 | def _wrap_model(self, model):
101 | if isinstance(model, _WrappedModel):
102 | return model
103 | return _WrappedModel(
104 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
105 | )
106 |
107 | def _scale_timesteps(self, t):
108 | # Scaling is done by the wrapped model.
109 | return t
110 |
111 |
112 | class _WrappedModel:
113 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
114 | self.model = model
115 | self.timestep_map = timestep_map
116 | self.rescale_timesteps = rescale_timesteps
117 | self.original_num_steps = original_num_steps
118 |
119 | def __call__(self, x, ts, **kwargs):
120 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
121 | new_ts = map_tensor[ts]
122 | if self.rescale_timesteps:
123 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
124 | return self.model(x, new_ts, **kwargs)
125 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import linalg
3 |
4 |
5 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
6 | def euclidean_distance_matrix(matrix1, matrix2):
7 | """
8 | Params:
9 | -- matrix1: N1 x D
10 | -- matrix2: N2 x D
11 | Returns:
12 | -- dist: N1 x N2
13 | dist[i, j] == distance(matrix1[i], matrix2[j])
14 | """
15 | assert matrix1.shape[1] == matrix2.shape[1]
16 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
17 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
18 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
19 | dists = np.sqrt(d1 + d2 + d3) # broadcasting
20 | return dists
21 |
22 | def calculate_top_k(mat, top_k):
23 | size = mat.shape[0]
24 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
25 | bool_mat = (mat == gt_mat)
26 | correct_vec = False
27 | top_k_list = []
28 | for i in range(top_k):
29 | # print(correct_vec, bool_mat[:, i])
30 | correct_vec = (correct_vec | bool_mat[:, i])
31 | # print(correct_vec)
32 | top_k_list.append(correct_vec[:, None])
33 | top_k_mat = np.concatenate(top_k_list, axis=1)
34 | return top_k_mat
35 |
36 |
37 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
38 | dist_mat = euclidean_distance_matrix(embedding1, embedding2)
39 | argmax = np.argsort(dist_mat, axis=1)
40 | top_k_mat = calculate_top_k(argmax, top_k)
41 | if sum_all:
42 | return top_k_mat.sum(axis=0)
43 | else:
44 | return top_k_mat
45 |
46 |
47 | def calculate_matching_score(embedding1, embedding2, sum_all=False):
48 | assert len(embedding1.shape) == 2
49 | assert embedding1.shape[0] == embedding2.shape[0]
50 | assert embedding1.shape[1] == embedding2.shape[1]
51 |
52 | dist = linalg.norm(embedding1 - embedding2, axis=1)
53 | if sum_all:
54 | return dist.sum(axis=0)
55 | else:
56 | return dist
57 |
58 |
59 |
60 | def calculate_activation_statistics(activations):
61 | """
62 | Params:
63 | -- activation: num_samples x dim_feat
64 | Returns:
65 | -- mu: dim_feat
66 | -- sigma: dim_feat x dim_feat
67 | """
68 | mu = np.mean(activations, axis=0)
69 | cov = np.cov(activations, rowvar=False)
70 | return mu, cov
71 |
72 |
73 | def calculate_diversity(activation, diversity_times):
74 | assert len(activation.shape) == 2
75 | assert activation.shape[0] > diversity_times
76 | num_samples = activation.shape[0]
77 |
78 | first_indices = np.random.choice(num_samples, diversity_times, replace=False)
79 | second_indices = np.random.choice(num_samples, diversity_times, replace=False)
80 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
81 | return dist.mean()
82 |
83 |
84 | def calculate_multimodality(activation, multimodality_times):
85 | assert len(activation.shape) == 3
86 | assert activation.shape[1] > multimodality_times
87 | num_per_sent = activation.shape[1]
88 |
89 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
90 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
91 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
92 | return dist.mean()
93 |
94 |
95 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
96 | """Numpy implementation of the Frechet Distance.
97 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
98 | and X_2 ~ N(mu_2, C_2) is
99 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
100 | Stable version by Dougal J. Sutherland.
101 | Params:
102 | -- mu1 : Numpy array containing the activations of a layer of the
103 | inception net (like returned by the function 'get_predictions')
104 | for generated samples.
105 | -- mu2 : The sample mean over activations, precalculated on an
106 | representative data set.
107 | -- sigma1: The covariance matrix over activations for generated samples.
108 | -- sigma2: The covariance matrix over activations, precalculated on an
109 | representative data set.
110 | Returns:
111 | -- : The Frechet Distance.
112 | """
113 |
114 | mu1 = np.atleast_1d(mu1)
115 | mu2 = np.atleast_1d(mu2)
116 |
117 | sigma1 = np.atleast_2d(sigma1)
118 | sigma2 = np.atleast_2d(sigma2)
119 |
120 | assert mu1.shape == mu2.shape, \
121 | 'Training and test mean vectors have different lengths'
122 | assert sigma1.shape == sigma2.shape, \
123 | 'Training and test covariances have different dimensions'
124 |
125 | diff = mu1 - mu2
126 |
127 | # Product might be almost singular
128 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
129 | if not np.isfinite(covmean).all():
130 | msg = ('fid calculation produces singular product; '
131 | 'adding %s to diagonal of cov estimates') % eps
132 | print(msg)
133 | offset = np.eye(sigma1.shape[0]) * eps
134 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
135 |
136 | # Numerical error might give slight imaginary component
137 | if np.iscomplexobj(covmean):
138 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
139 | m = np.max(np.abs(covmean.imag))
140 | raise ValueError('Imaginary component {}'.format(m))
141 | covmean = covmean.real
142 |
143 | tr_covmean = np.trace(covmean)
144 |
145 | return (diff.dot(diff) + np.trace(sigma1) +
146 | np.trace(sigma2) - 2 * tr_covmean)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ##
DiffSHEG: A Diffusion-Based Approach for Real-Time Speech-driven Holistic 3D Expression and Gesture Generation
4 | (CVPR 2024 Official Repo)
5 |
6 | [Junming Chen](https://jeremycjm.github.io)
†1,2, [Yunfei Liu](http://liuyunfei.net/)
2, [Jianan Wang](https://scholar.google.com/citations?user=mt5mvZ8AAAAJ&hl=en&inst=1381320739207392350)
2, [Ailing Zeng](https://ailingzeng.site/)
2, [Yu Li](https://yu-li.github.io/)
*2, [Qifeng Chen](https://cqf.io)
*1
7 |
8 |
1HKUST 2International Digital Economy Academy (IDEA)
9 |
*Corresponding authors †Work done during an internship at IDEA
10 |
11 | #### [Project Page](https://jeremycjm.github.io/proj/DiffSHEG/) · [Paper](https://arxiv.org/abs/2401.04747) · [Video](https://www.youtube.com/watch?v=HFaSd5do-zI)
12 |
13 |
14 |
15 | 
16 |
17 | ## Environment
18 | We have tested on Ubuntu 18.04 and 20.04.
19 | ```
20 | cd assets
21 | ```
22 | - Option 1: conda install
23 | ```
24 | conda env create -f environment.yml
25 | conda activate diffsheg
26 | ```
27 | - Option 2: pip install
28 | ```
29 | conda create -n "diffsheg" python=3.9
30 | conda activate diffsheg
31 | pip install -r requirements.txt
32 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
33 | ```
34 | - Untar data.tar.gz for data statistics
35 | ```
36 | tar zxvf data.tar.gz
37 | mv data ../
38 | ```
39 |
40 | ## Checkpoints
41 | [Google Drive](https://drive.google.com/file/d/1JPoMOcGDrvkFt7QbN6sEyYAPOOWkVN0h/view)
42 |
43 | ## Inference on a Custom Audio
44 | First specify the '--test_audio_path' argument to your test audio path in the following mentioned bash files. Note that the audio should be a .wav file.
45 |
46 | - Use model trained on BEAT dataset:
47 | ```
48 | bash inference_custom_audio_beat.sh
49 | ```
50 |
51 | - Use model trained on SHOW dataset:
52 | ```
53 | bash inference_custom_audio_talkshow.sh
54 | ```
55 | ## Training
56 | Train on BEAT dataset
57 |
58 | ```
59 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
60 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -u runner.py \
61 | --dataset_name beat \
62 | --name beat_diffsheg \
63 | --batch_size 2500 \
64 | --num_epochs 1000 \
65 | --save_every_e 20 \
66 | --eval_every_e 40 \
67 | --n_poses 34 \
68 | --ddim \
69 | --multiprocessing-distributed \
70 | --dist-url 'tcp://127.0.0.1:6666'
71 | ```
72 |
73 |
74 |
75 | Train on SHOW dataset
76 |
77 | ```
78 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
79 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -u runner.py \
80 | --dataset_name talkshow \
81 | --name talkshow_diffsheg \
82 | --batch_size 950 \
83 | --num_epochs 4000 \
84 | --save_every_e 20 \
85 | --eval_every_e 40 \
86 | --n_poses 88 \
87 | --classifier_free \
88 | --multiprocessing-distributed \
89 | --dist-url 'tcp://127.0.0.1:6667' \
90 | --ddim \
91 | --max_eval_samples 200
92 | ```
93 |
94 |
95 | ## Testing
96 |
97 | Test on BEAT dataset
98 |
99 | ```
100 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
101 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -u runner.py \
102 | --dataset_name talkshow \
103 | --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \
104 | --PE pe_sinu \
105 | --n_poses 88 \
106 | --multiprocessing-distributed \
107 | --dist-url 'tcp://127.0.0.1:8889' \
108 | --classifier_free \
109 | --cond_scale 1.25 \
110 | --ckpt ckpt_e2599.tar \
111 | --mode test_arbitrary_len \
112 | --ddim \
113 | --timestep_respacing ddim25 \
114 | --overlap_len 10
115 | ```
116 |
117 |
118 |
119 | Test on SHOW dataset
120 |
121 | ```
122 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
123 | OMP_NUM_THREADS=10 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -u runner.py \
124 | --dataset_name talkshow \
125 | --name talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree \
126 | --PE pe_sinu \
127 | --n_poses 88 \
128 | --multiprocessing-distributed \
129 | --dist-url 'tcp://127.0.0.1:8889' \
130 | --classifier_free \
131 | --cond_scale 1.25 \
132 | --ckpt ckpt_e2599.tar \
133 | --mode test_arbitrary_len \
134 | --ddim \
135 | --timestep_respacing ddim25 \
136 | --overlap_len 10
137 | ```
138 |
139 |
140 | ## Visualization
141 | After running under the test or test-custom-audio mode, the Gesture and Expression results will be saved in the ./results directory.
142 | ### BEAT
143 | 1. Open ```assets/beat_visualize.blend``` with latest Blender on your local computer.
144 | 2. Specify the audio, BVH (for gesture), JSON (for expression), and video saving path in the transcript in Blender.
145 | 3. (Optional) Click Window --> Toggle System Console to check the visulization progress.
146 | 4. Run the script in Blender.
147 | ### SHOW
148 | Please refer the the [TalkSHOW](https://github.com/yhw-yhw/TalkSHOW) code for the visualization of our generated motion.
149 |
150 | ## Acknowledgement
151 | Our implementation is partially based on [BEAT](https://github.com/PantoMatrix/BEAT), [TalkSHOW](https://github.com/yhw-yhw/TalkSHOW), and [MotionDiffuse](https://github.com/mingyuan-zhang/MotionDiffuse/tree/main).
152 |
153 | ## Citation
154 | If you use our code or find this repo useful, please consider cite our paper:
155 | ```
156 | @inproceedings{chen2024diffsheg,
157 | title = {DiffSHEG: A Diffusion-Based Approach for Real-Time Speech-driven Holistic 3D Expression and Gesture Generation},
158 | author = {Chen, Junming and Liu, Yunfei and Wang, Jianan and Zeng, Ailing and Li, Yu and Chen, Qifeng},
159 | booktitle = {CVPR},
160 | year = {2024}
161 | }
162 | ```
163 |
164 |
165 |
--------------------------------------------------------------------------------
/datasets/pymo/mocapplayer/js/skeletonFactory.js:
--------------------------------------------------------------------------------
1 | bm_v = new THREE.MeshPhongMaterial({
2 | color: 0x08519c,
3 | emissive: 0x08306b,
4 | specular: 0x08519c,
5 | shininess: 10,
6 | side: THREE.DoubleSide
7 | });
8 |
9 | jm_v = new THREE.MeshPhongMaterial({
10 | color: 0x08306b,
11 | emissive: 0x000000,
12 | specular: 0x111111,
13 | shininess: 90,
14 | side: THREE.DoubleSide
15 | });
16 |
17 | bm_a = new THREE.MeshPhongMaterial({
18 | color: 0x980043,
19 | emissive: 0x67001f,
20 | specular: 0x6a51a3,
21 | shininess: 10,
22 | side: THREE.DoubleSide
23 | });
24 |
25 | jm_a = new THREE.MeshPhongMaterial({
26 | color: 0x67001f,
27 | emissive: 0x000000,
28 | specular: 0x111111,
29 | shininess: 90,
30 | side: THREE.DoubleSide
31 | });
32 |
33 | bm_b = new THREE.MeshPhongMaterial({
34 | color: 0x3f007d,
35 | emissive: 0x3f007d,
36 | specular: 0x807dba,
37 | shininess: 2,
38 | side: THREE.DoubleSide
39 | });
40 |
41 | jm_b = new THREE.MeshPhongMaterial({
42 | color: 0x3f007d,
43 | emissive: 0x000000,
44 | specular: 0x807dba,
45 | shininess: 90,
46 | side: THREE.DoubleSide
47 | });
48 |
49 | //------------------
50 |
51 |
52 | jointmaterial = new THREE.MeshLambertMaterial({
53 | color: 0xc57206,
54 | emissive: 0x271c18,
55 | side: THREE.DoubleSide,
56 | // shading: THREE.FlatShading,
57 | wireframe: false,
58 | shininess: 90,
59 | });
60 |
61 | bonematerial = new THREE.MeshPhongMaterial({
62 | color: 0xbd9a6d,
63 | emissive: 0x271c18,
64 | side: THREE.DoubleSide,
65 | // shading: THREE.FlatShading,
66 | wireframe: false
67 | });
68 |
69 | jointmaterial2 = new THREE.MeshPhongMaterial({
70 | color: 0x1562a2,
71 | emissive: 0x000000,
72 | specular: 0x111111,
73 | shininess: 30,
74 | side: THREE.DoubleSide
75 | });
76 |
77 | bonematerial2 = new THREE.MeshPhongMaterial({
78 | color: 0x552211,
79 | emissive: 0x882211,
80 | // emissive: 0x000000,
81 | specular: 0x111111,
82 | shininess: 30,
83 | side: THREE.DoubleSide
84 | });
85 |
86 | bonematerial3 = new THREE.MeshPhongMaterial({
87 | color: 0x176793,
88 | emissive: 0x000000,
89 | specular: 0x111111,
90 | shininess: 90,
91 | side: THREE.DoubleSide
92 | });
93 |
94 |
95 |
96 | jointmaterial4 = new THREE.MeshPhongMaterial({
97 | color: 0xFF8A00,
98 | emissive: 0x000000,
99 | specular: 0x111111,
100 | shininess: 90,
101 | side: THREE.DoubleSide
102 | });
103 |
104 |
105 | bonematerial4 = new THREE.MeshPhongMaterial({
106 | color: 0x53633D,
107 | emissive: 0x000000,
108 | specular: 0xFFC450,
109 | shininess: 90,
110 | side: THREE.DoubleSide
111 | });
112 |
113 |
114 |
115 | bonematerial44 = new THREE.MeshPhongMaterial({
116 | color: 0x582A72,
117 | emissive: 0x000000,
118 | specular: 0xFFC450,
119 | shininess: 90,
120 | side: THREE.DoubleSide
121 | });
122 |
123 | jointmaterial5 = new THREE.MeshPhongMaterial({
124 | color: 0xAA5533,
125 | emissive: 0x000000,
126 | specular: 0x111111,
127 | shininess: 30,
128 | side: THREE.DoubleSide
129 | });
130 |
131 | bonematerial5 = new THREE.MeshPhongMaterial({
132 | color: 0x552211,
133 | emissive: 0x772211,
134 | specular: 0x111111,
135 | shininess: 30,
136 | side: THREE.DoubleSide
137 | });
138 |
139 |
140 | markermaterial = new THREE.MeshPhongMaterial({
141 | color: 0xc57206,
142 | emissive: 0x271c18,
143 | side: THREE.DoubleSide,
144 | // shading: THREE.FlatShading,
145 | wireframe: false,
146 | shininess: 20,
147 | });
148 |
149 | markermaterial2 = new THREE.MeshPhongMaterial({
150 | color: 0x1562a2,
151 | emissive: 0x271c18,
152 | side: THREE.DoubleSide,
153 | // shading: THREE.FlatShading,
154 | wireframe: false,
155 | shininess: 20,
156 | });
157 |
158 | markermaterial3 = new THREE.MeshPhongMaterial({
159 | color: 0x555555,
160 | emissive: 0x999999,
161 | side: THREE.DoubleSide,
162 | // shading: THREE.FlatShading,
163 | wireframe: false,
164 | shininess: 20,
165 | });
166 |
167 |
168 | var makeMarkerGeometry_Sphere10 = function(markerName, scale) {
169 | return new THREE.SphereGeometry(10, 60, 60);
170 | };
171 |
172 | var makeMarkerGeometry_Sphere3 = function(markerName, scale) {
173 | return new THREE.SphereGeometry(3, 60, 60);
174 | };
175 |
176 | var makeMarkerGeometry_SphereX = function(markerName, scale) {
177 | return new THREE.SphereGeometry(5, 60, 60);
178 | };
179 |
180 | var makeJointGeometry_SphereX = function(X) {
181 | return function(jointName, scale) {
182 | return new THREE.SphereGeometry(X, 60, 60);
183 | };
184 | };
185 |
186 |
187 | var makeJointGeometry_Sphere1 = function(jointName, scale) {
188 | return new THREE.SphereGeometry(2 / scale, 60, 60);
189 | };
190 |
191 | var makeJointGeometry_Sphere2 = function(jointName, scale) {
192 | return new THREE.SphereGeometry(1 / scale, 60, 60);
193 | };
194 |
195 | var makeJointGeometry_Dode = function(jointName, scale) {
196 | return new THREE.DodecahedronGeometry(1 / scale, 0);
197 | };
198 |
199 | var makeBoneGeometry_Cylinder1 = function(joint1Name, joint2Name, length, scale) {
200 | return new THREE.CylinderGeometry(1.5 / scale, 0.7 / scale, length, 40);
201 | };
202 |
203 | var makeBoneGeometry_Cylinder2 = function(joint1Name, joint2Name, length, scale) {
204 | // if (joint1Name.includes("LeftHip"))
205 | // length = 400;
206 | return new THREE.CylinderGeometry(1.5 / scale, 0.2 / scale, length, 40);
207 | };
208 |
209 | var makeBoneGeometry_Cylinder3 = function(joint1Name, joint2Name, length, scale) {
210 | var c1 = new THREE.CylinderGeometry(1.5 / scale, 0.2 / scale, length / 1, 20);
211 | var c2 = new THREE.CylinderGeometry(0.2 / scale, 1.5 / scale, length / 1, 40);
212 |
213 | var material = new THREE.MeshPhongMaterial({
214 | color: 0xF7FE2E
215 | });
216 | var mmesh = new THREE.Mesh(c1, material);
217 | mmesh.updateMatrix();
218 | c2.merge(mmesh.geometry, mmesh.matrix);
219 | return c2;
220 | };
221 |
222 | var makeBoneGeometry_Box1 = function(joint1Name, joint2Name, length, scale) {
223 | return new THREE.BoxGeometry(1 / scale, length, 1 / scale, 40);
224 | };
225 |
226 |
227 | var makeJointGeometry_Empty = function(jointName, scale) {
228 | return new THREE.SphereGeometry(0.001, 60, 60);
229 | };
230 |
231 | var makeBoneGeometry_Empty = function(joint1Name, joint2Name, length, scale) {
232 | return new THREE.CylinderGeometry(0.001, 0.001, 0.001, 40);
233 | };
234 |
--------------------------------------------------------------------------------
/models/motion_autoencoder.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def reparameterize(mu, logvar):
8 | std = torch.exp(0.5 * logvar)
9 | eps = torch.randn_like(std)
10 | return mu + eps * std
11 |
12 |
13 | def ConvNormRelu(in_channels, out_channels, downsample=False, padding=0, batchnorm=True):
14 | if not downsample:
15 | k = 3
16 | s = 1
17 | else:
18 | k = 4
19 | s = 2
20 |
21 | conv_block = nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
22 | norm_block = nn.BatchNorm1d(out_channels)
23 |
24 | if batchnorm:
25 | net = nn.Sequential(
26 | conv_block,
27 | norm_block,
28 | nn.LeakyReLU(0.2, True)
29 | )
30 | else:
31 | net = nn.Sequential(
32 | conv_block,
33 | nn.LeakyReLU(0.2, True)
34 | )
35 | return net
36 |
37 |
38 | class PoseEncoderConv(nn.Module):
39 | def __init__(self, length, dim, feature_length=32):
40 | super().__init__()
41 | self.base = feature_length
42 | self.net = nn.Sequential(
43 | ConvNormRelu(dim, self.base, batchnorm=True), #32
44 | ConvNormRelu(self.base, self.base*2, batchnorm=True), #30
45 | ConvNormRelu(self.base*2, self.base*2, True, batchnorm=True), #14
46 | nn.Conv1d(self.base*2, self.base, 3)
47 | )
48 | if length == 88:
49 | self.out_net = nn.Sequential(
50 | nn.Linear(39*self.base, self.base*12), # for 64 frames
51 | nn.BatchNorm1d(self.base*12),
52 | nn.Linear(12*self.base, self.base*4),
53 | nn.BatchNorm1d(self.base*4),
54 | nn.LeakyReLU(True),
55 | nn.Linear(self.base*4, self.base*2),
56 | nn.BatchNorm1d(self.base*2),
57 | nn.LeakyReLU(True),
58 | nn.Linear(self.base*2, self.base),
59 | )
60 |
61 | if length == 64:
62 | self.out_net = nn.Sequential(
63 | nn.Linear(27*self.base, self.base*12), # for 64 frames
64 | nn.BatchNorm1d(self.base*12),
65 | nn.Linear(12*self.base, self.base*4),
66 | nn.BatchNorm1d(self.base*4),
67 | nn.LeakyReLU(True),
68 | nn.Linear(self.base*4, self.base*2),
69 | nn.BatchNorm1d(self.base*2),
70 | nn.LeakyReLU(True),
71 | nn.Linear(self.base*2, self.base),
72 | )
73 |
74 | if length == 34:
75 | self.out_net = nn.Sequential(
76 | nn.Linear(12*self.base, self.base*4), # for 34 frames
77 | nn.BatchNorm1d(self.base*4),
78 | nn.LeakyReLU(True),
79 | nn.Linear(self.base*4, self.base*2),
80 | nn.BatchNorm1d(self.base*2),
81 | nn.LeakyReLU(True),
82 | nn.Linear(self.base*2, self.base),
83 | )
84 |
85 | self.fc_mu = nn.Linear(self.base, self.base)
86 | self.fc_logvar = nn.Linear(self.base, self.base)
87 |
88 | def forward(self, poses, variational_encoding=None):
89 | # encode
90 | poses = poses.transpose(1, 2)
91 | out = self.net(poses)
92 | out = out.flatten(1)
93 | out = self.out_net(out)
94 | mu = self.fc_mu(out)
95 | logvar = self.fc_logvar(out)
96 | if variational_encoding:
97 | z = reparameterize(mu, logvar)
98 | else:
99 | z = mu
100 | return z, mu, logvar
101 |
102 |
103 | class PoseDecoderConv(nn.Module):
104 | def __init__(self, length, dim, use_pre_poses=False, feature_length=32):
105 | super().__init__()
106 | self.use_pre_poses = use_pre_poses
107 | self.feat_size = feature_length
108 |
109 | if use_pre_poses:
110 | self.pre_pose_net = nn.Sequential(
111 | nn.Linear(dim * 4, 32),
112 | nn.BatchNorm1d(32),
113 | nn.ReLU(),
114 | nn.Linear(32, 32),
115 | )
116 | self.feat_size += 32
117 |
118 | # if length == 64:
119 | # self.pre_net = nn.Sequential(
120 | # nn.Linear(self.feat_size, 128),
121 | # nn.BatchNorm1d(128),
122 | # nn.LeakyReLU(True),
123 | # nn.Linear(128, 256),
124 | # )
125 | # elif length == 34:
126 | # self.pre_net = nn.Sequential(
127 | # nn.Linear(self.feat_size, self.feat_size*2),
128 | # nn.BatchNorm1d(self.feat_size*2),
129 | # nn.LeakyReLU(True),
130 | # nn.Linear(self.feat_size*2, self.feat_size//8*34),
131 | # )
132 | # else:
133 | # assert False
134 | self.pre_net = nn.Sequential(
135 | nn.Linear(self.feat_size, self.feat_size*2),
136 | nn.BatchNorm1d(self.feat_size*2),
137 | nn.LeakyReLU(True),
138 | nn.Linear(self.feat_size*2, self.feat_size//8*length),
139 | )
140 |
141 |
142 | self.decoder_size = self.feat_size//8
143 | self.net = nn.Sequential(
144 | nn.ConvTranspose1d(self.decoder_size, self.feat_size, 3),
145 | nn.BatchNorm1d(self.feat_size),
146 | nn.LeakyReLU(0.2, True),
147 |
148 | nn.ConvTranspose1d(self.feat_size, self.feat_size, 3),
149 | nn.BatchNorm1d(self.feat_size),
150 | nn.LeakyReLU(0.2, True),
151 | nn.Conv1d(self.feat_size, self.feat_size*2, 3),
152 | nn.Conv1d(self.feat_size*2, dim, 3),
153 | )
154 |
155 | def forward(self, feat, pre_poses=None):
156 | if self.use_pre_poses:
157 | pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1))
158 | feat = torch.cat((pre_pose_feat, feat), dim=1)
159 | #print(feat.shape)
160 | out = self.pre_net(feat)
161 | #print(out.shape)
162 | out = out.view(feat.shape[0], self.decoder_size, -1)
163 | #print(out.shape)
164 | out = self.net(out)
165 | out = out.transpose(1, 2)
166 | return out
167 |
168 |
169 | class EmbeddingNet(nn.Module):
170 | def __init__(self, args):
171 | super().__init__()
172 | n_frames = args.n_poses
173 | # n_frames = 34
174 | pose_dim = args.net_dim_pose
175 | feature_length = args.vae_length
176 | self.pose_encoder = PoseEncoderConv(n_frames, pose_dim, feature_length=feature_length)
177 | self.decoder = PoseDecoderConv(n_frames, pose_dim, feature_length=feature_length)
178 |
179 | def forward(self, pre_poses, poses, variational_encoding=False):
180 | poses_feat, pose_mu, pose_logvar = self.pose_encoder(poses, variational_encoding)
181 | latent_feat = poses_feat
182 | out_poses = self.decoder(latent_feat, pre_poses)
183 | return poses_feat, pose_mu, pose_logvar, out_poses
184 |
185 | def freeze_pose_nets(self):
186 | for param in self.pose_encoder.parameters():
187 | param.requires_grad = False
188 | for param in self.decoder.parameters():
189 | param.requires_grad = False
190 |
191 |
192 | class HalfEmbeddingNet(nn.Module):
193 | def __init__(self, args):
194 | super().__init__()
195 | n_frames = args.n_poses
196 | # n_frames = 34
197 | pose_dim = args.net_dim_pose
198 | feature_length = args.vae_length
199 | self.pose_encoder = PoseEncoderConv(n_frames, pose_dim, feature_length=feature_length)
200 | self.decoder = PoseDecoderConv(n_frames, pose_dim, feature_length=feature_length)
201 |
202 | def forward(self, poses):
203 | poses_feat, _, _ = self.pose_encoder(poses)
204 | return poses_feat
--------------------------------------------------------------------------------
/models/scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16 |
17 | def get_schedule(t_T, t_0, n_sample, n_steplength, debug=0):
18 | if n_steplength > 1:
19 | if not n_sample > 1:
20 | raise RuntimeError('n_steplength has no effect if n_sample=1')
21 |
22 | t = t_T
23 | times = [t]
24 | while t >= 0:
25 | t = t - 1
26 | times.append(t)
27 | n_steplength_cur = min(n_steplength, t_T - t)
28 |
29 | for _ in range(n_sample - 1):
30 |
31 | for _ in range(n_steplength_cur):
32 | t = t + 1
33 | times.append(t)
34 | for _ in range(n_steplength_cur):
35 | t = t - 1
36 | times.append(t)
37 |
38 | _check_times(times, t_0, t_T)
39 |
40 | if debug == 2:
41 | for x in [list(range(0, 50)), list(range(-1, -50, -1))]:
42 | _plot_times(x=x, times=[times[i] for i in x])
43 |
44 | return times
45 |
46 |
47 | def _check_times(times, t_0, t_T):
48 | # Check end
49 | assert times[0] > times[1], (times[0], times[1])
50 |
51 | # Check beginning
52 | assert times[-1] == -1, times[-1]
53 |
54 | # Steplength = 1
55 | for t_last, t_cur in zip(times[:-1], times[1:]):
56 | assert abs(t_last - t_cur) == 1, (t_last, t_cur)
57 |
58 | # Value range
59 | for t in times:
60 | assert t >= t_0, (t, t_0)
61 | assert t <= t_T, (t, t_T)
62 |
63 |
64 | def _plot_times(x, times):
65 | import matplotlib.pyplot as plt
66 | plt.plot(x, times)
67 | plt.show()
68 |
69 |
70 | def get_schedule_jump(t_T, n_sample, jump_length, jump_n_sample,
71 | jump2_length=1, jump2_n_sample=1,
72 | jump3_length=1, jump3_n_sample=1,
73 | start_resampling=100000000):
74 |
75 | jumps = {}
76 | for j in range(0, t_T - jump_length, jump_length):
77 | jumps[j] = jump_n_sample - 1
78 |
79 | jumps2 = {}
80 | for j in range(0, t_T - jump2_length, jump2_length):
81 | jumps2[j] = jump2_n_sample - 1
82 |
83 | jumps3 = {}
84 | for j in range(0, t_T - jump3_length, jump3_length):
85 | jumps3[j] = jump3_n_sample - 1
86 |
87 | t = t_T
88 | ts = []
89 |
90 | while t >= 1:
91 | t = t-1
92 | ts.append(t)
93 |
94 | if (
95 | t + 1 < t_T - 1 and
96 | t <= start_resampling
97 | ):
98 | for _ in range(n_sample - 1):
99 | t = t + 1
100 | ts.append(t)
101 |
102 | if t >= 0:
103 | t = t - 1
104 | ts.append(t)
105 |
106 | if (
107 | jumps3.get(t, 0) > 0 and
108 | t <= start_resampling - jump3_length
109 | ):
110 | jumps3[t] = jumps3[t] - 1
111 | for _ in range(jump3_length):
112 | t = t + 1
113 | ts.append(t)
114 |
115 | if (
116 | jumps2.get(t, 0) > 0 and
117 | t <= start_resampling - jump2_length
118 | ):
119 | jumps2[t] = jumps2[t] - 1
120 | for _ in range(jump2_length):
121 | t = t + 1
122 | ts.append(t)
123 | jumps3 = {}
124 | for j in range(0, t_T - jump3_length, jump3_length):
125 | jumps3[j] = jump3_n_sample - 1
126 |
127 | if (
128 | jumps.get(t, 0) > 0 and
129 | t <= start_resampling - jump_length
130 | ):
131 | jumps[t] = jumps[t] - 1
132 | for _ in range(jump_length):
133 | t = t + 1
134 | ts.append(t)
135 | jumps2 = {}
136 | for j in range(0, t_T - jump2_length, jump2_length):
137 | jumps2[j] = jump2_n_sample - 1
138 |
139 | jumps3 = {}
140 | for j in range(0, t_T - jump3_length, jump3_length):
141 | jumps3[j] = jump3_n_sample - 1
142 |
143 | ts.append(-1)
144 |
145 | _check_times(ts, -1, t_T)
146 |
147 | return ts
148 |
149 |
150 | def get_schedule_jump_paper():
151 | t_T = 250
152 | jump_length = 10
153 | jump_n_sample = 10
154 |
155 | jumps = {}
156 | for j in range(0, t_T - jump_length, jump_length):
157 | jumps[j] = jump_n_sample - 1
158 |
159 | t = t_T
160 | ts = []
161 |
162 | while t >= 1:
163 | t = t-1
164 | ts.append(t)
165 |
166 | if jumps.get(t, 0) > 0:
167 | jumps[t] = jumps[t] - 1
168 | for _ in range(jump_length):
169 | t = t + 1
170 | ts.append(t)
171 |
172 | ts.append(-1)
173 |
174 | _check_times(ts, -1, t_T)
175 |
176 | return ts
177 |
178 | def get_schedule_jump_cjm_ddim(time_respacing=25, jump_length=1, jump_n_sample=1):
179 | if time_respacing == 25:
180 | t_T = 15
181 | else:
182 | t_T = int(time_respacing * 0.6)
183 |
184 | # t_T = time_respacing
185 | # t_T = 15
186 |
187 | jumps = {}
188 | for j in range(0, t_T - jump_length, jump_length):
189 | jumps[j] = jump_n_sample - 1
190 |
191 | t = t_T
192 | ts = []
193 |
194 | while t >= 1:
195 | t = t-1
196 | ts.append(t)
197 |
198 | if jumps.get(t, 0) > 0:
199 | jumps[t] = jumps[t] - 1
200 | for _ in range(jump_length):
201 | t = t + 1
202 | ts.append(t)
203 |
204 | ts.append(-1)
205 |
206 | _check_times(ts, -1, t_T)
207 |
208 | return ts
209 |
210 |
211 | def get_schedule_jump_test(to_supplement=False):
212 | ts = get_schedule_jump(t_T=250, n_sample=1,
213 | jump_length=10, jump_n_sample=10,
214 | jump2_length=1, jump2_n_sample=1,
215 | jump3_length=1, jump3_n_sample=1,
216 | start_resampling=250)
217 |
218 | import matplotlib.pyplot as plt
219 | SMALL_SIZE = 8*3
220 | MEDIUM_SIZE = 10*3
221 | BIGGER_SIZE = 12*3
222 |
223 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes
224 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
225 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
226 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
227 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
228 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
229 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
230 |
231 | plt.plot(ts)
232 |
233 | fig = plt.gcf()
234 | fig.set_size_inches(20, 10)
235 |
236 | ax = plt.gca()
237 | ax.set_xlabel('Number of Transitions')
238 | ax.set_ylabel('Diffusion time $t$')
239 |
240 | fig.tight_layout()
241 |
242 | if to_supplement:
243 | out_path = "/cluster/home/alugmayr/gdiff/paper/supplement/figures/jump_sched.pdf"
244 | plt.savefig(out_path)
245 |
246 | out_path = "./schedule.png"
247 | plt.savefig(out_path)
248 | print(out_path)
249 |
250 |
251 | def main():
252 | get_schedule_jump_test()
253 |
254 |
255 | if __name__ == "__main__":
256 | main()
--------------------------------------------------------------------------------
/datasets/show.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import data
3 | import numpy as np
4 | import os
5 | from os.path import join as pjoin
6 | import random
7 | import codecs as cs
8 | from tqdm import tqdm
9 | import pickle
10 | import lmdb
11 | import pyarrow
12 | import torch.nn.functional as F
13 |
14 | class ShowDataset(data.Dataset):
15 | """
16 | TED dataset.
17 | Prepares conditioning information (previous poses + control signal) and the corresponding next poses"""
18 |
19 | def __init__(self, opt, data_path):
20 | """
21 | Args:
22 | control_data: The control input
23 | joint_data: body pose input
24 | Both with shape (samples, time-slices, features)s
25 | """
26 | self.opt = opt
27 |
28 | if self.opt.mode != "test_custom_audio":
29 | print("Loading data ...")
30 | self.lmdb_env = lmdb.open(data_path, readonly=True, lock=False)
31 | if opt.use_aud_feat or self.opt.expAddHubert or self.opt.addHubert:
32 | self.aud_feat_path = os.path.join(os.path.dirname(os.path.dirname(data_path)), f"cached_aud_hubert/{os.path.basename(data_path).split('_')[-2]}/hubert_large_ls960_ft")
33 | self.aud_lmdb_env = lmdb.open(self.aud_feat_path, readonly=True, lock=False)
34 |
35 | if opt.use_aud_feat or self.opt.addWav2Vec2:
36 | self.aud_feat_path = os.path.join(os.path.dirname(os.path.dirname(data_path)), f"cached_aud_wav2vec2/{os.path.basename(data_path).split('_')[-2]}/wav2vec2_base_960h")
37 | self.aud_lmdb_env = lmdb.open(self.aud_feat_path, readonly=True, lock=False)
38 |
39 | with self.lmdb_env.begin() as txn:
40 | self.n_samples = txn.stat()["entries"]
41 | # dict_keys(['poses', 'expression', 'aud_feat', 'speaker', 'aud_file', 'betas'])
42 |
43 | mean_std_dict = np.load("data/SHOW/talkshow_mean_std.npy", allow_pickle=True)[()]
44 |
45 | self.pose_mean = self.extract_pose(torch.from_numpy(mean_std_dict["pose_mean"]).float())
46 | self.pose_std = self.extract_pose(torch.from_numpy(mean_std_dict["pose_std"]).float())
47 | self.expression_mean = torch.cat([torch.from_numpy(mean_std_dict["pose_mean"][:3]).float(), torch.from_numpy(mean_std_dict["expression_mean"]).float()], dim=-1)
48 | self.expression_std = torch.cat([torch.from_numpy(mean_std_dict["pose_mean"][:3]).float(), torch.from_numpy(mean_std_dict["expression_std"]).float()], dim=-1)
49 |
50 | self.motion_mean = torch.cat([self.pose_mean, self.expression_mean], dim=-1)
51 | self.motion_std = torch.cat([self.pose_std, self.expression_std], dim=-1)
52 |
53 | if self.opt.usePredExpr:
54 | face_list = os.listdir(self.opt.usePredExpr)
55 | face_list = [ff for ff in face_list if ff.endswith(".npy")]
56 | face_list.sort()
57 | self.face_list = [os.path.join(self.opt.usePredExpr, pp) for pp in face_list]
58 |
59 |
60 |
61 |
62 | def __len__(self):
63 | return self.n_samples
64 |
65 | def __getitem__(self, idx):
66 | """
67 | Returns poses and conditioning.
68 | """
69 | with self.lmdb_env.begin(write=False) as txn:
70 | key = "{:005}".format(idx).encode("ascii")
71 | sample = txn.get(key)
72 | sample = pyarrow.deserialize(sample)
73 | pose, expression, aud_raw, mfcc, mel, speaker, aud_file, betas = sample
74 | # pose, expression, mfcc = pose.T, expression.T, mfcc.T
75 | pose = torch.from_numpy(pose.copy()).float()
76 | expression = torch.from_numpy(expression.copy()).float()
77 | aud_raw = torch.from_numpy(aud_raw.copy()).float()
78 | mfcc = torch.from_numpy(mfcc.copy()).float()
79 | mel = torch.from_numpy(mel.copy()).float()
80 | speaker = torch.from_numpy(speaker.copy()).float()
81 | betas = torch.from_numpy(betas.copy()).float()
82 |
83 | jaw_pose, leye_pose, reye_pose, global_orient, body_pose, hand_pose = torch.split(pose, [3,3,3,3, 63, 90], dim=-1)
84 | low1, up1, low2, up2, low3, up3, low4, up4 = torch.split(body_pose, [6, 3, 6, 3, 6, 3, 6, 30], dim=-1)
85 | pose = torch.cat([up1, up2, up3, up4, hand_pose], dim=-1)
86 | expression = torch.cat([jaw_pose, expression], dim=-1)
87 |
88 | pose = self.standardize(pose, self.pose_mean, self.pose_std)
89 | expression = self.standardize(expression, self.expression_mean, self.expression_std)
90 |
91 | hubert = None
92 | if self.opt.audio_feat == "hubert" or self.opt.expAddHubert or self.opt.addHubert:
93 | with self.aud_lmdb_env.begin(write=False) as txn_aud:
94 | key = "{:005}".format(idx).encode("ascii")
95 | hubert = txn_aud.get(key)
96 | hubert = pyarrow.deserialize(hubert)
97 | hubert = torch.from_numpy(hubert.copy()).float()
98 | hubert = F.interpolate(hubert.swapaxes(-1,-2).unsqueeze(0), size=pose.shape[0], mode='linear', align_corners=True).swapaxes(-1,-2).squeeze()
99 |
100 | wav2vec2 = None
101 | if self.opt.audio_feat == "wav2vec2" or self.opt.addWav2Vec2:
102 | with self.aud_lmdb_env.begin(write=False) as txn_aud:
103 | key = "{:005}".format(idx).encode("ascii")
104 | wav2vec2 = txn_aud.get(key)
105 | wav2vec2 = pyarrow.deserialize(wav2vec2)
106 | wav2vec2 = torch.from_numpy(wav2vec2.copy()).float()
107 |
108 | if self.opt.audio_feat == "mfcc":
109 | aud_feat = mfcc
110 | elif self.opt.audio_feat == "mel":
111 | aud_feat = mel
112 | elif self.opt.audio_feat == "raw":
113 | aud_feat = aud_raw
114 | elif self.opt.audio_feat == "hubert":
115 | aud_feat = hubert
116 | elif self.opt.audio_feat == "wav2vec2":
117 | aud_feat = wav2vec2
118 |
119 | if self.opt.expAddHubert or self.opt.addHubert:
120 | return {'poses': pose,
121 | 'expression': expression,
122 | 'aud_feat': aud_feat,
123 | 'pretrain_aud_feat': hubert,
124 | 'speaker': speaker,
125 | 'aud_file': aud_file,
126 | 'betas': betas
127 | }
128 | elif self.opt.addWav2Vec2:
129 | return {'poses': pose,
130 | 'expression': expression,
131 | 'aud_feat': aud_feat,
132 | 'pretrain_aud_feat': wav2vec2,
133 | 'speaker': speaker,
134 | 'aud_file': aud_file,
135 | 'betas': betas
136 | }
137 | else:
138 | return {'poses': pose,
139 | 'expression': expression,
140 | 'aud_feat': aud_feat,
141 | 'speaker': speaker,
142 | 'aud_file': aud_file,
143 | 'betas': betas
144 | }
145 |
146 | def extract_pose(self, pose):
147 | jaw_pose, leye_pose, reye_pose, global_orient, body_pose, hand_pose = torch.split(pose, [3,3,3,3, 63, 90], dim=-1)
148 | low1, up1, low2, up2, low3, up3, low4, up4 = torch.split(body_pose, [6, 3, 6, 3, 6, 3, 6, 30], dim=-1)
149 | # pose = torch.cat([jaw_pose, up1, up2, up3, up4, hand_pose], dim=-1)
150 | pose = torch.cat([up1, up2, up3, up4, hand_pose], dim=-1)
151 | return pose
152 |
153 | def standardize(self, data, mean, std):
154 | scaled = (data - mean) / std
155 | return scaled
156 |
157 | def inv_standardize(self, data, mean, std):
158 | try:
159 | inv_scaled = data * std + mean
160 | except:
161 | inv_scaled = data * std.numpy() + mean.numpy()
162 | return inv_scaled
--------------------------------------------------------------------------------
/datasets/pymo/viz_tools.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import IPython
5 | import os
6 |
7 | def save_fig(fig_id, tight_layout=True):
8 | if tight_layout:
9 | plt.tight_layout()
10 | plt.savefig(fig_id + '.png', format='png', dpi=300)
11 |
12 |
13 | def draw_stickfigure(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)):
14 | if ax is None:
15 | fig = plt.figure(figsize=figsize)
16 | ax = fig.add_subplot(111)
17 |
18 | if joints is None:
19 | joints_to_draw = mocap_track.skeleton.keys()
20 | else:
21 | joints_to_draw = joints
22 |
23 | if data is None:
24 | df = mocap_track.values
25 | else:
26 | df = data
27 |
28 | for joint in joints_to_draw:
29 | ax.scatter(x=df['%s_Xposition'%joint][frame],
30 | y=df['%s_Yposition'%joint][frame],
31 | alpha=0.6, c='b', marker='o')
32 |
33 | parent_x = df['%s_Xposition'%joint][frame]
34 | parent_y = df['%s_Yposition'%joint][frame]
35 |
36 | children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw]
37 |
38 | for c in children_to_draw:
39 | child_x = df['%s_Xposition'%c][frame]
40 | child_y = df['%s_Yposition'%c][frame]
41 | ax.plot([parent_x, child_x], [parent_y, child_y], 'k-', lw=2)
42 |
43 | if draw_names:
44 | ax.annotate(joint,
45 | (df['%s_Xposition'%joint][frame] + 0.1,
46 | df['%s_Yposition'%joint][frame] + 0.1))
47 |
48 | return ax
49 |
50 | def draw_stickfigure3d(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)):
51 | from mpl_toolkits.mplot3d import Axes3D
52 |
53 | if ax is None:
54 | fig = plt.figure(figsize=figsize)
55 | ax = fig.add_subplot(111, projection='3d')
56 |
57 | if joints is None:
58 | joints_to_draw = mocap_track.skeleton.keys()
59 | else:
60 | joints_to_draw = joints
61 |
62 | if data is None:
63 | df = mocap_track.values
64 | else:
65 | df = data
66 |
67 | for joint in joints_to_draw:
68 | parent_x = df['%s_Xposition'%joint][frame]
69 | parent_y = df['%s_Zposition'%joint][frame]
70 | parent_z = df['%s_Yposition'%joint][frame]
71 | # ^ In mocaps, Y is the up-right axis
72 |
73 | ax.scatter(xs=parent_x,
74 | ys=parent_y,
75 | zs=parent_z,
76 | alpha=0.6, c='b', marker='o')
77 |
78 |
79 | children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw]
80 |
81 | for c in children_to_draw:
82 | child_x = df['%s_Xposition'%c][frame]
83 | child_y = df['%s_Zposition'%c][frame]
84 | child_z = df['%s_Yposition'%c][frame]
85 | # ^ In mocaps, Y is the up-right axis
86 |
87 | ax.plot([parent_x, child_x], [parent_y, child_y], [parent_z, child_z], 'k-', lw=2, c='black')
88 |
89 | if draw_names:
90 | ax.text(x=parent_x + 0.1,
91 | y=parent_y + 0.1,
92 | z=parent_z + 0.1,
93 | s=joint,
94 | color='rgba(0,0,0,0.9)')
95 |
96 | return ax
97 |
98 |
99 | def sketch_move(mocap_track, data=None, ax=None, figsize=(16,8)):
100 | if ax is None:
101 | fig = plt.figure(figsize=figsize)
102 | ax = fig.add_subplot(111)
103 |
104 | if data is None:
105 | data = mocap_track.values
106 |
107 | for frame in range(0, data.shape[0], 4):
108 | # draw_stickfigure(mocap_track, f, data=data, ax=ax)
109 |
110 | for joint in mocap_track.skeleton.keys():
111 | children_to_draw = [c for c in mocap_track.skeleton[joint]['children']]
112 |
113 | parent_x = data['%s_Xposition'%joint][frame]
114 | parent_y = data['%s_Yposition'%joint][frame]
115 |
116 | frame_alpha = frame/data.shape[0]
117 |
118 | for c in children_to_draw:
119 | child_x = data['%s_Xposition'%c][frame]
120 | child_y = data['%s_Yposition'%c][frame]
121 |
122 | ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha)
123 |
124 |
125 |
126 | def viz_cnn_filter(feature_to_viz, mocap_track, data, gap=25):
127 | fig = plt.figure(figsize=(16,4))
128 | ax = plt.subplot2grid((1,8),(0,0))
129 | ax.imshow(feature_to_viz.T, aspect='auto', interpolation='nearest')
130 |
131 | ax = plt.subplot2grid((1,8),(0,1), colspan=7)
132 | for frame in range(feature_to_viz.shape[0]):
133 | frame_alpha = 0.2#frame/data.shape[0] * 2 + 0.2
134 |
135 | for joint_i, joint in enumerate(mocap_track.skeleton.keys()):
136 | children_to_draw = [c for c in mocap_track.skeleton[joint]['children']]
137 |
138 | parent_x = data['%s_Xposition'%joint][frame] + frame * gap
139 | parent_y = data['%s_Yposition'%joint][frame]
140 |
141 | ax.scatter(x=parent_x,
142 | y=parent_y,
143 | alpha=0.6,
144 | cmap='RdBu',
145 | c=feature_to_viz[frame][joint_i] * 10000,
146 | marker='o',
147 | s = abs(feature_to_viz[frame][joint_i] * 10000))
148 | plt.axis('off')
149 | for c in children_to_draw:
150 | child_x = data['%s_Xposition'%c][frame] + frame * gap
151 | child_y = data['%s_Yposition'%c][frame]
152 |
153 | ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha)
154 |
155 |
156 | def print_skel(X):
157 | stack = [X.root_name]
158 | tab=0
159 | while stack:
160 | joint = stack.pop()
161 | tab = len(stack)
162 | print('%s- %s (%s)'%('| '*tab, joint, X.skeleton[joint]['parent']))
163 | for c in X.skeleton[joint]['children']:
164 | stack.append(c)
165 |
166 |
167 | def nb_play_mocap_fromurl(mocap, mf, frame_time=1/30, scale=1, base_url='http://titan:8385'):
168 | if mf == 'bvh':
169 | bw = BVHWriter()
170 | with open('test.bvh', 'w') as ofile:
171 | bw.write(mocap, ofile)
172 |
173 | filepath = '../notebooks/test.bvh'
174 | elif mf == 'pos':
175 | c = list(mocap.values.columns)
176 |
177 | for cc in c:
178 | if 'rotation' in cc:
179 | c.remove(cc)
180 | mocap.values.to_csv('test.csv', index=False, columns=c)
181 |
182 | filepath = '../notebooks/test.csv'
183 | else:
184 | return
185 |
186 | url = '%s/mocapplayer/player.html?data_url=%s&scale=%f&cz=200&order=xzyi&frame_time=%f'%(base_url, filepath, scale, frame_time)
187 | iframe = ''
188 | link = 'New Window'%url
189 | return IPython.display.HTML(iframe+link)
190 |
191 | def nb_play_mocap(mocap, mf, meta=None, frame_time=1/30, scale=1, camera_z=500, base_url=None):
192 | data_template = 'var dataBuffer = `$$DATA$$`;'
193 | data_template += 'var metadata = $$META$$;'
194 | data_template += 'start(dataBuffer, metadata, $$CZ$$, $$SCALE$$, $$FRAMETIME$$);'
195 | dir_path = os.path.dirname(os.path.realpath(__file__))
196 |
197 |
198 | if base_url is None:
199 | base_url = os.path.join(dir_path, 'mocapplayer/playBuffer.html')
200 |
201 | # print(dir_path)
202 |
203 | if mf == 'bvh':
204 | pass
205 | elif mf == 'pos':
206 | cols = list(mocap.values.columns)
207 | for c in cols:
208 | if 'rotation' in c:
209 | cols.remove(c)
210 |
211 | data_csv = mocap.values.to_csv(index=False, columns=cols)
212 |
213 | if meta is not None:
214 | lines = [','.join(item) for item in meta.astype('str')]
215 | meta_csv = '[' + ','.join('[%s]'%l for l in lines) +']'
216 | else:
217 | meta_csv = '[]'
218 |
219 | data_assigned = data_template.replace('$$DATA$$', data_csv)
220 | data_assigned = data_assigned.replace('$$META$$', meta_csv)
221 | data_assigned = data_assigned.replace('$$CZ$$', str(camera_z))
222 | data_assigned = data_assigned.replace('$$SCALE$$', str(scale))
223 | data_assigned = data_assigned.replace('$$FRAMETIME$$', str(frame_time))
224 |
225 | else:
226 | return
227 |
228 |
229 |
230 | with open(os.path.join(dir_path, 'mocapplayer/data.js'), 'w') as oFile:
231 | oFile.write(data_assigned)
232 |
233 | url = '%s?&cz=200&order=xzyi&frame_time=%f&scale=%f'%(base_url, frame_time, scale)
234 | iframe = ''
235 | link = 'New Window'%url
236 | return IPython.display.HTML(iframe+link)
237 |
--------------------------------------------------------------------------------
/datasets/pymo/mocapplayer/playURL.html:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 | BVH Player
9 |
10 |
11 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
69 |
70 |
267 |
268 |
269 |
270 |
--------------------------------------------------------------------------------
/datasets/pymo/parsers.py:
--------------------------------------------------------------------------------
1 | '''
2 | BVH Parser Class
3 |
4 | By Omid Alemi
5 | Created: June 12, 2017
6 |
7 | Based on: https://gist.github.com/johnfredcee/2007503
8 |
9 | '''
10 | import re
11 | import numpy as np
12 | from .data import Joint, MocapData
13 |
14 | class BVHScanner():
15 | '''
16 | A wrapper class for re.Scanner
17 | '''
18 | def __init__(self):
19 |
20 | def identifier(scanner, token):
21 | return 'IDENT', token
22 |
23 | def operator(scanner, token):
24 | return 'OPERATOR', token
25 |
26 | def digit(scanner, token):
27 | return 'DIGIT', token
28 |
29 | def open_brace(scanner, token):
30 | return 'OPEN_BRACE', token
31 |
32 | def close_brace(scanner, token):
33 | return 'CLOSE_BRACE', token
34 |
35 | self.scanner = re.Scanner([
36 | (r'[a-zA-Z_]\w*', identifier),
37 | #(r'-*[0-9]+(\.[0-9]+)?', digit), # won't work for .34
38 | #(r'[-+]?[0-9]*\.?[0-9]+', digit), # won't work for 4.56e-2
39 | #(r'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit),
40 | (r'-*[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit),
41 | (r'}', close_brace),
42 | (r'}', close_brace),
43 | (r'{', open_brace),
44 | (r':', None),
45 | (r'\s+', None)
46 | ])
47 |
48 | def scan(self, stuff):
49 | return self.scanner.scan(stuff)
50 |
51 |
52 |
53 | class BVHParser():
54 | '''
55 | A class to parse a BVH file.
56 |
57 | Extracts the skeleton and channel values
58 | '''
59 | def __init__(self, filename=None):
60 | self.reset()
61 |
62 | def reset(self):
63 | self._skeleton = {}
64 | self.bone_context = []
65 | self._motion_channels = []
66 | self._motions = []
67 | self.current_token = 0
68 | self.framerate = 0.0
69 | self.root_name = ''
70 |
71 | self.scanner = BVHScanner()
72 |
73 | self.data = MocapData()
74 |
75 |
76 | def parse(self, filename, start=0, stop=-1):
77 | self.reset()
78 |
79 | with open(filename, 'r') as bvh_file:
80 | raw_contents = bvh_file.read()
81 | tokens, remainder = self.scanner.scan(raw_contents)
82 | self._parse_hierarchy(tokens)
83 | self.current_token = self.current_token + 1
84 | self._parse_motion(tokens, start, stop)
85 |
86 | self.data.skeleton = self._skeleton
87 | self.data.channel_names = self._motion_channels
88 | self.data.values = self._to_DataFrame()
89 | self.data.root_name = self.root_name
90 | self.data.framerate = self.framerate
91 |
92 | return self.data
93 |
94 | def _to_DataFrame(self):
95 | '''Returns all of the channels parsed from the file as a pandas DataFrame'''
96 |
97 | import pandas as pd
98 | time_index = pd.to_timedelta([f[0] for f in self._motions], unit='s')
99 | frames = [f[1] for f in self._motions]
100 | channels = np.asarray([[channel[2] for channel in frame] for frame in frames])
101 | column_names = ['%s_%s'%(c[0], c[1]) for c in self._motion_channels]
102 |
103 | return pd.DataFrame(data=channels, index=time_index, columns=column_names)
104 |
105 |
106 | def _new_bone(self, parent, name):
107 | bone = {'parent': parent, 'channels': [], 'offsets': [], 'order': '','children': []}
108 | return bone
109 |
110 | def _push_bone_context(self,name):
111 | self.bone_context.append(name)
112 |
113 | def _get_bone_context(self):
114 | return self.bone_context[len(self.bone_context)-1]
115 |
116 | def _pop_bone_context(self):
117 | self.bone_context = self.bone_context[:-1]
118 | return self.bone_context[len(self.bone_context)-1]
119 |
120 | def _read_offset(self, bvh, token_index):
121 | if bvh[token_index] != ('IDENT', 'OFFSET'):
122 | return None, None
123 | token_index = token_index + 1
124 | offsets = [0.0] * 3
125 | for i in range(3):
126 | offsets[i] = float(bvh[token_index][1])
127 | token_index = token_index + 1
128 | return offsets, token_index
129 |
130 | def _read_channels(self, bvh, token_index):
131 | if bvh[token_index] != ('IDENT', 'CHANNELS'):
132 | return None, None
133 | token_index = token_index + 1
134 | channel_count = int(bvh[token_index][1])
135 | token_index = token_index + 1
136 | channels = [""] * channel_count
137 | order = ""
138 | for i in range(channel_count):
139 | channels[i] = bvh[token_index][1]
140 | token_index = token_index + 1
141 | if(channels[i] == "Xrotation" or channels[i]== "Yrotation" or channels[i]== "Zrotation"):
142 | order += channels[i][0]
143 | else :
144 | order = ""
145 | return channels, token_index, order
146 |
147 | def _parse_joint(self, bvh, token_index):
148 | end_site = False
149 | joint_id = bvh[token_index][1]
150 | token_index = token_index + 1
151 | joint_name = bvh[token_index][1]
152 | token_index = token_index + 1
153 |
154 | parent_name = self._get_bone_context()
155 |
156 | if (joint_id == "End"):
157 | joint_name = parent_name+ '_Nub'
158 | end_site = True
159 | joint = self._new_bone(parent_name, joint_name)
160 | if bvh[token_index][0] != 'OPEN_BRACE':
161 | print('Was expecting brance, got ', bvh[token_index])
162 | return None
163 | token_index = token_index + 1
164 | offsets, token_index = self._read_offset(bvh, token_index)
165 | joint['offsets'] = offsets
166 | if not end_site:
167 | channels, token_index, order = self._read_channels(bvh, token_index)
168 | joint['channels'] = channels
169 | joint['order'] = order
170 | for channel in channels:
171 | self._motion_channels.append((joint_name, channel))
172 |
173 | self._skeleton[joint_name] = joint
174 | self._skeleton[parent_name]['children'].append(joint_name)
175 |
176 | while (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'JOINT') or (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'End'):
177 | self._push_bone_context(joint_name)
178 | token_index = self._parse_joint(bvh, token_index)
179 | self._pop_bone_context()
180 |
181 | if bvh[token_index][0] == 'CLOSE_BRACE':
182 | return token_index + 1
183 |
184 | print('Unexpected token ', bvh[token_index])
185 |
186 | def _parse_hierarchy(self, bvh):
187 | self.current_token = 0
188 | if bvh[self.current_token] != ('IDENT', 'HIERARCHY'):
189 | return None
190 | self.current_token = self.current_token + 1
191 | if bvh[self.current_token] != ('IDENT', 'ROOT'):
192 | return None
193 | self.current_token = self.current_token + 1
194 | if bvh[self.current_token][0] != 'IDENT':
195 | return None
196 |
197 | root_name = bvh[self.current_token][1]
198 | root_bone = self._new_bone(None, root_name)
199 | self.current_token = self.current_token + 2 #skipping open brace
200 | offsets, self.current_token = self._read_offset(bvh, self.current_token)
201 | channels, self.current_token, order = self._read_channels(bvh, self.current_token)
202 | root_bone['offsets'] = offsets
203 | root_bone['channels'] = channels
204 | root_bone['order'] = order
205 | self._skeleton[root_name] = root_bone
206 | self._push_bone_context(root_name)
207 |
208 | for channel in channels:
209 | self._motion_channels.append((root_name, channel))
210 |
211 | while bvh[self.current_token][1] == 'JOINT':
212 | self.current_token = self._parse_joint(bvh, self.current_token)
213 |
214 | self.root_name = root_name
215 |
216 | def _parse_motion(self, bvh, start, stop):
217 | if bvh[self.current_token][0] != 'IDENT':
218 | print('Unexpected text')
219 | return None
220 | if bvh[self.current_token][1] != 'MOTION':
221 | print('No motion section')
222 | return None
223 | self.current_token = self.current_token + 1
224 | if bvh[self.current_token][1] != 'Frames':
225 | return None
226 | self.current_token = self.current_token + 1
227 | frame_count = int(bvh[self.current_token][1])
228 |
229 | if stop<0 or stop>frame_count:
230 | stop = frame_count
231 |
232 | assert(start>=0)
233 | assert(start=start:
258 | self._motions[idx] = (frame_time, channel_values)
259 | frame_time = frame_time + frame_rate
260 | idx+=1
261 |
--------------------------------------------------------------------------------
/datasets/pymo/mocapplayer/libs/pace.min.js:
--------------------------------------------------------------------------------
1 | /*! pace 1.0.2 */
2 | (function(){var a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z,A,B,C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T,U,V,W,X=[].slice,Y={}.hasOwnProperty,Z=function(a,b){function c(){this.constructor=a}for(var d in b)Y.call(b,d)&&(a[d]=b[d]);return c.prototype=b.prototype,a.prototype=new c,a.__super__=b.prototype,a},$=[].indexOf||function(a){for(var b=0,c=this.length;c>b;b++)if(b in this&&this[b]===a)return b;return-1};for(u={catchupTime:100,initialRate:.03,minTime:250,ghostTime:100,maxProgressPerFrame:20,easeFactor:1.25,startOnPageLoad:!0,restartOnPushState:!0,restartOnRequestAfter:500,target:"body",elements:{checkInterval:100,selectors:["body"]},eventLag:{minSamples:10,sampleCount:3,lagThreshold:3},ajax:{trackMethods:["GET"],trackWebSockets:!0,ignoreURLs:[]}},C=function(){var a;return null!=(a="undefined"!=typeof performance&&null!==performance&&"function"==typeof performance.now?performance.now():void 0)?a:+new Date},E=window.requestAnimationFrame||window.mozRequestAnimationFrame||window.webkitRequestAnimationFrame||window.msRequestAnimationFrame,t=window.cancelAnimationFrame||window.mozCancelAnimationFrame,null==E&&(E=function(a){return setTimeout(a,50)},t=function(a){return clearTimeout(a)}),G=function(a){var b,c;return b=C(),(c=function(){var d;return d=C()-b,d>=33?(b=C(),a(d,function(){return E(c)})):setTimeout(c,33-d)})()},F=function(){var a,b,c;return c=arguments[0],b=arguments[1],a=3<=arguments.length?X.call(arguments,2):[],"function"==typeof c[b]?c[b].apply(c,a):c[b]},v=function(){var a,b,c,d,e,f,g;for(b=arguments[0],d=2<=arguments.length?X.call(arguments,1):[],f=0,g=d.length;g>f;f++)if(c=d[f])for(a in c)Y.call(c,a)&&(e=c[a],null!=b[a]&&"object"==typeof b[a]&&null!=e&&"object"==typeof e?v(b[a],e):b[a]=e);return b},q=function(a){var b,c,d,e,f;for(c=b=0,e=0,f=a.length;f>e;e++)d=a[e],c+=Math.abs(d),b++;return c/b},x=function(a,b){var c,d,e;if(null==a&&(a="options"),null==b&&(b=!0),e=document.querySelector("[data-pace-"+a+"]")){if(c=e.getAttribute("data-pace-"+a),!b)return c;try{return JSON.parse(c)}catch(f){return d=f,"undefined"!=typeof console&&null!==console?console.error("Error parsing inline pace options",d):void 0}}},g=function(){function a(){}return a.prototype.on=function(a,b,c,d){var e;return null==d&&(d=!1),null==this.bindings&&(this.bindings={}),null==(e=this.bindings)[a]&&(e[a]=[]),this.bindings[a].push({handler:b,ctx:c,once:d})},a.prototype.once=function(a,b,c){return this.on(a,b,c,!0)},a.prototype.off=function(a,b){var c,d,e;if(null!=(null!=(d=this.bindings)?d[a]:void 0)){if(null==b)return delete this.bindings[a];for(c=0,e=[];cQ;Q++)K=U[Q],D[K]===!0&&(D[K]=u[K]);i=function(a){function b(){return V=b.__super__.constructor.apply(this,arguments)}return Z(b,a),b}(Error),b=function(){function a(){this.progress=0}return a.prototype.getElement=function(){var a;if(null==this.el){if(a=document.querySelector(D.target),!a)throw new i;this.el=document.createElement("div"),this.el.className="pace pace-active",document.body.className=document.body.className.replace(/pace-done/g,""),document.body.className+=" pace-running",this.el.innerHTML='\n',null!=a.firstChild?a.insertBefore(this.el,a.firstChild):a.appendChild(this.el)}return this.el},a.prototype.finish=function(){var a;return a=this.getElement(),a.className=a.className.replace("pace-active",""),a.className+=" pace-inactive",document.body.className=document.body.className.replace("pace-running",""),document.body.className+=" pace-done"},a.prototype.update=function(a){return this.progress=a,this.render()},a.prototype.destroy=function(){try{this.getElement().parentNode.removeChild(this.getElement())}catch(a){i=a}return this.el=void 0},a.prototype.render=function(){var a,b,c,d,e,f,g;if(null==document.querySelector(D.target))return!1;for(a=this.getElement(),d="translate3d("+this.progress+"%, 0, 0)",g=["webkitTransform","msTransform","transform"],e=0,f=g.length;f>e;e++)b=g[e],a.children[0].style[b]=d;return(!this.lastRenderedProgress||this.lastRenderedProgress|0!==this.progress|0)&&(a.children[0].setAttribute("data-progress-text",""+(0|this.progress)+"%"),this.progress>=100?c="99":(c=this.progress<10?"0":"",c+=0|this.progress),a.children[0].setAttribute("data-progress",""+c)),this.lastRenderedProgress=this.progress},a.prototype.done=function(){return this.progress>=100},a}(),h=function(){function a(){this.bindings={}}return a.prototype.trigger=function(a,b){var c,d,e,f,g;if(null!=this.bindings[a]){for(f=this.bindings[a],g=[],d=0,e=f.length;e>d;d++)c=f[d],g.push(c.call(this,b));return g}},a.prototype.on=function(a,b){var c;return null==(c=this.bindings)[a]&&(c[a]=[]),this.bindings[a].push(b)},a}(),P=window.XMLHttpRequest,O=window.XDomainRequest,N=window.WebSocket,w=function(a,b){var c,d,e;e=[];for(d in b.prototype)try{e.push(null==a[d]&&"function"!=typeof b[d]?"function"==typeof Object.defineProperty?Object.defineProperty(a,d,{get:function(){return b.prototype[d]},configurable:!0,enumerable:!0}):a[d]=b.prototype[d]:void 0)}catch(f){c=f}return e},A=[],j.ignore=function(){var a,b,c;return b=arguments[0],a=2<=arguments.length?X.call(arguments,1):[],A.unshift("ignore"),c=b.apply(null,a),A.shift(),c},j.track=function(){var a,b,c;return b=arguments[0],a=2<=arguments.length?X.call(arguments,1):[],A.unshift("track"),c=b.apply(null,a),A.shift(),c},J=function(a){var b;if(null==a&&(a="GET"),"track"===A[0])return"force";if(!A.length&&D.ajax){if("socket"===a&&D.ajax.trackWebSockets)return!0;if(b=a.toUpperCase(),$.call(D.ajax.trackMethods,b)>=0)return!0}return!1},k=function(a){function b(){var a,c=this;b.__super__.constructor.apply(this,arguments),a=function(a){var b;return b=a.open,a.open=function(d,e){return J(d)&&c.trigger("request",{type:d,url:e,request:a}),b.apply(a,arguments)}},window.XMLHttpRequest=function(b){var c;return c=new P(b),a(c),c};try{w(window.XMLHttpRequest,P)}catch(d){}if(null!=O){window.XDomainRequest=function(){var b;return b=new O,a(b),b};try{w(window.XDomainRequest,O)}catch(d){}}if(null!=N&&D.ajax.trackWebSockets){window.WebSocket=function(a,b){var d;return d=null!=b?new N(a,b):new N(a),J("socket")&&c.trigger("request",{type:"socket",url:a,protocols:b,request:d}),d};try{w(window.WebSocket,N)}catch(d){}}}return Z(b,a),b}(h),R=null,y=function(){return null==R&&(R=new k),R},I=function(a){var b,c,d,e;for(e=D.ajax.ignoreURLs,c=0,d=e.length;d>c;c++)if(b=e[c],"string"==typeof b){if(-1!==a.indexOf(b))return!0}else if(b.test(a))return!0;return!1},y().on("request",function(b){var c,d,e,f,g;return f=b.type,e=b.request,g=b.url,I(g)?void 0:j.running||D.restartOnRequestAfter===!1&&"force"!==J(f)?void 0:(d=arguments,c=D.restartOnRequestAfter||0,"boolean"==typeof c&&(c=0),setTimeout(function(){var b,c,g,h,i,k;if(b="socket"===f?e.readyState<2:0<(h=e.readyState)&&4>h){for(j.restart(),i=j.sources,k=[],c=0,g=i.length;g>c;c++){if(K=i[c],K instanceof a){K.watch.apply(K,d);break}k.push(void 0)}return k}},c))}),a=function(){function a(){var a=this;this.elements=[],y().on("request",function(){return a.watch.apply(a,arguments)})}return a.prototype.watch=function(a){var b,c,d,e;return d=a.type,b=a.request,e=a.url,I(e)?void 0:(c="socket"===d?new n(b):new o(b),this.elements.push(c))},a}(),o=function(){function a(a){var b,c,d,e,f,g,h=this;if(this.progress=0,null!=window.ProgressEvent)for(c=null,a.addEventListener("progress",function(a){return h.progress=a.lengthComputable?100*a.loaded/a.total:h.progress+(100-h.progress)/2},!1),g=["load","abort","timeout","error"],d=0,e=g.length;e>d;d++)b=g[d],a.addEventListener(b,function(){return h.progress=100},!1);else f=a.onreadystatechange,a.onreadystatechange=function(){var b;return 0===(b=a.readyState)||4===b?h.progress=100:3===a.readyState&&(h.progress=50),"function"==typeof f?f.apply(null,arguments):void 0}}return a}(),n=function(){function a(a){var b,c,d,e,f=this;for(this.progress=0,e=["error","open"],c=0,d=e.length;d>c;c++)b=e[c],a.addEventListener(b,function(){return f.progress=100},!1)}return a}(),d=function(){function a(a){var b,c,d,f;for(null==a&&(a={}),this.elements=[],null==a.selectors&&(a.selectors=[]),f=a.selectors,c=0,d=f.length;d>c;c++)b=f[c],this.elements.push(new e(b))}return a}(),e=function(){function a(a){this.selector=a,this.progress=0,this.check()}return a.prototype.check=function(){var a=this;return document.querySelector(this.selector)?this.done():setTimeout(function(){return a.check()},D.elements.checkInterval)},a.prototype.done=function(){return this.progress=100},a}(),c=function(){function a(){var a,b,c=this;this.progress=null!=(b=this.states[document.readyState])?b:100,a=document.onreadystatechange,document.onreadystatechange=function(){return null!=c.states[document.readyState]&&(c.progress=c.states[document.readyState]),"function"==typeof a?a.apply(null,arguments):void 0}}return a.prototype.states={loading:0,interactive:50,complete:100},a}(),f=function(){function a(){var a,b,c,d,e,f=this;this.progress=0,a=0,e=[],d=0,c=C(),b=setInterval(function(){var g;return g=C()-c-50,c=C(),e.push(g),e.length>D.eventLag.sampleCount&&e.shift(),a=q(e),++d>=D.eventLag.minSamples&&a=100&&(this.done=!0),b===this.last?this.sinceLastUpdate+=a:(this.sinceLastUpdate&&(this.rate=(b-this.last)/this.sinceLastUpdate),this.catchup=(b-this.progress)/D.catchupTime,this.sinceLastUpdate=0,this.last=b),b>this.progress&&(this.progress+=this.catchup*a),c=1-Math.pow(this.progress/100,D.easeFactor),this.progress+=c*this.rate*a,this.progress=Math.min(this.lastProgress+D.maxProgressPerFrame,this.progress),this.progress=Math.max(0,this.progress),this.progress=Math.min(100,this.progress),this.lastProgress=this.progress,this.progress},a}(),L=null,H=null,r=null,M=null,p=null,s=null,j.running=!1,z=function(){return D.restartOnPushState?j.restart():void 0},null!=window.history.pushState&&(T=window.history.pushState,window.history.pushState=function(){return z(),T.apply(window.history,arguments)}),null!=window.history.replaceState&&(W=window.history.replaceState,window.history.replaceState=function(){return z(),W.apply(window.history,arguments)}),l={ajax:a,elements:d,document:c,eventLag:f},(B=function(){var a,c,d,e,f,g,h,i;for(j.sources=L=[],g=["ajax","elements","document","eventLag"],c=0,e=g.length;e>c;c++)a=g[c],D[a]!==!1&&L.push(new l[a](D[a]));for(i=null!=(h=D.extraSources)?h:[],d=0,f=i.length;f>d;d++)K=i[d],L.push(new K(D));return j.bar=r=new b,H=[],M=new m})(),j.stop=function(){return j.trigger("stop"),j.running=!1,r.destroy(),s=!0,null!=p&&("function"==typeof t&&t(p),p=null),B()},j.restart=function(){return j.trigger("restart"),j.stop(),j.start()},j.go=function(){var a;return j.running=!0,r.render(),a=C(),s=!1,p=G(function(b,c){var d,e,f,g,h,i,k,l,n,o,p,q,t,u,v,w;for(l=100-r.progress,e=p=0,f=!0,i=q=0,u=L.length;u>q;i=++q)for(K=L[i],o=null!=H[i]?H[i]:H[i]=[],h=null!=(w=K.elements)?w:[K],k=t=0,v=h.length;v>t;k=++t)g=h[k],n=null!=o[k]?o[k]:o[k]=new m(g),f&=n.done,n.done||(e++,p+=n.tick(b));return d=p/e,r.update(M.tick(b,d)),r.done()||f||s?(r.update(100),j.trigger("done"),setTimeout(function(){return r.finish(),j.running=!1,j.trigger("hide")},Math.max(D.ghostTime,Math.max(D.minTime-(C()-a),0)))):c()})},j.start=function(a){v(D,a),j.running=!0;try{r.render()}catch(b){i=b}return document.querySelector(".pace")?(j.trigger("start"),j.go()):setTimeout(j.start,50)},"function"==typeof define&&define.amd?define(["pace"],function(){return j}):"object"==typeof exports?module.exports=j:D.startOnPageLoad&&j.start()}).call(this);
--------------------------------------------------------------------------------
/datasets/pymo/mocapplayer/libs/papaparse.min.js:
--------------------------------------------------------------------------------
1 | /*!
2 | Papa Parse
3 | v4.1.2
4 | https://github.com/mholt/PapaParse
5 | */
6 | !function(e){"use strict";function t(t,r){if(r=r||{},r.worker&&S.WORKERS_SUPPORTED){var n=f();return n.userStep=r.step,n.userChunk=r.chunk,n.userComplete=r.complete,n.userError=r.error,r.step=m(r.step),r.chunk=m(r.chunk),r.complete=m(r.complete),r.error=m(r.error),delete r.worker,void n.postMessage({input:t,config:r,workerId:n.id})}var o=null;return"string"==typeof t?o=r.download?new i(r):new a(r):(e.File&&t instanceof File||t instanceof Object)&&(o=new s(r)),o.stream(t)}function r(e,t){function r(){"object"==typeof t&&("string"==typeof t.delimiter&&1==t.delimiter.length&&-1==S.BAD_DELIMITERS.indexOf(t.delimiter)&&(u=t.delimiter),("boolean"==typeof t.quotes||t.quotes instanceof Array)&&(o=t.quotes),"string"==typeof t.newline&&(h=t.newline))}function n(e){if("object"!=typeof e)return[];var t=[];for(var r in e)t.push(r);return t}function i(e,t){var r="";"string"==typeof e&&(e=JSON.parse(e)),"string"==typeof t&&(t=JSON.parse(t));var n=e instanceof Array&&e.length>0,i=!(t[0]instanceof Array);if(n){for(var a=0;a0&&(r+=u),r+=s(e[a],a);t.length>0&&(r+=h)}for(var o=0;oc;c++){c>0&&(r+=u);var d=n&&i?e[c]:c;r+=s(t[o][d],c)}o-1||" "==e.charAt(0)||" "==e.charAt(e.length-1);return r?'"'+e+'"':e}function a(e,t){for(var r=0;r-1)return!0;return!1}var o=!1,u=",",h="\r\n";if(r(),"string"==typeof e&&(e=JSON.parse(e)),e instanceof Array){if(!e.length||e[0]instanceof Array)return i(null,e);if("object"==typeof e[0])return i(n(e[0]),e)}else if("object"==typeof e)return"string"==typeof e.data&&(e.data=JSON.parse(e.data)),e.data instanceof Array&&(e.fields||(e.fields=e.data[0]instanceof Array?e.fields:n(e.data[0])),e.data[0]instanceof Array||"object"==typeof e.data[0]||(e.data=[e.data])),i(e.fields||[],e.data||[]);throw"exception: Unable to serialize unrecognized input"}function n(t){function r(e){var t=_(e);t.chunkSize=parseInt(t.chunkSize),e.step||e.chunk||(t.chunkSize=null),this._handle=new o(t),this._handle.streamer=this,this._config=t}this._handle=null,this._paused=!1,this._finished=!1,this._input=null,this._baseIndex=0,this._partialLine="",this._rowCount=0,this._start=0,this._nextChunk=null,this.isFirstChunk=!0,this._completeResults={data:[],errors:[],meta:{}},r.call(this,t),this.parseChunk=function(t){if(this.isFirstChunk&&m(this._config.beforeFirstChunk)){var r=this._config.beforeFirstChunk(t);void 0!==r&&(t=r)}this.isFirstChunk=!1;var n=this._partialLine+t;this._partialLine="";var i=this._handle.parse(n,this._baseIndex,!this._finished);if(!this._handle.paused()&&!this._handle.aborted()){var s=i.meta.cursor;this._finished||(this._partialLine=n.substring(s-this._baseIndex),this._baseIndex=s),i&&i.data&&(this._rowCount+=i.data.length);var a=this._finished||this._config.preview&&this._rowCount>=this._config.preview;if(y)e.postMessage({results:i,workerId:S.WORKER_ID,finished:a});else if(m(this._config.chunk)){if(this._config.chunk(i,this._handle),this._paused)return;i=void 0,this._completeResults=void 0}return this._config.step||this._config.chunk||(this._completeResults.data=this._completeResults.data.concat(i.data),this._completeResults.errors=this._completeResults.errors.concat(i.errors),this._completeResults.meta=i.meta),!a||!m(this._config.complete)||i&&i.meta.aborted||this._config.complete(this._completeResults),a||i&&i.meta.paused||this._nextChunk(),i}},this._sendError=function(t){m(this._config.error)?this._config.error(t):y&&this._config.error&&e.postMessage({workerId:S.WORKER_ID,error:t,finished:!1})}}function i(e){function t(e){var t=e.getResponseHeader("Content-Range");return parseInt(t.substr(t.lastIndexOf("/")+1))}e=e||{},e.chunkSize||(e.chunkSize=S.RemoteChunkSize),n.call(this,e);var r;this._nextChunk=k?function(){this._readChunk(),this._chunkLoaded()}:function(){this._readChunk()},this.stream=function(e){this._input=e,this._nextChunk()},this._readChunk=function(){if(this._finished)return void this._chunkLoaded();if(r=new XMLHttpRequest,k||(r.onload=g(this._chunkLoaded,this),r.onerror=g(this._chunkError,this)),r.open("GET",this._input,!k),this._config.chunkSize){var e=this._start+this._config.chunkSize-1;r.setRequestHeader("Range","bytes="+this._start+"-"+e),r.setRequestHeader("If-None-Match","webkit-no-cache")}try{r.send()}catch(t){this._chunkError(t.message)}k&&0==r.status?this._chunkError():this._start+=this._config.chunkSize},this._chunkLoaded=function(){if(4==r.readyState){if(r.status<200||r.status>=400)return void this._chunkError();this._finished=!this._config.chunkSize||this._start>t(r),this.parseChunk(r.responseText)}},this._chunkError=function(e){var t=r.statusText||e;this._sendError(t)}}function s(e){e=e||{},e.chunkSize||(e.chunkSize=S.LocalChunkSize),n.call(this,e);var t,r,i="undefined"!=typeof FileReader;this.stream=function(e){this._input=e,r=e.slice||e.webkitSlice||e.mozSlice,i?(t=new FileReader,t.onload=g(this._chunkLoaded,this),t.onerror=g(this._chunkError,this)):t=new FileReaderSync,this._nextChunk()},this._nextChunk=function(){this._finished||this._config.preview&&!(this._rowCount=this._input.size,this.parseChunk(e.target.result)},this._chunkError=function(){this._sendError(t.error)}}function a(e){e=e||{},n.call(this,e);var t,r;this.stream=function(e){return t=e,r=e,this._nextChunk()},this._nextChunk=function(){if(!this._finished){var e=this._config.chunkSize,t=e?r.substr(0,e):r;return r=e?r.substr(e):"",this._finished=!r,this.parseChunk(t)}}}function o(e){function t(){if(b&&d&&(h("Delimiter","UndetectableDelimiter","Unable to auto-detect delimiting character; defaulted to '"+S.DefaultDelimiter+"'"),d=!1),e.skipEmptyLines)for(var t=0;t=y.length?(r.__parsed_extra||(r.__parsed_extra=[]),r.__parsed_extra.push(b.data[t][n])):r[y[n]]=b.data[t][n])}e.header&&(b.data[t]=r,n>y.length?h("FieldMismatch","TooManyFields","Too many fields: expected "+y.length+" fields but parsed "+n,t):n1&&(h+=Math.abs(l-i),i=l):i=l}c.data.length>0&&(f/=c.data.length),("undefined"==typeof n||n>h)&&f>1.99&&(n=h,r=o)}return e.delimiter=r,{successful:!!r,bestDelimiter:r}}function a(e){e=e.substr(0,1048576);var t=e.split("\r");if(1==t.length)return"\n";for(var r=0,n=0;n=t.length/2?"\r\n":"\r"}function o(e){var t=l.test(e);return t?parseFloat(e):e}function h(e,t,r,n){b.errors.push({type:e,code:t,message:r,row:n})}var f,c,d,l=/^\s*-?(\d*\.?\d+|\d+\.?\d*)(e[-+]?\d+)?\s*$/i,p=this,g=0,v=!1,k=!1,y=[],b={data:[],errors:[],meta:{}};if(m(e.step)){var R=e.step;e.step=function(n){if(b=n,r())t();else{if(t(),0==b.data.length)return;g+=n.data.length,e.preview&&g>e.preview?c.abort():R(b,p)}}}this.parse=function(r,n,i){if(e.newline||(e.newline=a(r)),d=!1,!e.delimiter){var o=s(r);o.successful?e.delimiter=o.bestDelimiter:(d=!0,e.delimiter=S.DefaultDelimiter),b.meta.delimiter=e.delimiter}var h=_(e);return e.preview&&e.header&&h.preview++,f=r,c=new u(h),b=c.parse(f,n,i),t(),v?{meta:{paused:!0}}:b||{meta:{paused:!1}}},this.paused=function(){return v},this.pause=function(){v=!0,c.abort(),f=f.substr(c.getCharIndex())},this.resume=function(){v=!1,p.streamer.parseChunk(f)},this.aborted=function(){return k},this.abort=function(){k=!0,c.abort(),b.meta.aborted=!0,m(e.complete)&&e.complete(b),f=""}}function u(e){e=e||{};var t=e.delimiter,r=e.newline,n=e.comments,i=e.step,s=e.preview,a=e.fastMode;if(("string"!=typeof t||S.BAD_DELIMITERS.indexOf(t)>-1)&&(t=","),n===t)throw"Comment character same as delimiter";n===!0?n="#":("string"!=typeof n||S.BAD_DELIMITERS.indexOf(n)>-1)&&(n=!1),"\n"!=r&&"\r"!=r&&"\r\n"!=r&&(r="\n");var o=0,u=!1;this.parse=function(e,h,f){function c(e){b.push(e),S=o}function d(t){return f?p():("undefined"==typeof t&&(t=e.substr(o)),w.push(t),o=g,c(w),y&&_(),p())}function l(t){o=t,c(w),w=[],O=e.indexOf(r,o)}function p(e){return{data:b,errors:R,meta:{delimiter:t,linebreak:r,aborted:u,truncated:!!e,cursor:S+(h||0)}}}function _(){i(p()),b=[],R=[]}if("string"!=typeof e)throw"Input must be a string";var g=e.length,m=t.length,v=r.length,k=n.length,y="function"==typeof i;o=0;var b=[],R=[],w=[],S=0;if(!e)return p();if(a||a!==!1&&-1===e.indexOf('"')){for(var C=e.split(r),E=0;E=s)return b=b.slice(0,s),p(!0)}}return p()}for(var x=e.indexOf(t,o),O=e.indexOf(r,o);;)if('"'!=e[o])if(n&&0===w.length&&e.substr(o,k)===n){if(-1==O)return p();o=O+v,O=e.indexOf(r,o),x=e.indexOf(t,o)}else if(-1!==x&&(O>x||-1===O))w.push(e.substring(o,x)),o=x+m,x=e.indexOf(t,o);else{if(-1===O)break;if(w.push(e.substring(o,O)),l(O+v),y&&(_(),u))return p();if(s&&b.length>=s)return p(!0)}else{var I=o;for(o++;;){var I=e.indexOf('"',I+1);if(-1===I)return f||R.push({type:"Quotes",code:"MissingQuotes",message:"Quoted field unterminated",row:b.length,index:o}),d();if(I===g-1){var D=e.substring(o,I).replace(/""/g,'"');return d(D)}if('"'!=e[I+1]){if(e[I+1]==t){w.push(e.substring(o,I).replace(/""/g,'"')),o=I+1+m,x=e.indexOf(t,o),O=e.indexOf(r,o);break}if(e.substr(I+1,v)===r){if(w.push(e.substring(o,I).replace(/""/g,'"')),l(I+1+v),x=e.indexOf(t,o),y&&(_(),u))return p();if(s&&b.length>=s)return p(!0);break}}else I++}}return d()},this.abort=function(){u=!0},this.getCharIndex=function(){return o}}function h(){var e=document.getElementsByTagName("script");return e.length?e[e.length-1].src:""}function f(){if(!S.WORKERS_SUPPORTED)return!1;if(!b&&null===S.SCRIPT_PATH)throw new Error("Script path cannot be determined automatically when Papa Parse is loaded asynchronously. You need to set Papa.SCRIPT_PATH manually.");var t=S.SCRIPT_PATH||v;t+=(-1!==t.indexOf("?")?"&":"?")+"papaworker";var r=new e.Worker(t);return r.onmessage=c,r.id=w++,R[r.id]=r,r}function c(e){var t=e.data,r=R[t.workerId],n=!1;if(t.error)r.userError(t.error,t.file);else if(t.results&&t.results.data){var i=function(){n=!0,d(t.workerId,{data:[],errors:[],meta:{aborted:!0}})},s={abort:i,pause:l,resume:l};if(m(r.userStep)){for(var a=0;a 1 or opt.multiprocessing_distributed
75 |
76 | if torch.cuda.is_available():
77 | ngpus_per_node = torch.cuda.device_count()
78 | else:
79 | ngpus_per_node = 1
80 | if opt.multiprocessing_distributed:
81 | # Since we have ngpus_per_node processes per node, the total world_size
82 | # needs to be adjusted accordingly
83 | opt.world_size = ngpus_per_node * opt.world_size
84 | # Use torch.multiprocessing.spawn to launch distributed processes: the
85 | # main_worker process function
86 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt))
87 | else:
88 | # Simply call main_worker function
89 | main_worker(opt.gpu_id, ngpus_per_node, opt)
90 |
91 |
92 |
93 | def main_worker(gpu_id, ngpus_per_node, opt):
94 | # rank, world_size = get_dist_info()
95 | opt.gpu_id = gpu_id
96 |
97 | if opt.gpu_id is not None:
98 | print("Use GPU: {}".format(opt.gpu_id))
99 |
100 | if opt.distributed:
101 | if opt.dist_url == "env://" and opt.rank == -1:
102 | opt.rank = int(os.environ["RANK"])
103 | if opt.multiprocessing_distributed:
104 | # For multiprocessing distributed training, rank needs to be the
105 | # global rank among all the processes
106 | opt.rank = opt.rank * ngpus_per_node + gpu_id
107 | dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
108 | world_size=opt.world_size, rank=opt.rank)
109 |
110 |
111 | opt.device = torch.device("cuda")
112 | torch.autograd.set_detect_anomaly(True)
113 |
114 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
115 | opt.model_dir = pjoin(opt.save_root, 'model')
116 | opt.meta_dir = pjoin(opt.save_root, 'meta')
117 |
118 | # if opt.rank == 0:
119 | os.makedirs(opt.model_dir, exist_ok=True)
120 | os.makedirs(opt.meta_dir, exist_ok=True)
121 | if opt.world_size > 1:
122 | dist.barrier()
123 |
124 | if opt.dataset_name.lower() == 'beat':
125 | opt.data_root = 'data/BEAT'
126 | opt.fps = 15
127 | opt.net_dim_pose = 192 # body: [16, 34, 141], expression: [16, 34, 51], in_audio: [16, 36266]
128 | opt.split_pos = 141
129 | opt.dim_pose = 141
130 | if opt.remove_hand:
131 | opt.dim_pose = 33
132 | opt.expression_dim = 51
133 |
134 | if opt.expression_only or opt.gesCondition_expression_only:
135 | opt.net_dim_pose = opt.expression_dim # expression
136 | opt.e_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/weights/face_300.bin'
137 | elif opt.gesture_only or opt.expCondition_gesture_only != None or \
138 | opt.textExpEmoCondition_gesture_only:
139 | opt.net_dim_pose = opt.dim_pose # gesture
140 | if opt.axis_angle:
141 | opt.e_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/weights/ges_axis_angle_300.bin'
142 | else:
143 | opt.e_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/weights/ae_300.bin'
144 | else:
145 | opt.net_dim_pose = opt.dim_pose + opt.expression_dim # gesture + expression
146 | if opt.axis_angle:
147 | opt.e_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/weights/GesAxisAngle_Face_300.bin'
148 | else:
149 | raise NotImplementedError
150 |
151 | opt.audio_dim = 128
152 | if opt.use_aud_feat:
153 | opt.audio_dim = 1024
154 | opt.style_dim = 30 # totally 30 subjects
155 | opt.speaker_dim = 30
156 | opt.word_index_num = 5793
157 | opt.word_dims = 300
158 | opt.word_f = 128
159 | opt.emotion_f = 8
160 | opt.emotion_dims = 8
161 | opt.freeze_wordembed = False
162 | opt.hidden_size = 256
163 | opt.n_layer = 4
164 |
165 | if opt.n_poses == 150:
166 | opt.stride = 50
167 | elif opt.n_poses == 34:
168 | opt.stride = 10
169 | opt.pose_fps = 15
170 | opt.vae_length = 300
171 | opt.new_cache = False
172 | opt.audio_norm = False
173 | opt.facial_norm = True
174 | opt.pose_norm = True
175 | opt.train_data_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/train/'
176 | opt.val_data_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/val/'
177 | opt.test_data_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/test/'
178 | opt.mean_pose_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/train/'
179 | opt.std_pose_path = f'data/BEAT/beat_cache/{opt.beat_cache_name}/train/'
180 | opt.multi_length_training = [1.0]
181 | opt.audio_rep = 'wave16k'
182 | opt.facial_rep = 'facial52'
183 | opt.speaker_id = 'id'
184 | opt.pose_rep = 'bvh_rot'
185 | opt.word_rep = 'text'
186 | opt.sem_rep = 'sem'
187 | opt.emo_rep = 'emo'
188 |
189 | elif opt.dataset_name.lower() == 'talkshow':
190 | opt.talkshow_config = 'options/talkshow_configs/body_pixel.json'
191 | opt.speaker_dim = 4
192 | opt.fps = 30
193 | opt.dim_pose = 129
194 | opt.split_pos = 129
195 | if opt.remove_hand:
196 | opt.dim_pose = 39
197 | opt.expression_dim = 103
198 | if opt.ablation == "reverse_ges2exp":
199 | opt.expression_dim, opt.dim_pose = opt.dim_pose, opt.expression_dim
200 | if opt.expression_only or opt.gesCondition_expression_only:
201 | opt.net_dim_pose = opt.expression_dim # expression
202 | opt.e_path = f'data/SHOW/ae_weights/expression.pth.tar'
203 | elif opt.gesture_only or opt.expCondition_gesture_only != None:
204 | opt.net_dim_pose = opt.dim_pose # gesture
205 | opt.e_path = f'data/SHOW/ae_weights/gesture.pth.tar'
206 | else:
207 | opt.net_dim_pose = opt.dim_pose + opt.expression_dim # gesture + expression
208 | opt.e_path = f'data/SHOW/ae_weights/gesture_expression.pth.tar'
209 |
210 | if opt.audio_feat == 'mfcc':
211 | opt.audio_dim = 64
212 | elif opt.audio_feat == 'mel':
213 | opt.audio_dim = 128
214 | elif opt.audio_feat == 'raw':
215 | opt.audio_dim = 1
216 | elif opt.audio_feat == 'hubert':
217 | opt.audio_dim = 1024
218 | opt.style_dim = 4
219 | opt.speaker_dim = 4
220 | opt.n_poses = 88
221 | opt.pose_fps = 30
222 | opt.vae_length = 300
223 |
224 | else:
225 | raise KeyError('Dataset Does Not Exist')
226 |
227 |
228 |
229 | print("=> creating model '{}'".format(opt.model_base))
230 | model = build_models(opt, opt.net_dim_pose, opt.audio_dim, opt.audio_latent_dim, opt.style_dim)
231 |
232 | if opt.no_fgd == False:
233 | eval_model = build_fgd_val_model(opt)
234 | else:
235 | eval_model = None
236 |
237 | if not torch.cuda.is_available() and not torch.backends.mps.is_available():
238 | print('using CPU, this will be slow')
239 | elif opt.distributed:
240 | # For multiprocessing distributed, DistributedDataParallel constructor
241 | # should always set the single device scope, otherwise,
242 | # DistributedDataParallel will use all available devices.
243 | if torch.cuda.is_available():
244 | if opt.gpu_id is not None:
245 | torch.cuda.set_device(opt.gpu_id)
246 | model.cuda(opt.gpu_id)
247 | # When using a single GPU per process and per
248 | # DistributedDataParallel, we need to divide the batch size
249 | # ourselves based on the total number of GPUs of the current node.
250 | opt.batch_size = int(opt.batch_size / ngpus_per_node)
251 | opt.workers = int((opt.workers + ngpus_per_node - 1) / ngpus_per_node)
252 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu_id], find_unused_parameters=False)
253 |
254 | if not opt.no_fgd:
255 | eval_model.cuda(opt.gpu_id)
256 | eval_model = torch.nn.parallel.DistributedDataParallel(eval_model, device_ids=[opt.gpu_id], find_unused_parameters=False)
257 | else:
258 | # DistributedDataParallel will divide and allocate batch_size to all
259 | # available GPUs if device_ids are not set
260 | # model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
261 | model = torch.nn.parallel.DistributedDataParallel(model)
262 | if not opt.no_fgd:
263 | # eval_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(eval_model)
264 | eval_model = torch.nn.parallel.DistributedDataParallel(eval_model, device_ids=[opt.rank], broadcast_buffers=True, find_unused_parameters=False).to(opt.rank)
265 | elif opt.gpu_id is not None and torch.cuda.is_available():
266 | torch.cuda.set_device(opt.gpu_id)
267 | model = model.cuda(opt.gpu_id)
268 | if not opt.no_fgd:
269 | eval_model = eval_model.cuda(opt.gpu_id)
270 | elif torch.backends.mps.is_available():
271 | device = torch.device("mps")
272 | model = model.to(device)
273 | if not opt.no_fgd:
274 | eval_model = eval_model.to(device)
275 | else:
276 | # Use single gpu
277 | model = model.cuda()
278 | if not opt.no_fgd:
279 | eval_model = eval_model.cuda()
280 |
281 | if torch.cuda.is_available():
282 | if opt.gpu_id:
283 | device = torch.device('cuda:{}'.format(opt.gpu_id))
284 | else:
285 | device = torch.device("cuda")
286 | elif torch.backends.mps.is_available():
287 | device = torch.device("mps")
288 | else:
289 | device = torch.device("cpu")
290 |
291 | if opt.dataset_name == 'beat':
292 | runner = DDPMTrainer_beat(opt, model, eval_model=eval_model)
293 | elif opt.dataset_name == 'talkshow':
294 | runner = DDPMTrainer_show(opt, model, eval_model=eval_model)
295 | else:
296 | runner = DDPMTrainer(opt, model)
297 |
298 | if opt.mode == "train":
299 | if opt.dataset_name.lower() == 'beat':
300 | train_dataset = __import__(f"datasets.{opt.dataset_name}", fromlist=["something"]).BeatDataset(opt, "train")
301 | val_dataset = __import__(f"datasets.{opt.dataset_name}", fromlist=["something"]).BeatDataset(opt, "val")
302 |
303 | elif opt.dataset_name.lower() == 'talkshow':
304 | train_dataset = ShowDataset(opt, 'data/SHOW/cached_data/talkshow_train_cache')
305 | val_dataset = ShowDataset(opt, 'data/SHOW/cached_data/talkshow_val_cache')
306 |
307 |
308 | runner.train(train_dataset, val_dataset)
309 |
310 | elif "test" in opt.mode:
311 | if opt.dataset_name.lower() == 'beat':
312 | test_dataset = __import__(f"datasets.{opt.dataset_name}", fromlist=["something"]).BeatDataset(opt, "test")
313 |
314 | elif opt.dataset_name.lower() == 'talkshow':
315 | test_dataset = ShowDataset(opt, 'data/SHOW/cached_data/talkshow_test_cache')
316 |
317 | if opt.mode == "test":
318 | results_dir = runner.test(test_dataset)
319 | elif opt.mode == "test_arbitrary_len":
320 | opt.batch_size = 1
321 | results_dir = runner.test_arbitrary_len(test_dataset)
322 | elif opt.mode == "test_custom_audio":
323 | results_dir = runner.test_custom_aud(opt.test_audio_path, test_dataset)
324 | print(results_dir)
325 |
326 |
327 |
328 |
329 | if __name__ == '__main__':
330 | main()
331 |
332 |
--------------------------------------------------------------------------------
/models/ddpm_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 |
4 | from . import gaussian_diffusion as gd
5 | from .respace import SpacedDiffusion, space_timesteps
6 | # from .unet import SuperResModel, UNetModel, EncoderUNetModel
7 |
8 | NUM_CLASSES = 1000
9 |
10 |
11 | def diffusion_defaults():
12 | """
13 | Defaults for image and classifier training.
14 | """
15 | return dict(
16 | learn_sigma=False,
17 | diffusion_steps=1000,
18 | noise_schedule="linear",
19 | timestep_respacing="",
20 | use_kl=False,
21 | predict_xstart=False,
22 | rescale_timesteps=False,
23 | rescale_learned_sigmas=False,
24 | )
25 |
26 |
27 | def classifier_defaults():
28 | """
29 | Defaults for classifier models.
30 | """
31 | return dict(
32 | image_size=64,
33 | classifier_use_fp16=False,
34 | classifier_width=128,
35 | classifier_depth=2,
36 | classifier_attention_resolutions="32,16,8", # 16
37 | classifier_use_scale_shift_norm=True, # False
38 | classifier_resblock_updown=True, # False
39 | classifier_pool="attention",
40 | )
41 |
42 |
43 | def model_and_diffusion_defaults():
44 | """
45 | Defaults for image training.
46 | """
47 | res = dict(
48 | image_size=64,
49 | num_channels=128,
50 | num_res_blocks=2,
51 | num_heads=4,
52 | num_heads_upsample=-1,
53 | num_head_channels=-1,
54 | attention_resolutions="16,8",
55 | channel_mult="",
56 | dropout=0.0,
57 | class_cond=False,
58 | use_checkpoint=False,
59 | use_scale_shift_norm=True,
60 | resblock_updown=False,
61 | use_fp16=False,
62 | use_new_attention_order=False,
63 | )
64 | res.update(diffusion_defaults())
65 | return res
66 |
67 |
68 | def classifier_and_diffusion_defaults():
69 | res = classifier_defaults()
70 | res.update(diffusion_defaults())
71 | return res
72 |
73 |
74 | def create_model_and_diffusion(
75 | image_size,
76 | class_cond,
77 | learn_sigma,
78 | num_channels,
79 | num_res_blocks,
80 | channel_mult,
81 | num_heads,
82 | num_head_channels,
83 | num_heads_upsample,
84 | attention_resolutions,
85 | dropout,
86 | diffusion_steps,
87 | noise_schedule,
88 | timestep_respacing,
89 | use_kl,
90 | predict_xstart,
91 | rescale_timesteps,
92 | rescale_learned_sigmas,
93 | use_checkpoint,
94 | use_scale_shift_norm,
95 | resblock_updown,
96 | use_fp16,
97 | use_new_attention_order,
98 | ):
99 | model = create_model(
100 | image_size,
101 | num_channels,
102 | num_res_blocks,
103 | channel_mult=channel_mult,
104 | learn_sigma=learn_sigma,
105 | class_cond=class_cond,
106 | use_checkpoint=use_checkpoint,
107 | attention_resolutions=attention_resolutions,
108 | num_heads=num_heads,
109 | num_head_channels=num_head_channels,
110 | num_heads_upsample=num_heads_upsample,
111 | use_scale_shift_norm=use_scale_shift_norm,
112 | dropout=dropout,
113 | resblock_updown=resblock_updown,
114 | use_fp16=use_fp16,
115 | use_new_attention_order=use_new_attention_order,
116 | )
117 | diffusion = create_gaussian_diffusion(
118 | steps=diffusion_steps,
119 | learn_sigma=learn_sigma,
120 | noise_schedule=noise_schedule,
121 | use_kl=use_kl,
122 | predict_xstart=predict_xstart,
123 | rescale_timesteps=rescale_timesteps,
124 | rescale_learned_sigmas=rescale_learned_sigmas,
125 | timestep_respacing=timestep_respacing,
126 | )
127 | return model, diffusion
128 |
129 |
130 | def create_model(
131 | image_size,
132 | num_channels,
133 | num_res_blocks,
134 | channel_mult="",
135 | learn_sigma=False,
136 | class_cond=False,
137 | use_checkpoint=False,
138 | attention_resolutions="16",
139 | num_heads=1,
140 | num_head_channels=-1,
141 | num_heads_upsample=-1,
142 | use_scale_shift_norm=False,
143 | dropout=0,
144 | resblock_updown=False,
145 | use_fp16=False,
146 | use_new_attention_order=False,
147 | ):
148 | if channel_mult == "":
149 | if image_size == 512:
150 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
151 | elif image_size == 256:
152 | channel_mult = (1, 1, 2, 2, 4, 4)
153 | elif image_size == 128:
154 | channel_mult = (1, 1, 2, 3, 4)
155 | elif image_size == 64:
156 | channel_mult = (1, 2, 3, 4)
157 | else:
158 | raise ValueError(f"unsupported image size: {image_size}")
159 | else:
160 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
161 |
162 | attention_ds = []
163 | for res in attention_resolutions.split(","):
164 | attention_ds.append(image_size // int(res))
165 |
166 | return UNetModel(
167 | image_size=image_size,
168 | in_channels=3,
169 | model_channels=num_channels,
170 | out_channels=(3 if not learn_sigma else 6),
171 | num_res_blocks=num_res_blocks,
172 | attention_resolutions=tuple(attention_ds),
173 | dropout=dropout,
174 | channel_mult=channel_mult,
175 | num_classes=(NUM_CLASSES if class_cond else None),
176 | use_checkpoint=use_checkpoint,
177 | use_fp16=use_fp16,
178 | num_heads=num_heads,
179 | num_head_channels=num_head_channels,
180 | num_heads_upsample=num_heads_upsample,
181 | use_scale_shift_norm=use_scale_shift_norm,
182 | resblock_updown=resblock_updown,
183 | use_new_attention_order=use_new_attention_order,
184 | )
185 |
186 |
187 | def create_classifier_and_diffusion(
188 | image_size,
189 | classifier_use_fp16,
190 | classifier_width,
191 | classifier_depth,
192 | classifier_attention_resolutions,
193 | classifier_use_scale_shift_norm,
194 | classifier_resblock_updown,
195 | classifier_pool,
196 | learn_sigma,
197 | diffusion_steps,
198 | noise_schedule,
199 | timestep_respacing,
200 | use_kl,
201 | predict_xstart,
202 | rescale_timesteps,
203 | rescale_learned_sigmas,
204 | ):
205 | classifier = create_classifier(
206 | image_size,
207 | classifier_use_fp16,
208 | classifier_width,
209 | classifier_depth,
210 | classifier_attention_resolutions,
211 | classifier_use_scale_shift_norm,
212 | classifier_resblock_updown,
213 | classifier_pool,
214 | )
215 | diffusion = create_gaussian_diffusion(
216 | steps=diffusion_steps,
217 | learn_sigma=learn_sigma,
218 | noise_schedule=noise_schedule,
219 | use_kl=use_kl,
220 | predict_xstart=predict_xstart,
221 | rescale_timesteps=rescale_timesteps,
222 | rescale_learned_sigmas=rescale_learned_sigmas,
223 | timestep_respacing=timestep_respacing,
224 | )
225 | return classifier, diffusion
226 |
227 |
228 | def create_classifier(
229 | image_size,
230 | classifier_use_fp16,
231 | classifier_width,
232 | classifier_depth,
233 | classifier_attention_resolutions,
234 | classifier_use_scale_shift_norm,
235 | classifier_resblock_updown,
236 | classifier_pool,
237 | ):
238 | if image_size == 512:
239 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
240 | elif image_size == 256:
241 | channel_mult = (1, 1, 2, 2, 4, 4)
242 | elif image_size == 128:
243 | channel_mult = (1, 1, 2, 3, 4)
244 | elif image_size == 64:
245 | channel_mult = (1, 2, 3, 4)
246 | else:
247 | raise ValueError(f"unsupported image size: {image_size}")
248 |
249 | attention_ds = []
250 | for res in classifier_attention_resolutions.split(","):
251 | attention_ds.append(image_size // int(res))
252 |
253 | return EncoderUNetModel(
254 | image_size=image_size,
255 | in_channels=3,
256 | model_channels=classifier_width,
257 | out_channels=1000,
258 | num_res_blocks=classifier_depth,
259 | attention_resolutions=tuple(attention_ds),
260 | channel_mult=channel_mult,
261 | use_fp16=classifier_use_fp16,
262 | num_head_channels=64,
263 | use_scale_shift_norm=classifier_use_scale_shift_norm,
264 | resblock_updown=classifier_resblock_updown,
265 | pool=classifier_pool,
266 | )
267 |
268 |
269 | def sr_model_and_diffusion_defaults():
270 | res = model_and_diffusion_defaults()
271 | res["large_size"] = 256
272 | res["small_size"] = 64
273 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
274 | for k in res.copy().keys():
275 | if k not in arg_names:
276 | del res[k]
277 | return res
278 |
279 |
280 | def sr_create_model_and_diffusion(
281 | large_size,
282 | small_size,
283 | class_cond,
284 | learn_sigma,
285 | num_channels,
286 | num_res_blocks,
287 | num_heads,
288 | num_head_channels,
289 | num_heads_upsample,
290 | attention_resolutions,
291 | dropout,
292 | diffusion_steps,
293 | noise_schedule,
294 | timestep_respacing,
295 | use_kl,
296 | predict_xstart,
297 | rescale_timesteps,
298 | rescale_learned_sigmas,
299 | use_checkpoint,
300 | use_scale_shift_norm,
301 | resblock_updown,
302 | use_fp16,
303 | ):
304 | model = sr_create_model(
305 | large_size,
306 | small_size,
307 | num_channels,
308 | num_res_blocks,
309 | learn_sigma=learn_sigma,
310 | class_cond=class_cond,
311 | use_checkpoint=use_checkpoint,
312 | attention_resolutions=attention_resolutions,
313 | num_heads=num_heads,
314 | num_head_channels=num_head_channels,
315 | num_heads_upsample=num_heads_upsample,
316 | use_scale_shift_norm=use_scale_shift_norm,
317 | dropout=dropout,
318 | resblock_updown=resblock_updown,
319 | use_fp16=use_fp16,
320 | )
321 | diffusion = create_gaussian_diffusion(
322 | steps=diffusion_steps,
323 | learn_sigma=learn_sigma,
324 | noise_schedule=noise_schedule,
325 | use_kl=use_kl,
326 | predict_xstart=predict_xstart,
327 | rescale_timesteps=rescale_timesteps,
328 | rescale_learned_sigmas=rescale_learned_sigmas,
329 | timestep_respacing=timestep_respacing,
330 | )
331 | return model, diffusion
332 |
333 |
334 | def sr_create_model(
335 | large_size,
336 | small_size,
337 | num_channels,
338 | num_res_blocks,
339 | learn_sigma,
340 | class_cond,
341 | use_checkpoint,
342 | attention_resolutions,
343 | num_heads,
344 | num_head_channels,
345 | num_heads_upsample,
346 | use_scale_shift_norm,
347 | dropout,
348 | resblock_updown,
349 | use_fp16,
350 | ):
351 | _ = small_size # hack to prevent unused variable
352 |
353 | if large_size == 512:
354 | channel_mult = (1, 1, 2, 2, 4, 4)
355 | elif large_size == 256:
356 | channel_mult = (1, 1, 2, 2, 4, 4)
357 | elif large_size == 64:
358 | channel_mult = (1, 2, 3, 4)
359 | else:
360 | raise ValueError(f"unsupported large size: {large_size}")
361 |
362 | attention_ds = []
363 | for res in attention_resolutions.split(","):
364 | attention_ds.append(large_size // int(res))
365 |
366 | return SuperResModel(
367 | image_size=large_size,
368 | in_channels=3,
369 | model_channels=num_channels,
370 | out_channels=(3 if not learn_sigma else 6),
371 | num_res_blocks=num_res_blocks,
372 | attention_resolutions=tuple(attention_ds),
373 | dropout=dropout,
374 | channel_mult=channel_mult,
375 | num_classes=(NUM_CLASSES if class_cond else None),
376 | use_checkpoint=use_checkpoint,
377 | num_heads=num_heads,
378 | num_head_channels=num_head_channels,
379 | num_heads_upsample=num_heads_upsample,
380 | use_scale_shift_norm=use_scale_shift_norm,
381 | resblock_updown=resblock_updown,
382 | use_fp16=use_fp16,
383 | )
384 |
385 |
386 | def create_gaussian_diffusion(
387 | *,
388 | steps=1000,
389 | learn_sigma=False,
390 | sigma_small=False,
391 | noise_schedule="linear",
392 | use_kl=False,
393 | predict_xstart=False,
394 | rescale_timesteps=False,
395 | rescale_learned_sigmas=False,
396 | timestep_respacing="",
397 | ):
398 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
399 | if use_kl:
400 | loss_type = gd.LossType.RESCALED_KL
401 | elif rescale_learned_sigmas:
402 | loss_type = gd.LossType.RESCALED_MSE
403 | else:
404 | loss_type = gd.LossType.MSE
405 | if not timestep_respacing:
406 | timestep_respacing = [steps]
407 | return SpacedDiffusion(
408 | use_timesteps=space_timesteps(steps, timestep_respacing),
409 | betas=betas,
410 | model_mean_type=(
411 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
412 | ),
413 | model_var_type=(
414 | (
415 | gd.ModelVarType.FIXED_LARGE
416 | if not sigma_small
417 | else gd.ModelVarType.FIXED_SMALL
418 | )
419 | if not learn_sigma
420 | else gd.ModelVarType.LEARNED_RANGE
421 | ),
422 | loss_type=loss_type,
423 | rescale_timesteps=rescale_timesteps,
424 | )
425 |
426 |
427 | def add_dict_to_argparser(parser, default_dict):
428 | for k, v in default_dict.items():
429 | v_type = type(v)
430 | if v is None:
431 | v_type = str
432 | elif isinstance(v, bool):
433 | v_type = str2bool
434 | parser.add_argument(f"--{k}", default=v, type=v_type)
435 |
436 |
437 | def args_to_dict(args, keys):
438 | return {k: getattr(args, k) for k in keys}
439 |
440 |
441 | def str2bool(v):
442 | """
443 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
444 | """
445 | if isinstance(v, bool):
446 | return v
447 | if v.lower() in ("yes", "true", "t", "y", "1"):
448 | return True
449 | elif v.lower() in ("no", "false", "f", "n", "0"):
450 | return False
451 | else:
452 | raise argparse.ArgumentTypeError("boolean value expected")
--------------------------------------------------------------------------------
/utils/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 | _EPS4 = np.finfo(float).eps * 4.0
12 |
13 | _FLOAT_EPS = np.finfo(np.float).eps
14 |
15 | # PyTorch-backed implementations
16 | def qinv(q):
17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18 | mask = torch.ones_like(q)
19 | mask[..., 1:] = -mask[..., 1:]
20 | return q * mask
21 |
22 |
23 | def qinv_np(q):
24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25 | return qinv(torch.from_numpy(q).float()).numpy()
26 |
27 |
28 | def qnormalize(q):
29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30 | return q / torch.norm(q, dim=-1, keepdim=True)
31 |
32 |
33 | def qmul(q, r):
34 | """
35 | Multiply quaternion(s) q with quaternion(s) r.
36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37 | Returns q*r as a tensor of shape (*, 4).
38 | """
39 | assert q.shape[-1] == 4
40 | assert r.shape[-1] == 4
41 |
42 | original_shape = q.shape
43 |
44 | # Compute outer product
45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46 |
47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51 | return torch.stack((w, x, y, z), dim=1).view(original_shape)
52 |
53 |
54 | def qrot(q, v):
55 | """
56 | Rotate vector(s) v about the rotation described by quaternion(s) q.
57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58 | where * denotes any number of dimensions.
59 | Returns a tensor of shape (*, 3).
60 | """
61 | assert q.shape[-1] == 4
62 | assert v.shape[-1] == 3
63 | assert q.shape[:-1] == v.shape[:-1]
64 |
65 | original_shape = list(v.shape)
66 | # print(q.shape)
67 | q = q.contiguous().view(-1, 4)
68 | v = v.contiguous().view(-1, 3)
69 |
70 | qvec = q[:, 1:]
71 | uv = torch.cross(qvec, v, dim=1)
72 | uuv = torch.cross(qvec, uv, dim=1)
73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74 |
75 |
76 | def qeuler(q, order, epsilon=0, deg=True):
77 | """
78 | Convert quaternion(s) q to Euler angles.
79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80 | Returns a tensor of shape (*, 3).
81 | """
82 | assert q.shape[-1] == 4
83 |
84 | original_shape = list(q.shape)
85 | original_shape[-1] = 3
86 | q = q.view(-1, 4)
87 |
88 | q0 = q[:, 0]
89 | q1 = q[:, 1]
90 | q2 = q[:, 2]
91 | q3 = q[:, 3]
92 |
93 | if order == 'xyz':
94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97 | elif order == 'yzx':
98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101 | elif order == 'zxy':
102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105 | elif order == 'xzy':
106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109 | elif order == 'yxz':
110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113 | elif order == 'zyx':
114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117 | else:
118 | raise
119 |
120 | if deg:
121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122 | else:
123 | return torch.stack((x, y, z), dim=1).view(original_shape)
124 |
125 |
126 | # Numpy-backed implementations
127 |
128 | def qmul_np(q, r):
129 | q = torch.from_numpy(q).contiguous().float()
130 | r = torch.from_numpy(r).contiguous().float()
131 | return qmul(q, r).numpy()
132 |
133 |
134 | def qrot_np(q, v):
135 | q = torch.from_numpy(q).contiguous().float()
136 | v = torch.from_numpy(v).contiguous().float()
137 | return qrot(q, v).numpy()
138 |
139 |
140 | def qeuler_np(q, order, epsilon=0, use_gpu=False):
141 | if use_gpu:
142 | q = torch.from_numpy(q).cuda().float()
143 | return qeuler(q, order, epsilon).cpu().numpy()
144 | else:
145 | q = torch.from_numpy(q).contiguous().float()
146 | return qeuler(q, order, epsilon).numpy()
147 |
148 |
149 | def qfix(q):
150 | """
151 | Enforce quaternion continuity across the time dimension by selecting
152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153 | between two consecutive frames.
154 |
155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156 | Returns a tensor of the same shape.
157 | """
158 | assert len(q.shape) == 3
159 | assert q.shape[-1] == 4
160 |
161 | result = q.copy()
162 | dot_products = np.sum(q[1:] * q[:-1], axis=2)
163 | mask = dot_products < 0
164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165 | result[1:][mask] *= -1
166 | return result
167 |
168 |
169 | def euler2quat(e, order, deg=True):
170 | """
171 | Convert Euler angles to quaternions.
172 | """
173 | assert e.shape[-1] == 3
174 |
175 | original_shape = list(e.shape)
176 | original_shape[-1] = 4
177 |
178 | e = e.view(-1, 3)
179 |
180 | ## if euler angles in degrees
181 | if deg:
182 | e = e * np.pi / 180.
183 |
184 | x = e[:, 0]
185 | y = e[:, 1]
186 | z = e[:, 2]
187 |
188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191 |
192 | result = None
193 | for coord in order:
194 | if coord == 'x':
195 | r = rx
196 | elif coord == 'y':
197 | r = ry
198 | elif coord == 'z':
199 | r = rz
200 | else:
201 | raise
202 | if result is None:
203 | result = r
204 | else:
205 | result = qmul(result, r)
206 |
207 | # Reverse antipodal representation to have a non-negative "w"
208 | if order in ['xyz', 'yzx', 'zxy']:
209 | result *= -1
210 |
211 | return result.view(original_shape)
212 |
213 |
214 | def expmap_to_quaternion(e):
215 | """
216 | Convert axis-angle rotations (aka exponential maps) to quaternions.
217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219 | Returns a tensor of shape (*, 4).
220 | """
221 | assert e.shape[-1] == 3
222 |
223 | original_shape = list(e.shape)
224 | original_shape[-1] = 4
225 | e = e.reshape(-1, 3)
226 |
227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228 | w = np.cos(0.5 * theta).reshape(-1, 1)
229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231 |
232 |
233 | def euler_to_quaternion(e, order):
234 | """
235 | Convert Euler angles to quaternions.
236 | """
237 | assert e.shape[-1] == 3
238 |
239 | original_shape = list(e.shape)
240 | original_shape[-1] = 4
241 |
242 | e = e.reshape(-1, 3)
243 |
244 | x = e[:, 0]
245 | y = e[:, 1]
246 | z = e[:, 2]
247 |
248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251 |
252 | result = None
253 | for coord in order:
254 | if coord == 'x':
255 | r = rx
256 | elif coord == 'y':
257 | r = ry
258 | elif coord == 'z':
259 | r = rz
260 | else:
261 | raise
262 | if result is None:
263 | result = r
264 | else:
265 | result = qmul_np(result, r)
266 |
267 | # Reverse antipodal representation to have a non-negative "w"
268 | if order in ['xyz', 'yzx', 'zxy']:
269 | result *= -1
270 |
271 | return result.reshape(original_shape)
272 |
273 |
274 | def quaternion_to_matrix(quaternions):
275 | """
276 | Convert rotations given as quaternions to rotation matrices.
277 | Args:
278 | quaternions: quaternions with real part first,
279 | as tensor of shape (..., 4).
280 | Returns:
281 | Rotation matrices as tensor of shape (..., 3, 3).
282 | """
283 | r, i, j, k = torch.unbind(quaternions, -1)
284 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
285 |
286 | o = torch.stack(
287 | (
288 | 1 - two_s * (j * j + k * k),
289 | two_s * (i * j - k * r),
290 | two_s * (i * k + j * r),
291 | two_s * (i * j + k * r),
292 | 1 - two_s * (i * i + k * k),
293 | two_s * (j * k - i * r),
294 | two_s * (i * k - j * r),
295 | two_s * (j * k + i * r),
296 | 1 - two_s * (i * i + j * j),
297 | ),
298 | -1,
299 | )
300 | return o.reshape(quaternions.shape[:-1] + (3, 3))
301 |
302 |
303 | def quaternion_to_matrix_np(quaternions):
304 | q = torch.from_numpy(quaternions).contiguous().float()
305 | return quaternion_to_matrix(q).numpy()
306 |
307 |
308 | def quaternion_to_cont6d_np(quaternions):
309 | rotation_mat = quaternion_to_matrix_np(quaternions)
310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311 | return cont_6d
312 |
313 |
314 | def quaternion_to_cont6d(quaternions):
315 | rotation_mat = quaternion_to_matrix(quaternions)
316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317 | return cont_6d
318 |
319 |
320 | def cont6d_to_matrix(cont6d):
321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322 | x_raw = cont6d[..., 0:3]
323 | y_raw = cont6d[..., 3:6]
324 |
325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326 | z = torch.cross(x, y_raw, dim=-1)
327 | z = z / torch.norm(z, dim=-1, keepdim=True)
328 |
329 | y = torch.cross(z, x, dim=-1)
330 |
331 | x = x[..., None]
332 | y = y[..., None]
333 | z = z[..., None]
334 |
335 | mat = torch.cat([x, y, z], dim=-1)
336 | return mat
337 |
338 |
339 | def cont6d_to_matrix_np(cont6d):
340 | q = torch.from_numpy(cont6d).contiguous().float()
341 | return cont6d_to_matrix(q).numpy()
342 |
343 |
344 | def qpow(q0, t, dtype=torch.float):
345 | ''' q0 : tensor of quaternions
346 | t: tensor of powers
347 | '''
348 | q0 = qnormalize(q0)
349 | theta0 = torch.acos(q0[..., 0])
350 |
351 | ## if theta0 is close to zero, add epsilon to avoid NaNs
352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353 | theta0 = (1 - mask) * theta0 + mask * 10e-10
354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355 |
356 | if isinstance(t, torch.Tensor):
357 | q = torch.zeros(t.shape + q0.shape)
358 | theta = t.view(-1, 1) * theta0.view(1, -1)
359 | else: ## if t is a number
360 | q = torch.zeros(q0.shape)
361 | theta = t * theta0
362 |
363 | q[..., 0] = torch.cos(theta)
364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365 |
366 | return q.to(dtype)
367 |
368 |
369 | def qslerp(q0, q1, t):
370 | '''
371 | q0: starting quaternion
372 | q1: ending quaternion
373 | t: array of points along the way
374 |
375 | Returns:
376 | Tensor of Slerps: t.shape + q0.shape
377 | '''
378 |
379 | q0 = qnormalize(q0)
380 | q1 = qnormalize(q1)
381 | q_ = qpow(qmul(q1, qinv(q0)), t)
382 |
383 | return qmul(q_,
384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385 |
386 |
387 | def qbetween(v0, v1):
388 | '''
389 | find the quaternion used to rotate v0 to v1
390 | '''
391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393 |
394 | v = torch.cross(v0, v1)
395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396 | keepdim=True)
397 | return qnormalize(torch.cat([w, v], dim=-1))
398 |
399 |
400 | def qbetween_np(v0, v1):
401 | '''
402 | find the quaternion used to rotate v0 to v1
403 | '''
404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406 |
407 | v0 = torch.from_numpy(v0).float()
408 | v1 = torch.from_numpy(v1).float()
409 | return qbetween(v0, v1).numpy()
410 |
411 |
412 | def lerp(p0, p1, t):
413 | if not isinstance(t, torch.Tensor):
414 | t = torch.Tensor([t])
415 |
416 | new_shape = t.shape + p0.shape
417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419 | p0 = p0.view(new_view_p).expand(new_shape)
420 | p1 = p1.view(new_view_p).expand(new_shape)
421 | t = t.view(new_view_t).expand(new_shape)
422 |
423 | return p0 + t * (p1 - p0)
424 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from mmcv.runner.dist_utils import get_dist_info
5 | import torch.distributed as dist
6 |
7 |
8 | class BaseOptions():
9 | def __init__(self):
10 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
11 | self.initialized = False
12 |
13 | def initialize(self):
14 | self.parser.add_argument('--name', type=str, default="test", help='Name of this trial')
15 | self.parser.add_argument('--decomp_name', type=str, default="Decomp_SP001_SM001_H512", help='Name of autoencoder model')
16 | self.parser.add_argument('--model_base', type=str, default="transformer_encoder", choices=["transformer_decoder", "transformer_encoder", "st_unet"], help='Model architecture')
17 | self.parser.add_argument('--model_mean_type', type=str, default="epsilon", choices=["epsilon", "start_x", "previous_x"], help='Choose which type of data the model ouputs')
18 | self.parser.add_argument('--PE', type=str, default='pe_sinu', choices=['learnable', 'ppe_sinu', 'pe_sinu', 'pe_sinu_repeat', 'ppe_sinu_dropout'], help='Choose the type of positional emb')
19 | self.parser.add_argument("--ddim", action="store_true", help='Use ddim sampling')
20 | self.parser.add_argument("--timestep_respacing", type=str, default='ddim1000', help="Set ddim steps 'ddim{STEP}'")
21 | self.parser.add_argument("--cond_projection", type=str, default='mlp_includeX', choices=["linear_includeX", "mlp_includeX", "none", "linear_excludeX", "mlp_excludeX"], help="condition projection choices")
22 | self.parser.add_argument("--cond_residual", type=bool, default=True, help='Weather to use residual during condition projection')
23 |
24 |
25 | self.parser.add_argument("--gpu_id", type=int, default=None, help='GPU id')
26 | self.parser.add_argument("--distributed", action="store_true", help='Weather to use DDP training')
27 | self.parser.add_argument("--data_parallel", action="store_true", help='Weather to use DP training')
28 | self.parser.add_argument("--max_eval_samples", type=int, default=-1, help='max_eval_samples')
29 | self.parser.add_argument("--n_poses", type=int, help='number of poses for a training sequence')
30 | self.parser.add_argument("--axis_angle", type=bool, default=True, help='whether use the axis_angle rot representaiton')
31 | self.parser.add_argument("--rename", default=None, help='rename the experiment name during test')
32 |
33 | self.parser.add_argument("--debug", action="store_true", help='debug mode, only run one iteration')
34 | self.parser.add_argument('--mode', type=str, default='train', choices=["train", "val", "test", "test_arbitrary_len", "test_custom_audio"], help='train, val or test')
35 |
36 | self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name')
37 | self.parser.add_argument('--data_mode', type=str, default='original', choices=['original', 'add_init_state'], help='Data modes')
38 | self.parser.add_argument('--data_type', type=str, default='pos', choices=['pos', 'vel', 'pos_vel'], help='Data types')
39 | self.parser.add_argument('--data_sel', type=str, default='upperbody', choices=['upperbody', 'all', 'upperbody_head', 'upperbody_hands'], help='Data selection')
40 | self.parser.add_argument('--data_root', type=str, default='./Freeform/processed_data_200', help='Dataset path')
41 | self.parser.add_argument('--beat_cache_name', default='beat_4english_15_141', help='Beat cache name')
42 | self.parser.add_argument('--use_aud_feat', type=str, default=None, choices=["interpolate", "conv"], help='Audio feature path')
43 | self.parser.add_argument('--audio_feat', type=str, default='mel', choices=["mel", "mfcc", "raw", "hubert", 'wav2vec2'], help='Audio feature type')
44 | self.parser.add_argument('--test_audio_path', type=str, default=None, help='test audio file or directory path')
45 | self.parser.add_argument('--same_overlap_noisy', action="store_true", help='During the outpainting process, use the same overlapping noisyGT')
46 | self.parser.add_argument('--no_repaint', action="store_true", help='Do not perform repaint during long-form generation')
47 |
48 |
49 | self.parser.add_argument('--vel_interval', type=int, default=10, help='Interval to compute the velocity')
50 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
51 | self.parser.add_argument('--overlap_len', type=int, default=0, help='Fix the initial N frames for this clip')
52 | self.parser.add_argument('--addBlend', type=bool, default=True, help='Blend in the overlapping region at the last two denoise steps')
53 | self.parser.add_argument('--fix_very_first', action='store_true', help='Fix the very first {overlap_len} frames for this video to be the same as GT')
54 | self.parser.add_argument('--remove_audio', action='store_true', help='set audio to 0')
55 | self.parser.add_argument('--remove_style', action='store_true', help='set style to 0')
56 | self.parser.add_argument('--remove_hand', action='store_true', help='remove hand rotations from motion data')
57 | self.parser.add_argument('--no_fgd', action='store_true', help='do not compute fgd')
58 | self.parser.add_argument('--ablation', type=str, default=None, choices=["no_x0", "no_detach", "reverse_ges2exp"], help='ablation options')
59 | self.parser.add_argument('--rebuttal', type=str, default=None, choices=["noMelSpec", "noHuBert", "noMidAud"], help='rebuttal ablation options')
60 | self.parser.add_argument('--visualize_unify_x0_step', type=int, default=None, help='visualize expression x0 in unified mode every N step')
61 |
62 |
63 | self.parser.add_argument('--audio_dim', type=int, default=128, help='Input Audio feature dimension.')
64 | self.parser.add_argument('--audio_latent_dim', type=int, default=256, help='Audio latent dimension.')
65 | self.parser.add_argument('--style_dim', type=int, default=4, help='Input Style vector dimension. Can be one hot.')
66 | self.parser.add_argument('--dim_text_hidden', type=int, default=512, help='Dimension of hidden unit in text encoder')
67 | self.parser.add_argument('--dim_att_vec', type=int, default=512, help='Dimension of attention vector')
68 | self.parser.add_argument('--dim_z', type=int, default=128, help='Dimension of latent Gaussian vector')
69 |
70 | self.parser.add_argument('--n_layers_pri', type=int, default=1, help='Number of layers in prior network')
71 | self.parser.add_argument('--n_layers_pos', type=int, default=1, help='Number of layers in posterior network')
72 | self.parser.add_argument('--n_layers_dec', type=int, default=1, help='Number of layers in generator')
73 |
74 | self.parser.add_argument('--dim_pri_hidden', type=int, default=1024, help='Dimension of hidden unit in prior network')
75 | self.parser.add_argument('--dim_pos_hidden', type=int, default=1024, help='Dimension of hidden unit in posterior network')
76 | self.parser.add_argument('--dim_dec_hidden', type=int, default=1024, help='Dimension of hidden unit in generator')
77 |
78 | self.parser.add_argument('--dim_movement_enc_hidden', type=int, default=512,
79 | help='Dimension of hidden in AutoEncoder(encoder)')
80 | self.parser.add_argument('--dim_movement_dec_hidden', type=int, default=512,
81 | help='Dimension of hidden in AutoEncoder(decoder)')
82 | self.parser.add_argument('--dim_movement_latent', type=int, default=512, help='Dimension of motion snippet')
83 |
84 | self.parser.add_argument('--embed_net_path', type=str, default="feature_extractor/gesture_autoencoder_checkpoint_best.bin", help='embed_net_path')
85 |
86 | self.parser.add_argument('--fix_head_var', action="store_true", help='Make expression prediction derterministic')
87 | self.parser.add_argument('--expression_only', action="store_true", help='train epxression only')
88 | self.parser.add_argument('--gesture_only', action="store_true", help='train gesture only')
89 | self.parser.add_argument('--expCondition_gesture_only', type=str, choices=['gt', 'pred'], default=None, help='train gesture only, with expressions as condition')
90 | self.parser.add_argument('--gesCondition_expression_only', action="store_true", help='train expression only, with gesture as condition')
91 | self.parser.add_argument('--textExpEmoCondition_gesture_only', action="store_true", help='use all conditions: audio, text, emo, pid, facial')
92 | self.parser.add_argument('--addTextCond', action="store_true", help='add Text feature to audio feature')
93 | self.parser.add_argument('--addEmoCond', action="store_true", help='add Emo feature to audio feature')
94 | self.parser.add_argument('--expAddHubert', action="store_true", help='concat Hubert feature to encoded audio feature only for expression generation')
95 | self.parser.add_argument('--addHubert', type=bool, default=True, help='concat Hubert feature to encoded audio feature for both expression and gesture generation')
96 | self.parser.add_argument('--addWav2Vec2', action="store_true", help='concat Wav2Vec2 feature to encoded audio feature for both expression and gesture generation')
97 | self.parser.add_argument('--encode_wav2vec2', action="store_true", help='encode the wav2vec2 feature')
98 | self.parser.add_argument('--encode_hubert', type=bool, default=True, help='encode the hubert feature')
99 | self.parser.add_argument('--separate', type=str, choices=['v1', 'v2'], default=None, help='limit information exchange between expression and gestures, v1 share encoder, v2 two independent encoders')
100 | self.parser.add_argument('--usePredExpr', type=str, default=None, help='Path to the predicted expressions.')
101 | self.parser.add_argument('--unidiffuser', type=bool, default=True, help='Use the unified framework for joint expression and gesture generation')
102 |
103 | self.parser.add_argument('--separate_pure', action="store_true", help='pure two encoders')
104 |
105 | # classifier-free guidance
106 | self.parser.add_argument('--classifier_free', action="store_true", help='Use classifier-free guidance')
107 | self.parser.add_argument('--null_cond_prob', type=float, default=0.2, help='Probability of null condition during classifier-free training')
108 | self.parser.add_argument('--cond_scale', type=float, default=1.0, help='Scale of the condition in classifier-free guidance sampling')
109 |
110 | # Try Expression ID off
111 | self.parser.add_argument('--ExprID_off', action="store_true", help='Turn off the expression ID condition')
112 | self.parser.add_argument('--ExprID_off_uncond', action="store_true", help='Turn off the expression ID condition under the classifier-free uncondition part of training')
113 |
114 |
115 | self.parser.add_argument('--use_joints', action="store_true", help='Whether convert to joints if using TED 3D dataset')
116 | self.parser.add_argument('--use_single_style', action="store_true", help='Whether to use single style')
117 | self.parser.add_argument('--test_on_trainset', action="store_true", help='Whether to test on training set')
118 | self.parser.add_argument('--test_on_val', action="store_true", help='Whether to test on validation set')
119 | self.parser.add_argument('--output_gt', action="store_true", help='Directly output GT during test')
120 | self.parser.add_argument('--no_style', action="store_true", help='Do not use style vectors')
121 | self.parser.add_argument('--no_resample', action="store_true", help='Do not use resample during inpainting based sampling')
122 | self.parser.add_argument('--add_vel_loss', type=bool, default=True, help='Add velocity loss')
123 | self.parser.add_argument('--vel_loss_start', type=int, default=-1, help='velocity loss and huber loss start epoch')
124 | self.parser.add_argument('--expr_weight', type=int, default=1, help='expression weight')
125 |
126 | # inference
127 | self.parser.add_argument('--jump_n_sample', type=int, default=5, help='hyperparameter for resampling')
128 | self.parser.add_argument('--jump_length', type=int, default=3, help='hyperparameter for resampling')
129 |
130 |
131 | ## Distributed
132 | self.parser.add_argument('--world-size', default=1, type=int,
133 | help='number of nodes for distributed training')
134 | self.parser.add_argument('--rank', default=0, type=int,
135 | help='node rank for distributed training')
136 | self.parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
137 | help='url used to set up distributed training')
138 | self.parser.add_argument('--dist-backend', default='nccl', type=str,
139 | help='distributed backend')
140 | self.parser.add_argument('--multiprocessing-distributed', action='store_true',
141 | help='Use multi-processing distributed training to launch '
142 | 'N processes per node, which has N GPUs. This is the '
143 | 'fastest way to use PyTorch for either single node or '
144 | 'multi node data parallel training')
145 | self.parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
146 | help='number of data loading workers (default: 4)')
147 |
148 | self.initialized = True
149 |
150 |
151 | def parse(self):
152 | if not self.initialized:
153 | self.initialize()
154 |
155 | self.opt = self.parser.parse_args()
156 |
157 | self.opt.is_train = self.is_train
158 |
159 | args = vars(self.opt)
160 |
161 | if self.opt.rank == 0:
162 | print('------------ Options -------------')
163 | for k, v in sorted(args.items()):
164 | print('%s: %s' % (str(k), str(v)))
165 | print('-------------- End ----------------')
166 | if self.is_train:
167 | # save to the disk
168 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name)
169 | if not os.path.exists(expr_dir):
170 | os.makedirs(expr_dir)
171 | file_name = os.path.join(expr_dir, 'opt.txt')
172 | with open(file_name, 'wt') as opt_file:
173 | opt_file.write('------------ Options -------------\n')
174 | for k, v in sorted(args.items()):
175 | opt_file.write('%s: %s\n' % (str(k), str(v)))
176 | opt_file.write('-------------- End ----------------\n')
177 | if self.opt.world_size > 1:
178 | dist.barrier()
179 | return self.opt
180 |
--------------------------------------------------------------------------------