├── README.md ├── config ├── gesture_autoencoder.yml ├── hierarchy.yml ├── joint_embed.yml ├── multimodal_context.yml ├── seq2seq.yml └── speech2gesture.yml ├── config_expressive ├── gesture_autoencoder.yml ├── hierarchy.yml ├── joint_embed.yml ├── multimodal_context.yml ├── seq2seq.yml └── speech2gesture.yml ├── dataset_script ├── README.md ├── requirements.txt ├── script │ ├── calc_mean.py │ ├── clip_filter.py │ ├── config.py │ ├── data_utils.py │ ├── download_video.py │ ├── inference.py │ ├── main_speaker_selector.py │ ├── make_ted_dataset.py │ ├── merge_dataset.py │ ├── motion_preprocessor.py │ ├── review_filtered_clips.py │ ├── run_clip_filtering.py │ ├── run_expose.py │ ├── run_ffmpeg.py │ ├── run_gentle.py │ ├── run_mp3.py │ ├── run_openpose.py │ └── run_scenedetect.py └── video_ids.txt ├── license ├── misc ├── HA2G.png ├── sample1.gif ├── sample1.mp4 ├── sample2.gif └── sample2.mp4 ├── requirements.txt ├── scripts ├── calculate_angle_stats.py ├── calculate_motion_stats.py ├── data_loader │ ├── data_preprocessor.py │ ├── data_preprocessor_expressive.py │ ├── h36m_loader.py │ ├── lmdb_data_loader.py │ ├── lmdb_data_loader_expressive.py │ ├── motion_preprocessor.py │ └── motion_preprocessor_expressive.py ├── model │ ├── ResNetBlocks.py │ ├── ResNetSE34V2.py │ ├── embedding_net.py │ ├── embedding_space_evaluator.py │ ├── hierarchy_net.py │ ├── motion_ae.py │ ├── multimodal_context_net.py │ ├── seq2seq_net.py │ ├── speech2gesture.py │ ├── tcn.py │ ├── utils.py │ └── vocab.py ├── parse_args.py ├── synthesize.py ├── synthesize_expressive_hierarchy.py ├── synthesize_hierarchy.py ├── train.py ├── train_eval │ ├── train_gan.py │ ├── train_hierarchy.py │ ├── train_hierarchy_expressive.py │ ├── train_joint_embed.py │ ├── train_seq2seq.py │ └── train_speech2gesture.py ├── train_expressive.py ├── train_feature_extractor.py ├── train_feature_extractor_expressive.py └── utils │ ├── average_meter.py │ ├── data_utils.py │ ├── data_utils_expressive.py │ ├── train_utils.py │ ├── train_utils_expressive.py │ ├── tts_helper.py │ └── vocab_utils.py └── training_logs ├── ted_expressive_new.log ├── ted_expressive_original.log ├── ted_gesture_new.log └── ted_gesture_original.log /config/gesture_autoencoder.yml: -------------------------------------------------------------------------------- 1 | name: gesture_autoencoder 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_test 6 | 7 | model_save_path: output/train_h36m_gesture_autoencoder 8 | random_seed: -1 9 | 10 | # model params 11 | model: gesture_autoencoder 12 | mean_dir_vec: [ 0.0154009, -0.9690125, -0.0884354, -0.0022264, -0.8655276, 0.4342174, -0.0035145, -0.8755367, -0.4121039, -0.9236511, 0.3061306, -0.0012415, -0.5155854, 0.8129665, 0.0871897, 0.2348464, 0.1846561, 0.8091402, 0.9271948, 0.2960011, -0.013189 , 0.5233978, 0.8092403, 0.0725451, -0.2037076, 0.1924306, 0.8196916] 13 | mean_pose: [ 0.0000306, 0.0004946, 0.0008437, 0.0033759, -0.2051629, -0.0143453, 0.0031566, -0.3054764, 0.0411491, 0.0029072, -0.4254303, -0.001311 , -0.1458413, -0.1505532, -0.0138192, -0.2835603, 0.0670333, 0.0107002, -0.2280813, 0.112117 , 0.2087789, 0.1523502, -0.1521499, -0.0161503, 0.291909 , 0.0644232, 0.0040145, 0.2452035, 0.1115339, 0.2051307] 14 | 15 | # train params 16 | epochs: 500 17 | batch_size: 128 18 | learning_rate: 0.0005 19 | 20 | # dataset params 21 | motion_resampling_framerate: 15 22 | n_poses: 34 23 | n_pre_poses: 4 24 | subdivision_stride: 10 25 | loader_workers: 4 26 | 27 | pose_dim: 27 -------------------------------------------------------------------------------- /config/hierarchy.yml: -------------------------------------------------------------------------------- 1 | name: hierarchy 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin # from https://fasttext.cc/docs/en/english-vectors.html 9 | #freeze_wordembed: true 10 | 11 | model_save_path: TED-Gesture-output/train_hierarchy 12 | random_seed: -1 13 | 14 | # model params 15 | model: hierarchy 16 | mean_dir_vec: [ 0.0154009, -0.9690125, -0.0884354, -0.0022264, -0.8655276, 0.4342174, -0.0035145, -0.8755367, -0.4121039, -0.9236511, 0.3061306, -0.0012415, -0.5155854, 0.8129665, 0.0871897, 0.2348464, 0.1846561, 0.8091402, 0.9271948, 0.2960011, -0.013189 , 0.5233978, 0.8092403, 0.0725451, -0.2037076, 0.1924306, 0.8196916] 17 | mean_pose: [ 0.0000306, 0.0004946, 0.0008437, 0.0033759, -0.2051629, -0.0143453, 0.0031566, -0.3054764, 0.0411491, 0.0029072, -0.4254303, -0.001311 , -0.1458413, -0.1505532, -0.0138192, -0.2835603, 0.0670333, 0.0107002, -0.2280813, 0.112117 , 0.2087789, 0.1523502, -0.1521499, -0.0161503, 0.291909 , 0.0644232, 0.0040145, 0.2452035, 0.1115339, 0.2051307] 18 | 19 | n_layers: 4 20 | hidden_size: 300 21 | z_type: speaker # speaker, random, none 22 | input_context: both # both, audio, text, none 23 | 24 | # train params 25 | epochs: 100 26 | batch_size: 256 27 | learning_rate: 5e-4 28 | loss_regression_weight: 70.0 29 | loss_gan_weight: 5.0 30 | loss_warmup: 10 31 | loss_kld_weight: 0.1 32 | loss_reg_weight: 0.05 33 | 34 | loss_contrastive_pos_weight: 0.2 35 | loss_contrastive_neg_weight: 0.005 36 | loss_physical_weight: 0.01 37 | 38 | # eval params 39 | eval_net_path: data/train_h36m_gesture_autoencoder_gesture_autoencoder_checkpoint_best.bin 40 | 41 | # dataset params 42 | motion_resampling_framerate: 15 43 | n_poses: 34 44 | n_pre_poses: 4 45 | subdivision_stride: 10 46 | loader_workers: 8 47 | 48 | pose_dim: 27 -------------------------------------------------------------------------------- /config/joint_embed.yml: -------------------------------------------------------------------------------- 1 | name: joint_embedding 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin 9 | 10 | model_save_path: TED-Gesture-output/train_joint_embedding 11 | random_seed: -1 12 | 13 | # model params 14 | model: joint_embedding 15 | mean_dir_vec: [ 0.0154009, -0.9690125, -0.0884354, -0.0022264, -0.8655276, 0.4342174, -0.0035145, -0.8755367, -0.4121039, -0.9236511, 0.3061306, -0.0012415, -0.5155854, 0.8129665, 0.0871897, 0.2348464, 0.1846561, 0.8091402, 0.9271948, 0.2960011, -0.013189 , 0.5233978, 0.8092403, 0.0725451, -0.2037076, 0.1924306, 0.8196916] 16 | mean_pose: [ 0.0000306, 0.0004946, 0.0008437, 0.0033759, -0.2051629, -0.0143453, 0.0031566, -0.3054764, 0.0411491, 0.0029072, -0.4254303, -0.001311 , -0.1458413, -0.1505532, -0.0138192, -0.2835603, 0.0670333, 0.0107002, -0.2280813, 0.112117 , 0.2087789, 0.1523502, -0.1521499, -0.0161503, 0.291909 , 0.0644232, 0.0040145, 0.2452035, 0.1115339, 0.2051307] 17 | 18 | # train params 19 | epochs: 100 20 | batch_size: 128 21 | learning_rate: 0.0005 22 | 23 | # eval params 24 | eval_net_path: data/train_h36m_gesture_autoencoder_gesture_autoencoder_checkpoint_best.bin 25 | 26 | # dataset params 27 | motion_resampling_framerate: 15 28 | n_poses: 34 29 | n_pre_poses: 4 30 | subdivision_stride: 10 31 | loader_workers: 3 32 | 33 | pose_dim: 27 -------------------------------------------------------------------------------- /config/multimodal_context.yml: -------------------------------------------------------------------------------- 1 | name: multimodal_context 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin # from https://fasttext.cc/docs/en/english-vectors.html 9 | #freeze_wordembed: true 10 | 11 | model_save_path: TED-Gesture-output/train_multimodal_context 12 | random_seed: -1 13 | 14 | # model params 15 | model: multimodal_context 16 | mean_dir_vec: [ 0.0154009, -0.9690125, -0.0884354, -0.0022264, -0.8655276, 0.4342174, -0.0035145, -0.8755367, -0.4121039, -0.9236511, 0.3061306, -0.0012415, -0.5155854, 0.8129665, 0.0871897, 0.2348464, 0.1846561, 0.8091402, 0.9271948, 0.2960011, -0.013189 , 0.5233978, 0.8092403, 0.0725451, -0.2037076, 0.1924306, 0.8196916] 17 | mean_pose: [ 0.0000306, 0.0004946, 0.0008437, 0.0033759, -0.2051629, -0.0143453, 0.0031566, -0.3054764, 0.0411491, 0.0029072, -0.4254303, -0.001311 , -0.1458413, -0.1505532, -0.0138192, -0.2835603, 0.0670333, 0.0107002, -0.2280813, 0.112117 , 0.2087789, 0.1523502, -0.1521499, -0.0161503, 0.291909 , 0.0644232, 0.0040145, 0.2452035, 0.1115339, 0.2051307] 18 | 19 | n_layers: 4 20 | hidden_size: 300 21 | z_type: speaker # speaker, random, none 22 | input_context: both # both, audio, text, none 23 | 24 | # train params 25 | epochs: 100 26 | batch_size: 128 27 | learning_rate: 0.0005 28 | loss_regression_weight: 500 29 | loss_gan_weight: 5.0 30 | loss_warmup: 10 31 | loss_kld_weight: 0.1 32 | loss_reg_weight: 0.05 33 | 34 | # eval params 35 | eval_net_path: data/train_h36m_gesture_autoencoder_gesture_autoencoder_checkpoint_best.bin 36 | 37 | # dataset params 38 | motion_resampling_framerate: 15 39 | n_poses: 34 40 | n_pre_poses: 4 41 | subdivision_stride: 10 42 | loader_workers: 4 43 | 44 | pose_dim: 27 -------------------------------------------------------------------------------- /config/seq2seq.yml: -------------------------------------------------------------------------------- 1 | name: seq2seq 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin 9 | 10 | model_save_path: TED-Gesture-output/train_seq2seq 11 | random_seed: -1 12 | 13 | # model params 14 | model: seq2seq 15 | mean_dir_vec: [ 0.0154009, -0.9690125, -0.0884354, -0.0022264, -0.8655276, 0.4342174, -0.0035145, -0.8755367, -0.4121039, -0.9236511, 0.3061306, -0.0012415, -0.5155854, 0.8129665, 0.0871897, 0.2348464, 0.1846561, 0.8091402, 0.9271948, 0.2960011, -0.013189 , 0.5233978, 0.8092403, 0.0725451, -0.2037076, 0.1924306, 0.8196916] 16 | mean_pose: [ 0.0000306, 0.0004946, 0.0008437, 0.0033759, -0.2051629, -0.0143453, 0.0031566, -0.3054764, 0.0411491, 0.0029072, -0.4254303, -0.001311 , -0.1458413, -0.1505532, -0.0138192, -0.2835603, 0.0670333, 0.0107002, -0.2280813, 0.112117 , 0.2087789, 0.1523502, -0.1521499, -0.0161503, 0.291909 , 0.0644232, 0.0040145, 0.2452035, 0.1115339, 0.2051307] 17 | hidden_size: 200 18 | n_layers: 2 19 | dropout_prob: 0.1 20 | z_type: none 21 | 22 | # train params 23 | epochs: 100 24 | batch_size: 128 25 | learning_rate: 0.0001 26 | loss_regression_weight: 250 27 | loss_kld_weight: 0.1 # weight for continuous motion term 28 | loss_reg_weight: 25 # weight for motion variance term 29 | 30 | # eval params 31 | eval_net_path: data/train_h36m_gesture_autoencoder_gesture_autoencoder_checkpoint_best.bin 32 | 33 | # dataset params 34 | motion_resampling_framerate: 15 35 | n_poses: 34 36 | n_pre_poses: 4 37 | subdivision_stride: 10 38 | loader_workers: 3 39 | 40 | pose_dim: 27 -------------------------------------------------------------------------------- /config/speech2gesture.yml: -------------------------------------------------------------------------------- 1 | name: speech2gesture 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/ted_dataset/lmdb_test 6 | 7 | model_save_path: TED-Gesture-output/train_speech2gesture 8 | random_seed: -1 9 | 10 | # model params 11 | model: speech2gesture 12 | mean_dir_vec: [ 0.0154009, -0.9690125, -0.0884354, -0.0022264, -0.8655276, 0.4342174, -0.0035145, -0.8755367, -0.4121039, -0.9236511, 0.3061306, -0.0012415, -0.5155854, 0.8129665, 0.0871897, 0.2348464, 0.1846561, 0.8091402, 0.9271948, 0.2960011, -0.013189 , 0.5233978, 0.8092403, 0.0725451, -0.2037076, 0.1924306, 0.8196916] 13 | mean_pose: [ 0.0000306, 0.0004946, 0.0008437, 0.0033759, -0.2051629, -0.0143453, 0.0031566, -0.3054764, 0.0411491, 0.0029072, -0.4254303, -0.001311 , -0.1458413, -0.1505532, -0.0138192, -0.2835603, 0.0670333, 0.0107002, -0.2280813, 0.112117 , 0.2087789, 0.1523502, -0.1521499, -0.0161503, 0.291909 , 0.0644232, 0.0040145, 0.2452035, 0.1115339, 0.2051307] 14 | 15 | # train params 16 | epochs: 100 17 | batch_size: 128 18 | learning_rate: 0.001 19 | loss_regression_weight: 100 20 | loss_gan_weight: 10.0 21 | 22 | # eval params 23 | eval_net_path: data/train_h36m_gesture_autoencoder_gesture_autoencoder_checkpoint_best.bin 24 | 25 | # dataset params 26 | motion_resampling_framerate: 15 27 | n_poses: 34 28 | n_pre_poses: 4 29 | subdivision_stride: 10 30 | loader_workers: 3 31 | 32 | pose_dim: 27 -------------------------------------------------------------------------------- /config_expressive/gesture_autoencoder.yml: -------------------------------------------------------------------------------- 1 | name: ted_expressive_pose_autoencoder 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin 9 | #freeze_wordembed: true 10 | 11 | model_save_path: TED-Expressive-output/AE-cos1e-3 12 | random_seed: -1 13 | 14 | # model params 15 | model: multimodal_context 16 | mean_pose: [-0.0046788, -0.5397806, 0.007695 , -0.0171913, -0.7060388,-0.0107034, 0.1550734, -0.6823077, -0.0303645, -0.1514748, -0.6819547, -0.0268262, 0.2094328, -0.469447 , -0.0096073, -0.2318253, -0.4680838, -0.0444074, 0.1667382, -0.4643363, -0.1895118, -0.1648597, -0.4552845, -0.2159728, 0.1387546, -0.4859474, -0.2506667, 0.1263615, -0.4856088, -0.2675801, 0.1149031, -0.4804542, -0.267329 , 0.1414847, -0.4727709, -0.2583424, 0.1262482, -0.4686185, -0.2682536, 0.1150217, -0.4633611, -0.2640182, 0.1475897, -0.4415648, -0.2438853, 0.1367996, -0.4383164, -0.248248 , 0.1267222, -0.435534 , -0.2455436, 0.1455485, -0.4557491, -0.2521977, 0.1305471, -0.4535603, -0.2611591, 0.1184687, -0.4495366, -0.257798 , 0.1451682, -0.4802511, -0.2081622, 0.1301337, -0.4865308, -0.2175783, 0.1208341, -0.4932623, -0.2311025, -0.1409241,-0.4742868, -0.2795303, -0.1287992, -0.4724431, -0.2963172,-0.1159225, -0.4676439, -0.2948754, -0.1427748, -0.4589126,-0.2861245, -0.126862 , -0.4547355, -0.2962466, -0.1140265,-0.451308 , -0.2913815, -0.1447202, -0.4260471, -0.2697673,-0.1333492, -0.4239912, -0.2738043, -0.1226859, -0.4238346,-0.2706725, -0.1446909, -0.440342 , -0.2789209, -0.1291436,-0.4391063, -0.2876539, -0.1160435, -0.4376317, -0.2836147,-0.1441438, -0.4729031, -0.2355619, -0.1293268, -0.4793807,-0.2468831, -0.1204146, -0.4847246, -0.2613876, -0.0056085,-0.9224338, -0.1677302, -0.0352157, -0.963936 , -0.1388849,0.0236298, -0.9650772, -0.1385154, -0.0697098, -0.9514691,-0.055632 , 0.0568838, -0.9565502, -0.0567985] 17 | # mean_dir_vec: [-0.0738017, -0.9841649, -0.1082826, 0.8811584, 0.2400715,-0.102548 , -0.8614666, 0.3131963, -0.1039408, 0.2094324,0.9249059, 0.0824935, -0.168881 , -0.035383 , -0.7481456,-0.2794113, -0.2494704, -0.611501 , -0.3870275, 0.0050056,-0.528493 , -0.507089 , 0.2256335, 0.0053114, -0.2392472,-0.1022134, -0.6528637, -0.4960522, 0.1227789, -0.3288104, -0.4734803, 0.2132035, 0.174309 , -0.2060902, 0.2305164, -0.5872575, -0.5408988, 0.1303091, -0.2180242, -0.5190851,0.1211129, 0.1337597, -0.2163111, 0.0743368, -0.6415168,-0.5247244, 0.045758 , -0.3192654, -0.5047838, 0.1537042,0.1365917, -0.4349654, -0.3833296, -0.3847372, -0.4905643, -0.241614 , -0.3053538, -0.3553988, -0.2816598, -0.5148801,-0.3064362, 0.8918868, -0.0671329, 0.27626 , 0.006997 ,-0.7267295, 0.2420451, -0.2257494, -0.6342812, 0.3782409,0.0283403, -0.5432504, 0.571155 , 0.1934389, 0.0632587,0.2121533, -0.0624191, -0.6689178, 0.5177199, 0.1043125,-0.3447495, 0.5414736, 0.1280138, 0.2073257, 0.2196023,0.2821322, -0.578687 , 0.5695349, 0.0786589, -0.2132052,0.5497312, -0.0006156, 0.1598387, 0.2091917, 0.1241196,-0.6456096, 0.542575 , 0.0114165, -0.32 , 0.5477947,0.0489276, 0.1676565, 0.418691 , -0.401504 , -0.3910536,0.4823198, -0.2667319, -0.3554963, 0.341438 , -0.2418989,-0.548876 , 0.0485504, -0.6335777, -0.6837655, -0.470992 ,-0.6394692, 0.4633637, 0.4539758, -0.6504886, 0.4596563,-0.3254217, 0.1883226, 0.7889246, 0.3254174, 0.1292426,0.7991108] 18 | mean_dir_vec: [-0.0737964, -0.9968923, -0.1082858, 0.9111595, 0.2399522, -0.102547 , -0.8936886, 0.3131501, -0.1039348, 0.2093927, 0.958293 , 0.0824881, -0.1689021, -0.0353824, -0.7588258, -0.2794763, -0.2495191, -0.614666 , -0.3877234, 0.005006 , -0.5301695, -0.5098616, 0.2257808, 0.0053111, -0.2393621, -0.1022204, -0.6583039, -0.4992898, 0.1228059, -0.3292085, -0.4753748, 0.2132857, 0.1742853, -0.2062069, 0.2305175, -0.5897119, -0.5452555, 0.1303197, -0.2181693, -0.5221036, 0.1211322, 0.1337591, -0.2164441, 0.0743345, -0.6464546, -0.5284583, 0.0457585, -0.319634 , -0.5074904, 0.1537192, 0.1365934, -0.4354402, -0.3836682, -0.3850554, -0.4927187, -0.2417618, -0.3054556, -0.3556116, -0.281753 , -0.5164358, -0.3064435, 0.9284261, -0.067134 , 0.2764367, 0.006997 , -0.7365526, 0.2421269, -0.225798 , -0.6387642, 0.3788997, 0.0283412, -0.5451686, 0.5753376, 0.1935219, 0.0632555, 0.2122412, -0.0624179, -0.6755542, 0.5212831, 0.1043523, -0.345288 , 0.5443628, 0.128029 , 0.2073687, 0.2197118, 0.2821399, -0.580695 , 0.573988 , 0.0786667, -0.2133071, 0.5532452, -0.0006157, 0.1598754, 0.2093099, 0.124119, -0.6504359, 0.5465003, 0.0114155, -0.3203954, 0.5512083, 0.0489287, 0.1676814, 0.4190787, -0.4018607, -0.3912126, 0.4841548, -0.2668508, -0.3557675, 0.3416916, -0.2419564, -0.5509825, 0.0485515, -0.6343101, -0.6817347, -0.4705639, -0.6380668, 0.4641643, 0.4540192, -0.6486361, 0.4604001, -0.3256226, 0.1883097, 0.8057457, 0.3257385, 0.1292366, 0.815372] 19 | 20 | n_layers: 4 21 | hidden_size: 300 22 | z_type: speaker # speaker, random, none 23 | input_context: both # both, audio, text, none 24 | 25 | # train params 26 | epochs: 180 27 | batch_size: 64 28 | learning_rate: 0.0005 29 | loss_regression_weight: 500 30 | 31 | mse_loss_weight: 60000 32 | cos_loss_weight: 0.001 33 | static_loss_weight: 0 34 | motion_loss_weight: 900 35 | 36 | g_update_step: 5 37 | 38 | loss_gan_weight: 5.0 39 | loss_warmup: 10 40 | loss_kld_weight: 0.1 41 | loss_reg_weight: 0.05 42 | 43 | # eval params 44 | eval_net_path: TED-Expressive-output/AE-cos1e-3/checkpoint_best.bin 45 | 46 | # dataset params 47 | motion_resampling_framerate: 15 48 | n_poses: 34 49 | n_pre_poses: 4 50 | subdivision_stride: 30 51 | loader_workers: 16 52 | 53 | pose_dim: 126 54 | latent_dim: 128 -------------------------------------------------------------------------------- /config_expressive/hierarchy.yml: -------------------------------------------------------------------------------- 1 | name: ted_expressive_hierarchy 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin # from https://fasttext.cc/docs/en/english-vectors.html 9 | #freeze_wordembed: true 10 | 11 | model_save_path: TED-Expressive-output/train_hierarchy 12 | random_seed: -1 13 | 14 | # model params 15 | model: hierarchy 16 | mean_pose: [-0.0046788, -0.5397806, 0.007695 , -0.0171913, -0.7060388,-0.0107034, 0.1550734, -0.6823077, -0.0303645, -0.1514748, -0.6819547, -0.0268262, 0.2094328, -0.469447 , -0.0096073, -0.2318253, -0.4680838, -0.0444074, 0.1667382, -0.4643363, -0.1895118, -0.1648597, -0.4552845, -0.2159728, 0.1387546, -0.4859474, -0.2506667, 0.1263615, -0.4856088, -0.2675801, 0.1149031, -0.4804542, -0.267329 , 0.1414847, -0.4727709, -0.2583424, 0.1262482, -0.4686185, -0.2682536, 0.1150217, -0.4633611, -0.2640182, 0.1475897, -0.4415648, -0.2438853, 0.1367996, -0.4383164, -0.248248 , 0.1267222, -0.435534 , -0.2455436, 0.1455485, -0.4557491, -0.2521977, 0.1305471, -0.4535603, -0.2611591, 0.1184687, -0.4495366, -0.257798 , 0.1451682, -0.4802511, -0.2081622, 0.1301337, -0.4865308, -0.2175783, 0.1208341, -0.4932623, -0.2311025, -0.1409241,-0.4742868, -0.2795303, -0.1287992, -0.4724431, -0.2963172,-0.1159225, -0.4676439, -0.2948754, -0.1427748, -0.4589126,-0.2861245, -0.126862 , -0.4547355, -0.2962466, -0.1140265,-0.451308 , -0.2913815, -0.1447202, -0.4260471, -0.2697673,-0.1333492, -0.4239912, -0.2738043, -0.1226859, -0.4238346,-0.2706725, -0.1446909, -0.440342 , -0.2789209, -0.1291436,-0.4391063, -0.2876539, -0.1160435, -0.4376317, -0.2836147,-0.1441438, -0.4729031, -0.2355619, -0.1293268, -0.4793807,-0.2468831, -0.1204146, -0.4847246, -0.2613876, -0.0056085,-0.9224338, -0.1677302, -0.0352157, -0.963936 , -0.1388849,0.0236298, -0.9650772, -0.1385154, -0.0697098, -0.9514691,-0.055632 , 0.0568838, -0.9565502, -0.0567985] 17 | # mean_dir_vec: [-0.0738017, -0.9841649, -0.1082826, 0.8811584, 0.2400715,-0.102548 , -0.8614666, 0.3131963, -0.1039408, 0.2094324,0.9249059, 0.0824935, -0.168881 , -0.035383 , -0.7481456,-0.2794113, -0.2494704, -0.611501 , -0.3870275, 0.0050056,-0.528493 , -0.507089 , 0.2256335, 0.0053114, -0.2392472,-0.1022134, -0.6528637, -0.4960522, 0.1227789, -0.3288104, -0.4734803, 0.2132035, 0.174309 , -0.2060902, 0.2305164, -0.5872575, -0.5408988, 0.1303091, -0.2180242, -0.5190851,0.1211129, 0.1337597, -0.2163111, 0.0743368, -0.6415168,-0.5247244, 0.045758 , -0.3192654, -0.5047838, 0.1537042,0.1365917, -0.4349654, -0.3833296, -0.3847372, -0.4905643, -0.241614 , -0.3053538, -0.3553988, -0.2816598, -0.5148801,-0.3064362, 0.8918868, -0.0671329, 0.27626 , 0.006997 ,-0.7267295, 0.2420451, -0.2257494, -0.6342812, 0.3782409,0.0283403, -0.5432504, 0.571155 , 0.1934389, 0.0632587,0.2121533, -0.0624191, -0.6689178, 0.5177199, 0.1043125,-0.3447495, 0.5414736, 0.1280138, 0.2073257, 0.2196023,0.2821322, -0.578687 , 0.5695349, 0.0786589, -0.2132052,0.5497312, -0.0006156, 0.1598387, 0.2091917, 0.1241196,-0.6456096, 0.542575 , 0.0114165, -0.32 , 0.5477947,0.0489276, 0.1676565, 0.418691 , -0.401504 , -0.3910536,0.4823198, -0.2667319, -0.3554963, 0.341438 , -0.2418989,-0.548876 , 0.0485504, -0.6335777, -0.6837655, -0.470992 ,-0.6394692, 0.4633637, 0.4539758, -0.6504886, 0.4596563,-0.3254217, 0.1883226, 0.7889246, 0.3254174, 0.1292426,0.7991108] 18 | mean_dir_vec: [-0.0737964, -0.9968923, -0.1082858, 0.9111595, 0.2399522, -0.102547 , -0.8936886, 0.3131501, -0.1039348, 0.2093927, 0.958293 , 0.0824881, -0.1689021, -0.0353824, -0.7588258, -0.2794763, -0.2495191, -0.614666 , -0.3877234, 0.005006 , -0.5301695, -0.5098616, 0.2257808, 0.0053111, -0.2393621, -0.1022204, -0.6583039, -0.4992898, 0.1228059, -0.3292085, -0.4753748, 0.2132857, 0.1742853, -0.2062069, 0.2305175, -0.5897119, -0.5452555, 0.1303197, -0.2181693, -0.5221036, 0.1211322, 0.1337591, -0.2164441, 0.0743345, -0.6464546, -0.5284583, 0.0457585, -0.319634 , -0.5074904, 0.1537192, 0.1365934, -0.4354402, -0.3836682, -0.3850554, -0.4927187, -0.2417618, -0.3054556, -0.3556116, -0.281753 , -0.5164358, -0.3064435, 0.9284261, -0.067134 , 0.2764367, 0.006997 , -0.7365526, 0.2421269, -0.225798 , -0.6387642, 0.3788997, 0.0283412, -0.5451686, 0.5753376, 0.1935219, 0.0632555, 0.2122412, -0.0624179, -0.6755542, 0.5212831, 0.1043523, -0.345288 , 0.5443628, 0.128029 , 0.2073687, 0.2197118, 0.2821399, -0.580695 , 0.573988 , 0.0786667, -0.2133071, 0.5532452, -0.0006157, 0.1598754, 0.2093099, 0.124119, -0.6504359, 0.5465003, 0.0114155, -0.3203954, 0.5512083, 0.0489287, 0.1676814, 0.4190787, -0.4018607, -0.3912126, 0.4841548, -0.2668508, -0.3557675, 0.3416916, -0.2419564, -0.5509825, 0.0485515, -0.6343101, -0.6817347, -0.4705639, -0.6380668, 0.4641643, 0.4540192, -0.6486361, 0.4604001, -0.3256226, 0.1883097, 0.8057457, 0.3257385, 0.1292366, 0.815372] 19 | 20 | n_layers: 4 21 | hidden_size: 300 22 | z_type: speaker # speaker, random, none 23 | input_context: both # both, audio, text, none 24 | 25 | # train params 26 | epochs: 100 27 | batch_size: 96 28 | learning_rate: 0.0005 29 | loss_regression_weight: 250 30 | loss_gan_weight: 5.0 31 | loss_warmup: 10 32 | loss_kld_weight: 0.1 33 | loss_reg_weight: 0.05 34 | 35 | loss_contrastive_pos_weight: 0.2 36 | loss_contrastive_neg_weight: 0.005 37 | loss_physical_weight: 0.01 38 | 39 | # eval params 40 | eval_net_path: TED-Expressive-output/AE-cos1e-3/checkpoint_best.bin 41 | 42 | # dataset params 43 | motion_resampling_framerate: 15 44 | n_poses: 34 45 | n_pre_poses: 4 46 | subdivision_stride: 10 47 | loader_workers: 4 48 | 49 | pose_dim: 126 -------------------------------------------------------------------------------- /config_expressive/joint_embed.yml: -------------------------------------------------------------------------------- 1 | name: ted_expressive_joint_embedding 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin 9 | 10 | model_save_path: TED-Expressive-output/train_joint_embedding 11 | random_seed: -1 12 | 13 | # model params 14 | model: joint_embedding 15 | mean_pose: [-0.0046788, -0.5397806, 0.007695 , -0.0171913, -0.7060388,-0.0107034, 0.1550734, -0.6823077, -0.0303645, -0.1514748, -0.6819547, -0.0268262, 0.2094328, -0.469447 , -0.0096073, -0.2318253, -0.4680838, -0.0444074, 0.1667382, -0.4643363, -0.1895118, -0.1648597, -0.4552845, -0.2159728, 0.1387546, -0.4859474, -0.2506667, 0.1263615, -0.4856088, -0.2675801, 0.1149031, -0.4804542, -0.267329 , 0.1414847, -0.4727709, -0.2583424, 0.1262482, -0.4686185, -0.2682536, 0.1150217, -0.4633611, -0.2640182, 0.1475897, -0.4415648, -0.2438853, 0.1367996, -0.4383164, -0.248248 , 0.1267222, -0.435534 , -0.2455436, 0.1455485, -0.4557491, -0.2521977, 0.1305471, -0.4535603, -0.2611591, 0.1184687, -0.4495366, -0.257798 , 0.1451682, -0.4802511, -0.2081622, 0.1301337, -0.4865308, -0.2175783, 0.1208341, -0.4932623, -0.2311025, -0.1409241,-0.4742868, -0.2795303, -0.1287992, -0.4724431, -0.2963172,-0.1159225, -0.4676439, -0.2948754, -0.1427748, -0.4589126,-0.2861245, -0.126862 , -0.4547355, -0.2962466, -0.1140265,-0.451308 , -0.2913815, -0.1447202, -0.4260471, -0.2697673,-0.1333492, -0.4239912, -0.2738043, -0.1226859, -0.4238346,-0.2706725, -0.1446909, -0.440342 , -0.2789209, -0.1291436,-0.4391063, -0.2876539, -0.1160435, -0.4376317, -0.2836147,-0.1441438, -0.4729031, -0.2355619, -0.1293268, -0.4793807,-0.2468831, -0.1204146, -0.4847246, -0.2613876, -0.0056085,-0.9224338, -0.1677302, -0.0352157, -0.963936 , -0.1388849,0.0236298, -0.9650772, -0.1385154, -0.0697098, -0.9514691,-0.055632 , 0.0568838, -0.9565502, -0.0567985] 16 | # mean_dir_vec: [-0.0738017, -0.9841649, -0.1082826, 0.8811584, 0.2400715,-0.102548 , -0.8614666, 0.3131963, -0.1039408, 0.2094324,0.9249059, 0.0824935, -0.168881 , -0.035383 , -0.7481456,-0.2794113, -0.2494704, -0.611501 , -0.3870275, 0.0050056,-0.528493 , -0.507089 , 0.2256335, 0.0053114, -0.2392472,-0.1022134, -0.6528637, -0.4960522, 0.1227789, -0.3288104, -0.4734803, 0.2132035, 0.174309 , -0.2060902, 0.2305164, -0.5872575, -0.5408988, 0.1303091, -0.2180242, -0.5190851,0.1211129, 0.1337597, -0.2163111, 0.0743368, -0.6415168,-0.5247244, 0.045758 , -0.3192654, -0.5047838, 0.1537042,0.1365917, -0.4349654, -0.3833296, -0.3847372, -0.4905643, -0.241614 , -0.3053538, -0.3553988, -0.2816598, -0.5148801,-0.3064362, 0.8918868, -0.0671329, 0.27626 , 0.006997 ,-0.7267295, 0.2420451, -0.2257494, -0.6342812, 0.3782409,0.0283403, -0.5432504, 0.571155 , 0.1934389, 0.0632587,0.2121533, -0.0624191, -0.6689178, 0.5177199, 0.1043125,-0.3447495, 0.5414736, 0.1280138, 0.2073257, 0.2196023,0.2821322, -0.578687 , 0.5695349, 0.0786589, -0.2132052,0.5497312, -0.0006156, 0.1598387, 0.2091917, 0.1241196,-0.6456096, 0.542575 , 0.0114165, -0.32 , 0.5477947,0.0489276, 0.1676565, 0.418691 , -0.401504 , -0.3910536,0.4823198, -0.2667319, -0.3554963, 0.341438 , -0.2418989,-0.548876 , 0.0485504, -0.6335777, -0.6837655, -0.470992 ,-0.6394692, 0.4633637, 0.4539758, -0.6504886, 0.4596563,-0.3254217, 0.1883226, 0.7889246, 0.3254174, 0.1292426,0.7991108] 17 | mean_dir_vec: [-0.0737964, -0.9968923, -0.1082858, 0.9111595, 0.2399522, -0.102547 , -0.8936886, 0.3131501, -0.1039348, 0.2093927, 0.958293 , 0.0824881, -0.1689021, -0.0353824, -0.7588258, -0.2794763, -0.2495191, -0.614666 , -0.3877234, 0.005006 , -0.5301695, -0.5098616, 0.2257808, 0.0053111, -0.2393621, -0.1022204, -0.6583039, -0.4992898, 0.1228059, -0.3292085, -0.4753748, 0.2132857, 0.1742853, -0.2062069, 0.2305175, -0.5897119, -0.5452555, 0.1303197, -0.2181693, -0.5221036, 0.1211322, 0.1337591, -0.2164441, 0.0743345, -0.6464546, -0.5284583, 0.0457585, -0.319634 , -0.5074904, 0.1537192, 0.1365934, -0.4354402, -0.3836682, -0.3850554, -0.4927187, -0.2417618, -0.3054556, -0.3556116, -0.281753 , -0.5164358, -0.3064435, 0.9284261, -0.067134 , 0.2764367, 0.006997 , -0.7365526, 0.2421269, -0.225798 , -0.6387642, 0.3788997, 0.0283412, -0.5451686, 0.5753376, 0.1935219, 0.0632555, 0.2122412, -0.0624179, -0.6755542, 0.5212831, 0.1043523, -0.345288 , 0.5443628, 0.128029 , 0.2073687, 0.2197118, 0.2821399, -0.580695 , 0.573988 , 0.0786667, -0.2133071, 0.5532452, -0.0006157, 0.1598754, 0.2093099, 0.124119, -0.6504359, 0.5465003, 0.0114155, -0.3203954, 0.5512083, 0.0489287, 0.1676814, 0.4190787, -0.4018607, -0.3912126, 0.4841548, -0.2668508, -0.3557675, 0.3416916, -0.2419564, -0.5509825, 0.0485515, -0.6343101, -0.6817347, -0.4705639, -0.6380668, 0.4641643, 0.4540192, -0.6486361, 0.4604001, -0.3256226, 0.1883097, 0.8057457, 0.3257385, 0.1292366, 0.815372] 18 | 19 | # train params 20 | epochs: 100 21 | batch_size: 128 22 | learning_rate: 0.0005 23 | 24 | # eval params 25 | eval_net_path: TED-Expressive-output/AE-cos1e-3/checkpoint_best.bin 26 | 27 | # dataset params 28 | motion_resampling_framerate: 15 29 | n_poses: 34 30 | n_pre_poses: 4 31 | subdivision_stride: 10 32 | loader_workers: 3 33 | 34 | pose_dim: 126 -------------------------------------------------------------------------------- /config_expressive/multimodal_context.yml: -------------------------------------------------------------------------------- 1 | name: ted_expressive_multimodal_context 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin # from https://fasttext.cc/docs/en/english-vectors.html 9 | #freeze_wordembed: true 10 | 11 | model_save_path: TED-Expressive-output/train_multimodal_context 12 | random_seed: -1 13 | 14 | # model params 15 | model: multimodal_context 16 | mean_pose: [-0.0046788, -0.5397806, 0.007695 , -0.0171913, -0.7060388,-0.0107034, 0.1550734, -0.6823077, -0.0303645, -0.1514748, -0.6819547, -0.0268262, 0.2094328, -0.469447 , -0.0096073, -0.2318253, -0.4680838, -0.0444074, 0.1667382, -0.4643363, -0.1895118, -0.1648597, -0.4552845, -0.2159728, 0.1387546, -0.4859474, -0.2506667, 0.1263615, -0.4856088, -0.2675801, 0.1149031, -0.4804542, -0.267329 , 0.1414847, -0.4727709, -0.2583424, 0.1262482, -0.4686185, -0.2682536, 0.1150217, -0.4633611, -0.2640182, 0.1475897, -0.4415648, -0.2438853, 0.1367996, -0.4383164, -0.248248 , 0.1267222, -0.435534 , -0.2455436, 0.1455485, -0.4557491, -0.2521977, 0.1305471, -0.4535603, -0.2611591, 0.1184687, -0.4495366, -0.257798 , 0.1451682, -0.4802511, -0.2081622, 0.1301337, -0.4865308, -0.2175783, 0.1208341, -0.4932623, -0.2311025, -0.1409241,-0.4742868, -0.2795303, -0.1287992, -0.4724431, -0.2963172,-0.1159225, -0.4676439, -0.2948754, -0.1427748, -0.4589126,-0.2861245, -0.126862 , -0.4547355, -0.2962466, -0.1140265,-0.451308 , -0.2913815, -0.1447202, -0.4260471, -0.2697673,-0.1333492, -0.4239912, -0.2738043, -0.1226859, -0.4238346,-0.2706725, -0.1446909, -0.440342 , -0.2789209, -0.1291436,-0.4391063, -0.2876539, -0.1160435, -0.4376317, -0.2836147,-0.1441438, -0.4729031, -0.2355619, -0.1293268, -0.4793807,-0.2468831, -0.1204146, -0.4847246, -0.2613876, -0.0056085,-0.9224338, -0.1677302, -0.0352157, -0.963936 , -0.1388849,0.0236298, -0.9650772, -0.1385154, -0.0697098, -0.9514691,-0.055632 , 0.0568838, -0.9565502, -0.0567985] 17 | # mean_dir_vec: [-0.0738017, -0.9841649, -0.1082826, 0.8811584, 0.2400715,-0.102548 , -0.8614666, 0.3131963, -0.1039408, 0.2094324,0.9249059, 0.0824935, -0.168881 , -0.035383 , -0.7481456,-0.2794113, -0.2494704, -0.611501 , -0.3870275, 0.0050056,-0.528493 , -0.507089 , 0.2256335, 0.0053114, -0.2392472,-0.1022134, -0.6528637, -0.4960522, 0.1227789, -0.3288104, -0.4734803, 0.2132035, 0.174309 , -0.2060902, 0.2305164, -0.5872575, -0.5408988, 0.1303091, -0.2180242, -0.5190851,0.1211129, 0.1337597, -0.2163111, 0.0743368, -0.6415168,-0.5247244, 0.045758 , -0.3192654, -0.5047838, 0.1537042,0.1365917, -0.4349654, -0.3833296, -0.3847372, -0.4905643, -0.241614 , -0.3053538, -0.3553988, -0.2816598, -0.5148801,-0.3064362, 0.8918868, -0.0671329, 0.27626 , 0.006997 ,-0.7267295, 0.2420451, -0.2257494, -0.6342812, 0.3782409,0.0283403, -0.5432504, 0.571155 , 0.1934389, 0.0632587,0.2121533, -0.0624191, -0.6689178, 0.5177199, 0.1043125,-0.3447495, 0.5414736, 0.1280138, 0.2073257, 0.2196023,0.2821322, -0.578687 , 0.5695349, 0.0786589, -0.2132052,0.5497312, -0.0006156, 0.1598387, 0.2091917, 0.1241196,-0.6456096, 0.542575 , 0.0114165, -0.32 , 0.5477947,0.0489276, 0.1676565, 0.418691 , -0.401504 , -0.3910536,0.4823198, -0.2667319, -0.3554963, 0.341438 , -0.2418989,-0.548876 , 0.0485504, -0.6335777, -0.6837655, -0.470992 ,-0.6394692, 0.4633637, 0.4539758, -0.6504886, 0.4596563,-0.3254217, 0.1883226, 0.7889246, 0.3254174, 0.1292426,0.7991108] 18 | mean_dir_vec: [-0.0737964, -0.9968923, -0.1082858, 0.9111595, 0.2399522, -0.102547 , -0.8936886, 0.3131501, -0.1039348, 0.2093927, 0.958293 , 0.0824881, -0.1689021, -0.0353824, -0.7588258, -0.2794763, -0.2495191, -0.614666 , -0.3877234, 0.005006 , -0.5301695, -0.5098616, 0.2257808, 0.0053111, -0.2393621, -0.1022204, -0.6583039, -0.4992898, 0.1228059, -0.3292085, -0.4753748, 0.2132857, 0.1742853, -0.2062069, 0.2305175, -0.5897119, -0.5452555, 0.1303197, -0.2181693, -0.5221036, 0.1211322, 0.1337591, -0.2164441, 0.0743345, -0.6464546, -0.5284583, 0.0457585, -0.319634 , -0.5074904, 0.1537192, 0.1365934, -0.4354402, -0.3836682, -0.3850554, -0.4927187, -0.2417618, -0.3054556, -0.3556116, -0.281753 , -0.5164358, -0.3064435, 0.9284261, -0.067134 , 0.2764367, 0.006997 , -0.7365526, 0.2421269, -0.225798 , -0.6387642, 0.3788997, 0.0283412, -0.5451686, 0.5753376, 0.1935219, 0.0632555, 0.2122412, -0.0624179, -0.6755542, 0.5212831, 0.1043523, -0.345288 , 0.5443628, 0.128029 , 0.2073687, 0.2197118, 0.2821399, -0.580695 , 0.573988 , 0.0786667, -0.2133071, 0.5532452, -0.0006157, 0.1598754, 0.2093099, 0.124119, -0.6504359, 0.5465003, 0.0114155, -0.3203954, 0.5512083, 0.0489287, 0.1676814, 0.4190787, -0.4018607, -0.3912126, 0.4841548, -0.2668508, -0.3557675, 0.3416916, -0.2419564, -0.5509825, 0.0485515, -0.6343101, -0.6817347, -0.4705639, -0.6380668, 0.4641643, 0.4540192, -0.6486361, 0.4604001, -0.3256226, 0.1883097, 0.8057457, 0.3257385, 0.1292366, 0.815372] 19 | 20 | n_layers: 4 21 | hidden_size: 300 22 | z_type: speaker # speaker, random, none 23 | input_context: both # both, audio, text, none 24 | 25 | # train params 26 | epochs: 100 27 | batch_size: 128 28 | learning_rate: 0.0005 29 | loss_regression_weight: 500 30 | loss_gan_weight: 5.0 31 | loss_warmup: 10 32 | loss_kld_weight: 0.1 33 | loss_reg_weight: 0.05 34 | 35 | # eval params 36 | eval_net_path: TED-Expressive-output/AE-cos1e-3/checkpoint_best.bin 37 | 38 | # dataset params 39 | motion_resampling_framerate: 15 40 | n_poses: 34 41 | n_pre_poses: 4 42 | subdivision_stride: 10 43 | loader_workers: 4 44 | 45 | pose_dim: 126 -------------------------------------------------------------------------------- /config_expressive/seq2seq.yml: -------------------------------------------------------------------------------- 1 | name: ted_expressive_seq2seq 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: data/fasttext/crawl-300d-2M-subword.bin 9 | 10 | model_save_path: TED-Expressive-output/train_seq2seq 11 | random_seed: -1 12 | 13 | # model params 14 | model: seq2seq 15 | mean_pose: [-0.0046788, -0.5397806, 0.007695 , -0.0171913, -0.7060388,-0.0107034, 0.1550734, -0.6823077, -0.0303645, -0.1514748, -0.6819547, -0.0268262, 0.2094328, -0.469447 , -0.0096073, -0.2318253, -0.4680838, -0.0444074, 0.1667382, -0.4643363, -0.1895118, -0.1648597, -0.4552845, -0.2159728, 0.1387546, -0.4859474, -0.2506667, 0.1263615, -0.4856088, -0.2675801, 0.1149031, -0.4804542, -0.267329 , 0.1414847, -0.4727709, -0.2583424, 0.1262482, -0.4686185, -0.2682536, 0.1150217, -0.4633611, -0.2640182, 0.1475897, -0.4415648, -0.2438853, 0.1367996, -0.4383164, -0.248248 , 0.1267222, -0.435534 , -0.2455436, 0.1455485, -0.4557491, -0.2521977, 0.1305471, -0.4535603, -0.2611591, 0.1184687, -0.4495366, -0.257798 , 0.1451682, -0.4802511, -0.2081622, 0.1301337, -0.4865308, -0.2175783, 0.1208341, -0.4932623, -0.2311025, -0.1409241,-0.4742868, -0.2795303, -0.1287992, -0.4724431, -0.2963172,-0.1159225, -0.4676439, -0.2948754, -0.1427748, -0.4589126,-0.2861245, -0.126862 , -0.4547355, -0.2962466, -0.1140265,-0.451308 , -0.2913815, -0.1447202, -0.4260471, -0.2697673,-0.1333492, -0.4239912, -0.2738043, -0.1226859, -0.4238346,-0.2706725, -0.1446909, -0.440342 , -0.2789209, -0.1291436,-0.4391063, -0.2876539, -0.1160435, -0.4376317, -0.2836147,-0.1441438, -0.4729031, -0.2355619, -0.1293268, -0.4793807,-0.2468831, -0.1204146, -0.4847246, -0.2613876, -0.0056085,-0.9224338, -0.1677302, -0.0352157, -0.963936 , -0.1388849,0.0236298, -0.9650772, -0.1385154, -0.0697098, -0.9514691,-0.055632 , 0.0568838, -0.9565502, -0.0567985] 16 | # mean_dir_vec: [-0.0738017, -0.9841649, -0.1082826, 0.8811584, 0.2400715,-0.102548 , -0.8614666, 0.3131963, -0.1039408, 0.2094324,0.9249059, 0.0824935, -0.168881 , -0.035383 , -0.7481456,-0.2794113, -0.2494704, -0.611501 , -0.3870275, 0.0050056,-0.528493 , -0.507089 , 0.2256335, 0.0053114, -0.2392472,-0.1022134, -0.6528637, -0.4960522, 0.1227789, -0.3288104, -0.4734803, 0.2132035, 0.174309 , -0.2060902, 0.2305164, -0.5872575, -0.5408988, 0.1303091, -0.2180242, -0.5190851,0.1211129, 0.1337597, -0.2163111, 0.0743368, -0.6415168,-0.5247244, 0.045758 , -0.3192654, -0.5047838, 0.1537042,0.1365917, -0.4349654, -0.3833296, -0.3847372, -0.4905643, -0.241614 , -0.3053538, -0.3553988, -0.2816598, -0.5148801,-0.3064362, 0.8918868, -0.0671329, 0.27626 , 0.006997 ,-0.7267295, 0.2420451, -0.2257494, -0.6342812, 0.3782409,0.0283403, -0.5432504, 0.571155 , 0.1934389, 0.0632587,0.2121533, -0.0624191, -0.6689178, 0.5177199, 0.1043125,-0.3447495, 0.5414736, 0.1280138, 0.2073257, 0.2196023,0.2821322, -0.578687 , 0.5695349, 0.0786589, -0.2132052,0.5497312, -0.0006156, 0.1598387, 0.2091917, 0.1241196,-0.6456096, 0.542575 , 0.0114165, -0.32 , 0.5477947,0.0489276, 0.1676565, 0.418691 , -0.401504 , -0.3910536,0.4823198, -0.2667319, -0.3554963, 0.341438 , -0.2418989,-0.548876 , 0.0485504, -0.6335777, -0.6837655, -0.470992 ,-0.6394692, 0.4633637, 0.4539758, -0.6504886, 0.4596563,-0.3254217, 0.1883226, 0.7889246, 0.3254174, 0.1292426,0.7991108] 17 | mean_dir_vec: [-0.0737964, -0.9968923, -0.1082858, 0.9111595, 0.2399522, -0.102547 , -0.8936886, 0.3131501, -0.1039348, 0.2093927, 0.958293 , 0.0824881, -0.1689021, -0.0353824, -0.7588258, -0.2794763, -0.2495191, -0.614666 , -0.3877234, 0.005006 , -0.5301695, -0.5098616, 0.2257808, 0.0053111, -0.2393621, -0.1022204, -0.6583039, -0.4992898, 0.1228059, -0.3292085, -0.4753748, 0.2132857, 0.1742853, -0.2062069, 0.2305175, -0.5897119, -0.5452555, 0.1303197, -0.2181693, -0.5221036, 0.1211322, 0.1337591, -0.2164441, 0.0743345, -0.6464546, -0.5284583, 0.0457585, -0.319634 , -0.5074904, 0.1537192, 0.1365934, -0.4354402, -0.3836682, -0.3850554, -0.4927187, -0.2417618, -0.3054556, -0.3556116, -0.281753 , -0.5164358, -0.3064435, 0.9284261, -0.067134 , 0.2764367, 0.006997 , -0.7365526, 0.2421269, -0.225798 , -0.6387642, 0.3788997, 0.0283412, -0.5451686, 0.5753376, 0.1935219, 0.0632555, 0.2122412, -0.0624179, -0.6755542, 0.5212831, 0.1043523, -0.345288 , 0.5443628, 0.128029 , 0.2073687, 0.2197118, 0.2821399, -0.580695 , 0.573988 , 0.0786667, -0.2133071, 0.5532452, -0.0006157, 0.1598754, 0.2093099, 0.124119, -0.6504359, 0.5465003, 0.0114155, -0.3203954, 0.5512083, 0.0489287, 0.1676814, 0.4190787, -0.4018607, -0.3912126, 0.4841548, -0.2668508, -0.3557675, 0.3416916, -0.2419564, -0.5509825, 0.0485515, -0.6343101, -0.6817347, -0.4705639, -0.6380668, 0.4641643, 0.4540192, -0.6486361, 0.4604001, -0.3256226, 0.1883097, 0.8057457, 0.3257385, 0.1292366, 0.815372] 18 | 19 | hidden_size: 200 20 | n_layers: 2 21 | dropout_prob: 0.1 22 | z_type: none 23 | 24 | # train params 25 | epochs: 100 26 | batch_size: 128 27 | learning_rate: 0.0001 28 | loss_regression_weight: 250 29 | loss_kld_weight: 0.1 # weight for continuous motion term 30 | loss_reg_weight: 25 # weight for motion variance term 31 | 32 | # eval params 33 | eval_net_path: TED-Expressive-output/AE-cos1e-3/checkpoint_best.bin 34 | 35 | # dataset params 36 | motion_resampling_framerate: 15 37 | n_poses: 34 38 | n_pre_poses: 4 39 | subdivision_stride: 10 40 | loader_workers: 3 41 | 42 | pose_dim: 126 -------------------------------------------------------------------------------- /config_expressive/speech2gesture.yml: -------------------------------------------------------------------------------- 1 | name: ted_expressive_speech2gesture 2 | 3 | train_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/train 4 | val_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/val 5 | test_data_path: /mnt/lustressd/share/liuxian.vendor/complex_dataset/test 6 | 7 | model_save_path: TED-Expressive-output/train_speech2gesture 8 | random_seed: -1 9 | 10 | # model params 11 | model: speech2gesture 12 | mean_pose: [-0.0046788, -0.5397806, 0.007695 , -0.0171913, -0.7060388,-0.0107034, 0.1550734, -0.6823077, -0.0303645, -0.1514748, -0.6819547, -0.0268262, 0.2094328, -0.469447 , -0.0096073, -0.2318253, -0.4680838, -0.0444074, 0.1667382, -0.4643363, -0.1895118, -0.1648597, -0.4552845, -0.2159728, 0.1387546, -0.4859474, -0.2506667, 0.1263615, -0.4856088, -0.2675801, 0.1149031, -0.4804542, -0.267329 , 0.1414847, -0.4727709, -0.2583424, 0.1262482, -0.4686185, -0.2682536, 0.1150217, -0.4633611, -0.2640182, 0.1475897, -0.4415648, -0.2438853, 0.1367996, -0.4383164, -0.248248 , 0.1267222, -0.435534 , -0.2455436, 0.1455485, -0.4557491, -0.2521977, 0.1305471, -0.4535603, -0.2611591, 0.1184687, -0.4495366, -0.257798 , 0.1451682, -0.4802511, -0.2081622, 0.1301337, -0.4865308, -0.2175783, 0.1208341, -0.4932623, -0.2311025, -0.1409241,-0.4742868, -0.2795303, -0.1287992, -0.4724431, -0.2963172,-0.1159225, -0.4676439, -0.2948754, -0.1427748, -0.4589126,-0.2861245, -0.126862 , -0.4547355, -0.2962466, -0.1140265,-0.451308 , -0.2913815, -0.1447202, -0.4260471, -0.2697673,-0.1333492, -0.4239912, -0.2738043, -0.1226859, -0.4238346,-0.2706725, -0.1446909, -0.440342 , -0.2789209, -0.1291436,-0.4391063, -0.2876539, -0.1160435, -0.4376317, -0.2836147,-0.1441438, -0.4729031, -0.2355619, -0.1293268, -0.4793807,-0.2468831, -0.1204146, -0.4847246, -0.2613876, -0.0056085,-0.9224338, -0.1677302, -0.0352157, -0.963936 , -0.1388849,0.0236298, -0.9650772, -0.1385154, -0.0697098, -0.9514691,-0.055632 , 0.0568838, -0.9565502, -0.0567985] 13 | # mean_dir_vec: [-0.0738017, -0.9841649, -0.1082826, 0.8811584, 0.2400715,-0.102548 , -0.8614666, 0.3131963, -0.1039408, 0.2094324,0.9249059, 0.0824935, -0.168881 , -0.035383 , -0.7481456,-0.2794113, -0.2494704, -0.611501 , -0.3870275, 0.0050056,-0.528493 , -0.507089 , 0.2256335, 0.0053114, -0.2392472,-0.1022134, -0.6528637, -0.4960522, 0.1227789, -0.3288104, -0.4734803, 0.2132035, 0.174309 , -0.2060902, 0.2305164, -0.5872575, -0.5408988, 0.1303091, -0.2180242, -0.5190851,0.1211129, 0.1337597, -0.2163111, 0.0743368, -0.6415168,-0.5247244, 0.045758 , -0.3192654, -0.5047838, 0.1537042,0.1365917, -0.4349654, -0.3833296, -0.3847372, -0.4905643, -0.241614 , -0.3053538, -0.3553988, -0.2816598, -0.5148801,-0.3064362, 0.8918868, -0.0671329, 0.27626 , 0.006997 ,-0.7267295, 0.2420451, -0.2257494, -0.6342812, 0.3782409,0.0283403, -0.5432504, 0.571155 , 0.1934389, 0.0632587,0.2121533, -0.0624191, -0.6689178, 0.5177199, 0.1043125,-0.3447495, 0.5414736, 0.1280138, 0.2073257, 0.2196023,0.2821322, -0.578687 , 0.5695349, 0.0786589, -0.2132052,0.5497312, -0.0006156, 0.1598387, 0.2091917, 0.1241196,-0.6456096, 0.542575 , 0.0114165, -0.32 , 0.5477947,0.0489276, 0.1676565, 0.418691 , -0.401504 , -0.3910536,0.4823198, -0.2667319, -0.3554963, 0.341438 , -0.2418989,-0.548876 , 0.0485504, -0.6335777, -0.6837655, -0.470992 ,-0.6394692, 0.4633637, 0.4539758, -0.6504886, 0.4596563,-0.3254217, 0.1883226, 0.7889246, 0.3254174, 0.1292426,0.7991108] 14 | mean_dir_vec: [-0.0737964, -0.9968923, -0.1082858, 0.9111595, 0.2399522, -0.102547 , -0.8936886, 0.3131501, -0.1039348, 0.2093927, 0.958293 , 0.0824881, -0.1689021, -0.0353824, -0.7588258, -0.2794763, -0.2495191, -0.614666 , -0.3877234, 0.005006 , -0.5301695, -0.5098616, 0.2257808, 0.0053111, -0.2393621, -0.1022204, -0.6583039, -0.4992898, 0.1228059, -0.3292085, -0.4753748, 0.2132857, 0.1742853, -0.2062069, 0.2305175, -0.5897119, -0.5452555, 0.1303197, -0.2181693, -0.5221036, 0.1211322, 0.1337591, -0.2164441, 0.0743345, -0.6464546, -0.5284583, 0.0457585, -0.319634 , -0.5074904, 0.1537192, 0.1365934, -0.4354402, -0.3836682, -0.3850554, -0.4927187, -0.2417618, -0.3054556, -0.3556116, -0.281753 , -0.5164358, -0.3064435, 0.9284261, -0.067134 , 0.2764367, 0.006997 , -0.7365526, 0.2421269, -0.225798 , -0.6387642, 0.3788997, 0.0283412, -0.5451686, 0.5753376, 0.1935219, 0.0632555, 0.2122412, -0.0624179, -0.6755542, 0.5212831, 0.1043523, -0.345288 , 0.5443628, 0.128029 , 0.2073687, 0.2197118, 0.2821399, -0.580695 , 0.573988 , 0.0786667, -0.2133071, 0.5532452, -0.0006157, 0.1598754, 0.2093099, 0.124119, -0.6504359, 0.5465003, 0.0114155, -0.3203954, 0.5512083, 0.0489287, 0.1676814, 0.4190787, -0.4018607, -0.3912126, 0.4841548, -0.2668508, -0.3557675, 0.3416916, -0.2419564, -0.5509825, 0.0485515, -0.6343101, -0.6817347, -0.4705639, -0.6380668, 0.4641643, 0.4540192, -0.6486361, 0.4604001, -0.3256226, 0.1883097, 0.8057457, 0.3257385, 0.1292366, 0.815372] 15 | 16 | # train params 17 | epochs: 100 18 | batch_size: 128 19 | learning_rate: 0.001 20 | loss_regression_weight: 100 21 | loss_gan_weight: 10.0 22 | 23 | # eval params 24 | eval_net_path: TED-Expressive-output/AE-cos1e-3/checkpoint_best.bin 25 | 26 | # dataset params 27 | motion_resampling_framerate: 15 28 | n_poses: 34 29 | n_pre_poses: 4 30 | subdivision_stride: 10 31 | loader_workers: 3 32 | 33 | pose_dim: 126 -------------------------------------------------------------------------------- /dataset_script/README.md: -------------------------------------------------------------------------------- 1 | # TED Expressive Dataset 2 | 3 | This folder contains the scripts to build *TED Expressive Dataset*. 4 | You can download Youtube videos and transcripts, divide the videos into scenes, and extract human poses. Note that this dataset is built upon *TED Gesture Dataset* by Yoon et al., where we extend the pose annotations of 3D finger keypoints. 5 | Please see the project page and paper for more details. 6 | 7 | [Project](https://alvinliu0.github.io/projects/HA2G) | [Paper](https://arxiv.org/pdf/2203.13161.pdf) | [Demo](https://www.youtube.com/watch?v=CG632W-nIWk) 8 | 9 | ## Environment 10 | 11 | The scripts are tested on Ubuntu 16.04 LTS and Python 3.5.2. 12 | 13 | #### Dependencies 14 | * [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose) (v1.4) for pose estimation 15 | * [ExPose](https://github.com/vchoutas/expose) for 3d pose estimation 16 | * [PySceneDetect](https://pyscenedetect.readthedocs.io/en/latest/) (v0.5) for video scene segmentation 17 | * [OpenCV](https://pypi.org/project/opencv-python/) (v3.4) for video read 18 | * We use FFMPEG. Use the latest pip version of opencv-python or build OpenCV with FFMPEG. 19 | * [Gentle](https://github.com/lowerquality/gentle) (Jan. 2019 version) for transcript alignment 20 | * Download the source code from Gentle github and run ./install.sh. And then, you can import gentle library by specifying the path to the library. See `run_gentle.py`. 21 | * Add an option `-vn` to resample.py in gentle as follows: 22 | ```python 23 | cmd = [ 24 | FFMPEG, 25 | '-loglevel', 'panic', 26 | '-y', 27 | ] + offset + [ 28 | '-i', infile, 29 | ] + duration + [ 30 | '-vn', # ADDED (it blocks video streams, see the ffmpeg option) 31 | '-ac', '1', '-ar', '8000', 32 | '-acodec', 'pcm_s16le', 33 | outfile 34 | ] 35 | ``` 36 | 37 | ## A step-by-step guide 38 | 39 | 1. Set config 40 | * Update paths and youtube developer key in `config.py` (the directories will be created if not exist). 41 | * Update target channel ID. The scripts are tested for TED and LaughFactory channels. 42 | 43 | 2. Execute `download_video.py` 44 | * Download youtube videos, metadata, and subtitles (./videos_ted/*.mp4, *.json, *.vtt). 45 | 46 | 3. Execute `run_mp3.py` 47 | * Extract the audio files from the video files by ffmpeg (./audio_ted/*.mp3). 48 | 49 | 4. Execute `run_openpose.py` 50 | * Run [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose) to extract body, hand, and face skeletons for all videos (./temp_skeleton_raw/vid/keypoints/*.json). 51 | 52 | 5. Execute `run_ffmpeg.py` 53 | * Since the codebase of ExPose requires both the raw images and OpenPose keypoint json files for inference. We first extract all the raw image frames via ffmpeg (./temp_skeleton_raw/vid/images/*.png). 54 | 55 | 6. Execute `run_expose.py` 56 | * Run [ExPose](https://github.com/vchoutas/expose) to extract 3D human body, hand (contain finger), and face skeletons for all videos (./expose_ted/vid/*.npz). 57 | * Note that during our implementation, I fail to set up the open3d environment required by ExPose in the slurm without sudo. Hence I modify the inference code to avoid from such dependency. Besides, the output format of ExPose is slightly changed to better facilitate the dataset building (i.e., save extra estimated camera paramters for 3D keypoints visualization in Step 10). You could substitute the original `inference.py` under the ExPose directory by the [modified version code](https://github.com/alvinliu0/HA2G/blob/main/dataset_script/script/inference.py). 58 | 59 | 7. Execute `run_scenedetect.py` 60 | * Run [PySceneDetect](https://pyscenedetect.readthedocs.io/en/latest/) to divide videos into scene clips (./clip_ted/*.csv). 61 | 62 | 8. Execute `run_gentle.py` 63 | * Run [Gentle](https://github.com/lowerquality/gentle) for word-level alignments (./videos_ted/*_align_results.json). 64 | * You should skip this step if you use auto-generated subtitles. This step is necessary for the TED Talks channel. 65 | 66 | 9. Execute `run_clip_filtering.py` 67 | * Remove inappropriate clips. 68 | * Save clips with body skeletons (./filter_res/vid/*.json). 69 | 70 | 10. *(optional)* Execute `review_filtered_clips.py` 71 | * Review filtering results. Note that different from the original process that visualize the 2D keypoints on the image, we additionally support the visualization of 3D keypoint extracted by ExPose based on coordinates and camera parameters. 72 | 73 | 75 | 76 | 11. Execute `make_ted_dataset.py` 77 | * Do some post-processing and split into train, validation, and test sets (./whole_output/*.pickle). 78 | 79 | Note: Since the overall data pre-processing is quite time-consuming via single-thread execution, you could manually implement the dataset pre-processing in a multi-processing manner by splitting the vid range, i.e., process a subset of vid files each time by: 80 | 81 | ```python 82 | all_file_list = sorted(glob.glob(path_to_files_that_you_want_to_process), key=os.path.getmtime) 83 | subset_file_list = all_file_list[start_idx:end_idx] 84 | for each_file in subset_file_list: 85 | # execute the processing code here 86 | ``` 87 | 88 | In this way, you may get multiple dataset subsets files, you could merge them together into a single pickle file and finally transform into dataset file of lmdb format in consistent with our paper's implementation. A sample dataset merge file is given in `merge_dataset.py`. You may need to do some modifications to make it work properly according your dataset split implementation. 89 | 90 | ## Pre-built TED gesture dataset 91 | 92 | Running whole data collection pipeline is complex and takes several days, so we provide the pre-built dataset for the videos in the TED channel. 93 | 94 | [OneDrive Download Link](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155165198_link_cuhk_edu_hk/EQhOOXYsZDhJs-oEVwA7oyABSrkwcTKC6kwX-A85r0-42g?e=BiIsV1) 95 | 96 | ### Download videos and transcripts 97 | We do not provide the videos and transcripts of TED talks due to copyright issues. 98 | You should download actual videos and transcripts by yourself as follows: 99 | 1. Download and copy [[video_ids.txt]](https://github.com/alvinliu0/HA2G/blob/main/dataset_script/video_ids.txt) file which contains video ids into `./videos_ted` directory. 100 | 2. Run `download_video.py`. It downloads the videos and transcripts in `video_ids.txt`. 101 | Some videos may not match to the extracted poses that we provided if the videos are re-uploaded. 102 | Please compare the numbers of frames, just in case. 103 | 104 | 105 | ## Citation 106 | 107 | If you find our code or data useful, please kindly cite our work as: 108 | ``` 109 | @inproceedings{liu2022learning, 110 | title={Learning Hierarchical Cross-Modal Association for Co-Speech Gesture Generation}, 111 | author={Liu, Xian and Wu, Qianyi and Zhou, Hang and Xu, Yinghao and Qian, Rui and Lin, Xinyi and Zhou, Xiaowei and Wu, Wayne and Dai, Bo and Zhou, Bolei}, 112 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 113 | pages={10462--10472}, 114 | year={2022} 115 | } 116 | ``` 117 | 118 | Since the dataset is built upon previous works of Yoon et al., we also kindly ask you to cite their great paper: 119 | ``` 120 | @INPROCEEDINGS{ 121 | yoonICRA19, 122 | title={Robots Learn Social Skills: End-to-End Learning of Co-Speech Gesture Generation for Humanoid Robots}, 123 | author={Yoon, Youngwoo and Ko, Woo-Ri and Jang, Minsu and Lee, Jaeyeon and Kim, Jaehong and Lee, Geehyuk}, 124 | booktitle={Proc. of The International Conference in Robotics and Automation (ICRA)}, 125 | year={2019} 126 | } 127 | ``` 128 | 129 | 130 | ## Acknowledgement 131 | * Part of the dataset establishment code is developed based on [Youtube Gesture Dataset](https://github.com/youngwoo-yoon/youtube-gesture-dataset) of Yoon et al. 132 | * The dataset establishment process involves some existing assets, including [Gentle](https://github.com/lowerquality/gentle), [PySceneDetect](https://pyscenedetect.readthedocs.io/en/latest/), [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose), [ExPose](https://github.com/vchoutas/expose) and [OpenCV](https://pypi.org/project/opencv-python/). Many thanks to the authors' fantastic contributions! -------------------------------------------------------------------------------- /dataset_script/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | scipy 4 | tqdm 5 | Pillow 6 | 7 | google-api-python-client 8 | webvtt-py 9 | youtube-dl 10 | scenedetect 11 | 12 | 13 | -------------------------------------------------------------------------------- /dataset_script/script/calc_mean.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | 5 | from sklearn.preprocessing import normalize 6 | 7 | import glob 8 | import os 9 | import pickle 10 | import sys 11 | 12 | import cv2 13 | import math 14 | import lmdb 15 | import numpy as np 16 | from numpy import float32 17 | from tqdm import tqdm 18 | 19 | import unicodedata 20 | 21 | from data_utils import * 22 | 23 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 24 | 25 | dir_vec_pairs = [ 26 | (0, 1, 0.26), 27 | (1, 2, 0.22), 28 | (1, 3, 0.22), 29 | (2, 4, 0.36), 30 | (4, 6, 0.33), 31 | (6, 13, 0.1), 32 | (6, 14, 0.12), 33 | (6, 15, 0.14), 34 | (6, 16, 0.13), 35 | (6, 17, 0.12), 36 | 37 | (3, 5, 0.36), 38 | (5, 7, 0.33), 39 | (7, 18, 0.1), 40 | (7, 19, 0.12), 41 | (7, 20, 0.14), 42 | (7, 21, 0.13), 43 | (7, 22, 0.12), 44 | 45 | (1, 8, 0.18), 46 | (8, 9, 0.14), 47 | (8, 10, 0.14), 48 | (9, 11, 0.15), 49 | (10, 12, 0.15), 50 | ] 51 | 52 | def calc_mean_dir(): 53 | video_files = sorted(glob.glob(my_config.VIDEO_PATH + "/*.mp4"), key=os.path.getmtime) 54 | # video_files = video_files[:1] 55 | for v_i, video_file in enumerate(tqdm(video_files)): 56 | vid = os.path.split(video_file)[1][-15:-4] 57 | clip_data = load_clip_data(vid) 58 | if clip_data is None: 59 | print('[ERROR] clip data file does not exist!') 60 | break 61 | 62 | video_wrapper = read_video(my_config.VIDEO_PATH, vid) 63 | 64 | for ia, clip in enumerate(clip_data): 65 | # skip FALSE clips 66 | if not clip['clip_info'][2]: 67 | continue 68 | clip_pose_3d = clip['3d'] 69 | for frame in clip_pose_3d[:-1]: 70 | if frame: 71 | joints_full = frame['joints'] 72 | up_joints = np.vstack((joints_full[9], joints_full[12], joints_full[16:22], joints_full[55:60], joints_full[66:76])).astype('float32') 73 | 74 | 75 | def convert_dir_vec_to_pose(vec): 76 | vec = np.array(vec) 77 | 78 | if vec.shape[-1] != 3: 79 | vec = vec.reshape(vec.shape[:-1] + (-1, 3)) 80 | 81 | if len(vec.shape) == 2: 82 | joint_pos = np.zeros((23, 3)) 83 | for j, pair in enumerate(dir_vec_pairs): 84 | joint_pos[pair[1]] = joint_pos[pair[0]] + pair[2] * vec[j] 85 | elif len(vec.shape) == 3: 86 | joint_pos = np.zeros((vec.shape[0], 23, 3)) 87 | for j, pair in enumerate(dir_vec_pairs): 88 | joint_pos[:, pair[1]] = joint_pos[:, pair[0]] + pair[2] * vec[:, j] 89 | elif len(vec.shape) == 4: # (batch, seq, 9, 3) 90 | joint_pos = np.zeros((vec.shape[0], vec.shape[1], 23, 3)) 91 | for j, pair in enumerate(dir_vec_pairs): 92 | joint_pos[:, :, pair[1]] = joint_pos[:, :, pair[0]] + pair[2] * vec[:, :, j] 93 | else: 94 | assert False 95 | 96 | return joint_pos 97 | 98 | def convert_pose_seq_to_dir_vec(pose): 99 | if pose.shape[-1] != 3: 100 | pose = pose.reshape(pose.shape[:-1] + (-1, 3)) 101 | 102 | if len(pose.shape) == 3: 103 | dir_vec = np.zeros((pose.shape[0], len(dir_vec_pairs), 3)) 104 | for i, pair in enumerate(dir_vec_pairs): 105 | dir_vec[:, i] = pose[:, pair[1]] - pose[:, pair[0]] 106 | dir_vec[:, i, :] = normalize(dir_vec[:, i, :], axis=1) # to unit length 107 | elif len(pose.shape) == 4: # (batch, seq, ...) 108 | dir_vec = np.zeros((pose.shape[0], pose.shape[1], len(dir_vec_pairs), 3)) 109 | for i, pair in enumerate(dir_vec_pairs): 110 | dir_vec[:, :, i] = pose[:, :, pair[1]] - pose[:, :, pair[0]] 111 | for j in range(dir_vec.shape[0]): # batch 112 | for i in range(len(dir_vec_pairs)): 113 | dir_vec[j, :, i, :] = normalize(dir_vec[j, :, i, :], axis=1) # to unit length 114 | else: 115 | assert False 116 | 117 | return dir_vec 118 | 119 | if __name__ == '__main__': 120 | calc_mean_dir() -------------------------------------------------------------------------------- /dataset_script/script/clip_filter.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | import cv2 11 | import math 12 | import os 13 | 14 | from data_utils import get_skeleton_from_frame 15 | from config import my_config 16 | 17 | 18 | class ClipFilter: 19 | def __init__(self, vid, video, start_frame_no, end_frame_no, raw_skeleton, main_speaker_skeletons, height, width): 20 | self.skeleton_data = raw_skeleton 21 | self.main_speaker_skeletons = main_speaker_skeletons 22 | self.start_frame_no = start_frame_no 23 | self.end_frame_no = end_frame_no 24 | self.scene_length = end_frame_no - start_frame_no 25 | self.video = video 26 | self.vid = vid 27 | self.filter_option = my_config.FILTER_OPTION 28 | self.height = height 29 | self.width = width 30 | 31 | # filtering criteria variable 32 | self.filtering_results = [0, 0, 0, 0, 0, 0, 0] # too short, many_people, looking_back, joint_missing, looking_sideways, small, picture 33 | self.message = '' 34 | self.debugging_info = ['None', 'None', 'None', 'None', 'None'] # looking back, joint missing, looking sideways, small, picture 35 | 36 | def is_skeleton_back(self, ratio): 37 | n_incorrect_frame = 0 38 | 39 | for ia, skeleton in enumerate(self.main_speaker_skeletons): # frames 40 | body = get_skeleton_from_frame(skeleton) 41 | if body: 42 | if body[2 * 3] > body[5 * 3]: 43 | n_incorrect_frame += 1 44 | else: 45 | n_incorrect_frame += 1 46 | 47 | self.debugging_info[0] = round(n_incorrect_frame / self.scene_length, 3) 48 | 49 | return n_incorrect_frame / self.scene_length > ratio 50 | 51 | def is_skeleton_sideways(self, ratio): 52 | n_incorrect_frame = 0 53 | 54 | for ia, skeleton in enumerate(self.main_speaker_skeletons): # frames 55 | body = get_skeleton_from_frame(skeleton) 56 | if body: 57 | if (body[0] < min(body[2 * 3], body[5 * 3]) or body[0] > max(body[2 * 3], body[5 * 3])): 58 | n_incorrect_frame += 1 59 | else: 60 | n_incorrect_frame += 1 61 | 62 | self.debugging_info[2] = round(n_incorrect_frame / self.scene_length, 3) 63 | 64 | return n_incorrect_frame / self.scene_length > ratio 65 | 66 | def is_skeleton_missing(self, ratio): 67 | n_incorrect_frame = 0 68 | 69 | if self.main_speaker_skeletons == []: 70 | n_incorrect_frame = self.scene_length 71 | else: 72 | for ia, skeleton in enumerate(self.main_speaker_skeletons): # frames 73 | 74 | body = get_skeleton_from_frame(skeleton) 75 | if body: 76 | point_idx = [0, 1, 2, 3, 4, 5, 6, 7] # head and arms 77 | if any(body[idx * 3] == 0 for idx in point_idx): 78 | n_incorrect_frame += 1 79 | else: 80 | # proceed to expose examination 81 | cur_frame = self.start_frame_no + ia 82 | expose_path = my_config.EXPOSE_OUT_PATH + '/' + self.vid + '/' + '%05d'%cur_frame + '.npz' 83 | assert(os.path.exists(expose_path)) 84 | file = np.load(expose_path) 85 | proj_joints = file['proj_joints'] 86 | center = file['center'] 87 | mm = file['focal_length_in_mm'] 88 | px = file['focal_length_in_px'] 89 | proj_joints = proj_joints * px * 10 / mm 90 | proj_joints += center 91 | proj_joints = np.vstack((proj_joints[9], proj_joints[12], proj_joints[16:22], proj_joints[55:60], proj_joints[66:76])) 92 | for point in proj_joints: 93 | x = int(point[0]) 94 | y = int(point[1]) 95 | if (x >= self.width or x <= 0 or y >= self.height or y <= 0): 96 | n_incorrect_frame += 1 97 | break 98 | 99 | else: 100 | n_incorrect_frame += 1 101 | 102 | self.debugging_info[1] = round(n_incorrect_frame / self.scene_length, 3) 103 | return n_incorrect_frame / self.scene_length > ratio 104 | 105 | def is_skeleton_small(self, ratio): 106 | n_incorrect_frame = 0 107 | 108 | def distance(x1, y1, x2, y2): 109 | return np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 110 | 111 | for ia, skeleton in enumerate(self.main_speaker_skeletons): # frames 112 | body = get_skeleton_from_frame(skeleton) 113 | if body: 114 | threshold = self.filter_option['threshold'] # for TED videos in 720p 115 | if distance(body[2 * 3], body[2 * 3 + 1], body[5 * 3], body[5 * 3 + 1]) < threshold: # shoulder length 116 | n_incorrect_frame += 1 117 | else: 118 | n_incorrect_frame += 1 119 | 120 | self.debugging_info[3] = round(n_incorrect_frame / self.scene_length, 3) 121 | return n_incorrect_frame / self.scene_length > ratio 122 | 123 | def is_too_short(self): 124 | MIN_SCENE_LENGTH = 25 * 3 # assumed fps = 25 125 | return self.scene_length < MIN_SCENE_LENGTH 126 | 127 | def is_picture(self): 128 | sampling_interval = int(math.floor(self.scene_length / 5)) 129 | sampling_frames = list(range(self.start_frame_no + sampling_interval, 130 | self.end_frame_no - sampling_interval + 1, sampling_interval)) 131 | frames = [] 132 | for frame_no in sampling_frames: 133 | self.video.set(cv2.CAP_PROP_POS_FRAMES, frame_no) 134 | ret, frame = self.video.read() 135 | frames.append(frame) 136 | 137 | diff = 0 138 | n_diff = 0 139 | for frame, next_frame in zip(frames, frames[1:]): 140 | diff += cv2.norm(frame, next_frame, cv2.NORM_L1) # abs diff 141 | n_diff += 1 142 | diff /= n_diff 143 | self.debugging_info[4] = round(diff, 0) 144 | 145 | return diff < 3000000 146 | 147 | def is_many_people(self): 148 | n_people = [] 149 | for skeleton in self.skeleton_data: 150 | n_people.append(len(skeleton)) 151 | 152 | return len(n_people) > 0 and np.mean(n_people) > 5 153 | 154 | def is_correct_clip(self): 155 | # check if the clip is too short. 156 | if self.is_too_short(): 157 | self.message = "too Short" 158 | return False 159 | self.filtering_results[0] = 1 160 | 161 | # check if there are too many people on the clip 162 | if self.is_many_people(): 163 | self.message = "too many people" 164 | return False 165 | self.filtering_results[1] = 1 166 | 167 | # check if the ratio of back-facing skeletons in the clip exceeds the reference ratio 168 | if self.is_skeleton_back(0.3): 169 | self.message = "looking behind" 170 | return False 171 | self.filtering_results[2] = 1 172 | 173 | # check if the ratio of skeletons that missing joint in the clip exceeds the reference ratio 174 | if self.is_skeleton_missing(0.5): 175 | self.message = "too many missing joints" 176 | return False 177 | self.filtering_results[3] = 1 178 | 179 | # check if the ratio of sideways skeletons in the clip exceeds the reference ratio 180 | if self.is_skeleton_sideways(0.5): 181 | self.message = "looking sideways" 182 | return False 183 | self.filtering_results[4] = 1 184 | 185 | # check if the ratio of the too small skeleton in the clip exceeds the reference ratio 186 | if self.is_skeleton_small(0.5): 187 | self.message = "too small." 188 | return False 189 | self.filtering_results[5] = 1 190 | 191 | # check if the clip is picture 192 | if self.is_picture(): 193 | self.message = "still picture" 194 | return False 195 | self.filtering_results[6] = 1 196 | 197 | self.message = "PASS" 198 | return True 199 | 200 | def get_filter_variable(self): 201 | return self.filtering_results, self.message, self.debugging_info 202 | -------------------------------------------------------------------------------- /dataset_script/script/config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | from datetime import datetime 10 | 11 | 12 | class Config: 13 | DEVELOPER_KEY = "" # your youtube developer id 14 | EXPOSE_BASE_DIR = "/mnt/lustre/liuxian/expose/" 15 | OPENPOSE_BASE_DIR = "/mnt/lustre/liuxian/openpose/" 16 | OPENPOSE_BIN_PATH = "/mnt/lustre/liuxian/openpose/build/examples/openpose/openpose.bin" 17 | 18 | 19 | class TEDConfig(Config): 20 | YOUTUBE_CHANNEL_ID = "UCAuUUnT6oDeKwE6v1NGQxug" 21 | WORK_PATH = '/mnt/lustre/liuxian/final' 22 | CLIP_PATH = WORK_PATH + "/clip_ted" 23 | VIDEO_PATH = WORK_PATH + "/videos_ted" 24 | SKELETON_PATH = WORK_PATH + "/skeleton_ted" 25 | GENTLE_PATH = WORK_PATH + '/gentle_ted' 26 | AUDIO_PATH = WORK_PATH + '/audio_ted' 27 | SUBTITLE_PATH = WORK_PATH + 'sub_ted' 28 | OUTPUT_PATH = WORK_PATH + "/whole_output" 29 | FILTER_PATH = WORK_PATH + '/filter_res' 30 | EXPOSE_OUT_PATH = WORK_PATH + '/expose_ted' 31 | VIDEO_SEARCH_START_DATE = datetime(2011, 3, 1, 0, 0, 0) 32 | LANG = 'en' 33 | SUBTITLE_TYPE = 'gentle' 34 | FILTER_OPTION = {"threshold": 100} 35 | 36 | 37 | class LaughConfig(Config): 38 | YOUTUBE_CHANNEL_ID = "UCxyCzPY2pjAjrxoSYclpuLg" 39 | WORK_PATH = '/mnt/lustre/liuxian/final' 40 | CLIP_PATH = WORK_PATH + "/clip_laugh" 41 | VIDEO_PATH = WORK_PATH + "/videos_laugh" 42 | SKELETON_PATH = WORK_PATH + "/skeleton_laugh" 43 | SUBTITLE_PATH = VIDEO_PATH 44 | OUTPUT_PATH = WORK_PATH + "/output" 45 | VIDEO_SEARCH_START_DATE = datetime(2010, 5, 1, 0, 0, 0) 46 | LANG = 'en' 47 | SUBTITLE_TYPE = 'auto' 48 | FILTER_OPTION = {"threshold": 50} 49 | 50 | 51 | # SET THIS 52 | my_config = TEDConfig 53 | -------------------------------------------------------------------------------- /dataset_script/script/download_video.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | from __future__ import unicode_literals 10 | 11 | import glob 12 | import json 13 | import traceback 14 | 15 | import youtube_dl 16 | import urllib.request 17 | import sys 18 | import os 19 | from apiclient.discovery import build 20 | from datetime import datetime, timedelta 21 | from config import my_config 22 | 23 | YOUTUBE_API_SERVICE_NAME = "youtube" 24 | YOUTUBE_API_VERSION = "v3" 25 | 26 | RESUME_VIDEO_ID = "" # resume downloading from this video, set empty string to start over 27 | 28 | 29 | def fetch_video_ids(channel_id, search_start_time): # load video ids in the channel 30 | youtube = build(YOUTUBE_API_SERVICE_NAME, YOUTUBE_API_VERSION, developerKey=my_config.DEVELOPER_KEY) 31 | 32 | start_time = search_start_time 33 | td = timedelta(days=15) 34 | end_time = start_time + td 35 | 36 | res_items = [] 37 | 38 | # multiple quires are necessary to get all results surely 39 | while start_time < datetime.now(): 40 | start_string = str(start_time.isoformat()) + 'Z' 41 | end_string = str(end_time.isoformat()) + 'Z' 42 | 43 | res = youtube.search().list(part="id", channelId=channel_id, maxResults="50", 44 | publishedAfter=start_string, 45 | publishedBefore=end_string).execute() 46 | res_items += res['items'] 47 | 48 | while True: # paging 49 | if len(res['items']) < 50 or 'nextPageToken' not in res: 50 | break 51 | 52 | next_page_token = res['nextPageToken'] 53 | res = youtube.search().list(part="id", channelId=channel_id, maxResults="50", 54 | publishedAfter=start_string, 55 | publishedBefore=end_string, 56 | pageToken=next_page_token).execute() 57 | res_items += res['items'] 58 | 59 | print(' {} to {}, no of videos: {}'.format(start_string, end_string, len(res_items))) 60 | 61 | start_time = end_time 62 | end_time = start_time + td 63 | 64 | # collect video ids 65 | vid_list = [] 66 | for i in res_items: 67 | vid = (i.get('id')).get('videoId') 68 | if vid is not None: 69 | vid_list.append(vid) 70 | 71 | return vid_list 72 | 73 | 74 | def video_filter(info): 75 | passed = True 76 | 77 | exist_proper_format = False 78 | format_data = info.get('formats') 79 | for i in format_data: 80 | if i.get('ext') == 'mp4' and i.get('height') >= 720 and i.get('acodec') != 'none': 81 | exist_proper_format = True 82 | if not exist_proper_format: 83 | passed = False 84 | 85 | if passed: 86 | duration_hours = info.get('duration') / 3600.0 87 | if duration_hours > 1.0: 88 | passed = False 89 | 90 | if passed: 91 | if len(info.get('automatic_captions')) == 0 and len(info.get('subtitles')) == 0: 92 | passed = False 93 | 94 | return passed 95 | 96 | 97 | def download_subtitle(url, filename, postfix): 98 | urllib.request.urlretrieve(url, '{}-{}.vtt'.format(filename, postfix)) 99 | 100 | 101 | def download(vid_list): 102 | ydl_opts = {'format': 'best[height=720,ext=mp4]', 103 | 'writesubtitles': True, 104 | 'writeautomaticsub': True, 105 | 'outtmpl': 'dummy.mp4' 106 | } # download options 107 | language = my_config.LANG 108 | 109 | download_count = 0 110 | skip_count = 0 111 | sub_count = 0 112 | log = open("download_log.txt", 'w', encoding="utf-8") 113 | 114 | if len(RESUME_VIDEO_ID) < 10: 115 | skip_index = 0 116 | else: 117 | skip_index = vid_list.index(RESUME_VIDEO_ID) 118 | 119 | for i in range(len(vid_list)): 120 | error_count = 0 121 | print(vid_list[i]) 122 | if i < skip_index: 123 | continue 124 | 125 | # rename video (vid.mp4) 126 | ydl_opts['outtmpl'] = my_config.VIDEO_PATH + '/' + vid_list[i] + '.mp4' 127 | 128 | # check existing file 129 | if os.path.exists(ydl_opts['outtmpl']) and os.path.getsize(ydl_opts['outtmpl']): # existing and not empty 130 | print('video file already exists ({})'.format(vid_list[i])) 131 | continue 132 | 133 | with youtube_dl.YoutubeDL(ydl_opts) as ydl: 134 | vid = vid_list[i] 135 | url = "https://youtu.be/{}".format(vid) 136 | 137 | info = ydl.extract_info(url, download=False) 138 | if video_filter(info): 139 | with open("{}.json".format(vid), "w", encoding="utf-8") as js: 140 | json.dump(info, js) 141 | while 1: 142 | if error_count == 3: 143 | print('Exit...') 144 | sys.exit() 145 | try: 146 | ydl.download([url]) 147 | except(youtube_dl.utils.DownloadError, 148 | youtube_dl.utils.ContentTooShortError, 149 | youtube_dl.utils.ExtractorError): 150 | error_count += 1 151 | print(' Retrying... (error count : {})\n'.format(error_count)) 152 | traceback.print_exc() 153 | continue 154 | else: 155 | def get_subtitle_url(subtitles, language, ext): 156 | subtitles = subtitles.get(language) 157 | url = None 158 | for sub in subtitles: 159 | if sub.get('ext') == ext: 160 | url = sub.get('url') 161 | break 162 | return url 163 | 164 | if info.get('subtitles') != {} and (info.get('subtitles')).get(language) != None: 165 | sub_url = get_subtitle_url(info.get('subtitles'), language, 'vtt') 166 | download_subtitle(sub_url, vid, language) 167 | sub_count += 1 168 | if info.get('automatic_captions') != {}: 169 | auto_sub_url = get_subtitle_url(info.get('automatic_captions'), language, 'vtt') 170 | download_subtitle(auto_sub_url, vid, language+'-auto') 171 | 172 | log.write("{} - downloaded\n".format(str(vid))) 173 | download_count += 1 174 | break 175 | else: 176 | log.write("{} - skipped\n".format(str(info.get('id')))) 177 | skip_count += 1 178 | 179 | print(" downloaded: {}, skipped: {}".format(download_count, skip_count)) 180 | 181 | log.write("\nno of subtitles : {}\n".format(sub_count)) 182 | log.write("downloaded: {}, skipped : {}\n".format(download_count, skip_count)) 183 | log.close() 184 | 185 | 186 | def main(): 187 | if not os.path.exists(my_config.VIDEO_PATH): 188 | os.makedirs(my_config.VIDEO_PATH) 189 | 190 | os.chdir(my_config.VIDEO_PATH) 191 | vid_list = [] 192 | 193 | # read video list 194 | try: 195 | rf = open("video_ids.txt", 'r') 196 | except FileNotFoundError: 197 | print("fetching video ids...") 198 | vid_list = fetch_video_ids(my_config.YOUTUBE_CHANNEL_ID, my_config.VIDEO_SEARCH_START_DATE) 199 | wf = open("video_ids.txt", "w") 200 | for j in vid_list: 201 | wf.write(str(j)) 202 | wf.write('\n') 203 | wf.close() 204 | else: 205 | while 1: 206 | value = rf.readline()[:11] 207 | if value == '': 208 | break 209 | vid_list.append(value) 210 | rf.close() 211 | 212 | print("downloading videos...") 213 | download(vid_list) 214 | print("finished downloading videos") 215 | 216 | print("removing unnecessary subtitles...") 217 | for f in glob.glob("*.en.vtt"): 218 | os.remove(f) 219 | 220 | 221 | def test_fetch(): 222 | vid_list = fetch_video_ids(my_config.YOUTUBE_CHANNEL_ID, my_config.VIDEO_SEARCH_START_DATE) 223 | print(vid_list) 224 | print(len(vid_list)) 225 | 226 | 227 | if __name__ == '__main__': 228 | # test_fetch() 229 | main() 230 | 231 | -------------------------------------------------------------------------------- /dataset_script/script/main_speaker_selector.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright 2019 ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | import matplotlib 10 | from config import * 11 | import copy 12 | import os 13 | from data_utils import * 14 | from tqdm import * 15 | from config import * 16 | import numpy as np 17 | 18 | 19 | class MainSpeakerSelector: 20 | def __init__(self, raw_skeleton_chunk): 21 | self.main_speaker_skeletons = self.find_main_speaker_skeletons(raw_skeleton_chunk) 22 | 23 | def get(self): 24 | return self.main_speaker_skeletons 25 | 26 | def find_main_speaker_skeletons(self, raw_skeleton_chunk): 27 | tracked_skeletons = [] 28 | selected_skeletons = [] # reference skeleton 29 | for raw_frame in raw_skeleton_chunk: # frame 30 | tracked_person = [] 31 | if selected_skeletons == []: 32 | # select a main speaker 33 | confidence_list = [] 34 | for person in raw_frame: # people 35 | body = get_skeleton_from_frame(person) 36 | mean_confidence = 0 37 | n_points = 0 38 | 39 | # Calculate the average of confidences of each person 40 | for i in range(8): # upper-body only 41 | x = body[i * 3] 42 | y = body[i * 3 + 1] 43 | confidence = body[i * 3 + 2] 44 | if x > 0 and y > 0 and confidence > 0: 45 | n_points += 1 46 | mean_confidence += confidence 47 | if n_points > 0: 48 | mean_confidence /= n_points 49 | else: 50 | mean_confidence = 0 51 | confidence_list.append(mean_confidence) 52 | 53 | # select main_speaker with the highest average of confidence 54 | if len(confidence_list) > 0: 55 | max_index = confidence_list.index(max(confidence_list)) 56 | selected_skeletons = get_skeleton_from_frame(raw_frame[max_index]) 57 | 58 | if selected_skeletons != []: 59 | # find the closest one to the selected main_speaker's skeleton 60 | tracked_person = self.get_closest_skeleton(raw_frame, selected_skeletons) 61 | 62 | # save 63 | if tracked_person: 64 | skeleton_data = tracked_person 65 | selected_skeletons = get_skeleton_from_frame(tracked_person) 66 | else: 67 | skeleton_data = {} 68 | 69 | tracked_skeletons.append(skeleton_data) 70 | 71 | return tracked_skeletons 72 | 73 | def get_closest_skeleton(self, frame, selected_body): 74 | """ find the closest one to the selected skeleton """ 75 | diff_idx = [i * 3 for i in range(8)] + [i * 3 + 1 for i in range(8)] # upper-body 76 | 77 | min_diff = 10000000 78 | tracked_person = None 79 | for person in frame: # people 80 | body = get_skeleton_from_frame(person) 81 | 82 | diff = 0 83 | n_diff = 0 84 | for i in diff_idx: 85 | if body[i] > 0 and selected_body[i] > 0: 86 | diff += abs(body[i] - selected_body[i]) 87 | n_diff += 1 88 | if n_diff > 0: 89 | diff /= n_diff 90 | if diff < min_diff: 91 | min_diff = diff 92 | tracked_person = person 93 | 94 | base_distance = max(abs(selected_body[0 * 3 + 1] - selected_body[1 * 3 + 1]) * 3, 95 | abs(selected_body[2 * 3] - selected_body[5 * 3]) * 2) 96 | if tracked_person and min_diff > base_distance: # tracking failed 97 | tracked_person = None 98 | 99 | return tracked_person 100 | -------------------------------------------------------------------------------- /dataset_script/script/merge_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | import glob 9 | import os 10 | import pickle 11 | import sys 12 | 13 | import cv2 14 | import math 15 | import lmdb 16 | import numpy as np 17 | from numpy import float32 18 | from tqdm import tqdm 19 | 20 | import unicodedata 21 | import librosa 22 | import pyarrow 23 | 24 | from data_utils import * 25 | 26 | def merge_dataset(): 27 | dataset_train = [] 28 | dataset_val = [] 29 | dataset_test = [] 30 | pickle_folder_list = os.listdir(my_config.OUTPUT_PATH) 31 | 32 | out_lmdb_dir_train = my_config.OUTPUT_PATH + '/train' 33 | out_lmdb_dir_val = my_config.OUTPUT_PATH + '/val' 34 | out_lmdb_dir_test = my_config.OUTPUT_PATH + '/test' 35 | if not os.path.exists(out_lmdb_dir_train): 36 | os.makedirs(out_lmdb_dir_train) 37 | if not os.path.exists(out_lmdb_dir_val): 38 | os.makedirs(out_lmdb_dir_val) 39 | if not os.path.exists(out_lmdb_dir_test): 40 | os.makedirs(out_lmdb_dir_test) 41 | 42 | for dir in pickle_folder_list: 43 | pickle_file = my_config.OUTPUT_PATH + '/' + dir + '/ted_expressive_dataset_train.pickle' 44 | with open(pickle_file, 'rb') as file: 45 | temp_train = pickle.load(file) 46 | dataset_train.extend(temp_train) 47 | 48 | for dir in pickle_folder_list: 49 | pickle_file = my_config.OUTPUT_PATH + '/' + dir + '/ted_expressive_dataset_val.pickle' 50 | with open(pickle_file, 'rb') as file: 51 | temp_val = pickle.load(file) 52 | dataset_val.extend(temp_val) 53 | 54 | for dir in pickle_folder_list: 55 | pickle_file = my_config.OUTPUT_PATH + '/' + dir + '/ted_expressive_dataset_test.pickle' 56 | with open(pickle_file, 'rb') as file: 57 | temp_test = pickle.load(file) 58 | dataset_test.extend(temp_test) 59 | 60 | print('writing to pickle...') 61 | with open(my_config.OUTPUT_PATH + '/' + 'ted_expressive_dataset_train.pickle', 'wb') as f: 62 | pickle.dump(dataset_train, f) 63 | with open(my_config.OUTPUT_PATH + '/' + 'ted_expressive_dataset_val.pickle', 'wb') as f: 64 | pickle.dump(dataset_val, f) 65 | with open(my_config.OUTPUT_PATH + '/' + 'ted_expressive_dataset_test.pickle', 'wb') as f: 66 | pickle.dump(dataset_test, f) 67 | 68 | map_size = 1024 * 100 # in MB 69 | map_size <<= 20 # in B 70 | env_train = lmdb.open(out_lmdb_dir_train, map_size=map_size) 71 | env_val = lmdb.open(out_lmdb_dir_val, map_size=map_size) 72 | env_test = lmdb.open(out_lmdb_dir_test, map_size=map_size) 73 | 74 | # lmdb train 75 | with env_train.begin(write=True) as txn: 76 | for idx, dic in enumerate(dataset_train): 77 | k = '{:010}'.format(idx).encode('ascii') 78 | v = pyarrow.serialize(dic).to_buffer() 79 | txn.put(k, v) 80 | env_train.close() 81 | 82 | # lmdb val 83 | with env_val.begin(write=True) as txn: 84 | for idx, dic in enumerate(dataset_val): 85 | k = '{:010}'.format(idx).encode('ascii') 86 | v = pyarrow.serialize(dic).to_buffer() 87 | txn.put(k, v) 88 | env_val.close() 89 | 90 | # lmdb test 91 | with env_test.begin(write=True) as txn: 92 | for idx, dic in enumerate(dataset_test): 93 | k = '{:010}'.format(idx).encode('ascii') 94 | v = pyarrow.serialize(dic).to_buffer() 95 | txn.put(k, v) 96 | env_test.close() 97 | 98 | 99 | if __name__ == '__main__': 100 | merge_dataset() 101 | -------------------------------------------------------------------------------- /dataset_script/script/motion_preprocessor.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | from scipy.signal import savgol_filter 10 | import numpy as np 11 | from scipy.stats import circvar 12 | 13 | 14 | def normalize_skeleton(data, resize_factor=None): 15 | def distance(x1, y1, x2, y2): 16 | return np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 17 | 18 | anchor_pt = (data[1 * 2], data[1 * 2 + 1]) # neck 19 | if resize_factor is None: 20 | neck_height = float(abs(data[1] - data[1 * 2 + 1])) 21 | shoulder_length = distance(data[1 * 2], data[1 * 2 + 1], data[2 * 2], data[2 * 2 + 1]) + \ 22 | distance(data[1 * 2], data[1 * 2 + 1], data[5 * 2], data[5 * 2 + 1]) 23 | resized_neck_height = neck_height / float(shoulder_length) 24 | if resized_neck_height > 0.6: 25 | resize_factor = shoulder_length * resized_neck_height / 0.6 26 | else: 27 | resize_factor = shoulder_length 28 | 29 | normalized_data = data.copy() 30 | for i in range(0, len(data), 2): 31 | normalized_data[i] = (data[i] - anchor_pt[0]) / resize_factor 32 | normalized_data[i + 1] = (data[i + 1] - anchor_pt[1]) / resize_factor 33 | 34 | return normalized_data, resize_factor 35 | 36 | 37 | class MotionPreprocessor: 38 | def __init__(self, skeletons): 39 | self.skeletons = np.array(skeletons) 40 | self.filtering_message = "PASS" 41 | 42 | def get(self): 43 | assert (self.skeletons is not None) 44 | 45 | # filtering 46 | if self.has_missing_frames(): 47 | self.skeletons = [] 48 | self.filtering_message = "too many missing frames" 49 | 50 | # fill missing joints 51 | if self.skeletons != []: 52 | self.fill_missing_joints() 53 | if self.skeletons is None or np.isnan(self.skeletons).any(): 54 | self.filtering_message = "failed to fill missing joints" 55 | self.skeletons = [] 56 | 57 | # filtering 58 | if self.skeletons != []: 59 | if self.is_static(): 60 | self.skeletons = [] 61 | self.filtering_message = "static motion" 62 | elif self.has_jumping_joint(): 63 | self.skeletons = [] 64 | self.filtering_message = "jumping joint" 65 | 66 | # preprocessing 67 | if self.skeletons != []: 68 | 69 | self.smooth_motion() 70 | 71 | is_side_view = False 72 | self.skeletons = self.skeletons.tolist() 73 | for i, frame in enumerate(self.skeletons): 74 | del frame[2::3] # remove confidence values 75 | self.skeletons[i], _ = normalize_skeleton(frame) # translate and scale 76 | 77 | # assertion: missing joints 78 | assert not np.isnan(self.skeletons[i]).any() 79 | 80 | # side view check 81 | if (self.skeletons[i][0] < min(self.skeletons[i][2 * 2], 82 | self.skeletons[i][5 * 2]) or 83 | self.skeletons[i][0] > max(self.skeletons[i][2 * 2], 84 | self.skeletons[i][5 * 2])): 85 | is_side_view = True 86 | break 87 | 88 | if len(self.skeletons) == 0 or is_side_view: 89 | self.filtering_message = "sideview" 90 | self.skeletons = [] 91 | 92 | return self.skeletons, self.filtering_message 93 | 94 | def is_static(self, verbose=False): 95 | def joint_angle(p1, p2, p3): 96 | v1 = p1 - p2 97 | v2 = p3 - p2 98 | ang1 = np.arctan2(*v1[::-1]) 99 | ang2 = np.arctan2(*v2[::-1]) 100 | return np.rad2deg((ang1 - ang2) % (2 * np.pi)) 101 | 102 | def get_joint_variance(skeleton, index1, index2, index3): 103 | angles = [] 104 | 105 | for i in range(skeleton.shape[0]): 106 | x1, y1 = skeleton[i, index1 * 3], skeleton[i, index1 * 3 + 1] 107 | x2, y2 = skeleton[i, index2 * 3], skeleton[i, index2 * 3 + 1] 108 | x3, y3 = skeleton[i, index3 * 3], skeleton[i, index3 * 3 + 1] 109 | angle = joint_angle(np.array([x1, y1]), np.array([x2, y2]), np.array([x3, y3])) 110 | angles.append(angle) 111 | 112 | variance = circvar(angles, low=0, high=360) 113 | return variance 114 | 115 | left_arm_var = get_joint_variance(self.skeletons, 2, 3, 4) 116 | right_arm_var = get_joint_variance(self.skeletons, 5, 6, 7) 117 | 118 | th = 150 119 | if left_arm_var < th and right_arm_var < th: 120 | print('too static - left var {}, right var {}'.format(left_arm_var, right_arm_var)) 121 | return True 122 | else: 123 | if verbose: 124 | print('not static - left var {}, right var {}'.format(left_arm_var, right_arm_var)) 125 | return False 126 | 127 | def has_jumping_joint(self, verbose=False): 128 | frame_diff = np.squeeze(self.skeletons[1:, :24] - self.skeletons[:-1, :24]) 129 | diffs = abs(frame_diff.flatten()) 130 | width = max(self.skeletons[0, :24:3]) - min(self.skeletons[0, :24:3]) 131 | 132 | if max(diffs) > width / 2.0: 133 | print('jumping joint - diff {}, width {}'.format(max(diffs), width)) 134 | return True 135 | else: 136 | if verbose: 137 | print('no jumping joint - diff {}, width {}'.format(max(diffs), width)) 138 | return False 139 | 140 | def has_missing_frames(self): 141 | n_empty_frames = 0 142 | n_frames = self.skeletons.shape[0] 143 | for i in range(n_frames): 144 | if np.sum(self.skeletons[i]) == 0: 145 | n_empty_frames += 1 146 | 147 | ret = n_empty_frames > n_frames * 0.1 148 | if ret: 149 | print('missing frames - {} / {}'.format(n_empty_frames, n_frames)) 150 | return ret 151 | 152 | def smooth_motion(self): 153 | for i in range(24): 154 | self.skeletons[:, i] = savgol_filter(self.skeletons[:, i], 5, 2) 155 | 156 | def fill_missing_joints(self): 157 | skeletons = self.skeletons 158 | n_joints = 8 # only upper body 159 | 160 | def nan_helper(y): 161 | return np.isnan(y), lambda z: z.nonzero()[0] 162 | 163 | for i in range(n_joints): 164 | xs, ys = skeletons[:, i * 3], skeletons[:, i * 3 + 1] 165 | xs[xs == 0] = np.nan 166 | ys[ys == 0] = np.nan 167 | 168 | if sum(np.isnan(xs)) > len(xs) / 2: 169 | skeletons = None 170 | break 171 | 172 | if sum(np.isnan(ys)) > len(ys) / 2: 173 | skeletons = None 174 | break 175 | 176 | if np.isnan(xs).any(): 177 | nans, t = nan_helper(xs) 178 | xs[nans] = np.interp(t(nans), t(~nans), xs[~nans]) 179 | skeletons[:, i * 3] = xs 180 | 181 | if np.isnan(ys).any(): 182 | nans, t = nan_helper(ys) 183 | ys[nans] = np.interp(t(nans), t(~nans), ys[~nans]) 184 | skeletons[:, i * 3 + 1] = ys 185 | 186 | return skeletons 187 | -------------------------------------------------------------------------------- /dataset_script/script/run_clip_filtering.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | from __future__ import unicode_literals 10 | import csv 11 | from clip_filter import * 12 | from main_speaker_selector import * 13 | from config import my_config 14 | import numpy as np 15 | 16 | RESUME_VID = '' # resume the process from this video 17 | 18 | 19 | def read_sceneinfo(filepath): # reading csv file 20 | with open(filepath, 'r') as csv_file: 21 | frame_list = [0] 22 | for row in csv.reader(csv_file): 23 | if row: 24 | frame_list.append((row[1])) 25 | frame_list[0:3] = [] # skip header 26 | 27 | frame_list = [int(x) for x in frame_list] # str to int 28 | 29 | return frame_list 30 | 31 | 32 | def run_filtering(vid, scene_data, skeleton_wrapper, video_wrapper): 33 | filtered_clip_data = [] 34 | aux_info = [] 35 | video = video_wrapper.get_video_reader() 36 | height = video_wrapper.height 37 | width = video_wrapper.width 38 | 39 | for i in range(len(scene_data) - 1): # note: last scene is not processed 40 | start_frame_no, end_frame_no = scene_data[i], scene_data[i + 1] 41 | raw_skeleton_chunk = skeleton_wrapper.get(start_frame_no, end_frame_no) 42 | main_speaker_skeletons = MainSpeakerSelector(raw_skeleton_chunk=raw_skeleton_chunk).get() 43 | 44 | # run clip filtering 45 | clip_filter = ClipFilter(vid = vid, video=video, start_frame_no=start_frame_no, end_frame_no=end_frame_no, 46 | raw_skeleton=raw_skeleton_chunk, main_speaker_skeletons=main_speaker_skeletons, height = height, width = width) 47 | correct_clip = clip_filter.is_correct_clip() 48 | 49 | filtering_results, message, debugging_info = clip_filter.get_filter_variable() 50 | filter_elem = {'clip_info': [start_frame_no, end_frame_no, correct_clip], 'filtering_results': filtering_results, 51 | 'message': message, 'debugging_info': debugging_info} 52 | aux_info.append(filter_elem) 53 | 54 | # save 55 | elem = {'clip_info': [start_frame_no, end_frame_no, correct_clip], 'frames': [], '3d': []} 56 | 57 | if not correct_clip: 58 | filtered_clip_data.append(elem) 59 | continue 60 | elem['frames'] = main_speaker_skeletons 61 | expose_list = [] 62 | for ii in range(start_frame_no, end_frame_no + 1): 63 | expose_path = my_config.EXPOSE_OUT_PATH + '/' + vid + '/' + '%05d'%ii + '.npz' 64 | if (not os.path.exists(expose_path)): 65 | expose_list.append({}) 66 | continue 67 | else: 68 | new_dict = {} 69 | file = np.load(expose_path) 70 | for key in file.files: 71 | if key not in ['fname', 'full_pose']: 72 | if isinstance(file[key], (np.ndarray)): 73 | new_dict[key] = file[key].tolist() 74 | else: 75 | new_dict[key] = file[key] 76 | expose_list.append(new_dict) 77 | elem['3d'] = expose_list 78 | filtered_clip_data.append(elem) 79 | 80 | return filtered_clip_data, aux_info 81 | 82 | 83 | def main(): 84 | if RESUME_VID == "": 85 | skip_flag = False 86 | else: 87 | skip_flag = True 88 | 89 | file_list = sorted(glob.glob(my_config.CLIP_PATH + "/*.csv"), key=os.path.getmtime) 90 | # file_list = file_list[:1] 91 | 92 | for csv_path in file_list: 93 | 94 | vid = os.path.split(csv_path)[1][0:11] 95 | print(vid) 96 | 97 | # resume check 98 | if skip_flag and vid == RESUME_VID: 99 | skip_flag = False 100 | 101 | if not skip_flag: 102 | scene_data = read_sceneinfo(csv_path) 103 | skeleton_wrapper = SkeletonWrapper(my_config.SKELETON_PATH, vid) 104 | video_wrapper = read_video(my_config.VIDEO_PATH, vid) 105 | 106 | # if video_wrapper.height < 720: 107 | # print('[Fatal error] wrong video size (height: {})'.format(video_wrapper.height)) 108 | # assert False 109 | 110 | if abs(video_wrapper.total_frames - len(skeleton_wrapper.skeletons)) > 10: 111 | print('[Fatal error] video and skeleton object have different lengths (video: {}, skeletons: {})'.format 112 | (video_wrapper.total_frames, len(skeleton_wrapper.skeletons))) 113 | assert False 114 | 115 | if skeleton_wrapper.skeletons == [] or video_wrapper is None: 116 | print('[warning] no skeleton or video! skipped this video.') 117 | else: 118 | ############################################################################################### 119 | filtered_clip_data, aux_info = run_filtering(vid, scene_data, skeleton_wrapper, video_wrapper) 120 | ############################################################################################### 121 | 122 | # save filtered clips and aux info 123 | with open("{}/{}.json".format(my_config.FILTER_PATH, vid), 'w') as clip_file: 124 | json.dump(filtered_clip_data, clip_file) 125 | with open("{}/{}_aux_info.json".format(my_config.FILTER_PATH, vid), 'w') as aux_file: 126 | json.dump(aux_info, aux_file) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /dataset_script/script/run_expose.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | """ 10 | Extract pose skeletons by using OpenPose library 11 | Need proper LD_LIBRARY_PATH before run this script 12 | Pycharm: In RUN > Edit Configurations, add LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 13 | """ 14 | 15 | # python inference.py --exp-cfg data/conf.yaml --datasets openpose --exp-opts datasets.body.batch_size 1 datasets.body.openpose.data_folder /mnt/lustre/liuxian/youtube-gesture-dataset/temp_skeleton_raw/-2Dj9M71JAc --show False --output-folder OUTPUT_FOLDER --save-params True --save-vis False --save-mesh False 16 | # python inference.py --exp-cfg data/conf.yaml --datasets openpose --exp-opts datasets.body.batch_size 64 datasets.body.openpose.data_folder /home/yuxi/openpose/youtube-gesture-dataset/temp_skeleton_raw/-2Dj9M71JAc --show False --output-folder OUTPUT_FOLDER --save-params True --save-vis False --save-mesh False 17 | # python demo.py --image-folder samples --exp-cfg data/conf.yaml --show=False --output-folder OUTPUT_FOLDER --save-params True --save-vis False --save-mesh False 18 | 19 | import glob 20 | import json 21 | import os 22 | import pickle 23 | import subprocess 24 | 25 | import shutil 26 | 27 | from config import my_config 28 | 29 | # maximum accuracy, too slow (~1fps) 30 | # OPENPOSE_OPTION = "--net_resolution -1x736 --scale_number 4 --scale_gap 0.25 --hand --hand_scale_number 6 --hand_scale_range 0.4 --face" 31 | # OPENPOSE_OPTION = "--face --hand --number_people_max 1 -model_pose COCO --display 0 --render_pose 0" 32 | 33 | OUTPUT_SKELETON_PATH = my_config.WORK_PATH + "/temp_skeleton_raw" 34 | OUTPUT_3D_PATH = my_config.WORK_PATH + "/expose_ted" 35 | 36 | RESUME_VID = "" # resume from this video 37 | SKIP_EXISTING_SKELETON = True # skip if the skeleton file is existing 38 | 39 | 40 | def get_vid_from_filename(filename): 41 | return filename[-15:-4] 42 | 43 | 44 | def read_skeleton_json(_file): 45 | with open(_file) as json_file: 46 | skeleton_json = json.load(json_file) 47 | return skeleton_json['people'] 48 | 49 | 50 | def save_skeleton_to_pickle(_vid): 51 | files = glob.glob(OUTPUT_SKELETON_PATH + '/' + _vid + '/*.json') 52 | if len(files) > 10: 53 | files = sorted(files) 54 | skeletons = [] 55 | for file in files: 56 | skeletons.append(read_skeleton_json(file)) 57 | with open(my_config.SKELETON_PATH + '/' + _vid + '.pickle', 'wb') as file: 58 | pickle.dump(skeletons, file) 59 | 60 | 61 | if __name__ == '__main__': 62 | if not os.path.exists(OUTPUT_3D_PATH): 63 | os.makedirs(OUTPUT_3D_PATH) 64 | 65 | os.chdir(my_config.EXPOSE_BASE_DIR) 66 | 67 | if RESUME_VID == "": 68 | skip_flag = False 69 | else: 70 | skip_flag = True 71 | 72 | video_files = glob.glob(my_config.VIDEO_PATH + "/*.mp4") 73 | 74 | for file in sorted(video_files, key=os.path.getmtime): 75 | # print(file) 76 | vid = get_vid_from_filename(file) 77 | print(vid) 78 | 79 | skip_iter = False 80 | 81 | # resume check 82 | if skip_flag and vid == RESUME_VID: 83 | skip_flag = False 84 | skip_iter = skip_flag 85 | 86 | # # existing skeleton check 87 | # if SKIP_EXISTING_SKELETON: 88 | # if os.path.exists(my_config.SKELETON_PATH + '/' + vid + '.pickle'): 89 | # print('existing skeleton') 90 | # skip_iter = True 91 | 92 | if not skip_iter: 93 | # create out dir 94 | expose_out = OUTPUT_3D_PATH + "/" + vid 95 | if os.path.exists(expose_out): 96 | shutil.rmtree(expose_out) 97 | 98 | os.makedirs(expose_out) 99 | 100 | # call expose 101 | command = "python " + my_config.EXPOSE_BASE_DIR + "inference.py --exp-cfg " + my_config.EXPOSE_BASE_DIR + "data/conf.yaml --datasets openpose --exp-opts datasets.body.batch_size 256 datasets.body.openpose.data_folder " + OUTPUT_SKELETON_PATH + "/" + vid + " --show False --output-folder " + expose_out + " --save-params True --save-vis False --save-mesh False" 102 | print(command) 103 | subprocess.call(command, shell=True) 104 | 105 | # save skeletons to a pickle file 106 | # save_skeleton_to_pickle(vid) -------------------------------------------------------------------------------- /dataset_script/script/run_ffmpeg.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | """ 10 | Extract pose skeletons by using OpenPose library 11 | Need proper LD_LIBRARY_PATH before run this script 12 | Pycharm: In RUN > Edit Configurations, add LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 13 | """ 14 | 15 | import glob 16 | import json 17 | import os 18 | import pickle 19 | import subprocess 20 | 21 | import shutil 22 | 23 | from config import my_config 24 | 25 | OUTPUT_SKELETON_PATH = my_config.WORK_PATH + "/temp_skeleton_raw" 26 | 27 | def get_vid_from_filename(filename): 28 | return filename[-15:-4] 29 | 30 | 31 | if __name__ == '__main__': 32 | 33 | if not os.path.exists(OUTPUT_SKELETON_PATH): 34 | os.makedirs(OUTPUT_SKELETON_PATH) 35 | 36 | video_files = glob.glob(my_config.VIDEO_PATH + "/*.mp4") 37 | for file in sorted(video_files, key=os.path.getmtime): 38 | print(file) 39 | vid = get_vid_from_filename(file) 40 | print(vid) 41 | 42 | # create out dir 43 | skeleton_dir = OUTPUT_SKELETON_PATH + "/" + vid + "/" 44 | if not os.path.exists(skeleton_dir): 45 | os.makedirs(skeleton_dir) 46 | 47 | if os.path.exists(skeleton_dir + "images/"): 48 | shutil.rmtree(skeleton_dir + "images/") 49 | 50 | os.makedirs(skeleton_dir + "images/") 51 | 52 | command_ffmpeg = "ffmpeg -i " + my_config.VIDEO_PATH + "/" + vid + ".mp4 -start_number 0 -f image2 " + OUTPUT_SKELETON_PATH + "/" + vid + "/images/" + vid + "_%012d" + ".png" 53 | print(command_ffmpeg) 54 | subprocess.call(command_ffmpeg, shell=True) 55 | 56 | # import glob 57 | # import json 58 | # import os 59 | # import pickle 60 | # import subprocess 61 | 62 | # import shutil 63 | 64 | # def main(): 65 | # lmk_root = '/mnt/lustressd/share/wangyuxin1/DATASETS/lrw-v1/lipread_ldmk/' 66 | # align_root = '/mnt/lustre/DATAshare3/lrw_srcM/' 67 | # local_root_lmk = '/mnt/lustre/liuxian.vendor/lrw_lmk/' 68 | # local_root_align = '/mnt/lustre/liuxian.vendor/lrw_align/' 69 | # category_list = os.listdir(lmk_root) 70 | # for category in category_list: 71 | # if os.path.exists(local_root_lmk + category): 72 | # shutil.rmtree(local_root_lmk + category) 73 | # os.makedirs(local_root_lmk + category) 74 | # abs_train = local_root_lmk + category + '/train' 75 | # abs_val = local_root_lmk + category + '/val' 76 | # abs_test = local_root_lmk + category + '/test' 77 | # if os.path.exists(abs_train): 78 | # shutil.rmtree(abs_train) 79 | # if os.path.exists(abs_val): 80 | # shutil.rmtree(abs_val) 81 | # if os.path.exists(abs_test): 82 | # shutil.rmtree(abs_test) 83 | # os.makedirs(abs_train) 84 | # os.makedirs(abs_val) 85 | # os.makedirs(abs_test) 86 | 87 | # fname_list_train = os.listdir(lmk_root + category + '/train/') 88 | # fname_list_val = os.listdir(lmk_root + category + '/val/') 89 | # fname_list_test = os.listdir(lmk_root + category + '/test/') 90 | 91 | # for fname in fname_list_train: 92 | # if os.path.exists(abs_train + '/' + fname): 93 | # shutil.rmtree(abs_train + '/' + fname) 94 | # os.makedirs(abs_train + '/' + fname) 95 | # command = "cp " + lmk_root + category + '/train/' + fname + '/lmk.txt ' + abs_train + '/' + fname + '/lmk.txt' 96 | # subprocess.call(command, shell=True) 97 | 98 | # for fname in fname_list_val: 99 | # if os.path.exists(abs_val + '/' + fname): 100 | # shutil.rmtree(abs_val + '/' + fname) 101 | # os.makedirs(abs_val + '/' + fname) 102 | # command = "cp " + lmk_root + category + '/val/' + fname + '/lmk.txt ' + abs_val + '/' + fname + '/lmk.txt' 103 | # subprocess.call(command, shell=True) 104 | 105 | # for fname in fname_list_test: 106 | # if os.path.exists(abs_test + '/' + fname): 107 | # shutil.rmtree(abs_test + '/' + fname) 108 | # os.makedirs(abs_test + '/' + fname) 109 | # command = "cp " + lmk_root + category + '/test/' + fname + '/lmk.txt ' + abs_test + '/' + fname + '/lmk.txt' 110 | # subprocess.call(command, shell=True) 111 | 112 | # category_list = os.listdir(align_root) 113 | # for category in category_list: 114 | # if os.path.exists(local_root_align + category): 115 | # shutil.rmtree(local_root_align + category) 116 | # os.makedirs(local_root_align + category) 117 | # abs_train = local_root_align + category + '/train' 118 | # abs_test = local_root_align + category + '/test' 119 | # if os.path.exists(abs_train): 120 | # shutil.rmtree(abs_train) 121 | # if os.path.exists(abs_test): 122 | # shutil.rmtree(abs_test) 123 | # os.makedirs(abs_train) 124 | # os.makedirs(abs_test) 125 | 126 | # fname_list_train = glob.glob(align_root + category + '/train/*.txt') 127 | # fname_list_test = glob.glob(align_root + category + '/test/*.txt') 128 | 129 | # for fname in fname_list_train: 130 | # fname = fname[:-4] 131 | # if os.path.exists(abs_train + '/' + fname): 132 | # shutil.rmtree(abs_train + '/' + fname) 133 | # os.makedirs(abs_train + '/' + fname) 134 | # command = "cp " + align_root + category + '/train/' + fname + '.txt ' + abs_train + '/' + fname + '/align.txt' 135 | # subprocess.call(command, shell=True) 136 | 137 | # for fname in fname_list_test: 138 | # fname = fname[:-4] 139 | # if os.path.exists(abs_test + '/' + fname): 140 | # shutil.rmtree(abs_test + '/' + fname) 141 | # os.makedirs(abs_test + '/' + fname) 142 | # command = "cp " + align_root + category + '/test/' + fname + '.txt ' + abs_test + '/' + fname + '/align.txt' 143 | # subprocess.call(command, shell=True) 144 | 145 | # if __name__ == '__main__': 146 | # main() -------------------------------------------------------------------------------- /dataset_script/script/run_gentle.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | import glob 10 | import logging 11 | import multiprocessing 12 | import os 13 | import re 14 | import sys 15 | 16 | from tqdm import tqdm 17 | 18 | from config import * 19 | from make_ted_dataset import read_subtitle 20 | from config import my_config 21 | 22 | sys.path.insert(0, '../../../gentle') 23 | import gentle 24 | 25 | 26 | # prepare gentle 27 | nthreads = multiprocessing.cpu_count() - 2 28 | logging.getLogger().setLevel("WARNING") 29 | disfluencies = set(['uh', 'um']) 30 | resources = gentle.Resources() 31 | 32 | 33 | def run_gentle(video_path, vid, result_path): 34 | vtt_subtitle = read_subtitle(vid) 35 | transcript = '' 36 | for i, sub in enumerate(vtt_subtitle): 37 | transcript += (vtt_subtitle[i].text + ' ') 38 | transcript = re.sub('\n', ' ', transcript) # remove newline characters 39 | 40 | # align 41 | with gentle.resampled(video_path) as wav_file: 42 | aligner = gentle.ForcedAligner(resources, transcript, nthreads=nthreads, disfluency=False, conservative=False, 43 | disfluencies=disfluencies) 44 | result = aligner.transcribe(wav_file, logging=logging) 45 | 46 | # write results 47 | with open(result_path, 'w', encoding="utf-8") as fh: 48 | fh.write(result.to_json(indent=2)) 49 | 50 | 51 | def main(): 52 | videos = glob.glob(my_config.VIDEO_PATH + "/*.mp4") 53 | n_total = len(videos) 54 | for i, file_path in tqdm(enumerate(sorted(videos, key=os.path.getmtime))): 55 | vid = os.path.split(file_path)[1][-15:-4] 56 | print('{}/{} - {}'.format(i+1, n_total, vid)) 57 | result_path = my_config.VIDEO_PATH + '/' + vid + '_align_results.json' 58 | if os.path.exists(result_path) and os.path.getsize(result_path): # existing and not empty 59 | print('JSON file already exists ({})'.format(vid)) 60 | else: 61 | run_gentle(file_path, vid, result_path) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /dataset_script/script/run_mp3.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | """ 10 | Extract pose skeletons by using OpenPose library 11 | Need proper LD_LIBRARY_PATH before run this script 12 | Pycharm: In RUN > Edit Configurations, add LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 13 | """ 14 | 15 | import glob 16 | import json 17 | import os 18 | import pickle 19 | import subprocess 20 | 21 | import shutil 22 | 23 | from config import my_config 24 | 25 | def get_vid_from_filename(filename): 26 | return filename[-15:-4] 27 | 28 | 29 | if __name__ == '__main__': 30 | audiopath = my_config.WORK_PATH + '/audio_ted' 31 | if not os.path.exists(audiopath): 32 | os.makedirs(audiopath) 33 | 34 | video_files = glob.glob(my_config.VIDEO_PATH + "/*.mp4") 35 | for file in sorted(video_files, key=os.path.getmtime): 36 | print(file) 37 | vid = get_vid_from_filename(file) 38 | print(vid) 39 | 40 | command = "ffmpeg -i " + my_config.VIDEO_PATH + "/" + vid + ".mp4 " + audiopath + '/' + vid + ".mp3" 41 | print(command) 42 | subprocess.call(command, shell=True) -------------------------------------------------------------------------------- /dataset_script/script/run_openpose.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | """ 10 | Extract pose skeletons by using OpenPose library 11 | Need proper LD_LIBRARY_PATH before run this script 12 | Pycharm: In RUN > Edit Configurations, add LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 13 | """ 14 | 15 | import glob 16 | import json 17 | import os 18 | import pickle 19 | import subprocess 20 | 21 | import shutil 22 | 23 | from config import my_config 24 | 25 | # maximum accuracy, too slow (~1fps) 26 | #OPENPOSE_OPTION = "--net_resolution -1x736 --scale_number 4 --scale_gap 0.25 --hand --hand_scale_number 6 --hand_scale_range 0.4 --face" 27 | OPENPOSE_OPTION = "--face --hand --number_people_max 1 --display 0 --render_pose 0" 28 | 29 | OUTPUT_SKELETON_PATH = my_config.WORK_PATH + "/temp_skeleton_raw" 30 | OUTPUT_VIDEO_PATH = my_config.WORK_PATH + "/temp_skeleton_video" 31 | 32 | RESUME_VID = "" # resume from this video 33 | SKIP_EXISTING_SKELETON = True # skip if the skeleton file is existing 34 | 35 | 36 | def get_vid_from_filename(filename): 37 | return filename[-15:-4] 38 | 39 | 40 | def read_skeleton_json(_file): 41 | with open(_file) as json_file: 42 | skeleton_json = json.load(json_file) 43 | return skeleton_json['people'] 44 | 45 | 46 | def save_skeleton_to_pickle(_vid): 47 | files = glob.glob(OUTPUT_SKELETON_PATH + '/' + _vid + "/keypoints" + '/*.json') 48 | if len(files) > 10: 49 | files = sorted(files) 50 | skeletons = [] 51 | for file in files: 52 | skeletons.append(read_skeleton_json(file)) 53 | with open(my_config.SKELETON_PATH + '/' + _vid + '.pickle', 'wb') as file: 54 | pickle.dump(skeletons, file) 55 | 56 | 57 | if __name__ == '__main__': 58 | if not os.path.exists(my_config.SKELETON_PATH): 59 | os.makedirs(my_config.SKELETON_PATH) 60 | if not os.path.exists(OUTPUT_SKELETON_PATH): 61 | os.makedirs(OUTPUT_SKELETON_PATH) 62 | if not os.path.exists(OUTPUT_VIDEO_PATH): 63 | os.makedirs(OUTPUT_VIDEO_PATH) 64 | 65 | os.chdir(my_config.OPENPOSE_BASE_DIR) 66 | if RESUME_VID == "": 67 | skip_flag = False 68 | else: 69 | skip_flag = True 70 | 71 | video_files = glob.glob(my_config.VIDEO_PATH + "/*.mp4") 72 | for file in sorted(video_files, key=os.path.getmtime): 73 | print(file) 74 | vid = get_vid_from_filename(file) 75 | print(vid) 76 | 77 | skip_iter = False 78 | 79 | # resume check 80 | if skip_flag and vid == RESUME_VID: 81 | skip_flag = False 82 | skip_iter = skip_flag 83 | 84 | # existing skeleton check 85 | if SKIP_EXISTING_SKELETON: 86 | if os.path.exists(my_config.SKELETON_PATH + '/' + vid + '.pickle'): 87 | print('existing skeleton') 88 | skip_iter = True 89 | 90 | if not skip_iter: 91 | # create out dir 92 | skeleton_dir = OUTPUT_SKELETON_PATH + "/" + vid + "/" 93 | if not os.path.exists(skeleton_dir): 94 | os.makedirs(skeleton_dir) 95 | 96 | if os.path.exists(skeleton_dir + "keypoints/"): 97 | shutil.rmtree(skeleton_dir + "keypoints/") 98 | 99 | os.makedirs(skeleton_dir + "keypoints/") 100 | 101 | skeleton_dir += "keypoints/" 102 | 103 | # extract skeleton 104 | command = my_config.OPENPOSE_BIN_PATH + " " + OPENPOSE_OPTION + " --video \"" + file + "\"" 105 | # command += " --write_video " + OUTPUT_VIDEO_PATH + "/" + vid + "_result.avi" # write result video 106 | command += " --write_json " + skeleton_dir 107 | print(command) 108 | subprocess.call(command, shell=True) 109 | 110 | # save skeletons to a pickle file 111 | save_skeleton_to_pickle(vid) 112 | -------------------------------------------------------------------------------- /dataset_script/script/run_scenedetect.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright 2019 ETRI. All rights reserved. 3 | # Licensed under the BSD 3-Clause License. 4 | # This file is part of Youtube-Gesture-Dataset, a sub-project of AIR(AI for Robots) project. 5 | # You can refer to details of AIR project at https://aiforrobots.github.io 6 | # Written by Youngwoo Yoon (youngwoo@etri.re.kr) 7 | # ------------------------------------------------------------------------------ 8 | 9 | from __future__ import unicode_literals 10 | import subprocess 11 | import glob 12 | import os 13 | from tqdm import tqdm 14 | from config import my_config 15 | 16 | 17 | def run_pyscenedetect(file_path, vid): # using Pyscenedetect 18 | os.chdir(my_config.VIDEO_PATH) 19 | 20 | cmd = 'scenedetect --input "{}" --output "{}" -d 4 detect-content list-scenes'.format(file_path, my_config.CLIP_PATH) 21 | print(' ' + cmd) 22 | subprocess.run(cmd, shell=True, check=True) 23 | subprocess.run("exit", shell=True, check=True) 24 | 25 | 26 | def main(): 27 | if not os.path.exists(my_config.CLIP_PATH): 28 | os.makedirs(my_config.CLIP_PATH) 29 | 30 | videos = glob.glob(my_config.VIDEO_PATH + "/*.mp4") 31 | n_total = len(videos) 32 | for i, file_path in tqdm(enumerate(sorted(videos, key=os.path.getmtime))): 33 | print('{}/{}'.format(i+1, n_total)) 34 | vid = os.path.split(file_path)[1][-15:-4] 35 | 36 | csv_files = glob.glob(my_config.CLIP_PATH + "/{}*.csv".format(vid)) 37 | if len(csv_files) > 0 and os.path.getsize(csv_files[0]): # existing and not empty 38 | print(' CSV file already exists ({})'.format(vid)) 39 | else: 40 | run_pyscenedetect(file_path, vid) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /misc/HA2G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvinliu0/HA2G/5e1fd3343db7aa587db68a78397a1bbfea165132/misc/HA2G.png -------------------------------------------------------------------------------- /misc/sample1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvinliu0/HA2G/5e1fd3343db7aa587db68a78397a1bbfea165132/misc/sample1.gif -------------------------------------------------------------------------------- /misc/sample1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvinliu0/HA2G/5e1fd3343db7aa587db68a78397a1bbfea165132/misc/sample1.mp4 -------------------------------------------------------------------------------- /misc/sample2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvinliu0/HA2G/5e1fd3343db7aa587db68a78397a1bbfea165132/misc/sample2.gif -------------------------------------------------------------------------------- /misc/sample2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvinliu0/HA2G/5e1fd3343db7aa587db68a78397a1bbfea165132/misc/sample2.mp4 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.2 2 | numpy 3 | scipy 4 | scikit-learn 5 | matplotlib 6 | librosa 7 | tensorboard>=1.14 8 | future 9 | pyarrow==0.14.1 10 | lmdb==0.96 11 | tqdm 12 | fasttext 13 | configargparse 14 | soundfile 15 | pygame 16 | google-cloud-texttospeech==1.0.1 17 | librosa 18 | umap -------------------------------------------------------------------------------- /scripts/calculate_angle_stats.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | import time 3 | from pathlib import Path 4 | import sys 5 | 6 | [sys.path.append(i) for i in ['.', '..']] 7 | 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torch.nn as nn 13 | 14 | from model import speech2gesture, vocab 15 | from utils.average_meter import AverageMeter 16 | from utils.data_utils import convert_dir_vec_to_pose 17 | from utils.vocab_utils import build_vocab 18 | 19 | matplotlib.use('Agg') # we don't use interactive GUI 20 | 21 | from parse_args import parse_args 22 | from model.embedding_space_evaluator import EmbeddingSpaceEvaluator 23 | 24 | import math 25 | 26 | from data_loader.lmdb_data_loader import * 27 | import utils.train_utils 28 | 29 | import librosa 30 | from librosa.feature import melspectrogram 31 | 32 | from tqdm import tqdm 33 | 34 | import warnings 35 | warnings.filterwarnings('ignore') 36 | 37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 38 | 39 | def main(config): 40 | args = config['args'] 41 | 42 | # random seed 43 | if args.random_seed >= 0: 44 | utils.train_utils.set_random_seed(args.random_seed) 45 | 46 | collate_fn = default_collate_fn 47 | 48 | # dataset 49 | mean_dir_vec = np.array(args.mean_dir_vec).reshape(-1, 3) 50 | train_dataset = SpeechMotionDataset(args.train_data_path[0], 51 | n_poses=args.n_poses, 52 | subdivision_stride=args.subdivision_stride, 53 | pose_resampling_fps=args.motion_resampling_framerate, 54 | mean_dir_vec=mean_dir_vec, 55 | mean_pose=args.mean_pose, 56 | remove_word_timing=(args.input_context == 'text') 57 | ) 58 | train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, 59 | shuffle=True, drop_last=True, num_workers=args.loader_workers, pin_memory=True, 60 | collate_fn=collate_fn 61 | ) 62 | 63 | val_dataset = SpeechMotionDataset(args.val_data_path[0], 64 | n_poses=args.n_poses, 65 | subdivision_stride=args.subdivision_stride, 66 | pose_resampling_fps=args.motion_resampling_framerate, 67 | speaker_model=train_dataset.speaker_model, 68 | mean_dir_vec=mean_dir_vec, 69 | mean_pose=args.mean_pose, 70 | remove_word_timing=(args.input_context == 'text') 71 | ) 72 | 73 | test_dataset = SpeechMotionDataset(args.test_data_path[0], 74 | n_poses=args.n_poses, 75 | subdivision_stride=args.subdivision_stride, 76 | pose_resampling_fps=args.motion_resampling_framerate, 77 | speaker_model=train_dataset.speaker_model, 78 | mean_dir_vec=mean_dir_vec, 79 | mean_pose=args.mean_pose) 80 | 81 | # build vocab 82 | vocab_cache_path = os.path.join(os.path.split(args.train_data_path[0])[0], 'vocab_cache.pkl') 83 | lang_model = build_vocab('words', [train_dataset, val_dataset, test_dataset], vocab_cache_path, args.wordembed_path, 84 | args.wordembed_dim) 85 | train_dataset.set_lang_model(lang_model) 86 | val_dataset.set_lang_model(lang_model) 87 | 88 | if args.pose_dim == 27: 89 | angle_pair = [ 90 | (3, 4), 91 | (4, 5), 92 | (6, 7), 93 | (7, 8) 94 | ] 95 | elif args.pose_dim == 126: 96 | angle_pair = [ 97 | (0, 1), 98 | (0, 2), 99 | (1, 3), 100 | (3, 4), 101 | (5, 6), 102 | (6, 7), 103 | (8, 9), 104 | (9, 10), 105 | (11, 12), 106 | (12, 13), 107 | (14, 15), 108 | (15, 16), 109 | (17, 18), 110 | (18, 19), 111 | (17, 5), 112 | (5, 8), 113 | (8, 14), 114 | (14, 11), 115 | (2, 20), 116 | (20, 21), 117 | (22, 23), 118 | (23, 24), 119 | (25, 26), 120 | (26, 27), 121 | (28, 29), 122 | (29, 30), 123 | (31, 32), 124 | (32, 33), 125 | (34, 35), 126 | (35, 36), 127 | (34, 22), 128 | (22, 25), 129 | (25, 31), 130 | (31, 28), 131 | (0, 37), 132 | (37, 38), 133 | (37, 39), 134 | (38, 40), 135 | (39, 41), 136 | # palm 137 | (4, 42), 138 | (21, 43) 139 | ] 140 | else: 141 | assert False 142 | 143 | avg_angle = [0] * len(angle_pair) 144 | var_angle = [0] * len(angle_pair) 145 | change_angle = [0] * len(angle_pair) 146 | 147 | cnt_angle = 0 148 | cnt_change = 0 149 | 150 | # stat angle 151 | for data in tqdm(train_loader): 152 | 153 | in_text, text_lengths, in_text_padded, _, target_vec, in_audio, in_spec, aux_info = data 154 | batch_size = target_vec.size(0) 155 | target_vec = target_vec + torch.tensor(args.mean_dir_vec).squeeze(1).unsqueeze(0).unsqueeze(0) 156 | target_vec = target_vec.to(device) 157 | if args.pose_dim == 126: 158 | left_palm = torch.cross(target_vec[:, :, 11 * 3 : 12 * 3], target_vec[:, :, 17 * 3 : 18 * 3], dim = 2) 159 | right_palm = torch.cross(target_vec[:, :, 28 * 3 : 29 * 3], target_vec[:, :, 34 * 3 : 35 * 3], dim = 2) 160 | target_vec = torch.cat((target_vec, left_palm, right_palm), dim = 2) 161 | target_vec = target_vec.reshape(target_vec.shape[0], target_vec.shape[1], -1, 3) 162 | target_vec = F.normalize(target_vec, dim = -1) 163 | 164 | angle_batch = target_vec.shape[0] * target_vec.shape[1] 165 | change_batch = target_vec.shape[0] * (target_vec.shape[1] - 1) 166 | all_vec = target_vec.reshape(target_vec.shape[0] * target_vec.shape[1], -1, 3) 167 | 168 | for idx, pair in enumerate(angle_pair): 169 | vec1 = all_vec[:, pair[0]] 170 | vec2 = all_vec[:, pair[1]] 171 | inner_product = torch.einsum('ij,ij->i', [vec1, vec2]) 172 | inner_product = torch.clamp(inner_product, -1, 1, out=None) 173 | angle = torch.acos(inner_product) / math.pi 174 | angle_time = angle.reshape(batch_size, -1) 175 | angle_diff = torch.mean(torch.abs(angle_time[:, 1:] - angle_time[:, :-1])) 176 | avg_batch = torch.mean(angle) 177 | var_batch = torch.var(angle) 178 | if (torch.isnan(angle_diff)): 179 | angle_diff = change_angle[idx] 180 | if (torch.isnan(avg_batch)): 181 | avg_batch = avg_angle[idx] 182 | if (torch.isnan(var_batch)): 183 | var_batch = var_angle[idx] 184 | history_avg = avg_angle[idx] 185 | change_angle[idx] = (cnt_change * change_angle[idx] + angle_diff * change_batch) / (cnt_change + change_batch) 186 | avg_angle[idx] = (cnt_angle * avg_angle[idx] + angle_batch * avg_batch) / (cnt_angle + angle_batch) 187 | var_angle[idx] = (cnt_angle * (var_angle[idx] + torch.pow((avg_angle[idx] - history_avg), 2)) + angle_batch * (var_batch + torch.pow((avg_angle[idx] - avg_batch), 2))) / (cnt_angle + angle_batch) 188 | 189 | cnt_angle += angle_batch 190 | cnt_change += change_batch 191 | change_angle = [x.item() for x in change_angle] 192 | avg_angle = [x.item() for x in avg_angle] 193 | var_angle = [x.item() for x in var_angle] 194 | 195 | print('change angle: ', change_angle) 196 | print('avg angle: ', avg_angle) 197 | print('var angle: ', var_angle) 198 | 199 | if __name__ == '__main__': 200 | _args = parse_args() 201 | main({'args': _args}) 202 | -------------------------------------------------------------------------------- /scripts/calculate_motion_stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb 3 | import numpy as np 4 | import pyarrow 5 | import torch 6 | 7 | # calculate the motion stats for TED Expressive dataset 8 | import utils.train_utils_expressive 9 | import utils.data_utils_expressive 10 | 11 | 12 | def calculate_data_mean(base_path): 13 | lmdb_path = os.path.join(base_path, 'train') 14 | lmdb_env = lmdb.open(lmdb_path, readonly=True, lock=False) 15 | with lmdb_env.begin() as txn: 16 | n_videos = txn.stat()['entries'] 17 | src_txn = lmdb_env.begin(write=False) 18 | cursor = src_txn.cursor() 19 | 20 | pose_seq_list = [] 21 | total_duration = 0 22 | 23 | for key, value in cursor: 24 | video = pyarrow.deserialize(value) 25 | vid = video['vid'] 26 | clips = video['clips'] 27 | for clip_idx, clip in enumerate(clips): 28 | poses = clip['skeletons_3d'] 29 | pose_seq_list.append(poses) 30 | total_duration += (clip['end_time'] - clip['start_time']) 31 | 32 | # close db 33 | lmdb_env.close() 34 | 35 | all_poses = np.vstack(pose_seq_list) 36 | mean_pose = np.mean(all_poses, axis=0) 37 | 38 | # mean dir vec 39 | dir_vec = utils.data_utils_expressive.convert_pose_seq_to_dir_vec(torch.from_numpy(all_poses)).numpy() 40 | mean_dir_vec = np.mean(dir_vec, axis=0) 41 | 42 | # mean bone length 43 | bone_lengths = [] 44 | for i, pair in enumerate(utils.data_utils_expressive.dir_vec_pairs): 45 | vec = all_poses[:, pair[1]] - all_poses[:, pair[0]] 46 | bone_lengths.append(np.mean(np.linalg.norm(vec, axis=1))) 47 | 48 | print('mean pose', repr(mean_pose.flatten())) 49 | print('mean directional vector', repr(mean_dir_vec.flatten())) 50 | print('mean bone lengths', repr(bone_lengths)) 51 | print('total duration of the valid clips: {:.1f} h'.format(total_duration/3600)) 52 | 53 | 54 | if __name__ == '__main__': 55 | # import matplotlib 56 | # matplotlib.use('TkAgg') 57 | np.set_printoptions(precision=7, suppress=True) 58 | 59 | lmdb_base_path = '/mnt/lustressd/share/liuxian.vendor/complex_dataset' 60 | calculate_data_mean(lmdb_base_path) -------------------------------------------------------------------------------- /scripts/data_loader/h36m_loader.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import torch 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | 8 | from utils.data_utils import convert_pose_seq_to_dir_vec, convert_dir_vec_to_pose 9 | 10 | train_subject = ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11'] 11 | test_subject = ['S11'] 12 | 13 | 14 | class Human36M(Dataset): 15 | def __init__(self, path, mean_data, is_train=True, augment=False): 16 | n_poses = 34 17 | target_joints = [1, 6, 12, 13, 14, 15, 17, 18, 19, 25, 26, 27] # see https://github.com/kenkra/3d-pose-baseline-vmd/wiki/body 18 | 19 | self.is_train = is_train 20 | self.augment = augment 21 | self.mean_data = mean_data 22 | self.data = [] 23 | 24 | if is_train: 25 | subjects = train_subject 26 | else: 27 | subjects = test_subject 28 | 29 | # loading data and normalize 30 | frame_stride = 2 31 | data = np.load(path, allow_pickle=True)['positions_3d'].item() 32 | for subject, actions in data.items(): 33 | if subject not in subjects: 34 | continue 35 | 36 | for action_name, positions in actions.items(): 37 | positions = positions[:, target_joints] 38 | positions = self.normalize(positions) 39 | for f in range(0, len(positions), 10): 40 | if f+n_poses*frame_stride > len(positions): 41 | break 42 | self.data.append(positions[f:f+n_poses*frame_stride:frame_stride]) 43 | 44 | def __getitem__(self, index): 45 | poses = self.data[index] 46 | dir_vec = convert_pose_seq_to_dir_vec(poses) 47 | poses = convert_dir_vec_to_pose(dir_vec) 48 | 49 | if self.augment: # data augmentation by adding gaussian noises on joints coordinates 50 | rand_val = random.random() 51 | if rand_val < 0.2: 52 | poses = poses.copy() 53 | poses += np.random.normal(0, 0.002 ** 0.5, poses.shape) 54 | else: 55 | poses = poses.copy() 56 | poses += np.random.normal(0, 0.0001 ** 0.5, poses.shape) 57 | 58 | dir_vec = convert_pose_seq_to_dir_vec(poses) 59 | dir_vec = dir_vec.reshape(dir_vec.shape[0], -1) 60 | dir_vec = dir_vec - self.mean_data 61 | 62 | poses = torch.from_numpy(poses).float() 63 | dir_vec = torch.from_numpy(dir_vec).float() 64 | return poses, dir_vec 65 | 66 | def __len__(self): 67 | return len(self.data) 68 | 69 | def normalize(self, data): 70 | 71 | # pose normalization 72 | for f in range(data.shape[0]): 73 | data[f, :] -= data[f, 2] 74 | data[f, :, (0, 1, 2)] = data[f, :, (0, 2, 1)] # xy exchange 75 | data[f, :, 1] = -data[f, :, 1] # invert y 76 | 77 | # frontalize based on hip joints 78 | for f in range(data.shape[0]): 79 | hip_vec = data[f, 1] - data[f, 0] 80 | angle = np.pi - np.math.atan2(hip_vec[2], hip_vec[0]) # angles on XZ plane 81 | if 180 > np.rad2deg(angle) > 0: 82 | pass 83 | elif 180 < np.rad2deg(angle) < 360: 84 | angle = angle - np.deg2rad(360) 85 | 86 | rot = self.rotation_matrix([0, 1, 0], angle) 87 | data[f] = np.matmul(data[f], rot) 88 | 89 | data = data[:, 2:] # exclude hip joints 90 | return data 91 | 92 | @staticmethod 93 | def rotation_matrix(axis, theta): 94 | """ 95 | Return the rotation matrix associated with counterclockwise rotation about 96 | the given axis by theta radians. 97 | """ 98 | axis = np.asarray(axis) 99 | axis = axis / math.sqrt(np.dot(axis, axis)) 100 | a = math.cos(theta / 2.0) 101 | b, c, d = -axis * math.sin(theta / 2.0) 102 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 103 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 104 | return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 105 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 106 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) 107 | 108 | -------------------------------------------------------------------------------- /scripts/data_loader/lmdb_data_loader.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import pickle 5 | import random 6 | 7 | import numpy as np 8 | import lmdb as lmdb 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | from torch.utils.data import Dataset, DataLoader 13 | from torch.utils.data.dataloader import default_collate 14 | 15 | import utils.train_utils 16 | import utils.data_utils 17 | from model.vocab import Vocab 18 | from data_loader.data_preprocessor import DataPreprocessor 19 | import pyarrow 20 | import copy 21 | 22 | 23 | def word_seq_collate_fn(data): 24 | """ collate function for loading word sequences in variable lengths """ 25 | # sort a list by sequence length (descending order) to use pack_padded_sequence 26 | data.sort(key=lambda x: len(x[0]), reverse=True) 27 | 28 | # separate source and target sequences 29 | word_seq, text_padded, poses_seq, vec_seq, audio, spectrogram, aux_info = zip(*data) 30 | 31 | # merge sequences 32 | words_lengths = torch.LongTensor([len(x) for x in word_seq]) 33 | word_seq = pad_sequence(word_seq, batch_first=True).long() 34 | 35 | text_padded = default_collate(text_padded) 36 | poses_seq = default_collate(poses_seq) 37 | vec_seq = default_collate(vec_seq) 38 | audio = default_collate(audio) 39 | spectrogram = default_collate(spectrogram) 40 | aux_info = {key: default_collate([d[key] for d in aux_info]) for key in aux_info[0]} 41 | 42 | return word_seq, words_lengths, text_padded, poses_seq, vec_seq, audio, spectrogram, aux_info 43 | 44 | 45 | def default_collate_fn(data): 46 | _, text_padded, pose_seq, vec_seq, audio, spectrogram, aux_info = zip(*data) 47 | 48 | text_padded = default_collate(text_padded) 49 | pose_seq = default_collate(pose_seq) 50 | vec_seq = default_collate(vec_seq) 51 | audio = default_collate(audio) 52 | spectrogram = default_collate(spectrogram) 53 | aux_info = {key: default_collate([d[key] for d in aux_info]) for key in aux_info[0]} 54 | 55 | return torch.tensor([0]), torch.tensor([0]), text_padded, pose_seq, vec_seq, audio, spectrogram, aux_info 56 | 57 | 58 | class SpeechMotionDataset(Dataset): 59 | def __init__(self, lmdb_dir, n_poses, subdivision_stride, pose_resampling_fps, mean_pose, mean_dir_vec, 60 | speaker_model=None, remove_word_timing=False): 61 | 62 | self.lmdb_dir = lmdb_dir 63 | self.n_poses = n_poses 64 | self.subdivision_stride = subdivision_stride 65 | self.skeleton_resampling_fps = pose_resampling_fps 66 | self.mean_dir_vec = mean_dir_vec 67 | self.remove_word_timing = remove_word_timing 68 | 69 | self.expected_audio_length = int(round(n_poses / pose_resampling_fps * 16000)) 70 | self.expected_spectrogram_length = utils.data_utils.calc_spectrogram_length_from_motion_length( 71 | n_poses, pose_resampling_fps) 72 | 73 | self.lang_model = None 74 | 75 | logging.info("Reading data '{}'...".format(lmdb_dir)) 76 | preloaded_dir = lmdb_dir + '_cache' 77 | if not os.path.exists(preloaded_dir): 78 | logging.info('Creating the dataset cache...') 79 | assert mean_dir_vec is not None 80 | if mean_dir_vec.shape[-1] != 3: 81 | mean_dir_vec = mean_dir_vec.reshape(mean_dir_vec.shape[:-1] + (-1, 3)) 82 | n_poses_extended = int(round(n_poses * 1.25)) # some margin 83 | data_sampler = DataPreprocessor(lmdb_dir, preloaded_dir, n_poses_extended, 84 | subdivision_stride, pose_resampling_fps, mean_pose, mean_dir_vec) 85 | data_sampler.run() 86 | else: 87 | logging.info('Found the cache {}'.format(preloaded_dir)) 88 | 89 | # init lmdb 90 | self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) 91 | with self.lmdb_env.begin() as txn: 92 | self.n_samples = txn.stat()['entries'] 93 | 94 | # make a speaker model 95 | if speaker_model is None or speaker_model == 0: 96 | precomputed_model = lmdb_dir + '_speaker_model.pkl' 97 | if not os.path.exists(precomputed_model): 98 | self._make_speaker_model(lmdb_dir, precomputed_model) 99 | else: 100 | with open(precomputed_model, 'rb') as f: 101 | self.speaker_model = pickle.load(f) 102 | else: 103 | self.speaker_model = speaker_model 104 | 105 | def __len__(self): 106 | return self.n_samples 107 | 108 | def __getitem__(self, idx): 109 | with self.lmdb_env.begin(write=False) as txn: 110 | key = '{:010}'.format(idx).encode('ascii') 111 | sample = txn.get(key) 112 | 113 | sample = pyarrow.deserialize(sample) 114 | word_seq, pose_seq, vec_seq, audio, spectrogram, aux_info = sample 115 | 116 | def extend_word_seq(lang, words, end_time=None): 117 | n_frames = self.n_poses 118 | if end_time is None: 119 | end_time = aux_info['end_time'] 120 | frame_duration = (end_time - aux_info['start_time']) / n_frames 121 | 122 | extended_word_indices = np.zeros(n_frames) # zero is the index of padding token 123 | if self.remove_word_timing: 124 | n_words = 0 125 | for word in words: 126 | idx = max(0, int(np.floor((word[1] - aux_info['start_time']) / frame_duration))) 127 | if idx < n_frames: 128 | n_words += 1 129 | space = int(n_frames / (n_words + 1)) 130 | for i in range(n_words): 131 | idx = (i+1) * space 132 | extended_word_indices[idx] = lang.get_word_index(words[i][0]) 133 | else: 134 | prev_idx = 0 135 | for word in words: 136 | idx = max(0, int(np.floor((word[1] - aux_info['start_time']) / frame_duration))) 137 | if idx < n_frames: 138 | extended_word_indices[idx] = lang.get_word_index(word[0]) 139 | # extended_word_indices[prev_idx:idx+1] = lang.get_word_index(word[0]) 140 | prev_idx = idx 141 | return torch.Tensor(extended_word_indices).long() 142 | 143 | def words_to_tensor(lang, words, end_time=None): 144 | indexes = [lang.SOS_token] 145 | for word in words: 146 | if end_time is not None and word[1] > end_time: 147 | break 148 | indexes.append(lang.get_word_index(word[0])) 149 | indexes.append(lang.EOS_token) 150 | return torch.Tensor(indexes).long() 151 | 152 | duration = aux_info['end_time'] - aux_info['start_time'] 153 | do_clipping = True 154 | 155 | if do_clipping: 156 | sample_end_time = aux_info['start_time'] + duration * self.n_poses / vec_seq.shape[0] 157 | audio = utils.data_utils.make_audio_fixed_length(audio, self.expected_audio_length) 158 | spectrogram = spectrogram[:, 0:self.expected_spectrogram_length] 159 | vec_seq = vec_seq[0:self.n_poses] 160 | pose_seq = pose_seq[0:self.n_poses] 161 | else: 162 | sample_end_time = None 163 | 164 | # to tensors 165 | word_seq_tensor = words_to_tensor(self.lang_model, word_seq, sample_end_time) 166 | extended_word_seq = extend_word_seq(self.lang_model, word_seq, sample_end_time) 167 | vec_seq = torch.from_numpy(copy.copy(vec_seq)).reshape((vec_seq.shape[0], -1)).float() 168 | pose_seq = torch.from_numpy(copy.copy(pose_seq)).reshape((pose_seq.shape[0], -1)).float() 169 | audio = torch.from_numpy(copy.copy(audio)).float() 170 | spectrogram = torch.from_numpy(copy.copy(spectrogram)).float() 171 | # mean = torch.mean(spectrogram, dim=1) 172 | # std = torch.std(spectrogram, dim=1) 173 | 174 | # spectrogram = (spectrogram - mean.unsqueeze(1)) / (std + 1e-8).unsqueeze(1) 175 | 176 | return word_seq_tensor, extended_word_seq, pose_seq, vec_seq, audio, spectrogram, aux_info 177 | 178 | def set_lang_model(self, lang_model): 179 | self.lang_model = lang_model 180 | 181 | def _make_speaker_model(self, lmdb_dir, cache_path): 182 | logging.info(' building a speaker model...') 183 | speaker_model = Vocab('vid', insert_default_tokens=False) 184 | 185 | lmdb_env = lmdb.open(lmdb_dir, readonly=True, lock=False) 186 | txn = lmdb_env.begin(write=False) 187 | cursor = txn.cursor() 188 | for key, value in cursor: 189 | video = pyarrow.deserialize(value) 190 | vid = video['vid'] 191 | speaker_model.index_word(vid) 192 | 193 | lmdb_env.close() 194 | logging.info(' indexed %d videos' % speaker_model.n_words) 195 | self.speaker_model = speaker_model 196 | 197 | # cache 198 | with open(cache_path, 'wb') as f: 199 | pickle.dump(self.speaker_model, f) 200 | 201 | -------------------------------------------------------------------------------- /scripts/data_loader/lmdb_data_loader_expressive.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import pickle 5 | import random 6 | 7 | import numpy as np 8 | import lmdb as lmdb 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | from torch.utils.data import Dataset, DataLoader 13 | from torch.utils.data.dataloader import default_collate 14 | 15 | import utils.train_utils_expressive 16 | import utils.data_utils_expressive 17 | from model.vocab import Vocab 18 | from data_loader.data_preprocessor_expressive import DataPreprocessor 19 | import pyarrow 20 | import copy 21 | 22 | 23 | def word_seq_collate_fn(data): 24 | """ collate function for loading word sequences in variable lengths """ 25 | # sort a list by sequence length (descending order) to use pack_padded_sequence 26 | data.sort(key=lambda x: len(x[0]), reverse=True) 27 | 28 | # separate source and target sequences 29 | word_seq, text_padded, poses_seq, vec_seq, audio, spectrogram, aux_info = zip(*data) 30 | 31 | # merge sequences 32 | words_lengths = torch.LongTensor([len(x) for x in word_seq]) 33 | word_seq = pad_sequence(word_seq, batch_first=True).long() 34 | 35 | text_padded = default_collate(text_padded) 36 | poses_seq = default_collate(poses_seq) 37 | vec_seq = default_collate(vec_seq) 38 | audio = default_collate(audio) 39 | spectrogram = default_collate(spectrogram) 40 | aux_info = {key: default_collate([d[key] for d in aux_info]) for key in aux_info[0]} 41 | 42 | return word_seq, words_lengths, text_padded, poses_seq, vec_seq, audio, spectrogram, aux_info 43 | 44 | 45 | def default_collate_fn(data): 46 | _, text_padded, pose_seq, vec_seq, audio, spectrogram, aux_info = zip(*data) 47 | 48 | text_padded = default_collate(text_padded) 49 | pose_seq = default_collate(pose_seq) 50 | vec_seq = default_collate(vec_seq) 51 | audio = default_collate(audio) 52 | spectrogram = default_collate(spectrogram) 53 | aux_info = {key: default_collate([d[key] for d in aux_info]) for key in aux_info[0]} 54 | 55 | return torch.tensor([0]), torch.tensor([0]), text_padded, pose_seq, vec_seq, audio, spectrogram, aux_info 56 | 57 | 58 | class SpeechMotionDataset(Dataset): 59 | def __init__(self, lmdb_dir, n_poses, subdivision_stride, pose_resampling_fps, mean_pose, mean_dir_vec, 60 | speaker_model=None, remove_word_timing=False): 61 | 62 | # self.spec_mean = torch.load('/mnt/lustressd/liuxian.vendor/HA2G/spec_mean.pth') 63 | # self.spec_std = torch.load('/mnt/lustressd/liuxian.vendor/HA2G/spec_std.pth') 64 | 65 | self.lmdb_dir = lmdb_dir 66 | self.n_poses = n_poses 67 | self.subdivision_stride = subdivision_stride 68 | self.skeleton_resampling_fps = pose_resampling_fps 69 | self.mean_dir_vec = mean_dir_vec 70 | self.remove_word_timing = remove_word_timing 71 | 72 | self.expected_audio_length = int(round(n_poses / pose_resampling_fps * 16000)) 73 | self.expected_spectrogram_length = utils.data_utils_expressive.calc_spectrogram_length_from_motion_length( 74 | n_poses, pose_resampling_fps) 75 | 76 | self.lang_model = None 77 | 78 | logging.info("Reading data '{}'...".format(lmdb_dir)) 79 | preloaded_dir = lmdb_dir + '_cache' 80 | if not os.path.exists(preloaded_dir): 81 | logging.info('Creating the dataset cache...') 82 | assert mean_dir_vec is not None 83 | if mean_dir_vec.shape[-1] != 3: 84 | mean_dir_vec = mean_dir_vec.reshape(mean_dir_vec.shape[:-1] + (-1, 3)) 85 | n_poses_extended = int(round(n_poses * 1.25)) # some margin 86 | data_sampler = DataPreprocessor(lmdb_dir, preloaded_dir, n_poses_extended, 87 | subdivision_stride, pose_resampling_fps, mean_pose, mean_dir_vec) 88 | data_sampler.run() 89 | else: 90 | logging.info('Found the cache {}'.format(preloaded_dir)) 91 | 92 | # init lmdb 93 | self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) 94 | with self.lmdb_env.begin() as txn: 95 | self.n_samples = txn.stat()['entries'] 96 | 97 | # make a speaker model 98 | if speaker_model is None or speaker_model == 0: 99 | precomputed_model = lmdb_dir + '_speaker_model.pkl' 100 | if not os.path.exists(precomputed_model): 101 | self._make_speaker_model(lmdb_dir, precomputed_model) 102 | else: 103 | with open(precomputed_model, 'rb') as f: 104 | self.speaker_model = pickle.load(f) 105 | else: 106 | self.speaker_model = speaker_model 107 | 108 | def __len__(self): 109 | return self.n_samples 110 | 111 | def __getitem__(self, idx): 112 | with self.lmdb_env.begin(write=False) as txn: 113 | key = '{:010}'.format(idx).encode('ascii') 114 | sample = txn.get(key) 115 | 116 | sample = pyarrow.deserialize(sample) 117 | word_seq, pose_seq, vec_seq, audio, spectrogram, aux_info = sample 118 | 119 | def extend_word_seq(lang, words, end_time=None): 120 | n_frames = self.n_poses 121 | if end_time is None: 122 | end_time = aux_info['end_time'] 123 | frame_duration = (end_time - aux_info['start_time']) / n_frames 124 | 125 | extended_word_indices = np.zeros(n_frames) # zero is the index of padding token 126 | if self.remove_word_timing: 127 | n_words = 0 128 | for word in words: 129 | idx = max(0, int(np.floor((word[1] - aux_info['start_time']) / frame_duration))) 130 | if idx < n_frames: 131 | n_words += 1 132 | space = int(n_frames / (n_words + 1)) 133 | for i in range(n_words): 134 | idx = (i+1) * space 135 | extended_word_indices[idx] = lang.get_word_index(words[i][0]) 136 | else: 137 | prev_idx = 0 138 | for word in words: 139 | idx = max(0, int(np.floor((word[1] - aux_info['start_time']) / frame_duration))) 140 | if idx < n_frames: 141 | extended_word_indices[idx] = lang.get_word_index(word[0]) 142 | # extended_word_indices[prev_idx:idx+1] = lang.get_word_index(word[0]) 143 | prev_idx = idx 144 | return torch.Tensor(extended_word_indices).long() 145 | 146 | def words_to_tensor(lang, words, end_time=None): 147 | indexes = [lang.SOS_token] 148 | for word in words: 149 | if end_time is not None and word[1] > end_time: 150 | break 151 | indexes.append(lang.get_word_index(word[0])) 152 | indexes.append(lang.EOS_token) 153 | return torch.Tensor(indexes).long() 154 | 155 | duration = aux_info['end_time'] - aux_info['start_time'] 156 | do_clipping = True 157 | 158 | if do_clipping: 159 | sample_end_time = aux_info['start_time'] + duration * self.n_poses / vec_seq.shape[0] 160 | audio = utils.data_utils_expressive.make_audio_fixed_length(audio, self.expected_audio_length) 161 | spectrogram = spectrogram[:, 0:self.expected_spectrogram_length] 162 | vec_seq = vec_seq[0:self.n_poses] 163 | pose_seq = pose_seq[0:self.n_poses] 164 | else: 165 | sample_end_time = None 166 | 167 | # to tensors 168 | word_seq_tensor = words_to_tensor(self.lang_model, word_seq, sample_end_time) 169 | extended_word_seq = extend_word_seq(self.lang_model, word_seq, sample_end_time) 170 | vec_seq = torch.from_numpy(copy.copy(vec_seq)).reshape((vec_seq.shape[0], -1)).float() 171 | pose_seq = torch.from_numpy(copy.copy(pose_seq)).reshape((pose_seq.shape[0], -1)).float() 172 | audio = torch.from_numpy(copy.copy(audio)).float() 173 | spectrogram = torch.from_numpy(copy.copy(spectrogram)).float() 174 | 175 | # spectrogram = (spectrogram - self.spec_mean.unsqueeze(1)) / self.spec_std.unsqueeze(1) 176 | 177 | return word_seq_tensor, extended_word_seq, pose_seq, vec_seq, audio, spectrogram, aux_info 178 | 179 | def set_lang_model(self, lang_model): 180 | self.lang_model = lang_model 181 | 182 | def _make_speaker_model(self, lmdb_dir, cache_path): 183 | logging.info(' building a speaker model...') 184 | speaker_model = Vocab('vid', insert_default_tokens=False) 185 | 186 | lmdb_env = lmdb.open(lmdb_dir, readonly=True, lock=False) 187 | txn = lmdb_env.begin(write=False) 188 | cursor = txn.cursor() 189 | for key, value in cursor: 190 | video = pyarrow.deserialize(value) 191 | vid = video['vid'] 192 | speaker_model.index_word(vid) 193 | 194 | lmdb_env.close() 195 | logging.info(' indexed %d videos' % speaker_model.n_words) 196 | self.speaker_model = speaker_model 197 | 198 | # cache 199 | with open(cache_path, 'wb') as f: 200 | pickle.dump(self.speaker_model, f) 201 | 202 | -------------------------------------------------------------------------------- /scripts/data_loader/motion_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MotionPreprocessor: 5 | def __init__(self, skeletons, mean_pose): 6 | self.skeletons = np.array(skeletons) 7 | self.mean_pose = np.array(mean_pose).reshape(-1, 3) 8 | self.filtering_message = "PASS" 9 | 10 | def get(self): 11 | assert (self.skeletons is not None) 12 | 13 | # filtering 14 | if self.skeletons != []: 15 | if self.check_pose_diff(): 16 | self.skeletons = [] 17 | self.filtering_message = "pose" 18 | elif self.check_spine_angle(): 19 | self.skeletons = [] 20 | self.filtering_message = "spine angle" 21 | elif self.check_static_motion(): 22 | self.skeletons = [] 23 | self.filtering_message = "motion" 24 | 25 | if self.skeletons != []: 26 | self.skeletons = self.skeletons.tolist() 27 | for i, frame in enumerate(self.skeletons): 28 | assert not np.isnan(self.skeletons[i]).any() # missing joints 29 | 30 | return self.skeletons, self.filtering_message 31 | 32 | def check_static_motion(self, verbose=False): 33 | def get_variance(skeleton, joint_idx): 34 | wrist_pos = skeleton[:, joint_idx] 35 | variance = np.sum(np.var(wrist_pos, axis=0)) 36 | return variance 37 | 38 | left_arm_var = get_variance(self.skeletons, 6) 39 | right_arm_var = get_variance(self.skeletons, 9) 40 | 41 | th = 0.0014 # exclude 13110 42 | # th = 0.002 # exclude 16905 43 | if left_arm_var < th and right_arm_var < th: 44 | if verbose: 45 | print('skip - check_static_motion left var {}, right var {}'.format(left_arm_var, right_arm_var)) 46 | return True 47 | else: 48 | if verbose: 49 | print('pass - check_static_motion left var {}, right var {}'.format(left_arm_var, right_arm_var)) 50 | return False 51 | 52 | def check_pose_diff(self, verbose=False): 53 | diff = np.abs(self.skeletons - self.mean_pose) 54 | diff = np.mean(diff) 55 | 56 | # th = 0.017 57 | th = 0.02 # exclude 3594 58 | if diff < th: 59 | if verbose: 60 | print('skip - check_pose_diff {:.5f}'.format(diff)) 61 | return True 62 | else: 63 | if verbose: 64 | print('pass - check_pose_diff {:.5f}'.format(diff)) 65 | return False 66 | 67 | def check_spine_angle(self, verbose=False): 68 | def angle_between(v1, v2): 69 | v1_u = v1 / np.linalg.norm(v1) 70 | v2_u = v2 / np.linalg.norm(v2) 71 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 72 | 73 | angles = [] 74 | for i in range(self.skeletons.shape[0]): 75 | spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] 76 | angle = angle_between(spine_vec, [0, -1, 0]) 77 | angles.append(angle) 78 | 79 | if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 80 | # if np.rad2deg(max(angles)) > 20: # exclude 8270 81 | if verbose: 82 | print('skip - check_spine_angle {:.5f}, {:.5f}'.format(max(angles), np.mean(angles))) 83 | return True 84 | else: 85 | if verbose: 86 | print('pass - check_spine_angle {:.5f}'.format(max(angles))) 87 | return False -------------------------------------------------------------------------------- /scripts/data_loader/motion_preprocessor_expressive.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MotionPreprocessor: 5 | def __init__(self, skeletons, mean_pose): 6 | self.skeletons = np.array(skeletons) 7 | self.mean_pose = np.array(mean_pose).reshape(-1, 3) 8 | self.filtering_message = "PASS" 9 | 10 | def get(self): 11 | assert (self.skeletons is not None) 12 | 13 | # filtering 14 | if self.skeletons != []: 15 | if self.check_pose_diff(): 16 | self.skeletons = [] 17 | self.filtering_message = "pose" 18 | elif self.check_spine_angle(): 19 | self.skeletons = [] 20 | self.filtering_message = "spine angle" 21 | elif self.check_static_motion(): 22 | self.skeletons = [] 23 | self.filtering_message = "motion" 24 | 25 | if self.skeletons != []: 26 | self.skeletons = self.skeletons.tolist() 27 | for i, frame in enumerate(self.skeletons): 28 | assert not np.isnan(self.skeletons[i]).any() # missing joints 29 | 30 | return self.skeletons, self.filtering_message 31 | 32 | def check_static_motion(self, verbose=False): 33 | def get_variance(skeleton, joint_idx): 34 | wrist_pos = skeleton[:, joint_idx] 35 | variance = np.sum(np.var(wrist_pos, axis=0)) 36 | return variance 37 | 38 | left_arm_var = get_variance(self.skeletons, 6) 39 | right_arm_var = get_variance(self.skeletons, 7) 40 | 41 | th = 0.0014 # exclude 13110 42 | # th = 0.002 # exclude 16905 43 | if left_arm_var < th and right_arm_var < th: 44 | if verbose: 45 | print('skip - check_static_motion left var {}, right var {}'.format(left_arm_var, right_arm_var)) 46 | return True 47 | else: 48 | if verbose: 49 | print('pass - check_static_motion left var {}, right var {}'.format(left_arm_var, right_arm_var)) 50 | return False 51 | 52 | def check_pose_diff(self, verbose=False): 53 | diff = np.abs(self.skeletons - self.mean_pose) 54 | diff = np.mean(diff) 55 | 56 | # th = 0.017 57 | th = 0.02 # exclude 3594 58 | if diff < th: 59 | if verbose: 60 | print('skip - check_pose_diff {:.5f}'.format(diff)) 61 | return True 62 | else: 63 | if verbose: 64 | print('pass - check_pose_diff {:.5f}'.format(diff)) 65 | return False 66 | 67 | def check_spine_angle(self, verbose=False): 68 | def angle_between(v1, v2): 69 | v1_u = v1 / np.linalg.norm(v1) 70 | v2_u = v2 / np.linalg.norm(v2) 71 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 72 | 73 | angles = [] 74 | for i in range(self.skeletons.shape[0]): 75 | spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] 76 | angle = angle_between(spine_vec, [0, -1, 0]) 77 | angles.append(angle) 78 | 79 | if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 80 | # if np.rad2deg(max(angles)) > 20: # exclude 8270 81 | if verbose: 82 | print('skip - check_spine_angle {:.5f}, {:.5f}'.format(max(angles), np.mean(angles))) 83 | return True 84 | else: 85 | if verbose: 86 | print('pass - check_spine_angle {:.5f}'.format(max(angles))) 87 | return False -------------------------------------------------------------------------------- /scripts/model/ResNetBlocks.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class SEBasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 11 | super(SEBasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.se = SELayer(planes, reduction) 18 | self.downsample = downsample 19 | self.stride = stride 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | out = self.relu(out) 26 | out = self.bn1(out) 27 | 28 | out = self.conv2(out) 29 | out = self.bn2(out) 30 | out = self.se(out) 31 | 32 | if self.downsample is not None: 33 | residual = self.downsample(x) 34 | 35 | out += residual 36 | out = self.relu(out) 37 | return out 38 | 39 | 40 | class SEBottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 44 | super(SEBottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 48 | padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * 4) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.se = SELayer(planes * 4, reduction) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | residual = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv3(out) 69 | out = self.bn3(out) 70 | out = self.se(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class SELayer(nn.Module): 82 | def __init__(self, channel, reduction=8): 83 | super(SELayer, self).__init__() 84 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 85 | self.fc = nn.Sequential( 86 | nn.Linear(channel, channel // reduction), 87 | nn.ReLU(inplace=True), 88 | nn.Linear(channel // reduction, channel), 89 | nn.Sigmoid() 90 | ) 91 | 92 | def forward(self, x): 93 | b, c, _, _ = x.size() 94 | y = self.avg_pool(x).view(b, c) 95 | y = self.fc(y).view(b, c, 1, 1) 96 | return x * y -------------------------------------------------------------------------------- /scripts/model/motion_ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.nn.parameter import Parameter 6 | import math 7 | 8 | def ConvNormRelu(in_channels, out_channels, downsample=False, padding=0, batchnorm=True): 9 | if not downsample: 10 | k = 3 11 | s = 1 12 | else: 13 | k = 4 14 | s = 2 15 | 16 | conv_block = nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) 17 | norm_block = nn.BatchNorm1d(out_channels) 18 | 19 | if batchnorm: 20 | net = nn.Sequential( 21 | conv_block, 22 | norm_block, 23 | nn.LeakyReLU(0.2, True) 24 | ) 25 | else: 26 | net = nn.Sequential( 27 | conv_block, 28 | nn.LeakyReLU(0.2, True) 29 | ) 30 | 31 | return net 32 | 33 | class PoseEncoderConv(nn.Module): 34 | def __init__(self, length, pose_dim, latent_dim): 35 | super().__init__() 36 | 37 | self.net = nn.Sequential( 38 | ConvNormRelu(pose_dim, 32, batchnorm=True), 39 | ConvNormRelu(32, 64, batchnorm=True), 40 | ConvNormRelu(64, 64, True, batchnorm=True), 41 | nn.Conv1d(64, 32, 3) 42 | ) 43 | 44 | self.out_net = nn.Sequential( 45 | # nn.Linear(864, 256), # for 64 frames 46 | nn.Linear(384, 256), # for 34 frames 47 | nn.BatchNorm1d(256), 48 | nn.LeakyReLU(True), 49 | nn.Linear(256, 128), 50 | nn.BatchNorm1d(128), 51 | nn.LeakyReLU(True), 52 | nn.Linear(128, latent_dim), 53 | ) 54 | 55 | def forward(self, poses): 56 | # encode 57 | poses = poses.transpose(1, 2) # to (bs, dim, seq) 58 | out = self.net(poses) 59 | out = out.flatten(1) 60 | z = self.out_net(out) 61 | 62 | return z 63 | 64 | class PoseDecoderConv(nn.Module): 65 | def __init__(self, length, pose_dim, latent_dim, use_pre_poses=False): 66 | super().__init__() 67 | self.use_pre_poses = use_pre_poses 68 | 69 | feat_size = latent_dim 70 | if use_pre_poses: 71 | self.pre_pose_net = nn.Sequential( 72 | nn.Linear(pose_dim * 4, 32), 73 | nn.BatchNorm1d(32), 74 | nn.ReLU(), 75 | nn.Linear(32, 32), 76 | ) 77 | feat_size += 32 78 | 79 | if length == 64: 80 | self.pre_net = nn.Sequential( 81 | nn.Linear(feat_size, 128), 82 | nn.BatchNorm1d(128), 83 | nn.LeakyReLU(True), 84 | nn.Linear(128, 256), 85 | ) 86 | elif length == 34: 87 | self.pre_net = nn.Sequential( 88 | nn.Linear(feat_size, 64), 89 | nn.BatchNorm1d(64), 90 | nn.LeakyReLU(True), 91 | nn.Linear(64, 136), 92 | ) 93 | else: 94 | assert False 95 | 96 | self.net = nn.Sequential( 97 | nn.ConvTranspose1d(4, 32, 3), 98 | nn.BatchNorm1d(32), 99 | nn.LeakyReLU(0.2, True), 100 | nn.ConvTranspose1d(32, 32, 3), 101 | nn.BatchNorm1d(32), 102 | nn.LeakyReLU(0.2, True), 103 | nn.Conv1d(32, 32, 3), 104 | nn.Conv1d(32, pose_dim, 3), 105 | ) 106 | 107 | def forward(self, feat, pre_poses=None): 108 | if self.use_pre_poses: 109 | pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1)) 110 | feat = torch.cat((pre_pose_feat, feat), dim=1) 111 | 112 | out = self.pre_net(feat) 113 | out = out.view(feat.shape[0], 4, -1) 114 | out = self.net(out) 115 | out = out.transpose(1, 2) 116 | return out 117 | 118 | class MotionAE(nn.Module): 119 | def __init__(self, pose_dim, latent_dim): 120 | super(MotionAE, self).__init__() 121 | 122 | self.encoder = PoseEncoderConv(34, pose_dim, latent_dim) 123 | self.decoder = PoseDecoderConv(34, pose_dim, latent_dim) 124 | 125 | def forward(self, pose): 126 | pose = pose.view(pose.size(0), pose.size(1), -1) 127 | z = self.encoder(pose) 128 | pred = self.decoder(z) 129 | 130 | return pred, z 131 | 132 | 133 | 134 | if __name__ == '__main__': 135 | motion_vae = MotionAE(126, 128) 136 | pose_1 = torch.rand(4, 34, 126) 137 | pose_gt = torch.rand(4, 34, 126) 138 | 139 | pred, z = motion_vae(pose_1) 140 | loss_fn = nn.MSELoss() 141 | print(z.shape) 142 | print(pred.shape) 143 | print(loss_fn(pose_gt, pred)) 144 | -------------------------------------------------------------------------------- /scripts/model/tcn.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py """ 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils import weight_norm 5 | 6 | 7 | class Chomp1d(nn.Module): 8 | def __init__(self, chomp_size): 9 | super(Chomp1d, self).__init__() 10 | self.chomp_size = chomp_size 11 | 12 | def forward(self, x): 13 | return x[:, :, :-self.chomp_size].contiguous() 14 | 15 | 16 | class TemporalBlock(nn.Module): 17 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): 18 | super(TemporalBlock, self).__init__() 19 | self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, 20 | stride=stride, padding=padding, dilation=dilation)) 21 | self.chomp1 = Chomp1d(padding) 22 | self.relu1 = nn.ReLU() 23 | self.dropout1 = nn.Dropout(dropout) 24 | 25 | self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, 26 | stride=stride, padding=padding, dilation=dilation)) 27 | self.chomp2 = Chomp1d(padding) 28 | self.relu2 = nn.ReLU() 29 | self.dropout2 = nn.Dropout(dropout) 30 | 31 | self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, 32 | self.conv2, self.chomp2, self.relu2, self.dropout2) 33 | self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None 34 | self.relu = nn.ReLU() 35 | self.init_weights() 36 | 37 | def init_weights(self): 38 | self.conv1.weight.data.normal_(0, 0.01) 39 | self.conv2.weight.data.normal_(0, 0.01) 40 | if self.downsample is not None: 41 | self.downsample.weight.data.normal_(0, 0.01) 42 | 43 | def forward(self, x): 44 | out = self.net(x) 45 | res = x if self.downsample is None else self.downsample(x) 46 | return self.relu(out + res) 47 | 48 | 49 | class TemporalConvNet(nn.Module): 50 | def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): 51 | super(TemporalConvNet, self).__init__() 52 | layers = [] 53 | num_levels = len(num_channels) 54 | for i in range(num_levels): 55 | dilation_size = 2 ** i 56 | in_channels = num_inputs if i == 0 else num_channels[i-1] 57 | out_channels = num_channels[i] 58 | layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, 59 | padding=(kernel_size-1) * dilation_size, dropout=dropout)] 60 | 61 | self.network = nn.Sequential(*layers) 62 | 63 | def forward(self, x): 64 | return self.network(x) 65 | -------------------------------------------------------------------------------- /scripts/model/utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | """Computes the precision@k for the specified values of k""" 9 | maxk = max(topk) 10 | batch_size = target.size(0) 11 | 12 | _, pred = output.topk(maxk, 1, True, True) 13 | pred = pred.t() 14 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 15 | 16 | res = [] 17 | for k in topk: 18 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 19 | res.append(correct_k.mul_(100.0 / batch_size)) 20 | return res 21 | 22 | class PreEmphasis(torch.nn.Module): 23 | 24 | def __init__(self, coef: float = 0.97): 25 | super().__init__() 26 | self.coef = coef 27 | # make kernel 28 | # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. 29 | self.register_buffer( 30 | 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) 31 | ) 32 | 33 | def forward(self, input: torch.tensor) -> torch.tensor: 34 | assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!' 35 | # reflect padding to match lengths of in/out 36 | input = input.unsqueeze(1) 37 | input = F.pad(input, (1, 0), 'reflect') 38 | return F.conv1d(input, self.flipped_filter).squeeze(1) -------------------------------------------------------------------------------- /scripts/model/vocab.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import numpy as np 5 | import fasttext 6 | 7 | 8 | class Vocab: 9 | PAD_token = 0 10 | SOS_token = 1 11 | EOS_token = 2 12 | UNK_token = 3 13 | 14 | def __init__(self, name, insert_default_tokens=True): 15 | self.name = name 16 | self.trimmed = False 17 | self.word_embedding_weights = None 18 | self.reset_dictionary(insert_default_tokens) 19 | 20 | def reset_dictionary(self, insert_default_tokens=True): 21 | self.word2index = {} 22 | self.word2count = {} 23 | if insert_default_tokens: 24 | self.index2word = {self.PAD_token: "", self.SOS_token: "", 25 | self.EOS_token: "", self.UNK_token: ""} 26 | else: 27 | self.index2word = {self.UNK_token: ""} 28 | self.n_words = len(self.index2word) # count default tokens 29 | 30 | def index_word(self, word): 31 | if word not in self.word2index: 32 | self.word2index[word] = self.n_words 33 | self.word2count[word] = 1 34 | self.index2word[self.n_words] = word 35 | self.n_words += 1 36 | else: 37 | self.word2count[word] += 1 38 | 39 | def add_vocab(self, other_vocab): 40 | for word, _ in other_vocab.word2count.items(): 41 | self.index_word(word) 42 | 43 | # remove words below a certain count threshold 44 | def trim(self, min_count): 45 | if self.trimmed: 46 | return 47 | self.trimmed = True 48 | 49 | keep_words = [] 50 | 51 | for k, v in self.word2count.items(): 52 | if v >= min_count: 53 | keep_words.append(k) 54 | 55 | logging.info(' word trimming, kept %s / %s = %.4f' % ( 56 | len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) 57 | )) 58 | 59 | # reinitialize dictionary 60 | self.reset_dictionary() 61 | for word in keep_words: 62 | self.index_word(word) 63 | 64 | def get_word_index(self, word): 65 | if word in self.word2index: 66 | return self.word2index[word] 67 | else: 68 | return self.UNK_token 69 | 70 | def load_word_vectors(self, pretrained_path, embedding_dim=300): 71 | logging.info(" loading word vectors from '{}'...".format(pretrained_path)) 72 | 73 | # initialize embeddings to random values for special words 74 | init_sd = 1 / np.sqrt(embedding_dim) 75 | weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) 76 | weights = weights.astype(np.float32) 77 | 78 | # read word vectors 79 | word_model = fasttext.load_model(pretrained_path) 80 | for word, id in self.word2index.items(): 81 | vec = word_model.get_word_vector(word) 82 | weights[id] = vec 83 | 84 | self.word_embedding_weights = weights 85 | 86 | def __get_embedding_weight(self, pretrained_path, embedding_dim=300): 87 | """ function modified from http://ronny.rest/blog/post_2017_08_04_glove/ """ 88 | logging.info("Loading word embedding '{}'...".format(pretrained_path)) 89 | cache_path = os.path.splitext(pretrained_path)[0] + '_cache.pkl' 90 | weights = None 91 | 92 | # use cached file if it exists 93 | if os.path.exists(cache_path): # 94 | with open(cache_path, 'rb') as f: 95 | logging.info(' using cached result from {}'.format(cache_path)) 96 | weights = pickle.load(f) 97 | if weights.shape != (self.n_words, embedding_dim): 98 | logging.warning(' failed to load word embedding weights. reinitializing...') 99 | weights = None 100 | 101 | if weights is None: 102 | # initialize embeddings to random values for special and OOV words 103 | init_sd = 1 / np.sqrt(embedding_dim) 104 | weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) 105 | weights = weights.astype(np.float32) 106 | 107 | with open(pretrained_path, encoding="utf-8", mode="r") as textFile: 108 | num_embedded_words = 0 109 | for line_raw in textFile: 110 | # extract the word, and embeddings vector 111 | line = line_raw.split() 112 | try: 113 | word, vector = (line[0], np.array(line[1:], dtype=np.float32)) 114 | # if word == 'love': # debugging 115 | # print(word, vector) 116 | 117 | # if it is in our vocab, then update the corresponding weights 118 | id = self.word2index.get(word, None) 119 | if id is not None: 120 | weights[id] = vector 121 | num_embedded_words += 1 122 | except ValueError: 123 | logging.info(' parsing error at {}...'.format(line_raw[:50])) 124 | continue 125 | logging.info(' {} / {} word vectors are found in the embedding'.format(num_embedded_words, len(self.word2index))) 126 | 127 | with open(cache_path, 'wb') as f: 128 | pickle.dump(weights, f) 129 | 130 | return weights 131 | -------------------------------------------------------------------------------- /scripts/parse_args.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | 4 | def str2bool(v): 5 | """ from https://stackoverflow.com/a/43357954/1361529 """ 6 | if isinstance(v, bool): 7 | return v 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise configargparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | def parse_args(): 17 | parser = configargparse.ArgParser() 18 | parser.add('-c', '--config', required=True, is_config_file=True, help='Config file path') 19 | parser.add("--name", type=str, default="main") 20 | parser.add("--train_data_path", action="append") 21 | parser.add("--val_data_path", action="append") 22 | parser.add("--test_data_path", action="append") 23 | parser.add("--model_save_path", required=True) 24 | parser.add("--pose_representation", type=str, default='3d_vec') 25 | parser.add("--mean_dir_vec", action="append", type=float, nargs='*') 26 | parser.add("--mean_pose", action="append", type=float, nargs='*') 27 | parser.add("--random_seed", type=int, default=-1) 28 | parser.add("--save_result_video", type=str2bool, default=True) 29 | 30 | # word embedding 31 | parser.add("--wordembed_path", type=str, default=None) 32 | parser.add("--wordembed_dim", type=int, default=100) 33 | parser.add("--freeze_wordembed", type=str2bool, default=False) 34 | 35 | # model 36 | parser.add("--model", type=str, required=True) 37 | parser.add("--epochs", type=int, default=10) 38 | parser.add("--batch_size", type=int, default=50) 39 | parser.add("--dropout_prob", type=float, default=0.3) 40 | parser.add("--n_layers", type=int, default=2) 41 | parser.add("--hidden_size", type=int, default=200) 42 | parser.add("--z_type", type=str, default='none') 43 | parser.add("--input_context", type=str, default='both') 44 | 45 | # dataset 46 | parser.add("--motion_resampling_framerate", type=int, default=24) 47 | parser.add("--n_poses", type=int, default=50) 48 | parser.add("--n_pre_poses", type=int, default=5) 49 | parser.add("--subdivision_stride", type=int, default=5) 50 | parser.add("--loader_workers", type=int, default=0) 51 | 52 | parser.add("--pose_dim", type=int, required=True) 53 | parser.add("--latent_dim", type=int, default=128) 54 | 55 | # GAN parameter 56 | parser.add("--GAN_noise_size", type=int, default=0) 57 | 58 | # training 59 | parser.add("--learning_rate", type=float, default=0.0005) 60 | parser.add("--discriminator_lr_weight", type=float, default=0.2) 61 | parser.add("--loss_regression_weight", type=float, default=70.0) 62 | parser.add("--loss_gan_weight", type=float, default=1.0) 63 | parser.add("--loss_kld_weight", type=float, default=0.1) 64 | parser.add("--loss_reg_weight", type=float, default=0.01) 65 | parser.add("--loss_warmup", type=int, default=-1) 66 | 67 | parser.add("--loss_contrastive_pos_weight", type=float, default=0.2) 68 | parser.add("--loss_contrastive_neg_weight", type=float, default=0.005) 69 | 70 | parser.add("--loss_physical_weight", type=float, default=0.01) 71 | 72 | parser.add("--mse_loss_weight", type=float, default=50) 73 | parser.add("--cos_loss_weight", type=float, default=50) 74 | parser.add("--static_loss_weight", type=float, default=50) 75 | parser.add("--motion_loss_weight", type=float, default=50) 76 | 77 | parser.add("--g_update_step", type=int, default=5) 78 | 79 | # eval 80 | parser.add("--eval_net_path", type=str, default='') 81 | 82 | args = parser.parse_args() 83 | return args 84 | -------------------------------------------------------------------------------- /scripts/train_eval/train_gan.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def add_noise(data): 9 | noise = torch.randn_like(data) * 0.1 10 | return data + noise 11 | 12 | 13 | def train_iter_gan(args, epoch, in_text, in_audio, target_poses, vid_indices, 14 | pose_decoder, discriminator, 15 | pose_dec_optim, dis_optim): 16 | warm_up_epochs = args.loss_warmup 17 | use_noisy_target = False 18 | 19 | # make pre seq input 20 | pre_seq = target_poses.new_zeros((target_poses.shape[0], target_poses.shape[1], target_poses.shape[2] + 1)) 21 | pre_seq[:, 0:args.n_pre_poses, :-1] = target_poses[:, 0:args.n_pre_poses] 22 | pre_seq[:, 0:args.n_pre_poses, -1] = 1 # indicating bit for constraints 23 | 24 | ########################################################################################### 25 | # train D 26 | dis_error = None 27 | if epoch > warm_up_epochs and args.loss_gan_weight > 0.0: 28 | dis_optim.zero_grad() 29 | 30 | out_dir_vec, *_ = pose_decoder(pre_seq, in_text, in_audio, vid_indices) # out shape (batch x seq x dim) 31 | 32 | if use_noisy_target: 33 | noise_target = add_noise(target_poses) 34 | noise_out = add_noise(out_dir_vec.detach()) 35 | dis_real = discriminator(noise_target, in_text) 36 | dis_fake = discriminator(noise_out, in_text) 37 | else: 38 | dis_real = discriminator(target_poses, in_text) 39 | dis_fake = discriminator(out_dir_vec.detach(), in_text) 40 | 41 | dis_error = torch.sum(-torch.mean(torch.log(dis_real + 1e-8) + torch.log(1 - dis_fake + 1e-8))) # ns-gan 42 | dis_error.backward() 43 | dis_optim.step() 44 | 45 | ########################################################################################### 46 | # train G 47 | pose_dec_optim.zero_grad() 48 | 49 | # decoding 50 | out_dir_vec, z, z_mu, z_logvar = pose_decoder(pre_seq, in_text, in_audio, vid_indices) 51 | 52 | # loss 53 | beta = 0.1 54 | huber_loss = F.smooth_l1_loss(out_dir_vec / beta, target_poses / beta) * beta 55 | dis_output = discriminator(out_dir_vec, in_text) 56 | gen_error = -torch.mean(torch.log(dis_output + 1e-8)) 57 | kld = div_reg = None 58 | 59 | if (args.z_type == 'speaker' or args.z_type == 'random') and args.loss_reg_weight > 0.0: 60 | if args.z_type == 'speaker': 61 | # enforcing divergent gestures btw original vid and other vid 62 | rand_idx = torch.randperm(vid_indices.shape[0]) 63 | rand_vids = vid_indices[rand_idx] 64 | else: 65 | rand_vids = None 66 | 67 | out_dir_vec_rand_vid, z_rand_vid, _, _ = pose_decoder(pre_seq, in_text, in_audio, rand_vids) 68 | beta = 0.05 69 | pose_l1 = F.smooth_l1_loss(out_dir_vec / beta, out_dir_vec_rand_vid.detach() / beta, reduction='none') * beta 70 | pose_l1 = pose_l1.sum(dim=1).sum(dim=1) 71 | 72 | pose_l1 = pose_l1.view(pose_l1.shape[0], -1).mean(1) 73 | z_l1 = F.l1_loss(z.detach(), z_rand_vid.detach(), reduction='none') 74 | z_l1 = z_l1.view(z_l1.shape[0], -1).mean(1) 75 | div_reg = -(pose_l1 / (z_l1 + 1.0e-5)) 76 | div_reg = torch.clamp(div_reg, min=-1000) 77 | div_reg = div_reg.mean() 78 | 79 | if args.z_type == 'speaker': 80 | # speaker embedding KLD 81 | kld = -0.5 * torch.mean(1 + z_logvar - z_mu.pow(2) - z_logvar.exp()) 82 | loss = args.loss_regression_weight * huber_loss + args.loss_kld_weight * kld + args.loss_reg_weight * div_reg 83 | else: 84 | loss = args.loss_regression_weight * huber_loss + args.loss_reg_weight * div_reg 85 | else: 86 | loss = args.loss_regression_weight * huber_loss #+ var_loss 87 | 88 | if epoch > warm_up_epochs: 89 | loss += args.loss_gan_weight * gen_error 90 | 91 | loss.backward() 92 | pose_dec_optim.step() 93 | 94 | ret_dict = {'loss': args.loss_regression_weight * huber_loss.item()} 95 | if kld: 96 | ret_dict['KLD'] = args.loss_kld_weight * kld.item() 97 | if div_reg: 98 | ret_dict['DIV_REG'] = args.loss_reg_weight * div_reg.item() 99 | 100 | if epoch > warm_up_epochs and args.loss_gan_weight > 0.0: 101 | ret_dict['gen'] = args.loss_gan_weight * gen_error.item() 102 | ret_dict['dis'] = dis_error.item() 103 | return ret_dict 104 | 105 | -------------------------------------------------------------------------------- /scripts/train_eval/train_joint_embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def train_iter_embed(args, epoch, in_text, in_audio, target_data, net, optim, mode=None): 6 | pre_seq = target_data[:, 0:args.n_pre_poses] 7 | 8 | # zero gradients 9 | optim.zero_grad() 10 | 11 | if mode == 'random': # joint embed model 12 | variational_encoding = False # AE 13 | else: # feature extractor in FGD 14 | variational_encoding = False # VAE or AE 15 | 16 | # reconstruction loss 17 | context_feat, context_mu, context_logvar, poses_feat, pose_mu, pose_logvar, recon_data = \ 18 | net(in_text, in_audio, pre_seq, target_data, mode, variational_encoding=variational_encoding) 19 | 20 | recon_loss = F.l1_loss(recon_data, target_data, reduction='none') 21 | recon_loss = torch.mean(recon_loss, dim=(1, 2)) 22 | 23 | if False: # use pose diff 24 | target_diff = target_data[:, 1:] - target_data[:, :-1] 25 | recon_diff = recon_data[:, 1:] - recon_data[:, :-1] 26 | recon_loss += torch.mean(F.l1_loss(recon_diff, target_diff, reduction='none'), dim=(1, 2)) 27 | 28 | recon_loss = torch.sum(recon_loss) 29 | 30 | # KLD 31 | if variational_encoding: 32 | if net.mode == 'speech': 33 | KLD = -0.5 * torch.sum(1 + context_logvar - context_mu.pow(2) - context_logvar.exp()) 34 | else: 35 | KLD = -0.5 * torch.sum(1 + pose_logvar - pose_mu.pow(2) - pose_logvar.exp()) 36 | 37 | if epoch < 10: 38 | KLD_weight = 0 39 | else: 40 | KLD_weight = min(1.0, (epoch - 10) * args.loss_kld_weight) 41 | loss = args.loss_regression_weight * recon_loss + KLD_weight * KLD 42 | else: 43 | loss = recon_loss 44 | 45 | loss.backward() 46 | optim.step() 47 | 48 | ret_dict = {'loss': recon_loss.item()} 49 | if variational_encoding: 50 | ret_dict['KLD'] = KLD.item() 51 | return ret_dict 52 | 53 | 54 | def eval_embed(in_text, in_audio, pre_poses, target_poses, net, mode=None): 55 | context_feat, context_mu, context_logvar, poses_feat, pose_mu, pose_logvar, recon_poses = \ 56 | net(in_text, in_audio, pre_poses, target_poses, mode, variational_encoding=False) 57 | 58 | recon_loss = F.l1_loss(recon_poses, target_poses, reduction='none') 59 | recon_loss = torch.mean(recon_loss, dim=(1, 2)) 60 | loss = torch.mean(recon_loss) 61 | 62 | return loss, recon_poses 63 | -------------------------------------------------------------------------------- /scripts/train_eval/train_seq2seq.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | loss_i = 0 6 | def custom_loss(output, target, args, epoch): 7 | n_element = output.numel() 8 | 9 | # mae 10 | mse_loss = F.mse_loss(output, target) 11 | mse_loss *= args.loss_regression_weight 12 | 13 | # continuous motion 14 | diff = [abs(output[:, n, :] - output[:, n-1, :]) for n in range(1, output.shape[1])] 15 | cont_loss = torch.sum(torch.stack(diff)) / n_element 16 | cont_loss *= args.loss_kld_weight 17 | 18 | # motion variance 19 | norm = torch.norm(output, 2, 1) # output shape (batch, seq, dim) 20 | var_loss = -torch.sum(norm) / n_element 21 | var_loss *= args.loss_reg_weight 22 | 23 | loss = mse_loss + cont_loss + var_loss 24 | 25 | # debugging code 26 | global loss_i 27 | if loss_i == 1000: 28 | logging.debug('(custom loss) mse %.5f, cont %.5f, var %.5f' 29 | % (mse_loss.item(), cont_loss.item(), var_loss.item())) 30 | loss_i = 0 31 | loss_i += 1 32 | 33 | return loss 34 | 35 | 36 | def train_iter_seq2seq(args, epoch, in_text, in_lengths, target_poses, net, optim): 37 | # zero gradients 38 | optim.zero_grad() 39 | 40 | # generation 41 | outputs = net(in_text, in_lengths, target_poses, None) 42 | 43 | # loss 44 | loss = custom_loss(outputs, target_poses, args, epoch) 45 | loss.backward() 46 | 47 | # optimize 48 | torch.nn.utils.clip_grad_norm_(net.parameters(), 5) 49 | optim.step() 50 | 51 | return {'loss': loss.item()} 52 | -------------------------------------------------------------------------------- /scripts/train_eval/train_speech2gesture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def train_iter_speech2gesture(args, in_spec, target_poses, pose_decoder, discriminator, 6 | pose_dec_optim, dis_optim, loss_fn): 7 | # generation 8 | pre_poses = target_poses[:, 0:args.n_pre_poses] 9 | out_poses = pose_decoder(in_spec, pre_poses) 10 | 11 | # to motion 12 | target_motion = target_poses[:, 1:] - target_poses[:, :-1] 13 | out_motion = out_poses[:, 1:] - out_poses[:, :-1] 14 | 15 | ########################################################################################### 16 | # train D 17 | dis_optim.zero_grad() 18 | dis_real = discriminator(target_motion) 19 | dis_fake = discriminator(out_motion.detach()) 20 | dis_error = F.mse_loss(torch.ones_like(dis_real), dis_real) + F.mse_loss(torch.zeros_like(dis_fake), dis_fake) 21 | 22 | dis_error.backward() 23 | dis_optim.step() 24 | 25 | ########################################################################################### 26 | # train G 27 | pose_dec_optim.zero_grad() 28 | l1_loss = loss_fn(out_poses, target_poses) 29 | dis_output = discriminator(out_motion) 30 | gen_error = F.mse_loss(torch.ones_like(dis_output), dis_output) 31 | 32 | loss = args.loss_regression_weight * l1_loss + args.loss_gan_weight * gen_error 33 | loss.backward() 34 | pose_dec_optim.step() 35 | 36 | return {'loss': args.loss_regression_weight * l1_loss.item(), 'gen': args.loss_gan_weight * gen_error.item(), 37 | 'dis': dis_error.item()} 38 | -------------------------------------------------------------------------------- /scripts/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | """Computes and stores the average and current value""" 4 | def __init__(self, name, fmt=':f'): 5 | self.name = name 6 | self.fmt = fmt 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | def __str__(self): 22 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 23 | return fmtstr.format(**self.__dict__) 24 | -------------------------------------------------------------------------------- /scripts/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import librosa 4 | import numpy as np 5 | import torch 6 | from scipy.interpolate import interp1d 7 | from sklearn.preprocessing import normalize 8 | 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | skeleton_line_pairs = [(0, 1, 'b'), (1, 2, 'darkred'), (2, 3, 'r'), (3, 4, 'orange'), (1, 5, 'darkgreen'), 13 | (5, 6, 'limegreen'), (6, 7, 'darkseagreen')] 14 | dir_vec_pairs = [(0, 1, 0.26), (1, 2, 0.18), (2, 3, 0.14), (1, 4, 0.22), (4, 5, 0.36), 15 | (5, 6, 0.33), (1, 7, 0.22), (7, 8, 0.36), (8, 9, 0.33)] # adjacency and bone length 16 | 17 | 18 | def normalize_string(s): 19 | """ lowercase, trim, and remove non-letter characters """ 20 | s = s.lower().strip() 21 | s = re.sub(r"([,.!?])", r" \1 ", s) # isolate some marks 22 | s = re.sub(r"(['])", r"", s) # remove apostrophe 23 | s = re.sub(r"[^a-zA-Z,.!?]+", r" ", s) # replace other characters with whitespace 24 | s = re.sub(r"\s+", r" ", s).strip() 25 | return s 26 | 27 | 28 | def remove_tags_marks(text): 29 | reg_expr = re.compile('<.*?>|[.,:;!?]+') 30 | clean_text = re.sub(reg_expr, '', text) 31 | return clean_text 32 | 33 | 34 | def extract_melspectrogram(y, sr=16000): 35 | melspec = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=1024, hop_length=512, power=2) 36 | log_melspec = librosa.power_to_db(melspec, ref=np.max) # mels x time 37 | log_melspec = log_melspec.astype('float16') 38 | return log_melspec 39 | 40 | 41 | def calc_spectrogram_length_from_motion_length(n_frames, fps): 42 | ret = (n_frames / fps * 16000 - 1024) / 512 + 1 43 | return int(round(ret)) 44 | 45 | 46 | def resample_pose_seq(poses, duration_in_sec, fps): 47 | n = len(poses) 48 | x = np.arange(0, n) 49 | y = poses 50 | f = interp1d(x, y, axis=0, kind='linear', fill_value='extrapolate') 51 | expected_n = duration_in_sec * fps 52 | x_new = np.arange(0, n, n / expected_n) 53 | interpolated_y = f(x_new) 54 | if hasattr(poses, 'dtype'): 55 | interpolated_y = interpolated_y.astype(poses.dtype) 56 | return interpolated_y 57 | 58 | 59 | def time_stretch_for_words(words, start_time, speech_speed_rate): 60 | for i in range(len(words)): 61 | if words[i][1] > start_time: 62 | words[i][1] = start_time + (words[i][1] - start_time) / speech_speed_rate 63 | words[i][2] = start_time + (words[i][2] - start_time) / speech_speed_rate 64 | 65 | return words 66 | 67 | 68 | def make_audio_fixed_length(audio, expected_audio_length): 69 | n_padding = expected_audio_length - len(audio) 70 | if n_padding > 0: 71 | audio = np.pad(audio, (0, n_padding), mode='symmetric') 72 | else: 73 | audio = audio[0:expected_audio_length] 74 | return audio 75 | 76 | 77 | def convert_dir_vec_to_pose(vec): 78 | vec = np.array(vec) 79 | 80 | if vec.shape[-1] != 3: 81 | vec = vec.reshape(vec.shape[:-1] + (-1, 3)) 82 | 83 | if len(vec.shape) == 2: 84 | joint_pos = np.zeros((10, 3)) 85 | for j, pair in enumerate(dir_vec_pairs): 86 | joint_pos[pair[1]] = joint_pos[pair[0]] + pair[2] * vec[j] 87 | elif len(vec.shape) == 3: 88 | joint_pos = np.zeros((vec.shape[0], 10, 3)) 89 | for j, pair in enumerate(dir_vec_pairs): 90 | joint_pos[:, pair[1]] = joint_pos[:, pair[0]] + pair[2] * vec[:, j] 91 | elif len(vec.shape) == 4: # (batch, seq, 9, 3) 92 | joint_pos = np.zeros((vec.shape[0], vec.shape[1], 10, 3)) 93 | for j, pair in enumerate(dir_vec_pairs): 94 | joint_pos[:, :, pair[1]] = joint_pos[:, :, pair[0]] + pair[2] * vec[:, :, j] 95 | else: 96 | assert False 97 | 98 | return joint_pos 99 | 100 | 101 | def convert_pose_seq_to_dir_vec(pose): 102 | if pose.shape[-1] != 3: 103 | pose = pose.reshape(pose.shape[:-1] + (-1, 3)) 104 | 105 | if len(pose.shape) == 3: 106 | dir_vec = np.zeros((pose.shape[0], len(dir_vec_pairs), 3)) 107 | for i, pair in enumerate(dir_vec_pairs): 108 | dir_vec[:, i] = pose[:, pair[1]] - pose[:, pair[0]] 109 | dir_vec[:, i, :] = normalize(dir_vec[:, i, :], axis=1) # to unit length 110 | elif len(pose.shape) == 4: # (batch, seq, ...) 111 | dir_vec = np.zeros((pose.shape[0], pose.shape[1], len(dir_vec_pairs), 3)) 112 | for i, pair in enumerate(dir_vec_pairs): 113 | dir_vec[:, :, i] = pose[:, :, pair[1]] - pose[:, :, pair[0]] 114 | for j in range(dir_vec.shape[0]): # batch 115 | for i in range(len(dir_vec_pairs)): 116 | dir_vec[j, :, i, :] = normalize(dir_vec[j, :, i, :], axis=1) # to unit length 117 | else: 118 | assert False 119 | 120 | return dir_vec 121 | -------------------------------------------------------------------------------- /scripts/utils/data_utils_expressive.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import librosa 4 | import numpy as np 5 | import torch 6 | from scipy.interpolate import interp1d 7 | from torch.nn.functional import normalize 8 | 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | dir_vec_pairs = [ 12 | (0, 1, 0.26), # 0, spine-neck 13 | (1, 2, 0.22), # 1, neck-left shoulder 14 | (1, 3, 0.22), # 2, neck-right shoulder 15 | (2, 4, 0.36), # 3, left shoulder-elbow 16 | (4, 6, 0.33), # 4, left elbow-wrist 17 | 18 | (6, 8, 0.137), # 5 wrist-left index 1 19 | (8, 9, 0.044), # 6 20 | (9, 10, 0.031), # 7 21 | 22 | (6, 11, 0.144), # 8 wrist-left middle 1 23 | (11, 12, 0.042), # 9 24 | (12, 13, 0.033), # 10 25 | 26 | (6, 14, 0.127), # 11 wrist-left pinky 1 27 | (14, 15, 0.027), # 12 28 | (15, 16, 0.026), # 13 29 | 30 | (6, 17, 0.134), # 14 wrist-left ring 1 31 | (17, 18, 0.039), # 15 32 | (18, 19, 0.033), # 16 33 | 34 | (6, 20, 0.068), # 17 wrist-left thumb 1 35 | (20, 21, 0.042), # 18 36 | (21, 22, 0.036), # 19 37 | 38 | (3, 5, 0.36), # 20, right shoulder-elbow 39 | (5, 7, 0.33), # 21, right elbow-wrist 40 | 41 | (7, 23, 0.137), # 22 wrist-right index 1 42 | (23, 24, 0.044), # 23 43 | (24, 25, 0.031), # 24 44 | 45 | (7, 26, 0.144), # 25 wrist-right middle 1 46 | (26, 27, 0.042), # 26 47 | (27, 28, 0.033), # 27 48 | 49 | (7, 29, 0.127), # 28 wrist-right pinky 1 50 | (29, 30, 0.027), # 29 51 | (30, 31, 0.026), # 30 52 | 53 | (7, 32, 0.134), # 31 wrist-right ring 1 54 | (32, 33, 0.039), # 32 55 | (33, 34, 0.033), # 33 56 | 57 | (7, 35, 0.068), # 34 wrist-right thumb 1 58 | (35, 36, 0.042), # 35 59 | (36, 37, 0.036), # 36 60 | 61 | (1, 38, 0.18), # 37, neck-nose 62 | (38, 39, 0.14), # 38, nose-right eye 63 | (38, 40, 0.14), # 39, nose-left eye 64 | (39, 41, 0.15), # 40, right eye-right ear 65 | (40, 42, 0.15), # 41, left eye-left ear 66 | ] 67 | 68 | def normalize_string(s): 69 | """ lowercase, trim, and remove non-letter characters """ 70 | s = s.lower().strip() 71 | s = re.sub(r"([,.!?])", r" \1 ", s) # isolate some marks 72 | s = re.sub(r"(['])", r"", s) # remove apostrophe 73 | s = re.sub(r"[^a-zA-Z,.!?]+", r" ", s) # replace other characters with whitespace 74 | s = re.sub(r"\s+", r" ", s).strip() 75 | return s 76 | 77 | 78 | def remove_tags_marks(text): 79 | reg_expr = re.compile('<.*?>|[.,:;!?]+') 80 | clean_text = re.sub(reg_expr, '', text) 81 | return clean_text 82 | 83 | 84 | def extract_melspectrogram(y, sr=16000): 85 | melspec = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=1024, hop_length=512, power=2) 86 | log_melspec = librosa.power_to_db(melspec, ref=np.max) # mels x time 87 | log_melspec = log_melspec.astype('float16') 88 | return log_melspec 89 | 90 | 91 | def calc_spectrogram_length_from_motion_length(n_frames, fps): 92 | ret = (n_frames / fps * 16000 - 1024) / 512 + 1 93 | return int(round(ret)) 94 | 95 | 96 | def resample_pose_seq(poses, duration_in_sec, fps): 97 | n = len(poses) 98 | x = np.arange(0, n) 99 | y = poses 100 | f = interp1d(x, y, axis=0, kind='linear', fill_value='extrapolate') 101 | expected_n = duration_in_sec * fps 102 | x_new = np.arange(0, n, n / expected_n) 103 | interpolated_y = f(x_new) 104 | if hasattr(poses, 'dtype'): 105 | interpolated_y = interpolated_y.astype(poses.dtype) 106 | return interpolated_y 107 | 108 | 109 | def time_stretch_for_words(words, start_time, speech_speed_rate): 110 | for i in range(len(words)): 111 | if words[i][1] > start_time: 112 | words[i][1] = start_time + (words[i][1] - start_time) / speech_speed_rate 113 | words[i][2] = start_time + (words[i][2] - start_time) / speech_speed_rate 114 | 115 | return words 116 | 117 | 118 | def make_audio_fixed_length(audio, expected_audio_length): 119 | n_padding = expected_audio_length - len(audio) 120 | if n_padding > 0: 121 | audio = np.pad(audio, (0, n_padding), mode='symmetric') 122 | else: 123 | audio = audio[0:expected_audio_length] 124 | return audio 125 | 126 | 127 | def convert_dir_vec_to_pose(vec): 128 | # vec = np.array(vec) 129 | 130 | if vec.shape[-1] != 3: 131 | vec = vec.reshape(vec.shape[:-1] + (-1, 3)) 132 | 133 | if len(vec.shape) == 2: 134 | joint_pos = torch.zeros((43, 3)).to(device) 135 | for j, pair in enumerate(dir_vec_pairs): 136 | joint_pos[pair[1]] = joint_pos[pair[0]] + torch.Tensor([pair[2]]).to(device) * vec[j] 137 | elif len(vec.shape) == 3: 138 | joint_pos = torch.zeros((vec.shape[0], 43, 3)).to(device) 139 | for j, pair in enumerate(dir_vec_pairs): 140 | joint_pos[:, pair[1]] = joint_pos[:, pair[0]] + torch.Tensor([pair[2]]).to(device) * vec[:, j] 141 | elif len(vec.shape) == 4: # (batch, seq, 42, 3) 142 | joint_pos = torch.zeros((vec.shape[0], vec.shape[1], 43, 3)).to(device) 143 | for j, pair in enumerate(dir_vec_pairs): 144 | joint_pos[:, :, pair[1]] = joint_pos[:, :, pair[0]] + torch.Tensor([pair[2]]).to(device) * vec[:, :, j] 145 | else: 146 | assert False 147 | 148 | return joint_pos 149 | 150 | def convert_pose_seq_to_dir_vec(pose): 151 | if pose.shape[-1] != 3: 152 | pose = pose.reshape(pose.shape[:-1] + (-1, 3)) 153 | 154 | if len(pose.shape) == 3: 155 | dir_vec = torch.zeros((pose.shape[0], len(dir_vec_pairs), 3)) 156 | for i, pair in enumerate(dir_vec_pairs): 157 | dir_vec[:, i] = pose[:, pair[1]] - pose[:, pair[0]] 158 | dir_vec[:, i, :] = normalize(dir_vec[:, i, :], dim=1) # to unit length 159 | elif len(pose.shape) == 4: # (batch, seq, ...) 160 | dir_vec = torch.zeros((pose.shape[0], pose.shape[1], len(dir_vec_pairs), 3)) 161 | for i, pair in enumerate(dir_vec_pairs): 162 | dir_vec[:, :, i] = pose[:, :, pair[1]] - pose[:, :, pair[0]] 163 | for j in range(dir_vec.shape[0]): # batch 164 | for i in range(len(dir_vec_pairs)): 165 | dir_vec[j, :, i, :] = normalize(dir_vec[j, :, i, :], dim=1) # to unit length 166 | else: 167 | assert False 168 | 169 | return dir_vec 170 | -------------------------------------------------------------------------------- /scripts/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import random 5 | import subprocess 6 | from collections import defaultdict, namedtuple 7 | from logging.handlers import RotatingFileHandler 8 | from textwrap import wrap 9 | 10 | import numpy as np 11 | import re 12 | import time 13 | import math 14 | import soundfile as sf 15 | import librosa.display 16 | 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | 21 | import torch 22 | import matplotlib.ticker as ticker 23 | import matplotlib.animation as animation 24 | from mpl_toolkits import mplot3d 25 | 26 | import utils.data_utils 27 | import train 28 | import data_loader.lmdb_data_loader 29 | 30 | 31 | # only for unicode characters, you may remove these two lines 32 | from model import vocab 33 | 34 | matplotlib.rcParams['axes.unicode_minus'] = False 35 | 36 | 37 | def set_logger(log_path=None, log_filename='log'): 38 | for handler in logging.root.handlers[:]: 39 | logging.root.removeHandler(handler) 40 | handlers = [logging.StreamHandler()] 41 | if log_path is not None: 42 | os.makedirs(log_path, exist_ok=True) 43 | handlers.append( 44 | RotatingFileHandler(os.path.join(log_path, log_filename), maxBytes=10 * 1024 * 1024, backupCount=5)) 45 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s: %(message)s', handlers=handlers) 46 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 47 | 48 | 49 | def as_minutes(s): 50 | m = math.floor(s / 60) 51 | s -= m * 60 52 | return '%dm %ds' % (m, s) 53 | 54 | 55 | def time_since(since): 56 | now = time.time() 57 | s = now - since 58 | return '%s' % as_minutes(s) 59 | 60 | 61 | def create_video_and_save(save_path, epoch, prefix, iter_idx, target, output, mean_data, title, 62 | audio=None, aux_str=None, clipping_to_shortest_stream=False, delete_audio_file=True): 63 | print('rendering a video...') 64 | start = time.time() 65 | 66 | fig = plt.figure(figsize=(8, 4)) 67 | axes = [fig.add_subplot(1, 2, 1, projection='3d'), fig.add_subplot(1, 2, 2, projection='3d')] 68 | axes[0].view_init(elev=20, azim=-60) 69 | axes[1].view_init(elev=20, azim=-60) 70 | fig_title = title 71 | 72 | if aux_str: 73 | fig_title += ('\n' + aux_str) 74 | fig.suptitle('\n'.join(wrap(fig_title, 75)), fontsize='medium') 75 | 76 | # un-normalization and convert to poses 77 | mean_data = mean_data.flatten() 78 | output = output + mean_data 79 | output_poses = utils.data_utils.convert_dir_vec_to_pose(output) 80 | target_poses = None 81 | if target is not None: 82 | target = target + mean_data 83 | target_poses = utils.data_utils.convert_dir_vec_to_pose(target) 84 | 85 | def animate(i): 86 | for k, name in enumerate(['human', 'generated']): 87 | if name == 'human' and target is not None and i < len(target): 88 | pose = target_poses[i] 89 | elif name == 'generated' and i < len(output): 90 | pose = output_poses[i] 91 | else: 92 | pose = None 93 | 94 | if pose is not None: 95 | axes[k].clear() 96 | for j, pair in enumerate(utils.data_utils.dir_vec_pairs): 97 | axes[k].plot([pose[pair[0], 0], pose[pair[1], 0]], 98 | [pose[pair[0], 2], pose[pair[1], 2]], 99 | [pose[pair[0], 1], pose[pair[1], 1]], 100 | zdir='z', linewidth=1.5) 101 | axes[k].set_xlim3d(-0.5, 0.5) 102 | axes[k].set_ylim3d(0.5, -0.5) 103 | axes[k].set_zlim3d(0.5, -0.5) 104 | axes[k].set_xlabel('x') 105 | axes[k].set_ylabel('z') 106 | axes[k].set_zlabel('y') 107 | axes[k].set_title('{} ({}/{})'.format(name, i + 1, len(output))) 108 | 109 | if target is not None: 110 | num_frames = max(len(target), len(output)) 111 | else: 112 | num_frames = len(output) 113 | ani = animation.FuncAnimation(fig, animate, interval=30, frames=num_frames, repeat=False) 114 | 115 | # show audio 116 | audio_path = None 117 | if audio is not None: 118 | assert len(audio.shape) == 1 # 1-channel, raw signal 119 | audio = audio.astype(np.float32) 120 | sr = 16000 121 | audio_path = '{}/{}_{:03d}_{}.wav'.format(save_path, prefix, epoch, iter_idx) 122 | sf.write(audio_path, audio, sr) 123 | 124 | # save video 125 | try: 126 | video_path = '{}/temp_{}_{:03d}_{}.mp4'.format(save_path, prefix, epoch, iter_idx) 127 | ani.save(video_path, fps=15, dpi=80) # dpi 150 for a higher resolution 128 | del ani 129 | plt.close(fig) 130 | except RuntimeError: 131 | assert False, 'RuntimeError' 132 | 133 | # merge audio and video 134 | if audio is not None: 135 | merged_video_path = '{}/{}_{:03d}_{}.mp4'.format(save_path, prefix, epoch, iter_idx) 136 | cmd = ['ffmpeg', '-loglevel', 'panic', '-y', '-i', video_path, '-i', audio_path, '-strict', '-2', 137 | merged_video_path] 138 | if clipping_to_shortest_stream: 139 | cmd.insert(len(cmd) - 1, '-shortest') 140 | subprocess.call(cmd) 141 | if delete_audio_file: 142 | os.remove(audio_path) 143 | os.remove(video_path) 144 | 145 | print('done, took {:.1f} seconds'.format(time.time() - start)) 146 | return output_poses, target_poses 147 | 148 | 149 | def save_checkpoint(state, filename): 150 | torch.save(state, filename) 151 | logging.info('Saved the checkpoint') 152 | 153 | 154 | def get_speaker_model(net): 155 | try: 156 | if hasattr(net, 'module'): 157 | speaker_model = net.module.z_obj 158 | else: 159 | speaker_model = net.z_obj 160 | except AttributeError: 161 | speaker_model = None 162 | 163 | if not isinstance(speaker_model, vocab.Vocab): 164 | speaker_model = None 165 | 166 | return speaker_model 167 | 168 | 169 | def load_checkpoint_hierarchy(checkpoint_path, _device='cpu'): 170 | print('loading checkpoint {}'.format(checkpoint_path)) 171 | checkpoint = torch.load(checkpoint_path, map_location=_device) 172 | args = checkpoint['args'] 173 | epoch = checkpoint['epoch'] 174 | lang_model = checkpoint['lang_model'] 175 | speaker_model = checkpoint['speaker_model'] 176 | pose_dim = checkpoint['pose_dim'] 177 | print('epoch {}'.format(epoch)) 178 | 179 | generator, discriminator, audio_encoder, text_encoder, loss_fn = train.init_model(args, lang_model, speaker_model, pose_dim, _device) 180 | g1, _, _, _, _ = train.init_model(args, lang_model, speaker_model, 5 * 3, _device) 181 | g2, _, _, _, _ = train.init_model(args, lang_model, speaker_model, 7 * 3, _device) 182 | g3, _, _, _, _ = train.init_model(args, lang_model, speaker_model, 9 * 3, _device) 183 | 184 | g1.load_state_dict(checkpoint['gen_dict_1']) 185 | g2.load_state_dict(checkpoint['gen_dict_2']) 186 | g3.load_state_dict(checkpoint['gen_dict_3']) 187 | audio_encoder.load_state_dict(checkpoint['audio_dict']) 188 | 189 | # set to eval mode 190 | g1.train(False) 191 | g2.train(False) 192 | g3.train(False) 193 | audio_encoder.train(False) 194 | 195 | return args, g1, g2, g3, audio_encoder, loss_fn, lang_model, speaker_model, pose_dim 196 | 197 | def load_checkpoint_and_model(checkpoint_path, _device='cpu'): 198 | print('loading checkpoint {}'.format(checkpoint_path)) 199 | checkpoint = torch.load(checkpoint_path, map_location=_device) 200 | args = checkpoint['args'] 201 | epoch = checkpoint['epoch'] 202 | lang_model = checkpoint['lang_model'] 203 | speaker_model = checkpoint['speaker_model'] 204 | pose_dim = checkpoint['pose_dim'] 205 | print('epoch {}'.format(epoch)) 206 | 207 | generator, discriminator, loss_fn = train.init_model(args, lang_model, speaker_model, pose_dim, _device) 208 | generator.load_state_dict(checkpoint['gen_dict']) 209 | 210 | # set to eval mode 211 | generator.train(False) 212 | 213 | return args, generator, loss_fn, lang_model, speaker_model, pose_dim 214 | 215 | 216 | def set_random_seed(seed): 217 | torch.manual_seed(seed) 218 | torch.cuda.manual_seed_all(seed) 219 | np.random.seed(seed) 220 | random.seed(seed) 221 | os.environ['PYTHONHASHSEED'] = str(seed) 222 | -------------------------------------------------------------------------------- /scripts/utils/train_utils_expressive.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import random 5 | import subprocess 6 | from collections import defaultdict, namedtuple 7 | from logging.handlers import RotatingFileHandler 8 | from textwrap import wrap 9 | 10 | import numpy as np 11 | import re 12 | import time 13 | import math 14 | import soundfile as sf 15 | import librosa.display 16 | 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | import torch 21 | import matplotlib.ticker as ticker 22 | import matplotlib.animation as animation 23 | from mpl_toolkits import mplot3d 24 | 25 | import utils.data_utils_expressive 26 | import train_expressive 27 | import data_loader.lmdb_data_loader 28 | 29 | 30 | # only for unicode characters, you may remove these two lines 31 | from model import vocab 32 | 33 | matplotlib.rcParams['axes.unicode_minus'] = False 34 | 35 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 36 | 37 | 38 | def set_logger(log_path=None, log_filename='log'): 39 | for handler in logging.root.handlers[:]: 40 | logging.root.removeHandler(handler) 41 | handlers = [logging.StreamHandler()] 42 | if log_path is not None: 43 | os.makedirs(log_path, exist_ok=True) 44 | handlers.append( 45 | RotatingFileHandler(os.path.join(log_path, log_filename), maxBytes=10 * 1024 * 1024, backupCount=5)) 46 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s: %(message)s', handlers=handlers) 47 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 48 | 49 | 50 | def as_minutes(s): 51 | m = math.floor(s / 60) 52 | s -= m * 60 53 | return '%dm %ds' % (m, s) 54 | 55 | 56 | def time_since(since): 57 | now = time.time() 58 | s = now - since 59 | return '%s' % as_minutes(s) 60 | 61 | 62 | def create_video_and_save(save_path, epoch, prefix, iter_idx, target, output, mean_data, title, 63 | audio=None, aux_str=None, clipping_to_shortest_stream=False, delete_audio_file=True): 64 | print('rendering a video...') 65 | start = time.time() 66 | 67 | fig = plt.figure(figsize=(8, 4)) 68 | axes = [fig.add_subplot(1, 2, 1, projection='3d'), fig.add_subplot(1, 2, 2, projection='3d')] 69 | axes[0].view_init(elev=20, azim=-60) 70 | axes[1].view_init(elev=20, azim=-60) 71 | fig_title = title 72 | 73 | if aux_str: 74 | fig_title += ('\n' + aux_str) 75 | fig.suptitle('\n'.join(wrap(fig_title, 75)), fontsize='medium') 76 | 77 | # un-normalization and convert to poses 78 | mean_data = mean_data.flatten() 79 | output = output + mean_data 80 | output_poses = utils.data_utils_expressive.convert_dir_vec_to_pose(torch.Tensor(output).to(device)) 81 | target_poses = None 82 | if target is not None: 83 | target = target + mean_data 84 | target_poses = utils.data_utils_expressive.convert_dir_vec_to_pose(torch.Tensor(target).to(device)) 85 | 86 | def animate(i): 87 | for k, name in enumerate(['human', 'generated']): 88 | if name == 'human' and target is not None and i < len(target): 89 | pose = target_poses[i] 90 | elif name == 'generated' and i < len(output): 91 | pose = output_poses[i] 92 | else: 93 | pose = None 94 | 95 | if pose is not None: 96 | axes[k].clear() 97 | for j, pair in enumerate(utils.data_utils_expressive.dir_vec_pairs): 98 | axes[k].plot([pose[pair[0], 0], pose[pair[1], 0]], 99 | [pose[pair[0], 2], pose[pair[1], 2]], 100 | [pose[pair[0], 1], pose[pair[1], 1]], 101 | zdir='z', linewidth=1.5) 102 | axes[k].set_xlim3d(-0.5, 0.5) 103 | axes[k].set_ylim3d(0.5, -0.5) 104 | axes[k].set_zlim3d(0.5, -0.5) 105 | axes[k].set_xlabel('x') 106 | axes[k].set_ylabel('z') 107 | axes[k].set_zlabel('y') 108 | axes[k].set_title('{} ({}/{})'.format(name, i + 1, len(output))) 109 | 110 | if target is not None: 111 | num_frames = max(len(target), len(output)) 112 | else: 113 | num_frames = len(output) 114 | ani = animation.FuncAnimation(fig, animate, interval=30, frames=num_frames, repeat=False) 115 | 116 | # show audio 117 | audio_path = None 118 | if audio is not None: 119 | assert len(audio.shape) == 1 # 1-channel, raw signal 120 | audio = audio.astype(np.float32) 121 | sr = 16000 122 | audio_path = '{}/{}_{:03d}_{}.wav'.format(save_path, prefix, epoch, iter_idx) 123 | sf.write(audio_path, audio, sr) 124 | 125 | # save video 126 | try: 127 | video_path = '{}/temp_{}_{:03d}_{}.mp4'.format(save_path, prefix, epoch, iter_idx) 128 | ani.save(video_path, fps=15, dpi=80) # dpi 150 for a higher resolution 129 | del ani 130 | plt.close(fig) 131 | except RuntimeError: 132 | assert False, 'RuntimeError' 133 | 134 | # merge audio and video 135 | if audio is not None: 136 | merged_video_path = '{}/{}_{:03d}_{}.mp4'.format(save_path, prefix, epoch, iter_idx) 137 | cmd = ['ffmpeg', '-loglevel', 'panic', '-y', '-i', video_path, '-i', audio_path, '-strict', '-2', 138 | merged_video_path] 139 | if clipping_to_shortest_stream: 140 | cmd.insert(len(cmd) - 1, '-shortest') 141 | subprocess.call(cmd) 142 | if delete_audio_file: 143 | os.remove(audio_path) 144 | os.remove(video_path) 145 | 146 | print('done, took {:.1f} seconds'.format(time.time() - start)) 147 | return output_poses, target_poses 148 | 149 | 150 | def save_checkpoint(state, filename): 151 | torch.save(state, filename) 152 | logging.info('Saved the checkpoint') 153 | 154 | 155 | def get_speaker_model(net): 156 | try: 157 | if hasattr(net, 'module'): 158 | speaker_model = net.module.z_obj 159 | else: 160 | speaker_model = net.z_obj 161 | except AttributeError: 162 | speaker_model = None 163 | 164 | if not isinstance(speaker_model, vocab.Vocab): 165 | speaker_model = None 166 | 167 | return speaker_model 168 | 169 | 170 | def load_checkpoint_hierarchy(checkpoint_path, _device='cpu'): 171 | print('loading checkpoint {}'.format(checkpoint_path)) 172 | checkpoint = torch.load(checkpoint_path, map_location=_device) 173 | args = checkpoint['args'] 174 | epoch = checkpoint['epoch'] 175 | lang_model = checkpoint['lang_model'] 176 | speaker_model = checkpoint['speaker_model'] 177 | pose_dim = checkpoint['pose_dim'] 178 | print('epoch {}'.format(epoch)) 179 | 180 | generator, discriminator, audio_encoder, text_encoder, loss_fn = train_expressive.init_model(args, lang_model, speaker_model, pose_dim, device) 181 | g1, _, _, _, _ = train_expressive.init_model(args, lang_model, speaker_model, 8 * 3, device) 182 | g2, _, _, _, _ = train_expressive.init_model(args, lang_model, speaker_model, 10 * 3, device) 183 | g3, _, _, _, _ = train_expressive.init_model(args, lang_model, speaker_model, 12 * 3, device) 184 | g4, _, _, _, _ = train_expressive.init_model(args, lang_model, speaker_model, 22 * 3, device) 185 | g5, _, _, _, _ = train_expressive.init_model(args, lang_model, speaker_model, 32 * 3, device) 186 | g6, _, _, _, _ = train_expressive.init_model(args, lang_model, speaker_model, 42 * 3, device) 187 | 188 | g1.load_state_dict(checkpoint['gen_dict_1']) 189 | g2.load_state_dict(checkpoint['gen_dict_2']) 190 | g3.load_state_dict(checkpoint['gen_dict_3']) 191 | g4.load_state_dict(checkpoint['gen_dict_4']) 192 | g5.load_state_dict(checkpoint['gen_dict_5']) 193 | g6.load_state_dict(checkpoint['gen_dict_6']) 194 | audio_encoder.load_state_dict(checkpoint['audio_dict']) 195 | 196 | # set to eval mode 197 | g1.train(False) 198 | g2.train(False) 199 | g3.train(False) 200 | g4.train(False) 201 | g5.train(False) 202 | g6.train(False) 203 | audio_encoder.train(False) 204 | 205 | return args, g1, g2, g3, g4, g5, g6, audio_encoder, loss_fn, lang_model, speaker_model, pose_dim 206 | 207 | def load_checkpoint_and_model(checkpoint_path, _device='cpu'): 208 | print('loading checkpoint {}'.format(checkpoint_path)) 209 | checkpoint = torch.load(checkpoint_path, map_location=_device) 210 | args = checkpoint['args'] 211 | epoch = checkpoint['epoch'] 212 | lang_model = checkpoint['lang_model'] 213 | speaker_model = checkpoint['speaker_model'] 214 | pose_dim = checkpoint['pose_dim'] 215 | print('epoch {}'.format(epoch)) 216 | 217 | generator, discriminator, loss_fn = train_expressive.init_model(args, lang_model, speaker_model, pose_dim, device) 218 | generator.load_state_dict(checkpoint['gen_dict']) 219 | 220 | # set to eval mode 221 | generator.train(False) 222 | 223 | return args, generator, loss_fn, lang_model, speaker_model, pose_dim 224 | 225 | 226 | def set_random_seed(seed): 227 | torch.manual_seed(seed) 228 | torch.cuda.manual_seed_all(seed) 229 | np.random.seed(seed) 230 | random.seed(seed) 231 | os.environ['PYTHONHASHSEED'] = str(seed) -------------------------------------------------------------------------------- /scripts/utils/tts_helper.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | 5 | from google.cloud import texttospeech 6 | from pygame import mixer 7 | 8 | 9 | class TTSHelper: 10 | """ helper class for google TTS 11 | set the environment variable GOOGLE_APPLICATION_CREDENTIALS first 12 | GOOGLE_APPLICATION_CREDENTIALS = 'path to json key file' 13 | """ 14 | 15 | cache_folder = './cached_wav/' 16 | 17 | def __init__(self, cache_path=None): 18 | if cache_path is not None: 19 | self.cache_folder = cache_path 20 | 21 | # create cache folder 22 | try: 23 | os.makedirs(self.cache_folder) 24 | except OSError: 25 | pass 26 | 27 | # init tts 28 | self.client = texttospeech.TextToSpeechClient() 29 | self.voice_en_female = texttospeech.types.VoiceSelectionParams( 30 | language_code='en-US', name='en-US-Wavenet-F') 31 | self.voice_en_male = texttospeech.types.VoiceSelectionParams( 32 | language_code='en-US', name='en-US-Wavenet-D') 33 | self.audio_config_en = texttospeech.types.AudioConfig( 34 | speaking_rate=1.0, 35 | audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16) 36 | 37 | # init player 38 | mixer.init() 39 | 40 | # clean up cache folder 41 | self._cleanup_cachefolder() 42 | 43 | def _cleanup_cachefolder(self): 44 | """ remove least accessed files in the cache """ 45 | dir_to_search = self.cache_folder 46 | for dirpath, dirnames, filenames in os.walk(dir_to_search): 47 | for file in filenames: 48 | curpath = os.path.join(dirpath, file) 49 | file_accessed = datetime.datetime.fromtimestamp(os.path.getatime(curpath)) 50 | if datetime.datetime.now() - file_accessed > datetime.timedelta(days=30): 51 | os.remove(curpath) 52 | 53 | def _string2numeric_hash(self, text): 54 | import hashlib 55 | return int(hashlib.md5(text.encode('utf-8')).hexdigest()[:16], 16) 56 | 57 | def synthesis(self, ssml_text, voice_name='en-female', verbose=False): 58 | if not ssml_text.startswith(u''): 59 | ssml_text = u'' + ssml_text + u'' 60 | 61 | filename = os.path.join(self.cache_folder, str(self._string2numeric_hash(voice_name + ssml_text)) + '.wav') 62 | 63 | # load or synthesis audio 64 | if not os.path.exists(filename): 65 | if verbose: 66 | start = time.time() 67 | 68 | # let's synthesis 69 | if voice_name == 'en-female': 70 | voice = self.voice_en_female 71 | audio_config = self.audio_config_en 72 | elif voice_name == 'en-male': 73 | voice = self.voice_en_male 74 | audio_config = self.audio_config_en 75 | else: 76 | raise ValueError 77 | 78 | synthesis_input = texttospeech.types.SynthesisInput(ssml=ssml_text) 79 | response = self.client.synthesize_speech(synthesis_input, voice, audio_config) 80 | 81 | if verbose: 82 | print('TTS took {0:.2f} seconds'.format(time.time() - start)) 83 | start = time.time() 84 | 85 | # save to a file 86 | with open(filename, 'wb') as out: 87 | out.write(response.audio_content) 88 | if verbose: 89 | print('written to a file "{}"'.format(filename)) 90 | else: 91 | if verbose: 92 | print('use the cached wav "{}"'.format(filename)) 93 | 94 | return filename 95 | 96 | def get_sound_obj(self, filename): 97 | # play 98 | sound = mixer.Sound(filename) 99 | length = sound.get_length() 100 | 101 | return sound, length 102 | 103 | def play(self, sound): 104 | sound.play(loops=0) 105 | -------------------------------------------------------------------------------- /scripts/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import lmdb 6 | import pyarrow 7 | 8 | from model.vocab import Vocab 9 | 10 | 11 | def build_vocab(name, dataset_list, cache_path, word_vec_path=None, feat_dim=None): 12 | logging.info(' building a language model...') 13 | if not os.path.exists(cache_path): 14 | lang_model = Vocab(name) 15 | for dataset in dataset_list: 16 | logging.info(' indexing words from {}'.format(dataset.lmdb_dir)) 17 | index_words(lang_model, dataset.lmdb_dir) 18 | 19 | if word_vec_path is not None: 20 | lang_model.load_word_vectors(word_vec_path, feat_dim) 21 | 22 | with open(cache_path, 'wb') as f: 23 | pickle.dump(lang_model, f) 24 | else: 25 | logging.info(' loaded from {}'.format(cache_path)) 26 | with open(cache_path, 'rb') as f: 27 | lang_model = pickle.load(f) 28 | 29 | if word_vec_path is None: 30 | lang_model.word_embedding_weights = None 31 | elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: 32 | logging.warning(' failed to load word embedding weights. check this') 33 | assert False 34 | 35 | return lang_model 36 | 37 | 38 | def index_words(lang_model, lmdb_dir): 39 | lmdb_env = lmdb.open(lmdb_dir, readonly=True, lock=False) 40 | txn = lmdb_env.begin(write=False) 41 | cursor = txn.cursor() 42 | 43 | for key, buf in cursor: 44 | video = pyarrow.deserialize(buf) 45 | 46 | for clip in video['clips']: 47 | for word_info in clip['words']: 48 | word = word_info[0] 49 | lang_model.index_word(word) 50 | 51 | lmdb_env.close() 52 | logging.info(' indexed %d words' % lang_model.n_words) 53 | 54 | # filtering vocab 55 | # MIN_COUNT = 3 56 | # lang_model.trim(MIN_COUNT) 57 | 58 | --------------------------------------------------------------------------------