├── src ├── tools │ ├── common │ │ ├── __init__.py │ │ ├── skeleton.py │ │ └── quaternion.py │ ├── core_utils.py │ ├── img_gif.py │ ├── calculate_ev_metrics.py │ ├── utils.py │ ├── transformations.py │ └── bookkeeper.py ├── Lindyhop │ ├── LindyHop_dataloader.py │ ├── argUtils.py │ ├── process_LindyHop.py │ ├── train_VanillaTransformer.py │ ├── visualizer.py │ ├── models │ │ ├── MotionDiffuse_body.py │ │ └── MotionDiffusion_hand.py │ ├── train_hand_diffusion.py │ └── train_body_diffusion.py └── Ninjutsu │ ├── Ninjutsu_dataloader.py │ ├── argUtils.py │ └── process_Ninjutsu.py ├── save └── save.txt ├── data └── data.txt ├── requirements.txt └── README.md /src/tools/common/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /save/save.txt: -------------------------------------------------------------------------------- 1 | The pre-trained weights are saved in this folder. 2 | -------------------------------------------------------------------------------- /data/data.txt: -------------------------------------------------------------------------------- 1 | The processed data (train and test splits) will be stored here as pkl files after you run the 'process_ LindyHop.py' and 'process_Ninjutsu.py'. 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | accelerate 3 | argunparse 4 | astunparse 5 | async-timeout 6 | bvh-converter 7 | c3d 8 | decorator 9 | denoising-diffusion-pytorch 10 | dill 11 | diskcache 12 | einops 13 | ema-pytorch 14 | h5py 15 | imageio 16 | imageio 17 | imageio-ffmpeg 18 | joblib 19 | keras 20 | matplotlib 21 | numpy 22 | openai 23 | opencv-python 24 | pandas 25 | pathos 26 | pathtools 27 | prettytable 28 | protobuf 29 | psutil 30 | pytorch-lightning 31 | pytorch3d 32 | pytz 33 | pyyaml 34 | scikit-learn 35 | scipy 36 | seaborn 37 | sklearn 38 | tensorboard 39 | torchmetrics 40 | tqdm 41 | wandb 42 | werkzeug -------------------------------------------------------------------------------- /src/tools/core_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import torch 5 | 6 | def send_to_cuda(model): 7 | for key in model.keys(): 8 | model[key].cuda() 9 | 10 | return model 11 | 12 | 13 | class AverageMeter(object): 14 | """Computes and stores the average and current value""" 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | def load_model(path, model, optimizer=None): 31 | pass 32 | 33 | def save_model(path, model, epoch, optimizer=None): 34 | state_dict = model.state_dict() 35 | 36 | data = {'epoch': epoch, 37 | 'state_dict': state_dict} 38 | 39 | if not (optimizer is None): 40 | data['optimzer'] = optimizer.state_dict() 41 | 42 | torch.save(data, path) 43 | 44 | def makepath(desired_path, isfile = False): 45 | ''' 46 | if the path does not exist make it 47 | :param desired_path: can be path to a file or a folder name 48 | :return: 49 | ''' 50 | import os 51 | if isfile: 52 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path)) 53 | else: 54 | if not os.path.exists(desired_path): os.makedirs(desired_path) 55 | return desired_path -------------------------------------------------------------------------------- /src/tools/img_gif.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import imageio 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | 8 | def img2gif(image_folder): 9 | seqs = glob.glob(image_folder + '/*.jpg') 10 | 11 | out_filename = image_folder.split('/')[-1] 12 | int_seq = [int(seqs[i].split('/')[-1].split('.')[0]) for i in range(len(seqs))] 13 | # int_seq = [int(seqs[i].split('/')[-1].split('.')[0].split('_')[-1]) for i in range(len(seqs))] 14 | index = sorted(range(len(int_seq)), key=lambda k: int_seq[k]) 15 | all_imgs = [seqs[index[i]] for i in range(len(index))] 16 | gif_name = os.path.join(image_folder, out_filename+ '.gif') 17 | with imageio.get_writer(gif_name, mode='I') as writer: 18 | for filename in all_imgs: 19 | image = imageio.imread(filename) 20 | writer.append_data(image) 21 | 22 | def img2gif_compress(fp_in): 23 | x = 800 24 | y = 400 25 | gif_name = os.path.join(image_folder, image_folder.split('/')[-1]+'compress.gif') 26 | q = 40 # Quality 27 | seqs = glob.glob(fp_in + '/*.jpg') 28 | int_seq = [int(seqs[i].split('/')[-1].split('.')[0]) for i in range(len(seqs))] 29 | index = sorted(range(len(int_seq)), key=lambda k: int_seq[k]) 30 | all_imgs = [seqs[index[i]] for i in range(len(index))] 31 | img, *imgs = [Image.open(f).resize((x,y),Image.ANTIALIAS) for f in all_imgs] 32 | img.save(fp=gif_name, format='GIF', append_images=imgs,quality=q, 33 | save_all=True, loop=0, optimize=True) 34 | 35 | 36 | def img2video(image_folder, fps, img_type='png'): 37 | seqs = glob.glob(image_folder + '/*.'+ img_type) 38 | out_filename = image_folder.split('/')[-1] 39 | int_seq = [int(seqs[i].split('/')[-1].split('.')[0].split('_')[-1]) for i in range(len(seqs))] 40 | index = sorted(range(len(int_seq)), key=lambda k: int_seq[k]) 41 | all_imgs = [seqs[index[i]] for i in range(len(index))] 42 | img_array = [] 43 | video_name =os.path.join(image_folder, out_filename+ '.avi') 44 | for filename in all_imgs: 45 | img = cv2.imread(filename) 46 | height, width, layers = img.shape 47 | size = (width, height) 48 | img_array.append(img) 49 | out = cv2.VideoWriter(video_name ,cv2.VideoWriter_fourcc(*'DIVX'), fps, size) 50 | for i in range(len(img_array)): 51 | out.write(img_array[i]) 52 | out.release() 53 | -------------------------------------------------------------------------------- /src/Lindyhop/LindyHop_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import pytorch3d.transforms as t3d 5 | import random 6 | import sys 7 | sys.path.append('.') 8 | sys.path.append('..') 9 | import torch 10 | from math import radians, cos, sin 11 | from scipy.spatial.transform import Rotation as R 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | from src.Lindyhop.skeleton import InhouseStudioSkeleton 15 | from src.Lindyhop.visualizer import plot_contacts3D 16 | from src.tools.transformations import * 17 | from src.tools.utils import makepath 18 | from src.Lindyhop.argUtils import argparseNloop 19 | 20 | 21 | class LindyHopDataset(torch.utils.data.Dataset): 22 | def __init__(self, args, window_size=10, split='val'): 23 | self.root = args.data_dir 24 | self.scale = args.scale 25 | self.split = split 26 | self.window_size = int(window_size) 27 | with open(os.path.join(self.root, self.split+'.pkl'), 'rb') as f: 28 | self.annot_dict = pickle.load(f) 29 | self.output_keys = ['seq', 'pose_canon_1', 'pose_canon_2', 30 | 'contacts', 'dofs_1', 'dofs_2', 31 | 'rotmat_1', 'rotmat_2', 32 | 'offsets_1', 'offsets_2', 33 | ] 34 | self.skel = InhouseStudioSkeleton() 35 | 36 | 37 | def __getitem__(self, ind): 38 | index = ind % len(self.annot_dict['pose_canon_1']) 39 | annot = {} 40 | for key in self.output_keys: 41 | annot[key] = self.annot_dict[key][index] 42 | skip = 1 43 | start = np.random.randint(0, len(annot['pose_canon_1']) - self.window_size) 44 | end = start + self.window_size 45 | 46 | annot['contacts'] = annot['contacts'][start:end] # 0.rh-rh, 1: lh-lh, 2: lh-rh , 3: rh-lh) 47 | annot['pose_canon_1'] = annot['pose_canon_1'][start: end: skip] 48 | annot['pose_canon_2'] = annot['pose_canon_2'][start: end: skip] 49 | annot['dofs_1'] = np.pi * (annot['dofs_1'][start:end: skip]) / 180. 50 | annot['dofs_2'] = np.pi * (annot['dofs_2'][start:end: skip]) / 180. 51 | annot['rotmat_1'] = annot['rotmat_1'][start:end: skip] 52 | annot['rotmat_2'] = annot['rotmat_2'][start:end: skip] 53 | annot['global_root_rotation'] = np.linalg.inv(annot['rotmat_1'][:, 0]) 54 | annot['global_root_origin'] = annot['pose_canon_1'][:, 0] 55 | annot['p1_parent_rel'] = annot['pose_canon_1'][ :, 1:] - annot['pose_canon_1'][:, [self.skel.parents_full[x] for x in range(1, 69)]] 56 | annot['p2_parent_rel'] = annot['pose_canon_2'][:, 1:] - annot['pose_canon_2'][:, [self.skel.parents_full[x] for x in range(1, 69)]] 57 | return annot 58 | 59 | def __len__(self): 60 | return len(self.annot_dict['pose_canon_1']) 61 | 62 | -------------------------------------------------------------------------------- /src/Ninjutsu/Ninjutsu_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import pytorch3d.transforms as t3d 5 | import random 6 | import sys 7 | sys.path.append('.') 8 | sys.path.append('..') 9 | import torch 10 | from math import radians, cos, sin 11 | from scipy.spatial.transform import Rotation as R 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | from src.Ninjutsu.skeleton import InhouseStudioSkeleton 15 | from src.Ninjutsu.visualizer import plot_contacts3D 16 | from src.tools.transformations import * 17 | from src.tools.utils import makepath 18 | from src.Ninjutsu.argUtils import argparseNloop 19 | 20 | 21 | class NinjutsuDataset(torch.utils.data.Dataset): 22 | def __init__(self, args, window_size=10, split='val'): 23 | self.root = args.data_dir 24 | self.scale = args.scale 25 | self.split = split 26 | self.window_size = int(window_size) 27 | with open(os.path.join(self.root, self.split+'.pkl'), 'rb') as f: 28 | self.annot_dict = pickle.load(f) 29 | self.output_keys = ['seq', 'pose_canon_1', 'pose_canon_2', 30 | 'dofs_1', 'dofs_2', 31 | 'rotmat_1', 'rotmat_2', 32 | 'offsets_1', 'offsets_2', 33 | 'contacts' 34 | ] 35 | self.skel = InhouseStudioSkeleton() 36 | 37 | 38 | def __getitem__(self, ind): 39 | index = ind % len(self.annot_dict['pose_canon_1']) 40 | annot = {} 41 | for key in self.output_keys: 42 | annot[key] = self.annot_dict[key][index] 43 | skip = 1 44 | start = np.random.randint(0, len(annot['pose_canon_1']) - self.window_size) 45 | end = start + self.window_size 46 | annot['contacts'] = annot['contacts'][start: end: skip] 47 | annot['pose_canon_1'] = annot['pose_canon_1'][start: end: skip] 48 | annot['pose_canon_2'] = annot['pose_canon_2'][start: end: skip] 49 | annot['dofs_1'] = np.pi * (annot['dofs_1'][start:end: skip]) / 180. 50 | annot['dofs_2'] = np.pi * (annot['dofs_2'][start:end: skip]) / 180. 51 | annot['rotmat_1'] = annot['rotmat_1'][start:end: skip] 52 | annot['rotmat_2'] = annot['rotmat_2'][start:end: skip] 53 | # annot['seq'] = annot['seq'][start:end: skip] 54 | annot['offsets_1'] = annot['offsets_1'] 55 | annot['offsets_2'] = annot['offsets_2'] 56 | annot['global_root_origin'] = annot['pose_canon_1'][:, 0] 57 | annot['p1_parent_rel'] = annot['pose_canon_1'][ :, 1:] - annot['pose_canon_1'][:, [self.skel.parents_full[x] for x in range(1, 69)]] 58 | annot['p2_parent_rel'] = annot['pose_canon_2'][:, 1:] - annot['pose_canon_2'][:, [self.skel.parents_full[x] for x in range(1, 69)]] 59 | 60 | return annot 61 | 62 | def __len__(self): 63 | return len(self.annot_dict['pose_canon_1']) 64 | -------------------------------------------------------------------------------- /src/tools/calculate_ev_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import sys 5 | import torch 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | from scipy import linalg 9 | 10 | def mean_l2di_(reaction, reaction_gt): 11 | x = np.mean(np.sqrt(np.sum((reaction - reaction_gt)**2, -1))) 12 | return x 13 | 14 | def mean_jitter(reaction, reaction_gt, scale=0.1): 15 | a = reaction[:, 1:] - reaction[:, :-1] 16 | b = reaction_gt[:, 1:] - reaction_gt[:, :-1] 17 | x = np.mean(np.sqrt(np.sum((a - b)**2, -1))) * scale 18 | return x 19 | 20 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 21 | def euclidean_distance_matrix(matrix1, matrix2, scale=1.0): 22 | assert matrix1.shape[1] == matrix2.shape[1] 23 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 24 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 25 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 26 | dists = np.sqrt(d1 + d2 + d3) * scale # broadcasting 27 | return dists 28 | 29 | def calculate_top_k(mat, top_k): 30 | size = mat.shape[0] 31 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 32 | bool_mat = (mat == gt_mat) 33 | correct_vec = False 34 | top_k_list = [] 35 | for i in range(top_k): 36 | # print(correct_vec, bool_mat[:, i]) 37 | correct_vec = (correct_vec | bool_mat[:, i]) 38 | # print(correct_vec) 39 | top_k_list.append(correct_vec[:, None]) 40 | top_k_mat = np.concatenate(top_k_list, axis=1) 41 | return top_k_mat 42 | 43 | 44 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): 45 | dist_mat = euclidean_distance_matrix(embedding1, embedding2) 46 | argmax = np.argsort(dist_mat, axis=1) 47 | top_k_mat = calculate_top_k(argmax, top_k) 48 | if sum_all: 49 | return top_k_mat.sum(axis=0) 50 | else: 51 | return top_k_mat 52 | 53 | 54 | def calculate_matching_score(embedding1, embedding2, sum_all=False): 55 | assert len(embedding1.shape) == 2 56 | assert embedding1.shape[0] == embedding2.shape[0] 57 | assert embedding1.shape[1] == embedding2.shape[1] 58 | 59 | dist = linalg.norm(embedding1 - embedding2, axis=1) 60 | if sum_all: 61 | return dist.sum(axis=0) 62 | else: 63 | return dist 64 | 65 | 66 | 67 | def calculate_activation_statistics(activations): 68 | mu = np.mean(activations, axis=0) 69 | cov = np.cov(activations, rowvar=False) 70 | return mu, cov 71 | 72 | 73 | def calculate_diversity(activation, diversity_times, scale=1.0): 74 | assert len(activation.shape) == 2 75 | assert activation.shape[0] > diversity_times 76 | num_samples = activation.shape[0] 77 | 78 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 79 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 80 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) * scale 81 | return dist.mean() 82 | 83 | 84 | def calculate_multimodality(activation, multimodality_times): 85 | assert len(activation.shape) == 3 86 | assert activation.shape[1] > multimodality_times 87 | num_per_sent = activation.shape[1] 88 | 89 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 90 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 91 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 92 | return dist.mean() 93 | 94 | 95 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, scale=1e+1, eps=1e-6): 96 | 97 | mu1 = np.atleast_1d(mu1) 98 | mu2 = np.atleast_1d(mu2) 99 | 100 | sigma1 = np.atleast_2d(sigma1) 101 | sigma2 = np.atleast_2d(sigma2) 102 | 103 | assert mu1.shape == mu2.shape, \ 104 | 'Training and test mean vectors have different lengths' 105 | assert sigma1.shape == sigma2.shape, \ 106 | 'Training and test covariances have different dimensions' 107 | 108 | diff = mu1 - mu2 109 | 110 | # Product might be almost singular 111 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 112 | if not np.isfinite(covmean).all(): 113 | msg = ('fid calculation produces singular product; ' 114 | 'adding %s to diagonal of cov estimates') % eps 115 | print(msg) 116 | offset = np.eye(sigma1.shape[0]) * eps 117 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 118 | 119 | # Numerical error might give slight imaginary component 120 | if np.iscomplexobj(covmean): 121 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 122 | m = np.max(np.abs(covmean.imag)) 123 | raise ValueError('Imaginary component {}'.format(m)) 124 | covmean = covmean.real 125 | 126 | tr_covmean = np.trace(covmean) 127 | 128 | return scale * ((diff.dot(diff) + np.trace(sigma1) + 129 | np.trace(sigma2) - 2 * tr_covmean)) 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReMoS: 3D Motion-Conditioned Reaction Synthesis for Two-Person Interactions 2 | Accepted at the European Conference on Computer Vision (ECCV) 2024. 3 | 4 | [Paper](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/05358.pdf) | 5 | [Video](https://vcai.mpi-inf.mpg.de/projects/remos/Remos_ECCV_v2_1.mp4) | 6 | [Project Page](https://vcai.mpi-inf.mpg.de/projects/remos/) 7 | 8 | teaser image 9 | 10 | 11 | 12 | ## Pre-requisites 13 | We have tested our code on the following setups: 14 | * Ubuntu 20.04 LTS 15 | * Windows 10, 11 16 | * Python >= 3.8 17 | * Pytorch >= 1.11 18 | * conda >= 4.9.2 (optional but recommended) 19 | 20 | ## Getting started 21 | 22 | Follow these commands to create a conda environment: 23 | ``` 24 | conda create -n remos python=3.8 25 | conda activate remos 26 | conda install -c pytorch pytorch=1.11 torchvision cudatoolkit=11.3 27 | pip install -r requirements.txt 28 | ``` 29 | For pytorch3D installation refer to https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md 30 | 31 | **Note:** If PyOpenGL installed using `requirements.txt` causes issues in Ubuntu, then install PyOpenGL using: 32 | ``` 33 | apt-get update 34 | apt-get install python3-opengl 35 | ``` 36 | 37 | ## Dataset download and preprocess 38 | Download the ReMoCap dataset from the [ReMoS website](https://vcai.mpi-inf.mpg.de/projects/remos/#dataset_section). Unzip and place th dataset under `../DATASETS/ReMoCap`. 39 | The format of the dataset folder should be as follows: 40 | ```bash 41 | DATASETS 42 | ├── ReMoCap 43 | │ │ 44 | │ ├── LindyHop 45 | │ │ 46 | │ ├── train 47 | │ │ 48 | │ └── seq_3 49 | │ │ 50 | │ └── 0 'first person' 51 | │ └── motion.bvh 52 | │ └── motion_worldpose.csv 53 | │ └── motion_rotation.csv 54 | │ └── motion_offsets.pkl 55 | │ └── 1 'second person' 56 | │ └── motion.bvh 57 | │ └── motion_worldpose.csv 58 | │ └── motion_rotation.csv 59 | │ └── motion_offsets.pkl 60 | │ 61 | │ └── ... 62 | │ ├── test 63 | │ │ 64 | │ └── ... 65 | | 66 | │ ├── Ninjutsu 67 | │ │ 68 | │ ├── train 69 | │ │ 70 | │ └── shot_001 71 | │ │ 72 | │ └── 0.bvh 73 | │ └── 0_worldpose.csv 74 | │ └── 0_rotations.csv 75 | │ └── 0_offsets.pkl 76 | │ └── 1.bvh 77 | │ └── 1_worldpose.csv 78 | │ └── 1_rotations.csv 79 | │ └── 1_offsets.pkl 80 | │ └── shot_002 81 | │ └── ... 82 | │ └── ... 83 | │ ├── test 84 | │ │ 85 | │ └── ... 86 | 87 | ``` 88 | 89 | 3. To pre-process the two parts of the dataset for our setting, run: 90 | ``` 91 | python src/Lindyhop/process_LindyHop.py 92 | python src/Ninjutsu/process_Ninjutsu.py 93 | ``` 94 | This will create the 'train.pkl' and 'test.pkl' under `data/` folder. 95 | 96 | ## Training and testing on the Lindy Hop motion data 97 | 98 | 4. To train the ReMoS model on the Lindy Hop motions in our setting, run: 99 | ``` 100 | python src/Lindyhop/train_body_diffusion.py 101 | python src/Lindyhop/train_hand_diffusion.py 102 | ``` 103 | 104 | 5. To test and evaluate the ReMoS model on the Lindy Hop motions, run: 105 | ``` 106 | python src/Lindyhop/test_full_diffusion.py 107 | ``` 108 | Set 'is_eval' flag to True to get the evaluation metrics, and set 'is_eval' to False to visualize the results. 109 | 110 | Download the pre-trained weights for the Lindy Hop motions from [here](https://vcai.mpi-inf.mpg.de/projects/remos/LindyHop_pretrained_weights.zip) and unzip them under `save/LindyHop/`. 111 | 112 | ## Training and testing on the Ninjutsu motion data 113 | 114 | coming soon! 115 | 116 | ## License 117 | 118 | Copyright (c) 2024, Max Planck Institute for Informatics 119 | All rights reserved. 120 | 121 | Permission is hereby granted, free of charge, to any person or company obtaining a copy of this dataset and associated documentation files (the "Dataset") from the copyright holders to use the Dataset for any non-commercial purpose. Redistribution and (re)selling of the Dataset, of modifications, extensions, and derivates of it, and of other dataset containing portions of the licensed Dataset, are not permitted. The Copyright holder is permitted to publically disclose and advertise the use of the software by any licensee. 122 | 123 | Packaging or distributing parts or whole of the provided software (including code and data) as is or as part of other datasets is prohibited. Commercial use of parts or whole of the provided dataset (including code and data) is strictly prohibited. 124 | 125 | THE DATASET IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE DATASET OR THE USE OR OTHER DEALINGS IN THE DATASET. 126 | 127 | 128 | -------------------------------------------------------------------------------- /src/tools/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import logging 5 | import math 6 | import json 7 | import torch.nn.functional as F 8 | from copy import copy 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | to_cpu = lambda tensor: tensor.detach().cpu().numpy() 11 | 12 | 13 | def parse_npz(npz, allow_pickle=True): 14 | npz = np.load(npz, allow_pickle=allow_pickle) 15 | npz = {k: npz[k].tolist() for k in npz.files} 16 | return DotDict(npz) 17 | 18 | def params2torch(params, dtype = torch.float32): 19 | return {k: torch.from_numpy(v).type(dtype).to(device) for k, v in params.items()} 20 | 21 | def prepare_params(params, frame_mask, rel_trans = None, dtype = np.float32): 22 | n_params = {k: v[frame_mask].astype(dtype) for k, v in params.items()} 23 | if rel_trans is not None: 24 | n_params['transl'] -= rel_trans 25 | return n_params 26 | 27 | def torch2np(item, dtype=np.float32): 28 | out = {} 29 | for k, v in item.items(): 30 | if v ==[] or v=={}: 31 | continue 32 | if isinstance(v, list): 33 | if isinstance(v[0], str): 34 | out[k] = v 35 | else: 36 | if torch.is_tensor(v[0]): 37 | v = [v[i].cpu() for i in range(len(v))] 38 | try: 39 | out[k] = np.array(np.concatenate(v), dtype=dtype) 40 | except: 41 | out[k] = np.array(np.array(v), dtype=dtype) 42 | elif isinstance(v, dict): 43 | out[k] = torch2np(v) 44 | else: 45 | if torch.is_tensor(v): 46 | v = v.cpu() 47 | out[k] = np.array(v, dtype=dtype) 48 | 49 | return out 50 | 51 | def DotDict(in_dict): 52 | out_dict = copy(in_dict) 53 | for k,v in out_dict.items(): 54 | if isinstance(v,dict): 55 | out_dict[k] = DotDict(v) 56 | return dotdict(out_dict) 57 | 58 | class dotdict(dict): 59 | """dot.notation access to dictionary attributes""" 60 | __getattr__ = dict.get 61 | __setattr__ = dict.__setitem__ 62 | __delattr__ = dict.__delitem__ 63 | 64 | def append2dict(source, data): 65 | for k in data.keys(): 66 | if k in source.keys(): 67 | if isinstance(data[k], list): 68 | source[k] += data[k] 69 | else: 70 | source[k].append(data[k]) 71 | 72 | 73 | def append2list(source, data): 74 | # d = {} 75 | for k in data.keys(): 76 | leng = len(data[k]) 77 | break 78 | for id in range(leng): 79 | d = {} 80 | for k in data.keys(): 81 | if isinstance(data[k], list): 82 | if isinstance(data[k][0], str): 83 | d[k] = data[k] 84 | elif isinstance(data[k][0], np.ndarray): 85 | d[k] = data[k][id] 86 | 87 | elif isinstance(data[k], str): 88 | d[k] = data[k] 89 | elif isinstance(data[k], np.ndarray): 90 | d[k] = data[k] 91 | source.append(d) 92 | 93 | # source[k] += data[k].astype(np.float32) 94 | 95 | # source[k].append(data[k].astype(np.float32)) 96 | 97 | def np2torch(item, dtype=torch.float32): 98 | out = {} 99 | for k, v in item.items(): 100 | if v ==[] : 101 | continue 102 | if isinstance(v, str): 103 | out[k] = v 104 | elif isinstance(v, list): 105 | # if isinstance(v[0], str): 106 | # out[k] = v 107 | try: 108 | out[k] = torch.from_numpy(np.concatenate(v)).to(dtype) 109 | except: 110 | out[k] = v # torch.from_numpy(np.array(v)) 111 | elif isinstance(v, dict): 112 | out[k] = np2torch(v) 113 | else: 114 | out[k] = torch.from_numpy(v).to(dtype) 115 | return out 116 | 117 | def to_tensor(array, dtype=torch.float32): 118 | if not torch.is_tensor(array): 119 | array = torch.tensor(array) 120 | return array.to(dtype).to(device) 121 | 122 | 123 | def to_np(array, dtype=np.float32): 124 | if 'scipy.sparse' in str(type(array)): 125 | array = np.array(array.todencse(), dtype=dtype) 126 | elif torch.is_tensor(array): 127 | array = array.detach().cpu().numpy() 128 | return array 129 | 130 | def makepath(desired_path, isfile = False): 131 | ''' 132 | if the path does not exist make it 133 | :param desired_path: can be path to a file or a folder name 134 | :return: 135 | ''' 136 | import os 137 | if isfile: 138 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path)) 139 | else: 140 | if not os.path.exists(desired_path): os.makedirs(desired_path) 141 | return desired_path 142 | 143 | def lr_decay_step(optimizer, epo, lr, gamma): 144 | if epo % 3 == 0: 145 | lr = lr * gamma 146 | for param_group in optimizer.param_groups: 147 | param_group['lr'] = lr 148 | return lr 149 | 150 | def lr_decay_mine(optimizer, lr_now, gamma): 151 | lr = lr_now * gamma 152 | for param_group in optimizer.param_groups: 153 | param_group['lr'] = lr 154 | return lr 155 | 156 | def get_dct_matrix(N): 157 | dct_m = np.eye(N) 158 | for k in np.arange(N): 159 | for i in np.arange(N): 160 | w = np.sqrt(2 / N) 161 | if k == 0: 162 | w = np.sqrt(1 / N) 163 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N) 164 | idct_m = np.linalg.inv(dct_m) 165 | return dct_m, idct_m 166 | 167 | 168 | def save_csv_log(opt, head, value, is_create=False, file_name='train_log'): 169 | if len(value.shape) < 2: 170 | value = np.expand_dims(value, axis=0) 171 | df = pd.DataFrame(value) 172 | file_path = opt.ckpt + '/{}.csv'.format(file_name) 173 | if not os.path.exists(file_path) or is_create: 174 | df.to_csv(file_path, header=head, index=False) 175 | else: 176 | with open(file_path, 'a') as f: 177 | df.to_csv(f, header=False, index=False) 178 | 179 | 180 | def save_ckpt(state, epo, opt=None): 181 | file_path = os.path.join(opt.ckpt, 'ckpt_last.pth.tar') 182 | torch.save(state, file_path) 183 | # if epo ==24: # % 4 == 0 or epo>22 or epo<5: 184 | if epo % 5 == 0: 185 | file_path = os.path.join(opt.ckpt, 'ckpt_epo'+str(epo)+'.pth.tar') 186 | torch.save(state, file_path) 187 | 188 | 189 | def save_options(opt): 190 | with open('option.json', 'w') as f: 191 | f.write(json.dumps(vars(opt), sort_keys=False, indent=4)) 192 | 193 | 194 | -------------------------------------------------------------------------------- /src/tools/transformations.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2022 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 5 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 6 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 7 | # 8 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 9 | # on this computer program. You can only use this computer program if you have closed a license agreement 10 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 11 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 12 | # Contact: ps-license@tuebingen.mpg.de 13 | # 14 | 15 | import sys 16 | sys.path.append('.') 17 | sys.path.append('..') 18 | import numpy as np 19 | import torch 20 | import logging 21 | from copy import copy 22 | from scipy.spatial.transform import Rotation 23 | import torch.nn.functional as F 24 | # import pytorch3d.transforms as t3d 25 | 26 | 27 | LOGGER_DEFAULT_FORMAT = ('{time:YYYY-MM-DD HH:mm:ss.SSS} |' 28 | ' {level: <8} |' 29 | ' {name}:{function}:' 30 | '{line} - {message}') 31 | 32 | 33 | 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | to_cpu = lambda tensor: tensor.detach().cpu().numpy() 36 | 37 | def to_tensor(array, dtype=torch.float32): 38 | if not torch.is_tensor(array): 39 | array = torch.tensor(array) 40 | return array.to(dtype).to(device) 41 | 42 | 43 | def to_np(array, dtype=np.float32): 44 | if 'scipy.sparse' in str(type(array)): 45 | array = np.array(array.todencse(), dtype=dtype) 46 | elif torch.is_tensor(array): 47 | array = array.detach().cpu().numpy() 48 | return array 49 | 50 | def loc2vel(loc,fps): 51 | B = loc.shape[0] 52 | idxs = [0] + list(range(B-1)) 53 | vel = (loc - loc[idxs])/(1/float(fps)) 54 | return vel 55 | 56 | def vel2acc(vel,fps): 57 | B = vel.shape[0] 58 | idxs = [0] + list(range(B - 1)) 59 | acc = (vel - vel[idxs]) / (1 / float(fps)) 60 | return acc 61 | 62 | def loc2acc(loc,fps): 63 | vel = loc2vel(loc,fps) 64 | acc = vel2acc(vel,fps) 65 | return acc, vel 66 | 67 | 68 | def d62rotmat(pose): 69 | pose = torch.tensor(pose) 70 | reshaped_input = pose.reshape(-1, 6) 71 | return t3d.rotation_6d_to_matrix(reshaped_input) 72 | 73 | def rotmat2d6(pose): 74 | pose = torch.tensor(pose) 75 | return np.array(t3d.matrix_to_rotation_6d(pose)) 76 | 77 | def rotmat2d6_tensor(pose): 78 | pose = torch.tensor(pose) 79 | return torch.tensor(t3d.matrix_to_rotation_6d(pose)) 80 | 81 | def aa2rotmat(pose): 82 | pose = to_tensor(pose) 83 | return t3d.axis_angle_to_matrix(pose) 84 | 85 | def rotmat2aa(pose): 86 | pose = to_tensor(pose) 87 | quat = t3d.matrix_to_quaternion(pose) 88 | return t3d.quaternion_to_axis_angle(quat) 89 | # reshaped_input = pose.reshape(-1, 3, 3) 90 | # quat = t3d.matrix_to_quaternion(reshaped_input) 91 | 92 | def d62aa(pose): 93 | pose = to_tensor(pose) 94 | return rotmat2aa(d62rotmat(pose)) 95 | 96 | def aa2d6(pose): 97 | pose = to_tensor(pose) 98 | return rotmat2d6(aa2rotmat(pose)) 99 | 100 | def euler(rots, order='xyz', units='deg'): 101 | 102 | rots = np.asarray(rots) 103 | single_val = False if len(rots.shape)>1 else True 104 | rots = rots.reshape(-1,3) 105 | rotmats = [] 106 | 107 | for xyz in rots: 108 | if units == 'deg': 109 | xyz = np.radians(xyz) 110 | r = np.eye(3) 111 | for theta, axis in zip(xyz,order): 112 | c = np.cos(theta) 113 | s = np.sin(theta) 114 | if axis=='x': 115 | r = np.dot(np.array([[1,0,0],[0,c,-s],[0,s,c]]), r) 116 | if axis=='y': 117 | r = np.dot(np.array([[c,0,s],[0,1,0],[-s,0,c]]), r) 118 | if axis=='z': 119 | r = np.dot(np.array([[c,-s,0],[s,c,0],[0,0,1]]), r) 120 | rotmats.append(r) 121 | rotmats = np.stack(rotmats).astype(np.float32) 122 | if single_val: 123 | return rotmats[0] 124 | else: 125 | return rotmats 126 | 127 | def batch_euler_to_rotmat(bxyz, order='xyz', units='deg'): 128 | br = [] 129 | for frame in range(bxyz.shape[0]): 130 | # rotmat = euler(bxyz[frame], order, units) 131 | r1 = Rotation.from_euler('xyz', np.array(bxyz[frame]), degrees=True) 132 | rotmat = r1.as_matrix() 133 | br.append(rotmat) 134 | return np.stack(br).astype(np.float32) 135 | 136 | def batch_rotmat_to_euler(rotmat, order='ZYX'): 137 | 138 | # Convert to Euler angles and permute last dimension from ZYX to XYZ to match data order 139 | eu = t3d.matrix_to_euler_angles(rotmat, order)[..., [2, 1, 0]] 140 | return eu 141 | 142 | def batch_euler_to_6d(bxyz, order='xyz', units='deg'): 143 | br = [] 144 | for frame in range(bxyz.shape[0]): 145 | # rotmat = euler(bxyz[frame], order, units) 146 | r1 = Rotation.from_euler('xyz', np.array(bxyz[frame]), degrees=True) 147 | rotmat = r1.as_matrix() 148 | d6 = rotmat2d6(rotmat) 149 | br.append(d6) 150 | return np.stack(br).astype(np.float32) 151 | 152 | def batch_6d_to_euler(bxyz, order='XYZ'): 153 | br = [] 154 | 155 | for batch in range(bxyz.shape[0]): 156 | br_ = [] 157 | for frame in range(bxyz.shape[1]): 158 | # rotmat = t3d.rotation_6d_to_matrix(bxyz[batch, frame]) 159 | rotmat = d62rotmat(bxyz[batch, frame]) 160 | r = Rotation.from_matrix(np.array(rotmat)) 161 | eu = r.as_euler("xyz", degrees=True) 162 | br_.append(np.array(eu)) 163 | br.append(np.stack(br_).astype(np.float32)) 164 | return np.stack(br).astype(np.float32) 165 | 166 | 167 | def batch_6d_to_euler_tensor(bxyz, order='ZYX'): 168 | rotmat = t3d.rotation_6d_to_matrix(bxyz) 169 | # Convert to Euler angles and permute last dimension from ZYX to XYZ to match data order 170 | eu = t3d.matrix_to_euler_angles(rotmat, order)[..., [2, 1, 0]] 171 | return eu 172 | 173 | 174 | def rotate(points,R): 175 | shape = list(points.shape) 176 | points = to_tensor(points) 177 | R = to_tensor(R) 178 | if len(shape)>3: 179 | points = points.squeeze() 180 | if len(shape)<3: 181 | points = points.unsqueeze(dim=1) 182 | if R.shape[0] > shape[0]: 183 | shape[0] = R.shape[0] 184 | r_points = torch.matmul(points, R.transpose(1,2)) 185 | return r_points.reshape(shape) 186 | 187 | def rotmul(rotmat,R): 188 | if rotmat.ndim>3: 189 | rotmat = to_tensor(rotmat).squeeze() 190 | if R.ndim>3: 191 | R = to_tensor(R).squeeze() 192 | rot = torch.matmul(rotmat, R) 193 | return rot 194 | 195 | 196 | smplx_parents =[-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 197 | 16, 17, 18, 19, 15, 15, 15, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 198 | 35, 20, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50, 21, 52, 199 | 53] 200 | def smplx_loc2glob(local_pose): 201 | 202 | bs = local_pose.shape[0] 203 | local_pose = local_pose.view(bs, -1, 3, 3) 204 | global_pose = local_pose.clone() 205 | 206 | for i in range(1,len(smplx_parents)): 207 | global_pose[:,i] = torch.matmul(global_pose[:, smplx_parents[i]], global_pose[:, i].clone()) 208 | 209 | return global_pose.reshape(bs,-1,3,3) 210 | 211 | def rot2eul(R): 212 | beta = -np.arcsin(R[2,0]) 213 | alpha = np.arctan2(R[2,1]/np.cos(beta),R[2,2]/np.cos(beta)) 214 | gamma = np.arctan2(R[1,0]/np.cos(beta),R[0,0]/np.cos(beta)) 215 | return np.array((alpha, beta, gamma)) 216 | 217 | def eul2rot(theta) : 218 | 219 | R = np.array([[np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]), np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])], 220 | [np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]), np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])], 221 | [-np.sin(theta[1]), np.sin(theta[0])*np.cos(theta[1]), np.cos(theta[0])*np.cos(theta[1])]]) 222 | 223 | return R 224 | 225 | if __name__ == "__main__": 226 | euler_angles = np.array([0.3, -0.5, 0.7], dtype=np.float32) 227 | euler2matrix = t3d.euler_angles_to_matrix(torch.from_numpy(euler_angles), 'XYZ') 228 | matrix2euler = t3d.matrix_to_euler_angles(euler2matrix, 'XYZ') 229 | w = Rotation.from_euler('xyz', euler_angles, degrees=False) 230 | rotmat = w.as_matrix() 231 | r = Rotation.from_matrix(np.array(euler2matrix)) 232 | eu = r.as_euler("xyz", degrees=False) 233 | angrot = eul2rot(euler_angles) 234 | print() -------------------------------------------------------------------------------- /src/tools/common/skeleton.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | from src.tools.common.quaternion import * 5 | import scipy.ndimage.filters as filters 6 | 7 | class Skeleton(object): 8 | def __init__(self, offset, kinematic_tree, device): 9 | self.device = device 10 | self._raw_offset_np = offset.numpy() 11 | self._raw_offset = offset.clone().detach().to(device).float() 12 | self._kinematic_tree = kinematic_tree 13 | self._offset = None 14 | self._parents = [0] * len(self._raw_offset) 15 | self._parents[0] = -1 16 | for chain in self._kinematic_tree: 17 | for j in range(1, len(chain)): 18 | self._parents[chain[j]] = chain[j-1] 19 | 20 | def njoints(self): 21 | return len(self._raw_offset) 22 | 23 | def offset(self): 24 | return self._offset 25 | 26 | def set_offset(self, offsets): 27 | self._offset = offsets.clone().detach().to(self.device).float() 28 | 29 | def kinematic_tree(self): 30 | return self._kinematic_tree 31 | 32 | def parents(self): 33 | return self._parents 34 | 35 | # joints (batch_size, joints_num, 3) 36 | def get_offsets_joints_batch(self, joints): 37 | assert len(joints.shape) == 3 38 | _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() 39 | for i in range(1, self._raw_offset.shape[0]): 40 | _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] 41 | 42 | self._offset = _offsets.detach() 43 | return _offsets 44 | 45 | # joints (joints_num, 3) 46 | def get_offsets_joints(self, joints): 47 | assert len(joints.shape) == 2 48 | _offsets = self._raw_offset.clone() 49 | for i in range(1, self._raw_offset.shape[0]): 50 | # print(joints.shape) 51 | _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] 52 | 53 | self._offset = _offsets.detach() 54 | return _offsets 55 | 56 | # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder 57 | # joints (batch_size, joints_num, 3) 58 | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): 59 | assert len(face_joint_idx) == 4 60 | '''Get Forward Direction''' 61 | l_hip, r_hip, sdr_r, sdr_l = face_joint_idx 62 | across1 = joints[:, r_hip] - joints[:, l_hip] 63 | across2 = joints[:, sdr_r] - joints[:, sdr_l] 64 | across = across1 + across2 65 | across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] 66 | # print(across1.shape, across2.shape) 67 | 68 | # forward (batch_size, 3) 69 | forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 70 | if smooth_forward: 71 | forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') 72 | # forward (batch_size, 3) 73 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] 74 | 75 | '''Get Root Rotation''' 76 | target = np.array([[0,0,1]]).repeat(len(forward), axis=0) 77 | root_quat = qbetween_np(forward, target) 78 | 79 | '''Inverse Kinematics''' 80 | # quat_params (batch_size, joints_num, 4) 81 | # print(joints.shape[:-1]) 82 | quat_params = np.zeros(joints.shape[:-1] + (4,)) 83 | # print(quat_params.shape) 84 | root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 85 | quat_params[:, 0] = root_quat 86 | # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 87 | for chain in self._kinematic_tree: 88 | R = root_quat 89 | for j in range(len(chain) - 1): 90 | # (batch, 3) 91 | u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) 92 | # print(u.shape) 93 | # (batch, 3) 94 | v = joints[:, chain[j+1]] - joints[:, chain[j]] 95 | v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] 96 | # print(u.shape, v.shape) 97 | rot_u_v = qbetween_np(u, v) 98 | 99 | R_loc = qmul_np(qinv_np(R), rot_u_v) 100 | 101 | quat_params[:,chain[j + 1], :] = R_loc 102 | R = qmul_np(R, R_loc) 103 | 104 | return quat_params 105 | 106 | # Be sure root joint is at the beginning of kinematic chains 107 | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 108 | # quat_params (batch_size, joints_num, 4) 109 | # joints (batch_size, joints_num, 3) 110 | # root_pos (batch_size, 3) 111 | if skel_joints is not None: 112 | offsets = self.get_offsets_joints_batch(skel_joints) 113 | if len(self._offset.shape) == 2: 114 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 115 | joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) 116 | joints[:, 0] = root_pos 117 | for chain in self._kinematic_tree: 118 | if do_root_R: 119 | R = quat_params[:, 0] 120 | else: 121 | R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) 122 | for i in range(1, len(chain)): 123 | R = qmul(R, quat_params[:, chain[i]]) 124 | offset_vec = offsets[:, chain[i]] 125 | joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] 126 | return joints 127 | 128 | # Be sure root joint is at the beginning of kinematic chains 129 | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 130 | # quat_params (batch_size, joints_num, 4) 131 | # joints (batch_size, joints_num, 3) 132 | # root_pos (batch_size, 3) 133 | if skel_joints is not None: 134 | skel_joints = torch.from_numpy(skel_joints) 135 | offsets = self.get_offsets_joints_batch(skel_joints) 136 | if len(self._offset.shape) == 2: 137 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 138 | offsets = offsets.numpy() 139 | joints = np.zeros(quat_params.shape[:-1] + (3,)) 140 | joints[:, 0] = root_pos 141 | for chain in self._kinematic_tree: 142 | if do_root_R: 143 | R = quat_params[:, 0] 144 | else: 145 | R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) 146 | for i in range(1, len(chain)): 147 | R = qmul_np(R, quat_params[:, chain[i]]) 148 | offset_vec = offsets[:, chain[i]] 149 | joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] 150 | return joints 151 | 152 | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 153 | # cont6d_params (batch_size, joints_num, 6) 154 | # joints (batch_size, joints_num, 3) 155 | # root_pos (batch_size, 3) 156 | if skel_joints is not None: 157 | skel_joints = torch.from_numpy(skel_joints) 158 | offsets = self.get_offsets_joints_batch(skel_joints) 159 | if len(self._offset.shape) == 2: 160 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 161 | offsets = offsets.numpy() 162 | joints = np.zeros(cont6d_params.shape[:-1] + (3,)) 163 | joints[:, 0] = root_pos 164 | for chain in self._kinematic_tree: 165 | if do_root_R: 166 | matR = cont6d_to_matrix_np(cont6d_params[:, 0]) 167 | else: 168 | matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) 169 | for i in range(1, len(chain)): 170 | matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) 171 | offset_vec = offsets[:, chain[i]][..., np.newaxis] 172 | # print(matR.shape, offset_vec.shape) 173 | joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 174 | return joints 175 | 176 | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 177 | # cont6d_params (batch_size, joints_num, 6) 178 | # joints (batch_size, joints_num, 3) 179 | # root_pos (batch_size, 3) 180 | if skel_joints is not None: 181 | # skel_joints = torch.from_numpy(skel_joints) 182 | offsets = self.get_offsets_joints_batch(skel_joints) 183 | if len(self._offset.shape) == 2: 184 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 185 | joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) 186 | joints[..., 0, :] = root_pos 187 | for chain in self._kinematic_tree: 188 | if do_root_R: 189 | matR = cont6d_to_matrix(cont6d_params[:, 0]) 190 | else: 191 | matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) 192 | for i in range(1, len(chain)): 193 | matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) 194 | offset_vec = offsets[:, chain[i]].unsqueeze(-1) 195 | # print(matR.shape, offset_vec.shape) 196 | joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 197 | return joints 198 | 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /src/Ninjutsu/argUtils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import sys 4 | import os 5 | from ast import literal_eval 6 | 7 | def get_args_update_dict(args): 8 | args_update_dict = {} 9 | for string in sys.argv: 10 | string = ''.join(string.split('-')) 11 | if string in args: 12 | args_update_dict.update({string: args.__dict__[string]}) 13 | return args_update_dict 14 | 15 | 16 | def argparseNloop(): 17 | parser = argparse.ArgumentParser() 18 | 19 | '''Directories and data path''' 20 | parser.add_argument('--work-dir', default = os.path.join('src', 'Ninjutsu'), type=str, 21 | help='The path to the downloaded data') 22 | parser.add_argument('--data-path', default = os.path.join('..', 'DATASETS', 'Ninjutsu_Data'), type=str, 23 | help='The path to the folder that contains dataset before pre-processing') 24 | parser.add_argument('--model_path', default = 'smplx_model', type=str, 25 | help='The path to the folder containing SMPLX model') 26 | parser.add_argument('--save_dir', default = os.path.join('save', 'Ninjutsu', 'diffusion'), type=str, 27 | help='The path to the folder to save the processed data') 28 | parser.add_argument('--render_path', default = os.path.join('render', 'Ninjutsu'), type=str, 29 | help='The path to the folder to save the rendered output') 30 | parser.add_argument('--data_dir', default = os.path.join('data', 'Ninjutsu'), type=str, 31 | help='The path to the pre-processed data') 32 | 33 | 34 | '''Dataset Parameters''' 35 | parser.add_argument('-dataset', nargs='+', type=str, default='NinjutsuDataset', 36 | help='name of the dataset') 37 | parser.add_argument('--frames', nargs='+', type=int, default=50, 38 | help='Number of frames taken from each sequence in the dataset for training.') 39 | parser.add_argument('-seedLength', nargs='+', type=int, default=20, 40 | help='initial length of inputs to seed the prediction; used when offset > 0') 41 | parser.add_argument('-exp', nargs='+', type=int, default=0, 42 | help='experiment number') 43 | parser.add_argument('-scale', nargs='+', type=int, default=1000.0, 44 | help='Data scale by this factor') 45 | parser.add_argument('-framerate', nargs='+', type=int, default=20, 46 | help='frame rate after pre-processing.') 47 | parser.add_argument('-seed', nargs='+', type=int, default=4815, 48 | help='manual seed') 49 | parser.add_argument('-load', nargs='+', type=str, default=None, 50 | help='Load weights from this file') 51 | parser.add_argument('-cuda', nargs='+', type=int, default=0, 52 | help='choice of gpu device, -1 for cpu') 53 | parser.add_argument('-overfit', nargs='+', type=int, default=0, 54 | help='disables early stopping and saves models even if the dev loss increases. useful for performing an overfitting check') 55 | 56 | '''Diffusion parameters''' 57 | parser.add_argument("--noise_schedule", default='linear', choices=['linear', 'cosine', 'sigmoid'], type=str, 58 | help="Noise schedule type") 59 | parser.add_argument("--diffusion_steps", default=300, type=int, 60 | help="Number of diffusion steps (denoted T in the paper)") 61 | parser.add_argument("--sampler", default='uniform', type=str, 62 | help="Create a Schedule Sampler") 63 | 64 | 65 | '''Diffusion transformer model parameters''' 66 | parser.add_argument('-model', nargs='+', type=str, default='DiffusionTransformer', 67 | help='name of model') 68 | parser.add_argument('-input_feats', nargs='+', type=int, default=3, 69 | help='number of input features ') 70 | parser.add_argument('-out_feats', nargs='+', type=int, default=3, 71 | help='number of output features ') 72 | parser.add_argument('--jt_latent', nargs='+', type=int, default=32, 73 | help='dimensionality of last dimension after GCN') 74 | parser.add_argument('--d_model', nargs='+', type=int, default=256, 75 | help='dimensionality of model embeddings') 76 | parser.add_argument('--d_ff', nargs='+', type=int, default=512, 77 | help='dimensionality of the inner layer in the feed-forward network') 78 | parser.add_argument('--num_layer', nargs='+', type=int, default=6, 79 | help='number of layers in encoder-decoder of model') 80 | parser.add_argument('--num_head', nargs='+', type=int, default=4, 81 | help='number of attention heads in the multi-head attention mechanism.') 82 | parser.add_argument("--activations", default='LeakyReLU', choices=['LeakyReLU', 'SiLU', 'GELU'], type=str, 83 | help="Activation function") 84 | '''Diffusion transformer hand model parameters''' 85 | parser.add_argument('-hand_input_condn_feats', nargs='+', type=int, default=280, 86 | help='number of input features ') 87 | parser.add_argument('-hand_out_feats', nargs='+', type=int, default=3, 88 | help='number of output features ') 89 | parser.add_argument('--d_modelhand', nargs='+', type=int, default=256, 90 | help='dimensionality of model embeddings') 91 | parser.add_argument('--d_ffhand', nargs='+', type=int, default=512, 92 | help='dimensionality of the inner layer in the feed-forward network') 93 | parser.add_argument('--num_layer_hands', nargs='+', type=int, default=6, 94 | help='number of layers in encoder-decoder of model') 95 | parser.add_argument('--num_head_hands', nargs='+', type=int, default=4, 96 | help='number of attention heads in the multi-head attention mechanism.') 97 | 98 | 99 | '''Training parameters''' 100 | parser.add_argument('-batch_size', nargs='+', type=int, default=32, 101 | help='minibatch size.') 102 | parser.add_argument('-num_epochs', nargs='+', type=int, default=5000, 103 | help='number of epochs for training') 104 | parser.add_argument('--skip_train', nargs='+', type=int, default=1, 105 | help='downsampling factor of the training dataset. For example, a value of s indicates floor(D/s) training samples are loaded, ' 106 | 'where D is the total number of training samples (default: 1).') 107 | parser.add_argument('--skip_val', nargs='+', type=int, default=1, 108 | help='downsampling factor of the validation dataset. For example, a value of s indicates floor(D/s) validation samples are loaded, ' 109 | 'where D is the total number of validation samples (default: 1).') 110 | parser.add_argument('-early_stopping', nargs='+', type=int, default=0, 111 | help='Use 1 for early stopping') 112 | parser.add_argument('--n_workers', default=0, type=int, 113 | help='Number of PyTorch dataloader workers') 114 | parser.add_argument('-greedy_save', nargs='+', type=int, default=1, 115 | help='save weights after each epoch if 1') 116 | parser.add_argument('-save_model', nargs='+', type=int, default=1, 117 | help='flag to save model at every step') 118 | parser.add_argument('-stop_thresh', nargs='+', type=int, default=3, 119 | help='number of consequetive validation loss increses before stopping') 120 | parser.add_argument('-eps', nargs='+', type=float, default=0, 121 | help='if the decrease in validation is less than eps, it counts for one step in stop_thresh ') 122 | parser.add_argument('--curriculum', nargs='+', type=int, default=0, 123 | help='if 1, learn generating time steps by starting with 2 timesteps upto time, increasing by a power of 2') 124 | parser.add_argument('--use-multigpu', default=False, 125 | type=lambda arg: arg.lower() in ['true', '1'], 126 | help='If to use multiple GPUs for training') 127 | parser.add_argument('--load-on-ram', default=False, 128 | type=lambda arg: arg.lower() in ['true', '1'], 129 | help='This will load all the data on the RAM memory for faster training.' 130 | 'If your RAM capacity is more than 40 Gb, consider using this.') 131 | 132 | '''Optimizer parameters''' 133 | parser.add_argument('--optimizer', default='optim.Adam', type=str, 134 | help='Optimizer') 135 | parser.add_argument('-momentum', default=0.9, type=float, 136 | help='Weight decay for SGD Optimizer') 137 | parser.add_argument('-lr', nargs='+', type=float, default=1e-5, 138 | help='learning rate') 139 | 140 | '''Scheduler parameters''' 141 | parser.add_argument('--scheduler', default='torch.optim.lr_scheduler.StepLR', type=str, 142 | help='Scheduler') 143 | parser.add_argument('--patience', default=3, type=float, 144 | help='Step size for ReduceOnPlateau scheduler') 145 | parser.add_argument('--factor', default=0.99, type=float, 146 | help='Decay rate for ReduceOnPlateau scheduler') 147 | parser.add_argument('--threshold', default=0.05, type=float, 148 | help='THreshold for ReduceOnPlateau scheduler') 149 | 150 | parser.add_argument('--stepsize', default=5, type=float, 151 | help='Step size for StepLR scheduler') 152 | parser.add_argument('--gamma', default=0.99, type=float, 153 | help='Decay rate for StepLR scheduler') 154 | parser.add_argument('--milestones', default=[50, 100], type=float, 155 | help='List of epoch indices. Must be increasing for MultiStepLR scheduler') 156 | '''Loss parameters''' 157 | parser.add_argument('--lambda_loss', type=dict, default=None, 158 | help='weight of loss for VAE') 159 | 160 | 161 | args, unknown = parser.parse_known_args() 162 | return args 163 | -------------------------------------------------------------------------------- /src/Lindyhop/argUtils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import sys 4 | import os 5 | from ast import literal_eval 6 | 7 | def get_args_update_dict(args): 8 | args_update_dict = {} 9 | for string in sys.argv: 10 | string = ''.join(string.split('-')) 11 | if string in args: 12 | args_update_dict.update({string: args.__dict__[string]}) 13 | return args_update_dict 14 | 15 | 16 | def argparseNloop(): 17 | parser = argparse.ArgumentParser() 18 | 19 | '''Directories and data path''' 20 | parser.add_argument('--work-dir', default = os.path.join('src', 'Lindyhop'), type=str, 21 | help='The path to the downloaded data') 22 | parser.add_argument('--data-path', default = os.path.join('..', '..', 'DATASETS', 'LindyHop'), type=str, 23 | help='The path to the folder that contains dataset before pre-processing') 24 | parser.add_argument('--model_path', default = 'smplx_model', type=str, 25 | help='The path to the folder containing SMPLX model') 26 | parser.add_argument('--save_dir', default = os.path.join('save', 'Lindyhop', 'diffusion'), type=str, 27 | help='The path to the folder to save the processed data') 28 | parser.add_argument('--render_path', default = os.path.join('render', 'Lindyhop'), type=str, 29 | help='The path to the folder to save the rendered output') 30 | parser.add_argument('--data_dir', default = os.path.join('data', 'Lindyhop'), type=str, 31 | help='The path to the pre-processed data') 32 | 33 | 34 | '''Dataset Parameters''' 35 | parser.add_argument('-dataset', nargs='+', type=str, default='LindyHopDataset', 36 | help='name of the dataset') 37 | parser.add_argument('--frames', nargs='+', type=int, default=20, 38 | help='Number of frames taken from each sequence in the dataset for training.') 39 | parser.add_argument('-seedLength', nargs='+', type=int, default=20, 40 | help='initial length of inputs to seed the prediction; used when offset > 0') 41 | parser.add_argument('-exp', nargs='+', type=int, default=0, 42 | help='experiment number') 43 | parser.add_argument('-scale', nargs='+', type=int, default=1000.0, 44 | help='Data scale by this factor') 45 | parser.add_argument('-framerate', nargs='+', type=int, default=20, 46 | help='frame rate after pre-processing.') 47 | parser.add_argument('-seed', nargs='+', type=int, default=4815, 48 | help='manual seed') 49 | parser.add_argument('-load', nargs='+', type=str, default=None, 50 | help='Load weights from this file') 51 | parser.add_argument('-cuda', nargs='+', type=int, default=0, 52 | help='choice of gpu device, -1 for cpu') 53 | parser.add_argument('-overfit', nargs='+', type=int, default=0, 54 | help='disables early stopping and saves models even if the dev loss increases. useful for performing an overfitting check') 55 | 56 | '''Diffusion parameters''' 57 | parser.add_argument("--noise_schedule", default='linear', choices=['linear', 'cosine', 'sigmoid'], type=str, 58 | help="Noise schedule type") 59 | parser.add_argument("--diffusion_steps", default=500, type=int, 60 | help="Number of diffusion steps (denoted T in the paper)") 61 | parser.add_argument("--sampler", default='uniform', type=str, 62 | help="Create a Schedule Sampler") 63 | 64 | 65 | '''Diffusion transformer model parameters''' 66 | parser.add_argument('-model', nargs='+', type=str, default='DiffusionTransformer', 67 | help='name of model') 68 | parser.add_argument('-input_feats', nargs='+', type=int, default=3, 69 | help='number of input features ') 70 | parser.add_argument('-out_feats', nargs='+', type=int, default=3, 71 | help='number of output features ') 72 | parser.add_argument('--jt_latent', nargs='+', type=int, default=32, 73 | help='dimensionality of last dimension after GCN') 74 | parser.add_argument('--d_model', nargs='+', type=int, default=256, 75 | help='dimensionality of model embeddings') 76 | parser.add_argument('--d_ff', nargs='+', type=int, default=512, 77 | help='dimensionality of the inner layer in the feed-forward network') 78 | parser.add_argument('--num_layer', nargs='+', type=int, default=6, 79 | help='number of layers in encoder-decoder of model') 80 | parser.add_argument('--num_head', nargs='+', type=int, default=4, 81 | help='number of attention heads in the multi-head attention mechanism.') 82 | parser.add_argument("--activations", default='LeakyReLU', choices=['LeakyReLU', 'SiLU', 'GELU'], type=str, 83 | help="Activation function") 84 | 85 | '''Diffusion transformer hand model parameters''' 86 | parser.add_argument('-hand_input_condn_feats', nargs='+', type=int, default=280, 87 | help='number of input features ') 88 | parser.add_argument('-hand_out_feats', nargs='+', type=int, default=3, 89 | help='number of output features ') 90 | parser.add_argument('--d_modelhand', nargs='+', type=int, default=256, 91 | help='dimensionality of model embeddings') 92 | parser.add_argument('--d_ffhand', nargs='+', type=int, default=512, 93 | help='dimensionality of the inner layer in the feed-forward network') 94 | parser.add_argument('--num_layer_hands', nargs='+', type=int, default=6, 95 | help='number of layers in encoder-decoder of model') 96 | parser.add_argument('--num_head_hands', nargs='+', type=int, default=4, 97 | help='number of attention heads in the multi-head attention mechanism.') 98 | 99 | 100 | '''Training parameters''' 101 | parser.add_argument('-batch_size', nargs='+', type=int, default=32, 102 | help='minibatch size.') 103 | parser.add_argument('-num_epochs', nargs='+', type=int, default=300, 104 | help='number of epochs for training') 105 | parser.add_argument('--skip_train', nargs='+', type=int, default=1, 106 | help='downsampling factor of the training dataset. For example, a value of s indicates floor(D/s) training samples are loaded, ' 107 | 'where D is the total number of training samples (default: 1).') 108 | parser.add_argument('--skip_val', nargs='+', type=int, default=1, 109 | help='downsampling factor of the validation dataset. For example, a value of s indicates floor(D/s) validation samples are loaded, ' 110 | 'where D is the total number of validation samples (default: 1).') 111 | parser.add_argument('-early_stopping', nargs='+', type=int, default=0, 112 | help='Use 1 for early stopping') 113 | parser.add_argument('--n_workers', default=0, type=int, 114 | help='Number of PyTorch dataloader workers') 115 | parser.add_argument('-greedy_save', nargs='+', type=int, default=1, 116 | help='save weights after each epoch if 1') 117 | parser.add_argument('-save_model', nargs='+', type=int, default=1, 118 | help='flag to save model at every step') 119 | parser.add_argument('-stop_thresh', nargs='+', type=int, default=3, 120 | help='number of consequetive validation loss increses before stopping') 121 | parser.add_argument('-eps', nargs='+', type=float, default=0, 122 | help='if the decrease in validation is less than eps, it counts for one step in stop_thresh ') 123 | parser.add_argument('--curriculum', nargs='+', type=int, default=0, 124 | help='if 1, learn generating time steps by starting with 2 timesteps upto time, increasing by a power of 2') 125 | parser.add_argument('--use-multigpu', default=False, 126 | type=lambda arg: arg.lower() in ['true', '1'], 127 | help='If to use multiple GPUs for training') 128 | parser.add_argument('--load-on-ram', default=False, 129 | type=lambda arg: arg.lower() in ['true', '1'], 130 | help='This will load all the data on the RAM memory for faster training.' 131 | 'If your RAM capacity is more than 40 Gb, consider using this.') 132 | 133 | '''Optimizer parameters''' 134 | parser.add_argument('--optimizer', default='optim.Adam', type=str, 135 | help='Optimizer') 136 | parser.add_argument('-momentum', default=0.9, type=float, 137 | help='Weight decay for SGD Optimizer') 138 | parser.add_argument('-lr', nargs='+', type=float, default=1e-5, 139 | help='learning rate') 140 | 141 | '''Scheduler parameters''' 142 | parser.add_argument('--scheduler', default='torch.optim.lr_scheduler.StepLR', type=str, 143 | help='Scheduler') 144 | parser.add_argument('--patience', default=3, type=float, 145 | help='Step size for ReduceOnPlateau scheduler') 146 | parser.add_argument('--factor', default=0.99, type=float, 147 | help='Decay rate for ReduceOnPlateau scheduler') 148 | parser.add_argument('--threshold', default=0.05, type=float, 149 | help='THreshold for ReduceOnPlateau scheduler') 150 | 151 | parser.add_argument('--stepsize', default=5, type=float, 152 | help='Step size for StepLR scheduler') 153 | parser.add_argument('--gamma', default=0.99, type=float, 154 | help='Decay rate for StepLR scheduler') 155 | parser.add_argument('--milestones', default=[50, 100], type=float, 156 | help='List of epoch indices. Must be increasing for MultiStepLR scheduler') 157 | '''Loss parameters''' 158 | parser.add_argument('--lambda_loss', type=dict, default=None, 159 | help='weight of loss for VAE') 160 | 161 | 162 | args, unknown = parser.parse_known_args() 163 | return args 164 | -------------------------------------------------------------------------------- /src/Lindyhop/process_LindyHop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import torch 3 | import os 4 | import glob 5 | import sys 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | import pickle 9 | from src.tools.transformations import batch_euler_to_rotmat 10 | 11 | def makepath(desired_path, isfile = False): 12 | ''' 13 | if the path does not exist make it 14 | :param desired_path: can be path to a file or a folder name 15 | :return: 16 | ''' 17 | import os 18 | if isfile: 19 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path)) 20 | else: 21 | if not os.path.exists(desired_path): os.makedirs(desired_path) 22 | return desired_path 23 | 24 | class PreProcessor(): 25 | def __init__(self, root_dir, fps=20, split='train'): 26 | self.root = root_dir 27 | self.framerate = fps 28 | self.root = os.path.join(root_dir, split) 29 | seq = os.listdir(self.root) 30 | self.sequences = [int(x) for x in seq] 31 | self.total_frames = 0 32 | self.total_contact_frames = 0 33 | self.annot_dict = { 34 | 'cam': [], 35 | 'seq': [], 'contacts': [], 36 | 'pose_canon_1':[], 'pose_canon_2':[], 37 | 'dofs_1': [], 'dofs_2': [], 38 | 'rotmat_1': [], 'rotmat_2': [], 39 | 'offsets_1': [], 'offsets_2': [] 40 | } 41 | 42 | self.bvh_joint_order = { 43 | 'Hips': 0, 44 | 'RightUpLeg': 1, 45 | 'RightLeg': 2, 46 | 'RightFoot': 3, 47 | 'RightToeBase': 4, 48 | 'RightToeBaseEnd': 5, 49 | 'LeftUpLeg': 6, 50 | 'LeftLeg': 7, 51 | 'LeftFoot': 8, 52 | 'LeftToeBase': 9, 53 | 'LeftToeBaseEnd': 10, 54 | 'Spine': 11, 55 | 'Spine1': 12, 56 | 'Spine2': 13, 57 | 'Spine3': 14, 58 | 'RightShoulder': 15, 59 | 'RightArm': 16, 60 | 'RightForeArm': 17, 61 | 'RightHand': 18, 62 | 'RightHandEnd': 19, 63 | 'RightHandPinky1': 20, 64 | 'RightHandPinky2': 21, 65 | 'RightHandPinky3': 22, 66 | 'RightHandPinky3End': 23, 67 | 'RightHandRing1': 24, 68 | 'RightHandRing2': 25, 69 | 'RightHandRing3': 26, 70 | 'RightHandRing3End': 27, 71 | 'RightHandMiddle1': 28, 72 | 'RightHandMiddle2': 29, 73 | 'RightHandMiddle3': 30, 74 | 'RightHandMiddle3End': 31, 75 | 'RightHandIndex1': 32, 76 | 'RightHandIndex2': 33, 77 | 'RightHandIndex3': 34, 78 | 'RightHandIndex3End': 35, 79 | 'RightHandThumb1': 36, 80 | 'RightHandThumb2': 37, 81 | 'RightHandThumb3': 38, 82 | 'RightHandThumb3End': 39, 83 | 'LeftShoulder': 40, 84 | 'LeftArm': 41, 85 | 'LeftForeArm': 42, 86 | 'LeftHand': 43, 87 | 'LeftHandEnd': 44, 88 | 'LeftHandPinky1': 45, 89 | 'LeftHandPinky2': 46, 90 | 'LeftHandPinky3': 47, 91 | 'LeftHandPinky3End': 48, 92 | 'LeftHandRing1': 49, 93 | 'LeftHandRing2': 50, 94 | 'LeftHandRing3': 51, 95 | 'LeftHandRing3End': 52, 96 | 'LeftHandMiddle1': 53, 97 | 'LeftHandMiddle2': 54, 98 | 'LeftHandMiddle3': 55, 99 | 'LeftHandMiddle3End': 56, 100 | 'LeftHandIndex1': 57, 101 | 'LeftHandIndex2': 58, 102 | 'LeftHandIndex3': 59, 103 | 'LeftHandIndex3End': 60, 104 | 'LeftHandThumb1': 61, 105 | 'LeftHandThumb2': 62, 106 | 'LeftHandThumb3': 63, 107 | 'LeftHandThumb3End': 64, 108 | 'Spine4': 65, 109 | 'Neck': 66, 110 | 'Head': 67, 111 | 'HeadEnd': 68 112 | } 113 | 114 | print("creating the annot file") 115 | self.collate_videos() 116 | self.save_annot(split) 117 | 118 | def detect_contact(self, motion1, motion2, thresh=50): 119 | 120 | contact_joints = ['Hand', 'HandEnd', 121 | 'HandPinky1', 'HandPinky2', 'HandPinky3', 'HandPinky3End', 122 | 'HandRing1', 'HandRing2', 'HandRing3','HandRing3End', 123 | 'HandIndex1', 'HandIndex2', 'HandIndex3','HandIndex3End', 124 | 'HandMiddle1', 'HandMiddle2', 'HandMiddle3','HandMiddle3End', 125 | 'HandThumb1', 'HandThumb2', 'HandThumb3','HandThumb3End'] 126 | 127 | n_frames = motion1.shape[0] 128 | 129 | assert motion1.shape == motion2.shape 130 | 131 | ## 0 : no contact, 1: rh-rh, 2: lh-lh, 3: lh-rh , 4: rh-lh 132 | contact = np.zeros((n_frames, 5)) 133 | 134 | def dist(x, y): 135 | return np.sqrt(np.sum((x - y)**2)) 136 | contact_frames = [] 137 | 138 | count = 0 139 | for i in range(n_frames): 140 | for s, sides in enumerate([['Right', 'Right'], ['Left', 'Left'], ['Left', 'Right'], ['Right', 'Left']]): 141 | for j, joint1 in enumerate(contact_joints): 142 | if contact[i, s+1] == 1: 143 | break 144 | for k, joint2 in enumerate(contact_joints): 145 | j1 = sides[0] + joint1 146 | j2 = sides[1] + joint2 147 | 148 | idx1 = self.bvh_joint_order[j1] 149 | idx2 = self.bvh_joint_order[j2] 150 | 151 | d = dist(motion1[i, idx1], motion2[i, idx2]) 152 | if d <= thresh: 153 | contact[i, s+1] = 1 154 | contact_frames.append(i) 155 | count += 1 156 | break 157 | 158 | 159 | print(count) 160 | return contact, contact_frames 161 | 162 | 163 | def save_annot(self, split): 164 | save_path = makepath(os.path.join('data', 'LindyHop', split+'.pkl'), isfile=True) 165 | with open(save_path, 'wb') as f: 166 | pickle.dump(self.annot_dict, f) 167 | 168 | 169 | def _load_files(self, seq, p): 170 | file_basename = os.path.join(self.root, str(seq), str(p)) 171 | path_canon_3d = os.path.join(file_basename, 'motion_worldpos.csv') 172 | path_dofs = os.path.join(file_basename, 'motion_rotations.csv') 173 | path_offsets = os.path.join(file_basename, 'motion_offsets.pkl') 174 | 175 | print(f"loading file {file_basename}") 176 | canon_3d = np.genfromtxt(path_canon_3d, delimiter=',', skip_header=1) # n_c*n_f x n_dof 177 | dofs = np.genfromtxt(path_dofs, delimiter=',', skip_header=1) # n_c*n_f x n_dof 178 | with open(path_offsets, 'rb') as f: 179 | offset_dict = pickle.load(f) 180 | print(f"loading complete") 181 | 182 | n_frames = canon_3d.shape[0] 183 | canon_3d = np.float32(canon_3d[:, 1:].reshape(n_frames, -1, 3)) 184 | dofs = dofs[:, 1:].reshape(n_frames, -1, 3) 185 | 186 | #Downsample the data from 50 fps to given framerate 187 | use_frames = list(np.rint(np.arange(0, n_frames, 50/self.framerate))) 188 | use_frames = [int(a) for a in use_frames] 189 | canon_3d = canon_3d[use_frames] 190 | dofs = np.float32(dofs[use_frames]) 191 | print(canon_3d.shape) 192 | return n_frames, canon_3d, dofs, offset_dict 193 | 194 | 195 | def collate_videos(self): 196 | self.annot_dict['bvh_joint_order'] = self.bvh_joint_order 197 | # self.annot_dict['joint_order'] = self.joint_order 198 | for i, seq in enumerate(self.sequences): 199 | seq_total_frames, canon_3d_1, dofs_1, offsets_1 = self._load_files(seq, 0) 200 | self.total_frames += seq_total_frames 201 | # continue 202 | _, canon_3d_2, dofs_2, offsets_2 = self._load_files(seq, 1) 203 | if canon_3d_2.shape[0] < canon_3d_1.shape[0]: 204 | n_frames = canon_3d_2.shape[0] 205 | else: 206 | n_frames = canon_3d_1.shape[0] 207 | canon_3d_1 = canon_3d_1[:n_frames] 208 | canon_3d_2 = canon_3d_2[:n_frames] 209 | contacts, contact_frames = self.detect_contact(canon_3d_1, canon_3d_2) 210 | 211 | n_frames_contact = len(contact_frames) 212 | self.total_contact_frames += n_frames_contact 213 | canon_3d_1 = canon_3d_1[contact_frames] 214 | canon_3d_2 = canon_3d_2[contact_frames] 215 | output_dofs_1 = dofs_1[contact_frames] 216 | output_dofs_2 = dofs_2[contact_frames] 217 | contacts = contacts[contact_frames, 1:] 218 | rotmat_1 = batch_euler_to_rotmat(output_dofs_1) 219 | rotmat_2 = batch_euler_to_rotmat(output_dofs_2) 220 | self.annot_dict['offsets_1'].extend([offsets_1 for i in range(0, n_frames_contact)]) 221 | self.annot_dict['offsets_2'].extend([offsets_2 for i in range(0, n_frames_contact)]) 222 | self.annot_dict['seq'].extend([seq for i in range(0, n_frames_contact)]) 223 | self.annot_dict['pose_canon_1'].extend([canon_3d_1 for i in range(0, n_frames_contact)]) 224 | self.annot_dict['pose_canon_2'].extend([canon_3d_2 for i in range(0, n_frames_contact )]) 225 | self.annot_dict['contacts'].extend([contacts for i in range(0, n_frames_contact )]) 226 | self.annot_dict['dofs_1'].extend([output_dofs_1 for i in range(0, n_frames_contact )]) 227 | self.annot_dict['dofs_2'].extend([output_dofs_2 for i in range(0, n_frames_contact )]) 228 | self.annot_dict['rotmat_1'].extend([rotmat_1 for i in range(0, n_frames_contact)]) 229 | self.annot_dict['rotmat_2'].extend([rotmat_2 for i in range(0, n_frames_contact )]) 230 | 231 | print(self.total_frames) 232 | print(self.total_contact_frames) 233 | 234 | 235 | 236 | if __name__ == "__main__": 237 | root_path = os.path.join('..', 'DATASETS', 'ReMocap', 'LindyHop') 238 | fps = 20 239 | pp = PreProcessor(root_path, fps, 'train') 240 | pp = PreProcessor(root_path, fps, 'test') 241 | 242 | -------------------------------------------------------------------------------- /src/Lindyhop/train_VanillaTransformer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import shutil 7 | import sys 8 | sys.path.append('.') 9 | sys.path.append('..') 10 | import time 11 | import torch 12 | torch.cuda.empty_cache() 13 | import torch.nn as nn 14 | 15 | from cmath import nan 16 | from collections import OrderedDict 17 | from datetime import datetime 18 | from torch import optim 19 | from torch.utils.data import DataLoader 20 | from tqdm import tqdm 21 | 22 | from src.Lindyhop.argUtils import argparseNloop 23 | from src.Lindyhop.LindyHop_dataloader import LindyHopDataset 24 | from src.Lindyhop.models.transAE import * 25 | 26 | from src.Lindyhop.skeleton import * 27 | from src.tools.bookkeeper import * 28 | from src.tools.transformations import * 29 | from src.tools.utils import makepath 30 | 31 | right_side = [15, 16, 17, 18] 32 | left_side = [19, 20, 21, 22] 33 | # stat_metrics = CalculateMetricsDanceData() 34 | def dist(x, y): 35 | # return torch.mean(x - y) 36 | return torch.mean(torch.cdist(x, y, p=2)) 37 | 38 | def initialize_weights(m): 39 | std_dev = 0.02 40 | if isinstance(m, nn.Linear): 41 | nn.init.normal_(m.weight, std=std_dev) 42 | if m.bias is not None: 43 | nn.init.normal_(m.bias, std=std_dev) 44 | # nn.init.constant_(m.bias.data, 1e-5) 45 | elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d): 46 | torch.nn.init.normal_(m.weight, std=std_dev) 47 | if m.bias is not None: 48 | torch.nn.init.normal_(m.bias, std=std_dev) 49 | # nn.init.constant_(m.bias.data, 1e-5) 50 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 51 | nn.init.normal_(m.weight, std=std_dev) 52 | if m.bias is not None: 53 | nn.init.normal_(m.bias, std=std_dev) 54 | 55 | class Trainer: 56 | def __init__(self, args, is_train=True, split='test', JT_POSITION=False, num_jts = 69): 57 | torch.manual_seed(args.seed) 58 | self.model_path = args.model_path 59 | makepath(args.work_dir, isfile=False) 60 | use_cuda = torch.cuda.is_available() 61 | if use_cuda: 62 | torch.cuda.empty_cache() 63 | self.device = torch.device("cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu") 64 | gpu_brand = torch.cuda.get_device_name(args.cuda) if use_cuda else None 65 | gpu_count = torch.cuda.device_count() if args.use_multigpu else 1 66 | print('Using %d CUDA cores [%s] for training!' % (gpu_count, gpu_brand)) 67 | args_subset = ['exp', 'model', 'batch_size', 'frames'] 68 | self.book = BookKeeper(args, args_subset) 69 | self.args = self.book.args 70 | self.batch_size = args.batch_size 71 | self.curriculum = args.curriculum 72 | self.scale = args.scale 73 | self.dtype = torch.float32 74 | self.epochs_completed = self.book.last_epoch 75 | self.frames = args.frames 76 | self.model = args.model 77 | self.testtime_split = split 78 | self.num_jts = num_jts 79 | self.model_pose = VanillaTransformer(args).to(self.device).float() 80 | trainable_count_body = sum(p.numel() for p in self.model_pose.parameters() if p.requires_grad) 81 | self.model_pose.apply(initialize_weights) 82 | self.optimizer_model_pose = eval(args.optimizer)(self.model_pose.parameters(), lr = args.lr) 83 | self.scheduler_pose = eval(args.scheduler)(self.optimizer_model_pose, step_size=args.stepsize, gamma=args.gamma) 84 | self.skel = InhouseStudioSkeleton() 85 | 86 | print(args.model, 'Model Created') 87 | if args.load: 88 | print('Loading Model', args.model) 89 | self.book._load_model(self.model_pose, 'model_pose') 90 | print('Loading the data') 91 | if is_train: 92 | self.load_data(args) 93 | else: 94 | self.load_data_testtime(args) 95 | 96 | 97 | def load_data_testtime(self, args): 98 | self.ds_data = LindyHopDataset(args, window_size=self.frames, split=self.testtime_split) 99 | self.load_ds_data = DataLoader(self.ds_data, batch_size=1, shuffle=False, num_workers=0, drop_last=True) 100 | 101 | 102 | def load_data(self, args): 103 | 104 | ds_train = LindyHopDataset(args, window_size=self.frames, split='train') 105 | self.ds_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) 106 | print('Train set loaded. Size=', len(self.ds_train.dataset)) 107 | ds_val = LindyHopDataset(args, window_size=self.frames, split='test') 108 | self.ds_val = DataLoader(ds_val, batch_size=1, shuffle=False, num_workers=0, drop_last=True) 109 | print('Validation set loaded. Size=', len(self.ds_val.dataset)) 110 | 111 | 112 | def train(self, num_epoch, ablation=None): 113 | total_train_loss = 0.0 114 | self.model_pose.train() 115 | training_tqdm = tqdm(self.ds_train, desc='train' + ' {:.10f}'.format(0), leave=False, ncols=120) 116 | for count, batch in enumerate(training_tqdm): 117 | self.optimizer_model_pose.zero_grad() 118 | with torch.autograd.detect_anomaly(): 119 | global_pose1 = batch['pose_canon_1'].to(self.device).float() 120 | global_pose1 = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order, 121 | new_joint_order=self.skel.body_only) 122 | global_pose2 = batch['pose_canon_2'].to(self.device).float() 123 | global_pose2 = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order, 124 | new_joint_order=self.skel.body_only) 125 | 126 | _, loss_model = self.model_pose(global_pose1, global_pose2) 127 | total_train_loss += loss_model.item() 128 | 129 | if loss_model == float('inf') or torch.isnan(loss_model): 130 | print('Train loss is nan') 131 | exit() 132 | loss_model.backward() 133 | torch.nn.utils.clip_grad_value_(self.model_pose.parameters(), 0.01) 134 | self.optimizer_model_pose.step() 135 | 136 | avg_train_loss = total_train_loss/(count + 1) 137 | return avg_train_loss 138 | 139 | def evaluate(self, num_epoch, ablation=None): 140 | total_eval_loss = 0.0 141 | self.model_pose.eval() 142 | T = self.frames 143 | eval_tqdm = tqdm(self.ds_val, desc='eval' + ' {:.10f}'.format(0), leave=False, ncols=120) 144 | for count, batch in enumerate(eval_tqdm): 145 | if True: 146 | global_pose1 = batch['pose_canon_1'].to(self.device).float() 147 | global_pose1 = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order, 148 | new_joint_order=self.skel.body_only) 149 | global_pose2 = batch['pose_canon_2'].to(self.device).float() 150 | global_pose2 = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order, 151 | new_joint_order=self.skel.body_only) 152 | 153 | _, loss_model = self.model_pose(global_pose1, global_pose2) 154 | total_eval_loss += loss_model.item() 155 | 156 | avg_eval_loss = total_eval_loss/(count + 1) 157 | return avg_eval_loss 158 | 159 | def fit(self, n_epochs=None, ablation=False): 160 | print('*****Inside Trainer.fit *****') 161 | if n_epochs is None: 162 | n_epochs = self.args.num_epochs 163 | starttime = datetime.now().replace(microsecond=0) 164 | print('Started Training at', datetime.strftime(starttime, '%Y-%m-%d_%H:%M:%S'), 'Total epochs: ', n_epochs) 165 | save_model_dict = {} 166 | best_eval = 1000 167 | for epoch_num in range(self.epochs_completed, n_epochs + 1): 168 | tqdm.write('--- starting Epoch # %03d' % epoch_num) 169 | train_loss = self.train(epoch_num, ablation) 170 | 171 | if epoch_num % 5 == 0: 172 | eval_loss = self.evaluate(epoch_num, ablation) 173 | else: 174 | eval_loss = 0.0 175 | self.scheduler_pose.step() 176 | self.book.update_res({'epoch': epoch_num, 'train': train_loss, 'val': eval_loss, 'test': 0.0}) 177 | self.book._save_res() 178 | self.book.print_res(epoch_num, key_order=['train', 'val', 'test'], lr=self.optimizer_model_pose.param_groups[0]['lr']) 179 | 180 | if epoch_num > 100 and eval_loss < best_eval: 181 | print('Best eval at epoch {}'.format(epoch_num)) 182 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + 'best.p'), 'wb') 183 | save_model_dict.update({'model_pose': self.model_pose.state_dict()}) 184 | torch.save(save_model_dict, f) 185 | f.close() 186 | best_eval = eval_loss 187 | if epoch_num > 20 and epoch_num % 20 == 0 : 188 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + '{:06d}'.format(epoch_num) + '.p'), 'wb') 189 | save_model_dict.update({'model_pose': self.model_pose.state_dict()}) 190 | torch.save(save_model_dict, f) 191 | f.close() 192 | endtime = datetime.now().replace(microsecond=0) 193 | print('Finished Training at %s\n' % (datetime.strftime(endtime, '%Y-%m-%d_%H:%M:%S'))) 194 | print('Training complete in %s!\n' % (endtime - starttime)) 195 | 196 | 197 | 198 | if __name__ == '__main__': 199 | args = argparseNloop() 200 | 201 | is_train = True 202 | ablation = None # if True then ablation: no_IAC_loss 203 | model_trainer = Trainer(args=args, is_train=is_train, split='test', JT_POSITION=True, num_jts=27) 204 | print("** Method Initialization Complete **") 205 | model_trainer.fit(ablation=ablation) 206 | 207 | -------------------------------------------------------------------------------- /src/Ninjutsu/process_Ninjutsu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import torch 3 | import os 4 | import glob 5 | import sys 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | import pickle 9 | from src.tools.transformations import batch_euler_to_rotmat 10 | 11 | def makepath(desired_path, isfile = False): 12 | ''' 13 | if the path does not exist make it 14 | :param desired_path: can be path to a file or a folder name 15 | :return: 16 | ''' 17 | import os 18 | if isfile: 19 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path)) 20 | else: 21 | if not os.path.exists(desired_path): os.makedirs(desired_path) 22 | return desired_path 23 | 24 | class PreProcessor(): 25 | def __init__(self, root_dir, fps=20, split='train'): 26 | self.root = root_dir 27 | self.framerate = fps 28 | self.root = os.path.join(root_dir, split) 29 | self.sequences = os.listdir(self.root) 30 | 31 | 32 | self.annot_dict = { 33 | 'cam': [], 34 | 'seq': [], 35 | 'contacts': [], 36 | 'pose_canon_1':[], 'pose_canon_2':[], 37 | 'dofs_1': [], 'dofs_2': [], 38 | 'rotmat_1': [], 'rotmat_2': [], 39 | 'offsets_1': [], 'offsets_2': [] 40 | } 41 | 42 | self.bvh_joint_order = { 43 | 'Hips': 0, 44 | 'RightUpLeg': 1, 45 | 'RightLeg': 2, 46 | 'RightFoot': 3, 47 | 'RightToeBase': 4, 48 | 'RightToeBaseEnd': 5, 49 | 'LeftUpLeg': 6, 50 | 'LeftLeg': 7, 51 | 'LeftFoot': 8, 52 | 'LeftToeBase': 9, 53 | 'LeftToeBaseEnd': 10, 54 | 'Spine': 11, 55 | 'Spine1': 12, 56 | 'Spine2': 13, 57 | 'Spine3': 14, 58 | 'RightShoulder': 15, 59 | 'RightArm': 16, 60 | 'RightForeArm': 17, 61 | 'RightHand': 18, 62 | 'RightHandEnd': 19, 63 | 'RightHandPinky1': 20, 64 | 'RightHandPinky2': 21, 65 | 'RightHandPinky3': 22, 66 | 'RightHandPinky3End': 23, 67 | 'RightHandRing1': 24, 68 | 'RightHandRing2': 25, 69 | 'RightHandRing3': 26, 70 | 'RightHandRing3End': 27, 71 | 'RightHandMiddle1': 28, 72 | 'RightHandMiddle2': 29, 73 | 'RightHandMiddle3': 30, 74 | 'RightHandMiddle3End': 31, 75 | 'RightHandIndex1': 32, 76 | 'RightHandIndex2': 33, 77 | 'RightHandIndex3': 34, 78 | 'RightHandIndex3End': 35, 79 | 'RightHandThumb1': 36, 80 | 'RightHandThumb2': 37, 81 | 'RightHandThumb3': 38, 82 | 'RightHandThumb3End': 39, 83 | 'LeftShoulder': 40, 84 | 'LeftArm': 41, 85 | 'LeftForeArm': 42, 86 | 'LeftHand': 43, 87 | 'LeftHandEnd': 44, 88 | 'LeftHandPinky1': 45, 89 | 'LeftHandPinky2': 46, 90 | 'LeftHandPinky3': 47, 91 | 'LeftHandPinky3End': 48, 92 | 'LeftHandRing1': 49, 93 | 'LeftHandRing2': 50, 94 | 'LeftHandRing3': 51, 95 | 'LeftHandRing3End': 52, 96 | 'LeftHandMiddle1': 53, 97 | 'LeftHandMiddle2': 54, 98 | 'LeftHandMiddle3': 55, 99 | 'LeftHandMiddle3End': 56, 100 | 'LeftHandIndex1': 57, 101 | 'LeftHandIndex2': 58, 102 | 'LeftHandIndex3': 59, 103 | 'LeftHandIndex3End': 60, 104 | 'LeftHandThumb1': 61, 105 | 'LeftHandThumb2': 62, 106 | 'LeftHandThumb3': 63, 107 | 'LeftHandThumb3End': 64, 108 | 'Spine4': 65, 109 | 'Neck': 66, 110 | 'Head': 67, 111 | 'HeadEnd': 68 112 | } 113 | 114 | print("creating the annot file") 115 | self.collate_videos() 116 | self.save_annot(split) 117 | 118 | def detect_contact(self, motion1, motion2, thresh=120): 119 | 120 | 121 | contact_joints = ['Hand', 'HandEnd', 122 | 'HandPinky1', 'HandPinky2', 'HandPinky3', 'HandPinky3End', 123 | 'HandRing1', 'HandRing2', 'HandRing3','HandRing3End', 124 | 'HandIndex1', 'HandIndex2', 'HandIndex3','HandIndex3End', 125 | 'HandMiddle1', 'HandMiddle2', 'HandMiddle3','HandMiddle3End', 126 | 'HandThumb1', 'HandThumb2', 'HandThumb3','HandThumb3End'] 127 | 128 | n_frames = motion1.shape[0] 129 | 130 | assert motion1.shape == motion2.shape 131 | 132 | ## 0 : no contact, 1: rh-rh, 2: lh-lh, 3: lh-rh , 4: rh-lh 133 | contact = np.zeros((n_frames, 5)) 134 | 135 | def dist(x, y): 136 | return np.sqrt(np.sum((x - y)**2)) 137 | count = 0 138 | for i in range(n_frames): 139 | for s, sides in enumerate([['Right', 'Right'], ['Left', 'Left'], ['Left', 'Right'], ['Right', 'Left']]): 140 | for j, joint1 in enumerate(contact_joints): 141 | if contact[i, s+1] == 1: 142 | break 143 | for k, joint2 in enumerate(contact_joints): 144 | j1 = sides[0] + joint1 145 | j2 = sides[1] + joint2 146 | 147 | idx1 = self.bvh_joint_order[j1] 148 | idx2 = self.bvh_joint_order[j2] 149 | 150 | d = dist(motion1[i, idx1], motion2[i, idx2]) 151 | if d <= thresh: 152 | contact[i, s+1] = 1 153 | count += 1 154 | break 155 | 156 | 157 | print(count) 158 | return contact[:, 1:] 159 | 160 | 161 | def use_frames(self, motion1, motion2, thresh=1000): 162 | 163 | t_frame = motion1.shape[0] 164 | xx = np.tile(motion1, (1,69,1)) 165 | yy = np.repeat(motion2, 69, axis=1) 166 | 167 | diff = xx - yy 168 | D = np.linalg.norm(diff, axis=-1) 169 | contact = (D <= thresh)*1 170 | sum_contact = np.sum(contact, axis=1) 171 | contact_frames = list(np.nonzero(sum_contact)[0]) 172 | return contact, contact_frames 173 | 174 | def save_annot(self, split): 175 | save_path = makepath(os.path.join('data', 'Ninjutsu', split+'.pkl'), isfile=True) 176 | with open(save_path, 'wb') as f: 177 | pickle.dump(self.annot_dict, f) 178 | 179 | 180 | def load_files(self, seq, fname='0'): 181 | file_basename = os.path.join(self.root, str(seq)) 182 | path_canon_3d = os.path.join(file_basename, fname+'_worldpos.csv') 183 | path_dofs = os.path.join(file_basename, fname+'_rotations.csv') 184 | path_offsets = os.path.join(file_basename, fname+'_offsets.pkl') 185 | 186 | print(f"loading file {file_basename}") 187 | canon_3d = np.genfromtxt(path_canon_3d, delimiter=',', skip_header=1) # n_c*n_f x n_dof 188 | dofs = np.genfromtxt(path_dofs, delimiter=',', skip_header=1) # n_c*n_f x n_dof 189 | with open(path_offsets, 'rb') as f: 190 | offset_dict = pickle.load(f) 191 | print(f"loading complete") 192 | 193 | n_frames = canon_3d.shape[0] 194 | canon_3d = np.float32(canon_3d[:, 1:].reshape(n_frames, -1, 3)) 195 | dofs = dofs[:, 1:].reshape(n_frames, -1, 3) 196 | 197 | use_frames = list(np.rint(np.arange(0, n_frames-1, 25/self.framerate))) 198 | use_frames = [int(a) for a in use_frames] 199 | canon_3d = canon_3d[use_frames] 200 | dofs = np.float32(dofs[use_frames]) 201 | print(canon_3d.shape) 202 | return 0, 0, canon_3d, dofs, offset_dict 203 | 204 | 205 | def collate_videos(self): 206 | self.annot_dict['bvh_joint_order'] = self.bvh_joint_order 207 | for i, seq in enumerate(self.sequences): 208 | _, _, canon_3d_1, dofs_1, offsets_1 = self.load_files(seq, '0') 209 | _, _, canon_3d_2, dofs_2, offsets_2 = self.load_files(seq, '1') 210 | if canon_3d_2.shape[0] < canon_3d_1.shape[0]: 211 | n_frames = canon_3d_2.shape[0] 212 | else: 213 | n_frames = canon_3d_1.shape[0] 214 | canon_3d_1 = canon_3d_1[:n_frames] 215 | canon_3d_2 = canon_3d_2[:n_frames] 216 | rotmat_1 = batch_euler_to_rotmat(dofs_1) 217 | rotmat_2 = batch_euler_to_rotmat(dofs_2) 218 | contacts, contact_frames = self.use_frames(canon_3d_1, canon_3d_2) 219 | n_frames_contact = len(contact_frames) 220 | canon_3d_1 = canon_3d_1[contact_frames] 221 | canon_3d_2 = canon_3d_2[contact_frames] 222 | dofs_1 = dofs_1[contact_frames] 223 | dofs_2 = dofs_2[contact_frames] 224 | contacts = contacts[contact_frames].reshape(n_frames_contact, 69, 69) 225 | hand_contact = self.detect_contact(motion1=canon_3d_1, motion2=canon_3d_2) 226 | self.annot_dict['pose_canon_1'].extend(canon_3d_1 ) 227 | self.annot_dict['pose_canon_2'].extend(canon_3d_2 ) 228 | self.annot_dict['dofs_1'].extend(dofs_1 ) 229 | self.annot_dict['dofs_2'].extend(dofs_2 ) 230 | self.annot_dict['contacts'].extend(hand_contact) 231 | self.annot_dict['rotmat_1'].extend(rotmat_1 ) 232 | self.annot_dict['rotmat_2'].extend(rotmat_2 ) 233 | self.annot_dict['offsets_1'].extend([offsets_1 for i in range(0, n_frames_contact)]) 234 | self.annot_dict['offsets_2'].extend([offsets_2 for i in range(0, n_frames_contact)]) 235 | self.annot_dict['seq'].extend([seq for i in range(0, n_frames_contact)]) 236 | 237 | 238 | self.annot_dict['pose_canon_1'] = np.array(self.annot_dict['pose_canon_1']) 239 | self.annot_dict['pose_canon_2'] = np.array(self.annot_dict['pose_canon_2']) 240 | self.annot_dict['dofs_1'] = np.array(self.annot_dict['dofs_1']) 241 | self.annot_dict['dofs_2'] = np.array(self.annot_dict['dofs_2']) 242 | self.annot_dict['rotmat_1'] = np.array(self.annot_dict['rotmat_1']) 243 | self.annot_dict['rotmat_2'] = np.array(self.annot_dict['rotmat_2']) 244 | self.annot_dict['contacts'] = np.array(self.annot_dict['contacts']) 245 | print(len(self.annot_dict.keys())) 246 | print(len(self.annot_dict['seq'])) 247 | print(len(self.annot_dict['pose_canon_1'])) 248 | 249 | 250 | 251 | 252 | if __name__ == "__main__": 253 | root_path = os.path.join('..', 'DATASETS', 'ReMocap', 'Ninjutsu') 254 | fps = 10 255 | pp = PreProcessor(root_path, fps, 'train') 256 | pp = PreProcessor(root_path, fps, 'test') 257 | -------------------------------------------------------------------------------- /src/Lindyhop/visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import sys 8 | sys.path.append('.') 9 | sys.path.append('..') 10 | from mpl_toolkits.mplot3d import Axes3D 11 | from matplotlib.animation import FuncAnimation, PillowWriter 12 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 13 | import mpl_toolkits.mplot3d.axes3d as p3 14 | from PIL import Image 15 | from scipy import interpolate 16 | from src.tools.utils import makepath 17 | from src.tools.img_gif import img2video, img2gif 18 | 19 | LEFT_HANDSIDE = list(range(19, 24)) 20 | RIGHT_HANDSIDE = list(range(45, 48)) 21 | LEFT_FOOTSIDE = list(range(1, 5)) 22 | RIGHT_FOOTSIDE = list(range(6, 10)) 23 | 24 | kinematic_chain_full = [[0, 11], [11, 12], [12, 13], [13, 14], [14, 65], [65, 66], [66, 67], [67, 68], #spine, neck and head 25 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], # right leg 26 | [0, 6], [6, 7], [7, 8], [8, 9], [9, 10], # left leg 27 | [14, 15], [15, 16], [16, 17], [17, 18], [18, 19], # right arm 28 | [14, 40], [40, 41], [41, 42], [42, 43], [43, 44], # left arm 29 | [19, 20], [20, 21], [21, 22], [22, 23], # right pinky 30 | [19, 24], [24, 25], [25, 26], [26, 27], # right ring 31 | [19, 28], [28, 29], [29, 30], [30, 31], # right middle 32 | [19, 32], [32, 33], [33, 34], [34, 35], # right index 33 | [18, 36], [36, 37], [37, 38], [38, 39], # right thumb 34 | [44, 45], [45, 46], [46, 47], [47, 48], # left pinky 35 | [44, 49], [49, 50], [50, 51], [51, 52], # left ring 36 | [44, 53], [53, 54], [54, 55], [55, 56], # left middle 37 | [44, 57], [57, 58], [58, 59], [59, 60], # left index 38 | [43, 61], [61, 62], [62, 63], [63, 64], # left thumb 39 | ] 40 | kinematic_chain_reduced = [[0, 11], [11, 12], [12, 13], [13, 14], [14, 43], [43, 44], [44, 45], [45, 46], #spine, neck and head 41 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], # right leg 42 | [0, 6], [6, 7], [7, 8], [8, 9], [9, 10], # left leg 43 | [14, 15], [15, 16], [16, 17], [17, 18], # right arm 44 | [14, 29], [29, 30], [30, 31], [31, 32], # left arm 45 | [18, 19], [19, 20], # right pinky 46 | [18, 21], [21, 22], # right ring 47 | [18, 23], [23, 24], # right middle 48 | [18, 25], [25, 26], # right index 49 | [18, 27], [27, 28], # right thumb 50 | [32, 33], [33, 34], # left pinky 51 | [32, 35], [35, 36], # left ring 52 | [32, 37], [37, 38], # left middle 53 | [32, 39], [39, 40], # left index 54 | [32, 41], [41, 42], # left thumb 55 | ] 56 | 57 | kinematic_chain_short = [ 58 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], 59 | [0, 6], [6, 7], [7, 8], [8, 9], [9, 10], 60 | [0, 11], [11, 12], [12, 13], [13, 14], 61 | [14, 15], [15, 16], [16, 17], [17, 18], 62 | [14, 19], [19, 20], [20, 21], [21, 22], 63 | [14, 23], [23, 24], [24, 25], [25, 26] 64 | ] 65 | kinematic_chain_old = [ 66 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], 67 | [0, 6], [6, 7], [7, 8], [8, 9], [9, 10], 68 | [0, 11], [11, 12], [12, 13], [13, 14], 69 | [12, 15], [15, 16], [16, 17], 70 | [12, 18], [18, 19], [19, 20], 71 | [17, 21], [21, 22], [22, 23], [23, 24], [22, 25], [25, 26], [22, 27], 72 | [27, 28], [22, 29], [29, 30], [22, 31], [31, 32], 73 | [20, 33], [33, 34], [34, 35], [35, 36], [34, 37], [37, 38], [34, 39], [39, 40], 74 | [34, 41], [41, 42], [34, 43], [43, 44] 75 | ] 76 | def fig2data(fig): 77 | fig.canvas.draw() 78 | w, h = fig.canvas.get_width_height() 79 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 80 | buf.shape = (w, h, 4) 81 | buf = np.roll(buf, 3, axis=2) 82 | return buf 83 | 84 | def fig2img(fig): 85 | buf = fig2data(fig) 86 | w, h, d = buf.shape 87 | return Image.frombytes('RGBA', (w, h), buf.tostring()) 88 | 89 | def plot_contacts3D(pose1, pose2=None, gt_pose2=None, savepath=None, kinematic_chain = 'full', onlyone=False, gif=False): 90 | 91 | def plot_twoperson(pose1, pose2, i, kinematic_chain, savepath, gt_pose2=None): 92 | fig = plt.figure() 93 | 94 | ax = plt.subplot(projection='3d') 95 | 96 | ax.cla() 97 | ax.set_xlabel("x") 98 | ax.set_ylabel("y") 99 | ax.set_zlabel("z") 100 | ax.set_xlim3d([-1000, 2000]) 101 | ax.set_zlim3d([-1000, 2000]) 102 | ax.set_ylim3d([-1000, 2000]) 103 | ax.axis('off') 104 | ax.view_init(elev=0, azim=0, roll=90) 105 | if kinematic_chain == 'full': 106 | KINEMATIC_CHAIN = kinematic_chain_full 107 | elif kinematic_chain == 'no_fingers': 108 | KINEMATIC_CHAIN = kinematic_chain_short 109 | elif kinematic_chain == 'reduced': 110 | KINEMATIC_CHAIN = kinematic_chain_reduced 111 | elif kinematic_chain == 'old': 112 | KINEMATIC_CHAIN = kinematic_chain_old 113 | 114 | for limb in KINEMATIC_CHAIN: 115 | xs = [pose1[i, limb[0], 0], pose1[i, limb[1], 0]] 116 | ys = [pose1[i, limb[0], 1], pose1[i, limb[1], 1]] 117 | zs = [pose1[i, limb[0], 2], pose1[i, limb[1], 2]] 118 | # if limb[0] in LEFT_FOOTSIDE or limb[0] in LEFT_HANDSIDE: 119 | # ax.plot(xs, ys, zs, 'darkred', linewidth=2.0) 120 | # else: 121 | # ax.plot(xs, ys, zs, 'red', linewidth=2.0) 122 | ax.plot(xs, ys, zs, 'red', linewidth=2.0) 123 | 124 | xs_ = [pose2[i, limb[0], 0], pose2[i, limb[1], 0]] 125 | ys_ = [pose2[i, limb[0], 1], pose2[i, limb[1], 1]] 126 | zs_ = [pose2[i, limb[0], 2], pose2[i, limb[1], 2]] 127 | # if limb[0] in LEFT_FOOTSIDE or limb[0] in LEFT_HANDSIDE: 128 | # ax.plot(xs_, ys_, zs_, 'darkblue', linewidth=2.0) 129 | # else: 130 | # ax.plot(xs_, ys_, zs_, 'blue', linewidth=2.0) 131 | ax.plot(xs_, ys_, zs_, 'blue', linewidth=2.0) 132 | if gt_pose2 is not None: 133 | gt_xs_ = [gt_pose2[i, limb[0], 0], gt_pose2[i, limb[1], 0]] 134 | gt_ys_ = [gt_pose2[i, limb[0], 1], gt_pose2[i, limb[1], 1]] 135 | gt_zs_ = [gt_pose2[i, limb[0], 2], gt_pose2[i, limb[1], 2]] 136 | ax.plot(gt_xs_, gt_ys_, gt_zs_, 'g', linewidth=1.0) 137 | # min_x = min(min(pose1[i, :, 2]), min(pose2[i, :, 2])) - 100 138 | # min_y = min(min(pose1[i, :, 0]), min(pose2[i, :, 0])) - 100 139 | # max_x = max(max(pose1[i, :, 2]), max(pose2[i, :, 2])) + 100 140 | # max_y = max(max(pose1[i, :, 0]), max(pose2[i, :, 0])) + 100 141 | # x_pl, y_pl = np.meshgrid(np.linspace(min_x, max_x, 10), np.linspace(min_y, max_y, 10)) 142 | # foot_ground_contact_p1 =min(pose1[i, :, 1]) 143 | # foot_ground_contact_2 =min(pose2[i, :, 1]) 144 | # ground_plane = min(foot_ground_contact_p1, foot_ground_contact_2) 145 | # z_pl = torch.ones((10, 10)) * ground_plane 146 | # ax.plot_surface(x_pl, y_pl, z_pl, color= 'y', alpha=0.1) 147 | filename = makepath(os.path.join(savepath, str(i) +'.png'), isfile=True) 148 | plt.savefig(filename) 149 | plt.close() 150 | 151 | 152 | def plot_oneperson(pose2, i, kinematic_chain, savepath): 153 | # fig = plt.figure() 154 | ax = plt.subplot(projection='3d') 155 | ax.cla() 156 | ax.set_xlabel("x") 157 | ax.set_ylabel("y") 158 | ax.set_zlabel("z") 159 | 160 | ax.axis('off') 161 | ax.view_init(elev=0, azim=0, roll=0) 162 | if kinematic_chain == 'full': 163 | KINEMATIC_CHAIN = kinematic_chain_full 164 | elif kinematic_chain == 'no_fingers': 165 | KINEMATIC_CHAIN = kinematic_chain_short 166 | 167 | for limb in KINEMATIC_CHAIN: 168 | ys_ = [pose2[i, limb[0], 0], pose2[i, limb[1], 0]] 169 | zs_ = [pose2[i, limb[0], 1], pose2[i, limb[1], 1]] 170 | xs_ = [pose2[i, limb[0], 2], pose2[i, limb[1], 2]] 171 | if limb[0] in LEFT_FOOTSIDE or limb[0] in LEFT_HANDSIDE: 172 | ax.plot(xs_, ys_, zs_, 'darkred', linewidth=3.0) 173 | else: 174 | ax.plot(xs_, ys_, zs_, 'red', linewidth=3.0) 175 | filename = makepath(os.path.join(savepath, str(i) +'.png'), isfile=True) 176 | plt.savefig(filename) 177 | # plt.pause(0.001) 178 | plt.close() 179 | 180 | T = pose1.shape[0] 181 | is_interpolate = 0 182 | if is_interpolate: 183 | T1 = 3*T 184 | p1_x_interp =np.zeros((T1, pose1.shape[1])) 185 | p1_y_interp =np.zeros((T1, pose1.shape[1])) 186 | p1_z_interp =np.zeros((T1, pose1.shape[1])) 187 | p2_x_interp =np.zeros((T1, pose2.shape[1])) 188 | p2_y_interp =np.zeros((T1, pose2.shape[1])) 189 | p2_z_interp =np.zeros((T1, pose2.shape[1])) 190 | 191 | x = np.linspace(0, T-1 ,T) 192 | x_new = np.linspace(0, T-1 ,T1) 193 | for v1 in range(0, pose1.shape[1]): 194 | p1_x = pose1[:, v1, 0] 195 | p1_y = pose1[:, v1, 1] 196 | p1_z = pose1[:, v1, 2] 197 | p2_x = pose2[:, v1, 0] 198 | p2_y = pose2[:, v1, 1] 199 | p2_z = pose2[:, v1, 2] 200 | f_p1x = interpolate.interp1d(x, p1_x, kind = 'linear') 201 | f_p1y = interpolate.interp1d(x, p1_y, kind = 'linear') 202 | f_p1z = interpolate.interp1d(x, p1_z, kind = 'linear') 203 | f_p2x = interpolate.interp1d(x, p2_x, kind = 'linear') 204 | f_p2y = interpolate.interp1d(x, p2_y, kind = 'linear') 205 | f_p2z = interpolate.interp1d(x, p2_z, kind = 'linear') 206 | p1_x_interp[:, v1] = f_p1x(x_new) 207 | p1_y_interp[:, v1] = f_p1y(x_new) 208 | p1_z_interp[:, v1] = f_p1z(x_new) 209 | p2_x_interp[:, v1] = f_p2x(x_new) 210 | p2_y_interp[:, v1] = f_p2y(x_new) 211 | p2_z_interp[:, v1] = f_p2z(x_new) 212 | p1_x_interp = torch.from_numpy(p1_x_interp).unsqueeze(2) 213 | p1_y_interp = torch.from_numpy(p1_y_interp).unsqueeze(2) 214 | p1_z_interp = torch.from_numpy(p1_z_interp).unsqueeze(2) 215 | p1_interp = torch.cat((p1_x_interp, p1_y_interp, p1_z_interp), dim=-1) 216 | p2_x_interp = torch.from_numpy(p2_x_interp).unsqueeze(2) 217 | p2_y_interp = torch.from_numpy(p2_y_interp).unsqueeze(2) 218 | p2_z_interp = torch.from_numpy(p2_z_interp).unsqueeze(2) 219 | p2_interp = torch.cat((p2_x_interp, p2_y_interp, p2_z_interp), dim=-1) 220 | # verts_all = torch.cat((torch.from_numpy(verts_all[0]).unsqueeze(0), p1_interp), dim=0) 221 | T = T1 222 | pose1 = p1_interp 223 | pose2 = p2_interp 224 | 225 | 226 | for i in range(pose1.shape[0]): 227 | if onlyone: 228 | plot_oneperson(pose1, i, kinematic_chain, savepath) 229 | else: 230 | plot_twoperson(pose1, pose2, i, kinematic_chain, savepath, gt_pose2) 231 | if gif: 232 | img2gif(savepath) 233 | else: 234 | img2video(savepath, fps=20) 235 | 236 | -------------------------------------------------------------------------------- /src/Lindyhop/models/MotionDiffuse_body.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 S-Lab 3 | """ 4 | 5 | import matplotlib.pylab as plt 6 | import random 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import layer_norm, nn 10 | import numpy as np 11 | from torch.nn import functional 12 | 13 | import math 14 | body = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] 15 | hand = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44] 16 | hand_full = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68] 17 | 18 | def heatmap2d(arr: np.ndarray): 19 | plt.imshow(arr, cmap='bwr') 20 | plt.clim(-10, 10) 21 | plt.colorbar() 22 | plt.show() 23 | 24 | class PositionalEncoding(nn.Module): 25 | def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False): 26 | super().__init__() 27 | self.batch_first = batch_first 28 | 29 | self.dropout = nn.Dropout(p=dropout) 30 | 31 | pe = torch.zeros(max_len, d_model) 32 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 33 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 34 | pe[:, 0::2] = torch.sin(position * div_term) 35 | pe[:, 1::2] = torch.cos(position * div_term) 36 | 37 | 38 | for pos in range(max_len): 39 | for i in range(0, d_model-1, 2): 40 | pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model))) 41 | pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) 42 | 43 | 44 | pe = pe.unsqueeze(0) # [1, max_len, d_model] 45 | 46 | self.register_buffer('pe', pe) 47 | 48 | def forward(self, x): 49 | x = x + self.pe[:, :x.shape[1], :] 50 | return self.dropout(x) 51 | 52 | 53 | def timestep_embedding(timesteps, dim, freqs): 54 | """ 55 | Create sinusoidal timestep embeddings. 56 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 57 | These may be fractional. 58 | :param dim: the dimension of the output. 59 | :param max_period: controls the minimum frequency of the embeddings. 60 | :return: an [N x dim] Tensor of positional embeddings. 61 | """ 62 | # freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=self.latent_dim//2, dtype=torch.float32) / (self.latent_dim//2)).to(device) 63 | 64 | # timesteps= timesteps.to('cpu') 65 | args = timesteps[:, None].float() * freqs[None] 66 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 67 | if dim % 2: 68 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 69 | return embedding 70 | 71 | 72 | def set_requires_grad(nets, requires_grad=False): 73 | """Set requies_grad for all the networks. 74 | 75 | Args: 76 | nets (nn.Module | list[nn.Module]): A list of networks or a single 77 | network. 78 | requires_grad (bool): Whether the networks require gradients or not 79 | """ 80 | if not isinstance(nets, list): 81 | nets = [nets] 82 | for net in nets: 83 | if net is not None: 84 | for param in net.parameters(): 85 | param.requires_grad = requires_grad 86 | 87 | 88 | def zero_module(module): 89 | """ 90 | Zero out the parameters of a module and return it. 91 | """ 92 | for p in module.parameters(): 93 | p.detach().zero_() 94 | return module 95 | 96 | 97 | class StylizationBlock(nn.Module): 98 | 99 | def __init__(self, latent_dim, time_embed_dim, dropout): 100 | super().__init__() 101 | self.emb_layers = nn.Sequential( 102 | nn.SiLU(), 103 | nn.Linear(time_embed_dim, 2 * latent_dim), 104 | ) 105 | self.norm = nn.LayerNorm(latent_dim) 106 | self.out_layers = nn.Sequential( 107 | nn.SiLU(), 108 | nn.Dropout(p=dropout), 109 | zero_module(nn.Linear(latent_dim, latent_dim)), 110 | ) 111 | 112 | def forward(self, h, emb): 113 | """ 114 | h: B, T, D 115 | emb: B, D 116 | """ 117 | # B, 1, 2D 118 | emb_out = self.emb_layers(emb).unsqueeze(1) 119 | # scale: B, 1, D / shift: B, 1, D 120 | scale, shift = torch.chunk(emb_out, 2, dim=2) 121 | h = self.norm(h) * (1 + scale) + shift 122 | h = self.out_layers(h) 123 | return h 124 | 125 | 126 | 127 | class FFN(nn.Module): 128 | 129 | def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim): 130 | super().__init__() 131 | self.linear1 = nn.Linear(latent_dim, ffn_dim) 132 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) 133 | self.activation = nn.GELU() 134 | self.dropout = nn.Dropout(dropout) 135 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 136 | 137 | def forward(self, x, emb): 138 | y = self.linear2(self.dropout(self.activation(self.linear1(x)))) 139 | y = x + self.proj_out(y, emb) 140 | return y 141 | 142 | 143 | 144 | class TemporalSelfAttention(nn.Module): 145 | 146 | def __init__(self, seq_len, latent_dim, num_head, dropout, time_embed_dim): 147 | super().__init__() 148 | self.num_head = num_head 149 | self.norm = nn.LayerNorm(latent_dim) 150 | self.query = nn.Linear(latent_dim, latent_dim) 151 | self.key = nn.Linear(latent_dim, latent_dim) 152 | self.value = nn.Linear(latent_dim, latent_dim) 153 | self.dropout = nn.Dropout(dropout) 154 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 155 | 156 | def forward(self, x, emb, src_mask, eps=1e-8): 157 | """ 158 | x: B, T, D 159 | """ 160 | B, T, D = x.shape 161 | H = self.num_head 162 | # B, T, 1, D 163 | query = self.query(self.norm(x)).unsqueeze(2) 164 | # B, 1, T, D 165 | key = self.key(self.norm(x)).unsqueeze(1) 166 | query = query.view(B, T, H, -1) 167 | key = key.view(B, T, H, -1) 168 | # B, T, T, H 169 | attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / max(math.sqrt(D // H), eps) 170 | attention = attention * src_mask.unsqueeze(-1) 171 | weight = self.dropout(F.softmax(attention, dim=2)) 172 | value = self.value(self.norm(x)).view(B, T, H, -1) 173 | y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D) 174 | y = x + self.proj_out(y, emb) 175 | return y, attention 176 | 177 | class TemporalCrossAttention(nn.Module): 178 | 179 | def __init__(self, seq_len, latent_dim, mot1_latent_dim, num_head, dropout, time_embed_dim): 180 | super().__init__() 181 | self.num_head = num_head 182 | self.norm = nn.LayerNorm(latent_dim) 183 | self.mot1_norm = nn.LayerNorm(mot1_latent_dim) 184 | self.query = nn.Linear(latent_dim, latent_dim) 185 | self.key = nn.Linear(mot1_latent_dim, latent_dim) 186 | self.value = nn.Linear(mot1_latent_dim, latent_dim) 187 | self.dropout = nn.Dropout(dropout) 188 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 189 | 190 | def forward(self, x, xf, src_mask, emb, eps=1e-8): 191 | """ 192 | x: B, T, D 193 | xf: B, N, L 194 | """ 195 | B, T, D = x.shape 196 | N = xf.shape[1] 197 | H = self.num_head 198 | # B, T, 1, D 199 | query = self.query(self.norm(x)).unsqueeze(2) 200 | # B, 1, N, D 201 | key = self.key(self.mot1_norm(xf)).unsqueeze(1) 202 | query = query.view(B, T, H, -1) 203 | key = key.view(B, N, H, -1) 204 | # B, T, N, H 205 | attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / max(math.sqrt(D // H), eps) 206 | attention = attention * src_mask.unsqueeze(-1) 207 | weight = self.dropout(F.softmax(attention, dim=2)) 208 | value = self.value(self.mot1_norm(xf)).view(B, N, H, -1) 209 | y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D) 210 | y = x + self.proj_out(y, emb) 211 | return y, attention 212 | 213 | class TemporalDiffusionTransformerDecoderLayer(nn.Module): 214 | 215 | def __init__(self, 216 | seq_len=60, 217 | latent_dim=32, 218 | mot1_latent_dim=512, 219 | time_embed_dim=128, 220 | ffn_dim=256, 221 | num_head=4, 222 | dropout=0.1): 223 | super().__init__() 224 | self.sa_block = TemporalSelfAttention( 225 | seq_len, latent_dim, num_head, dropout, time_embed_dim) 226 | self.ca_block = TemporalCrossAttention( 227 | seq_len, latent_dim, mot1_latent_dim, num_head, dropout, time_embed_dim) 228 | self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim) 229 | 230 | def forward(self, x, xf, emb, src_mask): 231 | x, s_attn = self.sa_block(x, emb, src_mask) 232 | x, c_attn = self.ca_block(x, xf, src_mask, emb) 233 | x = self.ffn(x, emb) 234 | return x, s_attn, c_attn 235 | 236 | 237 | class DiffusionTransformer(nn.Module): 238 | def __init__(self, 239 | device= 'cuda', 240 | num_jts=27, 241 | num_frames=100, 242 | input_feats=3, 243 | latent_dim=32, 244 | ff_size=1024, 245 | num_layers=8, 246 | num_heads=4, 247 | dropout=0.05, 248 | activations="gelu", 249 | **kargs): 250 | super().__init__() 251 | 252 | self.num_frames = num_frames 253 | self.num_jts = num_jts 254 | self.latent_dim = latent_dim 255 | self.ff_size = ff_size 256 | self.num_layers = num_layers 257 | self.num_heads = num_heads 258 | self.freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=self.latent_dim//2, dtype=torch.float32) / (self.latent_dim//2)).to(device) 259 | self.dropout = dropout 260 | self.activation = activations 261 | self.input_feats = input_feats 262 | self.time_embed_dim = latent_dim 263 | self.spatio_temp = self.num_frames * self.num_jts 264 | 265 | # encode motion 1 266 | self.motion1_pre_proj = nn.Linear(self.input_feats, self.latent_dim) 267 | self.m1_temporal_pos_encoder = PositionalEncoding(d_model=self.latent_dim, dropout=self.dropout, max_len=self.spatio_temp) 268 | mot1TransEncoderLayer = nn.TransformerEncoderLayer( 269 | d_model=latent_dim, 270 | nhead=num_heads, 271 | dim_feedforward=ff_size, 272 | dropout=dropout, 273 | batch_first=True, 274 | activation='gelu') 275 | self.mot1TransEncoder = nn.TransformerEncoder( 276 | mot1TransEncoderLayer, 277 | num_layers=2) 278 | self.mot1_ln = nn.LayerNorm(latent_dim) 279 | #Classifier-free guidance 280 | # self.null_cond = nn.Parameter(torch.randn(self.num_frames * self.num_jts, latent_dim)) 281 | 282 | # Time Embedding 283 | self.time_embed = nn.Sequential( 284 | nn.Linear(self.latent_dim, self.time_embed_dim), 285 | nn.SiLU(), 286 | nn.Linear(self.time_embed_dim, self.time_embed_dim), 287 | ) 288 | 289 | # motion2 decoding 290 | self.motion2_pre_proj = nn.Linear(self.input_feats, self.latent_dim) 291 | self.m2_temporal_pos_encoder = PositionalEncoding(d_model=self.latent_dim, dropout=self.dropout, max_len=self.spatio_temp) 292 | self.temporal_decoder_blocks = nn.ModuleList() 293 | for i in range(num_layers): 294 | self.temporal_decoder_blocks.append( 295 | TemporalDiffusionTransformerDecoderLayer( 296 | seq_len=self.spatio_temp, 297 | latent_dim=latent_dim, 298 | mot1_latent_dim=latent_dim, 299 | time_embed_dim=self.time_embed_dim, 300 | ffn_dim=ff_size, 301 | num_head=num_heads, 302 | dropout=dropout 303 | ) 304 | ) 305 | # Output Module 306 | self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats)) 307 | 308 | 309 | def generate_src_mask(self, tgt): 310 | length = tgt.size(1) 311 | src_mask = (1 - torch.triu(torch.ones(1, length, length), diagonal=1)) 312 | return src_mask 313 | 314 | # def forward(self, motion2, timesteps, length=None, motion1=None, xf_out=None, contact_map=None): 315 | def forward(self, motion2, timesteps, motion1=None, contact_maps=None, spatial_guidance=None, guidance_scale=0): 316 | """ 317 | x: B, T, D 318 | """ 319 | B, T, J, _ = motion1.shape 320 | m1 = self.motion1_pre_proj(motion1) # GCN 321 | m2 = self.motion2_pre_proj(motion2) # GCN 322 | m1 = m1.reshape(B, T*J, -1) 323 | m2 = m2.reshape(B, T*J, -1) 324 | src_mask = self.generate_src_mask(m2).to(m2.device) 325 | 326 | m1_pe = self.m1_temporal_pos_encoder(m1) 327 | m1_cond = self.mot1_ln(self.mot1TransEncoder(m1_pe)) 328 | # null_cond = torch.repeat_interleave( 329 | # self.null_cond.to(m2.device).unsqueeze(0), B, dim=0) 330 | # m1_enc = m1_cond if random.random() > 0.25 else null_cond 331 | m1_enc = m1_cond 332 | m2_pe = self.m2_temporal_pos_encoder(m2) 333 | emb = self.time_embed(timestep_embedding( 334 | timesteps, self.latent_dim, self.freqs) ) 335 | h_pe = m2_pe 336 | for module in self.temporal_decoder_blocks: 337 | h_pe, s_attn, c_attn = module(h_pe, m1_enc, emb, src_mask) 338 | output = self.out(h_pe).view(B, T, J, -1).contiguous() 339 | # if timesteps == 1: 340 | # c_attn_map = c_attn[0,:,:,0].cpu().detach().numpy() 341 | # # ax = sns.heatmap(c_attn_map, vmin=-15, vmax=15, linewidth=2) 342 | # # plt.show() 343 | # heatmap2d(c_attn_map) 344 | # tmp=1 345 | return output, s_attn, c_attn 346 | 347 | def get_motion_embedding(self, motion1): 348 | B, T, J, _ = motion1.shape 349 | m1 = self.motion1_pre_proj(motion1) # GCN 350 | m1 = m1.reshape(B, T*J, -1) 351 | m1_pe = self.m1_temporal_pos_encoder(m1) 352 | m1_cond = self.mot1_ln(self.mot1TransEncoder(m1_pe)) 353 | return m1_cond -------------------------------------------------------------------------------- /src/tools/common/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | _EPS4 = np.finfo(float).eps * 4.0 12 | 13 | _FLOAT_EPS = np.finfo(np.float).eps 14 | 15 | # PyTorch-backed implementations 16 | def qinv(q): 17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 18 | mask = torch.ones_like(q) 19 | mask[..., 1:] = -mask[..., 1:] 20 | return q * mask 21 | 22 | 23 | def qinv_np(q): 24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 25 | return qinv(torch.from_numpy(q).float()).numpy() 26 | 27 | 28 | def qnormalize(q): 29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 30 | return q / torch.norm(q, dim=-1, keepdim=True) 31 | 32 | 33 | def qmul(q, r): 34 | """ 35 | Multiply quaternion(s) q with quaternion(s) r. 36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 37 | Returns q*r as a tensor of shape (*, 4). 38 | """ 39 | assert q.shape[-1] == 4 40 | assert r.shape[-1] == 4 41 | 42 | original_shape = q.shape 43 | 44 | # Compute outer product 45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 46 | 47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 51 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 52 | 53 | 54 | def qrot(q, v): 55 | """ 56 | Rotate vector(s) v about the rotation described by quaternion(s) q. 57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 58 | where * denotes any number of dimensions. 59 | Returns a tensor of shape (*, 3). 60 | """ 61 | assert q.shape[-1] == 4 62 | assert v.shape[-1] == 3 63 | assert q.shape[:-1] == v.shape[:-1] 64 | 65 | original_shape = list(v.shape) 66 | # print(q.shape) 67 | q = q.contiguous().view(-1, 4) 68 | v = v.contiguous().view(-1, 3) 69 | 70 | qvec = q[:, 1:] 71 | uv = torch.cross(qvec, v, dim=1) 72 | uuv = torch.cross(qvec, uv, dim=1) 73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 74 | 75 | 76 | def qeuler(q, order, epsilon=0, deg=True): 77 | """ 78 | Convert quaternion(s) q to Euler angles. 79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 80 | Returns a tensor of shape (*, 3). 81 | """ 82 | assert q.shape[-1] == 4 83 | 84 | original_shape = list(q.shape) 85 | original_shape[-1] = 3 86 | q = q.view(-1, 4) 87 | 88 | q0 = q[:, 0] 89 | q1 = q[:, 1] 90 | q2 = q[:, 2] 91 | q3 = q[:, 3] 92 | 93 | if order == 'xyz': 94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 97 | elif order == 'yzx': 98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 101 | elif order == 'zxy': 102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 105 | elif order == 'xzy': 106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 109 | elif order == 'yxz': 110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 113 | elif order == 'zyx': 114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 117 | else: 118 | raise 119 | 120 | if deg: 121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi 122 | else: 123 | return torch.stack((x, y, z), dim=1).view(original_shape) 124 | 125 | 126 | # Numpy-backed implementations 127 | 128 | def qmul_np(q, r): 129 | q = torch.from_numpy(q).contiguous().float() 130 | r = torch.from_numpy(r).contiguous().float() 131 | return qmul(q, r).numpy() 132 | 133 | 134 | def qrot_np(q, v): 135 | q = torch.from_numpy(q).contiguous().float() 136 | v = torch.from_numpy(v).contiguous().float() 137 | return qrot(q, v).numpy() 138 | 139 | 140 | def qeuler_np(q, order, epsilon=0, use_gpu=False): 141 | if use_gpu: 142 | q = torch.from_numpy(q).cuda().float() 143 | return qeuler(q, order, epsilon).cpu().numpy() 144 | else: 145 | q = torch.from_numpy(q).contiguous().float() 146 | return qeuler(q, order, epsilon).numpy() 147 | 148 | 149 | def qfix(q): 150 | """ 151 | Enforce quaternion continuity across the time dimension by selecting 152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) 153 | between two consecutive frames. 154 | 155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 156 | Returns a tensor of the same shape. 157 | """ 158 | assert len(q.shape) == 3 159 | assert q.shape[-1] == 4 160 | 161 | result = q.copy() 162 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 163 | mask = dot_products < 0 164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 165 | result[1:][mask] *= -1 166 | return result 167 | 168 | 169 | def euler2quat(e, order, deg=True): 170 | """ 171 | Convert Euler angles to quaternions. 172 | """ 173 | assert e.shape[-1] == 3 174 | 175 | original_shape = list(e.shape) 176 | original_shape[-1] = 4 177 | 178 | e = e.view(-1, 3) 179 | 180 | ## if euler angles in degrees 181 | if deg: 182 | e = e * np.pi / 180. 183 | 184 | x = e[:, 0] 185 | y = e[:, 1] 186 | z = e[:, 2] 187 | 188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) 189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) 190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) 191 | 192 | result = None 193 | for coord in order: 194 | if coord == 'x': 195 | r = rx 196 | elif coord == 'y': 197 | r = ry 198 | elif coord == 'z': 199 | r = rz 200 | else: 201 | raise 202 | if result is None: 203 | result = r 204 | else: 205 | result = qmul(result, r) 206 | 207 | # Reverse antipodal representation to have a non-negative "w" 208 | if order in ['xyz', 'yzx', 'zxy']: 209 | result *= -1 210 | 211 | return result.view(original_shape) 212 | 213 | 214 | def expmap_to_quaternion(e): 215 | """ 216 | Convert axis-angle rotations (aka exponential maps) to quaternions. 217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 219 | Returns a tensor of shape (*, 4). 220 | """ 221 | assert e.shape[-1] == 3 222 | 223 | original_shape = list(e.shape) 224 | original_shape[-1] = 4 225 | e = e.reshape(-1, 3) 226 | 227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 228 | w = np.cos(0.5 * theta).reshape(-1, 1) 229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 231 | 232 | 233 | def euler_to_quaternion(e, order): 234 | """ 235 | Convert Euler angles to quaternions. 236 | """ 237 | assert e.shape[-1] == 3 238 | 239 | original_shape = list(e.shape) 240 | original_shape[-1] = 4 241 | 242 | e = e.reshape(-1, 3) 243 | 244 | x = e[:, 0] 245 | y = e[:, 1] 246 | z = e[:, 2] 247 | 248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) 249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) 250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) 251 | 252 | result = None 253 | for coord in order: 254 | if coord == 'x': 255 | r = rx 256 | elif coord == 'y': 257 | r = ry 258 | elif coord == 'z': 259 | r = rz 260 | else: 261 | raise 262 | if result is None: 263 | result = r 264 | else: 265 | result = qmul_np(result, r) 266 | 267 | # Reverse antipodal representation to have a non-negative "w" 268 | if order in ['xyz', 'yzx', 'zxy']: 269 | result *= -1 270 | 271 | return result.reshape(original_shape) 272 | 273 | 274 | def quaternion_to_matrix(quaternions): 275 | """ 276 | Convert rotations given as quaternions to rotation matrices. 277 | Args: 278 | quaternions: quaternions with real part first, 279 | as tensor of shape (..., 4). 280 | Returns: 281 | Rotation matrices as tensor of shape (..., 3, 3). 282 | """ 283 | r, i, j, k = torch.unbind(quaternions, -1) 284 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 285 | 286 | o = torch.stack( 287 | ( 288 | 1 - two_s * (j * j + k * k), 289 | two_s * (i * j - k * r), 290 | two_s * (i * k + j * r), 291 | two_s * (i * j + k * r), 292 | 1 - two_s * (i * i + k * k), 293 | two_s * (j * k - i * r), 294 | two_s * (i * k - j * r), 295 | two_s * (j * k + i * r), 296 | 1 - two_s * (i * i + j * j), 297 | ), 298 | -1, 299 | ) 300 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 301 | 302 | 303 | def quaternion_to_matrix_np(quaternions): 304 | q = torch.from_numpy(quaternions).contiguous().float() 305 | return quaternion_to_matrix(q).numpy() 306 | 307 | 308 | def quaternion_to_cont6d_np(quaternions): 309 | rotation_mat = quaternion_to_matrix_np(quaternions) 310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) 311 | return cont_6d 312 | 313 | 314 | def quaternion_to_cont6d(quaternions): 315 | rotation_mat = quaternion_to_matrix(quaternions) 316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) 317 | return cont_6d 318 | 319 | 320 | def cont6d_to_matrix(cont6d): 321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6" 322 | x_raw = cont6d[..., 0:3] 323 | y_raw = cont6d[..., 3:6] 324 | 325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) 326 | z = torch.cross(x, y_raw, dim=-1) 327 | z = z / torch.norm(z, dim=-1, keepdim=True) 328 | 329 | y = torch.cross(z, x, dim=-1) 330 | 331 | x = x[..., None] 332 | y = y[..., None] 333 | z = z[..., None] 334 | 335 | mat = torch.cat([x, y, z], dim=-1) 336 | return mat 337 | 338 | 339 | def cont6d_to_matrix_np(cont6d): 340 | q = torch.from_numpy(cont6d).contiguous().float() 341 | return cont6d_to_matrix(q).numpy() 342 | 343 | 344 | def qpow(q0, t, dtype=torch.float): 345 | ''' q0 : tensor of quaternions 346 | t: tensor of powers 347 | ''' 348 | q0 = qnormalize(q0) 349 | theta0 = torch.acos(q0[..., 0]) 350 | 351 | ## if theta0 is close to zero, add epsilon to avoid NaNs 352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) 353 | theta0 = (1 - mask) * theta0 + mask * 10e-10 354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) 355 | 356 | if isinstance(t, torch.Tensor): 357 | q = torch.zeros(t.shape + q0.shape) 358 | theta = t.view(-1, 1) * theta0.view(1, -1) 359 | else: ## if t is a number 360 | q = torch.zeros(q0.shape) 361 | theta = t * theta0 362 | 363 | q[..., 0] = torch.cos(theta) 364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) 365 | 366 | return q.to(dtype) 367 | 368 | 369 | def qslerp(q0, q1, t): 370 | ''' 371 | q0: starting quaternion 372 | q1: ending quaternion 373 | t: array of points along the way 374 | 375 | Returns: 376 | Tensor of Slerps: t.shape + q0.shape 377 | ''' 378 | 379 | q0 = qnormalize(q0) 380 | q1 = qnormalize(q1) 381 | q_ = qpow(qmul(q1, qinv(q0)), t) 382 | 383 | return qmul(q_, 384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) 385 | 386 | 387 | def qbetween(v0, v1): 388 | ''' 389 | find the quaternion used to rotate v0 to v1 390 | ''' 391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 393 | 394 | v = torch.cross(v0, v1) 395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, 396 | keepdim=True) 397 | return qnormalize(torch.cat([w, v], dim=-1)) 398 | 399 | 400 | def qbetween_np(v0, v1): 401 | ''' 402 | find the quaternion used to rotate v0 to v1 403 | ''' 404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 406 | 407 | v0 = torch.from_numpy(v0).float() 408 | v1 = torch.from_numpy(v1).float() 409 | return qbetween(v0, v1).numpy() 410 | 411 | 412 | def lerp(p0, p1, t): 413 | if not isinstance(t, torch.Tensor): 414 | t = torch.Tensor([t]) 415 | 416 | new_shape = t.shape + p0.shape 417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape)) 418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape 419 | p0 = p0.view(new_view_p).expand(new_shape) 420 | p1 = p1.view(new_view_p).expand(new_shape) 421 | t = t.view(new_view_t).expand(new_shape) 422 | 423 | return p0 + t * (p1 - p0) 424 | -------------------------------------------------------------------------------- /src/Lindyhop/models/MotionDiffusion_hand.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 S-Lab 3 | """ 4 | 5 | import matplotlib.pylab as plt 6 | import random 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import layer_norm, nn 10 | import numpy as np 11 | from torch.nn import functional 12 | 13 | import math 14 | body = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] 15 | hand = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44] 16 | hand_full = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68] 17 | 18 | def heatmap2d(arr: np.ndarray): 19 | plt.imshow(arr, cmap='viridis') 20 | plt.clim(-10, 10) 21 | plt.colorbar() 22 | plt.show() 23 | 24 | def norm_array(x): 25 | return (x-np.min(x))/(np.max(x)-np.min(x)) 26 | 27 | class PositionalEncoding(nn.Module): 28 | def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False): 29 | super().__init__() 30 | self.batch_first = batch_first 31 | 32 | self.dropout = nn.Dropout(p=dropout) 33 | 34 | pe = torch.zeros(max_len, d_model) 35 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 36 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 37 | pe[:, 0::2] = torch.sin(position * div_term) 38 | pe[:, 1::2] = torch.cos(position * div_term) 39 | 40 | 41 | for pos in range(max_len): 42 | for i in range(0, d_model-1, 2): 43 | pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model))) 44 | pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) 45 | 46 | 47 | pe = pe.unsqueeze(0) # [1, max_len, d_model] 48 | 49 | self.register_buffer('pe', pe) 50 | 51 | def forward(self, x): 52 | x = x + self.pe[:, :x.shape[1], :] 53 | return self.dropout(x) 54 | 55 | 56 | def timestep_embedding(timesteps, dim, freqs): 57 | """ 58 | Create sinusoidal timestep embeddings. 59 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 60 | These may be fractional. 61 | :param dim: the dimension of the output. 62 | :param max_period: controls the minimum frequency of the embeddings. 63 | :return: an [N x dim] Tensor of positional embeddings. 64 | """ 65 | # freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=self.latent_dim//2, dtype=torch.float32) / (self.latent_dim//2)).to(device) 66 | 67 | # timesteps= timesteps.to('cpu') 68 | args = timesteps[:, None].float() * freqs[None] 69 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 70 | if dim % 2: 71 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 72 | return embedding 73 | 74 | 75 | def set_requires_grad(nets, requires_grad=False): 76 | """Set requies_grad for all the networks. 77 | 78 | Args: 79 | nets (nn.Module | list[nn.Module]): A list of networks or a single 80 | network. 81 | requires_grad (bool): Whether the networks require gradients or not 82 | """ 83 | if not isinstance(nets, list): 84 | nets = [nets] 85 | for net in nets: 86 | if net is not None: 87 | for param in net.parameters(): 88 | param.requires_grad = requires_grad 89 | 90 | 91 | def zero_module(module): 92 | """ 93 | Zero out the parameters of a module and return it. 94 | """ 95 | for p in module.parameters(): 96 | p.detach().zero_() 97 | return module 98 | 99 | 100 | class StylizationBlock(nn.Module): 101 | 102 | def __init__(self, latent_dim, time_embed_dim, dropout): 103 | super().__init__() 104 | self.emb_layers = nn.Sequential( 105 | nn.SiLU(), 106 | nn.Linear(time_embed_dim, 2 * latent_dim), 107 | ) 108 | self.norm = nn.LayerNorm(latent_dim) 109 | self.out_layers = nn.Sequential( 110 | nn.SiLU(), 111 | nn.Dropout(p=dropout), 112 | zero_module(nn.Linear(latent_dim, latent_dim)), 113 | ) 114 | 115 | def forward(self, h, emb): 116 | """ 117 | h: B, T, D 118 | emb: B, D 119 | """ 120 | # B, 1, 2D 121 | emb_out = self.emb_layers(emb).unsqueeze(1) 122 | # scale: B, 1, D / shift: B, 1, D 123 | scale, shift = torch.chunk(emb_out, 2, dim=2) 124 | h = self.norm(h) * (1 + scale) + shift 125 | h = self.out_layers(h) 126 | return h 127 | 128 | 129 | 130 | class FFN(nn.Module): 131 | 132 | def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim): 133 | super().__init__() 134 | self.linear1 = nn.Linear(latent_dim, ffn_dim) 135 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) 136 | self.activation = nn.GELU() 137 | self.dropout = nn.Dropout(dropout) 138 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 139 | 140 | def forward(self, x, emb): 141 | y = self.linear2(self.dropout(self.activation(self.linear1(x)))) 142 | y = x + self.proj_out(y, emb) 143 | return y 144 | 145 | 146 | 147 | class TemporalSelfAttention(nn.Module): 148 | 149 | def __init__(self, seq_len, latent_dim, num_head, dropout, time_embed_dim): 150 | super().__init__() 151 | self.num_head = num_head 152 | self.norm = nn.LayerNorm(latent_dim) 153 | self.query = nn.Linear(latent_dim, latent_dim) 154 | self.key = nn.Linear(latent_dim, latent_dim) 155 | self.value = nn.Linear(latent_dim, latent_dim) 156 | self.dropout = nn.Dropout(dropout) 157 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 158 | 159 | def forward(self, x, emb, src_mask, eps=1e-8): 160 | """ 161 | x: B, T, D 162 | """ 163 | B, T, D = x.shape 164 | H = self.num_head 165 | # B, T, 1, D 166 | query = self.query(self.norm(x)).unsqueeze(2) 167 | # B, 1, T, D 168 | key = self.key(self.norm(x)).unsqueeze(1) 169 | query = query.view(B, T, H, -1) 170 | key = key.view(B, T, H, -1) 171 | # B, T, T, H 172 | attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / max(math.sqrt(D // H), eps) 173 | # attention = attention * src_mask.unsqueeze(-1) 174 | weight = self.dropout(F.softmax(attention, dim=2)) 175 | value = self.value(self.norm(x)).view(B, T, H, -1) 176 | y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D) 177 | y = x + self.proj_out(y, emb) 178 | return y, attention 179 | 180 | class TemporalCrossAttention(nn.Module): 181 | 182 | def __init__(self, seq_len, latent_dim, mot1_latent_dim, num_head, dropout, time_embed_dim): 183 | super().__init__() 184 | self.num_head = num_head 185 | self.norm = nn.LayerNorm(latent_dim) 186 | self.mot1_norm = nn.LayerNorm(mot1_latent_dim) 187 | self.query = nn.Linear(latent_dim, latent_dim) 188 | self.key = nn.Linear(mot1_latent_dim, latent_dim) 189 | self.value = nn.Linear(mot1_latent_dim, latent_dim) 190 | self.dropout = nn.Dropout(dropout) 191 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 192 | 193 | def forward(self, x, xf, src_mask, emb, eps=1e-8): 194 | """ 195 | x: B, T, D 196 | xf: B, N, L 197 | """ 198 | B, T, D = x.shape 199 | N = xf.shape[1] 200 | H = self.num_head 201 | # B, T, 1, D 202 | query = self.query(self.norm(x)).unsqueeze(2) 203 | # B, 1, N, D 204 | key = self.key(self.mot1_norm(xf)).unsqueeze(1) 205 | query = query.view(B, T, H, -1) 206 | key = key.view(B, N, H, -1) 207 | # B, T, N, H 208 | attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / max(math.sqrt(D // H), eps) 209 | attention = attention * src_mask 210 | weight = self.dropout(F.softmax(attention, dim=2)) 211 | value = self.value(self.mot1_norm(xf)).view(B, N, H, -1) 212 | y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D) 213 | y = x + self.proj_out(y, emb) 214 | return y, attention 215 | 216 | class TemporalDiffusionTransformerDecoderLayer(nn.Module): 217 | 218 | def __init__(self, 219 | seq_len=60, 220 | latent_dim=32, 221 | mot1_latent_dim=512, 222 | time_embed_dim=128, 223 | ffn_dim=256, 224 | num_head=4, 225 | dropout=0.1): 226 | super().__init__() 227 | self.sa_block = TemporalSelfAttention( 228 | seq_len, latent_dim, num_head, dropout, time_embed_dim) 229 | self.ca_block = TemporalCrossAttention( 230 | seq_len, latent_dim, mot1_latent_dim, num_head, dropout, time_embed_dim) 231 | self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim) 232 | 233 | def forward(self, x, xf, emb, src_mask): 234 | x, s_attn = self.sa_block(x, emb, src_mask) 235 | x, c_attn = self.ca_block(x, xf, src_mask, emb) 236 | x = self.ffn(x, emb) 237 | return x, s_attn, c_attn 238 | 239 | 240 | class DiffusionTransformer(nn.Module): 241 | def __init__(self, 242 | device= 'cuda', 243 | num_frames=100, 244 | num_jts = 11, 245 | input_condn_feats=3, 246 | input_feats=3, 247 | latent_dim=32, 248 | ff_size=1024, 249 | num_layers=8, 250 | num_heads=4, 251 | dropout=0.05, 252 | activations="gelu", 253 | **kargs): 254 | super().__init__() 255 | 256 | self.num_frames = num_frames 257 | self.num_jts = num_jts 258 | self.latent_dim = latent_dim 259 | self.ff_size = ff_size 260 | self.num_layers = num_layers 261 | self.num_heads = num_heads 262 | self.freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=self.latent_dim//2, dtype=torch.float32) / (self.latent_dim//2)).to(device) 263 | self.dropout = dropout 264 | self.activation = activations 265 | self.input_condn_feats = input_condn_feats 266 | self.input_feats = input_feats 267 | self.time_embed_dim = latent_dim 268 | self.spatio_temp = self.num_frames * 2 * self.num_jts 269 | 270 | # encode motion 1 271 | self.motion1_pre_proj = nn.Linear(self.input_feats, self.latent_dim) 272 | self.m1_temporal_pos_encoder = PositionalEncoding(d_model=self.latent_dim, dropout=self.dropout, max_len=self.spatio_temp) 273 | mot1TransEncoderLayer = nn.TransformerEncoderLayer( 274 | d_model=latent_dim, 275 | nhead=num_heads, 276 | dim_feedforward=ff_size, 277 | dropout=dropout, 278 | batch_first=True, 279 | activation='gelu') 280 | self.mot1TransEncoder = nn.TransformerEncoder( 281 | mot1TransEncoderLayer, 282 | num_layers=2) 283 | self.mot1_ln = nn.LayerNorm(latent_dim) 284 | #Classifier-free guidance 285 | # self.null_cond = nn.Parameter(torch.randn(self.num_frames * self.num_jts, latent_dim)) 286 | 287 | # Time Embedding 288 | self.time_embed = nn.Sequential( 289 | nn.Linear(self.latent_dim, self.time_embed_dim), 290 | nn.SiLU(), 291 | nn.Linear(self.time_embed_dim, self.time_embed_dim), 292 | ) 293 | 294 | # motion2 decoding 295 | self.motion2_pre_proj = nn.Linear(self.input_feats, self.latent_dim) 296 | self.m2_temporal_pos_encoder = PositionalEncoding(d_model=self.latent_dim, dropout=self.dropout, max_len=self.spatio_temp) 297 | self.temporal_decoder_blocks = nn.ModuleList() 298 | for i in range(num_layers): 299 | self.temporal_decoder_blocks.append( 300 | TemporalDiffusionTransformerDecoderLayer( 301 | seq_len=self.spatio_temp, 302 | latent_dim=latent_dim, 303 | mot1_latent_dim=latent_dim, 304 | time_embed_dim=self.time_embed_dim, 305 | ffn_dim=ff_size, 306 | num_head=num_heads, 307 | dropout=dropout 308 | ) 309 | ) 310 | # Output Module 311 | self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats)) 312 | 313 | 314 | def generate_src_mask(self, tgt): 315 | length = tgt.size(1) 316 | src_mask = (1 - torch.triu(torch.ones(1, length, length), diagonal=1)) 317 | return src_mask 318 | 319 | # def forward(self, motion2, timesteps, length=None, motion1=None, xf_out=None, contact_map=None): 320 | def forward(self, motion2, timesteps, motion1=None, spatial_guidance=None): 321 | """ 322 | x: B, T, D 323 | """ 324 | B, T, D = motion1.shape 325 | contact_distance = motion1[:, :, -4:] 326 | rh_rh = contact_distance[:,:, 0] == 1 327 | rh_lh = contact_distance[:,:, 1] == 1 328 | lh_rh = contact_distance[:,:, 2] == 1 329 | lh_lh = contact_distance[:,:, 3] == 1 330 | rh_pose1 = torch.logical_or(rh_lh, rh_rh) 331 | lh_pose1 = torch.logical_or(lh_lh, lh_rh) 332 | m1 = self.motion1_pre_proj(motion1[:, :, :-4].reshape(B, T, -1, 3)) # GCN 333 | m2 = self.motion2_pre_proj(motion2.reshape(B, T, -1, 3)) # GCN 334 | m1 = m1.reshape(B, T*2*self.num_jts, -1) 335 | m2 = m2.reshape(B, T*2*self.num_jts, -1) 336 | src_mask = torch.zeros(B, T, self.num_jts*2).to(m1.device).float() 337 | src_mask[:, :, 0] = rh_pose1 338 | src_mask[:, :, 1] = rh_pose1 339 | src_mask[:, :, 2] = rh_pose1 340 | src_mask[:, :, 3] = rh_pose1 341 | src_mask[:, :, 4] = rh_pose1 342 | src_mask[:, :, 5] = rh_pose1 343 | src_mask[:, :, 6] = rh_pose1 344 | src_mask[:, :, 7] = rh_pose1 345 | src_mask[:, :, 8] = rh_pose1 346 | src_mask[:, :, 9] = rh_pose1 347 | src_mask[:, :, 10] = rh_pose1 348 | src_mask[:, :, 11] = lh_pose1 349 | src_mask[:, :, 12] = lh_pose1 350 | src_mask[:, :, 13] = lh_pose1 351 | src_mask[:, :, 14] = lh_pose1 352 | src_mask[:, :, 15] = lh_pose1 353 | src_mask[:, :, 16] = lh_pose1 354 | src_mask[:, :, 17] = lh_pose1 355 | src_mask[:, :, 18] = lh_pose1 356 | src_mask[:, :, 19] = lh_pose1 357 | src_mask[:, :, 20] = lh_pose1 358 | src_mask[:, :, 21] = lh_pose1 359 | src_mask = src_mask.reshape(B, -1) 360 | src_mask = torch.repeat_interleave(src_mask.unsqueeze(-1), src_mask.shape[1], axis=-1) 361 | src_mask = torch.repeat_interleave(src_mask.unsqueeze(-1), self.num_heads, axis=-1) 362 | # src_mask = self.generate_src_mask(m2).to(m2.device) 363 | 364 | m1_pe = self.m1_temporal_pos_encoder(m1) 365 | m1_cond = self.mot1_ln(self.mot1TransEncoder(m1_pe)) 366 | # null_cond = torch.repeat_interleave( 367 | # self.null_cond.to(m2.device).unsqueeze(0), B, dim=0) 368 | # m1_enc = m1_cond if random.random() > 0.25 else null_cond 369 | m1_enc = m1_cond 370 | m2_pe = self.m2_temporal_pos_encoder(m2) 371 | emb = self.time_embed(timestep_embedding( 372 | timesteps, self.latent_dim, self.freqs) ) 373 | h_pe = m2_pe 374 | for module in self.temporal_decoder_blocks: 375 | h_pe, s_attn, c_attn = module(h_pe, m1_enc, emb, src_mask) 376 | output = self.out(h_pe).view(B, T, -1).contiguous() 377 | # if timesteps == 400: 378 | # c_attn_map = c_attn[0,:,:,0].cpu().detach().numpy() 379 | # # ax = sns.heatmap(c_attn_map, vmin=-15, vmax=15, linewidth=2) 380 | # # plt.show() 381 | # heatmap2d(c_attn_map) 382 | # tmp=1 383 | return output, s_attn, c_attn 384 | -------------------------------------------------------------------------------- /src/tools/bookkeeper.py: -------------------------------------------------------------------------------- 1 | # import pickle as torch 2 | import json 3 | import os 4 | import sys 5 | from datetime import datetime 6 | from tqdm import tqdm 7 | import copy 8 | import random 9 | import numpy as np 10 | from pathlib import Path 11 | import argparse 12 | import argunparse 13 | import warnings 14 | from prettytable import PrettyTable 15 | 16 | # from tensorboardX import SummaryWriter 17 | import torch 18 | 19 | import pdb 20 | 21 | def get_args_update_dict(args): 22 | args_update_dict = {} 23 | for string in sys.argv: 24 | string = ''.join(string.split('-')) 25 | if string in args: 26 | args_update_dict.update({string: args.__dict__[string]}) 27 | return args_update_dict 28 | 29 | def accumulate_grads(model, grads_list): 30 | if grads_list: 31 | grads_list = [param.grad.data+old_grad.clone() for param, old_grad in zip(model.parameters(), grads_list)] 32 | else: 33 | grads_list += [param.grad.data for param in model.parameters()] 34 | return grads_list 35 | 36 | def save_grads(val, file_path): 37 | torch.save(val, open(file_path, 'wb')) 38 | 39 | def load_grads(file_path): 40 | return torch.load(open(file_path)) 41 | 42 | class TensorboardWrapper(): 43 | ''' 44 | Wrapper to add values to tensorboard using a dictionary of values 45 | ''' 46 | def __init__(self, log_dir): 47 | self.log_dir = log_dir 48 | self.writer = SummaryWriter(log_dir=self.log_dir, comment='NA') 49 | 50 | def __call__(self, write_dict): 51 | for key in write_dict: 52 | for value in write_dict[key]: 53 | getattr(self.writer, 'add_' + key)(*value) 54 | 55 | class BookKeeper(): 56 | '''BookKeeper 57 | if load_pretrained_model = True 58 | bookKeeper will not update args and will also call _new_exp 59 | 60 | TODO: add documentation 61 | TODO: add save_optimizer_args as well 62 | TODO: choice of score kind to decide early-stopping (currently dev is default) 63 | Required properties in args 64 | - load 65 | - seed 66 | - save_dir 67 | - num_epochs 68 | - cuda 69 | - save_model 70 | - greedy_save 71 | - stop_thresh 72 | - eps 73 | - early stopping 74 | ''' 75 | def __init__(self, args, args_subset, 76 | args_ext='args.args', 77 | name_ext='name.name', 78 | weights_ext='weights.p', 79 | res_ext='res.json', 80 | log_ext='log.log', 81 | script_ext='script.sh', 82 | args_dict_update={}, 83 | res={'train':[], 'val':[], 'test':[]}, 84 | tensorboard=None, 85 | load_pretrained_model=False): 86 | 87 | self.args = args 88 | self.save_flag = False 89 | self.args_subset = args_subset 90 | self.args_dict_update = args_dict_update 91 | 92 | self.args_ext = args_ext.split('.') 93 | self.name_ext = name_ext.split('.') 94 | self.weights_ext = weights_ext.split('.') 95 | self.res_ext = res_ext.split('.') 96 | self.log_ext = log_ext.split('.') 97 | self.script_ext = script_ext.split('.') 98 | 99 | ## params for saving/notSaving models 100 | self.stop_count = 0 101 | 102 | ## init empty results 103 | self.res = res 104 | if 'dev_key' in args: 105 | self.dev_key = args.dev_key 106 | self.dev_sign = args.dev_sign 107 | else: 108 | self.dev_key = 'val' 109 | self.dev_sign = 1 110 | self.best_dev_score = np.inf * self.dev_sign 111 | 112 | self.load_pretrained_model = load_pretrained_model 113 | self.last_epoch = 0 114 | if self.args.load: 115 | if os.path.isfile(self.args.load): 116 | ## update the save_dir if the files have moved 117 | self.save_dir = Path(args.load).parent.parent.as_posix() 118 | # self.save_dir = args.save_dir 119 | 120 | ## load Name 121 | self.name = self._load_name() 122 | 123 | ## load args 124 | self._load_args(args_dict_update) 125 | 126 | # if not self.load_pretrained_model: 127 | # ## Serialize and save args 128 | # self._save_args() 129 | 130 | ## load results 131 | self.res = self._load_res() 132 | self.last_epoch = self.res['epoch'][-1] 133 | 134 | else: 135 | ## run a new experiment 136 | self._new_exp() 137 | 138 | # if self.load_pretrained_model: 139 | # self._new_exp() 140 | 141 | ## Tensorboard 142 | if tensorboard: 143 | self.tensorboard = TensorboardWrapper(log_dir=(Path(self.save_dir)/Path(self.name.name+'tb')).as_posix()) 144 | else: 145 | self.tensorboard = None 146 | 147 | self._set_seed() 148 | 149 | def _set_seed(self): 150 | ## seed numpy and torch 151 | random.seed(self.args.seed) 152 | np.random.seed(self.args.seed) 153 | torch.manual_seed(self.args.seed) 154 | torch.cuda.manual_seed_all(self.args.seed) 155 | torch.cuda.manual_seed(self.args.seed) 156 | #torch.backends.cudnn.deterministic = True 157 | #torch.backends.cudnn.benchmark = False 158 | 159 | ''' 160 | Stuff to do for a new experiment 161 | ''' 162 | def _new_exp(self): 163 | ## update the experiment number 164 | self._update_exp() 165 | 166 | self.save_dir = self.args.save_dir 167 | self.name = Name(self.args, *self.args_subset) 168 | 169 | ## save name 170 | self._save_name() 171 | 172 | ## update args 173 | self.args.__dict__.update(self.args_dict_update) 174 | 175 | ## Serialize and save args 176 | self._save_args() 177 | 178 | ## save script 179 | #self._save_script() ## not functional yet. needs some work 180 | 181 | ## reinitialize results to empty 182 | self.res = {key:[] for key in self.res} 183 | 184 | def _update_exp(self): 185 | if self.args.exp is not None: 186 | exp = 0 187 | exp_file = '.experiments' 188 | if not os.path.exists(exp_file): 189 | with open(exp_file, 'w') as f: 190 | f.writelines([f'{exp}\n']) 191 | else: 192 | with open(exp_file, 'r') as f: 193 | lines = f.readlines() 194 | exp = int(lines[0].strip()) 195 | exp += 1 196 | with open(exp_file, 'w') as f: 197 | f.writelines([f'{exp}\n']) 198 | else: 199 | exp = 0 200 | print(f'Experiment Number: {exp}') 201 | self.args.__dict__.update({'exp':exp}) 202 | 203 | def _load_name(self): 204 | name_filepath = '_'.join(self.args.load.split('_')[:-1] + ['.'.join(self.name_ext)]) 205 | return torch.load(open(name_filepath, 'rb')) 206 | 207 | def _save_name(self): 208 | name_filepath = self.name(self.name_ext[0], self.name_ext[1], self.save_dir) 209 | torch.save(self.name, open(name_filepath, 'wb')) 210 | 211 | def _load_res(self): 212 | res_filepath = self.name(self.res_ext[0], self.res_ext[1], self.save_dir) 213 | # res_filepath = '_'.join(self.args.load.split('_')[:-1] + ['.'.join(self.res_ext)]) 214 | if os.path.exists(res_filepath): 215 | print('Results Loaded') 216 | return json.load(open(res_filepath)) 217 | else: 218 | warnings.warn('Could not find result file') 219 | return self.res 220 | 221 | def _save_res(self): 222 | res_filepath = self.name(self.res_ext[0], self.res_ext[1], self.save_dir) 223 | json.dump(self.res, open(res_filepath,'w')) 224 | 225 | def update_res(self, res): 226 | for key in res: 227 | if key in self.res: 228 | self.res[key].append(res[key]) 229 | else: 230 | self.res[key] = [res[key]] 231 | 232 | def update_tb(self, write_dict): 233 | if self.tensorboard: 234 | self.tensorboard(write_dict) 235 | else: 236 | warnings.warn('TensorboardWrapper not declared') 237 | 238 | def print_res(self, epoch, key_order=['train', 'val', 'test'], metric_order=[], exp=0, lr=None, fmt='{:.16f}'): 239 | print_str = "exp: {}, epoch: {}, lr:{}" 240 | table = PrettyTable([''] + key_order) 241 | table_str = ['loss'] + [fmt.format(self.res[key][-1]) for key in key_order] ## loss 242 | table.add_row(table_str) 243 | for metric in metric_order: 244 | table_str = [metric] + [fmt.format(self.res['{}_{}'.format(key, metric)][-1]) for key in key_order] 245 | table.add_row(table_str) 246 | 247 | if isinstance(lr, list): 248 | lr = lr[0] 249 | tqdm.write(print_str.format(exp, epoch, lr)) 250 | tqdm.write(table.__str__()) 251 | 252 | def print_res_archive(self, epoch, key_order=['train', 'val', 'test'], exp=0, lr=None, fmt='{:.9f}'): 253 | print_str = ', '.join(["exp: {}, epch: {}, lr:{}, "] + ["{}: {}".format(key,fmt) for key in key_order]) 254 | result_list = [self.res[key][-1] for key in key_order] 255 | if isinstance(lr, list): 256 | lr = lr[0] 257 | tqdm.write(print_str.format(exp, epoch, lr, *result_list)) 258 | 259 | def _load_args(self, args_dict_update): 260 | args_filepath = self.name(self.args_ext[0], self.args_ext[1], self.save_dir) 261 | # args_filepath = '_'.join(self.args.load.split('_')[:-1] + ['.'.join(self.args_ext)]) 262 | if os.path.isfile(args_filepath): 263 | args_dict = json.load(open(args_filepath)) 264 | ## update load path and cuda device to use 265 | args_dict.update({'load':self.args.load, 266 | 'cuda':self.args.cuda, 267 | 'save_dir':self.save_dir}) 268 | ## any new argument to be updated 269 | args_dict.update(args_dict_update) 270 | 271 | self.args.__dict__.update(args_dict) 272 | 273 | def _save_args(self): 274 | args_filepath = self.name(self.args_ext[0], self.args_ext[1], self.save_dir) 275 | json.dump(self.args.__dict__, open(args_filepath, 'w')) 276 | 277 | def _save_script(self): 278 | ''' 279 | Not functional 280 | ''' 281 | args_filepath = self.name(self.script_ext[0], self.script_ext[1], self.save_dir) 282 | unparser = argunparse.ArgumentUnparser() 283 | options = get_args_update_dict(self.args)#self.args.__dict__ 284 | args = {} 285 | script = unparser.unparse_to_list(*args, **options) 286 | script = ['python', sys.argv[0]] + script 287 | script = ' '.join(script) 288 | with open(args_filepath, 'w') as fp: 289 | fp.writelines(script) 290 | 291 | def _load_model(self, model, model_id): 292 | # weights_path = self.name(self.weights_ext[0], self.weights_ext[1], self.save_dir) 293 | weights_path = self.name(self.args.load.split('_')[-1].split('.')[0], self.weights_ext[1], self.save_dir) 294 | m = torch.load(open(weights_path, 'rb')) 295 | model.load_state_dict(m[model_id]) 296 | print('Model loaded') 297 | 298 | @staticmethod 299 | def load_pretrained_model(model, path2model): 300 | model.load_state_dict(torch.load(open(path2model, 'rb'))) 301 | return model 302 | 303 | def _save_model(self, model_state_dict, out, model_id='model_pose'): 304 | weights_path = self.name(self.weights_ext[0], self.weights_ext[1], self.save_dir) 305 | f = open(weights_path, 'wb') 306 | out.update({model_id: model_state_dict}) 307 | torch.save(out, f) 308 | f.close() 309 | 310 | def _copy_best_model(self, model): 311 | if isinstance(model, torch.nn.DataParallel): 312 | self.best_model = copy.deepcopy(model.module.state_dict()) 313 | else: 314 | self.best_model = copy.deepcopy(model.state_dict()) 315 | 316 | def _start_log(self): 317 | with open(self.name(self.log_ext[0],self.log_ext[1], self.save_dir), 'w') as f: 318 | f.write("S: {}\n".format(str(datetime.now()))) 319 | 320 | def _stop_log(self): 321 | with open(self.name(self.log_ext[0],self.log_ext[1], self.save_dir), 'r') as f: 322 | lines = f.readlines() 323 | if len(lines) > 1: ## this has already been sampled before 324 | lines = lines[0:1] + ["E: {}\n".format(str(datetime.now()))] 325 | else: 326 | lines.append("E: {}\n".format(str(datetime.now()))) 327 | with open(self.name(self.log_ext[0],self.log_ext[1], self.save_dir), 'w') as f: 328 | f.writelines(lines) 329 | 330 | def stop_training(self, model, model_id, epoch, out_dict, warmup=False): 331 | ## copy the best model 332 | if self.dev_sign * self.res[self.dev_key][-1] < self.dev_sign * self.best_dev_score and (not warmup): 333 | self._copy_best_model(model) 334 | self.best_dev_score = self.res[self.dev_key][-1] 335 | self.save_flag = True 336 | else: 337 | self.save_flag = False 338 | 339 | ## debug mode with no saving 340 | if not self.args.save_model: 341 | self.save_flag = False 342 | 343 | if self.args.overfit: 344 | self._copy_best_model(model) 345 | self.save_flag = True 346 | 347 | if epoch % 10 == 0: 348 | self.save_flag = True 349 | 350 | if self.save_flag: 351 | tqdm.write('Saving Model at epoch {}'.format(epoch)) 352 | self._copy_best_model(model) 353 | self._save_model(self.best_model, out_dict, model_id) 354 | 355 | ## early_stopping 356 | if self.args.early_stopping and len(self.res['train'])>=2 and not self.args.overfit: 357 | if (self.dev_sign*(self.res[self.dev_key][-2] - self.args.eps) < self.dev_sign * self.res[self.dev_key][-1]): 358 | self.stop_count += 1 359 | else: 360 | self.stop_count = 0 361 | 362 | if self.stop_count >= self.args.stop_thresh: 363 | print('Validation Loss is increasing') 364 | ## save the best model now 365 | if self.args.save_model: 366 | print('Saving Model by early stopping') 367 | self.save_flag = True 368 | self._copy_best_model(model) 369 | self._save_model(self.best_model, out_dict, model_id) 370 | return self.save_flag 371 | 372 | ## end of training loop 373 | if epoch == self.args.num_epochs-1 and self.args.save_model: 374 | print('Saving model after exceeding number of epochs') 375 | self.save_flag = True 376 | self._copy_best_model(model) 377 | self._save_model(self.best_model, out_dict, model_id) 378 | 379 | return self.save_flag 380 | 381 | 382 | class Name(object): 383 | ''' Create a name based on hyper-parameters, other arguments 384 | like number of epochs or error rates 385 | 386 | Arguments: 387 | path2file/...argname_value_..._outputkind.ext 388 | 389 | args: Namespace(argname,value, ....) generally taken from an argparser variable 390 | argname: Hyper-parameters (i.e. model structure) 391 | value: Values of the corresponding Hyper-parameters 392 | 393 | path2file: set as './' by default and decides the path where the file is to be stored 394 | outputkind: what is the kind of output 'err', 'vis', 'cpk' or any other acronym given as a string 395 | ext: file type given as a string 396 | 397 | *args_subset: The subset of arguments to be used and its order 398 | 399 | Methods: 400 | Name.dir(path2file): creates a directory at `path2file` with a name derived from arguments 401 | but outputkind and ext are omitted 402 | ''' 403 | 404 | def __init__(self, args, *args_subset): 405 | self.name = '' 406 | args_dict = vars(args) 407 | args_subset = list(args_subset) 408 | 409 | ## if args_subset is not provided take all the keys from args_dict 410 | if not args_subset: 411 | args_subset = list(args_dict.keys()) 412 | 413 | ## if args_subset is derived from an example name 414 | for i, arg_sub in enumerate(args_subset): 415 | for arg in args_dict: 416 | if arg_sub == ''.join(arg.split('_')): 417 | args_subset[i] = arg 418 | 419 | ## If args_subset is empty exit 420 | assert args_subset, 'Subset of arguments to be chosen is empty' 421 | 422 | ## Scan through required arguments in the name 423 | for arg in args_subset: 424 | if arg not in args_dict: 425 | warnings.warn('Key %s does not exist. Skipping...'%(arg)) 426 | else: 427 | self.name += '%s_%s_' % (''.join(arg.split('_')), '-'.join(str(args_dict[arg]).split('.'))) 428 | 429 | def dir(self, path2file='./'): 430 | try: 431 | os.makedirs(os.path.join(path2file, self.name[:-1])) 432 | except OSError: 433 | if not os.path.isdir(path2file): 434 | raise 'Directory could not be created. Check if you have the required permissions to make changes at the given path.' 435 | return os.path.join(path2file, self.name[:-1]) 436 | 437 | 438 | def __call__(self, outputkind, ext, path2file='./'): 439 | try: 440 | os.makedirs(os.path.join(path2file, self.name)) 441 | except OSError: 442 | if not os.path.isdir(path2file): 443 | raise 'Directory could not be created. Check if you have the required permissions to make changes at the given path.' 444 | return os.path.join(path2file, self.name, self.name + '%s.%s' %(outputkind,ext)) 445 | -------------------------------------------------------------------------------- /src/Lindyhop/train_hand_diffusion.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import shutil 7 | import sys 8 | sys.path.append('.') 9 | sys.path.append('..') 10 | import time 11 | import torch 12 | torch.cuda.empty_cache() 13 | import torch.nn as nn 14 | 15 | from cmath import nan 16 | from collections import OrderedDict 17 | from datetime import datetime 18 | from torch import optim 19 | from torch.utils.data import DataLoader 20 | from tqdm import tqdm 21 | 22 | from src.Lindyhop.argUtils import argparseNloop 23 | from src.Lindyhop.LindyHop_dataloader import LindyHopDataset 24 | from src.Lindyhop.models.MotionDiffusion_hand import * 25 | from src.Lindyhop.models.Gaussian_diffusion import ( 26 | GaussianDiffusion, 27 | get_named_beta_schedule, 28 | create_named_schedule_sampler, 29 | ModelMeanType, 30 | ModelVarType, 31 | LossType 32 | ) 33 | from src.Lindyhop.skeleton import * 34 | from src.Lindyhop.visualizer import plot_contacts3D 35 | from src.tools.bookkeeper import * 36 | from src.tools.calculate_ev_metrics import * 37 | from src.tools.transformations import * 38 | from src.tools.utils import makepath 39 | 40 | 41 | def dist(x, y): 42 | # return torch.mean(x - y) 43 | return torch.mean(torch.cdist(x, y, p=2)) 44 | 45 | def initialize_weights(m): 46 | std_dev = 0.02 47 | if isinstance(m, nn.Linear): 48 | nn.init.normal_(m.weight, std=std_dev) 49 | if m.bias is not None: 50 | nn.init.normal_(m.bias, std=std_dev) 51 | # nn.init.constant_(m.bias.data, 1e-5) 52 | elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d): 53 | torch.nn.init.normal_(m.weight, std=std_dev) 54 | if m.bias is not None: 55 | torch.nn.init.normal_(m.bias, std=std_dev) 56 | # nn.init.constant_(m.bias.data, 1e-5) 57 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 58 | nn.init.normal_(m.weight, std=std_dev) 59 | if m.bias is not None: 60 | nn.init.normal_(m.bias, std=std_dev) 61 | 62 | class Trainer: 63 | def __init__(self, args, is_train=True, split='test', JT_POSITION=False, num_jts = 69): 64 | torch.manual_seed(args.seed) 65 | self.model_path = args.model_path 66 | makepath(args.work_dir, isfile=False) 67 | use_cuda = torch.cuda.is_available() 68 | if use_cuda: 69 | torch.cuda.empty_cache() 70 | self.device = torch.device("cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu") 71 | gpu_brand = torch.cuda.get_device_name(args.cuda) if use_cuda else None 72 | gpu_count = torch.cuda.device_count() if args.use_multigpu else 1 73 | print('Using %d CUDA cores [%s] for training!' % (gpu_count, gpu_brand)) 74 | args_subset = ['exp', 'model', 'batch_size', 'frames'] 75 | self.book = BookKeeper(args, args_subset) 76 | self.args = self.book.args 77 | self.batch_size = args.batch_size 78 | self.curriculum = args.curriculum 79 | self.scale = args.scale 80 | self.dtype = torch.float32 81 | self.epochs_completed = self.book.last_epoch 82 | self.frames = args.frames 83 | self.model = args.model 84 | self.lambda_loss = args.lambda_loss 85 | self.testtime_split = split 86 | self.num_jts = num_jts 87 | self.model_pose = eval(args.model)(device=self.device, 88 | num_frames=self.frames, 89 | num_jts=self.num_jts, 90 | input_feats=args.hand_out_feats, 91 | latent_dim=args.d_modelhand, 92 | num_heads=args.num_head_hands, 93 | num_layers=args.num_layer_hands, 94 | ff_size=args.d_ffhand, 95 | activations=args.activations 96 | ).to(self.device).float() 97 | self.diffusion_steps = args.diffusion_steps 98 | self.beta_scheduler = args.noise_schedule 99 | self.betas = get_named_beta_schedule(self.beta_scheduler, self.diffusion_steps) 100 | self.diffusion = GaussianDiffusion( 101 | betas=self.betas, 102 | model_mean_type=ModelMeanType.START_X, 103 | model_var_type=ModelVarType.FIXED_SMALL, 104 | loss_type=LossType.MSE 105 | ) 106 | self.sampler_name = args.sampler 107 | self.sampler = create_named_schedule_sampler(self.sampler_name, self.diffusion) 108 | self.model_pose.apply(initialize_weights) 109 | self.optimizer_model_pose = eval(args.optimizer)(self.model_pose.parameters(), lr = args.lr) 110 | self.scheduler_pose = eval(args.scheduler)(self.optimizer_model_pose, step_size=args.stepsize, gamma=args.gamma) 111 | # self.scheduler_pose = eval(args.scheduler)(self.optimizer_model_pose, factor=args.factor, patience=args.patience, threshold= args.threshold, min_lr = 2e-7) 112 | self.skel = InhouseStudioSkeleton() 113 | self.mse_criterion = torch.nn.MSELoss() 114 | self.l1_criterion = torch.nn.L1Loss() 115 | self.BCE_criterion = torch.nn.BCELoss() 116 | print(args.model, 'Model Created') 117 | if args.load: 118 | print('Loading Model', args.model) 119 | self.book._load_model(self.model_pose, 'model_pose') 120 | print('Loading the data') 121 | if is_train: 122 | self.load_data(args) 123 | else: 124 | self.load_data_testtime(args) 125 | self.mean_var_norm = torch.load(args.mean_var_norm) 126 | 127 | def load_data_testtime(self, args): 128 | self.ds_data = LindyHopDataset(args, window_size=self.frames, split=self.testtime_split) 129 | 130 | def load_data(self, args): 131 | ds_train = LindyHopDataset(args, window_size=self.frames, split='train') 132 | self.ds_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) 133 | print('Train set loaded. Size=', len(self.ds_train.dataset)) 134 | ds_val = LindyHopDataset(args, window_size=self.frames, split='test') 135 | self.ds_val = DataLoader(ds_val, batch_size=1, shuffle=False, num_workers=0, drop_last=True) 136 | print('Validation set loaded. Size=', len(self.ds_val.dataset)) 137 | 138 | 139 | def calc_loss(self, num_epoch): 140 | bs, seq, dim = self.generated.shape 141 | pos_loss = self.lambda_loss['pos'] * self.mse_criterion(self.input, self.generated) 142 | vel_gt = self.input[:, 1:] - self.input[:, :-1] 143 | vel_gen = self.generated[:, 1:] - self.generated[:, :-1] 144 | velocity_loss = self.lambda_loss['vel'] * self.mse_criterion(vel_gt, vel_gen) 145 | acc_gt = vel_gt[:, 1:] - vel_gt[:, :-1] 146 | acc_gen = vel_gen[:, 1:] - vel_gen[:, :-1] 147 | acc_loss = self.lambda_loss['vel'] * self.mse_criterion(acc_gt, acc_gen) 148 | gt_pose = self.input.reshape(bs, seq, -1, 3) 149 | gen_pose = self.generated.reshape(bs, seq, -1, 3) 150 | p1_rhand = self.p1_rhand_pos 151 | p2gt_rhand = self.p2_rhand_pos 152 | p2gen_rhand = gen_pose[:, :, :11] 153 | p1_lhand = self.p1_lhand_pos 154 | p2gt_lhand = self.p2_lhand_pos 155 | p2gen_lhand = gen_pose[:, :, 11:] 156 | 157 | bone_len_gt = (gt_pose[:, :, 1:] - gt_pose[:, :, [self.skel.parent_fingers[x] for x in range(1, 2*self.num_jts)]]).norm(dim=-1) 158 | bone_len_gen = (gen_pose[:, :, 1:] - gen_pose[:, :, [self.skel.parent_fingers[x] for x in range(1, 2*self.num_jts)]]).norm(dim=-1) 159 | bone_len_consistency_loss = self.lambda_loss['bone'] * self.mse_criterion(bone_len_gt, bone_len_gen) 160 | 161 | loss_logs = [pos_loss, velocity_loss, acc_loss, bone_len_consistency_loss] 162 | 163 | # #include the interaction loss 164 | 165 | self.lambda_loss['in'] = 10.0 166 | rh_rh = self.contact_map[:,:, 0] == 1 167 | rh_lh = self.contact_map[:,:, 1] == 1 168 | lh_rh = self.contact_map[:,:, 2] == 1 169 | lh_lh = self.contact_map[:,:, 3] == 1 170 | 171 | interact_loss = self.lambda_loss['in'] * torch.mean( rh_lh * ((p1_rhand - p2gt_lhand).norm(dim=-1) - (p1_rhand - p2gen_lhand).norm(dim=-1)).norm(dim=-1) + 172 | rh_rh * ((p1_rhand - p2gt_rhand).norm(dim=-1) - (p1_rhand - p2gen_rhand).norm(dim=-1)).norm(dim=-1) + 173 | lh_rh * ((p1_lhand - p2gt_rhand).norm(dim=-1) - (p1_lhand - p2gen_rhand).norm(dim=-1)).norm(dim=-1) + 174 | lh_lh * ((p1_lhand - p2gt_lhand).norm(dim=-1) - (p1_lhand - p2gen_lhand).norm(dim=-1)).norm(dim=-1) ) 175 | loss_logs.append(interact_loss) 176 | return loss_logs 177 | 178 | def forward(self, motions1, motions2, t=None): 179 | B, T = motions2.shape[:2] 180 | if t == None: 181 | t, _ = self.sampler.sample(B, motions1.device) 182 | self.diffusion_timestep = t 183 | output = self.diffusion.training_losses( 184 | model=self.model_pose, 185 | x_start=motions2, 186 | t=t, 187 | model_kwargs={"motion1": motions1} 188 | ) 189 | 190 | self.generated = output['pred'] # synthesized pose 2 191 | return t, output['x_noisy'] 192 | 193 | def generate(self, motion1, motion2=None): 194 | B, T, J, dim_pose = motion1.shape 195 | output = self.diffusion.p_sample_loop( 196 | self.model_pose, 197 | (B, T, J, dim_pose), 198 | clip_denoised=False, 199 | progress=True, 200 | pre_seq= motion2, 201 | model_kwargs={ 202 | 'motion1': motion1, 203 | }) 204 | return output 205 | 206 | 207 | def relative_normalization(self, global_pose1, global_pose2, global_rot1, global_rot2): 208 | self.p1_rhand_wrist_pos = global_pose1[:, :, 18] 209 | self.p1_lhand_wrist_pos = global_pose1[:, :, 43] 210 | p1_rhand_wrist_pos = (global_pose1[:, :, 18] - self.p1_rhand_wrist_pos) / self.scale 211 | p1_lhand_wrist_pos = (global_pose1[:, :, 43] - self.p1_lhand_wrist_pos) / self.scale 212 | p2_rhand_wrist_pos = (global_pose2[:, :, 18] - self.p1_rhand_wrist_pos) / self.scale 213 | p2_lhand_wrist_pos = (global_pose2[:, :, 43] - self.p1_lhand_wrist_pos) / self.scale 214 | B = p1_rhand_wrist_pos.shape[0] 215 | T = p1_rhand_wrist_pos.shape[1] 216 | p1_rhand_rot = self.skel.select_bvh_joints(global_rot1, original_joint_order=self.skel.bvh_joint_order, 217 | new_joint_order=self.skel.rh_fingers_only).reshape(B, T, -1) 218 | p1_lhand_rot = self.skel.select_bvh_joints(global_rot1, original_joint_order=self.skel.bvh_joint_order, 219 | new_joint_order=self.skel.lh_fingers_only).reshape(B, T, -1) 220 | 221 | p2_rhand_rot = self.skel.select_bvh_joints(global_rot2, original_joint_order=self.skel.bvh_joint_order, 222 | new_joint_order=self.skel.rh_fingers_only).reshape(B, T, -1) 223 | p2_lhand_rot = self.skel.select_bvh_joints(global_rot2, original_joint_order=self.skel.bvh_joint_order, 224 | new_joint_order=self.skel.lh_fingers_only).reshape(B, T, -1) 225 | 226 | # create a contact map based on threshold of wrists 227 | self.contact_dist = torch.zeros(B, T, 4).to(p1_lhand_rot.device).float() 228 | self.contact_dist[:,:, 0] = ((p1_rhand_wrist_pos - p2_rhand_wrist_pos)**2).norm(dim=-1) 229 | self.contact_dist[:,:, 1] = ((p1_rhand_wrist_pos - p2_lhand_wrist_pos)**2).norm(dim=-1) 230 | self.contact_dist[:,:, 2] = ((p1_lhand_wrist_pos - p2_rhand_wrist_pos)**2).norm(dim=-1) 231 | self.contact_dist[:,:, 3] = ((p1_lhand_wrist_pos - p2_lhand_wrist_pos)**2).norm(dim=-1) 232 | 233 | self.input_condn = torch.cat((p1_rhand_wrist_pos, p1_lhand_wrist_pos, 234 | p2_rhand_wrist_pos, p2_lhand_wrist_pos, 235 | p1_rhand_rot, p1_lhand_rot, self.contact_dist), dim=-1) 236 | self.input = torch.cat((p2_rhand_rot, p2_lhand_rot), dim=-1) 237 | 238 | def pose_relative_normalization(self, global_pose1, global_pose2, contact_maps): 239 | p1_rhand_pos = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order, 240 | new_joint_order=self.skel.rh_fingers_only) 241 | p1_lhand_pos = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order, 242 | new_joint_order=self.skel.lh_fingers_only) 243 | 244 | p2_rhand_pos = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order, 245 | new_joint_order=self.skel.rh_fingers_only) 246 | p2_lhand_pos = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order, 247 | new_joint_order=self.skel.lh_fingers_only) 248 | self.p1_rhand_wrist_pos = p1_rhand_pos[:, :, 0] 249 | self.p2_rhand_wrist_pos = p2_rhand_pos[:, :, 0] 250 | self.p1_lhand_wrist_pos = p1_lhand_pos[:, :, 0] 251 | self.p2_lhand_wrist_pos = p2_lhand_pos[:, :, 0] 252 | self.p1_rhand_pos = (p1_rhand_pos - torch.repeat_interleave(self.p1_rhand_wrist_pos.unsqueeze(-2), self.num_jts, axis=-2))/self.scale 253 | self.p1_lhand_pos = (p1_lhand_pos - torch.repeat_interleave(self.p1_lhand_wrist_pos.unsqueeze(-2), self.num_jts, axis=-2))/self.scale 254 | self.p2_rhand_pos = (p2_rhand_pos - torch.repeat_interleave(self.p2_rhand_wrist_pos.unsqueeze(-2), self.num_jts, axis=-2))/self.scale 255 | self.p2_lhand_pos = (p2_lhand_pos - torch.repeat_interleave(self.p2_lhand_wrist_pos.unsqueeze(-2), self.num_jts, axis=-2))/self.scale 256 | B = self.p1_rhand_wrist_pos.shape[0] 257 | T = self.p1_rhand_wrist_pos.shape[1] 258 | 259 | self.contact_map = contact_maps.to(self.device).float() 260 | self.input_condn = torch.cat((self.p1_rhand_pos.reshape(B, T, -1), self.p1_lhand_pos.reshape(B, T, -1), 261 | self.contact_map), dim=-1) 262 | self.input = torch.cat((self.p2_rhand_pos.reshape(B, T, -1), self.p2_lhand_pos.reshape(B, T, -1)), dim=-1) 263 | 264 | 265 | def train(self, num_epoch, ablation=None): 266 | total_train_loss = 0.0 267 | self.model_pose.train() 268 | training_tqdm = tqdm(self.ds_train, desc='train' + ' {:.10f}'.format(0), leave=False, ncols=120) 269 | # self.joint_parent = self.ds_train.dataset.bvh_joint_parents_list 270 | diff_count = [0, 5, 10, 50, 100, 200, 300, 400, 499] 271 | for count, batch in enumerate(training_tqdm): 272 | self.optimizer_model_pose.zero_grad() 273 | 274 | # with torch.autograd.detect_anomaly(): 275 | if True: 276 | global_pose1 = batch['pose_canon_1'].to(self.device).float() 277 | global_pose2 = batch['pose_canon_2'].to(self.device).float() 278 | if global_pose1.shape[1] == 0: 279 | continue 280 | self.pose_relative_normalization(global_pose1, global_pose2, batch['contacts']) 281 | t, noisy = self.forward(self.input_condn, self.input) 282 | 283 | loss_logs = self.calc_loss(num_epoch) 284 | loss_model = sum(loss_logs) 285 | total_train_loss += loss_model.item() 286 | 287 | if loss_model == float('inf') or torch.isnan(loss_model): 288 | print('Train loss is nan') 289 | exit() 290 | loss_model.backward() 291 | torch.nn.utils.clip_grad_value_(self.model_pose.parameters(), 0.01) 292 | self.optimizer_model_pose.step() 293 | 294 | avg_train_loss = total_train_loss/(count + 1) 295 | return avg_train_loss 296 | 297 | def evaluate(self, num_epoch, ablation=None): 298 | total_eval_loss = 0.0 299 | self.model_pose.eval() 300 | T = self.frames 301 | eval_tqdm = tqdm(self.ds_val, desc='eval' + ' {:.10f}'.format(0), leave=False, ncols=120) 302 | 303 | for count, batch in enumerate(eval_tqdm): 304 | if True: 305 | global_pose1 = batch['pose_canon_1'].to(self.device).float() 306 | global_pose2 = batch['pose_canon_2'].to(self.device).float() 307 | if global_pose1.shape[1] == 0: 308 | continue 309 | self.pose_relative_normalization(global_pose1, global_pose2, batch['contacts']) 310 | t, noisy = self.forward(self.input_condn, self.input) 311 | loss_logs = self.calc_loss(num_epoch) 312 | loss_model = sum(loss_logs) 313 | total_eval_loss += loss_model.item() 314 | 315 | avg_eval_loss = total_eval_loss/(count + 1) 316 | 317 | return avg_eval_loss 318 | 319 | def fit(self, n_epochs=None, ablation=False): 320 | print('*****Inside Trainer.fit *****') 321 | if n_epochs is None: 322 | n_epochs = self.args.num_epochs 323 | starttime = datetime.now().replace(microsecond=0) 324 | print('Started Training at', datetime.strftime(starttime, '%Y-%m-%d_%H:%M:%S'), 'Total epochs: ', n_epochs) 325 | save_model_dict = {} 326 | best_eval = 1000 327 | 328 | for epoch_num in range(self.epochs_completed, n_epochs + 1): 329 | tqdm.write('--- starting Epoch # %03d' % epoch_num) 330 | train_loss = self.train(epoch_num, ablation) 331 | if epoch_num % 5 == 0: 332 | eval_loss = self.evaluate(epoch_num, ablation) 333 | else: 334 | eval_loss = 0.0 335 | self.scheduler_pose.step() 336 | self.book.update_res({'epoch': epoch_num, 'train': train_loss, 'val': eval_loss, 'test': 0.0}) 337 | self.book._save_res() 338 | self.book.print_res(epoch_num, key_order=['train', 'val', 'test'], lr=self.optimizer_model_pose.param_groups[0]['lr']) 339 | 340 | if epoch_num > 100 and eval_loss < best_eval: 341 | print('Best eval at epoch {}'.format(epoch_num)) 342 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + 'best.p'), 'wb') 343 | save_model_dict.update({'model_pose': self.model_pose.state_dict()}) 344 | torch.save(save_model_dict, f) 345 | f.close() 346 | best_eval = eval_loss 347 | if epoch_num > 20 and epoch_num % 20 == 0 : 348 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + '{:06d}'.format(epoch_num) + '.p'), 'wb') 349 | save_model_dict.update({'model_pose': self.model_pose.state_dict()}) 350 | torch.save(save_model_dict, f) 351 | f.close() 352 | endtime = datetime.now().replace(microsecond=0) 353 | print('Finished Training at %s\n' % (datetime.strftime(endtime, '%Y-%m-%d_%H:%M:%S'))) 354 | print('Training complete in %s!\n' % (endtime - starttime)) 355 | 356 | 357 | 358 | if __name__ == '__main__': 359 | args = argparseNloop() 360 | args.lambda_loss = { 361 | 'fk': 1.0, 362 | 'fk_vel': 1.0, 363 | 'rot': 1e+3, 364 | 'rot_vel': 1e+1, 365 | 'kldiv': 1.0, 366 | 'pos': 1e+3, 367 | 'vel': 1e+1, 368 | 'bone': 1.0, 369 | 'foot': 0.0, 370 | } 371 | is_train = True 372 | ablation = None # if True then ablation: no_IAC_loss 373 | model_trainer = Trainer(args=args, is_train=is_train, split='train', JT_POSITION=True, num_jts=11) 374 | print("** Method Initialization Complete **") 375 | model_trainer.fit(ablation=ablation) 376 | 377 | -------------------------------------------------------------------------------- /src/Lindyhop/train_body_diffusion.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import shutil 7 | import sys 8 | sys.path.append('.') 9 | sys.path.append('..') 10 | import time 11 | import torch 12 | torch.cuda.empty_cache() 13 | import torch.nn as nn 14 | 15 | from cmath import nan 16 | from collections import OrderedDict 17 | from datetime import datetime 18 | from torch import optim 19 | from torch.utils.data import DataLoader 20 | from tqdm import tqdm 21 | 22 | from src.Lindyhop.argUtils import argparseNloop 23 | from src.Lindyhop.LindyHop_dataloader import LindyHopDataset 24 | from src.Lindyhop.models.MotionDiffuse_body import * 25 | from src.Lindyhop.models.Gaussian_diffusion import ( 26 | GaussianDiffusion, 27 | get_named_beta_schedule, 28 | create_named_schedule_sampler, 29 | ModelMeanType, 30 | ModelVarType, 31 | LossType 32 | ) 33 | from src.Lindyhop.skeleton import * 34 | from src.Lindyhop.visualizer import plot_contacts3D 35 | from src.tools.bookkeeper import * 36 | from src.tools.transformations import * 37 | from src.tools.utils import makepath 38 | 39 | right_side = [15, 16, 17, 18] 40 | left_side = [19, 20, 21, 22] 41 | # stat_metrics = CalculateMetricsDanceData() 42 | def dist(x, y): 43 | # return torch.mean(x - y) 44 | return torch.mean(torch.cdist(x, y, p=2)) 45 | 46 | def initialize_weights(m): 47 | std_dev = 0.02 48 | if isinstance(m, nn.Linear): 49 | nn.init.normal_(m.weight, std=std_dev) 50 | if m.bias is not None: 51 | nn.init.normal_(m.bias, std=std_dev) 52 | # nn.init.constant_(m.bias.data, 1e-5) 53 | elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d): 54 | torch.nn.init.normal_(m.weight, std=std_dev) 55 | if m.bias is not None: 56 | torch.nn.init.normal_(m.bias, std=std_dev) 57 | # nn.init.constant_(m.bias.data, 1e-5) 58 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 59 | nn.init.normal_(m.weight, std=std_dev) 60 | if m.bias is not None: 61 | nn.init.normal_(m.bias, std=std_dev) 62 | 63 | class Trainer: 64 | def __init__(self, args, is_train=True, split='test', JT_POSITION=False, num_jts = 69): 65 | torch.manual_seed(args.seed) 66 | self.model_path = args.model_path 67 | makepath(args.work_dir, isfile=False) 68 | use_cuda = torch.cuda.is_available() 69 | if use_cuda: 70 | torch.cuda.empty_cache() 71 | self.device = torch.device("cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu") 72 | gpu_brand = torch.cuda.get_device_name(args.cuda) if use_cuda else None 73 | gpu_count = torch.cuda.device_count() if args.use_multigpu else 1 74 | print('Using %d CUDA cores [%s] for training!' % (gpu_count, gpu_brand)) 75 | args_subset = ['exp', 'model', 'batch_size', 'frames'] 76 | self.book = BookKeeper(args, args_subset) 77 | self.args = self.book.args 78 | self.batch_size = args.batch_size 79 | self.curriculum = args.curriculum 80 | self.scale = args.scale 81 | self.dtype = torch.float32 82 | self.epochs_completed = self.book.last_epoch 83 | self.frames = args.frames 84 | self.model = args.model 85 | self.lambda_loss = args.lambda_loss 86 | self.testtime_split = split 87 | self.num_jts = num_jts 88 | self.model_pose = eval(args.model)(device=self.device, 89 | num_jts=self.num_jts, 90 | num_frames=self.frames, 91 | input_feats=args.input_feats, 92 | # jt_latent_dim=args.jt_latent, 93 | latent_dim=args.d_model, 94 | num_heads=args.num_head, 95 | num_layers=args.num_layer, 96 | ff_size=args.d_ff, 97 | activations=args.activations 98 | ).to(self.device).float() 99 | trainable_count_body = sum(p.numel() for p in self.model_pose.parameters() if p.requires_grad) 100 | 101 | self.diffusion_steps = args.diffusion_steps 102 | self.beta_scheduler = args.noise_schedule 103 | self.betas = get_named_beta_schedule(self.beta_scheduler, self.diffusion_steps) 104 | self.diffusion = GaussianDiffusion( 105 | betas=self.betas, 106 | model_mean_type=ModelMeanType.START_X, 107 | model_var_type=ModelVarType.FIXED_SMALL, 108 | loss_type=LossType.MSE 109 | ) 110 | self.sampler_name = args.sampler 111 | self.sampler = create_named_schedule_sampler(self.sampler_name, self.diffusion) 112 | self.model_pose.apply(initialize_weights) 113 | self.optimizer_model_pose = eval(args.optimizer)(self.model_pose.parameters(), lr = args.lr) 114 | self.scheduler_pose = eval(args.scheduler)(self.optimizer_model_pose, step_size=args.stepsize, gamma=args.gamma) 115 | self.skel = InhouseStudioSkeleton() 116 | self.mse_criterion = torch.nn.MSELoss() 117 | self.l1_criterion = torch.nn.L1Loss() 118 | 119 | print(args.model, 'Model Created') 120 | if args.load: 121 | print('Loading Model', args.model) 122 | self.book._load_model(self.model_pose, 'model_pose') 123 | print('Loading the data') 124 | if is_train: 125 | self.load_data(args) 126 | else: 127 | self.load_data_testtime(args) 128 | 129 | 130 | def load_data_testtime(self, args): 131 | self.ds_data = LindyHopDataset(args, window_size=self.frames, split=self.testtime_split) 132 | self.load_ds_data = DataLoader(self.ds_data, batch_size=1, shuffle=False, num_workers=0, drop_last=True) 133 | 134 | 135 | def load_data(self, args): 136 | 137 | ds_train = LindyHopDataset(args, window_size=self.frames, split='train') 138 | self.ds_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) 139 | print('Train set loaded. Size=', len(self.ds_train.dataset)) 140 | ds_val = LindyHopDataset(args, window_size=self.frames, split='test') 141 | self.ds_val = DataLoader(ds_val, batch_size=1, shuffle=False, num_workers=0, drop_last=True) 142 | print('Validation set loaded. Size=', len(self.ds_val.dataset)) 143 | 144 | def calc_kldiv(self, dist_m): 145 | mu_ref = torch.zeros_like(dist_m.loc) 146 | scale_ref = torch.ones_like(dist_m.scale) 147 | dist_ref = torch.distributions.Normal(mu_ref, scale_ref) 148 | return torch.distributions.kl_divergence(dist_m, dist_ref) 149 | 150 | def calc_loss(self, num_epoch): 151 | bs, seq, J, dim = self.generated.shape 152 | pos_loss = self.lambda_loss['pos'] * self.mse_criterion(self.gt_pose2, self.generated) 153 | vel_gt = self.gt_pose2[:, 1:] - self.gt_pose2[:, :-1] 154 | vel_gen = self.generated[:, 1:] - self.generated[:, :-1] 155 | velocity_loss = self.lambda_loss['vel'] * self.mse_criterion(vel_gt, vel_gen) 156 | acc_gt = vel_gt[:, 1:] - vel_gt[:, :-1] 157 | acc_gen = vel_gen[:, 1:] - vel_gen[:, :-1] 158 | acc_loss = self.lambda_loss['vel'] * self.mse_criterion(acc_gt, acc_gen) 159 | bone_len_gt = (self.gt_pose2[:, :, 1:] - self.gt_pose2[:, :, [self.skel.parents_body_only[x] for x in range(1, J)]]).norm(dim=-1) 160 | bone_len_gen = (self.generated[:, :, 1:] - self.generated[:, :, [self.skel.parents_body_only[x] for x in range(1, J)]]).norm(dim=-1) 161 | bone_len_consistency_loss = self.lambda_loss['bone'] * self.mse_criterion(bone_len_gt, bone_len_gen) 162 | if num_epoch > 100: 163 | self.lambda_loss['foot'] = 20.0 164 | else: 165 | self.lambda_loss['foot'] = 0.0 166 | rightfoot_idx = [4, 5] 167 | leftfoot_idx = [9, 10] 168 | gen_leftfoot_joint = self.generated[:, :, leftfoot_idx] 169 | static_left_foot_index = gen_leftfoot_joint[..., 1] <= 0.02 170 | gen_rightfoot_joint = self.generated[:, :, rightfoot_idx] 171 | static_right_foot_index = gen_rightfoot_joint[..., 1] <= 0.02 172 | gen_leftfoot_vel = torch.zeros_like(gen_leftfoot_joint) 173 | gen_leftfoot_vel[:, :-1] = gen_leftfoot_joint[:, 1:] - gen_leftfoot_joint[:, :-1] 174 | gen_leftfoot_vel[~static_left_foot_index] = 0 175 | gen_rightfoot_vel = torch.zeros_like(gen_rightfoot_joint) 176 | gen_rightfoot_vel[:, :-1] = gen_rightfoot_joint[:, 1:] - gen_rightfoot_joint[:, :-1] 177 | gen_rightfoot_vel[~static_right_foot_index] = 0 178 | footskate_loss = self.lambda_loss['foot'] * (self.mse_criterion(gen_leftfoot_vel, torch.zeros_like(gen_leftfoot_vel)) + 179 | self.mse_criterion(gen_rightfoot_vel, torch.zeros_like(gen_rightfoot_vel)) ) 180 | 181 | loss_logs = [pos_loss, velocity_loss, bone_len_consistency_loss, 182 | footskate_loss, acc_loss] 183 | 184 | #include the interaction loss 185 | self.lambda_loss['in'] = 50.0 186 | rh_rh = self.contact_map[:,:, 0] == 1 187 | rh_lh = self.contact_map[:,:, 1] == 1 188 | lh_rh = self.contact_map[:,:, 2] == 1 189 | lh_lh = self.contact_map[:,:, 3] == 1 190 | 191 | arm_interact_loss = self.lambda_loss['in'] * torch.mean( 192 | rh_lh * ((self.pose1[:, :, right_side] - self.gt_pose2[:, :, left_side]).norm(dim=-1) - ( 193 | self.pose1[:, :, right_side] - self.generated[:, :, left_side]).norm(dim=-1)).norm(dim=-1) + rh_rh * ( 194 | (self.pose1[:, :, right_side] - self.gt_pose2[:, :, right_side]).norm(dim=-1) - ( 195 | self.pose1[:, :, right_side] - self.generated[:, :, right_side]).norm(dim=-1)).norm(dim=-1) + lh_rh * (( 196 | self.pose1[:, :, left_side] - self.gt_pose2[:, :, right_side]).norm(dim=-1) - ( 197 | self.pose1[:, :, left_side] - self.generated[:, :, right_side]).norm(dim=-1)).norm(dim=-1) + lh_lh * (( 198 | self.pose1[:, :, left_side] - self.gt_pose2[:, :, left_side]).norm(dim=-1) - ( 199 | self.pose1[:, :, left_side] - self.generated[:, :, left_side]).norm(dim=-1)).norm(dim=-1) ) 200 | 201 | loss_logs.append(arm_interact_loss) 202 | interact_loss = self.mse_criterion((self.pose1 - self.gt_pose2), (self.pose1 - self.generated)) 203 | loss_logs.append(interact_loss) 204 | return loss_logs 205 | 206 | def forward(self, motions1, motions2, t=None): 207 | B, T = motions2.shape[:2] 208 | if t == None: 209 | t, _ = self.sampler.sample(B, motions1.device) 210 | self.diffusion_timestep = t 211 | output = self.diffusion.training_losses( 212 | model=self.model_pose, 213 | x_start=motions2, 214 | t=t, 215 | model_kwargs={"motion1": motions1} 216 | ) 217 | 218 | self.pose1 = motions1 219 | self.gt_pose2 = motions2 #gt pose 2 220 | self.generated = output['pred'] # synthesized pose 2 221 | return t, output['x_noisy'] 222 | 223 | 224 | 225 | def root_relative_normalization(self, global_pose1, global_pose2): 226 | 227 | global_pose1 = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order, 228 | new_joint_order=self.skel.body_only) 229 | pose1_root_rel = global_pose1 - torch.repeat_interleave(self.global_root_origin.unsqueeze(-2), self.num_jts, axis=-2) 230 | self.pose1_root_rel = pose1_root_rel / self.scale 231 | global_pose2 = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order, 232 | new_joint_order=self.skel.body_only) 233 | 234 | pose2_root_rel = global_pose2 - torch.repeat_interleave(self.global_root_origin.unsqueeze(-2), self.num_jts, axis=-2) 235 | self.pose2_root_rel = pose2_root_rel / self.scale 236 | tmp=1 237 | 238 | def root_relative_unnormalization(self, pose1_normalized, pose2_normalized): 239 | pose1_unnormalized = pose1_normalized * self.scale 240 | pose2_unnormalized = pose2_normalized * self.scale 241 | global_pose1 = pose1_unnormalized + torch.repeat_interleave(self.global_root_origin.unsqueeze(-2), self.num_jts, axis=-2) 242 | global_pose2 = pose2_unnormalized + torch.repeat_interleave(self.global_root_origin.unsqueeze(-2), self.num_jts, axis=-2) 243 | return global_pose1, global_pose2 244 | 245 | def train(self, num_epoch, ablation=None): 246 | total_train_loss = 0.0 247 | total_pos_loss = 0.0 248 | total_vel_loss = 0.0 249 | total_bone_loss = 0.0 250 | total_footskate_loss = 0.0 251 | self.model_pose.train() 252 | training_tqdm = tqdm(self.ds_train, desc='train' + ' {:.10f}'.format(0), leave=False, ncols=120) 253 | diff_count = [0, 5, 10, 50, 100, 200, 300, 400, 499] 254 | for count, batch in enumerate(training_tqdm): 255 | self.optimizer_model_pose.zero_grad() 256 | 257 | with torch.autograd.detect_anomaly(): 258 | global_pose1 = batch['pose_canon_1'].to(self.device).float() 259 | global_pose2 = batch['pose_canon_2'].to(self.device).float() 260 | self.contact_map = batch['contacts'].to(self.device).float() 261 | self.global_root_origin = batch['global_root_origin'].to(device).float() 262 | if global_pose1.shape[1] == 0: 263 | continue 264 | self.root_relative_normalization(global_pose1, global_pose2) 265 | t, noisy = self.forward(self.pose1_root_rel, self.pose2_root_rel) 266 | 267 | loss_logs = self.calc_loss(num_epoch) 268 | loss_model = sum(loss_logs) 269 | total_train_loss += loss_model.item() 270 | total_pos_loss += loss_logs[0].item() 271 | total_vel_loss += loss_logs[1].item() 272 | total_bone_loss += loss_logs[2].item() 273 | total_footskate_loss += loss_logs[3].item() 274 | 275 | if loss_model == float('inf') or torch.isnan(loss_model): 276 | print('Train loss is nan') 277 | exit() 278 | loss_model.backward() 279 | torch.nn.utils.clip_grad_value_(self.model_pose.parameters(), 0.01) 280 | self.optimizer_model_pose.step() 281 | 282 | 283 | avg_train_loss = total_train_loss/(count + 1) 284 | avg_pos_loss = total_pos_loss/(count + 1) 285 | avg_vel_loss = total_vel_loss/(count + 1) 286 | avg_bone_loss = total_bone_loss/(count + 1) 287 | avg_footskate_loss = total_footskate_loss/(count + 1) 288 | 289 | return avg_train_loss, (avg_pos_loss, avg_vel_loss, avg_bone_loss, avg_footskate_loss) 290 | 291 | def evaluate(self, num_epoch, ablation=None): 292 | total_eval_loss = 0.0 293 | total_pos_loss = 0.0 294 | total_vel_loss = 0.0 295 | total_bone_loss = 0.0 296 | total_footskate_loss = 0.0 297 | self.model_pose.eval() 298 | T = self.frames 299 | eval_tqdm = tqdm(self.ds_val, desc='eval' + ' {:.10f}'.format(0), leave=False, ncols=120) 300 | 301 | for count, batch in enumerate(eval_tqdm): 302 | if True: 303 | global_pose1 = batch['pose_canon_1'].to(self.device).float() 304 | global_pose2 = batch['pose_canon_2'].to(self.device).float() 305 | self.contact_map = batch['contacts'].to(self.device).float() 306 | 307 | self.global_root_origin = batch['global_root_origin'].to(device).float() 308 | if global_pose1.shape[1] == 0: 309 | continue 310 | self.root_relative_normalization(global_pose1, global_pose2) 311 | t, noisy = self.forward(self.pose1_root_rel, self.pose2_root_rel) 312 | loss_logs = self.calc_loss(num_epoch) 313 | loss_model = sum(loss_logs) 314 | total_eval_loss += loss_model.item() 315 | total_pos_loss += loss_logs[0].item() 316 | total_vel_loss += loss_logs[1].item() 317 | total_bone_loss += loss_logs[2].item() 318 | total_footskate_loss += loss_logs[3].item() 319 | 320 | avg_eval_loss = total_eval_loss/(count + 1) 321 | avg_pos_loss = total_pos_loss/(count + 1) 322 | avg_vel_loss = total_vel_loss/(count + 1) 323 | avg_bone_loss = total_bone_loss/(count + 1) 324 | avg_footskate_loss = total_footskate_loss/(count + 1) 325 | 326 | return avg_eval_loss, (avg_pos_loss, avg_vel_loss, avg_bone_loss, avg_footskate_loss) 327 | 328 | def fit(self, n_epochs=None, ablation=False): 329 | print('*****Inside Trainer.fit *****') 330 | if n_epochs is None: 331 | n_epochs = self.args.num_epochs 332 | starttime = datetime.now().replace(microsecond=0) 333 | print('Started Training at', datetime.strftime(starttime, '%Y-%m-%d_%H:%M:%S'), 'Total epochs: ', n_epochs) 334 | save_model_dict = {} 335 | best_eval = 1000 336 | 337 | train_pos_loss = [] 338 | train_vel_loss = [] 339 | train_bone_loss = [] 340 | train_footskate_loss = [] 341 | eval_pos_loss = [] 342 | eval_vel_loss = [] 343 | eval_bone_loss = [] 344 | eval_footskate_loss = [] 345 | for epoch_num in range(self.epochs_completed, n_epochs + 1): 346 | tqdm.write('--- starting Epoch # %03d' % epoch_num) 347 | train_loss, (train_pos_loss_, train_vel_loss_, train_bone_loss_, 348 | train_footskate_loss_) = self.train(epoch_num, ablation) 349 | train_pos_loss.append(train_pos_loss_) 350 | train_vel_loss.append(train_vel_loss_) 351 | train_bone_loss.append(train_bone_loss_) 352 | train_footskate_loss.append(train_footskate_loss_) 353 | if epoch_num % 5 == 0: 354 | eval_loss, (eval_pos_loss_, eval_vel_loss_, eval_bone_loss_, 355 | eval_footskate_loss_) = self.evaluate(epoch_num, ablation) 356 | eval_pos_loss.append(eval_pos_loss_) 357 | eval_vel_loss.append(eval_vel_loss_) 358 | eval_bone_loss.append(eval_bone_loss_) 359 | eval_footskate_loss.append(eval_footskate_loss_) 360 | else: 361 | eval_loss = 0.0 362 | self.scheduler_pose.step() 363 | self.book.update_res({'epoch': epoch_num, 'train': train_loss, 'val': eval_loss, 'test': 0.0}) 364 | self.book._save_res() 365 | self.book.print_res(epoch_num, key_order=['train', 'val', 'test'], lr=self.optimizer_model_pose.param_groups[0]['lr']) 366 | 367 | if epoch_num > 100 and eval_loss < best_eval: 368 | print('Best eval at epoch {}'.format(epoch_num)) 369 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + 'best.p'), 'wb') 370 | save_model_dict.update({'model_pose': self.model_pose.state_dict()}) 371 | torch.save(save_model_dict, f) 372 | f.close() 373 | best_eval = eval_loss 374 | if epoch_num > 20 and epoch_num % 20 == 0 : 375 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + '{:06d}'.format(epoch_num) + '.p'), 'wb') 376 | save_model_dict.update({'model_pose': self.model_pose.state_dict()}) 377 | torch.save(save_model_dict, f) 378 | f.close() 379 | endtime = datetime.now().replace(microsecond=0) 380 | print('Finished Training at %s\n' % (datetime.strftime(endtime, '%Y-%m-%d_%H:%M:%S'))) 381 | print('Training complete in %s!\n' % (endtime - starttime)) 382 | 383 | 384 | 385 | if __name__ == '__main__': 386 | args = argparseNloop() 387 | args.lambda_loss = { 388 | 'fk': 1.0, 389 | 'fk_vel': 1.0, 390 | 'rot': 1.0, 391 | 'rot_vel': 1.0, 392 | 'kldiv': 1.0, 393 | 'pos': 1e+3, 394 | 'vel': 1e+1, 395 | 'bone': 1.0, 396 | 'foot': 0.0 397 | } 398 | is_train = True 399 | ablation = None # if True then ablation: no_IAC_loss 400 | model_trainer = Trainer(args=args, is_train=is_train, split='test', JT_POSITION=True, num_jts=27) 401 | print("** Method Initialization Complete **") 402 | model_trainer.fit(ablation=ablation) 403 | 404 | --------------------------------------------------------------------------------