├── requirements.txt ├── LICENSE ├── CDCD_aist.txt ├── post_process.py ├── utils ├── preprocess.py ├── quaternion.py ├── scaler.py ├── kinetic.py ├── utils.py ├── manual.py └── vis.py ├── genre.py ├── README.md ├── bas_cdcd.py ├── .gitignore ├── audio_to_images.py ├── eval_cdcd.py ├── genremodel └── models.py ├── norm_motion.py └── train_unet_latent.py /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | datasets 3 | accelerate==0.20.3 4 | diffusers==0.18.1 5 | audiodiffusion 6 | torchlibrosa 7 | frechet_audio_distance 8 | pydub -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Vanessa Tan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CDCD_aist.txt: -------------------------------------------------------------------------------- 1 | gPO_sBM_cAll_d11_mPO1_ch02_slice1.wav 2 | gLH_sBM_cAll_d17_mLH4_ch02_slice2.wav 3 | gLO_sBM_cAll_d15_mLO2_ch02_slice1.wav 4 | gMH_sBM_cAll_d24_mMH3_ch02_slice1.wav 5 | gBR_sBM_cAll_d05_mBR0_ch02_slice1.wav 6 | gLH_sBM_cAll_d17_mLH4_ch02_slice1.wav 7 | gPO_sBM_cAll_d10_mPO1_ch02_slice1.wav 8 | gMH_sBM_cAll_d24_mMH3_ch02_slice2.wav 9 | gBR_sBM_cAll_d04_mBR0_ch02_slice1.wav 10 | gKR_sBM_cAll_d28_mKR2_ch02_slice1.wav 11 | gKR_sBM_cAll_d30_mKR2_ch02_slice1.wav 12 | gLO_sBM_cAll_d15_mLO2_ch02_slice2.wav 13 | gLO_sBM_cAll_d13_mLO2_ch02_slice1.wav 14 | gJS_sBM_cAll_d03_mJS3_ch02_slice2.wav 15 | gHO_sBM_cAll_d21_mHO5_ch02_slice1.wav 16 | gLO_sBM_cAll_d13_mLO2_ch02_slice2.wav 17 | gKR_sBM_cAll_d28_mKR2_ch02_slice2.wav 18 | gBR_sBM_cAll_d04_mBR0_ch02_slice2.wav 19 | gJB_sBM_cAll_d08_mJB5_ch02_slice1.wav 20 | gKR_sBM_cAll_d30_mKR2_ch02_slice2.wav 21 | gLH_sBM_cAll_d18_mLH4_ch02_slice1.wav 22 | gJS_sBM_cAll_d01_mJS3_ch02_slice1.wav 23 | gBR_sBM_cAll_d05_mBR0_ch02_slice2.wav 24 | gMH_sBM_cAll_d22_mMH3_ch02_slice2.wav 25 | gMH_sBM_cAll_d22_mMH3_ch02_slice1.wav 26 | gJB_sBM_cAll_d09_mJB5_ch02_slice1.wav 27 | gPO_sBM_cAll_d10_mPO1_ch02_slice2.wav 28 | gLH_sBM_cAll_d18_mLH4_ch02_slice2.wav 29 | gJS_sBM_cAll_d03_mJS3_ch02_slice1.wav 30 | gHO_sBM_cAll_d20_mHO5_ch02_slice1.wav 31 | gJS_sBM_cAll_d01_mJS3_ch02_slice2.wav 32 | -------------------------------------------------------------------------------- /post_process.py: -------------------------------------------------------------------------------- 1 | # Normalize loudness of generated music to match ground truth 2 | import argparse 3 | import glob 4 | import os 5 | from tqdm.auto import tqdm 6 | from pydub import AudioSegment, effects 7 | 8 | 9 | def match_target_amplitude(sound, target_dBFS): 10 | change_in_dBFS = target_dBFS - sound.dBFS 11 | return sound.apply_gain(change_in_dBFS) 12 | 13 | 14 | def main(args): 15 | audio_files = sorted(glob.glob(os.path.join(args.input_dir, "*.wav"))) 16 | for audio_file in tqdm(audio_files): 17 | audio_file = (os.path.basename(audio_file).split('/')[-1][:-4]) 18 | sound = AudioSegment.from_file('{}{}.wav'.format(args.input_dir, audio_file), format="wav", frame_rate=22050) 19 | five_seconds = 5 * 1000 20 | first_5_seconds = sound[:five_seconds] 21 | normalized_sound = match_target_amplitude(first_5_seconds, -20.0) 22 | normalized_sound = normalized_sound + 8 23 | normalized_sound.export('{}{}.wav'.format(args.output_dir, audio_file), format="wav") 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser(description="Normalize Volume") 28 | parser.add_argument("--input_dir", type=str, default="wav/") 29 | parser.add_argument("--output_dir", type=str, default="wav_norm/") 30 | args = parser.parse_args() 31 | 32 | main(args) -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/Stanford-TML/EDGE 2 | # Normalization Code 3 | 4 | import glob 5 | import os 6 | import re 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | from scaler import MinMaxScaler 12 | 13 | 14 | def increment_path(path, exist_ok=False, sep="", mkdir=False): 15 | # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 16 | path = Path(path) # os-agnostic 17 | if path.exists() and not exist_ok: 18 | suffix = path.suffix 19 | path = path.with_suffix("") 20 | dirs = glob.glob(f"{path}{sep}*") # similar paths 21 | matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] 22 | i = [int(m.groups()[0]) for m in matches if m] # indices 23 | n = max(i) + 1 if i else 2 # increment number 24 | path = Path(f"{path}{sep}{n}{suffix}") # update path 25 | dir = path if path.suffix == "" else path.parent # directory 26 | if not dir.exists() and mkdir: 27 | dir.mkdir(parents=True, exist_ok=True) # make directory 28 | return path 29 | 30 | 31 | class Normalizer: 32 | def __init__(self, data): 33 | flat = data.reshape(-1, data.shape[-1]) 34 | self.scaler = MinMaxScaler((0, 1), clip=True) 35 | self.scaler.fit(flat) 36 | 37 | def normalize(self, x): 38 | batch, seq, ch = x.shape 39 | x = x.reshape(-1, ch) 40 | return self.scaler.transform(x).reshape((batch, seq, ch)) 41 | 42 | def unnormalize(self, x): 43 | batch, seq, ch = x.shape 44 | x = x.reshape(-1, ch) 45 | x = torch.clip(x, 0, 1) # clip to force compatibility 46 | return self.scaler.inverse_transform(x).reshape((batch, seq, ch)) 47 | 48 | 49 | def vectorize_many(data): 50 | # given a list of batch x seqlen x joints? x channels, flatten all to batch x seqlen x -1, concatenate 51 | batch_size = data[0].shape[0] 52 | seq_len = data[0].shape[1] 53 | 54 | out = [x.reshape(batch_size, seq_len, -1).contiguous() for x in data] 55 | 56 | global_pose_vec_gt = torch.cat(out, dim=2) 57 | return global_pose_vec_gt 58 | -------------------------------------------------------------------------------- /utils/quaternion.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/Stanford-TML/EDGE 2 | # Quaternion utilities for motion extraction 3 | 4 | import torch 5 | from pytorch3d.transforms import (axis_angle_to_matrix, matrix_to_axis_angle, 6 | matrix_to_quaternion, matrix_to_rotation_6d, 7 | quaternion_to_matrix, rotation_6d_to_matrix) 8 | 9 | 10 | def quat_to_6v(q): 11 | assert q.shape[-1] == 4 12 | mat = quaternion_to_matrix(q) 13 | mat = matrix_to_rotation_6d(mat) 14 | return mat 15 | 16 | 17 | def quat_from_6v(q): 18 | assert q.shape[-1] == 6 19 | mat = rotation_6d_to_matrix(q) 20 | quat = matrix_to_quaternion(mat) 21 | return quat 22 | 23 | 24 | def ax_to_6v(q): 25 | assert q.shape[-1] == 3 26 | mat = axis_angle_to_matrix(q) 27 | mat = matrix_to_rotation_6d(mat) 28 | return mat 29 | 30 | 31 | def ax_from_6v(q): 32 | assert q.shape[-1] == 6 33 | mat = rotation_6d_to_matrix(q) 34 | ax = matrix_to_axis_angle(mat) 35 | return ax 36 | 37 | 38 | def quat_slerp(x, y, a): 39 | """ 40 | Performs spherical linear interpolation (SLERP) between x and y, with proportion a 41 | 42 | :param x: quaternion tensor (N, S, J, 4) 43 | :param y: quaternion tensor (N, S, J, 4) 44 | :param a: interpolation weight (S, ) 45 | :return: tensor of interpolation results 46 | """ 47 | len = torch.sum(x * y, axis=-1) 48 | 49 | neg = len < 0.0 50 | len[neg] = -len[neg] 51 | y[neg] = -y[neg] 52 | 53 | a = torch.zeros_like(x[..., 0]) + a 54 | 55 | amount0 = torch.zeros_like(a) 56 | amount1 = torch.zeros_like(a) 57 | 58 | linear = (1.0 - len) < 0.01 59 | omegas = torch.arccos(len[~linear]) 60 | sinoms = torch.sin(omegas) 61 | 62 | amount0[linear] = 1.0 - a[linear] 63 | amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms 64 | 65 | amount1[linear] = a[linear] 66 | amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms 67 | 68 | # reshape 69 | amount0 = amount0[..., None] 70 | amount1 = amount1[..., None] 71 | 72 | res = amount0 * x + amount1 * y 73 | 74 | return res 75 | -------------------------------------------------------------------------------- /genre.py: -------------------------------------------------------------------------------- 1 | # Evaluate Genre KLD on test set of CDCD list 2 | # Model used: https://github.com/PeiChunChang/MS-SincResNet 3 | import argparse 4 | import scipy.io.wavfile 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import numpy.ma as ma 8 | import librosa 9 | import librosa.display 10 | import torch 11 | 12 | from genremodel.models import * 13 | from scipy import signal 14 | from tqdm.auto import tqdm 15 | from scipy.special import kl_div 16 | 17 | 18 | def NormalizeData(data): 19 | return (data - np.min(data)) / (np.max(data) - np.min(data)) 20 | 21 | def main(args): 22 | 23 | # Load MS-SincResNet model 24 | MODEL_PATH = 'MS-SincResNet.tar' 25 | state_dict = torch.load(MODEL_PATH) 26 | model = MS_SincResNet() 27 | model.load_state_dict(state_dict['state_dict']) 28 | model.cuda() 29 | model.eval() 30 | 31 | audio_files = [line.rstrip() for line in open('CDCD_aist.txt')] 32 | kl_array = [] 33 | for audio_file in tqdm(audio_files): 34 | _, data = scipy.io.wavfile.read('{}{}'.format(args.input_dir, audio_file)) 35 | data = signal.resample(data, 16000 * 30) 36 | data = data[24000:72000] 37 | 38 | data = torch.from_numpy(data).float() 39 | data.unsqueeze_(dim=0) 40 | data.unsqueeze_(dim=0) 41 | data = data.cuda() 42 | gt, _, _, _ = model(data) 43 | gt = gt.detach().cpu().numpy().tolist()[0] 44 | gt = NormalizeData(gt) 45 | 46 | _, data = scipy.io.wavfile.read('{}{}'.format(args.output_dir, audio_file)) 47 | data = signal.resample(data, 16000 * 30) 48 | data = data[24000:72000] 49 | 50 | data = torch.from_numpy(data).float() 51 | data.unsqueeze_(dim=0) 52 | data.unsqueeze_(dim=0) 53 | data = data.cuda() 54 | gen, _, _, _ = model(data) 55 | gen = gen.detach().cpu().numpy().tolist()[0] 56 | gen = NormalizeData(gen) 57 | 58 | output = sum(kl_div(np.array(gt), np.array(gen))) 59 | 60 | kl_array.append(output) 61 | print(np.mean(ma.masked_invalid(kl_array))) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser(description="Genre Classifier") 66 | parser.add_argument("--input_dir", type=str, default=r"/host_data/van/edge_aistpp/test/wavs_sliced/") 67 | parser.add_argument("--output_dir", type=str, default=r"/host_data/van/edge_aistpp/outputv2/all_01/normalized/") 68 | args = parser.parse_args() 69 | 70 | main(args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Motion to Dance Music Generation 2 | [**"Motion to Dance Music Generation using Latent Diffusion Model"**](https://dmdproject.github.io/) - Official PyTorch Implementation 3 | 4 | ![teaser](https://github.com/DMDproject/DMDproject.github.io/blob/main/static/images/main_figure_dmd.jpg) 5 | 6 | 7 | ## Installation 8 | 9 | This code was tested on `Ubuntu 20.04.2 LTS` and requires: 10 | 11 | * Python 3.8 12 | * CUDA capable GPU 13 | * Download Pre-processed [data and models](https://drive.google.com/file/d/1FRZY-RWiSno_yo7MYYzSri5DWEWSGukG/view?usp=sharing) 14 | 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | ## Dataset 19 | The dataset used was the [AIST++ dataset](https://google.github.io/aistplusplus_dataset/download.html). The segmented music data is also provided [here](https://drive.google.com/file/d/1rtEYKFMMC8y5EFkiCAC0GEP6AGOWDbeM/view?usp=drive_link). 20 | 21 | ## Preprocess data 22 | ### Generate mel spectrograms 23 | ```bash 24 | python audio_to_images.py 25 | ``` 26 | 27 | ### Generate concatenated motion and genre features 28 | ```bash 29 | python norm_motion.py 30 | ``` 31 | 32 | ## Training and inference 33 | ### Train latent diffusion model using pre-trained VAE 34 | ```bash 35 | python train_unet_latent.py 36 | ``` 37 | 38 | ### Generate samples then normalize loudness 39 | ```bash 40 | python eval_cdcd.py --gen_audio=True 41 | python post_process.py 42 | ``` 43 | 44 | ## Evaluation 45 | ```bash 46 | python eval_cdcd.py # beat coverage score, beat hit score, and FAD 47 | python bas_cdcd.py # beat align score 48 | python genre.py # genre KLD (get pretrained model from https://github.com/PeiChunChang/MS-SincResNet) 49 | ``` 50 | 51 | ## Attribution 52 | Please include the following citations in any preprints and publications that use this repository. 53 | ``` 54 | @inproceedings{10.1145/3610543.3626164, 55 | author = {Tan, Vanessa and Nam, Junghyun and Nam, Juhan and Noh, Junyong}, 56 | title = {Motion to Dance Music Generation Using Latent Diffusion Model}, 57 | year = {2023}, 58 | isbn = {9798400703140}, 59 | publisher = {Association for Computing Machinery}, 60 | address = {New York, NY, USA}, 61 | url = {https://doi.org/10.1145/3610543.3626164}, 62 | doi = {10.1145/3610543.3626164}, 63 | booktitle = {SIGGRAPH Asia 2023 Technical Communications}, 64 | articleno = {5}, 65 | numpages = {4}, 66 | keywords = {latent diffusion model, 3D motion to music, music generation}, 67 | location = {, Sydney, NSW, Australia, }, 68 | series = {SA Technical Communications '23} 69 | } 70 | ``` 71 | 72 | ## Acknowledgments 73 | 74 | We would like to thank [Joel Casimiro](https://sites.google.com/eee.upd.edu.ph/joelcasimiro) for helping in creating our preview image. 75 | We would also like to thank the following contributors that our code is based on: [Audio-Diffusion](https://github.com/teticio/audio-diffusion/), [EDGE](https://github.com/Stanford-TML/EDGE), [Bailando](https://github.com/lisiyao21/Bailando), [AIST++](https://github.com/google-research/mint), [MS-SincResNet](https://github.com/PeiChunChang/MS-SincResNet). 76 | -------------------------------------------------------------------------------- /utils/scaler.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/Stanford-TML/EDGE 2 | # Scaler Code 3 | 4 | import torch 5 | 6 | 7 | def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): 8 | # if we are fitting on 1D arrays, scale might be a scalar 9 | if constant_mask is None: 10 | # Detect near constant values to avoid dividing by a very small 11 | # value that could lead to surprising results and numerical 12 | # stability issues. 13 | constant_mask = scale < 10 * torch.finfo(scale.dtype).eps 14 | 15 | if copy: 16 | # New array to avoid side-effects 17 | scale = scale.clone() 18 | scale[constant_mask] = 1.0 19 | return scale 20 | 21 | 22 | class MinMaxScaler: 23 | _parameter_constraints: dict = { 24 | "feature_range": [tuple], 25 | "copy": ["boolean"], 26 | "clip": ["boolean"], 27 | } 28 | 29 | def __init__(self, feature_range=(0, 1), *, copy=True, clip=False): 30 | self.feature_range = feature_range 31 | self.copy = copy 32 | self.clip = clip 33 | 34 | def _reset(self): 35 | """Reset internal data-dependent state of the scaler, if necessary. 36 | __init__ parameters are not touched. 37 | """ 38 | # Checking one attribute is enough, because they are all set together 39 | # in partial_fit 40 | if hasattr(self, "scale_"): 41 | del self.scale_ 42 | del self.min_ 43 | del self.n_samples_seen_ 44 | del self.data_min_ 45 | del self.data_max_ 46 | del self.data_range_ 47 | 48 | def fit(self, X): 49 | # Reset internal state before fitting 50 | self._reset() 51 | return self.partial_fit(X) 52 | 53 | def partial_fit(self, X): 54 | feature_range = self.feature_range 55 | if feature_range[0] >= feature_range[1]: 56 | raise ValueError( 57 | "Minimum of desired feature range must be smaller than maximum. Got %s." 58 | % str(feature_range) 59 | ) 60 | 61 | data_min = torch.min(X, axis=0)[0] 62 | data_max = torch.max(X, axis=0)[0] 63 | 64 | self.n_samples_seen_ = X.shape[0] 65 | 66 | data_range = data_max - data_min 67 | self.scale_ = (feature_range[1] - feature_range[0]) / _handle_zeros_in_scale( 68 | data_range, copy=True 69 | ) 70 | self.min_ = feature_range[0] - data_min * self.scale_ 71 | self.data_min_ = data_min 72 | self.data_max_ = data_max 73 | self.data_range_ = data_range 74 | return self 75 | 76 | def transform(self, X): 77 | X *= self.scale_.to(X.device) 78 | X += self.min_.to(X.device) 79 | if self.clip: 80 | torch.clip(X, self.feature_range[0], self.feature_range[1], out=X) 81 | return X 82 | 83 | def inverse_transform(self, X): 84 | X -= self.min_[-X.shape[1] :].to(X.device) 85 | X /= self.scale_[-X.shape[1] :].to(X.device) 86 | return X 87 | -------------------------------------------------------------------------------- /bas_cdcd.py: -------------------------------------------------------------------------------- 1 | # References: https://github.com/lisiyao21/Bailando 2 | # https://github.com/google-research/mint 3 | # https://github.com/Stanford-TML/EDGE 4 | # Compute the beat aligned score for the test set from CDCD list 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import pickle 9 | import json 10 | import os 11 | import glob 12 | import torch 13 | import librosa 14 | 15 | from scipy.ndimage import gaussian_filter as G 16 | from scipy.signal import argrelextrema 17 | from utils.vis import SMPLSkeleton 18 | from scipy import linalg 19 | from pytorch3d.transforms import RotateAxisAngle 20 | from pytorch3d.transforms.rotation_conversions import (axis_angle_to_quaternion, 21 | quaternion_multiply, 22 | quaternion_to_axis_angle, 23 | quaternion_invert) 24 | 25 | from eval_cdcd import beat_detect 26 | from kinetic import extract_kinetic_features 27 | from manual import extract_manual_features 28 | 29 | 30 | def cal_motion_beat(root_pos : np.ndarray, joint_orn : np.ndarray, fps :int=30): 31 | smpl = SMPLSkeleton() 32 | 33 | root_pos = torch.Tensor(root_pos) 34 | root_pos = root_pos.reshape((1, 300, 3)) 35 | joint_orn = torch.Tensor(joint_orn) 36 | joint_orn = joint_orn.reshape((1, 300, -1, 3)) 37 | 38 | # AISTPP dataset comes y-up - rotate to z-up to standardize against the pretrain dataset 39 | root_q = joint_orn[:, :, :1, :] # sequence x 1 x 3 40 | root_q_quat = axis_angle_to_quaternion(root_q) 41 | rotation = torch.Tensor([0.7071068, 0.7071068, 0, 0]) # 90 degrees about the x axis 42 | root_q_quat = quaternion_multiply(rotation, root_q_quat) 43 | root_q = quaternion_to_axis_angle(root_q_quat) 44 | joint_orn[:, :, :1, :] = root_q 45 | 46 | pos_rotation = RotateAxisAngle(90, axis="X", degrees=True) 47 | root_pos = pos_rotation.transform_points(root_pos) # basically (y, z) -> (-z, y), expressed as a rotation for readability 48 | 49 | # get joint pos 50 | joint_pos = smpl.forward(joint_orn, root_pos) # batch x sequence x 24 x 3 51 | 52 | joint_pos = np.array(joint_pos).reshape(-1, 24, 3) 53 | kinetic_vel = np.mean(np.sqrt(np.sum((joint_pos[1:] - joint_pos[:-1]) ** 2, axis=2)), axis=1) 54 | kinetic_vel = G(kinetic_vel, 5) 55 | motion_beat = argrelextrema(kinetic_vel, np.less) 56 | 57 | return motion_beat, len(kinetic_vel) 58 | 59 | 60 | def ba_score(music_beats, motion_beats): 61 | ba = 0 62 | for bb in music_beats: 63 | ba += np.exp(-np.min((motion_beats[0] - bb)**2) / 2 / 9) 64 | return (ba / len(music_beats)) 65 | 66 | 67 | def calc_ba_score(motion_dir, music_dir): 68 | 69 | ba_scores = [] 70 | 71 | audio_files = [line.rstrip() for line in open('CDCD_aist.txt')] 72 | for motion in audio_files: 73 | m_name = motion[:-4] 74 | data = pickle.load(open(motion_dir + m_name + '.pkl', "rb")) 75 | 76 | root_pos = data["pos"] 77 | joint_orn = data["q"] 78 | dance_beats, length = cal_motion_beat(root_pos, joint_orn) 79 | 80 | # Beat Extractor: Librosa 81 | music, sr = librosa.load('{}{}.wav'.format(music_dir, m_name)) 82 | onset_env = librosa.onset.onset_strength(y=music) 83 | tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr) 84 | 85 | ba_scores.append(ba_score(beats, dance_beats)) 86 | 87 | return np.mean(ba_scores) 88 | 89 | 90 | if __name__ == '__main__': 91 | 92 | motion_dir = r"/host_data/van/edge_aistpp/test/motions_sliced/" 93 | music_dir = r"/host_data/van/edge_aistpp/outputv2/all_01/normalized/" 94 | print(calc_ba_score(motion_dir, music_dir)) 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /audio_to_images.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/teticio/audio-diffusion/ 2 | # Audio waveform to mel spectrogram script 3 | 4 | import argparse 5 | import io 6 | import logging 7 | import os 8 | import re 9 | import glob 10 | 11 | import numpy as np 12 | import pandas as pd 13 | from datasets import Dataset, DatasetDict, Features, Image, Value 14 | from diffusers.pipelines.audio_diffusion import Mel 15 | from tqdm.auto import tqdm 16 | 17 | logging.basicConfig(level=logging.WARN) 18 | logger = logging.getLogger("audio_to_images") 19 | 20 | 21 | def main(args): 22 | 23 | mel = Mel( 24 | x_res=args.resolution[0], 25 | y_res=args.resolution[1], 26 | hop_length=args.hop_length, 27 | sample_rate=args.sample_rate, 28 | n_fft=args.n_fft, 29 | ) 30 | os.makedirs(args.output_dir, exist_ok=True) 31 | 32 | audio_files = sorted(glob.glob(os.path.join(args.input_dir, "*.wav"))) 33 | examples = [] 34 | try: 35 | for audio_file in tqdm(audio_files): 36 | try: 37 | mel.load_audio(audio_file) 38 | except KeyboardInterrupt: 39 | raise 40 | except: 41 | continue 42 | for slice in range(mel.get_number_of_slices()): 43 | image = mel.audio_slice_to_image(slice) 44 | assert image.width == args.resolution[0] and image.height == args.resolution[1], "Wrong resolution" 45 | # skip completely silent slices 46 | if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255): 47 | logger.warn("File %s slice %d is completely silent", audio_file, slice) 48 | continue 49 | with io.BytesIO() as output: 50 | image.save(output, format="PNG") 51 | bytes = output.getvalue() 52 | examples.extend( 53 | [ 54 | { 55 | "image": {"bytes": bytes}, 56 | "audio_file": audio_file, 57 | "slice": slice, 58 | } 59 | ] 60 | ) 61 | except Exception as e: 62 | print(e) 63 | finally: 64 | if len(examples) == 0: 65 | logger.warn("No valid audio files were found.") 66 | return 67 | ds = Dataset.from_pandas( 68 | pd.DataFrame(examples), 69 | features=Features( 70 | { 71 | "image": Image(), 72 | "audio_file": Value(dtype="string"), 73 | "slice": Value(dtype="int16"), 74 | } 75 | ), 76 | ) 77 | dsd = DatasetDict({"train": ds}) 78 | dsd.save_to_disk(os.path.join(args.output_dir)) 79 | if args.push_to_hub: 80 | dsd.push_to_hub(args.push_to_hub) 81 | 82 | 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser(description="Create dataset of Mel spectrograms from directory of audio files.") 85 | parser.add_argument("--input_dir", type=str, default=r"/host_data/van/edge_aistpp/train/wavs_sliced/") 86 | parser.add_argument("--output_dir", type=str, default="aistpp_256_sorted") 87 | parser.add_argument( 88 | "--resolution", 89 | type=str, 90 | default="256", 91 | help="Either square resolution or width,height.", 92 | ) 93 | parser.add_argument("--hop_length", type=int, default=512) 94 | parser.add_argument("--push_to_hub", type=str, default=None) 95 | parser.add_argument("--sample_rate", type=int, default=22050) 96 | parser.add_argument("--n_fft", type=int, default=2048) 97 | args = parser.parse_args() 98 | 99 | if args.input_dir is None: 100 | raise ValueError("You must specify an input directory for the audio files.") 101 | 102 | # Handle the resolutions. 103 | try: 104 | args.resolution = (int(args.resolution), int(args.resolution)) 105 | except ValueError: 106 | try: 107 | args.resolution = tuple(int(x) for x in args.resolution.split(",")) 108 | if len(args.resolution) != 2: 109 | raise ValueError 110 | except ValueError: 111 | raise ValueError("Resolution must be a tuple of two integers or a single integer.") 112 | assert isinstance(args.resolution, tuple) 113 | 114 | main(args) 115 | -------------------------------------------------------------------------------- /utils/kinetic.py: -------------------------------------------------------------------------------- 1 | # BSD License 2 | 3 | # For fairmotion software 4 | 5 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | # Modified by Ruilong Li 7 | 8 | # Redistribution and use in source and binary forms, with or without modification, 9 | # are permitted provided that the following conditions are met: 10 | 11 | # * Redistributions of source code must retain the above copyright notice, this 12 | # list of conditions and the following disclaimer. 13 | 14 | # * Redistributions in binary form must reproduce the above copyright notice, 15 | # this list of conditions and the following disclaimer in the documentation 16 | # and/or other materials provided with the distribution. 17 | 18 | # * Neither the name Facebook nor the names of its contributors may be used to 19 | # endorse or promote products derived from this software without specific 20 | # prior written permission. 21 | 22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 23 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 24 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 26 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 27 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 28 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 29 | # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | import numpy as np 33 | import utils as feat_utils 34 | 35 | 36 | def extract_kinetic_features(positions): 37 | assert len(positions.shape) == 3 # (seq_len, n_joints, 3) 38 | features = KineticFeatures(positions) 39 | kinetic_feature_vector = [] 40 | for i in range(positions.shape[1]): 41 | feature_vector = np.hstack( 42 | [ 43 | features.average_kinetic_energy_horizontal(i), 44 | features.average_kinetic_energy_vertical(i), 45 | features.average_energy_expenditure(i), 46 | ] 47 | ) 48 | kinetic_feature_vector.extend(feature_vector) 49 | kinetic_feature_vector = np.array(kinetic_feature_vector, dtype=np.float32) 50 | return kinetic_feature_vector 51 | 52 | 53 | class KineticFeatures: 54 | def __init__( 55 | self, positions, frame_time=1./60, up_vec="y", sliding_window=2 56 | ): 57 | self.positions = positions 58 | self.frame_time = frame_time 59 | self.up_vec = up_vec 60 | self.sliding_window = sliding_window 61 | 62 | def average_kinetic_energy(self, joint): 63 | average_kinetic_energy = 0 64 | for i in range(1, len(self.positions)): 65 | average_velocity = feat_utils.calc_average_velocity( 66 | self.positions, i, joint, self.sliding_window, self.frame_time 67 | ) 68 | average_kinetic_energy += average_velocity ** 2 69 | average_kinetic_energy = average_kinetic_energy / ( 70 | len(self.positions) - 1.0 71 | ) 72 | return average_kinetic_energy 73 | 74 | def average_kinetic_energy_horizontal(self, joint): 75 | val = 0 76 | for i in range(1, len(self.positions)): 77 | average_velocity = feat_utils.calc_average_velocity_horizontal( 78 | self.positions, 79 | i, 80 | joint, 81 | self.sliding_window, 82 | self.frame_time, 83 | self.up_vec, 84 | ) 85 | val += average_velocity ** 2 86 | val = val / (len(self.positions) - 1.0) 87 | return val 88 | 89 | def average_kinetic_energy_vertical(self, joint): 90 | val = 0 91 | for i in range(1, len(self.positions)): 92 | average_velocity = feat_utils.calc_average_velocity_vertical( 93 | self.positions, 94 | i, 95 | joint, 96 | self.sliding_window, 97 | self.frame_time, 98 | self.up_vec, 99 | ) 100 | val += average_velocity ** 2 101 | val = val / (len(self.positions) - 1.0) 102 | return val 103 | 104 | def average_energy_expenditure(self, joint): 105 | val = 0.0 106 | for i in range(1, len(self.positions)): 107 | val += feat_utils.calc_average_acceleration( 108 | self.positions, i, joint, self.sliding_window, self.frame_time 109 | ) 110 | val = val / (len(self.positions) - 1.0) 111 | return val 112 | -------------------------------------------------------------------------------- /eval_cdcd.py: -------------------------------------------------------------------------------- 1 | # Evaluate test set from CDCD list 2 | # Evaluation metrics: beat coverage score, beat hit score, and FAD 3 | import torch 4 | import argparse 5 | import glob 6 | import os 7 | import random 8 | import librosa 9 | import pickle 10 | import time 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | from datasets import load_dataset 15 | from librosa.beat import beat_track 16 | from diffusers import DiffusionPipeline 17 | from scipy.io.wavfile import write 18 | from PIL import Image 19 | from tqdm.auto import tqdm 20 | from frechet_audio_distance import FrechetAudioDistance 21 | 22 | 23 | def beat_detect(x, sr=22050): 24 | onsets = librosa.onset.onset_detect(x, sr=sr, wait=1, delta=0.2, pre_avg=1, post_avg=1, post_max=1, units='time') 25 | n = np.ceil( len(x) / sr) 26 | beats = [0] * int(n) 27 | for time in onsets: 28 | beats[int(np.trunc(time))] = 1 29 | return beats 30 | 31 | 32 | def beat_scores(gt, syn): 33 | assert len(gt) == len(syn) 34 | total_beats = sum(gt) 35 | cover_beats = sum(syn) 36 | 37 | hit_beats = 0 38 | for i in range(len(gt)): 39 | if gt[i] == 1 and gt[i] == syn[i]: 40 | hit_beats += 1 41 | return cover_beats/total_beats, hit_beats/total_beats 42 | 43 | 44 | def main(args): 45 | # Generate Audio 46 | if args.gen_audio: 47 | device = "cuda" if torch.cuda.is_available() else "cpu" 48 | generator = torch.Generator(device=device) 49 | model_id = r"/host_data/van/edge_aistpp/modelsv2/all_01" 50 | encode_id = r"/host_data/van/edge_aistpp/test/concat/normalized_all_test_data_01.pkl" 51 | 52 | total_cover_score = 0 53 | total_hit_score = 0 54 | audio_files = [line.rstrip() for line in open('CDCD_aist.txt')] 55 | for audio_file in tqdm(audio_files): 56 | audio_file = audio_file[:-4] 57 | 58 | if args.gen_audio: 59 | #start = time.time() 60 | encodings = pickle.load(open(encode_id, "rb")) 61 | encoding = encodings[audio_file] 62 | print(np.array(encoding).shape) 63 | encoding = np.array(encoding).reshape(1, 150, 226) 64 | encoding = torch.Tensor(encoding).to(device) 65 | 66 | audio_diffusion = DiffusionPipeline.from_pretrained(model_id).to(device) 67 | mel = audio_diffusion.mel 68 | sample_rate = mel.get_sample_rate() 69 | 70 | seed = 2391504374279719 71 | generator.manual_seed(seed) 72 | output = audio_diffusion(generator=generator, eta=0, encoding=encoding) 73 | image = output.images[0] 74 | audio = output.audios[0, 0] 75 | 76 | # 64 x 64 can only output 2s so we outpaint 77 | if args.outpaint: 78 | overlap_secs = 0 79 | start_step = 0 80 | overlap_samples = overlap_secs * sample_rate 81 | track = audio 82 | for variation in range(3): 83 | output = audio_diffusion(raw_audio=audio[-overlap_samples:], start_step=start_step, mask_start_secs=overlap_secs, eta=0, encoding=encoding) 84 | audio2 = output.audios[0, 0] 85 | track = np.concatenate([track, audio2[overlap_samples:]]) 86 | audio = audio2 87 | write('{}{}.wav'.format(args.output_dir, audio_file), sample_rate, track) 88 | else: 89 | write('{}{}.wav'.format(args.output_dir, audio_file), sample_rate, audio) 90 | #end = time.time() 91 | #print(end - start) 92 | 93 | else: 94 | # Beat Evaluation (Librosa) 95 | music, sr = librosa.load('{}{}.wav'.format(args.input_dir, audio_file)) 96 | gt_beats = beat_detect(music) 97 | generated_audio, sr = librosa.load('{}{}.wav'.format(args.output_dir, audio_file)) 98 | syn_beats = beat_detect(generated_audio) 99 | 100 | score_cover, score_hit = beat_scores(gt_beats, syn_beats) 101 | total_cover_score += score_cover 102 | total_hit_score += score_hit 103 | 104 | if not args.gen_audio: 105 | print("Score Summary for cover and hit: ", total_cover_score/len(audio_files), total_hit_score/len(audio_files)) 106 | frechet = FrechetAudioDistance(model_name="vggish", use_pca=False, use_activation=False, verbose=False) 107 | fad_score = frechet.score(args.input_dir, args.output_dir) 108 | print("FAD: ", fad_score) 109 | 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser(description="Evaluate Dataset") 113 | parser.add_argument("--input_dir", type=str, default=r"/host_data/van/edge_aistpp/test/wavs_sliced/") 114 | parser.add_argument("--output_dir", type=str, default=r"/host_data/van/edge_aistpp/outputv2/all_01/normalized/") 115 | parser.add_argument("--gen_audio", type=bool, default=False) 116 | parser.add_argument("--outpaint", type=bool, default=False) 117 | args = parser.parse_args() 118 | 119 | main(args) 120 | 121 | 122 | -------------------------------------------------------------------------------- /genremodel/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.models as models 8 | 9 | 10 | class SincConv_fast(nn.Module): 11 | """Sinc-based convolution 12 | Parameters 13 | ---------- 14 | in_channels : `int` 15 | Number of input channels. Must be 1. 16 | out_channels : `int` 17 | Number of filters. 18 | kernel_size : `int` 19 | Filter length. 20 | sample_rate : `int`, optional 21 | Sample rate. Defaults to 16000. 22 | Usage 23 | ----- 24 | See `torch.nn.Conv1d` 25 | Reference 26 | --------- 27 | Mirco Ravanelli, Yoshua Bengio, 28 | "Speaker Recognition from raw waveform with SincNet". 29 | https://arxiv.org/abs/1808.00158 30 | """ 31 | 32 | @staticmethod 33 | def to_mel(hz): 34 | return 2595 * np.log10(1 + hz / 700) 35 | 36 | @staticmethod 37 | def to_hz(mel): 38 | return 700 * (10 ** (mel / 2595) - 1) 39 | 40 | def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, 41 | stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50): 42 | super(SincConv_fast,self).__init__() 43 | 44 | if in_channels != 1: 45 | #msg = (f'SincConv only support one input channel ' 46 | # f'(here, in_channels = {in_channels:d}).') 47 | msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels) 48 | raise ValueError(msg) 49 | 50 | self.out_channels = out_channels 51 | self.kernel_size = kernel_size 52 | 53 | # Forcing the filters to be odd (i.e, perfectly symmetrics) 54 | if kernel_size%2==0: 55 | self.kernel_size=self.kernel_size+1 56 | 57 | self.stride = stride 58 | self.padding = padding 59 | self.dilation = dilation 60 | 61 | if bias: 62 | raise ValueError('SincConv does not support bias.') 63 | if groups > 1: 64 | raise ValueError('SincConv does not support groups.') 65 | 66 | self.sample_rate = sample_rate 67 | self.min_low_hz = min_low_hz 68 | self.min_band_hz = min_band_hz 69 | 70 | # initialize filterbanks such that they are equally spaced in Mel scale 71 | low_hz = 30 72 | high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) 73 | 74 | mel = np.linspace(self.to_mel(low_hz), 75 | self.to_mel(high_hz), 76 | self.out_channels + 1) 77 | hz = self.to_hz(mel) 78 | 79 | # filter lower frequency (out_channels, 1) 80 | self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) 81 | 82 | # filter frequency band (out_channels, 1) 83 | self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) 84 | 85 | # Hamming window 86 | #self.window_ = torch.hamming_window(self.kernel_size) 87 | n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window 88 | self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size); 89 | 90 | # (1, kernel_size/2) 91 | n = (self.kernel_size - 1) / 2.0 92 | self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes 93 | 94 | def forward(self, waveforms): 95 | """ 96 | Parameters 97 | ---------- 98 | waveforms : `torch.Tensor` (batch_size, 1, n_samples) 99 | Batch of waveforms. 100 | Returns 101 | ------- 102 | features : `torch.Tensor` (batch_size, out_channels, n_samples_out) 103 | Batch of sinc filters activations. 104 | """ 105 | self.n_ = self.n_.to(waveforms.device) 106 | self.window_ = self.window_.to(waveforms.device) 107 | 108 | low = self.min_low_hz + torch.abs(self.low_hz_) 109 | high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2) 110 | band=(high-low)[:,0] 111 | 112 | f_times_t_low = torch.matmul(low, self.n_) 113 | f_times_t_high = torch.matmul(high, self.n_) 114 | 115 | band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. 116 | band_pass_center = 2*band.view(-1,1) 117 | band_pass_right= torch.flip(band_pass_left,dims=[1]) 118 | 119 | band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) 120 | band_pass = band_pass / (2*band[:,None]) 121 | self.filters = (band_pass).view( 122 | self.out_channels, 1, self.kernel_size) 123 | return F.conv1d(waveforms, self.filters, stride=self.stride, 124 | padding=self.padding, dilation=self.dilation, 125 | bias=None, groups=1) 126 | 127 | class myResnet(nn.Module): 128 | def __init__(self, pretrained=True): 129 | super(myResnet, self).__init__() 130 | self.model = models.resnet18(pretrained=True) 131 | self.model.fc = nn.Linear(512, 10, bias=True) 132 | 133 | def forward(self, x): 134 | x = self.model(x) 135 | return x 136 | 137 | class MS_SincResNet(nn.Module): 138 | def __init__(self): 139 | super(MS_SincResNet, self).__init__() 140 | self.layerNorm = nn.LayerNorm([1, 48000]) 141 | self.sincNet1 = nn.Sequential( 142 | SincConv_fast(out_channels=160, kernel_size=251), 143 | nn.BatchNorm1d(160), 144 | nn.ReLU(inplace=True), 145 | nn.AdaptiveAvgPool1d(1024)) 146 | self.sincNet2 = nn.Sequential( 147 | SincConv_fast(out_channels=160, kernel_size=501), 148 | nn.BatchNorm1d(160), 149 | nn.ReLU(inplace=True), 150 | nn.AdaptiveAvgPool1d(1024)) 151 | self.sincNet3 = nn.Sequential( 152 | SincConv_fast(out_channels=160, kernel_size=1001), 153 | nn.BatchNorm1d(160), 154 | nn.ReLU(inplace=True), 155 | nn.AdaptiveAvgPool1d(1024)) 156 | self.resnet = myResnet(pretrained=True) 157 | def forward(self, x): 158 | x = self.layerNorm(x) 159 | 160 | feat1 = self.sincNet1(x) 161 | feat2 = self.sincNet2(x) 162 | feat3 = self.sincNet3(x) 163 | 164 | x = torch.cat((feat1.unsqueeze_(dim=1), 165 | feat2.unsqueeze_(dim=1), 166 | feat3.unsqueeze_(dim=1)), dim=1) 167 | x = self.resnet(x) 168 | return x, feat1, feat2, feat3 169 | 170 | 171 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # BSD License 2 | 3 | # For fairmotion software 4 | 5 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | 7 | # Redistribution and use in source and binary forms, with or without modification, 8 | # are permitted provided that the following conditions are met: 9 | 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | 17 | # * Neither the name Facebook nor the names of its contributors may be used to 18 | # endorse or promote products derived from this software without specific 19 | # prior written permission. 20 | 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | import numpy as np 32 | 33 | 34 | def distance_between_points(a, b): 35 | return np.linalg.norm(np.array(a) - np.array(b)) 36 | 37 | 38 | def distance_from_plane(a, b, c, p, threshold): 39 | ba = np.array(b) - np.array(a) 40 | ca = np.array(c) - np.array(a) 41 | cross = np.cross(ca, ba) 42 | 43 | pa = np.array(p) - np.array(a) 44 | return np.dot(cross, pa) / np.linalg.norm(cross) > threshold 45 | 46 | 47 | def distance_from_plane_normal(n1, n2, a, p, threshold): 48 | normal = np.array(n2) - np.array(n1) 49 | pa = np.array(p) - np.array(a) 50 | return np.dot(normal, pa) / np.linalg.norm(normal) > threshold 51 | 52 | 53 | def angle_within_range(j1, j2, k1, k2, range): 54 | j = np.array(j2) - np.array(j1) 55 | k = np.array(k2) - np.array(k1) 56 | 57 | angle = np.arccos(np.dot(j, k) / (np.linalg.norm(j) * np.linalg.norm(k))) 58 | angle = np.degrees(angle) 59 | 60 | if angle > range[0] and angle < range[1]: 61 | return True 62 | else: 63 | return False 64 | 65 | 66 | def velocity_direction_above_threshold( 67 | j1, j1_prev, j2, j2_prev, p, p_prev, threshold, time_per_frame=1 / 120.0 68 | ): 69 | velocity = ( 70 | np.array(p) - np.array(j1) - (np.array(p_prev) - np.array(j1_prev)) 71 | ) 72 | direction = np.array(j2) - np.array(j1) 73 | 74 | velocity_along_direction = np.dot(velocity, direction) / np.linalg.norm( 75 | direction 76 | ) 77 | velocity_along_direction = velocity_along_direction / time_per_frame 78 | return velocity_along_direction > threshold 79 | 80 | 81 | def velocity_direction_above_threshold_normal( 82 | j1, j1_prev, j2, j3, p, p_prev, threshold, time_per_frame=1 / 120.0 83 | ): 84 | velocity = ( 85 | np.array(p) - np.array(j1) - (np.array(p_prev) - np.array(j1_prev)) 86 | ) 87 | j31 = np.array(j3) - np.array(j1) 88 | j21 = np.array(j2) - np.array(j1) 89 | direction = np.cross(j31, j21) 90 | 91 | velocity_along_direction = np.dot(velocity, direction) / np.linalg.norm( 92 | direction 93 | ) 94 | velocity_along_direction = velocity_along_direction / time_per_frame 95 | return velocity_along_direction > threshold 96 | 97 | 98 | def velocity_above_threshold(p, p_prev, threshold, time_per_frame=1 / 120.0): 99 | velocity = np.linalg.norm(np.array(p) - np.array(p_prev)) / time_per_frame 100 | return velocity > threshold 101 | 102 | 103 | def calc_average_velocity(positions, i, joint_idx, sliding_window, frame_time): 104 | current_window = 0 105 | average_velocity = np.zeros(len(positions[0][joint_idx])) 106 | for j in range(-sliding_window, sliding_window + 1): 107 | if i + j - 1 < 0 or i + j >= len(positions): 108 | continue 109 | average_velocity += ( 110 | positions[i + j][joint_idx] - positions[i + j - 1][joint_idx] 111 | ) 112 | current_window += 1 113 | return np.linalg.norm(average_velocity / (current_window * frame_time)) 114 | 115 | 116 | def calc_average_acceleration( 117 | positions, i, joint_idx, sliding_window, frame_time 118 | ): 119 | current_window = 0 120 | average_acceleration = np.zeros(len(positions[0][joint_idx])) 121 | for j in range(-sliding_window, sliding_window + 1): 122 | if i + j - 1 < 0 or i + j + 1 >= len(positions): 123 | continue 124 | v2 = ( 125 | positions[i + j + 1][joint_idx] - positions[i + j][joint_idx] 126 | ) / frame_time 127 | v1 = ( 128 | positions[i + j][joint_idx] 129 | - positions[i + j - 1][joint_idx] / frame_time 130 | ) 131 | average_acceleration += (v2 - v1) / frame_time 132 | current_window += 1 133 | return np.linalg.norm(average_acceleration / current_window) 134 | 135 | 136 | def calc_average_velocity_horizontal( 137 | positions, i, joint_idx, sliding_window, frame_time, up_vec="z" 138 | ): 139 | current_window = 0 140 | average_velocity = np.zeros(len(positions[0][joint_idx])) 141 | for j in range(-sliding_window, sliding_window + 1): 142 | if i + j - 1 < 0 or i + j >= len(positions): 143 | continue 144 | average_velocity += ( 145 | positions[i + j][joint_idx] - positions[i + j - 1][joint_idx] 146 | ) 147 | current_window += 1 148 | if up_vec == "y": 149 | average_velocity = np.array( 150 | [average_velocity[0], average_velocity[2]] 151 | ) / (current_window * frame_time) 152 | elif up_vec == "z": 153 | average_velocity = np.array( 154 | [average_velocity[0], average_velocity[1]] 155 | ) / (current_window * frame_time) 156 | else: 157 | raise NotImplementedError 158 | return np.linalg.norm(average_velocity) 159 | 160 | 161 | def calc_average_velocity_vertical( 162 | positions, i, joint_idx, sliding_window, frame_time, up_vec 163 | ): 164 | current_window = 0 165 | average_velocity = np.zeros(len(positions[0][joint_idx])) 166 | for j in range(-sliding_window, sliding_window + 1): 167 | if i + j - 1 < 0 or i + j >= len(positions): 168 | continue 169 | average_velocity += ( 170 | positions[i + j][joint_idx] - positions[i + j - 1][joint_idx] 171 | ) 172 | current_window += 1 173 | if up_vec == "y": 174 | average_velocity = np.array([average_velocity[1]]) / ( 175 | current_window * frame_time 176 | ) 177 | elif up_vec == "z": 178 | average_velocity = np.array([average_velocity[2]]) / ( 179 | current_window * frame_time 180 | ) 181 | else: 182 | raise NotImplementedError 183 | return np.linalg.norm(average_velocity) 184 | -------------------------------------------------------------------------------- /norm_motion.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/Stanford-TML/EDGE 2 | # Extract normalized motion and genre (conditioning signals) 3 | 4 | import os 5 | import glob 6 | import pickle 7 | import numpy as np 8 | import librosa 9 | import torch 10 | 11 | from pathlib import Path 12 | from sklearn.preprocessing import OneHotEncoder 13 | from pytorch3d.transforms import RotateAxisAngle 14 | from pytorch3d.transforms.rotation_conversions import (axis_angle_to_quaternion, 15 | quaternion_multiply, 16 | quaternion_to_axis_angle, 17 | quaternion_invert) 18 | 19 | from utils.quaternion import ax_to_6v 20 | from utils.preprocess import Normalizer 21 | from utils.vis import SMPLSkeleton 22 | 23 | 24 | ABLATION_LIST = ["all", "pos", "orn", "linvel", "angvel"] 25 | 26 | 27 | def concat_aistpp(data_path: str, backup_path: str, is_train: bool = True): 28 | # motion data specification 29 | raw_fps = 60 30 | data_fps = 30 31 | assert data_fps <= raw_fps 32 | data_stride = raw_fps // data_fps 33 | 34 | # file save specificiation 35 | split_data_path = os.path.join( 36 | data_path, "train" if is_train else "test") 37 | 38 | backup_path = Path(backup_path) 39 | backup_path.mkdir(parents=True, exist_ok=True) 40 | pickle_name = "processed_train_data.pkl" if is_train else "processed_test_data.pkl" 41 | if pickle_name in os.listdir(backup_path): 42 | return 43 | 44 | # load dataset 45 | print("Loading dataset...") 46 | motion_path = os.path.join(split_data_path, "motions_sliced") 47 | 48 | # sort motions and sounds 49 | motions = sorted(glob.glob(os.path.join(motion_path, "*.pkl"))) 50 | 51 | # stack the motions and features together 52 | all_pos = [] 53 | all_q = [] 54 | all_names = [] 55 | all_genre = [] 56 | 57 | for motion in motions: 58 | # make sure name is matching 59 | m_name = os.path.splitext(os.path.basename(motion))[0] 60 | 61 | # load motion 62 | data = pickle.load(open(motion, "rb")) 63 | pos = data["pos"] 64 | q = data["q"] 65 | all_pos.append(pos) 66 | all_q.append(q) 67 | all_names.append(m_name) 68 | all_genre.append(m_name.split("_")[0]) 69 | 70 | all_pos = np.array(all_pos) # N x seq x 3 71 | all_q = np.array(all_q) # N x seq x (joint * 3) 72 | 73 | # downsample the motions to the data fps 74 | print(f"total concated pos dim : {all_pos.shape}") 75 | 76 | all_pos = all_pos[:, :: data_stride, :] 77 | all_q = all_q[:, :: data_stride, :] 78 | data = {"root_pos": all_pos, "joint_orn": all_q, "filenames": all_names, "genre": all_genre} 79 | 80 | with open(os.path.join(backup_path, pickle_name), "wb") as f: 81 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 82 | 83 | 84 | def cal_joint_ang_vel(joint_orn: torch.Tensor, fps: int=30)-> torch.Tensor: 85 | bs, num_frame, num_joint, _ = joint_orn.size() 86 | joint_ang_vel = torch.zeros((bs, num_frame, num_joint, 3), dtype=torch.float32) 87 | frame_list = np.arange(0, num_frame, dtype=int) 88 | prev_frame = np.maximum(0, frame_list-1) 89 | next_frame = np.minimum(num_frame-1, frame_list+1) 90 | 91 | dframe = next_frame - prev_frame 92 | 93 | frame_list = torch.tensor(frame_list, dtype=torch.int64) 94 | prev_frame = torch.tensor(prev_frame, dtype=torch.int64) 95 | next_frame = torch.tensor(next_frame, dtype=torch.int64) 96 | dframe = torch.tensor(dframe, dtype=torch.int64).view(1, num_frame, 1, 1) 97 | 98 | dorn = quaternion_invert(joint_orn[:,prev_frame, : ,:])*joint_orn[:,next_frame, :, :] 99 | dorn_ax = quaternion_to_axis_angle(dorn) 100 | joint_ang_vel[:, frame_list, :, : ] = fps * torch.div(dorn_ax, dframe) 101 | 102 | return joint_ang_vel 103 | 104 | 105 | def cal_joint_lin_vel(joint_pos: torch.Tensor, fps: int =30)-> torch.Tensor: 106 | _, num_frame, _, _ = joint_pos.size() 107 | joint_lin_vel = torch.zeros_like(joint_pos) 108 | frame_list = np.arange(0, num_frame, dtype=int) 109 | prev_frame = np.maximum(0, frame_list-1) 110 | next_frame = np.minimum(num_frame-1, frame_list+1) 111 | 112 | dframe = next_frame - prev_frame 113 | 114 | frame_list = torch.tensor(frame_list, dtype=torch.int64) 115 | prev_frame = torch.tensor(prev_frame, dtype=torch.int64) 116 | next_frame = torch.tensor(next_frame, dtype=torch.int64) 117 | dframe = torch.tensor(dframe, dtype=torch.int64).view(1, num_frame, 1, 1) 118 | 119 | joint_lin_vel[:, frame_list, :, :] = fps * torch.div((joint_pos[:, next_frame, :, :] - joint_pos[:, prev_frame, : ,:]), dframe) 120 | 121 | return joint_lin_vel 122 | 123 | 124 | def extract_motion(root_pos: np.ndarray, joint_orn: np.ndarray, data_sort: str = "all")-> torch.Tensor: 125 | smpl = SMPLSkeleton() 126 | # to Tensor 127 | root_pos = torch.Tensor(root_pos) 128 | joint_orn = torch.Tensor(joint_orn) 129 | # to ax 130 | bs, sq, _ = joint_orn.shape 131 | joint_orn = joint_orn.reshape((bs, sq, -1, 3)) 132 | 133 | # AISTPP dataset comes y-up - rotate to z-up to standardize against the pretrain dataset 134 | root_q = joint_orn[:, :, :1, :] # sequence x 1 x 3 135 | root_q_quat = axis_angle_to_quaternion(root_q) 136 | rotation = torch.Tensor( 137 | [0.7071068, 0.7071068, 0, 0] 138 | ) # 90 degrees about the x axis 139 | root_q_quat = quaternion_multiply(rotation, root_q_quat) 140 | root_q = quaternion_to_axis_angle(root_q_quat) 141 | joint_orn[:, :, :1, :] = root_q 142 | 143 | pos_rotation = RotateAxisAngle(90, axis="X", degrees=True) 144 | root_pos = pos_rotation.transform_points( 145 | root_pos 146 | ) # basically (y, z) -> (-z, y), expressed as a rotation for readability 147 | 148 | # get joint pos 149 | joint_pos = smpl.forward(joint_orn, root_pos) # batch x sequence x 24 x 3 150 | 151 | # get joint linear vel 152 | joint_lin_vel = cal_joint_lin_vel(joint_pos, 30) # batch x sequence x 24 x 3 153 | 154 | # get joint angular vel 155 | joint_ang_vel = cal_joint_ang_vel(axis_angle_to_quaternion(joint_orn), 30) 156 | 157 | # get joint orn with 6D repr 158 | joint_orn = ax_to_6v(joint_orn) # batch x sequence x 24 x 6 159 | 160 | 161 | ## generate motion feature along data_sort 162 | if data_sort == "all": 163 | motion_feature = torch.cat([joint_orn, joint_pos, joint_lin_vel, joint_ang_vel], axis= -1) 164 | print(np.array(motion_feature).shape) 165 | elif data_sort == "pos": 166 | motion_feature = torch.cat([joint_orn, joint_lin_vel, joint_ang_vel], axis= -1) 167 | print(np.array(motion_feature).shape) 168 | elif data_sort == "orn": 169 | motion_feature = torch.cat([joint_pos, joint_lin_vel, joint_ang_vel], axis= -1) 170 | print(np.array(motion_feature).shape) 171 | elif data_sort == "linvel": 172 | motion_feature = torch.cat([joint_orn, joint_pos, joint_ang_vel], axis= -1) 173 | print(np.array(motion_feature).shape) 174 | elif data_sort == "angvel": 175 | motion_feature = torch.cat([joint_orn, joint_pos, joint_lin_vel], axis= -1) 176 | print(np.array(motion_feature).shape) 177 | else: 178 | assert False, f"data_sort is not supported : {data_sort}" 179 | 180 | return motion_feature.view(bs, sq, -1) # batch x sequence x (24 x 15) 181 | 182 | 183 | def preprocess_aistpp(pickle_path: str, is_train: bool = True, data_sort: str = "all"): 184 | pickle_name = "processed_train_data.pkl" if is_train else "processed_test_data.pkl" 185 | with open(os.path.join(pickle_path, pickle_name), "rb") as f: 186 | data = pickle.load(f) 187 | 188 | motion_root_pos = data["root_pos"] 189 | motion_joint_orn = data["joint_orn"] 190 | names = data["filenames"] 191 | genres = data["genre"] 192 | 193 | encoder = OneHotEncoder() 194 | encoded_genres = encoder.fit_transform(np.array(genres).reshape(-1,1)) 195 | 196 | motion_feature_array: torch.Tensor = extract_motion(motion_root_pos, motion_joint_orn, sort) # bs x sq x (num_joints x motion_feature(15)) 197 | 198 | # normalizer 199 | normalizer_name = f"normalizer_{data_sort}_01.pkl" 200 | 201 | if is_train: 202 | normalizer = Normalizer(motion_feature_array) 203 | motion_feature_array = normalizer.normalize(motion_feature_array) 204 | with open(os.path.join(pickle_path, normalizer_name), "wb") as f: 205 | pickle.dump(normalizer, f, pickle.HIGHEST_PROTOCOL) 206 | else: 207 | normalizer: Normalizer = pickle.load(open(os.path.join(pickle_path, normalizer_name), "rb")) 208 | motion_feature_array = normalizer.normalize(motion_feature_array) 209 | 210 | motion_feature_array = motion_feature_array.numpy() 211 | 212 | encodings = {} 213 | print(len(names)) 214 | for i in range(0, len(names)): 215 | genres = np.tile(encoded_genres[i].todense(), (150, 1)) 216 | features = np.concatenate((motion_feature_array[i], genres), axis=-1) 217 | encodings[names[i]] = features 218 | 219 | processed_data_name = f"normalized_{data_sort}_train_data_01.pkl" if is_train else f"normalized_{data_sort}_test_data_01.pkl" 220 | with open(os.path.join(pickle_path, processed_data_name), "wb") as f: 221 | pickle.dump(encodings, f, pickle.HIGHEST_PROTOCOL) 222 | 223 | print(f"Finished data preprocess for {data_sort}") 224 | 225 | 226 | if __name__ == '__main__': 227 | #data_path = f"./data/" 228 | #backup_path = f"./data/merged/" 229 | #concat_aistpp(data_path, backup_path, False) 230 | 231 | # Server Paths 232 | backup_path = r"/host_data/van/edge_aistpp/test/concat/" # test folder 233 | #backup_path = r"/host_data/van/edge_aistpp/encoding/" # train folder 234 | for sort in ABLATION_LIST: 235 | preprocess_aistpp(pickle_path = backup_path, is_train = False, data_sort = sort) 236 | 237 | -------------------------------------------------------------------------------- /utils/manual.py: -------------------------------------------------------------------------------- 1 | # BSD License 2 | 3 | # For fairmotion software 4 | 5 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | # Modified by Ruilong Li 7 | 8 | # Redistribution and use in source and binary forms, with or without modification, 9 | # are permitted provided that the following conditions are met: 10 | 11 | # * Redistributions of source code must retain the above copyright notice, this 12 | # list of conditions and the following disclaimer. 13 | 14 | # * Redistributions in binary form must reproduce the above copyright notice, 15 | # this list of conditions and the following disclaimer in the documentation 16 | # and/or other materials provided with the distribution. 17 | 18 | # * Neither the name Facebook nor the names of its contributors may be used to 19 | # endorse or promote products derived from this software without specific 20 | # prior written permission. 21 | 22 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 23 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 24 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 26 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 27 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 28 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 29 | # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | import numpy as np 33 | import utils as feat_utils 34 | 35 | 36 | SMPL_JOINT_NAMES = [ 37 | "root", 38 | "lhip", "rhip", "belly", 39 | "lknee", "rknee", "spine", 40 | "lankle", "rankle", "chest", 41 | "ltoes", "rtoes", "neck", 42 | "linshoulder", "rinshoulder", 43 | "head", "lshoulder", "rshoulder", 44 | "lelbow", "relbow", 45 | "lwrist", "rwrist", 46 | "lhand", "rhand", 47 | ] 48 | 49 | 50 | def extract_manual_features(positions): 51 | assert len(positions.shape) == 3 # (seq_len, n_joints, 3) 52 | features = [] 53 | f = ManualFeatures(positions) 54 | for _ in range(1, positions.shape[0]): 55 | pose_features = [] 56 | pose_features.append( 57 | f.f_nmove("neck", "rhip", "lhip", "rwrist", 1.8 * f.hl) 58 | ) 59 | pose_features.append( 60 | f.f_nmove("neck", "lhip", "rhip", "lwrist", 1.8 * f.hl) 61 | ) 62 | pose_features.append( 63 | f.f_nplane("chest", "neck", "neck", "rwrist", 0.2 * f.hl) 64 | ) 65 | pose_features.append( 66 | f.f_nplane("chest", "neck", "neck", "lwrist", 0.2 * f.hl) 67 | ) 68 | pose_features.append( 69 | f.f_move("belly", "chest", "chest", "rwrist", 1.8 * f.hl) 70 | ) 71 | pose_features.append( 72 | f.f_move("belly", "chest", "chest", "lwrist", 1.8 * f.hl) 73 | ) 74 | pose_features.append( 75 | f.f_angle("relbow", "rshoulder", "relbow", "rwrist", [0, 110]) 76 | ) 77 | pose_features.append( 78 | f.f_angle("lelbow", "lshoulder", "lelbow", "lwrist", [0, 110]) 79 | ) 80 | pose_features.append( 81 | f.f_nplane( 82 | "lshoulder", "rshoulder", "lwrist", "rwrist", 2.5 * f.sw 83 | ) 84 | ) 85 | pose_features.append( 86 | f.f_move("lwrist", "rwrist", "rwrist", "lwrist", 1.4 * f.hl) 87 | ) 88 | pose_features.append( 89 | f.f_move("rwrist", "root", "lwrist", "root", 1.4 * f.hl) 90 | ) 91 | pose_features.append( 92 | f.f_move("lwrist", "root", "rwrist", "root", 1.4 * f.hl) 93 | ) 94 | pose_features.append(f.f_fast("rwrist", 2.5 * f.hl)) 95 | pose_features.append(f.f_fast("lwrist", 2.5 * f.hl)) 96 | pose_features.append( 97 | f.f_plane("root", "lhip", "ltoes", "rankle", 0.38 * f.hl) 98 | ) 99 | pose_features.append( 100 | f.f_plane("root", "rhip", "rtoes", "lankle", 0.38 * f.hl) 101 | ) 102 | pose_features.append( 103 | f.f_nplane("zero", "y_unit", "y_min", "rankle", 1.2 * f.hl) 104 | ) 105 | pose_features.append( 106 | f.f_nplane("zero", "y_unit", "y_min", "lankle", 1.2 * f.hl) 107 | ) 108 | pose_features.append( 109 | f.f_nplane("lhip", "rhip", "lankle", "rankle", 2.1 * f.hw) 110 | ) 111 | pose_features.append( 112 | f.f_angle("rknee", "rhip", "rknee", "rankle", [0, 110]) 113 | ) 114 | pose_features.append( 115 | f.f_angle("lknee", "lhip", "lknee", "lankle", [0, 110]) 116 | ) 117 | pose_features.append(f.f_fast("rankle", 2.5 * f.hl)) 118 | pose_features.append(f.f_fast("lankle", 2.5 * f.hl)) 119 | pose_features.append( 120 | f.f_angle("neck", "root", "rshoulder", "relbow", [25, 180]) 121 | ) 122 | pose_features.append( 123 | f.f_angle("neck", "root", "lshoulder", "lelbow", [25, 180]) 124 | ) 125 | pose_features.append( 126 | f.f_angle("neck", "root", "rhip", "rknee", [50, 180]) 127 | ) 128 | pose_features.append( 129 | f.f_angle("neck", "root", "lhip", "lknee", [50, 180]) 130 | ) 131 | pose_features.append( 132 | f.f_plane("rankle", "neck", "lankle", "root", 0.5 * f.hl) 133 | ) 134 | pose_features.append( 135 | f.f_angle("neck", "root", "zero", "y_unit", [70, 110]) 136 | ) 137 | pose_features.append( 138 | f.f_nplane("zero", "minus_y_unit", "y_min", "rwrist", -1.2 * f.hl) 139 | ) 140 | pose_features.append( 141 | f.f_nplane("zero", "minus_y_unit", "y_min", "lwrist", -1.2 * f.hl) 142 | ) 143 | pose_features.append(f.f_fast("root", 2.3 * f.hl)) 144 | features.append(pose_features) 145 | f.next_frame() 146 | features = np.array(features, dtype=np.float32).mean(axis=0) 147 | return features 148 | 149 | 150 | class ManualFeatures: 151 | def __init__(self, positions, joint_names=SMPL_JOINT_NAMES): 152 | self.positions = positions 153 | self.joint_names = joint_names 154 | self.frame_num = 1 155 | 156 | # humerus length 157 | self.hl = feat_utils.distance_between_points( 158 | [1.99113488e-01, 2.36807942e-01, -1.80702247e-02], # "lshoulder", 159 | [4.54445392e-01, 2.21158922e-01, -4.10167128e-02], # "lelbow" 160 | ) 161 | # shoulder width 162 | self.sw = feat_utils.distance_between_points( 163 | [1.99113488e-01, 2.36807942e-01, -1.80702247e-02], # "lshoulder" 164 | [-1.91692337e-01, 2.36928746e-01, -1.23055102e-02,], # "rshoulder" 165 | ) 166 | # hip width 167 | self.hw = feat_utils.distance_between_points( 168 | [5.64076714e-02, -3.23069185e-01, 1.09197125e-02], # "lhip" 169 | [-6.24834076e-02, -3.31302464e-01, 1.50412619e-02], # "rhip" 170 | ) 171 | 172 | def next_frame(self): 173 | self.frame_num += 1 174 | 175 | def transform_and_fetch_position(self, j): 176 | if j == "y_unit": 177 | return [0, 1, 0] 178 | elif j == "minus_y_unit": 179 | return [0, -1, 0] 180 | elif j == "zero": 181 | return [0, 0, 0] 182 | elif j == "y_min": 183 | return [ 184 | 0, 185 | min( 186 | [y for (_, y, _) in self.positions[self.frame_num]] 187 | ), 188 | 0, 189 | ] 190 | return self.positions[self.frame_num][ 191 | self.joint_names.index(j) 192 | ] 193 | 194 | def transform_and_fetch_prev_position(self, j): 195 | return self.positions[self.frame_num - 1][ 196 | self.joint_names.index(j) 197 | ] 198 | 199 | def f_move(self, j1, j2, j3, j4, range): 200 | j1_prev, j2_prev, j3_prev, j4_prev = [ 201 | self.transform_and_fetch_prev_position(j) for j in [j1, j2, j3, j4] 202 | ] 203 | j1, j2, j3, j4 = [ 204 | self.transform_and_fetch_position(j) for j in [j1, j2, j3, j4] 205 | ] 206 | return feat_utils.velocity_direction_above_threshold( 207 | j1, j1_prev, j2, j2_prev, j3, j3_prev, range, 208 | ) 209 | 210 | def f_nmove(self, j1, j2, j3, j4, range): 211 | j1_prev, j2_prev, j3_prev, j4_prev = [ 212 | self.transform_and_fetch_prev_position(j) for j in [j1, j2, j3, j4] 213 | ] 214 | j1, j2, j3, j4 = [ 215 | self.transform_and_fetch_position(j) for j in [j1, j2, j3, j4] 216 | ] 217 | return feat_utils.velocity_direction_above_threshold_normal( 218 | j1, j1_prev, j2, j3, j4, j4_prev, range 219 | ) 220 | 221 | def f_plane(self, j1, j2, j3, j4, threshold): 222 | j1, j2, j3, j4 = [ 223 | self.transform_and_fetch_position(j) for j in [j1, j2, j3, j4] 224 | ] 225 | return feat_utils.distance_from_plane(j1, j2, j3, j4, threshold) 226 | 227 | # 228 | def f_nplane(self, j1, j2, j3, j4, threshold): 229 | j1, j2, j3, j4 = [ 230 | self.transform_and_fetch_position(j) for j in [j1, j2, j3, j4] 231 | ] 232 | return feat_utils.distance_from_plane_normal(j1, j2, j3, j4, threshold) 233 | 234 | # relative 235 | def f_angle(self, j1, j2, j3, j4, range): 236 | j1, j2, j3, j4 = [ 237 | self.transform_and_fetch_position(j) for j in [j1, j2, j3, j4] 238 | ] 239 | return feat_utils.angle_within_range(j1, j2, j3, j4, range) 240 | 241 | # non-relative 242 | def f_fast(self, j1, threshold): 243 | j1_prev = self.transform_and_fetch_prev_position(j1) 244 | j1 = self.transform_and_fetch_position(j1) 245 | return feat_utils.velocity_above_threshold(j1, j1_prev, threshold) 246 | -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/Stanford-TML/EDGE 2 | # Visualization code 3 | 4 | import os 5 | from pathlib import Path 6 | from tempfile import TemporaryDirectory 7 | 8 | import librosa as lr 9 | import matplotlib.animation as animation 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import soundfile as sf 13 | import torch 14 | from matplotlib import cm 15 | from matplotlib.colors import ListedColormap 16 | from pytorch3d.transforms import (axis_angle_to_quaternion, quaternion_apply, 17 | quaternion_multiply) 18 | from tqdm import tqdm 19 | 20 | smpl_joints = [ 21 | "root", # 0 22 | "lhip", # 1 23 | "rhip", # 2 24 | "belly", # 3 25 | "lknee", # 4 26 | "rknee", # 5 27 | "spine", # 6 28 | "lankle",# 7 29 | "rankle",# 8 30 | "chest", # 9 31 | "ltoes", # 10 32 | "rtoes", # 11 33 | "neck", # 12 34 | "linshoulder", # 13 35 | "rinshoulder", # 14 36 | "head", # 15 37 | "lshoulder", # 16 38 | "rshoulder", # 17 39 | "lelbow", # 18 40 | "relbow", # 19 41 | "lwrist", # 20 42 | "rwrist", # 21 43 | "lhand", # 22 44 | "rhand", # 23 45 | ] 46 | 47 | smpl_parents = [ 48 | -1, 49 | 0, 50 | 0, 51 | 0, 52 | 1, 53 | 2, 54 | 3, 55 | 4, 56 | 5, 57 | 6, 58 | 7, 59 | 8, 60 | 9, 61 | 9, 62 | 9, 63 | 12, 64 | 13, 65 | 14, 66 | 16, 67 | 17, 68 | 18, 69 | 19, 70 | 20, 71 | 21, 72 | ] 73 | 74 | smpl_offsets = [ 75 | [0.0, 0.0, 0.0], 76 | [0.05858135, -0.08228004, -0.01766408], 77 | [-0.06030973, -0.09051332, -0.01354254], 78 | [0.00443945, 0.12440352, -0.03838522], 79 | [0.04345142, -0.38646945, 0.008037], 80 | [-0.04325663, -0.38368791, -0.00484304], 81 | [0.00448844, 0.1379564, 0.02682033], 82 | [-0.01479032, -0.42687458, -0.037428], 83 | [0.01905555, -0.4200455, -0.03456167], 84 | [-0.00226458, 0.05603239, 0.00285505], 85 | [0.04105436, -0.06028581, 0.12204243], 86 | [-0.03483987, -0.06210566, 0.13032329], 87 | [-0.0133902, 0.21163553, -0.03346758], 88 | [0.07170245, 0.11399969, -0.01889817], 89 | [-0.08295366, 0.11247234, -0.02370739], 90 | [0.01011321, 0.08893734, 0.05040987], 91 | [0.12292141, 0.04520509, -0.019046], 92 | [-0.11322832, 0.04685326, -0.00847207], 93 | [0.2553319, -0.01564902, -0.02294649], 94 | [-0.26012748, -0.01436928, -0.03126873], 95 | [0.26570925, 0.01269811, -0.00737473], 96 | [-0.26910836, 0.00679372, -0.00602676], 97 | [0.08669055, -0.01063603, -0.01559429], 98 | [-0.0887537, -0.00865157, -0.01010708], 99 | ] 100 | 101 | 102 | def set_line_data_3d(line, x): 103 | line.set_data(x[:, :2].T) 104 | line.set_3d_properties(x[:, 2]) 105 | 106 | 107 | def set_scatter_data_3d(scat, x, c): 108 | scat.set_offsets(x[:, :2]) 109 | scat.set_3d_properties(x[:, 2], "z") 110 | scat.set_facecolors([c]) 111 | 112 | 113 | def get_axrange(poses): 114 | pose = poses[0] 115 | x_min = pose[:, 0].min() 116 | x_max = pose[:, 0].max() 117 | 118 | y_min = pose[:, 1].min() 119 | y_max = pose[:, 1].max() 120 | 121 | z_min = pose[:, 2].min() 122 | z_max = pose[:, 2].max() 123 | 124 | xdiff = x_max - x_min 125 | ydiff = y_max - y_min 126 | zdiff = z_max - z_min 127 | 128 | biggestdiff = max([xdiff, ydiff, zdiff]) 129 | return biggestdiff 130 | 131 | 132 | def plot_single_pose(num, poses, lines, ax, axrange, scat, contact): 133 | pose = poses[num] 134 | static = contact[num] 135 | indices = [7, 8, 10, 11] 136 | 137 | for i, (point, idx) in enumerate(zip(scat, indices)): 138 | position = pose[idx : idx + 1] 139 | color = "r" if static[i] else "g" 140 | set_scatter_data_3d(point, position, color) 141 | 142 | for i, (p, line) in enumerate(zip(smpl_parents, lines)): 143 | # don't plot root 144 | if i == 0: 145 | continue 146 | # stack to create a line 147 | data = np.stack((pose[i], pose[p]), axis=0) 148 | set_line_data_3d(line, data) 149 | 150 | if num == 0: 151 | if isinstance(axrange, int): 152 | axrange = (axrange, axrange, axrange) 153 | xcenter, ycenter, zcenter = 0, 0, 2.5 154 | stepx, stepy, stepz = axrange[0] / 2, axrange[1] / 2, axrange[2] / 2 155 | 156 | x_min, x_max = xcenter - stepx, xcenter + stepx 157 | y_min, y_max = ycenter - stepy, ycenter + stepy 158 | z_min, z_max = zcenter - stepz, zcenter + stepz 159 | 160 | ax.set_xlim(x_min, x_max) 161 | ax.set_ylim(y_min, y_max) 162 | ax.set_zlim(z_min, z_max) 163 | 164 | 165 | def skeleton_render( 166 | poses, 167 | epoch=0, 168 | out="renders", 169 | name="", 170 | sound=True, 171 | stitch=False, 172 | sound_folder="ood_sliced", 173 | contact=None, 174 | render=True 175 | ): 176 | if render: 177 | # generate the pose with FK 178 | Path(out).mkdir(parents=True, exist_ok=True) 179 | num_steps = poses.shape[0] 180 | 181 | fig = plt.figure() 182 | ax = fig.add_subplot(projection="3d") 183 | 184 | point = np.array([0, 0, 1]) 185 | normal = np.array([0, 0, 1]) 186 | d = -point.dot(normal) 187 | xx, yy = np.meshgrid(np.linspace(-1.5, 1.5, 2), np.linspace(-1.5, 1.5, 2)) 188 | z = (-normal[0] * xx - normal[1] * yy - d) * 1.0 / normal[2] 189 | # plot the plane 190 | ax.plot_surface(xx, yy, z, zorder=-11, cmap=cm.twilight) 191 | # Create lines initially without data 192 | lines = [ 193 | ax.plot([], [], [], zorder=10, linewidth=1.5)[0] 194 | for _ in smpl_parents 195 | ] 196 | scat = [ 197 | ax.scatter([], [], [], zorder=10, s=0, cmap=ListedColormap(["r", "g", "b"])) 198 | for _ in range(4) 199 | ] 200 | axrange = 3 201 | 202 | # create contact labels 203 | feet = poses[:, (7, 8, 10, 11)] 204 | feetv = np.zeros(feet.shape[:2]) 205 | feetv[:-1] = np.linalg.norm(feet[1:] - feet[:-1], axis=-1) 206 | if contact is None: 207 | contact = feetv < 0.01 208 | else: 209 | contact = contact > 0.95 210 | 211 | # Creating the Animation object 212 | anim = animation.FuncAnimation( 213 | fig, 214 | plot_single_pose, 215 | num_steps, 216 | fargs=(poses, lines, ax, axrange, scat, contact), 217 | interval=1000 // 30, 218 | ) 219 | if sound: 220 | # make a temporary directory to save the intermediate gif in 221 | if render: 222 | temp_dir = TemporaryDirectory() 223 | gifname = os.path.join(temp_dir.name, f"{epoch}.gif") 224 | anim.save(gifname) 225 | 226 | # stitch wavs 227 | if stitch: 228 | assert type(name) == list # must be a list of names to do stitching 229 | name_ = [os.path.splitext(x)[0] + ".wav" for x in name] 230 | audio, sr = lr.load(name_[0], sr=None) 231 | ll, half = len(audio), len(audio) // 2 232 | total_wav = np.zeros(ll + half * (len(name_) - 1)) 233 | total_wav[:ll] = audio 234 | idx = ll 235 | for n_ in name_[1:]: 236 | audio, sr = lr.load(n_, sr=None) 237 | total_wav[idx : idx + half] = audio[half:] 238 | idx += half 239 | # save a dummy spliced audio 240 | audioname = f"{temp_dir.name}/tempsound.wav" if render else os.path.join(out, f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.wav') 241 | sf.write(audioname, total_wav, sr) 242 | outname = os.path.join( 243 | out, 244 | f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.mp4', 245 | ) 246 | else: 247 | assert type(name) == str 248 | assert name != "", "Must provide an audio filename" 249 | audioname = name 250 | outname = os.path.join( 251 | out, f"{epoch}_{os.path.splitext(os.path.basename(name))[0]}.mp4" 252 | ) 253 | if render: 254 | out = os.system( 255 | f"ffmpeg -loglevel error -stream_loop 0 -y -i {gifname} -i {audioname} -shortest -c:v libx264 -crf 26 -c:a aac -q:a 4 {outname}" 256 | ) 257 | else: 258 | if render: 259 | # actually save the gif 260 | path = os.path.normpath(name) 261 | pathparts = path.split(os.sep) 262 | gifname = os.path.join(out, f"{pathparts[-1][:-4]}.gif") 263 | anim.save(gifname, savefig_kwargs={"transparent": True, "facecolor": "none"},) 264 | plt.close() 265 | 266 | 267 | class SMPLSkeleton: 268 | def __init__( 269 | self, device=None, 270 | ): 271 | offsets = smpl_offsets 272 | parents = smpl_parents 273 | assert len(offsets) == len(parents) 274 | 275 | self._offsets = torch.Tensor(offsets).to(device) 276 | self._parents = np.array(parents) 277 | self._compute_metadata() 278 | 279 | def _compute_metadata(self): 280 | self._has_children = np.zeros(len(self._parents)).astype(bool) 281 | for i, parent in enumerate(self._parents): 282 | if parent != -1: 283 | self._has_children[parent] = True 284 | 285 | self._children = [] 286 | for i, parent in enumerate(self._parents): 287 | self._children.append([]) 288 | for i, parent in enumerate(self._parents): 289 | if parent != -1: 290 | self._children[parent].append(i) 291 | 292 | def forward(self, rotations, root_positions): 293 | """ 294 | Perform forward kinematics using the given trajectory and local rotations. 295 | Arguments (where N = batch size, L = sequence length, J = number of joints): 296 | -- rotations: (N, L, J, 3) tensor of axis-angle rotations describing the local rotations of each joint. 297 | -- root_positions: (N, L, 3) tensor describing the root joint positions. 298 | """ 299 | assert len(rotations.shape) == 4 300 | assert len(root_positions.shape) == 3 301 | # transform from axis angle to quaternion 302 | rotations = axis_angle_to_quaternion(rotations) 303 | 304 | positions_world = [] 305 | rotations_world = [] 306 | 307 | expanded_offsets = self._offsets.expand( 308 | rotations.shape[0], 309 | rotations.shape[1], 310 | self._offsets.shape[0], 311 | self._offsets.shape[1], 312 | ) 313 | 314 | # Parallelize along the batch and time dimensions 315 | for i in range(self._offsets.shape[0]): 316 | if self._parents[i] == -1: 317 | positions_world.append(root_positions) 318 | rotations_world.append(rotations[:, :, 0]) 319 | else: 320 | positions_world.append( 321 | quaternion_apply( 322 | rotations_world[self._parents[i]], expanded_offsets[:, :, i] 323 | ) 324 | + positions_world[self._parents[i]] 325 | ) 326 | if self._has_children[i]: 327 | rotations_world.append( 328 | quaternion_multiply( 329 | rotations_world[self._parents[i]], rotations[:, :, i] 330 | ) 331 | ) 332 | else: 333 | # This joint is a terminal node -> it would be useless to compute the transformation 334 | rotations_world.append(None) 335 | 336 | return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2) 337 | -------------------------------------------------------------------------------- /train_unet_latent.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py 2 | # Reference: https://github.com/teticio/audio-diffusion/ 3 | # Training code for latent diffusion model 4 | 5 | import argparse 6 | import os 7 | import pickle 8 | import random 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from datasets import load_dataset, load_from_disk 18 | from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler, 19 | UNet2DConditionModel, UNet2DModel) 20 | from diffusers.optimization import get_scheduler 21 | from diffusers.pipelines.audio_diffusion import Mel 22 | from diffusers.training_utils import EMAModel 23 | from huggingface_hub import HfFolder, Repository, whoami 24 | from librosa.util import normalize 25 | from torchvision.transforms import Compose, Normalize, ToTensor 26 | from tqdm.auto import tqdm 27 | 28 | from audiodiffusion.pipeline_audio_diffusion import AudioDiffusionPipeline 29 | 30 | logger = get_logger(__name__) 31 | 32 | 33 | def get_full_repo_name(model_id: str, 34 | organization: Optional[str] = None, 35 | token: Optional[str] = None): 36 | if token is None: 37 | token = HfFolder.get_token() 38 | if organization is None: 39 | username = whoami(token)["name"] 40 | return f"{username}/{model_id}" 41 | else: 42 | return f"{organization}/{model_id}" 43 | 44 | 45 | def main(args): 46 | output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir 47 | logging_dir = os.path.join(output_dir, args.logging_dir) 48 | accelerator = Accelerator( 49 | gradient_accumulation_steps=args.gradient_accumulation_steps, 50 | mixed_precision=args.mixed_precision, 51 | log_with="tensorboard", 52 | project_dir=logging_dir, 53 | ) 54 | 55 | if args.dataset_name is not None: 56 | if os.path.exists(args.dataset_name): 57 | dataset = load_from_disk( 58 | args.dataset_name, 59 | storage_options=args.dataset_config_name)["train"] 60 | else: 61 | dataset = load_dataset( 62 | args.dataset_name, 63 | args.dataset_config_name, 64 | cache_dir=args.cache_dir, 65 | use_auth_token=True if args.use_auth_token else None, 66 | split="train", 67 | ) 68 | else: 69 | dataset = load_dataset( 70 | "imagefolder", 71 | data_dir=args.train_data_dir, 72 | cache_dir=args.cache_dir, 73 | split="train", 74 | ) 75 | # Determine image resolution 76 | resolution = dataset[0]["image"].height, dataset[0]["image"].width 77 | 78 | augmentations = Compose([ 79 | ToTensor(), 80 | Normalize([0.5], [0.5]), 81 | ]) 82 | 83 | def transforms(examples): 84 | if args.vae is not None and vqvae.config["in_channels"] == 3: 85 | images = [ 86 | augmentations(image.convert("RGB")) 87 | for image in examples["image"] 88 | ] 89 | else: 90 | images = [augmentations(image) for image in examples["image"]] 91 | if args.encodings is not None: 92 | encoding = [encodings[os.path.splitext(os.path.basename(file))[0]] for file in examples["audio_file"]] 93 | print(np.array(encoding).shape) 94 | encoding = np.array(encoding).reshape(1, 150, 226) 95 | return {"input": images, "encoding": torch.Tensor(encoding)} 96 | return {"input": images} 97 | 98 | dataset.set_transform(transforms) 99 | train_dataloader = torch.utils.data.DataLoader( 100 | dataset, batch_size=args.train_batch_size, shuffle=True) 101 | 102 | if args.encodings is not None: 103 | encodings = pickle.load(open(args.encodings, "rb")) 104 | 105 | vqvae = None 106 | if args.vae is not None: 107 | try: 108 | vqvae = AutoencoderKL.from_pretrained(args.vae) 109 | except EnvironmentError: 110 | vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae 111 | # Determine latent resolution 112 | with torch.no_grad(): 113 | latent_resolution = vqvae.encode( 114 | torch.zeros((1, 1) + 115 | resolution)).latent_dist.sample().shape[2:] 116 | 117 | if args.from_pretrained is not None: 118 | pipeline = AudioDiffusionPipeline.from_pretrained(args.from_pretrained) 119 | mel = pipeline.mel 120 | model = pipeline.unet 121 | if hasattr(pipeline, "vqvae"): 122 | vqvae = pipeline.vqvaee 123 | 124 | else: 125 | if args.encodings is None: 126 | model = UNet2DModel( 127 | sample_size=resolution if vqvae is None else latent_resolution, 128 | in_channels=1 129 | if vqvae is None else vqvae.config["latent_channels"], 130 | out_channels=1 131 | if vqvae is None else vqvae.config["latent_channels"], 132 | layers_per_block=2, 133 | block_out_channels=(128, 128, 256, 256, 512, 512), 134 | down_block_types=( 135 | "DownBlock2D", 136 | "DownBlock2D", 137 | "DownBlock2D", 138 | "DownBlock2D", 139 | "AttnDownBlock2D", 140 | "DownBlock2D", 141 | ), 142 | up_block_types=( 143 | "UpBlock2D", 144 | "AttnUpBlock2D", 145 | "UpBlock2D", 146 | "UpBlock2D", 147 | "UpBlock2D", 148 | "UpBlock2D", 149 | ), 150 | ) 151 | 152 | else: 153 | model = UNet2DConditionModel( 154 | sample_size=resolution if vqvae is None else latent_resolution, 155 | in_channels=1 156 | if vqvae is None else vqvae.config["latent_channels"], 157 | out_channels=1 158 | if vqvae is None else vqvae.config["latent_channels"], 159 | layers_per_block=2, 160 | block_out_channels=(128, 256, 512, 512), 161 | down_block_types=( 162 | "CrossAttnDownBlock2D", 163 | "CrossAttnDownBlock2D", 164 | "CrossAttnDownBlock2D", 165 | "DownBlock2D", 166 | ), 167 | up_block_types=( 168 | "UpBlock2D", 169 | "CrossAttnUpBlock2D", 170 | "CrossAttnUpBlock2D", 171 | "CrossAttnUpBlock2D", 172 | ), 173 | 174 | cross_attention_dim=list(encodings.values())[0].shape[-1], 175 | #cross_attention_dim = 226, 176 | ) 177 | 178 | if args.scheduler == "ddpm": 179 | noise_scheduler = DDPMScheduler( 180 | num_train_timesteps=args.num_train_steps) 181 | else: 182 | noise_scheduler = DDIMScheduler( 183 | num_train_timesteps=args.num_train_steps) 184 | 185 | optimizer = torch.optim.AdamW( 186 | model.parameters(), 187 | lr=args.learning_rate, 188 | betas=(args.adam_beta1, args.adam_beta2), 189 | weight_decay=args.adam_weight_decay, 190 | eps=args.adam_epsilon, 191 | ) 192 | 193 | lr_scheduler = get_scheduler( 194 | args.lr_scheduler, 195 | optimizer=optimizer, 196 | num_warmup_steps=args.lr_warmup_steps, 197 | num_training_steps=(len(train_dataloader) * args.num_epochs) // 198 | args.gradient_accumulation_steps, 199 | ) 200 | 201 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 202 | model, optimizer, train_dataloader, lr_scheduler) 203 | 204 | ema_model = EMAModel( 205 | getattr(model, "module", model), 206 | inv_gamma=args.ema_inv_gamma, 207 | power=args.ema_power, 208 | max_value=args.ema_max_decay, 209 | ) 210 | 211 | if args.push_to_hub: 212 | if args.hub_model_id is None: 213 | repo_name = get_full_repo_name(Path(output_dir).name, 214 | token=args.hub_token) 215 | else: 216 | repo_name = args.hub_model_id 217 | repo = Repository(output_dir, clone_from=repo_name) 218 | 219 | if accelerator.is_main_process: 220 | run = os.path.split(__file__)[-1].split(".")[0] 221 | accelerator.init_trackers(run) 222 | 223 | mel = Mel( 224 | x_res=resolution[1], 225 | y_res=resolution[0], 226 | hop_length=args.hop_length, 227 | sample_rate=args.sample_rate, 228 | n_fft=args.n_fft, 229 | ) 230 | 231 | global_step = 0 232 | for epoch in range(args.num_epochs): 233 | progress_bar = tqdm(total=len(train_dataloader), 234 | disable=not accelerator.is_local_main_process) 235 | progress_bar.set_description(f"Epoch {epoch}") 236 | 237 | if epoch < args.start_epoch: 238 | for step in range(len(train_dataloader)): 239 | optimizer.step() 240 | lr_scheduler.step() 241 | progress_bar.update(1) 242 | global_step += 1 243 | if epoch == args.start_epoch - 1 and args.use_ema: 244 | ema_model.optimization_step = global_step 245 | continue 246 | 247 | model.train() 248 | for step, batch in enumerate(train_dataloader): 249 | clean_images = batch["input"] 250 | 251 | if vqvae is not None: 252 | vqvae.to(clean_images.device) 253 | with torch.no_grad(): 254 | clean_images = vqvae.encode( 255 | clean_images).latent_dist.sample() 256 | # Scale latent images to ensure approximately unit variance 257 | clean_images = clean_images * 0.18215 258 | 259 | # Sample noise that we'll add to the images 260 | noise = torch.randn(clean_images.shape).to(clean_images.device) 261 | bsz = clean_images.shape[0] 262 | # Sample a random timestep for each image 263 | timesteps = torch.randint( 264 | 0, 265 | noise_scheduler.config.num_train_timesteps, 266 | (bsz, ), 267 | device=clean_images.device, 268 | ).long() 269 | 270 | # Add noise to the clean images according to the noise magnitude at each timestep 271 | # (this is the forward diffusion process) 272 | noisy_images = noise_scheduler.add_noise(clean_images, noise, 273 | timesteps) 274 | 275 | with accelerator.accumulate(model): 276 | # Predict the noise residual 277 | if args.encodings is not None: 278 | noise_pred = model(noisy_images, timesteps, 279 | batch["encoding"])["sample"] 280 | else: 281 | noise_pred = model(noisy_images, timesteps)["sample"] 282 | loss = F.mse_loss(noise_pred, noise) 283 | accelerator.backward(loss) 284 | 285 | if accelerator.sync_gradients: 286 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 287 | optimizer.step() 288 | lr_scheduler.step() 289 | if args.use_ema: 290 | ema_model.step(model) 291 | optimizer.zero_grad() 292 | 293 | progress_bar.update(1) 294 | global_step += 1 295 | 296 | logs = { 297 | "loss": loss.detach().item(), 298 | "lr": lr_scheduler.get_last_lr()[0], 299 | "step": global_step, 300 | } 301 | if args.use_ema: 302 | logs["ema_decay"] = ema_model.decay 303 | progress_bar.set_postfix(**logs) 304 | accelerator.log(logs, step=global_step) 305 | progress_bar.close() 306 | 307 | accelerator.wait_for_everyone() 308 | 309 | # Generate sample images for visual inspection 310 | if accelerator.is_main_process: 311 | if ((epoch + 1) % args.save_model_epochs == 0 312 | or (epoch + 1) % args.save_images_epochs == 0 313 | or epoch == args.num_epochs - 1): 314 | unet = accelerator.unwrap_model(model) 315 | if args.use_ema: 316 | ema_model.copy_to(unet.parameters()) 317 | pipeline = AudioDiffusionPipeline( 318 | vqvae=vqvae, 319 | unet=unet, 320 | mel=mel, 321 | scheduler=noise_scheduler, 322 | ) 323 | 324 | if ( 325 | epoch + 1 326 | ) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: 327 | pipeline.save_pretrained(output_dir) 328 | 329 | # save the model 330 | if args.push_to_hub: 331 | repo.push_to_hub( 332 | commit_message=f"Epoch {epoch}", 333 | blocking=False, 334 | auto_lfs_prune=True, 335 | ) 336 | 337 | if (epoch + 1) % args.save_images_epochs == 0: 338 | generator = torch.Generator( 339 | device=clean_images.device).manual_seed(42) 340 | 341 | if args.encodings is not None: 342 | random.seed(42) 343 | encoding = random.sample(list(np.array(list(encodings.values()))), args.eval_batch_size) 344 | encoding = np.array(encoding).reshape(args.eval_batch_size, 150, 226) 345 | encoding = torch.Tensor(encoding).to(clean_images.device) 346 | else: 347 | encoding = None 348 | 349 | # run pipeline in inference (sample random noise and denoise) 350 | images, (sample_rate, audios) = pipeline( 351 | generator=generator, 352 | batch_size=args.eval_batch_size, 353 | return_dict=False, 354 | encoding=encoding, 355 | ) 356 | 357 | # denormalize the images and save to tensorboard 358 | images = np.array([ 359 | np.frombuffer(image.tobytes(), dtype="uint8").reshape( 360 | (len(image.getbands()), image.height, image.width)) 361 | for image in images 362 | ]) 363 | accelerator.trackers[0].writer.add_images( 364 | "test_samples", images, epoch) 365 | for _, audio in enumerate(audios): 366 | accelerator.trackers[0].writer.add_audio( 367 | f"test_audio_{_}", 368 | normalize(audio), 369 | epoch, 370 | sample_rate=sample_rate, 371 | ) 372 | accelerator.wait_for_everyone() 373 | 374 | accelerator.end_training() 375 | 376 | 377 | if __name__ == "__main__": 378 | parser = argparse.ArgumentParser( 379 | description="Simple example of a training script.") 380 | parser.add_argument("--local_rank", type=int, default=-1) 381 | parser.add_argument("--dataset_name", type=str, default="aistpp_256_sorted") 382 | parser.add_argument("--dataset_config_name", type=str, default=None) 383 | parser.add_argument( 384 | "--train_data_dir", 385 | type=str, 386 | default="aistpp_256_sorted", 387 | help="A folder containing the training data.", 388 | ) 389 | parser.add_argument("--output_dir", type=str, default=r"/host_data/van/edge_aistpp/modelsv2/all_01") 390 | parser.add_argument("--overwrite_output_dir", type=bool, default=False) 391 | parser.add_argument("--cache_dir", type=str, default=None) 392 | parser.add_argument("--train_batch_size", type=int, default=8) 393 | parser.add_argument("--eval_batch_size", type=int, default=8) 394 | parser.add_argument("--num_epochs", type=int, default=100) 395 | parser.add_argument("--save_images_epochs", type=int, default=10) 396 | parser.add_argument("--save_model_epochs", type=int, default=10) 397 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 398 | parser.add_argument("--learning_rate", type=float, default=1e-4) 399 | parser.add_argument("--lr_scheduler", type=str, default="cosine") 400 | parser.add_argument("--lr_warmup_steps", type=int, default=500) 401 | parser.add_argument("--adam_beta1", type=float, default=0.95) 402 | parser.add_argument("--adam_beta2", type=float, default=0.999) 403 | parser.add_argument("--adam_weight_decay", type=float, default=1e-6) 404 | parser.add_argument("--adam_epsilon", type=float, default=1e-08) 405 | parser.add_argument("--use_ema", type=bool, default=True) 406 | parser.add_argument("--ema_inv_gamma", type=float, default=1.0) 407 | parser.add_argument("--ema_power", type=float, default=3 / 4) 408 | parser.add_argument("--ema_max_decay", type=float, default=0.9999) 409 | parser.add_argument("--push_to_hub", type=bool, default=False) 410 | parser.add_argument("--use_auth_token", type=bool, default=False) 411 | parser.add_argument("--hub_token", type=str, default=None) 412 | parser.add_argument("--hub_model_id", type=str, default=None) 413 | parser.add_argument("--hub_private_repo", type=bool, default=False) 414 | parser.add_argument("--logging_dir", type=str, default="logs") 415 | parser.add_argument( 416 | "--mixed_precision", 417 | type=str, 418 | default="no", 419 | choices=["no", "fp16", "bf16"], 420 | help=( 421 | "Whether to use mixed precision. Choose" 422 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 423 | "and an Nvidia Ampere GPU."), 424 | ) 425 | parser.add_argument("--hop_length", type=int, default=512) 426 | parser.add_argument("--sample_rate", type=int, default=22050) 427 | parser.add_argument("--n_fft", type=int, default=2048) 428 | parser.add_argument("--from_pretrained", type=str, default=None) 429 | parser.add_argument("--start_epoch", type=int, default=0) 430 | parser.add_argument("--num_train_steps", type=int, default=1000) 431 | parser.add_argument("--scheduler", 432 | type=str, 433 | default="ddim", 434 | help="ddpm or ddim") 435 | parser.add_argument( 436 | "--vae", 437 | type=str, 438 | default='teticio/latent-audio-diffusion-ddim-256', 439 | help="pretrained VAE model for latent diffusion", 440 | ) 441 | parser.add_argument( 442 | "--encodings", 443 | type=str, 444 | default=r"/host_data/van/edge_aistpp/encoding/normalized_all_train_data_01.pkl", 445 | help="picked dictionary mapping audio_file to encoding", 446 | ) 447 | 448 | args = parser.parse_args() 449 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 450 | if env_local_rank != -1 and env_local_rank != args.local_rank: 451 | args.local_rank = env_local_rank 452 | 453 | if args.dataset_name is None and args.train_data_dir is None: 454 | raise ValueError( 455 | "You must specify either a dataset name from the hub or a train data directory." 456 | ) 457 | 458 | main(args) 459 | --------------------------------------------------------------------------------