├── src
├── tools
│ ├── common
│ │ ├── __init__.py
│ │ ├── skeleton.py
│ │ └── quaternion.py
│ ├── core_utils.py
│ ├── img_gif.py
│ ├── calculate_ev_metrics.py
│ ├── utils.py
│ ├── transformations.py
│ └── bookkeeper.py
├── Lindyhop
│ ├── LindyHop_dataloader.py
│ ├── argUtils.py
│ ├── process_LindyHop.py
│ ├── train_VanillaTransformer.py
│ ├── visualizer.py
│ ├── models
│ │ ├── MotionDiffuse_body.py
│ │ └── MotionDiffusion_hand.py
│ ├── train_hand_diffusion.py
│ └── train_body_diffusion.py
└── Ninjutsu
│ ├── Ninjutsu_dataloader.py
│ ├── argUtils.py
│ └── process_Ninjutsu.py
├── save
└── save.txt
├── data
└── data.txt
├── requirements.txt
└── README.md
/src/tools/common/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/save/save.txt:
--------------------------------------------------------------------------------
1 | The pre-trained weights are saved in this folder.
2 |
--------------------------------------------------------------------------------
/data/data.txt:
--------------------------------------------------------------------------------
1 | The processed data (train and test splits) will be stored here as pkl files after you run the 'process_ LindyHop.py' and 'process_Ninjutsu.py'.
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | accelerate
3 | argunparse
4 | astunparse
5 | async-timeout
6 | bvh-converter
7 | c3d
8 | decorator
9 | denoising-diffusion-pytorch
10 | dill
11 | diskcache
12 | einops
13 | ema-pytorch
14 | h5py
15 | imageio
16 | imageio
17 | imageio-ffmpeg
18 | joblib
19 | keras
20 | matplotlib
21 | numpy
22 | openai
23 | opencv-python
24 | pandas
25 | pathos
26 | pathtools
27 | prettytable
28 | protobuf
29 | psutil
30 | pytorch-lightning
31 | pytorch3d
32 | pytz
33 | pyyaml
34 | scikit-learn
35 | scipy
36 | seaborn
37 | sklearn
38 | tensorboard
39 | torchmetrics
40 | tqdm
41 | wandb
42 | werkzeug
--------------------------------------------------------------------------------
/src/tools/core_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import torch
5 |
6 | def send_to_cuda(model):
7 | for key in model.keys():
8 | model[key].cuda()
9 |
10 | return model
11 |
12 |
13 | class AverageMeter(object):
14 | """Computes and stores the average and current value"""
15 | def __init__(self):
16 | self.reset()
17 |
18 | def reset(self):
19 | self.val = 0
20 | self.avg = 0
21 | self.sum = 0
22 | self.count = 0
23 |
24 | def update(self, val, n=1):
25 | self.val = val
26 | self.sum += val * n
27 | self.count += n
28 | self.avg = self.sum / self.count
29 |
30 | def load_model(path, model, optimizer=None):
31 | pass
32 |
33 | def save_model(path, model, epoch, optimizer=None):
34 | state_dict = model.state_dict()
35 |
36 | data = {'epoch': epoch,
37 | 'state_dict': state_dict}
38 |
39 | if not (optimizer is None):
40 | data['optimzer'] = optimizer.state_dict()
41 |
42 | torch.save(data, path)
43 |
44 | def makepath(desired_path, isfile = False):
45 | '''
46 | if the path does not exist make it
47 | :param desired_path: can be path to a file or a folder name
48 | :return:
49 | '''
50 | import os
51 | if isfile:
52 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
53 | else:
54 | if not os.path.exists(desired_path): os.makedirs(desired_path)
55 | return desired_path
--------------------------------------------------------------------------------
/src/tools/img_gif.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import imageio
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 |
8 | def img2gif(image_folder):
9 | seqs = glob.glob(image_folder + '/*.jpg')
10 |
11 | out_filename = image_folder.split('/')[-1]
12 | int_seq = [int(seqs[i].split('/')[-1].split('.')[0]) for i in range(len(seqs))]
13 | # int_seq = [int(seqs[i].split('/')[-1].split('.')[0].split('_')[-1]) for i in range(len(seqs))]
14 | index = sorted(range(len(int_seq)), key=lambda k: int_seq[k])
15 | all_imgs = [seqs[index[i]] for i in range(len(index))]
16 | gif_name = os.path.join(image_folder, out_filename+ '.gif')
17 | with imageio.get_writer(gif_name, mode='I') as writer:
18 | for filename in all_imgs:
19 | image = imageio.imread(filename)
20 | writer.append_data(image)
21 |
22 | def img2gif_compress(fp_in):
23 | x = 800
24 | y = 400
25 | gif_name = os.path.join(image_folder, image_folder.split('/')[-1]+'compress.gif')
26 | q = 40 # Quality
27 | seqs = glob.glob(fp_in + '/*.jpg')
28 | int_seq = [int(seqs[i].split('/')[-1].split('.')[0]) for i in range(len(seqs))]
29 | index = sorted(range(len(int_seq)), key=lambda k: int_seq[k])
30 | all_imgs = [seqs[index[i]] for i in range(len(index))]
31 | img, *imgs = [Image.open(f).resize((x,y),Image.ANTIALIAS) for f in all_imgs]
32 | img.save(fp=gif_name, format='GIF', append_images=imgs,quality=q,
33 | save_all=True, loop=0, optimize=True)
34 |
35 |
36 | def img2video(image_folder, fps, img_type='png'):
37 | seqs = glob.glob(image_folder + '/*.'+ img_type)
38 | out_filename = image_folder.split('/')[-1]
39 | int_seq = [int(seqs[i].split('/')[-1].split('.')[0].split('_')[-1]) for i in range(len(seqs))]
40 | index = sorted(range(len(int_seq)), key=lambda k: int_seq[k])
41 | all_imgs = [seqs[index[i]] for i in range(len(index))]
42 | img_array = []
43 | video_name =os.path.join(image_folder, out_filename+ '.avi')
44 | for filename in all_imgs:
45 | img = cv2.imread(filename)
46 | height, width, layers = img.shape
47 | size = (width, height)
48 | img_array.append(img)
49 | out = cv2.VideoWriter(video_name ,cv2.VideoWriter_fourcc(*'DIVX'), fps, size)
50 | for i in range(len(img_array)):
51 | out.write(img_array[i])
52 | out.release()
53 |
--------------------------------------------------------------------------------
/src/Lindyhop/LindyHop_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pickle
4 | import pytorch3d.transforms as t3d
5 | import random
6 | import sys
7 | sys.path.append('.')
8 | sys.path.append('..')
9 | import torch
10 | from math import radians, cos, sin
11 | from scipy.spatial.transform import Rotation as R
12 | from torch.utils.data import DataLoader
13 | from tqdm import tqdm
14 | from src.Lindyhop.skeleton import InhouseStudioSkeleton
15 | from src.Lindyhop.visualizer import plot_contacts3D
16 | from src.tools.transformations import *
17 | from src.tools.utils import makepath
18 | from src.Lindyhop.argUtils import argparseNloop
19 |
20 |
21 | class LindyHopDataset(torch.utils.data.Dataset):
22 | def __init__(self, args, window_size=10, split='val'):
23 | self.root = args.data_dir
24 | self.scale = args.scale
25 | self.split = split
26 | self.window_size = int(window_size)
27 | with open(os.path.join(self.root, self.split+'.pkl'), 'rb') as f:
28 | self.annot_dict = pickle.load(f)
29 | self.output_keys = ['seq', 'pose_canon_1', 'pose_canon_2',
30 | 'contacts', 'dofs_1', 'dofs_2',
31 | 'rotmat_1', 'rotmat_2',
32 | 'offsets_1', 'offsets_2',
33 | ]
34 | self.skel = InhouseStudioSkeleton()
35 |
36 |
37 | def __getitem__(self, ind):
38 | index = ind % len(self.annot_dict['pose_canon_1'])
39 | annot = {}
40 | for key in self.output_keys:
41 | annot[key] = self.annot_dict[key][index]
42 | skip = 1
43 | start = np.random.randint(0, len(annot['pose_canon_1']) - self.window_size)
44 | end = start + self.window_size
45 |
46 | annot['contacts'] = annot['contacts'][start:end] # 0.rh-rh, 1: lh-lh, 2: lh-rh , 3: rh-lh)
47 | annot['pose_canon_1'] = annot['pose_canon_1'][start: end: skip]
48 | annot['pose_canon_2'] = annot['pose_canon_2'][start: end: skip]
49 | annot['dofs_1'] = np.pi * (annot['dofs_1'][start:end: skip]) / 180.
50 | annot['dofs_2'] = np.pi * (annot['dofs_2'][start:end: skip]) / 180.
51 | annot['rotmat_1'] = annot['rotmat_1'][start:end: skip]
52 | annot['rotmat_2'] = annot['rotmat_2'][start:end: skip]
53 | annot['global_root_rotation'] = np.linalg.inv(annot['rotmat_1'][:, 0])
54 | annot['global_root_origin'] = annot['pose_canon_1'][:, 0]
55 | annot['p1_parent_rel'] = annot['pose_canon_1'][ :, 1:] - annot['pose_canon_1'][:, [self.skel.parents_full[x] for x in range(1, 69)]]
56 | annot['p2_parent_rel'] = annot['pose_canon_2'][:, 1:] - annot['pose_canon_2'][:, [self.skel.parents_full[x] for x in range(1, 69)]]
57 | return annot
58 |
59 | def __len__(self):
60 | return len(self.annot_dict['pose_canon_1'])
61 |
62 |
--------------------------------------------------------------------------------
/src/Ninjutsu/Ninjutsu_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pickle
4 | import pytorch3d.transforms as t3d
5 | import random
6 | import sys
7 | sys.path.append('.')
8 | sys.path.append('..')
9 | import torch
10 | from math import radians, cos, sin
11 | from scipy.spatial.transform import Rotation as R
12 | from torch.utils.data import DataLoader
13 | from tqdm import tqdm
14 | from src.Ninjutsu.skeleton import InhouseStudioSkeleton
15 | from src.Ninjutsu.visualizer import plot_contacts3D
16 | from src.tools.transformations import *
17 | from src.tools.utils import makepath
18 | from src.Ninjutsu.argUtils import argparseNloop
19 |
20 |
21 | class NinjutsuDataset(torch.utils.data.Dataset):
22 | def __init__(self, args, window_size=10, split='val'):
23 | self.root = args.data_dir
24 | self.scale = args.scale
25 | self.split = split
26 | self.window_size = int(window_size)
27 | with open(os.path.join(self.root, self.split+'.pkl'), 'rb') as f:
28 | self.annot_dict = pickle.load(f)
29 | self.output_keys = ['seq', 'pose_canon_1', 'pose_canon_2',
30 | 'dofs_1', 'dofs_2',
31 | 'rotmat_1', 'rotmat_2',
32 | 'offsets_1', 'offsets_2',
33 | 'contacts'
34 | ]
35 | self.skel = InhouseStudioSkeleton()
36 |
37 |
38 | def __getitem__(self, ind):
39 | index = ind % len(self.annot_dict['pose_canon_1'])
40 | annot = {}
41 | for key in self.output_keys:
42 | annot[key] = self.annot_dict[key][index]
43 | skip = 1
44 | start = np.random.randint(0, len(annot['pose_canon_1']) - self.window_size)
45 | end = start + self.window_size
46 | annot['contacts'] = annot['contacts'][start: end: skip]
47 | annot['pose_canon_1'] = annot['pose_canon_1'][start: end: skip]
48 | annot['pose_canon_2'] = annot['pose_canon_2'][start: end: skip]
49 | annot['dofs_1'] = np.pi * (annot['dofs_1'][start:end: skip]) / 180.
50 | annot['dofs_2'] = np.pi * (annot['dofs_2'][start:end: skip]) / 180.
51 | annot['rotmat_1'] = annot['rotmat_1'][start:end: skip]
52 | annot['rotmat_2'] = annot['rotmat_2'][start:end: skip]
53 | # annot['seq'] = annot['seq'][start:end: skip]
54 | annot['offsets_1'] = annot['offsets_1']
55 | annot['offsets_2'] = annot['offsets_2']
56 | annot['global_root_origin'] = annot['pose_canon_1'][:, 0]
57 | annot['p1_parent_rel'] = annot['pose_canon_1'][ :, 1:] - annot['pose_canon_1'][:, [self.skel.parents_full[x] for x in range(1, 69)]]
58 | annot['p2_parent_rel'] = annot['pose_canon_2'][:, 1:] - annot['pose_canon_2'][:, [self.skel.parents_full[x] for x in range(1, 69)]]
59 |
60 | return annot
61 |
62 | def __len__(self):
63 | return len(self.annot_dict['pose_canon_1'])
64 |
--------------------------------------------------------------------------------
/src/tools/calculate_ev_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pickle
4 | import sys
5 | import torch
6 | sys.path.append('.')
7 | sys.path.append('..')
8 | from scipy import linalg
9 |
10 | def mean_l2di_(reaction, reaction_gt):
11 | x = np.mean(np.sqrt(np.sum((reaction - reaction_gt)**2, -1)))
12 | return x
13 |
14 | def mean_jitter(reaction, reaction_gt, scale=0.1):
15 | a = reaction[:, 1:] - reaction[:, :-1]
16 | b = reaction_gt[:, 1:] - reaction_gt[:, :-1]
17 | x = np.mean(np.sqrt(np.sum((a - b)**2, -1))) * scale
18 | return x
19 |
20 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
21 | def euclidean_distance_matrix(matrix1, matrix2, scale=1.0):
22 | assert matrix1.shape[1] == matrix2.shape[1]
23 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
24 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
25 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
26 | dists = np.sqrt(d1 + d2 + d3) * scale # broadcasting
27 | return dists
28 |
29 | def calculate_top_k(mat, top_k):
30 | size = mat.shape[0]
31 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
32 | bool_mat = (mat == gt_mat)
33 | correct_vec = False
34 | top_k_list = []
35 | for i in range(top_k):
36 | # print(correct_vec, bool_mat[:, i])
37 | correct_vec = (correct_vec | bool_mat[:, i])
38 | # print(correct_vec)
39 | top_k_list.append(correct_vec[:, None])
40 | top_k_mat = np.concatenate(top_k_list, axis=1)
41 | return top_k_mat
42 |
43 |
44 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
45 | dist_mat = euclidean_distance_matrix(embedding1, embedding2)
46 | argmax = np.argsort(dist_mat, axis=1)
47 | top_k_mat = calculate_top_k(argmax, top_k)
48 | if sum_all:
49 | return top_k_mat.sum(axis=0)
50 | else:
51 | return top_k_mat
52 |
53 |
54 | def calculate_matching_score(embedding1, embedding2, sum_all=False):
55 | assert len(embedding1.shape) == 2
56 | assert embedding1.shape[0] == embedding2.shape[0]
57 | assert embedding1.shape[1] == embedding2.shape[1]
58 |
59 | dist = linalg.norm(embedding1 - embedding2, axis=1)
60 | if sum_all:
61 | return dist.sum(axis=0)
62 | else:
63 | return dist
64 |
65 |
66 |
67 | def calculate_activation_statistics(activations):
68 | mu = np.mean(activations, axis=0)
69 | cov = np.cov(activations, rowvar=False)
70 | return mu, cov
71 |
72 |
73 | def calculate_diversity(activation, diversity_times, scale=1.0):
74 | assert len(activation.shape) == 2
75 | assert activation.shape[0] > diversity_times
76 | num_samples = activation.shape[0]
77 |
78 | first_indices = np.random.choice(num_samples, diversity_times, replace=False)
79 | second_indices = np.random.choice(num_samples, diversity_times, replace=False)
80 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) * scale
81 | return dist.mean()
82 |
83 |
84 | def calculate_multimodality(activation, multimodality_times):
85 | assert len(activation.shape) == 3
86 | assert activation.shape[1] > multimodality_times
87 | num_per_sent = activation.shape[1]
88 |
89 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
90 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
91 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
92 | return dist.mean()
93 |
94 |
95 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, scale=1e+1, eps=1e-6):
96 |
97 | mu1 = np.atleast_1d(mu1)
98 | mu2 = np.atleast_1d(mu2)
99 |
100 | sigma1 = np.atleast_2d(sigma1)
101 | sigma2 = np.atleast_2d(sigma2)
102 |
103 | assert mu1.shape == mu2.shape, \
104 | 'Training and test mean vectors have different lengths'
105 | assert sigma1.shape == sigma2.shape, \
106 | 'Training and test covariances have different dimensions'
107 |
108 | diff = mu1 - mu2
109 |
110 | # Product might be almost singular
111 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
112 | if not np.isfinite(covmean).all():
113 | msg = ('fid calculation produces singular product; '
114 | 'adding %s to diagonal of cov estimates') % eps
115 | print(msg)
116 | offset = np.eye(sigma1.shape[0]) * eps
117 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
118 |
119 | # Numerical error might give slight imaginary component
120 | if np.iscomplexobj(covmean):
121 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
122 | m = np.max(np.abs(covmean.imag))
123 | raise ValueError('Imaginary component {}'.format(m))
124 | covmean = covmean.real
125 |
126 | tr_covmean = np.trace(covmean)
127 |
128 | return scale * ((diff.dot(diff) + np.trace(sigma1) +
129 | np.trace(sigma2) - 2 * tr_covmean))
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ReMoS: 3D Motion-Conditioned Reaction Synthesis for Two-Person Interactions
2 | Accepted at the European Conference on Computer Vision (ECCV) 2024.
3 |
4 | [Paper](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/05358.pdf) |
5 | [Video](https://vcai.mpi-inf.mpg.de/projects/remos/Remos_ECCV_v2_1.mp4) |
6 | [Project Page](https://vcai.mpi-inf.mpg.de/projects/remos/)
7 |
8 |
9 |
10 |
11 |
12 | ## Pre-requisites
13 | We have tested our code on the following setups:
14 | * Ubuntu 20.04 LTS
15 | * Windows 10, 11
16 | * Python >= 3.8
17 | * Pytorch >= 1.11
18 | * conda >= 4.9.2 (optional but recommended)
19 |
20 | ## Getting started
21 |
22 | Follow these commands to create a conda environment:
23 | ```
24 | conda create -n remos python=3.8
25 | conda activate remos
26 | conda install -c pytorch pytorch=1.11 torchvision cudatoolkit=11.3
27 | pip install -r requirements.txt
28 | ```
29 | For pytorch3D installation refer to https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md
30 |
31 | **Note:** If PyOpenGL installed using `requirements.txt` causes issues in Ubuntu, then install PyOpenGL using:
32 | ```
33 | apt-get update
34 | apt-get install python3-opengl
35 | ```
36 |
37 | ## Dataset download and preprocess
38 | Download the ReMoCap dataset from the [ReMoS website](https://vcai.mpi-inf.mpg.de/projects/remos/#dataset_section). Unzip and place th dataset under `../DATASETS/ReMoCap`.
39 | The format of the dataset folder should be as follows:
40 | ```bash
41 | DATASETS
42 | ├── ReMoCap
43 | │ │
44 | │ ├── LindyHop
45 | │ │
46 | │ ├── train
47 | │ │
48 | │ └── seq_3
49 | │ │
50 | │ └── 0 'first person'
51 | │ └── motion.bvh
52 | │ └── motion_worldpose.csv
53 | │ └── motion_rotation.csv
54 | │ └── motion_offsets.pkl
55 | │ └── 1 'second person'
56 | │ └── motion.bvh
57 | │ └── motion_worldpose.csv
58 | │ └── motion_rotation.csv
59 | │ └── motion_offsets.pkl
60 | │
61 | │ └── ...
62 | │ ├── test
63 | │ │
64 | │ └── ...
65 | |
66 | │ ├── Ninjutsu
67 | │ │
68 | │ ├── train
69 | │ │
70 | │ └── shot_001
71 | │ │
72 | │ └── 0.bvh
73 | │ └── 0_worldpose.csv
74 | │ └── 0_rotations.csv
75 | │ └── 0_offsets.pkl
76 | │ └── 1.bvh
77 | │ └── 1_worldpose.csv
78 | │ └── 1_rotations.csv
79 | │ └── 1_offsets.pkl
80 | │ └── shot_002
81 | │ └── ...
82 | │ └── ...
83 | │ ├── test
84 | │ │
85 | │ └── ...
86 |
87 | ```
88 |
89 | 3. To pre-process the two parts of the dataset for our setting, run:
90 | ```
91 | python src/Lindyhop/process_LindyHop.py
92 | python src/Ninjutsu/process_Ninjutsu.py
93 | ```
94 | This will create the 'train.pkl' and 'test.pkl' under `data/` folder.
95 |
96 | ## Training and testing on the Lindy Hop motion data
97 |
98 | 4. To train the ReMoS model on the Lindy Hop motions in our setting, run:
99 | ```
100 | python src/Lindyhop/train_body_diffusion.py
101 | python src/Lindyhop/train_hand_diffusion.py
102 | ```
103 |
104 | 5. To test and evaluate the ReMoS model on the Lindy Hop motions, run:
105 | ```
106 | python src/Lindyhop/test_full_diffusion.py
107 | ```
108 | Set 'is_eval' flag to True to get the evaluation metrics, and set 'is_eval' to False to visualize the results.
109 |
110 | Download the pre-trained weights for the Lindy Hop motions from [here](https://vcai.mpi-inf.mpg.de/projects/remos/LindyHop_pretrained_weights.zip) and unzip them under `save/LindyHop/`.
111 |
112 | ## Training and testing on the Ninjutsu motion data
113 |
114 | coming soon!
115 |
116 | ## License
117 |
118 | Copyright (c) 2024, Max Planck Institute for Informatics
119 | All rights reserved.
120 |
121 | Permission is hereby granted, free of charge, to any person or company obtaining a copy of this dataset and associated documentation files (the "Dataset") from the copyright holders to use the Dataset for any non-commercial purpose. Redistribution and (re)selling of the Dataset, of modifications, extensions, and derivates of it, and of other dataset containing portions of the licensed Dataset, are not permitted. The Copyright holder is permitted to publically disclose and advertise the use of the software by any licensee.
122 |
123 | Packaging or distributing parts or whole of the provided software (including code and data) as is or as part of other datasets is prohibited. Commercial use of parts or whole of the provided dataset (including code and data) is strictly prohibited.
124 |
125 | THE DATASET IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE DATASET OR THE USE OR OTHER DEALINGS IN THE DATASET.
126 |
127 |
128 |
--------------------------------------------------------------------------------
/src/tools/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import logging
5 | import math
6 | import json
7 | import torch.nn.functional as F
8 | from copy import copy
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10 | to_cpu = lambda tensor: tensor.detach().cpu().numpy()
11 |
12 |
13 | def parse_npz(npz, allow_pickle=True):
14 | npz = np.load(npz, allow_pickle=allow_pickle)
15 | npz = {k: npz[k].tolist() for k in npz.files}
16 | return DotDict(npz)
17 |
18 | def params2torch(params, dtype = torch.float32):
19 | return {k: torch.from_numpy(v).type(dtype).to(device) for k, v in params.items()}
20 |
21 | def prepare_params(params, frame_mask, rel_trans = None, dtype = np.float32):
22 | n_params = {k: v[frame_mask].astype(dtype) for k, v in params.items()}
23 | if rel_trans is not None:
24 | n_params['transl'] -= rel_trans
25 | return n_params
26 |
27 | def torch2np(item, dtype=np.float32):
28 | out = {}
29 | for k, v in item.items():
30 | if v ==[] or v=={}:
31 | continue
32 | if isinstance(v, list):
33 | if isinstance(v[0], str):
34 | out[k] = v
35 | else:
36 | if torch.is_tensor(v[0]):
37 | v = [v[i].cpu() for i in range(len(v))]
38 | try:
39 | out[k] = np.array(np.concatenate(v), dtype=dtype)
40 | except:
41 | out[k] = np.array(np.array(v), dtype=dtype)
42 | elif isinstance(v, dict):
43 | out[k] = torch2np(v)
44 | else:
45 | if torch.is_tensor(v):
46 | v = v.cpu()
47 | out[k] = np.array(v, dtype=dtype)
48 |
49 | return out
50 |
51 | def DotDict(in_dict):
52 | out_dict = copy(in_dict)
53 | for k,v in out_dict.items():
54 | if isinstance(v,dict):
55 | out_dict[k] = DotDict(v)
56 | return dotdict(out_dict)
57 |
58 | class dotdict(dict):
59 | """dot.notation access to dictionary attributes"""
60 | __getattr__ = dict.get
61 | __setattr__ = dict.__setitem__
62 | __delattr__ = dict.__delitem__
63 |
64 | def append2dict(source, data):
65 | for k in data.keys():
66 | if k in source.keys():
67 | if isinstance(data[k], list):
68 | source[k] += data[k]
69 | else:
70 | source[k].append(data[k])
71 |
72 |
73 | def append2list(source, data):
74 | # d = {}
75 | for k in data.keys():
76 | leng = len(data[k])
77 | break
78 | for id in range(leng):
79 | d = {}
80 | for k in data.keys():
81 | if isinstance(data[k], list):
82 | if isinstance(data[k][0], str):
83 | d[k] = data[k]
84 | elif isinstance(data[k][0], np.ndarray):
85 | d[k] = data[k][id]
86 |
87 | elif isinstance(data[k], str):
88 | d[k] = data[k]
89 | elif isinstance(data[k], np.ndarray):
90 | d[k] = data[k]
91 | source.append(d)
92 |
93 | # source[k] += data[k].astype(np.float32)
94 |
95 | # source[k].append(data[k].astype(np.float32))
96 |
97 | def np2torch(item, dtype=torch.float32):
98 | out = {}
99 | for k, v in item.items():
100 | if v ==[] :
101 | continue
102 | if isinstance(v, str):
103 | out[k] = v
104 | elif isinstance(v, list):
105 | # if isinstance(v[0], str):
106 | # out[k] = v
107 | try:
108 | out[k] = torch.from_numpy(np.concatenate(v)).to(dtype)
109 | except:
110 | out[k] = v # torch.from_numpy(np.array(v))
111 | elif isinstance(v, dict):
112 | out[k] = np2torch(v)
113 | else:
114 | out[k] = torch.from_numpy(v).to(dtype)
115 | return out
116 |
117 | def to_tensor(array, dtype=torch.float32):
118 | if not torch.is_tensor(array):
119 | array = torch.tensor(array)
120 | return array.to(dtype).to(device)
121 |
122 |
123 | def to_np(array, dtype=np.float32):
124 | if 'scipy.sparse' in str(type(array)):
125 | array = np.array(array.todencse(), dtype=dtype)
126 | elif torch.is_tensor(array):
127 | array = array.detach().cpu().numpy()
128 | return array
129 |
130 | def makepath(desired_path, isfile = False):
131 | '''
132 | if the path does not exist make it
133 | :param desired_path: can be path to a file or a folder name
134 | :return:
135 | '''
136 | import os
137 | if isfile:
138 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
139 | else:
140 | if not os.path.exists(desired_path): os.makedirs(desired_path)
141 | return desired_path
142 |
143 | def lr_decay_step(optimizer, epo, lr, gamma):
144 | if epo % 3 == 0:
145 | lr = lr * gamma
146 | for param_group in optimizer.param_groups:
147 | param_group['lr'] = lr
148 | return lr
149 |
150 | def lr_decay_mine(optimizer, lr_now, gamma):
151 | lr = lr_now * gamma
152 | for param_group in optimizer.param_groups:
153 | param_group['lr'] = lr
154 | return lr
155 |
156 | def get_dct_matrix(N):
157 | dct_m = np.eye(N)
158 | for k in np.arange(N):
159 | for i in np.arange(N):
160 | w = np.sqrt(2 / N)
161 | if k == 0:
162 | w = np.sqrt(1 / N)
163 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N)
164 | idct_m = np.linalg.inv(dct_m)
165 | return dct_m, idct_m
166 |
167 |
168 | def save_csv_log(opt, head, value, is_create=False, file_name='train_log'):
169 | if len(value.shape) < 2:
170 | value = np.expand_dims(value, axis=0)
171 | df = pd.DataFrame(value)
172 | file_path = opt.ckpt + '/{}.csv'.format(file_name)
173 | if not os.path.exists(file_path) or is_create:
174 | df.to_csv(file_path, header=head, index=False)
175 | else:
176 | with open(file_path, 'a') as f:
177 | df.to_csv(f, header=False, index=False)
178 |
179 |
180 | def save_ckpt(state, epo, opt=None):
181 | file_path = os.path.join(opt.ckpt, 'ckpt_last.pth.tar')
182 | torch.save(state, file_path)
183 | # if epo ==24: # % 4 == 0 or epo>22 or epo<5:
184 | if epo % 5 == 0:
185 | file_path = os.path.join(opt.ckpt, 'ckpt_epo'+str(epo)+'.pth.tar')
186 | torch.save(state, file_path)
187 |
188 |
189 | def save_options(opt):
190 | with open('option.json', 'w') as f:
191 | f.write(json.dumps(vars(opt), sort_keys=False, indent=4))
192 |
193 |
194 |
--------------------------------------------------------------------------------
/src/tools/transformations.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2022 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
5 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the
6 | # Max Planck Institute for Biological Cybernetics. All rights reserved.
7 | #
8 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
9 | # on this computer program. You can only use this computer program if you have closed a license agreement
10 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
11 | # Any use of the computer program without a valid license is prohibited and liable to prosecution.
12 | # Contact: ps-license@tuebingen.mpg.de
13 | #
14 |
15 | import sys
16 | sys.path.append('.')
17 | sys.path.append('..')
18 | import numpy as np
19 | import torch
20 | import logging
21 | from copy import copy
22 | from scipy.spatial.transform import Rotation
23 | import torch.nn.functional as F
24 | # import pytorch3d.transforms as t3d
25 |
26 |
27 | LOGGER_DEFAULT_FORMAT = ('{time:YYYY-MM-DD HH:mm:ss.SSS} |'
28 | ' {level: <8} |'
29 | ' {name}:{function}:'
30 | '{line} - {message}')
31 |
32 |
33 |
34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35 | to_cpu = lambda tensor: tensor.detach().cpu().numpy()
36 |
37 | def to_tensor(array, dtype=torch.float32):
38 | if not torch.is_tensor(array):
39 | array = torch.tensor(array)
40 | return array.to(dtype).to(device)
41 |
42 |
43 | def to_np(array, dtype=np.float32):
44 | if 'scipy.sparse' in str(type(array)):
45 | array = np.array(array.todencse(), dtype=dtype)
46 | elif torch.is_tensor(array):
47 | array = array.detach().cpu().numpy()
48 | return array
49 |
50 | def loc2vel(loc,fps):
51 | B = loc.shape[0]
52 | idxs = [0] + list(range(B-1))
53 | vel = (loc - loc[idxs])/(1/float(fps))
54 | return vel
55 |
56 | def vel2acc(vel,fps):
57 | B = vel.shape[0]
58 | idxs = [0] + list(range(B - 1))
59 | acc = (vel - vel[idxs]) / (1 / float(fps))
60 | return acc
61 |
62 | def loc2acc(loc,fps):
63 | vel = loc2vel(loc,fps)
64 | acc = vel2acc(vel,fps)
65 | return acc, vel
66 |
67 |
68 | def d62rotmat(pose):
69 | pose = torch.tensor(pose)
70 | reshaped_input = pose.reshape(-1, 6)
71 | return t3d.rotation_6d_to_matrix(reshaped_input)
72 |
73 | def rotmat2d6(pose):
74 | pose = torch.tensor(pose)
75 | return np.array(t3d.matrix_to_rotation_6d(pose))
76 |
77 | def rotmat2d6_tensor(pose):
78 | pose = torch.tensor(pose)
79 | return torch.tensor(t3d.matrix_to_rotation_6d(pose))
80 |
81 | def aa2rotmat(pose):
82 | pose = to_tensor(pose)
83 | return t3d.axis_angle_to_matrix(pose)
84 |
85 | def rotmat2aa(pose):
86 | pose = to_tensor(pose)
87 | quat = t3d.matrix_to_quaternion(pose)
88 | return t3d.quaternion_to_axis_angle(quat)
89 | # reshaped_input = pose.reshape(-1, 3, 3)
90 | # quat = t3d.matrix_to_quaternion(reshaped_input)
91 |
92 | def d62aa(pose):
93 | pose = to_tensor(pose)
94 | return rotmat2aa(d62rotmat(pose))
95 |
96 | def aa2d6(pose):
97 | pose = to_tensor(pose)
98 | return rotmat2d6(aa2rotmat(pose))
99 |
100 | def euler(rots, order='xyz', units='deg'):
101 |
102 | rots = np.asarray(rots)
103 | single_val = False if len(rots.shape)>1 else True
104 | rots = rots.reshape(-1,3)
105 | rotmats = []
106 |
107 | for xyz in rots:
108 | if units == 'deg':
109 | xyz = np.radians(xyz)
110 | r = np.eye(3)
111 | for theta, axis in zip(xyz,order):
112 | c = np.cos(theta)
113 | s = np.sin(theta)
114 | if axis=='x':
115 | r = np.dot(np.array([[1,0,0],[0,c,-s],[0,s,c]]), r)
116 | if axis=='y':
117 | r = np.dot(np.array([[c,0,s],[0,1,0],[-s,0,c]]), r)
118 | if axis=='z':
119 | r = np.dot(np.array([[c,-s,0],[s,c,0],[0,0,1]]), r)
120 | rotmats.append(r)
121 | rotmats = np.stack(rotmats).astype(np.float32)
122 | if single_val:
123 | return rotmats[0]
124 | else:
125 | return rotmats
126 |
127 | def batch_euler_to_rotmat(bxyz, order='xyz', units='deg'):
128 | br = []
129 | for frame in range(bxyz.shape[0]):
130 | # rotmat = euler(bxyz[frame], order, units)
131 | r1 = Rotation.from_euler('xyz', np.array(bxyz[frame]), degrees=True)
132 | rotmat = r1.as_matrix()
133 | br.append(rotmat)
134 | return np.stack(br).astype(np.float32)
135 |
136 | def batch_rotmat_to_euler(rotmat, order='ZYX'):
137 |
138 | # Convert to Euler angles and permute last dimension from ZYX to XYZ to match data order
139 | eu = t3d.matrix_to_euler_angles(rotmat, order)[..., [2, 1, 0]]
140 | return eu
141 |
142 | def batch_euler_to_6d(bxyz, order='xyz', units='deg'):
143 | br = []
144 | for frame in range(bxyz.shape[0]):
145 | # rotmat = euler(bxyz[frame], order, units)
146 | r1 = Rotation.from_euler('xyz', np.array(bxyz[frame]), degrees=True)
147 | rotmat = r1.as_matrix()
148 | d6 = rotmat2d6(rotmat)
149 | br.append(d6)
150 | return np.stack(br).astype(np.float32)
151 |
152 | def batch_6d_to_euler(bxyz, order='XYZ'):
153 | br = []
154 |
155 | for batch in range(bxyz.shape[0]):
156 | br_ = []
157 | for frame in range(bxyz.shape[1]):
158 | # rotmat = t3d.rotation_6d_to_matrix(bxyz[batch, frame])
159 | rotmat = d62rotmat(bxyz[batch, frame])
160 | r = Rotation.from_matrix(np.array(rotmat))
161 | eu = r.as_euler("xyz", degrees=True)
162 | br_.append(np.array(eu))
163 | br.append(np.stack(br_).astype(np.float32))
164 | return np.stack(br).astype(np.float32)
165 |
166 |
167 | def batch_6d_to_euler_tensor(bxyz, order='ZYX'):
168 | rotmat = t3d.rotation_6d_to_matrix(bxyz)
169 | # Convert to Euler angles and permute last dimension from ZYX to XYZ to match data order
170 | eu = t3d.matrix_to_euler_angles(rotmat, order)[..., [2, 1, 0]]
171 | return eu
172 |
173 |
174 | def rotate(points,R):
175 | shape = list(points.shape)
176 | points = to_tensor(points)
177 | R = to_tensor(R)
178 | if len(shape)>3:
179 | points = points.squeeze()
180 | if len(shape)<3:
181 | points = points.unsqueeze(dim=1)
182 | if R.shape[0] > shape[0]:
183 | shape[0] = R.shape[0]
184 | r_points = torch.matmul(points, R.transpose(1,2))
185 | return r_points.reshape(shape)
186 |
187 | def rotmul(rotmat,R):
188 | if rotmat.ndim>3:
189 | rotmat = to_tensor(rotmat).squeeze()
190 | if R.ndim>3:
191 | R = to_tensor(R).squeeze()
192 | rot = torch.matmul(rotmat, R)
193 | return rot
194 |
195 |
196 | smplx_parents =[-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14,
197 | 16, 17, 18, 19, 15, 15, 15, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34,
198 | 35, 20, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50, 21, 52,
199 | 53]
200 | def smplx_loc2glob(local_pose):
201 |
202 | bs = local_pose.shape[0]
203 | local_pose = local_pose.view(bs, -1, 3, 3)
204 | global_pose = local_pose.clone()
205 |
206 | for i in range(1,len(smplx_parents)):
207 | global_pose[:,i] = torch.matmul(global_pose[:, smplx_parents[i]], global_pose[:, i].clone())
208 |
209 | return global_pose.reshape(bs,-1,3,3)
210 |
211 | def rot2eul(R):
212 | beta = -np.arcsin(R[2,0])
213 | alpha = np.arctan2(R[2,1]/np.cos(beta),R[2,2]/np.cos(beta))
214 | gamma = np.arctan2(R[1,0]/np.cos(beta),R[0,0]/np.cos(beta))
215 | return np.array((alpha, beta, gamma))
216 |
217 | def eul2rot(theta) :
218 |
219 | R = np.array([[np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]), np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])],
220 | [np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]), np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])],
221 | [-np.sin(theta[1]), np.sin(theta[0])*np.cos(theta[1]), np.cos(theta[0])*np.cos(theta[1])]])
222 |
223 | return R
224 |
225 | if __name__ == "__main__":
226 | euler_angles = np.array([0.3, -0.5, 0.7], dtype=np.float32)
227 | euler2matrix = t3d.euler_angles_to_matrix(torch.from_numpy(euler_angles), 'XYZ')
228 | matrix2euler = t3d.matrix_to_euler_angles(euler2matrix, 'XYZ')
229 | w = Rotation.from_euler('xyz', euler_angles, degrees=False)
230 | rotmat = w.as_matrix()
231 | r = Rotation.from_matrix(np.array(euler2matrix))
232 | eu = r.as_euler("xyz", degrees=False)
233 | angrot = eul2rot(euler_angles)
234 | print()
--------------------------------------------------------------------------------
/src/tools/common/skeleton.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 | sys.path.append('..')
4 | from src.tools.common.quaternion import *
5 | import scipy.ndimage.filters as filters
6 |
7 | class Skeleton(object):
8 | def __init__(self, offset, kinematic_tree, device):
9 | self.device = device
10 | self._raw_offset_np = offset.numpy()
11 | self._raw_offset = offset.clone().detach().to(device).float()
12 | self._kinematic_tree = kinematic_tree
13 | self._offset = None
14 | self._parents = [0] * len(self._raw_offset)
15 | self._parents[0] = -1
16 | for chain in self._kinematic_tree:
17 | for j in range(1, len(chain)):
18 | self._parents[chain[j]] = chain[j-1]
19 |
20 | def njoints(self):
21 | return len(self._raw_offset)
22 |
23 | def offset(self):
24 | return self._offset
25 |
26 | def set_offset(self, offsets):
27 | self._offset = offsets.clone().detach().to(self.device).float()
28 |
29 | def kinematic_tree(self):
30 | return self._kinematic_tree
31 |
32 | def parents(self):
33 | return self._parents
34 |
35 | # joints (batch_size, joints_num, 3)
36 | def get_offsets_joints_batch(self, joints):
37 | assert len(joints.shape) == 3
38 | _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
39 | for i in range(1, self._raw_offset.shape[0]):
40 | _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
41 |
42 | self._offset = _offsets.detach()
43 | return _offsets
44 |
45 | # joints (joints_num, 3)
46 | def get_offsets_joints(self, joints):
47 | assert len(joints.shape) == 2
48 | _offsets = self._raw_offset.clone()
49 | for i in range(1, self._raw_offset.shape[0]):
50 | # print(joints.shape)
51 | _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
52 |
53 | self._offset = _offsets.detach()
54 | return _offsets
55 |
56 | # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
57 | # joints (batch_size, joints_num, 3)
58 | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
59 | assert len(face_joint_idx) == 4
60 | '''Get Forward Direction'''
61 | l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
62 | across1 = joints[:, r_hip] - joints[:, l_hip]
63 | across2 = joints[:, sdr_r] - joints[:, sdr_l]
64 | across = across1 + across2
65 | across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
66 | # print(across1.shape, across2.shape)
67 |
68 | # forward (batch_size, 3)
69 | forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
70 | if smooth_forward:
71 | forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
72 | # forward (batch_size, 3)
73 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
74 |
75 | '''Get Root Rotation'''
76 | target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
77 | root_quat = qbetween_np(forward, target)
78 |
79 | '''Inverse Kinematics'''
80 | # quat_params (batch_size, joints_num, 4)
81 | # print(joints.shape[:-1])
82 | quat_params = np.zeros(joints.shape[:-1] + (4,))
83 | # print(quat_params.shape)
84 | root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
85 | quat_params[:, 0] = root_quat
86 | # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
87 | for chain in self._kinematic_tree:
88 | R = root_quat
89 | for j in range(len(chain) - 1):
90 | # (batch, 3)
91 | u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
92 | # print(u.shape)
93 | # (batch, 3)
94 | v = joints[:, chain[j+1]] - joints[:, chain[j]]
95 | v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
96 | # print(u.shape, v.shape)
97 | rot_u_v = qbetween_np(u, v)
98 |
99 | R_loc = qmul_np(qinv_np(R), rot_u_v)
100 |
101 | quat_params[:,chain[j + 1], :] = R_loc
102 | R = qmul_np(R, R_loc)
103 |
104 | return quat_params
105 |
106 | # Be sure root joint is at the beginning of kinematic chains
107 | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
108 | # quat_params (batch_size, joints_num, 4)
109 | # joints (batch_size, joints_num, 3)
110 | # root_pos (batch_size, 3)
111 | if skel_joints is not None:
112 | offsets = self.get_offsets_joints_batch(skel_joints)
113 | if len(self._offset.shape) == 2:
114 | offsets = self._offset.expand(quat_params.shape[0], -1, -1)
115 | joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
116 | joints[:, 0] = root_pos
117 | for chain in self._kinematic_tree:
118 | if do_root_R:
119 | R = quat_params[:, 0]
120 | else:
121 | R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
122 | for i in range(1, len(chain)):
123 | R = qmul(R, quat_params[:, chain[i]])
124 | offset_vec = offsets[:, chain[i]]
125 | joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
126 | return joints
127 |
128 | # Be sure root joint is at the beginning of kinematic chains
129 | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
130 | # quat_params (batch_size, joints_num, 4)
131 | # joints (batch_size, joints_num, 3)
132 | # root_pos (batch_size, 3)
133 | if skel_joints is not None:
134 | skel_joints = torch.from_numpy(skel_joints)
135 | offsets = self.get_offsets_joints_batch(skel_joints)
136 | if len(self._offset.shape) == 2:
137 | offsets = self._offset.expand(quat_params.shape[0], -1, -1)
138 | offsets = offsets.numpy()
139 | joints = np.zeros(quat_params.shape[:-1] + (3,))
140 | joints[:, 0] = root_pos
141 | for chain in self._kinematic_tree:
142 | if do_root_R:
143 | R = quat_params[:, 0]
144 | else:
145 | R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
146 | for i in range(1, len(chain)):
147 | R = qmul_np(R, quat_params[:, chain[i]])
148 | offset_vec = offsets[:, chain[i]]
149 | joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
150 | return joints
151 |
152 | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
153 | # cont6d_params (batch_size, joints_num, 6)
154 | # joints (batch_size, joints_num, 3)
155 | # root_pos (batch_size, 3)
156 | if skel_joints is not None:
157 | skel_joints = torch.from_numpy(skel_joints)
158 | offsets = self.get_offsets_joints_batch(skel_joints)
159 | if len(self._offset.shape) == 2:
160 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
161 | offsets = offsets.numpy()
162 | joints = np.zeros(cont6d_params.shape[:-1] + (3,))
163 | joints[:, 0] = root_pos
164 | for chain in self._kinematic_tree:
165 | if do_root_R:
166 | matR = cont6d_to_matrix_np(cont6d_params[:, 0])
167 | else:
168 | matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
169 | for i in range(1, len(chain)):
170 | matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
171 | offset_vec = offsets[:, chain[i]][..., np.newaxis]
172 | # print(matR.shape, offset_vec.shape)
173 | joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
174 | return joints
175 |
176 | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
177 | # cont6d_params (batch_size, joints_num, 6)
178 | # joints (batch_size, joints_num, 3)
179 | # root_pos (batch_size, 3)
180 | if skel_joints is not None:
181 | # skel_joints = torch.from_numpy(skel_joints)
182 | offsets = self.get_offsets_joints_batch(skel_joints)
183 | if len(self._offset.shape) == 2:
184 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
185 | joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
186 | joints[..., 0, :] = root_pos
187 | for chain in self._kinematic_tree:
188 | if do_root_R:
189 | matR = cont6d_to_matrix(cont6d_params[:, 0])
190 | else:
191 | matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
192 | for i in range(1, len(chain)):
193 | matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
194 | offset_vec = offsets[:, chain[i]].unsqueeze(-1)
195 | # print(matR.shape, offset_vec.shape)
196 | joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
197 | return joints
198 |
199 |
200 |
201 |
202 |
203 |
--------------------------------------------------------------------------------
/src/Ninjutsu/argUtils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import sys
4 | import os
5 | from ast import literal_eval
6 |
7 | def get_args_update_dict(args):
8 | args_update_dict = {}
9 | for string in sys.argv:
10 | string = ''.join(string.split('-'))
11 | if string in args:
12 | args_update_dict.update({string: args.__dict__[string]})
13 | return args_update_dict
14 |
15 |
16 | def argparseNloop():
17 | parser = argparse.ArgumentParser()
18 |
19 | '''Directories and data path'''
20 | parser.add_argument('--work-dir', default = os.path.join('src', 'Ninjutsu'), type=str,
21 | help='The path to the downloaded data')
22 | parser.add_argument('--data-path', default = os.path.join('..', 'DATASETS', 'Ninjutsu_Data'), type=str,
23 | help='The path to the folder that contains dataset before pre-processing')
24 | parser.add_argument('--model_path', default = 'smplx_model', type=str,
25 | help='The path to the folder containing SMPLX model')
26 | parser.add_argument('--save_dir', default = os.path.join('save', 'Ninjutsu', 'diffusion'), type=str,
27 | help='The path to the folder to save the processed data')
28 | parser.add_argument('--render_path', default = os.path.join('render', 'Ninjutsu'), type=str,
29 | help='The path to the folder to save the rendered output')
30 | parser.add_argument('--data_dir', default = os.path.join('data', 'Ninjutsu'), type=str,
31 | help='The path to the pre-processed data')
32 |
33 |
34 | '''Dataset Parameters'''
35 | parser.add_argument('-dataset', nargs='+', type=str, default='NinjutsuDataset',
36 | help='name of the dataset')
37 | parser.add_argument('--frames', nargs='+', type=int, default=50,
38 | help='Number of frames taken from each sequence in the dataset for training.')
39 | parser.add_argument('-seedLength', nargs='+', type=int, default=20,
40 | help='initial length of inputs to seed the prediction; used when offset > 0')
41 | parser.add_argument('-exp', nargs='+', type=int, default=0,
42 | help='experiment number')
43 | parser.add_argument('-scale', nargs='+', type=int, default=1000.0,
44 | help='Data scale by this factor')
45 | parser.add_argument('-framerate', nargs='+', type=int, default=20,
46 | help='frame rate after pre-processing.')
47 | parser.add_argument('-seed', nargs='+', type=int, default=4815,
48 | help='manual seed')
49 | parser.add_argument('-load', nargs='+', type=str, default=None,
50 | help='Load weights from this file')
51 | parser.add_argument('-cuda', nargs='+', type=int, default=0,
52 | help='choice of gpu device, -1 for cpu')
53 | parser.add_argument('-overfit', nargs='+', type=int, default=0,
54 | help='disables early stopping and saves models even if the dev loss increases. useful for performing an overfitting check')
55 |
56 | '''Diffusion parameters'''
57 | parser.add_argument("--noise_schedule", default='linear', choices=['linear', 'cosine', 'sigmoid'], type=str,
58 | help="Noise schedule type")
59 | parser.add_argument("--diffusion_steps", default=300, type=int,
60 | help="Number of diffusion steps (denoted T in the paper)")
61 | parser.add_argument("--sampler", default='uniform', type=str,
62 | help="Create a Schedule Sampler")
63 |
64 |
65 | '''Diffusion transformer model parameters'''
66 | parser.add_argument('-model', nargs='+', type=str, default='DiffusionTransformer',
67 | help='name of model')
68 | parser.add_argument('-input_feats', nargs='+', type=int, default=3,
69 | help='number of input features ')
70 | parser.add_argument('-out_feats', nargs='+', type=int, default=3,
71 | help='number of output features ')
72 | parser.add_argument('--jt_latent', nargs='+', type=int, default=32,
73 | help='dimensionality of last dimension after GCN')
74 | parser.add_argument('--d_model', nargs='+', type=int, default=256,
75 | help='dimensionality of model embeddings')
76 | parser.add_argument('--d_ff', nargs='+', type=int, default=512,
77 | help='dimensionality of the inner layer in the feed-forward network')
78 | parser.add_argument('--num_layer', nargs='+', type=int, default=6,
79 | help='number of layers in encoder-decoder of model')
80 | parser.add_argument('--num_head', nargs='+', type=int, default=4,
81 | help='number of attention heads in the multi-head attention mechanism.')
82 | parser.add_argument("--activations", default='LeakyReLU', choices=['LeakyReLU', 'SiLU', 'GELU'], type=str,
83 | help="Activation function")
84 | '''Diffusion transformer hand model parameters'''
85 | parser.add_argument('-hand_input_condn_feats', nargs='+', type=int, default=280,
86 | help='number of input features ')
87 | parser.add_argument('-hand_out_feats', nargs='+', type=int, default=3,
88 | help='number of output features ')
89 | parser.add_argument('--d_modelhand', nargs='+', type=int, default=256,
90 | help='dimensionality of model embeddings')
91 | parser.add_argument('--d_ffhand', nargs='+', type=int, default=512,
92 | help='dimensionality of the inner layer in the feed-forward network')
93 | parser.add_argument('--num_layer_hands', nargs='+', type=int, default=6,
94 | help='number of layers in encoder-decoder of model')
95 | parser.add_argument('--num_head_hands', nargs='+', type=int, default=4,
96 | help='number of attention heads in the multi-head attention mechanism.')
97 |
98 |
99 | '''Training parameters'''
100 | parser.add_argument('-batch_size', nargs='+', type=int, default=32,
101 | help='minibatch size.')
102 | parser.add_argument('-num_epochs', nargs='+', type=int, default=5000,
103 | help='number of epochs for training')
104 | parser.add_argument('--skip_train', nargs='+', type=int, default=1,
105 | help='downsampling factor of the training dataset. For example, a value of s indicates floor(D/s) training samples are loaded, '
106 | 'where D is the total number of training samples (default: 1).')
107 | parser.add_argument('--skip_val', nargs='+', type=int, default=1,
108 | help='downsampling factor of the validation dataset. For example, a value of s indicates floor(D/s) validation samples are loaded, '
109 | 'where D is the total number of validation samples (default: 1).')
110 | parser.add_argument('-early_stopping', nargs='+', type=int, default=0,
111 | help='Use 1 for early stopping')
112 | parser.add_argument('--n_workers', default=0, type=int,
113 | help='Number of PyTorch dataloader workers')
114 | parser.add_argument('-greedy_save', nargs='+', type=int, default=1,
115 | help='save weights after each epoch if 1')
116 | parser.add_argument('-save_model', nargs='+', type=int, default=1,
117 | help='flag to save model at every step')
118 | parser.add_argument('-stop_thresh', nargs='+', type=int, default=3,
119 | help='number of consequetive validation loss increses before stopping')
120 | parser.add_argument('-eps', nargs='+', type=float, default=0,
121 | help='if the decrease in validation is less than eps, it counts for one step in stop_thresh ')
122 | parser.add_argument('--curriculum', nargs='+', type=int, default=0,
123 | help='if 1, learn generating time steps by starting with 2 timesteps upto time, increasing by a power of 2')
124 | parser.add_argument('--use-multigpu', default=False,
125 | type=lambda arg: arg.lower() in ['true', '1'],
126 | help='If to use multiple GPUs for training')
127 | parser.add_argument('--load-on-ram', default=False,
128 | type=lambda arg: arg.lower() in ['true', '1'],
129 | help='This will load all the data on the RAM memory for faster training.'
130 | 'If your RAM capacity is more than 40 Gb, consider using this.')
131 |
132 | '''Optimizer parameters'''
133 | parser.add_argument('--optimizer', default='optim.Adam', type=str,
134 | help='Optimizer')
135 | parser.add_argument('-momentum', default=0.9, type=float,
136 | help='Weight decay for SGD Optimizer')
137 | parser.add_argument('-lr', nargs='+', type=float, default=1e-5,
138 | help='learning rate')
139 |
140 | '''Scheduler parameters'''
141 | parser.add_argument('--scheduler', default='torch.optim.lr_scheduler.StepLR', type=str,
142 | help='Scheduler')
143 | parser.add_argument('--patience', default=3, type=float,
144 | help='Step size for ReduceOnPlateau scheduler')
145 | parser.add_argument('--factor', default=0.99, type=float,
146 | help='Decay rate for ReduceOnPlateau scheduler')
147 | parser.add_argument('--threshold', default=0.05, type=float,
148 | help='THreshold for ReduceOnPlateau scheduler')
149 |
150 | parser.add_argument('--stepsize', default=5, type=float,
151 | help='Step size for StepLR scheduler')
152 | parser.add_argument('--gamma', default=0.99, type=float,
153 | help='Decay rate for StepLR scheduler')
154 | parser.add_argument('--milestones', default=[50, 100], type=float,
155 | help='List of epoch indices. Must be increasing for MultiStepLR scheduler')
156 | '''Loss parameters'''
157 | parser.add_argument('--lambda_loss', type=dict, default=None,
158 | help='weight of loss for VAE')
159 |
160 |
161 | args, unknown = parser.parse_known_args()
162 | return args
163 |
--------------------------------------------------------------------------------
/src/Lindyhop/argUtils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import sys
4 | import os
5 | from ast import literal_eval
6 |
7 | def get_args_update_dict(args):
8 | args_update_dict = {}
9 | for string in sys.argv:
10 | string = ''.join(string.split('-'))
11 | if string in args:
12 | args_update_dict.update({string: args.__dict__[string]})
13 | return args_update_dict
14 |
15 |
16 | def argparseNloop():
17 | parser = argparse.ArgumentParser()
18 |
19 | '''Directories and data path'''
20 | parser.add_argument('--work-dir', default = os.path.join('src', 'Lindyhop'), type=str,
21 | help='The path to the downloaded data')
22 | parser.add_argument('--data-path', default = os.path.join('..', '..', 'DATASETS', 'LindyHop'), type=str,
23 | help='The path to the folder that contains dataset before pre-processing')
24 | parser.add_argument('--model_path', default = 'smplx_model', type=str,
25 | help='The path to the folder containing SMPLX model')
26 | parser.add_argument('--save_dir', default = os.path.join('save', 'Lindyhop', 'diffusion'), type=str,
27 | help='The path to the folder to save the processed data')
28 | parser.add_argument('--render_path', default = os.path.join('render', 'Lindyhop'), type=str,
29 | help='The path to the folder to save the rendered output')
30 | parser.add_argument('--data_dir', default = os.path.join('data', 'Lindyhop'), type=str,
31 | help='The path to the pre-processed data')
32 |
33 |
34 | '''Dataset Parameters'''
35 | parser.add_argument('-dataset', nargs='+', type=str, default='LindyHopDataset',
36 | help='name of the dataset')
37 | parser.add_argument('--frames', nargs='+', type=int, default=20,
38 | help='Number of frames taken from each sequence in the dataset for training.')
39 | parser.add_argument('-seedLength', nargs='+', type=int, default=20,
40 | help='initial length of inputs to seed the prediction; used when offset > 0')
41 | parser.add_argument('-exp', nargs='+', type=int, default=0,
42 | help='experiment number')
43 | parser.add_argument('-scale', nargs='+', type=int, default=1000.0,
44 | help='Data scale by this factor')
45 | parser.add_argument('-framerate', nargs='+', type=int, default=20,
46 | help='frame rate after pre-processing.')
47 | parser.add_argument('-seed', nargs='+', type=int, default=4815,
48 | help='manual seed')
49 | parser.add_argument('-load', nargs='+', type=str, default=None,
50 | help='Load weights from this file')
51 | parser.add_argument('-cuda', nargs='+', type=int, default=0,
52 | help='choice of gpu device, -1 for cpu')
53 | parser.add_argument('-overfit', nargs='+', type=int, default=0,
54 | help='disables early stopping and saves models even if the dev loss increases. useful for performing an overfitting check')
55 |
56 | '''Diffusion parameters'''
57 | parser.add_argument("--noise_schedule", default='linear', choices=['linear', 'cosine', 'sigmoid'], type=str,
58 | help="Noise schedule type")
59 | parser.add_argument("--diffusion_steps", default=500, type=int,
60 | help="Number of diffusion steps (denoted T in the paper)")
61 | parser.add_argument("--sampler", default='uniform', type=str,
62 | help="Create a Schedule Sampler")
63 |
64 |
65 | '''Diffusion transformer model parameters'''
66 | parser.add_argument('-model', nargs='+', type=str, default='DiffusionTransformer',
67 | help='name of model')
68 | parser.add_argument('-input_feats', nargs='+', type=int, default=3,
69 | help='number of input features ')
70 | parser.add_argument('-out_feats', nargs='+', type=int, default=3,
71 | help='number of output features ')
72 | parser.add_argument('--jt_latent', nargs='+', type=int, default=32,
73 | help='dimensionality of last dimension after GCN')
74 | parser.add_argument('--d_model', nargs='+', type=int, default=256,
75 | help='dimensionality of model embeddings')
76 | parser.add_argument('--d_ff', nargs='+', type=int, default=512,
77 | help='dimensionality of the inner layer in the feed-forward network')
78 | parser.add_argument('--num_layer', nargs='+', type=int, default=6,
79 | help='number of layers in encoder-decoder of model')
80 | parser.add_argument('--num_head', nargs='+', type=int, default=4,
81 | help='number of attention heads in the multi-head attention mechanism.')
82 | parser.add_argument("--activations", default='LeakyReLU', choices=['LeakyReLU', 'SiLU', 'GELU'], type=str,
83 | help="Activation function")
84 |
85 | '''Diffusion transformer hand model parameters'''
86 | parser.add_argument('-hand_input_condn_feats', nargs='+', type=int, default=280,
87 | help='number of input features ')
88 | parser.add_argument('-hand_out_feats', nargs='+', type=int, default=3,
89 | help='number of output features ')
90 | parser.add_argument('--d_modelhand', nargs='+', type=int, default=256,
91 | help='dimensionality of model embeddings')
92 | parser.add_argument('--d_ffhand', nargs='+', type=int, default=512,
93 | help='dimensionality of the inner layer in the feed-forward network')
94 | parser.add_argument('--num_layer_hands', nargs='+', type=int, default=6,
95 | help='number of layers in encoder-decoder of model')
96 | parser.add_argument('--num_head_hands', nargs='+', type=int, default=4,
97 | help='number of attention heads in the multi-head attention mechanism.')
98 |
99 |
100 | '''Training parameters'''
101 | parser.add_argument('-batch_size', nargs='+', type=int, default=32,
102 | help='minibatch size.')
103 | parser.add_argument('-num_epochs', nargs='+', type=int, default=300,
104 | help='number of epochs for training')
105 | parser.add_argument('--skip_train', nargs='+', type=int, default=1,
106 | help='downsampling factor of the training dataset. For example, a value of s indicates floor(D/s) training samples are loaded, '
107 | 'where D is the total number of training samples (default: 1).')
108 | parser.add_argument('--skip_val', nargs='+', type=int, default=1,
109 | help='downsampling factor of the validation dataset. For example, a value of s indicates floor(D/s) validation samples are loaded, '
110 | 'where D is the total number of validation samples (default: 1).')
111 | parser.add_argument('-early_stopping', nargs='+', type=int, default=0,
112 | help='Use 1 for early stopping')
113 | parser.add_argument('--n_workers', default=0, type=int,
114 | help='Number of PyTorch dataloader workers')
115 | parser.add_argument('-greedy_save', nargs='+', type=int, default=1,
116 | help='save weights after each epoch if 1')
117 | parser.add_argument('-save_model', nargs='+', type=int, default=1,
118 | help='flag to save model at every step')
119 | parser.add_argument('-stop_thresh', nargs='+', type=int, default=3,
120 | help='number of consequetive validation loss increses before stopping')
121 | parser.add_argument('-eps', nargs='+', type=float, default=0,
122 | help='if the decrease in validation is less than eps, it counts for one step in stop_thresh ')
123 | parser.add_argument('--curriculum', nargs='+', type=int, default=0,
124 | help='if 1, learn generating time steps by starting with 2 timesteps upto time, increasing by a power of 2')
125 | parser.add_argument('--use-multigpu', default=False,
126 | type=lambda arg: arg.lower() in ['true', '1'],
127 | help='If to use multiple GPUs for training')
128 | parser.add_argument('--load-on-ram', default=False,
129 | type=lambda arg: arg.lower() in ['true', '1'],
130 | help='This will load all the data on the RAM memory for faster training.'
131 | 'If your RAM capacity is more than 40 Gb, consider using this.')
132 |
133 | '''Optimizer parameters'''
134 | parser.add_argument('--optimizer', default='optim.Adam', type=str,
135 | help='Optimizer')
136 | parser.add_argument('-momentum', default=0.9, type=float,
137 | help='Weight decay for SGD Optimizer')
138 | parser.add_argument('-lr', nargs='+', type=float, default=1e-5,
139 | help='learning rate')
140 |
141 | '''Scheduler parameters'''
142 | parser.add_argument('--scheduler', default='torch.optim.lr_scheduler.StepLR', type=str,
143 | help='Scheduler')
144 | parser.add_argument('--patience', default=3, type=float,
145 | help='Step size for ReduceOnPlateau scheduler')
146 | parser.add_argument('--factor', default=0.99, type=float,
147 | help='Decay rate for ReduceOnPlateau scheduler')
148 | parser.add_argument('--threshold', default=0.05, type=float,
149 | help='THreshold for ReduceOnPlateau scheduler')
150 |
151 | parser.add_argument('--stepsize', default=5, type=float,
152 | help='Step size for StepLR scheduler')
153 | parser.add_argument('--gamma', default=0.99, type=float,
154 | help='Decay rate for StepLR scheduler')
155 | parser.add_argument('--milestones', default=[50, 100], type=float,
156 | help='List of epoch indices. Must be increasing for MultiStepLR scheduler')
157 | '''Loss parameters'''
158 | parser.add_argument('--lambda_loss', type=dict, default=None,
159 | help='weight of loss for VAE')
160 |
161 |
162 | args, unknown = parser.parse_known_args()
163 | return args
164 |
--------------------------------------------------------------------------------
/src/Lindyhop/process_LindyHop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # import torch
3 | import os
4 | import glob
5 | import sys
6 | sys.path.append('.')
7 | sys.path.append('..')
8 | import pickle
9 | from src.tools.transformations import batch_euler_to_rotmat
10 |
11 | def makepath(desired_path, isfile = False):
12 | '''
13 | if the path does not exist make it
14 | :param desired_path: can be path to a file or a folder name
15 | :return:
16 | '''
17 | import os
18 | if isfile:
19 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
20 | else:
21 | if not os.path.exists(desired_path): os.makedirs(desired_path)
22 | return desired_path
23 |
24 | class PreProcessor():
25 | def __init__(self, root_dir, fps=20, split='train'):
26 | self.root = root_dir
27 | self.framerate = fps
28 | self.root = os.path.join(root_dir, split)
29 | seq = os.listdir(self.root)
30 | self.sequences = [int(x) for x in seq]
31 | self.total_frames = 0
32 | self.total_contact_frames = 0
33 | self.annot_dict = {
34 | 'cam': [],
35 | 'seq': [], 'contacts': [],
36 | 'pose_canon_1':[], 'pose_canon_2':[],
37 | 'dofs_1': [], 'dofs_2': [],
38 | 'rotmat_1': [], 'rotmat_2': [],
39 | 'offsets_1': [], 'offsets_2': []
40 | }
41 |
42 | self.bvh_joint_order = {
43 | 'Hips': 0,
44 | 'RightUpLeg': 1,
45 | 'RightLeg': 2,
46 | 'RightFoot': 3,
47 | 'RightToeBase': 4,
48 | 'RightToeBaseEnd': 5,
49 | 'LeftUpLeg': 6,
50 | 'LeftLeg': 7,
51 | 'LeftFoot': 8,
52 | 'LeftToeBase': 9,
53 | 'LeftToeBaseEnd': 10,
54 | 'Spine': 11,
55 | 'Spine1': 12,
56 | 'Spine2': 13,
57 | 'Spine3': 14,
58 | 'RightShoulder': 15,
59 | 'RightArm': 16,
60 | 'RightForeArm': 17,
61 | 'RightHand': 18,
62 | 'RightHandEnd': 19,
63 | 'RightHandPinky1': 20,
64 | 'RightHandPinky2': 21,
65 | 'RightHandPinky3': 22,
66 | 'RightHandPinky3End': 23,
67 | 'RightHandRing1': 24,
68 | 'RightHandRing2': 25,
69 | 'RightHandRing3': 26,
70 | 'RightHandRing3End': 27,
71 | 'RightHandMiddle1': 28,
72 | 'RightHandMiddle2': 29,
73 | 'RightHandMiddle3': 30,
74 | 'RightHandMiddle3End': 31,
75 | 'RightHandIndex1': 32,
76 | 'RightHandIndex2': 33,
77 | 'RightHandIndex3': 34,
78 | 'RightHandIndex3End': 35,
79 | 'RightHandThumb1': 36,
80 | 'RightHandThumb2': 37,
81 | 'RightHandThumb3': 38,
82 | 'RightHandThumb3End': 39,
83 | 'LeftShoulder': 40,
84 | 'LeftArm': 41,
85 | 'LeftForeArm': 42,
86 | 'LeftHand': 43,
87 | 'LeftHandEnd': 44,
88 | 'LeftHandPinky1': 45,
89 | 'LeftHandPinky2': 46,
90 | 'LeftHandPinky3': 47,
91 | 'LeftHandPinky3End': 48,
92 | 'LeftHandRing1': 49,
93 | 'LeftHandRing2': 50,
94 | 'LeftHandRing3': 51,
95 | 'LeftHandRing3End': 52,
96 | 'LeftHandMiddle1': 53,
97 | 'LeftHandMiddle2': 54,
98 | 'LeftHandMiddle3': 55,
99 | 'LeftHandMiddle3End': 56,
100 | 'LeftHandIndex1': 57,
101 | 'LeftHandIndex2': 58,
102 | 'LeftHandIndex3': 59,
103 | 'LeftHandIndex3End': 60,
104 | 'LeftHandThumb1': 61,
105 | 'LeftHandThumb2': 62,
106 | 'LeftHandThumb3': 63,
107 | 'LeftHandThumb3End': 64,
108 | 'Spine4': 65,
109 | 'Neck': 66,
110 | 'Head': 67,
111 | 'HeadEnd': 68
112 | }
113 |
114 | print("creating the annot file")
115 | self.collate_videos()
116 | self.save_annot(split)
117 |
118 | def detect_contact(self, motion1, motion2, thresh=50):
119 |
120 | contact_joints = ['Hand', 'HandEnd',
121 | 'HandPinky1', 'HandPinky2', 'HandPinky3', 'HandPinky3End',
122 | 'HandRing1', 'HandRing2', 'HandRing3','HandRing3End',
123 | 'HandIndex1', 'HandIndex2', 'HandIndex3','HandIndex3End',
124 | 'HandMiddle1', 'HandMiddle2', 'HandMiddle3','HandMiddle3End',
125 | 'HandThumb1', 'HandThumb2', 'HandThumb3','HandThumb3End']
126 |
127 | n_frames = motion1.shape[0]
128 |
129 | assert motion1.shape == motion2.shape
130 |
131 | ## 0 : no contact, 1: rh-rh, 2: lh-lh, 3: lh-rh , 4: rh-lh
132 | contact = np.zeros((n_frames, 5))
133 |
134 | def dist(x, y):
135 | return np.sqrt(np.sum((x - y)**2))
136 | contact_frames = []
137 |
138 | count = 0
139 | for i in range(n_frames):
140 | for s, sides in enumerate([['Right', 'Right'], ['Left', 'Left'], ['Left', 'Right'], ['Right', 'Left']]):
141 | for j, joint1 in enumerate(contact_joints):
142 | if contact[i, s+1] == 1:
143 | break
144 | for k, joint2 in enumerate(contact_joints):
145 | j1 = sides[0] + joint1
146 | j2 = sides[1] + joint2
147 |
148 | idx1 = self.bvh_joint_order[j1]
149 | idx2 = self.bvh_joint_order[j2]
150 |
151 | d = dist(motion1[i, idx1], motion2[i, idx2])
152 | if d <= thresh:
153 | contact[i, s+1] = 1
154 | contact_frames.append(i)
155 | count += 1
156 | break
157 |
158 |
159 | print(count)
160 | return contact, contact_frames
161 |
162 |
163 | def save_annot(self, split):
164 | save_path = makepath(os.path.join('data', 'LindyHop', split+'.pkl'), isfile=True)
165 | with open(save_path, 'wb') as f:
166 | pickle.dump(self.annot_dict, f)
167 |
168 |
169 | def _load_files(self, seq, p):
170 | file_basename = os.path.join(self.root, str(seq), str(p))
171 | path_canon_3d = os.path.join(file_basename, 'motion_worldpos.csv')
172 | path_dofs = os.path.join(file_basename, 'motion_rotations.csv')
173 | path_offsets = os.path.join(file_basename, 'motion_offsets.pkl')
174 |
175 | print(f"loading file {file_basename}")
176 | canon_3d = np.genfromtxt(path_canon_3d, delimiter=',', skip_header=1) # n_c*n_f x n_dof
177 | dofs = np.genfromtxt(path_dofs, delimiter=',', skip_header=1) # n_c*n_f x n_dof
178 | with open(path_offsets, 'rb') as f:
179 | offset_dict = pickle.load(f)
180 | print(f"loading complete")
181 |
182 | n_frames = canon_3d.shape[0]
183 | canon_3d = np.float32(canon_3d[:, 1:].reshape(n_frames, -1, 3))
184 | dofs = dofs[:, 1:].reshape(n_frames, -1, 3)
185 |
186 | #Downsample the data from 50 fps to given framerate
187 | use_frames = list(np.rint(np.arange(0, n_frames, 50/self.framerate)))
188 | use_frames = [int(a) for a in use_frames]
189 | canon_3d = canon_3d[use_frames]
190 | dofs = np.float32(dofs[use_frames])
191 | print(canon_3d.shape)
192 | return n_frames, canon_3d, dofs, offset_dict
193 |
194 |
195 | def collate_videos(self):
196 | self.annot_dict['bvh_joint_order'] = self.bvh_joint_order
197 | # self.annot_dict['joint_order'] = self.joint_order
198 | for i, seq in enumerate(self.sequences):
199 | seq_total_frames, canon_3d_1, dofs_1, offsets_1 = self._load_files(seq, 0)
200 | self.total_frames += seq_total_frames
201 | # continue
202 | _, canon_3d_2, dofs_2, offsets_2 = self._load_files(seq, 1)
203 | if canon_3d_2.shape[0] < canon_3d_1.shape[0]:
204 | n_frames = canon_3d_2.shape[0]
205 | else:
206 | n_frames = canon_3d_1.shape[0]
207 | canon_3d_1 = canon_3d_1[:n_frames]
208 | canon_3d_2 = canon_3d_2[:n_frames]
209 | contacts, contact_frames = self.detect_contact(canon_3d_1, canon_3d_2)
210 |
211 | n_frames_contact = len(contact_frames)
212 | self.total_contact_frames += n_frames_contact
213 | canon_3d_1 = canon_3d_1[contact_frames]
214 | canon_3d_2 = canon_3d_2[contact_frames]
215 | output_dofs_1 = dofs_1[contact_frames]
216 | output_dofs_2 = dofs_2[contact_frames]
217 | contacts = contacts[contact_frames, 1:]
218 | rotmat_1 = batch_euler_to_rotmat(output_dofs_1)
219 | rotmat_2 = batch_euler_to_rotmat(output_dofs_2)
220 | self.annot_dict['offsets_1'].extend([offsets_1 for i in range(0, n_frames_contact)])
221 | self.annot_dict['offsets_2'].extend([offsets_2 for i in range(0, n_frames_contact)])
222 | self.annot_dict['seq'].extend([seq for i in range(0, n_frames_contact)])
223 | self.annot_dict['pose_canon_1'].extend([canon_3d_1 for i in range(0, n_frames_contact)])
224 | self.annot_dict['pose_canon_2'].extend([canon_3d_2 for i in range(0, n_frames_contact )])
225 | self.annot_dict['contacts'].extend([contacts for i in range(0, n_frames_contact )])
226 | self.annot_dict['dofs_1'].extend([output_dofs_1 for i in range(0, n_frames_contact )])
227 | self.annot_dict['dofs_2'].extend([output_dofs_2 for i in range(0, n_frames_contact )])
228 | self.annot_dict['rotmat_1'].extend([rotmat_1 for i in range(0, n_frames_contact)])
229 | self.annot_dict['rotmat_2'].extend([rotmat_2 for i in range(0, n_frames_contact )])
230 |
231 | print(self.total_frames)
232 | print(self.total_contact_frames)
233 |
234 |
235 |
236 | if __name__ == "__main__":
237 | root_path = os.path.join('..', 'DATASETS', 'ReMocap', 'LindyHop')
238 | fps = 20
239 | pp = PreProcessor(root_path, fps, 'train')
240 | pp = PreProcessor(root_path, fps, 'test')
241 |
242 |
--------------------------------------------------------------------------------
/src/Lindyhop/train_VanillaTransformer.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import shutil
7 | import sys
8 | sys.path.append('.')
9 | sys.path.append('..')
10 | import time
11 | import torch
12 | torch.cuda.empty_cache()
13 | import torch.nn as nn
14 |
15 | from cmath import nan
16 | from collections import OrderedDict
17 | from datetime import datetime
18 | from torch import optim
19 | from torch.utils.data import DataLoader
20 | from tqdm import tqdm
21 |
22 | from src.Lindyhop.argUtils import argparseNloop
23 | from src.Lindyhop.LindyHop_dataloader import LindyHopDataset
24 | from src.Lindyhop.models.transAE import *
25 |
26 | from src.Lindyhop.skeleton import *
27 | from src.tools.bookkeeper import *
28 | from src.tools.transformations import *
29 | from src.tools.utils import makepath
30 |
31 | right_side = [15, 16, 17, 18]
32 | left_side = [19, 20, 21, 22]
33 | # stat_metrics = CalculateMetricsDanceData()
34 | def dist(x, y):
35 | # return torch.mean(x - y)
36 | return torch.mean(torch.cdist(x, y, p=2))
37 |
38 | def initialize_weights(m):
39 | std_dev = 0.02
40 | if isinstance(m, nn.Linear):
41 | nn.init.normal_(m.weight, std=std_dev)
42 | if m.bias is not None:
43 | nn.init.normal_(m.bias, std=std_dev)
44 | # nn.init.constant_(m.bias.data, 1e-5)
45 | elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
46 | torch.nn.init.normal_(m.weight, std=std_dev)
47 | if m.bias is not None:
48 | torch.nn.init.normal_(m.bias, std=std_dev)
49 | # nn.init.constant_(m.bias.data, 1e-5)
50 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
51 | nn.init.normal_(m.weight, std=std_dev)
52 | if m.bias is not None:
53 | nn.init.normal_(m.bias, std=std_dev)
54 |
55 | class Trainer:
56 | def __init__(self, args, is_train=True, split='test', JT_POSITION=False, num_jts = 69):
57 | torch.manual_seed(args.seed)
58 | self.model_path = args.model_path
59 | makepath(args.work_dir, isfile=False)
60 | use_cuda = torch.cuda.is_available()
61 | if use_cuda:
62 | torch.cuda.empty_cache()
63 | self.device = torch.device("cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")
64 | gpu_brand = torch.cuda.get_device_name(args.cuda) if use_cuda else None
65 | gpu_count = torch.cuda.device_count() if args.use_multigpu else 1
66 | print('Using %d CUDA cores [%s] for training!' % (gpu_count, gpu_brand))
67 | args_subset = ['exp', 'model', 'batch_size', 'frames']
68 | self.book = BookKeeper(args, args_subset)
69 | self.args = self.book.args
70 | self.batch_size = args.batch_size
71 | self.curriculum = args.curriculum
72 | self.scale = args.scale
73 | self.dtype = torch.float32
74 | self.epochs_completed = self.book.last_epoch
75 | self.frames = args.frames
76 | self.model = args.model
77 | self.testtime_split = split
78 | self.num_jts = num_jts
79 | self.model_pose = VanillaTransformer(args).to(self.device).float()
80 | trainable_count_body = sum(p.numel() for p in self.model_pose.parameters() if p.requires_grad)
81 | self.model_pose.apply(initialize_weights)
82 | self.optimizer_model_pose = eval(args.optimizer)(self.model_pose.parameters(), lr = args.lr)
83 | self.scheduler_pose = eval(args.scheduler)(self.optimizer_model_pose, step_size=args.stepsize, gamma=args.gamma)
84 | self.skel = InhouseStudioSkeleton()
85 |
86 | print(args.model, 'Model Created')
87 | if args.load:
88 | print('Loading Model', args.model)
89 | self.book._load_model(self.model_pose, 'model_pose')
90 | print('Loading the data')
91 | if is_train:
92 | self.load_data(args)
93 | else:
94 | self.load_data_testtime(args)
95 |
96 |
97 | def load_data_testtime(self, args):
98 | self.ds_data = LindyHopDataset(args, window_size=self.frames, split=self.testtime_split)
99 | self.load_ds_data = DataLoader(self.ds_data, batch_size=1, shuffle=False, num_workers=0, drop_last=True)
100 |
101 |
102 | def load_data(self, args):
103 |
104 | ds_train = LindyHopDataset(args, window_size=self.frames, split='train')
105 | self.ds_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
106 | print('Train set loaded. Size=', len(self.ds_train.dataset))
107 | ds_val = LindyHopDataset(args, window_size=self.frames, split='test')
108 | self.ds_val = DataLoader(ds_val, batch_size=1, shuffle=False, num_workers=0, drop_last=True)
109 | print('Validation set loaded. Size=', len(self.ds_val.dataset))
110 |
111 |
112 | def train(self, num_epoch, ablation=None):
113 | total_train_loss = 0.0
114 | self.model_pose.train()
115 | training_tqdm = tqdm(self.ds_train, desc='train' + ' {:.10f}'.format(0), leave=False, ncols=120)
116 | for count, batch in enumerate(training_tqdm):
117 | self.optimizer_model_pose.zero_grad()
118 | with torch.autograd.detect_anomaly():
119 | global_pose1 = batch['pose_canon_1'].to(self.device).float()
120 | global_pose1 = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order,
121 | new_joint_order=self.skel.body_only)
122 | global_pose2 = batch['pose_canon_2'].to(self.device).float()
123 | global_pose2 = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order,
124 | new_joint_order=self.skel.body_only)
125 |
126 | _, loss_model = self.model_pose(global_pose1, global_pose2)
127 | total_train_loss += loss_model.item()
128 |
129 | if loss_model == float('inf') or torch.isnan(loss_model):
130 | print('Train loss is nan')
131 | exit()
132 | loss_model.backward()
133 | torch.nn.utils.clip_grad_value_(self.model_pose.parameters(), 0.01)
134 | self.optimizer_model_pose.step()
135 |
136 | avg_train_loss = total_train_loss/(count + 1)
137 | return avg_train_loss
138 |
139 | def evaluate(self, num_epoch, ablation=None):
140 | total_eval_loss = 0.0
141 | self.model_pose.eval()
142 | T = self.frames
143 | eval_tqdm = tqdm(self.ds_val, desc='eval' + ' {:.10f}'.format(0), leave=False, ncols=120)
144 | for count, batch in enumerate(eval_tqdm):
145 | if True:
146 | global_pose1 = batch['pose_canon_1'].to(self.device).float()
147 | global_pose1 = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order,
148 | new_joint_order=self.skel.body_only)
149 | global_pose2 = batch['pose_canon_2'].to(self.device).float()
150 | global_pose2 = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order,
151 | new_joint_order=self.skel.body_only)
152 |
153 | _, loss_model = self.model_pose(global_pose1, global_pose2)
154 | total_eval_loss += loss_model.item()
155 |
156 | avg_eval_loss = total_eval_loss/(count + 1)
157 | return avg_eval_loss
158 |
159 | def fit(self, n_epochs=None, ablation=False):
160 | print('*****Inside Trainer.fit *****')
161 | if n_epochs is None:
162 | n_epochs = self.args.num_epochs
163 | starttime = datetime.now().replace(microsecond=0)
164 | print('Started Training at', datetime.strftime(starttime, '%Y-%m-%d_%H:%M:%S'), 'Total epochs: ', n_epochs)
165 | save_model_dict = {}
166 | best_eval = 1000
167 | for epoch_num in range(self.epochs_completed, n_epochs + 1):
168 | tqdm.write('--- starting Epoch # %03d' % epoch_num)
169 | train_loss = self.train(epoch_num, ablation)
170 |
171 | if epoch_num % 5 == 0:
172 | eval_loss = self.evaluate(epoch_num, ablation)
173 | else:
174 | eval_loss = 0.0
175 | self.scheduler_pose.step()
176 | self.book.update_res({'epoch': epoch_num, 'train': train_loss, 'val': eval_loss, 'test': 0.0})
177 | self.book._save_res()
178 | self.book.print_res(epoch_num, key_order=['train', 'val', 'test'], lr=self.optimizer_model_pose.param_groups[0]['lr'])
179 |
180 | if epoch_num > 100 and eval_loss < best_eval:
181 | print('Best eval at epoch {}'.format(epoch_num))
182 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + 'best.p'), 'wb')
183 | save_model_dict.update({'model_pose': self.model_pose.state_dict()})
184 | torch.save(save_model_dict, f)
185 | f.close()
186 | best_eval = eval_loss
187 | if epoch_num > 20 and epoch_num % 20 == 0 :
188 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + '{:06d}'.format(epoch_num) + '.p'), 'wb')
189 | save_model_dict.update({'model_pose': self.model_pose.state_dict()})
190 | torch.save(save_model_dict, f)
191 | f.close()
192 | endtime = datetime.now().replace(microsecond=0)
193 | print('Finished Training at %s\n' % (datetime.strftime(endtime, '%Y-%m-%d_%H:%M:%S')))
194 | print('Training complete in %s!\n' % (endtime - starttime))
195 |
196 |
197 |
198 | if __name__ == '__main__':
199 | args = argparseNloop()
200 |
201 | is_train = True
202 | ablation = None # if True then ablation: no_IAC_loss
203 | model_trainer = Trainer(args=args, is_train=is_train, split='test', JT_POSITION=True, num_jts=27)
204 | print("** Method Initialization Complete **")
205 | model_trainer.fit(ablation=ablation)
206 |
207 |
--------------------------------------------------------------------------------
/src/Ninjutsu/process_Ninjutsu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # import torch
3 | import os
4 | import glob
5 | import sys
6 | sys.path.append('.')
7 | sys.path.append('..')
8 | import pickle
9 | from src.tools.transformations import batch_euler_to_rotmat
10 |
11 | def makepath(desired_path, isfile = False):
12 | '''
13 | if the path does not exist make it
14 | :param desired_path: can be path to a file or a folder name
15 | :return:
16 | '''
17 | import os
18 | if isfile:
19 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
20 | else:
21 | if not os.path.exists(desired_path): os.makedirs(desired_path)
22 | return desired_path
23 |
24 | class PreProcessor():
25 | def __init__(self, root_dir, fps=20, split='train'):
26 | self.root = root_dir
27 | self.framerate = fps
28 | self.root = os.path.join(root_dir, split)
29 | self.sequences = os.listdir(self.root)
30 |
31 |
32 | self.annot_dict = {
33 | 'cam': [],
34 | 'seq': [],
35 | 'contacts': [],
36 | 'pose_canon_1':[], 'pose_canon_2':[],
37 | 'dofs_1': [], 'dofs_2': [],
38 | 'rotmat_1': [], 'rotmat_2': [],
39 | 'offsets_1': [], 'offsets_2': []
40 | }
41 |
42 | self.bvh_joint_order = {
43 | 'Hips': 0,
44 | 'RightUpLeg': 1,
45 | 'RightLeg': 2,
46 | 'RightFoot': 3,
47 | 'RightToeBase': 4,
48 | 'RightToeBaseEnd': 5,
49 | 'LeftUpLeg': 6,
50 | 'LeftLeg': 7,
51 | 'LeftFoot': 8,
52 | 'LeftToeBase': 9,
53 | 'LeftToeBaseEnd': 10,
54 | 'Spine': 11,
55 | 'Spine1': 12,
56 | 'Spine2': 13,
57 | 'Spine3': 14,
58 | 'RightShoulder': 15,
59 | 'RightArm': 16,
60 | 'RightForeArm': 17,
61 | 'RightHand': 18,
62 | 'RightHandEnd': 19,
63 | 'RightHandPinky1': 20,
64 | 'RightHandPinky2': 21,
65 | 'RightHandPinky3': 22,
66 | 'RightHandPinky3End': 23,
67 | 'RightHandRing1': 24,
68 | 'RightHandRing2': 25,
69 | 'RightHandRing3': 26,
70 | 'RightHandRing3End': 27,
71 | 'RightHandMiddle1': 28,
72 | 'RightHandMiddle2': 29,
73 | 'RightHandMiddle3': 30,
74 | 'RightHandMiddle3End': 31,
75 | 'RightHandIndex1': 32,
76 | 'RightHandIndex2': 33,
77 | 'RightHandIndex3': 34,
78 | 'RightHandIndex3End': 35,
79 | 'RightHandThumb1': 36,
80 | 'RightHandThumb2': 37,
81 | 'RightHandThumb3': 38,
82 | 'RightHandThumb3End': 39,
83 | 'LeftShoulder': 40,
84 | 'LeftArm': 41,
85 | 'LeftForeArm': 42,
86 | 'LeftHand': 43,
87 | 'LeftHandEnd': 44,
88 | 'LeftHandPinky1': 45,
89 | 'LeftHandPinky2': 46,
90 | 'LeftHandPinky3': 47,
91 | 'LeftHandPinky3End': 48,
92 | 'LeftHandRing1': 49,
93 | 'LeftHandRing2': 50,
94 | 'LeftHandRing3': 51,
95 | 'LeftHandRing3End': 52,
96 | 'LeftHandMiddle1': 53,
97 | 'LeftHandMiddle2': 54,
98 | 'LeftHandMiddle3': 55,
99 | 'LeftHandMiddle3End': 56,
100 | 'LeftHandIndex1': 57,
101 | 'LeftHandIndex2': 58,
102 | 'LeftHandIndex3': 59,
103 | 'LeftHandIndex3End': 60,
104 | 'LeftHandThumb1': 61,
105 | 'LeftHandThumb2': 62,
106 | 'LeftHandThumb3': 63,
107 | 'LeftHandThumb3End': 64,
108 | 'Spine4': 65,
109 | 'Neck': 66,
110 | 'Head': 67,
111 | 'HeadEnd': 68
112 | }
113 |
114 | print("creating the annot file")
115 | self.collate_videos()
116 | self.save_annot(split)
117 |
118 | def detect_contact(self, motion1, motion2, thresh=120):
119 |
120 |
121 | contact_joints = ['Hand', 'HandEnd',
122 | 'HandPinky1', 'HandPinky2', 'HandPinky3', 'HandPinky3End',
123 | 'HandRing1', 'HandRing2', 'HandRing3','HandRing3End',
124 | 'HandIndex1', 'HandIndex2', 'HandIndex3','HandIndex3End',
125 | 'HandMiddle1', 'HandMiddle2', 'HandMiddle3','HandMiddle3End',
126 | 'HandThumb1', 'HandThumb2', 'HandThumb3','HandThumb3End']
127 |
128 | n_frames = motion1.shape[0]
129 |
130 | assert motion1.shape == motion2.shape
131 |
132 | ## 0 : no contact, 1: rh-rh, 2: lh-lh, 3: lh-rh , 4: rh-lh
133 | contact = np.zeros((n_frames, 5))
134 |
135 | def dist(x, y):
136 | return np.sqrt(np.sum((x - y)**2))
137 | count = 0
138 | for i in range(n_frames):
139 | for s, sides in enumerate([['Right', 'Right'], ['Left', 'Left'], ['Left', 'Right'], ['Right', 'Left']]):
140 | for j, joint1 in enumerate(contact_joints):
141 | if contact[i, s+1] == 1:
142 | break
143 | for k, joint2 in enumerate(contact_joints):
144 | j1 = sides[0] + joint1
145 | j2 = sides[1] + joint2
146 |
147 | idx1 = self.bvh_joint_order[j1]
148 | idx2 = self.bvh_joint_order[j2]
149 |
150 | d = dist(motion1[i, idx1], motion2[i, idx2])
151 | if d <= thresh:
152 | contact[i, s+1] = 1
153 | count += 1
154 | break
155 |
156 |
157 | print(count)
158 | return contact[:, 1:]
159 |
160 |
161 | def use_frames(self, motion1, motion2, thresh=1000):
162 |
163 | t_frame = motion1.shape[0]
164 | xx = np.tile(motion1, (1,69,1))
165 | yy = np.repeat(motion2, 69, axis=1)
166 |
167 | diff = xx - yy
168 | D = np.linalg.norm(diff, axis=-1)
169 | contact = (D <= thresh)*1
170 | sum_contact = np.sum(contact, axis=1)
171 | contact_frames = list(np.nonzero(sum_contact)[0])
172 | return contact, contact_frames
173 |
174 | def save_annot(self, split):
175 | save_path = makepath(os.path.join('data', 'Ninjutsu', split+'.pkl'), isfile=True)
176 | with open(save_path, 'wb') as f:
177 | pickle.dump(self.annot_dict, f)
178 |
179 |
180 | def load_files(self, seq, fname='0'):
181 | file_basename = os.path.join(self.root, str(seq))
182 | path_canon_3d = os.path.join(file_basename, fname+'_worldpos.csv')
183 | path_dofs = os.path.join(file_basename, fname+'_rotations.csv')
184 | path_offsets = os.path.join(file_basename, fname+'_offsets.pkl')
185 |
186 | print(f"loading file {file_basename}")
187 | canon_3d = np.genfromtxt(path_canon_3d, delimiter=',', skip_header=1) # n_c*n_f x n_dof
188 | dofs = np.genfromtxt(path_dofs, delimiter=',', skip_header=1) # n_c*n_f x n_dof
189 | with open(path_offsets, 'rb') as f:
190 | offset_dict = pickle.load(f)
191 | print(f"loading complete")
192 |
193 | n_frames = canon_3d.shape[0]
194 | canon_3d = np.float32(canon_3d[:, 1:].reshape(n_frames, -1, 3))
195 | dofs = dofs[:, 1:].reshape(n_frames, -1, 3)
196 |
197 | use_frames = list(np.rint(np.arange(0, n_frames-1, 25/self.framerate)))
198 | use_frames = [int(a) for a in use_frames]
199 | canon_3d = canon_3d[use_frames]
200 | dofs = np.float32(dofs[use_frames])
201 | print(canon_3d.shape)
202 | return 0, 0, canon_3d, dofs, offset_dict
203 |
204 |
205 | def collate_videos(self):
206 | self.annot_dict['bvh_joint_order'] = self.bvh_joint_order
207 | for i, seq in enumerate(self.sequences):
208 | _, _, canon_3d_1, dofs_1, offsets_1 = self.load_files(seq, '0')
209 | _, _, canon_3d_2, dofs_2, offsets_2 = self.load_files(seq, '1')
210 | if canon_3d_2.shape[0] < canon_3d_1.shape[0]:
211 | n_frames = canon_3d_2.shape[0]
212 | else:
213 | n_frames = canon_3d_1.shape[0]
214 | canon_3d_1 = canon_3d_1[:n_frames]
215 | canon_3d_2 = canon_3d_2[:n_frames]
216 | rotmat_1 = batch_euler_to_rotmat(dofs_1)
217 | rotmat_2 = batch_euler_to_rotmat(dofs_2)
218 | contacts, contact_frames = self.use_frames(canon_3d_1, canon_3d_2)
219 | n_frames_contact = len(contact_frames)
220 | canon_3d_1 = canon_3d_1[contact_frames]
221 | canon_3d_2 = canon_3d_2[contact_frames]
222 | dofs_1 = dofs_1[contact_frames]
223 | dofs_2 = dofs_2[contact_frames]
224 | contacts = contacts[contact_frames].reshape(n_frames_contact, 69, 69)
225 | hand_contact = self.detect_contact(motion1=canon_3d_1, motion2=canon_3d_2)
226 | self.annot_dict['pose_canon_1'].extend(canon_3d_1 )
227 | self.annot_dict['pose_canon_2'].extend(canon_3d_2 )
228 | self.annot_dict['dofs_1'].extend(dofs_1 )
229 | self.annot_dict['dofs_2'].extend(dofs_2 )
230 | self.annot_dict['contacts'].extend(hand_contact)
231 | self.annot_dict['rotmat_1'].extend(rotmat_1 )
232 | self.annot_dict['rotmat_2'].extend(rotmat_2 )
233 | self.annot_dict['offsets_1'].extend([offsets_1 for i in range(0, n_frames_contact)])
234 | self.annot_dict['offsets_2'].extend([offsets_2 for i in range(0, n_frames_contact)])
235 | self.annot_dict['seq'].extend([seq for i in range(0, n_frames_contact)])
236 |
237 |
238 | self.annot_dict['pose_canon_1'] = np.array(self.annot_dict['pose_canon_1'])
239 | self.annot_dict['pose_canon_2'] = np.array(self.annot_dict['pose_canon_2'])
240 | self.annot_dict['dofs_1'] = np.array(self.annot_dict['dofs_1'])
241 | self.annot_dict['dofs_2'] = np.array(self.annot_dict['dofs_2'])
242 | self.annot_dict['rotmat_1'] = np.array(self.annot_dict['rotmat_1'])
243 | self.annot_dict['rotmat_2'] = np.array(self.annot_dict['rotmat_2'])
244 | self.annot_dict['contacts'] = np.array(self.annot_dict['contacts'])
245 | print(len(self.annot_dict.keys()))
246 | print(len(self.annot_dict['seq']))
247 | print(len(self.annot_dict['pose_canon_1']))
248 |
249 |
250 |
251 |
252 | if __name__ == "__main__":
253 | root_path = os.path.join('..', 'DATASETS', 'ReMocap', 'Ninjutsu')
254 | fps = 10
255 | pp = PreProcessor(root_path, fps, 'train')
256 | pp = PreProcessor(root_path, fps, 'test')
257 |
--------------------------------------------------------------------------------
/src/Lindyhop/visualizer.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import matplotlib
5 | import matplotlib.pyplot as plt
6 | import torch
7 | import sys
8 | sys.path.append('.')
9 | sys.path.append('..')
10 | from mpl_toolkits.mplot3d import Axes3D
11 | from matplotlib.animation import FuncAnimation, PillowWriter
12 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
13 | import mpl_toolkits.mplot3d.axes3d as p3
14 | from PIL import Image
15 | from scipy import interpolate
16 | from src.tools.utils import makepath
17 | from src.tools.img_gif import img2video, img2gif
18 |
19 | LEFT_HANDSIDE = list(range(19, 24))
20 | RIGHT_HANDSIDE = list(range(45, 48))
21 | LEFT_FOOTSIDE = list(range(1, 5))
22 | RIGHT_FOOTSIDE = list(range(6, 10))
23 |
24 | kinematic_chain_full = [[0, 11], [11, 12], [12, 13], [13, 14], [14, 65], [65, 66], [66, 67], [67, 68], #spine, neck and head
25 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], # right leg
26 | [0, 6], [6, 7], [7, 8], [8, 9], [9, 10], # left leg
27 | [14, 15], [15, 16], [16, 17], [17, 18], [18, 19], # right arm
28 | [14, 40], [40, 41], [41, 42], [42, 43], [43, 44], # left arm
29 | [19, 20], [20, 21], [21, 22], [22, 23], # right pinky
30 | [19, 24], [24, 25], [25, 26], [26, 27], # right ring
31 | [19, 28], [28, 29], [29, 30], [30, 31], # right middle
32 | [19, 32], [32, 33], [33, 34], [34, 35], # right index
33 | [18, 36], [36, 37], [37, 38], [38, 39], # right thumb
34 | [44, 45], [45, 46], [46, 47], [47, 48], # left pinky
35 | [44, 49], [49, 50], [50, 51], [51, 52], # left ring
36 | [44, 53], [53, 54], [54, 55], [55, 56], # left middle
37 | [44, 57], [57, 58], [58, 59], [59, 60], # left index
38 | [43, 61], [61, 62], [62, 63], [63, 64], # left thumb
39 | ]
40 | kinematic_chain_reduced = [[0, 11], [11, 12], [12, 13], [13, 14], [14, 43], [43, 44], [44, 45], [45, 46], #spine, neck and head
41 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], # right leg
42 | [0, 6], [6, 7], [7, 8], [8, 9], [9, 10], # left leg
43 | [14, 15], [15, 16], [16, 17], [17, 18], # right arm
44 | [14, 29], [29, 30], [30, 31], [31, 32], # left arm
45 | [18, 19], [19, 20], # right pinky
46 | [18, 21], [21, 22], # right ring
47 | [18, 23], [23, 24], # right middle
48 | [18, 25], [25, 26], # right index
49 | [18, 27], [27, 28], # right thumb
50 | [32, 33], [33, 34], # left pinky
51 | [32, 35], [35, 36], # left ring
52 | [32, 37], [37, 38], # left middle
53 | [32, 39], [39, 40], # left index
54 | [32, 41], [41, 42], # left thumb
55 | ]
56 |
57 | kinematic_chain_short = [
58 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5],
59 | [0, 6], [6, 7], [7, 8], [8, 9], [9, 10],
60 | [0, 11], [11, 12], [12, 13], [13, 14],
61 | [14, 15], [15, 16], [16, 17], [17, 18],
62 | [14, 19], [19, 20], [20, 21], [21, 22],
63 | [14, 23], [23, 24], [24, 25], [25, 26]
64 | ]
65 | kinematic_chain_old = [
66 | [0, 1], [1, 2], [2, 3], [3, 4], [4, 5],
67 | [0, 6], [6, 7], [7, 8], [8, 9], [9, 10],
68 | [0, 11], [11, 12], [12, 13], [13, 14],
69 | [12, 15], [15, 16], [16, 17],
70 | [12, 18], [18, 19], [19, 20],
71 | [17, 21], [21, 22], [22, 23], [23, 24], [22, 25], [25, 26], [22, 27],
72 | [27, 28], [22, 29], [29, 30], [22, 31], [31, 32],
73 | [20, 33], [33, 34], [34, 35], [35, 36], [34, 37], [37, 38], [34, 39], [39, 40],
74 | [34, 41], [41, 42], [34, 43], [43, 44]
75 | ]
76 | def fig2data(fig):
77 | fig.canvas.draw()
78 | w, h = fig.canvas.get_width_height()
79 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
80 | buf.shape = (w, h, 4)
81 | buf = np.roll(buf, 3, axis=2)
82 | return buf
83 |
84 | def fig2img(fig):
85 | buf = fig2data(fig)
86 | w, h, d = buf.shape
87 | return Image.frombytes('RGBA', (w, h), buf.tostring())
88 |
89 | def plot_contacts3D(pose1, pose2=None, gt_pose2=None, savepath=None, kinematic_chain = 'full', onlyone=False, gif=False):
90 |
91 | def plot_twoperson(pose1, pose2, i, kinematic_chain, savepath, gt_pose2=None):
92 | fig = plt.figure()
93 |
94 | ax = plt.subplot(projection='3d')
95 |
96 | ax.cla()
97 | ax.set_xlabel("x")
98 | ax.set_ylabel("y")
99 | ax.set_zlabel("z")
100 | ax.set_xlim3d([-1000, 2000])
101 | ax.set_zlim3d([-1000, 2000])
102 | ax.set_ylim3d([-1000, 2000])
103 | ax.axis('off')
104 | ax.view_init(elev=0, azim=0, roll=90)
105 | if kinematic_chain == 'full':
106 | KINEMATIC_CHAIN = kinematic_chain_full
107 | elif kinematic_chain == 'no_fingers':
108 | KINEMATIC_CHAIN = kinematic_chain_short
109 | elif kinematic_chain == 'reduced':
110 | KINEMATIC_CHAIN = kinematic_chain_reduced
111 | elif kinematic_chain == 'old':
112 | KINEMATIC_CHAIN = kinematic_chain_old
113 |
114 | for limb in KINEMATIC_CHAIN:
115 | xs = [pose1[i, limb[0], 0], pose1[i, limb[1], 0]]
116 | ys = [pose1[i, limb[0], 1], pose1[i, limb[1], 1]]
117 | zs = [pose1[i, limb[0], 2], pose1[i, limb[1], 2]]
118 | # if limb[0] in LEFT_FOOTSIDE or limb[0] in LEFT_HANDSIDE:
119 | # ax.plot(xs, ys, zs, 'darkred', linewidth=2.0)
120 | # else:
121 | # ax.plot(xs, ys, zs, 'red', linewidth=2.0)
122 | ax.plot(xs, ys, zs, 'red', linewidth=2.0)
123 |
124 | xs_ = [pose2[i, limb[0], 0], pose2[i, limb[1], 0]]
125 | ys_ = [pose2[i, limb[0], 1], pose2[i, limb[1], 1]]
126 | zs_ = [pose2[i, limb[0], 2], pose2[i, limb[1], 2]]
127 | # if limb[0] in LEFT_FOOTSIDE or limb[0] in LEFT_HANDSIDE:
128 | # ax.plot(xs_, ys_, zs_, 'darkblue', linewidth=2.0)
129 | # else:
130 | # ax.plot(xs_, ys_, zs_, 'blue', linewidth=2.0)
131 | ax.plot(xs_, ys_, zs_, 'blue', linewidth=2.0)
132 | if gt_pose2 is not None:
133 | gt_xs_ = [gt_pose2[i, limb[0], 0], gt_pose2[i, limb[1], 0]]
134 | gt_ys_ = [gt_pose2[i, limb[0], 1], gt_pose2[i, limb[1], 1]]
135 | gt_zs_ = [gt_pose2[i, limb[0], 2], gt_pose2[i, limb[1], 2]]
136 | ax.plot(gt_xs_, gt_ys_, gt_zs_, 'g', linewidth=1.0)
137 | # min_x = min(min(pose1[i, :, 2]), min(pose2[i, :, 2])) - 100
138 | # min_y = min(min(pose1[i, :, 0]), min(pose2[i, :, 0])) - 100
139 | # max_x = max(max(pose1[i, :, 2]), max(pose2[i, :, 2])) + 100
140 | # max_y = max(max(pose1[i, :, 0]), max(pose2[i, :, 0])) + 100
141 | # x_pl, y_pl = np.meshgrid(np.linspace(min_x, max_x, 10), np.linspace(min_y, max_y, 10))
142 | # foot_ground_contact_p1 =min(pose1[i, :, 1])
143 | # foot_ground_contact_2 =min(pose2[i, :, 1])
144 | # ground_plane = min(foot_ground_contact_p1, foot_ground_contact_2)
145 | # z_pl = torch.ones((10, 10)) * ground_plane
146 | # ax.plot_surface(x_pl, y_pl, z_pl, color= 'y', alpha=0.1)
147 | filename = makepath(os.path.join(savepath, str(i) +'.png'), isfile=True)
148 | plt.savefig(filename)
149 | plt.close()
150 |
151 |
152 | def plot_oneperson(pose2, i, kinematic_chain, savepath):
153 | # fig = plt.figure()
154 | ax = plt.subplot(projection='3d')
155 | ax.cla()
156 | ax.set_xlabel("x")
157 | ax.set_ylabel("y")
158 | ax.set_zlabel("z")
159 |
160 | ax.axis('off')
161 | ax.view_init(elev=0, azim=0, roll=0)
162 | if kinematic_chain == 'full':
163 | KINEMATIC_CHAIN = kinematic_chain_full
164 | elif kinematic_chain == 'no_fingers':
165 | KINEMATIC_CHAIN = kinematic_chain_short
166 |
167 | for limb in KINEMATIC_CHAIN:
168 | ys_ = [pose2[i, limb[0], 0], pose2[i, limb[1], 0]]
169 | zs_ = [pose2[i, limb[0], 1], pose2[i, limb[1], 1]]
170 | xs_ = [pose2[i, limb[0], 2], pose2[i, limb[1], 2]]
171 | if limb[0] in LEFT_FOOTSIDE or limb[0] in LEFT_HANDSIDE:
172 | ax.plot(xs_, ys_, zs_, 'darkred', linewidth=3.0)
173 | else:
174 | ax.plot(xs_, ys_, zs_, 'red', linewidth=3.0)
175 | filename = makepath(os.path.join(savepath, str(i) +'.png'), isfile=True)
176 | plt.savefig(filename)
177 | # plt.pause(0.001)
178 | plt.close()
179 |
180 | T = pose1.shape[0]
181 | is_interpolate = 0
182 | if is_interpolate:
183 | T1 = 3*T
184 | p1_x_interp =np.zeros((T1, pose1.shape[1]))
185 | p1_y_interp =np.zeros((T1, pose1.shape[1]))
186 | p1_z_interp =np.zeros((T1, pose1.shape[1]))
187 | p2_x_interp =np.zeros((T1, pose2.shape[1]))
188 | p2_y_interp =np.zeros((T1, pose2.shape[1]))
189 | p2_z_interp =np.zeros((T1, pose2.shape[1]))
190 |
191 | x = np.linspace(0, T-1 ,T)
192 | x_new = np.linspace(0, T-1 ,T1)
193 | for v1 in range(0, pose1.shape[1]):
194 | p1_x = pose1[:, v1, 0]
195 | p1_y = pose1[:, v1, 1]
196 | p1_z = pose1[:, v1, 2]
197 | p2_x = pose2[:, v1, 0]
198 | p2_y = pose2[:, v1, 1]
199 | p2_z = pose2[:, v1, 2]
200 | f_p1x = interpolate.interp1d(x, p1_x, kind = 'linear')
201 | f_p1y = interpolate.interp1d(x, p1_y, kind = 'linear')
202 | f_p1z = interpolate.interp1d(x, p1_z, kind = 'linear')
203 | f_p2x = interpolate.interp1d(x, p2_x, kind = 'linear')
204 | f_p2y = interpolate.interp1d(x, p2_y, kind = 'linear')
205 | f_p2z = interpolate.interp1d(x, p2_z, kind = 'linear')
206 | p1_x_interp[:, v1] = f_p1x(x_new)
207 | p1_y_interp[:, v1] = f_p1y(x_new)
208 | p1_z_interp[:, v1] = f_p1z(x_new)
209 | p2_x_interp[:, v1] = f_p2x(x_new)
210 | p2_y_interp[:, v1] = f_p2y(x_new)
211 | p2_z_interp[:, v1] = f_p2z(x_new)
212 | p1_x_interp = torch.from_numpy(p1_x_interp).unsqueeze(2)
213 | p1_y_interp = torch.from_numpy(p1_y_interp).unsqueeze(2)
214 | p1_z_interp = torch.from_numpy(p1_z_interp).unsqueeze(2)
215 | p1_interp = torch.cat((p1_x_interp, p1_y_interp, p1_z_interp), dim=-1)
216 | p2_x_interp = torch.from_numpy(p2_x_interp).unsqueeze(2)
217 | p2_y_interp = torch.from_numpy(p2_y_interp).unsqueeze(2)
218 | p2_z_interp = torch.from_numpy(p2_z_interp).unsqueeze(2)
219 | p2_interp = torch.cat((p2_x_interp, p2_y_interp, p2_z_interp), dim=-1)
220 | # verts_all = torch.cat((torch.from_numpy(verts_all[0]).unsqueeze(0), p1_interp), dim=0)
221 | T = T1
222 | pose1 = p1_interp
223 | pose2 = p2_interp
224 |
225 |
226 | for i in range(pose1.shape[0]):
227 | if onlyone:
228 | plot_oneperson(pose1, i, kinematic_chain, savepath)
229 | else:
230 | plot_twoperson(pose1, pose2, i, kinematic_chain, savepath, gt_pose2)
231 | if gif:
232 | img2gif(savepath)
233 | else:
234 | img2video(savepath, fps=20)
235 |
236 |
--------------------------------------------------------------------------------
/src/Lindyhop/models/MotionDiffuse_body.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 S-Lab
3 | """
4 |
5 | import matplotlib.pylab as plt
6 | import random
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import layer_norm, nn
10 | import numpy as np
11 | from torch.nn import functional
12 |
13 | import math
14 | body = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
15 | hand = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]
16 | hand_full = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68]
17 |
18 | def heatmap2d(arr: np.ndarray):
19 | plt.imshow(arr, cmap='bwr')
20 | plt.clim(-10, 10)
21 | plt.colorbar()
22 | plt.show()
23 |
24 | class PositionalEncoding(nn.Module):
25 | def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False):
26 | super().__init__()
27 | self.batch_first = batch_first
28 |
29 | self.dropout = nn.Dropout(p=dropout)
30 |
31 | pe = torch.zeros(max_len, d_model)
32 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
33 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
34 | pe[:, 0::2] = torch.sin(position * div_term)
35 | pe[:, 1::2] = torch.cos(position * div_term)
36 |
37 |
38 | for pos in range(max_len):
39 | for i in range(0, d_model-1, 2):
40 | pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
41 | pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
42 |
43 |
44 | pe = pe.unsqueeze(0) # [1, max_len, d_model]
45 |
46 | self.register_buffer('pe', pe)
47 |
48 | def forward(self, x):
49 | x = x + self.pe[:, :x.shape[1], :]
50 | return self.dropout(x)
51 |
52 |
53 | def timestep_embedding(timesteps, dim, freqs):
54 | """
55 | Create sinusoidal timestep embeddings.
56 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
57 | These may be fractional.
58 | :param dim: the dimension of the output.
59 | :param max_period: controls the minimum frequency of the embeddings.
60 | :return: an [N x dim] Tensor of positional embeddings.
61 | """
62 | # freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=self.latent_dim//2, dtype=torch.float32) / (self.latent_dim//2)).to(device)
63 |
64 | # timesteps= timesteps.to('cpu')
65 | args = timesteps[:, None].float() * freqs[None]
66 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
67 | if dim % 2:
68 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
69 | return embedding
70 |
71 |
72 | def set_requires_grad(nets, requires_grad=False):
73 | """Set requies_grad for all the networks.
74 |
75 | Args:
76 | nets (nn.Module | list[nn.Module]): A list of networks or a single
77 | network.
78 | requires_grad (bool): Whether the networks require gradients or not
79 | """
80 | if not isinstance(nets, list):
81 | nets = [nets]
82 | for net in nets:
83 | if net is not None:
84 | for param in net.parameters():
85 | param.requires_grad = requires_grad
86 |
87 |
88 | def zero_module(module):
89 | """
90 | Zero out the parameters of a module and return it.
91 | """
92 | for p in module.parameters():
93 | p.detach().zero_()
94 | return module
95 |
96 |
97 | class StylizationBlock(nn.Module):
98 |
99 | def __init__(self, latent_dim, time_embed_dim, dropout):
100 | super().__init__()
101 | self.emb_layers = nn.Sequential(
102 | nn.SiLU(),
103 | nn.Linear(time_embed_dim, 2 * latent_dim),
104 | )
105 | self.norm = nn.LayerNorm(latent_dim)
106 | self.out_layers = nn.Sequential(
107 | nn.SiLU(),
108 | nn.Dropout(p=dropout),
109 | zero_module(nn.Linear(latent_dim, latent_dim)),
110 | )
111 |
112 | def forward(self, h, emb):
113 | """
114 | h: B, T, D
115 | emb: B, D
116 | """
117 | # B, 1, 2D
118 | emb_out = self.emb_layers(emb).unsqueeze(1)
119 | # scale: B, 1, D / shift: B, 1, D
120 | scale, shift = torch.chunk(emb_out, 2, dim=2)
121 | h = self.norm(h) * (1 + scale) + shift
122 | h = self.out_layers(h)
123 | return h
124 |
125 |
126 |
127 | class FFN(nn.Module):
128 |
129 | def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim):
130 | super().__init__()
131 | self.linear1 = nn.Linear(latent_dim, ffn_dim)
132 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
133 | self.activation = nn.GELU()
134 | self.dropout = nn.Dropout(dropout)
135 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
136 |
137 | def forward(self, x, emb):
138 | y = self.linear2(self.dropout(self.activation(self.linear1(x))))
139 | y = x + self.proj_out(y, emb)
140 | return y
141 |
142 |
143 |
144 | class TemporalSelfAttention(nn.Module):
145 |
146 | def __init__(self, seq_len, latent_dim, num_head, dropout, time_embed_dim):
147 | super().__init__()
148 | self.num_head = num_head
149 | self.norm = nn.LayerNorm(latent_dim)
150 | self.query = nn.Linear(latent_dim, latent_dim)
151 | self.key = nn.Linear(latent_dim, latent_dim)
152 | self.value = nn.Linear(latent_dim, latent_dim)
153 | self.dropout = nn.Dropout(dropout)
154 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
155 |
156 | def forward(self, x, emb, src_mask, eps=1e-8):
157 | """
158 | x: B, T, D
159 | """
160 | B, T, D = x.shape
161 | H = self.num_head
162 | # B, T, 1, D
163 | query = self.query(self.norm(x)).unsqueeze(2)
164 | # B, 1, T, D
165 | key = self.key(self.norm(x)).unsqueeze(1)
166 | query = query.view(B, T, H, -1)
167 | key = key.view(B, T, H, -1)
168 | # B, T, T, H
169 | attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / max(math.sqrt(D // H), eps)
170 | attention = attention * src_mask.unsqueeze(-1)
171 | weight = self.dropout(F.softmax(attention, dim=2))
172 | value = self.value(self.norm(x)).view(B, T, H, -1)
173 | y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
174 | y = x + self.proj_out(y, emb)
175 | return y, attention
176 |
177 | class TemporalCrossAttention(nn.Module):
178 |
179 | def __init__(self, seq_len, latent_dim, mot1_latent_dim, num_head, dropout, time_embed_dim):
180 | super().__init__()
181 | self.num_head = num_head
182 | self.norm = nn.LayerNorm(latent_dim)
183 | self.mot1_norm = nn.LayerNorm(mot1_latent_dim)
184 | self.query = nn.Linear(latent_dim, latent_dim)
185 | self.key = nn.Linear(mot1_latent_dim, latent_dim)
186 | self.value = nn.Linear(mot1_latent_dim, latent_dim)
187 | self.dropout = nn.Dropout(dropout)
188 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
189 |
190 | def forward(self, x, xf, src_mask, emb, eps=1e-8):
191 | """
192 | x: B, T, D
193 | xf: B, N, L
194 | """
195 | B, T, D = x.shape
196 | N = xf.shape[1]
197 | H = self.num_head
198 | # B, T, 1, D
199 | query = self.query(self.norm(x)).unsqueeze(2)
200 | # B, 1, N, D
201 | key = self.key(self.mot1_norm(xf)).unsqueeze(1)
202 | query = query.view(B, T, H, -1)
203 | key = key.view(B, N, H, -1)
204 | # B, T, N, H
205 | attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / max(math.sqrt(D // H), eps)
206 | attention = attention * src_mask.unsqueeze(-1)
207 | weight = self.dropout(F.softmax(attention, dim=2))
208 | value = self.value(self.mot1_norm(xf)).view(B, N, H, -1)
209 | y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
210 | y = x + self.proj_out(y, emb)
211 | return y, attention
212 |
213 | class TemporalDiffusionTransformerDecoderLayer(nn.Module):
214 |
215 | def __init__(self,
216 | seq_len=60,
217 | latent_dim=32,
218 | mot1_latent_dim=512,
219 | time_embed_dim=128,
220 | ffn_dim=256,
221 | num_head=4,
222 | dropout=0.1):
223 | super().__init__()
224 | self.sa_block = TemporalSelfAttention(
225 | seq_len, latent_dim, num_head, dropout, time_embed_dim)
226 | self.ca_block = TemporalCrossAttention(
227 | seq_len, latent_dim, mot1_latent_dim, num_head, dropout, time_embed_dim)
228 | self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim)
229 |
230 | def forward(self, x, xf, emb, src_mask):
231 | x, s_attn = self.sa_block(x, emb, src_mask)
232 | x, c_attn = self.ca_block(x, xf, src_mask, emb)
233 | x = self.ffn(x, emb)
234 | return x, s_attn, c_attn
235 |
236 |
237 | class DiffusionTransformer(nn.Module):
238 | def __init__(self,
239 | device= 'cuda',
240 | num_jts=27,
241 | num_frames=100,
242 | input_feats=3,
243 | latent_dim=32,
244 | ff_size=1024,
245 | num_layers=8,
246 | num_heads=4,
247 | dropout=0.05,
248 | activations="gelu",
249 | **kargs):
250 | super().__init__()
251 |
252 | self.num_frames = num_frames
253 | self.num_jts = num_jts
254 | self.latent_dim = latent_dim
255 | self.ff_size = ff_size
256 | self.num_layers = num_layers
257 | self.num_heads = num_heads
258 | self.freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=self.latent_dim//2, dtype=torch.float32) / (self.latent_dim//2)).to(device)
259 | self.dropout = dropout
260 | self.activation = activations
261 | self.input_feats = input_feats
262 | self.time_embed_dim = latent_dim
263 | self.spatio_temp = self.num_frames * self.num_jts
264 |
265 | # encode motion 1
266 | self.motion1_pre_proj = nn.Linear(self.input_feats, self.latent_dim)
267 | self.m1_temporal_pos_encoder = PositionalEncoding(d_model=self.latent_dim, dropout=self.dropout, max_len=self.spatio_temp)
268 | mot1TransEncoderLayer = nn.TransformerEncoderLayer(
269 | d_model=latent_dim,
270 | nhead=num_heads,
271 | dim_feedforward=ff_size,
272 | dropout=dropout,
273 | batch_first=True,
274 | activation='gelu')
275 | self.mot1TransEncoder = nn.TransformerEncoder(
276 | mot1TransEncoderLayer,
277 | num_layers=2)
278 | self.mot1_ln = nn.LayerNorm(latent_dim)
279 | #Classifier-free guidance
280 | # self.null_cond = nn.Parameter(torch.randn(self.num_frames * self.num_jts, latent_dim))
281 |
282 | # Time Embedding
283 | self.time_embed = nn.Sequential(
284 | nn.Linear(self.latent_dim, self.time_embed_dim),
285 | nn.SiLU(),
286 | nn.Linear(self.time_embed_dim, self.time_embed_dim),
287 | )
288 |
289 | # motion2 decoding
290 | self.motion2_pre_proj = nn.Linear(self.input_feats, self.latent_dim)
291 | self.m2_temporal_pos_encoder = PositionalEncoding(d_model=self.latent_dim, dropout=self.dropout, max_len=self.spatio_temp)
292 | self.temporal_decoder_blocks = nn.ModuleList()
293 | for i in range(num_layers):
294 | self.temporal_decoder_blocks.append(
295 | TemporalDiffusionTransformerDecoderLayer(
296 | seq_len=self.spatio_temp,
297 | latent_dim=latent_dim,
298 | mot1_latent_dim=latent_dim,
299 | time_embed_dim=self.time_embed_dim,
300 | ffn_dim=ff_size,
301 | num_head=num_heads,
302 | dropout=dropout
303 | )
304 | )
305 | # Output Module
306 | self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats))
307 |
308 |
309 | def generate_src_mask(self, tgt):
310 | length = tgt.size(1)
311 | src_mask = (1 - torch.triu(torch.ones(1, length, length), diagonal=1))
312 | return src_mask
313 |
314 | # def forward(self, motion2, timesteps, length=None, motion1=None, xf_out=None, contact_map=None):
315 | def forward(self, motion2, timesteps, motion1=None, contact_maps=None, spatial_guidance=None, guidance_scale=0):
316 | """
317 | x: B, T, D
318 | """
319 | B, T, J, _ = motion1.shape
320 | m1 = self.motion1_pre_proj(motion1) # GCN
321 | m2 = self.motion2_pre_proj(motion2) # GCN
322 | m1 = m1.reshape(B, T*J, -1)
323 | m2 = m2.reshape(B, T*J, -1)
324 | src_mask = self.generate_src_mask(m2).to(m2.device)
325 |
326 | m1_pe = self.m1_temporal_pos_encoder(m1)
327 | m1_cond = self.mot1_ln(self.mot1TransEncoder(m1_pe))
328 | # null_cond = torch.repeat_interleave(
329 | # self.null_cond.to(m2.device).unsqueeze(0), B, dim=0)
330 | # m1_enc = m1_cond if random.random() > 0.25 else null_cond
331 | m1_enc = m1_cond
332 | m2_pe = self.m2_temporal_pos_encoder(m2)
333 | emb = self.time_embed(timestep_embedding(
334 | timesteps, self.latent_dim, self.freqs) )
335 | h_pe = m2_pe
336 | for module in self.temporal_decoder_blocks:
337 | h_pe, s_attn, c_attn = module(h_pe, m1_enc, emb, src_mask)
338 | output = self.out(h_pe).view(B, T, J, -1).contiguous()
339 | # if timesteps == 1:
340 | # c_attn_map = c_attn[0,:,:,0].cpu().detach().numpy()
341 | # # ax = sns.heatmap(c_attn_map, vmin=-15, vmax=15, linewidth=2)
342 | # # plt.show()
343 | # heatmap2d(c_attn_map)
344 | # tmp=1
345 | return output, s_attn, c_attn
346 |
347 | def get_motion_embedding(self, motion1):
348 | B, T, J, _ = motion1.shape
349 | m1 = self.motion1_pre_proj(motion1) # GCN
350 | m1 = m1.reshape(B, T*J, -1)
351 | m1_pe = self.m1_temporal_pos_encoder(m1)
352 | m1_cond = self.mot1_ln(self.mot1TransEncoder(m1_pe))
353 | return m1_cond
--------------------------------------------------------------------------------
/src/tools/common/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 | _EPS4 = np.finfo(float).eps * 4.0
12 |
13 | _FLOAT_EPS = np.finfo(np.float).eps
14 |
15 | # PyTorch-backed implementations
16 | def qinv(q):
17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18 | mask = torch.ones_like(q)
19 | mask[..., 1:] = -mask[..., 1:]
20 | return q * mask
21 |
22 |
23 | def qinv_np(q):
24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25 | return qinv(torch.from_numpy(q).float()).numpy()
26 |
27 |
28 | def qnormalize(q):
29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30 | return q / torch.norm(q, dim=-1, keepdim=True)
31 |
32 |
33 | def qmul(q, r):
34 | """
35 | Multiply quaternion(s) q with quaternion(s) r.
36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37 | Returns q*r as a tensor of shape (*, 4).
38 | """
39 | assert q.shape[-1] == 4
40 | assert r.shape[-1] == 4
41 |
42 | original_shape = q.shape
43 |
44 | # Compute outer product
45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46 |
47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51 | return torch.stack((w, x, y, z), dim=1).view(original_shape)
52 |
53 |
54 | def qrot(q, v):
55 | """
56 | Rotate vector(s) v about the rotation described by quaternion(s) q.
57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58 | where * denotes any number of dimensions.
59 | Returns a tensor of shape (*, 3).
60 | """
61 | assert q.shape[-1] == 4
62 | assert v.shape[-1] == 3
63 | assert q.shape[:-1] == v.shape[:-1]
64 |
65 | original_shape = list(v.shape)
66 | # print(q.shape)
67 | q = q.contiguous().view(-1, 4)
68 | v = v.contiguous().view(-1, 3)
69 |
70 | qvec = q[:, 1:]
71 | uv = torch.cross(qvec, v, dim=1)
72 | uuv = torch.cross(qvec, uv, dim=1)
73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74 |
75 |
76 | def qeuler(q, order, epsilon=0, deg=True):
77 | """
78 | Convert quaternion(s) q to Euler angles.
79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80 | Returns a tensor of shape (*, 3).
81 | """
82 | assert q.shape[-1] == 4
83 |
84 | original_shape = list(q.shape)
85 | original_shape[-1] = 3
86 | q = q.view(-1, 4)
87 |
88 | q0 = q[:, 0]
89 | q1 = q[:, 1]
90 | q2 = q[:, 2]
91 | q3 = q[:, 3]
92 |
93 | if order == 'xyz':
94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97 | elif order == 'yzx':
98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101 | elif order == 'zxy':
102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105 | elif order == 'xzy':
106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109 | elif order == 'yxz':
110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113 | elif order == 'zyx':
114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117 | else:
118 | raise
119 |
120 | if deg:
121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122 | else:
123 | return torch.stack((x, y, z), dim=1).view(original_shape)
124 |
125 |
126 | # Numpy-backed implementations
127 |
128 | def qmul_np(q, r):
129 | q = torch.from_numpy(q).contiguous().float()
130 | r = torch.from_numpy(r).contiguous().float()
131 | return qmul(q, r).numpy()
132 |
133 |
134 | def qrot_np(q, v):
135 | q = torch.from_numpy(q).contiguous().float()
136 | v = torch.from_numpy(v).contiguous().float()
137 | return qrot(q, v).numpy()
138 |
139 |
140 | def qeuler_np(q, order, epsilon=0, use_gpu=False):
141 | if use_gpu:
142 | q = torch.from_numpy(q).cuda().float()
143 | return qeuler(q, order, epsilon).cpu().numpy()
144 | else:
145 | q = torch.from_numpy(q).contiguous().float()
146 | return qeuler(q, order, epsilon).numpy()
147 |
148 |
149 | def qfix(q):
150 | """
151 | Enforce quaternion continuity across the time dimension by selecting
152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153 | between two consecutive frames.
154 |
155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156 | Returns a tensor of the same shape.
157 | """
158 | assert len(q.shape) == 3
159 | assert q.shape[-1] == 4
160 |
161 | result = q.copy()
162 | dot_products = np.sum(q[1:] * q[:-1], axis=2)
163 | mask = dot_products < 0
164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165 | result[1:][mask] *= -1
166 | return result
167 |
168 |
169 | def euler2quat(e, order, deg=True):
170 | """
171 | Convert Euler angles to quaternions.
172 | """
173 | assert e.shape[-1] == 3
174 |
175 | original_shape = list(e.shape)
176 | original_shape[-1] = 4
177 |
178 | e = e.view(-1, 3)
179 |
180 | ## if euler angles in degrees
181 | if deg:
182 | e = e * np.pi / 180.
183 |
184 | x = e[:, 0]
185 | y = e[:, 1]
186 | z = e[:, 2]
187 |
188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191 |
192 | result = None
193 | for coord in order:
194 | if coord == 'x':
195 | r = rx
196 | elif coord == 'y':
197 | r = ry
198 | elif coord == 'z':
199 | r = rz
200 | else:
201 | raise
202 | if result is None:
203 | result = r
204 | else:
205 | result = qmul(result, r)
206 |
207 | # Reverse antipodal representation to have a non-negative "w"
208 | if order in ['xyz', 'yzx', 'zxy']:
209 | result *= -1
210 |
211 | return result.view(original_shape)
212 |
213 |
214 | def expmap_to_quaternion(e):
215 | """
216 | Convert axis-angle rotations (aka exponential maps) to quaternions.
217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219 | Returns a tensor of shape (*, 4).
220 | """
221 | assert e.shape[-1] == 3
222 |
223 | original_shape = list(e.shape)
224 | original_shape[-1] = 4
225 | e = e.reshape(-1, 3)
226 |
227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228 | w = np.cos(0.5 * theta).reshape(-1, 1)
229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231 |
232 |
233 | def euler_to_quaternion(e, order):
234 | """
235 | Convert Euler angles to quaternions.
236 | """
237 | assert e.shape[-1] == 3
238 |
239 | original_shape = list(e.shape)
240 | original_shape[-1] = 4
241 |
242 | e = e.reshape(-1, 3)
243 |
244 | x = e[:, 0]
245 | y = e[:, 1]
246 | z = e[:, 2]
247 |
248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251 |
252 | result = None
253 | for coord in order:
254 | if coord == 'x':
255 | r = rx
256 | elif coord == 'y':
257 | r = ry
258 | elif coord == 'z':
259 | r = rz
260 | else:
261 | raise
262 | if result is None:
263 | result = r
264 | else:
265 | result = qmul_np(result, r)
266 |
267 | # Reverse antipodal representation to have a non-negative "w"
268 | if order in ['xyz', 'yzx', 'zxy']:
269 | result *= -1
270 |
271 | return result.reshape(original_shape)
272 |
273 |
274 | def quaternion_to_matrix(quaternions):
275 | """
276 | Convert rotations given as quaternions to rotation matrices.
277 | Args:
278 | quaternions: quaternions with real part first,
279 | as tensor of shape (..., 4).
280 | Returns:
281 | Rotation matrices as tensor of shape (..., 3, 3).
282 | """
283 | r, i, j, k = torch.unbind(quaternions, -1)
284 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
285 |
286 | o = torch.stack(
287 | (
288 | 1 - two_s * (j * j + k * k),
289 | two_s * (i * j - k * r),
290 | two_s * (i * k + j * r),
291 | two_s * (i * j + k * r),
292 | 1 - two_s * (i * i + k * k),
293 | two_s * (j * k - i * r),
294 | two_s * (i * k - j * r),
295 | two_s * (j * k + i * r),
296 | 1 - two_s * (i * i + j * j),
297 | ),
298 | -1,
299 | )
300 | return o.reshape(quaternions.shape[:-1] + (3, 3))
301 |
302 |
303 | def quaternion_to_matrix_np(quaternions):
304 | q = torch.from_numpy(quaternions).contiguous().float()
305 | return quaternion_to_matrix(q).numpy()
306 |
307 |
308 | def quaternion_to_cont6d_np(quaternions):
309 | rotation_mat = quaternion_to_matrix_np(quaternions)
310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311 | return cont_6d
312 |
313 |
314 | def quaternion_to_cont6d(quaternions):
315 | rotation_mat = quaternion_to_matrix(quaternions)
316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317 | return cont_6d
318 |
319 |
320 | def cont6d_to_matrix(cont6d):
321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322 | x_raw = cont6d[..., 0:3]
323 | y_raw = cont6d[..., 3:6]
324 |
325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326 | z = torch.cross(x, y_raw, dim=-1)
327 | z = z / torch.norm(z, dim=-1, keepdim=True)
328 |
329 | y = torch.cross(z, x, dim=-1)
330 |
331 | x = x[..., None]
332 | y = y[..., None]
333 | z = z[..., None]
334 |
335 | mat = torch.cat([x, y, z], dim=-1)
336 | return mat
337 |
338 |
339 | def cont6d_to_matrix_np(cont6d):
340 | q = torch.from_numpy(cont6d).contiguous().float()
341 | return cont6d_to_matrix(q).numpy()
342 |
343 |
344 | def qpow(q0, t, dtype=torch.float):
345 | ''' q0 : tensor of quaternions
346 | t: tensor of powers
347 | '''
348 | q0 = qnormalize(q0)
349 | theta0 = torch.acos(q0[..., 0])
350 |
351 | ## if theta0 is close to zero, add epsilon to avoid NaNs
352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353 | theta0 = (1 - mask) * theta0 + mask * 10e-10
354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355 |
356 | if isinstance(t, torch.Tensor):
357 | q = torch.zeros(t.shape + q0.shape)
358 | theta = t.view(-1, 1) * theta0.view(1, -1)
359 | else: ## if t is a number
360 | q = torch.zeros(q0.shape)
361 | theta = t * theta0
362 |
363 | q[..., 0] = torch.cos(theta)
364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365 |
366 | return q.to(dtype)
367 |
368 |
369 | def qslerp(q0, q1, t):
370 | '''
371 | q0: starting quaternion
372 | q1: ending quaternion
373 | t: array of points along the way
374 |
375 | Returns:
376 | Tensor of Slerps: t.shape + q0.shape
377 | '''
378 |
379 | q0 = qnormalize(q0)
380 | q1 = qnormalize(q1)
381 | q_ = qpow(qmul(q1, qinv(q0)), t)
382 |
383 | return qmul(q_,
384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385 |
386 |
387 | def qbetween(v0, v1):
388 | '''
389 | find the quaternion used to rotate v0 to v1
390 | '''
391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393 |
394 | v = torch.cross(v0, v1)
395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396 | keepdim=True)
397 | return qnormalize(torch.cat([w, v], dim=-1))
398 |
399 |
400 | def qbetween_np(v0, v1):
401 | '''
402 | find the quaternion used to rotate v0 to v1
403 | '''
404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406 |
407 | v0 = torch.from_numpy(v0).float()
408 | v1 = torch.from_numpy(v1).float()
409 | return qbetween(v0, v1).numpy()
410 |
411 |
412 | def lerp(p0, p1, t):
413 | if not isinstance(t, torch.Tensor):
414 | t = torch.Tensor([t])
415 |
416 | new_shape = t.shape + p0.shape
417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419 | p0 = p0.view(new_view_p).expand(new_shape)
420 | p1 = p1.view(new_view_p).expand(new_shape)
421 | t = t.view(new_view_t).expand(new_shape)
422 |
423 | return p0 + t * (p1 - p0)
424 |
--------------------------------------------------------------------------------
/src/Lindyhop/models/MotionDiffusion_hand.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2021 S-Lab
3 | """
4 |
5 | import matplotlib.pylab as plt
6 | import random
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import layer_norm, nn
10 | import numpy as np
11 | from torch.nn import functional
12 |
13 | import math
14 | body = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
15 | hand = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]
16 | hand_full = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68]
17 |
18 | def heatmap2d(arr: np.ndarray):
19 | plt.imshow(arr, cmap='viridis')
20 | plt.clim(-10, 10)
21 | plt.colorbar()
22 | plt.show()
23 |
24 | def norm_array(x):
25 | return (x-np.min(x))/(np.max(x)-np.min(x))
26 |
27 | class PositionalEncoding(nn.Module):
28 | def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False):
29 | super().__init__()
30 | self.batch_first = batch_first
31 |
32 | self.dropout = nn.Dropout(p=dropout)
33 |
34 | pe = torch.zeros(max_len, d_model)
35 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
36 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
37 | pe[:, 0::2] = torch.sin(position * div_term)
38 | pe[:, 1::2] = torch.cos(position * div_term)
39 |
40 |
41 | for pos in range(max_len):
42 | for i in range(0, d_model-1, 2):
43 | pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
44 | pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
45 |
46 |
47 | pe = pe.unsqueeze(0) # [1, max_len, d_model]
48 |
49 | self.register_buffer('pe', pe)
50 |
51 | def forward(self, x):
52 | x = x + self.pe[:, :x.shape[1], :]
53 | return self.dropout(x)
54 |
55 |
56 | def timestep_embedding(timesteps, dim, freqs):
57 | """
58 | Create sinusoidal timestep embeddings.
59 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
60 | These may be fractional.
61 | :param dim: the dimension of the output.
62 | :param max_period: controls the minimum frequency of the embeddings.
63 | :return: an [N x dim] Tensor of positional embeddings.
64 | """
65 | # freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=self.latent_dim//2, dtype=torch.float32) / (self.latent_dim//2)).to(device)
66 |
67 | # timesteps= timesteps.to('cpu')
68 | args = timesteps[:, None].float() * freqs[None]
69 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
70 | if dim % 2:
71 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
72 | return embedding
73 |
74 |
75 | def set_requires_grad(nets, requires_grad=False):
76 | """Set requies_grad for all the networks.
77 |
78 | Args:
79 | nets (nn.Module | list[nn.Module]): A list of networks or a single
80 | network.
81 | requires_grad (bool): Whether the networks require gradients or not
82 | """
83 | if not isinstance(nets, list):
84 | nets = [nets]
85 | for net in nets:
86 | if net is not None:
87 | for param in net.parameters():
88 | param.requires_grad = requires_grad
89 |
90 |
91 | def zero_module(module):
92 | """
93 | Zero out the parameters of a module and return it.
94 | """
95 | for p in module.parameters():
96 | p.detach().zero_()
97 | return module
98 |
99 |
100 | class StylizationBlock(nn.Module):
101 |
102 | def __init__(self, latent_dim, time_embed_dim, dropout):
103 | super().__init__()
104 | self.emb_layers = nn.Sequential(
105 | nn.SiLU(),
106 | nn.Linear(time_embed_dim, 2 * latent_dim),
107 | )
108 | self.norm = nn.LayerNorm(latent_dim)
109 | self.out_layers = nn.Sequential(
110 | nn.SiLU(),
111 | nn.Dropout(p=dropout),
112 | zero_module(nn.Linear(latent_dim, latent_dim)),
113 | )
114 |
115 | def forward(self, h, emb):
116 | """
117 | h: B, T, D
118 | emb: B, D
119 | """
120 | # B, 1, 2D
121 | emb_out = self.emb_layers(emb).unsqueeze(1)
122 | # scale: B, 1, D / shift: B, 1, D
123 | scale, shift = torch.chunk(emb_out, 2, dim=2)
124 | h = self.norm(h) * (1 + scale) + shift
125 | h = self.out_layers(h)
126 | return h
127 |
128 |
129 |
130 | class FFN(nn.Module):
131 |
132 | def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim):
133 | super().__init__()
134 | self.linear1 = nn.Linear(latent_dim, ffn_dim)
135 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
136 | self.activation = nn.GELU()
137 | self.dropout = nn.Dropout(dropout)
138 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
139 |
140 | def forward(self, x, emb):
141 | y = self.linear2(self.dropout(self.activation(self.linear1(x))))
142 | y = x + self.proj_out(y, emb)
143 | return y
144 |
145 |
146 |
147 | class TemporalSelfAttention(nn.Module):
148 |
149 | def __init__(self, seq_len, latent_dim, num_head, dropout, time_embed_dim):
150 | super().__init__()
151 | self.num_head = num_head
152 | self.norm = nn.LayerNorm(latent_dim)
153 | self.query = nn.Linear(latent_dim, latent_dim)
154 | self.key = nn.Linear(latent_dim, latent_dim)
155 | self.value = nn.Linear(latent_dim, latent_dim)
156 | self.dropout = nn.Dropout(dropout)
157 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
158 |
159 | def forward(self, x, emb, src_mask, eps=1e-8):
160 | """
161 | x: B, T, D
162 | """
163 | B, T, D = x.shape
164 | H = self.num_head
165 | # B, T, 1, D
166 | query = self.query(self.norm(x)).unsqueeze(2)
167 | # B, 1, T, D
168 | key = self.key(self.norm(x)).unsqueeze(1)
169 | query = query.view(B, T, H, -1)
170 | key = key.view(B, T, H, -1)
171 | # B, T, T, H
172 | attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / max(math.sqrt(D // H), eps)
173 | # attention = attention * src_mask.unsqueeze(-1)
174 | weight = self.dropout(F.softmax(attention, dim=2))
175 | value = self.value(self.norm(x)).view(B, T, H, -1)
176 | y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
177 | y = x + self.proj_out(y, emb)
178 | return y, attention
179 |
180 | class TemporalCrossAttention(nn.Module):
181 |
182 | def __init__(self, seq_len, latent_dim, mot1_latent_dim, num_head, dropout, time_embed_dim):
183 | super().__init__()
184 | self.num_head = num_head
185 | self.norm = nn.LayerNorm(latent_dim)
186 | self.mot1_norm = nn.LayerNorm(mot1_latent_dim)
187 | self.query = nn.Linear(latent_dim, latent_dim)
188 | self.key = nn.Linear(mot1_latent_dim, latent_dim)
189 | self.value = nn.Linear(mot1_latent_dim, latent_dim)
190 | self.dropout = nn.Dropout(dropout)
191 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
192 |
193 | def forward(self, x, xf, src_mask, emb, eps=1e-8):
194 | """
195 | x: B, T, D
196 | xf: B, N, L
197 | """
198 | B, T, D = x.shape
199 | N = xf.shape[1]
200 | H = self.num_head
201 | # B, T, 1, D
202 | query = self.query(self.norm(x)).unsqueeze(2)
203 | # B, 1, N, D
204 | key = self.key(self.mot1_norm(xf)).unsqueeze(1)
205 | query = query.view(B, T, H, -1)
206 | key = key.view(B, N, H, -1)
207 | # B, T, N, H
208 | attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / max(math.sqrt(D // H), eps)
209 | attention = attention * src_mask
210 | weight = self.dropout(F.softmax(attention, dim=2))
211 | value = self.value(self.mot1_norm(xf)).view(B, N, H, -1)
212 | y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
213 | y = x + self.proj_out(y, emb)
214 | return y, attention
215 |
216 | class TemporalDiffusionTransformerDecoderLayer(nn.Module):
217 |
218 | def __init__(self,
219 | seq_len=60,
220 | latent_dim=32,
221 | mot1_latent_dim=512,
222 | time_embed_dim=128,
223 | ffn_dim=256,
224 | num_head=4,
225 | dropout=0.1):
226 | super().__init__()
227 | self.sa_block = TemporalSelfAttention(
228 | seq_len, latent_dim, num_head, dropout, time_embed_dim)
229 | self.ca_block = TemporalCrossAttention(
230 | seq_len, latent_dim, mot1_latent_dim, num_head, dropout, time_embed_dim)
231 | self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim)
232 |
233 | def forward(self, x, xf, emb, src_mask):
234 | x, s_attn = self.sa_block(x, emb, src_mask)
235 | x, c_attn = self.ca_block(x, xf, src_mask, emb)
236 | x = self.ffn(x, emb)
237 | return x, s_attn, c_attn
238 |
239 |
240 | class DiffusionTransformer(nn.Module):
241 | def __init__(self,
242 | device= 'cuda',
243 | num_frames=100,
244 | num_jts = 11,
245 | input_condn_feats=3,
246 | input_feats=3,
247 | latent_dim=32,
248 | ff_size=1024,
249 | num_layers=8,
250 | num_heads=4,
251 | dropout=0.05,
252 | activations="gelu",
253 | **kargs):
254 | super().__init__()
255 |
256 | self.num_frames = num_frames
257 | self.num_jts = num_jts
258 | self.latent_dim = latent_dim
259 | self.ff_size = ff_size
260 | self.num_layers = num_layers
261 | self.num_heads = num_heads
262 | self.freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=self.latent_dim//2, dtype=torch.float32) / (self.latent_dim//2)).to(device)
263 | self.dropout = dropout
264 | self.activation = activations
265 | self.input_condn_feats = input_condn_feats
266 | self.input_feats = input_feats
267 | self.time_embed_dim = latent_dim
268 | self.spatio_temp = self.num_frames * 2 * self.num_jts
269 |
270 | # encode motion 1
271 | self.motion1_pre_proj = nn.Linear(self.input_feats, self.latent_dim)
272 | self.m1_temporal_pos_encoder = PositionalEncoding(d_model=self.latent_dim, dropout=self.dropout, max_len=self.spatio_temp)
273 | mot1TransEncoderLayer = nn.TransformerEncoderLayer(
274 | d_model=latent_dim,
275 | nhead=num_heads,
276 | dim_feedforward=ff_size,
277 | dropout=dropout,
278 | batch_first=True,
279 | activation='gelu')
280 | self.mot1TransEncoder = nn.TransformerEncoder(
281 | mot1TransEncoderLayer,
282 | num_layers=2)
283 | self.mot1_ln = nn.LayerNorm(latent_dim)
284 | #Classifier-free guidance
285 | # self.null_cond = nn.Parameter(torch.randn(self.num_frames * self.num_jts, latent_dim))
286 |
287 | # Time Embedding
288 | self.time_embed = nn.Sequential(
289 | nn.Linear(self.latent_dim, self.time_embed_dim),
290 | nn.SiLU(),
291 | nn.Linear(self.time_embed_dim, self.time_embed_dim),
292 | )
293 |
294 | # motion2 decoding
295 | self.motion2_pre_proj = nn.Linear(self.input_feats, self.latent_dim)
296 | self.m2_temporal_pos_encoder = PositionalEncoding(d_model=self.latent_dim, dropout=self.dropout, max_len=self.spatio_temp)
297 | self.temporal_decoder_blocks = nn.ModuleList()
298 | for i in range(num_layers):
299 | self.temporal_decoder_blocks.append(
300 | TemporalDiffusionTransformerDecoderLayer(
301 | seq_len=self.spatio_temp,
302 | latent_dim=latent_dim,
303 | mot1_latent_dim=latent_dim,
304 | time_embed_dim=self.time_embed_dim,
305 | ffn_dim=ff_size,
306 | num_head=num_heads,
307 | dropout=dropout
308 | )
309 | )
310 | # Output Module
311 | self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats))
312 |
313 |
314 | def generate_src_mask(self, tgt):
315 | length = tgt.size(1)
316 | src_mask = (1 - torch.triu(torch.ones(1, length, length), diagonal=1))
317 | return src_mask
318 |
319 | # def forward(self, motion2, timesteps, length=None, motion1=None, xf_out=None, contact_map=None):
320 | def forward(self, motion2, timesteps, motion1=None, spatial_guidance=None):
321 | """
322 | x: B, T, D
323 | """
324 | B, T, D = motion1.shape
325 | contact_distance = motion1[:, :, -4:]
326 | rh_rh = contact_distance[:,:, 0] == 1
327 | rh_lh = contact_distance[:,:, 1] == 1
328 | lh_rh = contact_distance[:,:, 2] == 1
329 | lh_lh = contact_distance[:,:, 3] == 1
330 | rh_pose1 = torch.logical_or(rh_lh, rh_rh)
331 | lh_pose1 = torch.logical_or(lh_lh, lh_rh)
332 | m1 = self.motion1_pre_proj(motion1[:, :, :-4].reshape(B, T, -1, 3)) # GCN
333 | m2 = self.motion2_pre_proj(motion2.reshape(B, T, -1, 3)) # GCN
334 | m1 = m1.reshape(B, T*2*self.num_jts, -1)
335 | m2 = m2.reshape(B, T*2*self.num_jts, -1)
336 | src_mask = torch.zeros(B, T, self.num_jts*2).to(m1.device).float()
337 | src_mask[:, :, 0] = rh_pose1
338 | src_mask[:, :, 1] = rh_pose1
339 | src_mask[:, :, 2] = rh_pose1
340 | src_mask[:, :, 3] = rh_pose1
341 | src_mask[:, :, 4] = rh_pose1
342 | src_mask[:, :, 5] = rh_pose1
343 | src_mask[:, :, 6] = rh_pose1
344 | src_mask[:, :, 7] = rh_pose1
345 | src_mask[:, :, 8] = rh_pose1
346 | src_mask[:, :, 9] = rh_pose1
347 | src_mask[:, :, 10] = rh_pose1
348 | src_mask[:, :, 11] = lh_pose1
349 | src_mask[:, :, 12] = lh_pose1
350 | src_mask[:, :, 13] = lh_pose1
351 | src_mask[:, :, 14] = lh_pose1
352 | src_mask[:, :, 15] = lh_pose1
353 | src_mask[:, :, 16] = lh_pose1
354 | src_mask[:, :, 17] = lh_pose1
355 | src_mask[:, :, 18] = lh_pose1
356 | src_mask[:, :, 19] = lh_pose1
357 | src_mask[:, :, 20] = lh_pose1
358 | src_mask[:, :, 21] = lh_pose1
359 | src_mask = src_mask.reshape(B, -1)
360 | src_mask = torch.repeat_interleave(src_mask.unsqueeze(-1), src_mask.shape[1], axis=-1)
361 | src_mask = torch.repeat_interleave(src_mask.unsqueeze(-1), self.num_heads, axis=-1)
362 | # src_mask = self.generate_src_mask(m2).to(m2.device)
363 |
364 | m1_pe = self.m1_temporal_pos_encoder(m1)
365 | m1_cond = self.mot1_ln(self.mot1TransEncoder(m1_pe))
366 | # null_cond = torch.repeat_interleave(
367 | # self.null_cond.to(m2.device).unsqueeze(0), B, dim=0)
368 | # m1_enc = m1_cond if random.random() > 0.25 else null_cond
369 | m1_enc = m1_cond
370 | m2_pe = self.m2_temporal_pos_encoder(m2)
371 | emb = self.time_embed(timestep_embedding(
372 | timesteps, self.latent_dim, self.freqs) )
373 | h_pe = m2_pe
374 | for module in self.temporal_decoder_blocks:
375 | h_pe, s_attn, c_attn = module(h_pe, m1_enc, emb, src_mask)
376 | output = self.out(h_pe).view(B, T, -1).contiguous()
377 | # if timesteps == 400:
378 | # c_attn_map = c_attn[0,:,:,0].cpu().detach().numpy()
379 | # # ax = sns.heatmap(c_attn_map, vmin=-15, vmax=15, linewidth=2)
380 | # # plt.show()
381 | # heatmap2d(c_attn_map)
382 | # tmp=1
383 | return output, s_attn, c_attn
384 |
--------------------------------------------------------------------------------
/src/tools/bookkeeper.py:
--------------------------------------------------------------------------------
1 | # import pickle as torch
2 | import json
3 | import os
4 | import sys
5 | from datetime import datetime
6 | from tqdm import tqdm
7 | import copy
8 | import random
9 | import numpy as np
10 | from pathlib import Path
11 | import argparse
12 | import argunparse
13 | import warnings
14 | from prettytable import PrettyTable
15 |
16 | # from tensorboardX import SummaryWriter
17 | import torch
18 |
19 | import pdb
20 |
21 | def get_args_update_dict(args):
22 | args_update_dict = {}
23 | for string in sys.argv:
24 | string = ''.join(string.split('-'))
25 | if string in args:
26 | args_update_dict.update({string: args.__dict__[string]})
27 | return args_update_dict
28 |
29 | def accumulate_grads(model, grads_list):
30 | if grads_list:
31 | grads_list = [param.grad.data+old_grad.clone() for param, old_grad in zip(model.parameters(), grads_list)]
32 | else:
33 | grads_list += [param.grad.data for param in model.parameters()]
34 | return grads_list
35 |
36 | def save_grads(val, file_path):
37 | torch.save(val, open(file_path, 'wb'))
38 |
39 | def load_grads(file_path):
40 | return torch.load(open(file_path))
41 |
42 | class TensorboardWrapper():
43 | '''
44 | Wrapper to add values to tensorboard using a dictionary of values
45 | '''
46 | def __init__(self, log_dir):
47 | self.log_dir = log_dir
48 | self.writer = SummaryWriter(log_dir=self.log_dir, comment='NA')
49 |
50 | def __call__(self, write_dict):
51 | for key in write_dict:
52 | for value in write_dict[key]:
53 | getattr(self.writer, 'add_' + key)(*value)
54 |
55 | class BookKeeper():
56 | '''BookKeeper
57 | if load_pretrained_model = True
58 | bookKeeper will not update args and will also call _new_exp
59 |
60 | TODO: add documentation
61 | TODO: add save_optimizer_args as well
62 | TODO: choice of score kind to decide early-stopping (currently dev is default)
63 | Required properties in args
64 | - load
65 | - seed
66 | - save_dir
67 | - num_epochs
68 | - cuda
69 | - save_model
70 | - greedy_save
71 | - stop_thresh
72 | - eps
73 | - early stopping
74 | '''
75 | def __init__(self, args, args_subset,
76 | args_ext='args.args',
77 | name_ext='name.name',
78 | weights_ext='weights.p',
79 | res_ext='res.json',
80 | log_ext='log.log',
81 | script_ext='script.sh',
82 | args_dict_update={},
83 | res={'train':[], 'val':[], 'test':[]},
84 | tensorboard=None,
85 | load_pretrained_model=False):
86 |
87 | self.args = args
88 | self.save_flag = False
89 | self.args_subset = args_subset
90 | self.args_dict_update = args_dict_update
91 |
92 | self.args_ext = args_ext.split('.')
93 | self.name_ext = name_ext.split('.')
94 | self.weights_ext = weights_ext.split('.')
95 | self.res_ext = res_ext.split('.')
96 | self.log_ext = log_ext.split('.')
97 | self.script_ext = script_ext.split('.')
98 |
99 | ## params for saving/notSaving models
100 | self.stop_count = 0
101 |
102 | ## init empty results
103 | self.res = res
104 | if 'dev_key' in args:
105 | self.dev_key = args.dev_key
106 | self.dev_sign = args.dev_sign
107 | else:
108 | self.dev_key = 'val'
109 | self.dev_sign = 1
110 | self.best_dev_score = np.inf * self.dev_sign
111 |
112 | self.load_pretrained_model = load_pretrained_model
113 | self.last_epoch = 0
114 | if self.args.load:
115 | if os.path.isfile(self.args.load):
116 | ## update the save_dir if the files have moved
117 | self.save_dir = Path(args.load).parent.parent.as_posix()
118 | # self.save_dir = args.save_dir
119 |
120 | ## load Name
121 | self.name = self._load_name()
122 |
123 | ## load args
124 | self._load_args(args_dict_update)
125 |
126 | # if not self.load_pretrained_model:
127 | # ## Serialize and save args
128 | # self._save_args()
129 |
130 | ## load results
131 | self.res = self._load_res()
132 | self.last_epoch = self.res['epoch'][-1]
133 |
134 | else:
135 | ## run a new experiment
136 | self._new_exp()
137 |
138 | # if self.load_pretrained_model:
139 | # self._new_exp()
140 |
141 | ## Tensorboard
142 | if tensorboard:
143 | self.tensorboard = TensorboardWrapper(log_dir=(Path(self.save_dir)/Path(self.name.name+'tb')).as_posix())
144 | else:
145 | self.tensorboard = None
146 |
147 | self._set_seed()
148 |
149 | def _set_seed(self):
150 | ## seed numpy and torch
151 | random.seed(self.args.seed)
152 | np.random.seed(self.args.seed)
153 | torch.manual_seed(self.args.seed)
154 | torch.cuda.manual_seed_all(self.args.seed)
155 | torch.cuda.manual_seed(self.args.seed)
156 | #torch.backends.cudnn.deterministic = True
157 | #torch.backends.cudnn.benchmark = False
158 |
159 | '''
160 | Stuff to do for a new experiment
161 | '''
162 | def _new_exp(self):
163 | ## update the experiment number
164 | self._update_exp()
165 |
166 | self.save_dir = self.args.save_dir
167 | self.name = Name(self.args, *self.args_subset)
168 |
169 | ## save name
170 | self._save_name()
171 |
172 | ## update args
173 | self.args.__dict__.update(self.args_dict_update)
174 |
175 | ## Serialize and save args
176 | self._save_args()
177 |
178 | ## save script
179 | #self._save_script() ## not functional yet. needs some work
180 |
181 | ## reinitialize results to empty
182 | self.res = {key:[] for key in self.res}
183 |
184 | def _update_exp(self):
185 | if self.args.exp is not None:
186 | exp = 0
187 | exp_file = '.experiments'
188 | if not os.path.exists(exp_file):
189 | with open(exp_file, 'w') as f:
190 | f.writelines([f'{exp}\n'])
191 | else:
192 | with open(exp_file, 'r') as f:
193 | lines = f.readlines()
194 | exp = int(lines[0].strip())
195 | exp += 1
196 | with open(exp_file, 'w') as f:
197 | f.writelines([f'{exp}\n'])
198 | else:
199 | exp = 0
200 | print(f'Experiment Number: {exp}')
201 | self.args.__dict__.update({'exp':exp})
202 |
203 | def _load_name(self):
204 | name_filepath = '_'.join(self.args.load.split('_')[:-1] + ['.'.join(self.name_ext)])
205 | return torch.load(open(name_filepath, 'rb'))
206 |
207 | def _save_name(self):
208 | name_filepath = self.name(self.name_ext[0], self.name_ext[1], self.save_dir)
209 | torch.save(self.name, open(name_filepath, 'wb'))
210 |
211 | def _load_res(self):
212 | res_filepath = self.name(self.res_ext[0], self.res_ext[1], self.save_dir)
213 | # res_filepath = '_'.join(self.args.load.split('_')[:-1] + ['.'.join(self.res_ext)])
214 | if os.path.exists(res_filepath):
215 | print('Results Loaded')
216 | return json.load(open(res_filepath))
217 | else:
218 | warnings.warn('Could not find result file')
219 | return self.res
220 |
221 | def _save_res(self):
222 | res_filepath = self.name(self.res_ext[0], self.res_ext[1], self.save_dir)
223 | json.dump(self.res, open(res_filepath,'w'))
224 |
225 | def update_res(self, res):
226 | for key in res:
227 | if key in self.res:
228 | self.res[key].append(res[key])
229 | else:
230 | self.res[key] = [res[key]]
231 |
232 | def update_tb(self, write_dict):
233 | if self.tensorboard:
234 | self.tensorboard(write_dict)
235 | else:
236 | warnings.warn('TensorboardWrapper not declared')
237 |
238 | def print_res(self, epoch, key_order=['train', 'val', 'test'], metric_order=[], exp=0, lr=None, fmt='{:.16f}'):
239 | print_str = "exp: {}, epoch: {}, lr:{}"
240 | table = PrettyTable([''] + key_order)
241 | table_str = ['loss'] + [fmt.format(self.res[key][-1]) for key in key_order] ## loss
242 | table.add_row(table_str)
243 | for metric in metric_order:
244 | table_str = [metric] + [fmt.format(self.res['{}_{}'.format(key, metric)][-1]) for key in key_order]
245 | table.add_row(table_str)
246 |
247 | if isinstance(lr, list):
248 | lr = lr[0]
249 | tqdm.write(print_str.format(exp, epoch, lr))
250 | tqdm.write(table.__str__())
251 |
252 | def print_res_archive(self, epoch, key_order=['train', 'val', 'test'], exp=0, lr=None, fmt='{:.9f}'):
253 | print_str = ', '.join(["exp: {}, epch: {}, lr:{}, "] + ["{}: {}".format(key,fmt) for key in key_order])
254 | result_list = [self.res[key][-1] for key in key_order]
255 | if isinstance(lr, list):
256 | lr = lr[0]
257 | tqdm.write(print_str.format(exp, epoch, lr, *result_list))
258 |
259 | def _load_args(self, args_dict_update):
260 | args_filepath = self.name(self.args_ext[0], self.args_ext[1], self.save_dir)
261 | # args_filepath = '_'.join(self.args.load.split('_')[:-1] + ['.'.join(self.args_ext)])
262 | if os.path.isfile(args_filepath):
263 | args_dict = json.load(open(args_filepath))
264 | ## update load path and cuda device to use
265 | args_dict.update({'load':self.args.load,
266 | 'cuda':self.args.cuda,
267 | 'save_dir':self.save_dir})
268 | ## any new argument to be updated
269 | args_dict.update(args_dict_update)
270 |
271 | self.args.__dict__.update(args_dict)
272 |
273 | def _save_args(self):
274 | args_filepath = self.name(self.args_ext[0], self.args_ext[1], self.save_dir)
275 | json.dump(self.args.__dict__, open(args_filepath, 'w'))
276 |
277 | def _save_script(self):
278 | '''
279 | Not functional
280 | '''
281 | args_filepath = self.name(self.script_ext[0], self.script_ext[1], self.save_dir)
282 | unparser = argunparse.ArgumentUnparser()
283 | options = get_args_update_dict(self.args)#self.args.__dict__
284 | args = {}
285 | script = unparser.unparse_to_list(*args, **options)
286 | script = ['python', sys.argv[0]] + script
287 | script = ' '.join(script)
288 | with open(args_filepath, 'w') as fp:
289 | fp.writelines(script)
290 |
291 | def _load_model(self, model, model_id):
292 | # weights_path = self.name(self.weights_ext[0], self.weights_ext[1], self.save_dir)
293 | weights_path = self.name(self.args.load.split('_')[-1].split('.')[0], self.weights_ext[1], self.save_dir)
294 | m = torch.load(open(weights_path, 'rb'))
295 | model.load_state_dict(m[model_id])
296 | print('Model loaded')
297 |
298 | @staticmethod
299 | def load_pretrained_model(model, path2model):
300 | model.load_state_dict(torch.load(open(path2model, 'rb')))
301 | return model
302 |
303 | def _save_model(self, model_state_dict, out, model_id='model_pose'):
304 | weights_path = self.name(self.weights_ext[0], self.weights_ext[1], self.save_dir)
305 | f = open(weights_path, 'wb')
306 | out.update({model_id: model_state_dict})
307 | torch.save(out, f)
308 | f.close()
309 |
310 | def _copy_best_model(self, model):
311 | if isinstance(model, torch.nn.DataParallel):
312 | self.best_model = copy.deepcopy(model.module.state_dict())
313 | else:
314 | self.best_model = copy.deepcopy(model.state_dict())
315 |
316 | def _start_log(self):
317 | with open(self.name(self.log_ext[0],self.log_ext[1], self.save_dir), 'w') as f:
318 | f.write("S: {}\n".format(str(datetime.now())))
319 |
320 | def _stop_log(self):
321 | with open(self.name(self.log_ext[0],self.log_ext[1], self.save_dir), 'r') as f:
322 | lines = f.readlines()
323 | if len(lines) > 1: ## this has already been sampled before
324 | lines = lines[0:1] + ["E: {}\n".format(str(datetime.now()))]
325 | else:
326 | lines.append("E: {}\n".format(str(datetime.now())))
327 | with open(self.name(self.log_ext[0],self.log_ext[1], self.save_dir), 'w') as f:
328 | f.writelines(lines)
329 |
330 | def stop_training(self, model, model_id, epoch, out_dict, warmup=False):
331 | ## copy the best model
332 | if self.dev_sign * self.res[self.dev_key][-1] < self.dev_sign * self.best_dev_score and (not warmup):
333 | self._copy_best_model(model)
334 | self.best_dev_score = self.res[self.dev_key][-1]
335 | self.save_flag = True
336 | else:
337 | self.save_flag = False
338 |
339 | ## debug mode with no saving
340 | if not self.args.save_model:
341 | self.save_flag = False
342 |
343 | if self.args.overfit:
344 | self._copy_best_model(model)
345 | self.save_flag = True
346 |
347 | if epoch % 10 == 0:
348 | self.save_flag = True
349 |
350 | if self.save_flag:
351 | tqdm.write('Saving Model at epoch {}'.format(epoch))
352 | self._copy_best_model(model)
353 | self._save_model(self.best_model, out_dict, model_id)
354 |
355 | ## early_stopping
356 | if self.args.early_stopping and len(self.res['train'])>=2 and not self.args.overfit:
357 | if (self.dev_sign*(self.res[self.dev_key][-2] - self.args.eps) < self.dev_sign * self.res[self.dev_key][-1]):
358 | self.stop_count += 1
359 | else:
360 | self.stop_count = 0
361 |
362 | if self.stop_count >= self.args.stop_thresh:
363 | print('Validation Loss is increasing')
364 | ## save the best model now
365 | if self.args.save_model:
366 | print('Saving Model by early stopping')
367 | self.save_flag = True
368 | self._copy_best_model(model)
369 | self._save_model(self.best_model, out_dict, model_id)
370 | return self.save_flag
371 |
372 | ## end of training loop
373 | if epoch == self.args.num_epochs-1 and self.args.save_model:
374 | print('Saving model after exceeding number of epochs')
375 | self.save_flag = True
376 | self._copy_best_model(model)
377 | self._save_model(self.best_model, out_dict, model_id)
378 |
379 | return self.save_flag
380 |
381 |
382 | class Name(object):
383 | ''' Create a name based on hyper-parameters, other arguments
384 | like number of epochs or error rates
385 |
386 | Arguments:
387 | path2file/...argname_value_..._outputkind.ext
388 |
389 | args: Namespace(argname,value, ....) generally taken from an argparser variable
390 | argname: Hyper-parameters (i.e. model structure)
391 | value: Values of the corresponding Hyper-parameters
392 |
393 | path2file: set as './' by default and decides the path where the file is to be stored
394 | outputkind: what is the kind of output 'err', 'vis', 'cpk' or any other acronym given as a string
395 | ext: file type given as a string
396 |
397 | *args_subset: The subset of arguments to be used and its order
398 |
399 | Methods:
400 | Name.dir(path2file): creates a directory at `path2file` with a name derived from arguments
401 | but outputkind and ext are omitted
402 | '''
403 |
404 | def __init__(self, args, *args_subset):
405 | self.name = ''
406 | args_dict = vars(args)
407 | args_subset = list(args_subset)
408 |
409 | ## if args_subset is not provided take all the keys from args_dict
410 | if not args_subset:
411 | args_subset = list(args_dict.keys())
412 |
413 | ## if args_subset is derived from an example name
414 | for i, arg_sub in enumerate(args_subset):
415 | for arg in args_dict:
416 | if arg_sub == ''.join(arg.split('_')):
417 | args_subset[i] = arg
418 |
419 | ## If args_subset is empty exit
420 | assert args_subset, 'Subset of arguments to be chosen is empty'
421 |
422 | ## Scan through required arguments in the name
423 | for arg in args_subset:
424 | if arg not in args_dict:
425 | warnings.warn('Key %s does not exist. Skipping...'%(arg))
426 | else:
427 | self.name += '%s_%s_' % (''.join(arg.split('_')), '-'.join(str(args_dict[arg]).split('.')))
428 |
429 | def dir(self, path2file='./'):
430 | try:
431 | os.makedirs(os.path.join(path2file, self.name[:-1]))
432 | except OSError:
433 | if not os.path.isdir(path2file):
434 | raise 'Directory could not be created. Check if you have the required permissions to make changes at the given path.'
435 | return os.path.join(path2file, self.name[:-1])
436 |
437 |
438 | def __call__(self, outputkind, ext, path2file='./'):
439 | try:
440 | os.makedirs(os.path.join(path2file, self.name))
441 | except OSError:
442 | if not os.path.isdir(path2file):
443 | raise 'Directory could not be created. Check if you have the required permissions to make changes at the given path.'
444 | return os.path.join(path2file, self.name, self.name + '%s.%s' %(outputkind,ext))
445 |
--------------------------------------------------------------------------------
/src/Lindyhop/train_hand_diffusion.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import shutil
7 | import sys
8 | sys.path.append('.')
9 | sys.path.append('..')
10 | import time
11 | import torch
12 | torch.cuda.empty_cache()
13 | import torch.nn as nn
14 |
15 | from cmath import nan
16 | from collections import OrderedDict
17 | from datetime import datetime
18 | from torch import optim
19 | from torch.utils.data import DataLoader
20 | from tqdm import tqdm
21 |
22 | from src.Lindyhop.argUtils import argparseNloop
23 | from src.Lindyhop.LindyHop_dataloader import LindyHopDataset
24 | from src.Lindyhop.models.MotionDiffusion_hand import *
25 | from src.Lindyhop.models.Gaussian_diffusion import (
26 | GaussianDiffusion,
27 | get_named_beta_schedule,
28 | create_named_schedule_sampler,
29 | ModelMeanType,
30 | ModelVarType,
31 | LossType
32 | )
33 | from src.Lindyhop.skeleton import *
34 | from src.Lindyhop.visualizer import plot_contacts3D
35 | from src.tools.bookkeeper import *
36 | from src.tools.calculate_ev_metrics import *
37 | from src.tools.transformations import *
38 | from src.tools.utils import makepath
39 |
40 |
41 | def dist(x, y):
42 | # return torch.mean(x - y)
43 | return torch.mean(torch.cdist(x, y, p=2))
44 |
45 | def initialize_weights(m):
46 | std_dev = 0.02
47 | if isinstance(m, nn.Linear):
48 | nn.init.normal_(m.weight, std=std_dev)
49 | if m.bias is not None:
50 | nn.init.normal_(m.bias, std=std_dev)
51 | # nn.init.constant_(m.bias.data, 1e-5)
52 | elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
53 | torch.nn.init.normal_(m.weight, std=std_dev)
54 | if m.bias is not None:
55 | torch.nn.init.normal_(m.bias, std=std_dev)
56 | # nn.init.constant_(m.bias.data, 1e-5)
57 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
58 | nn.init.normal_(m.weight, std=std_dev)
59 | if m.bias is not None:
60 | nn.init.normal_(m.bias, std=std_dev)
61 |
62 | class Trainer:
63 | def __init__(self, args, is_train=True, split='test', JT_POSITION=False, num_jts = 69):
64 | torch.manual_seed(args.seed)
65 | self.model_path = args.model_path
66 | makepath(args.work_dir, isfile=False)
67 | use_cuda = torch.cuda.is_available()
68 | if use_cuda:
69 | torch.cuda.empty_cache()
70 | self.device = torch.device("cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")
71 | gpu_brand = torch.cuda.get_device_name(args.cuda) if use_cuda else None
72 | gpu_count = torch.cuda.device_count() if args.use_multigpu else 1
73 | print('Using %d CUDA cores [%s] for training!' % (gpu_count, gpu_brand))
74 | args_subset = ['exp', 'model', 'batch_size', 'frames']
75 | self.book = BookKeeper(args, args_subset)
76 | self.args = self.book.args
77 | self.batch_size = args.batch_size
78 | self.curriculum = args.curriculum
79 | self.scale = args.scale
80 | self.dtype = torch.float32
81 | self.epochs_completed = self.book.last_epoch
82 | self.frames = args.frames
83 | self.model = args.model
84 | self.lambda_loss = args.lambda_loss
85 | self.testtime_split = split
86 | self.num_jts = num_jts
87 | self.model_pose = eval(args.model)(device=self.device,
88 | num_frames=self.frames,
89 | num_jts=self.num_jts,
90 | input_feats=args.hand_out_feats,
91 | latent_dim=args.d_modelhand,
92 | num_heads=args.num_head_hands,
93 | num_layers=args.num_layer_hands,
94 | ff_size=args.d_ffhand,
95 | activations=args.activations
96 | ).to(self.device).float()
97 | self.diffusion_steps = args.diffusion_steps
98 | self.beta_scheduler = args.noise_schedule
99 | self.betas = get_named_beta_schedule(self.beta_scheduler, self.diffusion_steps)
100 | self.diffusion = GaussianDiffusion(
101 | betas=self.betas,
102 | model_mean_type=ModelMeanType.START_X,
103 | model_var_type=ModelVarType.FIXED_SMALL,
104 | loss_type=LossType.MSE
105 | )
106 | self.sampler_name = args.sampler
107 | self.sampler = create_named_schedule_sampler(self.sampler_name, self.diffusion)
108 | self.model_pose.apply(initialize_weights)
109 | self.optimizer_model_pose = eval(args.optimizer)(self.model_pose.parameters(), lr = args.lr)
110 | self.scheduler_pose = eval(args.scheduler)(self.optimizer_model_pose, step_size=args.stepsize, gamma=args.gamma)
111 | # self.scheduler_pose = eval(args.scheduler)(self.optimizer_model_pose, factor=args.factor, patience=args.patience, threshold= args.threshold, min_lr = 2e-7)
112 | self.skel = InhouseStudioSkeleton()
113 | self.mse_criterion = torch.nn.MSELoss()
114 | self.l1_criterion = torch.nn.L1Loss()
115 | self.BCE_criterion = torch.nn.BCELoss()
116 | print(args.model, 'Model Created')
117 | if args.load:
118 | print('Loading Model', args.model)
119 | self.book._load_model(self.model_pose, 'model_pose')
120 | print('Loading the data')
121 | if is_train:
122 | self.load_data(args)
123 | else:
124 | self.load_data_testtime(args)
125 | self.mean_var_norm = torch.load(args.mean_var_norm)
126 |
127 | def load_data_testtime(self, args):
128 | self.ds_data = LindyHopDataset(args, window_size=self.frames, split=self.testtime_split)
129 |
130 | def load_data(self, args):
131 | ds_train = LindyHopDataset(args, window_size=self.frames, split='train')
132 | self.ds_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
133 | print('Train set loaded. Size=', len(self.ds_train.dataset))
134 | ds_val = LindyHopDataset(args, window_size=self.frames, split='test')
135 | self.ds_val = DataLoader(ds_val, batch_size=1, shuffle=False, num_workers=0, drop_last=True)
136 | print('Validation set loaded. Size=', len(self.ds_val.dataset))
137 |
138 |
139 | def calc_loss(self, num_epoch):
140 | bs, seq, dim = self.generated.shape
141 | pos_loss = self.lambda_loss['pos'] * self.mse_criterion(self.input, self.generated)
142 | vel_gt = self.input[:, 1:] - self.input[:, :-1]
143 | vel_gen = self.generated[:, 1:] - self.generated[:, :-1]
144 | velocity_loss = self.lambda_loss['vel'] * self.mse_criterion(vel_gt, vel_gen)
145 | acc_gt = vel_gt[:, 1:] - vel_gt[:, :-1]
146 | acc_gen = vel_gen[:, 1:] - vel_gen[:, :-1]
147 | acc_loss = self.lambda_loss['vel'] * self.mse_criterion(acc_gt, acc_gen)
148 | gt_pose = self.input.reshape(bs, seq, -1, 3)
149 | gen_pose = self.generated.reshape(bs, seq, -1, 3)
150 | p1_rhand = self.p1_rhand_pos
151 | p2gt_rhand = self.p2_rhand_pos
152 | p2gen_rhand = gen_pose[:, :, :11]
153 | p1_lhand = self.p1_lhand_pos
154 | p2gt_lhand = self.p2_lhand_pos
155 | p2gen_lhand = gen_pose[:, :, 11:]
156 |
157 | bone_len_gt = (gt_pose[:, :, 1:] - gt_pose[:, :, [self.skel.parent_fingers[x] for x in range(1, 2*self.num_jts)]]).norm(dim=-1)
158 | bone_len_gen = (gen_pose[:, :, 1:] - gen_pose[:, :, [self.skel.parent_fingers[x] for x in range(1, 2*self.num_jts)]]).norm(dim=-1)
159 | bone_len_consistency_loss = self.lambda_loss['bone'] * self.mse_criterion(bone_len_gt, bone_len_gen)
160 |
161 | loss_logs = [pos_loss, velocity_loss, acc_loss, bone_len_consistency_loss]
162 |
163 | # #include the interaction loss
164 |
165 | self.lambda_loss['in'] = 10.0
166 | rh_rh = self.contact_map[:,:, 0] == 1
167 | rh_lh = self.contact_map[:,:, 1] == 1
168 | lh_rh = self.contact_map[:,:, 2] == 1
169 | lh_lh = self.contact_map[:,:, 3] == 1
170 |
171 | interact_loss = self.lambda_loss['in'] * torch.mean( rh_lh * ((p1_rhand - p2gt_lhand).norm(dim=-1) - (p1_rhand - p2gen_lhand).norm(dim=-1)).norm(dim=-1) +
172 | rh_rh * ((p1_rhand - p2gt_rhand).norm(dim=-1) - (p1_rhand - p2gen_rhand).norm(dim=-1)).norm(dim=-1) +
173 | lh_rh * ((p1_lhand - p2gt_rhand).norm(dim=-1) - (p1_lhand - p2gen_rhand).norm(dim=-1)).norm(dim=-1) +
174 | lh_lh * ((p1_lhand - p2gt_lhand).norm(dim=-1) - (p1_lhand - p2gen_lhand).norm(dim=-1)).norm(dim=-1) )
175 | loss_logs.append(interact_loss)
176 | return loss_logs
177 |
178 | def forward(self, motions1, motions2, t=None):
179 | B, T = motions2.shape[:2]
180 | if t == None:
181 | t, _ = self.sampler.sample(B, motions1.device)
182 | self.diffusion_timestep = t
183 | output = self.diffusion.training_losses(
184 | model=self.model_pose,
185 | x_start=motions2,
186 | t=t,
187 | model_kwargs={"motion1": motions1}
188 | )
189 |
190 | self.generated = output['pred'] # synthesized pose 2
191 | return t, output['x_noisy']
192 |
193 | def generate(self, motion1, motion2=None):
194 | B, T, J, dim_pose = motion1.shape
195 | output = self.diffusion.p_sample_loop(
196 | self.model_pose,
197 | (B, T, J, dim_pose),
198 | clip_denoised=False,
199 | progress=True,
200 | pre_seq= motion2,
201 | model_kwargs={
202 | 'motion1': motion1,
203 | })
204 | return output
205 |
206 |
207 | def relative_normalization(self, global_pose1, global_pose2, global_rot1, global_rot2):
208 | self.p1_rhand_wrist_pos = global_pose1[:, :, 18]
209 | self.p1_lhand_wrist_pos = global_pose1[:, :, 43]
210 | p1_rhand_wrist_pos = (global_pose1[:, :, 18] - self.p1_rhand_wrist_pos) / self.scale
211 | p1_lhand_wrist_pos = (global_pose1[:, :, 43] - self.p1_lhand_wrist_pos) / self.scale
212 | p2_rhand_wrist_pos = (global_pose2[:, :, 18] - self.p1_rhand_wrist_pos) / self.scale
213 | p2_lhand_wrist_pos = (global_pose2[:, :, 43] - self.p1_lhand_wrist_pos) / self.scale
214 | B = p1_rhand_wrist_pos.shape[0]
215 | T = p1_rhand_wrist_pos.shape[1]
216 | p1_rhand_rot = self.skel.select_bvh_joints(global_rot1, original_joint_order=self.skel.bvh_joint_order,
217 | new_joint_order=self.skel.rh_fingers_only).reshape(B, T, -1)
218 | p1_lhand_rot = self.skel.select_bvh_joints(global_rot1, original_joint_order=self.skel.bvh_joint_order,
219 | new_joint_order=self.skel.lh_fingers_only).reshape(B, T, -1)
220 |
221 | p2_rhand_rot = self.skel.select_bvh_joints(global_rot2, original_joint_order=self.skel.bvh_joint_order,
222 | new_joint_order=self.skel.rh_fingers_only).reshape(B, T, -1)
223 | p2_lhand_rot = self.skel.select_bvh_joints(global_rot2, original_joint_order=self.skel.bvh_joint_order,
224 | new_joint_order=self.skel.lh_fingers_only).reshape(B, T, -1)
225 |
226 | # create a contact map based on threshold of wrists
227 | self.contact_dist = torch.zeros(B, T, 4).to(p1_lhand_rot.device).float()
228 | self.contact_dist[:,:, 0] = ((p1_rhand_wrist_pos - p2_rhand_wrist_pos)**2).norm(dim=-1)
229 | self.contact_dist[:,:, 1] = ((p1_rhand_wrist_pos - p2_lhand_wrist_pos)**2).norm(dim=-1)
230 | self.contact_dist[:,:, 2] = ((p1_lhand_wrist_pos - p2_rhand_wrist_pos)**2).norm(dim=-1)
231 | self.contact_dist[:,:, 3] = ((p1_lhand_wrist_pos - p2_lhand_wrist_pos)**2).norm(dim=-1)
232 |
233 | self.input_condn = torch.cat((p1_rhand_wrist_pos, p1_lhand_wrist_pos,
234 | p2_rhand_wrist_pos, p2_lhand_wrist_pos,
235 | p1_rhand_rot, p1_lhand_rot, self.contact_dist), dim=-1)
236 | self.input = torch.cat((p2_rhand_rot, p2_lhand_rot), dim=-1)
237 |
238 | def pose_relative_normalization(self, global_pose1, global_pose2, contact_maps):
239 | p1_rhand_pos = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order,
240 | new_joint_order=self.skel.rh_fingers_only)
241 | p1_lhand_pos = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order,
242 | new_joint_order=self.skel.lh_fingers_only)
243 |
244 | p2_rhand_pos = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order,
245 | new_joint_order=self.skel.rh_fingers_only)
246 | p2_lhand_pos = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order,
247 | new_joint_order=self.skel.lh_fingers_only)
248 | self.p1_rhand_wrist_pos = p1_rhand_pos[:, :, 0]
249 | self.p2_rhand_wrist_pos = p2_rhand_pos[:, :, 0]
250 | self.p1_lhand_wrist_pos = p1_lhand_pos[:, :, 0]
251 | self.p2_lhand_wrist_pos = p2_lhand_pos[:, :, 0]
252 | self.p1_rhand_pos = (p1_rhand_pos - torch.repeat_interleave(self.p1_rhand_wrist_pos.unsqueeze(-2), self.num_jts, axis=-2))/self.scale
253 | self.p1_lhand_pos = (p1_lhand_pos - torch.repeat_interleave(self.p1_lhand_wrist_pos.unsqueeze(-2), self.num_jts, axis=-2))/self.scale
254 | self.p2_rhand_pos = (p2_rhand_pos - torch.repeat_interleave(self.p2_rhand_wrist_pos.unsqueeze(-2), self.num_jts, axis=-2))/self.scale
255 | self.p2_lhand_pos = (p2_lhand_pos - torch.repeat_interleave(self.p2_lhand_wrist_pos.unsqueeze(-2), self.num_jts, axis=-2))/self.scale
256 | B = self.p1_rhand_wrist_pos.shape[0]
257 | T = self.p1_rhand_wrist_pos.shape[1]
258 |
259 | self.contact_map = contact_maps.to(self.device).float()
260 | self.input_condn = torch.cat((self.p1_rhand_pos.reshape(B, T, -1), self.p1_lhand_pos.reshape(B, T, -1),
261 | self.contact_map), dim=-1)
262 | self.input = torch.cat((self.p2_rhand_pos.reshape(B, T, -1), self.p2_lhand_pos.reshape(B, T, -1)), dim=-1)
263 |
264 |
265 | def train(self, num_epoch, ablation=None):
266 | total_train_loss = 0.0
267 | self.model_pose.train()
268 | training_tqdm = tqdm(self.ds_train, desc='train' + ' {:.10f}'.format(0), leave=False, ncols=120)
269 | # self.joint_parent = self.ds_train.dataset.bvh_joint_parents_list
270 | diff_count = [0, 5, 10, 50, 100, 200, 300, 400, 499]
271 | for count, batch in enumerate(training_tqdm):
272 | self.optimizer_model_pose.zero_grad()
273 |
274 | # with torch.autograd.detect_anomaly():
275 | if True:
276 | global_pose1 = batch['pose_canon_1'].to(self.device).float()
277 | global_pose2 = batch['pose_canon_2'].to(self.device).float()
278 | if global_pose1.shape[1] == 0:
279 | continue
280 | self.pose_relative_normalization(global_pose1, global_pose2, batch['contacts'])
281 | t, noisy = self.forward(self.input_condn, self.input)
282 |
283 | loss_logs = self.calc_loss(num_epoch)
284 | loss_model = sum(loss_logs)
285 | total_train_loss += loss_model.item()
286 |
287 | if loss_model == float('inf') or torch.isnan(loss_model):
288 | print('Train loss is nan')
289 | exit()
290 | loss_model.backward()
291 | torch.nn.utils.clip_grad_value_(self.model_pose.parameters(), 0.01)
292 | self.optimizer_model_pose.step()
293 |
294 | avg_train_loss = total_train_loss/(count + 1)
295 | return avg_train_loss
296 |
297 | def evaluate(self, num_epoch, ablation=None):
298 | total_eval_loss = 0.0
299 | self.model_pose.eval()
300 | T = self.frames
301 | eval_tqdm = tqdm(self.ds_val, desc='eval' + ' {:.10f}'.format(0), leave=False, ncols=120)
302 |
303 | for count, batch in enumerate(eval_tqdm):
304 | if True:
305 | global_pose1 = batch['pose_canon_1'].to(self.device).float()
306 | global_pose2 = batch['pose_canon_2'].to(self.device).float()
307 | if global_pose1.shape[1] == 0:
308 | continue
309 | self.pose_relative_normalization(global_pose1, global_pose2, batch['contacts'])
310 | t, noisy = self.forward(self.input_condn, self.input)
311 | loss_logs = self.calc_loss(num_epoch)
312 | loss_model = sum(loss_logs)
313 | total_eval_loss += loss_model.item()
314 |
315 | avg_eval_loss = total_eval_loss/(count + 1)
316 |
317 | return avg_eval_loss
318 |
319 | def fit(self, n_epochs=None, ablation=False):
320 | print('*****Inside Trainer.fit *****')
321 | if n_epochs is None:
322 | n_epochs = self.args.num_epochs
323 | starttime = datetime.now().replace(microsecond=0)
324 | print('Started Training at', datetime.strftime(starttime, '%Y-%m-%d_%H:%M:%S'), 'Total epochs: ', n_epochs)
325 | save_model_dict = {}
326 | best_eval = 1000
327 |
328 | for epoch_num in range(self.epochs_completed, n_epochs + 1):
329 | tqdm.write('--- starting Epoch # %03d' % epoch_num)
330 | train_loss = self.train(epoch_num, ablation)
331 | if epoch_num % 5 == 0:
332 | eval_loss = self.evaluate(epoch_num, ablation)
333 | else:
334 | eval_loss = 0.0
335 | self.scheduler_pose.step()
336 | self.book.update_res({'epoch': epoch_num, 'train': train_loss, 'val': eval_loss, 'test': 0.0})
337 | self.book._save_res()
338 | self.book.print_res(epoch_num, key_order=['train', 'val', 'test'], lr=self.optimizer_model_pose.param_groups[0]['lr'])
339 |
340 | if epoch_num > 100 and eval_loss < best_eval:
341 | print('Best eval at epoch {}'.format(epoch_num))
342 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + 'best.p'), 'wb')
343 | save_model_dict.update({'model_pose': self.model_pose.state_dict()})
344 | torch.save(save_model_dict, f)
345 | f.close()
346 | best_eval = eval_loss
347 | if epoch_num > 20 and epoch_num % 20 == 0 :
348 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + '{:06d}'.format(epoch_num) + '.p'), 'wb')
349 | save_model_dict.update({'model_pose': self.model_pose.state_dict()})
350 | torch.save(save_model_dict, f)
351 | f.close()
352 | endtime = datetime.now().replace(microsecond=0)
353 | print('Finished Training at %s\n' % (datetime.strftime(endtime, '%Y-%m-%d_%H:%M:%S')))
354 | print('Training complete in %s!\n' % (endtime - starttime))
355 |
356 |
357 |
358 | if __name__ == '__main__':
359 | args = argparseNloop()
360 | args.lambda_loss = {
361 | 'fk': 1.0,
362 | 'fk_vel': 1.0,
363 | 'rot': 1e+3,
364 | 'rot_vel': 1e+1,
365 | 'kldiv': 1.0,
366 | 'pos': 1e+3,
367 | 'vel': 1e+1,
368 | 'bone': 1.0,
369 | 'foot': 0.0,
370 | }
371 | is_train = True
372 | ablation = None # if True then ablation: no_IAC_loss
373 | model_trainer = Trainer(args=args, is_train=is_train, split='train', JT_POSITION=True, num_jts=11)
374 | print("** Method Initialization Complete **")
375 | model_trainer.fit(ablation=ablation)
376 |
377 |
--------------------------------------------------------------------------------
/src/Lindyhop/train_body_diffusion.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import shutil
7 | import sys
8 | sys.path.append('.')
9 | sys.path.append('..')
10 | import time
11 | import torch
12 | torch.cuda.empty_cache()
13 | import torch.nn as nn
14 |
15 | from cmath import nan
16 | from collections import OrderedDict
17 | from datetime import datetime
18 | from torch import optim
19 | from torch.utils.data import DataLoader
20 | from tqdm import tqdm
21 |
22 | from src.Lindyhop.argUtils import argparseNloop
23 | from src.Lindyhop.LindyHop_dataloader import LindyHopDataset
24 | from src.Lindyhop.models.MotionDiffuse_body import *
25 | from src.Lindyhop.models.Gaussian_diffusion import (
26 | GaussianDiffusion,
27 | get_named_beta_schedule,
28 | create_named_schedule_sampler,
29 | ModelMeanType,
30 | ModelVarType,
31 | LossType
32 | )
33 | from src.Lindyhop.skeleton import *
34 | from src.Lindyhop.visualizer import plot_contacts3D
35 | from src.tools.bookkeeper import *
36 | from src.tools.transformations import *
37 | from src.tools.utils import makepath
38 |
39 | right_side = [15, 16, 17, 18]
40 | left_side = [19, 20, 21, 22]
41 | # stat_metrics = CalculateMetricsDanceData()
42 | def dist(x, y):
43 | # return torch.mean(x - y)
44 | return torch.mean(torch.cdist(x, y, p=2))
45 |
46 | def initialize_weights(m):
47 | std_dev = 0.02
48 | if isinstance(m, nn.Linear):
49 | nn.init.normal_(m.weight, std=std_dev)
50 | if m.bias is not None:
51 | nn.init.normal_(m.bias, std=std_dev)
52 | # nn.init.constant_(m.bias.data, 1e-5)
53 | elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
54 | torch.nn.init.normal_(m.weight, std=std_dev)
55 | if m.bias is not None:
56 | torch.nn.init.normal_(m.bias, std=std_dev)
57 | # nn.init.constant_(m.bias.data, 1e-5)
58 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
59 | nn.init.normal_(m.weight, std=std_dev)
60 | if m.bias is not None:
61 | nn.init.normal_(m.bias, std=std_dev)
62 |
63 | class Trainer:
64 | def __init__(self, args, is_train=True, split='test', JT_POSITION=False, num_jts = 69):
65 | torch.manual_seed(args.seed)
66 | self.model_path = args.model_path
67 | makepath(args.work_dir, isfile=False)
68 | use_cuda = torch.cuda.is_available()
69 | if use_cuda:
70 | torch.cuda.empty_cache()
71 | self.device = torch.device("cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")
72 | gpu_brand = torch.cuda.get_device_name(args.cuda) if use_cuda else None
73 | gpu_count = torch.cuda.device_count() if args.use_multigpu else 1
74 | print('Using %d CUDA cores [%s] for training!' % (gpu_count, gpu_brand))
75 | args_subset = ['exp', 'model', 'batch_size', 'frames']
76 | self.book = BookKeeper(args, args_subset)
77 | self.args = self.book.args
78 | self.batch_size = args.batch_size
79 | self.curriculum = args.curriculum
80 | self.scale = args.scale
81 | self.dtype = torch.float32
82 | self.epochs_completed = self.book.last_epoch
83 | self.frames = args.frames
84 | self.model = args.model
85 | self.lambda_loss = args.lambda_loss
86 | self.testtime_split = split
87 | self.num_jts = num_jts
88 | self.model_pose = eval(args.model)(device=self.device,
89 | num_jts=self.num_jts,
90 | num_frames=self.frames,
91 | input_feats=args.input_feats,
92 | # jt_latent_dim=args.jt_latent,
93 | latent_dim=args.d_model,
94 | num_heads=args.num_head,
95 | num_layers=args.num_layer,
96 | ff_size=args.d_ff,
97 | activations=args.activations
98 | ).to(self.device).float()
99 | trainable_count_body = sum(p.numel() for p in self.model_pose.parameters() if p.requires_grad)
100 |
101 | self.diffusion_steps = args.diffusion_steps
102 | self.beta_scheduler = args.noise_schedule
103 | self.betas = get_named_beta_schedule(self.beta_scheduler, self.diffusion_steps)
104 | self.diffusion = GaussianDiffusion(
105 | betas=self.betas,
106 | model_mean_type=ModelMeanType.START_X,
107 | model_var_type=ModelVarType.FIXED_SMALL,
108 | loss_type=LossType.MSE
109 | )
110 | self.sampler_name = args.sampler
111 | self.sampler = create_named_schedule_sampler(self.sampler_name, self.diffusion)
112 | self.model_pose.apply(initialize_weights)
113 | self.optimizer_model_pose = eval(args.optimizer)(self.model_pose.parameters(), lr = args.lr)
114 | self.scheduler_pose = eval(args.scheduler)(self.optimizer_model_pose, step_size=args.stepsize, gamma=args.gamma)
115 | self.skel = InhouseStudioSkeleton()
116 | self.mse_criterion = torch.nn.MSELoss()
117 | self.l1_criterion = torch.nn.L1Loss()
118 |
119 | print(args.model, 'Model Created')
120 | if args.load:
121 | print('Loading Model', args.model)
122 | self.book._load_model(self.model_pose, 'model_pose')
123 | print('Loading the data')
124 | if is_train:
125 | self.load_data(args)
126 | else:
127 | self.load_data_testtime(args)
128 |
129 |
130 | def load_data_testtime(self, args):
131 | self.ds_data = LindyHopDataset(args, window_size=self.frames, split=self.testtime_split)
132 | self.load_ds_data = DataLoader(self.ds_data, batch_size=1, shuffle=False, num_workers=0, drop_last=True)
133 |
134 |
135 | def load_data(self, args):
136 |
137 | ds_train = LindyHopDataset(args, window_size=self.frames, split='train')
138 | self.ds_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
139 | print('Train set loaded. Size=', len(self.ds_train.dataset))
140 | ds_val = LindyHopDataset(args, window_size=self.frames, split='test')
141 | self.ds_val = DataLoader(ds_val, batch_size=1, shuffle=False, num_workers=0, drop_last=True)
142 | print('Validation set loaded. Size=', len(self.ds_val.dataset))
143 |
144 | def calc_kldiv(self, dist_m):
145 | mu_ref = torch.zeros_like(dist_m.loc)
146 | scale_ref = torch.ones_like(dist_m.scale)
147 | dist_ref = torch.distributions.Normal(mu_ref, scale_ref)
148 | return torch.distributions.kl_divergence(dist_m, dist_ref)
149 |
150 | def calc_loss(self, num_epoch):
151 | bs, seq, J, dim = self.generated.shape
152 | pos_loss = self.lambda_loss['pos'] * self.mse_criterion(self.gt_pose2, self.generated)
153 | vel_gt = self.gt_pose2[:, 1:] - self.gt_pose2[:, :-1]
154 | vel_gen = self.generated[:, 1:] - self.generated[:, :-1]
155 | velocity_loss = self.lambda_loss['vel'] * self.mse_criterion(vel_gt, vel_gen)
156 | acc_gt = vel_gt[:, 1:] - vel_gt[:, :-1]
157 | acc_gen = vel_gen[:, 1:] - vel_gen[:, :-1]
158 | acc_loss = self.lambda_loss['vel'] * self.mse_criterion(acc_gt, acc_gen)
159 | bone_len_gt = (self.gt_pose2[:, :, 1:] - self.gt_pose2[:, :, [self.skel.parents_body_only[x] for x in range(1, J)]]).norm(dim=-1)
160 | bone_len_gen = (self.generated[:, :, 1:] - self.generated[:, :, [self.skel.parents_body_only[x] for x in range(1, J)]]).norm(dim=-1)
161 | bone_len_consistency_loss = self.lambda_loss['bone'] * self.mse_criterion(bone_len_gt, bone_len_gen)
162 | if num_epoch > 100:
163 | self.lambda_loss['foot'] = 20.0
164 | else:
165 | self.lambda_loss['foot'] = 0.0
166 | rightfoot_idx = [4, 5]
167 | leftfoot_idx = [9, 10]
168 | gen_leftfoot_joint = self.generated[:, :, leftfoot_idx]
169 | static_left_foot_index = gen_leftfoot_joint[..., 1] <= 0.02
170 | gen_rightfoot_joint = self.generated[:, :, rightfoot_idx]
171 | static_right_foot_index = gen_rightfoot_joint[..., 1] <= 0.02
172 | gen_leftfoot_vel = torch.zeros_like(gen_leftfoot_joint)
173 | gen_leftfoot_vel[:, :-1] = gen_leftfoot_joint[:, 1:] - gen_leftfoot_joint[:, :-1]
174 | gen_leftfoot_vel[~static_left_foot_index] = 0
175 | gen_rightfoot_vel = torch.zeros_like(gen_rightfoot_joint)
176 | gen_rightfoot_vel[:, :-1] = gen_rightfoot_joint[:, 1:] - gen_rightfoot_joint[:, :-1]
177 | gen_rightfoot_vel[~static_right_foot_index] = 0
178 | footskate_loss = self.lambda_loss['foot'] * (self.mse_criterion(gen_leftfoot_vel, torch.zeros_like(gen_leftfoot_vel)) +
179 | self.mse_criterion(gen_rightfoot_vel, torch.zeros_like(gen_rightfoot_vel)) )
180 |
181 | loss_logs = [pos_loss, velocity_loss, bone_len_consistency_loss,
182 | footskate_loss, acc_loss]
183 |
184 | #include the interaction loss
185 | self.lambda_loss['in'] = 50.0
186 | rh_rh = self.contact_map[:,:, 0] == 1
187 | rh_lh = self.contact_map[:,:, 1] == 1
188 | lh_rh = self.contact_map[:,:, 2] == 1
189 | lh_lh = self.contact_map[:,:, 3] == 1
190 |
191 | arm_interact_loss = self.lambda_loss['in'] * torch.mean(
192 | rh_lh * ((self.pose1[:, :, right_side] - self.gt_pose2[:, :, left_side]).norm(dim=-1) - (
193 | self.pose1[:, :, right_side] - self.generated[:, :, left_side]).norm(dim=-1)).norm(dim=-1) + rh_rh * (
194 | (self.pose1[:, :, right_side] - self.gt_pose2[:, :, right_side]).norm(dim=-1) - (
195 | self.pose1[:, :, right_side] - self.generated[:, :, right_side]).norm(dim=-1)).norm(dim=-1) + lh_rh * ((
196 | self.pose1[:, :, left_side] - self.gt_pose2[:, :, right_side]).norm(dim=-1) - (
197 | self.pose1[:, :, left_side] - self.generated[:, :, right_side]).norm(dim=-1)).norm(dim=-1) + lh_lh * ((
198 | self.pose1[:, :, left_side] - self.gt_pose2[:, :, left_side]).norm(dim=-1) - (
199 | self.pose1[:, :, left_side] - self.generated[:, :, left_side]).norm(dim=-1)).norm(dim=-1) )
200 |
201 | loss_logs.append(arm_interact_loss)
202 | interact_loss = self.mse_criterion((self.pose1 - self.gt_pose2), (self.pose1 - self.generated))
203 | loss_logs.append(interact_loss)
204 | return loss_logs
205 |
206 | def forward(self, motions1, motions2, t=None):
207 | B, T = motions2.shape[:2]
208 | if t == None:
209 | t, _ = self.sampler.sample(B, motions1.device)
210 | self.diffusion_timestep = t
211 | output = self.diffusion.training_losses(
212 | model=self.model_pose,
213 | x_start=motions2,
214 | t=t,
215 | model_kwargs={"motion1": motions1}
216 | )
217 |
218 | self.pose1 = motions1
219 | self.gt_pose2 = motions2 #gt pose 2
220 | self.generated = output['pred'] # synthesized pose 2
221 | return t, output['x_noisy']
222 |
223 |
224 |
225 | def root_relative_normalization(self, global_pose1, global_pose2):
226 |
227 | global_pose1 = self.skel.select_bvh_joints(global_pose1, original_joint_order=self.skel.bvh_joint_order,
228 | new_joint_order=self.skel.body_only)
229 | pose1_root_rel = global_pose1 - torch.repeat_interleave(self.global_root_origin.unsqueeze(-2), self.num_jts, axis=-2)
230 | self.pose1_root_rel = pose1_root_rel / self.scale
231 | global_pose2 = self.skel.select_bvh_joints(global_pose2, original_joint_order=self.skel.bvh_joint_order,
232 | new_joint_order=self.skel.body_only)
233 |
234 | pose2_root_rel = global_pose2 - torch.repeat_interleave(self.global_root_origin.unsqueeze(-2), self.num_jts, axis=-2)
235 | self.pose2_root_rel = pose2_root_rel / self.scale
236 | tmp=1
237 |
238 | def root_relative_unnormalization(self, pose1_normalized, pose2_normalized):
239 | pose1_unnormalized = pose1_normalized * self.scale
240 | pose2_unnormalized = pose2_normalized * self.scale
241 | global_pose1 = pose1_unnormalized + torch.repeat_interleave(self.global_root_origin.unsqueeze(-2), self.num_jts, axis=-2)
242 | global_pose2 = pose2_unnormalized + torch.repeat_interleave(self.global_root_origin.unsqueeze(-2), self.num_jts, axis=-2)
243 | return global_pose1, global_pose2
244 |
245 | def train(self, num_epoch, ablation=None):
246 | total_train_loss = 0.0
247 | total_pos_loss = 0.0
248 | total_vel_loss = 0.0
249 | total_bone_loss = 0.0
250 | total_footskate_loss = 0.0
251 | self.model_pose.train()
252 | training_tqdm = tqdm(self.ds_train, desc='train' + ' {:.10f}'.format(0), leave=False, ncols=120)
253 | diff_count = [0, 5, 10, 50, 100, 200, 300, 400, 499]
254 | for count, batch in enumerate(training_tqdm):
255 | self.optimizer_model_pose.zero_grad()
256 |
257 | with torch.autograd.detect_anomaly():
258 | global_pose1 = batch['pose_canon_1'].to(self.device).float()
259 | global_pose2 = batch['pose_canon_2'].to(self.device).float()
260 | self.contact_map = batch['contacts'].to(self.device).float()
261 | self.global_root_origin = batch['global_root_origin'].to(device).float()
262 | if global_pose1.shape[1] == 0:
263 | continue
264 | self.root_relative_normalization(global_pose1, global_pose2)
265 | t, noisy = self.forward(self.pose1_root_rel, self.pose2_root_rel)
266 |
267 | loss_logs = self.calc_loss(num_epoch)
268 | loss_model = sum(loss_logs)
269 | total_train_loss += loss_model.item()
270 | total_pos_loss += loss_logs[0].item()
271 | total_vel_loss += loss_logs[1].item()
272 | total_bone_loss += loss_logs[2].item()
273 | total_footskate_loss += loss_logs[3].item()
274 |
275 | if loss_model == float('inf') or torch.isnan(loss_model):
276 | print('Train loss is nan')
277 | exit()
278 | loss_model.backward()
279 | torch.nn.utils.clip_grad_value_(self.model_pose.parameters(), 0.01)
280 | self.optimizer_model_pose.step()
281 |
282 |
283 | avg_train_loss = total_train_loss/(count + 1)
284 | avg_pos_loss = total_pos_loss/(count + 1)
285 | avg_vel_loss = total_vel_loss/(count + 1)
286 | avg_bone_loss = total_bone_loss/(count + 1)
287 | avg_footskate_loss = total_footskate_loss/(count + 1)
288 |
289 | return avg_train_loss, (avg_pos_loss, avg_vel_loss, avg_bone_loss, avg_footskate_loss)
290 |
291 | def evaluate(self, num_epoch, ablation=None):
292 | total_eval_loss = 0.0
293 | total_pos_loss = 0.0
294 | total_vel_loss = 0.0
295 | total_bone_loss = 0.0
296 | total_footskate_loss = 0.0
297 | self.model_pose.eval()
298 | T = self.frames
299 | eval_tqdm = tqdm(self.ds_val, desc='eval' + ' {:.10f}'.format(0), leave=False, ncols=120)
300 |
301 | for count, batch in enumerate(eval_tqdm):
302 | if True:
303 | global_pose1 = batch['pose_canon_1'].to(self.device).float()
304 | global_pose2 = batch['pose_canon_2'].to(self.device).float()
305 | self.contact_map = batch['contacts'].to(self.device).float()
306 |
307 | self.global_root_origin = batch['global_root_origin'].to(device).float()
308 | if global_pose1.shape[1] == 0:
309 | continue
310 | self.root_relative_normalization(global_pose1, global_pose2)
311 | t, noisy = self.forward(self.pose1_root_rel, self.pose2_root_rel)
312 | loss_logs = self.calc_loss(num_epoch)
313 | loss_model = sum(loss_logs)
314 | total_eval_loss += loss_model.item()
315 | total_pos_loss += loss_logs[0].item()
316 | total_vel_loss += loss_logs[1].item()
317 | total_bone_loss += loss_logs[2].item()
318 | total_footskate_loss += loss_logs[3].item()
319 |
320 | avg_eval_loss = total_eval_loss/(count + 1)
321 | avg_pos_loss = total_pos_loss/(count + 1)
322 | avg_vel_loss = total_vel_loss/(count + 1)
323 | avg_bone_loss = total_bone_loss/(count + 1)
324 | avg_footskate_loss = total_footskate_loss/(count + 1)
325 |
326 | return avg_eval_loss, (avg_pos_loss, avg_vel_loss, avg_bone_loss, avg_footskate_loss)
327 |
328 | def fit(self, n_epochs=None, ablation=False):
329 | print('*****Inside Trainer.fit *****')
330 | if n_epochs is None:
331 | n_epochs = self.args.num_epochs
332 | starttime = datetime.now().replace(microsecond=0)
333 | print('Started Training at', datetime.strftime(starttime, '%Y-%m-%d_%H:%M:%S'), 'Total epochs: ', n_epochs)
334 | save_model_dict = {}
335 | best_eval = 1000
336 |
337 | train_pos_loss = []
338 | train_vel_loss = []
339 | train_bone_loss = []
340 | train_footskate_loss = []
341 | eval_pos_loss = []
342 | eval_vel_loss = []
343 | eval_bone_loss = []
344 | eval_footskate_loss = []
345 | for epoch_num in range(self.epochs_completed, n_epochs + 1):
346 | tqdm.write('--- starting Epoch # %03d' % epoch_num)
347 | train_loss, (train_pos_loss_, train_vel_loss_, train_bone_loss_,
348 | train_footskate_loss_) = self.train(epoch_num, ablation)
349 | train_pos_loss.append(train_pos_loss_)
350 | train_vel_loss.append(train_vel_loss_)
351 | train_bone_loss.append(train_bone_loss_)
352 | train_footskate_loss.append(train_footskate_loss_)
353 | if epoch_num % 5 == 0:
354 | eval_loss, (eval_pos_loss_, eval_vel_loss_, eval_bone_loss_,
355 | eval_footskate_loss_) = self.evaluate(epoch_num, ablation)
356 | eval_pos_loss.append(eval_pos_loss_)
357 | eval_vel_loss.append(eval_vel_loss_)
358 | eval_bone_loss.append(eval_bone_loss_)
359 | eval_footskate_loss.append(eval_footskate_loss_)
360 | else:
361 | eval_loss = 0.0
362 | self.scheduler_pose.step()
363 | self.book.update_res({'epoch': epoch_num, 'train': train_loss, 'val': eval_loss, 'test': 0.0})
364 | self.book._save_res()
365 | self.book.print_res(epoch_num, key_order=['train', 'val', 'test'], lr=self.optimizer_model_pose.param_groups[0]['lr'])
366 |
367 | if epoch_num > 100 and eval_loss < best_eval:
368 | print('Best eval at epoch {}'.format(epoch_num))
369 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + 'best.p'), 'wb')
370 | save_model_dict.update({'model_pose': self.model_pose.state_dict()})
371 | torch.save(save_model_dict, f)
372 | f.close()
373 | best_eval = eval_loss
374 | if epoch_num > 20 and epoch_num % 20 == 0 :
375 | f = open(os.path.join(self.args.save_dir, self.book.name.name, self.book.name.name + '{:06d}'.format(epoch_num) + '.p'), 'wb')
376 | save_model_dict.update({'model_pose': self.model_pose.state_dict()})
377 | torch.save(save_model_dict, f)
378 | f.close()
379 | endtime = datetime.now().replace(microsecond=0)
380 | print('Finished Training at %s\n' % (datetime.strftime(endtime, '%Y-%m-%d_%H:%M:%S')))
381 | print('Training complete in %s!\n' % (endtime - starttime))
382 |
383 |
384 |
385 | if __name__ == '__main__':
386 | args = argparseNloop()
387 | args.lambda_loss = {
388 | 'fk': 1.0,
389 | 'fk_vel': 1.0,
390 | 'rot': 1.0,
391 | 'rot_vel': 1.0,
392 | 'kldiv': 1.0,
393 | 'pos': 1e+3,
394 | 'vel': 1e+1,
395 | 'bone': 1.0,
396 | 'foot': 0.0
397 | }
398 | is_train = True
399 | ablation = None # if True then ablation: no_IAC_loss
400 | model_trainer = Trainer(args=args, is_train=is_train, split='test', JT_POSITION=True, num_jts=27)
401 | print("** Method Initialization Complete **")
402 | model_trainer.fit(ablation=ablation)
403 |
404 |
--------------------------------------------------------------------------------