├── LICENSE ├── README.md ├── images ├── caption_generation_model.png ├── captioning_pipeline.png ├── comparative_pipeline.png ├── feedback_generation_model.png ├── generative_model.png ├── main_picture.png ├── pose_editing_model.png ├── retrieval_model.png └── retrieval_modifier_model.png ├── pretrained_models.md ├── requirements.txt ├── setup.py └── src ├── __init__.py ├── other_utils ├── README.md ├── pair_mining.py └── pose_mining.py └── text2pose ├── __init__.py ├── config.py ├── data.py ├── data_augmentations.py ├── demo.py ├── encoders ├── __init__.py ├── modules.py ├── pose_encoder_decoder.py ├── text_decoders.py ├── text_encoders.py └── tokenizers.py ├── evaluate.py ├── fid.py ├── generative ├── README.md ├── __init__.py ├── demo_generative.py ├── evaluate_generative.py ├── generate_poses.py ├── look_at_generated_pose_samples.py ├── model_generative.py ├── script_generative.sh └── train_generative.py ├── generative_B ├── README.md ├── __init__.py ├── demo_generative_B.py ├── evaluate_generative_B.py ├── model_generative_B.py ├── script_generative_B.sh └── train_generative_B.py ├── generative_caption ├── README.md ├── __init__.py ├── demo_generative_caption.py ├── evaluate_generative_caption.py ├── model_generative_caption.py ├── script_generative_caption.sh └── train_generative_caption.py ├── generative_modifier ├── README.md ├── __init__.py ├── demo_generative_modifier.py ├── evaluate_generative_modifier.py ├── model_generative_modifier.py ├── script_generative_modifier.sh └── train_generative_modifier.py ├── loss.py ├── option.py ├── posefix ├── README.md ├── __init__.py ├── compute_rotation_change.py ├── correcting.py ├── corrective_data.py ├── explore_posefix.py └── paircodes.py ├── posescript ├── README.md ├── __init__.py ├── action_to_sent_template.json ├── captioning.py ├── captioning_data.py ├── compute_coords.py ├── explore_posescript.py ├── format_babel_labels.py ├── format_contact_info.py ├── posecodes.py ├── smplx_custom_semantic_segmentation.json └── utils.py ├── retrieval ├── README.md ├── __init__.py ├── demo_retrieval.py ├── evaluate_retrieval.py ├── model_retrieval.py ├── script_retrieval.sh └── train_retrieval.py ├── retrieval_modifier ├── README.md ├── __init__.py ├── demo_retrieval_modifier.py ├── evaluate_retrieval_modifier.py ├── model_retrieval_modifier.py ├── script_retrieval_modifier.sh └── train_retrieval_modifier.py ├── shortname_2_model_path.txt ├── trainer.py ├── utils.py ├── utils_logging.py ├── utils_visu.py └── vocab.py /images/caption_generation_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/caption_generation_model.png -------------------------------------------------------------------------------- /images/captioning_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/captioning_pipeline.png -------------------------------------------------------------------------------- /images/comparative_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/comparative_pipeline.png -------------------------------------------------------------------------------- /images/feedback_generation_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/feedback_generation_model.png -------------------------------------------------------------------------------- /images/generative_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/generative_model.png -------------------------------------------------------------------------------- /images/main_picture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/main_picture.png -------------------------------------------------------------------------------- /images/pose_editing_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/pose_editing_model.png -------------------------------------------------------------------------------- /images/retrieval_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/retrieval_model.png -------------------------------------------------------------------------------- /images/retrieval_modifier_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/retrieval_modifier_model.png -------------------------------------------------------------------------------- /pretrained_models.md: -------------------------------------------------------------------------------- 1 | # Available trained models 2 | 3 | ## Download links 4 | 5 | | Subdirectory | Model (link to README) | Model shortname | Link | Main results| 6 | |---|---|---|---|---| 7 | | `retrieval` | [text-to-pose retrieval model](./src/text2pose/retrieval/README.md) | `ret_distilbert_dataPSA2ftPSH2` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/ret_distilbert_dataPSA2ftPSH2.zip) | **mRecall** = 47.92
**R@1 Precision (GT)** = 84.76 | 8 | | `retrieval_modifier` | [pose-pair-to-instruction retrieval model](./src/text2pose/retrieval_modifier/README.md) | `modret_distilbert_dataPFAftPFH` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/modret_distilbert_dataPFAftPFH.zip) | **mRecall** = 30.00
**R@1 Precision (GT)** = 68.04 | 9 | | `generative` | [text-conditioned pose generation model](./src/text2pose/generative/README.md) | `gen_distilbert_dataPSA2ftPSH2` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/gen_distilbert_dataPSA2ftPSH2.zip) | **ELBO jts/vert/rot** = 1.44 / 1.82 / 0.90 | 10 | | `generative_B` | [text-guided pose editing model](./src/text2pose/generative_B/README.md) | `b_gen_distilbert_dataPFAftPFH` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/b_gen_distilbert_dataPFAftPFH.zip) | **ELBO jts/vert/rot** = 1.43 / 1.90 / 1.00 | 11 | | `generative_caption` | [pose description generation model](./src/text2pose/generative_caption/README.md) | `capgen_CAtransfPSA2H2_dataPSA2ftPSH2` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/capgen_CAtransfPSA2H2_dataPSA2ftPSH2.zip) | **R@1 Precision** = 89.38
**MPJE_30** = 202
**ROUGE-L** = 33.95 | 12 | | `generative_modifier` | [pose-based correctional text generation model](./src/text2pose/generative_modifier/README.md) | `modgen_CAtransfPFAHPP_dataPFAftPFH` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/modgen_CAtransfPFAHPP_dataPFAftPFH.zip) | **R@1 Precision** = 78.85
**MPJE_30** = 186
**ROUGE-L** = 33.53 | 13 | 14 | Unzip the archives and place the content of the resulting directories in ***GENERAL_EXP_OUTPUT_DIR***. 15 | 16 | **Note:** these models are the result of a two-stage training, involving a pretraining stage on automatic texts, and a finetuning stage on human-written annotations. 17 | 18 |
19 | Bash script to download & unzip everything all at once. 20 | 21 | ```bash 22 | cd "" # TODO replace! 23 | 24 | arr=( 25 | ret_distilbert_dataPSA2ftPSH2 26 | modret_distilbert_dataPFAftPFH 27 | gen_distilbert_dataPSA2ftPSH2 28 | b_gen_distilbert_dataPFAftPFH 29 | capgen_CAtransfPSA2H2_dataPSA2ftPSH2 30 | modgen_CAtransfPFAHPP_dataPFAftPFH 31 | ) 32 | 33 | for a in "${arr[@]}"; do 34 | echo "Download and extract $a" 35 | wget "https://download.europe.naverlabs.com/ComputerVision/PoseFix/${a}.zip" 36 | unzip "${a}.zip" 37 | rm "${a}.zip" 38 | done 39 | ``` 40 | 41 |
42 | 43 |
44 | Differences in results with the papers. 45 | 46 | * *Text-to-pose retrieval*: providing an improved model, pretrained on new automatic captions, and with a symmetric constrastive loss (vs. uni-directional contrastive loss in the paper) 47 | * *Instruction-to-pair retrieval*: providing an improved model trained with symmetric contrastive loss (vs. uni-directional contrastive loss in the paper). 48 | * *Pose editing:* the provided model uses a transformer-based text encoder (frozen DistilBert + learned transformer), for consistency with the other provided models (vs. GloVe+biGRU configuration used to report results in the paper). Note: this model was finetuned using the best setting as per Table 4: with L/R flip and paraphrases. The FID value may also change as evaluation is carried out with an improved version of the text-to-pose retrieval model. 49 | * *Text generation models*: evaluated with improved retrieval models; also note that, despite an average over 10 repetitions, R-precision metrics come with a great variability due to the randomized selection of the pool of samples to compare against. 50 |
51 | 52 | ## References in `shortname_2_model_path.txt` 53 | 54 | References should be given using the following format: 55 | 56 | ``` 57 | <4 spaces> 58 | ``` 59 | 60 | Thus, for the above-mentioned models (simply replace `` by its proper value): 61 | ```text 62 | ret_distilbert_dataPSA2ftPSH2 /ret_distilbert_dataPSA2ftPSH2/seed1/checkpoint_best.pth 63 | modret_distilbert_dataPFAftPFH /modret_distilbert_dataPFAftPFH/seed1/checkpoint_best.pth 64 | gen_distilbert_dataPSA2ftPSH2 /gen_distilbert_dataPSA2ftPSH2/seed1/checkpoint_best.pth 65 | b_gen_distilbert_dataPFAftPFH /b_gen_distilbert_dataPFAftPFH/seed1/checkpoint_best.pth 66 | capgen_CAtransfPSA2H2_dataPSA2ftPSH2 /capgen_CAtransfPSA2H2_dataPSA2ftPSH2/seed1/checkpoint_best.pth 67 | modgen_CAtransfPFAHPP_dataPFAftPFH /modgen_CAtransfPFAHPP_dataPFAftPFH/seed1/checkpoint_best.pth 68 | ``` 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | --find-links https://download.pytorch.org/whl/torch_stable.html 3 | 4 | git+https://github.com/nghorbani/body_visualizer 5 | git+https://github.com/MPI-IS/configer.git 6 | git+https://github.com/MPI-IS/mesh.git 7 | git+https://github.com/nghorbani/human_body_prior.git 8 | 9 | torch==1.10.1 10 | torchtext==0.11.1 11 | torchvision==0.11.2 12 | nltk 13 | smplx 14 | matplotlib 15 | opencv-python 16 | transformers 17 | trimesh[easy] 18 | pyrender 19 | roma 20 | streamlit==1.23.1 21 | tabulate 22 | tensorboard==2.11.2 23 | setuptools==59.5.0 24 | bert_score 25 | tqdm 26 | evaluate -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | from setuptools import setup, find_packages 11 | 12 | setup(name='text2pose', 13 | version='2.0', 14 | packages=find_packages('src'), 15 | package_dir={'': 'src'}, 16 | include_package_data=True, 17 | author='Ginger Delmas', 18 | author_email='ginger.delmas.pro@gmail.com', 19 | description='PoseScript ECCV22 & PoseFix ICCV23.', 20 | long_description=open("README.md").read(), 21 | long_description_content_type="text/markdown", 22 | install_requires=[], 23 | dependency_links=[], 24 | ) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/other_utils/README.md: -------------------------------------------------------------------------------- 1 | # Other potentially useful code 2 | 3 | This directory presents some code that has been useful in different stages of the project (eg. for dataset construction), but does not contribute directly to the training or testing of the models proposed in the papers. In particular, this code is not necessarily cleaned nor well documented, and may trigger a bunch of errors / require some additional setup before running successfully (watch for "TODO" marks) -- apologies! Hopefully, this code can still prove useful for concrete understanding of some aspects of the project. 4 | 5 | Included: 6 | * pose mining (pose selection in PoseScript; also corresponding to poses "B" in PoseFix) 7 | * pair mining (process to select pairs in PoseFix) 8 | 9 | *Note:* for the reasons mentioned previously, this code is currently seccluded in this separate directory. Note that it requires some functions defined in the rest of the repository, so one may need to run `python setup.py develop` to import text2pose as a package first... -------------------------------------------------------------------------------- /src/other_utils/pose_mining.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2024 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import os 11 | import torch 12 | from tqdm import tqdm 13 | 14 | 15 | # SETUP 16 | ################################################################################ 17 | 18 | DATA_LOCATION = "TODO" # where to save files 19 | 20 | 21 | # UTILS 22 | ################################################################################ 23 | 24 | def MPJE(coords_set, coords_one): 25 | """ 26 | Args: 27 | coords_set: tensor (B, nb joints, 3) 28 | coords_one: tensor (nb joints, 3) 29 | 30 | Returns: 31 | tensor (B) giving the mean per joint distance between each pose in the 32 | set and the provided pose. 33 | """ 34 | return torch.norm(coords_set - coords_one, dim=2).mean(1) # (B) 35 | 36 | 37 | def farther_sampling_resume(data, selected, distance, distance_func, nb_to_select=10): 38 | """ 39 | Args: 40 | data (torch.tensor): size (number of poses, number of joints, 3) 41 | or (number of poses, number of features) 42 | selected (list): indices of the poses that were already selected 43 | distance (torch.tensor): size (number of poses), distance between each 44 | pose and the closest pose of the selected set 45 | distance_func (function): computes the distance between 2 poses 46 | nb_to_select (int): number of data points to select, in addition of the 47 | ones already selected 48 | 49 | Returns: 50 | selected (list): indices of the selected poses 51 | distance (torch.tensor): size (number of poses), distance between each 52 | pose and the closest pose of the selected set 53 | """ 54 | 55 | nb_to_select = min(data.size(0)-len(selected), nb_to_select) 56 | 57 | for _ in tqdm(range(0, nb_to_select)): 58 | distance_update = distance_func(data, data[selected[-1]]) 59 | distance = torch.amin(torch.cat((distance.view(-1,1), distance_update.view(-1,1)), 1), dim=1) 60 | selected.append(torch.argmax(distance).item()) 61 | 62 | return selected, distance 63 | 64 | 65 | def get_diversity_pose_order(coords, suffix, split, nb_select=10, seed=0, resume=False): 66 | """ 67 | coords: dict {data_id: 3D coords of the main joints (torch.tensor shape (n_joints, 3))} 68 | suffix: for file naming 69 | """ 70 | 71 | # prepare data for the farther sampling 72 | # (resume, if applicable) 73 | data_ids = sorted(coords.keys()) 74 | coords = torch.stack([coords[did] for did in data_ids]) # (nb poses, nb joints, 3) 75 | nb_select = min(len(coords), nb_select) 76 | 77 | if resume: 78 | filepath_resume_from = os.path.join(DATA_LOCATION, f"farther_sample_{resume}_{suffix}.pt") 79 | data_ids_, selected, distance = torch.load(filepath_resume_from) 80 | assert data_ids == data_ids_, "Cannot resume. Data changed!" 81 | print("Resuming from:", filepath_resume_from) 82 | else: 83 | selected = [seed] # to make the results somewhat reproducible 84 | distance = torch.ones(len(coords)) * float('inf') 85 | print(f"Farther sampling from seed {seed}.") 86 | 87 | # farther sample 88 | print(f"Sampling from {len(coords)} elements. Number of elements already selected: {len(selected)}.") 89 | selected, distance = farther_sampling_resume(coords, selected, distance, MPJE, nb_select) 90 | 91 | # save 92 | filesave = os.path.join(DATA_LOCATION, f"farther_sample_{split}_{nb_select}_{suffix}.pt") 93 | torch.save([data_ids, selected, distance], filesave) 94 | print("Saved:", filesave) 95 | 96 | 97 | # MAIN 98 | ################################################################################ 99 | 100 | if __name__ == "__main__": 101 | 102 | import argparse 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument('--split', type=str, choices=('training', 'validation'), default='validation') 105 | parser.add_argument('--nb_select', type=int, default=50000) 106 | parser.add_argument('--resume', type=int, default=0) 107 | args = parser.parse_args() 108 | 109 | suffix = "try" 110 | get_diversity_pose_order(split=args.split, suffix=suffix, nb_select=args.nb_select, resume=args.resume) -------------------------------------------------------------------------------- /src/text2pose/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/config.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | # This file serves to store global config parameters and paths 11 | # (those may change depending on the user, provided data, trained models...) 12 | 13 | # default 14 | import os 15 | MAIN_DIR = os.path.realpath(__file__) 16 | MAIN_DIR = os.path.dirname(os.path.dirname(os.path.dirname(MAIN_DIR))) 17 | 18 | 19 | ################################################################################ 20 | # Output dir for experiments 21 | ################################################################################ 22 | 23 | GENERAL_EXP_OUTPUT_DIR = MAIN_DIR + '/experiments' 24 | 25 | 26 | ################################################################################ 27 | # Data 28 | ################################################################################ 29 | 30 | POSEFIX_LOCATION = MAIN_DIR + '/data/PoseFix/posefix_release' 31 | POSESCRIPT_LOCATION = MAIN_DIR + '/data/PoseScript/posescript_release' 32 | POSEMIX_LOCATION = MAIN_DIR + '/data/posemix' 33 | 34 | version_suffix = "_100k" # to be used for pipeline-related data (coords, rotation change, babel labels) 35 | file_pose_id_2_dataset_sequence_and_frame_index = f"{POSESCRIPT_LOCATION}/ids_2_dataset_sequence_and_frame_index_100k.json" 36 | file_pair_id_2_pose_ids = f"{POSEFIX_LOCATION}/pair_id_2_pose_ids.json" 37 | file_posescript_split = f"{POSESCRIPT_LOCATION}/%s_ids_100k.json" # %s --> (train|val|test) 38 | file_posefix_split = f"{POSEFIX_LOCATION}/%s_%s_sequence_pair_ids.json" # %s %s --> (train|val|test), (in|out) 39 | 40 | 41 | ### pose config ---------------------------------------------------------------- 42 | 43 | POSE_FORMAT = 'smplh' 44 | SMPLH_BODY_MODEL_PATH = MAIN_DIR + '/data/smplh_amass_body_models' 45 | NEUTRAL_BM = f'{SMPLH_BODY_MODEL_PATH}/neutral/model.npz' 46 | NB_INPUT_JOINTS = 52 # default value used when initializing modules, unless specified otherwise 47 | n_betas = 16 48 | 49 | SMPLX_BODY_MODEL_PATH = MAIN_DIR + '/data/smpl_models' # should contain "smplx/SMPLX_NEUTRAL.(npz|pkl)" 50 | 51 | PID_NAN = -99999 # pose fake IDs, used for empty poses 52 | 53 | ### pose data ------------------------------------------------------------------ 54 | 55 | AMASS_FILE_LOCATION = MAIN_DIR + f"/data/AMASS/{POSE_FORMAT}/" 56 | supported_datasets = {"AMASS":AMASS_FILE_LOCATION} 57 | 58 | BABEL_LOCATION = MAIN_DIR + "/data/BABEL/babel_v1.0_release" 59 | 60 | generated_pose_path = '%s/generated_poses/posescript_version_{data_version}_split_{split}_gensamples.pth' # %s is for the model directory (obtainable with shortname_2_model_path) 61 | 62 | 63 | ### text data ------------------------------------------------------------------ 64 | 65 | MAX_TOKENS = 500 # defined here because it depends on the provided data (only affects the glovebigru configuration) 66 | 67 | vocab_files = { 68 | # IMPORTANT: do not use "_" symbols in the keys of this dictionary 69 | # IMPORTANT: vocabs overlap, but the order of the tokens is not the same; a 70 | # model trained with one vocab can't be finetuned with another 71 | # without risk 72 | "vocPSA2H2": "vocab_posescript_6293_auto100k.pkl", 73 | "vocPFAHPP": "vocab_posefix_6157_pp4284_auto.pkl", 74 | "vocMixPSA2H2PFAHPP": "vocab_posemix_PS6193_PF6157pp4284.pkl", 75 | } 76 | 77 | caption_files = { 78 | # : (, ) 79 | "posescript-A2": (3, [f"{POSESCRIPT_LOCATION}/posescript_auto_100k.json"]), 80 | "posescript-H2": (1, [f"{POSESCRIPT_LOCATION}/posescript_human_6293.json"]), 81 | "posefix-A": (3, [f"{POSEFIX_LOCATION}/posefix_auto_135305.json"]), 82 | "posefix-PP": (1, [f"{POSEFIX_LOCATION}/posefix_paraphrases_4284.json"]), # average is 2 texts/item (min 1, max 6) 83 | "posefix-H": (1, [f"{POSEFIX_LOCATION}/posefix_human_6157.json"]), 84 | "posefix-HPP": (1, [f"{POSEFIX_LOCATION}/posefix_human_6157.json", f"{POSEFIX_LOCATION}/posefix_paraphrases_4284.json"]), # 1 text/item at min, because paraphrases are essentially for the train set 85 | "posemix-PSH2-PFHPP": (1, [f"{POSESCRIPT_LOCATION}/posescript_human_6293.json", f"{POSEFIX_LOCATION}/posefix_human_6157.json", f"{POSEFIX_LOCATION}/posefix_paraphrases_4284.json"]), 86 | } 87 | 88 | # data cache 89 | dirpath_cache_dataset = MAIN_DIR + "/dataset_cache" 90 | cache_file_path = { 91 | "posescript":'%s/PoseScript_version_{data_version}_split_{split}_tokenizer_{tokenizer}.pkl' % dirpath_cache_dataset, 92 | "posefix":'%s/PoseFix_version_{data_version}_split_{split}_tokenizer_{tokenizer}.pkl' % dirpath_cache_dataset, 93 | "posemix":'%s/PoseMix_version_{data_version}_split_{split}_tokenizer_{tokenizer}.pkl' % dirpath_cache_dataset, 94 | "posestream":'%s/PoseStream_version_{data_version}_split_{split}_tokenizer_{tokenizer}.pkl' % dirpath_cache_dataset, 95 | } 96 | 97 | 98 | ################################################################################ 99 | # Model cache 100 | ################################################################################ 101 | 102 | GLOVE_DIR = MAIN_DIR + '/tools/torch_models/glove' # or None 103 | TRANSFORMER_CACHE_DIR = MAIN_DIR + '/tools/huggingface_models' 104 | SELFCONTACT_ESSENTIALS_DIR = MAIN_DIR + '/tools/selfcontact/essentials' 105 | 106 | 107 | ################################################################################ 108 | # Shortnames to checkpoint paths 109 | ################################################################################ 110 | # Shortnames are used to refer to: 111 | # - pretrained models 112 | # - models that generated pose files 113 | # - models used for evaluation (fid, recall, reconstruction, r-precision...) 114 | 115 | # shortnames for models are expected to be the same accross seed values; 116 | # model paths should contain a specific seed_value field instead of the actual seed value 117 | normalize_model_path = lambda model_path, seed_value: "/".join(model_path.split("/")[:-2]) + f"/seed{seed_value}/"+ model_path.split("/")[-1] 118 | 119 | # shortname & model paths are stored in shortname_2_model_path.json (which can be updated by some scripts) 120 | try: 121 | with open("shortname_2_model_path.txt", "r") as f: 122 | # each line has the following format: <4 spaces> 123 | shortname_2_model_path = [l.split(" ") for l in f.readlines() if len(l.strip())] 124 | shortname_2_model_path = {l[0]:normalize_model_path(l[1].strip(), '{seed}') for l in shortname_2_model_path} 125 | except FileNotFoundError: 126 | # print("File not found: shortname_2_model_path.txt - Please ensure you are launching operations from the right directory.") 127 | pass # this file may not even be needed; subsequent errors can be expected otherwise 128 | 129 | 130 | ################################################################################ 131 | # Evaluation 132 | ################################################################################ 133 | 134 | # NOTE: models used to compute the fid should be specified in `shortname_2_model_path` 135 | 136 | k_recall_values = [1, 5, 10] 137 | nb_sample_reconstruction = 30 138 | k_topk_reconstruction_values = [1, 6] # keep the top-1 and the top-N/4 where N is the nb_sample_reconstruction 139 | k_topk_r_precision = [1,2,3] 140 | r_precision_n_repetitions = 10 141 | sample_size_r_precision = 32 142 | 143 | 144 | ################################################################################ 145 | # Visualization settings 146 | ################################################################################ 147 | 148 | meshviewer_size = 1600 149 | 150 | 151 | if __name__=="__main__": 152 | import sys 153 | try: 154 | # if the provided model shortname is registered, return the complete model path (with the provided seed value) 155 | if sys.argv[1] in shortname_2_model_path: 156 | print(shortname_2_model_path[sys.argv[1]].format(seed=sys.argv[2])) 157 | except IndexError: 158 | # clean shortname_2_model_path.txt 159 | update = [] 160 | for k,p in shortname_2_model_path.items(): 161 | update.append(f"{k} {p.format(seed=0)}\n") 162 | with open("shortname_2_model_path.txt", "w") as f: 163 | f.writelines(update) 164 | print("Cleaned shortname_2_model_path.txt (unique entries with seed 0).") -------------------------------------------------------------------------------- /src/text2pose/data_augmentations.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023, 2024 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import torch 11 | from copy import deepcopy 12 | 13 | import text2pose.config as config 14 | from text2pose.encoders.tokenizers import Tokenizer 15 | from text2pose.posescript.utils import ALL_JOINT_NAMES 16 | 17 | 18 | class PoseFlip(): 19 | 20 | def __init__(self, nb_joints=22): 21 | super(PoseFlip, self).__init__() 22 | 23 | # get joint names (depends on the case) 24 | if nb_joints == 21: 25 | # all main joints, without the root 26 | joint_names = ALL_JOINT_NAMES[1:22] 27 | elif nb_joints == 22: 28 | # all main joints, with the root 29 | joint_names = ALL_JOINT_NAMES[:22] 30 | elif nb_joints == 52: 31 | joint_names = ALL_JOINT_NAMES[:] 32 | else: 33 | raise NotImplementedError 34 | 35 | # build joint correspondance indices 36 | n2i = {n:i for i, n in enumerate(joint_names)} 37 | l2r_j_id = {i:n2i[n.replace("left", "right")] for n,i in n2i.items() if "left" in n} # joint index correspondance between left and right 38 | self.left_joint_inds = torch.tensor(list(l2r_j_id.keys())) 39 | self.right_joint_inds = torch.tensor(list(l2r_j_id.values())) 40 | 41 | def flip_pose_data_LR(self, pose_data): 42 | """ 43 | pose_data: shape (batch_size, nb_joint, 3) 44 | """ 45 | l_data = deepcopy(pose_data[:,self.left_joint_inds]) 46 | r_data = deepcopy(pose_data[:,self.right_joint_inds]) 47 | pose_data[:,self.left_joint_inds] = r_data 48 | pose_data[:,self.right_joint_inds] = l_data 49 | pose_data[:,:, 1:3] *= -1 50 | return pose_data 51 | 52 | def __call__(self, pose_data): 53 | return self.flip_pose_data_LR(pose_data.clone()) 54 | 55 | 56 | def DataAugmentation(args, mode, tokenizer_name=None, nb_joints=config.NB_INPUT_JOINTS): 57 | # --- define process 58 | if mode == "posefix": 59 | return PosefixDataAugmentation(args, tokenizer_name, nb_joints) 60 | elif mode == "posescript": 61 | return PosescriptDataAugmentation(args, tokenizer_name, nb_joints) 62 | else: 63 | raise NotImplementedError 64 | 65 | 66 | class GenericDataAugmentation(): 67 | def __init__(self, args, tokenizer_name=None, nb_joints=config.NB_INPUT_JOINTS): 68 | super(GenericDataAugmentation, self).__init__() 69 | 70 | self.args = args 71 | 72 | # --- initialize data augmentation tools 73 | if tokenizer_name and (self.args.apply_LR_augmentation or self.args.copy_augmentation): 74 | self.tokenizer = Tokenizer(tokenizer_name) 75 | 76 | if tokenizer_name and self.args.copy_augmentation: 77 | empty_text_tokens = self.tokenizer("") 78 | self.empty_text_length = len(empty_text_tokens) # account for BOS & EOS tokens 79 | self.empty_text_tokens = torch.cat( (empty_text_tokens, self.tokenizer.pad_token_id * torch.ones( self.tokenizer.max_tokens-self.empty_text_length, dtype=empty_text_tokens.dtype) ), dim=0) 80 | 81 | if self.args.apply_LR_augmentation: 82 | self.pose_flip = PoseFlip(nb_joints) 83 | 84 | 85 | class PosescriptDataAugmentation(GenericDataAugmentation): 86 | def __init__(self, args, tokenizer_name=None, nb_joints=config.NB_INPUT_JOINTS): 87 | super(PosescriptDataAugmentation, self).__init__(args, tokenizer_name=tokenizer_name, nb_joints=nb_joints) 88 | 89 | def __call__(self, poses, caption_tokens=None, caption_lengths=None): 90 | 91 | batch_size = poses.size(0) # beware of incomplete batches! 92 | 93 | # random L/R flip 94 | if self.args.apply_LR_augmentation: 95 | flippable = torch.rand(batch_size) < 0.5 # completely random flip 96 | if hasattr(self, "tokenizer"): 97 | caption_tokens, caption_lengths, actually_flipped = self.tokenizer.flip(caption_tokens, flippable) 98 | else: 99 | actually_flipped = flippable 100 | poses[actually_flipped] = self.pose_flip(poses[actually_flipped]) 101 | 102 | return poses, caption_tokens, caption_lengths 103 | 104 | 105 | class PosefixDataAugmentation(GenericDataAugmentation): 106 | def __init__(self, args, tokenizer_name=None, nb_joints=config.NB_INPUT_JOINTS): 107 | super(PosefixDataAugmentation, self).__init__(args, tokenizer_name=tokenizer_name, nb_joints=nb_joints) 108 | 109 | def __call__(self, poses_A, caption_tokens=None, caption_lengths=None, poses_B=None, posescript_poses=None): 110 | 111 | batch_size = poses_A.size(0) # beware of incomplete batches! 112 | 113 | # random L/R flip 114 | if self.args.apply_LR_augmentation: 115 | flippable = torch.rand(batch_size) < 0.5 # completely random flip 116 | if hasattr(self, "tokenizer"): 117 | caption_tokens, caption_lengths, actually_flipped = self.tokenizer.flip(caption_tokens, flippable) 118 | else: 119 | actually_flipped = flippable 120 | poses_A[actually_flipped] = self.pose_flip(poses_A[actually_flipped]) 121 | poses_B[actually_flipped] = self.pose_flip(poses_B[actually_flipped]) 122 | 123 | # remove text cue: learn to copy pose A 124 | if self.args.copy_augmentation > 0: 125 | change = torch.rand(batch_size) < self.args.copy_augmentation # change at most a proportion of args.copy_augmentation poses 126 | if self.args.copyB2A: 127 | copy_B2A = torch.ones(batch_size).bool() 128 | else: 129 | # in the case of PoseScript data, A is "0"; therefore, for such 130 | # elements, we must copy B to A 131 | copy_B2A = posescript_poses # (batch_size) 132 | copy_A2B = ~copy_B2A 133 | poses_A[copy_B2A*change] = deepcopy(poses_B[copy_B2A*change]) 134 | poses_B[copy_A2B*change] = deepcopy(poses_A[copy_A2B*change]) 135 | # empty text 136 | if hasattr(self, "tokenizer"): 137 | caption_tokens[change] = self.empty_text_tokens[:caption_tokens.shape[1]] # by default, `empty_text_tokens` is very long 138 | caption_lengths[change] = self.empty_text_length 139 | caption_tokens = caption_tokens[:,:caption_lengths.max()] # update truncation, the longest text may have changed 140 | 141 | return poses_A, caption_tokens, caption_lengths, poses_B -------------------------------------------------------------------------------- /src/text2pose/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | ################################################################################ 7 | ## Miscellaneous modules 8 | ################################################################################ 9 | 10 | 11 | class L2Norm(nn.Module): 12 | def forward(self, x): 13 | return x / x.norm(dim=-1, keepdim=True) 14 | 15 | 16 | class ConCatModule(nn.Module): 17 | 18 | def __init__(self): 19 | super(ConCatModule, self).__init__() 20 | 21 | def forward(self, x): 22 | x = torch.cat(x, dim=1) 23 | return x 24 | 25 | 26 | class DiffModule(nn.Module): 27 | 28 | def __init__(self): 29 | super(DiffModule, self).__init__() 30 | 31 | def forward(self, x): 32 | return x[1] - x[0] 33 | 34 | 35 | class AddModule(nn.Module): 36 | 37 | def __init__(self, axis=0): 38 | super(AddModule, self).__init__() 39 | self.axis = axis 40 | 41 | def forward(self, x): 42 | return x.sum(self.axis) 43 | 44 | 45 | class SeqModule(nn.Module): 46 | 47 | def __init__(self): 48 | super(SeqModule, self).__init__() 49 | 50 | def forward(self, x): 51 | # input: list of T tensors of size (BS, d) 52 | # output: tensor of size (BS, T, d) 53 | return torch.cat([xx.unsqueeze(1) for xx in x], dim = 1) 54 | 55 | 56 | class MiniMLP(nn.Module): 57 | 58 | def __init__(self, input_dim, hidden_dim, output_dim): 59 | super(MiniMLP, self).__init__() 60 | self.layers = nn.Sequential( 61 | nn.Linear(input_dim, hidden_dim), 62 | nn.ReLU(), 63 | nn.Linear(hidden_dim, output_dim) 64 | ) 65 | 66 | def forward(self, x): 67 | return self.layers(x) 68 | 69 | 70 | class PositionalEncoding(nn.Module): 71 | 72 | def __init__(self, d_model, dropout=0.1, max_len=5000): 73 | super(PositionalEncoding, self).__init__() 74 | 75 | pe = torch.zeros(max_len, 1, d_model) 76 | position = torch.arange(max_len).unsqueeze(1) 77 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) 78 | pe[:, 0, 0::2] = torch.sin(position * div_term) 79 | pe[:, 0, 1::2] = torch.cos(position * div_term) 80 | self.register_buffer('pe', pe) 81 | 82 | self.dropout = nn.Dropout(p=dropout) 83 | 84 | def forward(self, x): 85 | """ 86 | Args: 87 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 88 | """ 89 | x = x + self.pe[:x.size(0)] 90 | return self.dropout(x) 91 | 92 | 93 | class TIRG(nn.Module): 94 | """ 95 | The TIRG model. 96 | Implementation derived (except for BaseModel-inherence) from 97 | https://github.com/google/tirg (downloaded on July 23th 2020). 98 | The method is described in Nam Vo, Lu Jiang, Chen Sun, Kevin Murphy, Li-Jia 99 | Li, Li Fei-Fei, James Hays. "Composing Text and Image for Image Retrieval - 100 | An Empirical Odyssey" CVPR 2019. arXiv:1812.07119 101 | """ 102 | 103 | def __init__(self, input_dim=[512, 512], output_dim=512, out_l2_normalize=False): 104 | super(TIRG, self).__init__() 105 | 106 | self.input_dim = sum(input_dim) 107 | self.output_dim = output_dim 108 | 109 | # --- modules 110 | self.a = nn.Parameter(torch.tensor([1.0, 1.0])) # changed the second coeff from 10.0 to 1.0 111 | self.gated_feature_composer = nn.Sequential( 112 | ConCatModule(), nn.BatchNorm1d(self.input_dim), nn.ReLU(), 113 | nn.Linear(self.input_dim, self.output_dim)) 114 | self.res_info_composer = nn.Sequential( 115 | ConCatModule(), nn.BatchNorm1d(self.input_dim), nn.ReLU(), 116 | nn.Linear(self.input_dim, self.input_dim), nn.ReLU(), 117 | nn.Linear(self.input_dim, self.output_dim)) 118 | 119 | if out_l2_normalize: 120 | self.output_layer = L2Norm() # added to the official TIRG code 121 | else: 122 | self.output_layer = nn.Sequential() 123 | 124 | def query_compositional_embedding(self, main_features, modifying_features): 125 | f1 = self.gated_feature_composer((main_features, modifying_features)) 126 | f2 = self.res_info_composer((main_features, modifying_features)) 127 | f = torch.sigmoid(f1) * main_features * self.a[0] + f2 * self.a[1] 128 | f = self.output_layer(f) 129 | return f -------------------------------------------------------------------------------- /src/text2pose/encoders/pose_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023, 2024 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import torch.nn as nn 11 | import roma 12 | from human_body_prior.models.vposer_model import NormalDistDecoder, VPoser 13 | 14 | import text2pose.config as config 15 | from text2pose.encoders.modules import L2Norm 16 | 17 | 18 | ################################################################################ 19 | ## Pose encoder / decoder 20 | ################################################################################ 21 | 22 | 23 | class Object(object): 24 | pass 25 | 26 | 27 | class PoseEncoder(nn.Module): 28 | 29 | def __init__(self, num_neurons=512, num_neurons_mini=32, latentD=512, num_body_joints=config.NB_INPUT_JOINTS, role=None): 30 | super(PoseEncoder, self).__init__() 31 | 32 | self.num_body_joints = num_body_joints 33 | self.input_dim = self.num_body_joints * 3 34 | 35 | # use VPoser pose encoder architecture... 36 | vposer_params = Object() 37 | vposer_params.model_params = Object() 38 | vposer_params.model_params.num_neurons = num_neurons 39 | vposer_params.model_params.latentD = latentD 40 | vposer = VPoser(vposer_params) 41 | encoder_layers = list(vposer.encoder_net.children()) 42 | # change first layers to have the right data input size 43 | encoder_layers[1] = nn.BatchNorm1d(self.input_dim) 44 | encoder_layers[2] = nn.Linear(self.input_dim, num_neurons) 45 | # remove last layer; the last layer.s depend on the task/role 46 | encoder_layers = encoder_layers[:-1] 47 | 48 | # output layers 49 | if role == "retrieval": 50 | encoder_layers += [ 51 | nn.Linear(num_neurons, num_neurons_mini), # keep the bottleneck while adapting to the joint embedding size 52 | nn.ReLU(), 53 | nn.Linear(num_neurons_mini, latentD), 54 | L2Norm()] 55 | elif role == "generative": 56 | encoder_layers += [ NormalDistDecoder(num_neurons, latentD) ] 57 | elif role == "no_output_layer": 58 | encoder_layers += [ ] 59 | elif role == "modifier": 60 | encoder_layers += [ 61 | nn.Linear(num_neurons, latentD) 62 | ] 63 | else: 64 | raise NotImplementedError 65 | 66 | self.encoder = nn.Sequential(*encoder_layers) 67 | 68 | def forward(self, pose): 69 | return self.encoder(pose) 70 | 71 | 72 | class PoseDecoder(nn.Module): 73 | 74 | def __init__(self, num_neurons=512, latentD=32, num_body_joints=config.NB_INPUT_JOINTS): 75 | super(PoseDecoder, self).__init__() 76 | 77 | self.num_body_joints = num_body_joints 78 | 79 | # use VPoser pose decoder architecture... 80 | vposer_params = Object() 81 | vposer_params.model_params = Object() 82 | vposer_params.model_params.num_neurons = num_neurons 83 | vposer_params.model_params.latentD = latentD 84 | vposer = VPoser(vposer_params) 85 | decoder_layers = list(vposer.decoder_net.children()) 86 | # change one of the final layers to have the right data output size 87 | decoder_layers[-2] = nn.Linear(num_neurons, self.num_body_joints * 6) 88 | 89 | self.decoder = nn.Sequential(*decoder_layers) 90 | 91 | def forward(self, Zin): 92 | bs = Zin.shape[0] 93 | prec = self.decoder(Zin) 94 | return { 95 | 'pose_body': roma.rotmat_to_rotvec(prec.view(-1, 3, 3)).view(bs, -1, 3), # (batch_size, num_body_joints, 3) 96 | 'pose_body_matrot': prec.view(bs, -1, 9) # (batch_size, num_body_joints, 9) 97 | } -------------------------------------------------------------------------------- /src/text2pose/fid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from scipy import linalg 5 | 6 | import text2pose.config as config 7 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder 8 | 9 | class FID(nn.Module): 10 | 11 | def __init__(self, version, device=torch.device('cpu'), name_in_batch='pose'): 12 | super().__init__() 13 | assert isinstance(version, tuple), "FID version should follow the format (retrieval_model_shortname, seed), where retrieval_model_shortname is actually provided with --fid as input to the script." 14 | self.version = version 15 | self.device = device 16 | self.name_in_batch = name_in_batch 17 | self._load_model() 18 | 19 | def sstr(self): 20 | return f"FID_{self.version[0]}_seed{self.version[1]}" 21 | 22 | def _load_model(self): 23 | ckpt_path = config.shortname_2_model_path[self.version[0]].format(seed=self.version[1]) 24 | ckpt = torch.load(ckpt_path, 'cpu') 25 | print("FID: load", ckpt_path) 26 | self.model = PoseEncoder(latentD=ckpt['args'].latentD, num_body_joints=getattr(ckpt['args'], 'num_body_joints', 52), role="retrieval") 27 | self.model.load_state_dict({k[len('pose_encoder.'):]: v for k,v in ckpt['model'].items() if k.startswith('pose_encoder.encoder.')}) 28 | self.model.eval() 29 | self.model.to(self.device) 30 | 31 | def extract_features(self, batchpose): 32 | batchpose = batchpose.to(self.device) 33 | batchpose = batchpose.view(batchpose.size(0),-1)[:,:self.model.input_dim] 34 | features = self.model(batchpose) 35 | return features 36 | 37 | def extract_real_features(self, valdataloader): 38 | real_features = [] 39 | with torch.inference_mode(): 40 | for batches in valdataloader: 41 | real_features.append( self.extract_features(batches[self.name_in_batch]) ) 42 | self.real_features = torch.cat(real_features, dim=0).cpu().numpy() 43 | self.realmu = np.mean(self.real_features, axis=0) 44 | self.realsigma = np.cov(self.real_features, rowvar=False) 45 | print('FID: extracted real features', self.real_features.shape) 46 | 47 | def reset_gen_features(self): 48 | self.gen_features = [] 49 | 50 | def add_gen_features(self, batchpose): 51 | with torch.inference_mode(): 52 | self.gen_features.append( self.extract_features(batchpose) ) 53 | 54 | def compute(self): 55 | gen_features = torch.cat(self.gen_features, dim=0).cpu().numpy() 56 | assert gen_features.shape[0] == self.real_features.shape[0] 57 | mu = np.mean(gen_features, axis=0) 58 | sigma = np.cov(gen_features, rowvar=False) 59 | return calculate_frechet_distance(mu, sigma, self.realmu, self.realsigma) 60 | 61 | 62 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 63 | """Numpy implementation of the Frechet Distance. 64 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 65 | and X_2 ~ N(mu_2, C_2) is 66 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 67 | Stable version by Dougal J. Sutherland. 68 | Params: 69 | -- mu1 : Numpy array containing the activations of a layer of the 70 | inception net (like returned by the function 'get_predictions') 71 | for generated samples. 72 | -- mu2 : The sample mean over activations, precalculated on an 73 | representative data set. 74 | -- sigma1: The covariance matrix over activations for generated samples. 75 | -- sigma2: The covariance matrix over activations, precalculated on an 76 | representative data set. 77 | Returns: 78 | -- : The Frechet Distance. 79 | """ 80 | 81 | mu1 = np.atleast_1d(mu1) 82 | mu2 = np.atleast_1d(mu2) 83 | 84 | sigma1 = np.atleast_2d(sigma1) 85 | sigma2 = np.atleast_2d(sigma2) 86 | 87 | assert mu1.shape == mu2.shape, \ 88 | 'Training and test mean vectors have different lengths' 89 | assert sigma1.shape == sigma2.shape, \ 90 | 'Training and test covariances have different dimensions' 91 | 92 | diff = mu1 - mu2 93 | 94 | # Product might be almost singular 95 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 96 | if not np.isfinite(covmean).all(): 97 | msg = ('fid calculation produces singular product; ' 98 | 'adding %s to diagonal of cov estimates') % eps 99 | print(msg) 100 | offset = np.eye(sigma1.shape[0]) * eps 101 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 102 | 103 | # Numerical error might give slight imaginary component 104 | if np.iscomplexobj(covmean): 105 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 106 | m = np.max(np.abs(covmean.imag)) 107 | raise ValueError('Imaginary component {}'.format(m)) 108 | covmean = covmean.real 109 | 110 | tr_covmean = np.trace(covmean) 111 | 112 | return (diff.dot(diff) + np.trace(sigma1) 113 | + np.trace(sigma2) - 2 * tr_covmean) -------------------------------------------------------------------------------- /src/text2pose/generative/README.md: -------------------------------------------------------------------------------- 1 | # Text-conditioned Generative Model for 3D Human Poses 2 | 3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._ 4 | 5 | _:warning: The evaluation of this model relies partly on several [text-to-pose retrieval model](../retrieval/README.md), see section **Extra setup**, below._ 6 | 7 | ## Model overview 8 | 9 | * **Input**: pose description; 10 | * **Output**: 3D human pose. 11 | 12 | ![Generative model](../../../images/generative_model.png) 13 | 14 | ## :crystal_ball: Demo 15 | 16 | To generate poses based on a pretrained model and your own input description, run the following: 17 | 18 | ``` 19 | streamlit run generative/demo_generative.py -- --model_paths 20 | ``` 21 | 22 | :bulb: Tips: _Specify several model paths to compare models together._ 23 | 24 | ## Extra setup 25 | 26 | At the beginning of the bash script, assign to variable `fid` the shortname of the trained [text-to-pose retrieval model](../retrieval/README.md) to be used for computing the FID. 27 | 28 | Add a line in *shortname_2_model_path.txt* to indicate the path to the model corresponding to the provided shortname. 29 | 30 | ## :bullettrain_front: Train 31 | 32 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options. 33 | 34 | Then use the following command: 35 | ``` 36 | bash generative/script_generative.sh 'train' 37 | ``` 38 | 39 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model. 40 | 41 | ## :dart: Evaluate 42 | 43 | Use the following command: 44 | ``` 45 | bash generative/script_generative.sh 'eval' 46 | ``` 47 | 48 | Expected additional arguments include: 49 | 50 | | evaluation phase (`eval_type`) | expected additional arguments | needed to get... | 51 | |---|---|---| 52 | | `regular` | model path | **elbo**, **fid** | 53 | | `generate_poses` | model path, seed number | **mRecall R/G** (step 1/2), **mRecall G/R** (step 1/3) | 54 | | `RG` | seed number, shortname of the evaluated generative model, retrieval model shortname | **mRecall R/G** (step 2/2) | 55 | | `GRa` | seed number, shortname of the evaluated generative model | **mRecall G/R** (step 2/3) | 56 | | `GRb` | seed number, shortname to the model trained in previous step | **mRecall G/R** (step 3/3) | 57 | 58 | In the above table, "model path" is the path to the generative model to be evaluated. 59 | 60 | _**Important note**: the fid, mRecall R/G, and mRecall G/R rely on trained retrieval models._ 61 | 62 | ## Generate and visualize pose samples for the dataset 63 | 64 | For evaluation, *generative/script_generative.sh* makes the model generate pose samples for each caption of the dataset, thanks to the following command: 65 | 66 | ``` 67 | python generative/generate_poses.py --model_path 68 | ``` 69 | 70 | The generated pose samples can be visualized, along with the original pose and the related description, by running the following: 71 | 72 | ``` 73 | streamlit run generative/look_at_generated_pose_samples.py -- --model_path --dataset_version --split 74 | ``` 75 | -------------------------------------------------------------------------------- /src/text2pose/generative/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | -------------------------------------------------------------------------------- /src/text2pose/generative/demo_generative.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import streamlit as st 11 | import argparse 12 | import torch 13 | import numpy as np 14 | 15 | import text2pose.demo as demo 16 | import text2pose.utils as utils 17 | import text2pose.utils_visu as utils_visu 18 | from text2pose.generative.evaluate_generative import load_model 19 | 20 | 21 | parser = argparse.ArgumentParser(description='Parameters for the demo.') 22 | parser.add_argument('--model_paths', nargs='+', type=str, help='Paths to the models to be compared.') 23 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help="Checkpoint to choose if model path is incomplete.") 24 | parser.add_argument('--n_generate', type=int, default=12, help="Number of poses to generate (number of samples); if considering only one model.") 25 | args = parser.parse_args() 26 | 27 | 28 | ### INPUT 29 | ################################################################################ 30 | 31 | data_version = "posescript-H2" 32 | 33 | 34 | ### SETUP 35 | ################################################################################ 36 | 37 | # --- layout 38 | st.markdown(""" 39 | 44 | """, unsafe_allow_html=True) 45 | 46 | # correct the number of generated sample depending on the setting 47 | if len(args.model_paths) > 1: 48 | n_generate = 4 49 | else: 50 | n_generate = args.n_generate 51 | 52 | # --- data 53 | available_splits = ['train', 'val', 'test'] 54 | models, _, body_model = demo.setup_models(args.model_paths, args.checkpoint, load_model) 55 | dataID_2_pose_info, captions = demo.setup_posescript_data(data_version) 56 | 57 | # --- seed 58 | torch.manual_seed(42) 59 | np.random.seed(42) 60 | 61 | 62 | ### MAIN APP 63 | ################################################################################ 64 | 65 | # define query input interface 66 | cols_query = st.columns(3) 67 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test')) 68 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID'), index=1) 69 | number = cols_query[2].number_input("Split index or ID:", 0) 70 | st.markdown("""---""") 71 | 72 | # get query data 73 | pose_ID, pose_data, pose_img, default_description = demo.get_posescript_datapoint(number, query_type, split_for_research, captions, dataID_2_pose_info, body_model) 74 | 75 | # show query data 76 | cols_input = st.columns(2) 77 | cols_input[0].image(pose_img, caption="Annotated pose") 78 | if default_description: 79 | cols_input[1].write("Annotated text:") 80 | cols_input[1].write(f"_{default_description}_") 81 | else: 82 | cols_input[1].write("_(Not annotated.)_") 83 | 84 | # get input description 85 | description = cols_input[1].text_area("Pose description:", 86 | value=default_description, 87 | placeholder="The person is...", 88 | height=None, max_chars=None) 89 | 90 | analysis = cols_input[1].checkbox('Analysis') # whether to show the reconstructed pose and the mean sample pose in addition of some samples 91 | 92 | # generate results 93 | if analysis: 94 | 95 | st.markdown("""---""") 96 | st.write("**Generated poses** (*The reconstructed pose is shown in green; the mean pose in red; and samples in grey.*):") 97 | n_generate = 2 98 | nb_cols = 2 + n_generate # reconstructed pose + mean sample pose + n_generate sample poses: all must fit in one row, for each studied model 99 | 100 | for i, model in enumerate(models): 101 | with torch.no_grad(): 102 | rec_pose_data = model.forward_autoencoder(pose_data)['pose_body_pose'].view(1, -1) 103 | gen_pose_data_mean = model.sample_str_meanposes(description)['pose_body'].view(1, -1) 104 | gen_pose_data_samples = model.sample_str_nposes(description, n=n_generate)['pose_body'][0,...].view(n_generate, -1) 105 | 106 | # render poses 107 | imgs = utils_visu.image_from_pose_data(rec_pose_data, body_model, color='green', add_ground_plane=True, two_views=60) 108 | imgs += utils_visu.image_from_pose_data(gen_pose_data_mean, body_model, color='red', add_ground_plane=True, two_views=60) 109 | imgs += utils_visu.image_from_pose_data(gen_pose_data_samples, body_model, color='grey', add_ground_plane=True, two_views=60) 110 | 111 | # display images 112 | cols = st.columns(nb_cols+1) # +1 to display model info 113 | cols[0].markdown(f'

{args.model_paths[i]}

', unsafe_allow_html=True) 114 | for i in range(nb_cols): 115 | cols[i%nb_cols+1].image(demo.process_img(imgs[i])) 116 | st.markdown("""---""") 117 | 118 | else: 119 | 120 | st.markdown("""---""") 121 | st.write("**Generated poses:**") 122 | 123 | for i, model in enumerate(models): 124 | with torch.no_grad(): 125 | gen_pose_data_samples = model.sample_str_nposes(description, n=n_generate)['pose_body'][0,...].view(n_generate, -1) 126 | 127 | # render poses 128 | imgs = utils_visu.image_from_pose_data(gen_pose_data_samples, body_model, color='grey', add_ground_plane=True, two_views=60) 129 | 130 | # display images 131 | if len(models) > 1: 132 | cols = st.columns(n_generate+1) # +1 to display model info 133 | cols[0].markdown(f'

{args.model_paths[i]}

', unsafe_allow_html=True) 134 | for i in range(n_generate): 135 | cols[i%n_generate+1].image(demo.process_img(imgs[i])) 136 | st.markdown("""---""") 137 | else: 138 | cols = st.columns(demo.nb_cols) 139 | for i in range(n_generate): 140 | cols[i%demo.nb_cols].image(demo.process_img(imgs[i])) 141 | st.markdown("""---""") 142 | st.write(f"_Results obtained with model: {args.model_paths[0]}_") -------------------------------------------------------------------------------- /src/text2pose/generative/evaluate_generative.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import os 11 | from tqdm import tqdm 12 | import torch 13 | import numpy as np 14 | from human_body_prior.body_model.body_model import BodyModel 15 | 16 | import text2pose.config as config 17 | import text2pose.evaluate as evaluate 18 | from text2pose.data import PoseScript 19 | from text2pose.encoders.tokenizers import get_tokenizer_name 20 | from text2pose.generative.model_generative import CondTextPoser 21 | from text2pose.fid import FID 22 | 23 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 24 | 25 | OVERWRITE_RESULT = False 26 | 27 | 28 | ################################################################################ 29 | 30 | def load_model(model_path, device): 31 | 32 | assert os.path.isfile(model_path), "File {} not found.".format(model_path) 33 | 34 | # load checkpoint & model info 35 | ckpt = torch.load(model_path, 'cpu') 36 | text_encoder_name = ckpt['args'].text_encoder_name 37 | transformer_topping = getattr(ckpt['args'], 'transformer_topping', None) 38 | latentD = ckpt['args'].latentD 39 | num_body_joints = getattr(ckpt['args'], 'num_body_joints', 52) 40 | 41 | # load model 42 | model = CondTextPoser(text_encoder_name=text_encoder_name, 43 | transformer_topping=transformer_topping, 44 | latentD=latentD, 45 | num_body_joints=num_body_joints 46 | ).to(device) 47 | model.load_state_dict(ckpt['model']) 48 | model.eval() 49 | print(f"Loaded model from (epoch {ckpt['epoch']}):", model_path) 50 | 51 | return model, get_tokenizer_name(text_encoder_name) 52 | 53 | 54 | def eval_model(model_path, dataset_version, fid_version, split='val'): 55 | 56 | device = torch.device('cuda:0') 57 | 58 | # set seed for reproducibility (sampling for pose generation) 59 | torch.manual_seed(42) 60 | np.random.seed(42) 61 | 62 | # define result file & get auxiliary info 63 | fid_version, precision = get_evaluation_auxiliary_info(fid_version) 64 | nb_caps = config.caption_files[dataset_version][0] 65 | get_res_file = evaluate.get_result_filepath_func(model_path, split, dataset_version, precision, nb_caps) 66 | 67 | # load model if results for at least one caption is missing 68 | if OVERWRITE_RESULT or evaluate.one_result_file_is_missing(get_res_file, nb_caps): 69 | model, tokenizer_name = load_model(model_path, device) 70 | 71 | # compute or load results for the given run & caption 72 | results = {} 73 | for cap_ind in range(nb_caps): 74 | filename_res = get_res_file(cap_ind) 75 | if not os.path.isfile(filename_res) or OVERWRITE_RESULT: 76 | d = PoseScript(version=dataset_version, split=split, tokenizer_name=tokenizer_name, caption_index=cap_ind, num_body_joints=model.pose_encoder.num_body_joints, cache=True) 77 | cap_results = compute_eval_metrics(model, d, fid_version, device) 78 | evaluate.save_results_to_file(cap_results, filename_res) 79 | else: 80 | cap_results = evaluate.load_results_from_file(filename_res) 81 | # aggregate results 82 | results = {k:[v] for k, v in cap_results.items()} if not results else {k:results[k]+[v] for k,v in cap_results.items()} 83 | 84 | # average over captions 85 | results = {k:sum(v)/nb_caps for k,v in results.items()} 86 | 87 | return {k:[v] for k, v in results.items()} 88 | 89 | 90 | def get_evaluation_auxiliary_info(fid_version, seed=1): 91 | # NOTE: default seed=1 for consistent evaluation 92 | precision = "" 93 | if fid_version is not None: 94 | fid_version = (fid_version, seed) 95 | precision += f"_X{fid_version[0]}-{fid_version[1]}X" 96 | return fid_version, precision 97 | 98 | 99 | def compute_eval_metrics(model, dataset, fid_version, device): 100 | 101 | # initialize 102 | data_loader = torch.utils.data.DataLoader( 103 | dataset, sampler=None, shuffle=False, 104 | batch_size=32, 105 | num_workers=8, 106 | pin_memory=True, 107 | drop_last=False 108 | ) 109 | 110 | body_model = BodyModel(model_type = config.POSE_FORMAT, 111 | bm_fname = config.NEUTRAL_BM, 112 | num_betas = config.n_betas).to(device) 113 | 114 | fid = FID(version=fid_version, device=device) 115 | fid.extract_real_features(data_loader) 116 | fid.reset_gen_features() 117 | 118 | pose_metrics = {f'{k}_{v}': 0.0 for k in ['v2v', 'jts', 'rot'] for v in ['elbo', 'dist_avg'] \ 119 | + [f'dist_top{topk}' for topk in config.k_topk_reconstruction_values]} 120 | 121 | # compute metrics 122 | for batch in tqdm(data_loader): 123 | 124 | # data setup 125 | model_input = dict( 126 | poses = batch['pose'].to(device), 127 | caption_lengths = batch['caption_lengths'].to(device), 128 | captions = batch['caption_tokens'][:,:batch['caption_lengths'].max()].to(device), 129 | ) 130 | 131 | with torch.inference_mode(): 132 | 133 | pose_metrics, _ = evaluate.add_elbo_and_reconstruction(model_input, pose_metrics, model, body_model, output_distr_key="t_z", reference_pose_key="poses") 134 | fid.add_gen_features( model.sample_nposes(**model_input, n=1)['pose_body'] ) 135 | 136 | # average over the dataset 137 | for k in pose_metrics: pose_metrics[k] /= len(dataset) 138 | 139 | # normalize the elbo (the same is done earlier for the reconstruction metrics) 140 | pose_metrics.update({'v2v_elbo':pose_metrics['v2v_elbo']/(body_model.J_regressor.shape[1] * 3), 141 | 'jts_elbo':pose_metrics['jts_elbo']/(body_model.J_regressor.shape[0] * 3), 142 | 'rot_elbo':pose_metrics['rot_elbo']/(model.pose_decoder.num_body_joints * 9)}) 143 | 144 | # compute fid metric 145 | results = {'fid': fid.compute()} 146 | results.update(pose_metrics) 147 | 148 | return results 149 | 150 | 151 | def display_results(results): 152 | metric_order = ['fid'] + [f'{x}_elbo' for x in ['jts', 'v2v', 'rot']] \ 153 | + [f'{k}_{v}' for k in ['jts', 'v2v', 'rot'] 154 | for v in ['dist_avg'] + [f'dist_top{topk}' for topk in config.k_topk_reconstruction_values]] 155 | results = evaluate.scale_and_format_results(results) 156 | print(f"\n & {' & '.join([results[m] for m in metric_order])} \\\\\n") 157 | 158 | 159 | ################################################################################ 160 | 161 | if __name__=="__main__": 162 | 163 | # added special arguments 164 | evaluate.eval_parser.add_argument('--fid', type=str, help='Version of the fid to use for evaluation.') 165 | 166 | args = evaluate.eval_parser.parse_args() 167 | args = evaluate.get_full_model_path(args) 168 | 169 | # compute results 170 | if args.average_over_runs: 171 | ret = evaluate.eval_model_all_runs(eval_model, args.model_path, dataset_version=args.dataset, fid_version=args.fid, split=args.split) 172 | else: 173 | ret = eval_model(args.model_path, dataset_version=args.dataset, fid_version=args.fid, split=args.split) 174 | 175 | # display results 176 | print(ret) 177 | display_results(ret) -------------------------------------------------------------------------------- /src/text2pose/generative/generate_poses.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import os 11 | import argparse 12 | from tqdm import tqdm 13 | import torch 14 | import numpy as np 15 | 16 | import text2pose.config as config 17 | from text2pose.data import PoseScript 18 | from text2pose.generative.evaluate_generative import load_model 19 | 20 | 21 | parser = argparse.ArgumentParser(description='Parameters to generate poses corresponding to each caption.') 22 | parser.add_argument('--model_path', type=str, help='Path to the model.') 23 | parser.add_argument('--n_generate', type=int, default=5, help="Number of poses to generate for a given caption.") 24 | args = parser.parse_args() 25 | 26 | 27 | ### INPUT 28 | ################################################################################ 29 | 30 | device = torch.device('cuda:0') 31 | save_path = config.generated_pose_path % os.path.dirname(args.model_path) 32 | splits = ['train', 'test', 'val'] 33 | 34 | torch.manual_seed(42) 35 | np.random.seed(42) 36 | 37 | 38 | ### GENERATE POSES 39 | ################################################################################ 40 | 41 | # load model 42 | model, tokenizer_name = load_model(args.model_path, device) 43 | dataset_version = torch.load(args.model_path, 'cpu')['args'].dataset 44 | 45 | # create saving directory 46 | if not os.path.isdir(os.path.dirname(save_path)): 47 | os.mkdir(os.path.dirname(save_path)) 48 | 49 | # generate poses 50 | for s in splits: 51 | 52 | # check that the poses were not already generated 53 | filepath = save_path.format(data_version=dataset_version, split=s) 54 | assert not os.path.isfile(filepath), "Poses already generated!" 55 | 56 | d = PoseScript(version=dataset_version, split=s, tokenizer_name=tokenizer_name, num_body_joints=model.pose_decoder.num_body_joints) 57 | ncaptions = config.caption_files[dataset_version][0] 58 | output = torch.empty( (len(d), ncaptions, args.n_generate, model.pose_decoder.num_body_joints, 3), dtype=torch.float32) 59 | 60 | for index in tqdm(range(len(d))): 61 | # look at each available caption in turn 62 | for cidx in range(ncaptions): 63 | item = d.__getitem__(index, cidx=cidx) 64 | caption_tokens = item['caption_tokens'].to(device).unsqueeze(0) 65 | caption_lengths = torch.tensor([item['caption_lengths']]).to(device) 66 | caption_tokens = caption_tokens[:,:caption_lengths.max()] 67 | with torch.no_grad(): 68 | genposes = model.sample_nposes(caption_tokens, caption_lengths, n=args.n_generate)['pose_body'][0,...] 69 | output[index,cidx,...] = genposes 70 | 71 | # save 72 | torch.save(output, filepath) 73 | print(filepath) -------------------------------------------------------------------------------- /src/text2pose/generative/look_at_generated_pose_samples.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import streamlit as st 11 | import os 12 | import argparse 13 | from human_body_prior.body_model.body_model import BodyModel 14 | 15 | import text2pose.config as config 16 | import text2pose.utils_visu as utils_visu 17 | from text2pose.data import PoseScript 18 | 19 | 20 | parser = argparse.ArgumentParser(description='Parameters for the demo.') 21 | parser.add_argument('--model_path', type=str, help='Path to the model that generated the pose samples to visualize.') 22 | parser.add_argument('--dataset_version', type=str, help='Dataset version (depends on the model)') 23 | parser.add_argument('--split', type=str, help='Split') 24 | args = parser.parse_args() 25 | 26 | 27 | ### SETUP 28 | ################################################################################ 29 | 30 | @st.cache_resource 31 | def setup(args): 32 | 33 | # setup data 34 | generated_pose_path = config.generated_pose_path % os.path.dirname(args.model_path) 35 | generated_pose_path = generated_pose_path.format(data_version=args.dataset_version, split=args.split) 36 | 37 | dataset = PoseScript(version=args.dataset_version, split=args.split, 38 | cache=False, generated_pose_samples_path=generated_pose_path) 39 | 40 | # setup body model 41 | body_model = BodyModel(model_type = config.POSE_FORMAT, 42 | bm_fname = config.NEUTRAL_BM, 43 | num_betas = config.n_betas) 44 | body_model.eval() 45 | body_model.to('cpu') 46 | 47 | return generated_pose_path, dataset, body_model 48 | 49 | 50 | generated_pose_path, dataset, body_model = setup(args) 51 | 52 | 53 | ### VISUALIZE 54 | ################################################################################ 55 | 56 | st.write(f"**Dataset:** {args.dataset_version}") 57 | st.write(f"**Split:** {args.split}") 58 | st.write(f"**Using pose samples from:** {generated_pose_path}") 59 | 60 | # get input pose index & caption index 61 | st.write("**Choose a data point:**") 62 | index = st.number_input("Index in split:", 0, len(dataset.pose_samples)-1) 63 | cidx = st.number_input("Caption index:", 0, len(dataset.pose_samples[0])-1) 64 | 65 | # display description 66 | st.write("**Description:** "+dataset.captions[dataset.dataIDs[index]][cidx]) 67 | 68 | # render poses 69 | nb_samples = dataset.pose_samples.shape[2] 70 | img_original = utils_visu.image_from_pose_data(dataset.get_pose(index).view(1, -1), body_model, color='blue') 71 | imgs_sampled = utils_visu.image_from_pose_data(dataset.pose_samples[index, cidx].view(nb_samples, -1), body_model, color='green') 72 | 73 | # display original pose 74 | st.write("**Original pose:**") 75 | st.image(img_original[0]) 76 | 77 | # display generated pose samples 78 | st.write("**Generated pose samples for this description:**") 79 | cols = st.columns(nb_samples) 80 | for i in range(nb_samples): 81 | cols[i].image(imgs_sampled[i]) -------------------------------------------------------------------------------- /src/text2pose/generative/model_generative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import text2pose.config as config 6 | from text2pose.encoders.tokenizers import Tokenizer, get_text_encoder_or_decoder_module_name, get_tokenizer_name 7 | from text2pose.encoders.pose_encoder_decoder import PoseDecoder, PoseEncoder 8 | from text2pose.encoders.text_encoders import TextEncoder, TransformerTextEncoder 9 | 10 | class CondTextPoser(nn.Module): 11 | 12 | def __init__(self, num_neurons=512, latentD=32, num_body_joints=config.NB_INPUT_JOINTS, text_encoder_name='distilbertUncased', transformer_topping=None): 13 | super(CondTextPoser, self).__init__() 14 | 15 | self.latentD = latentD 16 | 17 | # Define pose auto-encoder 18 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons, latentD=latentD, num_body_joints=num_body_joints, role="generative") 19 | self.pose_decoder = PoseDecoder(num_neurons=num_neurons, latentD=latentD, num_body_joints=num_body_joints) 20 | 21 | # Define text encoder 22 | self.text_encoder_name = text_encoder_name 23 | module_ref = get_text_encoder_or_decoder_module_name(text_encoder_name) 24 | if module_ref in ["glovebigru"]: 25 | self.text_encoder = TextEncoder(self.text_encoder_name, num_neurons=num_neurons, latentD=latentD, role="generative") 26 | elif module_ref in ["glovetransf", "distilbertUncased"]: 27 | self.text_encoder = TransformerTextEncoder(self.text_encoder_name, num_neurons=num_neurons, latentD=latentD, topping=transformer_topping, role="generative") 28 | else: 29 | raise NotImplementedError 30 | 31 | # Define learned loss parameters 32 | self.decsigma_v2v = nn.Parameter( torch.zeros(1) ) # logsigma 33 | self.decsigma_jts = nn.Parameter( torch.zeros(1) ) # logsigma 34 | self.decsigma_rot = nn.Parameter( torch.zeros(1) ) # logsigma 35 | 36 | 37 | # FORWARD METHODS ---------------------------------------------------------- 38 | 39 | 40 | def encode_text(self, captions, caption_lengths): 41 | return self.text_encoder(captions, caption_lengths) 42 | 43 | def encode_pose(self, pose_body): 44 | return self.pose_encoder(pose_body) 45 | 46 | def decode_pose(self, z): 47 | return self.pose_decoder(z) 48 | 49 | def forward_autoencoder(self, poses): 50 | q_z = self.encode_pose(poses) 51 | q_z_sample = q_z.rsample() 52 | ret = {f"{k}_pose":v for k,v in self.decode_pose(q_z_sample).items()} 53 | ret.update({'q_z': q_z}) 54 | return ret 55 | 56 | def forward(self, poses, captions, caption_lengths): 57 | t_z = self.encode_text(captions, caption_lengths) 58 | q_z = self.encode_pose(poses) 59 | q_z_sample = q_z.rsample() 60 | t_z_sample = t_z.rsample() 61 | ret = {f"{k}_pose":v for k,v in self.decode_pose(q_z_sample).items()} 62 | ret.update({f"{k}_text":v for k,v in self.decode_pose(t_z_sample).items()}) 63 | ret.update({'q_z': q_z, 't_z': t_z}) 64 | return ret 65 | 66 | 67 | # SAMPLE METHODS ----------------------------------------------------------- 68 | 69 | 70 | def sample_nposes(self, captions, caption_lengths, n=1, **kwargs): 71 | t_z = self.encode_text(captions, caption_lengths) 72 | z = t_z.sample( [n] ).permute(1,0,2).flatten(0,1) 73 | decode_results = self.decode_pose(z) 74 | return {k: v.view(int(v.shape[0]/n), n, *v.shape[1:]) for k,v in decode_results.items()} 75 | 76 | def sample_str_nposes(self, s, n=1): 77 | device = self.decsigma_v2v.device 78 | # no text provided, sample pose directly from latent space 79 | if len(s)==0: 80 | z = torch.tensor(np.random.normal(0., 1., size=(n, self.latentD)), dtype=torch.float32, device=device) 81 | decode_results = self.decode_pose(z) 82 | return {k: v.view(int(v.shape[0]/n), n, *v.shape[1:]) for k,v in decode_results.items()} 83 | # otherwise, encode the text to sample a pose conditioned on it 84 | if not hasattr(self, 'tokenizer'): 85 | self.tokenizer = Tokenizer(get_tokenizer_name(self.text_encoder_name)) 86 | tokens = self.tokenizer(s).to(device=device) 87 | return self.sample_nposes(tokens.view(1, -1), torch.tensor([ len(tokens) ], dtype=tokens.dtype), n=n) 88 | 89 | def sample_str_meanposes(self, s): 90 | device = self.decsigma_v2v.device 91 | assert len(s)>0, "Please provide a non-empty text." 92 | if not hasattr(self, 'tokenizer'): 93 | self.tokenizer = Tokenizer(get_tokenizer_name(self.text_encoder_name)) 94 | tokens = self.tokenizer(s).to(device=device) 95 | t_z = self.encode_text(tokens.view(1, -1), torch.tensor([ len(tokens) ], dtype=tokens.dtype)) 96 | return self.decode_pose(t_z.mean.view(1, -1)) -------------------------------------------------------------------------------- /src/text2pose/generative/script_generative.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################################## 4 | ## text2pose ## 5 | ## Copyright (c) 2022, 2023 ## 6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 7 | ## and Naver Corporation ## 8 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 9 | ## See project root for license details. ## 10 | ############################################################## 11 | 12 | 13 | ############################################################## 14 | # SCRIPT ARGUMENTS 15 | 16 | action=$1 # (train|eval|demo) 17 | eval_data_version="H2" 18 | checkpoint_type="best" # (last|best) 19 | 20 | architecture_args=( 21 | --model CondTextPoser 22 | --latentD 32 23 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp" 24 | # --text_encoder_name 'glovebigru_vocPSA2H2' 25 | ) 26 | 27 | loss_args=( 28 | --wloss_v2v 4 --wloss_rot 2 --wloss_jts 2 29 | --wloss_kld 0.1 --wloss_kldnpmul 0.01 --wloss_kldntmul 0.01 30 | ) 31 | 32 | bonus_args=( 33 | ) 34 | 35 | fid="ret_distilbert_dataPSA2ftPSH2" 36 | 37 | pretrained="gen_distilbert_dataPSA2" # used only if phase=='finetune' 38 | 39 | 40 | ############################################################## 41 | # EXECUTE 42 | 43 | # TRAIN 44 | if [[ "$action" == *"train"* ]]; then 45 | 46 | phase=$2 # (pretrain|finetune) 47 | echo "NOTE: Expecting as argument the training phase. Got: $phase" 48 | seed=$3 49 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 50 | 51 | # PRETRAIN 52 | if [[ "$phase" == *"pretrain"* ]]; then 53 | 54 | python generative/train_generative.py --dataset "posescript-A2" \ 55 | "${architecture_args[@]}" \ 56 | "${loss_args[@]}" \ 57 | "${bonus_args[@]}" \ 58 | --lr 1e-05 --wd 0.0001 --batch_size 128 --seed $seed \ 59 | --epochs 5000 --log_step 20 --val_every 20 \ 60 | --fid $fid 61 | 62 | # FINETUNE 63 | elif [[ "$phase" == *"finetune"* ]]; then 64 | 65 | python generative/train_generative.py --dataset "posescript-H2" \ 66 | "${architecture_args[@]}" \ 67 | "${loss_args[@]}" \ 68 | "${bonus_args[@]}" \ 69 | --apply_LR_augmentation \ 70 | --lrposemul 0.1 --lrtextmul 1 \ 71 | --lr 1e-05 --wd 0.0001 --batch_size 128 --seed $seed \ 72 | --epochs 2000 --val_every 10 \ 73 | --fid $fid \ 74 | --pretrained $pretrained 75 | 76 | fi 77 | 78 | fi 79 | 80 | 81 | # EVAL QUANTITATIVELY 82 | if [[ "$action" == *"eval"* ]]; then 83 | 84 | eval_type=$2 # (regular|generate_poses|RG|GRa|GRb) 85 | echo "NOTE: Expecting as argument the evaluation phase. Got: $eval_type" 86 | 87 | # regular (fid, elbo) 88 | if [[ "$eval_type" == "regular" ]]; then 89 | 90 | # parse additional input 91 | model_path=$3 92 | echo "NOTE: Expecting as argument the path to the model to evaluate. Got: $model_path" 93 | 94 | python generative/evaluate_generative.py \ 95 | --dataset "posescript-$eval_data_version" --split "test" \ 96 | --model_path ${model_path} --checkpoint $checkpoint_type \ 97 | --fid $fid 98 | 99 | 100 | # generate poses for R/G & G/R 101 | elif [[ "$eval_type" == "generate_poses" ]]; then 102 | 103 | # parse additional input 104 | model_path=$3 105 | echo "NOTE: Expecting as argument the path to the model for which to generate sample poses. Got: $model_path" 106 | seed=$4 107 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 108 | 109 | mp=${model_path::-1}$seed/checkpoint_best.pth 110 | python generative/generate_poses.py --model_path $mp 111 | 112 | 113 | # eval R/G 114 | elif [[ "$eval_type" == "RG" ]]; then 115 | 116 | # parse additional input 117 | seed=$3 118 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 119 | model_shortname=$4 120 | echo "NOTE: Expecting as argument the shortname of the generative model to evaluate. Got: $model_shortname" 121 | retrieval_model_shortname=$5 122 | echo "NOTE: Expecting as argument the shortname of the retrieval model used for R/G. Got: $retrieval_model_shortname" 123 | 124 | # evaluate generated poses with the retrieval model 125 | retrieval_model_path=$(python config.py $retrieval_model_shortname $seed) 126 | python retrieval/evaluate_retrieval.py \ 127 | --dataset "posescript-"$eval_data_version --split 'test' \ 128 | --model_path $retrieval_model_path --checkpoint $checkpoint_type \ 129 | --generated_pose_samples $model_shortname 130 | 131 | 132 | # eval G/R, 1st step 133 | elif [[ "$eval_type" == "GRa" ]]; then 134 | 135 | # parse additional input 136 | seed=$3 137 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 138 | model_shortname=$4 139 | echo "NOTE: Expecting as argument the shortname of the generative model to evaluate. Got: $model_shortname" 140 | 141 | # define specificities for the new retrieval model 142 | args_ret=( 143 | --model 'PoseText' --latentD 512 144 | --lr_scheduler "stepLR" --lr 0.0002 --lr_gamma 0.5 145 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp" 146 | # --text_encoder_name 'glovebigru_vocPSA2H2' 147 | ) 148 | pret='ret_distilbert_dataPSA2' 149 | if [[ "$eval_data_version" == "A2" ]]; then 150 | args_ret+=(--dataset "posescript-A2" 151 | --lr_step 400 152 | --batch_size 512 153 | --epochs 1000 154 | ) 155 | elif [[ "$eval_data_version" == "H2" ]]; then 156 | args_ret+=(--dataset "posescript-H2" 157 | --lr_step 40 158 | --batch_size 32 159 | --epochs 200 160 | --pret $pret 161 | --apply_LR_augmentation 162 | ) 163 | fi 164 | echo "NOTE: Expecting in the script the spec of the new retrieval model to train on the generated poses. Got:" "${args_ret[@]}" 165 | 166 | # train a new retrieval model with the generated poses 167 | python retrieval/train_retrieval.py "${args_ret[@]}" --seed $seed \ 168 | --generated_pose_samples $model_shortname 169 | 170 | echo "IMPORTANT: please create an entry in shortname_2_model_path.txt for this newly trained retrieval model (providing a shortname and the path to the model), as it will be needed for the next evaluation step of this generative model ($model_shortname)" 171 | 172 | 173 | # eval G/R, 2st step 174 | elif [[ "$eval_type" == "GRb" ]]; then 175 | 176 | # parse additional input 177 | seed=$3 178 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 179 | spec_retrieval_model_shortname=$4 180 | echo "NOTE: Expecting as argument the shortname of the retrieval model used for G/R. Got: $spec_retrieval_model_shortname" 181 | 182 | # evaluate the retrieval model trained on generated poses, on the original poses 183 | spec_retrieval_model_path=$(python config.py $spec_retrieval_model_shortname $seed) 184 | python retrieval/evaluate_retrieval.py \ 185 | --dataset "posescript-"$eval_data_version --split 'test' \ 186 | --model_path $spec_retrieval_model_path --checkpoint $checkpoint_type 187 | 188 | fi 189 | fi 190 | 191 | 192 | # EVAL QUALITATIVELY 193 | if [[ "$action" == *"demo"* ]]; then 194 | 195 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one 196 | streamlit run generative/demo_generative.py -- --model_paths "${experiments[@]}" --checkpoint $checkpoint_type 197 | 198 | fi -------------------------------------------------------------------------------- /src/text2pose/generative_B/README.md: -------------------------------------------------------------------------------- 1 | # Text-guided 3D Human Pose Editing Model 2 | 3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._ 4 | 5 | _:warning: The evaluation of this model relies partly on a [text-to-pose retrieval model](../retrieval/README.md), see section **Extra setup**, below._ 6 | 7 | ## Model overview 8 | 9 | * **Inputs (#2)**: 3D human pose + text modifier; 10 | * **Output**: 3D human pose. 11 | 12 | ![Pose Editing model](../../../images/pose_editing_model.png) 13 | 14 | ## :crystal_ball: Demo 15 | 16 | To edit poses based on a pretrained model and example pairs of pose and (modifyable) modifier texts, run the following: 17 | 18 | ``` 19 | streamlit run generative_B/demo_generative_B.py -- --model_paths 20 | ``` 21 | 22 | :bulb: Tips: _Specify several model paths to compare models together._ 23 | 24 | ## Extra setup 25 | 26 | At the beginning of the bash script, assign to variable `fid` the shortname of the trained [text-to-pose retrieval model](../retrieval/README.md) to be used for computing the FID. 27 | 28 | Add a line in *shortname_2_model_path.txt* to indicate the path to the model corresponding to the provided shortname. 29 | 30 | ## :bullettrain_front: Train 31 | 32 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options. 33 | 34 | 35 | Then use the following command: 36 | ``` 37 | bash generative_B/script_generative_B.sh 'train' 38 | ``` 39 | 40 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model. 41 | 42 | ## :dart: Evaluate 43 | 44 | Use the following command: 45 | ``` 46 | bash generative_B/script_generative_B.sh 'eval' 47 | ``` 48 | -------------------------------------------------------------------------------- /src/text2pose/generative_B/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/generative_B/demo_generative_B.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import streamlit as st 11 | import argparse 12 | import torch 13 | import numpy as np 14 | 15 | import text2pose.config as config 16 | import text2pose.demo as demo 17 | import text2pose.utils as utils 18 | import text2pose.data as data 19 | import text2pose.utils_visu as utils_visu 20 | from text2pose.generative_B.evaluate_generative_B import load_model 21 | 22 | 23 | parser = argparse.ArgumentParser(description='Parameters for the demo.') 24 | parser.add_argument('--model_paths', nargs='+', type=str, help='Paths to the models to be compared.') 25 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help="Checkpoint to choose if model path is incomplete.") 26 | parser.add_argument('--n_generate', type=int, default=12, help="Number of poses to generate (number of samples); if considering only one model.") 27 | args = parser.parse_args() 28 | 29 | 30 | ### INPUT 31 | ################################################################################ 32 | 33 | posefix_data_version = "posefix-H" 34 | posescript_data_version = "posescript-H2" 35 | 36 | 37 | ### SETUP 38 | ################################################################################ 39 | 40 | # --- layout 41 | st.markdown(""" 42 | 47 | """, unsafe_allow_html=True) 48 | 49 | # correct the number of generated sample depending on the setting 50 | if len(args.model_paths) > 1: 51 | n_generate = 4 52 | else: 53 | n_generate = args.n_generate 54 | 55 | # --- data 56 | available_splits = ['train', 'val', 'test'] 57 | models, _, body_model = demo.setup_models(args.model_paths, args.checkpoint, load_model) 58 | dataID_2_pose_info, triplet_data = demo.setup_posefix_data(posefix_data_version) 59 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids) 60 | _, captions = demo.setup_posescript_data(posescript_data_version) 61 | 62 | # --- seed 63 | torch.manual_seed(42) 64 | np.random.seed(42) 65 | 66 | 67 | ### MAIN APP 68 | ################################################################################ 69 | 70 | # define query input interface 71 | cols_query = st.columns(3) 72 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test')) 73 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID')) 74 | number = cols_query[2].number_input("Split index or ID:", 0) 75 | st.markdown("""---""") 76 | 77 | # get query data 78 | pair_ID, pid_A, pid_B, pose_A_data, pose_B_data, pose_A_img, pose_B_img, default_modifier = demo.get_posefix_datapoint(number, query_type, split_for_research, triplet_data, pose_pairs, dataID_2_pose_info, body_model) 79 | 80 | # show query data 81 | st.write(f"**Query data:**") 82 | cols_input = st.columns([1,1,2]) 83 | # (enable PoseScript mode: description only) 84 | no_pose_A = cols_input[2].checkbox("PoseScript mode") 85 | if no_pose_A: 86 | pose_A_data = data.T_POSE.view(1, -1) 87 | pose_A_img = utils_visu.image_from_pose_data(pose_A_data, body_model, color="grey", add_ground_plane=True) 88 | pose_A_img = demo.process_img(pose_A_img[0]) 89 | pose_B_data, pose_B_img, default_description = demo.get_posescript_datapoint_from_pid(pid_B, captions, dataID_2_pose_info, body_model) 90 | # (actually show) 91 | cols_input[0].image(pose_A_img, caption="T-pose" if no_pose_A else "Pose A") 92 | cols_input[1].image(pose_B_img, caption="Annotated pose" if no_pose_A else "Annotated pose B") 93 | txt = default_description if no_pose_A else default_modifier 94 | if txt: 95 | cols_input[2].write("Annotated text:") 96 | cols_input[2].write(f"_{txt}_") 97 | else: 98 | cols_input[2].write("_(Not annotated.)_") 99 | 100 | # get input modifier 101 | modifier = cols_input[2].text_area("Pose modifier:", 102 | placeholder="Move your right arm... lift your left leg...", 103 | value=txt, 104 | height=None, max_chars=None) 105 | 106 | analysis = cols_input[2].checkbox('Analysis') # whether to show the reconstructed pose and the mean sample pose in addition of some samples 107 | 108 | # generate pose B 109 | if analysis: 110 | 111 | st.markdown("""---""") 112 | st.write("**Generated poses** (*The reconstructed B pose is shown in green; the mean pose in red; and samples in blue.*):") 113 | n_generate = 2 114 | nb_cols = 2 + n_generate # reconstructed pose + mean sample pose + n_generate sample poses: all must fit in one row, for each studied model 115 | 116 | for i, model in enumerate(models): 117 | with torch.no_grad(): 118 | rec_pose_data = model.forward_autoencoder(pose_B_data)['pose_body_pose'].view(1, -1) 119 | gen_pose_data_mean = model.sample_str_meanposes(pose_A_data, modifier)['pose_body'].view(1, -1) 120 | gen_pose_data_samples = model.sample_str_nposes(pose_A_data, modifier, n=n_generate)['pose_body'][0,...].view(n_generate, -1) 121 | 122 | # render poses 123 | imgs = utils_visu.image_from_pose_data(rec_pose_data, body_model, color='green', add_ground_plane=True, two_views=60) 124 | imgs += utils_visu.image_from_pose_data(gen_pose_data_mean, body_model, color='red', add_ground_plane=True, two_views=60) 125 | imgs += utils_visu.image_from_pose_data(gen_pose_data_samples, body_model, color='blue', add_ground_plane=True, two_views=60) 126 | 127 | # display images 128 | cols = st.columns(nb_cols+1) # +1 to display model info 129 | cols[0].markdown(f'

{args.model_paths[i]}

', unsafe_allow_html=True) 130 | for i in range(nb_cols): 131 | cols[i%nb_cols+1].image(demo.process_img(imgs[i])) 132 | st.markdown("""---""") 133 | 134 | else: 135 | 136 | st.markdown("""---""") 137 | st.write("**Generated poses:**") 138 | 139 | for i, model in enumerate(models): 140 | with torch.no_grad(): 141 | gen_pose_data_samples = model.sample_str_nposes(pose_A_data, modifier, n=n_generate)['pose_body'][0,...].view(n_generate, -1) 142 | 143 | # render poses 144 | imgs = utils_visu.image_from_pose_data(gen_pose_data_samples, body_model, color='blue', add_ground_plane=True, two_views=60) 145 | 146 | # display images 147 | if len(models) > 1: 148 | cols = st.columns(n_generate+1) # +1 to display model info 149 | cols[0].markdown(f'

{args.model_paths[i]}

', unsafe_allow_html=True) 150 | for i in range(n_generate): 151 | cols[i%n_generate+1].image(demo.process_img(imgs[i])) 152 | st.markdown("""---""") 153 | else: 154 | cols = st.columns(demo.nb_cols) 155 | for i in range(n_generate): 156 | cols[i%demo.nb_cols].image(demo.process_img(imgs[i])) 157 | st.markdown("""---""") 158 | st.write(f"_Results obtained with model: {args.model_paths[0]}_") -------------------------------------------------------------------------------- /src/text2pose/generative_B/script_generative_B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################################## 4 | ## text2pose ## 5 | ## Copyright (c) 2023 ## 6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 7 | ## and Naver Corporation ## 8 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 9 | ## See project root for license details. ## 10 | ############################################################## 11 | 12 | 13 | ############################################################## 14 | # SCRIPT ARGUMENTS 15 | 16 | action=$1 17 | checkpoint_type="best" # (last|best) 18 | 19 | architecture_args=( 20 | --model PoseBGenerator 21 | --correction_module_mode "tirg" 22 | --latentD 32 --special_text_latentD 128 23 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp" 24 | # --text_encoder_name 'glovebigru_vocPFAHPP' 25 | ) 26 | 27 | loss_args=( 28 | --wloss_kld 1.0 29 | --kld_epsilon 10.0 30 | --wloss_v2v 1.0 --wloss_rot 1.0 --wloss_jts 1.0 31 | ) 32 | 33 | bonus_args=( 34 | ) 35 | 36 | fid="ret_distilbert_dataPSA2ftPSH2" 37 | 38 | pretrained="b_gen_distilbert_dataPFA" # used only if phase=='finetune' 39 | 40 | 41 | ############################################################## 42 | # EXECUTE 43 | 44 | # TRAIN 45 | if [[ "$action" == *"train"* ]]; then 46 | 47 | phase=$2 # (pretrain|finetune) 48 | echo "NOTE: Expecting as argument the training phase. Got: $phase" 49 | seed=$3 50 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 51 | 52 | # PRETRAIN 53 | if [[ "$phase" == *"pretrain"* ]]; then 54 | 55 | python generative_B/train_generative_B.py --dataset "posefix-A" \ 56 | "${architecture_args[@]}" \ 57 | "${loss_args[@]}" \ 58 | "${bonus_args[@]}" \ 59 | --lr 0.00001 --wd 0.0001 --batch_size 128 --seed $seed \ 60 | --epochs 5000 --log_step 100 --val_every 20 \ 61 | --fid $fid 62 | 63 | # FINETUNE 64 | elif [[ "$phase" == *"finetune"* ]]; then 65 | 66 | python generative_B/train_generative_B.py --dataset "posefix-HPP" \ 67 | "${architecture_args[@]}" \ 68 | "${loss_args[@]}" \ 69 | "${bonus_args[@]}" \ 70 | --apply_LR_augmentation \ 71 | --lrposemul 0.1 --lrtextmul 1 \ 72 | --lr 0.000001 --wd 0.00001 --batch_size 128 --seed $seed \ 73 | --epochs 5000 --log_step 20 --val_every 20 \ 74 | --fid $fid \ 75 | --pretrained $pretrained 76 | 77 | fi 78 | 79 | fi 80 | 81 | 82 | # EVAL QUANTITATIVELY 83 | if [[ "$action" == *"eval"* ]]; then 84 | 85 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one 86 | 87 | for model_path in "${experiments[@]}" 88 | do 89 | echo $model_path 90 | python generative_B/evaluate_generative_B.py --dataset "posefix-H" \ 91 | --model_path ${model_path} --checkpoint $checkpoint_type \ 92 | --fid $fid \ 93 | --split test 94 | # --special_eval 95 | done 96 | fi 97 | 98 | 99 | # EVAL QUALITATIVELY 100 | if [[ "$action" == *"demo"* ]]; then 101 | 102 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one 103 | streamlit run generative_B/demo_generative_B.py -- --model_paths "${experiments[@]}" --checkpoint $checkpoint_type 104 | 105 | fi -------------------------------------------------------------------------------- /src/text2pose/generative_caption/README.md: -------------------------------------------------------------------------------- 1 | # Pose Description Generation Model 2 | 3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._ 4 | 5 | _:warning: The evaluation of this model relies partly on a [text-to-pose retrieval model](../retrieval/README.md) and a [text-conditionned pose generation model](../generative/README.md), see section **Extra setup**, below._ 6 | 7 | ## Model overview 8 | 9 | * **Input**: 3D human pose; 10 | * **Output**: pose description. 11 | 12 | ![Description generation model](../../../images/caption_generation_model.png) 13 | 14 | ## :crystal_ball: Demo 15 | 16 | To generate text descriptions based on a pretrained model and example 3D human poses, run the following: 17 | 18 | ``` 19 | streamlit run generative_caption/demo_generative_caption.py -- --model_paths 20 | ``` 21 | 22 | :bulb: Tips: _Specify several model paths to compare models together._ 23 | 24 | ## Extra setup 25 | 26 | At the beginning of the bash script, indicate the shortnames of the trained models used for evaluation: 27 | * `fid`: text-to-pose retrieval model ([info](../retrieval/README.md)), 28 | * `pose_generative_model`: text-to-pose generative model ([info](../generative/README.md)), 29 | * `textret_model`: text-to-pose retrieval model, eg. the same as for `fid` ([info](../retrieval/README.md)). 30 | 31 | Indicate the paths to the models corresponding to each of these shortnames in *shortname_2_model_path.txt*. 32 | 33 | ## :bullettrain_front: Train 34 | 35 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options. 36 | 37 | Then use the following command: 38 | ``` 39 | bash generative_caption/script_generative_caption.sh 'train' 40 | ``` 41 | 42 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model. 43 | 44 | ## :dart: Evaluate 45 | 46 | Use the following command: 47 | ``` 48 | bash generative_caption/script_generative_caption.sh 'eval' 49 | ``` 50 | -------------------------------------------------------------------------------- /src/text2pose/generative_caption/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/generative_caption/demo_generative_caption.py: -------------------------------------------------------------------------------- 1 | 2 | ############################################################## 3 | ## text2pose ## 4 | ## Copyright (c) 2023 ## 5 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 6 | ## and Naver Corporation ## 7 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 8 | ## See project root for license details. ## 9 | ############################################################## 10 | 11 | import streamlit as st 12 | import argparse 13 | import torch 14 | 15 | import text2pose.demo as demo 16 | from text2pose.generative_caption.evaluate_generative_caption import load_model 17 | 18 | 19 | parser = argparse.ArgumentParser(description='Parameters for the demo.') 20 | parser.add_argument('--model_paths', nargs='+', type=str, help='Paths to the models to be compared.') 21 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help='Checkpoint to choose if model path is incomplete.') 22 | args = parser.parse_args() 23 | 24 | 25 | ### INPUT 26 | ################################################################################ 27 | 28 | data_version = "posescript-H2" 29 | 30 | 31 | ### SETUP 32 | ################################################################################ 33 | 34 | # --- layout 35 | st.markdown(""" 36 | 41 | """, unsafe_allow_html=True) 42 | 43 | # --- data 44 | available_splits = ['train', 'val', 'test'] 45 | models, _, body_model = demo.setup_models(args.model_paths, args.checkpoint, load_model) 46 | dataID_2_pose_info, captions = demo.setup_posescript_data(data_version) 47 | 48 | 49 | ### MAIN APP 50 | ################################################################################ 51 | 52 | # define query input interface 53 | cols_query = st.columns(3) 54 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test')) 55 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID'), index=1) 56 | number = cols_query[2].number_input("Split index or ID:", 0) 57 | st.markdown("""---""") 58 | 59 | # get query data 60 | pose_ID, pose_data, pose_img, default_description = demo.get_posescript_datapoint(number, query_type, split_for_research, captions, dataID_2_pose_info, body_model) 61 | 62 | # show query data 63 | cols_input = st.columns(2) 64 | cols_input[0].image(pose_img, caption="Annotated pose") 65 | if default_description: 66 | cols_input[1].write("Annotated text:") 67 | cols_input[1].write(f"_{default_description}_") 68 | else: 69 | cols_input[1].write("_(Not annotated.)_") 70 | 71 | # generate text 72 | st.markdown("""---""") 73 | st.write("**Text generation:**") 74 | for i, model in enumerate(models): 75 | 76 | with torch.no_grad(): 77 | texts, scores = model.generate_text(pose_data.view(1, -1, 3)) # (1, njoints, 3) 78 | 79 | if len(models) > 1: 80 | cols = st.columns(2) 81 | cols[0].markdown(f'

{args.model_paths[i]}

', unsafe_allow_html=True) 82 | cols[1].write(texts[0]) 83 | st.markdown("""---""") 84 | else: 85 | st.write(texts[0]) 86 | st.markdown("""---""") 87 | st.write(f"_Results obtained with model: {args.model_paths[0]}_") -------------------------------------------------------------------------------- /src/text2pose/generative_caption/model_generative_caption.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import torch.nn as nn 11 | 12 | import text2pose.config as config 13 | from text2pose.encoders.tokenizers import get_text_encoder_or_decoder_module_name 14 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder 15 | from text2pose.encoders.text_decoders import TransformerTextDecoder, ModalityInputAdapter 16 | 17 | class DescriptionGenerator(nn.Module): 18 | 19 | def __init__(self, num_neurons=512, encoder_latentD=32, decoder_latentD=512, 20 | num_body_joints=config.NB_INPUT_JOINTS, 21 | decoder_nhead=8, decoder_nlayers=4, text_decoder_name="", 22 | transformer_mode="crossattention"): 23 | super(DescriptionGenerator, self).__init__() 24 | 25 | # Define pose encoder 26 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons, latentD=encoder_latentD, num_body_joints=num_body_joints, role="retrieval") 27 | 28 | # Define modality input adaptor 29 | self.modalityInputAdapter = ModalityInputAdapter(inlatentD=encoder_latentD, 30 | outlatentD=decoder_latentD) 31 | 32 | # Define text decoder 33 | self.text_decoder_name = text_decoder_name 34 | self.transformer_mode = transformer_mode 35 | module_ref = get_text_encoder_or_decoder_module_name(text_decoder_name) 36 | if module_ref == "transformer": 37 | self.text_decoder = TransformerTextDecoder(self.text_decoder_name, 38 | nhead=decoder_nhead, 39 | nlayers=decoder_nlayers, 40 | decoder_latentD=decoder_latentD, 41 | transformer_mode=transformer_mode) 42 | else: 43 | raise NotImplementedError 44 | 45 | 46 | def encode_pose(self, pose_body): 47 | return self.pose_encoder(pose_body) 48 | 49 | def decode_text(self, z, captions, caption_lengths, train=False): 50 | return self.text_decoder(z, captions, caption_lengths, train=train) 51 | 52 | def forward(self, poses, captions, caption_lengths): 53 | z = self.encode_pose(poses) 54 | z = self.modalityInputAdapter(z) 55 | decoded = self.decode_text(z, captions, caption_lengths, train=True) 56 | return dict(z=z, **decoded) 57 | 58 | def generate_text(self, poses): 59 | z = self.encode_pose(poses) 60 | z = self.modalityInputAdapter(z) 61 | decoded_texts, likelihood_scores = self.text_decoder.generate_greedy(z) 62 | return decoded_texts, likelihood_scores -------------------------------------------------------------------------------- /src/text2pose/generative_caption/script_generative_caption.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################################## 4 | ## text2pose ## 5 | ## Copyright (c) 2023 ## 6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 7 | ## and Naver Corporation ## 8 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 9 | ## See project root for license details. ## 10 | ############################################################## 11 | 12 | 13 | ############################################################## 14 | # SCRIPT ARGUMENTS 15 | 16 | action=$1 # (train|eval|demo) 17 | checkpoint_type="best" # (last|best) 18 | 19 | architecture_args=( 20 | --model DescriptionGenerator 21 | --text_decoder_name "transformer_vocPSA2H2" 22 | --transformer_mode 'crossattention' 23 | --decoder_nhead 8 --decoder_nlayers 4 24 | --latentD 32 --decoder_latentD 512 25 | ) 26 | 27 | bonus_args=( 28 | ) 29 | 30 | fid="ret_distilbert_dataPSA2ftPSH2" 31 | pose_generative_model="gen_distilbert_dataPSA2ftPSH2" 32 | textret_model="ret_distilbert_dataPSA2ftPSH2" 33 | 34 | pretrained="capgen_CAtransfPSA2H2_dataPSA2" # used only if phase=='finetune' 35 | 36 | 37 | ############################################################## 38 | # EXECUTE 39 | 40 | # TRAIN 41 | if [[ "$action" == *"train"* ]]; then 42 | 43 | phase=$2 # (pretrain|finetune) 44 | echo "NOTE: Expecting as argument the training phase. Got: $phase" 45 | seed=$3 46 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 47 | 48 | # PRETRAIN 49 | if [[ "$phase" == *"pretrain"* ]]; then 50 | 51 | python generative_caption/train_generative_caption.py --dataset "posescript-A2" \ 52 | "${architecture_args[@]}" \ 53 | "${bonus_args[@]}" \ 54 | --lr 0.0001 --wd 0.0001 --batch_size 128 --seed $seed \ 55 | --epochs 3000 --log_step 100 --val_every 20 \ 56 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model 57 | 58 | # FINETUNE 59 | elif [[ "$phase" == *"finetune"* ]]; then 60 | 61 | python generative_caption/train_generative_caption.py --dataset "posescript-H2" \ 62 | "${architecture_args[@]}" \ 63 | "${bonus_args[@]}" \ 64 | --apply_LR_augmentation \ 65 | --lr 0.00001 --wd 0.0001 --batch_size 128 --seed $seed \ 66 | --epochs 2000 --log_step 100 --val_every 20 \ 67 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model \ 68 | --pretrained $pretrained 69 | 70 | fi 71 | 72 | fi 73 | 74 | 75 | # EVAL QUANTITATIVELY 76 | if [[ "$action" == *"eval"* ]]; then 77 | 78 | experiments=( 79 | $2 80 | "GT" "random" "auto_posescript-A2_cap1" # control metrics 81 | ) 82 | 83 | for model_path in "${experiments[@]}" 84 | do 85 | echo $model_path 86 | python generative_caption/evaluate_generative_caption.py --dataset "posescript-H2" \ 87 | --model_path ${model_path} --checkpoint $checkpoint_type \ 88 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model \ 89 | --split test 90 | done 91 | fi 92 | 93 | 94 | # EVAL QUALITATIVELY 95 | if [[ "$action" == *"demo"* ]]; then 96 | 97 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one 98 | streamlit run generative_caption/demo_generative_caption.py -- --model_paths "${experiments[@]}" --checkpoint $checkpoint_type 99 | 100 | fi -------------------------------------------------------------------------------- /src/text2pose/generative_modifier/README.md: -------------------------------------------------------------------------------- 1 | # Pose-based Correctional Text Generation Model 2 | 3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._ 4 | 5 | _:warning: The evaluation of this model relies partly on a [pose-to-text retrieval model](../retrieval/README.md), a [pair-to-text retrieval model](../retrieval_modifier/README.md) and a [text-guided pose editing model](../generative_B/README.md), see section **Extra setup**, below._ 6 | 7 | ## Model overview 8 | 9 | * **Inputs (#2)**: a pair of 3D human poses (pose A + pose B); 10 | * **Output**: textual instruction explaining how to go from pose A to pose B. 11 | 12 | ![Modifier generation model](../../../images/feedback_generation_model.png) 13 | 14 | ## :crystal_ball: Demo 15 | 16 | To generate text instructions based on a pretrained model and examples of 3D human pose pairs, run the following: 17 | 18 | ``` 19 | streamlit run generative_modifier/demo_generative_modifier.py -- --model_paths 20 | ``` 21 | 22 | :bulb: Tips: _Specify several model paths to compare models together._ 23 | 24 | ## Extra setup 25 | 26 | At the beginning of the bash script, indicate the shortnames of the trained models used for evaluation: 27 | * `fid`: text-to-pose retrieval model ([info](../retrieval/README.md)), 28 | * `pose_generative_model`: text-guided pose editing model ([info](../generative_B/README.md)), 29 | * `textret_model`: pair-to-text retrieval model ([info](../retrieval_modifier/README.md)). 30 | 31 | Indicate the paths to the models corresponding to each of these shortnames in *shortname_2_model_path.txt*. 32 | 33 | ## :bullettrain_front: Train 34 | 35 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options. 36 | 37 | Then use the following command: 38 | ``` 39 | bash generative_modifier/script_generative_modifier.sh 'train' 40 | ``` 41 | 42 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model. 43 | 44 | ## :dart: Evaluate 45 | 46 | Use the following command: 47 | ``` 48 | bash generative_modifier/script_generative_modifier.sh 'eval' 49 | ``` 50 | -------------------------------------------------------------------------------- /src/text2pose/generative_modifier/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/generative_modifier/demo_generative_modifier.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import streamlit as st 11 | import argparse 12 | import torch 13 | 14 | import text2pose.config as config 15 | import text2pose.demo as demo 16 | import text2pose.utils as utils 17 | import text2pose.data as data 18 | import text2pose.utils_visu as utils_visu 19 | from text2pose.generative_modifier.evaluate_generative_modifier import load_model 20 | 21 | 22 | parser = argparse.ArgumentParser(description='Parameters for the demo.') 23 | parser.add_argument('--model_paths', nargs='+', type=str, help='Path to the models to be compared.') 24 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help="Checkpoint to choose if model path is incomplete.") 25 | args = parser.parse_args() 26 | 27 | 28 | ### INPUT 29 | ################################################################################ 30 | 31 | posefix_data_version = "posefix-H" 32 | posescript_data_version = "posescript-H2" 33 | 34 | 35 | ### SETUP 36 | ################################################################################ 37 | 38 | # --- layout 39 | st.markdown(""" 40 | 45 | """, unsafe_allow_html=True) 46 | 47 | # --- data 48 | available_splits = ['train', 'val', 'test'] 49 | models, _, body_model = demo.setup_models(args.model_paths, args.checkpoint, load_model) 50 | dataID_2_pose_info, triplet_data = demo.setup_posefix_data(posefix_data_version) 51 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids) 52 | _, captions = demo.setup_posescript_data(posescript_data_version) 53 | 54 | 55 | ### MAIN APP 56 | ################################################################################ 57 | 58 | # define query input interface 59 | cols_query = st.columns(3) 60 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test')) 61 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID')) 62 | number = cols_query[2].number_input("Split index or ID:", 0) 63 | st.markdown("""---""") 64 | 65 | # get query data 66 | pair_ID, pid_A, pid_B, pose_A_data, pose_B_data, pose_A_img, pose_B_img, default_modifier = demo.get_posefix_datapoint(number, query_type, split_for_research, triplet_data, pose_pairs, dataID_2_pose_info, body_model) 67 | 68 | # show query data 69 | st.write(f"**Query data:**") 70 | cols_input = st.columns([1,1,2]) 71 | # (enable PoseScript mode: description only) 72 | no_pose_A = cols_input[2].checkbox("PoseScript mode") 73 | if no_pose_A: 74 | pose_A_data = data.T_POSE.view(1, -1) 75 | pose_A_img = utils_visu.image_from_pose_data(pose_A_data, body_model, color="grey", add_ground_plane=True) 76 | pose_A_img = demo.process_img(pose_A_img[0]) 77 | pose_B_data, pose_B_img, default_description = demo.get_posescript_datapoint_from_pid(pid_B, captions, dataID_2_pose_info, body_model) 78 | # (actually show) 79 | cols_input[0].image(pose_A_img, caption="T-pose" if no_pose_A else "Pose A") 80 | cols_input[1].image(pose_B_img, caption="Annotated pose" if no_pose_A else "Annotated pose B") 81 | txt = default_description if no_pose_A else default_modifier 82 | if txt: 83 | cols_input[2].write("Annotated text:") 84 | cols_input[2].write(f"_{txt}_") 85 | else: 86 | cols_input[2].write("_(Not annotated.)_") 87 | 88 | # generate text 89 | st.markdown("""---""") 90 | st.write("**Text generation:**") 91 | for i, model in enumerate(models): 92 | 93 | with torch.no_grad(): 94 | texts, scores = model.generate_text(pose_A_data.view(1, -1, 3), pose_B_data.view(1, -1, 3)) # (1, njoints, 3) 95 | 96 | if len(models) > 1: 97 | cols = st.columns(2) 98 | cols[0].markdown(f'

{args.model_paths[i]}

', unsafe_allow_html=True) 99 | cols[1].write(texts[0]) 100 | st.markdown("""---""") 101 | else: 102 | st.write(texts[0]) 103 | st.markdown("""---""") 104 | st.write(f"_Results obtained with model: {args.model_paths[0]}_") -------------------------------------------------------------------------------- /src/text2pose/generative_modifier/model_generative_modifier.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import torch.nn as nn 11 | 12 | import text2pose.config as config 13 | from text2pose.encoders.tokenizers import get_text_encoder_or_decoder_module_name 14 | from text2pose.encoders.modules import TIRG 15 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder 16 | from text2pose.encoders.text_decoders import TransformerTextDecoder, ModalityInputAdapter 17 | 18 | 19 | class FeedbackGenerator(nn.Module): 20 | 21 | def __init__(self, num_neurons=512, encoder_latentD=32, comparison_latentD=32, 22 | num_body_joints=config.NB_INPUT_JOINTS, 23 | decoder_latentD=512, decoder_nhead=8, decoder_nlayers=4, text_decoder_name="", 24 | comparison_module_mode="tirg", transformer_mode="crossattention"): 25 | super(FeedbackGenerator, self).__init__() 26 | 27 | # Define pose encoder 28 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons, latentD=encoder_latentD, num_body_joints=num_body_joints, role="retrieval") 29 | 30 | # Define fusing module 31 | self.comparison_module = ComparisonModule(inlatentD=encoder_latentD, 32 | outlatentD=comparison_latentD, 33 | mode=comparison_module_mode) 34 | 35 | # Define modality input adaptor 36 | self.modalityInputAdapter = ModalityInputAdapter(inlatentD=comparison_latentD, 37 | outlatentD=decoder_latentD) 38 | 39 | # Define text decoder 40 | self.text_decoder_name = text_decoder_name 41 | self.transformer_mode = transformer_mode 42 | module_ref = get_text_encoder_or_decoder_module_name(text_decoder_name) 43 | if module_ref == "transformer": 44 | self.text_decoder = TransformerTextDecoder(self.text_decoder_name, 45 | nhead=decoder_nhead, 46 | nlayers=decoder_nlayers, 47 | decoder_latentD=decoder_latentD, 48 | transformer_mode=transformer_mode) 49 | else: 50 | raise NotImplementedError 51 | 52 | 53 | def encode_pose(self, pose_body): 54 | return self.pose_encoder(pose_body) 55 | 56 | def decode_text(self, z, captions, caption_lengths, train=False): 57 | return self.text_decoder(z, captions, caption_lengths, train=train) 58 | 59 | def fuse_input_poses(self, embed_poses_A, embed_poses_B): 60 | z = self.comparison_module(embed_poses_A, embed_poses_B) 61 | z = self.modalityInputAdapter(z) 62 | return z 63 | 64 | def forward(self, poses_A, captions, caption_lengths, poses_B): 65 | z_a = self.encode_pose(poses_A) 66 | z_b = self.encode_pose(poses_B) 67 | z = self.fuse_input_poses(z_a, z_b) 68 | decoded = self.decode_text(z, captions, caption_lengths, train=True) 69 | return dict(z=z, **decoded) 70 | 71 | def generate_text(self, poses_A, poses_B): 72 | z_a = self.encode_pose(poses_A) 73 | z_b = self.encode_pose(poses_B) 74 | z = self.fuse_input_poses(z_a, z_b) 75 | decoded_texts, likelihood_scores = self.text_decoder.generate_greedy(z) 76 | return decoded_texts, likelihood_scores 77 | 78 | 79 | class ComparisonModule(nn.Module): 80 | """ 81 | Given two poses A and B, compute an embedding representing the result from 82 | the comparison of the two poses. 83 | """ 84 | 85 | def __init__(self, inlatentD, outlatentD, mode="tirg"): 86 | super(ComparisonModule, self).__init__() 87 | 88 | self.inlatentD = inlatentD 89 | self.outlatentD = outlatentD 90 | self.mode = mode 91 | 92 | if mode == "tirg": 93 | self.tirg = TIRG(input_dim=[inlatentD, inlatentD], output_dim=outlatentD, out_l2_normalize=False) 94 | self.forward = self.forward_tirg 95 | else: 96 | print(f"Name for the mode of the comparison module is unknown (provided {mode}).") 97 | raise NotImplementedError 98 | 99 | 100 | def forward_tirg(self, pose_A_embeddings, pose_B_embeddings): 101 | return self.tirg.query_compositional_embedding(pose_B_embeddings, pose_A_embeddings) -------------------------------------------------------------------------------- /src/text2pose/generative_modifier/script_generative_modifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################################## 4 | ## text2pose ## 5 | ## Copyright (c) 2023 ## 6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 7 | ## and Naver Corporation ## 8 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 9 | ## See project root for license details. ## 10 | ############################################################## 11 | 12 | 13 | ############################################################## 14 | # SCRIPT ARGUMENTS 15 | 16 | action=$1 # (train|eval|demo) 17 | checkpoint_type="best" # (last|best) 18 | 19 | architecture_args=( 20 | --model FeedbackGenerator 21 | --text_decoder_name "transformer_vocPFAHPP" 22 | --transformer_mode 'crossattention' 23 | --comparison_module_mode 'tirg' 24 | --decoder_nhead 8 --decoder_nlayers 4 25 | --latentD 32 --comparison_latentD 32 --decoder_latentD 512 26 | ) 27 | 28 | bonus_args=( 29 | ) 30 | 31 | fid="ret_distilbert_dataPSA2ftPSH2" 32 | pose_generative_model="b_gen_distilbert_dataPFAftPFH" 33 | textret_model="modret_distilbert_dataPFAftPFH" 34 | 35 | pretrained="modgen_CAtransfPFAHPP_dataPFA" # used only if phase=='finetune' 36 | 37 | 38 | ############################################################## 39 | # EXECUTE 40 | 41 | # TRAIN 42 | if [[ "$action" == *"train"* ]]; then 43 | 44 | phase=$2 # (pretrain|finetune) 45 | echo "NOTE: Expecting as argument the training phase. Got: $phase" 46 | seed=$3 47 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 48 | 49 | # PRETRAIN 50 | if [[ "$phase" == *"pretrain"* ]]; then 51 | 52 | python generative_modifier/train_generative_modifier.py --dataset "posefix-A" \ 53 | "${architecture_args[@]}" \ 54 | "${bonus_args[@]}" \ 55 | --lr 0.0001 --wd 0.0001 --batch_size 128 --seed $seed \ 56 | --epochs 3000 --log_step 100 --val_every 20 \ 57 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model 58 | 59 | # FINETUNE 60 | elif [[ "$phase" == *"finetune"* ]]; then 61 | 62 | python generative_modifier/train_generative_modifier.py --dataset "posefix-H" \ 63 | "${architecture_args[@]}" \ 64 | "${bonus_args[@]}" \ 65 | --apply_LR_augmentation \ 66 | --lr 0.00001 --wd 0.0001 --batch_size 128 --seed $seed \ 67 | --epochs 2000 --log_step 100 --val_every 20 \ 68 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model \ 69 | --pretrained $pretrained 70 | 71 | fi 72 | 73 | fi 74 | 75 | 76 | # EVAL QUANTITATIVELY 77 | if [[ "$action" == *"eval"* ]]; then 78 | 79 | experiments=( 80 | $2 81 | "GT" "random" "auto_posefix-A_cap0" # control metrics 82 | ) 83 | 84 | for model_path in "${experiments[@]}" 85 | do 86 | echo $model_path 87 | python generative_modifier/evaluate_generative_modifier.py --dataset "posefix-H" \ 88 | --model_path ${model_path} --checkpoint $checkpoint_type \ 89 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model \ 90 | --split test 91 | done 92 | fi 93 | 94 | 95 | # EVAL QUALITATIVELY 96 | if [[ "$action" == *"demo"* ]]; then 97 | 98 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one 99 | streamlit run generative_modifier/demo_generative_modifier.py -- --model_paths "${experiments[@]}" --checkpoint $checkpoint_type 100 | 101 | fi -------------------------------------------------------------------------------- /src/text2pose/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def BBC(scores): 7 | # build the ground truth label tensor: the diagonal corresponds to the 8 | # correct classification 9 | GT_labels = torch.arange(scores.shape[0], device=scores.device).long() 10 | loss = F.cross_entropy(scores, GT_labels) # mean reduction 11 | return loss 12 | 13 | 14 | def symBBC(scores): 15 | x2y_loss = BBC(scores) 16 | y2x_loss = BBC(scores.t()) 17 | return (x2y_loss + y2x_loss) / 2.0 18 | 19 | 20 | def laplacian_nll(x_tilde, x, log_sigma): 21 | """ Negative log likelihood of an isotropic Laplacian density """ 22 | log_norm = - (np.log(2) + log_sigma) 23 | log_energy = - (torch.abs(x_tilde - x)) / torch.exp(log_sigma) 24 | return - (log_norm + log_energy) 25 | 26 | 27 | def gaussian_nll(x_tilde, x, log_sigma): 28 | """ Negative log-likelihood of an isotropic Gaussian density """ 29 | log_norm = - 0.5 * (np.log(2 * np.pi) + log_sigma) 30 | log_energy = - 0.5 * F.mse_loss(x_tilde, x, reduction='none') / torch.exp(log_sigma) 31 | return - (log_norm + log_energy) -------------------------------------------------------------------------------- /src/text2pose/posefix/README.md: -------------------------------------------------------------------------------- 1 | # About the PoseFix dataset 2 | 3 | ## :inbox_tray: Download 4 | 5 | **License.** 6 | *The PoseScript dataset is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license. 7 | A summary of the CC BY-NC-SA 4.0 license is located [here](https://creativecommons.org/licenses/by-nc-sa/4.0/). 8 | The CC BY-NC-SA 4.0 license is located [here](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).* 9 | 10 | | Version | Link | 11 | |---|---| 12 | | ICCV23 | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/posefix_dataset_release.zip) | 13 | 14 |
15 | Dataset content. 16 | 17 | * a file linking each pose ID to the reference of its corresponding pose sequence in AMASS, and its frame index; 18 | * a file linking each pair ID (index) to a pair of pose IDs; 19 | * a file linking each pair ID with its text modifiers (separate files for automatically generated modifiers, human-written ones and paraphrases); 20 | * files listing pair IDs for each split. 21 | 22 | Please refer to the provided README for more details. 23 |
24 | 25 | 26 | ## :crystal_ball: Take a quick look 27 | 28 | To take a quick look at the data (ie. look at the poses under different viewpoints, the human-written modifier and the paraphrases collected with InstructGPT when available, as well as the automatic modifiers): 29 | 30 | ``` 31 | streamlit run posefix/explore_posefix.py 32 | ``` 33 | 34 | ## :page_with_curl: Generate automatic modifiers 35 | 36 | ### Overview of the pipeline 37 | 38 | Given a pair of 3D poses, we use paircodes to extract semantic pair-relationship information, and posecodes for each of the two poses to enrich the instructions with "absolute" information. These codes are then selected, merged or combined (when relevant) before being converted into a structural instruction in natural language explaining how to go from the first pose to the second pose. Letters ‘L’ and ‘R’ stand for ‘left’ and ‘right’ respectively. 39 | 40 | ![Correcting pipeline](../../../images/comparative_pipeline.png) 41 | 42 | Please refer to the paper and the supplementary material for more extensive explanations. 43 | 44 | ### Generate modifiers 45 | 46 | To generate automatic modifiers, please follow these steps: 47 | 48 | - **compute joint coordinates for all poses** 49 | ``` 50 | python posescript/compute_coords.py 51 | ``` 52 | 53 | - **compute rotation change for all pairs** 54 | ``` 55 | python posefix/compute_rotation_change.py 56 | ``` 57 | 58 | - (optional) **modify corrective data as you see fit**: 59 | 60 | (Re)Define posecodes & paircodes (categories, thresholds, tolerable noise levels, eligibility), super-posecodes & super-paircodes, template sentences and so forth by modifiying both *posescript/captioning_data.py* and *posefix/corrective_data.py*. The data structures are extensively explained in these files, and one can follow some marks (`ADD_VIRTUAL_JOINT`, `ADD_POSECODE_KIND`, `ADD_SUPER_POSECODE`, `ADD_PAIRCODE_KIND`, `ADD_SUPER_PAIRCODE`) to add new instruction material. Note that some posecodes (initially created for [single pose description](../posescript/README.md)) are also used in this pipeline. 61 | 62 | - **generate automatic modifiers** 63 | 64 | *Possible arguments are:* 65 | - `--saving_dir`: general location for saving generated instructions and data related to them (default: */generated_instructions/*) 66 | - `--version_name`: name of the version. Will be used to create a subdirectory of `--saving_dir` in which to save all files (instructions & intermediary results). Default is 'tmp'. 67 | - `--simplified_instructions`: produce a simplified version of the instructions (basically: no aggregation, no randomly referring to a body part by a substitute word). 68 | - `--random_skip`: randomly skip some non-essential (ie. not so rare) posecodes & paircodes. 69 | 70 |
71 | 72 | For instance, modifiers can be generated using the following command: 73 | ``` 74 | python posefix/correcting.py --version_name --random_skip 75 | ``` 76 | 77 | To work on a small subset, one can use the following command: 78 | ``` 79 | python posefix/correcting.py --debug 80 | ``` 81 | and specify which pair to study by modifying the list of pair IDs marked with `SPECIFIC_INPUT_PAIR_IDS` in *posefix/correcting.py*. 82 | 83 |
84 | To generate caption versions similar to the ICCV 2023 paper (posefix-A): 85 | 86 | | Version | Command | 87 | |---------|---------| 88 | | pfA | `python posefix/correcting.py --version_name captions_pfA` | 89 | | pfB | `python posefix/correcting.py --version_name captions_pfB --random_skip` | 90 | | pfC | `python posefix/correcting.py --version_name captions_pfC --random_skip --simplified_captions` | 91 | 92 | *Note that some paircodes were added since, for the release of PoseEmbroider.* 93 |
94 | 95 | ## Citation 96 | 97 | If you use this code or the PoseFix dataset, please cite the following paper: 98 | 99 | ```bibtex 100 | @inproceedings{delmas2023posefix, 101 | title={{PoseFix: Correcting 3D Human Poses with Natural Language}}, 102 | author={{Delmas, Ginger and Weinzaepfel, Philippe and Moreno-Noguer, Francesc and Rogez, Gr\'egory}}, 103 | booktitle={{ICCV}}, 104 | year={2023} 105 | } 106 | ``` 107 | 108 | Please also remember to follow AMASS's respective citation guideline if you use the AMASS data. 109 | -------------------------------------------------------------------------------- /src/text2pose/posefix/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/posefix/compute_rotation_change.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import os 11 | from tqdm import tqdm 12 | import torch 13 | import roma 14 | 15 | import text2pose.config as config 16 | import text2pose.utils as utils 17 | 18 | 19 | ### SETUP 20 | ################################################################################ 21 | 22 | # load data 23 | dataID_2_pose_info = utils.read_json(config.file_pose_id_2_dataset_sequence_and_frame_index) 24 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids) 25 | 26 | 27 | ### COMPUTE ROTATION CHANGES 28 | ################################################################################ 29 | 30 | rad2deg = lambda x: x*180/torch.pi 31 | 32 | # compute rotation changes 33 | rotation_changes = [[0.0, 0.0, 0.0] for _ in pose_pairs] 34 | for pairID in tqdm(range(len(pose_pairs))): 35 | 36 | pidA, pidB = pose_pairs[pairID] 37 | pose_info_A = dataID_2_pose_info[str(pidA)] 38 | pose_info_B = dataID_2_pose_info[str(pidB)] 39 | 40 | # 1) normalize the change in rotation (ie. get the difference in 41 | # rotation between pose A and pose B) 42 | pose_data_A, R_norm = utils.get_pose_data_from_file(pose_info_A, output_rotation=True) 43 | pose_data_B = utils.get_pose_data_from_file(pose_info_B, applied_rotation=R_norm if pose_info_A[1] == pose_info_B[1] else None) 44 | 45 | # 2) get the change of rotation (angle in degree) of B wrt A 46 | # The angle should be positive if turning left (clockwise); and negative 47 | # otherwise (when turning right) 48 | r = roma.rotvec_composition((pose_data_B[0:1,:3], roma.rotvec_inverse(pose_data_A[0:1,:3]))) 49 | r = rad2deg(r[0]).tolist() # convert to degrees 50 | r = [r[0], r[2], -r[1]] # reorient (x,y,z) where x is oriented towards the right and y points up 51 | rotation_changes[pairID] = r 52 | 53 | # save 54 | save_filepath = os.path.join(config.POSEFIX_LOCATION, f"ids_2_rotation_change{config.version_suffix}.json") 55 | utils.write_json(rotation_changes, save_filepath, pretty=True) 56 | print("Save global rotation changes at", save_filepath) -------------------------------------------------------------------------------- /src/text2pose/posefix/explore_posefix.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | # $ streamlit run posefix/explore_posefix.py 11 | 12 | import streamlit as st 13 | import random 14 | 15 | import text2pose.config as config 16 | import text2pose.demo as demo 17 | import text2pose.utils as utils 18 | import text2pose.utils_visu as utils_visu 19 | 20 | 21 | ### INPUT 22 | ################################################################################ 23 | 24 | version_human = 'posefix-H' 25 | version_paraphrases = 'posefix-PP' 26 | version_auto = 'posefix-A' 27 | 28 | 29 | ### SETUP 30 | ################################################################################ 31 | 32 | dataID_2_pose_info, triplet_data_human = demo.setup_posefix_data(version_human) 33 | _, triplet_data_pp = demo.setup_posefix_data(version_paraphrases) 34 | _, triplet_data_auto = demo.setup_posefix_data(version_auto) 35 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids) 36 | body_model = demo.setup_body_model() 37 | 38 | 39 | ### DISPLAY DATA 40 | ################################################################################ 41 | 42 | # get input pair id 43 | pair_ID = st.number_input("Pair ID:", 0, len(pose_pairs)) 44 | if st.button('Look at a random pose!'): 45 | pair_ID = random.randint(0, len(pose_pairs)) 46 | st.write(f"Looking at pair ID: **{pair_ID}**") 47 | 48 | # display information about the pair 49 | pid_A, pid_B = pose_pairs[pair_ID] 50 | in_sequence = dataID_2_pose_info[str(pid_A)][1] == dataID_2_pose_info[str(pid_B)][1] 51 | st.markdown(f"{'In' if in_sequence else 'Out-of'}-sequence pair (pose A to pose B).", unsafe_allow_html=True) 52 | 53 | # load pose data 54 | pose_A_info = dataID_2_pose_info[str(pid_A)] 55 | pose_A_data, rA = utils.get_pose_data_from_file(pose_A_info, output_rotation=True) 56 | pose_B_info = dataID_2_pose_info[str(pid_B)] 57 | pose_B_data = utils.get_pose_data_from_file(pose_B_info, applied_rotation=rA if in_sequence else None) 58 | 59 | # render the pair under the desired viewpoint, and display it 60 | view_angle = st.slider("Point of view:", min_value=-180, max_value=180, step=20, value=0) 61 | viewpoint = [] if view_angle == 0 else (view_angle, (0,1,0)) 62 | pair_img = utils_visu.image_from_pair_data(pose_A_data, pose_B_data, body_model, viewpoint=viewpoint, add_ground_plane=True) 63 | st.image(demo.process_img(pair_img)) 64 | 65 | # display text annotations 66 | if pair_ID in triplet_data_human: 67 | for k,m in enumerate(triplet_data_human[pair_ID]["modifier"]): 68 | st.markdown(f"Human-written modifier n°{k+1}", unsafe_allow_html=True) 69 | st.write(f"_{m.strip()}_") 70 | if pair_ID in triplet_data_pp: 71 | for k,m in enumerate(triplet_data_pp[pair_ID]["modifier"]): 72 | st.markdown(f"Paraphrase n°{k+1}", unsafe_allow_html=True) 73 | st.write(f"_{m.strip()}_") 74 | if pair_ID in triplet_data_auto: 75 | for k,m in enumerate(triplet_data_auto[pair_ID]["modifier"]): 76 | st.markdown(f"Automatic modifier n°{k+1}", unsafe_allow_html=True) 77 | st.write(f"_{m.strip()}_") 78 | else: 79 | st.write("This pair ID was not annotated.") -------------------------------------------------------------------------------- /src/text2pose/posescript/README.md: -------------------------------------------------------------------------------- 1 | # About the PoseScript dataset 2 | 3 | ## :inbox_tray: Download 4 | 5 | **License.** 6 | *The PoseScript dataset is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license. 7 | A summary of the CC BY-NC-SA 4.0 license is located [here](https://creativecommons.org/licenses/by-nc-sa/4.0/). 8 | The CC BY-NC-SA 4.0 license is located [here](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).* 9 | 10 | | Version | Link | What changed ? | 11 | |---|---|---| 12 | | **V2** | [download](https://download.europe.naverlabs.com/ComputerVision/PoseScript/posescript_release_v2.zip) | pose set is 100k instead of 20k ; file format updated according to the 2023 code release | 13 | | **V1** | [download](https://download.europe.naverlabs.com/ComputerVision/PoseScript/posescript_dataset_release.zip) | 6293 human-written pairs instead of 3893 | 14 | | **V0 (ECCV22)** | (as V1) | | 15 | 16 | 17 |
18 | Dataset content. 19 | 20 | * a file linking each pose ID to the reference of its corresponding pose sequence in AMASS, and its frame index; 21 | * a file linking each pose ID with its descriptions (separate files for automatically generated captions and human-written ones); 22 | * files listing pose IDs for each split. 23 | 24 | Please refer to the provided README for more details. 25 |
26 | 27 | 28 | ## :crystal_ball: Take a quick look 29 | 30 | To take a quick look at the data (ie. look at the pose under different viewpoints, the human-written caption when available, and the automatic captions): 31 | 32 | ``` 33 | streamlit run posescript/explore_posescript.py 34 | ``` 35 | 36 | ## :page_with_curl: Generate automatic captions 37 | 38 | ### Overview of the captioning pipeline 39 | 40 | Given a normalized 3D pose, we use posecodes to extract semantic pose information. These posecodes are then selected, merged or combined (when relevant) before being converted into a structural pose description in natural language. Letters ‘L’ and ‘R’ stand for ‘left’ and ‘right’ respectively. 41 | 42 | ![Captioning pipeline](../../../images/captioning_pipeline.png) 43 | 44 | Please refer to the paper and the supplementary material for more extensive explanations. 45 | 46 | ### Generate captions 47 | 48 | To generate automatic captions, please follow these steps: 49 | 50 | *Note: the time estimations below are given for 20k poses, but the newest version of PoseScript has 100k poses.* 51 | 52 | - **compute joint coordinates for all poses** _(~ 20 min)_ 53 | ``` 54 | python posescript/compute_coords.py 55 | ``` 56 | 57 | - **get and format BABEL labels for poses in PoseScript** _(~ 5 min)_ 58 | ``` 59 | python posescript/format_babel_labels.py 60 | ``` 61 | 62 | - (optional) **modify captioning data as you see fit**: 63 | - looking at diagrams on posecode statistics can be helpful to decide on posecode eligibility. To do so, run the following: 64 | ``` 65 | python posescript/captioning.py --action posecode_stats --version_name posecode_stats 66 | ``` 67 | - (re)define posecodes (categories, thresholds, tolerable noise levels, eligibility), super-posecodes, ripple effect rules based on statistics, template sentences and so forth by modifiying *posescript/captioning_data.py*. The data structures are extensively explained in this file, and one can follow some marks (`ADD_VIRTUAL_JOINT`, `ADD_POSECODE_KIND`, `ADD_SUPER_POSECODE`) to add new captioning material. Note that some posecodes were added to be used in the [pipeline that generates automatic modifiers](../posefix/README.md) (search for the `ADDED_FOR_MODIFIERS` marks). 68 | 69 | 70 | - **generate automatic captions** _(~ 1 min = 20k captions, with 1 cap/pose)_ 71 | 72 | *Possible arguments are:* 73 | - `--saving_dir`: general location for saving generated captions and data related to them (default: */generated_captions/*) 74 | - `--version_name`: name of the caption version. Will be used to create a subdirectory of `--saving_dir` in which to save all files (descriptions & intermediary results). Default is 'tmp'. 75 | - `--simplified_captions`: produce a simplified version of the captions (basically: no aggregation, no omitting of some support keypoints for the sake of flow, no randomly referring to a body part by a substitute word). This configuration is used to generate caption versions E and F from the paper. 76 | - `--apply_transrel_ripple_effect`: discard some posecodes using ripple effect rules based on transitive relations between body parts. 77 | - `--apply_stat_ripple_effect`: discard some posecodes using ripple effect rules based on statistically frequent pairs and triplets of posecodes. 78 | - `--random_skip`: randomly skip some non-essential posecodes (ie. posecodes that were found to be satisfied by more than 6% of the 20k poses considered in PoseScript). 79 | - `--add_babel_info`: add sentences using information extracted from BABEL. 80 | - `--add_dancing_info`: add a sentence stating that the pose is a dancing pose if it comes from DanceDB (provided that `--add_babel_info` is also set to True.). 81 | 82 |
83 | 84 |
85 | To generate caption versions similar to the latest version of the dataset (posescript-A2): 86 | 87 | | Version | Command | 88 | |---------|---------| 89 | | N2 | `python posescript/captioning.py --version_name captions_n2 --random_skip --simplified_captions` | 90 | | N6 | `python posescript/captioning.py --version_name captions_n6 --random_skip --add_babel_info --add_dancing_info` | 91 | | N7 | `python posescript/captioning.py --version_name captions_n7 --random_skip --apply_transrel_ripple_effect --apply_stat_ripple_effect` | 92 | 93 | *Note that some posecodes were added since, for the release of PoseFix and PoseEmbroider.* 94 |
95 | 96 |
97 | To generate caption versions similar to the ECCV 2022 paper (posescript-A): 98 | 99 | | Version | Command | 100 | |---------|---------| 101 | | A | `python posescript/captioning.py --version_name captions_A --apply_transrel_ripple_effect --apply_stat_ripple_effect --random_skip --add_babel_info --add_dancing_info` | 102 | | B | `python posescript/captioning.py --version_name captions_B --random_skip --add_babel_info --add_dancing_info` | 103 | | C | `python posescript/captioning.py --version_name captions_C --random_skip --add_babel_info` | 104 | | D | `python posescript/captioning.py --version_name captions_D --random_skip` | 105 | | E | `python posescript/captioning.py --version_name captions_E --random_skip --simplified_captions` | 106 | | F | `python posescript/captioning.py --version_name captions_F --simplified_captions` | 107 | 108 | *Note that some posecodes were added since, for the release of PoseFix and PoseEmbroider.* 109 |
110 | 111 | 112 | 113 | ## Citation 114 | 115 | If you use this code or the PoseScript dataset, please cite the following paper: 116 | 117 | ```bibtex 118 | @inproceedings{posescript, 119 | title={{PoseScript: 3D Human Poses from Natural Language}}, 120 | author={{Delmas, Ginger and Weinzaepfel, Philippe and Lucas, Thomas and Moreno-Noguer, Francesc and Rogez, Gr\'egory}}, 121 | booktitle={{ECCV}}, 122 | year={2022} 123 | } 124 | ``` 125 | 126 | Please also remember to follow AMASS' and BABEL's respective citation guideline if you use the AMASS or BABEL data respectively. -------------------------------------------------------------------------------- /src/text2pose/posescript/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/posescript/compute_coords.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import os 11 | from tqdm import tqdm 12 | import math 13 | import torch 14 | from human_body_prior.body_model.body_model import BodyModel 15 | 16 | import text2pose.config as config 17 | import text2pose.utils as utils 18 | 19 | 20 | ### INPUT 21 | ################################################################################ 22 | 23 | device = 'cpu' 24 | 25 | 26 | ### SETUP 27 | ################################################################################ 28 | 29 | # setup body model 30 | body_model = BodyModel(model_type = config.POSE_FORMAT, 31 | bm_fname = config.NEUTRAL_BM, 32 | num_betas = config.n_betas) 33 | body_model.eval() 34 | body_model.to(device) 35 | 36 | # load data 37 | dataID_2_pose_info = utils.read_json(config.file_pose_id_2_dataset_sequence_and_frame_index) 38 | 39 | # rotation transformation to apply so that the coordinates correspond to what we 40 | # actually visualize (ie. from front view) 41 | rotX = lambda theta: torch.tensor([[1, 0, 0], [0, torch.cos(theta), -torch.sin(theta)], [0, torch.sin(theta), torch.cos(theta)]]) 42 | 43 | def transf(rotMat, theta_deg, values): 44 | theta_rad = math.pi * torch.tensor(theta_deg).float() / 180.0 45 | return rotMat(theta_rad).mm(values.t()).t() 46 | 47 | 48 | ### COMPUTE COORDINATES 49 | ################################################################################ 50 | 51 | coords = [] 52 | 53 | # compute all joint coordinates 54 | for dataID in tqdm(range(len(dataID_2_pose_info))): 55 | 56 | # load pose data 57 | pose_info = dataID_2_pose_info[str(dataID)] 58 | pose = utils.get_pose_data_from_file(pose_info) 59 | 60 | # infer coordinates 61 | with torch.no_grad(): 62 | j = body_model(**utils.pose_data_as_dict(pose)).Jtr 63 | j = j.detach().cpu()[0] 64 | j = transf(rotX, -90, j) 65 | 66 | # store data 67 | coords.append(j.view(1, -1, 3)) 68 | coords = torch.cat(coords) 69 | 70 | # save 71 | save_filepath = os.path.join(config.POSESCRIPT_LOCATION, f"ids_2_coords_correct_orient_adapted{config.version_suffix}.pt") 72 | torch.save(coords, save_filepath) 73 | print("Save coordinates at", save_filepath) -------------------------------------------------------------------------------- /src/text2pose/posescript/explore_posescript.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | # $ streamlit run posescript/explore_posescript.py 11 | 12 | import streamlit as st 13 | 14 | import text2pose.demo as demo 15 | import text2pose.utils as utils 16 | import text2pose.utils_visu as utils_visu 17 | 18 | 19 | ### INPUT 20 | ################################################################################ 21 | 22 | version_human = 'posescript-H2' 23 | version_auto = 'posescript-A2' 24 | 25 | 26 | ### SETUP 27 | ################################################################################ 28 | 29 | dataID_2_pose_info, captions_human = demo.setup_posescript_data(version_human) 30 | _, captions_auto = demo.setup_posescript_data(version_auto) 31 | body_model = demo.setup_body_model() 32 | 33 | 34 | ### DISPLAY DATA 35 | ################################################################################ 36 | 37 | # get input pose id 38 | dataID = st.number_input("Pose ID:", 0, len(dataID_2_pose_info)-1) 39 | 40 | # display information about the pose 41 | pose_info = dataID_2_pose_info[str(dataID)] 42 | st.write(f"Pose from the **{pose_info[0]}** dataset, **frame {pose_info[2]}** of sequence *{pose_info[1]}*") 43 | 44 | # load pose data 45 | pose_data = utils.get_pose_data_from_file(pose_info) 46 | 47 | # render the pose under the desired viewpoint, and display it 48 | view_angle = st.slider("Point of view:", min_value=-180, max_value=180, step=20, value=0) 49 | viewpoint = [] if view_angle == 0 else (view_angle, (0,1,0)) 50 | img = utils_visu.image_from_pose_data(pose_data, body_model, viewpoints=[viewpoint], add_ground_plane=True)[0] # 1 viewpoint 51 | st.image(demo.process_img(img)) 52 | 53 | # display captions 54 | if dataID in captions_human: 55 | for k,c in enumerate(captions_human[dataID]): 56 | st.markdown(f"Human-written description n°{k+1}", unsafe_allow_html=True) 57 | st.write(captions_human[dataID][k]) 58 | if dataID in captions_auto: 59 | for k,c in enumerate(captions_auto[dataID]): 60 | st.markdown(f"Automatic description n°{k+1}", unsafe_allow_html=True) 61 | st.write(captions_auto[dataID][k]) 62 | else: 63 | st.write("This pose ID was not annotated.") -------------------------------------------------------------------------------- /src/text2pose/posescript/format_babel_labels.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import os 11 | from tqdm import tqdm 12 | import json 13 | import numpy as np 14 | import pickle 15 | from tabulate import tabulate 16 | 17 | import text2pose.config as config 18 | import text2pose.utils as utils 19 | 20 | 21 | ### SETUP 22 | ################################################################################ 23 | 24 | # load BABEL 25 | l_babel_dense_files = ['train', 'val', 'test'] 26 | l_babel_extra_files = ['extra_train', 'extra_val'] 27 | 28 | babel = {} 29 | for file in l_babel_dense_files: 30 | babel[file] = json.load(open(os.path.join(config.BABEL_LOCATION, file+'.json'))) 31 | 32 | for file in l_babel_extra_files: 33 | babel[file] = json.load(open(os.path.join(config.BABEL_LOCATION, file+'.json'))) 34 | 35 | # load PoseScript 36 | dataID_2_pose_info = utils.read_json(config.file_pose_id_2_dataset_sequence_and_frame_index) 37 | 38 | # AMASS/BABEL path adaptation 39 | amass_to_babel_subdir = { 40 | 'ACCAD': 'ACCAD/ACCAD', 41 | 'BMLhandball': '', # not available 42 | 'BMLmovi': 'BMLmovi/BMLmovi', 43 | 'BioMotionLab_NTroje': 'BMLrub/BioMotionLab_NTroje', 44 | 'CMU': 'CMU/CMU', 45 | 'DFaust_67': 'DFaust67/DFaust_67', 46 | 'DanceDB': '', # not available 47 | 'EKUT': 'EKUT/EKUT', 48 | 'Eyes_Japan_Dataset': 'EyesJapanDataset/Eyes_Japan_Dataset', 49 | 'HumanEva': 'HumanEva/HumanEva', 50 | 'KIT': 'KIT/KIT', 51 | 'MPI_HDM05': 'MPIHDM05/MPI_HDM05', 52 | 'MPI_Limits': 'MPILimits/MPI_Limits', 53 | 'MPI_mosh': 'MPImosh/MPI_mosh', 54 | 'SFU': 'SFU/SFU', 55 | 'SSM_synced': 'SSMsynced/SSM_synced', 56 | 'TCD_handMocap': 'TCDhandMocap/TCD_handMocap', 57 | 'TotalCapture': 'TotalCapture/TotalCapture', 58 | 'Transitions_mocap': 'Transitionsmocap/Transitions_mocap', 59 | } 60 | 61 | 62 | ### GET LABELS 63 | ################################################################################ 64 | 65 | def get_babel_label(amass_rel_path, frame_id): 66 | 67 | # get path correspondance in BABEL 68 | dname = amass_rel_path.split('/')[0] 69 | bname = amass_to_babel_subdir[dname] 70 | if bname == '': 71 | return '__'+dname+'__' 72 | babel_rel_path = '/'.join([bname]+amass_rel_path.split('/')[1:]) 73 | 74 | # look for babel annotations 75 | babelfs = [] 76 | for f in babel.keys(): 77 | for s in babel[f].keys(): 78 | if babel[f][s]['feat_p'] == babel_rel_path: 79 | babelfs.append((f,s)) 80 | 81 | if len(babelfs) == 0: 82 | return None 83 | 84 | # convert frame id to second 85 | seqdata = np.load(os.path.join(config.AMASS_FILE_LOCATION, amass_rel_path)) 86 | framerate = seqdata['mocap_framerate'] 87 | t = frame_id / framerate 88 | 89 | # read babel annotations 90 | labels = [] 91 | for f,s in babelfs: 92 | if not 'frame_ann' in babel[f][s]: 93 | continue 94 | if babel[f][s]['frame_ann'] is None: 95 | continue 96 | babel_annots = babel[f][s]['frame_ann']['labels'] 97 | for i in range(len(babel_annots)): 98 | if t >= babel_annots[i]['start_t'] and t <= babel_annots[i]['end_t']: 99 | labels.append( (babel_annots[i]['raw_label'], babel_annots[i]['proc_label'], babel_annots[i]['act_cat']) ) 100 | 101 | return labels 102 | 103 | # gather labels for all poses in PoseScript that come from AMASS 104 | babel_labels_for_posescript = {} 105 | for dataID in tqdm(dataID_2_pose_info): 106 | pose_info = dataID_2_pose_info[dataID] 107 | if pose_info[0] == "AMASS": 108 | babel_labels_for_posescript[dataID] = get_babel_label(pose_info[1], pose_info[2]) 109 | 110 | # display some stats 111 | table = [] 112 | table.append(['None', sum([v is None for v in babel_labels_for_posescript.values()])]) 113 | table.append(['BMLhandball', sum([v=='__BMLhandball__' for v in babel_labels_for_posescript.values()])]) 114 | table.append(['DanceDB', sum([v=='__DanceDB__' for v in babel_labels_for_posescript.values()])]) 115 | table.append(['0 label', sum([ (isinstance(v,list) and len(v)==0) for v in babel_labels_for_posescript.values()])]) 116 | table.append(['None label', sum([ (isinstance(v,list) and len(v)>=1 and v[0][0] is None) for v in babel_labels_for_posescript.values()])]) 117 | table.append(['1 label', sum([ (isinstance(v,list) and len(v)==1 and v[0][0] is not None) for v in babel_labels_for_posescript.values()])]) 118 | table.append(['>1 label',sum([ (isinstance(v,list) and len(v)>=2 and v[0][0] is not None) for v in babel_labels_for_posescript.values()])]) 119 | print(tabulate(table, headers=["Label", "Number of poses"])) 120 | 121 | # save 122 | save_filepath = os.path.join(config.POSESCRIPT_LOCATION, f"babel_labels_for_posescript{config.version_suffix}.pkl") 123 | with open(save_filepath, 'wb') as f: 124 | pickle.dump(babel_labels_for_posescript, f) 125 | print("Saved", save_filepath) -------------------------------------------------------------------------------- /src/text2pose/retrieval/README.md: -------------------------------------------------------------------------------- 1 | # {Text :left_right_arrow: Pose} Retrieval Model 2 | 3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._ 4 | 5 | ## Model overview 6 | 7 | **Possible inputs**: 3D human pose, pose description. 8 | 9 | ![Retrieval model](../../../images/retrieval_model.png) 10 | 11 | ## :crystal_ball: Demo 12 | 13 | To look at a ranking of poses (resp. descriptions) referenced in PoseScript by relevance to your own input description (resp. chosen pose), using a pretrained model, run the following: 14 | 15 | ``` 16 | bash retrieval/script_retrieval.sh 'demo' 17 | ``` 18 | 19 | ## :bullettrain_front: Train 20 | 21 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options. 22 | 23 | Then use the following command: 24 | ``` 25 | bash retrieval/script_retrieval.sh 'train' 26 | ``` 27 | 28 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model. 29 | 30 | ## :dart: Evaluate 31 | 32 | Use the following command (test on PoseScript-H2): 33 | ``` 34 | bash retrieval/script_retrieval.sh 'eval' 35 | ``` 36 | -------------------------------------------------------------------------------- /src/text2pose/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/retrieval/demo_retrieval.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import streamlit as st 11 | import argparse 12 | import torch 13 | 14 | import text2pose.demo as demo 15 | import text2pose.utils as utils 16 | import text2pose.utils_visu as utils_visu 17 | from text2pose.retrieval.evaluate_retrieval import load_model 18 | 19 | 20 | parser = argparse.ArgumentParser(description='Parameters for the demo.') 21 | parser.add_argument('--model_path', type=str, help='Path to the model.') 22 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help='Checkpoint to choose if model path is incomplete.') 23 | parser.add_argument('--n_retrieve', type=int, default=12, help="Number of elements to retrieve.") 24 | args = parser.parse_args() 25 | 26 | 27 | ### INPUT 28 | ################################################################################ 29 | 30 | data_version_annotations = "posescript-H2" # defines what annotations to use as query examples 31 | data_version_poses_collection = "posescript-A2" # defines the set of poses to rank 32 | 33 | 34 | ### SETUP 35 | ################################################################################ 36 | 37 | # --- data 38 | available_splits = ['train', 'val', 'test'] 39 | model, tokenizer_name, body_model = demo.setup_models([args.model_path], args.checkpoint, load_model) 40 | model, tokenizer_name = model[0], tokenizer_name[0] 41 | dataID_2_pose_info, captions = demo.setup_posescript_data(data_version_annotations) 42 | 43 | 44 | ### MAIN APP 45 | ################################################################################ 46 | 47 | # define query input interface: split selection 48 | cols_query = st.columns(3) 49 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test')) 50 | 51 | # precompute features 52 | dataIDs = demo.setup_posescript_split(split_for_research) 53 | pose_dataIDs, poses_features = demo.precompute_posescript_pose_features(data_version_poses_collection, split_for_research, model) 54 | text_dataIDs, text_features = demo.precompute_text_features(data_version_annotations, split_for_research, model, tokenizer_name) 55 | 56 | # define query input interface: example selection 57 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID')) 58 | number = cols_query[2].number_input("Split index or ID:", 0) 59 | st.markdown("""---""") 60 | 61 | # get query data 62 | pose_ID, pose_data, pose_img, default_description = demo.get_posescript_datapoint(number, query_type, split_for_research, captions, dataID_2_pose_info, body_model) 63 | 64 | # show query data 65 | cols_input = st.columns(2) 66 | cols_input[0].image(pose_img, caption="Annotated pose") 67 | if default_description: 68 | cols_input[1].write("Annotated text:") 69 | cols_input[1].write(f"_{default_description}_") 70 | else: 71 | cols_input[1].write("_(Not annotated.)_") 72 | 73 | # get retrieval direction 74 | dt2p = "Text-2-Pose" 75 | dp2t = "Pose-2-Text" 76 | retrieval_direction = st.radio("Retrieval direction:", [dt2p, dp2t]) 77 | 78 | # TEXT-2-POSE 79 | if retrieval_direction == dt2p: 80 | 81 | # get input description 82 | description = cols_input[1].text_area("Pose description:", 83 | value=default_description, 84 | placeholder="The person is...", 85 | height=None, max_chars=None) 86 | 87 | # encode text 88 | with torch.no_grad(): 89 | text_feature = model.encode_raw_text(description) 90 | 91 | # rank poses by relevance and get their pose id 92 | scores = text_feature.view(1, -1).mm(poses_features.t())[0] 93 | _, indices_rank = scores.sort(descending=True) 94 | relevant_pose_ids = [pose_dataIDs[i] for i in indices_rank[:args.n_retrieve]] 95 | 96 | # get corresponding pose data 97 | all_pose_data = [] 98 | for pose_id in relevant_pose_ids: 99 | pose_info = dataID_2_pose_info[str(pose_id)] 100 | all_pose_data.append(utils.get_pose_data_from_file(pose_info)) 101 | all_pose_data = torch.cat(all_pose_data) 102 | 103 | # render poses 104 | imgs = utils_visu.image_from_pose_data(all_pose_data, body_model, color="blue", add_ground_plane=True) 105 | 106 | # display images 107 | st.markdown("""---""") 108 | st.write(f"**Retrieved poses for this description [{split_for_research} split]:**") 109 | cols = st.columns(demo.nb_cols) 110 | for i in range(args.n_retrieve): 111 | cols[i%demo.nb_cols].image(demo.process_img(imgs[i])) 112 | 113 | # POSE-2-TEXT 114 | elif retrieval_direction == dp2t: 115 | 116 | # rank texts by relevance and get their id 117 | pose_index = pose_dataIDs.index(pose_ID) 118 | scores = poses_features[pose_index].view(1, -1).mm(text_features.t())[0] 119 | _, indices_rank = scores.sort(descending=True) 120 | relevant_pose_ids = [text_dataIDs[i] for i in indices_rank[:args.n_retrieve]] 121 | 122 | # get corresponding text data (the text features were obtained using the first text) 123 | texts = [captions[pose_id][0] for pose_id in relevant_pose_ids] 124 | 125 | # display texts 126 | st.markdown("""---""") 127 | st.write(f"**Retrieved descriptions for this pose [{split_for_research} split]:**") 128 | for i in range(args.n_retrieve): 129 | st.write(f"**({i+1})** {texts[i]}") 130 | 131 | st.markdown("""---""") 132 | st.write(f"_Results obtained with model: {args.model_path}_") -------------------------------------------------------------------------------- /src/text2pose/retrieval/evaluate_retrieval.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import os 11 | import torch 12 | from tqdm import tqdm 13 | 14 | import text2pose.config as config 15 | import text2pose.evaluate as evaluate 16 | from text2pose.data import PoseScript 17 | from text2pose.encoders.tokenizers import get_tokenizer_name 18 | from text2pose.retrieval.model_retrieval import PoseText 19 | from text2pose.loss import BBC 20 | 21 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 22 | 23 | OVERWRITE_RESULT = False 24 | 25 | 26 | ################################################################################ 27 | 28 | def load_model(model_path, device): 29 | 30 | assert os.path.isfile(model_path), "File {} not found.".format(model_path) 31 | 32 | # load checkpoint & model info 33 | ckpt = torch.load(model_path, 'cpu') 34 | text_encoder_name = ckpt['args'].text_encoder_name 35 | transformer_topping = getattr(ckpt['args'], 'transformer_topping', None) 36 | latentD = ckpt['args'].latentD 37 | num_body_joints = getattr(ckpt['args'], 'num_body_joints', 52) 38 | 39 | # load model 40 | model = PoseText(text_encoder_name=text_encoder_name, 41 | transformer_topping=transformer_topping, 42 | latentD=latentD, 43 | num_body_joints=num_body_joints 44 | ).to(device) 45 | model.load_state_dict(ckpt['model']) 46 | model.eval() 47 | print(f"Loaded model from (epoch {ckpt['epoch']}):", model_path) 48 | 49 | return model, get_tokenizer_name(text_encoder_name) 50 | 51 | 52 | def eval_model(model_path, dataset_version, split='val', generated_pose_samples=None): 53 | 54 | device = torch.device('cuda:0') 55 | 56 | # define result file & get auxiliary info 57 | generated_pose_samples_path, precision = get_evaluation_auxiliary_info(model_path, generated_pose_samples) 58 | nb_caps = config.caption_files[dataset_version][0] 59 | get_res_file = evaluate.get_result_filepath_func(model_path, split, dataset_version, precision, nb_caps) 60 | 61 | # load model if results for at least one caption is missing 62 | if OVERWRITE_RESULT or evaluate.one_result_file_is_missing(get_res_file, nb_caps): 63 | model, tokenizer_name = load_model(model_path, device) 64 | 65 | # compute or load results for the given run & caption 66 | results = {} 67 | for cap_ind in range(nb_caps): 68 | filename_res = get_res_file(cap_ind) 69 | if not os.path.isfile(filename_res) or OVERWRITE_RESULT: 70 | if "posescript" in dataset_version: 71 | d = PoseScript(version=dataset_version, split=split, tokenizer_name=tokenizer_name, caption_index=cap_ind, num_body_joints=model.pose_encoder.num_body_joints, cache=True, generated_pose_samples_path=generated_pose_samples_path) 72 | else: 73 | raise NotImplementedError 74 | cap_results = compute_eval_metrics(model, d, device) 75 | evaluate.save_results_to_file(cap_results, filename_res) 76 | else: 77 | cap_results = evaluate.load_results_from_file(filename_res) 78 | # aggregate results 79 | results = {k:[v] for k, v in cap_results.items()} if not results else {k:results[k]+[v] for k,v in cap_results.items()} 80 | 81 | # average over captions 82 | results = {k:sum(v)/nb_caps for k,v in results.items()} 83 | 84 | return {k:[v] for k, v in results.items()} 85 | 86 | 87 | def get_evaluation_auxiliary_info(model_path, generated_pose_samples): 88 | precision = "" # default 89 | generated_pose_samples_path = None # default 90 | if generated_pose_samples: 91 | precision = f"_gensample_{generated_pose_samples}" 92 | seed = evaluate.get_seed_from_model_path(model_path) 93 | generated_pose_samples_model_path = (config.shortname_2_model_path[generated_pose_samples]).format(seed=seed) 94 | generated_pose_samples_path = config.generated_pose_path % os.path.dirname(generated_pose_samples_model_path) 95 | return generated_pose_samples_path, precision 96 | 97 | 98 | def compute_eval_metrics(model, dataset, device, compute_loss=False): 99 | 100 | # get data features 101 | poses_features, texts_features = infer_features(model, dataset, device) 102 | 103 | # pose-2-text matching 104 | p2t_recalls = evaluate.x2y_recall_metrics(poses_features, texts_features, config.k_recall_values, sstr="p2t_") 105 | # text-2-pose matching 106 | t2p_recalls = evaluate.x2y_recall_metrics(texts_features, poses_features, config.k_recall_values, sstr="t2p_") 107 | # r-precision 108 | rprecisions = evaluate.textret_metrics(texts_features, poses_features) 109 | 110 | # gather metrics 111 | recalls = {"mRecall": (sum(p2t_recalls.values()) + sum(t2p_recalls.values())) / (2 * len(config.k_recall_values))} 112 | recalls.update(p2t_recalls) 113 | recalls.update(t2p_recalls) 114 | recalls.update(rprecisions) 115 | 116 | # loss 117 | if compute_loss: 118 | score_t2p = texts_features.mm(poses_features.t()) 119 | loss = BBC(score_t2p*model.loss_weight) 120 | loss_value = loss.item() 121 | return recalls, loss_value 122 | 123 | return recalls 124 | 125 | 126 | def infer_features(model, dataset, device): 127 | 128 | batch_size = 32 129 | data_loader = torch.utils.data.DataLoader( 130 | dataset, sampler=None, shuffle=False, 131 | batch_size=batch_size, 132 | num_workers=8, 133 | pin_memory=True, 134 | drop_last=False 135 | ) 136 | 137 | poses_features = torch.zeros(len(dataset), model.latentD).to(device) 138 | texts_features = torch.zeros(len(dataset), model.latentD).to(device) 139 | 140 | for i, batch in tqdm(enumerate(data_loader)): 141 | poses = batch['pose'].to(device) 142 | caption_tokens = batch['caption_tokens'].to(device) 143 | caption_lengths = batch['caption_lengths'].to(device) 144 | caption_tokens = caption_tokens[:,:caption_lengths.max()] 145 | with torch.inference_mode(): 146 | pfeat, tfeat = model(poses, caption_tokens, caption_lengths) 147 | poses_features[i*batch_size:i*batch_size+len(poses)] = pfeat 148 | texts_features[i*batch_size:i*batch_size+len(poses)] = tfeat 149 | 150 | return poses_features, texts_features 151 | 152 | 153 | def display_results(results): 154 | metric_order = ['mRecall'] + ['%s_R@%d'%(d, k) for d in ['p2t', 't2p'] for k in config.k_recall_values] 155 | results = evaluate.scale_and_format_results(results) 156 | print(f"\n & {' & '.join([results[m] for m in metric_order])} \\\\\n") 157 | 158 | 159 | ################################################################################ 160 | 161 | if __name__ == '__main__': 162 | 163 | # added special arguments 164 | evaluate.eval_parser.add_argument('--generated_pose_samples', default=None, help="Shortname for the model that generated the pose files to be used (full path registered in config.py") 165 | 166 | args = evaluate.eval_parser.parse_args() 167 | args = evaluate.get_full_model_path(args) 168 | 169 | # compute results 170 | if args.average_over_runs: 171 | ret = evaluate.eval_model_all_runs(eval_model, args.model_path, dataset_version=args.dataset, split=args.split, generated_pose_samples=args.generated_pose_samples) 172 | else: 173 | ret = eval_model(args.model_path, dataset_version=args.dataset, split=args.split, generated_pose_samples=args.generated_pose_samples) 174 | 175 | # display results 176 | print(ret) 177 | display_results(ret) -------------------------------------------------------------------------------- /src/text2pose/retrieval/model_retrieval.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import torch 11 | from torch import nn 12 | 13 | import text2pose.config as config 14 | from text2pose.encoders.tokenizers import Tokenizer, get_text_encoder_or_decoder_module_name, get_tokenizer_name 15 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder 16 | from text2pose.encoders.text_encoders import TextEncoder, TransformerTextEncoder 17 | 18 | 19 | class PoseText(nn.Module): 20 | def __init__(self, num_neurons=512, num_neurons_mini=32, latentD=512, 21 | num_body_joints=config.NB_INPUT_JOINTS, 22 | text_encoder_name='distilbertUncased', transformer_topping=None): 23 | super(PoseText, self).__init__() 24 | 25 | self.latentD = latentD 26 | 27 | # Define pose encoder 28 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons, 29 | num_neurons_mini=num_neurons_mini, 30 | latentD=latentD, 31 | num_body_joints=num_body_joints, 32 | role="retrieval") 33 | 34 | # Define text encoder 35 | self.text_encoder_name = text_encoder_name 36 | module_ref = get_text_encoder_or_decoder_module_name(text_encoder_name) 37 | if module_ref in ["glovebigru"]: 38 | self.text_encoder = TextEncoder(self.text_encoder_name, latentD=latentD, role="retrieval") 39 | elif module_ref in ["glovetransf", "distilbertUncased"]: 40 | self.text_encoder = TransformerTextEncoder(self.text_encoder_name, latentD=latentD, topping=transformer_topping, role="retrieval") 41 | else: 42 | raise NotImplementedError 43 | 44 | # Loss temperature 45 | self.loss_weight = torch.nn.Parameter( torch.FloatTensor((10,)) ) 46 | self.loss_weight.requires_grad = True 47 | 48 | def forward(self, pose, captions, caption_lengths): 49 | pose_embs = self.pose_encoder(pose) 50 | text_embs = self.text_encoder(captions, caption_lengths) 51 | return pose_embs, text_embs 52 | 53 | def encode_raw_text(self, raw_text): 54 | if not hasattr(self, 'tokenizer'): 55 | self.tokenizer = Tokenizer(get_tokenizer_name(self.text_encoder_name)) 56 | tokens = self.tokenizer(raw_text).to(device=self.loss_weight.device) 57 | length = torch.tensor([ len(tokens) ], dtype=tokens.dtype) 58 | text_embs = self.text_encoder(tokens.view(1, -1), length) 59 | return text_embs 60 | 61 | def encode_pose(self, pose): 62 | return self.pose_encoder(pose) 63 | 64 | def encode_text(self, captions, caption_lengths): 65 | return self.text_encoder(captions, caption_lengths) -------------------------------------------------------------------------------- /src/text2pose/retrieval/script_retrieval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################################## 4 | ## text2pose ## 5 | ## Copyright (c) 2022, 2023 ## 6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 7 | ## and Naver Corporation ## 8 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 9 | ## See project root for license details. ## 10 | ############################################################## 11 | 12 | 13 | ############################################################## 14 | # SCRIPT ARGUMENTS 15 | 16 | action=$1 # (train|eval|demo) 17 | checkpoint_type="best" # (last|best) 18 | 19 | architecture_args=( 20 | --model PoseText 21 | --latentD 512 22 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp" 23 | # --text_encoder_name 'glovebigru_vocPSA2H2' 24 | ) 25 | 26 | loss_args=( 27 | --retrieval_loss 'symBBC' 28 | ) 29 | 30 | bonus_args=( 31 | ) 32 | 33 | pretrained="ret_distilbert_dataPSA2" # used only if phase=='finetune' 34 | 35 | 36 | ############################################################## 37 | # EXECUTE 38 | 39 | # TRAIN 40 | if [[ "$action" == *"train"* ]]; then 41 | 42 | phase=$2 # (pretrain|finetune) 43 | echo "NOTE: Expecting as argument the training phase. Got: $phase" 44 | seed=$3 45 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 46 | 47 | # PRETRAIN 48 | if [[ "$phase" == *"pretrain"* ]]; then 49 | 50 | python retrieval/train_retrieval.py --dataset "posescript-A2" \ 51 | "${architecture_args[@]}" \ 52 | "${loss_args[@]}" \ 53 | "${bonus_args[@]}" \ 54 | --lr_scheduler "stepLR" --lr 0.0002 --lr_step 400 --lr_gamma 0.5 \ 55 | --log_step 20 --val_every 20 \ 56 | --batch_size 512 --epochs 1000 --seed $seed 57 | 58 | # FINETUNE 59 | elif [[ "$phase" == *"finetune"* ]]; then 60 | 61 | python retrieval/train_retrieval.py --dataset "posescript-H2" \ 62 | "${architecture_args[@]}" \ 63 | "${loss_args[@]}" \ 64 | "${bonus_args[@]}" \ 65 | --apply_LR_augmentation \ 66 | --lr_scheduler "stepLR" --lr 0.0002 --lr_step 40 --lr_gamma 0.5 \ 67 | --batch_size 32 --epochs 200 --seed $seed \ 68 | --pretrained $pretrained 69 | 70 | fi 71 | 72 | fi 73 | 74 | 75 | # EVAL QUANTITATIVELY 76 | if [[ "$action" == *"eval"* ]]; then 77 | 78 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one 79 | 80 | for model_path in "${experiments[@]}" 81 | do 82 | echo $model_path 83 | python retrieval/evaluate_retrieval.py --dataset "posescript-H2" \ 84 | --model_path ${model_path} --checkpoint $checkpoint_type \ 85 | --split test 86 | done 87 | fi 88 | 89 | 90 | # EVAL QUALITATIVELY 91 | if [[ "$action" == *"demo"* ]]; then 92 | 93 | experiment=$2 # only one at a time 94 | streamlit run retrieval/demo_retrieval.py -- --model_path $experiment --checkpoint $checkpoint_type 95 | 96 | fi -------------------------------------------------------------------------------- /src/text2pose/retrieval/train_retrieval.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2022, 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import torch 11 | import math 12 | import sys 13 | import os 14 | os.umask(0x0002) 15 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 16 | 17 | from text2pose.option import get_args_parser 18 | from text2pose.trainer import GenericTrainer 19 | from text2pose.retrieval.model_retrieval import PoseText 20 | from text2pose.retrieval.evaluate_retrieval import compute_eval_metrics 21 | from text2pose.loss import BBC, symBBC 22 | from text2pose.data import PoseScript, PoseFix 23 | from text2pose.encoders.tokenizers import get_tokenizer_name 24 | from text2pose.data_augmentations import DataAugmentation 25 | 26 | import text2pose.config as config 27 | import text2pose.utils_logging as logging 28 | 29 | 30 | ################################################################################ 31 | 32 | 33 | class PoseTextTrainer(GenericTrainer): 34 | 35 | def __init__(self, args): 36 | super(PoseTextTrainer, self).__init__(args, retrieval_trainer=True) 37 | 38 | 39 | def load_dataset(self, split, caption_index, tokenizer_name=None): 40 | 41 | if tokenizer_name is None: tokenizer_name = get_tokenizer_name(self.args.text_encoder_name) 42 | data_size = self.args.data_size if split=="train" else None 43 | 44 | if "posescript" in self.args.dataset: 45 | d = PoseScript(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size) 46 | elif "posefix" in self.args.dataset: 47 | d = PoseFix(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size, posescript_format=True) 48 | else: 49 | raise NotImplementedError 50 | return d 51 | 52 | 53 | def init_model(self): 54 | print('Load model') 55 | self.model = PoseText(text_encoder_name=self.args.text_encoder_name, 56 | transformer_topping=self.args.transformer_topping, 57 | latentD=self.args.latentD, 58 | num_body_joints=self.args.num_body_joints) 59 | self.model.to(self.device) 60 | 61 | 62 | def get_param_groups(self): 63 | param_groups = [] 64 | param_groups.append({'params': self.model.pose_encoder.parameters(), 'lr': self.args.lr*self.args.lrposemul}) 65 | param_groups.append({'params': [p for k,p in self.model.text_encoder.named_parameters() if 'pretrained_text_encoder.' not in k], 'lr': self.args.lr*self.args.lrtextmul}) 66 | param_groups.append({'params': [self.model.loss_weight]}) 67 | return param_groups 68 | 69 | 70 | def init_optimizer(self): 71 | assert self.args.optimizer=='Adam' 72 | param_groups = self.get_param_groups() 73 | self.optimizer = torch.optim.Adam(param_groups, lr=self.args.lr) 74 | 75 | 76 | def init_lr_scheduler(self): 77 | self.lr_scheduler = None 78 | if self.args.lr_scheduler == "stepLR": 79 | self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 80 | step_size=self.args.lr_step, 81 | gamma=self.args.lr_gamma, 82 | last_epoch=-1) 83 | 84 | 85 | def init_other_training_elements(self): 86 | self.data_augmentation_module = DataAugmentation(self.args, mode="posescript", tokenizer_name=get_tokenizer_name(self.args.text_encoder_name), nb_joints=self.args.num_body_joints) 87 | 88 | 89 | def training_epoch(self, epoch): 90 | train_stats = self.one_epoch(epoch=epoch, is_training=True) 91 | return train_stats 92 | 93 | 94 | def validation_epoch(self, epoch): 95 | val_stats = {} 96 | if self.args.val_every and (epoch+1)%self.args.val_every==0: 97 | val_stats = self.validate(epoch=epoch) 98 | return val_stats 99 | 100 | 101 | def one_epoch(self, epoch, is_training): 102 | 103 | self.model.train(is_training) 104 | 105 | # define loggers 106 | metric_logger = logging.MetricLogger(delimiter=" ") 107 | if is_training: 108 | prefix, sstr = '', 'train' 109 | metric_logger.add_meter(f'{sstr}_lr', logging.SmoothedValue(window_size=1, fmt='{value:.6f}')) 110 | else: 111 | prefix, sstr = '[val] ', 'val' 112 | header = f'{prefix}Epoch: [{epoch}]' 113 | 114 | # define dataloader & other elements 115 | if is_training: 116 | data_loader = self.data_loader_train 117 | if not is_training: 118 | data_loader = self.data_loader_val 119 | 120 | # iterate over the batches 121 | for data_iter_step, item in enumerate(metric_logger.log_every(data_loader, self.args.log_step, header)): 122 | 123 | # get data 124 | poses = item['pose'].to(self.device) 125 | caption_tokens = item['caption_tokens'].to(self.device) 126 | caption_lengths = item['caption_lengths'].to(self.device) 127 | caption_tokens = caption_tokens[:,:caption_lengths.max()] # truncate within the batch, based on the longest text 128 | 129 | # online random augmentations 130 | poses, caption_tokens, caption_lengths = self.data_augmentation_module(poses, caption_tokens, caption_lengths) 131 | 132 | # forward; compute scores 133 | with torch.set_grad_enabled(is_training): 134 | poses_features, texts_features = self.model(poses, caption_tokens, caption_lengths) 135 | score_t2p = texts_features.mm(poses_features.t()) * self.model.loss_weight 136 | 137 | # compute loss 138 | if self.args.retrieval_loss == "BBC": 139 | loss = BBC(score_t2p) 140 | elif self.args.retrieval_loss == "symBBC": 141 | loss = symBBC(score_t2p) 142 | else: 143 | raise NotImplementedError 144 | 145 | loss_value = loss.item() 146 | if not math.isfinite(loss_value): 147 | print("Loss is {}, stopping training".format(loss_value)) 148 | sys.exit(1) 149 | 150 | # training step 151 | if is_training: 152 | self.optimizer.zero_grad() 153 | loss.backward() 154 | self.optimizer.step() 155 | 156 | # format data for logging 157 | scalars = [('loss', loss_value)] 158 | if is_training: 159 | lr_value = self.optimizer.param_groups[0]["lr"] 160 | scalars += [('lr', lr_value)] 161 | 162 | # actually log 163 | self.add_data_to_log_writer(epoch, sstr, scalars=scalars, is_training=is_training, data_iter_step=data_iter_step, total_steps=len(data_loader)) 164 | self.add_data_to_metric_logger(metric_logger, sstr, scalars) 165 | 166 | print("Averaged stats:", metric_logger) 167 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 168 | 169 | 170 | def validate(self, epoch): 171 | 172 | self.model.eval() 173 | 174 | recalls, loss_value = compute_eval_metrics(self.model, self.data_loader_val.dataset, self.device, compute_loss=True) 175 | val_stats = {"loss": loss_value} 176 | val_stats.update(recalls) 177 | 178 | # log 179 | self.add_data_to_log_writer(epoch, 'val', scalars=[('loss', loss_value), ('validation', recalls)], should_log_data=True) 180 | print(f"[val] Epoch: [{epoch}] Stats: " + " ".join(f"{k}: {round(v, 3)}" for k,v in val_stats.items()) ) 181 | return val_stats 182 | 183 | 184 | if __name__ == '__main__': 185 | 186 | argparser = get_args_parser() 187 | args = argparser.parse_args() 188 | 189 | PoseTextTrainer(args)() -------------------------------------------------------------------------------- /src/text2pose/retrieval_modifier/README.md: -------------------------------------------------------------------------------- 1 | # {Pose-Pair :left_right_arrow: Instruction} Retrieval Model 2 | 3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._ 4 | 5 | ## Model overview 6 | 7 | **Possible inputs**: a pair of 3D human poses (ie. two elements at once), text instruction. 8 | 9 | ![PairText retrieval model](../../../images/retrieval_modifier_model.png) 10 | 11 | ## :crystal_ball: Demo 12 | 13 | To look at a ranking of text instruction (resp. pose pair) referenced in PoseFix by relevance to a chosen pose pair (resp. text instruction), using a pretrained model, run the following: 14 | 15 | ``` 16 | bash retrieval_modifier/script_retrieval_modifier.sh 'demo' 17 | ``` 18 | 19 | ## :bullettrain_front: Train 20 | 21 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options. 22 | 23 | Then use the following command: 24 | ``` 25 | bash retrieval_modifier/script_retrieval_modifier.sh 'train' 26 | ``` 27 | 28 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model. 29 | 30 | ## :dart: Evaluate 31 | 32 | Use the following command (test on PoseFix-H): 33 | ``` 34 | bash retrieval_modifier/script_retrieval_modifier.sh 'eval' 35 | ``` 36 | -------------------------------------------------------------------------------- /src/text2pose/retrieval_modifier/__init__.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## -------------------------------------------------------------------------------- /src/text2pose/retrieval_modifier/demo_retrieval_modifier.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import streamlit as st 11 | import argparse 12 | import torch 13 | 14 | import text2pose.config as config 15 | import text2pose.demo as demo 16 | import text2pose.utils as utils 17 | import text2pose.utils_visu as utils_visu 18 | from text2pose.retrieval_modifier.evaluate_retrieval_modifier import load_model 19 | 20 | 21 | parser = argparse.ArgumentParser(description='Parameters for the demo.') 22 | parser.add_argument('--model_path', type=str, help='Path to the model.') 23 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help='Checkpoint to choose if model path is incomplete.') 24 | parser.add_argument('--n_retrieve', type=int, default=12, help="Number of elements to retrieve.") 25 | args = parser.parse_args() 26 | 27 | args.n_retrieve = 12 28 | 29 | ### INPUT 30 | ################################################################################ 31 | 32 | data_version_annotations = "posefix-H" # defines what annotations to use as query examples 33 | data_version_poses_collection = "posefix-A" # defines the set of poses to rank 34 | 35 | 36 | ### SETUP 37 | ################################################################################ 38 | 39 | # --- data 40 | available_splits = ['train', 'val', 'test'] 41 | model, tokenizer_name, body_model = demo.setup_models([args.model_path], args.checkpoint, load_model) 42 | model, tokenizer_name = model[0], tokenizer_name[0] 43 | dataID_2_pose_info, triplet_data = demo.setup_posefix_data(data_version_annotations) 44 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids) 45 | 46 | 47 | ### MAIN APP 48 | ################################################################################ 49 | 50 | # define query input interface: split selection 51 | cols_query = st.columns(3) 52 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test')) 53 | 54 | # precompute features 55 | dataIDs = demo.setup_posefix_split(split_for_research) 56 | pair_dataIDs, pairs_features = demo.precompute_posefix_pair_features(data_version_poses_collection, split_for_research, model) 57 | text_dataIDs, text_features = demo.precompute_text_features(data_version_annotations, split_for_research, model, tokenizer_name) 58 | 59 | # define query input interface: example selection 60 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID')) 61 | number = cols_query[2].number_input("Split index or ID:", 0) 62 | st.markdown("""---""") 63 | 64 | # get query data 65 | pair_ID, pid_A, pid_B, pose_A_data, pose_B_data, pose_A_img, pose_B_img, default_modifier = demo.get_posefix_datapoint(number, query_type, split_for_research, triplet_data, pose_pairs, dataID_2_pose_info, body_model) 66 | 67 | # show query data 68 | cols_input = st.columns([1,1,2]) 69 | cols_input[0].image(pose_A_img, caption="Pose A") 70 | cols_input[1].image(pose_B_img, caption="Pose B") 71 | if default_modifier: 72 | cols_input[2].write("Annotated text:") 73 | cols_input[2].write(f"_{default_modifier}_") 74 | else: 75 | cols_input[2].write("_(Not annotated.)_") 76 | 77 | # get retrieval direction 78 | dt2p = "Text-2-Pair" 79 | dp2t = "Pair-2-Text" 80 | retrieval_direction = st.radio("Retrieval direction:", [dt2p, dp2t]) 81 | 82 | # TEXT-2-PAIR 83 | if retrieval_direction == dt2p: 84 | 85 | # get input modifier 86 | modifier = cols_input[2].text_area("Pose modifier:", 87 | placeholder="Move your right arm... lift your left leg...", 88 | value=default_modifier, 89 | height=None, max_chars=None) 90 | # encode text 91 | with torch.no_grad(): 92 | text_feature = model.encode_raw_text(modifier) 93 | 94 | # rank poses by relevance and get their pose id 95 | scores = text_feature.view(1, -1).mm(pairs_features.t())[0] 96 | _, indices_rank = scores.sort(descending=True) 97 | relevant_pair_ids = [pair_dataIDs[i] for i in indices_rank[:args.n_retrieve]] 98 | 99 | # get corresponding pair data and render the pairs as images 100 | imgs = [] 101 | for pair_id in relevant_pair_ids: 102 | ret_pid_A, ret_pid_B = pose_pairs[pair_id] 103 | pose_A_info = dataID_2_pose_info[str(ret_pid_A)] 104 | pose_A_data, rA = utils.get_pose_data_from_file(pose_A_info, output_rotation=True) 105 | pose_B_info = dataID_2_pose_info[str(ret_pid_B)] 106 | pose_B_data = utils.get_pose_data_from_file(pose_B_info, applied_rotation=rA if pose_A_info[1]==pose_B_info[1] else None) 107 | imgs.append(utils_visu.image_from_pair_data(pose_A_data, pose_B_data, body_model, add_ground_plane=True)) 108 | 109 | # display images 110 | st.markdown("""---""") 111 | st.write(f"**Retrieved pairs for this modifier [{split_for_research} split]:**") 112 | cols = st.columns(demo.nb_cols) 113 | for i in range(args.n_retrieve): 114 | cols[i%demo.nb_cols].image(demo.process_img(imgs[i])) 115 | 116 | # PAIR-2-TEXT 117 | elif retrieval_direction == dp2t: 118 | 119 | # rank texts by relevance and get their id 120 | pair_index = pair_dataIDs.index(pair_ID) 121 | scores = pairs_features[pair_index].view(1, -1).mm(text_features.t())[0] 122 | _, indices_rank = scores.sort(descending=True) 123 | relevant_pair_ids = [text_dataIDs[i] for i in indices_rank[:args.n_retrieve]] 124 | 125 | # get corresponding text data (the text features were obtained using the first text) 126 | texts = [triplet_data[pair_id]['modifier'][0] for pair_id in relevant_pair_ids] 127 | 128 | # display texts 129 | st.markdown("""---""") 130 | st.write(f"**Retrieved modifiers for this pair [{split_for_research} split]:**") 131 | for i in range(args.n_retrieve): 132 | st.write(f"**({i+1})** {texts[i]}") 133 | 134 | st.markdown("""---""") 135 | st.write(f"_Results obtained with model: {args.model_path}_") -------------------------------------------------------------------------------- /src/text2pose/retrieval_modifier/evaluate_retrieval_modifier.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import os 11 | import torch 12 | from tqdm import tqdm 13 | 14 | import text2pose.config as config 15 | import text2pose.evaluate as evaluate 16 | from text2pose.data import PoseFix 17 | from text2pose.encoders.tokenizers import get_tokenizer_name 18 | from text2pose.retrieval_modifier.model_retrieval_modifier import PairText 19 | from text2pose.loss import BBC 20 | 21 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 22 | 23 | OVERWRITE_RESULT = False 24 | 25 | 26 | ################################################################################ 27 | 28 | def load_model(model_path, device): 29 | 30 | assert os.path.isfile(model_path), "File {} not found.".format(model_path) 31 | 32 | # load checkpoint & model info 33 | ckpt = torch.load(model_path, 'cpu') 34 | text_encoder_name = ckpt['args'].text_encoder_name 35 | transformer_topping = ckpt['args'].transformer_topping 36 | latentD = ckpt['args'].latentD 37 | num_body_joints = getattr(ckpt['args'], 'num_body_joints', 52) 38 | 39 | # load model 40 | model = PairText(text_encoder_name=text_encoder_name, 41 | transformer_topping=transformer_topping, 42 | latentD=latentD, 43 | num_body_joints=num_body_joints 44 | ).to(device) 45 | model.load_state_dict(ckpt['model']) 46 | model.eval() 47 | print(f"Loaded model from (epoch {ckpt['epoch']}):", model_path) 48 | 49 | return model, get_tokenizer_name(text_encoder_name) 50 | 51 | 52 | def eval_model(model_path, dataset_version, split='val'): 53 | 54 | device = torch.device('cuda:0') 55 | 56 | # define result file 57 | precision = "" # default 58 | nb_caps = config.caption_files[dataset_version][0] 59 | get_res_file = evaluate.get_result_filepath_func(model_path, split, dataset_version, precision, nb_caps) 60 | 61 | # load model if results for at least one caption is missing 62 | if OVERWRITE_RESULT or evaluate.one_result_file_is_missing(get_res_file, nb_caps): 63 | model, tokenizer_name = load_model(model_path, device) 64 | 65 | # compute or load results for the given run & caption 66 | results = {} 67 | for cap_ind in range(nb_caps): 68 | filename_res = get_res_file(cap_ind) 69 | if not os.path.isfile(filename_res) or OVERWRITE_RESULT: 70 | if "posefix" in dataset_version: 71 | d = PoseFix(version=dataset_version, split=split, tokenizer_name=tokenizer_name, caption_index=cap_ind, num_body_joints=model.pose_encoder.num_body_joints, cache=True) 72 | else: 73 | raise NotImplementedError 74 | cap_results = compute_eval_metrics(model, d, device) 75 | evaluate.save_results_to_file(cap_results, filename_res) 76 | else: 77 | cap_results = evaluate.load_results_from_file(filename_res) 78 | # aggregate results 79 | results = {k:[v] for k, v in cap_results.items()} if not results else {k:results[k]+[v] for k,v in cap_results.items()} 80 | 81 | # average over captions 82 | results = {k:sum(v)/nb_caps for k,v in results.items()} 83 | 84 | return {k:[v] for k, v in results.items()} 85 | 86 | 87 | def compute_eval_metrics(model, dataset, device, compute_loss=False): 88 | 89 | # get data features 90 | poses_features, texts_features = infer_features(model, dataset, device) 91 | 92 | # poses-2-text matching 93 | p2t_recalls = evaluate.x2y_recall_metrics(poses_features, texts_features, config.k_recall_values, sstr="p2t_") 94 | # text-2-poses matching 95 | t2p_recalls = evaluate.x2y_recall_metrics(texts_features, poses_features, config.k_recall_values, sstr="t2p_") 96 | # r-precision 97 | rprecisions = evaluate.textret_metrics(texts_features, poses_features) 98 | 99 | # gather metrics 100 | recalls = {"mRecall": (sum(p2t_recalls.values()) + sum(t2p_recalls.values())) / (2 * len(config.k_recall_values))} 101 | recalls.update(p2t_recalls) 102 | recalls.update(t2p_recalls) 103 | recalls.update(rprecisions) 104 | 105 | # loss 106 | if compute_loss: 107 | score_t2p = texts_features.mm(poses_features.t()) 108 | loss = BBC(score_t2p*model.loss_weight) 109 | loss_value = loss.item() 110 | return recalls, loss_value 111 | 112 | return recalls 113 | 114 | 115 | def infer_features(model, dataset, device): 116 | 117 | batch_size = 32 118 | data_loader = torch.utils.data.DataLoader( 119 | dataset, sampler=None, shuffle=False, 120 | batch_size=batch_size, 121 | num_workers=8, 122 | pin_memory=True, 123 | drop_last=False 124 | ) 125 | 126 | poses_features = torch.zeros(len(dataset), model.latentD).to(device) 127 | texts_features = torch.zeros(len(dataset), model.latentD).to(device) 128 | 129 | for i, batch in tqdm(enumerate(data_loader)): 130 | poses_A = batch['poses_A'].to(device) 131 | poses_B = batch['poses_B'].to(device) 132 | caption_tokens = batch['caption_tokens'].to(device) 133 | caption_lengths = batch['caption_lengths'].to(device) 134 | caption_tokens = caption_tokens[:,:caption_lengths.max()] 135 | with torch.inference_mode(): 136 | pfeat, tfeat = model(poses_A, caption_tokens, caption_lengths, poses_B) 137 | poses_features[i*batch_size:i*batch_size+len(poses_A)] = pfeat 138 | texts_features[i*batch_size:i*batch_size+len(poses_A)] = tfeat 139 | 140 | return poses_features, texts_features 141 | 142 | 143 | def display_results(results): 144 | metric_order = ['mRecall'] + ['%s_R@%d'%(d, k) for d in ['p2t', 't2p'] for k in config.k_recall_values] 145 | results = evaluate.scale_and_format_results(results) 146 | print(f"\n & {' & '.join([results[m] for m in metric_order])} \\\\\n") 147 | 148 | 149 | ################################################################################ 150 | 151 | if __name__ == '__main__': 152 | 153 | args = evaluate.eval_parser.parse_args() 154 | args = evaluate.get_full_model_path(args) 155 | 156 | # compute results 157 | if args.average_over_runs: 158 | ret = evaluate.eval_model_all_runs(eval_model, args.model_path, dataset_version=args.dataset, split=args.split) 159 | else: 160 | ret = eval_model(args.model_path, dataset_version=args.dataset, split=args.split) 161 | 162 | # display results 163 | print(ret) 164 | display_results(ret) -------------------------------------------------------------------------------- /src/text2pose/retrieval_modifier/model_retrieval_modifier.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import torch 11 | from torch import nn 12 | 13 | import text2pose.config as config 14 | from text2pose.encoders.tokenizers import Tokenizer, get_text_encoder_or_decoder_module_name, get_tokenizer_name 15 | from text2pose.encoders.modules import ConCatModule, L2Norm 16 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder 17 | from text2pose.encoders.text_encoders import TextEncoder, TransformerTextEncoder 18 | 19 | 20 | class PairText(nn.Module): 21 | def __init__(self, num_neurons=512, num_neurons_mini=32, latentD=512, 22 | num_body_joints=config.NB_INPUT_JOINTS, 23 | text_encoder_name='distilbertUncased', transformer_topping=None): 24 | super(PairText, self).__init__() 25 | 26 | self.latentD = latentD 27 | 28 | # Define pose encoder 29 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons, 30 | num_neurons_mini=num_neurons_mini, 31 | latentD=latentD, 32 | num_body_joints=num_body_joints, 33 | role="retrieval") 34 | 35 | # Define text encoder 36 | self.text_encoder_name = text_encoder_name 37 | module_ref = get_text_encoder_or_decoder_module_name(text_encoder_name) 38 | if module_ref in ["glovebigru"]: 39 | self.text_encoder = TextEncoder(self.text_encoder_name, latentD=latentD, role="retrieval") 40 | elif module_ref in ["glovetransf", "distilbertUncased"]: 41 | self.text_encoder = TransformerTextEncoder(self.text_encoder_name, latentD=latentD, topping=transformer_topping, role="retrieval") 42 | else: 43 | raise NotImplementedError 44 | 45 | # Define projecting layers 46 | self.pose_mlp = nn.Sequential( 47 | ConCatModule(), 48 | nn.Linear(2 * latentD, 2 * latentD), 49 | nn.LeakyReLU(), 50 | nn.Linear(2 * latentD, latentD), 51 | nn.LeakyReLU(), 52 | nn.Linear(latentD, latentD), 53 | nn.LeakyReLU(), 54 | L2Norm() 55 | ) 56 | 57 | # Loss temperature 58 | self.loss_weight = torch.nn.Parameter( torch.FloatTensor((10,)) ) 59 | self.loss_weight.requires_grad = True 60 | 61 | def forward(self, poses_A, captions, caption_lengths, poses_B): 62 | embed_AB = self.encode_pose_pair(poses_A, poses_B) 63 | text_embs = self.encode_text(captions, caption_lengths) 64 | return embed_AB, text_embs 65 | 66 | def encode_raw_text(self, raw_text): 67 | if not hasattr(self, 'tokenizer'): 68 | self.tokenizer = Tokenizer(get_tokenizer_name(self.text_encoder_name)) 69 | tokens = self.tokenizer(raw_text).to(device=self.loss_weight.device) 70 | length = torch.tensor([ len(tokens) ], dtype=tokens.dtype) 71 | text_embs = self.text_encoder(tokens.view(1, -1), length) 72 | return text_embs 73 | 74 | def encode_pose_pair(self, poses_A, poses_B): 75 | embed_poses_A = self.pose_encoder(poses_A) 76 | embed_poses_B = self.pose_encoder(poses_B) 77 | embed_AB = self.pose_mlp([embed_poses_A, embed_poses_B]) 78 | return embed_AB 79 | 80 | def encode_text(self, captions, caption_lengths): 81 | return self.text_encoder(captions, caption_lengths) -------------------------------------------------------------------------------- /src/text2pose/retrieval_modifier/script_retrieval_modifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################################## 4 | ## text2pose ## 5 | ## Copyright (c) 2023 ## 6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 7 | ## and Naver Corporation ## 8 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 9 | ## See project root for license details. ## 10 | ############################################################## 11 | 12 | 13 | ############################################################## 14 | # SCRIPT ARGUMENTS 15 | 16 | action=$1 # (train|eval|demo) 17 | checkpoint_type="best" # (last|best) 18 | 19 | architecture_args=( 20 | --model PairText 21 | --latentD 512 22 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp" 23 | # --text_encoder_name 'glovebigru_vocPFAHPP' 24 | ) 25 | 26 | loss_args=( 27 | --retrieval_loss 'symBBC' 28 | ) 29 | 30 | bonus_args=( 31 | ) 32 | 33 | pretrained="modret_distilbert_dataPFA" # used only if phase=='finetune' 34 | 35 | 36 | ############################################################## 37 | # EXECUTE 38 | 39 | # TRAIN 40 | if [[ "$action" == *"train"* ]]; then 41 | 42 | phase=$2 # (pretrain|finetune) 43 | echo "NOTE: Expecting as argument the training phase. Got: $phase" 44 | seed=$3 45 | echo "NOTE: Expecting as argument the seed value. Got: $seed" 46 | 47 | # PRETRAIN 48 | if [[ "$phase" == *"pretrain"* ]]; then 49 | 50 | python retrieval_modifier/train_retrieval_modifier.py --dataset "posefix-A" \ 51 | "${architecture_args[@]}" \ 52 | "${loss_args[@]}" \ 53 | "${bonus_args[@]}" \ 54 | --lr_scheduler "stepLR" --lr 0.00005 --lr_step 400 --lr_gamma 0.5 \ 55 | --log_step 20 --val_every 20 \ 56 | --batch_size 128 --epochs 400 --seed $seed 57 | 58 | # FINETUNE 59 | elif [[ "$phase" == *"finetune"* ]]; then 60 | 61 | python retrieval_modifier/train_retrieval_modifier.py --dataset "posefix-H" \ 62 | "${architecture_args[@]}" \ 63 | "${loss_args[@]}" \ 64 | "${bonus_args[@]}" \ 65 | --apply_LR_augmentation \ 66 | --lr_scheduler "stepLR" --lr 0.00005 --lr_step 40 --lr_gamma 0.5 \ 67 | --batch_size 128 --epochs 50 --seed $seed \ 68 | --pretrained $pretrained 69 | 70 | fi 71 | 72 | fi 73 | 74 | 75 | # EVAL QUANTITATIVELY 76 | if [[ "$action" == *"eval"* ]]; then 77 | 78 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one 79 | 80 | for model_path in "${experiments[@]}" 81 | do 82 | echo $model_path 83 | python retrieval_modifier/evaluate_retrieval_modifier.py --dataset "posefix-H" \ 84 | --model_path ${model_path} --checkpoint $checkpoint_type \ 85 | --split test 86 | done 87 | fi 88 | 89 | 90 | # EVAL QUALITATIVELY 91 | if [[ "$action" == *"demo"* ]]; then 92 | 93 | experiment=$2 # only one at a time 94 | streamlit run retrieval_modifier/demo_retrieval_modifier.py -- --model_path $experiment --checkpoint $checkpoint_type 95 | 96 | fi -------------------------------------------------------------------------------- /src/text2pose/retrieval_modifier/train_retrieval_modifier.py: -------------------------------------------------------------------------------- 1 | ############################################################## 2 | ## text2pose ## 3 | ## Copyright (c) 2023 ## 4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ## 5 | ## and Naver Corporation ## 6 | ## Licensed under the CC BY-NC-SA 4.0 license. ## 7 | ## See project root for license details. ## 8 | ############################################################## 9 | 10 | import torch 11 | import math 12 | import sys 13 | import os 14 | os.umask(0x0002) 15 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 16 | 17 | from text2pose.option import get_args_parser 18 | from text2pose.trainer import GenericTrainer 19 | from text2pose.retrieval_modifier.model_retrieval_modifier import PairText 20 | from text2pose.retrieval_modifier.evaluate_retrieval_modifier import compute_eval_metrics 21 | from text2pose.loss import BBC, symBBC 22 | from text2pose.data import PoseFix, PoseMix, PoseScript 23 | from text2pose.encoders.tokenizers import get_tokenizer_name 24 | from text2pose.data_augmentations import DataAugmentation 25 | 26 | import text2pose.config as config 27 | import text2pose.utils_logging as logging 28 | 29 | 30 | ################################################################################ 31 | 32 | 33 | class PairTextTrainer(GenericTrainer): 34 | 35 | def __init__(self, args): 36 | super(PairTextTrainer, self).__init__(args, retrieval_trainer=True) 37 | 38 | 39 | def load_dataset(self, split, caption_index, tokenizer_name=None): 40 | 41 | if tokenizer_name is None: tokenizer_name = get_tokenizer_name(self.args.text_encoder_name) 42 | data_size = self.args.data_size if split=="train" else None 43 | 44 | if "posefix" in self.args.dataset: 45 | d = PoseFix(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size) 46 | elif "posemix" in self.args.dataset: 47 | # NOTE: if specifying data_size: only the first loaded data items 48 | # will be considered (since PoseFix is loaded before PoseScript, if 49 | # data_size < the size of PoseFix, no PoseScript data will be 50 | # loaded) 51 | d = PoseMix(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size) 52 | elif "posescript" in self.args.dataset: 53 | d = PoseScript(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size, posefix_format=True) 54 | else: 55 | raise NotImplementedError 56 | return d 57 | 58 | 59 | def init_model(self): 60 | print('Load model') 61 | self.model = PairText(text_encoder_name=self.args.text_encoder_name, 62 | transformer_topping=self.args.transformer_topping, 63 | latentD=self.args.latentD, 64 | num_body_joints=self.args.num_body_joints) 65 | self.model.to(self.device) 66 | 67 | 68 | def get_param_groups(self): 69 | param_groups = [] 70 | param_groups.append({'params': self.model.pose_encoder.parameters(), 'lr': self.args.lr*self.args.lrposemul}) 71 | param_groups.append({'params': self.model.pose_mlp.parameters(), 'lr': self.args.lr*self.args.lrposemul}) 72 | param_groups.append({'params': [p for k,p in self.model.text_encoder.named_parameters() if 'pretrained_text_encoder.' not in k], 'lr': self.args.lr*self.args.lrtextmul}) 73 | param_groups.append({'params': [self.model.loss_weight]}) 74 | return param_groups 75 | 76 | 77 | def init_optimizer(self): 78 | assert self.args.optimizer=='Adam' 79 | param_groups = self.get_param_groups() 80 | self.optimizer = torch.optim.Adam(param_groups, lr=self.args.lr) 81 | 82 | 83 | def init_lr_scheduler(self): 84 | self.lr_scheduler = None 85 | if self.args.lr_scheduler == "stepLR": 86 | self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 87 | step_size=self.args.lr_step, 88 | gamma=self.args.lr_gamma, 89 | last_epoch=-1) 90 | 91 | 92 | def init_other_training_elements(self): 93 | self.data_augmentation_module = DataAugmentation(self.args, mode="posefix", tokenizer_name=get_tokenizer_name(self.args.text_encoder_name), nb_joints=self.args.num_body_joints) 94 | 95 | 96 | def training_epoch(self, epoch): 97 | train_stats = self.one_epoch(epoch=epoch, is_training=True) 98 | return train_stats 99 | 100 | 101 | def validation_epoch(self, epoch): 102 | val_stats = {} 103 | if self.args.val_every and (epoch+1)%self.args.val_every==0: 104 | val_stats = self.validate(epoch=epoch) 105 | return val_stats 106 | 107 | 108 | def one_epoch(self, epoch, is_training): 109 | 110 | self.model.train(is_training) 111 | 112 | # define loggers 113 | metric_logger = logging.MetricLogger(delimiter=" ") 114 | if is_training: 115 | prefix, sstr = '', 'train' 116 | metric_logger.add_meter(f'{sstr}_lr', logging.SmoothedValue(window_size=1, fmt='{value:.6f}')) 117 | else: 118 | prefix, sstr = '[val] ', 'val' 119 | header = f'{prefix}Epoch: [{epoch}]' 120 | 121 | # define dataloader & other elements 122 | if is_training: 123 | data_loader = self.data_loader_train 124 | if not is_training: 125 | data_loader = self.data_loader_val 126 | 127 | # iterate over the batches 128 | for data_iter_step, item in enumerate(metric_logger.log_every(data_loader, self.args.log_step, header)): 129 | 130 | # get data 131 | poses_A = item['poses_A'].to(self.device) 132 | poses_B = item['poses_B'].to(self.device) 133 | caption_tokens = item['caption_tokens'].to(self.device) 134 | caption_lengths = item['caption_lengths'].to(self.device) 135 | caption_tokens = caption_tokens[:,:caption_lengths.max()] # truncate within the batch, based on the longest text 136 | 137 | # online random augmentations 138 | posescript_poses = item['poses_A_ids'] == config.PID_NAN 139 | poses_A, caption_tokens, caption_lengths, poses_B = self.data_augmentation_module(poses_A, caption_tokens, caption_lengths, poses_B, posescript_poses) 140 | 141 | # forward; compute scores 142 | with torch.set_grad_enabled(is_training): 143 | poses_features, texts_features = self.model(poses_A, caption_tokens, caption_lengths, poses_B) 144 | score_t2p = texts_features.mm(poses_features.t()) * self.model.loss_weight 145 | 146 | # compute loss 147 | if self.args.retrieval_loss == "BBC": 148 | loss = BBC(score_t2p) 149 | elif self.args.retrieval_loss == "symBBC": 150 | loss = symBBC(score_t2p) 151 | else: 152 | raise NotImplementedError 153 | 154 | loss_value = loss.item() 155 | if not math.isfinite(loss_value): 156 | print("Loss is {}, stopping training".format(loss_value)) 157 | sys.exit(1) 158 | 159 | # training step 160 | if is_training: 161 | self.optimizer.zero_grad() 162 | loss.backward() 163 | self.optimizer.step() 164 | 165 | # format data for logging 166 | scalars = [('loss', loss_value)] 167 | if is_training: 168 | lr_value = self.optimizer.param_groups[0]["lr"] 169 | scalars += [('lr', lr_value)] 170 | 171 | # actually log 172 | self.add_data_to_log_writer(epoch, sstr, scalars=scalars, is_training=is_training, data_iter_step=data_iter_step, total_steps=len(data_loader)) 173 | self.add_data_to_metric_logger(metric_logger, sstr, scalars) 174 | 175 | print("Averaged stats:", metric_logger) 176 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 177 | 178 | 179 | def validate(self, epoch): 180 | 181 | self.model.eval() 182 | 183 | recalls, loss_value = compute_eval_metrics(self.model, self.data_loader_val.dataset, self.device, compute_loss=True) 184 | val_stats = {"loss": loss_value} 185 | val_stats.update(recalls) 186 | 187 | # log 188 | self.add_data_to_log_writer(epoch, 'val', scalars=[('loss', loss_value), ('validation', recalls)], should_log_data=True) 189 | print(f"[val] Epoch: [{epoch}] Stats: " + " ".join(f"{k}: {round(v, 3)}" for k,v in val_stats.items()) ) 190 | return val_stats 191 | 192 | 193 | if __name__ == '__main__': 194 | 195 | argparser = get_args_parser() 196 | args = argparser.parse_args() 197 | 198 | PairTextTrainer(args)() -------------------------------------------------------------------------------- /src/text2pose/shortname_2_model_path.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/src/text2pose/shortname_2_model_path.txt -------------------------------------------------------------------------------- /src/text2pose/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ################################################################################ 4 | ## READ/WRITE TO FILES 5 | ################################################################################ 6 | 7 | import json 8 | 9 | def read_json(absolute_filepath): 10 | with open(absolute_filepath, 'r') as f: 11 | data = json.load(f) 12 | return data 13 | 14 | def write_json(data, absolute_filepath, pretty=False): 15 | with open(absolute_filepath, "w") as f: 16 | if pretty: 17 | json.dump(data, f, ensure_ascii=False, indent=2) 18 | else: 19 | json.dump(data, f) 20 | 21 | 22 | ################################################################################ 23 | ## ANGLE TRANSFORMATION FONCTIONS 24 | ################################################################################ 25 | 26 | import roma 27 | 28 | def rotvec_to_eulerangles(x): 29 | x_rotmat = roma.rotvec_to_rotmat(x) 30 | thetax = torch.atan2(x_rotmat[:,2,1], x_rotmat[:,2,2]) 31 | thetay = torch.atan2(-x_rotmat[:,2,0], torch.sqrt(x_rotmat[:,2,1]**2+x_rotmat[:,2,2]**2)) 32 | thetaz = torch.atan2(x_rotmat[:,1,0], x_rotmat[:,0,0]) 33 | return thetax, thetay, thetaz 34 | 35 | def eulerangles_to_rotmat(thetax, thetay, thetaz): 36 | N = thetax.numel() 37 | # rotx 38 | rotx = torch.eye( (3) ).to(thetax.device).repeat(N,1,1) 39 | roty = torch.eye( (3) ).to(thetax.device).repeat(N,1,1) 40 | rotz = torch.eye( (3) ).to(thetax.device).repeat(N,1,1) 41 | rotx[:,1,1] = torch.cos(thetax) 42 | rotx[:,2,2] = torch.cos(thetax) 43 | rotx[:,1,2] = -torch.sin(thetax) 44 | rotx[:,2,1] = torch.sin(thetax) 45 | roty[:,0,0] = torch.cos(thetay) 46 | roty[:,2,2] = torch.cos(thetay) 47 | roty[:,0,2] = torch.sin(thetay) 48 | roty[:,2,0] = -torch.sin(thetay) 49 | rotz[:,0,0] = torch.cos(thetaz) 50 | rotz[:,1,1] = torch.cos(thetaz) 51 | rotz[:,0,1] = -torch.sin(thetaz) 52 | rotz[:,1,0] = torch.sin(thetaz) 53 | rotmat = torch.einsum('bij,bjk->bik', rotz, torch.einsum('bij,bjk->bik', roty, rotx)) 54 | return rotmat 55 | 56 | def eulerangles_to_rotvec(thetax, thetay, thetaz): 57 | rotmat = eulerangles_to_rotmat(thetax, thetay, thetaz) 58 | return roma.rotmat_to_rotvec(rotmat) 59 | 60 | 61 | ################################################################################ 62 | ## LOAD POSE DATA 63 | ################################################################################ 64 | 65 | import os 66 | import numpy as np 67 | 68 | import text2pose.config as config 69 | 70 | 71 | def get_pose_data_from_file(pose_info, applied_rotation=None, output_rotation=False): 72 | """ 73 | Load pose data and normalize the orientation. 74 | 75 | Args: 76 | pose_info: list [dataset (string), sequence_filepath (string), frame_index (int)] 77 | applied_rotation: rotation to be applied to the pose data. If None, the 78 | normalization rotation is applied. 79 | output_rotation: whether to output the rotation performed for 80 | normalization, in addition of the normalized pose data. 81 | 82 | Returns: 83 | pose data, torch.tensor of size (1, n_joints*3), all joints considered. 84 | (optional) R, torch.tensor representing the rotation of normalization 85 | """ 86 | 87 | # load pose data 88 | assert pose_info[0] in config.supported_datasets, f"Expected data from on of the following datasets: {','.join(config.supported_datasets)} (provided dataset: {pose_info[0]})." 89 | 90 | if pose_info[0] == "AMASS": 91 | dp = np.load(os.path.join(config.supported_datasets[pose_info[0]], pose_info[1])) 92 | pose = dp['poses'][pose_info[2],:].reshape(-1,3) # (n_joints, 3) 93 | pose = torch.as_tensor(pose).to(dtype=torch.float32) 94 | 95 | # normalize the global orient 96 | initial_rotation = pose[:1,:].clone() 97 | if applied_rotation is None: 98 | thetax, thetay, thetaz = rotvec_to_eulerangles( initial_rotation ) 99 | zeros = torch.zeros_like(thetaz) 100 | pose[0:1,:] = eulerangles_to_rotvec(thetax, thetay, zeros) 101 | else: 102 | pose[0:1,:] = roma.rotvec_composition((applied_rotation, initial_rotation)) 103 | if output_rotation: 104 | # a = A.u, after normalization, becomes a' = A'.u 105 | # we look for the normalization rotation R such that: a' = R.a 106 | # since a = A.u ==> u = A^-1.a 107 | # a' = A'.u = A'.A^-1.a ==> R = A'.A^-1 108 | R = roma.rotvec_composition((pose[0:1,:], roma.rotvec_inverse(initial_rotation))) 109 | return pose.reshape(1, -1), R 110 | 111 | return pose.reshape(1, -1) 112 | 113 | 114 | def pose_data_as_dict(pose_data, code_base='human_body_prior'): 115 | """ 116 | Args: 117 | pose_data, torch.tensor of shape (*, n_joints*3) or (*, n_joints, 3), 118 | all joints considered. 119 | Returns: 120 | dict 121 | """ 122 | # reshape to (*, n_joints*3) if necessary 123 | if len(pose_data.shape) == 3: 124 | # shape (batch_size, n_joints, 3) 125 | pose_data = pose_data.flatten(1,2) 126 | if len(pose_data.shape) == 2 and pose_data.shape[1] == 3: 127 | # shape (n_joints, 3) 128 | pose_data = pose_data.view(1, -1) 129 | # provide as a dict, with different keys, depending on the code base 130 | if code_base == 'human_body_prior': 131 | d = {"root_orient":pose_data[:,:3], 132 | "pose_body":pose_data[:,3:66]} 133 | if pose_data.shape[1] > 66: 134 | d["pose_hand"] = pose_data[:,66:] 135 | elif code_base == 'smplx': 136 | d = {"global_orient":pose_data[:,:3], 137 | "body_pose":pose_data[:,3:66]} 138 | if pose_data.shape[1] > 66: 139 | d.update({"left_hand_pose":pose_data[:,66:111], 140 | "right_hand_pose":pose_data[:,111:]}) 141 | return d -------------------------------------------------------------------------------- /src/text2pose/utils_logging.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import datetime 4 | import time 5 | from collections import defaultdict, deque 6 | import torch 7 | 8 | 9 | class HiddenPrints: 10 | def __enter__(self): 11 | self._original_stdout = sys.stdout 12 | sys.stdout = open(os.devnull, 'w') 13 | 14 | def __exit__(self, exc_type, exc_val, exc_tb): 15 | sys.stdout.close() 16 | sys.stdout = self._original_stdout 17 | 18 | 19 | class SmoothedValue(object): 20 | """ 21 | Track a series of values and provide access to smoothed values over a window 22 | or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | @property 39 | def median(self): 40 | d = torch.tensor(list(self.deque)) 41 | return d.median().item() 42 | 43 | @property 44 | def avg(self): 45 | d = torch.tensor(list(self.deque), dtype=torch.float32) 46 | return d.mean().item() 47 | 48 | @property 49 | def global_avg(self): 50 | return self.total / self.count 51 | 52 | @property 53 | def max(self): 54 | return max(self.deque) 55 | 56 | @property 57 | def value(self): 58 | return self.deque[-1] 59 | 60 | def __str__(self): 61 | return self.fmt.format( 62 | median=self.median, 63 | avg=self.avg, 64 | global_avg=self.global_avg, 65 | max=self.max, 66 | value=self.value) 67 | 68 | 69 | class MetricLogger(object): 70 | 71 | def __init__(self, delimiter="\t"): 72 | self.meters = defaultdict(SmoothedValue) 73 | self.delimiter = delimiter 74 | 75 | def update(self, **kwargs): 76 | for k, v in kwargs.items(): 77 | if v is None: 78 | continue 79 | if isinstance(v, torch.Tensor): 80 | v = v.item() 81 | assert isinstance(v, (float, int)) 82 | self.meters[k].update(v) 83 | 84 | def __getattr__(self, attr): 85 | if attr in self.meters: 86 | return self.meters[attr] 87 | if attr in self.__dict__: 88 | return self.__dict__[attr] 89 | raise AttributeError("'{}' object has no attribute '{}'".format( 90 | type(self).__name__, attr)) 91 | 92 | def __str__(self): 93 | loss_str = [] 94 | for name, meter in self.meters.items(): 95 | loss_str.append( 96 | "{}: {}".format(name, str(meter)) 97 | ) 98 | return self.delimiter.join(loss_str) 99 | 100 | def add_meter(self, name, meter): 101 | self.meters[name] = meter 102 | 103 | def log_every(self, iterable, print_freq, header=None): 104 | i = 0 105 | if not header: 106 | header = '' 107 | start_time = time.time() 108 | end = time.time() 109 | iter_time = SmoothedValue(fmt='{avg:.4f}') 110 | data_time = SmoothedValue(fmt='{avg:.4f}') 111 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 112 | log_msg = [ 113 | header, 114 | '[{0' + space_fmt + '}/{1}]', 115 | 'eta: {eta}', 116 | '{meters}', 117 | 'time: {time}', 118 | 'data: {data}' 119 | ] 120 | if torch.cuda.is_available(): 121 | log_msg.append('max mem: {memory:.0f}') 122 | log_msg = self.delimiter.join(log_msg) 123 | MB = 1024.0 * 1024.0 124 | for obj in iterable: 125 | data_time.update(time.time() - end) 126 | yield obj 127 | iter_time.update(time.time() - end) 128 | if i % print_freq == 0 or i == len(iterable) - 1: 129 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 130 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 131 | if torch.cuda.is_available(): 132 | print(log_msg.format( 133 | i, len(iterable), eta=eta_string, 134 | meters=str(self), 135 | time=str(iter_time), data=str(data_time), 136 | memory=torch.cuda.max_memory_allocated() / MB)) 137 | else: 138 | print(log_msg.format( 139 | i, len(iterable), eta=eta_string, 140 | meters=str(self), 141 | time=str(iter_time), data=str(data_time))) 142 | i += 1 143 | end = time.time() 144 | total_time = time.time() - start_time 145 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 146 | print('{} Total time: {} ({:.4f} s / it)'.format( 147 | header, total_time_str, total_time / len(iterable))) 148 | --------------------------------------------------------------------------------