├── README.md ├── configs ├── 0_default.json ├── 2_zscore.json ├── __init__.py ├── generate_config_mixste.py ├── generate_config_motionagformer.py ├── generate_config_motionbert.py ├── generate_config_poseformer.py ├── generate_config_poseformerv2.py ├── mixste │ └── default.json ├── motionagformer │ ├── 1_xsmall.json │ ├── 2_small.json │ ├── 3_base.json │ ├── 4_large.json │ └── default.json ├── motionbert │ └── 11_pd,json ├── motionbert_backup │ ├── 3_Aug_rotation.json │ ├── 4_Aug_translation.json │ ├── 5_Aug_rotation_range20.json │ └── 6_Aug_translation_frac0.1.json ├── poseformer │ ├── 1_minmax.json │ ├── 3_uncentered_unnorm.json │ ├── 4_uncentered_minmax.json │ └── 5_uncentered_zscore.json └── poseformerv2 │ └── test.json ├── const ├── __init__.py ├── const.py └── path.py ├── data ├── Visualize_reconst3d.py ├── __init__.py ├── augmentations.py ├── data_augmentation.py ├── dataloaders.py ├── pd │ ├── Removed_sequences.csv │ ├── const_pd.py │ └── preprocess_pd.py ├── public_pd_datareader.py └── utility.py ├── eval_encoder.py ├── learning ├── __init__.py ├── criterion.py ├── criterions │ └── __init__.py ├── optimizer.py └── utils.py ├── model ├── __init__.py ├── backbone_loader.py ├── mixste │ ├── model_cross.py │ └── rela.py ├── motion_encoder.py ├── motionagformer │ ├── MotionAGFormer.py │ ├── __init__.py │ └── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── graph.py │ │ ├── mlp.py │ │ └── tcn.py ├── motionbert │ ├── DSTformer.py │ └── drop.py ├── poseformer │ ├── Conv1DEncoder.py │ ├── PoseEncoderDecoder.py │ ├── PoseGCN.py │ ├── PoseTransformer.py │ ├── PositionEncodings.py │ ├── Transformer.py │ ├── TransformerEncoder.py │ ├── __init__.py │ └── seq2seq_model_fn.py ├── poseformerv2 │ └── model_poseformer.py └── utils.py ├── requirements.txt ├── stat_analysis └── get_stats.py ├── test.py ├── train.py └── utility ├── __init__.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # [FG 2024] Benchmarking Skeleton-based Motion Encoder Models for Clinical Applications 2 | 3 | 4 | ![License](https://img.shields.io/badge/license-MIT-blue) [![arXiv](https://img.shields.io/badge/arXiv-2405.17817-b31b1b.svg)](https://arxiv.org/abs/2405.17817) 5 | 6 | 7 | ## Introduction 8 | This project is created as part of the research for the paper titled "[Benchmarking Skeleton-based Motion Encoder Models for Clinical Applications: Estimating Parkinson’s Disease Severity in Walking Sequences](https://arxiv.org/abs/2405.17817)" accepted at IEEE international conference on automatic face \& gesture recognition (FG 2024). 9 | 10 | ## Installation 11 | ```bash 12 | git clone https://github.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark.git 13 | cd MotionEncoders_parkinsonism_benchmark 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Data 18 | Dataloaders will be added soon. 19 | 20 | ## Demo 21 | Demo will be added soon. 22 | 23 | ## Leaderboard 24 | 25 | | Model | F1 Score | Paper/Source | 26 | | ---------------|----------|--------------| 27 | | MixSTE | 0.41 | [Link](https://paperswithcode.com/paper/mixste-seq2seq-mixed-spatio-temporal-encoder) | 28 | | MotionAGFormer | 0.42 | [Link](https://paperswithcode.com/paper/motionagformer-enhancing-3d-human-pose) | 29 | | MotionBERT-LITE | 0.43 | [Link](https://paperswithcode.com/paper/motionbert-unified-pretraining-for-human) | 30 | | POTR | 0.46 | [Link](https://paperswithcode.com/paper/pose-transformers-potr-human-motion) | 31 | | MotionBERT | 0.47 | [Link](https://paperswithcode.com/paper/motionbert-unified-pretraining-for-human) | 32 | | PD STGCN | 0.48 | [Link](https://paperswithcode.com/paper/abc) | 33 | | PoseFormerV2 | 0.59 | [Link](https://paperswithcode.com/paper/poseformerv2-exploring-frequency-domain-for) | 34 | | PoseFormerV2-Finetuned | 0.62 | [Link](https://paperswithcode.com/paper/abc) | 35 | 36 | 37 | For detailed rankings, visit the [Paperswithcode Leaderboard](https://paperswithcode.com/sota/classification-on-full-body-parkinsons). 38 | 39 | 40 | ## Acknowledgement 41 | Special thanks to the creators of the dataset for making their clinical data publicly available: 42 | - [A public data set of walking full-body](https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2023.992585/full) 43 | 44 | Our code also refers to the following repositories. We thank the authors for releasing their codes. 45 | 46 | - [PoseFormerV2](https://github.com/QitaoZhao/PoseFormerV2) 47 | - [MotionBERT](https://github.com/Walter0807/MotionBERT) 48 | - [MixSTE](https://github.com/JinluZhang1126/MixSTE) 49 | - [POTR](https://github.com/idiap/potr) 50 | - [MotionAGFormer](https://github.com/TaatiTeam/MotionAGFormer/tree/master) 51 | - [stgcn_parkinsonism_prediction](https://github.com/TaatiTeam/stgcn_parkinsonism_prediction) 52 | 53 | 54 | 55 | 56 | ## Citation 57 | Please cite our paper if this library helps your research: 58 | ``` 59 | @inproceedings{PDmotionBenchmark2024, 60 | title = {Benchmarking Skeleton-based Motion Encoder Models for Clinical Applications: Estimating Parkinson’s Disease Severity in Walking Sequences}, 61 | author = {Vida Adeli, Soroush Mehraban, Yasamin Zarghami, Irene Ballester, Andrea Sabo, Andrea Iaboni, Babak Taati}, 62 | booktitle = {2024 18th IEEE international conference on automatic face & gesture recognition (FG 2024)}, 63 | year = {2024} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /configs/0_default.json: -------------------------------------------------------------------------------- 1 | { 2 | } 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /configs/2_zscore.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_norm": "zscore" 3 | } 4 | 5 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/configs/__init__.py -------------------------------------------------------------------------------- /configs/generate_config_mixste.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from const import path 4 | 5 | 6 | def generate_config(param, f_name): 7 | data_params = { 8 | 'data_type': 'PD', # options: "Kinect", "GastNet", "PCT", "ViTPose" 9 | 'data_dim': 3, 10 | 'in_data_dim': 2, 11 | 'data_centered': True, 12 | 'merge_last_dim': False, 13 | 'use_validation': True, 14 | 'simulate_confidence_score': False, # It doesn't have confidence score. So we just need to discard the last dim. 15 | 'pretrained_dataset_name': 'h36m', 16 | 'model_prefix': 'mixste_', 17 | # options: mirror_reflection, random_rotation, random_translation 18 | # 'augmentation': [], 19 | 'rotation_range': [-10, 10], 20 | 'rotation_prob': 0.5, 21 | 'mirror_prob': 0.5, 22 | 'noise_prob': 0.5, 23 | 'axis_mask_prob': 0.5, 24 | 'translation_frac': 0.05, 25 | 'data_norm': "rescaling", 26 | 'select_middle': False, 27 | 'exclude_non_rgb_sequences': False 28 | } 29 | model_params = { 30 | 'source_seq_len': 81, 31 | 'num_joints': 17, 32 | 'embed_dim_ratio': 512, 33 | 'depth': 8, 34 | 'merge_joints': False, 35 | 'classifier_hidden_dims': [2048], 36 | 'classifier_dropout': 0.5, 37 | 'model_checkpoint_path': f"{path.PRETRAINEDD_MODEL_CHECKPOINTS_ROOT_PATH}/mixste/best_epoch_cpn_81f.bin" 38 | } 39 | learning_params = { 40 | 'wandb_name': 'mixste', 41 | 'experiment_name': '', 42 | 'batch_size': 256, 43 | 'criterion': 'CrossEntropyLoss', 44 | 'optimizer': 'AdamW', 45 | 'lr_backbone': 0.0001, 46 | 'lr_head': 0.001, 47 | 'weight_decay': 0.01, 48 | 'lambda_l1': 0.0, 49 | 'scheduler': "StepLR", 50 | 'lr_decay': 0.99, 51 | 'epochs': 20, 52 | 'stopping_tolerance': 10, 53 | 'lr_step_size': 1 54 | } 55 | 56 | params = {**param, **data_params, **model_params, **learning_params} 57 | 58 | f = open("./configs/mixste/" + f_name, "rb") 59 | new_param = json.load(f) 60 | 61 | for p in new_param: 62 | if not p in params.keys(): 63 | raise ValueError( 64 | "Error: One of the config parameters in " + "./Configs/" + f_name + " does not match code!") 65 | params[p] = new_param[p] 66 | 67 | params['labels_path'] = params['data_path'] # Data Path is the path to csv files by default 68 | 69 | params['model_prefix'] = params['model_prefix'] + f_name.split('.json')[0] 70 | return params, new_param 71 | -------------------------------------------------------------------------------- /configs/generate_config_motionagformer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | from colorama import Fore 5 | 6 | from const import path 7 | 8 | 9 | def generate_config(param, f_name): 10 | data_params = { 11 | 'data_type': 'Kinect', # options: "Kinect", "GastNet", "PCT", "ViTPose" 12 | 'data_dim': 3, 13 | 'in_data_dim': 2, 14 | 'data_centered': True, 15 | 'merge_last_dim': False, 16 | 'use_validation': True, 17 | 'simulate_confidence_score': True, 18 | 'pretrained_dataset_name': 'h36m', 19 | 'model_prefix': 'MotionAGFormer_', 20 | # options: mirror_reflection, random_rotation, random_translation 21 | # 'augmentation': [], 22 | 'rotation_range': [-10, 10], 23 | 'rotation_prob': 0.5, 24 | 'mirror_prob': 0.5, 25 | 'noise_prob': 0.5, 26 | 'axis_mask_prob': 0.5, 27 | 'translation_frac': 0.05, 28 | 'data_norm': "rescaling", 29 | 'select_middle': True, 30 | 'exclude_non_rgb_sequences': False 31 | } 32 | model_params = { 33 | 'source_seq_len': 243, 34 | 'n_layers': 16, 35 | 'dim_in': 3, 36 | 'dim_feat': 128, 37 | 'dim_rep': 512, 38 | 'dim_out': 3, 39 | 'mlp_ratio': 4, 40 | 'attn_drop': 0.0, 41 | 'drop': 0.0, 42 | "drop_path": 0.0, 43 | "use_layer_scale": True, 44 | "layer_scale_init_value": 0.00001, 45 | "use_adaptive_fusion": True, 46 | "num_heads": 8, 47 | "qkv_bias": False, 48 | "qkv_scale": None, 49 | "hierarchical": False, 50 | "use_temporal_similarity": True, 51 | "neighbour_num": 2, 52 | "temporal_connection_len": 1, 53 | "use_tcn": False, 54 | "graph_only": False, 55 | 'classifier_dropout': 0.0, 56 | 'merge_joints': True, 57 | 'classifier_hidden_dims': [1024], 58 | 'model_checkpoint_path': f"{path.PRETRAINEDD_MODEL_CHECKPOINTS_ROOT_PATH}/motionagformer/checkpoint.pth.tr" 59 | } 60 | learning_params = { 61 | 'wandb_name': 'MotionAGFormer', 62 | 'experiment_name': '', 63 | 'batch_size': 256, 64 | 'criterion': 'CrossEntropyLoss', 65 | 'optimizer': 'AdamW', 66 | 'lr_backbone': 0.0001, 67 | 'lr_head': 0.001, 68 | 'weight_decay': 0.01, 69 | 'lambda_l1': 0.0001, 70 | 'scheduler': "StepLR", 71 | 'lr_decay': 0.99, 72 | 'epochs': 20, 73 | 'stopping_tolerance': 10, 74 | 'lr_step_size': 1 75 | } 76 | 77 | params = {**param, **data_params, **model_params, **learning_params} 78 | 79 | f = open("./configs/motionagformer/" + f_name, "rb") 80 | new_param = json.load(f) 81 | 82 | for p in new_param: 83 | if not p in params.keys(): 84 | raise ValueError( 85 | "Error: One of the config parameters in " + "./Configs/" + f_name + " does not match code!") 86 | params[p] = new_param[p] 87 | 88 | params['labels_path'] = params['data_path'] # Data Path is the path to csv files by default 89 | 90 | params['model_prefix'] = params['model_prefix'] + f_name.split('.json')[0] 91 | return params, new_param 92 | -------------------------------------------------------------------------------- /configs/generate_config_motionbert.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | from colorama import Fore 5 | 6 | from const import path 7 | 8 | 9 | def generate_config(param, f_name): 10 | data_params = { 11 | 'data_type': 'PD', # options: "Kinect", "GastNet", "PCT", "ViTPose, "PD" 12 | 'data_dim': 3, 13 | 'in_data_dim': 2, 14 | 'data_centered': True, 15 | 'merge_last_dim': False, 16 | 'use_validation': True, 17 | 'simulate_confidence_score': True, 18 | 'pretrained_dataset_name': 'h36m', 19 | 'model_prefix': 'motionbert_', 20 | # options: mirror_reflection, random_rotation, random_translation 21 | # 'augmentation': [], 22 | 'rotation_range': [-10, 10], 23 | 'rotation_prob': 0.5, 24 | 'mirror_prob': 0.5, 25 | 'noise_prob': 0.5, 26 | 'axis_mask_prob': 0.5, 27 | 'translation_frac': 0.05, 28 | 'data_norm': "rescaling", 29 | 'select_middle': True, 30 | 'exclude_non_rgb_sequences': False 31 | } 32 | model_params = { 33 | 'source_seq_len': 243, 34 | 'num_joints': 17, 35 | 'dim_feat': 512, #MotionBERT 512, Lite 256 36 | 'dim_rep': 512, 37 | 'depth': 5, 38 | 'num_heads': 8, 39 | 'mlp_ratio': 2, #MotionBERT 2, Lite 4 40 | 'maxlen': 243, 41 | 'classifier_dropout': 0.5, 42 | 'merge_joints': False, 43 | 'classifier_hidden_dims': [2048], 44 | 'model_checkpoint_path': f"{path.PRETRAINEDD_MODEL_CHECKPOINTS_ROOT_PATH}/motionbert/motionbert.bin" 45 | } 46 | learning_params = { 47 | 'wandb_name': 'motionbert', 48 | 'experiment_name': '', 49 | 'batch_size': 256, 50 | 'criterion': 'CrossEntropyLoss', 51 | 'optimizer': 'AdamW', 52 | 'lr_backbone': 0.0001, 53 | 'lr_head': 0.001, 54 | 'weight_decay': 0.01, 55 | 'lambda_l1': 0.0, 56 | 'scheduler': "StepLR", 57 | 'lr_decay': 0.99, 58 | 'epochs': 20, 59 | 'stopping_tolerance': 10, 60 | 'lr_step_size': 1 61 | } 62 | 63 | params = {**param, **data_params, **model_params, **learning_params} 64 | 65 | f = open("./configs/motionbert/" + f_name, "rb") 66 | new_param = json.load(f) 67 | 68 | for p in new_param: 69 | if not p in params.keys(): 70 | raise ValueError( 71 | "Error: One of the config parameters in " + "./Configs/" + f_name + " does not match code!") 72 | params[p] = new_param[p] 73 | 74 | 75 | 76 | params['labels_path'] = params['data_path'] # Data Path is the path to csv files by default 77 | 78 | 79 | if params['data_type'] == "PD": 80 | print("path.PD_PATH_LABELS", path.PD_PATH_LABELS) 81 | print("path.PD_PATH_POSES" , path.PD_PATH_POSES) 82 | params['labels_path'] = path.PD_PATH_LABELS 83 | params['data_path'] = path.PD_PATH_POSES 84 | 85 | 86 | params['model_prefix'] = params['model_prefix'] + f_name.split('.json')[0] 87 | return params, new_param 88 | -------------------------------------------------------------------------------- /configs/generate_config_poseformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from colorama import Fore 5 | 6 | from const import path 7 | 8 | 9 | _NSEEDS = 8 10 | import json 11 | 12 | 13 | def generate_config(param, f_name): 14 | data_params = { 15 | 'data_dim': 3, 16 | 'in_data_dim': 3, 17 | 'data_type': 'PD', 18 | 'data_centered': False, 19 | 'merge_last_dim': True, 20 | 'use_validation': True, 21 | 'simulate_confidence_score': False, 22 | 'pretrained_dataset_name': 'NTU', 23 | 'voting': False, 24 | 'model_prefix': 'POTR_', 25 | 'data_norm': 'unnorm', # [minmax, unnorm, zscore] 26 | 'source_seq_len': 80, 27 | 'interpolate': True, 28 | "select_middle": False, 29 | 'rotation_range': [-10, 10], 30 | 'rotation_prob': 0.5, 31 | 'mirror_prob': 0.5, 32 | 'noise_prob': 0.5, 33 | 'axis_mask_prob': 0.5, 34 | 'translation_frac': 0.05, 35 | 'augmentation': [] 36 | } 37 | 38 | model_params = { 39 | 'model_dim': 128, 40 | 'num_encoder_layers': 4, 41 | 'num_heads': 4, 42 | 'dim_ffn': 2048, 43 | 'init_fn': 'xavier_init', 44 | 'pose_embedding_type': 'gcn_enc', 45 | 'pos_enc_alpha': 10, 46 | 'pos_enc_beta': 500, 47 | 'downstream_strategy': 'both', # ['both', 'class', 'both_then_class'], 48 | 'model_checkpoint_path': f"{path.PRETRAINEDD_MODEL_CHECKPOINTS_ROOT_PATH}/poseforemer/pre-trained_NTU_ckpt_epoch_199_enc_80_dec_20.pt", 49 | 'pose_format': None, 50 | 'classifier_dropout': 0.5, 51 | 'classifier_hidden_dim': 2048, 52 | 'preclass_rem_T': True 53 | } 54 | 55 | learning_params = { 56 | 'wandb_name': 'poseforemer', 57 | 'batch_size': 256, 58 | 'lr_backbone': 0.0001, 59 | 'lr_head': 0.001, 60 | 'epochs': 100, 61 | 'steps_per_epoch': 200, 62 | 'dropout': 0.3, 63 | 'max_gradient_norm': 0.1, 64 | 'lr_step_size': 1, 65 | 'learning_rate_fn': 'step', 66 | 'criterion': 'WCELoss', 67 | 'warmup_epochs': 10, 68 | 'smoothing_scale': 0.1, 69 | 'optimizer': 'AdamW', 70 | 'stopping_tolerance': 10, 71 | 'weight_decay': 0.00001, 72 | 'lr_decay': 0.99, 73 | 'experiment_name': '', 74 | } 75 | 76 | params = {**param, **data_params, **model_params, **learning_params} 77 | 78 | f = open("./configs/poseformer/" + f_name, "rb") 79 | new_param = json.load(f) 80 | 81 | for p in new_param: 82 | if not p in params.keys(): 83 | print("Error: One of the config parameters in " + "./Configs/" + f_name + " does not match code!") 84 | print(Fore.RED + 'Configuration mismatch at:' + p) 85 | sys.exit(1) 86 | params[p] = new_param[p] 87 | 88 | 89 | if params['dataset'] == 'PD': 90 | params['labels_path'] = path.PD_PATH_LABELS 91 | params['data_path'] = path.PD_PATH_POSES 92 | 93 | params['model_prefix'] = params['model_prefix'] + f_name.split('.json')[0] 94 | return params, new_param 95 | -------------------------------------------------------------------------------- /configs/generate_config_poseformerv2.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from const import path 4 | 5 | 6 | def generate_config(param, f_name): 7 | data_params = { 8 | 'data_type': 'PD', # options: "Kinect", "GastNet", "PCT", "ViTPose" 9 | 'data_dim': 3, 10 | 'in_data_dim': 2, 11 | 'data_centered': True, 12 | 'merge_last_dim': False, 13 | 'use_validation': True, 14 | 'simulate_confidence_score': False, # Since poseformerv2 doesn't require confidence score, we just ignore last dim. 15 | 'pretrained_dataset_name': 'h36m', 16 | 'model_prefix': 'poseformerv2_', 17 | # options: mirror_reflection, random_rotation, random_translation 18 | # 'augmentation': [], 19 | 'rotation_range': [-10, 10], 20 | 'rotation_prob': 0.5, 21 | 'mirror_prob': 0.5, 22 | 'noise_prob': 0.5, 23 | 'axis_mask_prob': 0.5, 24 | 'translation_frac': 0.05, 25 | 'data_norm': "rescaling", 26 | 'select_middle': False, 27 | 'exclude_non_rgb_sequences': False 28 | } 29 | model_params = { 30 | 'source_seq_len': 81, 31 | 'num_joints': 17, 32 | 'embed_dim_ratio': 32, 33 | 'depth': 4, 34 | 'number_of_kept_frames': 9, 35 | 'number_of_kept_coeffs': 9, 36 | 'classifier_dropout': 0.5, 37 | 'merge_joints': False, 38 | 'classifier_hidden_dims': [2048], 39 | 'model_checkpoint_path': f"{path.PRETRAINEDD_MODEL_CHECKPOINTS_ROOT_PATH}/poseformerv2/9_81_46.0.bin" 40 | } 41 | learning_params = { 42 | 'wandb_name': 'poseformerv2', 43 | 'experiment_name': '', 44 | 'batch_size': 256, 45 | 'criterion': 'CrossEntropyLoss', 46 | 'optimizer': 'AdamW', 47 | 'lr_backbone': 0.0001, 48 | 'lr_head': 0.001, 49 | 'weight_decay': 0.01, 50 | 'lambda_l1': 0.0, 51 | 'scheduler': "StepLR", 52 | 'lr_decay': 0.99, 53 | 'epochs': 20, 54 | 'stopping_tolerance': 10, 55 | 'lr_step_size': 1 56 | } 57 | 58 | params = {**param, **data_params, **model_params, **learning_params} 59 | 60 | f = open("./configs/poseformerv2/" + f_name, "rb") 61 | new_param = json.load(f) 62 | 63 | for p in new_param: 64 | if not p in params.keys(): 65 | raise ValueError( 66 | "Error: One of the config parameters in " + "./Configs/" + f_name + " does not match code!") 67 | params[p] = new_param[p] 68 | 69 | params['labels_path'] = params['data_path'] # Data Path is the path to csv files by default 70 | 71 | params['model_prefix'] = params['model_prefix'] + f_name.split('.json')[0] 72 | return params, new_param 73 | -------------------------------------------------------------------------------- /configs/mixste/default.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /configs/motionagformer/1_xsmall.json: -------------------------------------------------------------------------------- 1 | { 2 | "merge_joints": true, 3 | "classifier_hidden_dims": [ 4 | 1024 5 | ], 6 | "classifier_dropout": 0.0, 7 | "weight_decay": 0.0, 8 | "lambda_l1": 0.0001, 9 | "select_middle": false, 10 | "mirror_prob": 0.5, 11 | "rotation_prob": 0.5, 12 | "noise_prob": 0.5, 13 | "data_type": "Kinect", 14 | "source_seq_len": 27, 15 | "dim_feat": 64, 16 | "n_layers": 12, 17 | "model_checkpoint_path": "/cluster/projects/taati/vida/motion_evaluator_data/Pretrained_checkpoints/motionagformer/motionagformer-xs-h36m.pth.tr" 18 | } -------------------------------------------------------------------------------- /configs/motionagformer/2_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "merge_joints": true, 3 | "classifier_hidden_dims": [ 4 | 1024 5 | ], 6 | "classifier_dropout": 0.0, 7 | "weight_decay": 0.0, 8 | "lambda_l1": 0.0001, 9 | "select_middle": false, 10 | "mirror_prob": 0.5, 11 | "rotation_prob": 0.5, 12 | "noise_prob": 0.5, 13 | "data_type": "Kinect", 14 | "source_seq_len": 81, 15 | "dim_feat": 64, 16 | "n_layers": 26, 17 | "model_checkpoint_path": "/cluster/projects/taati/vida/motion_evaluator_data/Pretrained_checkpoints/motionagformer/motionagformer-s-h36m.pth.tr" 18 | } -------------------------------------------------------------------------------- /configs/motionagformer/3_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "merge_joints": true, 3 | "classifier_hidden_dims": [ 4 | 1024 5 | ], 6 | "classifier_dropout": 0.0, 7 | "weight_decay": 0.0, 8 | "lambda_l1": 0.0001, 9 | "select_middle": false, 10 | "mirror_prob": 0.5, 11 | "rotation_prob": 0.5, 12 | "noise_prob": 0.5, 13 | "data_type": "Kinect", 14 | "source_seq_len": 243, 15 | "dim_feat": 128, 16 | "n_layers": 16, 17 | "model_checkpoint_path": "/cluster/projects/taati/vida/motion_evaluator_data/Pretrained_checkpoints/motionagformer/motionagformer-b-h36m.pth.tr" 18 | } -------------------------------------------------------------------------------- /configs/motionagformer/4_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "merge_joints": true, 3 | "classifier_hidden_dims": [ 4 | 1024 5 | ], 6 | "classifier_dropout": 0.0, 7 | "weight_decay": 0.0, 8 | "lambda_l1": 0.0001, 9 | "select_middle": false, 10 | "mirror_prob": 0.5, 11 | "rotation_prob": 0.5, 12 | "noise_prob": 0.5, 13 | "data_type": "Kinect", 14 | "source_seq_len": 243, 15 | "dim_feat": 128, 16 | "n_layers": 26, 17 | "model_checkpoint_path": "/cluster/projects/taati/vida/motion_evaluator_data/Pretrained_checkpoints/motionagformer/motionagformer-l-h36m.pth.tr" 18 | } -------------------------------------------------------------------------------- /configs/motionagformer/default.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /configs/motionbert/11_pd,json: -------------------------------------------------------------------------------- 1 | { 2 | "merge_joints": true, 3 | "classifier_hidden_dims": [ 4 | 1024 5 | ], 6 | "classifier_dropout": 0.0, 7 | "weight_decay": 0.0, 8 | "lambda_l1": 0.0001, 9 | "select_middle": false, 10 | "mirror_prob": 0.5, 11 | "rotation_prob": 0.5, 12 | "noise_prob": 0.5, 13 | "data_type": "PD", 14 | "use_validation": true, 15 | "source_seq_len": 90, 16 | "stopping_tolerance": 50, 17 | "epochs": 50, 18 | "criterion": "WCELoss" 19 | } -------------------------------------------------------------------------------- /configs/motionbert_backup/3_Aug_rotation.json: -------------------------------------------------------------------------------- 1 | { 2 | "augmentation": ["random_rotation"] 3 | } 4 | 5 | -------------------------------------------------------------------------------- /configs/motionbert_backup/4_Aug_translation.json: -------------------------------------------------------------------------------- 1 | { 2 | "augmentation": ["random_translation"] 3 | } 4 | 5 | -------------------------------------------------------------------------------- /configs/motionbert_backup/5_Aug_rotation_range20.json: -------------------------------------------------------------------------------- 1 | { 2 | "augmentation": ["random_rotation"], 3 | "rotation_range": [-20, 20] 4 | } 5 | 6 | -------------------------------------------------------------------------------- /configs/motionbert_backup/6_Aug_translation_frac0.1.json: -------------------------------------------------------------------------------- 1 | { 2 | "augmentation": ["random_translation"], 3 | "translation_frac": 0.1 4 | } 5 | 6 | -------------------------------------------------------------------------------- /configs/poseformer/1_minmax.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_norm": "minmax" 3 | } 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /configs/poseformer/3_uncentered_unnorm.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_norm": "unnorm", 3 | "data_centered": false 4 | } 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /configs/poseformer/4_uncentered_minmax.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_norm": "minmax", 3 | "data_centered": false 4 | } 5 | 6 | -------------------------------------------------------------------------------- /configs/poseformer/5_uncentered_zscore.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_norm": "zscore", 3 | "data_centered": false 4 | } 5 | 6 | -------------------------------------------------------------------------------- /configs/poseformerv2/test.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /const/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/const/__init__.py -------------------------------------------------------------------------------- /const/const.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | _DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -------------------------------------------------------------------------------- /const/path.py: -------------------------------------------------------------------------------- 1 | NDRIVE_PROJECT_ROOT = '/data/iballester/motion_evaluator' 2 | 3 | PRETRAINEDD_MODEL_CHECKPOINTS_ROOT_PATH = '/data/iballester/motion_evaluator/Pretrained_checkpoints' 4 | 5 | OUT_PATH = '/caa/Homes01/iballester/log/motion_encoder/out/' 6 | 7 | # KINECT 8 | PREPROCESSED_DATA_ROOT_PATH = f'{NDRIVE_PROJECT_ROOT}/data' 9 | 10 | # PD 11 | PD_PATH_POSES='/data/iballester/datasets/Public_PD/C3Dfiles_processed' 12 | PD_PATH_LABELS='/data/iballester/datasets/Public_PD/PDGinfo.csv' 13 | 14 | CHECKPOINT_ROOT_PATH = '/caa/Homes01/iballester/log/motion_encoder/out/motionbert/finetune_6_pd,json/1/models' -------------------------------------------------------------------------------- /data/Visualize_reconst3d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | 5 | from matplotlib import pyplot as plt 6 | import matplotlib 7 | from matplotlib.animation import FuncAnimation 8 | 9 | this_path = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.insert(0, this_path + "/../") 11 | 12 | from const import path 13 | from data.dataloaders import * 14 | from model.backbone_loader import load_pretrained_backbone 15 | from configs import generate_config_poseformer, generate_config_motionbert, generate_config_poseformerv2, generate_config_mixste, generate_config_motionagformer 16 | 17 | _DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | VIEWS = { 20 | "pd": { 21 | "best": (45, 20, 100), 22 | "best2": (0, 0, 0), 23 | "side": (90, 0, 90), 24 | }, 25 | "tmp": { 26 | "best": (45, 20, 100), 27 | "side": (90, 0, 90), 28 | } 29 | } 30 | 31 | H36M_FULL = { 32 | 'B.TORSO': 0, 33 | 'L.HIP': 1, 34 | 'L.KNEE': 2, 35 | 'L.FOOT': 3, 36 | 'R.HIP': 4, 37 | 'R.KNEE': 5, 38 | 'R.FOOT': 6, 39 | 'C.TORSO': 7, 40 | 'U.TORSO': 8, 41 | 'NECK': 9, 42 | 'HEAD': 10, 43 | 'R.SHOULDER': 11, 44 | 'R.ELBOW': 12, 45 | 'R.HAND': 13, 46 | 'L.SHOULDER': 14, 47 | 'L.ELBOW': 15, 48 | 'L.HAND': 16 49 | } 50 | 51 | H36M_CONNECTIONS_FULL = { 52 | (H36M_FULL['B.TORSO'], H36M_FULL['L.HIP']), 53 | (H36M_FULL['B.TORSO'], H36M_FULL['R.HIP']), 54 | (H36M_FULL['R.HIP'], H36M_FULL['R.KNEE']), 55 | (H36M_FULL['R.KNEE'], H36M_FULL['R.FOOT']), 56 | (H36M_FULL['L.HIP'], H36M_FULL['L.KNEE']), 57 | (H36M_FULL['L.KNEE'], H36M_FULL['L.FOOT']), 58 | (H36M_FULL['B.TORSO'], H36M_FULL['C.TORSO']), 59 | (H36M_FULL['C.TORSO'], H36M_FULL['U.TORSO']), 60 | (H36M_FULL['U.TORSO'], H36M_FULL['L.SHOULDER']), 61 | (H36M_FULL['L.SHOULDER'], H36M_FULL['L.ELBOW']), 62 | (H36M_FULL['L.ELBOW'], H36M_FULL['L.HAND']), 63 | (H36M_FULL['U.TORSO'], H36M_FULL['R.SHOULDER']), 64 | (H36M_FULL['R.SHOULDER'], H36M_FULL['R.ELBOW']), 65 | (H36M_FULL['R.ELBOW'], H36M_FULL['R.HAND']), 66 | (H36M_FULL['U.TORSO'], H36M_FULL['NECK']), 67 | (H36M_FULL['NECK'], H36M_FULL['HEAD']) 68 | } 69 | 70 | def rotate_around_z_axis(points, theta): 71 | c, s = np.cos(np.radians(theta)), np.sin(np.radians(theta)) 72 | R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) 73 | return np.dot(points, R.T) 74 | 75 | 76 | def visualize_sequence(seq, name): 77 | VIEWS = { 78 | "pd": { 79 | "best": (45, 20, 100), 80 | "best2": (0, 0, 0), 81 | "side": (90, 0, 90), 82 | }, 83 | "tmp": { 84 | "best": (45, 20, 100), 85 | "side": (90, 0, 90), 86 | } 87 | } 88 | elev, azim, roll = VIEWS["pd"]["side"] 89 | # Apply the rotation to each point in the sequence 90 | for i in range(seq.shape[1]): 91 | seq[:, i, :] = rotate_around_z_axis(seq[:, i, :], roll) 92 | 93 | def update(frame): 94 | ax.clear() 95 | 96 | ax.set_xlim3d([min_x, max_x]) 97 | ax.set_ylim3d([min_y, max_y]) 98 | ax.set_zlim3d([min_z, max_z]) 99 | 100 | # print(VIEWS[data_type][view_type]) 101 | # ax.view_init(*VIEWS[data_type][view_type]) 102 | elev, azim, roll = VIEWS["pd"]["best"] 103 | ax.view_init(elev=elev, azim=azim) 104 | ax.set_box_aspect(aspect_ratio) 105 | ax.set_title(f'Frame: {frame}') 106 | 107 | x = seq[frame, :, 0] 108 | y = seq[frame, :, 1] 109 | z = seq[frame, :, 2] 110 | 111 | for connection in H36M_CONNECTIONS_FULL: 112 | start = seq[frame, connection[0], :] 113 | end = seq[frame, connection[1], :] 114 | xs = [start[0], end[0]] 115 | ys = [start[1], end[1]] 116 | zs = [start[2], end[2]] 117 | 118 | ax.plot(xs, ys, zs) 119 | ax.scatter(x, y, z) 120 | 121 | 122 | print(f"Number of frames: {seq.shape[0]}") 123 | 124 | min_x, min_y, min_z = np.min(seq, axis=(0, 1)) 125 | max_x, max_y, max_z = np.max(seq, axis=(0, 1)) 126 | 127 | x_range = max_x - min_x 128 | y_range = max_y - min_y 129 | z_range = max_z - min_z 130 | aspect_ratio = [x_range, y_range, z_range] 131 | 132 | 133 | fig = plt.figure() 134 | ax = fig.add_subplot(111, projection='3d') 135 | 136 | # create the animation 137 | ani = FuncAnimation(fig, update, frames=seq.shape[0], interval=1) 138 | ani.save(f'{name}.gif', writer='pillow') 139 | 140 | plt.close(fig) 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser() 145 | 146 | parser.add_argument('--backbone', type=str, default='motionbert', help='model name ( poseformer, ' 147 | 'motionbert )') 148 | parser.add_argument('--train_mode', type=str, default='classifier_only', help='train mode( end2end, classifier_only )') 149 | parser.add_argument('--dataset', type=str, default='PD', 150 | help='**currently code only works for PD') 151 | parser.add_argument('--data_path', type=str, 152 | default=path.PD_PATH_POSES) 153 | parser.add_argument('--seed', default=0, type=int, help='random seed') 154 | parser.add_argument('--tune_fresh', default=1, type=int, help='start a new tuning process or cont. on a previous study') 155 | parser.add_argument('--ntrials', default=30, type=int, help='number of hyper-param tuning trials') 156 | parser.add_argument('--last_run_foldnum', type=str, default='1') 157 | parser.add_argument('--readstudyfrom', default=5, type=int) 158 | 159 | parser.add_argument('--hypertune', default=1, type=int, help='perform hyper parameter tuning [0 or 1]') 160 | 161 | args = parser.parse_args() 162 | 163 | param = vars(args) 164 | 165 | backbone_name = param['backbone'] 166 | 167 | if backbone_name == 'poseformer': 168 | conf_path = './configs/poseformer/' 169 | elif backbone_name == 'motionbert': 170 | conf_path = './configs/motionbert/' 171 | elif backbone_name == 'poseformerv2': 172 | conf_path = "./configs/poseformerv2" 173 | elif backbone_name == 'mixste': 174 | conf_path = "./configs/mixste", 175 | elif backbone_name == 'motionagformer': 176 | conf_path = "./configs/motionagformer" 177 | else: 178 | raise NotImplementedError(f"Backbone '{backbone_name}' is not supported") 179 | 180 | for fi in sorted(os.listdir(conf_path)): 181 | if backbone_name == 'poseformer': 182 | params, new_params = generate_config_poseformer.generate_config(param, fi) 183 | elif backbone_name == 'motionbert': 184 | params, new_params = generate_config_motionbert.generate_config(param, fi) 185 | elif backbone_name == "poseformerv2": 186 | params, new_params = generate_config_poseformerv2.generate_config(param, fi) 187 | elif backbone_name == "mixste": 188 | params, new_params = generate_config_mixste.generate_config(param, fi) 189 | elif backbone_name == "motionagformer": 190 | params, new_params = generate_config_motionagformer.generate_config(param, fi) 191 | else: 192 | raise NotImplementedError(f"Backbone '{param['backbone']}' does not exist.") 193 | 194 | train_dataset_fn, test_dataset_fn, val_dataset_fn, class_weights = dataset_factory(params, backbone_name, 1) 195 | 196 | params['input_dim'] = train_dataset_fn.dataset._pose_dim 197 | params['pose_dim'] = train_dataset_fn.dataset._pose_dim 198 | params['num_joints'] = train_dataset_fn.dataset._NMAJOR_JOINTS 199 | 200 | 201 | model_backbone = load_pretrained_backbone(params, backbone_name) 202 | model_backbone = model_backbone.to(_DEVICE) 203 | for param in model_backbone.parameters(): 204 | param.requires_grad = False 205 | print("[INFO - MotionEncoder] Backbone parameters are frozen") 206 | 207 | 208 | 209 | for x, _, video_idx in train_dataset_fn: 210 | x = x.to(_DEVICE) 211 | 212 | batch_size = x.shape[0] 213 | 214 | pose3D = model_backbone(x, return_rep=False) 215 | pose3D = pose3D.cpu().numpy() 216 | for b in range(batch_size): 217 | visualize_sequence(pose3D[b,:,:,:], f'./data/pd/pd_reconst/video{video_idx[b].cpu().numpy()}') 218 | ppp=1 219 | 220 | 221 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/data/__init__.py -------------------------------------------------------------------------------- /data/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | 6 | class MirrorReflection: 7 | """ 8 | Do horizontal flipping for each frame of the sequence. 9 | 10 | Args: 11 | format (str): Skeleton format. By default it expects it to be h36m (as motion encoders are mostly trained on that) 12 | """ 13 | 14 | def __init__(self, format='h36m', data_dim=2): 15 | if format == 'h36m': 16 | self.left = [14, 15, 16, 1, 2, 3] 17 | self.right = [11, 12, 13, 4, 5, 6] 18 | else: 19 | raise NotImplementedError("Skeleton format is not supported.") 20 | self.data_dim = data_dim 21 | 22 | def __call__(self, sample): 23 | sequence, label, labels_str = sample['encoder_inputs'], sample['label'], sample['labels_str'] 24 | if isinstance(sequence, np.ndarray): 25 | sequence = torch.from_numpy(sequence) 26 | if self.data_dim == 3: 27 | merge_last_dim = 0 28 | if sequence.ndim == 2: 29 | sequence = sequence.view(-1, 17, 3) # Reshape sequence back to N x 17 x 3 30 | merge_last_dim = 1 31 | mirrored_sequence = sequence.clone() 32 | mirrored_sequence[:, :, 0] *= -1 33 | mirrored_sequence[:, self.left + self.right, :] = mirrored_sequence[:, self.right + self.left, :] 34 | 35 | if self.data_dim == 3 and merge_last_dim: # Reshape sequence back to N x 51 36 | N = np.shape(mirrored_sequence)[0] 37 | mirrored_sequence = mirrored_sequence.reshape(N, -1) 38 | return { 39 | 'encoder_inputs': mirrored_sequence, 40 | 'label': label, 41 | 'labels_str': labels_str 42 | } 43 | 44 | 45 | class RandomRotation: 46 | """ 47 | Rotate randomly all the joints in all the frames. 48 | 49 | Args: 50 | min_rotate (int): Minimum degree of rotation angle. 51 | max_rotate (int): Maximum degree of rotation angle. 52 | """ 53 | 54 | def __init__(self, min_rotate, max_rotate, data_dim=2): 55 | self.min_rotate, self.max_rotate = min_rotate, max_rotate 56 | self.data_dim = data_dim 57 | 58 | def _create_3d_rotation_matrix(self, axis, rotation_angle): 59 | theta = rotation_angle * (torch.pi / 180) 60 | if axis == 0: # x-axis 61 | rotation_matrix = torch.tensor([[1, 0, 0], 62 | [0, torch.cos(theta), torch.sin(theta)], 63 | [0, -torch.sin(theta), torch.cos(theta)]]) 64 | elif axis == 1: # y-axis 65 | rotation_matrix = torch.tensor([[torch.cos(theta), 0, -torch.sin(theta)], 66 | [0, 1, 0], 67 | [torch.sin(theta), 0, torch.cos(theta)]]) 68 | elif axis == 2: # z-axis 69 | rotation_matrix = torch.tensor([[torch.cos(theta), torch.sin(theta), 0], 70 | [-torch.sin(theta), torch.cos(theta), 0], 71 | [0, 0, 1]]) 72 | return rotation_matrix 73 | 74 | def _create_rotation_matrix(self): 75 | rotation_angle = torch.FloatTensor(1).uniform_(self.min_rotate, self.max_rotate) 76 | theta = rotation_angle * (torch.pi / 180) 77 | 78 | rotation_matrix = torch.tensor([[torch.cos(theta), -torch.sin(theta)], 79 | [torch.sin(theta), torch.cos(theta)]]) 80 | return rotation_matrix 81 | 82 | def __call__(self, sample): 83 | sequence, label, labels_str = sample['encoder_inputs'], sample['label'], sample['labels_str'] 84 | if isinstance(sequence, np.ndarray): 85 | sequence = torch.from_numpy(sequence) 86 | 87 | if self.data_dim == 2: 88 | rotation_matrix = self._create_rotation_matrix() 89 | 90 | sequence_has_confidence_score = sequence.shape[-1] == 3 91 | if sequence_has_confidence_score: 92 | rotated_sequence = sequence[..., :2] @ rotation_matrix 93 | rotated_sequence = torch.cat((rotated_sequence, sequence[..., 2].unsqueeze(-1)), dim=-1) 94 | else: 95 | rotated_sequence = sequence @ rotation_matrix 96 | else: 97 | merge_last_dim = 0 98 | if sequence.ndim == 2: 99 | sequence = sequence.view(-1, 17, 3) # Reshape sequence back to N x 17 x 3 100 | merge_last_dim = 1 101 | rotated_sequence = sequence.clone() 102 | total_axis = [0, 1, 2] 103 | main_axis = random.randint(0, 2) 104 | for axis in total_axis: 105 | if axis == main_axis: 106 | rotation_angle = torch.FloatTensor(1).uniform_(self.min_rotate, self.max_rotate) 107 | rotation_matrix = self._create_3d_rotation_matrix(axis, rotation_angle) 108 | else: 109 | rotation_angle = torch.FloatTensor(1).uniform_(self.min_rotate/10, self.max_rotate/10) 110 | rotation_matrix = self._create_3d_rotation_matrix(axis, rotation_angle) 111 | rotated_sequence = rotated_sequence @ rotation_matrix 112 | if merge_last_dim: # Reshape sequence back to N x 51 113 | N = np.shape(rotated_sequence)[0] 114 | rotated_sequence = rotated_sequence.reshape(N, -1) 115 | 116 | return { 117 | 'encoder_inputs': rotated_sequence, 118 | 'label': label, 119 | 'labels_str': labels_str 120 | } 121 | 122 | 123 | class RandomNoise: 124 | """ 125 | Adds noise randomly to each join separately from normal distribution. 126 | """ 127 | def __init__(self, mean=0, std=0.01, data_dim=2): 128 | self.mean = mean 129 | self.std = std 130 | self.data_dim = data_dim 131 | 132 | def __call__(self, sample): 133 | sequence, label, labels_str = sample['encoder_inputs'], sample['label'], sample['labels_str'] 134 | if isinstance(sequence, np.ndarray): 135 | sequence = torch.from_numpy(sequence) 136 | noise = torch.normal(self.mean, self.std, size=sequence.shape) 137 | noise_sequence = sequence + noise 138 | 139 | return { 140 | 'encoder_inputs': noise_sequence, 141 | 'label': label, 142 | 'labels_str': labels_str 143 | } 144 | 145 | 146 | class axis_mask: 147 | def __init__(self, data_dim=3): 148 | self.data_dim = data_dim 149 | 150 | def Zero_out_axis(self, sequence): 151 | axis_next = random.randint(0, self.data_dim-1) 152 | temp = sequence.clone() 153 | T, J, C = sequence.shape 154 | x_new = torch.zeros(T, J, device=temp.device) 155 | temp[:, :, axis_next] = x_new 156 | return temp 157 | 158 | def __call__(self, sample): 159 | 160 | sequence, label, labels_str = sample['encoder_inputs'], sample['label'], sample['labels_str'] 161 | if isinstance(sequence, np.ndarray): 162 | sequence = torch.from_numpy(sequence) 163 | 164 | if self.data_dim > 2 : 165 | if self.data_dim == 3: 166 | merge_last_dim = 0 167 | if sequence.ndim == 2: 168 | sequence = sequence.view(-1, 17, 3) # Reshape sequence back to N x 17 x 3 169 | merge_last_dim = 1 170 | masked_sequence = self.Zero_out_axis(sequence) 171 | 172 | if self.data_dim == 3 and merge_last_dim: # Reshape sequence back to N x 51 173 | N = np.shape(masked_sequence)[0] 174 | masked_sequence = masked_sequence.reshape(N, -1) 175 | 176 | return { 177 | 'encoder_inputs': masked_sequence, 178 | 'label': label, 179 | 'labels_str': labels_str 180 | } 181 | else: 182 | return { 183 | 'encoder_inputs': sequence, 184 | 'label': label, 185 | 'labels_str': labels_str 186 | } 187 | -------------------------------------------------------------------------------- /data/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | from data.visualize import visualize_sequence 5 | 6 | 7 | class PoseSequenceAugmentation: 8 | def __init__(self, params): 9 | self.augmentation_methods = { 10 | "mirror_reflection": self.mirror_reflection, 11 | "joint_dropout": self.joint_dropout, 12 | "random_rotation": self.random_rotation, 13 | "random_translation": self.random_translation 14 | } 15 | self.params = params 16 | 17 | def augment_data(self, raw_data, augmentation_list, visualize_only=False): 18 | if "random_translation" in augmentation_list: 19 | self.estimate_translation_range(raw_data.pose_dict) 20 | 21 | augmented_data_dict = {"pose_dict": {}, "labels_dict": {}} 22 | 23 | augmented_video_names = [] 24 | for video_name, pose_sequence in raw_data.pose_dict.items(): 25 | augmented_sequences = {} 26 | 27 | for augmentation_name in augmentation_list: 28 | if augmentation_name in self.augmentation_methods: 29 | augmented_sequence = self.augmentation_methods[augmentation_name](pose_sequence) 30 | augmented_sequences[augmentation_name] = augmented_sequence 31 | if visualize_only: 32 | visualize_sequence(pose_sequence, augmented_sequence, video_name + '_org') 33 | else: 34 | print(f"Warning: Unknown augmentation technique '{augmentation_name}'") 35 | 36 | if visualize_only: 37 | exit() 38 | 39 | for augmentation_name, augmented_sequence in augmented_sequences.items(): 40 | augmented_video_name = f"{video_name}_{augmentation_name}" 41 | augmented_video_names.append(augmentation_name) 42 | 43 | augmented_data_dict["pose_dict"][augmented_video_name] = augmented_sequence 44 | augmented_data_dict["labels_dict"][augmented_video_name] = raw_data.labels_dict[ 45 | video_name] 46 | 47 | return self.update_datareader(raw_data, augmented_data_dict, augmented_video_names) 48 | 49 | @staticmethod 50 | def update_datareader(raw_data, augmented_data_dict, augmented_video_names): 51 | raw_data_augmented = copy.deepcopy(raw_data) 52 | raw_data_augmented.labels = raw_data_augmented.labels + list( 53 | augmented_data_dict['labels_dict'].values()) # TODO: Should we also remove this? 54 | raw_data_augmented.video_names = raw_data_augmented.video_names + augmented_video_names 55 | raw_data_augmented.labels_dict.update(augmented_data_dict['labels_dict']) 56 | raw_data_augmented.pose_dict.update(augmented_data_dict['pose_dict']) 57 | return raw_data_augmented 58 | 59 | @staticmethod 60 | def mirror_reflection(pose_sequence): 61 | mirrored_sequence = pose_sequence.copy() 62 | left = [4, 5, 6, 10, 11, 12] 63 | right = [7, 8, 9, 13, 14, 15] 64 | mirrored_sequence[:, :, 0] *= -1 65 | mirrored_sequence[:, left + right, :] = mirrored_sequence[:, right + left, :] 66 | return mirrored_sequence 67 | 68 | @staticmethod 69 | def joint_dropout(pose_sequence, dropout_prob): 70 | # Randomly remove certain joints from the pose sequence 71 | dropout_mask = np.random.choice([0, 1], size=pose_sequence.shape[1], p=[dropout_prob, 1 - dropout_prob]) 72 | dropped_sequence = pose_sequence * dropout_mask 73 | return dropped_sequence 74 | 75 | def random_rotation(self, pose_sequence): 76 | # Randomly rotate the pose sequence 77 | rotation_angles = np.random.uniform(self.params['rotation_range'][0], self.params['rotation_range'][1], size=3) 78 | rotation_matrix = self.rotation_matrix(rotation_angles) 79 | rotated_sequence = np.matmul(pose_sequence, rotation_matrix) 80 | return rotated_sequence 81 | 82 | def random_translation(self, pose_sequence): 83 | noise_scale = 0 84 | # # Randomly translate the pose sequence with added noise 85 | translation = np.random.uniform(self.translation_range[0], self.translation_range[1], size=3) 86 | noise = np.random.normal(scale=noise_scale, size=pose_sequence.shape) 87 | translated_sequence = pose_sequence + translation + noise 88 | return translated_sequence 89 | 90 | def estimate_translation_range(self, pose_dict): 91 | min_values = np.min([np.min(pose) for pose in pose_dict.values()]) 92 | max_values = np.max([np.max(pose) for pose in pose_dict.values()]) 93 | overall_range = max_values - min_values 94 | self.translation_range = (-self.params['translation_frac'] * overall_range, 95 | self.params['translation_frac'] * overall_range) # Adjust the fraction (0.1) 96 | 97 | def estimate_noise_scale(pose_dict): 98 | min_values = np.min([np.min(pose) for pose in pose_dict.values()]) 99 | max_values = np.max([np.max(pose) for pose in pose_dict.values()]) 100 | overall_range = max_values - min_values 101 | noise_scale = 0.1 * overall_range # Adjust the fraction (0.1) as desired 102 | return noise_scale 103 | 104 | @staticmethod 105 | def rotation_matrix(angles): 106 | radians = angles * (np.pi / 180) 107 | # Helper function to generate a rotation matrix 108 | alpha, beta, gamma = radians 109 | Rx = np.array([[1, 0, 0], 110 | [0, np.cos(alpha), -np.sin(alpha)], 111 | [0, np.sin(alpha), np.cos(alpha)]]) 112 | Ry = np.array([[np.cos(beta), 0, np.sin(beta)], 113 | [0, 1, 0], 114 | [-np.sin(beta), 0, np.cos(beta)]]) 115 | Rz = np.array([[np.cos(gamma), -np.sin(gamma), 0], 116 | [np.sin(gamma), np.cos(gamma), 0], 117 | [0, 0, 1]]) 118 | rotation_matrix = np.matmul(Rz, np.matmul(Ry, Rx)) 119 | return rotation_matrix 120 | -------------------------------------------------------------------------------- /data/pd/const_pd.py: -------------------------------------------------------------------------------- 1 | H36M_FULL = { 2 | 'B.TORSO': 0, 3 | 'L.HIP': 1, 4 | 'L.KNEE': 2, 5 | 'L.FOOT': 3, 6 | 'R.HIP': 4, 7 | 'R.KNEE': 5, 8 | 'R.FOOT': 6, 9 | 'C.TORSO': 7, 10 | 'U.TORSO': 8, 11 | 'NECK': 9, 12 | 'HEAD': 10, 13 | 'R.SHOULDER': 11, 14 | 'R.ELBOW': 12, 15 | 'R.HAND': 13, 16 | 'L.SHOULDER': 14, 17 | 'L.ELBOW': 15, 18 | 'L.HAND': 16 19 | } 20 | 21 | PD = { 22 | 'CLAV': 0, 23 | 'STRN': 1, 24 | 'C7': 2, 25 | 'T10': 3, 26 | 'R.SHO': 4, 27 | 'L.SHO': 5, 28 | 'R.UPA': 6, 29 | 'R.EL': 7, 30 | 'R.EM': 8, 31 | 'R.FRA': 9, 32 | 'R.WL': 10, 33 | 'R.WM': 11, 34 | 'L.UPA': 12, 35 | 'L.EL': 13, 36 | 'L.EM': 14, 37 | 'L.FRA': 15, 38 | 'L.WL': 16, 39 | 'L.WM': 17, 40 | 'R.ASIS': 18, 41 | 'L.ASIS': 19, 42 | 'R.PSIS': 20, 43 | 'L.PSIS': 21, 44 | 'R.GTR': 22, 45 | 'R.KNEE': 23, 46 | 'R.HF': 24, 47 | 'R.TT': 25, 48 | 'R.ANKLE': 26, 49 | 'R.HEEL': 27, 50 | 'R.MT1': 28, 51 | 'R.MT5': 29, 52 | 'L.GTR': 30, 53 | 'L.KNEE': 31, 54 | 'L.HF': 32, 55 | 'L.TT': 33, 56 | 'L.ANKLE': 34, 57 | 'L.HEEL': 35, 58 | 'L.MT1': 36, 59 | 'L.MT5': 37, 60 | 'R.KNEE.MEDIAL': 38, 61 | 'R.ANKLE.MEDIAL': 39, 62 | 'R.MT2': 40, 63 | 'L.KNEE.MEDIAL': 41, 64 | 'L.ANKLE.MEDIAL': 42, 65 | 'L.MT2': 43 66 | } 67 | -------------------------------------------------------------------------------- /data/pd/preprocess_pd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pandas as pd 4 | import numpy as np 5 | import c3d 6 | import csv 7 | 8 | from const_pd import H36M_FULL, PD 9 | 10 | from matplotlib import pyplot as plt 11 | import matplotlib 12 | from matplotlib.animation import FuncAnimation 13 | 14 | 15 | def convert_pd_h36m(sequence): 16 | new_keyponts = np.zeros((sequence.shape[0], 17, 3)) 17 | new_keyponts[..., H36M_FULL['B.TORSO'], :] = (sequence[..., PD['L.ASIS'], :] + 18 | sequence[..., PD['R.ASIS'], :] + 19 | sequence[..., PD['L.PSIS'], :] + 20 | sequence[..., PD['R.PSIS'], :]) / 4 21 | new_keyponts[..., H36M_FULL['L.HIP'], :] = (sequence[..., PD['L.ASIS'], :] + 22 | sequence[..., PD['L.PSIS'], :]) / 2 23 | new_keyponts[..., H36M_FULL['L.KNEE'], :] = sequence[..., PD['L.KNEE'], :] 24 | new_keyponts[..., H36M_FULL['L.FOOT'], :] = sequence[..., PD['L.ANKLE'], :] 25 | new_keyponts[..., H36M_FULL['R.HIP'], :] = (sequence[..., PD['R.ASIS'], :] + 26 | sequence[..., PD['R.PSIS'], :]) / 2 27 | new_keyponts[..., H36M_FULL['R.KNEE'], :] = sequence[..., PD['R.KNEE'], :] 28 | new_keyponts[..., H36M_FULL['R.FOOT'], :] = sequence[..., PD['R.ANKLE'], :] 29 | new_keyponts[..., H36M_FULL['U.TORSO'], :] = (sequence[..., PD['C7'], :] + 30 | sequence[..., PD['CLAV'], :]) / 2 31 | new_keyponts[..., H36M_FULL['C.TORSO'], :] = (sequence[..., PD['STRN'], :] + 32 | sequence[..., PD['T10'], :]) / 2 33 | new_keyponts[..., H36M_FULL['R.SHOULDER'], :] = sequence[..., PD['R.SHO'], :] 34 | new_keyponts[..., H36M_FULL['R.ELBOW'], :] = (sequence[..., PD['R.EL'], :] + 35 | sequence[..., PD['R.EM'], :]) / 2 36 | new_keyponts[..., H36M_FULL['R.HAND'], :] = (sequence[..., PD['R.WL'], :] + 37 | sequence[..., PD['R.WM'], :]) / 2 38 | new_keyponts[..., H36M_FULL['L.SHOULDER'], :] = sequence[..., PD['L.SHO'], :] 39 | new_keyponts[..., H36M_FULL['L.ELBOW'], :] = (sequence[..., PD['L.EL'], :] + 40 | sequence[..., PD['L.EM'], :]) / 2 41 | new_keyponts[..., H36M_FULL['L.HAND'], :] = (sequence[..., PD['L.WL'], :] + 42 | sequence[..., PD['L.WM'], :]) / 2 43 | new_keyponts[..., H36M_FULL['NECK'], :] = new_keyponts[..., H36M_FULL['U.TORSO'], :] + [0.27, 57.48, 11.44] 44 | new_keyponts[..., H36M_FULL['HEAD'], :] = new_keyponts[..., H36M_FULL['U.TORSO'], :] + [-2.07, 165.23, 34.02] 45 | 46 | return new_keyponts 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser() 50 | 51 | parser.add_argument('--input_path', default='/mnt/Ndrive/AMBIENT/Vida/Public_datasets/pd/14896881/', type=str, help='Path to the input folder') 52 | args = parser.parse_args() 53 | return args 54 | 55 | 56 | def rotate_around_z_axis(points, theta): 57 | c, s = np.cos(np.radians(theta)), np.sin(np.radians(theta)) 58 | R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) 59 | return np.dot(points, R.T) 60 | 61 | 62 | def visualize_sequence(seq, name): 63 | VIEWS = { 64 | "pd": { 65 | "best": (45, 20, 100), 66 | "best2": (0, 0, 0), 67 | "side": (90, 0, 90), 68 | }, 69 | "tmp": { 70 | "best": (45, 20, 100), 71 | "side": (90, 0, 90), 72 | } 73 | } 74 | elev, azim, roll = VIEWS["pd"]["best"] 75 | # Apply the rotation to each point in the sequence 76 | for i in range(seq.shape[1]): 77 | seq[:, i, :] = rotate_around_z_axis(seq[:, i, :], roll) 78 | 79 | def update(frame): 80 | ax.clear() 81 | 82 | ax.set_xlim3d([min_x, max_x]) 83 | ax.set_ylim3d([min_y, max_y]) 84 | ax.set_zlim3d([min_z, max_z]) 85 | 86 | # print(VIEWS[data_type][view_type]) 87 | # ax.view_init(*VIEWS[data_type][view_type]) 88 | elev, azim, roll = VIEWS["pd"]["best"] 89 | ax.view_init(elev=elev, azim=azim) 90 | ax.set_box_aspect(aspect_ratio) 91 | ax.set_title(f'Frame: {frame}') 92 | 93 | x = seq[frame, :, 0] 94 | y = seq[frame, :, 1] 95 | z = seq[frame, :, 2] 96 | 97 | # for connection in connections: 98 | # start = seq[frame, connection[0], :] 99 | # end = seq[frame, connection[1], :] 100 | # xs = [start[0], end[0]] 101 | # ys = [start[1], end[1]] 102 | # zs = [start[2], end[2]] 103 | 104 | # ax.plot(xs, ys, zs) 105 | ax.scatter(x, y, z) 106 | 107 | 108 | print(f"Number of frames: {seq.shape[0]}") 109 | 110 | min_x, min_y, min_z = np.min(seq, axis=(0, 1)) 111 | max_x, max_y, max_z = np.max(seq, axis=(0, 1)) 112 | 113 | x_range = max_x - min_x 114 | y_range = max_y - min_y 115 | z_range = max_z - min_z 116 | aspect_ratio = [x_range, y_range, z_range] 117 | 118 | 119 | fig = plt.figure() 120 | ax = fig.add_subplot(111, projection='3d') 121 | 122 | # create the animation 123 | ani = FuncAnimation(fig, update, frames=seq.shape[0], interval=1) 124 | ani.save(f'{name}.gif', writer='pillow') 125 | 126 | plt.close(fig) 127 | 128 | def read_pd(sequence_path, start_index, step): 129 | """ 130 | Read points data from a .c3d file and create a sequence of selected frames. 131 | 132 | Parameters: 133 | sequence_path (str): The file path for the .c3d file. 134 | start_index (int): The frame index at which to start reading the data. 135 | step (int): The number of frames to skip between reads. A step of n reads every nth frame. 136 | 137 | Returns: 138 | numpy.ndarray: An array containing the processed sequence of points data from the .c3d file. 139 | 140 | """ 141 | reader = c3d.Reader(open(sequence_path, 'rb')) 142 | sequence = [] 143 | for i, points, analog in reader.read_frames(): 144 | if i >= start_index and (i - start_index) % step == 0: 145 | if np.any(np.all(points[:44, :3] == 0, axis=1)): #Removed frames with corrupted joints 146 | continue 147 | sequence.append(points[None, :44, :3]) 148 | if len(sequence) == 0: 149 | print(sequence_path) 150 | with open('./data/pd/Removed_sequences.csv', 'a', newline='') as file: 151 | writer = csv.writer(file) 152 | writer.writerow([sequence_path]) 153 | return sequence 154 | # sequence2 = [] 155 | # for i, points, analog in reader.read_frames(): 156 | # if i >= start_index and (i - start_index) % step == 0: 157 | # sequence2.append(points[None, :44, :3]) 158 | # sequence2 = np.concatenate(sequence2) 159 | # visualize_sequence(sequence2, './data/pd/orig_allremoved') 160 | sequence = np.concatenate(sequence) 161 | 162 | sequence = convert_pd_h36m(sequence) 163 | # visualize_sequence(sequence, './data/pd/orig_all') 164 | return sequence 165 | 166 | def main(): 167 | args = parse_args() 168 | 169 | input_path_c3dfiles = os.path.join(args.input_path, 'C3Dfiles') 170 | output_path_c3dfiles = os.path.join(args.input_path, 'C3Dfiles_processed_new') 171 | 172 | if not os.path.exists(input_path_c3dfiles): 173 | raise FileNotFoundError(f"Input folder '{input_path_c3dfiles}' not found.") 174 | 175 | os.makedirs(output_path_c3dfiles, exist_ok=True) 176 | for root, dirs, files in os.walk(input_path_c3dfiles): 177 | for file in files: 178 | if file.endswith('.c3d') and "walk" in file and file.startswith("SUB"): 179 | sequence_path = os.path.join(root, file) 180 | try: 181 | for start_index in range(3): 182 | sequence = read_pd(sequence_path, start_index, 3) 183 | if len(sequence) == 0: 184 | continue 185 | output_sequence_path = os.path.join(output_path_c3dfiles, f"{file[:-4]}_{start_index}") 186 | print(output_sequence_path) 187 | np.save(output_sequence_path + '.npy', sequence) 188 | except Exception as e: 189 | print(f"Error reading {sequence_path}: {str(e)}") 190 | 191 | if __name__ == "__main__": 192 | main() -------------------------------------------------------------------------------- /data/public_pd_datareader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from tqdm import tqdm 5 | from datetime import * 6 | 7 | class PDReader(): 8 | """ 9 | Reads the data from the Parkinson's Disease dataset 10 | """ 11 | 12 | ON_LABEL_COLUMN = 'ON - UPDRS-III - walking' 13 | OFF_LABEL_COLUMN = 'OFF - UPDRS-III - walking' 14 | DELIMITER = ';' 15 | 16 | def __init__(self, joints_path, labels_path): 17 | self.joints_path = joints_path 18 | self.labels_path = labels_path 19 | self.pose_dict, self.labels_dict, self.video_names, self.participant_ID, self.metadata_dict = self.read_keypoints_and_labels() 20 | 21 | def read_sequence(self, path_file): 22 | """ 23 | Reads skeletons from npy files 24 | """ 25 | if os.path.exists(path_file): 26 | body = np.load(path_file) 27 | body = body/1000 #convert mm to m 28 | else: 29 | body = None 30 | return body 31 | 32 | def read_label(self, file_name): 33 | subject_id, on_or_off = file_name.split("_")[:2] 34 | df = pd.read_excel(self.labels_path) 35 | # df = pd.read_csv(self.labels_path, delimiter=self.DELIMITER) 36 | df = df[['ID', self.ON_LABEL_COLUMN, self.OFF_LABEL_COLUMN]] 37 | subject_rows = df[df['ID'] == subject_id] 38 | if on_or_off == "on": 39 | label = subject_rows[self.ON_LABEL_COLUMN].values[0] 40 | else: 41 | label = subject_rows[self.OFF_LABEL_COLUMN].values[0] 42 | return int(label) 43 | 44 | def read_metadata(self, file_name): 45 | #If you change this function make sure to adjust the METADATA_MAP in the dataloaders.py accordingly 46 | subject_id = file_name.split("_")[0] 47 | df = pd.read_excel(self.labels_path) 48 | df = df[['ID', 'Gender', 'Age', 'Height (cm)', 'Weight (kg)', 'BMI (kg/m2)']] 49 | df.rename(columns={ 50 | "Gender": "gender", 51 | "Age": "age", 52 | "Height (cm)": "height", 53 | "Weight (kg)": "weight", 54 | "BMI (kg/m2)": "bmi"}, inplace=True) 55 | df.loc[:, 'gender'] = df['gender'].map({'M': 0, 'F': 1}) 56 | 57 | # Using Min-Max normalization 58 | df['age'] = (df['age'] - df['age'].min()) / (df['age'].max() - df['age'].min()) 59 | df['height'] = (df['height'] - df['height'].min()) / (df['height'].max() - df['height'].min()) 60 | df['weight'] = (df['weight'] - df['weight'].min()) / (df['weight'].max() - df['weight'].min()) 61 | df['bmi'] = (df['bmi'] - df['bmi'].min()) / (df['bmi'].max() - df['bmi'].min()) 62 | 63 | subject_rows = df[df['ID'] == subject_id] 64 | return subject_rows.values[:, 1:] 65 | 66 | def read_keypoints_and_labels(self): 67 | """ 68 | Read npy files in given directory into arrays of pose keypoints. 69 | :return: dictionary with 70 | """ 71 | pose_dict = {} 72 | labels_dict = {} 73 | metadata_dict = {} 74 | video_names_list = [] 75 | participant_ID = [] 76 | 77 | print('[INFO - PublicPDReader] Reading body keypoints from npy') 78 | 79 | print(self.joints_path) 80 | 81 | for file_name in tqdm(os.listdir(self.joints_path)): 82 | path_file = os.path.join(self.joints_path, file_name) 83 | joints = self.read_sequence(path_file) 84 | label = self.read_label(file_name) 85 | metadata = self.read_metadata(file_name) 86 | if joints is None: 87 | print(f"[WARN - PublicPDReader] Numpy file {file_name} does not exist") 88 | continue 89 | file_name = file_name.split(".")[0] 90 | pose_dict[file_name] = joints 91 | labels_dict[file_name] = label 92 | metadata_dict[file_name] = metadata 93 | video_names_list.append(file_name) 94 | participant_ID.append(file_name.split("_")[0]) 95 | 96 | participant_ID = self.select_unique_entries(participant_ID) 97 | 98 | return pose_dict, labels_dict, video_names_list, participant_ID, metadata_dict 99 | 100 | @staticmethod 101 | def select_unique_entries(a_list): 102 | return sorted(list(set(a_list))) 103 | 104 | 105 | def __len__(self): 106 | return len(self.labels) 107 | 108 | def __getitem__(self, idx): 109 | """Get item for the training mode.""" 110 | 111 | # Based on index, get the video name 112 | video_name = self.video_names[idx] 113 | 114 | x = self.poses[video_name] 115 | label = self.labels[video_name] 116 | 117 | 118 | x = np.array(x, dtype=np.float32) 119 | 120 | sample = { 121 | 'encoder_inputs': x, 122 | 'label': label, 123 | 124 | } 125 | #if self.transform: 126 | # sample = self.transform(sample) 127 | 128 | return sample 129 | 130 | # raw_data = PDReader('/data/iballester/datasets/Public_PD/C3Dfiles_processed/', '/data/iballester/datasets/Public_PD/PDGinfo.csv') -------------------------------------------------------------------------------- /data/utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def walkid_to_AMBID(cur_walk_id): 4 | # Extract AMBID from the walk 5 | raw_id = cur_walk_id 6 | if raw_id >= 60: 7 | id = raw_id - 3 8 | else: 9 | id = raw_id - 2 10 | return id 11 | 12 | 13 | def get_AMBID_from_Videoname(path_file): 14 | # The pattern is YYYY_MM_DD_hh_mm_ss_ID_XX_state_X.csv 15 | AMBID = walkid_to_AMBID(int(path_file[24:26])) 16 | AMBID = 'AMB' + str(AMBID).zfill(2) 17 | 18 | return AMBID 19 | 20 | 21 | def extract_unique_subs(dataset): 22 | if dataset is None: 23 | return [] 24 | unique_subs = set() 25 | for name in dataset.video_names: 26 | sub = name.split('_')[0] # Assuming SUBXX is always the first part of the video name 27 | unique_subs.add(sub) 28 | return list(unique_subs) 29 | 30 | def count_labels(dataset, all_labels): 31 | label_counts = {lbl: 0 for lbl in all_labels} # Initialize all labels with count 0 32 | if dataset is not None: 33 | labels, counts = np.unique(dataset.labels, return_counts=True) 34 | label_counts.update(dict(zip(labels, counts))) 35 | return label_counts -------------------------------------------------------------------------------- /eval_encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | 7 | import wandb 8 | 9 | from configs import generate_config_poseformer, generate_config_motionbert, generate_config_poseformerv2, generate_config_mixste, generate_config_motionagformer 10 | 11 | from data.dataloaders import * 12 | from const import path 13 | from utility.utils import set_random_seed 14 | from test import * 15 | 16 | this_path = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.insert(0, this_path + "/../") 18 | 19 | 20 | _DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | def log_results(rep, confusion, rep_name, conf_name, out_p): 23 | print(rep) 24 | fig, ax = plt.subplots(figsize=(10, 8)) 25 | sns.heatmap(confusion, annot=True, ax=ax, cmap="Blues", fmt='g', annot_kws={"size": 26}) 26 | ax.set_xlabel('Predicted labels', fontsize=28) 27 | ax.set_ylabel('True labels', fontsize=28) 28 | ax.set_title('Confusion Matrix', fontsize=30) 29 | ax.xaxis.set_ticklabels(['class 0', 'class 1', 'class 2'], fontsize=22) # Modify class names as needed 30 | ax.yaxis.set_ticklabels(['class 0', 'class 1', 'class 2'], fontsize=22) 31 | # Save the figure 32 | plt.savefig(os.path.join(out_p, conf_name)) 33 | plt.close(fig) 34 | with open(os.path.join(out_p, rep_name), "w") as text_file: 35 | text_file.write(rep) 36 | 37 | artifact = wandb.Artifact(f'confusion_matrices', type='image-results') 38 | artifact.add_file(os.path.join(out_p, conf_name)) 39 | wandb.log_artifact(artifact) 40 | 41 | artifact = wandb.Artifact('reports', type='txtfile-results') 42 | artifact.add_file(os.path.join(out_p, rep_name)) 43 | wandb.log_artifact(artifact) 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument('--backbone', type=str, default='motionbert', help='model name ( poseformer, ''motionbert )') 50 | parser.add_argument('--train_mode', type=str, default='classifier_only', help='train mode( end2end, classifier_only )') 51 | parser.add_argument('--dataset', type=str, default='PD',help='**currently code only works for PD') 52 | parser.add_argument('--data_path', type=str,default=path.PD_PATH_POSES) 53 | parser.add_argument('--seed', default=0, type=int, help='random seed') 54 | parser.add_argument('--tune_fresh', default=1, type=int, help='start a new tuning process or cont. on a previous study') 55 | parser.add_argument('--last_run_foldnum', default='7', type=str) 56 | parser.add_argument('--readstudyfrom', default=1, type=int) 57 | 58 | parser.add_argument('--medication', default=0, type=int, help='add medication prob to the training [0 or 1]') 59 | parser.add_argument('--metadata', default='', type=str, help="add metadata prob to the training 'gender,age,bmi,height,weight'") 60 | 61 | args = parser.parse_args() 62 | 63 | param = vars(args) 64 | 65 | param['metadata'] = param['metadata'].split(',') if param['metadata'] else [] 66 | 67 | torch.backends.cudnn.benchmark = False 68 | 69 | backbone_name = param['backbone'] 70 | 71 | # TODO: Make it scalable 72 | if backbone_name == 'poseformer': 73 | conf_path = './configs/poseformer/' 74 | elif backbone_name == 'motionbert': 75 | conf_path = './configs/motionbert/' 76 | elif backbone_name == 'poseformerv2': 77 | conf_path = './configs/poseformerv2' 78 | elif backbone_name == 'mixste': 79 | conf_path = './configs/mixste' 80 | elif backbone_name == 'motionagformer': 81 | conf_path = './configs/motionagformer' 82 | else: 83 | raise NotImplementedError(f"Backbone '{backbone_name}' is not supported") 84 | 85 | for fi in sorted(os.listdir(conf_path)): 86 | 87 | if backbone_name == 'poseformer': 88 | params, new_params = generate_config_poseformer.generate_config(param, fi) 89 | elif backbone_name == 'motionbert': 90 | params, new_params = generate_config_motionbert.generate_config(param, fi) 91 | elif backbone_name == 'poseformerv2': 92 | params, new_params = generate_config_poseformerv2.generate_config(param, fi) 93 | elif backbone_name == 'mixste': 94 | params, new_params = generate_config_mixste.generate_config(param, fi) 95 | elif backbone_name == 'motionagformer': 96 | params, new_params = generate_config_motionagformer.generate_config(param, fi) 97 | else: 98 | raise NotImplementedError(f"Backbone '{param['backbone']}' does not exist.") 99 | 100 | if param['dataset'] == 'PD': 101 | num_folds = 23 102 | params['num_classes'] = 3 103 | else: 104 | raise NotImplementedError(f"dataset '{param['dataset']}' is not supported.") 105 | 106 | all_folds = range(1, num_folds + 1) 107 | set_random_seed(param['seed']) 108 | 109 | test_and_report(params, new_params, all_folds, backbone_name, _DEVICE) 110 | -------------------------------------------------------------------------------- /learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/learning/__init__.py -------------------------------------------------------------------------------- /learning/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from const import const 5 | 6 | 7 | def choose_criterion(key, params, class_weights): 8 | if key == 'CrossEntropyLoss': 9 | return nn.CrossEntropyLoss() 10 | elif key == 'WCELoss': 11 | class_weights_tensor = torch.FloatTensor(class_weights).to(const._DEVICE) 12 | return nn.CrossEntropyLoss(weight=class_weights_tensor) 13 | elif key == 'WCELoss+smoothing': 14 | #TODO 15 | weights = class_weights #torch.tensor([88., 131., 180.]) 16 | weights = weights / weights.sum() # turn into percentage 17 | weights = 1.0 / weights # inverse 18 | weights = weights / weights.sum() 19 | loss_weights = weights.to(const._DEVICE) 20 | print('Using a weighted *Smoothing CE loss* for gait impairment score prediction.') 21 | WeightedCrossEntropyLossWithLabelSmoothing(weight=loss_weights, smoothing=params['smoothing_scale']) 22 | else: 23 | raise ModuleNotFoundError("Criterion does not exist") 24 | 25 | 26 | # Define cross-entropy loss with label smoothing 27 | class CrossEntropyLossWithLabelSmoothing(nn.Module): 28 | def __init__(self, smoothing=0.1): 29 | super(CrossEntropyLossWithLabelSmoothing, self).__init__() 30 | self.smoothing = smoothing 31 | def forward(self, inputs, targets): 32 | num_classes = inputs.size()[-1] 33 | log_preds = nn.functional.log_softmax(inputs, dim=-1) 34 | targets = torch.zeros_like(log_preds).scatter_(-1, targets.unsqueeze(-1), 1) 35 | targets = (1 - self.smoothing) * targets + self.smoothing / num_classes 36 | loss = nn.functional.kl_div(log_preds, targets, reduction='batchmean') 37 | return loss 38 | 39 | 40 | class WeightedCrossEntropyLossWithLabelSmoothing(nn.Module): 41 | def __init__(self, weight, smoothing=0.1): 42 | super(WeightedCrossEntropyLossWithLabelSmoothing, self).__init__() 43 | self.smoothing = smoothing 44 | self.weight = weight 45 | 46 | def forward(self, inputs, targets): 47 | num_classes = inputs.size()[-1] 48 | log_preds = nn.functional.log_softmax(inputs, dim=-1) 49 | targets = torch.zeros_like(log_preds).scatter_(-1, targets.unsqueeze(-1), 1) 50 | targets = (1 - self.smoothing) * targets + self.smoothing / num_classes 51 | loss = nn.functional.kl_div(log_preds, targets, reduction='none') 52 | loss = loss * self.weight.unsqueeze(0) 53 | loss = loss.sum(dim=-1).mean() 54 | return loss -------------------------------------------------------------------------------- /learning/criterions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/learning/criterions/__init__.py -------------------------------------------------------------------------------- /learning/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from torch.optim.lr_scheduler import StepLR 3 | 4 | 5 | def choose_scheduler(optimizer, params): 6 | scheduler_name = params.get('scheduler') 7 | if scheduler_name is None: 8 | print("[WARN] LR Scheduler is not used") 9 | return None 10 | 11 | if scheduler_name == "StepLR": 12 | scheduler = StepLR(optimizer, step_size=params['lr_step_size'], gamma=params['lr_decay']) 13 | else: 14 | raise ModuleNotFoundError("Scheduler is not defined") 15 | 16 | return scheduler 17 | 18 | 19 | def choose_optimizer(model, params): 20 | optimizer_name = params['optimizer'] 21 | try: 22 | backbone_params = set(model.module.backbone.parameters()) 23 | head_params = set(model.module.head.parameters()) 24 | except AttributeError: 25 | backbone_params = set(model.backbone.parameters()) 26 | head_params = set(model.head.parameters()) 27 | 28 | all_params = set(model.parameters()) 29 | other_params = all_params - backbone_params - head_params 30 | 31 | param_groups = [ 32 | {"params": filter(lambda p: p.requires_grad, backbone_params), "lr": params['lr_backbone']}, 33 | {"params": filter(lambda p: p.requires_grad, head_params), "lr": params['lr_head']}, 34 | {"params": filter(lambda p: p.requires_grad, other_params), "lr": params['lr_head']} 35 | ] 36 | 37 | if optimizer_name == "AdamW": 38 | optimizer = optim.AdamW(param_groups, weight_decay=params['weight_decay']) 39 | elif optimizer_name == "RMSprop": 40 | optimizer = optim.RMSprop(param_groups, weight_decay=params['weight_decay']) 41 | elif optimizer_name == "SGD": 42 | optimizer = optim.SGD(param_groups, momentum=params.get('momentum', 0.9)) 43 | else: 44 | raise ModuleNotFoundError("Optimizer not found") 45 | 46 | return optimizer 47 | -------------------------------------------------------------------------------- /learning/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import wandb 5 | 6 | from collections import Counter 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | 12 | def __init__(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | with torch.no_grad(): 28 | maxk = max(topk) 29 | batch_size = target.size(0) 30 | _, pred = output.topk(maxk, 1, True, True) 31 | pred = pred.t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | res = [] 34 | for k in topk: 35 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 36 | res.append(correct_k.mul_(100.0 / batch_size).item()) 37 | return res 38 | 39 | 40 | def upload_checkpoints_to_wandb(latest_epoch_path, best_epoch_path): 41 | artifact = wandb.Artifact(f'model', type='model') 42 | artifact.add_file(latest_epoch_path) 43 | artifact.add_file(best_epoch_path) 44 | wandb.log_artifact(artifact) 45 | 46 | 47 | def save_checkpoint(checkpoint_root_path, epoch, lr, optimizer, model, best_accuracy, fold, latest): 48 | checkpoint_path_fold = os.path.join(checkpoint_root_path, f"fold{fold}") 49 | if not os.path.exists(checkpoint_path_fold): 50 | os.makedirs(checkpoint_path_fold) 51 | checkpoint_path = os.path.join(checkpoint_path_fold, 52 | 'latest_epoch.pth.tr' if latest else 'best_epoch.pth.tr') 53 | torch.save({ 54 | 'epoch': epoch + 1, 55 | 'lr': lr, 56 | 'optimizer': optimizer.state_dict(), 57 | 'model': model.state_dict(), 58 | 'best_accuracy': best_accuracy 59 | }, checkpoint_path) 60 | 61 | 62 | def assert_learning_params(params): 63 | """Makes sure the learning parameters is set as parameters (To avoid raising error during training)""" 64 | learning_params = ['batch_size', 'criterion', 'optimizer', 'lr_backbone', 'lr_head', 'weight_decay', 'epochs', 65 | 'stopping_tolerance'] 66 | for learning_param in learning_params: 67 | assert learning_param in params, f'"{learning_param}" is not set in params.' 68 | 69 | def compute_class_weights(data_loader): 70 | class_counts = Counter() 71 | total_samples = 0 72 | num_classes = 0 73 | 74 | for _, targets, _, _ in data_loader: 75 | class_counts.update(targets.tolist()) 76 | total_samples += len(targets) 77 | 78 | class_weights = [] 79 | 80 | num_classes = len(class_counts) 81 | for i in range(num_classes): 82 | count = class_counts[i] 83 | weight = 0.0 if count == 0 else total_samples / (num_classes * count) 84 | class_weights.append(weight) 85 | 86 | total_weights = sum(class_weights) 87 | normalized_class_weights = [weight / total_weights for weight in class_weights] 88 | 89 | return normalized_class_weights 90 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/model/__init__.py -------------------------------------------------------------------------------- /model/backbone_loader.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from model.motionbert.DSTformer import DSTformer 7 | from model.poseformer import PoseTransformer 8 | from model.poseformer import PoseEncoderDecoder 9 | from model.poseformerv2.model_poseformer import PoseTransformerV2 10 | from model.mixste.model_cross import MixSTE2 11 | from model.motionagformer.MotionAGFormer import MotionAGFormer 12 | 13 | 14 | def count_parameters(model): 15 | model_params = 0 16 | for parameter in model.parameters(): 17 | model_params += parameter.numel() 18 | return model_params 19 | 20 | 21 | def load_pretrained_weights(model, checkpoint): 22 | """ 23 | Load pretrained weights to model 24 | Incompatible layers (unmatched in name or size) will be ignored 25 | Args: 26 | - model (nn.Module): network model, which must not be nn.DataParallel 27 | - checkpoint (dict): the checkpoint 28 | """ 29 | import collections 30 | if 'state_dict' in checkpoint: 31 | state_dict = checkpoint['state_dict'] 32 | else: 33 | state_dict = checkpoint 34 | model_dict = model.state_dict() 35 | model_first_key = next(iter(model_dict)) 36 | new_state_dict = collections.OrderedDict() 37 | matched_layers, discarded_layers = [], [] 38 | for k, v in state_dict.items(): 39 | # If the pretrained state_dict was saved as nn.DataParallel, 40 | # keys would contain "module.", which should be ignored. 41 | if not 'module.' in model_first_key: 42 | if k.startswith('module.'): 43 | k = k[7:] 44 | if k in model_dict: 45 | new_state_dict[k] = v 46 | matched_layers.append(k) 47 | else: 48 | discarded_layers.append(k) 49 | model_dict.update(new_state_dict) 50 | model.load_state_dict(model_dict, strict=True) 51 | print(f'[INFO] (load_pretrained_weights) {len(matched_layers)} layers are loaded') 52 | print(f'[INFO] (load_pretrained_weights) {len(discarded_layers)} layers are discared') 53 | if len(matched_layers) == 0: 54 | print ("--------------------------model_dict------------------") 55 | print (model_dict.keys()) 56 | print ("--------------------------discarded_layers------------------") 57 | print (discarded_layers) 58 | raise NotImplementedError(f"Loading problem!!!!!!") 59 | 60 | 61 | 62 | def load_pretrained_backbone(params, backbone_name): 63 | if backbone_name == 'motionbert': 64 | model_backbone = DSTformer(dim_in=3, 65 | dim_out=3, 66 | dim_feat=params['dim_feat'], 67 | dim_rep=params['dim_rep'], 68 | depth=params['depth'], 69 | num_heads=params['num_heads'], 70 | mlp_ratio=params['mlp_ratio'], 71 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 72 | maxlen=params['maxlen'], 73 | num_joints=params['num_joints']) 74 | checkpoint = torch.load(params['model_checkpoint_path'], map_location=lambda storage, loc: storage)['model_pos'] 75 | elif backbone_name == 'poseformer': 76 | pose_encoder_fn, _ = PoseEncoderDecoder.select_pose_encoder_decoder_fn(params) 77 | model_backbone = PoseTransformer.model_factory(params, pose_encoder_fn) 78 | checkpoint = torch.load(params['model_checkpoint_path'], map_location=lambda storage, loc: storage) 79 | elif backbone_name == "poseformerv2": 80 | model_backbone = PoseTransformerV2(num_joints=params['num_joints'], 81 | embed_dim_ratio=params['embed_dim_ratio'], 82 | depth=params['depth'], 83 | number_of_kept_frames=params['number_of_kept_frames'], 84 | number_of_kept_coeffs=params['number_of_kept_coeffs'], 85 | in_chans=2, 86 | num_heads=8, 87 | mlp_ratio=2, 88 | qkv_bias=True, 89 | qk_scale=None, 90 | drop_path_rate=0, 91 | ) 92 | checkpoint = torch.load(params['model_checkpoint_path'], map_location=lambda storage, loc: storage)['model_pos'] 93 | elif backbone_name == 'mixste': 94 | model_backbone = MixSTE2( 95 | num_frame=params['source_seq_len'], 96 | num_joints=params['num_joints'], 97 | in_chans=2, 98 | embed_dim_ratio=params['embed_dim_ratio'], 99 | depth=params['depth'], 100 | num_heads=8, 101 | mlp_ratio=2., 102 | qkv_bias=True, 103 | qk_scale=None, 104 | drop_path_rate=0 105 | ) 106 | checkpoint = torch.load(params['model_checkpoint_path'], map_location=lambda storage, loc: storage)['model_pos'] 107 | elif backbone_name == "motionagformer": 108 | model_backbone = MotionAGFormer(n_layers=params['n_layers'], 109 | dim_in=params['dim_in'], 110 | dim_feat=params['dim_feat'], 111 | dim_rep=params['dim_rep'], 112 | dim_out=params['dim_out'], 113 | mlp_ratio=params['mlp_ratio'], 114 | act_layer=nn.GELU, 115 | attn_drop=params['attn_drop'], 116 | drop=params['drop'], 117 | drop_path=params['drop_path'], 118 | use_layer_scale=params['use_layer_scale'], 119 | layer_scale_init_value=params['layer_scale_init_value'], 120 | use_adaptive_fusion=params['use_adaptive_fusion'], 121 | num_heads=params['num_heads'], 122 | qkv_bias=params['qkv_bias'], 123 | qkv_scale=params['qkv_scale'], 124 | hierarchical=params['hierarchical'], 125 | num_joints=params['num_joints'], 126 | use_temporal_similarity=params['use_temporal_similarity'], 127 | temporal_connection_len=params['temporal_connection_len'], 128 | use_tcn=params['use_tcn'], 129 | graph_only=params['graph_only'], 130 | neighbour_num=params['neighbour_num'], 131 | n_frames=params['source_seq_len']) 132 | checkpoint = torch.load(params['model_checkpoint_path'], map_location=lambda storage, loc: storage)['model'] 133 | else: 134 | raise Exception("Undefined backbone type.") 135 | 136 | load_pretrained_weights(model_backbone, checkpoint) 137 | return model_backbone 138 | -------------------------------------------------------------------------------- /model/mixste/rela.py: -------------------------------------------------------------------------------- 1 | # Following the implementation of RELA(https://github.com/rishikksh20/rectified-linear-attention/blob/master/attention.py) 2 | import torch 3 | import math 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | 10 | class Residual(nn.Module): 11 | def __init__(self, fn): 12 | super().__init__() 13 | self.fn = fn 14 | def forward(self, x, **kwargs): 15 | return self.fn(x, **kwargs) + x 16 | 17 | class PreNorm(nn.Module): 18 | def __init__(self, dim, fn): 19 | super().__init__() 20 | self.norm = nn.LayerNorm(dim) 21 | self.fn = fn 22 | def forward(self, x, **kwargs): 23 | return self.fn(self.norm(x), **kwargs) 24 | 25 | class FeedForward(nn.Module): 26 | def __init__(self, dim, hidden_dim, dropout = 0.): 27 | super().__init__() 28 | self.net = nn.Sequential( 29 | nn.Linear(dim, hidden_dim), 30 | nn.GELU(), 31 | nn.Dropout(dropout), 32 | nn.Linear(hidden_dim, dim), 33 | nn.Dropout(dropout) 34 | ) 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | class RectifiedLinearAttention(nn.Module): 39 | def __init__(self, dim, num_heads = 8, attn_drop = 0., proj_drop=0., qk_scale=None, qkv_bias=False, comb=False, vis=False): 40 | super().__init__() 41 | dim_head = dim // num_heads 42 | inner_dim = dim_head * num_heads 43 | project_out = not (num_heads == 1 and dim_head == dim) 44 | 45 | self.heads = num_heads 46 | self.scale = qk_scale or dim_head ** -0.5 47 | 48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, qkv_bias) 49 | 50 | self.norm = nn.LayerNorm(inner_dim) 51 | 52 | self.to_out = nn.Sequential( 53 | nn.Linear(inner_dim, dim), 54 | nn.Dropout(proj_drop) 55 | ) if project_out else nn.Identity() 56 | 57 | def forward(self, x, vis=False): 58 | b, n, _, h = *x.shape, self.heads 59 | qkv = self.to_qkv(x).chunk(3, dim = -1) 60 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 61 | 62 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 63 | 64 | attn = F.relu(dots) 65 | 66 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 67 | out = rearrange(out, 'b h n d -> b n (h d)') 68 | 69 | out = self.to_out(self.norm(out)) 70 | return out -------------------------------------------------------------------------------- /model/motion_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ClassifierHead(nn.Module): 6 | def __init__(self, params, num_classes=3, num_joints=17): 7 | super(ClassifierHead, self).__init__() 8 | self.params = params 9 | input_dim = self._get_input_dim(num_joints) 10 | if self.params['medication']: 11 | input_dim += 1 12 | if len(self.params['metadata']) > 0: 13 | input_dim += len(self.params['metadata']) 14 | self.dims = [input_dim, *self.params['classifier_hidden_dims'], num_classes] 15 | 16 | self.fc_layers = self._create_fc_layers() 17 | self.batch_norms = self._create_batch_norms() 18 | self.dropout = nn.Dropout(p=self.params['classifier_dropout']) 19 | self.activation = nn.ReLU() 20 | 21 | def _create_fc_layers(self): 22 | fc_layers = nn.ModuleList() 23 | mlp_size = len(self.dims) 24 | 25 | for i in range(mlp_size - 1): 26 | fc_layer = nn.Linear(in_features=self.dims[i], 27 | out_features=self.dims[i+1]) 28 | fc_layers.append(fc_layer) 29 | 30 | return fc_layers 31 | 32 | def _create_batch_norms(self): 33 | batch_norms = nn.ModuleList() 34 | n_batchnorms = len(self.dims) - 2 35 | if n_batchnorms == 0: 36 | return batch_norms 37 | 38 | for i in range(n_batchnorms): 39 | batch_norm = nn.BatchNorm1d(self.dims[i+1], momentum=0.1) 40 | batch_norms.append(batch_norm) 41 | 42 | return batch_norms 43 | 44 | def _get_input_dim(self, num_joints): 45 | backbone = self.params['backbone'] 46 | if backbone == 'poseformer': 47 | if self.params['preclass_rem_T']: 48 | return self.params['model_dim'] 49 | else: 50 | return self.params['model_dim'] * self.params['source_seq_len'] 51 | elif backbone == "motionbert": 52 | if self.params['merge_joints']: 53 | return self.params['dim_rep'] 54 | else: 55 | return self.params['dim_rep'] * num_joints 56 | elif backbone == 'poseformerv2': 57 | return self.params['embed_dim_ratio'] * num_joints * 2 58 | elif backbone == "mixste": 59 | if self.params['merge_joints']: 60 | return self.params['embed_dim_ratio'] 61 | else: 62 | return self.params['embed_dim_ratio'] * num_joints 63 | elif backbone == "motionagformer": 64 | if self.params['merge_joints']: 65 | return self.params['dim_rep'] 66 | else: 67 | return self.params['dim_rep'] * num_joints 68 | 69 | def forward(self, feat): 70 | feat = self.dropout(feat) 71 | if self.params['backbone'] == 'motionbert': 72 | return self._forward_motionbert(feat) 73 | elif self.params['backbone'] == 'poseformer': 74 | return self._forward_poseforemer(feat) 75 | elif self.params['backbone'] == 'poseformerv2': 76 | return self._forward_poseformerv2(feat) 77 | elif self.params['backbone'] == "mixste": 78 | return self._forward_mixste(feat) 79 | elif self.params['backbone'] == "motionagformer": 80 | return self._forward_motionagformer(feat) 81 | 82 | def _forward_fc_layers(self, feat): 83 | mlp_size = len(self.dims) 84 | for i in range(mlp_size - 2): 85 | fc_layer = self.fc_layers[i] 86 | batch_norm = self.batch_norms[i] 87 | 88 | feat = self.activation(batch_norm(fc_layer(feat))) 89 | 90 | last_fc_layer = self.fc_layers[-1] 91 | feat = last_fc_layer(feat) 92 | return feat 93 | 94 | def _forward_motionagformer(self, feat): 95 | B, T, J, C = feat.shape 96 | feat = feat.permute(0, 2, 3, 1) # (B, T, J, C) -> (B, J, C, T) 97 | feat = feat.mean(dim=-1) # (B, J, C, T) -> (B, J, C) 98 | if self.params['merge_joints']: 99 | feat = feat.mean(dim=-2) # (B, J, C) -> (B, C) 100 | else: 101 | feat = feat.reshape(B, -1) # (B, J * C) 102 | feat = self._forward_fc_layers(feat) 103 | return feat 104 | 105 | def _forward_mixste(self, feat): 106 | """ 107 | x: Tensor with shape (batch_size, n_frames, n_joints, dim_representation) 108 | """ 109 | B, T, J, C = feat.shape 110 | feat = feat.permute(0, 2, 3, 1) # (B, T, J, C) -> (B, J, C, T) 111 | feat = feat.mean(dim=-1) # (B, J, C, T) -> (B, J, C) 112 | if self.params['merge_joints']: 113 | feat = feat.mean(dim=-2) # (B, J, C) -> (B, C) 114 | else: 115 | feat = feat.reshape(B, -1) # (B, J * C) 116 | feat = self._forward_fc_layers(feat) 117 | return feat 118 | 119 | def _forward_poseformerv2(self, feat): 120 | """ 121 | x: Tensor with shape (batch_size, 1, embed_dim_ratio * num_joints * 2) 122 | """ 123 | B, _, C = feat.shape 124 | feat = feat.reshape(B, C) # (B, 1, C) -> (B, C) 125 | feat = self._forward_fc_layers(feat) 126 | return feat 127 | 128 | def _forward_motionbert(self, feat): 129 | """ 130 | x: Tensor with shape (batch_size, n_frames, n_joints, dim_representation) 131 | """ 132 | B, T, J, C = feat.shape 133 | feat = feat.permute(0, 2, 3, 1) # (B, T, J, C) -> (B, J, C, T) 134 | feat = feat.mean(dim=-1) # (B, J, C, T) -> (B, J, C) 135 | if self.params['merge_joints']: 136 | feat = feat.mean(dim=-2) # (B, J, C) -> (B, C) 137 | else: 138 | feat = feat.reshape(B, -1) # (B, J * C) 139 | feat = self._forward_fc_layers(feat) 140 | return feat 141 | 142 | def _forward_poseforemer(self, feat): 143 | """ 144 | x: Tensor with shape (batch_size, n_frames, dim_representation) 145 | """ 146 | T, B, C = feat.shape 147 | if self.params['preclass_rem_T']: 148 | # Reshape the tensor to (B, 1, C, T) J=1 149 | feat = feat.permute(1, 2, 0).unsqueeze(1) 150 | feat = feat.mean(dim=-1) # (B, J, C, T) -> (B, J, C) 151 | else: 152 | feat = feat.permute(1, 0, 2) # (B, T, C) 153 | 154 | feat = feat.reshape(B, -1) # (B, J * C) or (B, T * C) 155 | feat = self._forward_fc_layers(feat) 156 | return feat 157 | 158 | 159 | class MotionEncoder(nn.Module): 160 | def __init__(self, backbone, params, num_classes=4, num_joints=17, train_mode='end2end'): 161 | super(MotionEncoder, self).__init__() 162 | assert train_mode in ['end2end', 'classifier_only'], "train_mode should be either end2end or classifier_only." \ 163 | f" Found {train_mode}" 164 | self.backbone = backbone 165 | if train_mode == 'classifier_only': 166 | self.freeze_backbone() 167 | self.head = ClassifierHead(params, num_classes=num_classes, num_joints=num_joints) 168 | self.num_classes = num_classes 169 | self.medprob = params['medication'] 170 | self.metadata = params['metadata'] 171 | 172 | def freeze_backbone(self): 173 | for param in self.backbone.parameters(): 174 | param.requires_grad = False 175 | print("[INFO - MotionEncoder] Backbone parameters are frozen") 176 | 177 | def forward(self, x, metadata, med=None): 178 | """ 179 | x: Tensor with shape (batch_size, n_frames, n_joints, C=3) 180 | """ 181 | feat = self.backbone(x) 182 | if self.medprob and med is not None: 183 | med = med.to(feat.device) 184 | med = med.view(*[-1] + [1] * (feat.dim() - 1)) 185 | s = list(feat.shape) 186 | s[-1] = 1 # Set the last dimension to 1 187 | med = med.expand(*s) 188 | feat = torch.cat((feat, med), dim=-1) 189 | if len(self.metadata) > 0: 190 | metadata = metadata.view(metadata.shape[0], *([1] * (feat.dim() - 2)), metadata.shape[-1]) 191 | metadata = metadata.expand(*feat.shape[:-1], metadata.shape[-1]) 192 | feat = torch.cat((feat, metadata), dim=-1) 193 | out = self.head(feat) 194 | return out 195 | 196 | 197 | def _test_classifier_head(): 198 | params = { 199 | "backbone": "motionbert", 200 | "dim_rep": 512, 201 | "classifier_hidden_dims": [], 202 | 'classifier_dropout': 0.5 203 | } 204 | head = ClassifierHead(params, num_classes=3, num_joints=17) 205 | 206 | B, T, J, C = 4, 243, 17, 512 207 | feat = torch.randn(B, T, J, C) 208 | out = head(feat) 209 | assert out.shape == (4, 3) 210 | 211 | if __name__ == "__main__": 212 | _test_classifier_head() -------------------------------------------------------------------------------- /model/motionagformer/MotionAGFormer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from timm.models.layers import DropPath 6 | 7 | import sys 8 | import os 9 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd()))) 10 | 11 | from model.motionagformer.modules.attention import Attention 12 | from model.motionagformer.modules.graph import GCN 13 | from model.motionagformer.modules.mlp import MLP 14 | from model.motionagformer.modules.tcn import MultiScaleTCN 15 | 16 | 17 | 18 | class AGFormerBlock(nn.Module): 19 | """ 20 | Implementation of AGFormer block. 21 | """ 22 | 23 | def __init__(self, dim, mlp_ratio=4., act_layer=nn.GELU, attn_drop=0., drop=0., drop_path=0., 24 | num_heads=8, qkv_bias=False, qk_scale=None, use_layer_scale=True, layer_scale_init_value=1e-5, 25 | mode='spatial', mixer_type="attention", use_temporal_similarity=True, 26 | temporal_connection_len=1, neighbour_num=4, n_frames=243): 27 | super().__init__() 28 | self.norm1 = nn.LayerNorm(dim) 29 | if mixer_type == 'attention': 30 | self.mixer = Attention(dim, dim, num_heads, qkv_bias, qk_scale, attn_drop, 31 | proj_drop=drop, mode=mode) 32 | elif mixer_type == 'graph': 33 | self.mixer = GCN(dim, dim, 34 | num_nodes=17 if mode == 'spatial' else n_frames, 35 | neighbour_num=neighbour_num, 36 | mode=mode, 37 | use_temporal_similarity=use_temporal_similarity, 38 | temporal_connection_len=temporal_connection_len) 39 | elif mixer_type == "ms-tcn": 40 | self.mixer = MultiScaleTCN(in_channels=dim, out_channels=dim) 41 | else: 42 | raise NotImplementedError("AGFormer mixer_type is either attention or graph") 43 | self.norm2 = nn.LayerNorm(dim) 44 | 45 | mlp_hidden_dim = int(dim * mlp_ratio) 46 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, 47 | act_layer=act_layer, drop=drop) 48 | 49 | # The following two techniques are useful to train deep GraphFormers. 50 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 51 | self.use_layer_scale = use_layer_scale 52 | if use_layer_scale: 53 | self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 54 | self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 55 | 56 | def forward(self, x): 57 | """ 58 | x: tensor with shape [B, T, J, C] 59 | """ 60 | if self.use_layer_scale: 61 | x = x + self.drop_path( 62 | self.layer_scale_1.unsqueeze(0).unsqueeze(0) 63 | * self.mixer(self.norm1(x))) 64 | x = x + self.drop_path( 65 | self.layer_scale_2.unsqueeze(0).unsqueeze(0) 66 | * self.mlp(self.norm2(x))) 67 | else: 68 | x = x + self.drop_path(self.mixer(self.norm1(x))) 69 | x = x + self.drop_path(self.mlp(self.norm2(x))) 70 | return x 71 | 72 | 73 | class MotionAGFormerBlock(nn.Module): 74 | """ 75 | Implementation of MotionAGFormer block. It has two ST and TS branches followed by adaptive fusion. 76 | """ 77 | 78 | def __init__(self, dim, mlp_ratio=4., act_layer=nn.GELU, attn_drop=0., drop=0., drop_path=0., 79 | num_heads=8, use_layer_scale=True, qkv_bias=False, qk_scale=None, layer_scale_init_value=1e-5, 80 | use_adaptive_fusion=True, hierarchical=False, use_temporal_similarity=True, 81 | temporal_connection_len=1, use_tcn=False, graph_only=False, neighbour_num=4, n_frames=243): 82 | super().__init__() 83 | self.hierarchical = hierarchical 84 | dim = dim // 2 if hierarchical else dim 85 | 86 | # ST Attention branch 87 | self.att_spatial = AGFormerBlock(dim, mlp_ratio, act_layer, attn_drop, drop, drop_path, num_heads, qkv_bias, 88 | qk_scale, use_layer_scale, layer_scale_init_value, 89 | mode='spatial', mixer_type="attention", 90 | use_temporal_similarity=use_temporal_similarity, 91 | neighbour_num=neighbour_num, 92 | n_frames=n_frames) 93 | self.att_temporal = AGFormerBlock(dim, mlp_ratio, act_layer, attn_drop, drop, drop_path, num_heads, qkv_bias, 94 | qk_scale, use_layer_scale, layer_scale_init_value, 95 | mode='temporal', mixer_type="attention", 96 | use_temporal_similarity=use_temporal_similarity, 97 | neighbour_num=neighbour_num, 98 | n_frames=n_frames) 99 | 100 | # ST Graph branch 101 | if graph_only: 102 | self.graph_spatial = GCN(dim, dim, 103 | num_nodes=17, 104 | mode='spatial') 105 | if use_tcn: 106 | self.graph_temporal = MultiScaleTCN(in_channels=dim, out_channels=dim) 107 | else: 108 | self.graph_temporal = GCN(dim, dim, 109 | num_nodes=n_frames, 110 | neighbour_num=neighbour_num, 111 | mode='temporal', 112 | use_temporal_similarity=use_temporal_similarity, 113 | temporal_connection_len=temporal_connection_len) 114 | else: 115 | self.graph_spatial = AGFormerBlock(dim, mlp_ratio, act_layer, attn_drop, drop, drop_path, num_heads, 116 | qkv_bias, 117 | qk_scale, use_layer_scale, layer_scale_init_value, 118 | mode='spatial', mixer_type="graph", 119 | use_temporal_similarity=use_temporal_similarity, 120 | temporal_connection_len=temporal_connection_len, 121 | neighbour_num=neighbour_num, 122 | n_frames=n_frames) 123 | self.graph_temporal = AGFormerBlock(dim, mlp_ratio, act_layer, attn_drop, drop, drop_path, num_heads, 124 | qkv_bias, 125 | qk_scale, use_layer_scale, layer_scale_init_value, 126 | mode='temporal', mixer_type="ms-tcn" if use_tcn else 'graph', 127 | use_temporal_similarity=use_temporal_similarity, 128 | temporal_connection_len=temporal_connection_len, 129 | neighbour_num=neighbour_num, 130 | n_frames=n_frames) 131 | 132 | self.use_adaptive_fusion = use_adaptive_fusion 133 | if self.use_adaptive_fusion: 134 | self.fusion = nn.Linear(dim * 2, 2) 135 | self._init_fusion() 136 | 137 | def _init_fusion(self): 138 | self.fusion.weight.data.fill_(0) 139 | self.fusion.bias.data.fill_(0.5) 140 | 141 | def forward(self, x): 142 | """ 143 | x: tensor with shape [B, T, J, C] 144 | """ 145 | if self.hierarchical: 146 | B, T, J, C = x.shape 147 | x_attn, x_graph = x[..., :C // 2], x[..., C // 2:] 148 | 149 | x_attn = self.att_temporal(self.att_spatial(x_attn)) 150 | x_graph = self.graph_temporal(self.graph_spatial(x_graph + x_attn)) 151 | else: 152 | x_attn = self.att_temporal(self.att_spatial(x)) 153 | x_graph = self.graph_temporal(self.graph_spatial(x)) 154 | 155 | if self.hierarchical: 156 | x = torch.cat((x_attn, x_graph), dim=-1) 157 | elif self.use_adaptive_fusion: 158 | alpha = torch.cat((x_attn, x_graph), dim=-1) 159 | alpha = self.fusion(alpha) 160 | alpha = alpha.softmax(dim=-1) 161 | x = x_attn * alpha[..., 0:1] + x_graph * alpha[..., 1:2] 162 | else: 163 | x = (x_attn + x_graph) * 0.5 164 | 165 | return x 166 | 167 | 168 | def create_layers(dim, n_layers, mlp_ratio=4., act_layer=nn.GELU, attn_drop=0., drop_rate=0., drop_path_rate=0., 169 | num_heads=8, use_layer_scale=True, qkv_bias=False, qkv_scale=None, layer_scale_init_value=1e-5, 170 | use_adaptive_fusion=True, hierarchical=False, use_temporal_similarity=True, 171 | temporal_connection_len=1, use_tcn=False, graph_only=False, neighbour_num=4, n_frames=243): 172 | """ 173 | generates MotionAGFormer layers 174 | """ 175 | layers = [] 176 | for _ in range(n_layers): 177 | layers.append(MotionAGFormerBlock(dim=dim, 178 | mlp_ratio=mlp_ratio, 179 | act_layer=act_layer, 180 | attn_drop=attn_drop, 181 | drop=drop_rate, 182 | drop_path=drop_path_rate, 183 | num_heads=num_heads, 184 | use_layer_scale=use_layer_scale, 185 | layer_scale_init_value=layer_scale_init_value, 186 | qkv_bias=qkv_bias, 187 | qk_scale=qkv_scale, 188 | use_adaptive_fusion=use_adaptive_fusion, 189 | hierarchical=hierarchical, 190 | use_temporal_similarity=use_temporal_similarity, 191 | temporal_connection_len=temporal_connection_len, 192 | use_tcn=use_tcn, 193 | graph_only=graph_only, 194 | neighbour_num=neighbour_num, 195 | n_frames=n_frames)) 196 | layers = nn.Sequential(*layers) 197 | 198 | return layers 199 | 200 | 201 | class MotionAGFormer(nn.Module): 202 | """ 203 | MotionAGFormer, the main class of our model. 204 | """ 205 | 206 | def __init__(self, n_layers, dim_in, dim_feat, dim_rep=512, dim_out=3, mlp_ratio=4, act_layer=nn.GELU, attn_drop=0., 207 | drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5, use_adaptive_fusion=True, 208 | num_heads=4, qkv_bias=False, qkv_scale=None, hierarchical=False, num_joints=17, 209 | use_temporal_similarity=True, temporal_connection_len=1, use_tcn=False, graph_only=False, 210 | neighbour_num=4, n_frames=243): 211 | """ 212 | :param n_layers: Number of layers. 213 | :param dim_in: Input dimension. 214 | :param dim_feat: Feature dimension. 215 | :param dim_rep: Motion representation dimension 216 | :param dim_out: output dimension. For 3D pose lifting it is set to 3 217 | :param mlp_ratio: MLP ratio. 218 | :param act_layer: Activation layer. 219 | :param drop: Dropout rate. 220 | :param drop_path: Stochastic drop probability. 221 | :param use_layer_scale: Whether to use layer scaling or not. 222 | :param layer_scale_init_value: Layer scale init value in case of using layer scaling. 223 | :param use_adaptive_fusion: Whether to use adaptive fusion or not. 224 | :param num_heads: Number of attention heads in attention branch 225 | :param qkv_bias: Whether to include bias in the linear layers that create query, key, and value or not. 226 | :param qkv_scale: scale factor to multiply after outer product of query and key. If None, it's set to 227 | 1 / sqrt(dim_feature // num_heads) 228 | :param hierarchical: Whether to use hierarchical structure or not. 229 | :param num_joints: Number of joints. 230 | :param use_temporal_similarity: If true, for temporal GCN uses top-k similarity between nodes 231 | :param temporal_connection_len: Connects joint to itself within next `temporal_connection_len` frames 232 | :param use_tcn: If true, uses MS-TCN for temporal part of the graph branch. 233 | :param graph_only: Uses GCN instead of GraphFormer in the graph branch. 234 | :param neighbour_num: Number of neighbors for temporal GCN similarity. 235 | :param n_frames: Number of frames. Default is 243 236 | """ 237 | super().__init__() 238 | 239 | self.joints_embed = nn.Linear(dim_in, dim_feat) 240 | self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat)) 241 | self.norm = nn.LayerNorm(dim_feat) 242 | 243 | self.layers = create_layers(dim=dim_feat, 244 | n_layers=n_layers, 245 | mlp_ratio=mlp_ratio, 246 | act_layer=act_layer, 247 | attn_drop=attn_drop, 248 | drop_rate=drop, 249 | drop_path_rate=drop_path, 250 | num_heads=num_heads, 251 | use_layer_scale=use_layer_scale, 252 | qkv_bias=qkv_bias, 253 | qkv_scale=qkv_scale, 254 | layer_scale_init_value=layer_scale_init_value, 255 | use_adaptive_fusion=use_adaptive_fusion, 256 | hierarchical=hierarchical, 257 | use_temporal_similarity=use_temporal_similarity, 258 | temporal_connection_len=temporal_connection_len, 259 | use_tcn=use_tcn, 260 | graph_only=graph_only, 261 | neighbour_num=neighbour_num, 262 | n_frames=n_frames) 263 | 264 | self.rep_logit = nn.Sequential(OrderedDict([ 265 | ('fc', nn.Linear(dim_feat, dim_rep)), 266 | ('act', nn.Tanh()) 267 | ])) 268 | 269 | self.head = nn.Linear(dim_rep, dim_out) 270 | 271 | def forward(self, x, return_rep=True): 272 | """ 273 | :param x: tensor with shape [B, T, J, C] (T=243, J=17, C=3) 274 | :param return_rep: Returns motion representation feature volume (In case of using this as backbone) 275 | """ 276 | x = self.joints_embed(x) 277 | x = x + self.pos_embed 278 | 279 | for layer in self.layers: 280 | x = layer(x) 281 | 282 | x = self.norm(x) 283 | x = self.rep_logit(x) 284 | if return_rep: 285 | return x 286 | 287 | x = self.head(x) 288 | 289 | return x 290 | -------------------------------------------------------------------------------- /model/motionagformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/model/motionagformer/__init__.py -------------------------------------------------------------------------------- /model/motionagformer/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/model/motionagformer/modules/__init__.py -------------------------------------------------------------------------------- /model/motionagformer/modules/attention.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Attention(nn.Module): 5 | """ 6 | A simplified version of attention from DSTFormer that also considers x tensor to be (B, T, J, C) instead of 7 | (B * T, J, C) 8 | """ 9 | 10 | def __init__(self, dim_in, dim_out, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 11 | mode='spatial'): 12 | super().__init__() 13 | self.num_heads = num_heads 14 | head_dim = dim_in // num_heads 15 | self.scale = qk_scale or head_dim ** -0.5 16 | 17 | self.attn_drop = nn.Dropout(attn_drop) 18 | self.proj = nn.Linear(dim_in, dim_out) 19 | self.mode = mode 20 | self.qkv = nn.Linear(dim_in, dim_in * 3, bias=qkv_bias) 21 | self.proj_drop = nn.Dropout(proj_drop) 22 | 23 | def forward(self, x): 24 | B, T, J, C = x.shape 25 | 26 | qkv = self.qkv(x).reshape(B, T, J, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 27 | 5) # (3, B, H, T, J, C) 28 | if self.mode == 'temporal': 29 | q, k, v = qkv[0], qkv[1], qkv[2] 30 | x = self.forward_temporal(q, k, v) 31 | elif self.mode == 'spatial': 32 | q, k, v = qkv[0], qkv[1], qkv[2] 33 | x = self.forward_spatial(q, k, v) 34 | else: 35 | raise NotImplementedError(self.mode) 36 | x = self.proj(x) 37 | x = self.proj_drop(x) 38 | return x 39 | 40 | def forward_spatial(self, q, k, v): 41 | B, H, T, J, C = q.shape 42 | attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, J, J) 43 | attn = attn.softmax(dim=-1) 44 | attn = self.attn_drop(attn) 45 | 46 | x = attn @ v # (B, H, T, J, C) 47 | x = x.permute(0, 2, 3, 1, 4).reshape(B, T, J, C * self.num_heads) 48 | return x # (B, T, J, C) 49 | 50 | def forward_temporal(self, q, k, v): 51 | B, H, T, J, C = q.shape 52 | qt = q.transpose(2, 3) # (B, H, J, T, C) 53 | kt = k.transpose(2, 3) # (B, H, J, T, C) 54 | vt = v.transpose(2, 3) # (B, H, J, T, C) 55 | 56 | attn = (qt @ kt.transpose(-2, -1)) * self.scale # (B, H, J, T, T) 57 | attn = attn.softmax(dim=-1) 58 | attn = self.attn_drop(attn) 59 | 60 | x = attn @ vt # (B, H, J, T, C) 61 | x = x.permute(0, 3, 2, 1, 4).reshape(B, T, J, C * self.num_heads) 62 | return x # (B, T, J, C) 63 | -------------------------------------------------------------------------------- /model/motionagformer/modules/graph.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | CONNECTIONS = {10: [9], 9: [8, 10], 8: [7, 9], 14: [15, 8], 15: [16, 14], 11: [12, 8], 12: [13, 11], 7 | 7: [0, 8], 0: [1, 7], 1: [2, 0], 2: [3, 1], 4: [5, 0], 5: [6, 4], 16: [15], 13: [12], 3: [2], 6: [5]} 8 | 9 | 10 | class GCN(nn.Module): 11 | def __init__(self, dim_in, dim_out, num_nodes, neighbour_num=4, mode='spatial', use_temporal_similarity=True, 12 | temporal_connection_len=1, connections=None): 13 | self.nodes_ = """ 14 | :param dim_int: Channel input dimension 15 | :param dim_out: Channel output dimension 16 | :param num_nodes: Number of nodes 17 | :param neighbour_num: Neighbor numbers. Used in temporal GCN to create edges 18 | :param mode: Either 'spatial' or 'temporal' 19 | :param use_temporal_similarity: If true, for temporal GCN uses top-k similarity between nodes 20 | :param temporal_connection_len: Connects joint to itself within next `temporal_connection_len` frames 21 | :param connections: Spatial connections for graph edges (Optional) 22 | """ 23 | super().__init__() 24 | assert mode in ['spatial', 'temporal'], "Mode is undefined" 25 | 26 | self.relu = nn.ReLU() 27 | self.neighbour_num = neighbour_num 28 | self.dim_in = dim_in 29 | self.dim_out = dim_out 30 | self.mode = mode 31 | self.use_temporal_similarity = use_temporal_similarity 32 | self.num_nodes = num_nodes 33 | self.connections = connections 34 | 35 | self.U = nn.Linear(self.dim_in, self.dim_out) 36 | self.V = nn.Linear(self.dim_in, self.dim_out) 37 | self.batch_norm = nn.BatchNorm1d(self.num_nodes) 38 | 39 | self._init_gcn() 40 | 41 | if mode == 'spatial': 42 | self.adj = self._init_spatial_adj() 43 | elif mode == 'temporal' and not self.use_temporal_similarity: 44 | self.adj = self._init_temporal_adj(temporal_connection_len) 45 | 46 | def _init_gcn(self): 47 | self.U.weight.data.normal_(0, math.sqrt(2. / self.dim_in)) 48 | self.V.weight.data.normal_(0, math.sqrt(2. / self.dim_in)) 49 | self.batch_norm.weight.data.fill_(1) 50 | self.batch_norm.bias.data.zero_() 51 | 52 | def _init_spatial_adj(self): 53 | adj = torch.zeros((self.num_nodes, self.num_nodes)) 54 | connections = self.connections if self.connections is not None else CONNECTIONS 55 | 56 | for i in range(self.num_nodes): 57 | connected_nodes = connections[i] 58 | for j in connected_nodes: 59 | adj[i, j] = 1 60 | return adj 61 | 62 | def _init_temporal_adj(self, connection_length): 63 | """Connects each joint to itself and the same joint withing next `connection_length` frames.""" 64 | adj = torch.zeros((self.num_nodes, self.num_nodes)) 65 | 66 | for i in range(self.num_nodes): 67 | try: 68 | for j in range(connection_length + 1): 69 | adj[i, i + j] = 1 70 | except IndexError: # next j frame does not exist 71 | pass 72 | return adj 73 | 74 | @staticmethod 75 | def normalize_digraph(adj): 76 | b, n, c = adj.shape 77 | 78 | node_degrees = adj.detach().sum(dim=-1) 79 | deg_inv_sqrt = node_degrees ** -0.5 80 | norm_deg_matrix = torch.eye(n) 81 | dev = adj.get_device() 82 | if dev >= 0: 83 | norm_deg_matrix = norm_deg_matrix.to(dev) 84 | norm_deg_matrix = norm_deg_matrix.view(1, n, n) * deg_inv_sqrt.view(b, n, 1) 85 | norm_adj = torch.bmm(torch.bmm(norm_deg_matrix, adj), norm_deg_matrix) 86 | 87 | return norm_adj 88 | 89 | def change_adj_device_to_cuda(self, adj): 90 | dev = self.V.weight.get_device() 91 | if dev >= 0 and adj.get_device() < 0: 92 | adj = adj.to(dev) 93 | return adj 94 | 95 | def forward(self, x): 96 | """ 97 | x: tensor with shape [B, T, J, C] 98 | """ 99 | b, t, j, c = x.shape 100 | if self.mode == 'temporal': 101 | x = x.transpose(1, 2) # (B, T, J, C) -> (B, J, T, C) 102 | x = x.reshape(-1, t, c) 103 | if self.use_temporal_similarity: 104 | similarity = x @ x.transpose(1, 2) 105 | threshold = similarity.topk(k=self.neighbour_num, dim=-1, largest=True)[0][..., -1].view(b * j, t, 1) 106 | adj = (similarity >= threshold).float() 107 | else: 108 | adj = self.adj 109 | adj = self.change_adj_device_to_cuda(adj) 110 | adj = adj.repeat(b * j, 1, 1) 111 | 112 | else: 113 | x = x.reshape(-1, j, c) 114 | adj = self.adj 115 | adj = self.change_adj_device_to_cuda(adj) 116 | adj = adj.repeat(b * t, 1, 1) 117 | 118 | norm_adj = self.normalize_digraph(adj) 119 | aggregate = norm_adj @ self.V(x) 120 | 121 | if self.dim_in == self.dim_out: 122 | x = self.relu(x + self.batch_norm(aggregate + self.U(x))) 123 | else: 124 | x = self.relu(self.batch_norm(aggregate + self.U(x))) 125 | 126 | x = x.reshape(-1, t, j, self.dim_out) if self.mode == 'spatial' \ 127 | else x.reshape(-1, j, t, self.dim_out).transpose(1, 2) 128 | return x 129 | -------------------------------------------------------------------------------- /model/motionagformer/modules/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., 6 | channel_first=False): 7 | """ 8 | :param channel_first: if True, during forward the tensor shape is [B, C, T, J] and fc layers are performed with 9 | 1x1 convolutions. 10 | """ 11 | super().__init__() 12 | out_features = out_features or in_features 13 | hidden_features = hidden_features or in_features 14 | self.act = act_layer() 15 | self.drop = nn.Dropout(drop) 16 | 17 | if channel_first: 18 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 19 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 20 | else: 21 | self.fc1 = nn.Linear(in_features, hidden_features) 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | return x 31 | -------------------------------------------------------------------------------- /model/motionagformer/modules/tcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TemporalConv(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1): 7 | super(TemporalConv, self).__init__() 8 | pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 9 | self.conv = nn.Conv2d( 10 | in_channels, 11 | out_channels, 12 | kernel_size=(kernel_size, 1), 13 | padding=(pad, 0), 14 | stride=(stride, 1), 15 | dilation=(dilation, 1)) 16 | 17 | self.bn = nn.BatchNorm2d(out_channels) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | x = self.bn(x) 22 | return x 23 | 24 | 25 | class MultiScaleTCN(nn.Module): 26 | def __init__(self, 27 | in_channels, 28 | out_channels, 29 | kernel_size=5, 30 | stride=1, 31 | dilations=(1, 2), 32 | residual=True, 33 | residual_kernel_size=1): 34 | 35 | super().__init__() 36 | assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches (6x)' 37 | 38 | # Multiple branches of temporal convolution 39 | self.num_branches = len(dilations) + 2 40 | branch_channels = out_channels // self.num_branches 41 | if isinstance(kernel_size, list): 42 | assert len(kernel_size) == len(dilations) 43 | else: 44 | kernel_size = [kernel_size] * len(dilations) 45 | 46 | # Temporal Convolution branches 47 | self.branches = nn.ModuleList([ 48 | nn.Sequential( 49 | nn.Conv2d( 50 | in_channels, 51 | branch_channels, 52 | kernel_size=1, 53 | padding=0), 54 | nn.BatchNorm2d(branch_channels), 55 | nn.ReLU(inplace=True), 56 | TemporalConv( 57 | branch_channels, 58 | branch_channels, 59 | kernel_size=ks, 60 | stride=stride, 61 | dilation=dilation), 62 | ) 63 | for ks, dilation in zip(kernel_size, dilations) 64 | ]) 65 | 66 | # Additional Max & 1x1 branch 67 | self.branches.append(nn.Sequential( 68 | nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0), 69 | nn.BatchNorm2d(branch_channels), 70 | nn.ReLU(inplace=True), 71 | nn.MaxPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)), 72 | nn.BatchNorm2d(branch_channels) 73 | )) 74 | 75 | self.branches.append(nn.Sequential( 76 | nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride, 1)), 77 | nn.BatchNorm2d(branch_channels) 78 | )) 79 | 80 | # Residual connection 81 | if not residual: 82 | self.residual = lambda x: 0 83 | elif (in_channels == out_channels) and (stride == 1): 84 | self.residual = lambda x: x 85 | else: 86 | self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride) 87 | 88 | def forward(self, x): 89 | """ 90 | x: tensor with shape [B, T, J, C] 91 | """ 92 | x = x.permute(0, 3, 1, 2) # (B, T, J, C) -> (B, C, T, J) 93 | 94 | res = self.residual(x) 95 | branch_outs = [] 96 | for temp_conv in self.branches: 97 | out = temp_conv(x) 98 | branch_outs.append(out) 99 | 100 | out = torch.cat(branch_outs, dim=1) 101 | out += res 102 | 103 | out = out.permute(0, 2, 3, 1) # (B, C, T, J) -> (B, T, J, C) 104 | return out 105 | 106 | 107 | if __name__ == "__main__": 108 | ms_tcn = MultiScaleTCN(528, 528) 109 | x = torch.randn(8, 243, 17, 528) 110 | ms_tcn.forward(x) 111 | for name, param in ms_tcn.named_parameters(): 112 | print(f'{name}: {param.numel()}') 113 | print(sum(p.numel() for p in ms_tcn.parameters() if p.requires_grad)) 114 | -------------------------------------------------------------------------------- /model/motionbert/drop.py: -------------------------------------------------------------------------------- 1 | """ DropBlock, DropPath 2 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. 3 | Papers: 4 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) 5 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) 6 | Code: 7 | DropBlock impl inspired by two Tensorflow impl that I liked: 8 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 9 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | def drop_path(x, drop_prob: float = 0., training: bool = False): 18 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 19 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 20 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 21 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 22 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 23 | 'survival rate' as the argument. 24 | """ 25 | if drop_prob == 0. or not training: 26 | return x 27 | keep_prob = 1 - drop_prob 28 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 29 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 30 | random_tensor.floor_() # binarize 31 | output = x.div(keep_prob) * random_tensor 32 | return output 33 | 34 | 35 | class DropPath(nn.Module): 36 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 37 | """ 38 | def __init__(self, drop_prob=None): 39 | super(DropPath, self).__init__() 40 | self.drop_prob = drop_prob 41 | 42 | def forward(self, x): 43 | return drop_path(x, self.drop_prob, self.training) -------------------------------------------------------------------------------- /model/poseformer/Conv1DEncoder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Pose Transformers (POTR): Human Motion Prediction with Non-Autoregressive 3 | # Transformers 4 | # 5 | # Copyright (c) 2021 Idiap Research Institute, http://www.idiap.ch/ 6 | # Written by 7 | # Angel Martinez , 8 | # 9 | # This file is part of 10 | # POTR: Human Motion Prediction with Non-Autoregressive Transformers 11 | # 12 | # POTR is free software: you can redistribute it and/or modify 13 | # it under the terms of the GNU General Public License version 3 as 14 | # published by the Free Software Foundation. 15 | # 16 | # POTR is distributed in the hope that it will be useful, 17 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 18 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 | # GNU General Public License for more details. 20 | # 21 | # You should have received a copy of the GNU General Public License 22 | # along with POTR. If not, see . 23 | ############################################################################### 24 | """Model of 1D convolutions for encoding pose sequences.""" 25 | 26 | import numpy as np 27 | import os 28 | import sys 29 | 30 | thispath = os.path.dirname(os.path.abspath(__file__)) 31 | sys.path.insert(0, thispath + "/../") 32 | 33 | import torch 34 | import torch.nn as nn 35 | 36 | 37 | class Pose1DEncoder(nn.Module): 38 | def __init__(self, input_channels=3, output_channels=128, n_joints=21): 39 | super(Pose1DEncoder, self).__init__() 40 | self._input_channels = input_channels 41 | self._output_channels = output_channels 42 | self._n_joints = n_joints 43 | self.init_model() 44 | 45 | def init_model(self): 46 | self._model = nn.Sequential( 47 | nn.Conv1d(in_channels=self._input_channels, out_channels=32, kernel_size=7), 48 | nn.BatchNorm1d(32), 49 | nn.ReLU(True), 50 | nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3), 51 | nn.BatchNorm1d(32), 52 | nn.ReLU(True), 53 | nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3), 54 | nn.BatchNorm1d(64), 55 | nn.ReLU(True), 56 | nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3), 57 | nn.BatchNorm1d(64), 58 | nn.ReLU(True), 59 | nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3), 60 | nn.BatchNorm1d(128), 61 | nn.ReLU(True), 62 | nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3), 63 | nn.BatchNorm1d(128), 64 | nn.ReLU(True), 65 | nn.Conv1d(in_channels=128, out_channels=self._output_channels, kernel_size=3), 66 | nn.BatchNorm1d(self._output_channels), 67 | nn.ReLU(True), 68 | nn.Conv1d(in_channels=self._output_channels, out_channels=self._output_channels, kernel_size=3) 69 | ) 70 | 71 | def forward(self, x): 72 | """ 73 | Args: 74 | x: [batch_size, seq_len, skeleton_dim]. 75 | """ 76 | # inputs to model is [batch_size, channels, n_joints] 77 | # transform the batch to [batch_size*seq_len, dof, n_joints] 78 | bs, seq_len, dim = x.size() 79 | dof = dim // self._n_joints 80 | x = x.view(bs * seq_len, dof, self._n_joints) 81 | 82 | # [batch_size*seq_len, dof, n_joints] 83 | x = self._model(x) 84 | # [batch_size, seq_len, output_channels] 85 | x = x.view(bs, seq_len, self._output_channels) 86 | 87 | return x 88 | 89 | 90 | class Pose1DTemporalEncoder(nn.Module): 91 | def __init__(self, input_channels, output_channels): 92 | super(Pose1DTemporalEncoder, self).__init__() 93 | self._input_channels = input_channels 94 | self._output_channels = output_channels 95 | self.init_model() 96 | 97 | def init_model(self): 98 | self._model = nn.Sequential( 99 | nn.Conv1d( 100 | in_channels=self._input_channels, out_channels=32, kernel_size=3, padding=1), 101 | nn.BatchNorm1d(32), 102 | nn.ReLU(True), 103 | nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, padding=1), 104 | nn.BatchNorm1d(32), 105 | nn.ReLU(True), 106 | nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 107 | nn.BatchNorm1d(64), 108 | nn.ReLU(True), 109 | nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, padding=1), 110 | nn.BatchNorm1d(64), 111 | nn.ReLU(True), 112 | nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 113 | nn.BatchNorm1d(128), 114 | nn.ReLU(True), 115 | nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 116 | nn.BatchNorm1d(128), 117 | nn.ReLU(True), 118 | nn.Conv1d(in_channels=128, out_channels=self._output_channels, kernel_size=3, padding=1), 119 | nn.BatchNorm1d(self._output_channels), 120 | nn.ReLU(True), 121 | nn.Conv1d(in_channels=self._output_channels, out_channels=self._output_channels, kernel_size=3, padding=1) 122 | ) 123 | 124 | def forward(self, x): 125 | """ 126 | Args: 127 | x: [batch_size, seq_len, skeleton_dim]. 128 | """ 129 | # batch_size, skeleton_dim, seq_len 130 | x = torch.transpose(x, 1, 2) 131 | x = self._model(x) 132 | # batch_size, seq_len, skeleton_dim 133 | x = torch.transpose(x, 1, 2) 134 | return x 135 | 136 | 137 | if __name__ == '__main__': 138 | dof = 9 139 | output_channels = 128 140 | n_joints = 21 141 | seq_len = 49 142 | 143 | model = Pose1DTemporalEncoder(input_channels=dof * n_joints, output_channels=output_channels) 144 | inputs = torch.FloatTensor(10, seq_len, dof * n_joints) 145 | X = model(inputs) 146 | print(X.size()) 147 | 148 | # model = Pose1DEncoder(input_channels=dof, output_channels=output_channels) 149 | # inputs = torch.FloatTensor(10, seq_len, dof*n_joints) 150 | # X = model(inputs) 151 | # print(X.size()) 152 | -------------------------------------------------------------------------------- /model/poseformer/PoseEncoderDecoder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Pose Transformers (POTR): Human Motion Prediction with Non-Autoregressive 3 | # Transformers 4 | # 5 | # Copyright (c) 2021 Idiap Research Institute, http://www.idiap.ch/ 6 | # Written by 7 | # Angel Martinez , 8 | # 9 | # This file is part of 10 | # POTR: Human Motion Prediction with Non-Autoregressive Transformers 11 | # 12 | # POTR is free software: you can redistribute it and/or modify 13 | # it under the terms of the GNU General Public License version 3 as 14 | # published by the Free Software Foundation. 15 | # 16 | # POTR is distributed in the hope that it will be useful, 17 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 18 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 | # GNU General Public License for more details. 20 | # 21 | # You should have received a copy of the GNU General Public License 22 | # along with POTR. If not, see . 23 | ############################################################################### 24 | """Definition of pose encoder and encoder embeddings and model factory.""" 25 | 26 | import numpy as np 27 | import os 28 | import sys 29 | 30 | import torch 31 | import torch.nn as nn 32 | 33 | thispath = os.path.dirname(os.path.abspath(__file__)) 34 | sys.path.insert(0, thispath + "/../") 35 | 36 | from model import utils 37 | import model.poseformer.PoseGCN as GCN 38 | import model.poseformer.Conv1DEncoder as Conv1DEncoder 39 | 40 | 41 | def pose_encoder_mlp(params): 42 | # These two encoders should be experimented with a graph NN and 43 | # a prior based pose decoder using also the graph 44 | init_fn = utils.normal_init_ \ 45 | if params['init_fn'] == 'normal_init' else utils.xavier_init_ 46 | pose_embedding = nn.Sequential( 47 | nn.Linear(params['input_dim'], params['model_dim']), 48 | nn.Dropout(0.1) 49 | ) 50 | utils.weight_init(pose_embedding, init_fn_=init_fn) 51 | return pose_embedding 52 | 53 | 54 | def pose_decoder_mlp(params): 55 | init_fn = utils.normal_init_ \ 56 | if params['init_fn'] == 'normal_init' else utils.xavier_init_ 57 | pose_decoder = nn.Linear(params['model_dim'], params['pose_dim']) 58 | utils.weight_init(pose_decoder, init_fn_=init_fn) 59 | return pose_decoder 60 | 61 | 62 | def pose_decoder_gcn(params): 63 | decoder = GCN.PoseGCN( 64 | input_features=params['model_dim'], 65 | output_features=9 if params['pose_format'] == 'rotmat' else 3, 66 | model_dim=params['model_dim'], 67 | output_nodes=params['n_joints'], 68 | p_dropout=params['dropout'], 69 | num_stage=1 70 | ) 71 | 72 | return decoder 73 | 74 | 75 | def pose_encoder_gcn(params): 76 | encoder = GCN.SimpleEncoder( 77 | n_nodes=params['n_joints'], 78 | input_features=9 if params['pose_format'] == 'rotmat' else 3, 79 | # n_nodes=params['pose_dim'], 80 | # input_features=1, 81 | model_dim=params['model_dim'], 82 | p_dropout=params['dropout'] 83 | ) 84 | 85 | return encoder 86 | 87 | 88 | def pose_encoder_conv1d(params): 89 | encoder = Conv1DEncoder.Pose1DEncoder( 90 | input_channels=9 if params['pose_format'] == 'rotmat' else 3, 91 | output_channels=params['model_dim'], 92 | n_joints=params['n_joints'] 93 | ) 94 | return encoder 95 | 96 | 97 | def pose_encoder_conv1dtemporal(params): 98 | dof = 9 if params['pose_format'] == 'rotmat' else 3 99 | encoder = Conv1DEncoder.Pose1DTemporalEncoder( 100 | input_channels=dof * params['n_joints'], 101 | output_channels=params['model_dim'] 102 | ) 103 | return encoder 104 | 105 | 106 | def select_pose_encoder_decoder_fn(params): 107 | if params['pose_embedding_type'].lower() == 'simple': 108 | return pose_encoder_mlp, pose_decoder_mlp 109 | if params['pose_embedding_type'].lower() == 'conv1d_enc': 110 | return pose_encoder_conv1d, pose_decoder_mlp 111 | if params['pose_embedding_type'].lower() == 'convtemp_enc': 112 | return pose_encoder_conv1dtemporal, pose_decoder_mlp 113 | if params['pose_embedding_type'].lower() == 'gcn_dec': 114 | return pose_encoder_mlp, pose_decoder_gcn 115 | if params['pose_embedding_type'].lower() == 'gcn_enc': 116 | return pose_encoder_gcn, pose_decoder_mlp 117 | if params['pose_embedding_type'].lower() == 'gcn_full': 118 | return pose_encoder_gcn, pose_decoder_gcn 119 | elif params['pose_embedding_type'].lower() == 'vae': 120 | return pose_encoder_vae, pose_decoder_mlp 121 | else: 122 | raise ValueError('Unknown pose embedding {}'.format(params['pose_embedding_type'])) -------------------------------------------------------------------------------- /model/poseformer/PoseGCN.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Pose Transformers (POTR): Human Motion Prediction with Non-Autoregressive 3 | # Transformers 4 | # 5 | # Copyright (c) 2021 Idiap Research Institute, http://www.idiap.ch/ 6 | # Written by 7 | # Angel Martinez , 8 | # 9 | # This file is part of 10 | # POTR: Human Motion Prediction with Non-Autoregressive Transformers 11 | # 12 | # POTR is free software: you can redistribute it and/or modify 13 | # it under the terms of the GNU General Public License version 3 as 14 | # published by the Free Software Foundation. 15 | # 16 | # POTR is distributed in the hope that it will be useful, 17 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 18 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 | # GNU General Public License for more details. 20 | # 21 | # You should have received a copy of the GNU General Public License 22 | # along with POTR. If not, see . 23 | ############################################################################### 24 | """Graph Convolutional Neural Network implementation. 25 | 26 | Code adapted from [1]. 27 | 28 | [1] https://github.com/wei-mao-2019/HisRepItself 29 | [2] https://github.com/tkipf/gcn/blob/92600c39797c2bfb61a508e52b88fb554df30177/gcn/layers.py#L132 30 | """ 31 | 32 | import os 33 | import sys 34 | import torch.nn as nn 35 | import torch 36 | from torch.nn.parameter import Parameter 37 | import math 38 | import numpy as np 39 | 40 | thispath = os.path.dirname(os.path.abspath(__file__)) 41 | sys.path.insert(0, thispath + "/../") 42 | 43 | from model import utils 44 | 45 | 46 | class GraphConvolution(nn.Module): 47 | """Implements graph convolutions.""" 48 | 49 | def __init__(self, in_features, out_features, output_nodes=48, bias=False): 50 | """Constructor. 51 | 52 | The graph convolutions can be defined as \sigma(AxHxW), where A is the 53 | adjacency matrix, H is the feature representation from previous layer 54 | and W is the wegith of the current layer. The dimensions of such martices 55 | A\in R^{NxN}, H\in R^{NxM} and W\in R^{MxO} where 56 | - N is the number of nodes 57 | - M is the number of input features per node 58 | - O is the number of output features per node 59 | 60 | Args: 61 | in_features: Number of input features per node. 62 | out_features: Number of output features per node. 63 | output_nodes: Number of nodes in the graph. 64 | """ 65 | super(GraphConvolution, self).__init__() 66 | self.in_features = in_features 67 | self.out_features = out_features 68 | self._output_nodes = output_nodes 69 | # W\in R^{MxO} 70 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 71 | # A\in R^{NxN} 72 | self.att = Parameter(torch.FloatTensor(output_nodes, output_nodes)) 73 | if bias: 74 | self.bias = Parameter(torch.FloatTensor(out_features)) 75 | else: 76 | self.register_parameter('bias', None) 77 | self.reset_parameters() 78 | 79 | def reset_parameters(self): 80 | stdv = 1. / math.sqrt(self.weight.size(1)) 81 | self.weight.data.uniform_(-stdv, stdv) 82 | self.att.data.uniform_(-stdv, stdv) 83 | if self.bias is not None: 84 | self.bias.data.uniform_(-stdv, stdv) 85 | 86 | def forward(self, x): 87 | """Forward pass. 88 | 89 | Args: 90 | x: [batch_size, n_nodes, input_features] 91 | Returns: 92 | Feature representation computed from inputs. 93 | Shape is [batch_size, n_nodes, output_features]. 94 | """ 95 | # [batch_size, input_dim, output_features] 96 | # HxW = {NxM}x{MxO} = {NxO} 97 | support = torch.matmul(x, self.weight) 98 | # [batch_size, n_nodes, output_features] 99 | # = {NxN}x{NxO} = {NxO} 100 | output = torch.matmul(self.att, support) 101 | 102 | if self.bias is not None: 103 | return output + self.bias 104 | else: 105 | return output 106 | 107 | def __repr__(self): 108 | return self.__class__.__name__ + ' (' \ 109 | + str(self.in_features) + ' -> ' \ 110 | + str(self.out_features) + ')' 111 | 112 | 113 | class GC_Block(nn.Module): 114 | """Residual block with graph convolutions. 115 | 116 | The implementation uses the same number of input features for outputs. 117 | """ 118 | 119 | def __init__(self, in_features, p_dropout, output_nodes=48, bias=False): 120 | """Constructor. 121 | 122 | Args: 123 | in_features: Number of input and output features. 124 | p_dropout: Dropout used in the layers. 125 | output_nodes: Number of output nodes in the graph. 126 | """ 127 | super(GC_Block, self).__init__() 128 | self.in_features = in_features 129 | self.out_features = in_features 130 | 131 | self.gc1 = GraphConvolution( 132 | in_features, in_features, 133 | output_nodes=output_nodes, 134 | bias=bias 135 | ) 136 | self.bn1 = nn.BatchNorm1d(output_nodes * in_features) 137 | self.gc2 = GraphConvolution( 138 | in_features, in_features, 139 | output_nodes=output_nodes, 140 | bias=bias 141 | ) 142 | self.bn2 = nn.BatchNorm1d(output_nodes * in_features) 143 | 144 | self.do = nn.Dropout(p_dropout) 145 | self.act_f = nn.Tanh() 146 | 147 | def forward(self, x): 148 | """Forward pass of the residual module""" 149 | y = self.gc1(x) 150 | b, n, f = y.shape 151 | y = self.bn1(y.view(b, -1)).view(b, n, f) 152 | y = self.act_f(y) 153 | y = self.do(y) 154 | 155 | y = self.gc2(y) 156 | b, n, f = y.shape 157 | y = self.bn2(y.view(b, -1)).view(b, n, f) 158 | y = self.act_f(y) 159 | y = self.do(y) 160 | 161 | return y + x 162 | 163 | def __repr__(self): 164 | return self.__class__.__name__ + ' (' \ 165 | + str(self.in_features) + ' -> ' \ 166 | + str(self.out_features) + ')' 167 | 168 | 169 | class PoseGCN(nn.Module): 170 | def __init__(self, 171 | input_features=128, 172 | output_features=3, 173 | model_dim=128, 174 | output_nodes=21, 175 | p_dropout=0.1, 176 | num_stage=1): 177 | """Constructor. 178 | 179 | Args: 180 | input_feature: num of input feature of the graph nodes. 181 | model_dim: num of hidden features of the generated embeddings. 182 | p_dropout: dropout probability 183 | num_stage: number of residual blocks in the network. 184 | output_nodes: number of nodes in graph. 185 | """ 186 | super(PoseGCN, self).__init__() 187 | self.num_stage = num_stage 188 | self._n_nodes = output_nodes 189 | self._model_dim = model_dim 190 | self._output_features = output_features 191 | self._hidden_dim = 512 192 | 193 | self._front = nn.Sequential( 194 | nn.Linear(model_dim, output_nodes * self._hidden_dim), 195 | nn.Dropout(p_dropout) 196 | ) 197 | utils.weight_init(self._front, init_fn_=utils.xavier_init_) 198 | 199 | self.gc1 = GraphConvolution( 200 | self._hidden_dim, 201 | self._hidden_dim, 202 | output_nodes=output_nodes 203 | ) 204 | self.bn1 = nn.BatchNorm1d(output_nodes * self._hidden_dim) 205 | 206 | self.gcbs = [] 207 | for i in range(num_stage): 208 | self.gcbs.append(GC_Block( 209 | self._hidden_dim, 210 | p_dropout=p_dropout, 211 | output_nodes=output_nodes) 212 | ) 213 | 214 | self.gcbs = nn.ModuleList(self.gcbs) 215 | 216 | self.gc7 = GraphConvolution( 217 | self._hidden_dim, 218 | output_features, 219 | output_nodes=output_nodes 220 | ) 221 | self.do = nn.Dropout(p_dropout) 222 | self.act_f = nn.Tanh() 223 | 224 | gcn_params = filter(lambda p: p.requires_grad, self.parameters()) 225 | nparams = sum([np.prod(p.size()) for p in gcn_params]) 226 | print('[INFO] ({}) GCN has {} params!'.format(self.__class__.__name__, nparams)) 227 | 228 | def preprocess(self, x): 229 | if len(x.size()) < 3: 230 | _, D = x.size() 231 | # seq_len, batch_size, input_dim 232 | x = x.view(self._seq_len, -1, D) 233 | # [batch_size, seq_len, input_dim] 234 | x = torch.transpose(x, 0, 1) 235 | # [batch_size, input_dim, seq_len] 236 | x = torch.transpose(x, 1, 2) 237 | return x 238 | 239 | return x 240 | 241 | def postprocess(self, y): 242 | """Flattents the input tensor. 243 | Args: 244 | y: Input tensor of shape [batch_size, n_nodes, output_features]. 245 | """ 246 | y = y.view(-1, self._n_nodes * self._output_features) 247 | return y 248 | 249 | def forward(self, x): 250 | """Forward pass of network. 251 | 252 | Args: 253 | x: [batch_size, model_dim]. 254 | """ 255 | # [batch_size, model_dim*n_nodes] 256 | x = self._front(x) 257 | x = x.view(-1, self._n_nodes, self._hidden_dim) 258 | 259 | # [batch_size, n_joints, model_dim] 260 | y = self.gc1(x) 261 | b, n, f = y.shape 262 | y = self.bn1(y.view(b, -1)).view(b, n, f) 263 | y = self.act_f(y) 264 | y = self.do(y) 265 | 266 | for i in range(self.num_stage): 267 | y = self.gcbs[i](y) 268 | 269 | # [batch_size, n_joints, output_features] 270 | y = self.gc7(y) 271 | # y = y + x 272 | 273 | # [seq_len*batch_size, input_dim] 274 | y = self.postprocess(y) 275 | 276 | return y 277 | 278 | 279 | class SimpleEncoder(nn.Module): 280 | def __init__(self, 281 | n_nodes=63, 282 | input_features=1, 283 | model_dim=128, 284 | p_dropout=0.1): 285 | """Constructor. 286 | 287 | Args: 288 | input_dim: Dimension of the input vector. This will be equivalent to 289 | the number of nodes in the graph, each node with 1 feature each. 290 | model_dim: Dimension of the output vector to produce. 291 | p_dropout: Dropout to be applied for regularization. 292 | """ 293 | super(SimpleEncoder, self).__init__() 294 | # The graph convolutions can be defined as \sigma(AxHxW), where A is the 295 | # A\in R^{NxN} x H\in R^{NxM} x W\in R ^{MxO} 296 | self._input_features = input_features 297 | self._output_nodes = n_nodes 298 | self._hidden_dim = 512 299 | self._model_dim = model_dim 300 | self._num_stage = 1 301 | 302 | print('[INFO] ({}) Hidden dimension: {}!'.format( 303 | self.__class__.__name__, self._hidden_dim)) 304 | self.gc1 = GraphConvolution( 305 | in_features=self._input_features, 306 | out_features=self._hidden_dim, 307 | output_nodes=self._output_nodes 308 | ) 309 | self.bn1 = nn.BatchNorm1d(self._output_nodes * self._hidden_dim) 310 | self.gc2 = GraphConvolution( 311 | in_features=self._hidden_dim, 312 | out_features=model_dim, 313 | output_nodes=self._output_nodes 314 | ) 315 | 316 | self.gcbs = [] 317 | for i in range(self._num_stage): 318 | self.gcbs.append(GC_Block( 319 | self._hidden_dim, 320 | p_dropout=p_dropout, 321 | output_nodes=self._output_nodes) 322 | ) 323 | self.gcbs = nn.ModuleList(self.gcbs) 324 | 325 | self.do = nn.Dropout(p_dropout) 326 | self.act_f = nn.Tanh() 327 | 328 | self._back = nn.Sequential( 329 | nn.Linear(model_dim * self._output_nodes, model_dim), 330 | nn.Dropout(p_dropout) 331 | ) 332 | utils.weight_init(self._back, init_fn_=utils.xavier_init_) 333 | 334 | gcn_params = filter(lambda p: p.requires_grad, self.parameters()) 335 | nparams = sum([np.prod(p.size()) for p in gcn_params]) 336 | print('[INFO] ({}) GCN has {} params!'.format(self.__class__.__name__, nparams)) 337 | 338 | def forward(self, x): 339 | """Forward pass of network. 340 | 341 | Args: 342 | x: [batch_size, n_poses, pose_dim/input_dim]. 343 | """ 344 | B, S, D = x.size() 345 | # [batch_size, n_joints, model_dim] 346 | y = self.gc1(x.view(-1, self._output_nodes, self._input_features)) 347 | b, n, f = y.shape 348 | y = self.bn1(y.view(b, -1)).view(b, n, f) 349 | y = self.act_f(y) 350 | y = self.do(y) 351 | 352 | for i in range(self._num_stage): 353 | y = self.gcbs[i](y) 354 | 355 | # [batch_size, n_joints, model_dim] 356 | y = self.gc2(y) 357 | 358 | # [batch_size, model_dim] 359 | y = self._back(y.view(-1, self._model_dim * self._output_nodes)) 360 | 361 | # [batch_size, n_poses, model_dim] 362 | y = y.view(B, S, self._model_dim) 363 | 364 | return y 365 | 366 | 367 | def test_decoder(): 368 | seq_len = 25 369 | input_size = 63 370 | model_dim = 128 371 | dropout = 0.3 372 | n_stages = 2 373 | output_nodes = 21 374 | 375 | joint_dof = 1 376 | n_joints = model_dim 377 | layer = GraphConvolution( 378 | in_features=joint_dof, 379 | out_features=model_dim, 380 | output_nodes=n_joints 381 | ) 382 | 383 | X = torch.FloatTensor(10, n_joints, joint_dof) 384 | print(layer(X).size()) 385 | 386 | gcn = PoseGCN( 387 | input_features=model_dim, 388 | output_features=3, 389 | model_dim=model_dim, 390 | output_nodes=output_nodes, 391 | p_dropout=0.1, 392 | num_stage=2 393 | ) 394 | 395 | X = torch.FloatTensor(10 * seq_len, model_dim) 396 | print(gcn(X).size()) 397 | 398 | 399 | def test_encoder(): 400 | input_size = 63 401 | model_dim = 128 402 | dropout = 0.3 403 | n_stages = 2 404 | output_nodes = 21 405 | dof = 9 406 | 407 | encoder = SimpleEncoder( 408 | n_nodes=output_nodes, 409 | model_dim=model_dim, 410 | input_features=dof, 411 | p_dropout=0.1 412 | ) 413 | X = torch.FloatTensor(10, 25, output_nodes * dof) 414 | 415 | print(encoder(X).size()) 416 | 417 | 418 | if __name__ == '__main__': 419 | # test_decoder() 420 | test_encoder() 421 | -------------------------------------------------------------------------------- /model/poseformer/PoseTransformer.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Pose Transformers (POTR): Human Motion Prediction with Non-Autoregressive 3 | # Transformers 4 | # 5 | # Copyright (c) 2021 Idiap Research Institute, http://www.idiap.ch/ 6 | # Written by 7 | # Angel Martinez , 8 | # 9 | # This file is part of 10 | # POTR: Human Motion Prediction with Non-Autoregressive Transformers 11 | # 12 | # POTR is free software: you can redistribute it and/or modify 13 | # it under the terms of the GNU General Public License version 3 as 14 | # published by the Free Software Foundation. 15 | # 16 | # POTR is distributed in the hope that it will be useful, 17 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 18 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 | # GNU General Public License for more details. 20 | # 21 | # You should have received a copy of the GNU General Public License 22 | # along with POTR. If not, see . 23 | ############################################################################### 24 | 25 | """Implementation of the Transformer for sequence-to-sequence decoding. 26 | 27 | Implementation of the transformer for sequence to sequence prediction as in 28 | [1] and [2]. 29 | 30 | [1] https://arxiv.org/pdf/1706.03762.pdf 31 | [2] https://arxiv.org/pdf/2005.12872.pdf 32 | """ 33 | 34 | 35 | import numpy as np 36 | import os 37 | import sys 38 | import copy 39 | import torch 40 | import torch.nn as nn 41 | import torch.nn.functional as F 42 | 43 | thispath = os.path.dirname(os.path.abspath(__file__)) 44 | sys.path.insert(0, thispath+"/../") 45 | 46 | from model import utils 47 | from model.poseformer import PositionEncodings 48 | from model.poseformer.Transformer import Transformer 49 | 50 | 51 | _SOURCE_LENGTH = 110 52 | _POSE_DIM = 54 53 | _PAD_LENGTH = _SOURCE_LENGTH 54 | 55 | 56 | class PoseTransformer(nn.Module): 57 | """Implements the sequence-to-sequence Transformer .model for pose prediction.""" 58 | def __init__(self, 59 | pose_dim=_POSE_DIM, 60 | source_seq_length=_SOURCE_LENGTH, 61 | model_dim=256, 62 | num_encoder_layers=6, 63 | num_heads=8, 64 | dim_ffn=2048, 65 | dropout=0.1, 66 | input_dim=None, 67 | init_fn=utils.xavier_init_, 68 | pre_normalization=False, 69 | pose_embedding=None, 70 | copy_method='uniform_scan', 71 | pos_encoding_params=(10000, 1)): 72 | """Initialization of pose transformers.""" 73 | super(PoseTransformer, self).__init__() 74 | self._source_seq_length = source_seq_length 75 | self._pose_dim = pose_dim 76 | self._input_dim = pose_dim if input_dim is None else input_dim 77 | self._model_dim = model_dim 78 | self._use_class_token = False 79 | 80 | self._mlp_dim = model_dim 81 | self._pose_embedding = pose_embedding 82 | thisname = self.__class__.__name__ 83 | self._copy_method = copy_method 84 | self._pos_encoding_params = pos_encoding_params 85 | 86 | self._transformer = Transformer( 87 | num_encoder_layers=num_encoder_layers, 88 | model_dim=model_dim, 89 | num_heads=num_heads, 90 | dim_ffn=dim_ffn, 91 | dropout=dropout, 92 | init_fn=init_fn, 93 | pre_normalization=pre_normalization, 94 | ) 95 | 96 | self._pos_encoder = PositionEncodings.PositionEncodings1D( 97 | num_pos_feats=self._model_dim, 98 | temperature=self._pos_encoding_params[0], 99 | alpha=self._pos_encoding_params[1] 100 | ) 101 | 102 | self.init_position_encodings() 103 | 104 | 105 | def init_position_encodings(self): 106 | src_len = self._source_seq_length 107 | # when using a token we need an extra element in the sequence 108 | if self._use_class_token: 109 | src_len = src_len + 1 110 | encoder_pos_encodings = self._pos_encoder(src_len).view( 111 | src_len, 1, self._model_dim) 112 | self._encoder_pos_encodings = nn.Parameter( 113 | encoder_pos_encodings, requires_grad=False) 114 | 115 | 116 | def forward(self, 117 | input_pose_seq, 118 | get_attn_weights=False, 119 | fold=None, 120 | eval_step=None): 121 | """Performs the forward pass of the pose transformers. 122 | 123 | Args: 124 | input_pose_seq: Shape [batch_size, src_sequence_length, dim_pose]. 125 | target_pose_seq: Shape [batch_size, tgt_sequence_length, dim_pose]. 126 | 127 | Returns: 128 | A tensor of the predicted sequence with shape [batch_size, 129 | tgt_sequence_length, dim_pose]. 130 | """ 131 | # 1) Encode the sequence with given pose encoder 132 | # [batch_size, sequence_length, model_dim] 133 | input_pose_seq = input_pose_seq 134 | if self._pose_embedding is not None: 135 | input_pose_seq = self._pose_embedding(input_pose_seq) 136 | 137 | # 2) compute the look-ahead mask and the positional encodings 138 | # [sequence_length, batch_size, model_dim] 139 | input_pose_seq = torch.transpose(input_pose_seq, 0, 1) 140 | 141 | # 3) compute the attention weights using the transformer 142 | # [target_sequence_length, batch_size, model_dim] 143 | memory = self._transformer( 144 | input_pose_seq, 145 | encoder_position_encodings=self._encoder_pos_encodings, 146 | ) 147 | 148 | return memory 149 | 150 | 151 | 152 | def model_factory(params, pose_embedding_fn): 153 | init_fn = utils.normal_init_ \ 154 | if params['init_fn'] == 'normal_init' else utils.xavier_init_ 155 | return PoseTransformer( 156 | pose_dim=params['pose_dim'], 157 | input_dim=params['input_dim'], 158 | source_seq_length=params['source_seq_len'], 159 | model_dim=params['model_dim'], 160 | num_encoder_layers=params['num_encoder_layers'], 161 | num_heads=params['num_heads'], 162 | dim_ffn=params['dim_ffn'], 163 | dropout=params['dropout'], 164 | init_fn=init_fn, 165 | pose_embedding=pose_embedding_fn(params), 166 | pos_encoding_params=(params['pos_enc_beta'], params['pos_enc_alpha']) 167 | ) 168 | 169 | 170 | if __name__ == '__main__': 171 | transformer = PoseTransformer(model_dim=_POSE_DIM, num_heads=6) 172 | transformer.eval() 173 | batch_size = 8 174 | model_dim = 256 175 | tgt_seq = torch.FloatTensor(batch_size, _TARGET_LENGTH, _POSE_DIM).fill_(1) 176 | src_seq = torch.FloatTensor(batch_size, _SOURCE_LENGTH-1, _POSE_DIM).fill_(1) 177 | 178 | outputs = transformer(src_seq, tgt_seq) 179 | print(outputs[-1].size()) 180 | 181 | -------------------------------------------------------------------------------- /model/poseformer/PositionEncodings.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Pose Transformers (POTR): Human Motion Prediction with Non-Autoregressive 3 | # Transformers 4 | # 5 | # Copyright (c) 2021 Idiap Research Institute, http://www.idiap.ch/ 6 | # Written by 7 | # Angel Martinez , 8 | # 9 | # This file is part of 10 | # POTR: Human Motion Prediction with Non-Autoregressive Transformers 11 | # 12 | # POTR is free software: you can redistribute it and/or modify 13 | # it under the terms of the GNU General Public License version 3 as 14 | # published by the Free Software Foundation. 15 | # 16 | # POTR is distributed in the hope that it will be useful, 17 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 18 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 | # GNU General Public License for more details. 20 | # 21 | # You should have received a copy of the GNU General Public License 22 | # along with POTR. If not, see . 23 | ############################################################################### 24 | 25 | """Implementation of the 2D positional encodings used in [1]. 26 | 27 | Position encodings gives a signature to each pixel in the image by a set 28 | of sine frequecies computed with a 2D sine function. 29 | 30 | [1] https://arxiv.org/abs/2005.12872 31 | [2] https://arxiv.org/pdf/1706.03762.pdf 32 | """ 33 | 34 | import numpy as np 35 | import math 36 | import torch 37 | from torch import nn 38 | 39 | 40 | class PositionEncodings2D(object): 41 | """Implementation of 2d masked position encodings as a NN layer. 42 | 43 | This is a more general version of the position embedding, very similar 44 | to the one used by the Attention is all you need paper, but generalized 45 | to work on images as used in [1]. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | num_pos_feats=64, 51 | temperature=10000, 52 | normalize=False, 53 | scale=None): 54 | """Constructs position embeding layer. 55 | 56 | Args: 57 | num_pos_feats: An integer for the depth of the encoding signature per 58 | pixel for each axis `x` and `y`. 59 | temperature: Value of the exponential temperature. 60 | normalize: Bool indicating if the encodings shuld be normalized by number 61 | of pixels in each image row. 62 | scale: Use for scaling factor. Normally None is used for 2*pi scaling. 63 | """ 64 | super().__init__() 65 | self._num_pos_feats = num_pos_feats 66 | self._temperature = temperature 67 | self._normalize = normalize 68 | if scale is not None and normalize is False: 69 | raise ValueError("normalize should be True if scale is passed") 70 | if scale is None: 71 | scale = 2 * math.pi 72 | self._scale = scale 73 | 74 | def __call__(self, mask): 75 | """Generates the positional encoding given image boolean mask. 76 | 77 | Args: 78 | mask: Boolean tensor of shape [batch_size, width, height] with ones 79 | in pixels that belong to the padding and zero in valid pixels. 80 | 81 | Returns: 82 | Sine position encodings. Shape [batch_size, num_pos_feats*2, width, height] 83 | """ 84 | # the positional encodings are generated for valid pixels hence 85 | # we need to take the negation of the boolean mask 86 | not_mask = ~mask 87 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 88 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 89 | if self._normalize: 90 | eps = 1e-6 91 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self._scale 92 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self._scale 93 | 94 | dim_t = torch.arange( 95 | self._num_pos_feats, dtype=torch.float32) 96 | dim_t = self._temperature ** (2 * (dim_t // 2) / self._num_pos_feats) 97 | 98 | pos_x = x_embed[:, :, :, None] / dim_t 99 | pos_y = y_embed[:, :, :, None] / dim_t 100 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), 101 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 102 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), 103 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 104 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 105 | 106 | return pos 107 | 108 | 109 | class PositionEncodings1D(object): 110 | """Positional encodings for `1D` sequences. 111 | 112 | Implements the following equations: 113 | 114 | PE_{(pos, 2i)} = sin(pos/10000^{2i/d_model}) 115 | PE_{(pos, 2i+1)} = cos(pos/10000^{2i/d_model}) 116 | 117 | Where d_model is the number of positional features. Also known as the 118 | depth of the positional encodings. These are the positional encodings 119 | proposed in [2]. 120 | """ 121 | 122 | def __init__(self, num_pos_feats=512, temperature=10000, alpha=1): 123 | self._num_pos_feats = num_pos_feats 124 | self._temperature = temperature 125 | self._alpha = alpha 126 | 127 | def __call__(self, seq_length): 128 | angle_rads = self.get_angles( 129 | np.arange(seq_length)[:, np.newaxis], 130 | np.arange(self._num_pos_feats)[np.newaxis, :] 131 | ) 132 | 133 | # apply sin to even indices in the array; 2i 134 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) 135 | # apply cos to odd indices in the array; 2i+1 136 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) 137 | pos_encoding = angle_rads[np.newaxis, ...] 138 | pos_encoding = pos_encoding.astype(np.float32) 139 | 140 | return torch.from_numpy(pos_encoding) 141 | 142 | def get_angles(self, pos, i): 143 | angle_rates = 1 / np.power( 144 | self._temperature, (2 * (i // 2)) / np.float32(self._num_pos_feats)) 145 | return self._alpha * pos * angle_rates 146 | 147 | 148 | def visualize_2d_encodings(): 149 | import cv2 150 | import numpy as np 151 | import matplotlib.pyplot as pplt 152 | 153 | # Create a mask where pixels are all valid 154 | mask = torch.BoolTensor(1, 32, 32).fill_(False) 155 | # position encodigns with a signature of depth per pixel 156 | # the efective pixel signature is num_pos_feats*2 (128 for each axis) 157 | pos_encodings_gen = PositionEncodings2D(num_pos_feats=128, normalize=True) 158 | 159 | encodings = pos_encodings_gen(mask).numpy() 160 | print('Shape of encodings', encodings.shape) 161 | # visualize the first frequency channel for x and y 162 | y_encodings = encodings[0, 0, :, :] 163 | x_encodings = encodings[0, 128, :, :] 164 | 165 | pplt.matshow(x_encodings, cmap=pplt.get_cmap('jet')) 166 | pplt.matshow(y_encodings, cmap=pplt.get_cmap('jet')) 167 | pplt.show() 168 | 169 | 170 | def visualize_1d_encodings(): 171 | import matplotlib.pyplot as plt 172 | pos_encoder_gen = PositionEncodings1D() 173 | 174 | pos_encoding = pos_encoder_gen(50).numpy() 175 | print(pos_encoding.shape) 176 | 177 | plt.pcolormesh(pos_encoding[0], cmap='RdBu') 178 | plt.xlabel('Depth') 179 | plt.xlim((0, 512)) 180 | plt.ylabel('position in sequence') 181 | plt.colorbar() 182 | plt.show() 183 | 184 | 185 | if __name__ == "__main__": 186 | visualize_2d_encodings() 187 | # visualize_1d_encodings() 188 | -------------------------------------------------------------------------------- /model/poseformer/Transformer.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Pose Transformers (POTR): Human Motion Prediction with Non-Autoregressive 3 | # Transformers 4 | # 5 | # Copyright (c) 2021 Idiap Research Institute, http://www.idiap.ch/ 6 | # Written by 7 | # Angel Martinez , 8 | # 9 | # This file is part of 10 | # POTR: Human Motion Prediction with Non-Autoregressive Transformers 11 | # 12 | # POTR is free software: you can redistribute it and/or modify 13 | # it under the terms of the GNU General Public License version 3 as 14 | # published by the Free Software Foundation. 15 | # 16 | # POTR is distributed in the hope that it will be useful, 17 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 18 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 | # GNU General Public License for more details. 20 | # 21 | # You should have received a copy of the GNU General Public License 22 | # along with POTR. If not, see . 23 | ############################################################################### 24 | 25 | 26 | """Implementation of the Transformer for sequence-to-sequence decoding. 27 | 28 | Implementation of the transformer for sequence to sequence prediction as in 29 | [1] and [2]. 30 | 31 | [1] https://arxiv.org/pdf/1706.03762.pdf 32 | [2] https://arxiv.org/pdf/2005.12872.pdf 33 | """ 34 | 35 | 36 | import numpy as np 37 | import os 38 | import sys 39 | import copy 40 | 41 | import torch 42 | import torch.nn as nn 43 | import torch.nn.functional as F 44 | 45 | from scipy.optimize import linear_sum_assignment 46 | 47 | thispath = os.path.dirname(os.path.abspath(__file__)) 48 | sys.path.insert(0, thispath+"/../") 49 | 50 | from model import utils 51 | import model.poseformer.TransformerEncoder as Encoder 52 | 53 | 54 | class Transformer(nn.Module): 55 | def __init__(self, 56 | num_encoder_layers=6, 57 | model_dim=256, 58 | num_heads=8, 59 | dim_ffn=2048, 60 | dropout=0.1, 61 | init_fn=utils.normal_init_, 62 | pre_normalization=False): 63 | """Implements the Transformer model for sequence-to-sequence modeling.""" 64 | super(Transformer, self).__init__() 65 | self._model_dim = model_dim 66 | self._num_heads = num_heads 67 | self._dim_ffn = dim_ffn 68 | self._dropout = dropout 69 | 70 | self._encoder = Encoder.TransformerEncoder( 71 | num_layers=num_encoder_layers, 72 | model_dim=model_dim, 73 | num_heads=num_heads, 74 | dim_ffn=dim_ffn, 75 | dropout=dropout, 76 | init_fn=init_fn, 77 | pre_normalization=pre_normalization 78 | ) 79 | 80 | 81 | def forward(self, 82 | source_seq, 83 | encoder_position_encodings=None): 84 | 85 | 86 | memory, enc_weights = self._encoder(source_seq, encoder_position_encodings) 87 | 88 | # Save encoder outputs 89 | # if fold is not None: 90 | # encoder_output_dir = 'encoder_outputs' 91 | # if not os.path.exists(f'{encoder_output_dir}f{fold}/'): 92 | # os.makedirs(f'{encoder_output_dir}f{fold}/') 93 | # outpath = f'{encoder_output_dir}f{fold}/{eval_step}.npy' 94 | # encoder_output = memory.detach().cpu().numpy() 95 | # np.save(outpath, encoder_output) 96 | 97 | return memory 98 | -------------------------------------------------------------------------------- /model/poseformer/TransformerEncoder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Pose Transformers (POTR): Human Motion Prediction with Non-Autoregressive 3 | # Transformers 4 | # 5 | # Copyright (c) 2021 Idiap Research Institute, http://www.idiap.ch/ 6 | # Written by 7 | # Angel Martinez , 8 | # 9 | # This file is part of 10 | # POTR: Human Motion Prediction with Non-Autoregressive Transformers 11 | # 12 | # POTR is free software: you can redistribute it and/or modify 13 | # it under the terms of the GNU General Public License version 3 as 14 | # published by the Free Software Foundation. 15 | # 16 | # POTR is distributed in the hope that it will be useful, 17 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 18 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 | # GNU General Public License for more details. 20 | # 21 | # You should have received a copy of the GNU General Public License 22 | # along with POTR. If not, see . 23 | ############################################################################### 24 | 25 | """Implementation of Transformer encoder and encoder layer with self attention. 26 | 27 | Implementation of the encoder layer as in [1] and [2] for sequence to 28 | sequence modeling. 29 | 30 | [1] https://arxiv.org/pdf/1706.03762.pdf 31 | [2] https://arxiv.org/pdf/2005.12872.pdf 32 | """ 33 | 34 | import numpy as np 35 | import sys 36 | import os 37 | 38 | import torch 39 | import torch.nn as nn 40 | 41 | # thispath = os.path.dirname(os.path.abspath(__file__)) 42 | # sys.path.insert(0, thispath + "/../") 43 | from model import utils 44 | 45 | 46 | class EncoderLayer(nn.Module): 47 | """Implements the transformer encoder Layer.""" 48 | 49 | def __init__(self, 50 | model_dim=256, 51 | num_heads=8, 52 | dim_ffn=2048, 53 | dropout=0.1, 54 | init_fn=utils.normal_init_, 55 | pre_normalization=False): 56 | """Encoder layer initialization. 57 | 58 | Args: 59 | model_dim: 60 | num_heads: 61 | dim_ffn: 62 | dropout: 63 | """ 64 | super(EncoderLayer, self).__init__() 65 | self._model_dim = model_dim 66 | self._num_heads = num_heads 67 | self._dim_ffn = dim_ffn 68 | self._dropout = dropout 69 | self._pre_normalization = pre_normalization 70 | 71 | self._self_attn = nn.MultiheadAttention(model_dim, num_heads, dropout) 72 | self._relu = nn.ReLU() 73 | self._dropout_layer = nn.Dropout(self._dropout) 74 | 75 | self._linear1 = nn.Linear(model_dim, self._dim_ffn) 76 | self._linear2 = nn.Linear(self._dim_ffn, self._model_dim) 77 | self._norm1 = nn.LayerNorm(model_dim, eps=1e-5) 78 | self._norm2 = nn.LayerNorm(model_dim, eps=1e-5) 79 | 80 | utils.weight_init(self._linear1, init_fn_=init_fn) 81 | utils.weight_init(self._linear2, init_fn_=init_fn) 82 | 83 | def forward(self, source_seq, pos_encodings): 84 | """Computes forward pass according. 85 | 86 | Args: 87 | source_seq: [sequence_length, batch_size, model_dim]. 88 | pos_encodings: [sequence_length, model_dim]. 89 | 90 | Returns: 91 | Tensor of shape [sequence_length, batch_size, model_dim]. 92 | """ 93 | if self._pre_normalization: 94 | return self.forward_pre(source_seq, pos_encodings) 95 | 96 | return self.forward_post(source_seq, pos_encodings) 97 | 98 | def forward_post(self, source_seq, pos_encodings): 99 | """Computes decoder layer forward pass with pre normalization. 100 | 101 | Args: 102 | source_seq: [sequence_length, batch_size, model_dim]. 103 | pos_encodings: [sequence_length, model_dim]. 104 | 105 | Returns: 106 | Tensor of shape [sequence_length, batch_size, model_dim]. 107 | """ 108 | # add positional encodings to the input sequence 109 | # for self attention query is the same as key 110 | query = source_seq + pos_encodings 111 | key = query 112 | value = source_seq 113 | 114 | attn_output, attn_weights = self._self_attn( 115 | query, 116 | key, 117 | value, 118 | need_weights=True 119 | ) 120 | 121 | norm_attn = self._dropout_layer(attn_output) + source_seq 122 | norm_attn = self._norm1(norm_attn) 123 | 124 | output = self._linear1(norm_attn) 125 | output = self._relu(output) 126 | output = self._dropout_layer(output) 127 | output = self._linear2(output) 128 | output = self._dropout_layer(output) + norm_attn 129 | output = self._norm2(output) 130 | 131 | return output, attn_weights 132 | 133 | def forward_pre(self, source_seq_, pos_encodings): 134 | """Computes decoder layer forward pass with pre normalization. 135 | 136 | Args: 137 | source_seq: [sequence_length, batch_size, model_dim]. 138 | pos_encodings: [sequence_length, model_dim]. 139 | 140 | Returns: 141 | Tensor of shape [sequence_length, batch_size, model_dim]. 142 | """ 143 | # add positional encodings to the input sequence 144 | # for self attention query is the same as key 145 | source_seq = self._norm1(source_seq_) 146 | query = source_seq + pos_encodings 147 | key = query 148 | value = source_seq 149 | 150 | attn_output, attn_weights = self._self_attn( 151 | query, 152 | key, 153 | value, 154 | need_weights=True 155 | ) 156 | 157 | norm_attn_ = self._dropout_layer(attn_output) + source_seq_ 158 | norm_attn = self._norm2(norm_attn_) 159 | 160 | output = self._linear1(norm_attn) 161 | output = self._relu(output) 162 | output = self._dropout_layer(output) 163 | output = self._linear2(output) 164 | output = self._dropout_layer(output) + norm_attn_ 165 | 166 | return output, attn_weights 167 | 168 | 169 | class TransformerEncoder(nn.Module): 170 | def __init__(self, 171 | num_layers=6, 172 | model_dim=256, 173 | num_heads=8, 174 | dim_ffn=2048, 175 | dropout=0.1, 176 | init_fn=utils.normal_init_, 177 | pre_normalization=False): 178 | super(TransformerEncoder, self).__init__() 179 | """Transforme encoder initialization.""" 180 | self._num_layers = num_layers 181 | self._model_dim = model_dim 182 | self._num_heads = num_heads 183 | self._dim_ffn = dim_ffn 184 | self._dropout = dropout 185 | # self._norm = norm 186 | self._pre_normalization = pre_normalization 187 | 188 | self._encoder_stack = self.init_encoder_stack(init_fn) 189 | 190 | def init_encoder_stack(self, init_fn): 191 | """Create the stack of encoder layers.""" 192 | stack = nn.ModuleList() 193 | for s in range(self._num_layers): 194 | layer = EncoderLayer( 195 | model_dim=self._model_dim, 196 | num_heads=self._num_heads, 197 | dim_ffn=self._dim_ffn, 198 | dropout=self._dropout, 199 | init_fn=init_fn, 200 | pre_normalization=self._pre_normalization 201 | ) 202 | stack.append(layer) 203 | return stack 204 | 205 | def forward(self, input_sequence, pos_encodings): 206 | """Computes decoder forward pass. 207 | 208 | Args: 209 | source_seq: [sequence_length, batch_size, model_dim]. 210 | pos_encodings: [sequence_length, model_dim]. 211 | 212 | Returns: 213 | Tensor of shape [sequence_length, batch_size, model_dim]. 214 | """ 215 | outputs = input_sequence 216 | 217 | for l in range(self._num_layers): 218 | outputs, attn_weights = self._encoder_stack[l](outputs, pos_encodings) 219 | 220 | # if self._norm: 221 | # outputs = self._norm(outputs) 222 | 223 | return outputs, attn_weights 224 | 225 | 226 | if __name__ == '__main__': 227 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 228 | seq_length = 50 229 | 230 | pos_encodings = torch.FloatTensor(seq_length, 1, 256).uniform_(0, 1) 231 | seq = torch.FloatTensor(seq_length, 8, 256).fill_(1.0) 232 | 233 | pos_encodings = pos_encodings.to(device) 234 | seq = seq.to(device) 235 | 236 | encoder = TransformerEncoder(num_layers=6) 237 | encoder.to(device) 238 | encoder.eval() 239 | 240 | print(encoder(seq, pos_encodings).size()) 241 | -------------------------------------------------------------------------------- /model/poseformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/model/poseformer/__init__.py -------------------------------------------------------------------------------- /model/poseformerv2/model_poseformer.py: -------------------------------------------------------------------------------- 1 | ## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 2 | # Written by Ce Zheng (cezheng@knights.ucf.edu) 3 | # Modified by Qitao Zhao (qitaozhao@mail.sdu.edu.cn) 4 | 5 | import math 6 | import logging 7 | from functools import partial 8 | from einops import rearrange 9 | 10 | import torch 11 | import torch_dct as dct 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | import numpy as np 16 | from timm.models.layers import DropPath 17 | 18 | 19 | class Mlp(nn.Module): 20 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 21 | super().__init__() 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | self.fc1 = nn.Linear(in_features, hidden_features) 25 | self.act = act_layer() 26 | self.fc2 = nn.Linear(hidden_features, out_features) 27 | self.drop = nn.Dropout(drop) 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | x = self.act(x) 32 | x = self.drop(x) 33 | x = self.fc2(x) 34 | x = self.drop(x) 35 | return x 36 | 37 | 38 | class FreqMlp(nn.Module): 39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 40 | super().__init__() 41 | out_features = out_features or in_features 42 | hidden_features = hidden_features or in_features 43 | self.fc1 = nn.Linear(in_features, hidden_features) 44 | self.act = act_layer() 45 | self.fc2 = nn.Linear(hidden_features, out_features) 46 | self.drop = nn.Dropout(drop) 47 | 48 | def forward(self, x): 49 | b, f, _ = x.shape 50 | x = dct.dct(x.permute(0, 2, 1)).permute(0, 2, 1).contiguous() 51 | x = self.fc1(x) 52 | x = self.act(x) 53 | x = self.drop(x) 54 | x = self.fc2(x) 55 | x = self.drop(x) 56 | x = dct.idct(x.permute(0, 2, 1)).permute(0, 2, 1).contiguous() 57 | return x 58 | 59 | 60 | class Attention(nn.Module): 61 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 62 | super().__init__() 63 | self.num_heads = num_heads 64 | head_dim = dim // num_heads 65 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 66 | self.scale = qk_scale or head_dim ** -0.5 67 | 68 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 69 | self.attn_drop = nn.Dropout(attn_drop) 70 | self.proj = nn.Linear(dim, dim) 71 | self.proj_drop = nn.Dropout(proj_drop) 72 | 73 | def forward(self, x): 74 | B, N, C = x.shape 75 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 76 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 77 | 78 | attn = (q @ k.transpose(-2, -1)) * self.scale 79 | attn = attn.softmax(dim=-1) 80 | attn = self.attn_drop(attn) 81 | 82 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 83 | x = self.proj(x) 84 | x = self.proj_drop(x) 85 | return x 86 | 87 | 88 | class Block(nn.Module): 89 | 90 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 91 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 92 | super().__init__() 93 | self.norm1 = norm_layer(dim) 94 | self.attn = Attention( 95 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 96 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 97 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 98 | self.norm2 = norm_layer(dim) 99 | mlp_hidden_dim = int(dim * mlp_ratio) 100 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 101 | 102 | def forward(self, x): 103 | x = x + self.drop_path(self.attn(self.norm1(x))) 104 | x = x + self.drop_path(self.mlp(self.norm2(x))) 105 | return x 106 | 107 | 108 | class MixedBlock(nn.Module): 109 | 110 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 111 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 112 | super().__init__() 113 | self.norm1 = norm_layer(dim) 114 | self.attn = Attention( 115 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 116 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 117 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 118 | self.norm2 = norm_layer(dim) 119 | mlp_hidden_dim = int(dim * mlp_ratio) 120 | self.mlp1 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 121 | self.norm3 = norm_layer(dim) 122 | self.mlp2 = FreqMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 123 | 124 | def forward(self, x): 125 | b, f, c = x.shape 126 | x = x + self.drop_path(self.attn(self.norm1(x))) 127 | x1 = x[:, :f//2] + self.drop_path(self.mlp1(self.norm2(x[:, :f//2]))) 128 | x2 = x[:, f//2:] + self.drop_path(self.mlp2(self.norm3(x[:, f//2:]))) 129 | return torch.cat((x1, x2), dim=1) 130 | 131 | 132 | class PoseTransformerV2(nn.Module): 133 | def __init__(self, num_joints=17, in_chans=2, embed_dim_ratio=2, depth=1, 134 | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, 135 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None, 136 | number_of_kept_frames=1, number_of_kept_coeffs=1): 137 | """ ##########hybrid_backbone=None, representation_size=None, 138 | Args: 139 | num_joints (int, tuple): joints number 140 | in_chans (int): number of input channels, 2D joints have 2 channels: (x,y) 141 | embed_dim_ratio (int): embedding dimension ratio 142 | depth (int): depth of transformer 143 | num_heads (int): number of attention heads 144 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 145 | qkv_bias (bool): enable bias for qkv if True 146 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 147 | drop_rate (float): dropout rate 148 | attn_drop_rate (float): attention dropout rate 149 | drop_path_rate (float): stochastic depth rate 150 | norm_layer: (nn.Module): normalization layer 151 | """ 152 | super().__init__() 153 | 154 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 155 | embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio 156 | out_dim = num_joints * 3 #### output dimension is num_joints * 3 157 | self.num_frame_kept = number_of_kept_frames 158 | self.num_coeff_kept = number_of_kept_coeffs 159 | 160 | ### spatial patch embedding 161 | self.Joint_embedding = nn.Linear(in_chans, embed_dim_ratio) 162 | self.Freq_embedding = nn.Linear(in_chans*num_joints, embed_dim) 163 | 164 | self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) 165 | self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, self.num_frame_kept, embed_dim)) 166 | self.Temporal_pos_embed_ = nn.Parameter(torch.zeros(1, self.num_coeff_kept, embed_dim)) 167 | self.pos_drop = nn.Dropout(p=drop_rate) 168 | 169 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 170 | 171 | self.Spatial_blocks = nn.ModuleList([ 172 | Block( 173 | dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 174 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 175 | for i in range(depth)]) 176 | 177 | self.blocks = nn.ModuleList([ 178 | MixedBlock( 179 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 180 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 181 | for i in range(depth)]) 182 | 183 | self.Spatial_norm = norm_layer(embed_dim_ratio) 184 | self.Temporal_norm = norm_layer(embed_dim) 185 | 186 | ####### A easy way to implement weighted mean 187 | self.weighted_mean = torch.nn.Conv1d(in_channels=self.num_coeff_kept, out_channels=1, kernel_size=1) 188 | self.weighted_mean_ = torch.nn.Conv1d(in_channels=self.num_frame_kept, out_channels=1, kernel_size=1) 189 | 190 | self.head = nn.Sequential( 191 | nn.LayerNorm(embed_dim*2), 192 | nn.Linear(embed_dim*2, out_dim), 193 | ) 194 | 195 | def Spatial_forward_features(self, x): 196 | b, f, p, _ = x.shape ##### b is batch size, f is number of frames, p is number of joints 197 | num_frame_kept = self.num_frame_kept 198 | 199 | index = torch.arange((f-1)//2-num_frame_kept//2, (f-1)//2+num_frame_kept//2+1) 200 | 201 | x = self.Joint_embedding(x[:, index].view(b*num_frame_kept, p, -1)) 202 | x += self.Spatial_pos_embed 203 | x = self.pos_drop(x) 204 | 205 | for blk in self.Spatial_blocks: 206 | x = blk(x) 207 | x = self.Spatial_norm(x) 208 | x = rearrange(x, '(b f) p c -> b f (p c)', f=num_frame_kept) 209 | return x 210 | 211 | def forward_features(self, x, Spatial_feature): 212 | b, f, p, _ = x.shape 213 | num_coeff_kept = self.num_coeff_kept 214 | 215 | x = dct.dct(x.permute(0, 2, 3, 1))[:, :, :, :num_coeff_kept] 216 | x = x.permute(0, 3, 1, 2).contiguous().view(b, num_coeff_kept, -1) 217 | x = self.Freq_embedding(x) 218 | 219 | Spatial_feature += self.Temporal_pos_embed 220 | x += self.Temporal_pos_embed_ 221 | x = torch.cat((x, Spatial_feature), dim=1) 222 | 223 | for blk in self.blocks: 224 | x = blk(x) 225 | 226 | x = self.Temporal_norm(x) 227 | return x 228 | 229 | def forward(self, x, return_rep=True): 230 | b, f, p, _ = x.shape 231 | x_ = x.clone() 232 | Spatial_feature = self.Spatial_forward_features(x) 233 | x = self.forward_features(x_, Spatial_feature) 234 | x = torch.cat((self.weighted_mean(x[:, :self.num_coeff_kept]), self.weighted_mean_(x[:, self.num_coeff_kept:])), dim=-1) 235 | 236 | if return_rep: 237 | return x 238 | x = self.head(x).view(b, 1, p, -1) 239 | return x 240 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normal_init_(layer, mean_, sd_, bias, norm_bias=True): 6 | """Intialization of layers with normal distribution with mean and bias""" 7 | classname = layer.__class__.__name__ 8 | # Only use the convolutional layers of the module 9 | # if (classname.find('Conv') != -1 ) or (classname.find('Linear')!=-1): 10 | if classname.find('Linear') != -1: 11 | print('[INFO] (normal_init) Initializing layer {}'.format(classname)) 12 | layer.weight.data.normal_(mean_, sd_) 13 | if norm_bias: 14 | layer.bias.data.normal_(bias, 0.05) 15 | else: 16 | layer.bias.data.fill_(bias) 17 | 18 | 19 | def weight_init( 20 | module, 21 | mean_=0, 22 | sd_=0.004, 23 | bias=0.0, 24 | norm_bias=False, 25 | init_fn_=normal_init_): 26 | """Initialization of layers with normal distribution""" 27 | moduleclass = module.__class__.__name__ 28 | try: 29 | for layer in module: 30 | if layer.__class__.__name__ == 'Sequential': 31 | for l in layer: 32 | init_fn_(l, mean_, sd_, bias, norm_bias) 33 | else: 34 | init_fn_(layer, mean_, sd_, bias, norm_bias) 35 | except TypeError: 36 | init_fn_(module, mean_, sd_, bias, norm_bias) 37 | 38 | 39 | def xavier_init_(layer, mean_, sd_, bias, norm_bias=True): 40 | classname = layer.__class__.__name__ 41 | if classname.find('Linear') != -1: 42 | print('[INFO] (xavier_init) Initializing layer {}'.format(classname)) 43 | nn.init.xavier_uniform_(layer.weight.data) 44 | # nninit.xavier_normal(layer.bias.data) 45 | if norm_bias: 46 | layer.bias.data.normal_(0, 0.05) 47 | else: 48 | layer.bias.data.zero_() 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torch-dct 3 | numpyencoder 4 | matplotlib 5 | colorama 6 | pandas 7 | numpy -------------------------------------------------------------------------------- /stat_analysis/get_stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pandas as pd 4 | from scipy import stats 5 | 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | import numpy as np 9 | 10 | def get_stats(video_names, predictions, rep_out, type): 11 | with open(os.path.join(rep_out, f'stats_output_{type}.txt'), 'w') as f: 12 | original_stdout = sys.stdout 13 | sys.stdout = f 14 | data = pd.DataFrame({'VideoName': video_names, 'predScore': predictions}) 15 | 16 | data['Participant'] = data['VideoName'].apply(lambda x: x.split('_')[0]) 17 | data['MedicationStatus'] = data['VideoName'].apply(lambda x: x.split('_')[1]) 18 | 19 | on_medication_data = data[data['MedicationStatus'] == 'on'] 20 | mean_scores_on = on_medication_data.groupby('Participant')['predScore'].mean() 21 | 22 | off_medication_data = data[data['MedicationStatus'] == 'off'] 23 | mean_scores_off = off_medication_data.groupby('Participant')['predScore'].mean() 24 | 25 | print(mean_scores_on) 26 | print(mean_scores_off) 27 | 28 | participants_on = set(mean_scores_on.index) 29 | participants_off = set(mean_scores_off.index) 30 | common_participants = participants_on.intersection(participants_off) 31 | mean_scores_on_paired = mean_scores_on[mean_scores_on.index.isin(common_participants)] 32 | mean_scores_off_paired = mean_scores_off[mean_scores_off.index.isin(common_participants)] 33 | 34 | stat, p = stats.shapiro(mean_scores_on_paired) 35 | print('Shapiro-Wilk Test statistics=%.3f, p=%.3f' % (stat, p)) 36 | alpha = 0.05 37 | if p > alpha: 38 | print('ON Sample looks Gaussian (fail to reject H0)') 39 | else: 40 | print('ON Sample does not look Gaussian (reject H0)') 41 | 42 | stat, p = stats.shapiro(mean_scores_off_paired) 43 | print('Shapiro-Wilk Test statistics=%.3f, p=%.3f' % (stat, p)) 44 | alpha = 0.05 45 | if p > alpha: 46 | print('OFF Sample looks Gaussian (fail to reject H0)') 47 | else: 48 | print('OFF Sample does not look Gaussian (reject H0)') 49 | 50 | if mean_scores_on_paired.index.equals(mean_scores_off_paired.index): 51 | print("Datasets are properly paired.") 52 | else: 53 | print("Datasets are not properly paired.") 54 | 55 | stat, p_value = stats.wilcoxon(mean_scores_on_paired, mean_scores_off_paired) 56 | print(f'Wilcoxon Signed-Rank Test Statistic: {stat}') 57 | print(f'P-Value: {p_value}') 58 | 59 | t_statistic, p_value = stats.ttest_rel(mean_scores_on_paired, mean_scores_off_paired) 60 | print(f"t-test: T-Statistic: {t_statistic}") 61 | print(f"P-Value: {p_value}") 62 | 63 | sys.stdout = original_stdout 64 | 65 | 66 | fig, axes = plt.subplots(1, 2, figsize=(14, 7)) 67 | 68 | # Histogram for ON medication on the left subplot 69 | sns.histplot(mean_scores_on_paired, kde=True, bins=[-0.3, 0,0.3,0.7,1,1.3, 1.7, 2,2.3], color='blue', label='ON Medication', ax=axes[0]) 70 | axes[0].set_title('ON Medication') 71 | axes[0].set_xlabel('UPDRS Scores') 72 | axes[0].set_ylabel('Frequency') 73 | axes[0].legend() 74 | 75 | # Histogram for OFF medication on the right subplot 76 | sns.histplot(mean_scores_off_paired, kde=True, bins=[-0.3, 0,0.3,0.7,1,1.3, 1.7, 2,2.3], color='red', label='OFF Medication', ax=axes[1]) 77 | axes[1].set_title('OFF Medication') 78 | axes[1].set_xlabel('UPDRS Scores') 79 | axes[1].set_ylabel('Frequency') 80 | axes[1].legend() 81 | 82 | # Adjust the layout 83 | plt.tight_layout() 84 | 85 | plt.savefig(os.path.join(rep_out, f'pred_distribution_histogram_{type}.png')) 86 | plt.close() -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import pickle 4 | import pandas as pd 5 | import torch 6 | from torch import nn 7 | import wandb 8 | 9 | import pkg_resources 10 | from sklearn.metrics import classification_report, confusion_matrix 11 | 12 | from data.dataloaders import dataset_factory 13 | from model.motion_encoder import MotionEncoder 14 | from model.backbone_loader import load_pretrained_backbone, count_parameters, load_pretrained_weights 15 | from train import train_model, final_test 16 | from utility import utils 17 | from const import path 18 | from eval_encoder import log_results 19 | from stat_analysis.get_stats import get_stats 20 | 21 | 22 | def setup_experiment_path(params): 23 | exp_path = path.OUT_PATH + os.path.join(params['model_prefix'], str(params['last_run_foldnum'])) 24 | if not os.path.exists(exp_path): 25 | os.makedirs(exp_path) 26 | params['model_prefix'] = os.path.join(params['model_prefix'], str(params['last_run_foldnum'])) 27 | rep_out = path.OUT_PATH + os.path.join(params['model_prefix']) 28 | return params, rep_out 29 | 30 | 31 | def initialize_wandb(params): 32 | wandb.init(name=params['wandb_name'], project='MotionEncoderEvaluator_PD', settings=wandb.Settings(start_method='fork')) 33 | installed_packages = {d.project_name: d.version for d in pkg_resources.working_set} 34 | wandb.config.update(params) 35 | wandb.config.update({'installed_packages': installed_packages}) 36 | 37 | 38 | def map_to_classifier_dim(backbone_name, option): 39 | classifier_dims = { 40 | 'poseformer': {'option1': []}, 41 | 'motionbert': {'option1': []}, 42 | 'poseformerv2': {'option1': []}, 43 | 'mixste': {'option1': []}, 44 | 'motionagformer': {'option1': []} 45 | } 46 | return classifier_dims[backbone_name][option] 47 | 48 | 49 | def configure_params_for_best_model(params, backbone_name): 50 | best_params = { 51 | "lr": 1e-05, 52 | "num_epochs": 20, 53 | "num_hidden_layers": 2, 54 | "layer_sizes": [256, 50, 16, 3], 55 | "optimizer": 'RMSprop', 56 | "use_weighted_loss": True, 57 | "batch_size": 128, 58 | "dropout_rate": 0.1, 59 | 'weight_decay': 0.00057, 60 | 'momentum': 0.66 61 | } 62 | print_best_model_configuration(best_params, backbone_name) 63 | update_params_with_best(params, best_params, backbone_name) 64 | return params 65 | 66 | 67 | def print_best_model_configuration(best_params, backbone_name): 68 | print("====================================BEST MODEL====================================================") 69 | print(f"Trial {best_params['best_trial_number']}, lr: {best_params['lr']}, num_epochs: {best_params['num_epochs']}") 70 | print(f"classifier_hidden_dims: {map_to_classifier_dim(backbone_name, 'option1')}") 71 | print(f"optimizer_name: {best_params['optimizer']}, use_weighted_loss: {best_params['use_weighted_loss']}") 72 | print("========================================================================================") 73 | 74 | 75 | def update_params_with_best(params, best_params, backbone_name): 76 | params['classifier_dropout'] = best_params['dropout_rate'] 77 | params['classifier_hidden_dims'] = map_to_classifier_dim(backbone_name, 'option1') 78 | params['optimizer'] = best_params['optimizer'] 79 | params['lr_head'] = best_params['lr'] 80 | params['lambda_l1'] = best_params['lambda_l1'] 81 | params['epochs'] = best_params['num_epochs'] 82 | params['criterion'] = 'WCELoss' if best_params['use_weighted_loss'] else 'CrossEntropyLoss' 83 | if params['optimizer'] in ['AdamW', 'Adam', 'RMSprop']: 84 | params['weight_decay'] = best_params['weight_decay'] 85 | if params['optimizer'] == 'SGD': 86 | params['momentum'] = best_params['momentum'] 87 | params['wandb_name'] = params['wandb_name'] + '_test' + str(params['last_run_foldnum']) 88 | 89 | 90 | def run_fold_tests(params, all_folds, backbone_name, device, rep_out): 91 | splits = setup_datasets(params, backbone_name, all_folds) 92 | return run_tests_for_each_fold(params, splits, backbone_name, device, rep_out) 93 | 94 | 95 | def setup_datasets(params, backbone_name, all_folds): 96 | splits = [] 97 | for fold in all_folds: 98 | train_dataset_fn, test_dataset_fn, val_dataset_fn, class_weights = dataset_factory(params, backbone_name, fold) 99 | splits.append((train_dataset_fn, val_dataset_fn, test_dataset_fn, class_weights)) 100 | return splits 101 | 102 | 103 | def run_tests_for_each_fold(params, splits, backbone_name, device, rep_out): 104 | total_outs_best, total_outs_last, total_gts, total_logits, total_states, total_video_names = [], [], [], [], [], [] 105 | for fold, (train_dataset_fn, val_dataset_fn, test_dataset_fn, class_weights) in enumerate(splits): 106 | process_fold(fold, params, backbone_name, train_dataset_fn, val_dataset_fn, test_dataset_fn, class_weights, device, total_outs_best, total_gts, total_logits, total_states, total_video_names, total_outs_last, rep_out) 107 | return total_outs_best, total_gts, total_states, total_video_names, total_outs_last 108 | 109 | 110 | def process_fold(fold, params, backbone_name, train_dataset_fn, val_dataset_fn, test_dataset_fn, class_weights, device, total_outs_best, total_gts, total_logits, total_states, total_video_names, total_outs_last, rep_out): 111 | start_time = datetime.datetime.now() 112 | params['input_dim'] = train_dataset_fn.dataset._pose_dim 113 | params['pose_dim'] = train_dataset_fn.dataset._pose_dim 114 | params['num_joints'] = train_dataset_fn.dataset._NMAJOR_JOINTS 115 | 116 | model_backbone = load_pretrained_backbone(params, backbone_name) 117 | model = MotionEncoder(backbone=model_backbone, 118 | params=params, 119 | num_classes=params['num_classes'], 120 | num_joints=params['num_joints'], 121 | train_mode=params['train_mode']) 122 | model = model.to(device) 123 | if torch.cuda.device_count() > 1: 124 | print("Using", torch.cuda.device_count(), "GPUs!") 125 | model = nn.DataParallel(model) 126 | if fold == 1: 127 | model_params = count_parameters(model) 128 | print(f"[INFO] Model has {model_params} parameters.") 129 | 130 | train_model(params, class_weights, train_dataset_fn, val_dataset_fn, model, fold, backbone_name) 131 | 132 | checkpoint_root_path = os.path.join(path.OUT_PATH, params['model_prefix'],'models', f"fold{fold}") 133 | best_ckpt_path = os.path.join(checkpoint_root_path, 'best_epoch.pth.tr') 134 | load_pretrained_weights(model, checkpoint=torch.load(best_ckpt_path)['model']) 135 | model.cuda() 136 | outs, gts, logits, states, video_names = final_test(model, test_dataset_fn, params) 137 | total_outs_best.extend(outs) 138 | total_gts.extend(gts) 139 | total_states.extend(states) 140 | total_video_names.extend(video_names) 141 | print(f'fold # of test samples: {len(video_names)}') 142 | print(f'current sum # of test samples: {len(total_video_names)}') 143 | attributes = [total_outs_best, total_gts] 144 | names = ['predicted_classes', 'true_labels'] 145 | res_dir = path.OUT_PATH + os.path.join(params['model_prefix'], 'results') 146 | if not os.path.exists(res_dir): 147 | os.makedirs(res_dir) 148 | utils.save_json(os.path.join(res_dir, 'results_Best_fold{}.json'.format(fold)), attributes, names) 149 | 150 | total_logits.extend(logits) 151 | attributes = [total_logits, total_gts] 152 | 153 | logits_dir = path.OUT_PATH + os.path.join(params['model_prefix'], 'logits') 154 | if not os.path.exists(logits_dir): 155 | os.makedirs(logits_dir) 156 | utils.save_json(os.path.join(logits_dir, 'logits_Best_fold{}.json'.format(fold)), attributes, names) 157 | 158 | last_ckpt_path = os.path.join(checkpoint_root_path, 'latest_epoch.pth.tr') 159 | load_pretrained_weights(model, checkpoint=torch.load(last_ckpt_path)['model']) 160 | model.cuda() 161 | outs_last, gts, logits, states, video_names = final_test(model, test_dataset_fn, params) 162 | total_outs_last.extend(outs_last) 163 | attributes = [total_outs_last, total_gts] 164 | utils.save_json(os.path.join(res_dir, 'results_last_fold{}.json'.format(fold)), attributes, names) 165 | 166 | res = pd.DataFrame({'total_video_names': total_video_names, 'total_outs_best': total_outs_best, 'total_outs_last': total_outs_last, 'total_gts':total_gts, 'total_states':total_states}) 167 | with open(os.path.join(rep_out, f'total_results_fold{fold}.pkl'), 'wb') as file: 168 | pickle.dump(res, file) 169 | 170 | end_time = datetime.datetime.now() 171 | 172 | duration = end_time - start_time 173 | print(f"Fold {fold} run time:", duration) 174 | 175 | 176 | def calculate_metrics(outputs, targets, states, phase, report_prefix, output_dir): 177 | # Filter outputs and targets based on the phase ('ON' or 'OFF') 178 | filtered_gts = [gt for gt, state in zip(targets, states) if state == phase] 179 | filtered_outs = [out for out, state in zip(outputs, states) if state == phase] 180 | 181 | report = classification_report(filtered_gts, filtered_outs) 182 | confusion = confusion_matrix(filtered_gts, filtered_outs) 183 | 184 | log_results( 185 | report, confusion, 186 | f'{report_prefix}_allfolds_{phase}.txt', 187 | f'{report_prefix}_confusion_matrix_allfolds_{phase}.png', 188 | output_dir 189 | ) 190 | 191 | def process_reports(outputs_best, outputs_last, targets, states, output_dir): 192 | # Process reports for 'best' and 'last' data 193 | for prefix, outputs in [('best', outputs_best), ('last', outputs_last)]: 194 | print(f"=========={prefix.upper()} REPORTS============") 195 | # Full dataset metrics 196 | report_final = classification_report(targets, outputs) 197 | confusion_final = confusion_matrix(targets, outputs) 198 | log_results(report_final, confusion_final, f'{prefix}_report_allfolds.txt', f'{prefix}_confusion_matrix_allfolds.png', output_dir) 199 | 200 | # 'ON' and 'OFF' group metrics 201 | for phase in ['ON', 'OFF']: 202 | calculate_metrics(outputs, targets, states, phase, prefix, output_dir) 203 | 204 | def save_and_load_results(video_names, outputs_best, outputs_last, targets, output_dir): 205 | results = pd.DataFrame({ 206 | 'total_video_names': video_names, 207 | 'total_outs_best': outputs_best, 208 | 'total_outs_last': outputs_last, 209 | 'total_gts': targets 210 | }) 211 | results_path = os.path.join(output_dir, 'final_results.pkl') 212 | with open(results_path, 'wb') as file: 213 | pickle.dump(results, file) 214 | 215 | with open(results_path, 'rb') as file: 216 | loaded_results = pickle.load(file) 217 | 218 | total_video_names = loaded_results['total_video_names'] 219 | total_outs_best = loaded_results['total_outs_best'] 220 | total_outs_last = loaded_results['total_outs_last'] 221 | 222 | get_stats(total_video_names, total_outs_best, output_dir, 'best') 223 | get_stats(total_video_names, total_outs_last, output_dir, 'last') 224 | 225 | 226 | def test_and_report(params, new_params, all_folds, backbone_name, device): 227 | params, rep_out = setup_experiment_path(params) 228 | params = configure_params_for_best_model(params, backbone_name) 229 | initialize_wandb(params) 230 | total_outs_best, total_gts, total_states, total_video_names, total_outs_last = run_fold_tests(params, all_folds, backbone_name, device, rep_out) 231 | process_reports(total_outs_best, total_outs_last, total_gts, total_states, rep_out) 232 | save_and_load_results(total_video_names, total_outs_best, total_outs_last, total_gts, rep_out) 233 | wandb.finish() 234 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import wandb 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | from learning.criterion import choose_criterion 9 | from learning.optimizer import choose_optimizer, choose_scheduler 10 | from learning.utils import AverageMeter, accuracy, save_checkpoint, assert_learning_params 11 | from const import path, const 12 | from utility.utils import is_substring, check_uniformity_and_get_first_elements 13 | 14 | from collections import Counter, defaultdict 15 | from sklearn.metrics import f1_score 16 | 17 | import time 18 | from collections import Counter 19 | 20 | device = const._DEVICE 21 | 22 | def final_test(model, test_loader, params): 23 | model.eval() 24 | video_logits = defaultdict(list) 25 | video_predclasses = defaultdict(list) 26 | video_labels = defaultdict(list) 27 | video_indices = defaultdict(list) 28 | video_states = defaultdict(list) 29 | video_names = defaultdict(list) 30 | 31 | loop = tqdm(test_loader) 32 | with torch.no_grad(): 33 | for x, y, video_idx, metadata in loop: 34 | x, y = x.to(device), y.to(device) 35 | metadata = metadata.to(device) 36 | if params['medication']: 37 | vi = video_idx.tolist() 38 | vn = [test_loader.dataset.video_names[i] for i in vi] 39 | on_off = [1 if 'on' in name else 0 for name in vn] 40 | on_off = torch.tensor(on_off, dtype=torch.float32, device=device) 41 | out = model(x, metadata, on_off) 42 | else: 43 | out = model(x, metadata) 44 | 45 | # Assuming out is a single tensor representing the output of the model for all clips 46 | summed_logits = torch.sum(out, dim=0).cpu().numpy() 47 | 48 | # Get the predicted class 49 | predicted_class = torch.argmax(torch.sum(out, dim=0).cpu()).item() 50 | 51 | # Append the logits, predicted class, and ground truth label for the video 52 | video_logits[video_idx.item()].append(summed_logits) 53 | video_predclasses[video_idx.item()].append(predicted_class) 54 | video_labels[video_idx.item()].append(y[0].item()) 55 | video_indices[video_idx.item()].append(video_idx) 56 | 57 | # Retrieve and store the video name using video_idx 58 | video_name = test_loader.dataset.video_names[video_idx.item()] 59 | # video_names.append(video_name) 60 | video_names[video_idx.item()].append(video_name) 61 | 62 | if is_substring('on', test_loader.dataset.video_names[video_idx]): 63 | video_states[video_idx.item()].append('ON') 64 | else: 65 | video_states[video_idx.item()].append('OFF') 66 | 67 | #Just to make sure everything is ok with the process of gathering clips 68 | video_labels = check_uniformity_and_get_first_elements(list(video_labels.values())) 69 | video_indices = check_uniformity_and_get_first_elements(list(video_indices.values())) 70 | video_states = check_uniformity_and_get_first_elements(list(video_states.values())) 71 | video_names = check_uniformity_and_get_first_elements(list(video_names.values())) 72 | 73 | # Summing logits in each clip 74 | summed_video_logits = {idx: np.sum(logits, axis=0) for idx, logits in video_logits.items()} 75 | # Majority vote for predicted classes 76 | majority_vote_classes = {} 77 | for idx, classes in video_predclasses.items(): 78 | class_counts = Counter(classes) 79 | majority_class = class_counts.most_common(1)[0][0] 80 | majority_vote_classes[idx] = majority_class 81 | 82 | return list(majority_vote_classes.values()), video_labels, list(summed_video_logits.values()), video_states, video_names 83 | 84 | 85 | def validate_model(model, validation_loader, params, class_weights): 86 | criterion = choose_criterion(params['criterion'], params, class_weights) 87 | 88 | if torch.cuda.is_available(): 89 | model = model.to(device) 90 | criterion = criterion.to(device) 91 | else: 92 | raise Exception("Cuda is not enabled") 93 | 94 | model.eval() 95 | accuracies = AverageMeter() 96 | losses = AverageMeter() 97 | all_preds = [] 98 | all_labels = [] 99 | with torch.no_grad(): 100 | video_predictions = defaultdict(list) 101 | video_predictions_labels = defaultdict(list) 102 | video_labels = {} 103 | 104 | for x, y, video_idx, metadata in validation_loader: 105 | x, y = x.to(device), y.to(device) 106 | metadata = metadata.to(device) 107 | batch_size = x.shape[0] 108 | 109 | if params['medication']: 110 | vi = video_idx.tolist() 111 | vn = [validation_loader.dataset.video_names[i] for i in vi] 112 | on_off = [1 if 'on' in name else 0 for name in vn] 113 | on_off = torch.tensor(on_off, dtype=torch.float32, device=device) 114 | out = model(x, metadata, on_off) 115 | else: 116 | out = model(x, metadata) 117 | _, out_label = torch.max(out, 1) 118 | 119 | loss = criterion(out, y) 120 | losses.update(loss.item(), batch_size) 121 | 122 | for i, idx in enumerate(video_idx): 123 | video_predictions_labels[idx.item()].append(out_label[i].detach()) 124 | video_predictions[idx.item()].append(out[i].detach()) 125 | video_labels[idx.item()] = y[i].item() 126 | 127 | total_correct = 0 128 | total_videos = 0 129 | for video_idx in video_predictions: 130 | predictions = video_predictions[video_idx] 131 | label_predictions = video_predictions_labels[video_idx] 132 | label_predictions = [label.item() for label in label_predictions] 133 | 134 | video_prediction = torch.stack(predictions).mean(dim=0).unsqueeze(0) 135 | video_label = torch.tensor([video_labels[video_idx]], device=video_prediction.device) 136 | label_counts = Counter(label_predictions) 137 | video_prediction_label, _ = label_counts.most_common(1)[0] 138 | 139 | acc, = accuracy(video_prediction, video_label) 140 | total_correct += acc 141 | 142 | total_videos += 1 143 | all_preds.append(video_prediction_label) 144 | all_labels.extend(video_label.cpu().numpy()) 145 | 146 | video_accuracy = total_correct / total_videos 147 | accuracies.update(video_accuracy, total_videos) 148 | val_f1_score = f1_score(all_labels, all_preds, average='weighted') 149 | 150 | return losses.avg, accuracies.avg, val_f1_score 151 | 152 | 153 | def train_model(params, class_weights, train_loader, val_loader, model, fold, backbone_name, mode="RUN"): 154 | assert_learning_params(params) 155 | 156 | criterion = choose_criterion(params['criterion'], params, class_weights) 157 | 158 | if torch.cuda.is_available(): 159 | model = model.to(device) 160 | criterion = criterion.to(device) 161 | else: 162 | raise Exception("Cuda is not enabled") 163 | 164 | optimizer = choose_optimizer(model, params) 165 | scheduler = choose_scheduler(optimizer, params) 166 | 167 | checkpoint_root_path = os.path.join(path.OUT_PATH, params['model_prefix'],'models') 168 | if not os.path.exists(checkpoint_root_path): os.mkdir(checkpoint_root_path) 169 | 170 | loop = tqdm(range(params['epochs']), desc=f'Training (fold{fold})', unit="epoch") 171 | for epoch in loop: 172 | # print(f"[INFO] epoch {epoch}") 173 | train_acc = AverageMeter() 174 | train_loss = AverageMeter() 175 | 176 | model.train() 177 | 178 | video_predictions = defaultdict(list) 179 | video_labels = {} 180 | 181 | epoch_start_time = time.time() 182 | for x, y, video_idx, metadata in train_loader: 183 | x, y = x.to(device), y.to(device) 184 | metadata = metadata.to(device) 185 | 186 | 187 | batch_size = x.shape[0] 188 | optimizer.zero_grad() 189 | 190 | if params['medication']: 191 | vi = video_idx.tolist() 192 | vn = [train_loader.dataset.video_names[i] for i in vi] 193 | on_off = [1 if 'on' in name else 0 for name in vn] 194 | on_off = torch.tensor(on_off, dtype=torch.float32, device=device) 195 | out = model(x, metadata, on_off) 196 | else: 197 | out = model(x, metadata) 198 | 199 | loss = criterion(out, y) 200 | train_loss.update(loss.item(), batch_size) 201 | 202 | for i, idx in enumerate(video_idx): 203 | video_predictions[idx.item()].append(out[i].detach()) 204 | video_labels[idx.item()] = y[i].item() 205 | 206 | if params['lambda_l1'] > 0: 207 | learnable_params = torch.cat([param.view(-1) for param in model.parameters() if param.requires_grad]) 208 | l1_regularization = torch.norm(learnable_params, p=1) 209 | 210 | loss += params['lambda_l1'] * l1_regularization 211 | 212 | loss.backward() 213 | optimizer.step() 214 | 215 | 216 | epoch_time = time.time() - epoch_start_time 217 | print(f"Epoch {epoch} completed in {epoch_time:.2f}s") 218 | # Compute accuracy per video 219 | total_correct = 0 220 | total_videos = 0 221 | for video_idx, predictions in video_predictions.items(): 222 | video_prediction = torch.stack(predictions).mean(dim=0).unsqueeze(0) 223 | video_label = torch.tensor([video_labels[video_idx]], device=video_prediction.device) 224 | 225 | acc, = accuracy(video_prediction, video_label) 226 | total_correct += acc 227 | total_videos += 1 228 | 229 | video_accuracy = total_correct / total_videos 230 | train_acc.update(video_accuracy, total_videos) 231 | 232 | val_loss, val_acc, val_f1_score = validate_model(model, val_loader, params, class_weights) 233 | 234 | lr_backbone = optimizer.param_groups[0]['lr'] 235 | 236 | if scheduler: 237 | scheduler.step() 238 | 239 | loop.set_postfix(train_loss=train_loss.avg, train_accuracy=train_acc.avg, 240 | val_loss=val_loss, val_accuracy=val_acc, val_f1_score=val_f1_score) 241 | 242 | log_wandb(epoch, fold, lr_backbone, train_acc, train_loss, 1, val_acc, 243 | val_loss, val_f1_score) 244 | 245 | if mode == "RUN": 246 | save_checkpoint(checkpoint_root_path, epoch, lr_backbone, optimizer, model, None, fold, latest=True) 247 | print(f'Checkpoint saved at: {checkpoint_root_path}') 248 | 249 | 250 | 251 | def log_wandb(epoch, fold, lr_backbone, train_acc, train_loss, use_validation, validation_acc, 252 | validation_loss, validation_f1): 253 | log_dict = { 254 | f'train/fold{fold}_lr': lr_backbone, 255 | f'train_loss/fold{fold}_loss': train_loss.avg, 256 | f'train_accuracy/fold{fold}_accuracy': train_acc.avg, 257 | f'epoch': epoch, 258 | } 259 | if use_validation: 260 | log_dict[f'eval_loss/fold{fold}_loss'] = validation_loss 261 | log_dict[f'eval_acc/fold{fold}_accuracy'] = validation_acc 262 | log_dict[f'eval_f1/fold{fold}_f1'] = validation_f1 263 | wandb.log(log_dict) -------------------------------------------------------------------------------- /utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaatiTeam/MotionEncoders_parkinsonism_benchmark/ce1d617169348ff0f29fb8d1a3e1469c11a87c7e/utility/__init__.py -------------------------------------------------------------------------------- /utility/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import json 5 | import random 6 | import torch 7 | from numpyencoder import NumpyEncoder 8 | import re 9 | 10 | 11 | def set_random_seed(seed): 12 | """Sets random seed for training reproducibility""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | 17 | 18 | def create_dir_tree(base_dir, numfolds): 19 | dir_tree = ['models', 'config', 'std_log'] 20 | last_run = 1 21 | for dir_ in dir_tree: 22 | if dir_ == dir_tree[0]: 23 | if not os.path.exists(os.path.join(base_dir)): 24 | os.makedirs(os.path.join(base_dir, str(last_run), dir_)) 25 | else: 26 | last_run = np.max(list(map(int, os.listdir(base_dir)))) 27 | last_run += 1 28 | if not os.path.exists( 29 | os.path.join(base_dir, str(last_run - 1), 'classification_report_last.txt')): 30 | last_run -= 1 31 | shutil.rmtree(os.path.join(base_dir, str(last_run))) 32 | os.makedirs(os.path.join(base_dir, str(last_run), dir_)) 33 | else: 34 | os.makedirs(os.path.join(base_dir, str(last_run), dir_)) 35 | return last_run 36 | 37 | def create_dir_tree2(base_dir, last_run): 38 | dir_tree = ['models', 'config'] 39 | for dir_ in dir_tree: 40 | if dir_ == dir_tree[0]: 41 | if not os.path.exists(os.path.join(base_dir, str(last_run))): 42 | os.makedirs(os.path.join(base_dir, str(last_run), dir_)) 43 | else: 44 | shutil.rmtree(os.path.join(base_dir, str(last_run))) 45 | os.makedirs(os.path.join(base_dir, str(last_run), dir_)) 46 | else: 47 | os.makedirs(os.path.join(base_dir, str(last_run), dir_)) 48 | 49 | def save_json(filename, attributes, names): 50 | """ 51 | Save training parameters and evaluation results to json file. 52 | :param filename: save filename 53 | :param attributes: attributes to save 54 | :param names: name of attributes to save in json file 55 | """ 56 | with open(filename, "w", encoding="utf8") as outfile: 57 | d = {} 58 | for i in range(len(attributes)): 59 | name = names[i] 60 | attribute = attributes[i] 61 | d[name] = attribute 62 | json.dump(d, outfile, indent=4, cls=NumpyEncoder) 63 | 64 | 65 | def is_substring(str1, str2): 66 | return str1.lower() in str2.lower() 67 | 68 | 69 | 70 | def check_and_get_first_elements(list_of_lists): 71 | """ 72 | check that all elements within each inner list are the same 73 | and also retrieve the first element of each inner list 74 | 75 | Parameters: list_of_lists (list of lists): A list containing inner lists to be checked. 76 | Returns: list: A list of the first elements from each uniform inner list. 77 | Raises: ValueError: If any inner list is empty or contains non-uniform elements. 78 | """ 79 | first_elements = [] 80 | 81 | for inner_list in list_of_lists: 82 | if not inner_list: 83 | raise ValueError("One of the inner lists is empty.") 84 | 85 | first_element = inner_list[0] 86 | if all(element == first_element for element in inner_list): 87 | first_elements.append(first_element) 88 | else: 89 | raise ValueError(f"Elements in the inner list {inner_list} differ.") 90 | 91 | return first_elements 92 | 93 | def check_uniformity_and_get_first_elements(mainlist): 94 | try: 95 | mainlist = check_and_get_first_elements(mainlist) 96 | # print("First elements of each uniform inner list:", mainlist) 97 | return mainlist 98 | except ValueError as e: 99 | print(f"Error: {e}") 100 | 101 | 102 | 103 | def natural_sort_key(s): 104 | return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] --------------------------------------------------------------------------------