├── LICENSE
├── README.md
├── images
├── caption_generation_model.png
├── captioning_pipeline.png
├── comparative_pipeline.png
├── feedback_generation_model.png
├── generative_model.png
├── main_picture.png
├── pose_editing_model.png
├── retrieval_model.png
└── retrieval_modifier_model.png
├── pretrained_models.md
├── requirements.txt
├── setup.py
└── src
├── __init__.py
├── other_utils
├── README.md
├── pair_mining.py
└── pose_mining.py
└── text2pose
├── __init__.py
├── config.py
├── data.py
├── data_augmentations.py
├── demo.py
├── encoders
├── __init__.py
├── modules.py
├── pose_encoder_decoder.py
├── text_decoders.py
├── text_encoders.py
└── tokenizers.py
├── evaluate.py
├── fid.py
├── generative
├── README.md
├── __init__.py
├── demo_generative.py
├── evaluate_generative.py
├── generate_poses.py
├── look_at_generated_pose_samples.py
├── model_generative.py
├── script_generative.sh
└── train_generative.py
├── generative_B
├── README.md
├── __init__.py
├── demo_generative_B.py
├── evaluate_generative_B.py
├── model_generative_B.py
├── script_generative_B.sh
└── train_generative_B.py
├── generative_caption
├── README.md
├── __init__.py
├── demo_generative_caption.py
├── evaluate_generative_caption.py
├── model_generative_caption.py
├── script_generative_caption.sh
└── train_generative_caption.py
├── generative_modifier
├── README.md
├── __init__.py
├── demo_generative_modifier.py
├── evaluate_generative_modifier.py
├── model_generative_modifier.py
├── script_generative_modifier.sh
└── train_generative_modifier.py
├── loss.py
├── option.py
├── posefix
├── README.md
├── __init__.py
├── compute_rotation_change.py
├── correcting.py
├── corrective_data.py
├── explore_posefix.py
└── paircodes.py
├── posescript
├── README.md
├── __init__.py
├── action_to_sent_template.json
├── captioning.py
├── captioning_data.py
├── compute_coords.py
├── explore_posescript.py
├── format_babel_labels.py
├── format_contact_info.py
├── posecodes.py
├── smplx_custom_semantic_segmentation.json
└── utils.py
├── retrieval
├── README.md
├── __init__.py
├── demo_retrieval.py
├── evaluate_retrieval.py
├── model_retrieval.py
├── script_retrieval.sh
└── train_retrieval.py
├── retrieval_modifier
├── README.md
├── __init__.py
├── demo_retrieval_modifier.py
├── evaluate_retrieval_modifier.py
├── model_retrieval_modifier.py
├── script_retrieval_modifier.sh
└── train_retrieval_modifier.py
├── shortname_2_model_path.txt
├── trainer.py
├── utils.py
├── utils_logging.py
├── utils_visu.py
└── vocab.py
/images/caption_generation_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/caption_generation_model.png
--------------------------------------------------------------------------------
/images/captioning_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/captioning_pipeline.png
--------------------------------------------------------------------------------
/images/comparative_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/comparative_pipeline.png
--------------------------------------------------------------------------------
/images/feedback_generation_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/feedback_generation_model.png
--------------------------------------------------------------------------------
/images/generative_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/generative_model.png
--------------------------------------------------------------------------------
/images/main_picture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/main_picture.png
--------------------------------------------------------------------------------
/images/pose_editing_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/pose_editing_model.png
--------------------------------------------------------------------------------
/images/retrieval_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/retrieval_model.png
--------------------------------------------------------------------------------
/images/retrieval_modifier_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/images/retrieval_modifier_model.png
--------------------------------------------------------------------------------
/pretrained_models.md:
--------------------------------------------------------------------------------
1 | # Available trained models
2 |
3 | ## Download links
4 |
5 | | Subdirectory | Model (link to README) | Model shortname | Link | Main results|
6 | |---|---|---|---|---|
7 | | `retrieval` | [text-to-pose retrieval model](./src/text2pose/retrieval/README.md) | `ret_distilbert_dataPSA2ftPSH2` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/ret_distilbert_dataPSA2ftPSH2.zip) | **mRecall** = 47.92
**R@1 Precision (GT)** = 84.76 |
8 | | `retrieval_modifier` | [pose-pair-to-instruction retrieval model](./src/text2pose/retrieval_modifier/README.md) | `modret_distilbert_dataPFAftPFH` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/modret_distilbert_dataPFAftPFH.zip) | **mRecall** = 30.00
**R@1 Precision (GT)** = 68.04 |
9 | | `generative` | [text-conditioned pose generation model](./src/text2pose/generative/README.md) | `gen_distilbert_dataPSA2ftPSH2` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/gen_distilbert_dataPSA2ftPSH2.zip) | **ELBO jts/vert/rot** = 1.44 / 1.82 / 0.90 |
10 | | `generative_B` | [text-guided pose editing model](./src/text2pose/generative_B/README.md) | `b_gen_distilbert_dataPFAftPFH` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/b_gen_distilbert_dataPFAftPFH.zip) | **ELBO jts/vert/rot** = 1.43 / 1.90 / 1.00 |
11 | | `generative_caption` | [pose description generation model](./src/text2pose/generative_caption/README.md) | `capgen_CAtransfPSA2H2_dataPSA2ftPSH2` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/capgen_CAtransfPSA2H2_dataPSA2ftPSH2.zip) | **R@1 Precision** = 89.38
**MPJE_30** = 202
**ROUGE-L** = 33.95 |
12 | | `generative_modifier` | [pose-based correctional text generation model](./src/text2pose/generative_modifier/README.md) | `modgen_CAtransfPFAHPP_dataPFAftPFH` | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/modgen_CAtransfPFAHPP_dataPFAftPFH.zip) | **R@1 Precision** = 78.85
**MPJE_30** = 186
**ROUGE-L** = 33.53 |
13 |
14 | Unzip the archives and place the content of the resulting directories in ***GENERAL_EXP_OUTPUT_DIR***.
15 |
16 | **Note:** these models are the result of a two-stage training, involving a pretraining stage on automatic texts, and a finetuning stage on human-written annotations.
17 |
18 |
19 | Bash script to download & unzip everything all at once.
20 |
21 | ```bash
22 | cd "" # TODO replace!
23 |
24 | arr=(
25 | ret_distilbert_dataPSA2ftPSH2
26 | modret_distilbert_dataPFAftPFH
27 | gen_distilbert_dataPSA2ftPSH2
28 | b_gen_distilbert_dataPFAftPFH
29 | capgen_CAtransfPSA2H2_dataPSA2ftPSH2
30 | modgen_CAtransfPFAHPP_dataPFAftPFH
31 | )
32 |
33 | for a in "${arr[@]}"; do
34 | echo "Download and extract $a"
35 | wget "https://download.europe.naverlabs.com/ComputerVision/PoseFix/${a}.zip"
36 | unzip "${a}.zip"
37 | rm "${a}.zip"
38 | done
39 | ```
40 |
41 |
42 |
43 |
44 | Differences in results with the papers.
45 |
46 | * *Text-to-pose retrieval*: providing an improved model, pretrained on new automatic captions, and with a symmetric constrastive loss (vs. uni-directional contrastive loss in the paper)
47 | * *Instruction-to-pair retrieval*: providing an improved model trained with symmetric contrastive loss (vs. uni-directional contrastive loss in the paper).
48 | * *Pose editing:* the provided model uses a transformer-based text encoder (frozen DistilBert + learned transformer), for consistency with the other provided models (vs. GloVe+biGRU configuration used to report results in the paper). Note: this model was finetuned using the best setting as per Table 4: with L/R flip and paraphrases. The FID value may also change as evaluation is carried out with an improved version of the text-to-pose retrieval model.
49 | * *Text generation models*: evaluated with improved retrieval models; also note that, despite an average over 10 repetitions, R-precision metrics come with a great variability due to the randomized selection of the pool of samples to compare against.
50 |
51 |
52 | ## References in `shortname_2_model_path.txt`
53 |
54 | References should be given using the following format:
55 |
56 | ```
57 | <4 spaces>
58 | ```
59 |
60 | Thus, for the above-mentioned models (simply replace `` by its proper value):
61 | ```text
62 | ret_distilbert_dataPSA2ftPSH2 /ret_distilbert_dataPSA2ftPSH2/seed1/checkpoint_best.pth
63 | modret_distilbert_dataPFAftPFH /modret_distilbert_dataPFAftPFH/seed1/checkpoint_best.pth
64 | gen_distilbert_dataPSA2ftPSH2 /gen_distilbert_dataPSA2ftPSH2/seed1/checkpoint_best.pth
65 | b_gen_distilbert_dataPFAftPFH /b_gen_distilbert_dataPFAftPFH/seed1/checkpoint_best.pth
66 | capgen_CAtransfPSA2H2_dataPSA2ftPSH2 /capgen_CAtransfPSA2H2_dataPSA2ftPSH2/seed1/checkpoint_best.pth
67 | modgen_CAtransfPFAHPP_dataPFAftPFH /modgen_CAtransfPFAHPP_dataPFAftPFH/seed1/checkpoint_best.pth
68 | ```
69 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 |
2 | --find-links https://download.pytorch.org/whl/torch_stable.html
3 |
4 | git+https://github.com/nghorbani/body_visualizer
5 | git+https://github.com/MPI-IS/configer.git
6 | git+https://github.com/MPI-IS/mesh.git
7 | git+https://github.com/nghorbani/human_body_prior.git
8 |
9 | torch==1.10.1
10 | torchtext==0.11.1
11 | torchvision==0.11.2
12 | nltk
13 | smplx
14 | matplotlib
15 | opencv-python
16 | transformers
17 | trimesh[easy]
18 | pyrender
19 | roma
20 | streamlit==1.23.1
21 | tabulate
22 | tensorboard==2.11.2
23 | setuptools==59.5.0
24 | bert_score
25 | tqdm
26 | evaluate
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | from setuptools import setup, find_packages
11 |
12 | setup(name='text2pose',
13 | version='2.0',
14 | packages=find_packages('src'),
15 | package_dir={'': 'src'},
16 | include_package_data=True,
17 | author='Ginger Delmas',
18 | author_email='ginger.delmas.pro@gmail.com',
19 | description='PoseScript ECCV22 & PoseFix ICCV23.',
20 | long_description=open("README.md").read(),
21 | long_description_content_type="text/markdown",
22 | install_requires=[],
23 | dependency_links=[],
24 | )
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/other_utils/README.md:
--------------------------------------------------------------------------------
1 | # Other potentially useful code
2 |
3 | This directory presents some code that has been useful in different stages of the project (eg. for dataset construction), but does not contribute directly to the training or testing of the models proposed in the papers. In particular, this code is not necessarily cleaned nor well documented, and may trigger a bunch of errors / require some additional setup before running successfully (watch for "TODO" marks) -- apologies! Hopefully, this code can still prove useful for concrete understanding of some aspects of the project.
4 |
5 | Included:
6 | * pose mining (pose selection in PoseScript; also corresponding to poses "B" in PoseFix)
7 | * pair mining (process to select pairs in PoseFix)
8 |
9 | *Note:* for the reasons mentioned previously, this code is currently seccluded in this separate directory. Note that it requires some functions defined in the rest of the repository, so one may need to run `python setup.py develop` to import text2pose as a package first...
--------------------------------------------------------------------------------
/src/other_utils/pose_mining.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2024 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import os
11 | import torch
12 | from tqdm import tqdm
13 |
14 |
15 | # SETUP
16 | ################################################################################
17 |
18 | DATA_LOCATION = "TODO" # where to save files
19 |
20 |
21 | # UTILS
22 | ################################################################################
23 |
24 | def MPJE(coords_set, coords_one):
25 | """
26 | Args:
27 | coords_set: tensor (B, nb joints, 3)
28 | coords_one: tensor (nb joints, 3)
29 |
30 | Returns:
31 | tensor (B) giving the mean per joint distance between each pose in the
32 | set and the provided pose.
33 | """
34 | return torch.norm(coords_set - coords_one, dim=2).mean(1) # (B)
35 |
36 |
37 | def farther_sampling_resume(data, selected, distance, distance_func, nb_to_select=10):
38 | """
39 | Args:
40 | data (torch.tensor): size (number of poses, number of joints, 3)
41 | or (number of poses, number of features)
42 | selected (list): indices of the poses that were already selected
43 | distance (torch.tensor): size (number of poses), distance between each
44 | pose and the closest pose of the selected set
45 | distance_func (function): computes the distance between 2 poses
46 | nb_to_select (int): number of data points to select, in addition of the
47 | ones already selected
48 |
49 | Returns:
50 | selected (list): indices of the selected poses
51 | distance (torch.tensor): size (number of poses), distance between each
52 | pose and the closest pose of the selected set
53 | """
54 |
55 | nb_to_select = min(data.size(0)-len(selected), nb_to_select)
56 |
57 | for _ in tqdm(range(0, nb_to_select)):
58 | distance_update = distance_func(data, data[selected[-1]])
59 | distance = torch.amin(torch.cat((distance.view(-1,1), distance_update.view(-1,1)), 1), dim=1)
60 | selected.append(torch.argmax(distance).item())
61 |
62 | return selected, distance
63 |
64 |
65 | def get_diversity_pose_order(coords, suffix, split, nb_select=10, seed=0, resume=False):
66 | """
67 | coords: dict {data_id: 3D coords of the main joints (torch.tensor shape (n_joints, 3))}
68 | suffix: for file naming
69 | """
70 |
71 | # prepare data for the farther sampling
72 | # (resume, if applicable)
73 | data_ids = sorted(coords.keys())
74 | coords = torch.stack([coords[did] for did in data_ids]) # (nb poses, nb joints, 3)
75 | nb_select = min(len(coords), nb_select)
76 |
77 | if resume:
78 | filepath_resume_from = os.path.join(DATA_LOCATION, f"farther_sample_{resume}_{suffix}.pt")
79 | data_ids_, selected, distance = torch.load(filepath_resume_from)
80 | assert data_ids == data_ids_, "Cannot resume. Data changed!"
81 | print("Resuming from:", filepath_resume_from)
82 | else:
83 | selected = [seed] # to make the results somewhat reproducible
84 | distance = torch.ones(len(coords)) * float('inf')
85 | print(f"Farther sampling from seed {seed}.")
86 |
87 | # farther sample
88 | print(f"Sampling from {len(coords)} elements. Number of elements already selected: {len(selected)}.")
89 | selected, distance = farther_sampling_resume(coords, selected, distance, MPJE, nb_select)
90 |
91 | # save
92 | filesave = os.path.join(DATA_LOCATION, f"farther_sample_{split}_{nb_select}_{suffix}.pt")
93 | torch.save([data_ids, selected, distance], filesave)
94 | print("Saved:", filesave)
95 |
96 |
97 | # MAIN
98 | ################################################################################
99 |
100 | if __name__ == "__main__":
101 |
102 | import argparse
103 | parser = argparse.ArgumentParser()
104 | parser.add_argument('--split', type=str, choices=('training', 'validation'), default='validation')
105 | parser.add_argument('--nb_select', type=int, default=50000)
106 | parser.add_argument('--resume', type=int, default=0)
107 | args = parser.parse_args()
108 |
109 | suffix = "try"
110 | get_diversity_pose_order(split=args.split, suffix=suffix, nb_select=args.nb_select, resume=args.resume)
--------------------------------------------------------------------------------
/src/text2pose/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/config.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | # This file serves to store global config parameters and paths
11 | # (those may change depending on the user, provided data, trained models...)
12 |
13 | # default
14 | import os
15 | MAIN_DIR = os.path.realpath(__file__)
16 | MAIN_DIR = os.path.dirname(os.path.dirname(os.path.dirname(MAIN_DIR)))
17 |
18 |
19 | ################################################################################
20 | # Output dir for experiments
21 | ################################################################################
22 |
23 | GENERAL_EXP_OUTPUT_DIR = MAIN_DIR + '/experiments'
24 |
25 |
26 | ################################################################################
27 | # Data
28 | ################################################################################
29 |
30 | POSEFIX_LOCATION = MAIN_DIR + '/data/PoseFix/posefix_release'
31 | POSESCRIPT_LOCATION = MAIN_DIR + '/data/PoseScript/posescript_release'
32 | POSEMIX_LOCATION = MAIN_DIR + '/data/posemix'
33 |
34 | version_suffix = "_100k" # to be used for pipeline-related data (coords, rotation change, babel labels)
35 | file_pose_id_2_dataset_sequence_and_frame_index = f"{POSESCRIPT_LOCATION}/ids_2_dataset_sequence_and_frame_index_100k.json"
36 | file_pair_id_2_pose_ids = f"{POSEFIX_LOCATION}/pair_id_2_pose_ids.json"
37 | file_posescript_split = f"{POSESCRIPT_LOCATION}/%s_ids_100k.json" # %s --> (train|val|test)
38 | file_posefix_split = f"{POSEFIX_LOCATION}/%s_%s_sequence_pair_ids.json" # %s %s --> (train|val|test), (in|out)
39 |
40 |
41 | ### pose config ----------------------------------------------------------------
42 |
43 | POSE_FORMAT = 'smplh'
44 | SMPLH_BODY_MODEL_PATH = MAIN_DIR + '/data/smplh_amass_body_models'
45 | NEUTRAL_BM = f'{SMPLH_BODY_MODEL_PATH}/neutral/model.npz'
46 | NB_INPUT_JOINTS = 52 # default value used when initializing modules, unless specified otherwise
47 | n_betas = 16
48 |
49 | SMPLX_BODY_MODEL_PATH = MAIN_DIR + '/data/smpl_models' # should contain "smplx/SMPLX_NEUTRAL.(npz|pkl)"
50 |
51 | PID_NAN = -99999 # pose fake IDs, used for empty poses
52 |
53 | ### pose data ------------------------------------------------------------------
54 |
55 | AMASS_FILE_LOCATION = MAIN_DIR + f"/data/AMASS/{POSE_FORMAT}/"
56 | supported_datasets = {"AMASS":AMASS_FILE_LOCATION}
57 |
58 | BABEL_LOCATION = MAIN_DIR + "/data/BABEL/babel_v1.0_release"
59 |
60 | generated_pose_path = '%s/generated_poses/posescript_version_{data_version}_split_{split}_gensamples.pth' # %s is for the model directory (obtainable with shortname_2_model_path)
61 |
62 |
63 | ### text data ------------------------------------------------------------------
64 |
65 | MAX_TOKENS = 500 # defined here because it depends on the provided data (only affects the glovebigru configuration)
66 |
67 | vocab_files = {
68 | # IMPORTANT: do not use "_" symbols in the keys of this dictionary
69 | # IMPORTANT: vocabs overlap, but the order of the tokens is not the same; a
70 | # model trained with one vocab can't be finetuned with another
71 | # without risk
72 | "vocPSA2H2": "vocab_posescript_6293_auto100k.pkl",
73 | "vocPFAHPP": "vocab_posefix_6157_pp4284_auto.pkl",
74 | "vocMixPSA2H2PFAHPP": "vocab_posemix_PS6193_PF6157pp4284.pkl",
75 | }
76 |
77 | caption_files = {
78 | # : (, )
79 | "posescript-A2": (3, [f"{POSESCRIPT_LOCATION}/posescript_auto_100k.json"]),
80 | "posescript-H2": (1, [f"{POSESCRIPT_LOCATION}/posescript_human_6293.json"]),
81 | "posefix-A": (3, [f"{POSEFIX_LOCATION}/posefix_auto_135305.json"]),
82 | "posefix-PP": (1, [f"{POSEFIX_LOCATION}/posefix_paraphrases_4284.json"]), # average is 2 texts/item (min 1, max 6)
83 | "posefix-H": (1, [f"{POSEFIX_LOCATION}/posefix_human_6157.json"]),
84 | "posefix-HPP": (1, [f"{POSEFIX_LOCATION}/posefix_human_6157.json", f"{POSEFIX_LOCATION}/posefix_paraphrases_4284.json"]), # 1 text/item at min, because paraphrases are essentially for the train set
85 | "posemix-PSH2-PFHPP": (1, [f"{POSESCRIPT_LOCATION}/posescript_human_6293.json", f"{POSEFIX_LOCATION}/posefix_human_6157.json", f"{POSEFIX_LOCATION}/posefix_paraphrases_4284.json"]),
86 | }
87 |
88 | # data cache
89 | dirpath_cache_dataset = MAIN_DIR + "/dataset_cache"
90 | cache_file_path = {
91 | "posescript":'%s/PoseScript_version_{data_version}_split_{split}_tokenizer_{tokenizer}.pkl' % dirpath_cache_dataset,
92 | "posefix":'%s/PoseFix_version_{data_version}_split_{split}_tokenizer_{tokenizer}.pkl' % dirpath_cache_dataset,
93 | "posemix":'%s/PoseMix_version_{data_version}_split_{split}_tokenizer_{tokenizer}.pkl' % dirpath_cache_dataset,
94 | "posestream":'%s/PoseStream_version_{data_version}_split_{split}_tokenizer_{tokenizer}.pkl' % dirpath_cache_dataset,
95 | }
96 |
97 |
98 | ################################################################################
99 | # Model cache
100 | ################################################################################
101 |
102 | GLOVE_DIR = MAIN_DIR + '/tools/torch_models/glove' # or None
103 | TRANSFORMER_CACHE_DIR = MAIN_DIR + '/tools/huggingface_models'
104 | SELFCONTACT_ESSENTIALS_DIR = MAIN_DIR + '/tools/selfcontact/essentials'
105 |
106 |
107 | ################################################################################
108 | # Shortnames to checkpoint paths
109 | ################################################################################
110 | # Shortnames are used to refer to:
111 | # - pretrained models
112 | # - models that generated pose files
113 | # - models used for evaluation (fid, recall, reconstruction, r-precision...)
114 |
115 | # shortnames for models are expected to be the same accross seed values;
116 | # model paths should contain a specific seed_value field instead of the actual seed value
117 | normalize_model_path = lambda model_path, seed_value: "/".join(model_path.split("/")[:-2]) + f"/seed{seed_value}/"+ model_path.split("/")[-1]
118 |
119 | # shortname & model paths are stored in shortname_2_model_path.json (which can be updated by some scripts)
120 | try:
121 | with open("shortname_2_model_path.txt", "r") as f:
122 | # each line has the following format: <4 spaces>
123 | shortname_2_model_path = [l.split(" ") for l in f.readlines() if len(l.strip())]
124 | shortname_2_model_path = {l[0]:normalize_model_path(l[1].strip(), '{seed}') for l in shortname_2_model_path}
125 | except FileNotFoundError:
126 | # print("File not found: shortname_2_model_path.txt - Please ensure you are launching operations from the right directory.")
127 | pass # this file may not even be needed; subsequent errors can be expected otherwise
128 |
129 |
130 | ################################################################################
131 | # Evaluation
132 | ################################################################################
133 |
134 | # NOTE: models used to compute the fid should be specified in `shortname_2_model_path`
135 |
136 | k_recall_values = [1, 5, 10]
137 | nb_sample_reconstruction = 30
138 | k_topk_reconstruction_values = [1, 6] # keep the top-1 and the top-N/4 where N is the nb_sample_reconstruction
139 | k_topk_r_precision = [1,2,3]
140 | r_precision_n_repetitions = 10
141 | sample_size_r_precision = 32
142 |
143 |
144 | ################################################################################
145 | # Visualization settings
146 | ################################################################################
147 |
148 | meshviewer_size = 1600
149 |
150 |
151 | if __name__=="__main__":
152 | import sys
153 | try:
154 | # if the provided model shortname is registered, return the complete model path (with the provided seed value)
155 | if sys.argv[1] in shortname_2_model_path:
156 | print(shortname_2_model_path[sys.argv[1]].format(seed=sys.argv[2]))
157 | except IndexError:
158 | # clean shortname_2_model_path.txt
159 | update = []
160 | for k,p in shortname_2_model_path.items():
161 | update.append(f"{k} {p.format(seed=0)}\n")
162 | with open("shortname_2_model_path.txt", "w") as f:
163 | f.writelines(update)
164 | print("Cleaned shortname_2_model_path.txt (unique entries with seed 0).")
--------------------------------------------------------------------------------
/src/text2pose/data_augmentations.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023, 2024 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import torch
11 | from copy import deepcopy
12 |
13 | import text2pose.config as config
14 | from text2pose.encoders.tokenizers import Tokenizer
15 | from text2pose.posescript.utils import ALL_JOINT_NAMES
16 |
17 |
18 | class PoseFlip():
19 |
20 | def __init__(self, nb_joints=22):
21 | super(PoseFlip, self).__init__()
22 |
23 | # get joint names (depends on the case)
24 | if nb_joints == 21:
25 | # all main joints, without the root
26 | joint_names = ALL_JOINT_NAMES[1:22]
27 | elif nb_joints == 22:
28 | # all main joints, with the root
29 | joint_names = ALL_JOINT_NAMES[:22]
30 | elif nb_joints == 52:
31 | joint_names = ALL_JOINT_NAMES[:]
32 | else:
33 | raise NotImplementedError
34 |
35 | # build joint correspondance indices
36 | n2i = {n:i for i, n in enumerate(joint_names)}
37 | l2r_j_id = {i:n2i[n.replace("left", "right")] for n,i in n2i.items() if "left" in n} # joint index correspondance between left and right
38 | self.left_joint_inds = torch.tensor(list(l2r_j_id.keys()))
39 | self.right_joint_inds = torch.tensor(list(l2r_j_id.values()))
40 |
41 | def flip_pose_data_LR(self, pose_data):
42 | """
43 | pose_data: shape (batch_size, nb_joint, 3)
44 | """
45 | l_data = deepcopy(pose_data[:,self.left_joint_inds])
46 | r_data = deepcopy(pose_data[:,self.right_joint_inds])
47 | pose_data[:,self.left_joint_inds] = r_data
48 | pose_data[:,self.right_joint_inds] = l_data
49 | pose_data[:,:, 1:3] *= -1
50 | return pose_data
51 |
52 | def __call__(self, pose_data):
53 | return self.flip_pose_data_LR(pose_data.clone())
54 |
55 |
56 | def DataAugmentation(args, mode, tokenizer_name=None, nb_joints=config.NB_INPUT_JOINTS):
57 | # --- define process
58 | if mode == "posefix":
59 | return PosefixDataAugmentation(args, tokenizer_name, nb_joints)
60 | elif mode == "posescript":
61 | return PosescriptDataAugmentation(args, tokenizer_name, nb_joints)
62 | else:
63 | raise NotImplementedError
64 |
65 |
66 | class GenericDataAugmentation():
67 | def __init__(self, args, tokenizer_name=None, nb_joints=config.NB_INPUT_JOINTS):
68 | super(GenericDataAugmentation, self).__init__()
69 |
70 | self.args = args
71 |
72 | # --- initialize data augmentation tools
73 | if tokenizer_name and (self.args.apply_LR_augmentation or self.args.copy_augmentation):
74 | self.tokenizer = Tokenizer(tokenizer_name)
75 |
76 | if tokenizer_name and self.args.copy_augmentation:
77 | empty_text_tokens = self.tokenizer("")
78 | self.empty_text_length = len(empty_text_tokens) # account for BOS & EOS tokens
79 | self.empty_text_tokens = torch.cat( (empty_text_tokens, self.tokenizer.pad_token_id * torch.ones( self.tokenizer.max_tokens-self.empty_text_length, dtype=empty_text_tokens.dtype) ), dim=0)
80 |
81 | if self.args.apply_LR_augmentation:
82 | self.pose_flip = PoseFlip(nb_joints)
83 |
84 |
85 | class PosescriptDataAugmentation(GenericDataAugmentation):
86 | def __init__(self, args, tokenizer_name=None, nb_joints=config.NB_INPUT_JOINTS):
87 | super(PosescriptDataAugmentation, self).__init__(args, tokenizer_name=tokenizer_name, nb_joints=nb_joints)
88 |
89 | def __call__(self, poses, caption_tokens=None, caption_lengths=None):
90 |
91 | batch_size = poses.size(0) # beware of incomplete batches!
92 |
93 | # random L/R flip
94 | if self.args.apply_LR_augmentation:
95 | flippable = torch.rand(batch_size) < 0.5 # completely random flip
96 | if hasattr(self, "tokenizer"):
97 | caption_tokens, caption_lengths, actually_flipped = self.tokenizer.flip(caption_tokens, flippable)
98 | else:
99 | actually_flipped = flippable
100 | poses[actually_flipped] = self.pose_flip(poses[actually_flipped])
101 |
102 | return poses, caption_tokens, caption_lengths
103 |
104 |
105 | class PosefixDataAugmentation(GenericDataAugmentation):
106 | def __init__(self, args, tokenizer_name=None, nb_joints=config.NB_INPUT_JOINTS):
107 | super(PosefixDataAugmentation, self).__init__(args, tokenizer_name=tokenizer_name, nb_joints=nb_joints)
108 |
109 | def __call__(self, poses_A, caption_tokens=None, caption_lengths=None, poses_B=None, posescript_poses=None):
110 |
111 | batch_size = poses_A.size(0) # beware of incomplete batches!
112 |
113 | # random L/R flip
114 | if self.args.apply_LR_augmentation:
115 | flippable = torch.rand(batch_size) < 0.5 # completely random flip
116 | if hasattr(self, "tokenizer"):
117 | caption_tokens, caption_lengths, actually_flipped = self.tokenizer.flip(caption_tokens, flippable)
118 | else:
119 | actually_flipped = flippable
120 | poses_A[actually_flipped] = self.pose_flip(poses_A[actually_flipped])
121 | poses_B[actually_flipped] = self.pose_flip(poses_B[actually_flipped])
122 |
123 | # remove text cue: learn to copy pose A
124 | if self.args.copy_augmentation > 0:
125 | change = torch.rand(batch_size) < self.args.copy_augmentation # change at most a proportion of args.copy_augmentation poses
126 | if self.args.copyB2A:
127 | copy_B2A = torch.ones(batch_size).bool()
128 | else:
129 | # in the case of PoseScript data, A is "0"; therefore, for such
130 | # elements, we must copy B to A
131 | copy_B2A = posescript_poses # (batch_size)
132 | copy_A2B = ~copy_B2A
133 | poses_A[copy_B2A*change] = deepcopy(poses_B[copy_B2A*change])
134 | poses_B[copy_A2B*change] = deepcopy(poses_A[copy_A2B*change])
135 | # empty text
136 | if hasattr(self, "tokenizer"):
137 | caption_tokens[change] = self.empty_text_tokens[:caption_tokens.shape[1]] # by default, `empty_text_tokens` is very long
138 | caption_lengths[change] = self.empty_text_length
139 | caption_tokens = caption_tokens[:,:caption_lengths.max()] # update truncation, the longest text may have changed
140 |
141 | return poses_A, caption_tokens, caption_lengths, poses_B
--------------------------------------------------------------------------------
/src/text2pose/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | ################################################################################
7 | ## Miscellaneous modules
8 | ################################################################################
9 |
10 |
11 | class L2Norm(nn.Module):
12 | def forward(self, x):
13 | return x / x.norm(dim=-1, keepdim=True)
14 |
15 |
16 | class ConCatModule(nn.Module):
17 |
18 | def __init__(self):
19 | super(ConCatModule, self).__init__()
20 |
21 | def forward(self, x):
22 | x = torch.cat(x, dim=1)
23 | return x
24 |
25 |
26 | class DiffModule(nn.Module):
27 |
28 | def __init__(self):
29 | super(DiffModule, self).__init__()
30 |
31 | def forward(self, x):
32 | return x[1] - x[0]
33 |
34 |
35 | class AddModule(nn.Module):
36 |
37 | def __init__(self, axis=0):
38 | super(AddModule, self).__init__()
39 | self.axis = axis
40 |
41 | def forward(self, x):
42 | return x.sum(self.axis)
43 |
44 |
45 | class SeqModule(nn.Module):
46 |
47 | def __init__(self):
48 | super(SeqModule, self).__init__()
49 |
50 | def forward(self, x):
51 | # input: list of T tensors of size (BS, d)
52 | # output: tensor of size (BS, T, d)
53 | return torch.cat([xx.unsqueeze(1) for xx in x], dim = 1)
54 |
55 |
56 | class MiniMLP(nn.Module):
57 |
58 | def __init__(self, input_dim, hidden_dim, output_dim):
59 | super(MiniMLP, self).__init__()
60 | self.layers = nn.Sequential(
61 | nn.Linear(input_dim, hidden_dim),
62 | nn.ReLU(),
63 | nn.Linear(hidden_dim, output_dim)
64 | )
65 |
66 | def forward(self, x):
67 | return self.layers(x)
68 |
69 |
70 | class PositionalEncoding(nn.Module):
71 |
72 | def __init__(self, d_model, dropout=0.1, max_len=5000):
73 | super(PositionalEncoding, self).__init__()
74 |
75 | pe = torch.zeros(max_len, 1, d_model)
76 | position = torch.arange(max_len).unsqueeze(1)
77 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
78 | pe[:, 0, 0::2] = torch.sin(position * div_term)
79 | pe[:, 0, 1::2] = torch.cos(position * div_term)
80 | self.register_buffer('pe', pe)
81 |
82 | self.dropout = nn.Dropout(p=dropout)
83 |
84 | def forward(self, x):
85 | """
86 | Args:
87 | x: Tensor, shape [seq_len, batch_size, embedding_dim]
88 | """
89 | x = x + self.pe[:x.size(0)]
90 | return self.dropout(x)
91 |
92 |
93 | class TIRG(nn.Module):
94 | """
95 | The TIRG model.
96 | Implementation derived (except for BaseModel-inherence) from
97 | https://github.com/google/tirg (downloaded on July 23th 2020).
98 | The method is described in Nam Vo, Lu Jiang, Chen Sun, Kevin Murphy, Li-Jia
99 | Li, Li Fei-Fei, James Hays. "Composing Text and Image for Image Retrieval -
100 | An Empirical Odyssey" CVPR 2019. arXiv:1812.07119
101 | """
102 |
103 | def __init__(self, input_dim=[512, 512], output_dim=512, out_l2_normalize=False):
104 | super(TIRG, self).__init__()
105 |
106 | self.input_dim = sum(input_dim)
107 | self.output_dim = output_dim
108 |
109 | # --- modules
110 | self.a = nn.Parameter(torch.tensor([1.0, 1.0])) # changed the second coeff from 10.0 to 1.0
111 | self.gated_feature_composer = nn.Sequential(
112 | ConCatModule(), nn.BatchNorm1d(self.input_dim), nn.ReLU(),
113 | nn.Linear(self.input_dim, self.output_dim))
114 | self.res_info_composer = nn.Sequential(
115 | ConCatModule(), nn.BatchNorm1d(self.input_dim), nn.ReLU(),
116 | nn.Linear(self.input_dim, self.input_dim), nn.ReLU(),
117 | nn.Linear(self.input_dim, self.output_dim))
118 |
119 | if out_l2_normalize:
120 | self.output_layer = L2Norm() # added to the official TIRG code
121 | else:
122 | self.output_layer = nn.Sequential()
123 |
124 | def query_compositional_embedding(self, main_features, modifying_features):
125 | f1 = self.gated_feature_composer((main_features, modifying_features))
126 | f2 = self.res_info_composer((main_features, modifying_features))
127 | f = torch.sigmoid(f1) * main_features * self.a[0] + f2 * self.a[1]
128 | f = self.output_layer(f)
129 | return f
--------------------------------------------------------------------------------
/src/text2pose/encoders/pose_encoder_decoder.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023, 2024 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import torch.nn as nn
11 | import roma
12 | from human_body_prior.models.vposer_model import NormalDistDecoder, VPoser
13 |
14 | import text2pose.config as config
15 | from text2pose.encoders.modules import L2Norm
16 |
17 |
18 | ################################################################################
19 | ## Pose encoder / decoder
20 | ################################################################################
21 |
22 |
23 | class Object(object):
24 | pass
25 |
26 |
27 | class PoseEncoder(nn.Module):
28 |
29 | def __init__(self, num_neurons=512, num_neurons_mini=32, latentD=512, num_body_joints=config.NB_INPUT_JOINTS, role=None):
30 | super(PoseEncoder, self).__init__()
31 |
32 | self.num_body_joints = num_body_joints
33 | self.input_dim = self.num_body_joints * 3
34 |
35 | # use VPoser pose encoder architecture...
36 | vposer_params = Object()
37 | vposer_params.model_params = Object()
38 | vposer_params.model_params.num_neurons = num_neurons
39 | vposer_params.model_params.latentD = latentD
40 | vposer = VPoser(vposer_params)
41 | encoder_layers = list(vposer.encoder_net.children())
42 | # change first layers to have the right data input size
43 | encoder_layers[1] = nn.BatchNorm1d(self.input_dim)
44 | encoder_layers[2] = nn.Linear(self.input_dim, num_neurons)
45 | # remove last layer; the last layer.s depend on the task/role
46 | encoder_layers = encoder_layers[:-1]
47 |
48 | # output layers
49 | if role == "retrieval":
50 | encoder_layers += [
51 | nn.Linear(num_neurons, num_neurons_mini), # keep the bottleneck while adapting to the joint embedding size
52 | nn.ReLU(),
53 | nn.Linear(num_neurons_mini, latentD),
54 | L2Norm()]
55 | elif role == "generative":
56 | encoder_layers += [ NormalDistDecoder(num_neurons, latentD) ]
57 | elif role == "no_output_layer":
58 | encoder_layers += [ ]
59 | elif role == "modifier":
60 | encoder_layers += [
61 | nn.Linear(num_neurons, latentD)
62 | ]
63 | else:
64 | raise NotImplementedError
65 |
66 | self.encoder = nn.Sequential(*encoder_layers)
67 |
68 | def forward(self, pose):
69 | return self.encoder(pose)
70 |
71 |
72 | class PoseDecoder(nn.Module):
73 |
74 | def __init__(self, num_neurons=512, latentD=32, num_body_joints=config.NB_INPUT_JOINTS):
75 | super(PoseDecoder, self).__init__()
76 |
77 | self.num_body_joints = num_body_joints
78 |
79 | # use VPoser pose decoder architecture...
80 | vposer_params = Object()
81 | vposer_params.model_params = Object()
82 | vposer_params.model_params.num_neurons = num_neurons
83 | vposer_params.model_params.latentD = latentD
84 | vposer = VPoser(vposer_params)
85 | decoder_layers = list(vposer.decoder_net.children())
86 | # change one of the final layers to have the right data output size
87 | decoder_layers[-2] = nn.Linear(num_neurons, self.num_body_joints * 6)
88 |
89 | self.decoder = nn.Sequential(*decoder_layers)
90 |
91 | def forward(self, Zin):
92 | bs = Zin.shape[0]
93 | prec = self.decoder(Zin)
94 | return {
95 | 'pose_body': roma.rotmat_to_rotvec(prec.view(-1, 3, 3)).view(bs, -1, 3), # (batch_size, num_body_joints, 3)
96 | 'pose_body_matrot': prec.view(bs, -1, 9) # (batch_size, num_body_joints, 9)
97 | }
--------------------------------------------------------------------------------
/src/text2pose/fid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from scipy import linalg
5 |
6 | import text2pose.config as config
7 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder
8 |
9 | class FID(nn.Module):
10 |
11 | def __init__(self, version, device=torch.device('cpu'), name_in_batch='pose'):
12 | super().__init__()
13 | assert isinstance(version, tuple), "FID version should follow the format (retrieval_model_shortname, seed), where retrieval_model_shortname is actually provided with --fid as input to the script."
14 | self.version = version
15 | self.device = device
16 | self.name_in_batch = name_in_batch
17 | self._load_model()
18 |
19 | def sstr(self):
20 | return f"FID_{self.version[0]}_seed{self.version[1]}"
21 |
22 | def _load_model(self):
23 | ckpt_path = config.shortname_2_model_path[self.version[0]].format(seed=self.version[1])
24 | ckpt = torch.load(ckpt_path, 'cpu')
25 | print("FID: load", ckpt_path)
26 | self.model = PoseEncoder(latentD=ckpt['args'].latentD, num_body_joints=getattr(ckpt['args'], 'num_body_joints', 52), role="retrieval")
27 | self.model.load_state_dict({k[len('pose_encoder.'):]: v for k,v in ckpt['model'].items() if k.startswith('pose_encoder.encoder.')})
28 | self.model.eval()
29 | self.model.to(self.device)
30 |
31 | def extract_features(self, batchpose):
32 | batchpose = batchpose.to(self.device)
33 | batchpose = batchpose.view(batchpose.size(0),-1)[:,:self.model.input_dim]
34 | features = self.model(batchpose)
35 | return features
36 |
37 | def extract_real_features(self, valdataloader):
38 | real_features = []
39 | with torch.inference_mode():
40 | for batches in valdataloader:
41 | real_features.append( self.extract_features(batches[self.name_in_batch]) )
42 | self.real_features = torch.cat(real_features, dim=0).cpu().numpy()
43 | self.realmu = np.mean(self.real_features, axis=0)
44 | self.realsigma = np.cov(self.real_features, rowvar=False)
45 | print('FID: extracted real features', self.real_features.shape)
46 |
47 | def reset_gen_features(self):
48 | self.gen_features = []
49 |
50 | def add_gen_features(self, batchpose):
51 | with torch.inference_mode():
52 | self.gen_features.append( self.extract_features(batchpose) )
53 |
54 | def compute(self):
55 | gen_features = torch.cat(self.gen_features, dim=0).cpu().numpy()
56 | assert gen_features.shape[0] == self.real_features.shape[0]
57 | mu = np.mean(gen_features, axis=0)
58 | sigma = np.cov(gen_features, rowvar=False)
59 | return calculate_frechet_distance(mu, sigma, self.realmu, self.realsigma)
60 |
61 |
62 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
63 | """Numpy implementation of the Frechet Distance.
64 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
65 | and X_2 ~ N(mu_2, C_2) is
66 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
67 | Stable version by Dougal J. Sutherland.
68 | Params:
69 | -- mu1 : Numpy array containing the activations of a layer of the
70 | inception net (like returned by the function 'get_predictions')
71 | for generated samples.
72 | -- mu2 : The sample mean over activations, precalculated on an
73 | representative data set.
74 | -- sigma1: The covariance matrix over activations for generated samples.
75 | -- sigma2: The covariance matrix over activations, precalculated on an
76 | representative data set.
77 | Returns:
78 | -- : The Frechet Distance.
79 | """
80 |
81 | mu1 = np.atleast_1d(mu1)
82 | mu2 = np.atleast_1d(mu2)
83 |
84 | sigma1 = np.atleast_2d(sigma1)
85 | sigma2 = np.atleast_2d(sigma2)
86 |
87 | assert mu1.shape == mu2.shape, \
88 | 'Training and test mean vectors have different lengths'
89 | assert sigma1.shape == sigma2.shape, \
90 | 'Training and test covariances have different dimensions'
91 |
92 | diff = mu1 - mu2
93 |
94 | # Product might be almost singular
95 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
96 | if not np.isfinite(covmean).all():
97 | msg = ('fid calculation produces singular product; '
98 | 'adding %s to diagonal of cov estimates') % eps
99 | print(msg)
100 | offset = np.eye(sigma1.shape[0]) * eps
101 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
102 |
103 | # Numerical error might give slight imaginary component
104 | if np.iscomplexobj(covmean):
105 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
106 | m = np.max(np.abs(covmean.imag))
107 | raise ValueError('Imaginary component {}'.format(m))
108 | covmean = covmean.real
109 |
110 | tr_covmean = np.trace(covmean)
111 |
112 | return (diff.dot(diff) + np.trace(sigma1)
113 | + np.trace(sigma2) - 2 * tr_covmean)
--------------------------------------------------------------------------------
/src/text2pose/generative/README.md:
--------------------------------------------------------------------------------
1 | # Text-conditioned Generative Model for 3D Human Poses
2 |
3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._
4 |
5 | _:warning: The evaluation of this model relies partly on several [text-to-pose retrieval model](../retrieval/README.md), see section **Extra setup**, below._
6 |
7 | ## Model overview
8 |
9 | * **Input**: pose description;
10 | * **Output**: 3D human pose.
11 |
12 | 
13 |
14 | ## :crystal_ball: Demo
15 |
16 | To generate poses based on a pretrained model and your own input description, run the following:
17 |
18 | ```
19 | streamlit run generative/demo_generative.py -- --model_paths
20 | ```
21 |
22 | :bulb: Tips: _Specify several model paths to compare models together._
23 |
24 | ## Extra setup
25 |
26 | At the beginning of the bash script, assign to variable `fid` the shortname of the trained [text-to-pose retrieval model](../retrieval/README.md) to be used for computing the FID.
27 |
28 | Add a line in *shortname_2_model_path.txt* to indicate the path to the model corresponding to the provided shortname.
29 |
30 | ## :bullettrain_front: Train
31 |
32 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options.
33 |
34 | Then use the following command:
35 | ```
36 | bash generative/script_generative.sh 'train'
37 | ```
38 |
39 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model.
40 |
41 | ## :dart: Evaluate
42 |
43 | Use the following command:
44 | ```
45 | bash generative/script_generative.sh 'eval'
46 | ```
47 |
48 | Expected additional arguments include:
49 |
50 | | evaluation phase (`eval_type`) | expected additional arguments | needed to get... |
51 | |---|---|---|
52 | | `regular` | model path | **elbo**, **fid** |
53 | | `generate_poses` | model path, seed number | **mRecall R/G** (step 1/2), **mRecall G/R** (step 1/3) |
54 | | `RG` | seed number, shortname of the evaluated generative model, retrieval model shortname | **mRecall R/G** (step 2/2) |
55 | | `GRa` | seed number, shortname of the evaluated generative model | **mRecall G/R** (step 2/3) |
56 | | `GRb` | seed number, shortname to the model trained in previous step | **mRecall G/R** (step 3/3) |
57 |
58 | In the above table, "model path" is the path to the generative model to be evaluated.
59 |
60 | _**Important note**: the fid, mRecall R/G, and mRecall G/R rely on trained retrieval models._
61 |
62 | ## Generate and visualize pose samples for the dataset
63 |
64 | For evaluation, *generative/script_generative.sh* makes the model generate pose samples for each caption of the dataset, thanks to the following command:
65 |
66 | ```
67 | python generative/generate_poses.py --model_path
68 | ```
69 |
70 | The generated pose samples can be visualized, along with the original pose and the related description, by running the following:
71 |
72 | ```
73 | streamlit run generative/look_at_generated_pose_samples.py -- --model_path --dataset_version --split
74 | ```
75 |
--------------------------------------------------------------------------------
/src/text2pose/generative/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
--------------------------------------------------------------------------------
/src/text2pose/generative/demo_generative.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import streamlit as st
11 | import argparse
12 | import torch
13 | import numpy as np
14 |
15 | import text2pose.demo as demo
16 | import text2pose.utils as utils
17 | import text2pose.utils_visu as utils_visu
18 | from text2pose.generative.evaluate_generative import load_model
19 |
20 |
21 | parser = argparse.ArgumentParser(description='Parameters for the demo.')
22 | parser.add_argument('--model_paths', nargs='+', type=str, help='Paths to the models to be compared.')
23 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help="Checkpoint to choose if model path is incomplete.")
24 | parser.add_argument('--n_generate', type=int, default=12, help="Number of poses to generate (number of samples); if considering only one model.")
25 | args = parser.parse_args()
26 |
27 |
28 | ### INPUT
29 | ################################################################################
30 |
31 | data_version = "posescript-H2"
32 |
33 |
34 | ### SETUP
35 | ################################################################################
36 |
37 | # --- layout
38 | st.markdown("""
39 |
44 | """, unsafe_allow_html=True)
45 |
46 | # correct the number of generated sample depending on the setting
47 | if len(args.model_paths) > 1:
48 | n_generate = 4
49 | else:
50 | n_generate = args.n_generate
51 |
52 | # --- data
53 | available_splits = ['train', 'val', 'test']
54 | models, _, body_model = demo.setup_models(args.model_paths, args.checkpoint, load_model)
55 | dataID_2_pose_info, captions = demo.setup_posescript_data(data_version)
56 |
57 | # --- seed
58 | torch.manual_seed(42)
59 | np.random.seed(42)
60 |
61 |
62 | ### MAIN APP
63 | ################################################################################
64 |
65 | # define query input interface
66 | cols_query = st.columns(3)
67 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test'))
68 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID'), index=1)
69 | number = cols_query[2].number_input("Split index or ID:", 0)
70 | st.markdown("""---""")
71 |
72 | # get query data
73 | pose_ID, pose_data, pose_img, default_description = demo.get_posescript_datapoint(number, query_type, split_for_research, captions, dataID_2_pose_info, body_model)
74 |
75 | # show query data
76 | cols_input = st.columns(2)
77 | cols_input[0].image(pose_img, caption="Annotated pose")
78 | if default_description:
79 | cols_input[1].write("Annotated text:")
80 | cols_input[1].write(f"_{default_description}_")
81 | else:
82 | cols_input[1].write("_(Not annotated.)_")
83 |
84 | # get input description
85 | description = cols_input[1].text_area("Pose description:",
86 | value=default_description,
87 | placeholder="The person is...",
88 | height=None, max_chars=None)
89 |
90 | analysis = cols_input[1].checkbox('Analysis') # whether to show the reconstructed pose and the mean sample pose in addition of some samples
91 |
92 | # generate results
93 | if analysis:
94 |
95 | st.markdown("""---""")
96 | st.write("**Generated poses** (*The reconstructed pose is shown in green; the mean pose in red; and samples in grey.*):")
97 | n_generate = 2
98 | nb_cols = 2 + n_generate # reconstructed pose + mean sample pose + n_generate sample poses: all must fit in one row, for each studied model
99 |
100 | for i, model in enumerate(models):
101 | with torch.no_grad():
102 | rec_pose_data = model.forward_autoencoder(pose_data)['pose_body_pose'].view(1, -1)
103 | gen_pose_data_mean = model.sample_str_meanposes(description)['pose_body'].view(1, -1)
104 | gen_pose_data_samples = model.sample_str_nposes(description, n=n_generate)['pose_body'][0,...].view(n_generate, -1)
105 |
106 | # render poses
107 | imgs = utils_visu.image_from_pose_data(rec_pose_data, body_model, color='green', add_ground_plane=True, two_views=60)
108 | imgs += utils_visu.image_from_pose_data(gen_pose_data_mean, body_model, color='red', add_ground_plane=True, two_views=60)
109 | imgs += utils_visu.image_from_pose_data(gen_pose_data_samples, body_model, color='grey', add_ground_plane=True, two_views=60)
110 |
111 | # display images
112 | cols = st.columns(nb_cols+1) # +1 to display model info
113 | cols[0].markdown(f'{args.model_paths[i]}
', unsafe_allow_html=True)
114 | for i in range(nb_cols):
115 | cols[i%nb_cols+1].image(demo.process_img(imgs[i]))
116 | st.markdown("""---""")
117 |
118 | else:
119 |
120 | st.markdown("""---""")
121 | st.write("**Generated poses:**")
122 |
123 | for i, model in enumerate(models):
124 | with torch.no_grad():
125 | gen_pose_data_samples = model.sample_str_nposes(description, n=n_generate)['pose_body'][0,...].view(n_generate, -1)
126 |
127 | # render poses
128 | imgs = utils_visu.image_from_pose_data(gen_pose_data_samples, body_model, color='grey', add_ground_plane=True, two_views=60)
129 |
130 | # display images
131 | if len(models) > 1:
132 | cols = st.columns(n_generate+1) # +1 to display model info
133 | cols[0].markdown(f'{args.model_paths[i]}
', unsafe_allow_html=True)
134 | for i in range(n_generate):
135 | cols[i%n_generate+1].image(demo.process_img(imgs[i]))
136 | st.markdown("""---""")
137 | else:
138 | cols = st.columns(demo.nb_cols)
139 | for i in range(n_generate):
140 | cols[i%demo.nb_cols].image(demo.process_img(imgs[i]))
141 | st.markdown("""---""")
142 | st.write(f"_Results obtained with model: {args.model_paths[0]}_")
--------------------------------------------------------------------------------
/src/text2pose/generative/evaluate_generative.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import os
11 | from tqdm import tqdm
12 | import torch
13 | import numpy as np
14 | from human_body_prior.body_model.body_model import BodyModel
15 |
16 | import text2pose.config as config
17 | import text2pose.evaluate as evaluate
18 | from text2pose.data import PoseScript
19 | from text2pose.encoders.tokenizers import get_tokenizer_name
20 | from text2pose.generative.model_generative import CondTextPoser
21 | from text2pose.fid import FID
22 |
23 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
24 |
25 | OVERWRITE_RESULT = False
26 |
27 |
28 | ################################################################################
29 |
30 | def load_model(model_path, device):
31 |
32 | assert os.path.isfile(model_path), "File {} not found.".format(model_path)
33 |
34 | # load checkpoint & model info
35 | ckpt = torch.load(model_path, 'cpu')
36 | text_encoder_name = ckpt['args'].text_encoder_name
37 | transformer_topping = getattr(ckpt['args'], 'transformer_topping', None)
38 | latentD = ckpt['args'].latentD
39 | num_body_joints = getattr(ckpt['args'], 'num_body_joints', 52)
40 |
41 | # load model
42 | model = CondTextPoser(text_encoder_name=text_encoder_name,
43 | transformer_topping=transformer_topping,
44 | latentD=latentD,
45 | num_body_joints=num_body_joints
46 | ).to(device)
47 | model.load_state_dict(ckpt['model'])
48 | model.eval()
49 | print(f"Loaded model from (epoch {ckpt['epoch']}):", model_path)
50 |
51 | return model, get_tokenizer_name(text_encoder_name)
52 |
53 |
54 | def eval_model(model_path, dataset_version, fid_version, split='val'):
55 |
56 | device = torch.device('cuda:0')
57 |
58 | # set seed for reproducibility (sampling for pose generation)
59 | torch.manual_seed(42)
60 | np.random.seed(42)
61 |
62 | # define result file & get auxiliary info
63 | fid_version, precision = get_evaluation_auxiliary_info(fid_version)
64 | nb_caps = config.caption_files[dataset_version][0]
65 | get_res_file = evaluate.get_result_filepath_func(model_path, split, dataset_version, precision, nb_caps)
66 |
67 | # load model if results for at least one caption is missing
68 | if OVERWRITE_RESULT or evaluate.one_result_file_is_missing(get_res_file, nb_caps):
69 | model, tokenizer_name = load_model(model_path, device)
70 |
71 | # compute or load results for the given run & caption
72 | results = {}
73 | for cap_ind in range(nb_caps):
74 | filename_res = get_res_file(cap_ind)
75 | if not os.path.isfile(filename_res) or OVERWRITE_RESULT:
76 | d = PoseScript(version=dataset_version, split=split, tokenizer_name=tokenizer_name, caption_index=cap_ind, num_body_joints=model.pose_encoder.num_body_joints, cache=True)
77 | cap_results = compute_eval_metrics(model, d, fid_version, device)
78 | evaluate.save_results_to_file(cap_results, filename_res)
79 | else:
80 | cap_results = evaluate.load_results_from_file(filename_res)
81 | # aggregate results
82 | results = {k:[v] for k, v in cap_results.items()} if not results else {k:results[k]+[v] for k,v in cap_results.items()}
83 |
84 | # average over captions
85 | results = {k:sum(v)/nb_caps for k,v in results.items()}
86 |
87 | return {k:[v] for k, v in results.items()}
88 |
89 |
90 | def get_evaluation_auxiliary_info(fid_version, seed=1):
91 | # NOTE: default seed=1 for consistent evaluation
92 | precision = ""
93 | if fid_version is not None:
94 | fid_version = (fid_version, seed)
95 | precision += f"_X{fid_version[0]}-{fid_version[1]}X"
96 | return fid_version, precision
97 |
98 |
99 | def compute_eval_metrics(model, dataset, fid_version, device):
100 |
101 | # initialize
102 | data_loader = torch.utils.data.DataLoader(
103 | dataset, sampler=None, shuffle=False,
104 | batch_size=32,
105 | num_workers=8,
106 | pin_memory=True,
107 | drop_last=False
108 | )
109 |
110 | body_model = BodyModel(model_type = config.POSE_FORMAT,
111 | bm_fname = config.NEUTRAL_BM,
112 | num_betas = config.n_betas).to(device)
113 |
114 | fid = FID(version=fid_version, device=device)
115 | fid.extract_real_features(data_loader)
116 | fid.reset_gen_features()
117 |
118 | pose_metrics = {f'{k}_{v}': 0.0 for k in ['v2v', 'jts', 'rot'] for v in ['elbo', 'dist_avg'] \
119 | + [f'dist_top{topk}' for topk in config.k_topk_reconstruction_values]}
120 |
121 | # compute metrics
122 | for batch in tqdm(data_loader):
123 |
124 | # data setup
125 | model_input = dict(
126 | poses = batch['pose'].to(device),
127 | caption_lengths = batch['caption_lengths'].to(device),
128 | captions = batch['caption_tokens'][:,:batch['caption_lengths'].max()].to(device),
129 | )
130 |
131 | with torch.inference_mode():
132 |
133 | pose_metrics, _ = evaluate.add_elbo_and_reconstruction(model_input, pose_metrics, model, body_model, output_distr_key="t_z", reference_pose_key="poses")
134 | fid.add_gen_features( model.sample_nposes(**model_input, n=1)['pose_body'] )
135 |
136 | # average over the dataset
137 | for k in pose_metrics: pose_metrics[k] /= len(dataset)
138 |
139 | # normalize the elbo (the same is done earlier for the reconstruction metrics)
140 | pose_metrics.update({'v2v_elbo':pose_metrics['v2v_elbo']/(body_model.J_regressor.shape[1] * 3),
141 | 'jts_elbo':pose_metrics['jts_elbo']/(body_model.J_regressor.shape[0] * 3),
142 | 'rot_elbo':pose_metrics['rot_elbo']/(model.pose_decoder.num_body_joints * 9)})
143 |
144 | # compute fid metric
145 | results = {'fid': fid.compute()}
146 | results.update(pose_metrics)
147 |
148 | return results
149 |
150 |
151 | def display_results(results):
152 | metric_order = ['fid'] + [f'{x}_elbo' for x in ['jts', 'v2v', 'rot']] \
153 | + [f'{k}_{v}' for k in ['jts', 'v2v', 'rot']
154 | for v in ['dist_avg'] + [f'dist_top{topk}' for topk in config.k_topk_reconstruction_values]]
155 | results = evaluate.scale_and_format_results(results)
156 | print(f"\n & {' & '.join([results[m] for m in metric_order])} \\\\\n")
157 |
158 |
159 | ################################################################################
160 |
161 | if __name__=="__main__":
162 |
163 | # added special arguments
164 | evaluate.eval_parser.add_argument('--fid', type=str, help='Version of the fid to use for evaluation.')
165 |
166 | args = evaluate.eval_parser.parse_args()
167 | args = evaluate.get_full_model_path(args)
168 |
169 | # compute results
170 | if args.average_over_runs:
171 | ret = evaluate.eval_model_all_runs(eval_model, args.model_path, dataset_version=args.dataset, fid_version=args.fid, split=args.split)
172 | else:
173 | ret = eval_model(args.model_path, dataset_version=args.dataset, fid_version=args.fid, split=args.split)
174 |
175 | # display results
176 | print(ret)
177 | display_results(ret)
--------------------------------------------------------------------------------
/src/text2pose/generative/generate_poses.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import os
11 | import argparse
12 | from tqdm import tqdm
13 | import torch
14 | import numpy as np
15 |
16 | import text2pose.config as config
17 | from text2pose.data import PoseScript
18 | from text2pose.generative.evaluate_generative import load_model
19 |
20 |
21 | parser = argparse.ArgumentParser(description='Parameters to generate poses corresponding to each caption.')
22 | parser.add_argument('--model_path', type=str, help='Path to the model.')
23 | parser.add_argument('--n_generate', type=int, default=5, help="Number of poses to generate for a given caption.")
24 | args = parser.parse_args()
25 |
26 |
27 | ### INPUT
28 | ################################################################################
29 |
30 | device = torch.device('cuda:0')
31 | save_path = config.generated_pose_path % os.path.dirname(args.model_path)
32 | splits = ['train', 'test', 'val']
33 |
34 | torch.manual_seed(42)
35 | np.random.seed(42)
36 |
37 |
38 | ### GENERATE POSES
39 | ################################################################################
40 |
41 | # load model
42 | model, tokenizer_name = load_model(args.model_path, device)
43 | dataset_version = torch.load(args.model_path, 'cpu')['args'].dataset
44 |
45 | # create saving directory
46 | if not os.path.isdir(os.path.dirname(save_path)):
47 | os.mkdir(os.path.dirname(save_path))
48 |
49 | # generate poses
50 | for s in splits:
51 |
52 | # check that the poses were not already generated
53 | filepath = save_path.format(data_version=dataset_version, split=s)
54 | assert not os.path.isfile(filepath), "Poses already generated!"
55 |
56 | d = PoseScript(version=dataset_version, split=s, tokenizer_name=tokenizer_name, num_body_joints=model.pose_decoder.num_body_joints)
57 | ncaptions = config.caption_files[dataset_version][0]
58 | output = torch.empty( (len(d), ncaptions, args.n_generate, model.pose_decoder.num_body_joints, 3), dtype=torch.float32)
59 |
60 | for index in tqdm(range(len(d))):
61 | # look at each available caption in turn
62 | for cidx in range(ncaptions):
63 | item = d.__getitem__(index, cidx=cidx)
64 | caption_tokens = item['caption_tokens'].to(device).unsqueeze(0)
65 | caption_lengths = torch.tensor([item['caption_lengths']]).to(device)
66 | caption_tokens = caption_tokens[:,:caption_lengths.max()]
67 | with torch.no_grad():
68 | genposes = model.sample_nposes(caption_tokens, caption_lengths, n=args.n_generate)['pose_body'][0,...]
69 | output[index,cidx,...] = genposes
70 |
71 | # save
72 | torch.save(output, filepath)
73 | print(filepath)
--------------------------------------------------------------------------------
/src/text2pose/generative/look_at_generated_pose_samples.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import streamlit as st
11 | import os
12 | import argparse
13 | from human_body_prior.body_model.body_model import BodyModel
14 |
15 | import text2pose.config as config
16 | import text2pose.utils_visu as utils_visu
17 | from text2pose.data import PoseScript
18 |
19 |
20 | parser = argparse.ArgumentParser(description='Parameters for the demo.')
21 | parser.add_argument('--model_path', type=str, help='Path to the model that generated the pose samples to visualize.')
22 | parser.add_argument('--dataset_version', type=str, help='Dataset version (depends on the model)')
23 | parser.add_argument('--split', type=str, help='Split')
24 | args = parser.parse_args()
25 |
26 |
27 | ### SETUP
28 | ################################################################################
29 |
30 | @st.cache_resource
31 | def setup(args):
32 |
33 | # setup data
34 | generated_pose_path = config.generated_pose_path % os.path.dirname(args.model_path)
35 | generated_pose_path = generated_pose_path.format(data_version=args.dataset_version, split=args.split)
36 |
37 | dataset = PoseScript(version=args.dataset_version, split=args.split,
38 | cache=False, generated_pose_samples_path=generated_pose_path)
39 |
40 | # setup body model
41 | body_model = BodyModel(model_type = config.POSE_FORMAT,
42 | bm_fname = config.NEUTRAL_BM,
43 | num_betas = config.n_betas)
44 | body_model.eval()
45 | body_model.to('cpu')
46 |
47 | return generated_pose_path, dataset, body_model
48 |
49 |
50 | generated_pose_path, dataset, body_model = setup(args)
51 |
52 |
53 | ### VISUALIZE
54 | ################################################################################
55 |
56 | st.write(f"**Dataset:** {args.dataset_version}")
57 | st.write(f"**Split:** {args.split}")
58 | st.write(f"**Using pose samples from:** {generated_pose_path}")
59 |
60 | # get input pose index & caption index
61 | st.write("**Choose a data point:**")
62 | index = st.number_input("Index in split:", 0, len(dataset.pose_samples)-1)
63 | cidx = st.number_input("Caption index:", 0, len(dataset.pose_samples[0])-1)
64 |
65 | # display description
66 | st.write("**Description:** "+dataset.captions[dataset.dataIDs[index]][cidx])
67 |
68 | # render poses
69 | nb_samples = dataset.pose_samples.shape[2]
70 | img_original = utils_visu.image_from_pose_data(dataset.get_pose(index).view(1, -1), body_model, color='blue')
71 | imgs_sampled = utils_visu.image_from_pose_data(dataset.pose_samples[index, cidx].view(nb_samples, -1), body_model, color='green')
72 |
73 | # display original pose
74 | st.write("**Original pose:**")
75 | st.image(img_original[0])
76 |
77 | # display generated pose samples
78 | st.write("**Generated pose samples for this description:**")
79 | cols = st.columns(nb_samples)
80 | for i in range(nb_samples):
81 | cols[i].image(imgs_sampled[i])
--------------------------------------------------------------------------------
/src/text2pose/generative/model_generative.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | import text2pose.config as config
6 | from text2pose.encoders.tokenizers import Tokenizer, get_text_encoder_or_decoder_module_name, get_tokenizer_name
7 | from text2pose.encoders.pose_encoder_decoder import PoseDecoder, PoseEncoder
8 | from text2pose.encoders.text_encoders import TextEncoder, TransformerTextEncoder
9 |
10 | class CondTextPoser(nn.Module):
11 |
12 | def __init__(self, num_neurons=512, latentD=32, num_body_joints=config.NB_INPUT_JOINTS, text_encoder_name='distilbertUncased', transformer_topping=None):
13 | super(CondTextPoser, self).__init__()
14 |
15 | self.latentD = latentD
16 |
17 | # Define pose auto-encoder
18 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons, latentD=latentD, num_body_joints=num_body_joints, role="generative")
19 | self.pose_decoder = PoseDecoder(num_neurons=num_neurons, latentD=latentD, num_body_joints=num_body_joints)
20 |
21 | # Define text encoder
22 | self.text_encoder_name = text_encoder_name
23 | module_ref = get_text_encoder_or_decoder_module_name(text_encoder_name)
24 | if module_ref in ["glovebigru"]:
25 | self.text_encoder = TextEncoder(self.text_encoder_name, num_neurons=num_neurons, latentD=latentD, role="generative")
26 | elif module_ref in ["glovetransf", "distilbertUncased"]:
27 | self.text_encoder = TransformerTextEncoder(self.text_encoder_name, num_neurons=num_neurons, latentD=latentD, topping=transformer_topping, role="generative")
28 | else:
29 | raise NotImplementedError
30 |
31 | # Define learned loss parameters
32 | self.decsigma_v2v = nn.Parameter( torch.zeros(1) ) # logsigma
33 | self.decsigma_jts = nn.Parameter( torch.zeros(1) ) # logsigma
34 | self.decsigma_rot = nn.Parameter( torch.zeros(1) ) # logsigma
35 |
36 |
37 | # FORWARD METHODS ----------------------------------------------------------
38 |
39 |
40 | def encode_text(self, captions, caption_lengths):
41 | return self.text_encoder(captions, caption_lengths)
42 |
43 | def encode_pose(self, pose_body):
44 | return self.pose_encoder(pose_body)
45 |
46 | def decode_pose(self, z):
47 | return self.pose_decoder(z)
48 |
49 | def forward_autoencoder(self, poses):
50 | q_z = self.encode_pose(poses)
51 | q_z_sample = q_z.rsample()
52 | ret = {f"{k}_pose":v for k,v in self.decode_pose(q_z_sample).items()}
53 | ret.update({'q_z': q_z})
54 | return ret
55 |
56 | def forward(self, poses, captions, caption_lengths):
57 | t_z = self.encode_text(captions, caption_lengths)
58 | q_z = self.encode_pose(poses)
59 | q_z_sample = q_z.rsample()
60 | t_z_sample = t_z.rsample()
61 | ret = {f"{k}_pose":v for k,v in self.decode_pose(q_z_sample).items()}
62 | ret.update({f"{k}_text":v for k,v in self.decode_pose(t_z_sample).items()})
63 | ret.update({'q_z': q_z, 't_z': t_z})
64 | return ret
65 |
66 |
67 | # SAMPLE METHODS -----------------------------------------------------------
68 |
69 |
70 | def sample_nposes(self, captions, caption_lengths, n=1, **kwargs):
71 | t_z = self.encode_text(captions, caption_lengths)
72 | z = t_z.sample( [n] ).permute(1,0,2).flatten(0,1)
73 | decode_results = self.decode_pose(z)
74 | return {k: v.view(int(v.shape[0]/n), n, *v.shape[1:]) for k,v in decode_results.items()}
75 |
76 | def sample_str_nposes(self, s, n=1):
77 | device = self.decsigma_v2v.device
78 | # no text provided, sample pose directly from latent space
79 | if len(s)==0:
80 | z = torch.tensor(np.random.normal(0., 1., size=(n, self.latentD)), dtype=torch.float32, device=device)
81 | decode_results = self.decode_pose(z)
82 | return {k: v.view(int(v.shape[0]/n), n, *v.shape[1:]) for k,v in decode_results.items()}
83 | # otherwise, encode the text to sample a pose conditioned on it
84 | if not hasattr(self, 'tokenizer'):
85 | self.tokenizer = Tokenizer(get_tokenizer_name(self.text_encoder_name))
86 | tokens = self.tokenizer(s).to(device=device)
87 | return self.sample_nposes(tokens.view(1, -1), torch.tensor([ len(tokens) ], dtype=tokens.dtype), n=n)
88 |
89 | def sample_str_meanposes(self, s):
90 | device = self.decsigma_v2v.device
91 | assert len(s)>0, "Please provide a non-empty text."
92 | if not hasattr(self, 'tokenizer'):
93 | self.tokenizer = Tokenizer(get_tokenizer_name(self.text_encoder_name))
94 | tokens = self.tokenizer(s).to(device=device)
95 | t_z = self.encode_text(tokens.view(1, -1), torch.tensor([ len(tokens) ], dtype=tokens.dtype))
96 | return self.decode_pose(t_z.mean.view(1, -1))
--------------------------------------------------------------------------------
/src/text2pose/generative/script_generative.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ##############################################################
4 | ## text2pose ##
5 | ## Copyright (c) 2022, 2023 ##
6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
7 | ## and Naver Corporation ##
8 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
9 | ## See project root for license details. ##
10 | ##############################################################
11 |
12 |
13 | ##############################################################
14 | # SCRIPT ARGUMENTS
15 |
16 | action=$1 # (train|eval|demo)
17 | eval_data_version="H2"
18 | checkpoint_type="best" # (last|best)
19 |
20 | architecture_args=(
21 | --model CondTextPoser
22 | --latentD 32
23 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp"
24 | # --text_encoder_name 'glovebigru_vocPSA2H2'
25 | )
26 |
27 | loss_args=(
28 | --wloss_v2v 4 --wloss_rot 2 --wloss_jts 2
29 | --wloss_kld 0.1 --wloss_kldnpmul 0.01 --wloss_kldntmul 0.01
30 | )
31 |
32 | bonus_args=(
33 | )
34 |
35 | fid="ret_distilbert_dataPSA2ftPSH2"
36 |
37 | pretrained="gen_distilbert_dataPSA2" # used only if phase=='finetune'
38 |
39 |
40 | ##############################################################
41 | # EXECUTE
42 |
43 | # TRAIN
44 | if [[ "$action" == *"train"* ]]; then
45 |
46 | phase=$2 # (pretrain|finetune)
47 | echo "NOTE: Expecting as argument the training phase. Got: $phase"
48 | seed=$3
49 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
50 |
51 | # PRETRAIN
52 | if [[ "$phase" == *"pretrain"* ]]; then
53 |
54 | python generative/train_generative.py --dataset "posescript-A2" \
55 | "${architecture_args[@]}" \
56 | "${loss_args[@]}" \
57 | "${bonus_args[@]}" \
58 | --lr 1e-05 --wd 0.0001 --batch_size 128 --seed $seed \
59 | --epochs 5000 --log_step 20 --val_every 20 \
60 | --fid $fid
61 |
62 | # FINETUNE
63 | elif [[ "$phase" == *"finetune"* ]]; then
64 |
65 | python generative/train_generative.py --dataset "posescript-H2" \
66 | "${architecture_args[@]}" \
67 | "${loss_args[@]}" \
68 | "${bonus_args[@]}" \
69 | --apply_LR_augmentation \
70 | --lrposemul 0.1 --lrtextmul 1 \
71 | --lr 1e-05 --wd 0.0001 --batch_size 128 --seed $seed \
72 | --epochs 2000 --val_every 10 \
73 | --fid $fid \
74 | --pretrained $pretrained
75 |
76 | fi
77 |
78 | fi
79 |
80 |
81 | # EVAL QUANTITATIVELY
82 | if [[ "$action" == *"eval"* ]]; then
83 |
84 | eval_type=$2 # (regular|generate_poses|RG|GRa|GRb)
85 | echo "NOTE: Expecting as argument the evaluation phase. Got: $eval_type"
86 |
87 | # regular (fid, elbo)
88 | if [[ "$eval_type" == "regular" ]]; then
89 |
90 | # parse additional input
91 | model_path=$3
92 | echo "NOTE: Expecting as argument the path to the model to evaluate. Got: $model_path"
93 |
94 | python generative/evaluate_generative.py \
95 | --dataset "posescript-$eval_data_version" --split "test" \
96 | --model_path ${model_path} --checkpoint $checkpoint_type \
97 | --fid $fid
98 |
99 |
100 | # generate poses for R/G & G/R
101 | elif [[ "$eval_type" == "generate_poses" ]]; then
102 |
103 | # parse additional input
104 | model_path=$3
105 | echo "NOTE: Expecting as argument the path to the model for which to generate sample poses. Got: $model_path"
106 | seed=$4
107 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
108 |
109 | mp=${model_path::-1}$seed/checkpoint_best.pth
110 | python generative/generate_poses.py --model_path $mp
111 |
112 |
113 | # eval R/G
114 | elif [[ "$eval_type" == "RG" ]]; then
115 |
116 | # parse additional input
117 | seed=$3
118 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
119 | model_shortname=$4
120 | echo "NOTE: Expecting as argument the shortname of the generative model to evaluate. Got: $model_shortname"
121 | retrieval_model_shortname=$5
122 | echo "NOTE: Expecting as argument the shortname of the retrieval model used for R/G. Got: $retrieval_model_shortname"
123 |
124 | # evaluate generated poses with the retrieval model
125 | retrieval_model_path=$(python config.py $retrieval_model_shortname $seed)
126 | python retrieval/evaluate_retrieval.py \
127 | --dataset "posescript-"$eval_data_version --split 'test' \
128 | --model_path $retrieval_model_path --checkpoint $checkpoint_type \
129 | --generated_pose_samples $model_shortname
130 |
131 |
132 | # eval G/R, 1st step
133 | elif [[ "$eval_type" == "GRa" ]]; then
134 |
135 | # parse additional input
136 | seed=$3
137 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
138 | model_shortname=$4
139 | echo "NOTE: Expecting as argument the shortname of the generative model to evaluate. Got: $model_shortname"
140 |
141 | # define specificities for the new retrieval model
142 | args_ret=(
143 | --model 'PoseText' --latentD 512
144 | --lr_scheduler "stepLR" --lr 0.0002 --lr_gamma 0.5
145 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp"
146 | # --text_encoder_name 'glovebigru_vocPSA2H2'
147 | )
148 | pret='ret_distilbert_dataPSA2'
149 | if [[ "$eval_data_version" == "A2" ]]; then
150 | args_ret+=(--dataset "posescript-A2"
151 | --lr_step 400
152 | --batch_size 512
153 | --epochs 1000
154 | )
155 | elif [[ "$eval_data_version" == "H2" ]]; then
156 | args_ret+=(--dataset "posescript-H2"
157 | --lr_step 40
158 | --batch_size 32
159 | --epochs 200
160 | --pret $pret
161 | --apply_LR_augmentation
162 | )
163 | fi
164 | echo "NOTE: Expecting in the script the spec of the new retrieval model to train on the generated poses. Got:" "${args_ret[@]}"
165 |
166 | # train a new retrieval model with the generated poses
167 | python retrieval/train_retrieval.py "${args_ret[@]}" --seed $seed \
168 | --generated_pose_samples $model_shortname
169 |
170 | echo "IMPORTANT: please create an entry in shortname_2_model_path.txt for this newly trained retrieval model (providing a shortname and the path to the model), as it will be needed for the next evaluation step of this generative model ($model_shortname)"
171 |
172 |
173 | # eval G/R, 2st step
174 | elif [[ "$eval_type" == "GRb" ]]; then
175 |
176 | # parse additional input
177 | seed=$3
178 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
179 | spec_retrieval_model_shortname=$4
180 | echo "NOTE: Expecting as argument the shortname of the retrieval model used for G/R. Got: $spec_retrieval_model_shortname"
181 |
182 | # evaluate the retrieval model trained on generated poses, on the original poses
183 | spec_retrieval_model_path=$(python config.py $spec_retrieval_model_shortname $seed)
184 | python retrieval/evaluate_retrieval.py \
185 | --dataset "posescript-"$eval_data_version --split 'test' \
186 | --model_path $spec_retrieval_model_path --checkpoint $checkpoint_type
187 |
188 | fi
189 | fi
190 |
191 |
192 | # EVAL QUALITATIVELY
193 | if [[ "$action" == *"demo"* ]]; then
194 |
195 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one
196 | streamlit run generative/demo_generative.py -- --model_paths "${experiments[@]}" --checkpoint $checkpoint_type
197 |
198 | fi
--------------------------------------------------------------------------------
/src/text2pose/generative_B/README.md:
--------------------------------------------------------------------------------
1 | # Text-guided 3D Human Pose Editing Model
2 |
3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._
4 |
5 | _:warning: The evaluation of this model relies partly on a [text-to-pose retrieval model](../retrieval/README.md), see section **Extra setup**, below._
6 |
7 | ## Model overview
8 |
9 | * **Inputs (#2)**: 3D human pose + text modifier;
10 | * **Output**: 3D human pose.
11 |
12 | 
13 |
14 | ## :crystal_ball: Demo
15 |
16 | To edit poses based on a pretrained model and example pairs of pose and (modifyable) modifier texts, run the following:
17 |
18 | ```
19 | streamlit run generative_B/demo_generative_B.py -- --model_paths
20 | ```
21 |
22 | :bulb: Tips: _Specify several model paths to compare models together._
23 |
24 | ## Extra setup
25 |
26 | At the beginning of the bash script, assign to variable `fid` the shortname of the trained [text-to-pose retrieval model](../retrieval/README.md) to be used for computing the FID.
27 |
28 | Add a line in *shortname_2_model_path.txt* to indicate the path to the model corresponding to the provided shortname.
29 |
30 | ## :bullettrain_front: Train
31 |
32 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options.
33 |
34 |
35 | Then use the following command:
36 | ```
37 | bash generative_B/script_generative_B.sh 'train'
38 | ```
39 |
40 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model.
41 |
42 | ## :dart: Evaluate
43 |
44 | Use the following command:
45 | ```
46 | bash generative_B/script_generative_B.sh 'eval'
47 | ```
48 |
--------------------------------------------------------------------------------
/src/text2pose/generative_B/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/generative_B/demo_generative_B.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import streamlit as st
11 | import argparse
12 | import torch
13 | import numpy as np
14 |
15 | import text2pose.config as config
16 | import text2pose.demo as demo
17 | import text2pose.utils as utils
18 | import text2pose.data as data
19 | import text2pose.utils_visu as utils_visu
20 | from text2pose.generative_B.evaluate_generative_B import load_model
21 |
22 |
23 | parser = argparse.ArgumentParser(description='Parameters for the demo.')
24 | parser.add_argument('--model_paths', nargs='+', type=str, help='Paths to the models to be compared.')
25 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help="Checkpoint to choose if model path is incomplete.")
26 | parser.add_argument('--n_generate', type=int, default=12, help="Number of poses to generate (number of samples); if considering only one model.")
27 | args = parser.parse_args()
28 |
29 |
30 | ### INPUT
31 | ################################################################################
32 |
33 | posefix_data_version = "posefix-H"
34 | posescript_data_version = "posescript-H2"
35 |
36 |
37 | ### SETUP
38 | ################################################################################
39 |
40 | # --- layout
41 | st.markdown("""
42 |
47 | """, unsafe_allow_html=True)
48 |
49 | # correct the number of generated sample depending on the setting
50 | if len(args.model_paths) > 1:
51 | n_generate = 4
52 | else:
53 | n_generate = args.n_generate
54 |
55 | # --- data
56 | available_splits = ['train', 'val', 'test']
57 | models, _, body_model = demo.setup_models(args.model_paths, args.checkpoint, load_model)
58 | dataID_2_pose_info, triplet_data = demo.setup_posefix_data(posefix_data_version)
59 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids)
60 | _, captions = demo.setup_posescript_data(posescript_data_version)
61 |
62 | # --- seed
63 | torch.manual_seed(42)
64 | np.random.seed(42)
65 |
66 |
67 | ### MAIN APP
68 | ################################################################################
69 |
70 | # define query input interface
71 | cols_query = st.columns(3)
72 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test'))
73 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID'))
74 | number = cols_query[2].number_input("Split index or ID:", 0)
75 | st.markdown("""---""")
76 |
77 | # get query data
78 | pair_ID, pid_A, pid_B, pose_A_data, pose_B_data, pose_A_img, pose_B_img, default_modifier = demo.get_posefix_datapoint(number, query_type, split_for_research, triplet_data, pose_pairs, dataID_2_pose_info, body_model)
79 |
80 | # show query data
81 | st.write(f"**Query data:**")
82 | cols_input = st.columns([1,1,2])
83 | # (enable PoseScript mode: description only)
84 | no_pose_A = cols_input[2].checkbox("PoseScript mode")
85 | if no_pose_A:
86 | pose_A_data = data.T_POSE.view(1, -1)
87 | pose_A_img = utils_visu.image_from_pose_data(pose_A_data, body_model, color="grey", add_ground_plane=True)
88 | pose_A_img = demo.process_img(pose_A_img[0])
89 | pose_B_data, pose_B_img, default_description = demo.get_posescript_datapoint_from_pid(pid_B, captions, dataID_2_pose_info, body_model)
90 | # (actually show)
91 | cols_input[0].image(pose_A_img, caption="T-pose" if no_pose_A else "Pose A")
92 | cols_input[1].image(pose_B_img, caption="Annotated pose" if no_pose_A else "Annotated pose B")
93 | txt = default_description if no_pose_A else default_modifier
94 | if txt:
95 | cols_input[2].write("Annotated text:")
96 | cols_input[2].write(f"_{txt}_")
97 | else:
98 | cols_input[2].write("_(Not annotated.)_")
99 |
100 | # get input modifier
101 | modifier = cols_input[2].text_area("Pose modifier:",
102 | placeholder="Move your right arm... lift your left leg...",
103 | value=txt,
104 | height=None, max_chars=None)
105 |
106 | analysis = cols_input[2].checkbox('Analysis') # whether to show the reconstructed pose and the mean sample pose in addition of some samples
107 |
108 | # generate pose B
109 | if analysis:
110 |
111 | st.markdown("""---""")
112 | st.write("**Generated poses** (*The reconstructed B pose is shown in green; the mean pose in red; and samples in blue.*):")
113 | n_generate = 2
114 | nb_cols = 2 + n_generate # reconstructed pose + mean sample pose + n_generate sample poses: all must fit in one row, for each studied model
115 |
116 | for i, model in enumerate(models):
117 | with torch.no_grad():
118 | rec_pose_data = model.forward_autoencoder(pose_B_data)['pose_body_pose'].view(1, -1)
119 | gen_pose_data_mean = model.sample_str_meanposes(pose_A_data, modifier)['pose_body'].view(1, -1)
120 | gen_pose_data_samples = model.sample_str_nposes(pose_A_data, modifier, n=n_generate)['pose_body'][0,...].view(n_generate, -1)
121 |
122 | # render poses
123 | imgs = utils_visu.image_from_pose_data(rec_pose_data, body_model, color='green', add_ground_plane=True, two_views=60)
124 | imgs += utils_visu.image_from_pose_data(gen_pose_data_mean, body_model, color='red', add_ground_plane=True, two_views=60)
125 | imgs += utils_visu.image_from_pose_data(gen_pose_data_samples, body_model, color='blue', add_ground_plane=True, two_views=60)
126 |
127 | # display images
128 | cols = st.columns(nb_cols+1) # +1 to display model info
129 | cols[0].markdown(f'{args.model_paths[i]}
', unsafe_allow_html=True)
130 | for i in range(nb_cols):
131 | cols[i%nb_cols+1].image(demo.process_img(imgs[i]))
132 | st.markdown("""---""")
133 |
134 | else:
135 |
136 | st.markdown("""---""")
137 | st.write("**Generated poses:**")
138 |
139 | for i, model in enumerate(models):
140 | with torch.no_grad():
141 | gen_pose_data_samples = model.sample_str_nposes(pose_A_data, modifier, n=n_generate)['pose_body'][0,...].view(n_generate, -1)
142 |
143 | # render poses
144 | imgs = utils_visu.image_from_pose_data(gen_pose_data_samples, body_model, color='blue', add_ground_plane=True, two_views=60)
145 |
146 | # display images
147 | if len(models) > 1:
148 | cols = st.columns(n_generate+1) # +1 to display model info
149 | cols[0].markdown(f'{args.model_paths[i]}
', unsafe_allow_html=True)
150 | for i in range(n_generate):
151 | cols[i%n_generate+1].image(demo.process_img(imgs[i]))
152 | st.markdown("""---""")
153 | else:
154 | cols = st.columns(demo.nb_cols)
155 | for i in range(n_generate):
156 | cols[i%demo.nb_cols].image(demo.process_img(imgs[i]))
157 | st.markdown("""---""")
158 | st.write(f"_Results obtained with model: {args.model_paths[0]}_")
--------------------------------------------------------------------------------
/src/text2pose/generative_B/script_generative_B.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ##############################################################
4 | ## text2pose ##
5 | ## Copyright (c) 2023 ##
6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
7 | ## and Naver Corporation ##
8 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
9 | ## See project root for license details. ##
10 | ##############################################################
11 |
12 |
13 | ##############################################################
14 | # SCRIPT ARGUMENTS
15 |
16 | action=$1
17 | checkpoint_type="best" # (last|best)
18 |
19 | architecture_args=(
20 | --model PoseBGenerator
21 | --correction_module_mode "tirg"
22 | --latentD 32 --special_text_latentD 128
23 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp"
24 | # --text_encoder_name 'glovebigru_vocPFAHPP'
25 | )
26 |
27 | loss_args=(
28 | --wloss_kld 1.0
29 | --kld_epsilon 10.0
30 | --wloss_v2v 1.0 --wloss_rot 1.0 --wloss_jts 1.0
31 | )
32 |
33 | bonus_args=(
34 | )
35 |
36 | fid="ret_distilbert_dataPSA2ftPSH2"
37 |
38 | pretrained="b_gen_distilbert_dataPFA" # used only if phase=='finetune'
39 |
40 |
41 | ##############################################################
42 | # EXECUTE
43 |
44 | # TRAIN
45 | if [[ "$action" == *"train"* ]]; then
46 |
47 | phase=$2 # (pretrain|finetune)
48 | echo "NOTE: Expecting as argument the training phase. Got: $phase"
49 | seed=$3
50 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
51 |
52 | # PRETRAIN
53 | if [[ "$phase" == *"pretrain"* ]]; then
54 |
55 | python generative_B/train_generative_B.py --dataset "posefix-A" \
56 | "${architecture_args[@]}" \
57 | "${loss_args[@]}" \
58 | "${bonus_args[@]}" \
59 | --lr 0.00001 --wd 0.0001 --batch_size 128 --seed $seed \
60 | --epochs 5000 --log_step 100 --val_every 20 \
61 | --fid $fid
62 |
63 | # FINETUNE
64 | elif [[ "$phase" == *"finetune"* ]]; then
65 |
66 | python generative_B/train_generative_B.py --dataset "posefix-HPP" \
67 | "${architecture_args[@]}" \
68 | "${loss_args[@]}" \
69 | "${bonus_args[@]}" \
70 | --apply_LR_augmentation \
71 | --lrposemul 0.1 --lrtextmul 1 \
72 | --lr 0.000001 --wd 0.00001 --batch_size 128 --seed $seed \
73 | --epochs 5000 --log_step 20 --val_every 20 \
74 | --fid $fid \
75 | --pretrained $pretrained
76 |
77 | fi
78 |
79 | fi
80 |
81 |
82 | # EVAL QUANTITATIVELY
83 | if [[ "$action" == *"eval"* ]]; then
84 |
85 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one
86 |
87 | for model_path in "${experiments[@]}"
88 | do
89 | echo $model_path
90 | python generative_B/evaluate_generative_B.py --dataset "posefix-H" \
91 | --model_path ${model_path} --checkpoint $checkpoint_type \
92 | --fid $fid \
93 | --split test
94 | # --special_eval
95 | done
96 | fi
97 |
98 |
99 | # EVAL QUALITATIVELY
100 | if [[ "$action" == *"demo"* ]]; then
101 |
102 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one
103 | streamlit run generative_B/demo_generative_B.py -- --model_paths "${experiments[@]}" --checkpoint $checkpoint_type
104 |
105 | fi
--------------------------------------------------------------------------------
/src/text2pose/generative_caption/README.md:
--------------------------------------------------------------------------------
1 | # Pose Description Generation Model
2 |
3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._
4 |
5 | _:warning: The evaluation of this model relies partly on a [text-to-pose retrieval model](../retrieval/README.md) and a [text-conditionned pose generation model](../generative/README.md), see section **Extra setup**, below._
6 |
7 | ## Model overview
8 |
9 | * **Input**: 3D human pose;
10 | * **Output**: pose description.
11 |
12 | 
13 |
14 | ## :crystal_ball: Demo
15 |
16 | To generate text descriptions based on a pretrained model and example 3D human poses, run the following:
17 |
18 | ```
19 | streamlit run generative_caption/demo_generative_caption.py -- --model_paths
20 | ```
21 |
22 | :bulb: Tips: _Specify several model paths to compare models together._
23 |
24 | ## Extra setup
25 |
26 | At the beginning of the bash script, indicate the shortnames of the trained models used for evaluation:
27 | * `fid`: text-to-pose retrieval model ([info](../retrieval/README.md)),
28 | * `pose_generative_model`: text-to-pose generative model ([info](../generative/README.md)),
29 | * `textret_model`: text-to-pose retrieval model, eg. the same as for `fid` ([info](../retrieval/README.md)).
30 |
31 | Indicate the paths to the models corresponding to each of these shortnames in *shortname_2_model_path.txt*.
32 |
33 | ## :bullettrain_front: Train
34 |
35 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options.
36 |
37 | Then use the following command:
38 | ```
39 | bash generative_caption/script_generative_caption.sh 'train'
40 | ```
41 |
42 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model.
43 |
44 | ## :dart: Evaluate
45 |
46 | Use the following command:
47 | ```
48 | bash generative_caption/script_generative_caption.sh 'eval'
49 | ```
50 |
--------------------------------------------------------------------------------
/src/text2pose/generative_caption/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/generative_caption/demo_generative_caption.py:
--------------------------------------------------------------------------------
1 |
2 | ##############################################################
3 | ## text2pose ##
4 | ## Copyright (c) 2023 ##
5 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
6 | ## and Naver Corporation ##
7 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
8 | ## See project root for license details. ##
9 | ##############################################################
10 |
11 | import streamlit as st
12 | import argparse
13 | import torch
14 |
15 | import text2pose.demo as demo
16 | from text2pose.generative_caption.evaluate_generative_caption import load_model
17 |
18 |
19 | parser = argparse.ArgumentParser(description='Parameters for the demo.')
20 | parser.add_argument('--model_paths', nargs='+', type=str, help='Paths to the models to be compared.')
21 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help='Checkpoint to choose if model path is incomplete.')
22 | args = parser.parse_args()
23 |
24 |
25 | ### INPUT
26 | ################################################################################
27 |
28 | data_version = "posescript-H2"
29 |
30 |
31 | ### SETUP
32 | ################################################################################
33 |
34 | # --- layout
35 | st.markdown("""
36 |
41 | """, unsafe_allow_html=True)
42 |
43 | # --- data
44 | available_splits = ['train', 'val', 'test']
45 | models, _, body_model = demo.setup_models(args.model_paths, args.checkpoint, load_model)
46 | dataID_2_pose_info, captions = demo.setup_posescript_data(data_version)
47 |
48 |
49 | ### MAIN APP
50 | ################################################################################
51 |
52 | # define query input interface
53 | cols_query = st.columns(3)
54 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test'))
55 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID'), index=1)
56 | number = cols_query[2].number_input("Split index or ID:", 0)
57 | st.markdown("""---""")
58 |
59 | # get query data
60 | pose_ID, pose_data, pose_img, default_description = demo.get_posescript_datapoint(number, query_type, split_for_research, captions, dataID_2_pose_info, body_model)
61 |
62 | # show query data
63 | cols_input = st.columns(2)
64 | cols_input[0].image(pose_img, caption="Annotated pose")
65 | if default_description:
66 | cols_input[1].write("Annotated text:")
67 | cols_input[1].write(f"_{default_description}_")
68 | else:
69 | cols_input[1].write("_(Not annotated.)_")
70 |
71 | # generate text
72 | st.markdown("""---""")
73 | st.write("**Text generation:**")
74 | for i, model in enumerate(models):
75 |
76 | with torch.no_grad():
77 | texts, scores = model.generate_text(pose_data.view(1, -1, 3)) # (1, njoints, 3)
78 |
79 | if len(models) > 1:
80 | cols = st.columns(2)
81 | cols[0].markdown(f'{args.model_paths[i]}
', unsafe_allow_html=True)
82 | cols[1].write(texts[0])
83 | st.markdown("""---""")
84 | else:
85 | st.write(texts[0])
86 | st.markdown("""---""")
87 | st.write(f"_Results obtained with model: {args.model_paths[0]}_")
--------------------------------------------------------------------------------
/src/text2pose/generative_caption/model_generative_caption.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import torch.nn as nn
11 |
12 | import text2pose.config as config
13 | from text2pose.encoders.tokenizers import get_text_encoder_or_decoder_module_name
14 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder
15 | from text2pose.encoders.text_decoders import TransformerTextDecoder, ModalityInputAdapter
16 |
17 | class DescriptionGenerator(nn.Module):
18 |
19 | def __init__(self, num_neurons=512, encoder_latentD=32, decoder_latentD=512,
20 | num_body_joints=config.NB_INPUT_JOINTS,
21 | decoder_nhead=8, decoder_nlayers=4, text_decoder_name="",
22 | transformer_mode="crossattention"):
23 | super(DescriptionGenerator, self).__init__()
24 |
25 | # Define pose encoder
26 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons, latentD=encoder_latentD, num_body_joints=num_body_joints, role="retrieval")
27 |
28 | # Define modality input adaptor
29 | self.modalityInputAdapter = ModalityInputAdapter(inlatentD=encoder_latentD,
30 | outlatentD=decoder_latentD)
31 |
32 | # Define text decoder
33 | self.text_decoder_name = text_decoder_name
34 | self.transformer_mode = transformer_mode
35 | module_ref = get_text_encoder_or_decoder_module_name(text_decoder_name)
36 | if module_ref == "transformer":
37 | self.text_decoder = TransformerTextDecoder(self.text_decoder_name,
38 | nhead=decoder_nhead,
39 | nlayers=decoder_nlayers,
40 | decoder_latentD=decoder_latentD,
41 | transformer_mode=transformer_mode)
42 | else:
43 | raise NotImplementedError
44 |
45 |
46 | def encode_pose(self, pose_body):
47 | return self.pose_encoder(pose_body)
48 |
49 | def decode_text(self, z, captions, caption_lengths, train=False):
50 | return self.text_decoder(z, captions, caption_lengths, train=train)
51 |
52 | def forward(self, poses, captions, caption_lengths):
53 | z = self.encode_pose(poses)
54 | z = self.modalityInputAdapter(z)
55 | decoded = self.decode_text(z, captions, caption_lengths, train=True)
56 | return dict(z=z, **decoded)
57 |
58 | def generate_text(self, poses):
59 | z = self.encode_pose(poses)
60 | z = self.modalityInputAdapter(z)
61 | decoded_texts, likelihood_scores = self.text_decoder.generate_greedy(z)
62 | return decoded_texts, likelihood_scores
--------------------------------------------------------------------------------
/src/text2pose/generative_caption/script_generative_caption.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ##############################################################
4 | ## text2pose ##
5 | ## Copyright (c) 2023 ##
6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
7 | ## and Naver Corporation ##
8 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
9 | ## See project root for license details. ##
10 | ##############################################################
11 |
12 |
13 | ##############################################################
14 | # SCRIPT ARGUMENTS
15 |
16 | action=$1 # (train|eval|demo)
17 | checkpoint_type="best" # (last|best)
18 |
19 | architecture_args=(
20 | --model DescriptionGenerator
21 | --text_decoder_name "transformer_vocPSA2H2"
22 | --transformer_mode 'crossattention'
23 | --decoder_nhead 8 --decoder_nlayers 4
24 | --latentD 32 --decoder_latentD 512
25 | )
26 |
27 | bonus_args=(
28 | )
29 |
30 | fid="ret_distilbert_dataPSA2ftPSH2"
31 | pose_generative_model="gen_distilbert_dataPSA2ftPSH2"
32 | textret_model="ret_distilbert_dataPSA2ftPSH2"
33 |
34 | pretrained="capgen_CAtransfPSA2H2_dataPSA2" # used only if phase=='finetune'
35 |
36 |
37 | ##############################################################
38 | # EXECUTE
39 |
40 | # TRAIN
41 | if [[ "$action" == *"train"* ]]; then
42 |
43 | phase=$2 # (pretrain|finetune)
44 | echo "NOTE: Expecting as argument the training phase. Got: $phase"
45 | seed=$3
46 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
47 |
48 | # PRETRAIN
49 | if [[ "$phase" == *"pretrain"* ]]; then
50 |
51 | python generative_caption/train_generative_caption.py --dataset "posescript-A2" \
52 | "${architecture_args[@]}" \
53 | "${bonus_args[@]}" \
54 | --lr 0.0001 --wd 0.0001 --batch_size 128 --seed $seed \
55 | --epochs 3000 --log_step 100 --val_every 20 \
56 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model
57 |
58 | # FINETUNE
59 | elif [[ "$phase" == *"finetune"* ]]; then
60 |
61 | python generative_caption/train_generative_caption.py --dataset "posescript-H2" \
62 | "${architecture_args[@]}" \
63 | "${bonus_args[@]}" \
64 | --apply_LR_augmentation \
65 | --lr 0.00001 --wd 0.0001 --batch_size 128 --seed $seed \
66 | --epochs 2000 --log_step 100 --val_every 20 \
67 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model \
68 | --pretrained $pretrained
69 |
70 | fi
71 |
72 | fi
73 |
74 |
75 | # EVAL QUANTITATIVELY
76 | if [[ "$action" == *"eval"* ]]; then
77 |
78 | experiments=(
79 | $2
80 | "GT" "random" "auto_posescript-A2_cap1" # control metrics
81 | )
82 |
83 | for model_path in "${experiments[@]}"
84 | do
85 | echo $model_path
86 | python generative_caption/evaluate_generative_caption.py --dataset "posescript-H2" \
87 | --model_path ${model_path} --checkpoint $checkpoint_type \
88 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model \
89 | --split test
90 | done
91 | fi
92 |
93 |
94 | # EVAL QUALITATIVELY
95 | if [[ "$action" == *"demo"* ]]; then
96 |
97 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one
98 | streamlit run generative_caption/demo_generative_caption.py -- --model_paths "${experiments[@]}" --checkpoint $checkpoint_type
99 |
100 | fi
--------------------------------------------------------------------------------
/src/text2pose/generative_modifier/README.md:
--------------------------------------------------------------------------------
1 | # Pose-based Correctional Text Generation Model
2 |
3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._
4 |
5 | _:warning: The evaluation of this model relies partly on a [pose-to-text retrieval model](../retrieval/README.md), a [pair-to-text retrieval model](../retrieval_modifier/README.md) and a [text-guided pose editing model](../generative_B/README.md), see section **Extra setup**, below._
6 |
7 | ## Model overview
8 |
9 | * **Inputs (#2)**: a pair of 3D human poses (pose A + pose B);
10 | * **Output**: textual instruction explaining how to go from pose A to pose B.
11 |
12 | 
13 |
14 | ## :crystal_ball: Demo
15 |
16 | To generate text instructions based on a pretrained model and examples of 3D human pose pairs, run the following:
17 |
18 | ```
19 | streamlit run generative_modifier/demo_generative_modifier.py -- --model_paths
20 | ```
21 |
22 | :bulb: Tips: _Specify several model paths to compare models together._
23 |
24 | ## Extra setup
25 |
26 | At the beginning of the bash script, indicate the shortnames of the trained models used for evaluation:
27 | * `fid`: text-to-pose retrieval model ([info](../retrieval/README.md)),
28 | * `pose_generative_model`: text-guided pose editing model ([info](../generative_B/README.md)),
29 | * `textret_model`: pair-to-text retrieval model ([info](../retrieval_modifier/README.md)).
30 |
31 | Indicate the paths to the models corresponding to each of these shortnames in *shortname_2_model_path.txt*.
32 |
33 | ## :bullettrain_front: Train
34 |
35 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options.
36 |
37 | Then use the following command:
38 | ```
39 | bash generative_modifier/script_generative_modifier.sh 'train'
40 | ```
41 |
42 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model.
43 |
44 | ## :dart: Evaluate
45 |
46 | Use the following command:
47 | ```
48 | bash generative_modifier/script_generative_modifier.sh 'eval'
49 | ```
50 |
--------------------------------------------------------------------------------
/src/text2pose/generative_modifier/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/generative_modifier/demo_generative_modifier.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import streamlit as st
11 | import argparse
12 | import torch
13 |
14 | import text2pose.config as config
15 | import text2pose.demo as demo
16 | import text2pose.utils as utils
17 | import text2pose.data as data
18 | import text2pose.utils_visu as utils_visu
19 | from text2pose.generative_modifier.evaluate_generative_modifier import load_model
20 |
21 |
22 | parser = argparse.ArgumentParser(description='Parameters for the demo.')
23 | parser.add_argument('--model_paths', nargs='+', type=str, help='Path to the models to be compared.')
24 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help="Checkpoint to choose if model path is incomplete.")
25 | args = parser.parse_args()
26 |
27 |
28 | ### INPUT
29 | ################################################################################
30 |
31 | posefix_data_version = "posefix-H"
32 | posescript_data_version = "posescript-H2"
33 |
34 |
35 | ### SETUP
36 | ################################################################################
37 |
38 | # --- layout
39 | st.markdown("""
40 |
45 | """, unsafe_allow_html=True)
46 |
47 | # --- data
48 | available_splits = ['train', 'val', 'test']
49 | models, _, body_model = demo.setup_models(args.model_paths, args.checkpoint, load_model)
50 | dataID_2_pose_info, triplet_data = demo.setup_posefix_data(posefix_data_version)
51 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids)
52 | _, captions = demo.setup_posescript_data(posescript_data_version)
53 |
54 |
55 | ### MAIN APP
56 | ################################################################################
57 |
58 | # define query input interface
59 | cols_query = st.columns(3)
60 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test'))
61 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID'))
62 | number = cols_query[2].number_input("Split index or ID:", 0)
63 | st.markdown("""---""")
64 |
65 | # get query data
66 | pair_ID, pid_A, pid_B, pose_A_data, pose_B_data, pose_A_img, pose_B_img, default_modifier = demo.get_posefix_datapoint(number, query_type, split_for_research, triplet_data, pose_pairs, dataID_2_pose_info, body_model)
67 |
68 | # show query data
69 | st.write(f"**Query data:**")
70 | cols_input = st.columns([1,1,2])
71 | # (enable PoseScript mode: description only)
72 | no_pose_A = cols_input[2].checkbox("PoseScript mode")
73 | if no_pose_A:
74 | pose_A_data = data.T_POSE.view(1, -1)
75 | pose_A_img = utils_visu.image_from_pose_data(pose_A_data, body_model, color="grey", add_ground_plane=True)
76 | pose_A_img = demo.process_img(pose_A_img[0])
77 | pose_B_data, pose_B_img, default_description = demo.get_posescript_datapoint_from_pid(pid_B, captions, dataID_2_pose_info, body_model)
78 | # (actually show)
79 | cols_input[0].image(pose_A_img, caption="T-pose" if no_pose_A else "Pose A")
80 | cols_input[1].image(pose_B_img, caption="Annotated pose" if no_pose_A else "Annotated pose B")
81 | txt = default_description if no_pose_A else default_modifier
82 | if txt:
83 | cols_input[2].write("Annotated text:")
84 | cols_input[2].write(f"_{txt}_")
85 | else:
86 | cols_input[2].write("_(Not annotated.)_")
87 |
88 | # generate text
89 | st.markdown("""---""")
90 | st.write("**Text generation:**")
91 | for i, model in enumerate(models):
92 |
93 | with torch.no_grad():
94 | texts, scores = model.generate_text(pose_A_data.view(1, -1, 3), pose_B_data.view(1, -1, 3)) # (1, njoints, 3)
95 |
96 | if len(models) > 1:
97 | cols = st.columns(2)
98 | cols[0].markdown(f'{args.model_paths[i]}
', unsafe_allow_html=True)
99 | cols[1].write(texts[0])
100 | st.markdown("""---""")
101 | else:
102 | st.write(texts[0])
103 | st.markdown("""---""")
104 | st.write(f"_Results obtained with model: {args.model_paths[0]}_")
--------------------------------------------------------------------------------
/src/text2pose/generative_modifier/model_generative_modifier.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import torch.nn as nn
11 |
12 | import text2pose.config as config
13 | from text2pose.encoders.tokenizers import get_text_encoder_or_decoder_module_name
14 | from text2pose.encoders.modules import TIRG
15 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder
16 | from text2pose.encoders.text_decoders import TransformerTextDecoder, ModalityInputAdapter
17 |
18 |
19 | class FeedbackGenerator(nn.Module):
20 |
21 | def __init__(self, num_neurons=512, encoder_latentD=32, comparison_latentD=32,
22 | num_body_joints=config.NB_INPUT_JOINTS,
23 | decoder_latentD=512, decoder_nhead=8, decoder_nlayers=4, text_decoder_name="",
24 | comparison_module_mode="tirg", transformer_mode="crossattention"):
25 | super(FeedbackGenerator, self).__init__()
26 |
27 | # Define pose encoder
28 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons, latentD=encoder_latentD, num_body_joints=num_body_joints, role="retrieval")
29 |
30 | # Define fusing module
31 | self.comparison_module = ComparisonModule(inlatentD=encoder_latentD,
32 | outlatentD=comparison_latentD,
33 | mode=comparison_module_mode)
34 |
35 | # Define modality input adaptor
36 | self.modalityInputAdapter = ModalityInputAdapter(inlatentD=comparison_latentD,
37 | outlatentD=decoder_latentD)
38 |
39 | # Define text decoder
40 | self.text_decoder_name = text_decoder_name
41 | self.transformer_mode = transformer_mode
42 | module_ref = get_text_encoder_or_decoder_module_name(text_decoder_name)
43 | if module_ref == "transformer":
44 | self.text_decoder = TransformerTextDecoder(self.text_decoder_name,
45 | nhead=decoder_nhead,
46 | nlayers=decoder_nlayers,
47 | decoder_latentD=decoder_latentD,
48 | transformer_mode=transformer_mode)
49 | else:
50 | raise NotImplementedError
51 |
52 |
53 | def encode_pose(self, pose_body):
54 | return self.pose_encoder(pose_body)
55 |
56 | def decode_text(self, z, captions, caption_lengths, train=False):
57 | return self.text_decoder(z, captions, caption_lengths, train=train)
58 |
59 | def fuse_input_poses(self, embed_poses_A, embed_poses_B):
60 | z = self.comparison_module(embed_poses_A, embed_poses_B)
61 | z = self.modalityInputAdapter(z)
62 | return z
63 |
64 | def forward(self, poses_A, captions, caption_lengths, poses_B):
65 | z_a = self.encode_pose(poses_A)
66 | z_b = self.encode_pose(poses_B)
67 | z = self.fuse_input_poses(z_a, z_b)
68 | decoded = self.decode_text(z, captions, caption_lengths, train=True)
69 | return dict(z=z, **decoded)
70 |
71 | def generate_text(self, poses_A, poses_B):
72 | z_a = self.encode_pose(poses_A)
73 | z_b = self.encode_pose(poses_B)
74 | z = self.fuse_input_poses(z_a, z_b)
75 | decoded_texts, likelihood_scores = self.text_decoder.generate_greedy(z)
76 | return decoded_texts, likelihood_scores
77 |
78 |
79 | class ComparisonModule(nn.Module):
80 | """
81 | Given two poses A and B, compute an embedding representing the result from
82 | the comparison of the two poses.
83 | """
84 |
85 | def __init__(self, inlatentD, outlatentD, mode="tirg"):
86 | super(ComparisonModule, self).__init__()
87 |
88 | self.inlatentD = inlatentD
89 | self.outlatentD = outlatentD
90 | self.mode = mode
91 |
92 | if mode == "tirg":
93 | self.tirg = TIRG(input_dim=[inlatentD, inlatentD], output_dim=outlatentD, out_l2_normalize=False)
94 | self.forward = self.forward_tirg
95 | else:
96 | print(f"Name for the mode of the comparison module is unknown (provided {mode}).")
97 | raise NotImplementedError
98 |
99 |
100 | def forward_tirg(self, pose_A_embeddings, pose_B_embeddings):
101 | return self.tirg.query_compositional_embedding(pose_B_embeddings, pose_A_embeddings)
--------------------------------------------------------------------------------
/src/text2pose/generative_modifier/script_generative_modifier.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ##############################################################
4 | ## text2pose ##
5 | ## Copyright (c) 2023 ##
6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
7 | ## and Naver Corporation ##
8 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
9 | ## See project root for license details. ##
10 | ##############################################################
11 |
12 |
13 | ##############################################################
14 | # SCRIPT ARGUMENTS
15 |
16 | action=$1 # (train|eval|demo)
17 | checkpoint_type="best" # (last|best)
18 |
19 | architecture_args=(
20 | --model FeedbackGenerator
21 | --text_decoder_name "transformer_vocPFAHPP"
22 | --transformer_mode 'crossattention'
23 | --comparison_module_mode 'tirg'
24 | --decoder_nhead 8 --decoder_nlayers 4
25 | --latentD 32 --comparison_latentD 32 --decoder_latentD 512
26 | )
27 |
28 | bonus_args=(
29 | )
30 |
31 | fid="ret_distilbert_dataPSA2ftPSH2"
32 | pose_generative_model="b_gen_distilbert_dataPFAftPFH"
33 | textret_model="modret_distilbert_dataPFAftPFH"
34 |
35 | pretrained="modgen_CAtransfPFAHPP_dataPFA" # used only if phase=='finetune'
36 |
37 |
38 | ##############################################################
39 | # EXECUTE
40 |
41 | # TRAIN
42 | if [[ "$action" == *"train"* ]]; then
43 |
44 | phase=$2 # (pretrain|finetune)
45 | echo "NOTE: Expecting as argument the training phase. Got: $phase"
46 | seed=$3
47 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
48 |
49 | # PRETRAIN
50 | if [[ "$phase" == *"pretrain"* ]]; then
51 |
52 | python generative_modifier/train_generative_modifier.py --dataset "posefix-A" \
53 | "${architecture_args[@]}" \
54 | "${bonus_args[@]}" \
55 | --lr 0.0001 --wd 0.0001 --batch_size 128 --seed $seed \
56 | --epochs 3000 --log_step 100 --val_every 20 \
57 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model
58 |
59 | # FINETUNE
60 | elif [[ "$phase" == *"finetune"* ]]; then
61 |
62 | python generative_modifier/train_generative_modifier.py --dataset "posefix-H" \
63 | "${architecture_args[@]}" \
64 | "${bonus_args[@]}" \
65 | --apply_LR_augmentation \
66 | --lr 0.00001 --wd 0.0001 --batch_size 128 --seed $seed \
67 | --epochs 2000 --log_step 100 --val_every 20 \
68 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model \
69 | --pretrained $pretrained
70 |
71 | fi
72 |
73 | fi
74 |
75 |
76 | # EVAL QUANTITATIVELY
77 | if [[ "$action" == *"eval"* ]]; then
78 |
79 | experiments=(
80 | $2
81 | "GT" "random" "auto_posefix-A_cap0" # control metrics
82 | )
83 |
84 | for model_path in "${experiments[@]}"
85 | do
86 | echo $model_path
87 | python generative_modifier/evaluate_generative_modifier.py --dataset "posefix-H" \
88 | --model_path ${model_path} --checkpoint $checkpoint_type \
89 | --fid $fid --pose_generative_model $pose_generative_model --textret_model $textret_model \
90 | --split test
91 | done
92 | fi
93 |
94 |
95 | # EVAL QUALITATIVELY
96 | if [[ "$action" == *"demo"* ]]; then
97 |
98 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one
99 | streamlit run generative_modifier/demo_generative_modifier.py -- --model_paths "${experiments[@]}" --checkpoint $checkpoint_type
100 |
101 | fi
--------------------------------------------------------------------------------
/src/text2pose/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 |
6 | def BBC(scores):
7 | # build the ground truth label tensor: the diagonal corresponds to the
8 | # correct classification
9 | GT_labels = torch.arange(scores.shape[0], device=scores.device).long()
10 | loss = F.cross_entropy(scores, GT_labels) # mean reduction
11 | return loss
12 |
13 |
14 | def symBBC(scores):
15 | x2y_loss = BBC(scores)
16 | y2x_loss = BBC(scores.t())
17 | return (x2y_loss + y2x_loss) / 2.0
18 |
19 |
20 | def laplacian_nll(x_tilde, x, log_sigma):
21 | """ Negative log likelihood of an isotropic Laplacian density """
22 | log_norm = - (np.log(2) + log_sigma)
23 | log_energy = - (torch.abs(x_tilde - x)) / torch.exp(log_sigma)
24 | return - (log_norm + log_energy)
25 |
26 |
27 | def gaussian_nll(x_tilde, x, log_sigma):
28 | """ Negative log-likelihood of an isotropic Gaussian density """
29 | log_norm = - 0.5 * (np.log(2 * np.pi) + log_sigma)
30 | log_energy = - 0.5 * F.mse_loss(x_tilde, x, reduction='none') / torch.exp(log_sigma)
31 | return - (log_norm + log_energy)
--------------------------------------------------------------------------------
/src/text2pose/posefix/README.md:
--------------------------------------------------------------------------------
1 | # About the PoseFix dataset
2 |
3 | ## :inbox_tray: Download
4 |
5 | **License.**
6 | *The PoseScript dataset is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
7 | A summary of the CC BY-NC-SA 4.0 license is located [here](https://creativecommons.org/licenses/by-nc-sa/4.0/).
8 | The CC BY-NC-SA 4.0 license is located [here](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).*
9 |
10 | | Version | Link |
11 | |---|---|
12 | | ICCV23 | [download](https://download.europe.naverlabs.com/ComputerVision/PoseFix/posefix_dataset_release.zip) |
13 |
14 |
15 | Dataset content.
16 |
17 | * a file linking each pose ID to the reference of its corresponding pose sequence in AMASS, and its frame index;
18 | * a file linking each pair ID (index) to a pair of pose IDs;
19 | * a file linking each pair ID with its text modifiers (separate files for automatically generated modifiers, human-written ones and paraphrases);
20 | * files listing pair IDs for each split.
21 |
22 | Please refer to the provided README for more details.
23 |
24 |
25 |
26 | ## :crystal_ball: Take a quick look
27 |
28 | To take a quick look at the data (ie. look at the poses under different viewpoints, the human-written modifier and the paraphrases collected with InstructGPT when available, as well as the automatic modifiers):
29 |
30 | ```
31 | streamlit run posefix/explore_posefix.py
32 | ```
33 |
34 | ## :page_with_curl: Generate automatic modifiers
35 |
36 | ### Overview of the pipeline
37 |
38 | Given a pair of 3D poses, we use paircodes to extract semantic pair-relationship information, and posecodes for each of the two poses to enrich the instructions with "absolute" information. These codes are then selected, merged or combined (when relevant) before being converted into a structural instruction in natural language explaining how to go from the first pose to the second pose. Letters ‘L’ and ‘R’ stand for ‘left’ and ‘right’ respectively.
39 |
40 | 
41 |
42 | Please refer to the paper and the supplementary material for more extensive explanations.
43 |
44 | ### Generate modifiers
45 |
46 | To generate automatic modifiers, please follow these steps:
47 |
48 | - **compute joint coordinates for all poses**
49 | ```
50 | python posescript/compute_coords.py
51 | ```
52 |
53 | - **compute rotation change for all pairs**
54 | ```
55 | python posefix/compute_rotation_change.py
56 | ```
57 |
58 | - (optional) **modify corrective data as you see fit**:
59 |
60 | (Re)Define posecodes & paircodes (categories, thresholds, tolerable noise levels, eligibility), super-posecodes & super-paircodes, template sentences and so forth by modifiying both *posescript/captioning_data.py* and *posefix/corrective_data.py*. The data structures are extensively explained in these files, and one can follow some marks (`ADD_VIRTUAL_JOINT`, `ADD_POSECODE_KIND`, `ADD_SUPER_POSECODE`, `ADD_PAIRCODE_KIND`, `ADD_SUPER_PAIRCODE`) to add new instruction material. Note that some posecodes (initially created for [single pose description](../posescript/README.md)) are also used in this pipeline.
61 |
62 | - **generate automatic modifiers**
63 |
64 | *Possible arguments are:*
65 | - `--saving_dir`: general location for saving generated instructions and data related to them (default: */generated_instructions/*)
66 | - `--version_name`: name of the version. Will be used to create a subdirectory of `--saving_dir` in which to save all files (instructions & intermediary results). Default is 'tmp'.
67 | - `--simplified_instructions`: produce a simplified version of the instructions (basically: no aggregation, no randomly referring to a body part by a substitute word).
68 | - `--random_skip`: randomly skip some non-essential (ie. not so rare) posecodes & paircodes.
69 |
70 |
71 |
72 | For instance, modifiers can be generated using the following command:
73 | ```
74 | python posefix/correcting.py --version_name --random_skip
75 | ```
76 |
77 | To work on a small subset, one can use the following command:
78 | ```
79 | python posefix/correcting.py --debug
80 | ```
81 | and specify which pair to study by modifying the list of pair IDs marked with `SPECIFIC_INPUT_PAIR_IDS` in *posefix/correcting.py*.
82 |
83 |
84 | To generate caption versions similar to the ICCV 2023 paper (posefix-A):
85 |
86 | | Version | Command |
87 | |---------|---------|
88 | | pfA | `python posefix/correcting.py --version_name captions_pfA` |
89 | | pfB | `python posefix/correcting.py --version_name captions_pfB --random_skip` |
90 | | pfC | `python posefix/correcting.py --version_name captions_pfC --random_skip --simplified_captions` |
91 |
92 | *Note that some paircodes were added since, for the release of PoseEmbroider.*
93 |
94 |
95 | ## Citation
96 |
97 | If you use this code or the PoseFix dataset, please cite the following paper:
98 |
99 | ```bibtex
100 | @inproceedings{delmas2023posefix,
101 | title={{PoseFix: Correcting 3D Human Poses with Natural Language}},
102 | author={{Delmas, Ginger and Weinzaepfel, Philippe and Moreno-Noguer, Francesc and Rogez, Gr\'egory}},
103 | booktitle={{ICCV}},
104 | year={2023}
105 | }
106 | ```
107 |
108 | Please also remember to follow AMASS's respective citation guideline if you use the AMASS data.
109 |
--------------------------------------------------------------------------------
/src/text2pose/posefix/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/posefix/compute_rotation_change.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import os
11 | from tqdm import tqdm
12 | import torch
13 | import roma
14 |
15 | import text2pose.config as config
16 | import text2pose.utils as utils
17 |
18 |
19 | ### SETUP
20 | ################################################################################
21 |
22 | # load data
23 | dataID_2_pose_info = utils.read_json(config.file_pose_id_2_dataset_sequence_and_frame_index)
24 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids)
25 |
26 |
27 | ### COMPUTE ROTATION CHANGES
28 | ################################################################################
29 |
30 | rad2deg = lambda x: x*180/torch.pi
31 |
32 | # compute rotation changes
33 | rotation_changes = [[0.0, 0.0, 0.0] for _ in pose_pairs]
34 | for pairID in tqdm(range(len(pose_pairs))):
35 |
36 | pidA, pidB = pose_pairs[pairID]
37 | pose_info_A = dataID_2_pose_info[str(pidA)]
38 | pose_info_B = dataID_2_pose_info[str(pidB)]
39 |
40 | # 1) normalize the change in rotation (ie. get the difference in
41 | # rotation between pose A and pose B)
42 | pose_data_A, R_norm = utils.get_pose_data_from_file(pose_info_A, output_rotation=True)
43 | pose_data_B = utils.get_pose_data_from_file(pose_info_B, applied_rotation=R_norm if pose_info_A[1] == pose_info_B[1] else None)
44 |
45 | # 2) get the change of rotation (angle in degree) of B wrt A
46 | # The angle should be positive if turning left (clockwise); and negative
47 | # otherwise (when turning right)
48 | r = roma.rotvec_composition((pose_data_B[0:1,:3], roma.rotvec_inverse(pose_data_A[0:1,:3])))
49 | r = rad2deg(r[0]).tolist() # convert to degrees
50 | r = [r[0], r[2], -r[1]] # reorient (x,y,z) where x is oriented towards the right and y points up
51 | rotation_changes[pairID] = r
52 |
53 | # save
54 | save_filepath = os.path.join(config.POSEFIX_LOCATION, f"ids_2_rotation_change{config.version_suffix}.json")
55 | utils.write_json(rotation_changes, save_filepath, pretty=True)
56 | print("Save global rotation changes at", save_filepath)
--------------------------------------------------------------------------------
/src/text2pose/posefix/explore_posefix.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | # $ streamlit run posefix/explore_posefix.py
11 |
12 | import streamlit as st
13 | import random
14 |
15 | import text2pose.config as config
16 | import text2pose.demo as demo
17 | import text2pose.utils as utils
18 | import text2pose.utils_visu as utils_visu
19 |
20 |
21 | ### INPUT
22 | ################################################################################
23 |
24 | version_human = 'posefix-H'
25 | version_paraphrases = 'posefix-PP'
26 | version_auto = 'posefix-A'
27 |
28 |
29 | ### SETUP
30 | ################################################################################
31 |
32 | dataID_2_pose_info, triplet_data_human = demo.setup_posefix_data(version_human)
33 | _, triplet_data_pp = demo.setup_posefix_data(version_paraphrases)
34 | _, triplet_data_auto = demo.setup_posefix_data(version_auto)
35 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids)
36 | body_model = demo.setup_body_model()
37 |
38 |
39 | ### DISPLAY DATA
40 | ################################################################################
41 |
42 | # get input pair id
43 | pair_ID = st.number_input("Pair ID:", 0, len(pose_pairs))
44 | if st.button('Look at a random pose!'):
45 | pair_ID = random.randint(0, len(pose_pairs))
46 | st.write(f"Looking at pair ID: **{pair_ID}**")
47 |
48 | # display information about the pair
49 | pid_A, pid_B = pose_pairs[pair_ID]
50 | in_sequence = dataID_2_pose_info[str(pid_A)][1] == dataID_2_pose_info[str(pid_B)][1]
51 | st.markdown(f"{'In' if in_sequence else 'Out-of'}-sequence pair (pose A to pose B).", unsafe_allow_html=True)
52 |
53 | # load pose data
54 | pose_A_info = dataID_2_pose_info[str(pid_A)]
55 | pose_A_data, rA = utils.get_pose_data_from_file(pose_A_info, output_rotation=True)
56 | pose_B_info = dataID_2_pose_info[str(pid_B)]
57 | pose_B_data = utils.get_pose_data_from_file(pose_B_info, applied_rotation=rA if in_sequence else None)
58 |
59 | # render the pair under the desired viewpoint, and display it
60 | view_angle = st.slider("Point of view:", min_value=-180, max_value=180, step=20, value=0)
61 | viewpoint = [] if view_angle == 0 else (view_angle, (0,1,0))
62 | pair_img = utils_visu.image_from_pair_data(pose_A_data, pose_B_data, body_model, viewpoint=viewpoint, add_ground_plane=True)
63 | st.image(demo.process_img(pair_img))
64 |
65 | # display text annotations
66 | if pair_ID in triplet_data_human:
67 | for k,m in enumerate(triplet_data_human[pair_ID]["modifier"]):
68 | st.markdown(f"Human-written modifier n°{k+1}", unsafe_allow_html=True)
69 | st.write(f"_{m.strip()}_")
70 | if pair_ID in triplet_data_pp:
71 | for k,m in enumerate(triplet_data_pp[pair_ID]["modifier"]):
72 | st.markdown(f"Paraphrase n°{k+1}", unsafe_allow_html=True)
73 | st.write(f"_{m.strip()}_")
74 | if pair_ID in triplet_data_auto:
75 | for k,m in enumerate(triplet_data_auto[pair_ID]["modifier"]):
76 | st.markdown(f"Automatic modifier n°{k+1}", unsafe_allow_html=True)
77 | st.write(f"_{m.strip()}_")
78 | else:
79 | st.write("This pair ID was not annotated.")
--------------------------------------------------------------------------------
/src/text2pose/posescript/README.md:
--------------------------------------------------------------------------------
1 | # About the PoseScript dataset
2 |
3 | ## :inbox_tray: Download
4 |
5 | **License.**
6 | *The PoseScript dataset is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
7 | A summary of the CC BY-NC-SA 4.0 license is located [here](https://creativecommons.org/licenses/by-nc-sa/4.0/).
8 | The CC BY-NC-SA 4.0 license is located [here](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).*
9 |
10 | | Version | Link | What changed ? |
11 | |---|---|---|
12 | | **V2** | [download](https://download.europe.naverlabs.com/ComputerVision/PoseScript/posescript_release_v2.zip) | pose set is 100k instead of 20k ; file format updated according to the 2023 code release |
13 | | **V1** | [download](https://download.europe.naverlabs.com/ComputerVision/PoseScript/posescript_dataset_release.zip) | 6293 human-written pairs instead of 3893 |
14 | | **V0 (ECCV22)** | (as V1) | |
15 |
16 |
17 |
18 | Dataset content.
19 |
20 | * a file linking each pose ID to the reference of its corresponding pose sequence in AMASS, and its frame index;
21 | * a file linking each pose ID with its descriptions (separate files for automatically generated captions and human-written ones);
22 | * files listing pose IDs for each split.
23 |
24 | Please refer to the provided README for more details.
25 |
26 |
27 |
28 | ## :crystal_ball: Take a quick look
29 |
30 | To take a quick look at the data (ie. look at the pose under different viewpoints, the human-written caption when available, and the automatic captions):
31 |
32 | ```
33 | streamlit run posescript/explore_posescript.py
34 | ```
35 |
36 | ## :page_with_curl: Generate automatic captions
37 |
38 | ### Overview of the captioning pipeline
39 |
40 | Given a normalized 3D pose, we use posecodes to extract semantic pose information. These posecodes are then selected, merged or combined (when relevant) before being converted into a structural pose description in natural language. Letters ‘L’ and ‘R’ stand for ‘left’ and ‘right’ respectively.
41 |
42 | 
43 |
44 | Please refer to the paper and the supplementary material for more extensive explanations.
45 |
46 | ### Generate captions
47 |
48 | To generate automatic captions, please follow these steps:
49 |
50 | *Note: the time estimations below are given for 20k poses, but the newest version of PoseScript has 100k poses.*
51 |
52 | - **compute joint coordinates for all poses** _(~ 20 min)_
53 | ```
54 | python posescript/compute_coords.py
55 | ```
56 |
57 | - **get and format BABEL labels for poses in PoseScript** _(~ 5 min)_
58 | ```
59 | python posescript/format_babel_labels.py
60 | ```
61 |
62 | - (optional) **modify captioning data as you see fit**:
63 | - looking at diagrams on posecode statistics can be helpful to decide on posecode eligibility. To do so, run the following:
64 | ```
65 | python posescript/captioning.py --action posecode_stats --version_name posecode_stats
66 | ```
67 | - (re)define posecodes (categories, thresholds, tolerable noise levels, eligibility), super-posecodes, ripple effect rules based on statistics, template sentences and so forth by modifiying *posescript/captioning_data.py*. The data structures are extensively explained in this file, and one can follow some marks (`ADD_VIRTUAL_JOINT`, `ADD_POSECODE_KIND`, `ADD_SUPER_POSECODE`) to add new captioning material. Note that some posecodes were added to be used in the [pipeline that generates automatic modifiers](../posefix/README.md) (search for the `ADDED_FOR_MODIFIERS` marks).
68 |
69 |
70 | - **generate automatic captions** _(~ 1 min = 20k captions, with 1 cap/pose)_
71 |
72 | *Possible arguments are:*
73 | - `--saving_dir`: general location for saving generated captions and data related to them (default: */generated_captions/*)
74 | - `--version_name`: name of the caption version. Will be used to create a subdirectory of `--saving_dir` in which to save all files (descriptions & intermediary results). Default is 'tmp'.
75 | - `--simplified_captions`: produce a simplified version of the captions (basically: no aggregation, no omitting of some support keypoints for the sake of flow, no randomly referring to a body part by a substitute word). This configuration is used to generate caption versions E and F from the paper.
76 | - `--apply_transrel_ripple_effect`: discard some posecodes using ripple effect rules based on transitive relations between body parts.
77 | - `--apply_stat_ripple_effect`: discard some posecodes using ripple effect rules based on statistically frequent pairs and triplets of posecodes.
78 | - `--random_skip`: randomly skip some non-essential posecodes (ie. posecodes that were found to be satisfied by more than 6% of the 20k poses considered in PoseScript).
79 | - `--add_babel_info`: add sentences using information extracted from BABEL.
80 | - `--add_dancing_info`: add a sentence stating that the pose is a dancing pose if it comes from DanceDB (provided that `--add_babel_info` is also set to True.).
81 |
82 |
83 |
84 |
85 | To generate caption versions similar to the latest version of the dataset (posescript-A2):
86 |
87 | | Version | Command |
88 | |---------|---------|
89 | | N2 | `python posescript/captioning.py --version_name captions_n2 --random_skip --simplified_captions` |
90 | | N6 | `python posescript/captioning.py --version_name captions_n6 --random_skip --add_babel_info --add_dancing_info` |
91 | | N7 | `python posescript/captioning.py --version_name captions_n7 --random_skip --apply_transrel_ripple_effect --apply_stat_ripple_effect` |
92 |
93 | *Note that some posecodes were added since, for the release of PoseFix and PoseEmbroider.*
94 |
95 |
96 |
97 | To generate caption versions similar to the ECCV 2022 paper (posescript-A):
98 |
99 | | Version | Command |
100 | |---------|---------|
101 | | A | `python posescript/captioning.py --version_name captions_A --apply_transrel_ripple_effect --apply_stat_ripple_effect --random_skip --add_babel_info --add_dancing_info` |
102 | | B | `python posescript/captioning.py --version_name captions_B --random_skip --add_babel_info --add_dancing_info` |
103 | | C | `python posescript/captioning.py --version_name captions_C --random_skip --add_babel_info` |
104 | | D | `python posescript/captioning.py --version_name captions_D --random_skip` |
105 | | E | `python posescript/captioning.py --version_name captions_E --random_skip --simplified_captions` |
106 | | F | `python posescript/captioning.py --version_name captions_F --simplified_captions` |
107 |
108 | *Note that some posecodes were added since, for the release of PoseFix and PoseEmbroider.*
109 |
110 |
111 |
112 |
113 | ## Citation
114 |
115 | If you use this code or the PoseScript dataset, please cite the following paper:
116 |
117 | ```bibtex
118 | @inproceedings{posescript,
119 | title={{PoseScript: 3D Human Poses from Natural Language}},
120 | author={{Delmas, Ginger and Weinzaepfel, Philippe and Lucas, Thomas and Moreno-Noguer, Francesc and Rogez, Gr\'egory}},
121 | booktitle={{ECCV}},
122 | year={2022}
123 | }
124 | ```
125 |
126 | Please also remember to follow AMASS' and BABEL's respective citation guideline if you use the AMASS or BABEL data respectively.
--------------------------------------------------------------------------------
/src/text2pose/posescript/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/posescript/compute_coords.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import os
11 | from tqdm import tqdm
12 | import math
13 | import torch
14 | from human_body_prior.body_model.body_model import BodyModel
15 |
16 | import text2pose.config as config
17 | import text2pose.utils as utils
18 |
19 |
20 | ### INPUT
21 | ################################################################################
22 |
23 | device = 'cpu'
24 |
25 |
26 | ### SETUP
27 | ################################################################################
28 |
29 | # setup body model
30 | body_model = BodyModel(model_type = config.POSE_FORMAT,
31 | bm_fname = config.NEUTRAL_BM,
32 | num_betas = config.n_betas)
33 | body_model.eval()
34 | body_model.to(device)
35 |
36 | # load data
37 | dataID_2_pose_info = utils.read_json(config.file_pose_id_2_dataset_sequence_and_frame_index)
38 |
39 | # rotation transformation to apply so that the coordinates correspond to what we
40 | # actually visualize (ie. from front view)
41 | rotX = lambda theta: torch.tensor([[1, 0, 0], [0, torch.cos(theta), -torch.sin(theta)], [0, torch.sin(theta), torch.cos(theta)]])
42 |
43 | def transf(rotMat, theta_deg, values):
44 | theta_rad = math.pi * torch.tensor(theta_deg).float() / 180.0
45 | return rotMat(theta_rad).mm(values.t()).t()
46 |
47 |
48 | ### COMPUTE COORDINATES
49 | ################################################################################
50 |
51 | coords = []
52 |
53 | # compute all joint coordinates
54 | for dataID in tqdm(range(len(dataID_2_pose_info))):
55 |
56 | # load pose data
57 | pose_info = dataID_2_pose_info[str(dataID)]
58 | pose = utils.get_pose_data_from_file(pose_info)
59 |
60 | # infer coordinates
61 | with torch.no_grad():
62 | j = body_model(**utils.pose_data_as_dict(pose)).Jtr
63 | j = j.detach().cpu()[0]
64 | j = transf(rotX, -90, j)
65 |
66 | # store data
67 | coords.append(j.view(1, -1, 3))
68 | coords = torch.cat(coords)
69 |
70 | # save
71 | save_filepath = os.path.join(config.POSESCRIPT_LOCATION, f"ids_2_coords_correct_orient_adapted{config.version_suffix}.pt")
72 | torch.save(coords, save_filepath)
73 | print("Save coordinates at", save_filepath)
--------------------------------------------------------------------------------
/src/text2pose/posescript/explore_posescript.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | # $ streamlit run posescript/explore_posescript.py
11 |
12 | import streamlit as st
13 |
14 | import text2pose.demo as demo
15 | import text2pose.utils as utils
16 | import text2pose.utils_visu as utils_visu
17 |
18 |
19 | ### INPUT
20 | ################################################################################
21 |
22 | version_human = 'posescript-H2'
23 | version_auto = 'posescript-A2'
24 |
25 |
26 | ### SETUP
27 | ################################################################################
28 |
29 | dataID_2_pose_info, captions_human = demo.setup_posescript_data(version_human)
30 | _, captions_auto = demo.setup_posescript_data(version_auto)
31 | body_model = demo.setup_body_model()
32 |
33 |
34 | ### DISPLAY DATA
35 | ################################################################################
36 |
37 | # get input pose id
38 | dataID = st.number_input("Pose ID:", 0, len(dataID_2_pose_info)-1)
39 |
40 | # display information about the pose
41 | pose_info = dataID_2_pose_info[str(dataID)]
42 | st.write(f"Pose from the **{pose_info[0]}** dataset, **frame {pose_info[2]}** of sequence *{pose_info[1]}*")
43 |
44 | # load pose data
45 | pose_data = utils.get_pose_data_from_file(pose_info)
46 |
47 | # render the pose under the desired viewpoint, and display it
48 | view_angle = st.slider("Point of view:", min_value=-180, max_value=180, step=20, value=0)
49 | viewpoint = [] if view_angle == 0 else (view_angle, (0,1,0))
50 | img = utils_visu.image_from_pose_data(pose_data, body_model, viewpoints=[viewpoint], add_ground_plane=True)[0] # 1 viewpoint
51 | st.image(demo.process_img(img))
52 |
53 | # display captions
54 | if dataID in captions_human:
55 | for k,c in enumerate(captions_human[dataID]):
56 | st.markdown(f"Human-written description n°{k+1}", unsafe_allow_html=True)
57 | st.write(captions_human[dataID][k])
58 | if dataID in captions_auto:
59 | for k,c in enumerate(captions_auto[dataID]):
60 | st.markdown(f"Automatic description n°{k+1}", unsafe_allow_html=True)
61 | st.write(captions_auto[dataID][k])
62 | else:
63 | st.write("This pose ID was not annotated.")
--------------------------------------------------------------------------------
/src/text2pose/posescript/format_babel_labels.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import os
11 | from tqdm import tqdm
12 | import json
13 | import numpy as np
14 | import pickle
15 | from tabulate import tabulate
16 |
17 | import text2pose.config as config
18 | import text2pose.utils as utils
19 |
20 |
21 | ### SETUP
22 | ################################################################################
23 |
24 | # load BABEL
25 | l_babel_dense_files = ['train', 'val', 'test']
26 | l_babel_extra_files = ['extra_train', 'extra_val']
27 |
28 | babel = {}
29 | for file in l_babel_dense_files:
30 | babel[file] = json.load(open(os.path.join(config.BABEL_LOCATION, file+'.json')))
31 |
32 | for file in l_babel_extra_files:
33 | babel[file] = json.load(open(os.path.join(config.BABEL_LOCATION, file+'.json')))
34 |
35 | # load PoseScript
36 | dataID_2_pose_info = utils.read_json(config.file_pose_id_2_dataset_sequence_and_frame_index)
37 |
38 | # AMASS/BABEL path adaptation
39 | amass_to_babel_subdir = {
40 | 'ACCAD': 'ACCAD/ACCAD',
41 | 'BMLhandball': '', # not available
42 | 'BMLmovi': 'BMLmovi/BMLmovi',
43 | 'BioMotionLab_NTroje': 'BMLrub/BioMotionLab_NTroje',
44 | 'CMU': 'CMU/CMU',
45 | 'DFaust_67': 'DFaust67/DFaust_67',
46 | 'DanceDB': '', # not available
47 | 'EKUT': 'EKUT/EKUT',
48 | 'Eyes_Japan_Dataset': 'EyesJapanDataset/Eyes_Japan_Dataset',
49 | 'HumanEva': 'HumanEva/HumanEva',
50 | 'KIT': 'KIT/KIT',
51 | 'MPI_HDM05': 'MPIHDM05/MPI_HDM05',
52 | 'MPI_Limits': 'MPILimits/MPI_Limits',
53 | 'MPI_mosh': 'MPImosh/MPI_mosh',
54 | 'SFU': 'SFU/SFU',
55 | 'SSM_synced': 'SSMsynced/SSM_synced',
56 | 'TCD_handMocap': 'TCDhandMocap/TCD_handMocap',
57 | 'TotalCapture': 'TotalCapture/TotalCapture',
58 | 'Transitions_mocap': 'Transitionsmocap/Transitions_mocap',
59 | }
60 |
61 |
62 | ### GET LABELS
63 | ################################################################################
64 |
65 | def get_babel_label(amass_rel_path, frame_id):
66 |
67 | # get path correspondance in BABEL
68 | dname = amass_rel_path.split('/')[0]
69 | bname = amass_to_babel_subdir[dname]
70 | if bname == '':
71 | return '__'+dname+'__'
72 | babel_rel_path = '/'.join([bname]+amass_rel_path.split('/')[1:])
73 |
74 | # look for babel annotations
75 | babelfs = []
76 | for f in babel.keys():
77 | for s in babel[f].keys():
78 | if babel[f][s]['feat_p'] == babel_rel_path:
79 | babelfs.append((f,s))
80 |
81 | if len(babelfs) == 0:
82 | return None
83 |
84 | # convert frame id to second
85 | seqdata = np.load(os.path.join(config.AMASS_FILE_LOCATION, amass_rel_path))
86 | framerate = seqdata['mocap_framerate']
87 | t = frame_id / framerate
88 |
89 | # read babel annotations
90 | labels = []
91 | for f,s in babelfs:
92 | if not 'frame_ann' in babel[f][s]:
93 | continue
94 | if babel[f][s]['frame_ann'] is None:
95 | continue
96 | babel_annots = babel[f][s]['frame_ann']['labels']
97 | for i in range(len(babel_annots)):
98 | if t >= babel_annots[i]['start_t'] and t <= babel_annots[i]['end_t']:
99 | labels.append( (babel_annots[i]['raw_label'], babel_annots[i]['proc_label'], babel_annots[i]['act_cat']) )
100 |
101 | return labels
102 |
103 | # gather labels for all poses in PoseScript that come from AMASS
104 | babel_labels_for_posescript = {}
105 | for dataID in tqdm(dataID_2_pose_info):
106 | pose_info = dataID_2_pose_info[dataID]
107 | if pose_info[0] == "AMASS":
108 | babel_labels_for_posescript[dataID] = get_babel_label(pose_info[1], pose_info[2])
109 |
110 | # display some stats
111 | table = []
112 | table.append(['None', sum([v is None for v in babel_labels_for_posescript.values()])])
113 | table.append(['BMLhandball', sum([v=='__BMLhandball__' for v in babel_labels_for_posescript.values()])])
114 | table.append(['DanceDB', sum([v=='__DanceDB__' for v in babel_labels_for_posescript.values()])])
115 | table.append(['0 label', sum([ (isinstance(v,list) and len(v)==0) for v in babel_labels_for_posescript.values()])])
116 | table.append(['None label', sum([ (isinstance(v,list) and len(v)>=1 and v[0][0] is None) for v in babel_labels_for_posescript.values()])])
117 | table.append(['1 label', sum([ (isinstance(v,list) and len(v)==1 and v[0][0] is not None) for v in babel_labels_for_posescript.values()])])
118 | table.append(['>1 label',sum([ (isinstance(v,list) and len(v)>=2 and v[0][0] is not None) for v in babel_labels_for_posescript.values()])])
119 | print(tabulate(table, headers=["Label", "Number of poses"]))
120 |
121 | # save
122 | save_filepath = os.path.join(config.POSESCRIPT_LOCATION, f"babel_labels_for_posescript{config.version_suffix}.pkl")
123 | with open(save_filepath, 'wb') as f:
124 | pickle.dump(babel_labels_for_posescript, f)
125 | print("Saved", save_filepath)
--------------------------------------------------------------------------------
/src/text2pose/retrieval/README.md:
--------------------------------------------------------------------------------
1 | # {Text :left_right_arrow: Pose} Retrieval Model
2 |
3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._
4 |
5 | ## Model overview
6 |
7 | **Possible inputs**: 3D human pose, pose description.
8 |
9 | 
10 |
11 | ## :crystal_ball: Demo
12 |
13 | To look at a ranking of poses (resp. descriptions) referenced in PoseScript by relevance to your own input description (resp. chosen pose), using a pretrained model, run the following:
14 |
15 | ```
16 | bash retrieval/script_retrieval.sh 'demo'
17 | ```
18 |
19 | ## :bullettrain_front: Train
20 |
21 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options.
22 |
23 | Then use the following command:
24 | ```
25 | bash retrieval/script_retrieval.sh 'train'
26 | ```
27 |
28 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model.
29 |
30 | ## :dart: Evaluate
31 |
32 | Use the following command (test on PoseScript-H2):
33 | ```
34 | bash retrieval/script_retrieval.sh 'eval'
35 | ```
36 |
--------------------------------------------------------------------------------
/src/text2pose/retrieval/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/retrieval/demo_retrieval.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import streamlit as st
11 | import argparse
12 | import torch
13 |
14 | import text2pose.demo as demo
15 | import text2pose.utils as utils
16 | import text2pose.utils_visu as utils_visu
17 | from text2pose.retrieval.evaluate_retrieval import load_model
18 |
19 |
20 | parser = argparse.ArgumentParser(description='Parameters for the demo.')
21 | parser.add_argument('--model_path', type=str, help='Path to the model.')
22 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help='Checkpoint to choose if model path is incomplete.')
23 | parser.add_argument('--n_retrieve', type=int, default=12, help="Number of elements to retrieve.")
24 | args = parser.parse_args()
25 |
26 |
27 | ### INPUT
28 | ################################################################################
29 |
30 | data_version_annotations = "posescript-H2" # defines what annotations to use as query examples
31 | data_version_poses_collection = "posescript-A2" # defines the set of poses to rank
32 |
33 |
34 | ### SETUP
35 | ################################################################################
36 |
37 | # --- data
38 | available_splits = ['train', 'val', 'test']
39 | model, tokenizer_name, body_model = demo.setup_models([args.model_path], args.checkpoint, load_model)
40 | model, tokenizer_name = model[0], tokenizer_name[0]
41 | dataID_2_pose_info, captions = demo.setup_posescript_data(data_version_annotations)
42 |
43 |
44 | ### MAIN APP
45 | ################################################################################
46 |
47 | # define query input interface: split selection
48 | cols_query = st.columns(3)
49 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test'))
50 |
51 | # precompute features
52 | dataIDs = demo.setup_posescript_split(split_for_research)
53 | pose_dataIDs, poses_features = demo.precompute_posescript_pose_features(data_version_poses_collection, split_for_research, model)
54 | text_dataIDs, text_features = demo.precompute_text_features(data_version_annotations, split_for_research, model, tokenizer_name)
55 |
56 | # define query input interface: example selection
57 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID'))
58 | number = cols_query[2].number_input("Split index or ID:", 0)
59 | st.markdown("""---""")
60 |
61 | # get query data
62 | pose_ID, pose_data, pose_img, default_description = demo.get_posescript_datapoint(number, query_type, split_for_research, captions, dataID_2_pose_info, body_model)
63 |
64 | # show query data
65 | cols_input = st.columns(2)
66 | cols_input[0].image(pose_img, caption="Annotated pose")
67 | if default_description:
68 | cols_input[1].write("Annotated text:")
69 | cols_input[1].write(f"_{default_description}_")
70 | else:
71 | cols_input[1].write("_(Not annotated.)_")
72 |
73 | # get retrieval direction
74 | dt2p = "Text-2-Pose"
75 | dp2t = "Pose-2-Text"
76 | retrieval_direction = st.radio("Retrieval direction:", [dt2p, dp2t])
77 |
78 | # TEXT-2-POSE
79 | if retrieval_direction == dt2p:
80 |
81 | # get input description
82 | description = cols_input[1].text_area("Pose description:",
83 | value=default_description,
84 | placeholder="The person is...",
85 | height=None, max_chars=None)
86 |
87 | # encode text
88 | with torch.no_grad():
89 | text_feature = model.encode_raw_text(description)
90 |
91 | # rank poses by relevance and get their pose id
92 | scores = text_feature.view(1, -1).mm(poses_features.t())[0]
93 | _, indices_rank = scores.sort(descending=True)
94 | relevant_pose_ids = [pose_dataIDs[i] for i in indices_rank[:args.n_retrieve]]
95 |
96 | # get corresponding pose data
97 | all_pose_data = []
98 | for pose_id in relevant_pose_ids:
99 | pose_info = dataID_2_pose_info[str(pose_id)]
100 | all_pose_data.append(utils.get_pose_data_from_file(pose_info))
101 | all_pose_data = torch.cat(all_pose_data)
102 |
103 | # render poses
104 | imgs = utils_visu.image_from_pose_data(all_pose_data, body_model, color="blue", add_ground_plane=True)
105 |
106 | # display images
107 | st.markdown("""---""")
108 | st.write(f"**Retrieved poses for this description [{split_for_research} split]:**")
109 | cols = st.columns(demo.nb_cols)
110 | for i in range(args.n_retrieve):
111 | cols[i%demo.nb_cols].image(demo.process_img(imgs[i]))
112 |
113 | # POSE-2-TEXT
114 | elif retrieval_direction == dp2t:
115 |
116 | # rank texts by relevance and get their id
117 | pose_index = pose_dataIDs.index(pose_ID)
118 | scores = poses_features[pose_index].view(1, -1).mm(text_features.t())[0]
119 | _, indices_rank = scores.sort(descending=True)
120 | relevant_pose_ids = [text_dataIDs[i] for i in indices_rank[:args.n_retrieve]]
121 |
122 | # get corresponding text data (the text features were obtained using the first text)
123 | texts = [captions[pose_id][0] for pose_id in relevant_pose_ids]
124 |
125 | # display texts
126 | st.markdown("""---""")
127 | st.write(f"**Retrieved descriptions for this pose [{split_for_research} split]:**")
128 | for i in range(args.n_retrieve):
129 | st.write(f"**({i+1})** {texts[i]}")
130 |
131 | st.markdown("""---""")
132 | st.write(f"_Results obtained with model: {args.model_path}_")
--------------------------------------------------------------------------------
/src/text2pose/retrieval/evaluate_retrieval.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import os
11 | import torch
12 | from tqdm import tqdm
13 |
14 | import text2pose.config as config
15 | import text2pose.evaluate as evaluate
16 | from text2pose.data import PoseScript
17 | from text2pose.encoders.tokenizers import get_tokenizer_name
18 | from text2pose.retrieval.model_retrieval import PoseText
19 | from text2pose.loss import BBC
20 |
21 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
22 |
23 | OVERWRITE_RESULT = False
24 |
25 |
26 | ################################################################################
27 |
28 | def load_model(model_path, device):
29 |
30 | assert os.path.isfile(model_path), "File {} not found.".format(model_path)
31 |
32 | # load checkpoint & model info
33 | ckpt = torch.load(model_path, 'cpu')
34 | text_encoder_name = ckpt['args'].text_encoder_name
35 | transformer_topping = getattr(ckpt['args'], 'transformer_topping', None)
36 | latentD = ckpt['args'].latentD
37 | num_body_joints = getattr(ckpt['args'], 'num_body_joints', 52)
38 |
39 | # load model
40 | model = PoseText(text_encoder_name=text_encoder_name,
41 | transformer_topping=transformer_topping,
42 | latentD=latentD,
43 | num_body_joints=num_body_joints
44 | ).to(device)
45 | model.load_state_dict(ckpt['model'])
46 | model.eval()
47 | print(f"Loaded model from (epoch {ckpt['epoch']}):", model_path)
48 |
49 | return model, get_tokenizer_name(text_encoder_name)
50 |
51 |
52 | def eval_model(model_path, dataset_version, split='val', generated_pose_samples=None):
53 |
54 | device = torch.device('cuda:0')
55 |
56 | # define result file & get auxiliary info
57 | generated_pose_samples_path, precision = get_evaluation_auxiliary_info(model_path, generated_pose_samples)
58 | nb_caps = config.caption_files[dataset_version][0]
59 | get_res_file = evaluate.get_result_filepath_func(model_path, split, dataset_version, precision, nb_caps)
60 |
61 | # load model if results for at least one caption is missing
62 | if OVERWRITE_RESULT or evaluate.one_result_file_is_missing(get_res_file, nb_caps):
63 | model, tokenizer_name = load_model(model_path, device)
64 |
65 | # compute or load results for the given run & caption
66 | results = {}
67 | for cap_ind in range(nb_caps):
68 | filename_res = get_res_file(cap_ind)
69 | if not os.path.isfile(filename_res) or OVERWRITE_RESULT:
70 | if "posescript" in dataset_version:
71 | d = PoseScript(version=dataset_version, split=split, tokenizer_name=tokenizer_name, caption_index=cap_ind, num_body_joints=model.pose_encoder.num_body_joints, cache=True, generated_pose_samples_path=generated_pose_samples_path)
72 | else:
73 | raise NotImplementedError
74 | cap_results = compute_eval_metrics(model, d, device)
75 | evaluate.save_results_to_file(cap_results, filename_res)
76 | else:
77 | cap_results = evaluate.load_results_from_file(filename_res)
78 | # aggregate results
79 | results = {k:[v] for k, v in cap_results.items()} if not results else {k:results[k]+[v] for k,v in cap_results.items()}
80 |
81 | # average over captions
82 | results = {k:sum(v)/nb_caps for k,v in results.items()}
83 |
84 | return {k:[v] for k, v in results.items()}
85 |
86 |
87 | def get_evaluation_auxiliary_info(model_path, generated_pose_samples):
88 | precision = "" # default
89 | generated_pose_samples_path = None # default
90 | if generated_pose_samples:
91 | precision = f"_gensample_{generated_pose_samples}"
92 | seed = evaluate.get_seed_from_model_path(model_path)
93 | generated_pose_samples_model_path = (config.shortname_2_model_path[generated_pose_samples]).format(seed=seed)
94 | generated_pose_samples_path = config.generated_pose_path % os.path.dirname(generated_pose_samples_model_path)
95 | return generated_pose_samples_path, precision
96 |
97 |
98 | def compute_eval_metrics(model, dataset, device, compute_loss=False):
99 |
100 | # get data features
101 | poses_features, texts_features = infer_features(model, dataset, device)
102 |
103 | # pose-2-text matching
104 | p2t_recalls = evaluate.x2y_recall_metrics(poses_features, texts_features, config.k_recall_values, sstr="p2t_")
105 | # text-2-pose matching
106 | t2p_recalls = evaluate.x2y_recall_metrics(texts_features, poses_features, config.k_recall_values, sstr="t2p_")
107 | # r-precision
108 | rprecisions = evaluate.textret_metrics(texts_features, poses_features)
109 |
110 | # gather metrics
111 | recalls = {"mRecall": (sum(p2t_recalls.values()) + sum(t2p_recalls.values())) / (2 * len(config.k_recall_values))}
112 | recalls.update(p2t_recalls)
113 | recalls.update(t2p_recalls)
114 | recalls.update(rprecisions)
115 |
116 | # loss
117 | if compute_loss:
118 | score_t2p = texts_features.mm(poses_features.t())
119 | loss = BBC(score_t2p*model.loss_weight)
120 | loss_value = loss.item()
121 | return recalls, loss_value
122 |
123 | return recalls
124 |
125 |
126 | def infer_features(model, dataset, device):
127 |
128 | batch_size = 32
129 | data_loader = torch.utils.data.DataLoader(
130 | dataset, sampler=None, shuffle=False,
131 | batch_size=batch_size,
132 | num_workers=8,
133 | pin_memory=True,
134 | drop_last=False
135 | )
136 |
137 | poses_features = torch.zeros(len(dataset), model.latentD).to(device)
138 | texts_features = torch.zeros(len(dataset), model.latentD).to(device)
139 |
140 | for i, batch in tqdm(enumerate(data_loader)):
141 | poses = batch['pose'].to(device)
142 | caption_tokens = batch['caption_tokens'].to(device)
143 | caption_lengths = batch['caption_lengths'].to(device)
144 | caption_tokens = caption_tokens[:,:caption_lengths.max()]
145 | with torch.inference_mode():
146 | pfeat, tfeat = model(poses, caption_tokens, caption_lengths)
147 | poses_features[i*batch_size:i*batch_size+len(poses)] = pfeat
148 | texts_features[i*batch_size:i*batch_size+len(poses)] = tfeat
149 |
150 | return poses_features, texts_features
151 |
152 |
153 | def display_results(results):
154 | metric_order = ['mRecall'] + ['%s_R@%d'%(d, k) for d in ['p2t', 't2p'] for k in config.k_recall_values]
155 | results = evaluate.scale_and_format_results(results)
156 | print(f"\n & {' & '.join([results[m] for m in metric_order])} \\\\\n")
157 |
158 |
159 | ################################################################################
160 |
161 | if __name__ == '__main__':
162 |
163 | # added special arguments
164 | evaluate.eval_parser.add_argument('--generated_pose_samples', default=None, help="Shortname for the model that generated the pose files to be used (full path registered in config.py")
165 |
166 | args = evaluate.eval_parser.parse_args()
167 | args = evaluate.get_full_model_path(args)
168 |
169 | # compute results
170 | if args.average_over_runs:
171 | ret = evaluate.eval_model_all_runs(eval_model, args.model_path, dataset_version=args.dataset, split=args.split, generated_pose_samples=args.generated_pose_samples)
172 | else:
173 | ret = eval_model(args.model_path, dataset_version=args.dataset, split=args.split, generated_pose_samples=args.generated_pose_samples)
174 |
175 | # display results
176 | print(ret)
177 | display_results(ret)
--------------------------------------------------------------------------------
/src/text2pose/retrieval/model_retrieval.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import torch
11 | from torch import nn
12 |
13 | import text2pose.config as config
14 | from text2pose.encoders.tokenizers import Tokenizer, get_text_encoder_or_decoder_module_name, get_tokenizer_name
15 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder
16 | from text2pose.encoders.text_encoders import TextEncoder, TransformerTextEncoder
17 |
18 |
19 | class PoseText(nn.Module):
20 | def __init__(self, num_neurons=512, num_neurons_mini=32, latentD=512,
21 | num_body_joints=config.NB_INPUT_JOINTS,
22 | text_encoder_name='distilbertUncased', transformer_topping=None):
23 | super(PoseText, self).__init__()
24 |
25 | self.latentD = latentD
26 |
27 | # Define pose encoder
28 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons,
29 | num_neurons_mini=num_neurons_mini,
30 | latentD=latentD,
31 | num_body_joints=num_body_joints,
32 | role="retrieval")
33 |
34 | # Define text encoder
35 | self.text_encoder_name = text_encoder_name
36 | module_ref = get_text_encoder_or_decoder_module_name(text_encoder_name)
37 | if module_ref in ["glovebigru"]:
38 | self.text_encoder = TextEncoder(self.text_encoder_name, latentD=latentD, role="retrieval")
39 | elif module_ref in ["glovetransf", "distilbertUncased"]:
40 | self.text_encoder = TransformerTextEncoder(self.text_encoder_name, latentD=latentD, topping=transformer_topping, role="retrieval")
41 | else:
42 | raise NotImplementedError
43 |
44 | # Loss temperature
45 | self.loss_weight = torch.nn.Parameter( torch.FloatTensor((10,)) )
46 | self.loss_weight.requires_grad = True
47 |
48 | def forward(self, pose, captions, caption_lengths):
49 | pose_embs = self.pose_encoder(pose)
50 | text_embs = self.text_encoder(captions, caption_lengths)
51 | return pose_embs, text_embs
52 |
53 | def encode_raw_text(self, raw_text):
54 | if not hasattr(self, 'tokenizer'):
55 | self.tokenizer = Tokenizer(get_tokenizer_name(self.text_encoder_name))
56 | tokens = self.tokenizer(raw_text).to(device=self.loss_weight.device)
57 | length = torch.tensor([ len(tokens) ], dtype=tokens.dtype)
58 | text_embs = self.text_encoder(tokens.view(1, -1), length)
59 | return text_embs
60 |
61 | def encode_pose(self, pose):
62 | return self.pose_encoder(pose)
63 |
64 | def encode_text(self, captions, caption_lengths):
65 | return self.text_encoder(captions, caption_lengths)
--------------------------------------------------------------------------------
/src/text2pose/retrieval/script_retrieval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ##############################################################
4 | ## text2pose ##
5 | ## Copyright (c) 2022, 2023 ##
6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
7 | ## and Naver Corporation ##
8 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
9 | ## See project root for license details. ##
10 | ##############################################################
11 |
12 |
13 | ##############################################################
14 | # SCRIPT ARGUMENTS
15 |
16 | action=$1 # (train|eval|demo)
17 | checkpoint_type="best" # (last|best)
18 |
19 | architecture_args=(
20 | --model PoseText
21 | --latentD 512
22 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp"
23 | # --text_encoder_name 'glovebigru_vocPSA2H2'
24 | )
25 |
26 | loss_args=(
27 | --retrieval_loss 'symBBC'
28 | )
29 |
30 | bonus_args=(
31 | )
32 |
33 | pretrained="ret_distilbert_dataPSA2" # used only if phase=='finetune'
34 |
35 |
36 | ##############################################################
37 | # EXECUTE
38 |
39 | # TRAIN
40 | if [[ "$action" == *"train"* ]]; then
41 |
42 | phase=$2 # (pretrain|finetune)
43 | echo "NOTE: Expecting as argument the training phase. Got: $phase"
44 | seed=$3
45 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
46 |
47 | # PRETRAIN
48 | if [[ "$phase" == *"pretrain"* ]]; then
49 |
50 | python retrieval/train_retrieval.py --dataset "posescript-A2" \
51 | "${architecture_args[@]}" \
52 | "${loss_args[@]}" \
53 | "${bonus_args[@]}" \
54 | --lr_scheduler "stepLR" --lr 0.0002 --lr_step 400 --lr_gamma 0.5 \
55 | --log_step 20 --val_every 20 \
56 | --batch_size 512 --epochs 1000 --seed $seed
57 |
58 | # FINETUNE
59 | elif [[ "$phase" == *"finetune"* ]]; then
60 |
61 | python retrieval/train_retrieval.py --dataset "posescript-H2" \
62 | "${architecture_args[@]}" \
63 | "${loss_args[@]}" \
64 | "${bonus_args[@]}" \
65 | --apply_LR_augmentation \
66 | --lr_scheduler "stepLR" --lr 0.0002 --lr_step 40 --lr_gamma 0.5 \
67 | --batch_size 32 --epochs 200 --seed $seed \
68 | --pretrained $pretrained
69 |
70 | fi
71 |
72 | fi
73 |
74 |
75 | # EVAL QUANTITATIVELY
76 | if [[ "$action" == *"eval"* ]]; then
77 |
78 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one
79 |
80 | for model_path in "${experiments[@]}"
81 | do
82 | echo $model_path
83 | python retrieval/evaluate_retrieval.py --dataset "posescript-H2" \
84 | --model_path ${model_path} --checkpoint $checkpoint_type \
85 | --split test
86 | done
87 | fi
88 |
89 |
90 | # EVAL QUALITATIVELY
91 | if [[ "$action" == *"demo"* ]]; then
92 |
93 | experiment=$2 # only one at a time
94 | streamlit run retrieval/demo_retrieval.py -- --model_path $experiment --checkpoint $checkpoint_type
95 |
96 | fi
--------------------------------------------------------------------------------
/src/text2pose/retrieval/train_retrieval.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2022, 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import torch
11 | import math
12 | import sys
13 | import os
14 | os.umask(0x0002)
15 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
16 |
17 | from text2pose.option import get_args_parser
18 | from text2pose.trainer import GenericTrainer
19 | from text2pose.retrieval.model_retrieval import PoseText
20 | from text2pose.retrieval.evaluate_retrieval import compute_eval_metrics
21 | from text2pose.loss import BBC, symBBC
22 | from text2pose.data import PoseScript, PoseFix
23 | from text2pose.encoders.tokenizers import get_tokenizer_name
24 | from text2pose.data_augmentations import DataAugmentation
25 |
26 | import text2pose.config as config
27 | import text2pose.utils_logging as logging
28 |
29 |
30 | ################################################################################
31 |
32 |
33 | class PoseTextTrainer(GenericTrainer):
34 |
35 | def __init__(self, args):
36 | super(PoseTextTrainer, self).__init__(args, retrieval_trainer=True)
37 |
38 |
39 | def load_dataset(self, split, caption_index, tokenizer_name=None):
40 |
41 | if tokenizer_name is None: tokenizer_name = get_tokenizer_name(self.args.text_encoder_name)
42 | data_size = self.args.data_size if split=="train" else None
43 |
44 | if "posescript" in self.args.dataset:
45 | d = PoseScript(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size)
46 | elif "posefix" in self.args.dataset:
47 | d = PoseFix(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size, posescript_format=True)
48 | else:
49 | raise NotImplementedError
50 | return d
51 |
52 |
53 | def init_model(self):
54 | print('Load model')
55 | self.model = PoseText(text_encoder_name=self.args.text_encoder_name,
56 | transformer_topping=self.args.transformer_topping,
57 | latentD=self.args.latentD,
58 | num_body_joints=self.args.num_body_joints)
59 | self.model.to(self.device)
60 |
61 |
62 | def get_param_groups(self):
63 | param_groups = []
64 | param_groups.append({'params': self.model.pose_encoder.parameters(), 'lr': self.args.lr*self.args.lrposemul})
65 | param_groups.append({'params': [p for k,p in self.model.text_encoder.named_parameters() if 'pretrained_text_encoder.' not in k], 'lr': self.args.lr*self.args.lrtextmul})
66 | param_groups.append({'params': [self.model.loss_weight]})
67 | return param_groups
68 |
69 |
70 | def init_optimizer(self):
71 | assert self.args.optimizer=='Adam'
72 | param_groups = self.get_param_groups()
73 | self.optimizer = torch.optim.Adam(param_groups, lr=self.args.lr)
74 |
75 |
76 | def init_lr_scheduler(self):
77 | self.lr_scheduler = None
78 | if self.args.lr_scheduler == "stepLR":
79 | self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
80 | step_size=self.args.lr_step,
81 | gamma=self.args.lr_gamma,
82 | last_epoch=-1)
83 |
84 |
85 | def init_other_training_elements(self):
86 | self.data_augmentation_module = DataAugmentation(self.args, mode="posescript", tokenizer_name=get_tokenizer_name(self.args.text_encoder_name), nb_joints=self.args.num_body_joints)
87 |
88 |
89 | def training_epoch(self, epoch):
90 | train_stats = self.one_epoch(epoch=epoch, is_training=True)
91 | return train_stats
92 |
93 |
94 | def validation_epoch(self, epoch):
95 | val_stats = {}
96 | if self.args.val_every and (epoch+1)%self.args.val_every==0:
97 | val_stats = self.validate(epoch=epoch)
98 | return val_stats
99 |
100 |
101 | def one_epoch(self, epoch, is_training):
102 |
103 | self.model.train(is_training)
104 |
105 | # define loggers
106 | metric_logger = logging.MetricLogger(delimiter=" ")
107 | if is_training:
108 | prefix, sstr = '', 'train'
109 | metric_logger.add_meter(f'{sstr}_lr', logging.SmoothedValue(window_size=1, fmt='{value:.6f}'))
110 | else:
111 | prefix, sstr = '[val] ', 'val'
112 | header = f'{prefix}Epoch: [{epoch}]'
113 |
114 | # define dataloader & other elements
115 | if is_training:
116 | data_loader = self.data_loader_train
117 | if not is_training:
118 | data_loader = self.data_loader_val
119 |
120 | # iterate over the batches
121 | for data_iter_step, item in enumerate(metric_logger.log_every(data_loader, self.args.log_step, header)):
122 |
123 | # get data
124 | poses = item['pose'].to(self.device)
125 | caption_tokens = item['caption_tokens'].to(self.device)
126 | caption_lengths = item['caption_lengths'].to(self.device)
127 | caption_tokens = caption_tokens[:,:caption_lengths.max()] # truncate within the batch, based on the longest text
128 |
129 | # online random augmentations
130 | poses, caption_tokens, caption_lengths = self.data_augmentation_module(poses, caption_tokens, caption_lengths)
131 |
132 | # forward; compute scores
133 | with torch.set_grad_enabled(is_training):
134 | poses_features, texts_features = self.model(poses, caption_tokens, caption_lengths)
135 | score_t2p = texts_features.mm(poses_features.t()) * self.model.loss_weight
136 |
137 | # compute loss
138 | if self.args.retrieval_loss == "BBC":
139 | loss = BBC(score_t2p)
140 | elif self.args.retrieval_loss == "symBBC":
141 | loss = symBBC(score_t2p)
142 | else:
143 | raise NotImplementedError
144 |
145 | loss_value = loss.item()
146 | if not math.isfinite(loss_value):
147 | print("Loss is {}, stopping training".format(loss_value))
148 | sys.exit(1)
149 |
150 | # training step
151 | if is_training:
152 | self.optimizer.zero_grad()
153 | loss.backward()
154 | self.optimizer.step()
155 |
156 | # format data for logging
157 | scalars = [('loss', loss_value)]
158 | if is_training:
159 | lr_value = self.optimizer.param_groups[0]["lr"]
160 | scalars += [('lr', lr_value)]
161 |
162 | # actually log
163 | self.add_data_to_log_writer(epoch, sstr, scalars=scalars, is_training=is_training, data_iter_step=data_iter_step, total_steps=len(data_loader))
164 | self.add_data_to_metric_logger(metric_logger, sstr, scalars)
165 |
166 | print("Averaged stats:", metric_logger)
167 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
168 |
169 |
170 | def validate(self, epoch):
171 |
172 | self.model.eval()
173 |
174 | recalls, loss_value = compute_eval_metrics(self.model, self.data_loader_val.dataset, self.device, compute_loss=True)
175 | val_stats = {"loss": loss_value}
176 | val_stats.update(recalls)
177 |
178 | # log
179 | self.add_data_to_log_writer(epoch, 'val', scalars=[('loss', loss_value), ('validation', recalls)], should_log_data=True)
180 | print(f"[val] Epoch: [{epoch}] Stats: " + " ".join(f"{k}: {round(v, 3)}" for k,v in val_stats.items()) )
181 | return val_stats
182 |
183 |
184 | if __name__ == '__main__':
185 |
186 | argparser = get_args_parser()
187 | args = argparser.parse_args()
188 |
189 | PoseTextTrainer(args)()
--------------------------------------------------------------------------------
/src/text2pose/retrieval_modifier/README.md:
--------------------------------------------------------------------------------
1 | # {Pose-Pair :left_right_arrow: Instruction} Retrieval Model
2 |
3 | _:warning: In what follows, command lines are assumed to be launched from `./src/text2pose`._
4 |
5 | ## Model overview
6 |
7 | **Possible inputs**: a pair of 3D human poses (ie. two elements at once), text instruction.
8 |
9 | 
10 |
11 | ## :crystal_ball: Demo
12 |
13 | To look at a ranking of text instruction (resp. pose pair) referenced in PoseFix by relevance to a chosen pose pair (resp. text instruction), using a pretrained model, run the following:
14 |
15 | ```
16 | bash retrieval_modifier/script_retrieval_modifier.sh 'demo'
17 | ```
18 |
19 | ## :bullettrain_front: Train
20 |
21 | :memo: Modify the variables at the top of the bash script to specify the desired model & training options.
22 |
23 | Then use the following command:
24 | ```
25 | bash retrieval_modifier/script_retrieval_modifier.sh 'train'
26 | ```
27 |
28 | **Note for the finetuning step**: In the script, `pretrained` defines the nickname of the pretrained model. The mapping between nicknames and actual model paths is given by *shortname_2_model_path.txt*. This means that if you train a model and intend to use its weights to train another, you should first write its path in *shortname_2_model_path.txt*, give it a nickname, and write this nickname in front of the `pretrained` argument in the script. The nickname will appear in the path of the finetuned model.
29 |
30 | ## :dart: Evaluate
31 |
32 | Use the following command (test on PoseFix-H):
33 | ```
34 | bash retrieval_modifier/script_retrieval_modifier.sh 'eval'
35 | ```
36 |
--------------------------------------------------------------------------------
/src/text2pose/retrieval_modifier/__init__.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
--------------------------------------------------------------------------------
/src/text2pose/retrieval_modifier/demo_retrieval_modifier.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import streamlit as st
11 | import argparse
12 | import torch
13 |
14 | import text2pose.config as config
15 | import text2pose.demo as demo
16 | import text2pose.utils as utils
17 | import text2pose.utils_visu as utils_visu
18 | from text2pose.retrieval_modifier.evaluate_retrieval_modifier import load_model
19 |
20 |
21 | parser = argparse.ArgumentParser(description='Parameters for the demo.')
22 | parser.add_argument('--model_path', type=str, help='Path to the model.')
23 | parser.add_argument('--checkpoint', default='best', choices=('best', 'last'), help='Checkpoint to choose if model path is incomplete.')
24 | parser.add_argument('--n_retrieve', type=int, default=12, help="Number of elements to retrieve.")
25 | args = parser.parse_args()
26 |
27 | args.n_retrieve = 12
28 |
29 | ### INPUT
30 | ################################################################################
31 |
32 | data_version_annotations = "posefix-H" # defines what annotations to use as query examples
33 | data_version_poses_collection = "posefix-A" # defines the set of poses to rank
34 |
35 |
36 | ### SETUP
37 | ################################################################################
38 |
39 | # --- data
40 | available_splits = ['train', 'val', 'test']
41 | model, tokenizer_name, body_model = demo.setup_models([args.model_path], args.checkpoint, load_model)
42 | model, tokenizer_name = model[0], tokenizer_name[0]
43 | dataID_2_pose_info, triplet_data = demo.setup_posefix_data(data_version_annotations)
44 | pose_pairs = utils.read_json(config.file_pair_id_2_pose_ids)
45 |
46 |
47 | ### MAIN APP
48 | ################################################################################
49 |
50 | # define query input interface: split selection
51 | cols_query = st.columns(3)
52 | split_for_research = cols_query[0].selectbox('Split:', tuple(available_splits), index=available_splits.index('test'))
53 |
54 | # precompute features
55 | dataIDs = demo.setup_posefix_split(split_for_research)
56 | pair_dataIDs, pairs_features = demo.precompute_posefix_pair_features(data_version_poses_collection, split_for_research, model)
57 | text_dataIDs, text_features = demo.precompute_text_features(data_version_annotations, split_for_research, model, tokenizer_name)
58 |
59 | # define query input interface: example selection
60 | query_type = cols_query[1].selectbox("Query type:", ('Split index', 'ID'))
61 | number = cols_query[2].number_input("Split index or ID:", 0)
62 | st.markdown("""---""")
63 |
64 | # get query data
65 | pair_ID, pid_A, pid_B, pose_A_data, pose_B_data, pose_A_img, pose_B_img, default_modifier = demo.get_posefix_datapoint(number, query_type, split_for_research, triplet_data, pose_pairs, dataID_2_pose_info, body_model)
66 |
67 | # show query data
68 | cols_input = st.columns([1,1,2])
69 | cols_input[0].image(pose_A_img, caption="Pose A")
70 | cols_input[1].image(pose_B_img, caption="Pose B")
71 | if default_modifier:
72 | cols_input[2].write("Annotated text:")
73 | cols_input[2].write(f"_{default_modifier}_")
74 | else:
75 | cols_input[2].write("_(Not annotated.)_")
76 |
77 | # get retrieval direction
78 | dt2p = "Text-2-Pair"
79 | dp2t = "Pair-2-Text"
80 | retrieval_direction = st.radio("Retrieval direction:", [dt2p, dp2t])
81 |
82 | # TEXT-2-PAIR
83 | if retrieval_direction == dt2p:
84 |
85 | # get input modifier
86 | modifier = cols_input[2].text_area("Pose modifier:",
87 | placeholder="Move your right arm... lift your left leg...",
88 | value=default_modifier,
89 | height=None, max_chars=None)
90 | # encode text
91 | with torch.no_grad():
92 | text_feature = model.encode_raw_text(modifier)
93 |
94 | # rank poses by relevance and get their pose id
95 | scores = text_feature.view(1, -1).mm(pairs_features.t())[0]
96 | _, indices_rank = scores.sort(descending=True)
97 | relevant_pair_ids = [pair_dataIDs[i] for i in indices_rank[:args.n_retrieve]]
98 |
99 | # get corresponding pair data and render the pairs as images
100 | imgs = []
101 | for pair_id in relevant_pair_ids:
102 | ret_pid_A, ret_pid_B = pose_pairs[pair_id]
103 | pose_A_info = dataID_2_pose_info[str(ret_pid_A)]
104 | pose_A_data, rA = utils.get_pose_data_from_file(pose_A_info, output_rotation=True)
105 | pose_B_info = dataID_2_pose_info[str(ret_pid_B)]
106 | pose_B_data = utils.get_pose_data_from_file(pose_B_info, applied_rotation=rA if pose_A_info[1]==pose_B_info[1] else None)
107 | imgs.append(utils_visu.image_from_pair_data(pose_A_data, pose_B_data, body_model, add_ground_plane=True))
108 |
109 | # display images
110 | st.markdown("""---""")
111 | st.write(f"**Retrieved pairs for this modifier [{split_for_research} split]:**")
112 | cols = st.columns(demo.nb_cols)
113 | for i in range(args.n_retrieve):
114 | cols[i%demo.nb_cols].image(demo.process_img(imgs[i]))
115 |
116 | # PAIR-2-TEXT
117 | elif retrieval_direction == dp2t:
118 |
119 | # rank texts by relevance and get their id
120 | pair_index = pair_dataIDs.index(pair_ID)
121 | scores = pairs_features[pair_index].view(1, -1).mm(text_features.t())[0]
122 | _, indices_rank = scores.sort(descending=True)
123 | relevant_pair_ids = [text_dataIDs[i] for i in indices_rank[:args.n_retrieve]]
124 |
125 | # get corresponding text data (the text features were obtained using the first text)
126 | texts = [triplet_data[pair_id]['modifier'][0] for pair_id in relevant_pair_ids]
127 |
128 | # display texts
129 | st.markdown("""---""")
130 | st.write(f"**Retrieved modifiers for this pair [{split_for_research} split]:**")
131 | for i in range(args.n_retrieve):
132 | st.write(f"**({i+1})** {texts[i]}")
133 |
134 | st.markdown("""---""")
135 | st.write(f"_Results obtained with model: {args.model_path}_")
--------------------------------------------------------------------------------
/src/text2pose/retrieval_modifier/evaluate_retrieval_modifier.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import os
11 | import torch
12 | from tqdm import tqdm
13 |
14 | import text2pose.config as config
15 | import text2pose.evaluate as evaluate
16 | from text2pose.data import PoseFix
17 | from text2pose.encoders.tokenizers import get_tokenizer_name
18 | from text2pose.retrieval_modifier.model_retrieval_modifier import PairText
19 | from text2pose.loss import BBC
20 |
21 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
22 |
23 | OVERWRITE_RESULT = False
24 |
25 |
26 | ################################################################################
27 |
28 | def load_model(model_path, device):
29 |
30 | assert os.path.isfile(model_path), "File {} not found.".format(model_path)
31 |
32 | # load checkpoint & model info
33 | ckpt = torch.load(model_path, 'cpu')
34 | text_encoder_name = ckpt['args'].text_encoder_name
35 | transformer_topping = ckpt['args'].transformer_topping
36 | latentD = ckpt['args'].latentD
37 | num_body_joints = getattr(ckpt['args'], 'num_body_joints', 52)
38 |
39 | # load model
40 | model = PairText(text_encoder_name=text_encoder_name,
41 | transformer_topping=transformer_topping,
42 | latentD=latentD,
43 | num_body_joints=num_body_joints
44 | ).to(device)
45 | model.load_state_dict(ckpt['model'])
46 | model.eval()
47 | print(f"Loaded model from (epoch {ckpt['epoch']}):", model_path)
48 |
49 | return model, get_tokenizer_name(text_encoder_name)
50 |
51 |
52 | def eval_model(model_path, dataset_version, split='val'):
53 |
54 | device = torch.device('cuda:0')
55 |
56 | # define result file
57 | precision = "" # default
58 | nb_caps = config.caption_files[dataset_version][0]
59 | get_res_file = evaluate.get_result_filepath_func(model_path, split, dataset_version, precision, nb_caps)
60 |
61 | # load model if results for at least one caption is missing
62 | if OVERWRITE_RESULT or evaluate.one_result_file_is_missing(get_res_file, nb_caps):
63 | model, tokenizer_name = load_model(model_path, device)
64 |
65 | # compute or load results for the given run & caption
66 | results = {}
67 | for cap_ind in range(nb_caps):
68 | filename_res = get_res_file(cap_ind)
69 | if not os.path.isfile(filename_res) or OVERWRITE_RESULT:
70 | if "posefix" in dataset_version:
71 | d = PoseFix(version=dataset_version, split=split, tokenizer_name=tokenizer_name, caption_index=cap_ind, num_body_joints=model.pose_encoder.num_body_joints, cache=True)
72 | else:
73 | raise NotImplementedError
74 | cap_results = compute_eval_metrics(model, d, device)
75 | evaluate.save_results_to_file(cap_results, filename_res)
76 | else:
77 | cap_results = evaluate.load_results_from_file(filename_res)
78 | # aggregate results
79 | results = {k:[v] for k, v in cap_results.items()} if not results else {k:results[k]+[v] for k,v in cap_results.items()}
80 |
81 | # average over captions
82 | results = {k:sum(v)/nb_caps for k,v in results.items()}
83 |
84 | return {k:[v] for k, v in results.items()}
85 |
86 |
87 | def compute_eval_metrics(model, dataset, device, compute_loss=False):
88 |
89 | # get data features
90 | poses_features, texts_features = infer_features(model, dataset, device)
91 |
92 | # poses-2-text matching
93 | p2t_recalls = evaluate.x2y_recall_metrics(poses_features, texts_features, config.k_recall_values, sstr="p2t_")
94 | # text-2-poses matching
95 | t2p_recalls = evaluate.x2y_recall_metrics(texts_features, poses_features, config.k_recall_values, sstr="t2p_")
96 | # r-precision
97 | rprecisions = evaluate.textret_metrics(texts_features, poses_features)
98 |
99 | # gather metrics
100 | recalls = {"mRecall": (sum(p2t_recalls.values()) + sum(t2p_recalls.values())) / (2 * len(config.k_recall_values))}
101 | recalls.update(p2t_recalls)
102 | recalls.update(t2p_recalls)
103 | recalls.update(rprecisions)
104 |
105 | # loss
106 | if compute_loss:
107 | score_t2p = texts_features.mm(poses_features.t())
108 | loss = BBC(score_t2p*model.loss_weight)
109 | loss_value = loss.item()
110 | return recalls, loss_value
111 |
112 | return recalls
113 |
114 |
115 | def infer_features(model, dataset, device):
116 |
117 | batch_size = 32
118 | data_loader = torch.utils.data.DataLoader(
119 | dataset, sampler=None, shuffle=False,
120 | batch_size=batch_size,
121 | num_workers=8,
122 | pin_memory=True,
123 | drop_last=False
124 | )
125 |
126 | poses_features = torch.zeros(len(dataset), model.latentD).to(device)
127 | texts_features = torch.zeros(len(dataset), model.latentD).to(device)
128 |
129 | for i, batch in tqdm(enumerate(data_loader)):
130 | poses_A = batch['poses_A'].to(device)
131 | poses_B = batch['poses_B'].to(device)
132 | caption_tokens = batch['caption_tokens'].to(device)
133 | caption_lengths = batch['caption_lengths'].to(device)
134 | caption_tokens = caption_tokens[:,:caption_lengths.max()]
135 | with torch.inference_mode():
136 | pfeat, tfeat = model(poses_A, caption_tokens, caption_lengths, poses_B)
137 | poses_features[i*batch_size:i*batch_size+len(poses_A)] = pfeat
138 | texts_features[i*batch_size:i*batch_size+len(poses_A)] = tfeat
139 |
140 | return poses_features, texts_features
141 |
142 |
143 | def display_results(results):
144 | metric_order = ['mRecall'] + ['%s_R@%d'%(d, k) for d in ['p2t', 't2p'] for k in config.k_recall_values]
145 | results = evaluate.scale_and_format_results(results)
146 | print(f"\n & {' & '.join([results[m] for m in metric_order])} \\\\\n")
147 |
148 |
149 | ################################################################################
150 |
151 | if __name__ == '__main__':
152 |
153 | args = evaluate.eval_parser.parse_args()
154 | args = evaluate.get_full_model_path(args)
155 |
156 | # compute results
157 | if args.average_over_runs:
158 | ret = evaluate.eval_model_all_runs(eval_model, args.model_path, dataset_version=args.dataset, split=args.split)
159 | else:
160 | ret = eval_model(args.model_path, dataset_version=args.dataset, split=args.split)
161 |
162 | # display results
163 | print(ret)
164 | display_results(ret)
--------------------------------------------------------------------------------
/src/text2pose/retrieval_modifier/model_retrieval_modifier.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import torch
11 | from torch import nn
12 |
13 | import text2pose.config as config
14 | from text2pose.encoders.tokenizers import Tokenizer, get_text_encoder_or_decoder_module_name, get_tokenizer_name
15 | from text2pose.encoders.modules import ConCatModule, L2Norm
16 | from text2pose.encoders.pose_encoder_decoder import PoseEncoder
17 | from text2pose.encoders.text_encoders import TextEncoder, TransformerTextEncoder
18 |
19 |
20 | class PairText(nn.Module):
21 | def __init__(self, num_neurons=512, num_neurons_mini=32, latentD=512,
22 | num_body_joints=config.NB_INPUT_JOINTS,
23 | text_encoder_name='distilbertUncased', transformer_topping=None):
24 | super(PairText, self).__init__()
25 |
26 | self.latentD = latentD
27 |
28 | # Define pose encoder
29 | self.pose_encoder = PoseEncoder(num_neurons=num_neurons,
30 | num_neurons_mini=num_neurons_mini,
31 | latentD=latentD,
32 | num_body_joints=num_body_joints,
33 | role="retrieval")
34 |
35 | # Define text encoder
36 | self.text_encoder_name = text_encoder_name
37 | module_ref = get_text_encoder_or_decoder_module_name(text_encoder_name)
38 | if module_ref in ["glovebigru"]:
39 | self.text_encoder = TextEncoder(self.text_encoder_name, latentD=latentD, role="retrieval")
40 | elif module_ref in ["glovetransf", "distilbertUncased"]:
41 | self.text_encoder = TransformerTextEncoder(self.text_encoder_name, latentD=latentD, topping=transformer_topping, role="retrieval")
42 | else:
43 | raise NotImplementedError
44 |
45 | # Define projecting layers
46 | self.pose_mlp = nn.Sequential(
47 | ConCatModule(),
48 | nn.Linear(2 * latentD, 2 * latentD),
49 | nn.LeakyReLU(),
50 | nn.Linear(2 * latentD, latentD),
51 | nn.LeakyReLU(),
52 | nn.Linear(latentD, latentD),
53 | nn.LeakyReLU(),
54 | L2Norm()
55 | )
56 |
57 | # Loss temperature
58 | self.loss_weight = torch.nn.Parameter( torch.FloatTensor((10,)) )
59 | self.loss_weight.requires_grad = True
60 |
61 | def forward(self, poses_A, captions, caption_lengths, poses_B):
62 | embed_AB = self.encode_pose_pair(poses_A, poses_B)
63 | text_embs = self.encode_text(captions, caption_lengths)
64 | return embed_AB, text_embs
65 |
66 | def encode_raw_text(self, raw_text):
67 | if not hasattr(self, 'tokenizer'):
68 | self.tokenizer = Tokenizer(get_tokenizer_name(self.text_encoder_name))
69 | tokens = self.tokenizer(raw_text).to(device=self.loss_weight.device)
70 | length = torch.tensor([ len(tokens) ], dtype=tokens.dtype)
71 | text_embs = self.text_encoder(tokens.view(1, -1), length)
72 | return text_embs
73 |
74 | def encode_pose_pair(self, poses_A, poses_B):
75 | embed_poses_A = self.pose_encoder(poses_A)
76 | embed_poses_B = self.pose_encoder(poses_B)
77 | embed_AB = self.pose_mlp([embed_poses_A, embed_poses_B])
78 | return embed_AB
79 |
80 | def encode_text(self, captions, caption_lengths):
81 | return self.text_encoder(captions, caption_lengths)
--------------------------------------------------------------------------------
/src/text2pose/retrieval_modifier/script_retrieval_modifier.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ##############################################################
4 | ## text2pose ##
5 | ## Copyright (c) 2023 ##
6 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
7 | ## and Naver Corporation ##
8 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
9 | ## See project root for license details. ##
10 | ##############################################################
11 |
12 |
13 | ##############################################################
14 | # SCRIPT ARGUMENTS
15 |
16 | action=$1 # (train|eval|demo)
17 | checkpoint_type="best" # (last|best)
18 |
19 | architecture_args=(
20 | --model PairText
21 | --latentD 512
22 | --text_encoder_name 'distilbertUncased' --transformer_topping "avgp"
23 | # --text_encoder_name 'glovebigru_vocPFAHPP'
24 | )
25 |
26 | loss_args=(
27 | --retrieval_loss 'symBBC'
28 | )
29 |
30 | bonus_args=(
31 | )
32 |
33 | pretrained="modret_distilbert_dataPFA" # used only if phase=='finetune'
34 |
35 |
36 | ##############################################################
37 | # EXECUTE
38 |
39 | # TRAIN
40 | if [[ "$action" == *"train"* ]]; then
41 |
42 | phase=$2 # (pretrain|finetune)
43 | echo "NOTE: Expecting as argument the training phase. Got: $phase"
44 | seed=$3
45 | echo "NOTE: Expecting as argument the seed value. Got: $seed"
46 |
47 | # PRETRAIN
48 | if [[ "$phase" == *"pretrain"* ]]; then
49 |
50 | python retrieval_modifier/train_retrieval_modifier.py --dataset "posefix-A" \
51 | "${architecture_args[@]}" \
52 | "${loss_args[@]}" \
53 | "${bonus_args[@]}" \
54 | --lr_scheduler "stepLR" --lr 0.00005 --lr_step 400 --lr_gamma 0.5 \
55 | --log_step 20 --val_every 20 \
56 | --batch_size 128 --epochs 400 --seed $seed
57 |
58 | # FINETUNE
59 | elif [[ "$phase" == *"finetune"* ]]; then
60 |
61 | python retrieval_modifier/train_retrieval_modifier.py --dataset "posefix-H" \
62 | "${architecture_args[@]}" \
63 | "${loss_args[@]}" \
64 | "${bonus_args[@]}" \
65 | --apply_LR_augmentation \
66 | --lr_scheduler "stepLR" --lr 0.00005 --lr_step 40 --lr_gamma 0.5 \
67 | --batch_size 128 --epochs 50 --seed $seed \
68 | --pretrained $pretrained
69 |
70 | fi
71 |
72 | fi
73 |
74 |
75 | # EVAL QUANTITATIVELY
76 | if [[ "$action" == *"eval"* ]]; then
77 |
78 | shift; experiments=( "$@" ) # gets all the arguments starting from the 2nd one
79 |
80 | for model_path in "${experiments[@]}"
81 | do
82 | echo $model_path
83 | python retrieval_modifier/evaluate_retrieval_modifier.py --dataset "posefix-H" \
84 | --model_path ${model_path} --checkpoint $checkpoint_type \
85 | --split test
86 | done
87 | fi
88 |
89 |
90 | # EVAL QUALITATIVELY
91 | if [[ "$action" == *"demo"* ]]; then
92 |
93 | experiment=$2 # only one at a time
94 | streamlit run retrieval_modifier/demo_retrieval_modifier.py -- --model_path $experiment --checkpoint $checkpoint_type
95 |
96 | fi
--------------------------------------------------------------------------------
/src/text2pose/retrieval_modifier/train_retrieval_modifier.py:
--------------------------------------------------------------------------------
1 | ##############################################################
2 | ## text2pose ##
3 | ## Copyright (c) 2023 ##
4 | ## Institut de Robotica i Informatica Industrial, CSIC-UPC ##
5 | ## and Naver Corporation ##
6 | ## Licensed under the CC BY-NC-SA 4.0 license. ##
7 | ## See project root for license details. ##
8 | ##############################################################
9 |
10 | import torch
11 | import math
12 | import sys
13 | import os
14 | os.umask(0x0002)
15 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
16 |
17 | from text2pose.option import get_args_parser
18 | from text2pose.trainer import GenericTrainer
19 | from text2pose.retrieval_modifier.model_retrieval_modifier import PairText
20 | from text2pose.retrieval_modifier.evaluate_retrieval_modifier import compute_eval_metrics
21 | from text2pose.loss import BBC, symBBC
22 | from text2pose.data import PoseFix, PoseMix, PoseScript
23 | from text2pose.encoders.tokenizers import get_tokenizer_name
24 | from text2pose.data_augmentations import DataAugmentation
25 |
26 | import text2pose.config as config
27 | import text2pose.utils_logging as logging
28 |
29 |
30 | ################################################################################
31 |
32 |
33 | class PairTextTrainer(GenericTrainer):
34 |
35 | def __init__(self, args):
36 | super(PairTextTrainer, self).__init__(args, retrieval_trainer=True)
37 |
38 |
39 | def load_dataset(self, split, caption_index, tokenizer_name=None):
40 |
41 | if tokenizer_name is None: tokenizer_name = get_tokenizer_name(self.args.text_encoder_name)
42 | data_size = self.args.data_size if split=="train" else None
43 |
44 | if "posefix" in self.args.dataset:
45 | d = PoseFix(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size)
46 | elif "posemix" in self.args.dataset:
47 | # NOTE: if specifying data_size: only the first loaded data items
48 | # will be considered (since PoseFix is loaded before PoseScript, if
49 | # data_size < the size of PoseFix, no PoseScript data will be
50 | # loaded)
51 | d = PoseMix(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size)
52 | elif "posescript" in self.args.dataset:
53 | d = PoseScript(version=self.args.dataset, split=split, tokenizer_name=tokenizer_name, caption_index=caption_index, num_body_joints=self.args.num_body_joints, data_size=data_size, posefix_format=True)
54 | else:
55 | raise NotImplementedError
56 | return d
57 |
58 |
59 | def init_model(self):
60 | print('Load model')
61 | self.model = PairText(text_encoder_name=self.args.text_encoder_name,
62 | transformer_topping=self.args.transformer_topping,
63 | latentD=self.args.latentD,
64 | num_body_joints=self.args.num_body_joints)
65 | self.model.to(self.device)
66 |
67 |
68 | def get_param_groups(self):
69 | param_groups = []
70 | param_groups.append({'params': self.model.pose_encoder.parameters(), 'lr': self.args.lr*self.args.lrposemul})
71 | param_groups.append({'params': self.model.pose_mlp.parameters(), 'lr': self.args.lr*self.args.lrposemul})
72 | param_groups.append({'params': [p for k,p in self.model.text_encoder.named_parameters() if 'pretrained_text_encoder.' not in k], 'lr': self.args.lr*self.args.lrtextmul})
73 | param_groups.append({'params': [self.model.loss_weight]})
74 | return param_groups
75 |
76 |
77 | def init_optimizer(self):
78 | assert self.args.optimizer=='Adam'
79 | param_groups = self.get_param_groups()
80 | self.optimizer = torch.optim.Adam(param_groups, lr=self.args.lr)
81 |
82 |
83 | def init_lr_scheduler(self):
84 | self.lr_scheduler = None
85 | if self.args.lr_scheduler == "stepLR":
86 | self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
87 | step_size=self.args.lr_step,
88 | gamma=self.args.lr_gamma,
89 | last_epoch=-1)
90 |
91 |
92 | def init_other_training_elements(self):
93 | self.data_augmentation_module = DataAugmentation(self.args, mode="posefix", tokenizer_name=get_tokenizer_name(self.args.text_encoder_name), nb_joints=self.args.num_body_joints)
94 |
95 |
96 | def training_epoch(self, epoch):
97 | train_stats = self.one_epoch(epoch=epoch, is_training=True)
98 | return train_stats
99 |
100 |
101 | def validation_epoch(self, epoch):
102 | val_stats = {}
103 | if self.args.val_every and (epoch+1)%self.args.val_every==0:
104 | val_stats = self.validate(epoch=epoch)
105 | return val_stats
106 |
107 |
108 | def one_epoch(self, epoch, is_training):
109 |
110 | self.model.train(is_training)
111 |
112 | # define loggers
113 | metric_logger = logging.MetricLogger(delimiter=" ")
114 | if is_training:
115 | prefix, sstr = '', 'train'
116 | metric_logger.add_meter(f'{sstr}_lr', logging.SmoothedValue(window_size=1, fmt='{value:.6f}'))
117 | else:
118 | prefix, sstr = '[val] ', 'val'
119 | header = f'{prefix}Epoch: [{epoch}]'
120 |
121 | # define dataloader & other elements
122 | if is_training:
123 | data_loader = self.data_loader_train
124 | if not is_training:
125 | data_loader = self.data_loader_val
126 |
127 | # iterate over the batches
128 | for data_iter_step, item in enumerate(metric_logger.log_every(data_loader, self.args.log_step, header)):
129 |
130 | # get data
131 | poses_A = item['poses_A'].to(self.device)
132 | poses_B = item['poses_B'].to(self.device)
133 | caption_tokens = item['caption_tokens'].to(self.device)
134 | caption_lengths = item['caption_lengths'].to(self.device)
135 | caption_tokens = caption_tokens[:,:caption_lengths.max()] # truncate within the batch, based on the longest text
136 |
137 | # online random augmentations
138 | posescript_poses = item['poses_A_ids'] == config.PID_NAN
139 | poses_A, caption_tokens, caption_lengths, poses_B = self.data_augmentation_module(poses_A, caption_tokens, caption_lengths, poses_B, posescript_poses)
140 |
141 | # forward; compute scores
142 | with torch.set_grad_enabled(is_training):
143 | poses_features, texts_features = self.model(poses_A, caption_tokens, caption_lengths, poses_B)
144 | score_t2p = texts_features.mm(poses_features.t()) * self.model.loss_weight
145 |
146 | # compute loss
147 | if self.args.retrieval_loss == "BBC":
148 | loss = BBC(score_t2p)
149 | elif self.args.retrieval_loss == "symBBC":
150 | loss = symBBC(score_t2p)
151 | else:
152 | raise NotImplementedError
153 |
154 | loss_value = loss.item()
155 | if not math.isfinite(loss_value):
156 | print("Loss is {}, stopping training".format(loss_value))
157 | sys.exit(1)
158 |
159 | # training step
160 | if is_training:
161 | self.optimizer.zero_grad()
162 | loss.backward()
163 | self.optimizer.step()
164 |
165 | # format data for logging
166 | scalars = [('loss', loss_value)]
167 | if is_training:
168 | lr_value = self.optimizer.param_groups[0]["lr"]
169 | scalars += [('lr', lr_value)]
170 |
171 | # actually log
172 | self.add_data_to_log_writer(epoch, sstr, scalars=scalars, is_training=is_training, data_iter_step=data_iter_step, total_steps=len(data_loader))
173 | self.add_data_to_metric_logger(metric_logger, sstr, scalars)
174 |
175 | print("Averaged stats:", metric_logger)
176 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
177 |
178 |
179 | def validate(self, epoch):
180 |
181 | self.model.eval()
182 |
183 | recalls, loss_value = compute_eval_metrics(self.model, self.data_loader_val.dataset, self.device, compute_loss=True)
184 | val_stats = {"loss": loss_value}
185 | val_stats.update(recalls)
186 |
187 | # log
188 | self.add_data_to_log_writer(epoch, 'val', scalars=[('loss', loss_value), ('validation', recalls)], should_log_data=True)
189 | print(f"[val] Epoch: [{epoch}] Stats: " + " ".join(f"{k}: {round(v, 3)}" for k,v in val_stats.items()) )
190 | return val_stats
191 |
192 |
193 | if __name__ == '__main__':
194 |
195 | argparser = get_args_parser()
196 | args = argparser.parse_args()
197 |
198 | PairTextTrainer(args)()
--------------------------------------------------------------------------------
/src/text2pose/shortname_2_model_path.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver/posescript/15c8958a7130e6fda225710249324d3b6f5a75de/src/text2pose/shortname_2_model_path.txt
--------------------------------------------------------------------------------
/src/text2pose/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | ################################################################################
4 | ## READ/WRITE TO FILES
5 | ################################################################################
6 |
7 | import json
8 |
9 | def read_json(absolute_filepath):
10 | with open(absolute_filepath, 'r') as f:
11 | data = json.load(f)
12 | return data
13 |
14 | def write_json(data, absolute_filepath, pretty=False):
15 | with open(absolute_filepath, "w") as f:
16 | if pretty:
17 | json.dump(data, f, ensure_ascii=False, indent=2)
18 | else:
19 | json.dump(data, f)
20 |
21 |
22 | ################################################################################
23 | ## ANGLE TRANSFORMATION FONCTIONS
24 | ################################################################################
25 |
26 | import roma
27 |
28 | def rotvec_to_eulerangles(x):
29 | x_rotmat = roma.rotvec_to_rotmat(x)
30 | thetax = torch.atan2(x_rotmat[:,2,1], x_rotmat[:,2,2])
31 | thetay = torch.atan2(-x_rotmat[:,2,0], torch.sqrt(x_rotmat[:,2,1]**2+x_rotmat[:,2,2]**2))
32 | thetaz = torch.atan2(x_rotmat[:,1,0], x_rotmat[:,0,0])
33 | return thetax, thetay, thetaz
34 |
35 | def eulerangles_to_rotmat(thetax, thetay, thetaz):
36 | N = thetax.numel()
37 | # rotx
38 | rotx = torch.eye( (3) ).to(thetax.device).repeat(N,1,1)
39 | roty = torch.eye( (3) ).to(thetax.device).repeat(N,1,1)
40 | rotz = torch.eye( (3) ).to(thetax.device).repeat(N,1,1)
41 | rotx[:,1,1] = torch.cos(thetax)
42 | rotx[:,2,2] = torch.cos(thetax)
43 | rotx[:,1,2] = -torch.sin(thetax)
44 | rotx[:,2,1] = torch.sin(thetax)
45 | roty[:,0,0] = torch.cos(thetay)
46 | roty[:,2,2] = torch.cos(thetay)
47 | roty[:,0,2] = torch.sin(thetay)
48 | roty[:,2,0] = -torch.sin(thetay)
49 | rotz[:,0,0] = torch.cos(thetaz)
50 | rotz[:,1,1] = torch.cos(thetaz)
51 | rotz[:,0,1] = -torch.sin(thetaz)
52 | rotz[:,1,0] = torch.sin(thetaz)
53 | rotmat = torch.einsum('bij,bjk->bik', rotz, torch.einsum('bij,bjk->bik', roty, rotx))
54 | return rotmat
55 |
56 | def eulerangles_to_rotvec(thetax, thetay, thetaz):
57 | rotmat = eulerangles_to_rotmat(thetax, thetay, thetaz)
58 | return roma.rotmat_to_rotvec(rotmat)
59 |
60 |
61 | ################################################################################
62 | ## LOAD POSE DATA
63 | ################################################################################
64 |
65 | import os
66 | import numpy as np
67 |
68 | import text2pose.config as config
69 |
70 |
71 | def get_pose_data_from_file(pose_info, applied_rotation=None, output_rotation=False):
72 | """
73 | Load pose data and normalize the orientation.
74 |
75 | Args:
76 | pose_info: list [dataset (string), sequence_filepath (string), frame_index (int)]
77 | applied_rotation: rotation to be applied to the pose data. If None, the
78 | normalization rotation is applied.
79 | output_rotation: whether to output the rotation performed for
80 | normalization, in addition of the normalized pose data.
81 |
82 | Returns:
83 | pose data, torch.tensor of size (1, n_joints*3), all joints considered.
84 | (optional) R, torch.tensor representing the rotation of normalization
85 | """
86 |
87 | # load pose data
88 | assert pose_info[0] in config.supported_datasets, f"Expected data from on of the following datasets: {','.join(config.supported_datasets)} (provided dataset: {pose_info[0]})."
89 |
90 | if pose_info[0] == "AMASS":
91 | dp = np.load(os.path.join(config.supported_datasets[pose_info[0]], pose_info[1]))
92 | pose = dp['poses'][pose_info[2],:].reshape(-1,3) # (n_joints, 3)
93 | pose = torch.as_tensor(pose).to(dtype=torch.float32)
94 |
95 | # normalize the global orient
96 | initial_rotation = pose[:1,:].clone()
97 | if applied_rotation is None:
98 | thetax, thetay, thetaz = rotvec_to_eulerangles( initial_rotation )
99 | zeros = torch.zeros_like(thetaz)
100 | pose[0:1,:] = eulerangles_to_rotvec(thetax, thetay, zeros)
101 | else:
102 | pose[0:1,:] = roma.rotvec_composition((applied_rotation, initial_rotation))
103 | if output_rotation:
104 | # a = A.u, after normalization, becomes a' = A'.u
105 | # we look for the normalization rotation R such that: a' = R.a
106 | # since a = A.u ==> u = A^-1.a
107 | # a' = A'.u = A'.A^-1.a ==> R = A'.A^-1
108 | R = roma.rotvec_composition((pose[0:1,:], roma.rotvec_inverse(initial_rotation)))
109 | return pose.reshape(1, -1), R
110 |
111 | return pose.reshape(1, -1)
112 |
113 |
114 | def pose_data_as_dict(pose_data, code_base='human_body_prior'):
115 | """
116 | Args:
117 | pose_data, torch.tensor of shape (*, n_joints*3) or (*, n_joints, 3),
118 | all joints considered.
119 | Returns:
120 | dict
121 | """
122 | # reshape to (*, n_joints*3) if necessary
123 | if len(pose_data.shape) == 3:
124 | # shape (batch_size, n_joints, 3)
125 | pose_data = pose_data.flatten(1,2)
126 | if len(pose_data.shape) == 2 and pose_data.shape[1] == 3:
127 | # shape (n_joints, 3)
128 | pose_data = pose_data.view(1, -1)
129 | # provide as a dict, with different keys, depending on the code base
130 | if code_base == 'human_body_prior':
131 | d = {"root_orient":pose_data[:,:3],
132 | "pose_body":pose_data[:,3:66]}
133 | if pose_data.shape[1] > 66:
134 | d["pose_hand"] = pose_data[:,66:]
135 | elif code_base == 'smplx':
136 | d = {"global_orient":pose_data[:,:3],
137 | "body_pose":pose_data[:,3:66]}
138 | if pose_data.shape[1] > 66:
139 | d.update({"left_hand_pose":pose_data[:,66:111],
140 | "right_hand_pose":pose_data[:,111:]})
141 | return d
--------------------------------------------------------------------------------
/src/text2pose/utils_logging.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import datetime
4 | import time
5 | from collections import defaultdict, deque
6 | import torch
7 |
8 |
9 | class HiddenPrints:
10 | def __enter__(self):
11 | self._original_stdout = sys.stdout
12 | sys.stdout = open(os.devnull, 'w')
13 |
14 | def __exit__(self, exc_type, exc_val, exc_tb):
15 | sys.stdout.close()
16 | sys.stdout = self._original_stdout
17 |
18 |
19 | class SmoothedValue(object):
20 | """
21 | Track a series of values and provide access to smoothed values over a window
22 | or the global series average.
23 | """
24 |
25 | def __init__(self, window_size=20, fmt=None):
26 | if fmt is None:
27 | fmt = "{median:.4f} ({global_avg:.4f})"
28 | self.deque = deque(maxlen=window_size)
29 | self.total = 0.0
30 | self.count = 0
31 | self.fmt = fmt
32 |
33 | def update(self, value, n=1):
34 | self.deque.append(value)
35 | self.count += n
36 | self.total += value * n
37 |
38 | @property
39 | def median(self):
40 | d = torch.tensor(list(self.deque))
41 | return d.median().item()
42 |
43 | @property
44 | def avg(self):
45 | d = torch.tensor(list(self.deque), dtype=torch.float32)
46 | return d.mean().item()
47 |
48 | @property
49 | def global_avg(self):
50 | return self.total / self.count
51 |
52 | @property
53 | def max(self):
54 | return max(self.deque)
55 |
56 | @property
57 | def value(self):
58 | return self.deque[-1]
59 |
60 | def __str__(self):
61 | return self.fmt.format(
62 | median=self.median,
63 | avg=self.avg,
64 | global_avg=self.global_avg,
65 | max=self.max,
66 | value=self.value)
67 |
68 |
69 | class MetricLogger(object):
70 |
71 | def __init__(self, delimiter="\t"):
72 | self.meters = defaultdict(SmoothedValue)
73 | self.delimiter = delimiter
74 |
75 | def update(self, **kwargs):
76 | for k, v in kwargs.items():
77 | if v is None:
78 | continue
79 | if isinstance(v, torch.Tensor):
80 | v = v.item()
81 | assert isinstance(v, (float, int))
82 | self.meters[k].update(v)
83 |
84 | def __getattr__(self, attr):
85 | if attr in self.meters:
86 | return self.meters[attr]
87 | if attr in self.__dict__:
88 | return self.__dict__[attr]
89 | raise AttributeError("'{}' object has no attribute '{}'".format(
90 | type(self).__name__, attr))
91 |
92 | def __str__(self):
93 | loss_str = []
94 | for name, meter in self.meters.items():
95 | loss_str.append(
96 | "{}: {}".format(name, str(meter))
97 | )
98 | return self.delimiter.join(loss_str)
99 |
100 | def add_meter(self, name, meter):
101 | self.meters[name] = meter
102 |
103 | def log_every(self, iterable, print_freq, header=None):
104 | i = 0
105 | if not header:
106 | header = ''
107 | start_time = time.time()
108 | end = time.time()
109 | iter_time = SmoothedValue(fmt='{avg:.4f}')
110 | data_time = SmoothedValue(fmt='{avg:.4f}')
111 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
112 | log_msg = [
113 | header,
114 | '[{0' + space_fmt + '}/{1}]',
115 | 'eta: {eta}',
116 | '{meters}',
117 | 'time: {time}',
118 | 'data: {data}'
119 | ]
120 | if torch.cuda.is_available():
121 | log_msg.append('max mem: {memory:.0f}')
122 | log_msg = self.delimiter.join(log_msg)
123 | MB = 1024.0 * 1024.0
124 | for obj in iterable:
125 | data_time.update(time.time() - end)
126 | yield obj
127 | iter_time.update(time.time() - end)
128 | if i % print_freq == 0 or i == len(iterable) - 1:
129 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
130 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
131 | if torch.cuda.is_available():
132 | print(log_msg.format(
133 | i, len(iterable), eta=eta_string,
134 | meters=str(self),
135 | time=str(iter_time), data=str(data_time),
136 | memory=torch.cuda.max_memory_allocated() / MB))
137 | else:
138 | print(log_msg.format(
139 | i, len(iterable), eta=eta_string,
140 | meters=str(self),
141 | time=str(iter_time), data=str(data_time)))
142 | i += 1
143 | end = time.time()
144 | total_time = time.time() - start_time
145 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
146 | print('{} Total time: {} ({:.4f} s / it)'.format(
147 | header, total_time_str, total_time / len(iterable)))
148 |
--------------------------------------------------------------------------------