├── data_loaders ├── humanml │ ├── data │ │ └── __init__.py │ ├── networks │ │ ├── __init__.py │ │ └── evaluator_wrapper.py │ ├── motion_loaders │ │ ├── __init__.py │ │ ├── dataset_motion_loader.py │ │ └── model_motion_loaders.py │ ├── README.md │ └── utils │ │ ├── paramUtil.py │ │ ├── get_opt.py │ │ ├── word_vectorizer.py │ │ ├── metrics.py │ │ ├── utils.py │ │ └── plot_script.py ├── multidataset.py ├── get_data.py ├── humanml_utils.py ├── ucf101.py └── tensors.py ├── visualize ├── blender │ ├── __init__.py │ ├── data.py │ ├── vertices.py │ ├── sampler.py │ ├── tools.py │ ├── camera.py │ ├── video.py │ ├── meshes.py │ ├── floor.py │ ├── scene.py │ ├── render.py │ └── materials.py ├── joints2smpl │ ├── smpl_models │ │ ├── faces.npy │ │ ├── SMPL_downsample_index.pkl │ │ └── neutral_smpl_mean_params.h5 │ ├── environment.yaml │ ├── src │ │ ├── config.py │ │ ├── prior.py │ │ └── customloss.py │ ├── README.md │ └── fit_seq.py ├── render_mesh.py ├── vis_utils.py ├── motions2hik.py └── simplify_loc2rot.py ├── assets └── teaser.png ├── configs ├── crossdiff_pre.yaml ├── crossdiff_finetune.yaml ├── crossdiff_ucf.yaml └── base.yaml ├── prepare ├── download_t2m_evaluators.sh ├── download_smpl_files.sh ├── download_glove.sh ├── download_pretrained_models.sh └── project.py ├── utils ├── fixseed.py ├── config.py ├── misc.py ├── filter.py ├── parser_util.py ├── PYTORCH3D_LICENSE ├── model_util.py └── load_utils.py ├── fit_smpl.py ├── requirements.txt ├── LICENSE ├── train └── train_platforms.py ├── render_blender.py ├── test.py ├── train.py ├── model ├── cfg_sampler.py ├── rotation2xyz.py ├── smpl.py └── crossdiff.py ├── .gitignore ├── diffusion ├── losses.py ├── respace.py ├── resample.py ├── nn.py └── fp16_util.py ├── generate.py └── README.md /data_loaders/humanml/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_loaders/humanml/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_loaders/humanml/motion_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /visualize/blender/__init__.py: -------------------------------------------------------------------------------- 1 | from .render import render 2 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderNo/crossdiff/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /visualize/blender/data.py: -------------------------------------------------------------------------------- 1 | class Data: 2 | def __len__(self): 3 | return self.N 4 | -------------------------------------------------------------------------------- /data_loaders/humanml/README.md: -------------------------------------------------------------------------------- 1 | This code is based on https://github.com/EricGuo5513/text-to-motion.git -------------------------------------------------------------------------------- /configs/crossdiff_pre.yaml: -------------------------------------------------------------------------------- 1 | NAME: crossdiff_pre 2 | 3 | batch_size: 32 4 | num_epochs: 4000 5 | lr: 1e-4 6 | -------------------------------------------------------------------------------- /visualize/joints2smpl/smpl_models/faces.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderNo/crossdiff/HEAD/visualize/joints2smpl/smpl_models/faces.npy -------------------------------------------------------------------------------- /visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderNo/crossdiff/HEAD/visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl -------------------------------------------------------------------------------- /visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonderNo/crossdiff/HEAD/visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 -------------------------------------------------------------------------------- /configs/crossdiff_finetune.yaml: -------------------------------------------------------------------------------- 1 | NAME: crossdiff_finetune 2 | 3 | batch_size: 32 4 | num_epochs: 2000 5 | mode: 'finetune' 6 | lr: 1e-5 7 | save_interval: 300 8 | 9 | w_m2j: 0.1 10 | w_j2m: 0.1 11 | 12 | resume_checkpoint: ~ 13 | test_checkpoint: ~ -------------------------------------------------------------------------------- /configs/crossdiff_ucf.yaml: -------------------------------------------------------------------------------- 1 | NAME: crossdiff_ufc 2 | 3 | batch_size: 8 4 | num_epochs: 2000 5 | mode: 'finetune_ucf' 6 | lr: 1e-5 7 | save_interval: 300 8 | 9 | dataset: 'multi' 10 | ucf_ratio: 0.5 11 | 12 | w_m2j: 0.1 13 | w_j2m: 0.1 14 | 15 | resume_checkpoint: '' -------------------------------------------------------------------------------- /prepare/download_t2m_evaluators.sh: -------------------------------------------------------------------------------- 1 | echo -e "Downloading T2M evaluators" 2 | cd ./data 3 | 4 | gdown "https://drive.google.com/uc?id=1rcYjuawHqq5Z229rIR_dgfTmELNgst0O" 5 | 6 | unzip t2m.zip 7 | echo -e "Cleaning\n" 8 | rm t2m.zip 9 | 10 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /prepare/download_smpl_files.sh: -------------------------------------------------------------------------------- 1 | cd ./data 2 | echo -e "The smpl files will be stored in the './data/smpl/' folder\n" 3 | gdown "https://drive.google.com/uc?id=1INYlGA76ak_cKGzvpOV2Pe6RkYTlXTW2" 4 | rm -rf smpl 5 | 6 | unzip smpl.zip 7 | echo -e "Cleaning\n" 8 | rm smpl.zip 9 | 10 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /prepare/download_glove.sh: -------------------------------------------------------------------------------- 1 | echo -e "Downloading glove (in use by the evaluators, not by CrossDiff itself)" 2 | cd ./data 3 | gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing 4 | rm -rf glove 5 | 6 | unzip glove.zip 7 | echo -e "Cleaning\n" 8 | rm glove.zip 9 | 10 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /prepare/download_pretrained_models.sh: -------------------------------------------------------------------------------- 1 | echo -e "Downloading pre-trained models" 2 | cd ./data 3 | mkdir ./checkpoints 4 | cd checkpoints 5 | 6 | gdown "https://drive.google.com/uc?id=1pKSpIuYES6-ToJPPowwps9LIb_xzTMLE" 7 | gdown "https://drive.google.com/uc?id=13C26tAg2aBU60mwU63DR4dbU_bWvLDWP" 8 | gdown "https://drive.google.com/uc?id=1oMdt1Z8jBulXTqjm8y9or5jBx0IMmzwp" 9 | 10 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /utils/fixseed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | def fixseed(seed): 7 | torch.backends.cudnn.benchmark = False 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | 12 | 13 | # SEED = 10 14 | # EVALSEED = 0 15 | # # Provoc warning: not fully functionnal yet 16 | # # torch.set_deterministic(True) 17 | # torch.backends.cudnn.benchmark = False 18 | # fixseed(SEED) 19 | -------------------------------------------------------------------------------- /visualize/blender/vertices.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def prepare_vertices(vertices, canonicalize=True): 5 | data = vertices 6 | # Swap axis (gravity=Z instead of Y) 7 | # data = data[..., [2, 0, 1]] 8 | 9 | # Make left/right correct 10 | # data[..., [1]] = -data[..., [1]] 11 | 12 | # Center the first root to the first frame 13 | data -= data[[0], [0], :] 14 | 15 | # Remove the floor 16 | data[..., 2] -= np.min(data[..., 2]) 17 | return data 18 | -------------------------------------------------------------------------------- /fit_smpl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from visualize.simplify_loc2rot import joints2smpl 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--file", '-f',required=True, type=str) 7 | args = parser.parse_args() 8 | print(f'dealing file {args.file}') 9 | motion = np.load(args.file) 10 | 11 | j2s = joints2smpl(num_frames=motion.shape[0], device_id=0) 12 | 13 | meshes, vertices = j2s.joint2smpl(motion) 14 | np.save(args.file[:-4] + '_mesh.npy', vertices) 15 | 16 | print('done') -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | SMPL_DATA_PATH = "./body_models/smpl" 4 | 5 | SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl") 6 | SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl") 7 | JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy') 8 | 9 | ROT_CONVENTION_TO_ROT_NUMBER = { 10 | 'legacy': 23, 11 | 'no_hands': 21, 12 | 'full_hands': 51, 13 | 'mitten_hands': 33, 14 | } 15 | 16 | GENDERS = ['neutral', 'male', 'female'] 17 | NUM_BETAS = 10 -------------------------------------------------------------------------------- /visualize/joints2smpl/environment.yaml: -------------------------------------------------------------------------------- 1 | name: fit3d 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | - pytorch3d 7 | - open3d-admin 8 | - anaconda 9 | dependencies: 10 | - pip=21.1.3 11 | - numpy=1.20.3 12 | - numpy-base=1.20.3 13 | - matplotlib=3.4.2 14 | - matplotlib-base=3.4.2 15 | - pandas=1.3.1 16 | - python=3.7.6 17 | - pytorch=1.7.1 18 | - tensorboardx=2.2 19 | - cudatoolkit=10.2.89 20 | - torchvision=0.8.2 21 | - einops=0.3.0 22 | - pytorch3d=0.4.0 23 | - tqdm=4.61.2 24 | - trimesh=3.9.24 25 | - joblib=1.0.1 26 | - open3d=0.13.0 27 | - pip: 28 | - h5py==2.9.0 29 | - chumpy==0.70 30 | - smplx==0.1.28 31 | -------------------------------------------------------------------------------- /visualize/blender/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_frameidx(*, mode, nframes, exact_frame, frames_to_keep, pre_idx): 4 | if mode == "sequence": 5 | if pre_idx is not None: 6 | frameidx = pre_idx 7 | else: 8 | frameidx = np.linspace(0, nframes - 1, frames_to_keep) 9 | frameidx = np.round(frameidx).astype(int) 10 | frameidx = list(frameidx) 11 | elif mode == "frame": 12 | if pre_idx is not None: 13 | frameidx = pre_idx 14 | else: 15 | index_frame = int(exact_frame*nframes) 16 | frameidx = [index_frame] 17 | elif mode == "video": 18 | frameidx = range(0, nframes) 19 | else: 20 | raise ValueError(f"Not support {mode} render mode") 21 | return frameidx 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | antlr4-python3-runtime 3 | cachetools 4 | certifi 5 | charset-normalizer 6 | clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 7 | contourpy 8 | cycler 9 | decorator 10 | fonttools 11 | ftfy 12 | google-auth 13 | google-auth-oauthlib 14 | grpcio 15 | idna 16 | imageio 17 | imageio-ffmpeg 18 | importlib-metadata 19 | importlib-resources 20 | kiwisolver 21 | MarkupSafe 22 | matplotlib 23 | moviepy 24 | numpy 25 | oauthlib 26 | omegaconf 27 | packaging 28 | pandas 29 | Pillow 30 | proglog 31 | protobuf 32 | pyasn1 33 | pyasn1-modules 34 | pyparsing 35 | python-dateutil 36 | pytz 37 | PyYAML 38 | regex 39 | requests 40 | requests-oauthlib 41 | rsa 42 | scipy 43 | six 44 | tensorboard 45 | tensorboard-data-server 46 | tqdm 47 | typing_extensions 48 | tzdata 49 | urllib3 50 | wcwidth 51 | Werkzeug 52 | zipp 53 | -------------------------------------------------------------------------------- /data_loaders/multidataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from data_loaders.humanml.data.dataset import HumanML3D 3 | from data_loaders.ucf101 import UCF101 4 | 5 | class MultiDataset(HumanML3D): 6 | 7 | def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", args=None): 8 | super(MultiDataset, self).__init__(mode, datapath, split, args) 9 | 10 | self.ucf101dataset = UCF101(args) 11 | 12 | self.cut_length = len(self.ucf101dataset) 13 | self.little = args.ucf_ratio > 0 14 | 15 | if self.little: 16 | self.length = int(len(self.ucf101dataset) / args.ucf_ratio) 17 | else: 18 | self.length = len(self.t2m_dataset) + len(self.ucf101dataset) 19 | 20 | 21 | def __getitem__(self, index): 22 | if index < self.cut_length: 23 | return self.ucf101dataset[index] 24 | elif self.little: 25 | return self.t2m_dataset[np.random.randint(0, len(self.t2m_dataset))] 26 | else: 27 | return self.t2m_dataset[index - self.cut_length] 28 | 29 | def __len__(self): 30 | return self.length -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 CrossDiff 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_loaders/humanml/motion_loaders/dataset_motion_loader.py: -------------------------------------------------------------------------------- 1 | from t2m.data.dataset import Text2MotionDatasetV2, collate_fn 2 | from t2m.utils.word_vectorizer import WordVectorizer 3 | import numpy as np 4 | from os.path import join as pjoin 5 | from torch.utils.data import DataLoader 6 | from t2m.utils.get_opt import get_opt 7 | 8 | def get_dataset_motion_loader(opt_path, batch_size, device): 9 | opt = get_opt(opt_path, device) 10 | 11 | # Configurations of T2M dataset and KIT dataset is almost the same 12 | if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': 13 | print('Loading dataset %s ...' % opt.dataset_name) 14 | 15 | mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) 16 | std = np.load(pjoin(opt.meta_dir, 'std.npy')) 17 | 18 | w_vectorizer = WordVectorizer('./glove', 'our_vab') 19 | split_file = pjoin(opt.data_root, 'test.txt') 20 | dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer) 21 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True, 22 | collate_fn=collate_fn, shuffle=True) 23 | else: 24 | raise KeyError('Dataset not Recognized !!') 25 | 26 | print('Ground Truth Dataset Loading Completed!!!') 27 | return dataloader, dataset -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def to_numpy(tensor): 5 | if torch.is_tensor(tensor): 6 | return tensor.cpu().numpy() 7 | elif type(tensor).__module__ != 'numpy': 8 | raise ValueError("Cannot convert {} to numpy array".format( 9 | type(tensor))) 10 | return tensor 11 | 12 | 13 | def to_torch(ndarray): 14 | if type(ndarray).__module__ == 'numpy': 15 | return torch.from_numpy(ndarray) 16 | elif not torch.is_tensor(ndarray): 17 | raise ValueError("Cannot convert {} to torch tensor".format( 18 | type(ndarray))) 19 | return ndarray 20 | 21 | 22 | def cleanexit(): 23 | import sys 24 | import os 25 | try: 26 | sys.exit(0) 27 | except SystemExit: 28 | os._exit(0) 29 | 30 | def load_model_wo_clip(model, state_dict): 31 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 32 | assert len(unexpected_keys) == 0 33 | assert all([k.startswith('clip_model.') for k in missing_keys]) 34 | 35 | def freeze_joints(x, joints_to_freeze): 36 | # Freezes selected joint *rotations* as they appear in the first frame 37 | # x [bs, [root+n_joints], joint_dim(6), seqlen] 38 | frozen = x.detach().clone() 39 | frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] 40 | return frozen 41 | -------------------------------------------------------------------------------- /utils/filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | class OneEuroFilter: 5 | 6 | def __init__(self,te=1, min_cutoff=0.05, beta=0.004, d_cutoff=1.0): 7 | # The parameters. 8 | self.te = te 9 | self.min_cutoff = min_cutoff 10 | self.beta = beta 11 | self.d_cutoff = d_cutoff 12 | 13 | # Previous values. 14 | self.x_prev = None 15 | self.dx_prev = None 16 | self.a_d = self.smoothing_factor(self.d_cutoff) 17 | 18 | def smoothing_factor(self, cutoff): 19 | r = 2 * torch.pi * cutoff * self.te 20 | return r / (r + 1) 21 | 22 | def exponential_smoothing(self, a, x, x_prev): 23 | return a * x + (1 - a) * x_prev 24 | 25 | def filter_signal(self, x): 26 | if self.x_prev is None: 27 | self.x_prev = copy.deepcopy(x) 28 | self.dx_prev = torch.zeros_like(x) 29 | return self.x_prev 30 | 31 | dx = torch.zeros_like(x) 32 | dx = (x - self.x_prev) / self.te 33 | self.dx_prev = self.exponential_smoothing(self.a_d, dx, self.dx_prev) 34 | 35 | # The filtered signal. 36 | cutoff = self.min_cutoff + self.beta * abs(self.dx_prev) 37 | a = self.smoothing_factor(cutoff) 38 | self.x_prev = self.exponential_smoothing(a, x, self.x_prev) 39 | 40 | return self.x_prev -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | NAME: 'test' 2 | 3 | # model 4 | model_type: 'crossdiff' 5 | resume_checkpoint: ~ 6 | test_checkpoint: ~ 7 | layers: 4 8 | layers2: 6 9 | latent_dim: 512 10 | cond_mask_prob: 0.1 11 | w_m2j: 1 12 | w_j2m: 1 13 | 14 | # train 15 | num_epochs: 2000 16 | save_interval: 200 17 | eval_interval: 500 18 | log_interval: 50 19 | mode: 'pretrain' 20 | save_dir: 'save' 21 | running_mode: 'train' 22 | seed: 233 23 | batch_size: 32 24 | eval_during_train: True 25 | train_platform_type: 'TensorboardPlatform' 26 | lr: 1e-4 27 | weight_decay: 0 28 | lr_anneal_steps: 0 29 | eval_batch_size: 32 30 | unconstrained: False 31 | 32 | # data 33 | cut_2d: False 34 | motion_ratio: 1 35 | dataset: 'humanml' 36 | data_root: '/apdcephfs_cq3/share_1290939/zepingren/humanml3d' 37 | use_mean_joint: True 38 | joint_mask_ratio: 0 39 | ucf_ratio: 0 40 | ucf_root: '/apdcephfs_cq3/share_1290939/zepingren/ufc101' 41 | ucf_keys: [] 42 | 43 | # test 44 | test_generatefrom2d: False 45 | change_idx: 300 46 | classifier_free: False 47 | eval_part: 'all' 48 | test_mm: False 49 | 50 | # diffusion 51 | noise_schedule: 'cosine' 52 | diffusion_steps: 1000 53 | sigma_small: True 54 | 55 | # generate 56 | generate_3d: 1 57 | generate_2d: 0 58 | sample_times: 1 59 | captions: 60 | - 'a person is swimming' 61 | - 'person is dancing eloquently' 62 | - 'he is punching in a fight.' 63 | -------------------------------------------------------------------------------- /utils/parser_util.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import json 4 | from omegaconf import OmegaConf 5 | 6 | 7 | def parse_args(): 8 | parser = ArgumentParser() 9 | 10 | parser.add_argument("--debug", action='store_true') 11 | parser.add_argument("--cfg", 12 | default='configs/crossdiff_pre.yaml', type=str) 13 | 14 | args_ori, extras = parser.parse_known_args() 15 | 16 | args = OmegaConf.load('./configs/base.yaml') 17 | args = OmegaConf.merge(args, OmegaConf.load(args_ori.cfg), OmegaConf.from_cli(extras)) 18 | if 'LOCAL_RANK' in os.environ.keys(): 19 | args.local_rank = int(os.environ['LOCAL_RANK']) 20 | else: 21 | args.local_rank = -1 22 | args.cfg = args_ori.cfg 23 | args.debug = args_ori.debug 24 | 25 | args.save_dir = os.path.join(args.save_dir, args.NAME) 26 | args.ck_save_dir = os.path.join(args.save_dir, 'checkpoint') 27 | args.eval_dir = os.path.join(args.save_dir, 'eval') 28 | 29 | if args.local_rank > 0: 30 | args.train_platform_type = 'NoPlatform' 31 | 32 | if args.debug: 33 | args.train_platform_type = 'NoPlatform' 34 | 35 | args.batch_size = 4 36 | args.device = [0] 37 | 38 | if args.cond_mask_prob == 0: 39 | args.guidance_param = 1 40 | else: 41 | args.guidance_param = 2.5 42 | 43 | 44 | 45 | return args 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /visualize/joints2smpl/src/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Map joints Name to SMPL joints idx 4 | JOINT_MAP = { 5 | 'MidHip': 0, 6 | 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, 7 | 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, 8 | 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 'LHand': 22, 9 | 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 'RHand': 23, 10 | 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, 11 | 'LCollar':13, 'Rcollar' :14, 12 | 'Nose':24, 'REye':26, 'LEye':26, 'REar':27, 'LEar':28, 13 | 'LHeel': 31, 'RHeel': 34, 14 | 'OP RShoulder': 17, 'OP LShoulder': 16, 15 | 'OP RHip': 2, 'OP LHip': 1, 16 | 'OP Neck': 12, 17 | } 18 | 19 | full_smpl_idx = range(24) 20 | key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] 21 | 22 | 23 | AMASS_JOINT_MAP = { 24 | 'MidHip': 0, 25 | 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, 26 | 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, 27 | 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 28 | 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 29 | 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, 30 | 'LCollar':13, 'Rcollar' :14, 31 | } 32 | amass_idx = range(22) 33 | amass_smpl_idx = range(22) 34 | 35 | 36 | SMPL_MODEL_DIR = "./body_models/" 37 | GMM_MODEL_DIR = "./visualize/joints2smpl/smpl_models/" 38 | SMPL_MEAN_FILE = "./visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5" 39 | # for collsion 40 | Part_Seg_DIR = "./visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl" -------------------------------------------------------------------------------- /visualize/blender/tools.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import numpy as np 3 | 4 | 5 | def mesh_detect(data): 6 | # heuristic 7 | if data.shape[1] > 1000: 8 | return True 9 | return False 10 | 11 | 12 | # see this for more explanation 13 | # https://gist.github.com/iyadahmed/7c7c0fae03c40bd87e75dc7059e35377 14 | # This should be solved with new version of blender 15 | class ndarray_pydata(np.ndarray): 16 | def __bool__(self) -> bool: 17 | return len(self) > 0 18 | 19 | 20 | def load_numpy_vertices_into_blender(vertices, faces, name, mat): 21 | mesh = bpy.data.meshes.new(name) 22 | mesh.from_pydata(vertices, [], faces.view(ndarray_pydata)) 23 | mesh.validate() 24 | 25 | obj = bpy.data.objects.new(name, mesh) 26 | bpy.context.scene.collection.objects.link(obj) 27 | 28 | bpy.ops.object.select_all(action='DESELECT') 29 | obj.select_set(True) 30 | obj.active_material = mat 31 | bpy.context.view_layer.objects.active = obj 32 | bpy.ops.object.shade_smooth() 33 | bpy.ops.object.select_all(action='DESELECT') 34 | return True 35 | 36 | 37 | def delete_objs(names): 38 | if not isinstance(names, list): 39 | names = [names] 40 | # bpy.ops.object.mode_set(mode='OBJECT') 41 | bpy.ops.object.select_all(action='DESELECT') 42 | for obj in bpy.context.scene.objects: 43 | for name in names: 44 | if obj.name.startswith(name) or obj.name.endswith(name): 45 | obj.select_set(True) 46 | bpy.ops.object.delete() 47 | bpy.ops.object.select_all(action='DESELECT') 48 | -------------------------------------------------------------------------------- /visualize/blender/camera.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | 3 | 4 | class Camera: 5 | def __init__(self, *, first_root, mode, is_mesh): 6 | camera = bpy.data.objects['Camera'] 7 | 8 | ## initial position 9 | camera.location.x = 7.36 10 | camera.location.y = -6.93 11 | if is_mesh: 12 | # camera.location.z = 5.45 13 | camera.location.z = 5.6 14 | else: 15 | camera.location.z = 5.2 16 | 17 | # wider point of view 18 | if mode == "sequence": 19 | if is_mesh: 20 | camera.data.lens = 65 21 | else: 22 | camera.data.lens = 85 23 | elif mode == "frame": 24 | if is_mesh: 25 | camera.data.lens = 130 26 | else: 27 | camera.data.lens = 85 28 | elif mode == "video": 29 | if is_mesh: 30 | camera.data.lens = 110 31 | else: 32 | # avoid cutting person 33 | camera.data.lens = 85 34 | # camera.data.lens = 140 35 | 36 | # camera.location.x += 0.75 37 | 38 | self.mode = mode 39 | self.camera = camera 40 | 41 | self.camera.location.x += first_root[0] 42 | self.camera.location.y += first_root[1] 43 | 44 | self._root = first_root 45 | 46 | def update(self, newroot): 47 | delta_root = newroot - self._root 48 | 49 | self.camera.location.x += delta_root[0] 50 | self.camera.location.y += delta_root[1] 51 | 52 | self._root = newroot 53 | -------------------------------------------------------------------------------- /train/train_platforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class TrainPlatform: 4 | def __init__(self, save_dir): 5 | pass 6 | 7 | def report_scalar(self, name, value, iteration, group_name=None): 8 | pass 9 | 10 | def report_args(self, args, name): 11 | pass 12 | 13 | def close(self): 14 | pass 15 | 16 | 17 | class ClearmlPlatform(TrainPlatform): 18 | def __init__(self, save_dir): 19 | from clearml import Task 20 | path, name = os.path.split(save_dir) 21 | self.task = Task.init(project_name='motion_diffusion', 22 | task_name=name, 23 | output_uri=path) 24 | self.logger = self.task.get_logger() 25 | 26 | def report_scalar(self, name, value, iteration, group_name): 27 | self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) 28 | 29 | def report_args(self, args, name): 30 | self.task.connect(args, name=name) 31 | 32 | def close(self): 33 | self.task.close() 34 | 35 | 36 | class TensorboardPlatform(TrainPlatform): 37 | def __init__(self, save_dir): 38 | from torch.utils.tensorboard import SummaryWriter 39 | self.writer = SummaryWriter(log_dir=save_dir) 40 | 41 | def report_scalar(self, name, value, iteration, group_name=None): 42 | self.writer.add_scalar(f'{group_name}/{name}', value, iteration) 43 | 44 | def close(self): 45 | self.writer.close() 46 | 47 | 48 | class NoPlatform(TrainPlatform): 49 | def __init__(self, save_dir): 50 | pass 51 | 52 | 53 | -------------------------------------------------------------------------------- /visualize/render_mesh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from visualize import vis_utils 4 | import shutil 5 | from tqdm import tqdm 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input_path", type=str, required=True, help='stick figure mp4 file to be rendered.') 10 | parser.add_argument("--cuda", type=bool, default=True, help='') 11 | parser.add_argument("--device", type=int, default=0, help='') 12 | params = parser.parse_args() 13 | 14 | assert params.input_path.endswith('.mp4') 15 | parsed_name = os.path.basename(params.input_path).replace('.mp4', '').replace('sample', '').replace('rep', '') 16 | sample_i, rep_i = [int(e) for e in parsed_name.split('_')] 17 | npy_path = os.path.join(os.path.dirname(params.input_path), 'results.npy') 18 | out_npy_path = params.input_path.replace('.mp4', '_smpl_params.npy') 19 | assert os.path.exists(npy_path) 20 | results_dir = params.input_path.replace('.mp4', '_obj') 21 | if os.path.exists(results_dir): 22 | shutil.rmtree(results_dir) 23 | os.makedirs(results_dir) 24 | 25 | npy2obj = vis_utils.npy2obj(npy_path, sample_i, rep_i, 26 | device=params.device, cuda=params.cuda) 27 | 28 | print('Saving obj files to [{}]'.format(os.path.abspath(results_dir))) 29 | for frame_i in tqdm(range(npy2obj.real_num_frames)): 30 | npy2obj.save_obj(os.path.join(results_dir, 'frame{:03d}.obj'.format(frame_i)), frame_i) 31 | 32 | print('Saving SMPL params to [{}]'.format(os.path.abspath(out_npy_path))) 33 | npy2obj.save_npy(out_npy_path) 34 | -------------------------------------------------------------------------------- /render_blender.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import sys 4 | sys.path.insert(0,'/root/.local/lib/python3.9/site-packages') 5 | 6 | 7 | try: 8 | import bpy 9 | sys.path.append(os.path.dirname(bpy.data.filepath)) 10 | except ImportError: 11 | raise ImportError("Blender is not properly installed or not launch properly. See README.md to have instruction on how to install and use blender.") 12 | 13 | from visualize.blender.render import render 14 | from visualize.blender.video import Video 15 | import numpy as np 16 | import argparse 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--file", '-f', required=True, type=str) 22 | args = parser.parse_args() 23 | 24 | print(f'visualizing file {args.file}...') 25 | data = np.load(args.file) 26 | frames_folder = args.file[:-4] + '_frames' 27 | mp4_file = args.file[:-4] + '_blender.mp4' 28 | mode = 'video' #'sequence' #'video' 'frame 29 | if mode == 'video': 30 | frames_folder = args.file[:-4] + '_frames' 31 | mp4_file = args.file[:-4] + '_blender.mp4' 32 | elif mode == 'sequence': 33 | frames_folder = args.file[:-4] + '_blender.png' 34 | elif mode == 'frame': 35 | frames_folder = args.file[:-4] + '_blender2.png' 36 | 37 | 38 | render(data, frames_folder, mode=mode,) 39 | 40 | if mode == 'video': 41 | video = Video(frames_folder, fps=20) 42 | video.save(out_path=mp4_file) 43 | shutil.rmtree(frames_folder) 44 | print(f"remove tmp fig folder and save video in {mp4_file}") 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /utils/PYTORCH3D_LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For PyTorch3D software 4 | 5 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | """ 3 | Train a diffusion model on images. 4 | """ 5 | import os 6 | 7 | os.environ['CUDA_LAUNCH_BLOCKING']="1" 8 | 9 | import json 10 | import torch 11 | import torch.distributed as dist 12 | 13 | from utils.fixseed import fixseed 14 | from utils.parser_util import parse_args 15 | from train.training_loop import TrainLoop 16 | from data_loaders.get_data import get_dataset_loader 17 | from utils.model_util import create_model_and_diffusion 18 | from train.train_platforms import NoPlatform 19 | 20 | from diffusion import logger 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | args.running_mode = 'test' 26 | 27 | logger.configure(args.eval_dir, debug=args.debug, rank=args.local_rank) 28 | fixseed(args.seed + args.local_rank) 29 | 30 | train_platform = NoPlatform(args.eval_dir) 31 | 32 | if args.local_rank != -1: 33 | torch.cuda.set_device(args.local_rank) 34 | dist.init_process_group(backend='nccl') 35 | 36 | logger.log("creating data loader...") 37 | data = None 38 | test_data = get_dataset_loader('humanml', batch_size=32, split='test', args=args) 39 | 40 | 41 | logger.log("creating model and diffusion...") 42 | model, diffusion = create_model_and_diffusion(args) 43 | 44 | logger.log('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 45 | logger.log("Testing init...") 46 | try: 47 | tain_loop = TrainLoop(args, train_platform, model, diffusion, data, test_data) 48 | logger.log(f"Start testing... ") 49 | if not args.test_mm: 50 | tain_loop.multi_eval(test_limit=20, replication_times=20, test_main=True, test_mm=False) 51 | else: 52 | tain_loop.multi_eval(test_limit=5, replication_times=5, test_main=False, test_mm=True) 53 | except Exception as e: 54 | logger.error(e) 55 | raise e 56 | 57 | logger.log('!JOB COMPLETED!') 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /visualize/joints2smpl/README.md: -------------------------------------------------------------------------------- 1 | # joints2smpl 2 | fit SMPL model using 3D joints 3 | 4 | ## Prerequisites 5 | We have tested the code on Ubuntu 18.04/20.04 with CUDA 10.2/11.3 6 | 7 | ## Installation 8 | First you have to make sure that you have all dependencies in place. 9 | The simplest way to do is to use the [anaconda](https://www.anaconda.com/). 10 | 11 | You can create an anaconda environment called `fit3d` using 12 | ``` 13 | conda env create -f environment.yaml 14 | conda activate fit3d 15 | ``` 16 | 17 | ## Download SMPL models 18 | Download [SMPL Female and Male](https://smpl.is.tue.mpg.de/) and [SMPL Netural](https://smplify.is.tue.mpg.de/), and rename the files and extract them to `/smpl_models/smpl/`, eventually, the `/smpl_models` folder should have the following structure: 19 | ``` 20 | smpl_models 21 | └-- smpl 22 | └-- SMPL_FEMALE.pkl 23 | └-- SMPL_MALE.pkl 24 | └-- SMPL_NEUTRAL.pkl 25 | ``` 26 | 27 | ## Demo 28 | ### Demo for sequences 29 | python fit_seq.py --files test_motion2.npy 30 | 31 | The results will locate in ./demo/demo_results/ 32 | 33 | ## Citation 34 | If you find this project useful for your research, please consider citing: 35 | ``` 36 | @article{zuo2021sparsefusion, 37 | title={Sparsefusion: Dynamic human avatar modeling from sparse rgbd images}, 38 | author={Zuo, Xinxin and Wang, Sen and Zheng, Jiangbin and Yu, Weiwei and Gong, Minglun and Yang, Ruigang and Cheng, Li}, 39 | journal={IEEE Transactions on Multimedia}, 40 | volume={23}, 41 | pages={1617--1629}, 42 | year={2021} 43 | } 44 | ``` 45 | 46 | ## References 47 | We indicate if a function or script is borrowed externally inside each file. Here are some great resources we 48 | benefit: 49 | 50 | - Shape/Pose prior and some functions are borrowed from [VIBE](https://github.com/mkocabas/VIBE). 51 | - SMPL models and layer is from [SMPL-X model](https://github.com/vchoutas/smplx). 52 | - Some functions are borrowed from [HMR-pytorch](https://github.com/MandyMo/pytorch_HMR). 53 | -------------------------------------------------------------------------------- /data_loaders/get_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from data_loaders.tensors import t2m_collate, simple_collate 3 | from diffusion import logger 4 | 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | def get_dataset_class(name): 8 | if name in ["humanml"]: 9 | from data_loaders.humanml.data.dataset import HumanML3D 10 | return HumanML3D 11 | elif name == 'multi': 12 | from data_loaders.multidataset import MultiDataset 13 | return MultiDataset 14 | elif name == 'ufc': 15 | from data_loaders.ucf101 import UFC101 16 | return UFC101 17 | else: 18 | raise ValueError(f'Unsupported dataset name [{name}]') 19 | 20 | def get_collate_fn(split='train'): 21 | if split == 'train': 22 | return simple_collate 23 | else: 24 | return t2m_collate 25 | 26 | 27 | def get_dataset(name, split='train', args=None): 28 | DATA = get_dataset_class(name) 29 | if name in ["humanml", "multi"]: 30 | dataset = DATA(split=split, args=args) 31 | elif name in ['ufc']: 32 | dataset = DATA(args) 33 | 34 | return dataset 35 | 36 | 37 | def get_dataset_loader(name, batch_size=1, split='train', args=None): 38 | dataset = get_dataset(name, split, args=args) 39 | collate = get_collate_fn(split) 40 | 41 | if split == 'generate': 42 | return dataset 43 | 44 | if args.local_rank != -1: 45 | sampler = DistributedSampler(dataset) 46 | loader = DataLoader( 47 | dataset, batch_size=batch_size, 48 | num_workers=8, drop_last=True, collate_fn=collate, 49 | sampler=sampler 50 | ) 51 | logger.info(f'{name} has {len(dataset)} samples with {len(loader)} batch in ddp mode..') 52 | else: 53 | loader = DataLoader( 54 | dataset, batch_size=batch_size, shuffle=True, 55 | num_workers=8, drop_last=True, collate_fn=collate, 56 | ) 57 | 58 | logger.info(f'{name} has {len(dataset)} samples with {len(loader)} batch in single gpu mode..') 59 | 60 | 61 | 62 | 63 | return loader -------------------------------------------------------------------------------- /data_loaders/humanml/utils/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define a kinematic tree for the skeletal struture 4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 5 | 6 | kit_raw_offsets = np.array( 7 | [ 8 | [0, 0, 0], 9 | [0, 1, 0], 10 | [0, 1, 0], 11 | [0, 1, 0], 12 | [0, 1, 0], 13 | [1, 0, 0], 14 | [0, -1, 0], 15 | [0, -1, 0], 16 | [-1, 0, 0], 17 | [0, -1, 0], 18 | [0, -1, 0], 19 | [1, 0, 0], 20 | [0, -1, 0], 21 | [0, -1, 0], 22 | [0, 0, 1], 23 | [0, 0, 1], 24 | [-1, 0, 0], 25 | [0, -1, 0], 26 | [0, -1, 0], 27 | [0, 0, 1], 28 | [0, 0, 1] 29 | ] 30 | ) 31 | 32 | t2m_raw_offsets = np.array([[0,0,0], 33 | [1,0,0], 34 | [-1,0,0], 35 | [0,1,0], 36 | [0,-1,0], 37 | [0,-1,0], 38 | [0,1,0], 39 | [0,-1,0], 40 | [0,-1,0], 41 | [0,1,0], 42 | [0,0,1], 43 | [0,0,1], 44 | [0,1,0], 45 | [1,0,0], 46 | [-1,0,0], 47 | [0,0,1], 48 | [0,-1,0], 49 | [0,-1,0], 50 | [0,-1,0], 51 | [0,-1,0], 52 | [0,-1,0], 53 | [0,-1,0]]) 54 | 55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 58 | 59 | 60 | kit_tgt_skel_id = '03950' 61 | 62 | t2m_tgt_skel_id = '000021' 63 | 64 | -------------------------------------------------------------------------------- /data_loaders/humanml_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | HML_JOINT_NAMES = [ 4 | 'pelvis', 5 | 'left_hip', 6 | 'right_hip', 7 | 'spine1', 8 | 'left_knee', 9 | 'right_knee', 10 | 'spine2', 11 | 'left_ankle', 12 | 'right_ankle', 13 | 'spine3', 14 | 'left_foot', 15 | 'right_foot', 16 | 'neck', 17 | 'left_collar', 18 | 'right_collar', 19 | 'head', 20 | 'left_shoulder', 21 | 'right_shoulder', 22 | 'left_elbow', 23 | 'right_elbow', 24 | 'left_wrist', 25 | 'right_wrist', 26 | ] 27 | 28 | NUM_HML_JOINTS = len(HML_JOINT_NAMES) # 22 SMPLH body joints 29 | 30 | HML_LOWER_BODY_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', 'left_foot', 'right_foot',]] 31 | SMPL_UPPER_BODY_JOINTS = [i for i in range(len(HML_JOINT_NAMES)) if i not in HML_LOWER_BODY_JOINTS] 32 | 33 | 34 | # Recover global angle and positions for rotation data 35 | # root_rot_velocity (B, seq_len, 1) 36 | # root_linear_velocity (B, seq_len, 2) 37 | # root_y (B, seq_len, 1) 38 | # ric_data (B, seq_len, (joint_num - 1)*3) 39 | # rot_data (B, seq_len, (joint_num - 1)*6) 40 | # local_velocity (B, seq_len, joint_num*3) 41 | # foot contact (B, seq_len, 4) 42 | HML_ROOT_BINARY = np.array([True] + [False] * (NUM_HML_JOINTS-1)) 43 | HML_ROOT_MASK = np.concatenate(([True]*(1+2+1), 44 | HML_ROOT_BINARY[1:].repeat(3), 45 | HML_ROOT_BINARY[1:].repeat(6), 46 | HML_ROOT_BINARY.repeat(3), 47 | [False] * 4)) 48 | HML_LOWER_BODY_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_JOINTS for i in range(NUM_HML_JOINTS)]) 49 | HML_LOWER_BODY_MASK = np.concatenate(([True]*(1+2+1), 50 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(3), 51 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(6), 52 | HML_LOWER_BODY_JOINTS_BINARY.repeat(3), 53 | [True]*4)) 54 | HML_UPPER_BODY_MASK = ~HML_LOWER_BODY_MASK -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | """ 3 | Train a diffusion model on images. 4 | """ 5 | import os 6 | import json 7 | import torch 8 | import torch.distributed as dist 9 | from omegaconf import OmegaConf 10 | from utils.fixseed import fixseed 11 | from utils.parser_util import parse_args 12 | from train.training_loop import TrainLoop 13 | from data_loaders.get_data import get_dataset_loader 14 | from utils.model_util import create_model_and_diffusion 15 | from train.train_platforms import ClearmlPlatform, TensorboardPlatform, NoPlatform # required for the eval operation 16 | 17 | from diffusion import logger 18 | 19 | 20 | def main(): 21 | args = parse_args() 22 | 23 | 24 | logger.configure(args.save_dir, debug=args.debug, rank=args.local_rank) 25 | 26 | fixseed(args.seed + args.local_rank) 27 | train_platform_type = eval(args.train_platform_type) 28 | train_platform = train_platform_type(args.save_dir) 29 | train_platform.report_args(args, name='Args') 30 | 31 | args_path = os.path.join(args.save_dir, 'config.yaml') 32 | 33 | if not os.path.exists(args_path) and args.local_rank < 1 and not args.debug: 34 | OmegaConf.save(config=args, f=args_path) 35 | 36 | if args.local_rank != -1: 37 | torch.cuda.set_device(args.local_rank) 38 | dist.init_process_group(backend='nccl') 39 | 40 | 41 | logger.log("creating data loader...") 42 | data = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, args=args) 43 | test_data = get_dataset_loader('humanml', batch_size=32, split='test', args=args) 44 | 45 | logger.log("creating model and diffusion...") 46 | model, diffusion = create_model_and_diffusion(args) 47 | 48 | 49 | logger.log('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 50 | logger.log("Training init...") 51 | try: 52 | tain_loop = TrainLoop(args, train_platform, model, diffusion, data, test_data) 53 | logger.log("Start training...") 54 | tain_loop.run_loop() 55 | except Exception as e: 56 | logger.error(e) 57 | raise e 58 | train_platform.close() 59 | 60 | logger.log('!JOB COMPLETED!') 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /visualize/blender/video.py: -------------------------------------------------------------------------------- 1 | import moviepy.editor as mp 2 | import moviepy.video.fx.all as vfx 3 | import os 4 | import imageio 5 | 6 | 7 | def mask_png(frames): 8 | for frame in frames: 9 | im = imageio.imread(frame) 10 | im[im[:, :, 3] < 1, :] = 255 11 | imageio.imwrite(frame, im[:, :, 0:3]) 12 | return 13 | 14 | 15 | class Video: 16 | def __init__(self, frame_path: str, fps: float = 12.5, res="high"): 17 | frame_path = str(frame_path) 18 | self.fps = fps 19 | 20 | self._conf = {"codec": "libx264", 21 | "fps": self.fps, 22 | "audio_codec": "aac", 23 | "temp_audiofile": "temp-audio.m4a", 24 | "remove_temp": True} 25 | 26 | if res == "low": 27 | bitrate = "500k" 28 | else: 29 | bitrate = "5000k" 30 | 31 | self._conf = {"bitrate": bitrate, 32 | "fps": self.fps} 33 | 34 | # Load video 35 | # video = mp.VideoFileClip(video1_path, audio=False) 36 | # Load with frames 37 | frames = [os.path.join(frame_path, x) 38 | for x in sorted(os.listdir(frame_path))] 39 | 40 | # mask background white for videos 41 | mask_png(frames) 42 | 43 | video = mp.ImageSequenceClip(frames, fps=fps) 44 | self.video = video 45 | self.duration = video.duration 46 | 47 | def add_text(self, text): 48 | # needs ImageMagick 49 | video_text = mp.TextClip(text, 50 | font='Amiri', 51 | color='white', 52 | method='caption', 53 | align="center", 54 | size=(self.video.w, None), 55 | fontsize=30) 56 | video_text = video_text.on_color(size=(self.video.w, video_text.h + 5), 57 | color=(0, 0, 0), 58 | col_opacity=0.6) 59 | # video_text = video_text.set_pos('bottom') 60 | video_text = video_text.set_pos('top') 61 | 62 | self.video = mp.CompositeVideoClip([self.video, video_text]) 63 | 64 | def save(self, out_path): 65 | out_path = str(out_path) 66 | self.video.subclip(0, self.duration).write_videofile( 67 | out_path, **self._conf) 68 | -------------------------------------------------------------------------------- /data_loaders/humanml/utils/get_opt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | import re 4 | from os.path import join as pjoin 5 | from data_loaders.humanml.utils.word_vectorizer import POS_enumerator 6 | 7 | 8 | def is_float(numStr): 9 | flag = False 10 | numStr = str(numStr).strip().lstrip('-').lstrip('+') # 去除正数(+)、负数(-)符号 11 | try: 12 | reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$') 13 | res = reg.match(str(numStr)) 14 | if res: 15 | flag = True 16 | except Exception as ex: 17 | print("is_float() - error: " + str(ex)) 18 | return flag 19 | 20 | 21 | def is_number(numStr): 22 | flag = False 23 | numStr = str(numStr).strip().lstrip('-').lstrip('+') # 去除正数(+)、负数(-)符号 24 | if str(numStr).isdigit(): 25 | flag = True 26 | return flag 27 | 28 | 29 | def get_opt(opt_path, device): 30 | opt = Namespace() 31 | opt_dict = vars(opt) 32 | 33 | skip = ('-------------- End ----------------', 34 | '------------ Options -------------', 35 | '\n') 36 | with open(opt_path) as f: 37 | for line in f: 38 | if line.strip() not in skip: 39 | # print(line.strip()) 40 | key, value = line.strip().split(': ') 41 | if value in ('True', 'False'): 42 | opt_dict[key] = bool(value) 43 | elif is_float(value): 44 | opt_dict[key] = float(value) 45 | elif is_number(value): 46 | opt_dict[key] = int(value) 47 | else: 48 | opt_dict[key] = str(value) 49 | 50 | # print(opt) 51 | opt_dict['which_epoch'] = 'latest' 52 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 53 | opt.model_dir = pjoin(opt.save_root, 'model') 54 | opt.meta_dir = pjoin(opt.save_root, 'meta') 55 | 56 | if opt.dataset_name == 't2m': 57 | opt.joints_num = 22 58 | opt.dim_pose = 263 59 | opt.max_motion_length = 196 60 | elif opt.dataset_name == 'kit': 61 | opt.joints_num = 21 62 | opt.dim_pose = 251 63 | opt.max_motion_length = 196 64 | else: 65 | raise KeyError('Dataset not recognized') 66 | 67 | opt.dim_word = 300 68 | opt.num_classes = 200 // opt.unit_length 69 | opt.dim_pos_ohot = len(POS_enumerator) 70 | opt.is_train = False 71 | opt.is_continue = False 72 | opt.device = device 73 | 74 | return opt -------------------------------------------------------------------------------- /model/cfg_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from copy import deepcopy 5 | 6 | # A wrapper model for Classifier-free guidance **SAMPLING** only 7 | # https://arxiv.org/abs/2207.12598 8 | class ClassifierFreeSampleModel(nn.Module): 9 | 10 | def __init__(self, model): 11 | super().__init__() 12 | self.model = model # model is the actual model to run 13 | 14 | # assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions' 15 | # if args.mode in ['motion','motion_uselift']: 16 | 17 | # # pointers to inner model 18 | # # self.rot2xyz = self.model.rot2xyz 19 | # self.translation = self.model.translation 20 | # self.njoints = self.model.njoints 21 | # self.nfeats = self.model.nfeats 22 | # self.data_rep = self.model.data_rep 23 | # self.cond_mode = self.model.cond_mode 24 | # else: 25 | # self.rot2xyz = self.model.mdm3d.rot2xyz 26 | # self.translation = self.model.mdm3d.translation 27 | # self.njoints = self.model.mdm3d.njoints 28 | # self.nfeats = self.model.mdm3d.nfeats 29 | # self.data_rep = self.model.mdm3d.data_rep 30 | # self.cond_mode = self.model.mdm3d.cond_mode 31 | 32 | def forward(self, x, timesteps, y=None, 33 | return_m=True, return_j=False): 34 | # cond_mode = self.model.cond_mode 35 | # assert cond_mode in ['text', 'action'] 36 | # y_uncond = deepcopy(y) 37 | # y_uncond['uncond'] = True 38 | if 'force_mask' in y.keys(): 39 | out = self.model(x, timesteps, y['enc_text'], return_m=return_m, 40 | return_j=return_j, force_mask=True) 41 | else: 42 | out = self.model(x, timesteps, y['enc_text'], return_m=return_m, 43 | return_j=return_j) 44 | if 'scale' in y.keys(): 45 | out_uncond = self.model(x, timesteps, y['enc_text'], return_m=return_m, 46 | return_j=return_j, force_mask=True) 47 | if return_m: 48 | out['m'] = out_uncond['m'] + (y['scale'].view(-1, 1, 1, 1) * (out['m'] - out_uncond['m'])) 49 | if return_j: 50 | out['j'] = out_uncond['j'] + (y['scale'].view(-1, 1, 1, 1) * (out['j'] - out_uncond['j'])) 51 | return out 52 | 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | /data 132 | /.vscode 133 | /save -------------------------------------------------------------------------------- /visualize/blender/meshes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | 4 | from .materials import body_material 5 | 6 | # green 7 | GT_SMPL = body_material(0.0, 0.392, 0.158) 8 | 9 | # blue 10 | GEN_SMPL = body_material(0.022, 0.129, 0.439) 11 | 12 | 13 | class Meshes: 14 | def __init__(self, data, *, ours, mode, canonicalize, always_on_floor, oldrender=True, **kwargs): 15 | data = prepare_meshes(data, canonicalize=canonicalize, 16 | always_on_floor=always_on_floor) 17 | 18 | self.faces = np.load('./visualize/joints2smpl/smpl_models/faces.npy') 19 | self.data = data 20 | self.mode = mode 21 | self.oldrender = oldrender 22 | 23 | self.N = len(data) 24 | self.trajectory = data[:, :, [0, 1]].mean(1) 25 | 26 | if ours: 27 | self.mat = GT_SMPL 28 | else: 29 | self.mat = GEN_SMPL 30 | 31 | def get_sequence_mat(self, frac, ours): 32 | if ours: 33 | cmap = matplotlib.cm.get_cmap('Greens') 34 | else: 35 | cmap = matplotlib.cm.get_cmap('Blues') 36 | 37 | # cmap = matplotlib.cm.get_cmap('Oranges') 38 | # begin = 0.60 39 | # end = 0.90 40 | begin = 0.50 41 | end = 0.90 42 | rgbcolor = cmap(begin + (end-begin)*frac) 43 | mat = body_material(*rgbcolor, oldrender=self.oldrender) 44 | # mat = body_material(156/255, 156/255, 156/255) 45 | return mat 46 | 47 | def get_root(self, index): 48 | return self.data[index].mean(0) 49 | 50 | def get_mean_root(self): 51 | return self.data.mean((0, 1)) 52 | 53 | def load_in_blender(self, index, mat): 54 | vertices = self.data[index] 55 | faces = self.faces 56 | name = f"{str(index).zfill(4)}" 57 | 58 | from .tools import load_numpy_vertices_into_blender 59 | load_numpy_vertices_into_blender(vertices, faces, name, mat) 60 | 61 | return name 62 | 63 | def __len__(self): 64 | return self.N 65 | 66 | 67 | def prepare_meshes(data, canonicalize=True, always_on_floor=False): 68 | if canonicalize: 69 | print("No canonicalization for now") 70 | 71 | # fitted mesh do not need fixing axis 72 | # # fix axis 73 | # data[..., 1] = - data[..., 1] 74 | # data[..., 0] = - data[..., 0] 75 | 76 | # Swap axis (gravity=Z instead of Y) 77 | data = data[..., [2, 0, 1]] 78 | 79 | # Remove the floor 80 | data[..., 2] -= data[..., 2].min() 81 | 82 | # Put all the body on the floor 83 | if always_on_floor: 84 | data[..., 2] -= data[..., 2].min(1)[:, None] 85 | 86 | return data 87 | -------------------------------------------------------------------------------- /utils/model_util.py: -------------------------------------------------------------------------------- 1 | from model.mdm import MDM 2 | from diffusion import gaussian_diffusion as gd 3 | from diffusion.respace import SpacedDiffusion, space_timesteps 4 | from model.crossdiff import CrossDiff 5 | 6 | 7 | def create_model_and_diffusion(args): 8 | if args.model_type == 'MDM': 9 | model = MDM(**get_model_args(args)) 10 | elif args.model_type == 'crossdiff': 11 | model = CrossDiff(args) 12 | else: 13 | raise NotImplementedError(f'not implement {args.model_type}') 14 | 15 | diffusion = create_gaussian_diffusion(args) 16 | return model, diffusion 17 | 18 | 19 | def get_model_args(args): 20 | 21 | # default args 22 | clip_version = 'ViT-B/32' 23 | action_emb = 'tensor' 24 | cond_mode = 'text' 25 | num_actions = 1 26 | 27 | # SMPL defaults 28 | data_rep = 'rot6d' 29 | njoints = 25 30 | nfeats = 6 31 | 32 | if args.dataset in ['humanml', 'humanmlcut']: 33 | data_rep = 'hml_vec' 34 | njoints = 263 35 | nfeats = 1 36 | elif args.dataset == 'kit': 37 | data_rep = 'hml_vec' 38 | njoints = 251 39 | nfeats = 1 40 | 41 | return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions, 42 | 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, 43 | 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, 44 | 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode, 45 | 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch, 46 | 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset} 47 | 48 | 49 | def create_gaussian_diffusion(args): 50 | # default params 51 | predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! 52 | steps = 1000 53 | scale_beta = 1. # no scaling 54 | timestep_respacing = '' # can be used for ddim sampling, we don't use it. 55 | learn_sigma = False 56 | rescale_timesteps = False 57 | 58 | betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) 59 | loss_type = gd.LossType.MSE 60 | 61 | if not timestep_respacing: 62 | timestep_respacing = [steps] 63 | 64 | return SpacedDiffusion( 65 | use_timesteps=space_timesteps(steps, timestep_respacing), 66 | betas=betas, 67 | model_mean_type=( 68 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 69 | ), 70 | model_var_type=gd.ModelVarType.FIXED_SMALL, 71 | loss_type=loss_type, 72 | rescale_timesteps=rescale_timesteps, 73 | ) -------------------------------------------------------------------------------- /diffusion/losses.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | """ 3 | Helpers for various likelihood-based losses. These are ported from the original 4 | Ho et al. diffusion models codebase: 5 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 6 | """ 7 | 8 | import numpy as np 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /data_loaders/humanml/utils/word_vectorizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from os.path import join as pjoin 4 | 5 | POS_enumerator = { 6 | 'VERB': 0, 7 | 'NOUN': 1, 8 | 'DET': 2, 9 | 'ADP': 3, 10 | 'NUM': 4, 11 | 'AUX': 5, 12 | 'PRON': 6, 13 | 'ADJ': 7, 14 | 'ADV': 8, 15 | 'Loc_VIP': 9, 16 | 'Body_VIP': 10, 17 | 'Obj_VIP': 11, 18 | 'Act_VIP': 12, 19 | 'Desc_VIP': 13, 20 | 'OTHER': 14, 21 | } 22 | 23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', 24 | 'up', 'down', 'straight', 'curve') 25 | 26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') 27 | 28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') 29 | 30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', 31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', 32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') 33 | 34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', 35 | 'angrily', 'sadly') 36 | 37 | VIP_dict = { 38 | 'Loc_VIP': Loc_list, 39 | 'Body_VIP': Body_list, 40 | 'Obj_VIP': Obj_List, 41 | 'Act_VIP': Act_list, 42 | 'Desc_VIP': Desc_list, 43 | } 44 | 45 | 46 | class WordVectorizer(object): 47 | def __init__(self, meta_root, prefix): 48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) 49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) 50 | word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) 51 | self.word2vec = {w: vectors[word2idx[w]] for w in words} 52 | 53 | def _get_pos_ohot(self, pos): 54 | pos_vec = np.zeros(len(POS_enumerator)) 55 | if pos in POS_enumerator: 56 | pos_vec[POS_enumerator[pos]] = 1 57 | else: 58 | pos_vec[POS_enumerator['OTHER']] = 1 59 | return pos_vec 60 | 61 | def __len__(self): 62 | return len(self.word2vec) 63 | 64 | def __getitem__(self, item): 65 | word, pos = item.split('/') 66 | if word in self.word2vec: 67 | word_vec = self.word2vec[word] 68 | vip_pos = None 69 | for key, values in VIP_dict.items(): 70 | if word in values: 71 | vip_pos = key 72 | break 73 | if vip_pos is not None: 74 | pos_vec = self._get_pos_ohot(vip_pos) 75 | else: 76 | pos_vec = self._get_pos_ohot(pos) 77 | else: 78 | word_vec = self.word2vec['unk'] 79 | pos_vec = self._get_pos_ohot('OTHER') 80 | return word_vec, pos_vec -------------------------------------------------------------------------------- /visualize/blender/floor.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | from .materials import floor_mat 3 | 4 | 5 | def get_trajectory(data, is_mesh): 6 | if is_mesh: 7 | # mean of the vertices 8 | trajectory = data[:, :, [0, 1]].mean(1) 9 | else: 10 | # get the root joint 11 | trajectory = data[:, 0, [0, 1]] 12 | return trajectory 13 | 14 | 15 | def plot_floor(data, big_plane=True): 16 | # Create a floor 17 | minx, miny, _ = data.min(axis=(0, 1)) 18 | maxx, maxy, _ = data.max(axis=(0, 1)) 19 | minz = 0 20 | 21 | location = ((maxx + minx)/2, (maxy + miny)/2, 0) 22 | # a little bit bigger 23 | scale = (1.08*(maxx - minx)/2, 1.08*(maxy - miny)/2, 1) 24 | 25 | bpy.ops.mesh.primitive_plane_add(size=2, enter_editmode=False, align='WORLD', location=location, scale=(1, 1, 1)) 26 | 27 | bpy.ops.transform.resize(value=scale, orient_type='GLOBAL', orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), orient_matrix_type='GLOBAL', 28 | constraint_axis=(False, True, False), mirror=True, use_proportional_edit=False, 29 | proportional_edit_falloff='SMOOTH', proportional_size=1, use_proportional_connected=False, 30 | use_proportional_projected=False, release_confirm=True) 31 | obj = bpy.data.objects["Plane"] 32 | obj.name = "SmallPlane" 33 | obj.data.name = "SmallPlane" 34 | 35 | if not big_plane: 36 | obj.active_material = floor_mat(color=(0.2, 0.2, 0.2, 1)) 37 | else: 38 | obj.active_material = floor_mat(color=(0.1, 0.1, 0.1, 1)) 39 | 40 | if big_plane: 41 | location = ((maxx + minx)/2, (maxy + miny)/2, -0.01) 42 | bpy.ops.mesh.primitive_plane_add(size=2, enter_editmode=False, align='WORLD', location=location, scale=(1, 1, 1)) 43 | 44 | bpy.ops.transform.resize(value=[2*x for x in scale], orient_type='GLOBAL', orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), orient_matrix_type='GLOBAL', 45 | constraint_axis=(False, True, False), mirror=True, use_proportional_edit=False, 46 | proportional_edit_falloff='SMOOTH', proportional_size=1, use_proportional_connected=False, 47 | use_proportional_projected=False, release_confirm=True) 48 | 49 | obj = bpy.data.objects["Plane"] 50 | obj.name = "BigPlane" 51 | obj.data.name = "BigPlane" 52 | obj.active_material = floor_mat(color=(0.2, 0.2, 0.2, 1)) 53 | 54 | 55 | def show_traj(coords): 56 | # None 57 | # create the Curve Datablock 58 | curveData = bpy.data.curves.new('myCurve', type='CURVE') 59 | curveData.dimensions = '3D' 60 | curveData.resolution_u = 2 61 | 62 | # map coords to spline 63 | polyline = curveData.splines.new('POLY') 64 | polyline.points.add(len(coords)-1) 65 | for i, coord in enumerate(coords): 66 | x, y = coord 67 | polyline.points[i].co = (x, y, 0.001, 1) 68 | 69 | # create Object 70 | curveOB = bpy.data.objects.new('myCurve', curveData) 71 | curveData.bevel_depth = 0.01 72 | 73 | bpy.context.collection.objects.link(curveOB) 74 | -------------------------------------------------------------------------------- /visualize/vis_utils.py: -------------------------------------------------------------------------------- 1 | from model.rotation2xyz import Rotation2xyz 2 | import numpy as np 3 | from trimesh import Trimesh 4 | import os 5 | import torch 6 | from visualize.simplify_loc2rot import joints2smpl 7 | 8 | class npy2obj: 9 | def __init__(self, npy_path, sample_idx, rep_idx, device=0, cuda=True): 10 | self.npy_path = npy_path 11 | self.motions = np.load(self.npy_path, allow_pickle=True) 12 | if self.npy_path.endswith('.npz'): 13 | self.motions = self.motions['arr_0'] 14 | self.motions = self.motions[None][0] 15 | self.rot2xyz = Rotation2xyz(device='cpu') 16 | self.faces = self.rot2xyz.smpl_model.faces 17 | self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape 18 | self.opt_cache = {} 19 | self.sample_idx = sample_idx 20 | self.total_num_samples = self.motions['num_samples'] 21 | self.rep_idx = rep_idx 22 | self.absl_idx = self.rep_idx*self.total_num_samples + self.sample_idx 23 | self.num_frames = self.motions['motion'][self.absl_idx].shape[-1] 24 | self.j2s = joints2smpl(num_frames=self.num_frames, device_id=device, cuda=cuda) 25 | 26 | if self.nfeats == 3: 27 | print(f'Running SMPLify For sample [{sample_idx}], repetition [{rep_idx}], it may take a few minutes.') 28 | motion_tensor, opt_dict = self.j2s.joint2smpl(self.motions['motion'][self.absl_idx].transpose(2, 0, 1)) # [nframes, njoints, 3] 29 | self.motions['motion'] = motion_tensor.cpu().numpy() 30 | elif self.nfeats == 6: 31 | self.motions['motion'] = self.motions['motion'][[self.absl_idx]] 32 | self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape 33 | self.real_num_frames = self.motions['lengths'][self.absl_idx] 34 | 35 | self.vertices = self.rot2xyz(torch.tensor(self.motions['motion']), mask=None, 36 | pose_rep='rot6d', translation=True, glob=True, 37 | jointstype='vertices', 38 | # jointstype='smpl', # for joint locations 39 | vertstrans=True) 40 | self.root_loc = self.motions['motion'][:, -1, :3, :].reshape(1, 1, 3, -1) 41 | self.vertices += self.root_loc 42 | 43 | def get_vertices(self, sample_i, frame_i): 44 | return self.vertices[sample_i, :, :, frame_i].squeeze().tolist() 45 | 46 | def get_trimesh(self, sample_i, frame_i): 47 | return Trimesh(vertices=self.get_vertices(sample_i, frame_i), 48 | faces=self.faces) 49 | 50 | def save_obj(self, save_path, frame_i): 51 | mesh = self.get_trimesh(0, frame_i) 52 | with open(save_path, 'w') as fw: 53 | mesh.export(fw, 'obj') 54 | return save_path 55 | 56 | def save_npy(self, save_path): 57 | data_dict = { 58 | 'motion': self.motions['motion'][0, :, :, :self.real_num_frames], 59 | 'thetas': self.motions['motion'][0, :-1, :, :self.real_num_frames], 60 | 'root_translation': self.motions['motion'][0, -1, :3, :self.real_num_frames], 61 | 'faces': self.faces, 62 | 'vertices': self.vertices[0, :, :, :self.real_num_frames], 63 | 'text': self.motions['text'][0], 64 | 'length': self.real_num_frames, 65 | } 66 | np.save(save_path, data_dict) 67 | -------------------------------------------------------------------------------- /utils/load_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from diffusion import logger 3 | import clip 4 | import torch 5 | 6 | def sum_flat(tensor): 7 | """ 8 | Take the sum over all non-batch dimensions. 9 | """ 10 | return tensor.sum(dim=list(range(1, len(tensor.shape)))) 11 | 12 | def masked_l2(a, b, mask): 13 | # assuming a.shape == b.shape == bs, J, Jdim, seqlen 14 | # assuming mask.shape == bs, 1, 1, seqlen 15 | loss = (a - b) ** 2 16 | loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements 17 | if mask.shape[1] == 1: 18 | n_entries = a.shape[1] * a.shape[2] 19 | else: 20 | n_entries = 1 21 | non_zero_elements = sum_flat(mask) * n_entries 22 | mse_loss_val = loss / (non_zero_elements + 1e-8) 23 | return mse_loss_val 24 | 25 | 26 | def l2(a, b): 27 | loss = (a - b) ** 2 28 | loss = sum_flat(loss) # gives \sigma_euclidean over unmasked elements 29 | n_entries = a.shape[1] * a.shape[2] * a.shape[3] 30 | mse_loss_val = loss / (n_entries + 1e-8) 31 | return mse_loss_val 32 | 33 | 34 | def find_resume_checkpoint(dir): 35 | if not os.path.exists(dir): 36 | return '', -1 37 | checkpoints = sorted(os.listdir(dir), 38 | key=lambda x: int(x[5:-3]), 39 | reverse=True) 40 | if len(checkpoints) == 0: 41 | return '', -1 42 | else: 43 | start_epoch = int(checkpoints[0][5:9]) 44 | return os.path.join(dir, checkpoints[0]), start_epoch 45 | 46 | 47 | def log_loss_dict(losses): 48 | for key, values in losses.items(): 49 | logger.logkv_mean(key, values.mean().item()) 50 | # Log the quantiles (four quartiles, in particular). 51 | # for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 52 | # quartile = int(4 * sub_t / diffusion.num_timesteps) 53 | # logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 54 | 55 | def load_and_freeze_clip(): 56 | clip_model, clip_preprocess = clip.load('ViT-B/32', device='cpu', 57 | jit=False, download_root='/apdcephfs_cq3/share_1290939/zepingren/CLIP') # Must set jit=False for training 58 | clip.model.convert_weights( 59 | clip_model) # Actually this line is unnecessary since clip by default already on float16 60 | 61 | # Freeze CLIP weights 62 | clip_model.eval() 63 | for p in clip_model.parameters(): 64 | p.requires_grad = False 65 | 66 | return clip_model 67 | 68 | def encode_text(clip_model, raw_text, device): 69 | # raw_text - list (batch_size length) of strings with input text prompts 70 | max_text_len = 20 # Specific hardcoding for humanml dataset 71 | if max_text_len is not None: 72 | default_context_length = 77 73 | context_length = max_text_len + 2 # start_token + 20 + end_token 74 | assert context_length < default_context_length 75 | texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate 76 | # print('texts', texts.shape) 77 | zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) 78 | texts = torch.cat([texts, zero_pad], dim=1) 79 | # print('texts after pad', texts.shape, texts) 80 | else: 81 | texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate 82 | return clip_model.encode_text(texts).float() -------------------------------------------------------------------------------- /visualize/blender/scene.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | from .materials import plane_mat # noqa 3 | 4 | 5 | def setup_renderer(denoising=True, oldrender=True, accelerator="gpu", device=[0]): 6 | bpy.context.scene.render.engine = "CYCLES" 7 | bpy.data.scenes[0].render.engine = "CYCLES" 8 | if accelerator.lower() == "gpu": 9 | bpy.context.preferences.addons[ 10 | "cycles" 11 | ].preferences.compute_device_type = "CUDA" 12 | bpy.context.scene.cycles.device = "GPU" 13 | i = 0 14 | bpy.context.preferences.addons["cycles"].preferences.get_devices() 15 | for d in bpy.context.preferences.addons["cycles"].preferences.devices: 16 | if i in device: # gpu id 17 | d["use"] = 1 18 | print(d["name"], "".join(str(i) for i in device)) 19 | else: 20 | d["use"] = 0 21 | i += 1 22 | 23 | if denoising: 24 | bpy.context.scene.cycles.use_denoising = True 25 | 26 | bpy.context.scene.render.tile_x = 256 27 | bpy.context.scene.render.tile_y = 256 28 | bpy.context.scene.cycles.samples = 64 29 | # bpy.context.scene.cycles.denoiser = 'OPTIX' 30 | 31 | if not oldrender: 32 | bpy.context.scene.view_settings.view_transform = "Standard" 33 | bpy.context.scene.render.film_transparent = True 34 | bpy.context.scene.display_settings.display_device = "sRGB" 35 | bpy.context.scene.view_settings.gamma = 1.2 36 | bpy.context.scene.view_settings.exposure = -0.75 37 | 38 | 39 | # Setup scene 40 | def setup_scene( 41 | res="high", denoising=True, oldrender=True, accelerator="gpu", device=[0] 42 | ): 43 | scene = bpy.data.scenes["Scene"] 44 | assert res in ["ultra", "high", "med", "low"] 45 | if res == "high": 46 | scene.render.resolution_x = 1280 47 | scene.render.resolution_y = 1024 48 | elif res == "med": 49 | scene.render.resolution_x = 1280 // 2 50 | scene.render.resolution_y = 1024 // 2 51 | elif res == "low": 52 | scene.render.resolution_x = 1280 // 4 53 | scene.render.resolution_y = 1024 // 4 54 | elif res == "ultra": 55 | scene.render.resolution_x = 1280 * 2 56 | scene.render.resolution_y = 1024 * 2 57 | 58 | scene.render.film_transparent= True 59 | world = bpy.data.worlds["World"] 60 | world.use_nodes = True 61 | bg = world.node_tree.nodes["Background"] 62 | bg.inputs[0].default_value[:3] = (1.0, 1.0, 1.0) 63 | bg.inputs[1].default_value = 1.0 64 | 65 | # Remove default cube 66 | if "Cube" in bpy.data.objects: 67 | bpy.data.objects["Cube"].select_set(True) 68 | bpy.ops.object.delete() 69 | 70 | bpy.ops.object.light_add( 71 | type="SUN", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) 72 | ) 73 | bpy.data.objects["Sun"].data.energy = 1.5 74 | 75 | # rotate camera 76 | bpy.ops.object.empty_add( 77 | type="PLAIN_AXES", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) 78 | ) 79 | bpy.ops.transform.resize( 80 | value=(10, 10, 10), 81 | orient_type="GLOBAL", 82 | orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), 83 | orient_matrix_type="GLOBAL", 84 | mirror=True, 85 | use_proportional_edit=False, 86 | proportional_edit_falloff="SMOOTH", 87 | proportional_size=1, 88 | use_proportional_connected=False, 89 | use_proportional_projected=False, 90 | ) 91 | bpy.ops.object.select_all(action="DESELECT") 92 | 93 | setup_renderer( 94 | denoising=denoising, oldrender=oldrender, accelerator=accelerator, device=device 95 | ) 96 | return scene 97 | -------------------------------------------------------------------------------- /model/rotation2xyz.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | import torch 3 | import utils.rotation_conversions as geometry 4 | 5 | 6 | from model.smpl import SMPL, JOINTSTYPE_ROOT 7 | # from .get_model import JOINTSTYPES 8 | JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] 9 | 10 | 11 | class Rotation2xyz: 12 | def __init__(self, device, dataset='amass'): 13 | self.device = device 14 | self.dataset = dataset 15 | self.smpl_model = SMPL().eval().to(device) 16 | 17 | def __call__(self, x, mask, pose_rep, translation, glob, 18 | jointstype, vertstrans, betas=None, beta=0, 19 | glob_rot=None, get_rotations_back=False, **kwargs): 20 | if pose_rep == "xyz": 21 | return x 22 | 23 | if mask is None: 24 | mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device) 25 | 26 | if not glob and glob_rot is None: 27 | raise TypeError("You must specify global rotation if glob is False") 28 | 29 | if jointstype not in JOINTSTYPES: 30 | raise NotImplementedError("This jointstype is not implemented.") 31 | 32 | if translation: 33 | x_translations = x[:, -1, :3] 34 | x_rotations = x[:, :-1] 35 | else: 36 | x_rotations = x 37 | 38 | x_rotations = x_rotations.permute(0, 3, 1, 2) 39 | nsamples, time, njoints, feats = x_rotations.shape 40 | 41 | # Compute rotations (convert only masked sequences output) 42 | if pose_rep == "rotvec": 43 | rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) 44 | elif pose_rep == "rotmat": 45 | rotations = x_rotations[mask].view(-1, njoints, 3, 3) 46 | elif pose_rep == "rotquat": 47 | rotations = geometry.quaternion_to_matrix(x_rotations[mask]) 48 | elif pose_rep == "rot6d": 49 | rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) 50 | else: 51 | raise NotImplementedError("No geometry for this one.") 52 | 53 | if not glob: 54 | global_orient = torch.tensor(glob_rot, device=x.device) 55 | global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3) 56 | global_orient = global_orient.repeat(len(rotations), 1, 1, 1) 57 | else: 58 | global_orient = rotations[:, 0] 59 | rotations = rotations[:, 1:] 60 | 61 | if betas is None: 62 | betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas], 63 | dtype=rotations.dtype, device=rotations.device) 64 | betas[:, 1] = beta 65 | # import ipdb; ipdb.set_trace() 66 | out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas) 67 | 68 | # get the desirable joints 69 | joints = out[jointstype] 70 | 71 | x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype) 72 | x_xyz[~mask] = 0 73 | x_xyz[mask] = joints 74 | 75 | x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() 76 | 77 | # the first translation root at the origin on the prediction 78 | if jointstype != "vertices": 79 | rootindex = JOINTSTYPE_ROOT[jointstype] 80 | x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] 81 | 82 | if translation and vertstrans: 83 | # the first translation root at the origin 84 | x_translations = x_translations - x_translations[:, :, [0]] 85 | 86 | # add the translation to all the joints 87 | x_xyz = x_xyz + x_translations[:, None, :, :] 88 | 89 | if get_rotations_back: 90 | return x_xyz, rotations, global_orient 91 | else: 92 | return x_xyz 93 | -------------------------------------------------------------------------------- /visualize/motions2hik.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from utils.rotation_conversions import rotation_6d_to_matrix, matrix_to_euler_angles 5 | from visualize.simplify_loc2rot import joints2smpl 6 | 7 | """ 8 | Utility function to convert model output to a representation used by HumanIK skeletons in Maya and Motion Builder 9 | by converting joint positions to joint rotations in degrees. Based on visualize.vis_utils.npy2obj 10 | """ 11 | 12 | # Mapping of SMPL joint index to HIK joint Name 13 | JOINT_MAP = [ 14 | 'Hips', 15 | 'LeftUpLeg', 16 | 'RightUpLeg', 17 | 'Spine', 18 | 'LeftLeg', 19 | 'RightLeg', 20 | 'Spine1', 21 | 'LeftFoot', 22 | 'RightFoot', 23 | 'Spine2', 24 | 'LeftToeBase', 25 | 'RightToeBase', 26 | 'Neck', 27 | 'LeftShoulder', 28 | 'RightShoulder', 29 | 'Head', 30 | 'LeftArm', 31 | 'RightArm', 32 | 'LeftForeArm', 33 | 'RightForeArm', 34 | 'LeftHand', 35 | 'RightHand' 36 | ] 37 | 38 | 39 | def motions2hik(motions, device=0, cuda=True): 40 | """ 41 | Utility function to convert model output to a representation used by HumanIK skeletons in Maya and Motion Builder 42 | by converting joint positions to joint rotations in degrees. Based on visualize.vis_utils.npy2obj 43 | 44 | :param motions: numpy array containing MDM model output [num_reps, num_joints, num_params (xyz), num_frames 45 | :param device: 46 | :param cuda: 47 | 48 | :returns: JSON serializable dict to be used with the Replicate API implementation 49 | """ 50 | 51 | nreps, njoints, nfeats, nframes = motions.shape 52 | j2s = joints2smpl(num_frames=nframes, device_id=device, cuda=cuda) 53 | 54 | thetas = [] 55 | root_translation = [] 56 | for rep_idx in range(nreps): 57 | rep_motions = motions[rep_idx].transpose(2, 0, 1) # [nframes, njoints, 3] 58 | 59 | if nfeats == 3: 60 | print(f'Running SMPLify for repetition [{rep_idx + 1}] of {nreps}, it may take a few minutes.') 61 | motion_tensor, opt_dict = j2s.joint2smpl(rep_motions) # [nframes, njoints, 3] 62 | motion = motion_tensor.cpu().numpy() 63 | 64 | elif nfeats == 6: 65 | motion = rep_motions 66 | thetas.append(rep_motions) 67 | 68 | # Convert 6D rotation representation to Euler angles 69 | thetas_6d = motion[0, :-1, :, :nframes].transpose(2, 0, 1) # [nframes, njoints, 6] 70 | thetas_deg = [] 71 | for frame, d6 in enumerate(thetas_6d): 72 | thetas_deg.append([_rotation_6d_to_euler(d6)]) 73 | 74 | thetas.append([np.concatenate(thetas_deg, axis=0)]) 75 | root_translation.append([motion[0, -1, :3, :nframes].transpose(1, 0)]) # [nframes, 3] 76 | 77 | thetas = np.concatenate(thetas, axis=0)[:nframes] 78 | root_translation = np.concatenate(root_translation, axis=0)[:nframes] 79 | 80 | data_dict = { 81 | 'joint_map': JOINT_MAP, 82 | 'thetas': thetas.tolist(), # [nreps, nframes, njoints, 3 (deg)] 83 | 'root_translation': root_translation.tolist(), # [nreps, nframes, 3 (xyz)] 84 | } 85 | 86 | return data_dict 87 | 88 | 89 | def _rotation_6d_to_euler(d6): 90 | """ 91 | Converts 6D rotation representation by Zhou et al. [1] to euler angles 92 | using Gram--Schmidt orthogonalisation per Section B of [1]. 93 | 94 | :param d6: numpy Array 6D rotation representation, of size (*, 6) 95 | :returns: JSON serializable dict to be used with the Replicate API implementation 96 | :returns: euler angles in degrees as a numpy array with shape (*, 3) 97 | """ 98 | rot_mat = rotation_6d_to_matrix(torch.tensor(d6)) 99 | rot_eul_rad = matrix_to_euler_angles(rot_mat, 'XYZ') 100 | eul_deg = torch.rad2deg(rot_eul_rad).numpy() 101 | 102 | return eul_deg 103 | 104 | -------------------------------------------------------------------------------- /model/smpl.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | import numpy as np 3 | import torch 4 | 5 | import contextlib 6 | 7 | from smplx import SMPLLayer as _SMPLLayer 8 | from smplx.lbs import vertices2joints 9 | 10 | 11 | # action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38] 12 | # change 0 and 8 13 | action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38] 14 | 15 | from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA 16 | 17 | JOINTSTYPE_ROOT = {"a2m": 0, # action2motion 18 | "smpl": 0, 19 | "a2mpl": 0, # set(smpl, a2m) 20 | "vibe": 8} # 0 is the 8 position: OP MidHip below 21 | 22 | JOINT_MAP = { 23 | 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, 24 | 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, 25 | 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, 26 | 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, 27 | 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, 28 | 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, 29 | 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, 30 | 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, 31 | 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, 32 | 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, 33 | 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, 34 | 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, 35 | 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, 36 | 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, 37 | 'Spine (H36M)': 51, 'Jaw (H36M)': 52, 38 | 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, 39 | 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 40 | } 41 | 42 | JOINT_NAMES = [ 43 | 'OP Nose', 'OP Neck', 'OP RShoulder', 44 | 'OP RElbow', 'OP RWrist', 'OP LShoulder', 45 | 'OP LElbow', 'OP LWrist', 'OP MidHip', 46 | 'OP RHip', 'OP RKnee', 'OP RAnkle', 47 | 'OP LHip', 'OP LKnee', 'OP LAnkle', 48 | 'OP REye', 'OP LEye', 'OP REar', 49 | 'OP LEar', 'OP LBigToe', 'OP LSmallToe', 50 | 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', 51 | 'Right Ankle', 'Right Knee', 'Right Hip', 52 | 'Left Hip', 'Left Knee', 'Left Ankle', 53 | 'Right Wrist', 'Right Elbow', 'Right Shoulder', 54 | 'Left Shoulder', 'Left Elbow', 'Left Wrist', 55 | 'Neck (LSP)', 'Top of Head (LSP)', 56 | 'Pelvis (MPII)', 'Thorax (MPII)', 57 | 'Spine (H36M)', 'Jaw (H36M)', 58 | 'Head (H36M)', 'Nose', 'Left Eye', 59 | 'Right Eye', 'Left Ear', 'Right Ear' 60 | ] 61 | 62 | 63 | # adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints 64 | class SMPL(_SMPLLayer): 65 | """ Extension of the official SMPL implementation to support more joints """ 66 | 67 | def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): 68 | kwargs["model_path"] = model_path 69 | 70 | # remove the verbosity for the 10-shapes beta parameters 71 | with contextlib.redirect_stdout(None): 72 | super(SMPL, self).__init__(**kwargs) 73 | 74 | J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) 75 | self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) 76 | vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) 77 | a2m_indexes = vibe_indexes[action2motion_joints] 78 | smpl_indexes = np.arange(24) 79 | a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes]) 80 | 81 | self.maps = {"vibe": vibe_indexes, 82 | "a2m": a2m_indexes, 83 | "smpl": smpl_indexes, 84 | "a2mpl": a2mpl_indexes} 85 | 86 | def forward(self, *args, **kwargs): 87 | smpl_output = super(SMPL, self).forward(*args, **kwargs) 88 | 89 | extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) 90 | all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1) 91 | 92 | output = {"vertices": smpl_output.vertices} 93 | 94 | for joinstype, indexes in self.maps.items(): 95 | output[joinstype] = all_joints[:, indexes] 96 | 97 | return output -------------------------------------------------------------------------------- /data_loaders/humanml/motion_loaders/model_motion_loaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | from data_loaders.humanml.utils.get_opt import get_opt 3 | from data_loaders.humanml.motion_loaders.comp_v6_model_dataset import CompMDMGeneratedDataset 4 | from data_loaders.humanml.utils.word_vectorizer import WordVectorizer 5 | import numpy as np 6 | from torch.utils.data._utils.collate import default_collate 7 | 8 | 9 | def collate_fn(batch): 10 | batch.sort(key=lambda x: x[3], reverse=True) 11 | return default_collate(batch) 12 | 13 | 14 | class MMGeneratedDataset(Dataset): 15 | def __init__(self, opt, motion_dataset, w_vectorizer): 16 | self.opt = opt 17 | self.dataset = motion_dataset.mm_generated_motion 18 | self.w_vectorizer = w_vectorizer 19 | 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | def __getitem__(self, item): 24 | data = self.dataset[item] 25 | mm_motions = data['mm_motions'] 26 | m_lens = [] 27 | motions = [] 28 | for mm_motion in mm_motions: 29 | m_lens.append(mm_motion['length']) 30 | motion = mm_motion['motion'] 31 | # We don't need the following logic because our sample func generates the full tensor anyway: 32 | # if len(motion) < self.opt.max_motion_length: 33 | # motion = np.concatenate([motion, 34 | # np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1])) 35 | # ], axis=0) 36 | motion = motion[None, :] 37 | motions.append(motion) 38 | m_lens = np.array(m_lens, dtype=np.int) 39 | motions = np.concatenate(motions, axis=0) 40 | sort_indx = np.argsort(m_lens)[::-1].copy() 41 | # print(m_lens) 42 | # print(sort_indx) 43 | # print(m_lens[sort_indx]) 44 | m_lens = m_lens[sort_indx] 45 | motions = motions[sort_indx] 46 | return motions, m_lens 47 | 48 | 49 | 50 | # def get_motion_loader(opt_path, batch_size, ground_truth_dataset, mm_num_samples, mm_num_repeats, device): 51 | # opt = get_opt(opt_path, device) 52 | 53 | # # Currently the configurations of two datasets are almost the same 54 | # if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': 55 | # w_vectorizer = WordVectorizer('./glove', 'our_vab') 56 | # else: 57 | # raise KeyError('Dataset not recognized!!') 58 | # print('Generating %s ...' % opt.name) 59 | 60 | # if 'v6' in opt.name: 61 | # dataset = CompV6GeneratedDataset(opt, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats) 62 | # else: 63 | # raise KeyError('Dataset not recognized!!') 64 | 65 | # mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) 66 | 67 | # motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4) 68 | # mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) 69 | 70 | # print('Generated Dataset Loading Completed!!!') 71 | 72 | # return motion_loader, mm_motion_loader 73 | 74 | # our loader 75 | def get_mdm_loader(model, diffusion, batch_size, ground_truth_loader, 76 | mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale): 77 | opt = { 78 | 'name': 'test', # FIXME 79 | } 80 | print('Generating %s ...' % opt['name']) 81 | # dataset = CompMDMGeneratedDataset(opt, ground_truth_dataset, ground_truth_dataset.w_vectorizer, mm_num_samples, mm_num_repeats) 82 | dataset = CompMDMGeneratedDataset(model, diffusion, ground_truth_loader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale) 83 | 84 | mm_dataset = MMGeneratedDataset(opt, dataset, ground_truth_loader.dataset.w_vectorizer) 85 | 86 | # NOTE: bs must not be changed! this will cause a bug in R precision calc! 87 | motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4) 88 | mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) 89 | 90 | print('Generated Dataset Loading Completed!!!') 91 | 92 | return motion_loader, mm_motion_loader -------------------------------------------------------------------------------- /data_loaders/ucf101.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | from torch.utils import data 4 | import numpy as np 5 | import os 6 | import json 7 | import random 8 | 9 | from diffusion import logger 10 | 11 | class UCF101(data.Dataset): 12 | def __init__(self, args): 13 | 14 | self.dataset_name = 'ucf101' 15 | self.max_motion_length = 196 16 | self.use_mean_joint = args.use_mean_joint 17 | self.joint_mask_ratio = args.joint_mask_ratio 18 | self.mask_joint = self.joint_mask_ratio > 0 19 | 20 | joints_dir = os.path.join(args.ucf_root, 'complicate_2d') 21 | text_json = os.path.join(args.ucf_root, 'text.json') 22 | self.mean_joint = np.load(os.path.join(args.data_root, 'Mean_complicate2d.npy')) 23 | self.std_joint = np.load(os.path.join(args.data_root, 'Std_complicate2d.npy')) 24 | 25 | with open(text_json) as f: 26 | self.text_dict = json.load(f) 27 | 28 | file_list = os.listdir(joints_dir) 29 | 30 | self.data = [] 31 | if not args.ucf_keys: 32 | select = False 33 | else: 34 | select = True 35 | 36 | # count = 0 37 | for file in file_list: 38 | tag = file.split('_')[1] 39 | if (select and tag not in args.ucf_keys) or tag not in self.text_dict.keys(): 40 | continue 41 | 42 | joint = np.load(os.path.join(joints_dir, file)).astype(np.float32) 43 | if len(joint) < 40 or len(joint) > 196: 44 | continue 45 | 46 | # mask[:,[0,1,86,87]] = 0 47 | if joint.shape[-1] == 134: 48 | mask = np.ones_like(joint) 49 | else: 50 | mask = joint[:,134:] 51 | joint = joint[:,:134] 52 | self.data.append({'joint': joint, 53 | 'tag': tag, 54 | 'm_length': len(joint), 55 | 'mask':mask 56 | }) 57 | # count += 1 58 | # if count > 5: 59 | # break 60 | 61 | logger.info(f'UCF101 dataset has {self.__len__()} samples..') 62 | 63 | def __len__(self): 64 | return len(self.data) 65 | 66 | def __getitem__(self, index): 67 | joint_dict = self.data[index] 68 | joint, tag, m_length = joint_dict['joint'], joint_dict['tag'], joint_dict['m_length'] 69 | joint_mask = joint_dict['mask'] 70 | text = random.choice(self.text_dict[tag]) 71 | 72 | # coin2 = np.random.choice(['single', 'single', 'double']) 73 | 74 | # if coin2 == 'double': 75 | # m_length = (m_length // 4 - 1) * 4 76 | # elif coin2 == 'single': 77 | # m_length = (m_length // 4) * 4 78 | # idx = random.randint(0, len(joint) - m_length) 79 | # joint = joint[idx:idx + m_length] 80 | # joint_mask = joint_mask[idx:idx + m_length] 81 | 82 | "Z Normalization" 83 | # joint_mask = np.ones_like(joint) 84 | 85 | if self.use_mean_joint: 86 | joint = (joint - self.mean_joint) / self.std_joint 87 | 88 | motion_mask = np.ones((m_length, 263), dtype=np.float32) 89 | 90 | if m_length < self.max_motion_length: 91 | joint = np.concatenate([joint, 92 | np.zeros((self.max_motion_length - m_length, joint.shape[1]), dtype=np.float32) 93 | ], axis=0) 94 | joint_mask = np.concatenate([joint_mask, 95 | np.zeros((self.max_motion_length - m_length, joint_mask.shape[1]), dtype=np.float32) 96 | ], axis=0) 97 | motion_mask = np.concatenate([motion_mask, 98 | np.zeros((self.max_motion_length - m_length, motion_mask.shape[1]), dtype=np.float32) 99 | ], axis=0) 100 | motion = np.zeros((self.max_motion_length, 263), dtype=np.float32) 101 | 102 | return (text, motion, m_length, joint, False, joint_mask, motion_mask) -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from utils.parser_util import parse_args 6 | from model.cfg_sampler import ClassifierFreeSampleModel 7 | from data_loaders.get_data import get_dataset_loader 8 | from utils.model_util import create_model_and_diffusion 9 | from data_loaders.humanml.scripts.motion_process import recover_from_ric 10 | from utils.load_utils import encode_text, find_resume_checkpoint, load_and_freeze_clip 11 | import data_loaders.humanml.utils.paramUtil as paramUtil 12 | from data_loaders.humanml.utils.plot_script import plot_3d_motion, plot_2d_motion 13 | from diffusion import logger 14 | 15 | 16 | def main(): 17 | args = parse_args() 18 | args.debug = True 19 | 20 | logger.configure(args.save_dir, debug=True) 21 | 22 | device = torch.device('cuda') 23 | 24 | dataset = get_dataset_loader('humanml', split='generate', args=args) 25 | 26 | model, diffusion = create_model_and_diffusion(args) 27 | print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 28 | 29 | if args.test_checkpoint: 30 | resume_checkpoint = args.test_checkpoint 31 | else: 32 | resume_checkpoint, last_epoch = find_resume_checkpoint(args.ck_save_dir) 33 | print(f"loading model from checkpoint: {resume_checkpoint}...") 34 | model.load_state_dict(torch.load(resume_checkpoint, map_location='cpu')) 35 | 36 | 37 | 38 | clip_model = load_and_freeze_clip() 39 | clip_model.cuda() 40 | 41 | os.makedirs(args.eval_dir, exist_ok=True) 42 | captions = list(args.captions) * args.sample_times 43 | 44 | 45 | args.nsamples = len(captions) 46 | 47 | model_kwargs = {} 48 | model_kwargs['enc_text'] = encode_text(clip_model, captions, device) 49 | 50 | if args.guidance_param != 1: 51 | model = ClassifierFreeSampleModel(model) 52 | model.cuda() 53 | model.eval() 54 | 55 | 56 | skeleton = paramUtil.t2m_kinematic_chain 57 | fps = 20 58 | os.makedirs(args.eval_dir, exist_ok=True) 59 | 60 | # directly generate 3D motion 61 | if args.generate_3d: 62 | sample = diffusion.p_sample_loop( 63 | model, 64 | (args.nsamples, 263, 1, 196), 65 | clip_denoised=False, 66 | model_kwargs=model_kwargs, 67 | skip_timesteps=0, 68 | init_image=None, 69 | progress=True, 70 | dump_steps=None, 71 | noise=None, 72 | const_noise=False, 73 | ) 74 | 75 | sample = sample[:,:,0].permute(0,2,1) 76 | sample = dataset.inv_transform(sample) 77 | sample = recover_from_ric(sample, 22).cpu().numpy() 78 | 79 | 80 | for i in range(args.nsamples): 81 | print(f'generating {i}') 82 | animation_save_path = os.path.join(args.eval_dir, f'{i}.mp4') 83 | motion = sample[i] 84 | caption = captions[i] 85 | np.save(os.path.join(args.eval_dir, f'{i}.npy'), motion) 86 | plot_3d_motion(animation_save_path, skeleton, motion, dataset=args.dataset, title=caption, fps=fps) 87 | 88 | 89 | 90 | # sample first in 2D domain and then in 3D domain 91 | if args.generate_2d: 92 | sample = diffusion.p_sample_loop( 93 | model, 94 | (args.nsamples, 134, 1, 196), 95 | clip_denoised=False, 96 | model_kwargs=model_kwargs, 97 | skip_timesteps=0, 98 | init_image=None, 99 | progress=True, 100 | dump_steps=None, 101 | noise=None, 102 | const_noise=False, 103 | ) 104 | 105 | t = torch.tensor([0]*args.nsamples, device=device) 106 | joint_predict = model.model(sample, t, model_kwargs['enc_text'], return_m=True, return_j=True, force_mask=True) 107 | 108 | sample = joint_predict['m'].detach() 109 | sample = sample[:,:,0].permute(0,2,1) 110 | sample = dataset.inv_transform(sample) 111 | sample = recover_from_ric(sample, 22).cpu().numpy() 112 | 113 | 114 | for i in range(args.nsamples): 115 | print(f'generating {i}') 116 | animation_save_path = os.path.join(args.eval_dir, f'{i}_3dfrom2d.mp4') 117 | motion = sample[i] 118 | caption = captions[i] 119 | np.save(os.path.join(args.eval_dir, f'{i}_3dfrom2d.npy'), motion) 120 | plot_3d_motion(animation_save_path, skeleton, motion, dataset=args.dataset, title=caption, fps=fps) 121 | 122 | 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /visualize/blender/render.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | 5 | import bpy 6 | import numpy as np 7 | 8 | from .camera import Camera 9 | from .floor import get_trajectory, plot_floor, show_traj 10 | from .sampler import get_frameidx 11 | from .scene import setup_scene # noqa 12 | from .tools import delete_objs, load_numpy_vertices_into_blender, mesh_detect 13 | from .vertices import prepare_vertices 14 | from .meshes import Meshes 15 | 16 | 17 | def prune_begin_end(data, perc): 18 | to_remove = int(len(data)*perc) 19 | if to_remove == 0: 20 | return data 21 | return data[to_remove:-to_remove] 22 | 23 | 24 | def render_current_frame(path): 25 | bpy.context.scene.render.filepath = path 26 | bpy.ops.render.render(use_viewport=True, write_still=True) 27 | 28 | def render(npydata, frames_folder, *, mode, gt=False, 29 | exact_frame=None, num=8, downsample=True, 30 | canonicalize=True, always_on_floor=False, denoising=True, 31 | oldrender=True,jointstype="mmm", res="high", init=True, 32 | accelerator='gpu',device=[0], pre_idx=None, ours=True): 33 | if init: 34 | # Setup the scene (lights / render engine / resolution etc) 35 | setup_scene(res=res, denoising=denoising, oldrender=oldrender,accelerator=accelerator,device=device) 36 | 37 | is_mesh = mesh_detect(npydata) 38 | 39 | # Put everything in this folder 40 | if mode == "video": 41 | if always_on_floor: 42 | frames_folder += "_of" 43 | os.makedirs(frames_folder, exist_ok=True) 44 | # if it is a mesh, it is already downsampled 45 | if downsample and not is_mesh: 46 | npydata = npydata[::8] 47 | elif mode == "sequence": 48 | img_name, ext = os.path.splitext(frames_folder) 49 | if always_on_floor: 50 | img_name += "_of" 51 | img_path = f"{img_name}{ext}" 52 | 53 | elif mode == "frame": 54 | img_name, ext = os.path.splitext(frames_folder) 55 | if always_on_floor: 56 | img_name += "_of" 57 | img_path = f"{img_name}_{exact_frame}{ext}" 58 | 59 | # remove X% of begining and end 60 | # as it is almost always static 61 | # in this part 62 | # if mode == "sequence": 63 | # perc = 0.2 64 | # npydata = prune_begin_end(npydata, perc) 65 | 66 | data = Meshes(npydata, ours=ours, mode=mode, 67 | canonicalize=canonicalize, 68 | always_on_floor=always_on_floor) 69 | 70 | 71 | # Number of frames possible to render 72 | nframes = len(data) 73 | 74 | # Show the trajectory 75 | # show_traj(data.trajectory) 76 | 77 | # Create a floor 78 | plot_floor(data.data, big_plane=False) 79 | 80 | # initialize the camera 81 | camera = Camera(first_root=data.get_root(0), mode=mode, is_mesh=is_mesh) 82 | 83 | frameidx = get_frameidx(mode=mode, nframes=nframes, 84 | exact_frame=exact_frame, 85 | frames_to_keep=num, pre_idx=pre_idx) 86 | 87 | nframes_to_render = len(frameidx) 88 | 89 | # center the camera to the middle 90 | if mode == "sequence": 91 | camera.update(data.get_mean_root()) 92 | 93 | imported_obj_names = [] 94 | for index, frameidx in enumerate(frameidx): 95 | if mode == "sequence": 96 | if nframes_to_render == 1: 97 | frac = 1 98 | else: 99 | frac = index / (nframes_to_render-1) 100 | mat = data.get_sequence_mat(frac, ours) 101 | else: 102 | mat = data.mat 103 | camera.update(data.get_root(frameidx)) 104 | 105 | islast = index == (nframes_to_render-1) 106 | 107 | objname = data.load_in_blender(frameidx, mat) 108 | name = f"{str(index).zfill(4)}" 109 | 110 | if mode == "video": 111 | path = os.path.join(frames_folder, f"frame_{name}.png") 112 | else: 113 | path = img_path 114 | 115 | if mode == "sequence": 116 | imported_obj_names.extend(objname) 117 | elif mode == "frame": 118 | camera.update(data.get_root(frameidx)) 119 | 120 | if mode != "sequence" or islast: 121 | render_current_frame(path) 122 | delete_objs(objname) 123 | 124 | # bpy.ops.wm.save_as_mainfile(filepath="/Users/mathis/TEMOS_github/male_line_test.blend") 125 | # exit() 126 | 127 | # remove every object created 128 | delete_objs(imported_obj_names) 129 | delete_objs(["Plane", "myCurve", "Cylinder"]) 130 | 131 | if mode == "video": 132 | return frames_folder 133 | else: 134 | return img_path 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /data_loaders/tensors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def lengths_to_mask(lengths, max_len): 4 | # max_len = max(lengths) 5 | mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) 6 | return mask 7 | 8 | 9 | def collate_tensors(batch): 10 | dims = batch[0].dim() 11 | max_size = [max([b.size(i) for b in batch]) for i in range(dims)] 12 | size = (len(batch),) + tuple(max_size) 13 | canvas = batch[0].new_zeros(size=size) 14 | for i, b in enumerate(batch): 15 | sub_tensor = canvas[i] 16 | for d in range(dims): 17 | sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) 18 | sub_tensor.add_(b) 19 | return canvas 20 | 21 | 22 | 23 | def t2m_collate(batch): 24 | notnone_batches = [b for b in batch if b is not None] 25 | notnone_batches.sort(key=lambda x: x[3], reverse=True) 26 | 27 | databatch = [torch.tensor(b[4].T).float().unsqueeze(1) for b in notnone_batches] 28 | jointbatch = [torch.tensor(b[7].T).float().unsqueeze(1) for b in notnone_batches] 29 | lenbatch = [b[5] for b in notnone_batches] 30 | 31 | databatchTensor = collate_tensors(databatch) 32 | jointbatchTensor = collate_tensors(jointbatch) 33 | lenbatchTensor = torch.as_tensor(lenbatch) 34 | maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) 35 | 36 | word_embeddings = torch.stack([torch.as_tensor(b[0]).float() for b in notnone_batches], dim=0) 37 | pos_one_hots = torch.stack([torch.as_tensor(b[1]).float() for b in notnone_batches], dim=0) 38 | 39 | new_batch = {'motion':databatchTensor, 40 | 'mask':maskbatchTensor, 41 | 'lengths': lenbatchTensor, 42 | 'joint': jointbatchTensor, 43 | 'text': [b[2] for b in notnone_batches], 44 | 'tokens': [b[6] for b in notnone_batches], 45 | 'valid': torch.as_tensor([b[8] for b in notnone_batches]), 46 | 'word_embeddings': word_embeddings, 47 | 'pos_one_hots': pos_one_hots, 48 | 'sent_len': torch.as_tensor([b[3] for b in notnone_batches]), 49 | } 50 | 51 | return new_batch 52 | 53 | def kit_collate(batch): 54 | notnone_batches = [b for b in batch if b is not None] 55 | notnone_batches.sort(key=lambda x: x[3], reverse=True) 56 | 57 | databatch = [torch.tensor(b[4].T).float().unsqueeze(1) for b in notnone_batches] 58 | lenbatch = [b[5] for b in notnone_batches] 59 | 60 | databatchTensor = collate_tensors(databatch) 61 | lenbatchTensor = torch.as_tensor(lenbatch) 62 | maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) 63 | 64 | word_embeddings = torch.stack([torch.as_tensor(b[0]).float() for b in notnone_batches], dim=0) 65 | pos_one_hots = torch.stack([torch.as_tensor(b[1]).float() for b in notnone_batches], dim=0) 66 | 67 | new_batch = {'motion':databatchTensor, 68 | 'mask':maskbatchTensor, 69 | 'lengths': lenbatchTensor, 70 | 'text': [b[2] for b in notnone_batches], 71 | 'tokens': [b[6] for b in notnone_batches], 72 | 'word_embeddings': word_embeddings, 73 | 'pos_one_hots': pos_one_hots, 74 | 'sent_len': torch.as_tensor([b[3] for b in notnone_batches]), 75 | } 76 | 77 | return new_batch 78 | 79 | def simple_collate(batch): 80 | notnone_batches = [b for b in batch if b is not None] 81 | 82 | databatch = [torch.tensor(b[1].T).float().unsqueeze(1) for b in notnone_batches] 83 | jointbatch = [torch.tensor(b[3].T).float().unsqueeze(1) for b in notnone_batches] 84 | lenbatch = [b[2] for b in notnone_batches] 85 | jointmaskbatch = [torch.tensor(b[5].T).float().unsqueeze(1) for b in notnone_batches] 86 | motionmaskbatch = [torch.tensor(b[6].T).float().unsqueeze(1) for b in notnone_batches] 87 | 88 | databatchTensor = collate_tensors(databatch) 89 | jointbatchTensor = collate_tensors(jointbatch) 90 | jointmaskbatchTensor = collate_tensors(jointmaskbatch) 91 | motionmaskbatchTensor = collate_tensors(motionmaskbatch) 92 | lenbatchTensor = torch.as_tensor(lenbatch) 93 | maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) 94 | 95 | 96 | new_batch = {'motion':databatchTensor, 97 | 'mask':maskbatchTensor, 98 | 'lengths': lenbatchTensor, 99 | 'joint': jointbatchTensor, 100 | 'joint_mask': jointmaskbatchTensor, 101 | 'motion_mask': motionmaskbatchTensor, 102 | 'text': [b[0] for b in notnone_batches], 103 | 'valid': torch.as_tensor([b[4] for b in notnone_batches]), 104 | } 105 | 106 | return new_batch 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CrossDiff(ECCV2024) 2 | 3 | ### [Project Page](https://wonderno.github.io/CrossDiff-webpage/) | [Arxiv](https://arxiv.org/abs/2312.10993) 4 | 5 | This is the official PyTorch implementation of the paper "Realistic Human Motion Generation with Cross-Diffusion Models". Our method leverages intricate 2D motion knowledge and builds a cross-diffusion mechanism to enhance 3D motion generation. 6 | 7 | ![teaser](https://github.com/wonderNo/crossdiff/blob/master/assets/teaser.png) 8 | 9 | ## 1 Setup 10 | 11 | ### 1.1 Environment 12 | 13 | This code has been tested with Python 3.8 and PyTorch 1.11. 14 | 15 | ```shell 16 | conda create -n crossdiff python=3.8 17 | conda activate crossdiff 18 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ### 1.2 Dependencies 23 | 24 | Execute the following script to download the necessary materials: 25 | 26 | ```shell 27 | mkdir data/ 28 | bash prepare/download_smpl_files.sh 29 | bash prepare/download_glove.sh 30 | bash prepare/download_t2m_evaluators.sh 31 | ``` 32 | 33 | ### 1.3 Pre-train model 34 | 35 | Run the script below to download the pre-trained model: 36 | 37 | ```shell 38 | bash prepare/download_pretrained_models.sh 39 | ``` 40 | 41 | ## 2 Train 42 | 43 | ### 2.1 Prepare data 44 | 45 | **HumanML3D** - Follow the instructions provided in the [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git). Afterward, execute the following command to obtain the corresponding 2D motion: 46 | ```shell 47 | python prepare/project.py --data_root YOUR_DATA_ROOT 48 | ``` 49 | Additionally, please set the `data_root` in the configuration file `configs/base.yaml` for subsequent training. 50 | 51 | **UCF101** - This dataset is used to train the model with real-world 2D motion. 52 | 53 | Download the original data from the [UCF101 project page](https://www.crcv.ucf.edu/data/UCF101.php#Results_on_UCF101). Then, estimate the 2D pose using the off-the-shelf model [ViTPose](https://github.com/ViTAE-Transformer/ViTPose) and process the 2D data in the same manner as HumanML3D. 54 | 55 | ### 2.2 Train the model 56 | 57 | For the first stage, execute the following command: 58 | 59 | ```shell 60 | python train.py --cfg configs/crossdiff_pre.yaml 61 | ``` 62 | The results will be stored in `./save/crossdiff_pre`. Locate the best checkpoint and set the `resume_checkpoint` in `configs/crossdiff_finetune.yaml`. 63 | 64 | For the second stage, run: 65 | ```shell 66 | python train.py --cfg configs/crossdiff_finetune.yaml 67 | ``` 68 | The final results will be saved in `./save/crossdiff_finetune` 69 | 70 | ## 3 Test 71 | 72 | After training, run the following command to test the model: 73 | ```shell 74 | python test.py --cfg configs/crossdiff_finetune.yaml 75 | ``` 76 | By default, the code will use the final model for testing. Alternatively, you can set the `test_checkpoint` in the configuration file to test a specific model. 77 | 78 | You may also configure the following options: 79 | * `test_mm`: Test Multimodality. 80 | * `eval_part`: Choose from `all`,`upper`, or `lower` to test metrics for different body parts. 81 | 82 | ## 4 Generate 83 | 84 | To generate motion from text, use: 85 | 86 | ```shell 87 | python generate.py --cfg configs/crossdiff_finetune.yaml test_checkpoint=./data/checkpoints/pretrain.pt 88 | ``` 89 | 90 | You can edit the text in the configuration file using the `captions` parameter. The output will be saved in `./save/crossdiff_finetune/eval`. Then, execute: 91 | 92 | ```shell 93 | python fit_smpl.py -f YOUR_KEYPOINT_FILE 94 | ``` 95 | This will fit the selected `.npy` file of body keypoints, and you will obtain the mesh file `_mesh.npy`. 96 | 97 | For visualizing SMPL results, refer to [MLD-Visualization](https://github.com/ChenFengYe/motion-latent-diffusion) and [TEMOS-Rendering motions](https://github.com/Mathux/TEMOS) for Blender setup. 98 | 99 | Run the following command to visualize SMPL: 100 | 101 | ```shell 102 | blender --background --python render_blender.py -- --file=YOUR_MESH_FILE 103 | ``` 104 | 105 | ## Acknowledgments 106 | 107 | We express our gratitude to [MDM](https://github.com/GuyTevet/motion-diffusion-model), [MLD](https://github.com/ChenFengYe/motion-latent-diffusion), [T2M-GPT](https://github.com/Mael-zys/T2M-GPT), [TEMOS](https://github.com/Mathux/TEMOS). Our code is partially adapted from their work. 108 | 109 | ## Bibtex 110 | 111 | If you find this code useful in your research, please cite: 112 | 113 | ``` 114 | @inproceedings{ren2025realistic, 115 | title={Realistic Human Motion Generation with Cross-Diffusion Models}, 116 | author={Ren, Zeping and Huang, Shaoli and Li, Xiu}, 117 | booktitle={European Conference on Computer Vision}, 118 | pages={345--362}, 119 | year={2025}, 120 | organization={Springer} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /visualize/joints2smpl/fit_seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import argparse 3 | import torch 4 | import os,sys 5 | from os import walk, listdir 6 | from os.path import isfile, join 7 | import numpy as np 8 | import joblib 9 | import smplx 10 | import trimesh 11 | import h5py 12 | from tqdm import tqdm 13 | 14 | sys.path.append(os.path.join(os.path.dirname(__file__), "src")) 15 | from smplify import SMPLify3D 16 | import config 17 | 18 | # parsing argmument 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--batchSize', type=int, default=1, 21 | help='input batch size') 22 | parser.add_argument('--num_smplify_iters', type=int, default=100, 23 | help='num of smplify iters') 24 | parser.add_argument('--cuda', type=bool, default=False, 25 | help='enables cuda') 26 | parser.add_argument('--gpu_ids', type=int, default=0, 27 | help='choose gpu ids') 28 | parser.add_argument('--num_joints', type=int, default=22, 29 | help='joint number') 30 | parser.add_argument('--joint_category', type=str, default="AMASS", 31 | help='use correspondence') 32 | parser.add_argument('--fix_foot', type=str, default="False", 33 | help='fix foot or not') 34 | parser.add_argument('--data_folder', type=str, default="./demo/demo_data/", 35 | help='data in the folder') 36 | parser.add_argument('--save_folder', type=str, default="./demo/demo_results/", 37 | help='results save folder') 38 | parser.add_argument('--files', type=str, default="test_motion.npy", 39 | help='files use') 40 | opt = parser.parse_args() 41 | print(opt) 42 | 43 | # ---load predefined something 44 | device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu") 45 | print(config.SMPL_MODEL_DIR) 46 | smplmodel = smplx.create(config.SMPL_MODEL_DIR, 47 | model_type="smpl", gender="neutral", ext="pkl", 48 | batch_size=opt.batchSize).to(device) 49 | 50 | # ## --- load the mean pose as original ---- 51 | smpl_mean_file = config.SMPL_MEAN_FILE 52 | 53 | file = h5py.File(smpl_mean_file, 'r') 54 | init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).float() 55 | init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).float() 56 | cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).to(device) 57 | # 58 | pred_pose = torch.zeros(opt.batchSize, 72).to(device) 59 | pred_betas = torch.zeros(opt.batchSize, 10).to(device) 60 | pred_cam_t = torch.zeros(opt.batchSize, 3).to(device) 61 | keypoints_3d = torch.zeros(opt.batchSize, opt.num_joints, 3).to(device) 62 | 63 | # # #-------------initialize SMPLify 64 | smplify = SMPLify3D(smplxmodel=smplmodel, 65 | batch_size=opt.batchSize, 66 | joints_category=opt.joint_category, 67 | num_iters=opt.num_smplify_iters, 68 | device=device) 69 | #print("initialize SMPLify3D done!") 70 | 71 | 72 | purename = os.path.splitext(opt.files)[0] 73 | # --- load data --- 74 | data = np.load(opt.data_folder + "/" + purename + ".npy") # [nframes, njoints, 3] 75 | 76 | dir_save = os.path.join(opt.save_folder, purename) 77 | if not os.path.isdir(dir_save): 78 | os.makedirs(dir_save, exist_ok=True) 79 | 80 | # run the whole seqs 81 | num_seqs = data.shape[0] 82 | 83 | for idx in tqdm(range(num_seqs)): 84 | #print(idx) 85 | 86 | joints3d = data[idx] #*1.2 #scale problem [check first] 87 | keypoints_3d[0, :, :] = torch.Tensor(joints3d).to(device).float() 88 | 89 | if idx == 0: 90 | pred_betas[0, :] = init_mean_shape 91 | pred_pose[0, :] = init_mean_pose 92 | pred_cam_t[0, :] = cam_trans_zero 93 | else: 94 | data_param = joblib.load(dir_save + "/" + "%04d"%(idx-1) + ".pkl") 95 | pred_betas[0, :] = torch.from_numpy(data_param['beta']).unsqueeze(0).float() 96 | pred_pose[0, :] = torch.from_numpy(data_param['pose']).unsqueeze(0).float() 97 | pred_cam_t[0, :] = torch.from_numpy(data_param['cam']).unsqueeze(0).float() 98 | 99 | if opt.joint_category =="AMASS": 100 | confidence_input = torch.ones(opt.num_joints) 101 | # make sure the foot and ankle 102 | if opt.fix_foot == True: 103 | confidence_input[7] = 1.5 104 | confidence_input[8] = 1.5 105 | confidence_input[10] = 1.5 106 | confidence_input[11] = 1.5 107 | else: 108 | print("Such category not settle down!") 109 | 110 | # ----- from initial to fitting ------- 111 | new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ 112 | new_opt_cam_t, new_opt_joint_loss = smplify( 113 | pred_pose.detach(), 114 | pred_betas.detach(), 115 | pred_cam_t.detach(), 116 | keypoints_3d, 117 | conf_3d=confidence_input.to(device), 118 | seq_ind=idx 119 | ) 120 | 121 | # # -- save the results to ply--- 122 | outputp = smplmodel(betas=new_opt_betas, global_orient=new_opt_pose[:, :3], body_pose=new_opt_pose[:, 3:], 123 | transl=new_opt_cam_t, return_verts=True) 124 | mesh_p = trimesh.Trimesh(vertices=outputp.vertices.detach().cpu().numpy().squeeze(), faces=smplmodel.faces, process=False) 125 | mesh_p.export(dir_save + "/" + "%04d"%idx + ".ply") 126 | 127 | # save the pkl 128 | param = {} 129 | param['beta'] = new_opt_betas.detach().cpu().numpy() 130 | param['pose'] = new_opt_pose.detach().cpu().numpy() 131 | param['cam'] = new_opt_cam_t.detach().cpu().numpy() 132 | joblib.dump(param, dir_save + "/" + "%04d"%idx + ".pkl", compress=3) 133 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | import numpy as np 3 | import torch as th 4 | 5 | from .gaussian_diffusion import GaussianDiffusion 6 | 7 | 8 | def space_timesteps(num_timesteps, section_counts): 9 | """ 10 | Create a list of timesteps to use from an original diffusion process, 11 | given the number of timesteps we want to take from equally-sized portions 12 | of the original process. 13 | 14 | For example, if there's 300 timesteps and the section counts are [10,15,20] 15 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 16 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 17 | 18 | If the stride is a string starting with "ddim", then the fixed striding 19 | from the DDIM paper is used, and only one section is allowed. 20 | 21 | :param num_timesteps: the number of diffusion steps in the original 22 | process to divide up. 23 | :param section_counts: either a list of numbers, or a string containing 24 | comma-separated numbers, indicating the step count 25 | per section. As a special case, use "ddimN" where N 26 | is a number of steps to use the striding from the 27 | DDIM paper. 28 | :return: a set of diffusion steps from the original process to use. 29 | """ 30 | if isinstance(section_counts, str): 31 | if section_counts.startswith("ddim"): 32 | desired_count = int(section_counts[len("ddim") :]) 33 | for i in range(1, num_timesteps): 34 | if len(range(0, num_timesteps, i)) == desired_count: 35 | return set(range(0, num_timesteps, i)) 36 | raise ValueError( 37 | f"cannot create exactly {num_timesteps} steps with an integer stride" 38 | ) 39 | section_counts = [int(x) for x in section_counts.split(",")] 40 | size_per = num_timesteps // len(section_counts) 41 | extra = num_timesteps % len(section_counts) 42 | start_idx = 0 43 | all_steps = [] 44 | for i, section_count in enumerate(section_counts): 45 | size = size_per + (1 if i < extra else 0) 46 | if size < section_count: 47 | raise ValueError( 48 | f"cannot divide section of {size} steps into {section_count}" 49 | ) 50 | if section_count <= 1: 51 | frac_stride = 1 52 | else: 53 | frac_stride = (size - 1) / (section_count - 1) 54 | cur_idx = 0.0 55 | taken_steps = [] 56 | for _ in range(section_count): 57 | taken_steps.append(start_idx + round(cur_idx)) 58 | cur_idx += frac_stride 59 | all_steps += taken_steps 60 | start_idx += size 61 | return set(all_steps) 62 | 63 | 64 | class SpacedDiffusion(GaussianDiffusion): 65 | """ 66 | A diffusion process which can skip steps in a base diffusion process. 67 | 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | if self.rescale_timesteps: 128 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /data_loaders/humanml/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 6 | def euclidean_distance_matrix(matrix1, matrix2): 7 | """ 8 | Params: 9 | -- matrix1: N1 x D 10 | -- matrix2: N2 x D 11 | Returns: 12 | -- dist: N1 x N2 13 | dist[i, j] == distance(matrix1[i], matrix2[j]) 14 | """ 15 | assert matrix1.shape[1] == matrix2.shape[1] 16 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 17 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 18 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 19 | dists = np.sqrt(d1 + d2 + d3) # broadcasting 20 | return dists 21 | 22 | def calculate_top_k(mat, top_k): 23 | size = mat.shape[0] 24 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 25 | bool_mat = (mat == gt_mat) 26 | correct_vec = False 27 | top_k_list = [] 28 | for i in range(top_k): 29 | # print(correct_vec, bool_mat[:, i]) 30 | correct_vec = (correct_vec | bool_mat[:, i]) 31 | # print(correct_vec) 32 | top_k_list.append(correct_vec[:, None]) 33 | top_k_mat = np.concatenate(top_k_list, axis=1) 34 | return top_k_mat 35 | 36 | 37 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): 38 | dist_mat = euclidean_distance_matrix(embedding1, embedding2) 39 | argmax = np.argsort(dist_mat, axis=1) 40 | top_k_mat = calculate_top_k(argmax, top_k) 41 | if sum_all: 42 | return top_k_mat.sum(axis=0) 43 | else: 44 | return top_k_mat 45 | 46 | 47 | def calculate_matching_score(embedding1, embedding2, sum_all=False): 48 | assert len(embedding1.shape) == 2 49 | assert embedding1.shape[0] == embedding2.shape[0] 50 | assert embedding1.shape[1] == embedding2.shape[1] 51 | 52 | dist = linalg.norm(embedding1 - embedding2, axis=1) 53 | if sum_all: 54 | return dist.sum(axis=0) 55 | else: 56 | return dist 57 | 58 | 59 | 60 | def calculate_activation_statistics(activations): 61 | """ 62 | Params: 63 | -- activation: num_samples x dim_feat 64 | Returns: 65 | -- mu: dim_feat 66 | -- sigma: dim_feat x dim_feat 67 | """ 68 | mu = np.mean(activations, axis=0) 69 | cov = np.cov(activations, rowvar=False) 70 | return mu, cov 71 | 72 | 73 | def calculate_diversity(activation, diversity_times): 74 | assert len(activation.shape) == 2 75 | assert activation.shape[0] > diversity_times 76 | num_samples = activation.shape[0] 77 | 78 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 79 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 80 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) 81 | return dist.mean() 82 | 83 | 84 | def calculate_multimodality(activation, multimodality_times): 85 | assert len(activation.shape) == 3 86 | assert activation.shape[1] > multimodality_times 87 | num_per_sent = activation.shape[1] 88 | 89 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 90 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 91 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 92 | return dist.mean() 93 | 94 | 95 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 96 | """Numpy implementation of the Frechet Distance. 97 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 98 | and X_2 ~ N(mu_2, C_2) is 99 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 100 | Stable version by Dougal J. Sutherland. 101 | Params: 102 | -- mu1 : Numpy array containing the activations of a layer of the 103 | inception net (like returned by the function 'get_predictions') 104 | for generated samples. 105 | -- mu2 : The sample mean over activations, precalculated on an 106 | representative dataset set. 107 | -- sigma1: The covariance matrix over activations for generated samples. 108 | -- sigma2: The covariance matrix over activations, precalculated on an 109 | representative dataset set. 110 | Returns: 111 | -- : The Frechet Distance. 112 | """ 113 | 114 | mu1 = np.atleast_1d(mu1) 115 | mu2 = np.atleast_1d(mu2) 116 | 117 | sigma1 = np.atleast_2d(sigma1) 118 | sigma2 = np.atleast_2d(sigma2) 119 | 120 | assert mu1.shape == mu2.shape, \ 121 | 'Training and test mean vectors have different lengths' 122 | assert sigma1.shape == sigma2.shape, \ 123 | 'Training and test covariances have different dimensions' 124 | 125 | diff = mu1 - mu2 126 | 127 | # Product might be almost singular 128 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 129 | if not np.isfinite(covmean).all(): 130 | msg = ('fid calculation produces singular product; ' 131 | 'adding %s to diagonal of cov estimates') % eps 132 | print(msg) 133 | offset = np.eye(sigma1.shape[0]) * eps 134 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 135 | 136 | # Numerical error might give slight imaginary component 137 | if np.iscomplexobj(covmean): 138 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 139 | m = np.max(np.abs(covmean.imag)) 140 | raise ValueError('Imaginary component {}'.format(m)) 141 | covmean = covmean.real 142 | 143 | tr_covmean = np.trace(covmean) 144 | 145 | return (diff.dot(diff) + np.trace(sigma1) + 146 | np.trace(sigma2) - 2 * tr_covmean) -------------------------------------------------------------------------------- /visualize/blender/materials.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | 3 | 4 | def clear_material(material): 5 | if material.node_tree: 6 | material.node_tree.links.clear() 7 | material.node_tree.nodes.clear() 8 | 9 | 10 | def colored_material_diffuse_BSDF(r, g, b, a=1, roughness=0.127451): 11 | materials = bpy.data.materials 12 | material = materials.new(name="body") 13 | material.use_nodes = True 14 | clear_material(material) 15 | nodes = material.node_tree.nodes 16 | links = material.node_tree.links 17 | output = nodes.new(type='ShaderNodeOutputMaterial') 18 | diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') 19 | diffuse.inputs["Color"].default_value = (r, g, b, a) 20 | diffuse.inputs["Roughness"].default_value = roughness 21 | links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) 22 | return material 23 | 24 | def colored_material_relection_BSDF(r, g, b, a=1, roughness=0.127451, saturation_factor=1): 25 | materials = bpy.data.materials 26 | material = materials.new(name="body") 27 | material.use_nodes = True 28 | # clear_material(material) 29 | nodes = material.node_tree.nodes 30 | links = material.node_tree.links 31 | output = nodes.new(type='ShaderNodeOutputMaterial') 32 | # diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') 33 | diffuse = nodes["Principled BSDF"] 34 | diffuse.inputs["Base Color"].default_value = (r*saturation_factor, g*saturation_factor, b*saturation_factor, a) 35 | diffuse.inputs["Roughness"].default_value = roughness 36 | links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) 37 | return material 38 | 39 | # keys: 40 | # ['Base Color', 'Subsurface', 'Subsurface Radius', 'Subsurface Color', 'Metallic', 'Specular', 'Specular Tint', 'Roughness', 'Anisotropic', 'Anisotropic Rotation', 'Sheen', 1Sheen Tint', 'Clearcoat', 'Clearcoat Roughness', 'IOR', 'Transmission', 'Transmission Roughness', 'Emission', 'Emission Strength', 'Alpha', 'Normal', 'Clearcoat Normal', 'Tangent'] 41 | DEFAULT_BSDF_SETTINGS = {"Subsurface": 0.15, 42 | "Subsurface Radius": [1.1, 0.2, 0.1], 43 | "Metallic": 0.3, 44 | "Specular": 0.5, 45 | "Specular Tint": 0.5, 46 | "Roughness": 0.75, 47 | "Anisotropic": 0.25, 48 | "Anisotropic Rotation": 0.25, 49 | "Sheen": 0.75, 50 | "Sheen Tint": 0.5, 51 | "Clearcoat": 0.5, 52 | "Clearcoat Roughness": 0.5, 53 | "IOR": 1.450, 54 | "Transmission": 0.1, 55 | "Transmission Roughness": 0.1, 56 | "Emission": (0, 0, 0, 1), 57 | "Emission Strength": 0.0, 58 | "Alpha": 1.0} 59 | 60 | def body_material(r, g, b, a=1, name="body", oldrender=True): 61 | if oldrender: 62 | material = colored_material_diffuse_BSDF(r, g, b, a=a) 63 | else: 64 | materials = bpy.data.materials 65 | material = materials.new(name=name) 66 | material.use_nodes = True 67 | nodes = material.node_tree.nodes 68 | diffuse = nodes["Principled BSDF"] 69 | inputs = diffuse.inputs 70 | 71 | settings = DEFAULT_BSDF_SETTINGS.copy() 72 | settings["Base Color"] = (r, g, b, a) 73 | settings["Subsurface Color"] = (r, g, b, a) 74 | settings["Subsurface"] = 0.0 75 | 76 | for setting, val in settings.items(): 77 | inputs[setting].default_value = val 78 | 79 | return material 80 | 81 | 82 | def colored_material_bsdf(name, **kwargs): 83 | materials = bpy.data.materials 84 | material = materials.new(name=name) 85 | material.use_nodes = True 86 | nodes = material.node_tree.nodes 87 | diffuse = nodes["Principled BSDF"] 88 | inputs = diffuse.inputs 89 | 90 | settings = DEFAULT_BSDF_SETTINGS.copy() 91 | for key, val in kwargs.items(): 92 | settings[key] = val 93 | 94 | for setting, val in settings.items(): 95 | inputs[setting].default_value = val 96 | 97 | return material 98 | 99 | 100 | def floor_mat(name="floor_mat", color=(0.1, 0.1, 0.1, 1), roughness=0.127451): 101 | return colored_material_diffuse_BSDF(color[0], color[1], color[2], a=color[3], roughness=roughness) 102 | 103 | 104 | def plane_mat(): 105 | materials = bpy.data.materials 106 | material = materials.new(name="plane") 107 | material.use_nodes = True 108 | clear_material(material) 109 | nodes = material.node_tree.nodes 110 | links = material.node_tree.links 111 | output = nodes.new(type='ShaderNodeOutputMaterial') 112 | diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') 113 | checker = nodes.new(type="ShaderNodeTexChecker") 114 | checker.inputs["Scale"].default_value = 1024 115 | checker.inputs["Color1"].default_value = (0.8, 0.8, 0.8, 1) 116 | checker.inputs["Color2"].default_value = (0.3, 0.3, 0.3, 1) 117 | links.new(checker.outputs["Color"], diffuse.inputs['Color']) 118 | links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) 119 | diffuse.inputs["Roughness"].default_value = 0.127451 120 | return material 121 | 122 | 123 | def plane_mat_uni(): 124 | materials = bpy.data.materials 125 | material = materials.new(name="plane_uni") 126 | material.use_nodes = True 127 | clear_material(material) 128 | nodes = material.node_tree.nodes 129 | links = material.node_tree.links 130 | output = nodes.new(type='ShaderNodeOutputMaterial') 131 | diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') 132 | diffuse.inputs["Color"].default_value = (0.8, 0.8, 0.8, 1) 133 | diffuse.inputs["Roughness"].default_value = 0.127451 134 | links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) 135 | return material 136 | -------------------------------------------------------------------------------- /data_loaders/humanml/networks/evaluator_wrapper.py: -------------------------------------------------------------------------------- 1 | from data_loaders.humanml.networks.modules import * 2 | from data_loaders.humanml.utils.word_vectorizer import POS_enumerator 3 | from os.path import join as pjoin 4 | 5 | 6 | def build_evaluators(opt): 7 | 8 | 9 | if opt['eval_part'] == 'all': 10 | eval_index = list(range(opt['dim_pose']-4)) 11 | else: 12 | lower_joint = np.array([1,2,4,5,7,8,10,11]) - 1 13 | lower_joint2 = np.array([0, 1,2,4,5,7,8,10,11]) 14 | lower_index1 = np.array([0,1,2,3]) 15 | lower_index2 = np.stack([4 + lower_joint * 3, 5 + lower_joint * 3, 6 + lower_joint * 3], axis=1).reshape(-1) 16 | lower_index3 = np.stack([67 + lower_joint * 6, 68 + lower_joint * 6, 69 + lower_joint * 6, 17 | 70 + lower_joint * 6, 71 + lower_joint * 6, 72 + lower_joint * 6,], axis=1).reshape(-1) 18 | lower_index4 = np.stack([193 + lower_joint2 * 3, 194 + lower_joint2 * 3, 195 + lower_joint2 * 3], axis=1).reshape(-1) 19 | lower_index5 = np.array([259,260,261,262]) 20 | lower_index = np.concatenate([lower_index1, lower_index2, lower_index3, lower_index4, lower_index5]).tolist() # 107 21 | upper_index = [i for i in range(263) if i not in lower_index] # 156 22 | lower_index = lower_index[:-4] # 103 23 | 24 | if opt['eval_part'] == 'upper': 25 | eval_index = upper_index 26 | elif opt['eval_part'] == 'lower': 27 | eval_index = lower_index 28 | else: 29 | raise NotImplementedError('Unsupported evaluation part.') 30 | 31 | movement_enc = MovementConvEncoder(len(eval_index), opt['dim_movement_enc_hidden'], opt['dim_movement_latent']) 32 | text_enc = TextEncoderBiGRUCo(word_size=opt['dim_word'], 33 | pos_size=opt['dim_pos_ohot'], 34 | hidden_size=opt['dim_text_hidden'], 35 | output_size=opt['dim_coemb_hidden'], 36 | device=opt['device']) 37 | 38 | motion_enc = MotionEncoderBiGRUCo(input_size=opt['dim_movement_latent'], 39 | hidden_size=opt['dim_motion_hidden'], 40 | output_size=opt['dim_coemb_hidden'], 41 | device=opt['device']) 42 | 43 | 44 | checkpoint = torch.load(pjoin(opt['checkpoints_dir'], 't2m', opt["eval_part"] + '.tar'), 45 | map_location=opt['device']) 46 | movement_enc.load_state_dict(checkpoint['movement_encoder']) 47 | text_enc.load_state_dict(checkpoint['text_encoder']) 48 | motion_enc.load_state_dict(checkpoint['motion_encoder']) 49 | # print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) 50 | return text_enc, motion_enc, movement_enc, eval_index 51 | 52 | class EvaluatorWrapper(object): 53 | 54 | def __init__(self, device, eval_part='all'): 55 | opt = { 56 | 'dataset_name': 'humanml', 57 | 'device': device, 58 | 'dim_word': 300, 59 | 'max_motion_length': 196, 60 | 'dim_pos_ohot': len(POS_enumerator), 61 | 'dim_motion_hidden': 1024, 62 | 'max_text_len': 20, 63 | 'dim_text_hidden': 512, 64 | 'dim_coemb_hidden': 512, 65 | 'dim_pose': 263, 66 | 'dim_movement_enc_hidden': 512, 67 | 'dim_movement_latent': 512, 68 | 'checkpoints_dir': './data', 69 | 'unit_length': 4, 70 | 'eval_part': eval_part 71 | } 72 | 73 | self.text_encoder, self.motion_encoder, self.movement_encoder, self.eval_index = build_evaluators(opt) 74 | self.opt = opt 75 | self.device = opt['device'] 76 | 77 | self.text_encoder.to(opt['device']) 78 | self.motion_encoder.to(opt['device']) 79 | self.movement_encoder.to(opt['device']) 80 | 81 | self.text_encoder.eval() 82 | self.motion_encoder.eval() 83 | self.movement_encoder.eval() 84 | 85 | # Please note that the results does not following the order of inputs 86 | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): 87 | with torch.no_grad(): 88 | word_embs = word_embs.detach().to(self.device).float() 89 | pos_ohot = pos_ohot.detach().to(self.device).float() 90 | motions = motions[..., self.eval_index].detach().to(self.device).float() 91 | 92 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 93 | motions = motions[align_idx] 94 | m_lens = m_lens[align_idx] 95 | 96 | '''Movement Encoding''' 97 | movements = self.movement_encoder(motions).detach() 98 | m_lens = torch.div(m_lens, self.opt['unit_length'], rounding_mode='floor') 99 | motion_embedding = self.motion_encoder(movements, m_lens) 100 | 101 | '''Text Encoding''' 102 | text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) 103 | text_embedding = text_embedding[align_idx] 104 | return text_embedding, motion_embedding 105 | 106 | 107 | # Please note that the results does not following the order of inputs 108 | def get_motion_embeddings(self, motions, m_lens): 109 | with torch.no_grad(): 110 | motions = motions[..., self.eval_index].detach().to(self.device).float() 111 | 112 | align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() 113 | motions = motions[align_idx] 114 | m_lens = m_lens[align_idx] 115 | 116 | '''Movement Encoding''' 117 | movements = self.movement_encoder(motions).detach() 118 | m_lens = m_lens // self.opt['unit_length'] 119 | motion_embedding = self.motion_encoder(movements, m_lens) 120 | return motion_embedding 121 | 122 | -------------------------------------------------------------------------------- /visualize/simplify_loc2rot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from visualize.joints2smpl.src import config 5 | import smplx 6 | import h5py 7 | from visualize.joints2smpl.src.smplify import SMPLify3D 8 | from tqdm import tqdm 9 | import utils.rotation_conversions as geometry 10 | import argparse 11 | from trimesh import Trimesh 12 | 13 | class joints2smpl: 14 | 15 | def __init__(self, num_frames, device_id, cuda=True): 16 | self.device = torch.device("cuda:" + str(device_id) if cuda else "cpu") 17 | # self.device = torch.device("cpu") 18 | self.batch_size = num_frames 19 | self.num_joints = 22 # for HumanML3D 20 | self.joint_category = "AMASS" 21 | self.num_smplify_iters = 150 22 | self.fix_foot = False 23 | smplmodel = smplx.create(config.SMPL_MODEL_DIR, 24 | model_type="smpl", gender="neutral", ext="pkl", 25 | batch_size=self.batch_size).to(self.device) 26 | 27 | # ## --- load the mean pose as original ---- 28 | smpl_mean_file = config.SMPL_MEAN_FILE 29 | 30 | file = h5py.File(smpl_mean_file, 'r') 31 | self.init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) 32 | self.init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) 33 | self.cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(self.device) 34 | # 35 | 36 | # # #-------------initialize SMPLify 37 | self.smplify = SMPLify3D(smplxmodel=smplmodel, 38 | batch_size=self.batch_size, 39 | joints_category=self.joint_category, 40 | num_iters=self.num_smplify_iters, 41 | device=self.device) 42 | 43 | 44 | def npy2smpl(self, npy_path): 45 | out_path = npy_path.replace('.npy', '_rot.npy') 46 | motions = np.load(npy_path, allow_pickle=True)[None][0] 47 | # print_batch('', motions) 48 | n_samples = motions['motion'].shape[0] 49 | all_thetas = [] 50 | for sample_i in tqdm(range(n_samples)): 51 | thetas, _ = self.joint2smpl(motions['motion'][sample_i].transpose(2, 0, 1)) # [nframes, njoints, 3] 52 | all_thetas.append(thetas.cpu().numpy()) 53 | motions['motion'] = np.concatenate(all_thetas, axis=0) 54 | print('motions', motions['motion'].shape) 55 | 56 | print(f'Saving [{out_path}]') 57 | np.save(out_path, motions) 58 | exit() 59 | 60 | 61 | 62 | def joint2smpl(self, input_joints, init_params=None): 63 | _smplify = self.smplify # if init_params is None else self.smplify_fast 64 | pred_pose = torch.zeros(self.batch_size, 72).to(self.device) 65 | pred_betas = torch.zeros(self.batch_size, 10).to(self.device) 66 | pred_cam_t = torch.zeros(self.batch_size, 3).to(self.device) 67 | keypoints_3d = torch.zeros(self.batch_size, self.num_joints, 3).to(self.device) 68 | 69 | # run the whole seqs 70 | num_seqs = input_joints.shape[0] 71 | 72 | 73 | # joints3d = input_joints[idx] # *1.2 #scale problem [check first] 74 | keypoints_3d = torch.Tensor(input_joints).to(self.device).float() 75 | 76 | # if idx == 0: 77 | if init_params is None: 78 | pred_betas = self.init_mean_shape 79 | pred_pose = self.init_mean_pose 80 | pred_cam_t = self.cam_trans_zero 81 | else: 82 | pred_betas = init_params['betas'] 83 | pred_pose = init_params['pose'] 84 | pred_cam_t = init_params['cam'] 85 | 86 | if self.joint_category == "AMASS": 87 | confidence_input = torch.ones(self.num_joints) 88 | # make sure the foot and ankle 89 | if self.fix_foot == True: 90 | confidence_input[7] = 1.5 91 | confidence_input[8] = 1.5 92 | confidence_input[10] = 1.5 93 | confidence_input[11] = 1.5 94 | else: 95 | print("Such category not settle down!") 96 | 97 | new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ 98 | new_opt_cam_t, new_opt_joint_loss = _smplify( 99 | pred_pose.detach(), 100 | pred_betas.detach(), 101 | pred_cam_t.detach(), 102 | keypoints_3d, 103 | conf_3d=confidence_input.to(self.device), 104 | # seq_ind=idx 105 | ) 106 | 107 | # thetas = new_opt_pose.reshape(self.batch_size, 24, 3) 108 | # thetas = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(thetas)) # [bs, 24, 6] 109 | # root_loc = torch.tensor(keypoints_3d[:, 0]) # [bs, 3] 110 | # root_loc = torch.cat([root_loc, torch.zeros_like(root_loc)], dim=-1).unsqueeze(1) # [bs, 1, 6] 111 | # thetas = torch.cat([thetas, root_loc], dim=1).unsqueeze(0).permute(0, 2, 3, 1) # [1, 25, 6, 196] 112 | 113 | new_opt_vertices = new_opt_vertices + new_opt_cam_t 114 | meshes = [] 115 | for i in range(self.batch_size): 116 | mesh = Trimesh(new_opt_vertices[i].detach().cpu(), self.smplify.smpl.faces) 117 | # mesh.export(os.path.join(self.save_path, 'test.obj')) 118 | meshes.append(mesh) 119 | 120 | return meshes, new_opt_vertices.detach().cpu().numpy() 121 | 122 | 123 | if __name__ == '__main__': 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument("--input_path", type=str, required=True, help='Blender file or dir with blender files') 126 | parser.add_argument("--cuda", type=bool, default=True, help='') 127 | parser.add_argument("--device", type=int, default=0, help='') 128 | params = parser.parse_args() 129 | 130 | simplify = joints2smpl(device_id=params.device, cuda=params.cuda) 131 | 132 | if os.path.isfile(params.input_path) and params.input_path.endswith('.npy'): 133 | simplify.npy2smpl(params.input_path) 134 | elif os.path.isdir(params.input_path): 135 | files = [os.path.join(params.input_path, f) for f in os.listdir(params.input_path) if f.endswith('.npy')] 136 | for f in files: 137 | simplify.npy2smpl(f) -------------------------------------------------------------------------------- /diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /data_loaders/humanml/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # import cv2 4 | from PIL import Image 5 | from data_loaders.humanml.utils import paramUtil 6 | import math 7 | import time 8 | import matplotlib.pyplot as plt 9 | from scipy.ndimage import gaussian_filter 10 | 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 17 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 18 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 19 | 20 | MISSING_VALUE = -1 21 | 22 | def save_image(image_numpy, image_path): 23 | img_pil = Image.fromarray(image_numpy) 24 | img_pil.save(image_path) 25 | 26 | 27 | def save_logfile(log_loss, save_path): 28 | with open(save_path, 'wt') as f: 29 | for k, v in log_loss.items(): 30 | w_line = k 31 | for digit in v: 32 | w_line += ' %.3f' % digit 33 | f.write(w_line + '\n') 34 | 35 | 36 | def print_current_loss(start_time, niter_state, losses, epoch=None, sub_epoch=None, 37 | inner_iter=None, tf_ratio=None, sl_steps=None): 38 | 39 | def as_minutes(s): 40 | m = math.floor(s / 60) 41 | s -= m * 60 42 | return '%dm %ds' % (m, s) 43 | 44 | def time_since(since, percent): 45 | now = time.time() 46 | s = now - since 47 | es = s / percent 48 | rs = es - s 49 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 50 | 51 | if epoch is not None: 52 | print('epoch: %3d niter: %6d sub_epoch: %2d inner_iter: %4d' % (epoch, niter_state, sub_epoch, inner_iter), end=" ") 53 | 54 | # message = '%s niter: %d completed: %3d%%)' % (time_since(start_time, niter_state / total_niters), 55 | # niter_state, niter_state / total_niters * 100) 56 | now = time.time() 57 | message = '%s'%(as_minutes(now - start_time)) 58 | 59 | for k, v in losses.items(): 60 | message += ' %s: %.4f ' % (k, v) 61 | message += ' sl_length:%2d tf_ratio:%.2f'%(sl_steps, tf_ratio) 62 | print(message) 63 | 64 | def print_current_loss_decomp(start_time, niter_state, total_niters, losses, epoch=None, inner_iter=None): 65 | 66 | def as_minutes(s): 67 | m = math.floor(s / 60) 68 | s -= m * 60 69 | return '%dm %ds' % (m, s) 70 | 71 | def time_since(since, percent): 72 | now = time.time() 73 | s = now - since 74 | es = s / percent 75 | rs = es - s 76 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 77 | 78 | print('epoch: %03d inner_iter: %5d' % (epoch, inner_iter), end=" ") 79 | # now = time.time() 80 | message = '%s niter: %07d completed: %3d%%)'%(time_since(start_time, niter_state / total_niters), niter_state, niter_state / total_niters * 100) 81 | for k, v in losses.items(): 82 | message += ' %s: %.4f ' % (k, v) 83 | print(message) 84 | 85 | 86 | def compose_gif_img_list(img_list, fp_out, duration): 87 | img, *imgs = [Image.fromarray(np.array(image)) for image in img_list] 88 | img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False, 89 | save_all=True, loop=0, duration=duration) 90 | 91 | 92 | def save_images(visuals, image_path): 93 | if not os.path.exists(image_path): 94 | os.makedirs(image_path) 95 | 96 | for i, (label, img_numpy) in enumerate(visuals.items()): 97 | img_name = '%d_%s.jpg' % (i, label) 98 | save_path = os.path.join(image_path, img_name) 99 | save_image(img_numpy, save_path) 100 | 101 | 102 | def save_images_test(visuals, image_path, from_name, to_name): 103 | if not os.path.exists(image_path): 104 | os.makedirs(image_path) 105 | 106 | for i, (label, img_numpy) in enumerate(visuals.items()): 107 | img_name = "%s_%s_%s" % (from_name, to_name, label) 108 | save_path = os.path.join(image_path, img_name) 109 | save_image(img_numpy, save_path) 110 | 111 | 112 | def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)): 113 | # print(col, row) 114 | compose_img = compose_image(img_list, col, row, img_size) 115 | if not os.path.exists(save_dir): 116 | os.makedirs(save_dir) 117 | img_path = os.path.join(save_dir, img_name) 118 | # print(img_path) 119 | compose_img.save(img_path) 120 | 121 | 122 | def compose_image(img_list, col, row, img_size): 123 | to_image = Image.new('RGB', (col * img_size[0], row * img_size[1])) 124 | for y in range(0, row): 125 | for x in range(0, col): 126 | from_img = Image.fromarray(img_list[y * col + x]) 127 | # print((x * img_size[0], y*img_size[1], 128 | # (x + 1) * img_size[0], (y + 1) * img_size[1])) 129 | paste_area = (x * img_size[0], y*img_size[1], 130 | (x + 1) * img_size[0], (y + 1) * img_size[1]) 131 | to_image.paste(from_img, paste_area) 132 | # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img 133 | return to_image 134 | 135 | 136 | def plot_loss_curve(losses, save_path, intervals=500): 137 | plt.figure(figsize=(10, 5)) 138 | plt.title("Loss During Training") 139 | for key in losses.keys(): 140 | plt.plot(list_cut_average(losses[key], intervals), label=key) 141 | plt.xlabel("Iterations/" + str(intervals)) 142 | plt.ylabel("Loss") 143 | plt.legend() 144 | plt.savefig(save_path) 145 | plt.show() 146 | 147 | 148 | def list_cut_average(ll, intervals): 149 | if intervals == 1: 150 | return ll 151 | 152 | bins = math.ceil(len(ll) * 1.0 / intervals) 153 | ll_new = [] 154 | for i in range(bins): 155 | l_low = intervals * i 156 | l_high = l_low + intervals 157 | l_high = l_high if l_high < len(ll) else len(ll) 158 | ll_new.append(np.mean(ll[l_low:l_high])) 159 | return ll_new 160 | 161 | 162 | def motion_temporal_filter(motion, sigma=1): 163 | motion = motion.reshape(motion.shape[0], -1) 164 | # print(motion.shape)
 165 | for i in range(motion.shape[1]): 166 | motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") 167 | return motion.reshape(motion.shape[0], -1, 3) 168 | 169 | -------------------------------------------------------------------------------- /diffusion/nn.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | """ 3 | Various utilities for neural networks. 4 | """ 5 | 6 | import math 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | 12 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 13 | class SiLU(nn.Module): 14 | def forward(self, x): 15 | return x * th.sigmoid(x) 16 | 17 | 18 | class GroupNorm32(nn.GroupNorm): 19 | def forward(self, x): 20 | return super().forward(x.float()).type(x.dtype) 21 | 22 | 23 | def conv_nd(dims, *args, **kwargs): 24 | """ 25 | Create a 1D, 2D, or 3D convolution module. 26 | """ 27 | if dims == 1: 28 | return nn.Conv1d(*args, **kwargs) 29 | elif dims == 2: 30 | return nn.Conv2d(*args, **kwargs) 31 | elif dims == 3: 32 | return nn.Conv3d(*args, **kwargs) 33 | raise ValueError(f"unsupported dimensions: {dims}") 34 | 35 | 36 | def linear(*args, **kwargs): 37 | """ 38 | Create a linear module. 39 | """ 40 | return nn.Linear(*args, **kwargs) 41 | 42 | 43 | def avg_pool_nd(dims, *args, **kwargs): 44 | """ 45 | Create a 1D, 2D, or 3D average pooling module. 46 | """ 47 | if dims == 1: 48 | return nn.AvgPool1d(*args, **kwargs) 49 | elif dims == 2: 50 | return nn.AvgPool2d(*args, **kwargs) 51 | elif dims == 3: 52 | return nn.AvgPool3d(*args, **kwargs) 53 | raise ValueError(f"unsupported dimensions: {dims}") 54 | 55 | 56 | def update_ema(target_params, source_params, rate=0.99): 57 | """ 58 | Update target parameters to be closer to those of source parameters using 59 | an exponential moving average. 60 | 61 | :param target_params: the target parameter sequence. 62 | :param source_params: the source parameter sequence. 63 | :param rate: the EMA rate (closer to 1 means slower). 64 | """ 65 | for targ, src in zip(target_params, source_params): 66 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 67 | 68 | 69 | def zero_module(module): 70 | """ 71 | Zero out the parameters of a module and return it. 72 | """ 73 | for p in module.parameters(): 74 | p.detach().zero_() 75 | return module 76 | 77 | 78 | def scale_module(module, scale): 79 | """ 80 | Scale the parameters of a module and return it. 81 | """ 82 | for p in module.parameters(): 83 | p.detach().mul_(scale) 84 | return module 85 | 86 | 87 | def mean_flat(tensor): 88 | """ 89 | Take the mean over all non-batch dimensions. 90 | """ 91 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 92 | 93 | def sum_flat(tensor): 94 | """ 95 | Take the sum over all non-batch dimensions. 96 | """ 97 | return tensor.sum(dim=list(range(1, len(tensor.shape)))) 98 | 99 | 100 | def normalization(channels): 101 | """ 102 | Make a standard normalization layer. 103 | 104 | :param channels: number of input channels. 105 | :return: an nn.Module for normalization. 106 | """ 107 | return GroupNorm32(32, channels) 108 | 109 | 110 | def timestep_embedding(timesteps, dim, max_period=10000): 111 | """ 112 | Create sinusoidal timestep embeddings. 113 | 114 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 115 | These may be fractional. 116 | :param dim: the dimension of the output. 117 | :param max_period: controls the minimum frequency of the embeddings. 118 | :return: an [N x dim] Tensor of positional embeddings. 119 | """ 120 | half = dim // 2 121 | freqs = th.exp( 122 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 123 | ).to(device=timesteps.device) 124 | args = timesteps[:, None].float() * freqs[None] 125 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 126 | if dim % 2: 127 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 128 | return embedding 129 | 130 | 131 | def checkpoint(func, inputs, params, flag): 132 | """ 133 | Evaluate a function without caching intermediate activations, allowing for 134 | reduced memory at the expense of extra compute in the backward pass. 135 | :param func: the function to evaluate. 136 | :param inputs: the argument sequence to pass to `func`. 137 | :param params: a sequence of parameters `func` depends on but does not 138 | explicitly take as arguments. 139 | :param flag: if False, disable gradient checkpointing. 140 | """ 141 | if flag: 142 | args = tuple(inputs) + tuple(params) 143 | return CheckpointFunction.apply(func, len(inputs), *args) 144 | else: 145 | return func(*inputs) 146 | 147 | 148 | class CheckpointFunction(th.autograd.Function): 149 | @staticmethod 150 | @th.cuda.amp.custom_fwd 151 | def forward(ctx, run_function, length, *args): 152 | ctx.run_function = run_function 153 | ctx.input_length = length 154 | ctx.save_for_backward(*args) 155 | with th.no_grad(): 156 | output_tensors = ctx.run_function(*args[:length]) 157 | return output_tensors 158 | 159 | @staticmethod 160 | @th.cuda.amp.custom_bwd 161 | def backward(ctx, *output_grads): 162 | args = list(ctx.saved_tensors) 163 | 164 | # Filter for inputs that require grad. If none, exit early. 165 | input_indices = [i for (i, x) in enumerate(args) if x.requires_grad] 166 | if not input_indices: 167 | return (None, None) + tuple(None for _ in args) 168 | 169 | with th.enable_grad(): 170 | for i in input_indices: 171 | if i < ctx.input_length: 172 | # Not sure why the OAI code does this little 173 | # dance. It might not be necessary. 174 | args[i] = args[i].detach().requires_grad_() 175 | args[i] = args[i].view_as(args[i]) 176 | output_tensors = ctx.run_function(*args[:ctx.input_length]) 177 | 178 | if isinstance(output_tensors, th.Tensor): 179 | output_tensors = [output_tensors] 180 | 181 | # Filter for outputs that require grad. If none, exit early. 182 | out_and_grads = [(o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad] 183 | if not out_and_grads: 184 | return (None, None) + tuple(None for _ in args) 185 | 186 | # Compute gradients on the filtered tensors. 187 | computed_grads = th.autograd.grad( 188 | [o for (o, g) in out_and_grads], 189 | [args[i] for i in input_indices], 190 | [g for (o, g) in out_and_grads] 191 | ) 192 | 193 | # Reassemble the complete gradient tuple. 194 | input_grads = [None for _ in args] 195 | for (i, g) in zip(input_indices, computed_grads): 196 | input_grads[i] = g 197 | return (None, None) + tuple(input_grads) 198 | -------------------------------------------------------------------------------- /data_loaders/humanml/utils/plot_script.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import numpy as np 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | from mpl_toolkits.mplot3d import Axes3D 7 | from matplotlib.animation import FuncAnimation, FFMpegFileWriter 8 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 9 | import mpl_toolkits.mplot3d.axes3d as p3 10 | # import cv2 11 | from textwrap import wrap 12 | 13 | def plot_2d_motion(save_path, kinematic_tree, joints, title='', figsize=(10, 10), fps=20, kp_thr=0.3): 14 | fig, ax = plt.subplots() 15 | title = '\n'.join(wrap(title, 20)) 16 | 17 | 18 | # MINS = joints.min(axis=(0,1)) 19 | # MAXS = joints.max(axis=(0,1)) 20 | # width = (MAXS - MINS).max() / 2 21 | # middle = (MAXS + MINS) / 2 22 | # MAXS = middle + width 23 | # MINS = middle - width 24 | MINS = [-1, -1] 25 | MAXS = [1, 1] 26 | 27 | def init(): 28 | ax.set_xlim(MINS[0], MAXS[0]) 29 | ax.set_ylim(MINS[1], MAXS[1]) 30 | fig.suptitle(title, fontsize=10) 31 | 32 | colors = ["#DD5A37", # 红 33 | "#D69E00", 34 | "#ef1f9b", 35 | "#DD5A37", 36 | "#D69E00", 37 | '#a73dfa', # 紫 38 | '#fca6f7', # 粉 39 | ] 40 | # colors = ['red', 'blue', 'black', 'red', 'blue', 41 | # 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue', 42 | # 'darkred', 'darkred','darkred','darkred','darkred'] 43 | frame_number = joints.shape[0] 44 | def update(frame): 45 | ax.clear() 46 | init() 47 | 48 | for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): 49 | linewidth = 2.0 50 | if joints.shape[2] == 3: 51 | for j in range(len(chain) - 1): 52 | if joints[frame, chain[j], 2] > kp_thr and joints[frame, chain[j+1], 2] > kp_thr: 53 | ax.plot(joints[frame, chain[j:j+2], 0], joints[frame, chain[j:j+2], 1], linewidth=linewidth, color=color) 54 | else: 55 | ax.plot(joints[frame, chain, 0], joints[frame, chain, 1], linewidth=linewidth, color=color) 56 | 57 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False) 58 | 59 | ani.save(save_path, fps=fps) 60 | plt.close() 61 | 62 | def list_cut_average(ll, intervals): 63 | if intervals == 1: 64 | return ll 65 | 66 | bins = math.ceil(len(ll) * 1.0 / intervals) 67 | ll_new = [] 68 | for i in range(bins): 69 | l_low = intervals * i 70 | l_high = l_low + intervals 71 | l_high = l_high if l_high < len(ll) else len(ll) 72 | ll_new.append(np.mean(ll[l_low:l_high])) 73 | return ll_new 74 | 75 | 76 | def plot_3d_motion(save_path, kinematic_tree, joints, title='', dataset=None, figsize=(10, 10), fps=20, radius=4, 77 | vis_mode='default', gt_frames=[], view_angle=(120, -90)): 78 | matplotlib.use('Agg') 79 | 80 | if os.path.isdir(save_path): 81 | title = '' 82 | else: 83 | title = '\n'.join(wrap(title, 20)) 84 | 85 | def init(): 86 | # ax.set_xlim3d([-radius / 2, radius / 2]) 87 | # ax.set_ylim3d([0, radius]) 88 | # ax.set_zlim3d([-radius / 3., radius * 2 / 3.]) 89 | ax.set_xlim3d([-radius / 2, radius / 2]) 90 | ax.set_ylim3d([-radius / 2, radius / 2]) 91 | ax.set_zlim3d([-radius / 2, radius / 2]) 92 | # print(title) 93 | fig.suptitle(title, fontsize=10) 94 | ax.grid(b=False) 95 | 96 | def plot_xzPlane(minx, maxx, miny, minz, maxz): 97 | ## Plot a plane XZ 98 | verts = [ 99 | [minx, miny, minz], 100 | [minx, miny, maxz], 101 | [maxx, miny, maxz], 102 | [maxx, miny, minz] 103 | ] 104 | xz_plane = Poly3DCollection([verts]) 105 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 106 | ax.add_collection3d(xz_plane) 107 | 108 | # return ax 109 | 110 | # (seq_len, joints_num, 3) 111 | data = joints.copy().reshape(len(joints), -1, 3) 112 | 113 | # preparation related to specific datasets 114 | data *= 1.3 # scale for visualization 115 | 116 | 117 | fig = plt.figure(figsize=figsize) 118 | # plt.tight_layout() 119 | ax = p3.Axes3D(fig) 120 | init() 121 | MINS = data.min(axis=0).min(axis=0) 122 | MAXS = data.max(axis=0).max(axis=0) 123 | colors_blue = ["#4D84AA", "#5B9965", "#61CEB9", "#34C1E2", "#80B79A"] # GT color 124 | colors_orange = ["#DD5A37", # 红 125 | "#D69E00", # 黄 126 | "#b7fa3d", # 绿 127 | "#3dfae3", # 浅蓝 128 | "#3d55fa", # 深蓝 129 | '#a73dfa', # 紫 130 | '#fca6f7', # 粉 131 | ] # Generation color 132 | colors = ["#DD5A37", # 红 133 | "#D69E00", 134 | "#ef1f9b", 135 | "#DD5A37", 136 | "#D69E00", 137 | '#a73dfa', # 紫 138 | '#fca6f7', # 粉 139 | ] 140 | if vis_mode == 'upper_body': # lower body taken fixed to input motion 141 | colors[0] = colors_blue[0] 142 | colors[1] = colors_blue[1] 143 | elif vis_mode == 'gt': 144 | colors = colors_blue 145 | 146 | frame_number = data.shape[0] 147 | # print(dataset.shape) 148 | 149 | # height_offset = MINS[1] 150 | # data[:, :, 1] -= height_offset 151 | trajec = data[:, 0, [0, 2]] 152 | 153 | # data[..., 0] -= data[:, 0:1, 0] 154 | # data[..., 1] -= data[:, 0:1, 1] 155 | # data[..., 2] -= data[:, 0:1, 2] 156 | 157 | # print(trajec.shape) 158 | 159 | def update(index): 160 | # print(index) 161 | ax.clear() 162 | init() 163 | # ax.view_init(elev=120, azim=-90) 164 | ax.view_init(*view_angle) 165 | ax.dist = 14 166 | # ax = 167 | # plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], 168 | # MAXS[2] - trajec[index, 1]) 169 | # ax.scatter(dataset[index, :22, 0], dataset[index, :22, 1], dataset[index, :22, 2], color='black', s=3) 170 | 171 | # if index > 1: 172 | # ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]), 173 | # trajec[:index, 1] - trajec[index, 1], linewidth=1.0, 174 | # color='blue') 175 | # # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2]) 176 | 177 | used_colors = colors_blue if index in gt_frames else colors 178 | for i, (chain, color) in enumerate(zip(kinematic_tree, used_colors)): 179 | if i < 5: 180 | linewidth = 4.0 181 | else: 182 | linewidth = 2.0 183 | ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, 184 | color=color) 185 | # print(trajec[:index, 0].shape) 186 | 187 | plt.axis('off') 188 | ax.set_xticklabels([]) 189 | ax.set_yticklabels([]) 190 | ax.set_zticklabels([]) 191 | 192 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) 193 | 194 | # writer = FFMpegFileWriter(fps=fps) 195 | 196 | # ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False, init_func=init) 197 | # ani.save(save_path, writer='pillow', fps=1000 / fps) 198 | if os.path.isdir(save_path): 199 | for i in range(len(list(ani.new_frame_seq()))): 200 | # ani.save(os.path.join(save_path,f"frame_{i}.png"), writer='pillow', savefig_kwargs={'facecolor': 'white'}) 201 | ani._draw_next_frame(i, blit=False) 202 | ax.figure.savefig(os.path.join(save_path,f"frame_{i}.png")) 203 | else: 204 | ani.save(save_path, fps=fps) 205 | 206 | plt.close() 207 | 208 | if __name__ == '__main__': 209 | npy_file = '/apdcephfs_cq3/share_1290939/zepingren/ufc101/annot2/t2m2/v_BaseballPitch_g09_c02.npy' 210 | motion = np.load(npy_file) 211 | skeleton = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 212 | # save_dir = npy_file[:-4] 213 | # os.makedirs(save_dir,exist_ok=True) 214 | # plot_3d_motion(save_dir, skeleton,motion) 215 | # plot_3d_motion('test.mp4', skeleton,motion) 216 | plot_2d_motion('test.mp4', skeleton, motion, kp_thr=-1) 217 | # plot_3d_motion('test.mp4', skeleton,motion,view_angle=(180,-30)) -------------------------------------------------------------------------------- /model/crossdiff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import ModuleList 5 | import copy 6 | from model.mdm import PositionalEncoding, TimestepEmbedder 7 | 8 | class CrossDiff(nn.Module): 9 | def __init__(self, args): 10 | super().__init__() 11 | 12 | 13 | if args.dataset == 'kit': 14 | self.njoints_m = 251 15 | self.njoints_j = 128 16 | else: 17 | self.njoints_m = 263 18 | self.njoints_j = 134 19 | 20 | 21 | 22 | self.latent_dim = args.latent_dim # 512 23 | 24 | self.ff_size = 1024 25 | self.num_layers = args.layers 26 | self.num_layers2 = args.layers2 27 | self.num_heads = 4 28 | self.dropout = 0.1 29 | 30 | self.activation = 'gelu' 31 | self.clip_dim = 512 32 | self.action_emb = 'tensor' 33 | 34 | self.input_feats_m = self.njoints_m 35 | self.input_feats_j = self.njoints_j 36 | 37 | 38 | self.cond_mode = 'text' 39 | self.cond_mask_prob = args.cond_mask_prob 40 | 41 | # pose pipeline 42 | self.p_linear_m = nn.Linear(self.input_feats_m-4, self.latent_dim) 43 | self.p_linear_j = nn.Linear(self.input_feats_j-2, self.latent_dim) 44 | 45 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) 46 | 47 | seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, 48 | nhead=self.num_heads, 49 | dim_feedforward=self.ff_size, 50 | dropout=self.dropout, 51 | activation=self.activation) 52 | seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim, 53 | nhead=self.num_heads, 54 | dim_feedforward=self.ff_size, 55 | dropout=self.dropout, 56 | activation=self.activation) 57 | 58 | self.seqTransEncoder_m = nn.TransformerEncoder(seqTransEncoderLayer, 59 | num_layers=self.num_layers) 60 | self.seqTransEncoder_j = nn.TransformerEncoder(seqTransEncoderLayer, 61 | num_layers=self.num_layers) 62 | 63 | self.motion_token = nn.Parameter(torch.randn((196, 1, self.latent_dim))) 64 | self.joint_token = nn.Parameter(torch.randn((196, 1, self.latent_dim))) 65 | self.decoder_m = ModuleList([copy.deepcopy(seqTransDecoderLayer) for i in range(self.num_layers2)]) 66 | self.decoder_j = ModuleList([copy.deepcopy(seqTransDecoderLayer) for i in range(self.num_layers2)]) 67 | self.bridge = ModuleList([copy.deepcopy(seqTransEncoderLayer) for i in range(self.num_layers2)]) 68 | 69 | # root pipeline 70 | r_EncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, 71 | nhead=self.num_heads, 72 | dim_feedforward=512, 73 | dropout=self.dropout, 74 | activation=self.activation) 75 | r_DecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim, 76 | nhead=self.num_heads, 77 | dim_feedforward=512, 78 | dropout=self.dropout, 79 | activation=self.activation) 80 | self.r_linear_m = nn.Linear(4, self.latent_dim) 81 | self.r_encoder_m = nn.TransformerEncoder(r_EncoderLayer, 82 | num_layers=2) 83 | self.r_decoder_m = ModuleList([copy.deepcopy(r_DecoderLayer) for i in range(2)]) 84 | self.r_linear2_m = nn.Linear(self.latent_dim, 4) 85 | self.r_linear_j = nn.Linear(2, self.latent_dim) 86 | self.r_encoder_j = nn.TransformerEncoder(r_EncoderLayer, 87 | num_layers=2) 88 | self.r_decoder_j = ModuleList([copy.deepcopy(r_DecoderLayer) for i in range(2)]) 89 | self.r_linear2_j = nn.Linear(self.latent_dim, 2) 90 | 91 | 92 | self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) 93 | self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) 94 | 95 | self.p_linear2_m = nn.Linear(self.latent_dim, self.input_feats_m-4) 96 | self.p_linear2_j = nn.Linear(self.latent_dim, self.input_feats_j-2) 97 | 98 | 99 | 100 | 101 | def mask_cond(self, cond, force_mask=False): 102 | bs, d = cond.shape 103 | if force_mask: 104 | return torch.zeros_like(cond) 105 | elif self.training and self.cond_mask_prob > 0.: 106 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond 107 | return cond * (1. - mask) 108 | else: 109 | return cond 110 | 111 | 112 | def forward(self, x, timesteps, y, force_mask=False, return_m=True, return_j=False): 113 | """ 114 | x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper 115 | timesteps: [batch_size] (int) 116 | """ 117 | bs, njoints, nfeats, nframes = x.shape 118 | x = x[:,:,0].permute(2,0,1) 119 | 120 | emb = self.embed_timestep(timesteps) # [1, bs, d] 121 | enc_text = y 122 | emb += self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) 123 | 124 | if njoints == 263 or njoints == 251: 125 | x_root = self.r_linear_m(x[:,:,:4]) 126 | x_pose = self.p_linear_m(x[:,:,4:]) 127 | else: 128 | x_root = self.r_linear_j(x[:,:,:2]) 129 | x_pose = self.p_linear_j(x[:,:,2:]) 130 | 131 | xseq_pose = torch.cat((emb, x_pose), axis=0) # [seqlen+1, bs, d] 132 | xseq_pose = self.sequence_pos_encoder(xseq_pose) # [seqlen+1, bs, d] 133 | xseq_root = torch.cat((emb, x_root), axis=0) 134 | xseq_root = self.sequence_pos_encoder(xseq_root) 135 | 136 | if njoints == 263 or njoints == 251: 137 | xseq_pose = self.seqTransEncoder_m(xseq_pose) # , src_key_padding_mask=~maskseq) # [seqlen+1, bs, d] 138 | xseq_root = self.r_encoder_m(xseq_root) 139 | else: 140 | xseq_pose = self.seqTransEncoder_j(xseq_pose) 141 | xseq_root = self.r_encoder_j(xseq_root) 142 | 143 | middle_infos = [] 144 | for mod in self.bridge: 145 | xseq_pose = mod(xseq_pose) 146 | middle_infos.append(xseq_pose) 147 | 148 | output_m = None 149 | output_j = None 150 | if return_m: 151 | output_m = self.motion_token.expand((-1, bs, -1)) 152 | r_output_m = xseq_root[1:] 153 | for i, (info, mod) in enumerate(zip(middle_infos, self.decoder_m)): 154 | output_m = mod(output_m, info) 155 | if i >= self.num_layers2 - 2: 156 | r_output_m = self.r_decoder_m[i + 2 - self.num_layers2](r_output_m, output_m) 157 | output_m = self.p_linear2_m(output_m) # [nframes, bs, nfeats] 158 | r_output_m = self.r_linear2_m(r_output_m) 159 | output_m = torch.cat([r_output_m, output_m], dim=-1) 160 | output_m = output_m.permute(1,2,0).unsqueeze(2) 161 | 162 | if return_j: 163 | output_j = self.joint_token.expand((-1, bs, -1)) 164 | r_output_j = xseq_root[1:] 165 | for i, (info, mod) in enumerate(zip(middle_infos, self.decoder_j)): 166 | output_j = mod(output_j, info) 167 | if i >= self.num_layers2 - 2: 168 | r_output_j = self.r_decoder_j[i + 2 - self.num_layers2](r_output_j, output_j) 169 | output_j = self.p_linear2_j(output_j) # [nframes, bs, nfeats] 170 | r_output_j = self.r_linear2_j(r_output_j) 171 | output_j = torch.cat([r_output_j, output_j], dim=-1) 172 | output_j = output_j.permute(1,2,0).unsqueeze(2) 173 | 174 | return {'m': output_m, 'j': output_j} 175 | 176 | -------------------------------------------------------------------------------- /diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from diffusion import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /visualize/joints2smpl/src/prior.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import sys 22 | import os 23 | 24 | import time 25 | import pickle 26 | 27 | import numpy as np 28 | 29 | import torch 30 | import torch.nn as nn 31 | 32 | DEFAULT_DTYPE = torch.float32 33 | 34 | 35 | def create_prior(prior_type, **kwargs): 36 | if prior_type == 'gmm': 37 | prior = MaxMixturePrior(**kwargs) 38 | elif prior_type == 'l2': 39 | return L2Prior(**kwargs) 40 | elif prior_type == 'angle': 41 | return SMPLifyAnglePrior(**kwargs) 42 | elif prior_type == 'none' or prior_type is None: 43 | # Don't use any pose prior 44 | def no_prior(*args, **kwargs): 45 | return 0.0 46 | prior = no_prior 47 | else: 48 | raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') 49 | return prior 50 | 51 | 52 | class SMPLifyAnglePrior(nn.Module): 53 | def __init__(self, dtype=torch.float32, **kwargs): 54 | super(SMPLifyAnglePrior, self).__init__() 55 | 56 | # Indices for the roration angle of 57 | # 55: left elbow, 90deg bend at -np.pi/2 58 | # 58: right elbow, 90deg bend at np.pi/2 59 | # 12: left knee, 90deg bend at np.pi/2 60 | # 15: right knee, 90deg bend at np.pi/2 61 | angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) 62 | angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) 63 | self.register_buffer('angle_prior_idxs', angle_prior_idxs) 64 | 65 | angle_prior_signs = np.array([1, -1, -1, -1], 66 | dtype=np.float32 if dtype == torch.float32 67 | else np.float64) 68 | angle_prior_signs = torch.tensor(angle_prior_signs, 69 | dtype=dtype) 70 | self.register_buffer('angle_prior_signs', angle_prior_signs) 71 | 72 | def forward(self, pose, with_global_pose=False): 73 | ''' Returns the angle prior loss for the given pose 74 | 75 | Args: 76 | pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle 77 | representation of the rotations of the joints of the SMPL model. 78 | Kwargs: 79 | with_global_pose: Whether the pose vector also contains the global 80 | orientation of the SMPL model. If not then the indices must be 81 | corrected. 82 | Returns: 83 | A sze (B) tensor containing the angle prior loss for each element 84 | in the batch. 85 | ''' 86 | angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 87 | return torch.exp(pose[:, angle_prior_idxs] * 88 | self.angle_prior_signs).pow(2) 89 | 90 | 91 | class L2Prior(nn.Module): 92 | def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): 93 | super(L2Prior, self).__init__() 94 | 95 | def forward(self, module_input, *args): 96 | return torch.sum(module_input.pow(2)) 97 | 98 | 99 | class MaxMixturePrior(nn.Module): 100 | 101 | def __init__(self, prior_folder='prior', 102 | num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, 103 | use_merged=True, 104 | **kwargs): 105 | super(MaxMixturePrior, self).__init__() 106 | 107 | if dtype == DEFAULT_DTYPE: 108 | np_dtype = np.float32 109 | elif dtype == torch.float64: 110 | np_dtype = np.float64 111 | else: 112 | print('Unknown float type {}, exiting!'.format(dtype)) 113 | sys.exit(-1) 114 | 115 | self.num_gaussians = num_gaussians 116 | self.epsilon = epsilon 117 | self.use_merged = use_merged 118 | gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) 119 | 120 | full_gmm_fn = os.path.join(prior_folder, gmm_fn) 121 | if not os.path.exists(full_gmm_fn): 122 | print('The path to the mixture prior "{}"'.format(full_gmm_fn) + 123 | ' does not exist, exiting!') 124 | sys.exit(-1) 125 | 126 | with open(full_gmm_fn, 'rb') as f: 127 | gmm = pickle.load(f, encoding='latin1') 128 | 129 | if type(gmm) == dict: 130 | means = gmm['means'].astype(np_dtype) 131 | covs = gmm['covars'].astype(np_dtype) 132 | weights = gmm['weights'].astype(np_dtype) 133 | elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): 134 | means = gmm.means_.astype(np_dtype) 135 | covs = gmm.covars_.astype(np_dtype) 136 | weights = gmm.weights_.astype(np_dtype) 137 | else: 138 | print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) 139 | sys.exit(-1) 140 | 141 | self.register_buffer('means', torch.tensor(means, dtype=dtype)) 142 | 143 | self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) 144 | 145 | precisions = [np.linalg.inv(cov) for cov in covs] 146 | precisions = np.stack(precisions).astype(np_dtype) 147 | 148 | self.register_buffer('precisions', 149 | torch.tensor(precisions, dtype=dtype)) 150 | 151 | # The constant term: 152 | sqrdets = np.array([(np.sqrt(np.linalg.det(c))) 153 | for c in gmm['covars']]) 154 | const = (2 * np.pi)**(69 / 2.) 155 | 156 | nll_weights = np.asarray(gmm['weights'] / (const * 157 | (sqrdets / sqrdets.min()))) 158 | nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) 159 | self.register_buffer('nll_weights', nll_weights) 160 | 161 | weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) 162 | self.register_buffer('weights', weights) 163 | 164 | self.register_buffer('pi_term', 165 | torch.log(torch.tensor(2 * np.pi, dtype=dtype))) 166 | 167 | cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) 168 | for cov in covs] 169 | self.register_buffer('cov_dets', 170 | torch.tensor(cov_dets, dtype=dtype)) 171 | 172 | # The dimensionality of the random variable 173 | self.random_var_dim = self.means.shape[1] 174 | 175 | def get_mean(self): 176 | ''' Returns the mean of the mixture ''' 177 | mean_pose = torch.matmul(self.weights, self.means) 178 | return mean_pose 179 | 180 | def merged_log_likelihood(self, pose, betas): 181 | diff_from_mean = pose.unsqueeze(dim=1) - self.means 182 | 183 | prec_diff_prod = torch.einsum('mij,bmj->bmi', 184 | [self.precisions, diff_from_mean]) 185 | diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) 186 | 187 | curr_loglikelihood = 0.5 * diff_prec_quadratic - \ 188 | torch.log(self.nll_weights) 189 | # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + 190 | # self.random_var_dim * self.pi_term + 191 | # diff_prec_quadratic 192 | # ) - torch.log(self.weights) 193 | 194 | min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) 195 | return min_likelihood 196 | 197 | def log_likelihood(self, pose, betas, *args, **kwargs): 198 | ''' Create graph operation for negative log-likelihood calculation 199 | ''' 200 | likelihoods = [] 201 | 202 | for idx in range(self.num_gaussians): 203 | mean = self.means[idx] 204 | prec = self.precisions[idx] 205 | cov = self.covs[idx] 206 | diff_from_mean = pose - mean 207 | 208 | curr_loglikelihood = torch.einsum('bj,ji->bi', 209 | [diff_from_mean, prec]) 210 | curr_loglikelihood = torch.einsum('bi,bi->b', 211 | [curr_loglikelihood, 212 | diff_from_mean]) 213 | cov_term = torch.log(torch.det(cov) + self.epsilon) 214 | curr_loglikelihood += 0.5 * (cov_term + 215 | self.random_var_dim * 216 | self.pi_term) 217 | likelihoods.append(curr_loglikelihood) 218 | 219 | log_likelihoods = torch.stack(likelihoods, dim=1) 220 | min_idx = torch.argmin(log_likelihoods, dim=1) 221 | weight_component = self.nll_weights[:, min_idx] 222 | weight_component = -torch.log(weight_component) 223 | 224 | return weight_component + log_likelihoods[:, min_idx] 225 | 226 | def forward(self, pose, betas): 227 | if self.use_merged: 228 | return self.merged_log_likelihood(pose, betas) 229 | else: 230 | return self.log_likelihood(pose, betas) -------------------------------------------------------------------------------- /visualize/joints2smpl/src/customloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from visualize.joints2smpl.src import config 4 | 5 | # Guassian 6 | def gmof(x, sigma): 7 | """ 8 | Geman-McClure error function 9 | """ 10 | x_squared = x ** 2 11 | sigma_squared = sigma ** 2 12 | return (sigma_squared * x_squared) / (sigma_squared + x_squared) 13 | 14 | # angle prior 15 | def angle_prior(pose): 16 | """ 17 | Angle prior that penalizes unnatural bending of the knees and elbows 18 | """ 19 | # We subtract 3 because pose does not include the global rotation of the model 20 | return torch.exp( 21 | pose[:, [55 - 3, 58 - 3, 12 - 3, 15 - 3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2 22 | 23 | 24 | def perspective_projection(points, rotation, translation, 25 | focal_length, camera_center): 26 | """ 27 | This function computes the perspective projection of a set of points. 28 | Input: 29 | points (bs, N, 3): 3D points 30 | rotation (bs, 3, 3): Camera rotation 31 | translation (bs, 3): Camera translation 32 | focal_length (bs,) or scalar: Focal length 33 | camera_center (bs, 2): Camera center 34 | """ 35 | batch_size = points.shape[0] 36 | K = torch.zeros([batch_size, 3, 3], device=points.device) 37 | K[:, 0, 0] = focal_length 38 | K[:, 1, 1] = focal_length 39 | K[:, 2, 2] = 1. 40 | K[:, :-1, -1] = camera_center 41 | 42 | # Transform points 43 | points = torch.einsum('bij,bkj->bki', rotation, points) 44 | points = points + translation.unsqueeze(1) 45 | 46 | # Apply perspective distortion 47 | projected_points = points / points[:, :, -1].unsqueeze(-1) 48 | 49 | # Apply camera intrinsics 50 | projected_points = torch.einsum('bij,bkj->bki', K, projected_points) 51 | 52 | return projected_points[:, :, :-1] 53 | 54 | 55 | def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center, 56 | joints_2d, joints_conf, pose_prior, 57 | focal_length=5000, sigma=100, pose_prior_weight=4.78, 58 | shape_prior_weight=5, angle_prior_weight=15.2, 59 | output='sum'): 60 | """ 61 | Loss function for body fitting 62 | """ 63 | batch_size = body_pose.shape[0] 64 | rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(batch_size, -1, -1) 65 | 66 | projected_joints = perspective_projection(model_joints, rotation, camera_t, 67 | focal_length, camera_center) 68 | 69 | # Weighted robust reprojection error 70 | reprojection_error = gmof(projected_joints - joints_2d, sigma) 71 | reprojection_loss = (joints_conf ** 2) * reprojection_error.sum(dim=-1) 72 | 73 | # Pose prior loss 74 | pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) 75 | 76 | # Angle prior for knees and elbows 77 | angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) 78 | 79 | # Regularizer to prevent betas from taking large values 80 | shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) 81 | 82 | total_loss = reprojection_loss.sum(dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss 83 | 84 | if output == 'sum': 85 | return total_loss.sum() 86 | elif output == 'reprojection': 87 | return reprojection_loss 88 | 89 | 90 | # --- get camera fitting loss ----- 91 | def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center, 92 | joints_2d, joints_conf, 93 | focal_length=5000, depth_loss_weight=100): 94 | """ 95 | Loss function for camera optimization. 96 | """ 97 | # Project model joints 98 | batch_size = model_joints.shape[0] 99 | rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(batch_size, -1, -1) 100 | projected_joints = perspective_projection(model_joints, rotation, camera_t, 101 | focal_length, camera_center) 102 | 103 | # get the indexed four 104 | op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] 105 | op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] 106 | gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] 107 | gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] 108 | 109 | reprojection_error_op = (joints_2d[:, op_joints_ind] - 110 | projected_joints[:, op_joints_ind]) ** 2 111 | reprojection_error_gt = (joints_2d[:, gt_joints_ind] - 112 | projected_joints[:, gt_joints_ind]) ** 2 113 | 114 | # Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections 115 | # OpenPose joints are more reliable for this task, so we prefer to use them if possible 116 | is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:, None, None] > 0).float() 117 | reprojection_loss = (is_valid * reprojection_error_op + (1 - is_valid) * reprojection_error_gt).sum(dim=(1, 2)) 118 | 119 | # Loss that penalizes deviation from depth estimate 120 | depth_loss = (depth_loss_weight ** 2) * (camera_t[:, 2] - camera_t_est[:, 2]) ** 2 121 | 122 | total_loss = reprojection_loss + depth_loss 123 | return total_loss.sum() 124 | 125 | 126 | 127 | # #####--- body fitiing loss ----- 128 | def body_fitting_loss_3d(body_pose, preserve_pose, 129 | betas, model_joints, camera_translation, 130 | j3d, pose_prior, 131 | joints3d_conf, 132 | sigma=100, pose_prior_weight=4.78*1.5, 133 | shape_prior_weight=5.0, angle_prior_weight=15.2, 134 | joint_loss_weight=500.0, 135 | pose_preserve_weight=0.0, 136 | use_collision=False, 137 | model_vertices=None, model_faces=None, 138 | search_tree=None, pen_distance=None, filter_faces=None, 139 | collision_loss_weight=1000 140 | ): 141 | """ 142 | Loss function for body fitting 143 | """ 144 | batch_size = body_pose.shape[0] 145 | 146 | #joint3d_loss = (joint_loss_weight ** 2) * gmof((model_joints + camera_translation) - j3d, sigma).sum(dim=-1) 147 | 148 | joint3d_error = gmof((model_joints + camera_translation) - j3d, sigma) 149 | 150 | joint3d_loss_part = (joints3d_conf ** 2) * joint3d_error.sum(dim=-1) 151 | joint3d_loss = ((joint_loss_weight ** 2) * joint3d_loss_part).sum(dim=-1) 152 | 153 | # Pose prior loss 154 | pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) 155 | # Angle prior for knees and elbows 156 | angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) 157 | # Regularizer to prevent betas from taking large values 158 | shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) 159 | 160 | collision_loss = 0.0 161 | # Calculate the loss due to interpenetration 162 | if use_collision: 163 | triangles = torch.index_select( 164 | model_vertices, 1, 165 | model_faces).view(batch_size, -1, 3, 3) 166 | 167 | with torch.no_grad(): 168 | collision_idxs = search_tree(triangles) 169 | 170 | # Remove unwanted collisions 171 | if filter_faces is not None: 172 | collision_idxs = filter_faces(collision_idxs) 173 | 174 | if collision_idxs.ge(0).sum().item() > 0: 175 | collision_loss = torch.sum(collision_loss_weight * pen_distance(triangles, collision_idxs)) 176 | 177 | pose_preserve_loss = (pose_preserve_weight ** 2) * ((body_pose - preserve_pose) ** 2).sum(dim=-1) 178 | 179 | # print('joint3d_loss', joint3d_loss.shape) 180 | # print('pose_prior_loss', pose_prior_loss.shape) 181 | # print('angle_prior_loss', angle_prior_loss.shape) 182 | # print('shape_prior_loss', shape_prior_loss.shape) 183 | # print('collision_loss', collision_loss) 184 | # print('pose_preserve_loss', pose_preserve_loss.shape) 185 | 186 | total_loss = joint3d_loss + pose_prior_loss + angle_prior_loss + shape_prior_loss + collision_loss + pose_preserve_loss 187 | 188 | return total_loss.sum() 189 | 190 | 191 | # #####--- get camera fitting loss ----- 192 | def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est, 193 | j3d, joints_category="orig", depth_loss_weight=100.0): 194 | """ 195 | Loss function for camera optimization. 196 | """ 197 | model_joints = model_joints + camera_t 198 | # # get the indexed four 199 | # op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] 200 | # op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] 201 | # 202 | # j3d_error_loss = (j3d[:, op_joints_ind] - 203 | # model_joints[:, op_joints_ind]) ** 2 204 | 205 | gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] 206 | gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] 207 | 208 | if joints_category=="orig": 209 | select_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] 210 | elif joints_category=="AMASS": 211 | select_joints_ind = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] 212 | else: 213 | print("NO SUCH JOINTS CATEGORY!") 214 | 215 | j3d_error_loss = (j3d[:, select_joints_ind] - 216 | model_joints[:, gt_joints_ind]) ** 2 217 | 218 | # Loss that penalizes deviation from depth estimate 219 | depth_loss = (depth_loss_weight**2) * (camera_t - camera_t_est)**2 220 | 221 | total_loss = j3d_error_loss + depth_loss 222 | return total_loss.sum() 223 | -------------------------------------------------------------------------------- /prepare/project.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from matplotlib.animation import FuncAnimation 7 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 8 | import mpl_toolkits.mplot3d.axes3d as p3 9 | 10 | 11 | 12 | 13 | def easy_project(joints3d, cam_t, tan_fov): 14 | joints3d_cam = joints3d + cam_t 15 | joints2d = - joints3d_cam[...,:2] / (joints3d_cam[...,2:] * tan_fov) 16 | return joints2d 17 | 18 | def project_np(joints3d, cam_t, tan_fov): 19 | return easy_project(torch.from_numpy(joints3d), cam_t, tan_fov).numpy() 20 | 21 | def plot_2d_motion(save_path, joints, figsize=(10, 10), fps=20, radius=4, kinematic_tree=None): 22 | fig, ax = plt.subplots() 23 | # MINS = joints.min(axis=(0,1)) 24 | # MAXS = joints.max(axis=(0,1)) 25 | MINS = [-1, -1] 26 | MAXS = [1, 1] 27 | 28 | colors = ['red', 'blue', 'black', 'red', 'blue', 29 | 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue', 30 | 'darkred', 'darkred','darkred','darkred','darkred'] 31 | frame_number = joints.shape[0] 32 | def update(frame): 33 | ax.clear() 34 | ax.set_xlim(MINS[0], MAXS[0]) 35 | ax.set_ylim(MINS[1], MAXS[1]) 36 | 37 | for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): 38 | if i < 5: 39 | linewidth = 4.0 40 | else: 41 | linewidth = 2.0 42 | ax.plot(joints[frame, chain, 0], joints[frame, chain, 1], linewidth=linewidth, color=color) 43 | 44 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False) 45 | 46 | ani.save(save_path, fps=fps) 47 | plt.close() 48 | 49 | 50 | def plot_3d_motion(save_path, kinematic_tree, joints, figsize=(10, 10), fps=20, radius=4): 51 | # matplotlib.use('Agg') 52 | 53 | # title_sp = title.split(' ') 54 | # if len(title_sp) > 10: 55 | # title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])]) 56 | def init(): 57 | ax.set_xlim3d([-radius / 2, radius / 2]) 58 | ax.set_ylim3d([0, radius]) 59 | ax.set_zlim3d([0, radius]) 60 | # print(title) 61 | # fig.suptitle(title, fontsize=20) 62 | # ax.grid(b=False) 63 | 64 | def plot_xzPlane(minx, maxx, miny, minz, maxz): 65 | ## Plot a plane XZ 66 | verts = [ 67 | [minx, miny, minz], 68 | [minx, miny, maxz], 69 | [maxx, miny, maxz], 70 | [maxx, miny, minz] 71 | ] 72 | xz_plane = Poly3DCollection([verts]) 73 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 74 | ax.add_collection3d(xz_plane) 75 | 76 | # return ax 77 | 78 | # (seq_len, joints_num, 3) 79 | data = joints.copy().reshape(len(joints), -1, 3) 80 | fig = plt.figure(figsize=figsize) 81 | ax = p3.Axes3D(fig) 82 | init() 83 | MINS = data.min(axis=0).min(axis=0) 84 | MAXS = data.max(axis=0).max(axis=0) 85 | colors = ['red', 'blue', 'black', 'yellow','green','red', 'blue', 86 | 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue', 87 | 'darkred', 'darkred','darkred','darkred','darkred'] 88 | frame_number = data.shape[0] 89 | # print(data.shape) 90 | 91 | # height_offset = MINS[1] 92 | # data[:, :, 1] -= height_offset 93 | trajec = data[:, 0, [0, 2]] 94 | 95 | # data[..., 0] -= data[:, 0:1, 0] 96 | # data[..., 2] -= data[:, 0:1, 2] 97 | 98 | # print(trajec.shape) 99 | 100 | def update(index): 101 | # print(index) 102 | ax.clear() 103 | # ax.collections = [] 104 | ax.view_init(elev=120, azim=-90) 105 | ax.dist = 7.5 106 | # ax = 107 | plot_xzPlane(MINS[0]-trajec[index, 0], MAXS[0]-trajec[index, 0], 0, MINS[2]-trajec[index, 1], MAXS[2]-trajec[index, 1]) 108 | # ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3) 109 | 110 | if index > 1: 111 | ax.plot3D(trajec[:index, 0]-trajec[index, 0], np.zeros_like(trajec[:index, 0]), trajec[:index, 1]-trajec[index, 1], linewidth=1.0, 112 | color='blue') 113 | # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2]) 114 | 115 | 116 | for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): 117 | # print(color) 118 | if i < 5: 119 | linewidth = 4.0 120 | else: 121 | linewidth = 2.0 122 | ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, color=color) 123 | # print(trajec[:index, 0].shape) 124 | 125 | plt.axis('off') 126 | ax.set_xticklabels([]) 127 | ax.set_yticklabels([]) 128 | ax.set_zticklabels([]) 129 | 130 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False) 131 | 132 | ani.save(save_path, fps=fps) 133 | plt.close() 134 | 135 | def get_matrix_from_vec(vec): 136 | # (B,2) 137 | vec = vec / np.sqrt(np.sum(vec ** 2, axis=-1, keepdims=True)) 138 | mat = np.stack([vec[:, 0], -vec[:,1], vec[:,1], vec[:,0]], axis=1).reshape((-1,2,2)) 139 | return mat 140 | 141 | def process_file(positions, feet_thre=0.002): 142 | # (seq_len, joints_num, 2) 143 | fid_r, fid_l = [8, 11], [7, 10] 144 | # fid_r, fid_l = [14, 15], [19, 20] #KIT 145 | T = positions.shape[0] 146 | positions = positions - positions[:1,:1] 147 | 148 | global_positions = positions.copy() 149 | 150 | 151 | def foot_detect(positions, thres): 152 | velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0]) 153 | 154 | feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 155 | feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 156 | feet_l = ((feet_l_x + feet_l_y) < velfactor).astype(np.float32) 157 | 158 | feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 159 | feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 160 | feet_r = (((feet_r_x + feet_r_y) < velfactor)).astype(np.float32) 161 | return feet_l, feet_r 162 | # 163 | feet_l, feet_r = foot_detect(positions, feet_thre) 164 | 165 | 166 | 167 | rot_params = np.zeros((positions.shape[0], positions.shape[1] - 1, 2, 2)) 168 | for chain in kinematic_chain: 169 | R = np.eye(2)[None,].repeat(len(positions), axis=0) 170 | for j in range(len(chain) - 1): 171 | # (batch, 3) 172 | v = positions[:, chain[j+1]] - positions[:, chain[j]] 173 | rot_mat = get_matrix_from_vec(v) 174 | 175 | R_loc = np.einsum('bij,bjk->bik', R.transpose(0,2,1), rot_mat) 176 | 177 | rot_params[:,chain[j + 1] - 1] = R_loc 178 | R = rot_mat 179 | 180 | cont_2d_params = rot_params[:,:,0] 181 | 182 | root_v = global_positions[1:,0] - global_positions[:1,0] 183 | 184 | positions -= positions[:, 0:1] 185 | positions = positions[:,1:].reshape(T, -1) 186 | 187 | rot_data = cont_2d_params.reshape(T, -1) 188 | 189 | local_vel = global_positions[1:] - global_positions[:-1] 190 | local_vel = local_vel.reshape(len(local_vel), -1) 191 | 192 | # root_v:2, loc_position:21*2, rot_2d:21*2, local_vel:22*2, feet_contact:4--[B,134] 193 | data = np.concatenate([root_v, positions[:-1], rot_data[:-1], local_vel, feet_l, feet_r], axis=-1) 194 | 195 | return data 196 | 197 | if __name__ == '__main__': 198 | from argparse import ArgumentParser 199 | parser = ArgumentParser() 200 | 201 | parser.add_argument("--data_root", required=True, type=str) 202 | 203 | args = parser.parse_args() 204 | joint3d_dir = os.path.join(args.data_root, 'new_joints') 205 | joint2d_dir = os.path.join(args.data_root, 'new_joints2d') 206 | movie_dir = os.path.join(args.data_root, 'animations') 207 | joint2d_complicate_dir = os.path.join(args.data_root, 'new_joints2d_complicate') 208 | 209 | # kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] kit 210 | kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 211 | FOV = 10 212 | tan_fov = np.tan(np.radians(FOV/2.)) 213 | cam_t = torch.tensor([0, -1, -1 / tan_fov-10], dtype=torch.float32) 214 | # cam_t = torch.tensor([0, -1, -1 / tan_fov-10], dtype=torch.float32) * 520 215 | Rs = torch.tensor([[[1,0,0],[0,1,0],[0,0,1]], [[0,0,1],[0,1,0],[-1,0,0]], 216 | [[-1,0,0],[0,1,0],[0,0,-1]],[[0,0,-1],[0,1,0],[1,0,0]]],dtype=torch.float32) 217 | 218 | os.makedirs(joint2d_dir, exist_ok=True) 219 | os.makedirs(movie_dir, exist_ok=True) 220 | 221 | files = os.listdir(joint3d_dir) 222 | try: 223 | for i, file in enumerate(tqdm(files)): 224 | joint3d_file = os.path.join(joint3d_dir, file) 225 | 226 | 227 | joints3d = np.load(joint3d_file) 228 | joints3d = torch.from_numpy(joints3d) 229 | for j,R in enumerate(Rs): 230 | # project 231 | save_file = os.path.join(joint2d_dir, file[:-4] + f'-{j}.npy') 232 | save_file_complicate = os.path.join(joint2d_complicate_dir, file[:-4] + f'-{j}.npy') 233 | if os.path.exists(save_file_complicate): 234 | continue 235 | if len(joints3d.shape) != 3: 236 | print(f'wrong in {joint3d_file}') 237 | np.save(save_file, np.zeros((1,21,2), dtype=np.float32)) 238 | continue 239 | joint3d = torch.einsum('tjk,pk->tjp', joints3d, R) 240 | joints2d = easy_project(joint3d, cam_t, tan_fov).numpy() 241 | joints2d = joints2d - joints2d[:1,:1] 242 | if i < 10: 243 | if j == 0: 244 | movie3d_file = os.path.join(movie_dir, file[:-4] + '-3d.mp4') 245 | plot_3d_motion(movie3d_file, kinematic_chain, joints3d.numpy()) 246 | movie2d_file = os.path.join(movie_dir, file[:-4] + f'-2d-{j}.mp4') 247 | plot_2d_motion(movie2d_file, joints2d, kinematic_tree=kinematic_chain) 248 | np.save(save_file, joints2d) 249 | 250 | 251 | # process 252 | if joints2d.shape[1] != 22: 253 | raise NameError('not 22 joint') 254 | if len(joints2d) == 1: 255 | # data = np.zeros((1,128), dtype=np.float32) 256 | joints2d_complicate = np.zeros((1,134), dtype=np.float32) 257 | else: 258 | joints2d_complicate = process_file(joints2d) 259 | np.save(save_file_complicate, joints2d_complicate.astype(np.float32)) 260 | 261 | except Exception as e: 262 | print(file) 263 | print(e) 264 | 265 | 266 | print('all done!') 267 | --------------------------------------------------------------------------------