├── LICENSE ├── README.md ├── assets ├── EPIC-KITCHENS │ └── P01 │ │ └── rgb_frames │ │ └── P01_01 │ │ ├── frame_0000016827.jpg │ │ ├── frame_0000016833.jpg │ │ ├── frame_0000016839.jpg │ │ ├── frame_0000016845.jpg │ │ ├── frame_0000016851.jpg │ │ ├── frame_0000016857.jpg │ │ ├── frame_0000016863.jpg │ │ ├── frame_0000016869.jpg │ │ ├── frame_0000016875.jpg │ │ ├── frame_0000016881.jpg │ │ └── frame_0000016887.jpg ├── demo_gen.jpg └── teaser.gif ├── datasets ├── dataloaders.py ├── dataset_utils.py ├── datasetopts.py ├── ho_utils.py ├── holoaders.py └── input_loaders.py ├── demo_gen.py ├── environment.yaml ├── evaluation ├── affordance_eval.py └── traj_eval.py ├── netscripts ├── epoch_feat.py ├── epoch_utils.py ├── get_datasets.py ├── get_network.py ├── get_optimizer.py └── modelio.py ├── networks ├── affordance_decoder.py ├── decoder_modules.py ├── embedding.py ├── layer.py ├── model.py ├── net_utils.py ├── traj_decoder.py └── transformer.py ├── options ├── expopts.py └── netsopts.py ├── preprocess ├── affordance_util.py ├── dataset_util.py ├── ho_types.py ├── obj_util.py ├── traj_util.py ├── types_pb2.py └── vis_util.py ├── requirements.txt └── traineval.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 stevenlsw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HOI-Forecast 2 | 3 | **Joint Hand Motion and Interaction Hotspots Prediction from Egocentric Videos (CVPR 2022)** 4 | 5 | 6 | 7 | #### [[Project Page]](https://stevenlsw.github.io/hoi-forecast/) [[Paper]](https://arxiv.org/abs/2204.01696) [[Training Data]](https://drive.google.com/drive/folders/1llDYFwn2gGQLpcWy6YScp3ej7A3LIPFc) 8 | 9 | Given observation frames of the past, we predict future hand trajectories (green and red lines) and object interaction hotspots (heatmaps) in egocentric view. We genearte training data **automatically** and use this data to train an Object-Centric Transformer (OCT) model for prediction. 10 |
11 |

12 | 13 |

14 | 15 | 16 | ## Installation 17 | - Clone this repository: 18 | ```Shell 19 | git clone https://github.com/stevenlsw/hoi-forecast 20 | cd hoi-forecast 21 | ``` 22 | - Python 3.6 Environment: 23 | ```Shell 24 | conda env create -f environment.yaml 25 | conda activate fhoi 26 | ``` 27 | 28 | ## Quick training data generation 29 | Official Epic-Kitchens Dataset looks the same as `assets/EPIC-KITCHENS`, rgb frames needed for the demo has been pre-downloaded in `assets/EPIC-KITCHENS/P01/rgb_frames/P01_01`. 30 | 31 | - Download Epic-Kitchens 55 Dataset [annotations](https://raw.githubusercontent.com/epic-kitchens/epic-kitchens-55-annotations/master/EPIC_train_action_labels.csv) and save in `assets` folder 32 | 33 | - Download hand-object detections below 34 | ```Shell 35 | link=https://data.bris.ac.uk/datasets/3l8eci2oqgst92n14w2yqi5ytu/hand-objects/P01/P01_01.pkl 36 | wget -P assets/EPIC-KITCHENS/P01/hand-objects $link 37 | ``` 38 | - Run `python demo_gen.py` and results [png, pkl] are stored in `figs`, you should visualize the result 39 | 40 |
41 | 42 |
43 | 44 | - For more generated training labels, please visit [google drive](https://drive.google.com/drive/folders/1lNOSXXKbiqYJqC1hp1CIaIeM8u6UVvcg) and run `python example.py`. 45 | 46 | 47 | ## Evaluation on EK100 48 | We maunally collect the hand trajectories and interaction hotspots for evaluation. We pre-extract the input videos features. 49 | 50 | - Download the processed [files](https://drive.google.com/file/d/1s7qpBa-JjjuGk7v_aiuU2lvRjgP6fi9C/view?usp=sharing) (include collected labels, pre-extracted features, and dataset partitions, 600 MB) and **unzipped**. You will get the stucture like: 51 | ``` 52 | hoi-forecast 53 | |-- data 54 | | |-- ek100 55 | | | |-- ek100_eval_labels.pkl 56 | | | |-- video_info.json 57 | | | |-- labels 58 | | | | |-- label_303.pkl 59 | | | | |-- ... 60 | | | |-- feats 61 | | | | |-- data.lmdb (RGB) 62 | |-- common 63 | | |-- epic-kitchens-55-annotations 64 | | |-- epic-kitchens-100-annotations 65 | | |-- rulstm 66 | ``` 67 | 68 | - Download [pretrained models](https://drive.google.com/file/d/16IkQ4hOQk2_Klhd806J-46hLN-OGokxa/view?usp=sharing) on EK100 and the stored model path is refered as `$resume`. 69 | 70 | - Install PyTorch and dependencies by the following command: 71 | ```Shell 72 | pip install -r requirements.txt 73 | ``` 74 | 75 | - Evaluate future hand trajectory 76 | ```Shell 77 | python traineval.py --evaluate --ek_version=ek100 --resume={path to the model} --traj_only 78 | ``` 79 | 80 | - Evaluate future interaction hotspots 81 | ```Shell 82 | python traineval.py --evaluate --ek_version=ek100 --resume={path to the model} 83 | ``` 84 | 85 | - Results should like: 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 |
Hand TrajectoryInteraction Hotspots
ADE ↓FDE ↓SIM ↑AUC-J ↑NSS ↑
0.120.110.190.690.72
111 | 112 | 113 | ## Training 114 | - Extract per-frame features of training set similar to [RULSTM](https://github.com/fpv-iplab/rulstm) and store them in `data/ek100/feats/ek100.lmdb`, the key-value pair likes 115 | ```python 116 | fname = 'P01/rgb_frames/P01_01/frame_0000000720.jpg' 117 | env[fname.encode()] = result_dict # extracted feature results 118 | ``` 119 | 120 | - Start training 121 | ``` 122 | python traineval.py --ek_version=ek100 123 | ``` 124 | 125 | ## Citation 126 | ```latex 127 | @inproceedings{liu2022joint, 128 | title={Joint Hand Motion and Interaction Hotspots Prediction from Egocentric Videos}, 129 | author={Liu, Shaowei and Tripathi, Subarna and Majumdar, Somdeb and Wang, Xiaolong}, 130 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 131 | year={2022} 132 | } 133 | ``` 134 | 135 | ## Acknowledges 136 | We thank: 137 | * [epic-kitchens](https://github.com/epic-kitchens/epic-kitchens-100-hand-object-bboxes) for hand-object detections on Epic-Kitchens dataset 138 | * [rulstm](https://github.com/fpv-iplab/rulstm) for features extraction and action anticipation 139 | * [epic-kitchens-dataset-pytorch](https://github.com/guglielmocamporese/epic-kitchens-dataset-pytorch) for 140 | epic-kitchens dataloader 141 | * [Miao Liu](https://aptx4869lm.github.io/) for help with prior work 142 | 143 | -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016827.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016827.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016833.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016833.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016839.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016839.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016845.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016845.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016851.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016851.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016857.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016857.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016863.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016863.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016869.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016869.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016875.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016875.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016881.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016881.jpg -------------------------------------------------------------------------------- /assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016887.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016887.jpg -------------------------------------------------------------------------------- /assets/demo_gen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/demo_gen.jpg -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/teaser.gif -------------------------------------------------------------------------------- /datasets/dataloaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from tqdm import tqdm 3 | import json 4 | import numpy as np 5 | 6 | from datasets.dataset_utils import get_ek55_annotation, get_ek100_annotation 7 | from datasets.input_loaders import get_loaders 8 | 9 | 10 | class EpicAction(object): 11 | def __init__(self, uid, participant_id, video_id, verb, verb_class, 12 | noun, noun_class, all_nouns, all_noun_classes, start_frame, 13 | stop_frame, start_time, stop_time, ori_fps, partition, action, action_class): 14 | self.uid = uid 15 | self.participant_id = participant_id 16 | self.video_id = video_id 17 | self.verb = verb 18 | self.verb_class = verb_class 19 | self.noun = noun 20 | self.noun_class = noun_class 21 | self.all_nouns = all_nouns 22 | self.all_noun_classes = all_noun_classes 23 | self.start_frame = start_frame 24 | self.stop_frame = stop_frame 25 | self.start_time = start_time 26 | self.stop_time = stop_time 27 | self.ori_fps = ori_fps 28 | self.partition = partition 29 | self.action = action 30 | self.action_class = action_class 31 | 32 | self.duration = self.stop_time - self.start_time 33 | 34 | def __repr__(self): 35 | return json.dumps(self.__dict__, indent=4) 36 | 37 | def set_previous_actions(self, actions): 38 | self.actions_prev = actions 39 | 40 | 41 | class EpicVideo(object): 42 | def __init__(self, df_video, ori_fps, partition, t_ant=None): 43 | self.df = df_video 44 | self.ori_fps = ori_fps 45 | self.partition = partition 46 | self.t_ant = t_ant 47 | 48 | self.actions, self.actions_invalid = self._get_actions() 49 | self.duration = max([a.stop_time for a in self.actions]) 50 | 51 | def _get_actions(self): 52 | actions = [] 53 | _actions_all = [] 54 | actions_invalid = [] 55 | for _, row in self.df.iterrows(): 56 | action_args = { 57 | 'uid': row.uid, 58 | 'participant_id': row.participant_id, 59 | 'video_id': row.video_id, 60 | 'verb': row.verb if 'test' not in self.partition else None, 61 | 'verb_class': row.verb_class if 'test' not in self.partition else None, 62 | 'noun': row.noun if 'test' not in self.partition else None, 63 | 'noun_class': row.noun_class if 'test' not in self.partition else None, 64 | 'all_nouns': row.all_nouns if 'test' not in self.partition else None, 65 | 'all_noun_classes': row.all_noun_classes if 'test' not in self.partition else None, 66 | 'start_frame': row.start_frame, 67 | 'stop_frame': row.stop_time, 68 | 'start_time': row.start_time, 69 | 'stop_time': row.stop_time, 70 | 'ori_fps': self.ori_fps, 71 | 'partition': self.partition, 72 | 'action': row.action if 'test' not in self.partition else None, 73 | 'action_class': row.action_class if 'test' not in self.partition else None, 74 | } 75 | action = EpicAction(**action_args) 76 | action.set_previous_actions([aa for aa in _actions_all]) 77 | assert self.t_ant is not None 78 | assert self.t_ant > 0.0 79 | if action.start_time - self.t_ant >= 0: 80 | actions += [action] 81 | else: 82 | actions_invalid += [action] 83 | _actions_all += [action] 84 | return actions, actions_invalid 85 | 86 | 87 | class EpicDataset(Dataset): 88 | def __init__(self, df, partition, ori_fps=60.0, fps=4.0, loader=None, t_ant=None, transform=None, 89 | num_actions_prev=None, label_path=None, eval_label_path=None, 90 | annot_path=None, rulstm_annot_path=None, ek_version=None): 91 | super().__init__() 92 | self.partition = partition 93 | self.ori_fps = ori_fps 94 | self.fps = fps 95 | self.df = df 96 | self.loader = loader 97 | self.t_ant = t_ant 98 | self.transform = transform 99 | self.num_actions_prev = num_actions_prev 100 | 101 | self.videos = self._get_videos() 102 | self.actions, self.actions_invalid = self._get_actions() 103 | 104 | def _get_videos(self): 105 | video_ids = sorted(list(set(self.df['video_id'].values.tolist()))) 106 | videos = [] 107 | pbar = tqdm(desc=f'Loading {self.partition} samples', total=len(self.df)) 108 | for video_id in video_ids: 109 | video_args = { 110 | 'df_video': self.df[self.df['video_id'] == video_id].copy(), 111 | 'ori_fps': self.ori_fps, 112 | 'partition': self.partition, 113 | 't_ant': self.t_ant 114 | } 115 | video = EpicVideo(**video_args) 116 | videos += [video] 117 | pbar.update(len(video.actions)) 118 | pbar.close() 119 | return videos 120 | 121 | def _get_actions(self): 122 | actions = [] 123 | actions_invalid = [] 124 | for video in self.videos: 125 | actions += video.actions 126 | actions_invalid += video.actions_invalid 127 | return actions, actions_invalid 128 | 129 | def __len__(self): 130 | return len(self.actions) 131 | 132 | def __getitem__(self, idx): 133 | a = self.actions[idx] 134 | sample = {'uid': a.uid} 135 | 136 | inputs = self.loader(a) 137 | sample.update(inputs) 138 | 139 | if 'test' not in self.partition: 140 | sample['verb_class'] = a.verb_class 141 | sample['noun_class'] = a.noun_class 142 | sample['action_class'] = a.action_class 143 | 144 | actions_prev = [-1] + [aa.action_class for aa in a.actions_prev] 145 | actions_prev = actions_prev[-self.num_actions_prev:] 146 | if len(actions_prev) < self.num_actions_prev: 147 | actions_prev = actions_prev[0:1] * (self.num_actions_prev - len(actions_prev)) + actions_prev 148 | actions_prev = np.array(actions_prev, dtype=np.int64) 149 | sample['action_class_prev'] = actions_prev 150 | return sample 151 | 152 | 153 | def get_datasets(args, epic_ds=None, featuresloader=None): 154 | loaders = get_loaders(args, featuresloader=featuresloader) 155 | 156 | annotation_args = { 157 | 'annot_path': args.annot_path, 158 | 'label_path': args.label_path, 159 | 'eval_label_path': args.eval_label_path, 160 | 'rulstm_annot_path': args.rulstm_annot_path, 161 | 'validation_ratio': args.validation_ratio, 162 | 'use_rulstm_splits': args.use_rulstm_splits, 163 | 'use_label_only': args.use_label_only 164 | } 165 | 166 | if args.ek_version == 'ek55': 167 | dfs = { 168 | 'train': get_ek55_annotation(partition='train', **annotation_args), 169 | 'validation': get_ek55_annotation(partition='validation', **annotation_args), 170 | 'eval': get_ek55_annotation(partition='eval', **annotation_args), 171 | 'test_s1': get_ek55_annotation(partition='test_s1', **annotation_args), 172 | 'test_s2': get_ek55_annotation(partition='test_s2', **annotation_args), 173 | } 174 | elif args.ek_version == 'ek100': 175 | dfs = { 176 | 'train': get_ek100_annotation(partition='train', **annotation_args), 177 | 'validation': get_ek100_annotation(partition='validation', **annotation_args), 178 | 'eval': get_ek100_annotation(partition='eval', **annotation_args), 179 | 'test': get_ek100_annotation(partition='test', **annotation_args), 180 | } 181 | else: 182 | raise Exception(f'Error. EPIC-Kitchens Version "{args.ek_version}" not supported.') 183 | 184 | ds_args = { 185 | 'label_path': args.label_path[args.ek_version], 186 | 'eval_label_path': args.eval_label_path[args.ek_version], 187 | 'annot_path': args.annot_path, 188 | 'rulstm_annot_path': args.rulstm_annot_path[args.ek_version], 189 | 'ek_version': args.ek_version, 190 | 'ori_fps': args.ori_fps, 191 | 'fps': args.fps, 192 | 't_ant': args.t_ant, 193 | 'num_actions_prev': args.num_actions_prev if args.task in ['anticipation'] else None, 194 | } 195 | 196 | if epic_ds is None: 197 | epic_ds = EpicDataset 198 | 199 | if args.mode in ['train', 'training']: 200 | dss = { 201 | 'train': epic_ds(df=dfs['train'], partition='train', loader=loaders['train'], **ds_args), 202 | 'validation': epic_ds(df=dfs['validation'], partition='validation', loader=loaders['validation'], 203 | **ds_args), 204 | 'eval': epic_ds(df=dfs['eval'], partition='eval', loader=loaders['validation'], **ds_args), 205 | } 206 | elif args.mode in ['validation', 'validating', 'validate']: 207 | dss = { 208 | 'validation': epic_ds(df=dfs['validation'], partition='validation', 209 | loader=loaders['validation'], **ds_args), 210 | 'eval': epic_ds(df=dfs['eval'], partition='eval', loader=loaders['validation'], **ds_args), 211 | } 212 | elif args.mode in ['test', 'testing']: 213 | 214 | if args.ek_version == "ek55": 215 | dss = { 216 | 'test_s1': epic_ds(df=dfs['test_s1'], partition='test_s1', loader=loaders['test'], **ds_args), 217 | 'test_s2': epic_ds(df=dfs['test_s2'], partition='test_s2', loader=loaders['test'], **ds_args), 218 | } 219 | elif args.ek_version == "ek100": 220 | dss = { 221 | 'test': epic_ds(df=dfs['test'], partition='test', loader=loaders['test'], **ds_args), 222 | } 223 | else: 224 | raise Exception(f'Error. Mode "{args.mode}" not supported.') 225 | 226 | return dss 227 | 228 | 229 | def get_dataloaders(args, epic_ds=None, featuresloader=None): 230 | dss = get_datasets(args, epic_ds=epic_ds,featuresloader=featuresloader) 231 | dl_args = { 232 | 'batch_size': args.batch_size, 233 | 'pin_memory': True, 234 | 'num_workers': args.num_workers, 235 | } 236 | if args.mode in ['train', 'training']: 237 | dls = { 238 | 'train': DataLoader(dss['train'], shuffle=False, **dl_args), 239 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args), 240 | } 241 | elif args.mode in ['validate', 'validation', 'validating']: 242 | dls = { 243 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args), 244 | } 245 | elif args.mode == 'test': 246 | if args.ek_version == "ek55": 247 | dls = { 248 | 'test_s1': DataLoader(dss['test_s1'], shuffle=False, **dl_args), 249 | 'test_s2': DataLoader(dss['test_s2'], shuffle=False, **dl_args), 250 | } 251 | elif args.ek_version == "ek100": 252 | dls = { 253 | 'test': DataLoader(dss['test'], shuffle=False, **dl_args), 254 | } 255 | else: 256 | raise Exception(f'Error. Mode "{args.mode}" not supported.') 257 | return dls 258 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import pickle 5 | from sklearn.model_selection import train_test_split 6 | import numpy as np 7 | 8 | 9 | def timestr2sec(t_str): 10 | hh, mm, ss = [float(x) for x in t_str.split(':')] 11 | t_sec = hh * 3600.0 + mm * 60.0 + ss 12 | return t_sec 13 | 14 | 15 | def read_rulstm_splits(rulstm_annotation_path): 16 | header = ['uid', 'video_id', 'start_frame', 'stop_frame', 'verb_class', 'noun_class', 'action_class'] 17 | df_train = pd.read_csv(os.path.join(rulstm_annotation_path, 'training.csv'), names=header) 18 | df_validation = pd.read_csv(os.path.join(rulstm_annotation_path, 'validation.csv'), names=header) 19 | return df_train, df_validation 20 | 21 | 22 | def str2list(s, out_type=None): 23 | """ 24 | Convert a string "[i1, i2, ...]" of items into a list [i1, i2, ...] of items. 25 | """ 26 | s = s.replace('[', '').replace(']', '') 27 | s = s.replace('\'', '') 28 | s = s.split(', ') 29 | if out_type is not None: 30 | s = [out_type(ss) for ss in s] 31 | return s 32 | 33 | 34 | def split_train_val(df, validation_ratio=0.2, use_rulstm_splits=False, 35 | rulstm_annotation_path=None, label_info_path=None, 36 | use_label_only=True): 37 | if label_info_path is not None and use_label_only: 38 | with open(label_info_path, 'r') as f: 39 | uids_label = json.load(f) 40 | df = df.loc[df['uid'].isin(uids_label)] 41 | if use_rulstm_splits: 42 | assert rulstm_annotation_path is not None 43 | df_train_rulstm, df_validation_rulstm = read_rulstm_splits(rulstm_annotation_path) 44 | uids_train = df_train_rulstm['uid'].values.tolist() 45 | uids_validation = df_validation_rulstm['uid'].values.tolist() 46 | df_train = df.loc[df['uid'].isin(uids_train)] 47 | df_validation = df.loc[df['uid'].isin(uids_validation)] 48 | else: 49 | if validation_ratio == 0.0: 50 | df_train = df 51 | df_validation = pd.DataFrame(columns=df.columns) 52 | elif validation_ratio == 1.0: 53 | df_train = pd.DataFrame(columns=df.columns) 54 | df_validation = df 55 | elif 0.0 < validation_ratio < 1.0: 56 | df_train, df_validation = train_test_split(df, test_size=validation_ratio, 57 | random_state=3577, 58 | shuffle=True, stratify=df['participant_id']) 59 | else: 60 | raise Exception(f'Error. Validation "{validation_ratio}" not supported.') 61 | return df_train, df_validation 62 | 63 | 64 | def create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, ek_version, out_path='actions.csv', use_rulstm_splits=True): 65 | if use_rulstm_splits: 66 | if ek_version == 'ek55': 67 | df_actions = pd.read_csv(os.path.join(rulstm_annot_path['ek55'], 'actions.csv')) 68 | elif ek_version == 'ek100': 69 | df_actions = pd.read_csv(os.path.join(rulstm_annot_path['ek100'], 'actions.csv')) 70 | df_actions['action'] = df_actions.action.map(lambda x: x.replace(' ', '_')) 71 | 72 | df_actions['verb_class'] = df_actions.verb 73 | df_actions['noun_class'] = df_actions.noun 74 | df_actions['verb'] = df_actions.action.map(lambda x: x.split('_')[0]) 75 | df_actions['noun'] = df_actions.action.map(lambda x: x.split('_')[1]) 76 | df_actions['action'] = df_actions.action 77 | df_actions['action_class'] = df_actions.id 78 | del df_actions['id'] 79 | 80 | else: 81 | if ek_version == 'ek55': 82 | df_train = get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path=None, partition='train', 83 | use_label_only=False, raw=True) 84 | df_validation = get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition='validation', 85 | use_label_only=False, raw=True) 86 | df = pd.concat([df_train, df_validation]) 87 | df.sort_values(by=['uid'], inplace=True) 88 | 89 | elif ek_version == 'ek100': 90 | df_train = get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path=None, partition='train', 91 | use_label_only=False, raw=True) 92 | df_validation = get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition='validation', 93 | use_label_only=False, raw=True) 94 | df = pd.concat([df_train, df_validation]) 95 | df.sort_values(by=['narration_id'], inplace=True) 96 | 97 | noun_classes = df.noun_class.values 98 | nouns = df.noun.values 99 | verb_classes = df.verb_class.values 100 | verbs = df.verb.values 101 | 102 | actions_combinations = [f'{v}_{n}' for v, n in zip(verb_classes, noun_classes)] 103 | actions = [f'{v}_{n}' for v, n in zip(verbs, nouns)] 104 | 105 | df_actions = {'verb_class': [], 'noun_class': [], 'verb': [], 'noun': [], 'action': []} 106 | vn_combinations = [] 107 | for i, a in enumerate(actions_combinations): 108 | if a in vn_combinations: 109 | continue 110 | 111 | v, n = a.split('_') 112 | v = int(v) 113 | n = int(n) 114 | df_actions['verb_class'] += [v] 115 | df_actions['noun_class'] += [n] 116 | df_actions['action'] += [actions[i]] 117 | df_actions['verb'] += [verbs[i]] 118 | df_actions['noun'] += [nouns[i]] 119 | vn_combinations += [a] 120 | df_actions = pd.DataFrame(df_actions) 121 | df_actions.sort_values(by=['verb_class', 'noun_class'], inplace=True) 122 | df_actions['action_class'] = range(len(df_actions)) 123 | 124 | df_actions.to_csv(out_path, index=False) 125 | print(f'Saved file at "{out_path}".') 126 | 127 | 128 | def get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition, validation_ratio=0.2, 129 | use_rulstm_splits=False, use_label_only=True, raw=False): 130 | if partition in ['train', 'validation']: 131 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_train_action_labels.csv') 132 | label_info_path = os.path.join(label_path['ek55'], "video_info.json") 133 | df = pd.read_csv(csv_path) 134 | df_train, df_validation = split_train_val(df, validation_ratio=validation_ratio, 135 | use_rulstm_splits=use_rulstm_splits, 136 | rulstm_annotation_path=rulstm_annot_path['ek55'], 137 | label_info_path=label_info_path, 138 | use_label_only=use_label_only) 139 | 140 | df = df_train if partition == 'train' else df_validation 141 | if not use_rulstm_splits: 142 | df.sort_values(by=['uid'], inplace=True) 143 | 144 | elif partition in ['eval', 'evaluation']: 145 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_train_action_labels.csv') 146 | df = pd.read_csv(csv_path) 147 | with open(eval_label_path['ek55'], 'rb') as f: 148 | eval_labels = pickle.load(f) 149 | eval_uids = eval_labels.keys() 150 | df = df.loc[df['uid'].isin(eval_uids)] 151 | 152 | elif partition == 'test_s1': 153 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_test_s1_timestamps.csv') 154 | df = pd.read_csv(csv_path) 155 | 156 | elif partition == 'test_s2': 157 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_test_s2_timestamps.csv') 158 | df = pd.read_csv(csv_path) 159 | else: 160 | raise Exception(f'Error. Partition "{partition}" not supported.') 161 | 162 | if raw: 163 | return df 164 | 165 | actions_df_path = os.path.join(annot_path['ek55'], 'actions.csv') 166 | if not os.path.exists(actions_df_path): 167 | create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, 'ek55', out_path=actions_df_path, use_rulstm_splits=True) 168 | df_actions = pd.read_csv(actions_df_path) 169 | 170 | df['start_time'] = df['start_timestamp'].map(lambda t: timestr2sec(t)) 171 | df['stop_time'] = df['stop_timestamp'].map(lambda t: timestr2sec(t)) 172 | if 'test' not in partition: 173 | action_classes = [] 174 | actions = [] 175 | for _, row in df.iterrows(): 176 | v, n = row.verb_class, row.noun_class 177 | df_a_sub = df_actions[(df_actions['verb_class'] == v) & (df_actions['noun_class'] == n)] 178 | a_cl = df_a_sub['action_class'].values 179 | a = df_a_sub['action'].values 180 | if len(a_cl) > 1: 181 | print(a_cl) 182 | action_classes += [a_cl[0]] 183 | actions += [a[0]] 184 | df['action_class'] = action_classes 185 | df['action'] = actions 186 | df['all_nouns'] = df['all_nouns'].map(lambda x: str2list(x)) 187 | df['all_noun_classes'] = df['all_noun_classes'].map(lambda x: str2list(x, out_type=int)) 188 | 189 | return df 190 | 191 | 192 | def get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition, validation_ratio=0.2, 193 | use_rulstm_splits=False, use_label_only=True, raw=False): 194 | if partition in 'train': 195 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv')) 196 | uids = np.arange(len(df)) 197 | 198 | elif partition in 'validation': 199 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv')) 200 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv')) 201 | uids = np.arange(len(df)) + len(df_train) 202 | 203 | elif partition in 'evaluation': 204 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv')) 205 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv')) 206 | uids = np.arange(len(df)) + len(df_train) 207 | df['uid'] = uids 208 | with open(eval_label_path['ek100'], 'rb') as f: 209 | eval_labels = pickle.load(f) 210 | eval_uids = eval_labels.keys() 211 | df = df.loc[df['uid'].isin(eval_uids)] 212 | 213 | elif partition == 'test': 214 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv')) 215 | df_validation = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv')) 216 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_test_timestamps.csv')) 217 | uids = np.arange(len(df)) + len(df_train) + len(df_validation) 218 | 219 | else: 220 | raise Exception(f'Error. Partition "{partition}" not supported.') 221 | if raw: 222 | return df 223 | 224 | actions_df_path = os.path.join(annot_path['ek100'], 'actions.csv') 225 | if not os.path.exists(actions_df_path): 226 | create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, 'ek100', actions_df_path) 227 | df_actions = pd.read_csv(actions_df_path) 228 | 229 | df['start_time'] = df['start_timestamp'].map(lambda t: timestr2sec(t)) 230 | df['stop_time'] = df['stop_timestamp'].map(lambda t: timestr2sec(t)) 231 | if not 'uid' in df: 232 | df['uid'] = uids 233 | 234 | if use_label_only: 235 | label_info_path = os.path.join(label_path['ek100'], "video_info.json") 236 | with open(label_info_path, 'r') as f: 237 | uids_label = json.load(f) 238 | df = df.loc[df['uid'].isin(uids_label)] 239 | 240 | if 'test' not in partition: 241 | action_classes = [] 242 | actions = [] 243 | for _, row in df.iterrows(): 244 | v, n = row.verb_class, row.noun_class 245 | df_a_sub = df_actions[(df_actions['verb_class'] == v) & (df_actions['noun_class'] == n)] 246 | a_cl = df_a_sub['action_class'].values 247 | a = df_a_sub['action'].values 248 | if len(a_cl) > 1: 249 | print(a_cl) 250 | action_classes += [a_cl[0]] 251 | actions += [a[0]] 252 | df['action_class'] = action_classes 253 | df['action'] = actions 254 | df['all_nouns'] = df['all_nouns'].map(lambda x: str2list(x)) 255 | df['all_noun_classes'] = df['all_noun_classes'].map(lambda x: str2list(x, out_type=int)) 256 | return df -------------------------------------------------------------------------------- /datasets/datasetopts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | class DatasetArgs(object): 6 | def __init__(self, ek_version='ek55', mode="train", use_label_only=True, 7 | base_path="./", batch_size=32, num_workers=0, modalities=['feat'], 8 | fps=4, t_buffer=2.5): 9 | 10 | self.features_paths = { 11 | 'ek55': os.path.join(base_path, 'data/ek55/feats'), 12 | 'ek100': os.path.join(base_path, 'data/ek100/feats')} 13 | 14 | # generated data labels 15 | self.label_path = { 16 | 'ek55': os.path.join(base_path, 'data/ek55'), 17 | 'ek100': os.path.join(base_path, 'data/ek100')} 18 | 19 | # amazon-annotated eval labels 20 | self.eval_label_path = { 21 | 'ek55': os.path.join(base_path, 'data/ek55/ek55_eval_labels.pkl'), 22 | 'ek100': os.path.join(base_path, 'data/ek100/ek100_eval_labels.pkl') 23 | } 24 | 25 | self.annot_path = { 26 | 'ek55': os.path.join(base_path, 'common/epic-kitchens-55-annotations'), 27 | 'ek100': os.path.join(base_path, 'common/epic-kitchens-100-annotations')} 28 | 29 | self.rulstm_annot_path = { 30 | 'ek55': os.path.join(base_path, 'common/rulstm/RULSTM/data/ek55'), 31 | 'ek100': os.path.join(base_path, 'common/rulstm/RULSTM/data/ek100')} 32 | 33 | self.pretrained_backbone_path = { 34 | 'ek55': os.path.join(base_path, 'common/rulstm/FEATEXT/models/ek55', 'TSN-rgb.pth.tar'), 35 | 'ek100': os.path.join(base_path, 'common/rulstm/FEATEXT/models/ek100', 'TSN-rgb-ek100.pth.tar'), 36 | } 37 | 38 | # default settings, no need changes 39 | if fps is None: 40 | self.fps = 4 41 | else: 42 | self.fps = fps 43 | 44 | if t_buffer is None: 45 | self.t_buffer = 2.5 46 | else: 47 | self.t_buffer = t_buffer 48 | 49 | self.ori_fps = 60.0 50 | self.t_ant = 1.0 51 | 52 | self.validation_ratio = 0.2 53 | self.use_rulstm_splits = True 54 | 55 | # only preprocess uids that have corresponding labels, in "video_info.json" 56 | self.use_label_only = use_label_only 57 | 58 | self.task = 'anticipation' 59 | self.num_actions_prev = 1 60 | 61 | self.batch_size = batch_size 62 | self.num_workers = num_workers 63 | 64 | self.modalities = modalities 65 | self.ek_version = ek_version # 'ek55' or 'ek100' 66 | self.mode = mode # 'train' 67 | 68 | def add_attr(self, attr_name, attr_value): 69 | setattr(self, attr_name, attr_value) 70 | 71 | def has_attr(self, attr_name): 72 | return hasattr(self, attr_name) 73 | 74 | def __repr__(self): 75 | return 'Input Args: ' + json.dumps(self.__dict__, indent=4) 76 | 77 | -------------------------------------------------------------------------------- /datasets/ho_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | 6 | def load_video_info(label_path, video_index): 7 | with open(os.path.join(label_path, "label_{}.pkl".format(video_index)), 'rb') as f: 8 | video_info = pickle.load(f) 9 | return video_info 10 | 11 | 12 | def sample_hand_traj(meta, fps, t_ant, shape=(456, 256)): 13 | width, height = shape 14 | traj = meta["traj"] 15 | ori_fps = int((len(traj) - 1) / t_ant) 16 | gap = int(ori_fps // fps) 17 | stop_idx = len(traj) 18 | indices = [0] + list(range(gap, stop_idx, gap)) 19 | hand_traj = [] 20 | for idx in indices: 21 | x, y = traj[idx] 22 | x, y, = x / width, y / height 23 | hand_traj.append(np.array([x, y], dtype=np.float32)) 24 | hand_traj = np.array(hand_traj, dtype=np.float32) 25 | return hand_traj, indices 26 | 27 | 28 | def process_video_info(video_info, fps=4, t_ant=1.0, shape=(456, 256)): 29 | frames_idxs = video_info["frame_indices"] 30 | hand_trajs = video_info["hand_trajs"] 31 | obj_affordance = video_info['affordance']['select_points_homo'] 32 | num_points = obj_affordance.shape[0] 33 | select_idx = np.random.choice(num_points, 1, replace=False) 34 | contact_point = obj_affordance[select_idx] 35 | cx, cy = contact_point[0] 36 | width, height = shape 37 | cx, cy = cx / width, cy/ height 38 | contact_point = np.array([cx, cy], dtype=np.float32) 39 | 40 | valid_mask = [] 41 | if "RIGHT" in hand_trajs: 42 | meta = hand_trajs["RIGHT"] 43 | rhand_traj, indices = sample_hand_traj(meta, fps, t_ant, shape) 44 | valid_mask.append(1) 45 | else: 46 | length = int(fps * t_ant + 1) 47 | rhand_traj = np.repeat(np.array([[0.75, 1.5]], dtype=np.float32), length, axis=0) 48 | valid_mask.append(0) 49 | 50 | if "LEFT" in hand_trajs: 51 | meta = hand_trajs["LEFT"] 52 | lhand_traj, indices = sample_hand_traj(meta, fps, t_ant, shape) 53 | valid_mask.append(1) 54 | else: 55 | length = int(fps * t_ant + 1) 56 | lhand_traj = np.repeat(np.array([[0.25, 1.5]], dtype=np.float32), length, axis=0) 57 | valid_mask.append(0) 58 | 59 | future_hands = np.stack((rhand_traj, lhand_traj), axis=0) 60 | future_valid = np.array(valid_mask, dtype=np.int) 61 | 62 | last_frame_index = frames_idxs[0] 63 | return future_hands, contact_point, future_valid, last_frame_index 64 | 65 | 66 | def process_eval_video_info(video_info, fps=4, t_ant=1.0): 67 | valid_mask = [] 68 | if "RIGHT" in video_info: 69 | rhand_traj = video_info["RIGHT"] 70 | assert rhand_traj.shape[0] == int(fps * t_ant + 1) 71 | valid_mask.append(1) 72 | else: 73 | rhand_traj = np.repeat(np.array([[0.75, 1.5]], dtype=np.float32), int(fps * t_ant + 1), axis=0) 74 | valid_mask.append(0) 75 | 76 | if "LEFT" in video_info: 77 | lhand_traj = video_info['LEFT'] 78 | assert lhand_traj.shape[0] == int(fps * t_ant + 1) 79 | valid_mask.append(1) 80 | else: 81 | lhand_traj = np.repeat(np.array([[0.25, 1.5]], dtype=np.float32), int(fps * t_ant + 1), axis=0) 82 | valid_mask.append(0) 83 | 84 | future_hands = np.stack((rhand_traj, lhand_traj), axis=0) 85 | future_valid = np.array(valid_mask, dtype=np.int) 86 | return future_hands, future_valid 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /datasets/holoaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | from lmdbdict import lmdbdict 5 | from torch.utils.data import DataLoader 6 | 7 | from datasets.dataloaders import EpicDataset, get_datasets 8 | from datasets.ho_utils import load_video_info, process_video_info, process_eval_video_info 9 | 10 | 11 | class FeaturesHOLoader(object): 12 | def __init__(self, sampler, feature_base_path, fps=4.0, input_name='rgb', 13 | frame_tmpl='frame_{:010d}.jpg', 14 | transform_feat=None, transform_video=None, 15 | t_observe=2.5): 16 | 17 | self.feature_base_path = feature_base_path 18 | self.lmdb_path = os.path.join(self.feature_base_path, "data.lmdb") 19 | self.env = lmdbdict(self.lmdb_path, 'r') 20 | self.fps = fps 21 | self.input_name = input_name 22 | self.frame_tmpl = frame_tmpl 23 | self.transform_feat = transform_feat 24 | self.transform_video = transform_video 25 | self.sampler = sampler 26 | self.t_observe = t_observe 27 | self.num_observe = int(t_observe * self.fps) 28 | 29 | def __call__(self, action): 30 | times, frames_idxs = self.sampler(action) 31 | assert self.num_observe <= len(frames_idxs), \ 32 | "num of observation exceed the limit of {}, set smaller t_observe, current is {}".format(len(frames_idxs), self.t_observe) 33 | frames_names = [self.frame_tmpl.format(i) for i in frames_idxs] 34 | start_frame_idx = len(frames_idxs) - self.num_observe 35 | frames_names = frames_names[start_frame_idx:] 36 | 37 | full_names = [] 38 | global_feats, global_masks = [], [] 39 | rhand_feats, rhand_masks, rhand_bboxs = [], [], [] 40 | lhand_feats, lhand_masks, lhand_bboxs = [], [], [] 41 | robj_feats, robj_masks, robj_bboxs = [], [], [] 42 | lobj_feats, lobj_masks, lobj_bboxs = [], [], [] 43 | 44 | for f_name in frames_names: 45 | # full_name: e.g. 'P24/rgb_frames/P24_05/frame_0000075700.jpg' 46 | full_name = os.path.join(action.participant_id, "rgb_frames", action.video_id, f_name) 47 | full_names.append(full_name) 48 | # f_dict: 'GLOBAL_FEAT', 49 | # 'HAND_RIGHT_FEAT', 'HAND_RIGHT_BBOX', 'OBJECT_RIGHT_FEAT', 'OBJECT_RIGHT_BBOX', 50 | # 'HAND_LEFT_FEAT', 'HAND_LEFT_BBOX', 'OBJECT_LEFT_FEAT', 'OBJECT_LEFT_BBOX'] 51 | key_enc = full_name.strip().encode('utf-8') 52 | if key_enc not in self.env: 53 | raise KeyError("invalid key {}, check lmdb file in {}".format(full_name.strip(), self.lmdb_path)) 54 | f_dict = self.env[key_enc] 55 | 56 | global_feat = f_dict['GLOBAL_FEAT'] 57 | global_masks.append(1) 58 | global_feats.append(global_feat) 59 | 60 | if 'HAND_RIGHT_FEAT' in f_dict: 61 | rhand_feat = f_dict['HAND_RIGHT_FEAT'] 62 | else: 63 | rhand_feat = np.zeros_like(global_feat, dtype=np.float32) 64 | rhand_feats.append(rhand_feat) 65 | 66 | if 'HAND_LEFT_FEAT' in f_dict: 67 | lhand_feat = f_dict['HAND_LEFT_FEAT'] 68 | else: 69 | lhand_feat = np.zeros_like(global_feat, dtype=np.float32) 70 | lhand_feats.append(lhand_feat) 71 | 72 | if 'OBJECT_RIGHT_FEAT' in f_dict: 73 | robj_feat = f_dict['OBJECT_RIGHT_FEAT'] 74 | else: 75 | robj_feat = np.zeros_like(global_feat, dtype=np.float32) 76 | robj_feats.append(robj_feat) 77 | 78 | if 'OBJECT_LEFT_FEAT' in f_dict: 79 | lobj_feat = f_dict['OBJECT_LEFT_FEAT'] 80 | else: 81 | lobj_feat = np.zeros_like(global_feat, dtype=np.float32) 82 | lobj_feats.append(lobj_feat) 83 | 84 | if 'HAND_RIGHT_BBOX' in f_dict: 85 | rhand_bbox = f_dict['HAND_RIGHT_BBOX'] 86 | rhand_masks.append(1) 87 | else: 88 | cx, cy = (0.75, 1.5) 89 | sx, sy = (0.1, 0.1) 90 | rhand_bbox = np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2]) 91 | rhand_masks.append(0) 92 | rhand_bboxs.append(rhand_bbox) 93 | 94 | if 'HAND_LEFT_BBOX' in f_dict: 95 | lhand_bbox = f_dict['HAND_LEFT_BBOX'] 96 | lhand_masks.append(1) 97 | else: 98 | cx, cy = (0.25, 1.5) 99 | sx, sy = (0.1, 0.1) 100 | lhand_bbox = np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2]) 101 | lhand_masks.append(0) 102 | lhand_bboxs.append(lhand_bbox) 103 | 104 | if 'OBJECT_RIGHT_BBOX' in f_dict: 105 | robj_bbox = f_dict['OBJECT_RIGHT_BBOX'] 106 | robj_masks.append(1) 107 | else: 108 | robj_bbox = np.array([0.0, 0.0, 1.0, 1.0]) 109 | robj_masks.append(0) 110 | robj_bboxs.append(robj_bbox) 111 | 112 | if 'OBJECT_LEFT_BBOX' in f_dict: 113 | lobj_bbox = f_dict['OBJECT_LEFT_BBOX'] 114 | lobj_masks.append(1) 115 | else: 116 | lobj_bbox = np.array([0.0, 0.0, 1.0, 1.0]) 117 | lobj_masks.append(0) 118 | lobj_bboxs.append(lobj_bbox) 119 | 120 | global_feats = np.stack(global_feats, axis=0) 121 | rhand_feats = np.stack(rhand_feats, axis=0) 122 | lhand_feats = np.stack(lhand_feats, axis=0) 123 | robj_feats = np.stack(robj_feats, axis=0) 124 | lobj_feats = np.stack(lobj_feats, axis=0) 125 | 126 | feats = np.stack((global_feats, rhand_feats, lhand_feats, robj_feats, lobj_feats), axis=0) 127 | 128 | rhand_bboxs = np.stack(rhand_bboxs, axis=0) 129 | lhand_bboxs = np.stack(lhand_bboxs, axis=0) 130 | robj_bboxs = np.stack(robj_bboxs, axis=0) 131 | lobj_bboxs = np.stack(lobj_bboxs, axis=0) 132 | 133 | bbox_feats = np.stack((rhand_bboxs, lhand_bboxs, robj_bboxs, lobj_bboxs), axis=0) 134 | 135 | global_masks = np.stack(global_masks, axis=0) 136 | rhand_masks = np.stack(rhand_masks, axis=0) 137 | lhand_masks = np.stack(lhand_masks, axis=0) 138 | robj_masks = np.stack(robj_masks, axis=0) 139 | lobj_masks = np.stack(lobj_masks, axis=0) 140 | 141 | valid_masks = np.stack((global_masks, rhand_masks, lhand_masks, robj_masks, lobj_masks), axis=0) 142 | 143 | out = {"name": full_names, "feat": feats, "bbox_feat": bbox_feats, "valid_mask": valid_masks, 'times': times, 144 | 'start_time': action.start_time, 'frames_idxs': frames_idxs} 145 | return out 146 | 147 | 148 | class EpicHODataset(EpicDataset): 149 | def __init__(self, df, partition, ori_fps=60.0, fps=4.0, loader=None, t_ant=None, transform=None, 150 | num_actions_prev=None, label_path=None, eval_label_path=None, 151 | annot_path=None, rulstm_annot_path=None, ek_version=None): 152 | super().__init__(df=df, partition=partition, ori_fps=ori_fps, fps=fps, 153 | loader=loader, t_ant=t_ant, transform=transform, 154 | num_actions_prev=num_actions_prev) 155 | self.ek_version = ek_version 156 | self.rulstm_annot_path = rulstm_annot_path 157 | self.annot_path = annot_path 158 | self.shape = (456, 256) 159 | self.discarded_labels, self.discarded_ids = self._get_discarded() 160 | 161 | if 'eval' not in self.partition: 162 | self.label_dir = os.path.join(label_path, "labels") 163 | else: 164 | with open(eval_label_path, 'rb') as f: 165 | self.eval_labels = pickle.load(f) 166 | 167 | def _load_eval_labels(self, uid): 168 | video_info = self.eval_labels[uid] 169 | future_hands, future_valid = process_eval_video_info(video_info, fps=self.fps, t_ant=self.t_ant) 170 | return future_hands, future_valid 171 | 172 | def _load_labels(self, uid): 173 | if os.path.exists(os.path.join(self.label_dir, "label_{}.pkl".format(uid))): 174 | label_valid = True 175 | video_info = load_video_info(self.label_dir, uid) 176 | future_hands, contact_point, future_valid, last_frame_index = process_video_info(video_info, fps=self.fps, 177 | t_ant=self.t_ant, 178 | shape=self.shape) 179 | else: 180 | label_valid = False 181 | length = int(self.fps * self.t_ant + 1) 182 | future_hands = np.zeros((2, length, 2), dtype=np.float32) 183 | contact_point = np.zeros((2,), dtype=np.float32) 184 | future_valid = np.array([0, 0], dtype=np.int) 185 | last_frame_index = None 186 | 187 | return future_hands, contact_point, future_valid, last_frame_index, label_valid 188 | 189 | def _get_discarded(self): 190 | discarded_ids = [] 191 | discarded_labels = [] 192 | if 'train' not in self.partition: 193 | label_type = ['verb', 'noun', 'action'] 194 | else: 195 | label_type = 'action' 196 | if 'test' in self.partition: 197 | challenge = True 198 | else: 199 | challenge = False 200 | 201 | for action in self.actions_invalid: 202 | discarded_ids.append(action.uid) 203 | if isinstance(label_type, list): 204 | if challenge: 205 | discarded_labels.append(-1) 206 | else: 207 | verb, noun, action_class = action.verb_class, action.noun_class, action.action_class 208 | label = np.array([verb, noun, action_class], dtype=np.int) 209 | discarded_labels.append(label) 210 | else: 211 | if challenge: 212 | discarded_labels.append(-1) 213 | else: 214 | action_class = action.action_class 215 | discarded_labels.append(action_class) 216 | return discarded_labels, discarded_ids 217 | 218 | def __getitem__(self, idx): 219 | a = self.actions[idx] 220 | sample = {'uid': a.uid} 221 | 222 | inputs = self.loader(a) 223 | sample.update(inputs) 224 | 225 | if 'eval' not in self.partition: 226 | future_hands, contact_point, future_valid, last_frame_index, label_valid = self._load_labels(a.uid) 227 | 228 | sample['future_hands'] = future_hands 229 | sample['contact_point'] = contact_point 230 | sample['future_valid'] = future_valid 231 | sample['label_valid'] = label_valid 232 | 233 | if "frames_idxs" in sample and last_frame_index is not None: 234 | assert last_frame_index == sample["frames_idxs"][-1], \ 235 | "dataloader video clip {} last observation frame mismatch, " \ 236 | "index load from history s {} while load from future is{}!".format(a.uid, sample["frames_idxs"][-1], 237 | last_frame_index) 238 | 239 | else: 240 | future_hands, future_valid = self._load_eval_labels(a.uid) 241 | sample['future_hands'] = future_hands 242 | sample['future_valid'] = future_valid 243 | sample['label_valid'] = True 244 | # 'contact_point' load in evaluate func 245 | 246 | if 'test' not in self.partition: 247 | sample['verb_class'] = a.verb_class 248 | sample['noun_class'] = a.noun_class 249 | sample['action_class'] = a.action_class 250 | sample['label'] = np.array([a.verb_class, a.noun_class, a.action_class], dtype=np.int) 251 | return sample 252 | 253 | 254 | def get_dataloaders(args, epic_ds=None, featuresloader=None): 255 | dss = get_datasets(args, epic_ds=epic_ds, featuresloader=featuresloader) 256 | 257 | dl_args = { 258 | 'batch_size': args.batch_size, 259 | 'pin_memory': True, 260 | 'num_workers': args.num_workers, 261 | 'drop_last': False 262 | } 263 | if args.mode in ['train', 'training']: 264 | dls = { 265 | 'train': DataLoader(dss['train'], shuffle=True, **dl_args), 266 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args), 267 | 'eval': DataLoader(dss['eval'], shuffle=False, **dl_args) 268 | } 269 | elif args.mode in ['validate', 'validation', 'validating']: 270 | dls = { 271 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args), 272 | 'eval': DataLoader(dss['eval'], shuffle=False, **dl_args) 273 | } 274 | elif args.mode == 'test': 275 | if args.ek_version == "ek55": 276 | dls = { 277 | 'test_s1': DataLoader(dss['test_s1'], shuffle=False, **dl_args), 278 | 'test_s2': DataLoader(dss['test_s2'], shuffle=False, **dl_args), 279 | } 280 | elif args.ek_version == "ek100": 281 | dls = { 282 | 'test': DataLoader(dss['test'], shuffle=False, **dl_args), 283 | } 284 | else: 285 | raise Exception(f'Error. Mode "{args.mode}" not supported.') 286 | return dls 287 | -------------------------------------------------------------------------------- /datasets/input_loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torchvision import transforms 4 | import torch 5 | import os 6 | import lmdb 7 | 8 | 9 | class ActionAnticipationSampler(object): 10 | def __init__(self, t_buffer, t_ant=1.0, fps=4.0, ori_fps=60.0): 11 | self.t_buffer = t_buffer 12 | self.t_ant = t_ant 13 | self.fps = fps 14 | self.ori_fps = ori_fps 15 | 16 | def __call__(self, action): 17 | times, frames_idxs = sample_history_frames(action.start_frame, self.t_buffer, 18 | self.t_ant, fps=self.fps, 19 | fps_init=self.ori_fps) 20 | return times, frames_idxs 21 | 22 | 23 | def get_sampler(args): 24 | sampler = ActionAnticipationSampler(t_buffer=args.t_buffer, t_ant=args.t_ant, 25 | fps=args.fps, ori_fps=args.ori_fps) 26 | return sampler 27 | 28 | 29 | def sample_history_frames(frame_start, t_buffer=2.5, t_ant=1.0, fps=4.0, fps_init=60.0): 30 | time_start = (frame_start - 1) / fps_init 31 | num_frames = int(np.floor(t_buffer * fps)) 32 | time_ant = time_start - t_ant 33 | times = (np.arange(1, num_frames + 1) - num_frames) / fps + time_ant 34 | times = np.clip(times, 0, np.inf) 35 | times = times.astype(np.float32) 36 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1 37 | times = (frames_idxs - 1) / fps_init 38 | return times, frames_idxs 39 | 40 | 41 | def sample_future_frames(frame_start, t_buffer=1, fps=4.0, fps_init=60.0): 42 | time_start = (frame_start - 1) / fps_init 43 | num_frames = int(np.floor(t_buffer * fps)) 44 | times = (np.arange(num_frames + 1) - num_frames) / fps + time_start 45 | times = np.clip(times, 0, np.inf) 46 | times = times.astype(np.float32) 47 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1 48 | if frames_idxs.max() >= 1: 49 | frames_idxs[frames_idxs < 1] = frames_idxs[frames_idxs >= 1].min() 50 | return list(frames_idxs) 51 | 52 | 53 | class FeaturesLoader(object): 54 | def __init__(self, sampler, feature_base_path, fps, input_name='rgb', 55 | frame_tmpl='frame_{:010d}.jpg', transform_feat=None, 56 | transform_video=None): 57 | self.feature_base_path = feature_base_path 58 | self.env = lmdb.open(os.path.join(self.feature_base_path, input_name), readonly=True, lock=False) 59 | self.fps = fps 60 | self.input_name = input_name 61 | self.frame_tmpl = frame_tmpl 62 | self.transform_feat = transform_feat 63 | self.transform_video = transform_video 64 | self.sampler = sampler 65 | 66 | def __call__(self, action): 67 | times, frames_idxs = self.sampler(action) 68 | frames_names = [self.frame_tmpl.format(action.video_id, i) for i in frames_idxs] 69 | feats = [] 70 | with self.env.begin() as env: 71 | for f_name in frames_names: 72 | feat = env.get(f_name.strip().encode('utf-8')) 73 | if feat is None: 74 | print(f_name) 75 | feat = np.frombuffer(feat, 'float32') 76 | 77 | if self.transform_feat is not None: 78 | feat = self.transform_feat(feat) 79 | feats += [feat] 80 | 81 | if self.transform_video is not None: 82 | feats = self.transform_video(feats) 83 | out = {self.input_name: feats} 84 | out['times'] = times 85 | out['start_time'] = action.start_time 86 | out['frames_idxs'] = frames_idxs 87 | return out 88 | 89 | 90 | class PipeLoaders(object): 91 | def __init__(self, loader_list): 92 | self.loader_list = loader_list 93 | 94 | def __call__(self, action): 95 | out = {} 96 | for loader in self.loader_list: 97 | out.update(loader(action)) 98 | return out 99 | 100 | 101 | def get_features_loader(args, featuresloader=None): 102 | sampler = get_sampler(args) 103 | feat_in_modalities = list({'feat'}.intersection(args.modalities)) 104 | transform_feat = lambda x: torch.tensor(x.copy()) 105 | transform_video = lambda x: torch.stack(x, 0) 106 | loader_args = { 107 | 'feature_base_path': args.features_paths[args.ek_version], 108 | 'fps': args.fps, 109 | 'frame_tmpl': 'frame_{:010d}.jpg', 110 | 'transform_feat': transform_feat, 111 | 'transform_video': transform_video, 112 | 'sampler': sampler} 113 | if featuresloader is None: 114 | featuresloader = FeaturesLoader 115 | feat_loader_list = [] 116 | for modality in feat_in_modalities: 117 | feat_loader = featuresloader(input_name=modality, **loader_args) 118 | feat_loader_list += [feat_loader] 119 | feat_loaders = { 120 | 'train': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None, 121 | 'validation': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None, 122 | 'test': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None, 123 | } 124 | return feat_loaders 125 | 126 | 127 | def get_loaders(args, featuresloader=None): 128 | loaders = { 129 | 'train': [], 130 | 'validation': [], 131 | 'test': [], 132 | } 133 | 134 | if 'feat' in args.modalities: 135 | feat_loaders = get_features_loader(args, featuresloader=featuresloader) 136 | for k, l in feat_loaders.items(): 137 | if l is not None: 138 | loaders[k] += [l] 139 | 140 | for k, l in loaders.items(): 141 | loaders[k] = PipeLoaders(l) 142 | return loaders 143 | -------------------------------------------------------------------------------- /demo_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import argparse 4 | import pandas as pd 5 | import cv2 6 | 7 | from preprocess.traj_util import compute_hand_traj 8 | from preprocess.dataset_util import FrameDetections, sample_action_anticipation_frames, fetch_data, save_video_info 9 | from preprocess.obj_util import compute_obj_traj 10 | from preprocess.affordance_util import compute_obj_affordance 11 | from preprocess.vis_util import vis_affordance, vis_hand_traj 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--label_path', default="assets/EPIC_train_action_labels.csv", type=str, help="dataset annotation") 17 | parser.add_argument('--dataset_path', default="assets/EPIC-KITCHENS", type=str, help='dataset root') 18 | parser.add_argument('--save_path', default="./figs", type=str, help="generated results save path") 19 | parser.add_argument('--fps', default=10, type=int, help="sample frames per second") 20 | parser.add_argument('--hand_threshold', default=0.1, type=float, help="hand detection threshold") 21 | parser.add_argument('--obj_threshold', default=0.1, type=float, help="object detection threshold") 22 | parser.add_argument('--contact_ratio', default=0.4, type=float, help="active obj contact frames ratio") 23 | parser.add_argument('--num_sampling', default=20, type=int, help="sampling points for affordance") 24 | parser.add_argument('--num_points', default=5, type=int, help="selected points for affordance") 25 | 26 | args = parser.parse_args() 27 | os.makedirs(args.save_path, exist_ok=True) 28 | save_path = args.save_path 29 | 30 | uid = 52 31 | 32 | annotations = pd.read_csv(args.label_path) 33 | participant_id = annotations.loc[annotations['uid'] == uid].participant_id.item() 34 | video_id = annotations.loc[annotations['uid'] == uid].video_id.item() 35 | frames_path = os.path.join(args.dataset_path, participant_id, "rgb_frames", video_id) 36 | ho_path = os.path.join(args.dataset_path, participant_id, "hand-objects", "{}.pkl".format(video_id)) 37 | start_act_frame = annotations.loc[annotations['uid'] == uid].start_frame.item() 38 | frames_idxs = sample_action_anticipation_frames(start_act_frame, fps=args.fps) 39 | 40 | with open(ho_path, "rb") as f: 41 | video_detections = [FrameDetections.from_protobuf_str(s) for s in pickle.load(f)] 42 | results = fetch_data(frames_path, video_detections, frames_idxs) 43 | if results is None: 44 | print("data fetch failed") 45 | else: 46 | frames_idxs, frames, annots, hand_sides = results 47 | 48 | results_hand = compute_hand_traj(frames, annots, hand_sides, hand_threshold=args.hand_threshold, 49 | obj_threshold=args.obj_threshold) 50 | if results_hand is None: 51 | print("compute traj failed") # homography fails or not enough points 52 | else: 53 | homography_stack, hand_trajs = results_hand 54 | results_obj = compute_obj_traj(frames, annots, hand_sides, homography_stack, 55 | hand_threshold=args.hand_threshold, 56 | obj_threshold=args.obj_threshold, 57 | contact_ratio=args.contact_ratio) 58 | if results_obj is None: 59 | print("compute obj traj failed") 60 | else: 61 | contacts, obj_trajs, active_obj, active_object_idx, obj_bboxs_traj = results_obj 62 | frame, homography = frames[-1], homography_stack[-1] 63 | affordance_info = compute_obj_affordance(frame, annots[-1], active_obj, active_object_idx, homography, 64 | active_obj_traj=obj_trajs['traj'], obj_bboxs_traj=obj_bboxs_traj, 65 | num_points=args.num_points, num_sampling=args.num_sampling) 66 | if affordance_info is not None: 67 | img_vis = vis_hand_traj(frames, hand_trajs) 68 | img_vis = vis_affordance(img_vis, affordance_info) 69 | img = cv2.hconcat([img_vis, frames[-1]]) 70 | cv2.imwrite(os.path.join(save_path, "demo_{}.jpg".format(uid)), img) 71 | save_video_info(save_path, uid, frames_idxs, homography_stack, contacts, hand_trajs, obj_trajs, affordance_info) 72 | print(f"result stored at {save_path}") 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: fhoi 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.6 6 | - pip 7 | - pip: 8 | - pandas 9 | - numpy 10 | - matplotlib 11 | - dataclasses 12 | - protobuf 13 | - Scipy 14 | - opencv-contrib-python==3.4.2.16 -------------------------------------------------------------------------------- /evaluation/affordance_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | from sklearn.cluster import KMeans 5 | from joblib import Parallel, delayed 6 | 7 | 8 | def farthest_sampling(pcd, n_samples): 9 | def compute_distance(a, b): 10 | return np.linalg.norm(a - b, ord=2, axis=2) 11 | 12 | n_pts, dim = pcd.shape[0], pcd.shape[1] 13 | selected_pts_expanded = np.zeros(shape=(n_samples, 1, dim)) 14 | remaining_pts = np.copy(pcd) 15 | 16 | if n_pts > 1: 17 | start_idx = np.random.randint(low=0, high=n_pts - 1) 18 | else: 19 | start_idx = 0 20 | selected_pts_expanded[0] = remaining_pts[start_idx] 21 | n_selected_pts = 1 22 | 23 | for _ in range(1, n_samples): 24 | if n_selected_pts < n_samples: 25 | dist_pts_to_selected = compute_distance(remaining_pts, selected_pts_expanded[:n_selected_pts]).T 26 | dist_pts_to_selected_min = np.min(dist_pts_to_selected, axis=1, keepdims=True) 27 | res_selected_idx = np.argmax(dist_pts_to_selected_min) 28 | selected_pts_expanded[n_selected_pts] = remaining_pts[res_selected_idx] 29 | n_selected_pts += 1 30 | 31 | selected_pts = np.squeeze(selected_pts_expanded, axis=1) 32 | return selected_pts 33 | 34 | 35 | def makeGaussian(size, fwhm=3., center=None): 36 | x = np.arange(0, size, 1, float) 37 | y = x[:, np.newaxis] 38 | if center is None: 39 | x0 = y0 = size // 2 40 | else: 41 | x0 = center[0] 42 | y0 = center[1] 43 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / fwhm ** 2) 44 | 45 | 46 | def compute_heatmap(normalized_points, image_size, k_ratio=3.0, transpose=True, 47 | fps=False, kmeans=False, n_pts=5, gaussian_sigma=0.): 48 | normalized_points = np.asarray(normalized_points) 49 | heatmap = np.zeros((image_size[0], image_size[1]), dtype=np.float32) 50 | n_points = normalized_points.shape[0] 51 | if n_points > n_pts and kmeans: 52 | kmeans = KMeans(n_clusters=n_pts, random_state=0).fit(normalized_points) 53 | normalized_points = kmeans.cluster_centers_ 54 | elif n_points > n_pts and fps: 55 | normalized_points = farthest_sampling(normalized_points, n_samples=n_pts) 56 | n_points = normalized_points.shape[0] 57 | for i in range(n_points): 58 | x = normalized_points[i, 0] * image_size[0] 59 | y = normalized_points[i, 1] * image_size[1] 60 | col = int(x) 61 | row = int(y) 62 | try: 63 | heatmap[col, row] += 1.0 64 | except: 65 | col = min(max(col, 0), image_size[0] - 1) 66 | row = min(max(row, 0), image_size[1] - 1) 67 | heatmap[col, row] += 1.0 68 | k_size = int(np.sqrt(image_size[0] * image_size[1]) / k_ratio) 69 | if k_size % 2 == 0: 70 | k_size += 1 71 | heatmap = cv2.GaussianBlur(heatmap, (k_size, k_size), gaussian_sigma) 72 | if heatmap.max() > 0: 73 | heatmap /= heatmap.max() 74 | if transpose: 75 | heatmap = heatmap.transpose() 76 | return heatmap 77 | 78 | 79 | def SIM(map1, map2, eps=1e-12): 80 | map1, map2 = map1 / (map1.sum() + eps), map2 / (map2.sum() + eps) 81 | intersection = np.minimum(map1, map2) 82 | return np.sum(intersection) 83 | 84 | 85 | def AUC_Judd(saliency_map, fixation_map, jitter=True): 86 | saliency_map = np.array(saliency_map, copy=False) 87 | fixation_map = np.array(fixation_map, copy=False) > 0.5 88 | if not np.any(fixation_map): 89 | return np.nan 90 | if saliency_map.shape != fixation_map.shape: 91 | saliency_map = cv2.resize(saliency_map, fixation_map.shape, interpolation=cv2.INTER_AREA) 92 | if jitter: 93 | saliency_map += np.random.rand(*saliency_map.shape) * 1e-7 94 | saliency_map = (saliency_map - np.min(saliency_map)) / (np.max(saliency_map) - np.min(saliency_map) + 1e-12) 95 | 96 | S = saliency_map.ravel() 97 | F = fixation_map.ravel() 98 | S_fix = S[F] 99 | n_fix = len(S_fix) 100 | n_pixels = len(S) 101 | thresholds = sorted(S_fix, reverse=True) 102 | tp = np.zeros(len(thresholds) + 2) 103 | fp = np.zeros(len(thresholds) + 2) 104 | tp[0] = 0; 105 | tp[-1] = 1 106 | fp[0] = 0; 107 | fp[-1] = 1 108 | for k, thresh in enumerate(thresholds): 109 | above_th = np.sum(S >= thresh) 110 | tp[k + 1] = (k + 1) / float(n_fix) 111 | fp[k + 1] = (above_th - (k + 1)) / float(n_pixels - n_fix) 112 | return np.trapz(tp, fp) 113 | 114 | 115 | def NSS(saliency_map, fixation_map): 116 | MAP = (saliency_map - saliency_map.mean()) / (saliency_map.std()) 117 | mask = fixation_map.astype(np.bool) 118 | score = MAP[mask].mean() 119 | return score 120 | 121 | 122 | def compute_score(pred, gt, valid_thresh=0.001): 123 | if torch.is_tensor(pred): 124 | pred = pred.numpy() 125 | if torch.is_tensor(gt): 126 | gt = gt.numpy() 127 | 128 | pred = pred / (pred.max() + 1e-12) 129 | 130 | all_thresh = np.linspace(0.001, 1.0, 41) 131 | tp = np.zeros((all_thresh.shape[0],)) 132 | fp = np.zeros((all_thresh.shape[0],)) 133 | fn = np.zeros((all_thresh.shape[0],)) 134 | tn = np.zeros((all_thresh.shape[0],)) 135 | valid_gt = gt > valid_thresh 136 | for idx, thresh in enumerate(all_thresh): 137 | mask = (pred >= thresh) 138 | tp[idx] += np.sum(np.logical_and(mask == 1, valid_gt == 1)) 139 | tn[idx] += np.sum(np.logical_and(mask == 0, valid_gt == 0)) 140 | fp[idx] += np.sum(np.logical_and(mask == 1, valid_gt == 0)) 141 | fn[idx] += np.sum(np.logical_and(mask == 0, valid_gt == 1)) 142 | 143 | scores = {} 144 | gt_real = np.array(gt) 145 | if gt_real.sum() == 0: 146 | gt_real = np.ones(gt_real.shape) / np.product(gt_real.shape) 147 | 148 | score = SIM(pred, gt_real) 149 | scores['SIM'] = score if not np.isnan(score) else None 150 | 151 | gt_binary = np.array(gt) 152 | gt_binary = (gt_binary / gt_binary.max() + 1e-12) if gt_binary.max() > 0 else gt_binary 153 | gt_binary = np.where(gt_binary > 0.5, 1, 0) 154 | score = AUC_Judd(pred, gt_binary) 155 | scores['AUC-J'] = score if not np.isnan(score) else None 156 | 157 | score = NSS(pred, gt_binary) 158 | scores['NSS'] = score if not np.isnan(score) else None 159 | 160 | return dict(scores), tp, tn, fp, fn 161 | 162 | 163 | def evaluate_affordance(preds_dict, gts_dict, val_log=None, 164 | sz=32, fps=False, kmeans=False, n_pts=5, 165 | gaussian_sigma=3., gaussian_k_ratio=3.): 166 | scores = [] 167 | all_thresh = np.linspace(0.001, 1.0, 41) 168 | tp = np.zeros((all_thresh.shape[0],)) 169 | fp = np.zeros((all_thresh.shape[0],)) 170 | fn = np.zeros((all_thresh.shape[0],)) 171 | tn = np.zeros((all_thresh.shape[0],)) 172 | 173 | pred_hmaps = Parallel(n_jobs=16, verbose=0)(delayed(compute_heatmap)(norm_contacts, (sz, sz), 174 | fps=fps, kmeans=kmeans, n_pts=n_pts, 175 | gaussian_sigma=gaussian_sigma, 176 | k_ratio=gaussian_k_ratio) 177 | for (uid, norm_contacts) in preds_dict.items()) 178 | gt_hmaps = Parallel(n_jobs=16, verbose=0)(delayed(compute_heatmap)(norm_contacts, (sz, sz), 179 | fps=fps, n_pts=n_pts, 180 | gaussian_sigma=0, 181 | k_ratio=3.) 182 | for (uid, norm_contacts) in gts_dict.items()) 183 | 184 | for (pred_hmap, gt_hmap) in zip(pred_hmaps, gt_hmaps): 185 | score, ctp, ctn, cfp, cfn = compute_score(pred_hmap, gt_hmap) 186 | scores.append(score) 187 | tp = tp + ctp 188 | tn = tn + ctn 189 | fp = fp + cfp 190 | fn = fn + cfn 191 | 192 | write_out = [] 193 | metrics = {} 194 | for key in ['SIM', 'AUC-J', 'NSS']: 195 | key_score = [s[key] for s in scores if s[key] is not None] 196 | mean, stderr = np.mean(key_score), np.std(key_score) / (np.sqrt(len(key_score))) 197 | log_str = '%s: %.3f ± %.3f (%d/%d)' % (key, mean, stderr, len(key_score), len(gts_dict)) 198 | write_out.append(log_str) 199 | metrics[key] = mean 200 | write_out = '\n'.join(write_out) 201 | print(write_out) 202 | if val_log is not None: 203 | with open(val_log, "a") as log_file: 204 | log_file.write(write_out + '\n') 205 | 206 | write_out = [] 207 | prec = tp / (tp + fp + 1e-6) 208 | recall = tp / (tp + fn + 1e-6) 209 | f1 = 2 * prec * recall / (prec + recall + 1e-6) 210 | idx = np.argmax(f1) 211 | prec_score = prec[idx] 212 | f1_score = f1[idx] 213 | recall_score = recall[idx] 214 | 215 | log_str = 'Precision: {:.3f}'.format(prec_score) 216 | write_out.append(log_str) 217 | log_str = 'Recall: {:0.4f}'.format(recall_score) 218 | write_out.append(log_str) 219 | log_str = 'F1 Score: {:0.4f}'.format(f1_score) 220 | write_out.append(log_str) 221 | write_out = '\n'.join(write_out) 222 | print(write_out) 223 | if val_log is not None: 224 | with open(val_log, "a") as log_file: 225 | log_file.write(write_out + '\n') 226 | 227 | return metrics -------------------------------------------------------------------------------- /evaluation/traj_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def compute_ade(pred_traj, gt_traj, valid_traj=None, reduction=True): 6 | valid_loc = (gt_traj[:, :, :, 0] >= 0) & (gt_traj[:, :, :, 1] >= 0) \ 7 | & (gt_traj[:, :, :, 0] < 1) & (gt_traj[:, :, :, 1] < 1) 8 | 9 | error = gt_traj - pred_traj 10 | error = error * valid_loc[:, :, :, None] 11 | 12 | if torch.is_tensor(error): 13 | if valid_traj is None: 14 | valid_traj = torch.ones(pred_traj.shape[0], pred_traj.shape[1]) 15 | error = error ** 2 16 | ade = torch.sqrt(error.sum(dim=3)).mean(dim=2) * valid_traj 17 | if reduction: 18 | ade = ade.sum() / valid_traj.sum() 19 | valid_traj = valid_traj.sum() 20 | else: 21 | if valid_traj is None: 22 | valid_traj = np.ones((pred_traj.shape[0], pred_traj.shape[1]), dtype=int) 23 | error = np.linalg.norm(error, axis=3) 24 | ade = error.mean(axis=2) * valid_traj 25 | if reduction: 26 | ade = ade.sum() / valid_traj.sum() 27 | valid_traj = valid_traj.sum() 28 | 29 | return ade, valid_traj 30 | 31 | 32 | def compute_fde(pred_traj, gt_traj, valid_traj=None, reduction=True): 33 | pred_last = pred_traj[:, :, -1, :] 34 | gt_last = gt_traj[:, :, -1, :] 35 | 36 | valid_loc = (gt_last[:, :, 0] >= 0) & (gt_last[:, :, 1] >= 0) \ 37 | & (gt_last[:, :, 0] < 1) & (gt_last[:, :, 1] < 1) 38 | 39 | error = gt_last - pred_last 40 | error = error * valid_loc[:, :, None] 41 | 42 | if torch.is_tensor(error): 43 | if valid_traj is None: 44 | valid_traj = torch.ones(pred_traj.shape[0], pred_traj.shape[1]) 45 | error = error ** 2 46 | fde = torch.sqrt(error.sum(dim=2)) * valid_traj 47 | if reduction: 48 | fde = fde.sum() / valid_traj.sum() 49 | valid_traj = valid_traj.sum() 50 | else: 51 | if valid_traj is None: 52 | valid_traj = np.ones((pred_traj.shape[0], pred_traj.shape[1]), dtype=int) 53 | error = np.linalg.norm(error, axis=2) 54 | fde = error * valid_traj 55 | if reduction: 56 | fde = fde.sum() / valid_traj.sum() 57 | valid_traj = valid_traj.sum() 58 | 59 | return fde, valid_traj 60 | 61 | 62 | def evaluate_traj_stochastic(preds, gts, valids): 63 | len_dataset, num_samples, num_obj = preds.shape[0], preds.shape[1], preds.shape[2] 64 | ade_list, fde_list = [], [] 65 | for idx in range(num_samples): 66 | ade, _ = compute_fde(preds[:, idx, :, :, :], gts, valids, reduction=False) 67 | ade_list.append(ade) 68 | fde, _ = compute_ade(preds[:, idx, :, :, :], gts, valids, reduction=False) 69 | fde_list.append(fde) 70 | 71 | if torch.is_tensor(preds): 72 | ade_list = torch.stack(ade_list, dim=0) 73 | fde_list = torch.stack(fde_list, dim=0) 74 | 75 | ade_err_min, _ = torch.min(ade_list, dim=0) 76 | ade_err_min = ade_err_min * valids 77 | fde_err_min, _ = torch.min(fde_list, dim=0) 78 | fde_err_min = fde_err_min * valids 79 | 80 | ade_err_mean = torch.mean(ade_list, dim=0) 81 | ade_err_mean = ade_err_mean * valids 82 | fde_err_mean = torch.mean(fde_list, dim=0) 83 | fde_err_mean = fde_err_mean * valids 84 | 85 | ade_err_std = torch.std(ade_list, dim=0) * np.sqrt((ade_list.shape[0] - 1.) / ade_list.shape[0]) 86 | ade_err_std = ade_err_std * valids 87 | fde_err_std = torch.std(fde_list, dim=0) * np.sqrt((fde_list.shape[0] - 1.) / fde_list.shape[0]) 88 | fde_err_std = fde_err_std * valids 89 | 90 | else: 91 | ade_list = np.array(ade_list, dtype=np.float32) 92 | fde_list = np.array(fde_list, dtype=np.float32) 93 | 94 | ade_err_min = ade_list.min(axis=0) * valids 95 | fde_err_min = fde_list.min(axis=0) * valids 96 | 97 | ade_err_mean = ade_list.mean(axis=0) * valids 98 | fde_err_mean = fde_list.mean(axis=0) * valids 99 | 100 | ade_err_std = ade_list.std(axis=0) * valids 101 | fde_err_std = fde_list.std(axis=0) * valids 102 | 103 | ade_mean = ade_err_mean.sum() / valids.sum() 104 | fde_mean = fde_err_mean.sum() / valids.sum() 105 | 106 | ade_std = ade_err_std.sum() / valids.sum() 107 | fde_std = fde_err_std.sum() / valids.sum() 108 | ade_mean_info = 'ADE: %.3f ± %.3f (%d/%d)' % (ade_mean, ade_std, valids.sum(), len_dataset * num_obj) 109 | fde_mean_info = "FDE: %.3f ± %.3f (%d/%d)" % (fde_mean, fde_std, valids.sum(), len_dataset * num_obj) 110 | 111 | ade_min = ade_err_min.sum() / valids.sum() 112 | fde_min = fde_err_min.sum() / valids.sum() 113 | ade_min_info = 'min ADE: %.3f (%d/%d)' % (ade_min, valids.sum(), len_dataset * num_obj) 114 | fde_min_info = "min FDE: %.3f (%d/%d)" % (fde_min, valids.sum(), len_dataset * num_obj) 115 | 116 | print(ade_min_info) 117 | print(fde_min_info) 118 | print(ade_mean_info) 119 | print(fde_mean_info) 120 | 121 | return ade_mean, fde_mean 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /netscripts/epoch_feat.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | 5 | from netscripts.epoch_utils import progress_bar as bar, AverageMeters 6 | from evaluation.traj_eval import evaluate_traj_stochastic 7 | from evaluation.affordance_eval import evaluate_affordance 8 | 9 | 10 | def epoch_pass(loader, model, epoch, phase, optimizer=None, 11 | train=True, use_cuda=False, 12 | num_samples=5, pred_len=4, num_points=5, gaussian_sigma=3., gaussian_k_ratio=3., 13 | scheduler=None): 14 | time_meters = AverageMeters() 15 | 16 | if train: 17 | print(f"{phase} epoch: {epoch + 1}") 18 | loss_meters = AverageMeters() 19 | model.train() 20 | else: 21 | print(f"evaluate epoch {epoch}") 22 | preds_traj, gts_traj, valids_traj = [], [], [] 23 | gts_affordance_dict, preds_affordance_dict = {}, {} 24 | model.eval() 25 | 26 | if use_cuda and torch.cuda.is_available(): 27 | device = torch.device('cuda') 28 | else: 29 | device = torch.device("cpu") 30 | end = time.time() 31 | for batch_idx, sample in enumerate(loader): 32 | if train: 33 | input = sample['feat'].float().to(device) 34 | bbox_feat = sample['bbox_feat'].float().to(device) 35 | valid_mask = sample['valid_mask'].float().to(device) 36 | future_hands = sample['future_hands'].float().to(device) 37 | contact_point = sample['contact_point'].float().to(device) 38 | future_valid = sample['future_valid'].float().to(device) 39 | time_meters.add_loss_value("data_time", time.time() - end) 40 | model_loss, model_losses = model(input, bbox_feat=bbox_feat, 41 | valid_mask=valid_mask, future_hands=future_hands, 42 | contact_point=contact_point, future_valid=future_valid) 43 | 44 | optimizer.zero_grad() 45 | model_loss.backward() 46 | optimizer.step() 47 | 48 | for key, val in model_losses.items(): 49 | if val is not None: 50 | loss_meters.add_loss_value(key, val) 51 | 52 | time_meters.add_loss_value("batch_time", time.time() - end) 53 | 54 | suffix = "({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s " \ 55 | "| Hand Traj Loss: {traj_loss:.3f} " \ 56 | "| Hand Traj KL Loss: {traj_kl_loss:.3f} " \ 57 | "| Object Affordance Loss: {obj_loss:.3f} " \ 58 | "| Object Affordance KL Loss: {obj_kl_loss:.3f} " \ 59 | "| Total Loss: {total_loss:.3f} ".format(batch=batch_idx + 1, size=len(loader), 60 | data=time_meters.average_meters["data_time"].val, 61 | bt=time_meters.average_meters["batch_time"].avg, 62 | traj_loss=loss_meters.average_meters["traj_loss"].avg, 63 | traj_kl_loss=loss_meters.average_meters[ 64 | "traj_kl_loss"].avg, 65 | obj_loss=loss_meters.average_meters["obj_loss"].avg, 66 | obj_kl_loss=loss_meters.average_meters["obj_kl_loss"].avg, 67 | total_loss=loss_meters.average_meters[ 68 | "total_loss"].avg) 69 | bar(suffix) 70 | end = time.time() 71 | if scheduler is not None: 72 | scheduler.step() 73 | else: 74 | input = sample['feat'].float().to(device) 75 | bbox_feat = sample['bbox_feat'].float().to(device) 76 | valid_mask = sample['valid_mask'].float().to(device) 77 | future_valid = sample['future_valid'].float().to(device) 78 | 79 | time_meters.add_loss_value("data_time", time.time() - end) 80 | 81 | with torch.no_grad(): 82 | pred_future_hands, contact_points = model(input, bbox_feat, valid_mask, 83 | num_samples=num_samples, 84 | future_valid=future_valid, 85 | pred_len=pred_len) 86 | 87 | uids = sample['uid'].numpy() 88 | future_hands = sample['future_hands'][:, :, 1:, :].float().numpy() 89 | future_valid = sample['future_valid'].float().numpy() 90 | 91 | gts_traj.append(future_hands) 92 | valids_traj.append(future_valid) 93 | 94 | pred_future_hands = pred_future_hands.cpu().numpy() 95 | preds_traj.append(pred_future_hands) 96 | 97 | if 'eval' in loader.dataset.partition: 98 | contact_points = contact_points.cpu().numpy() 99 | for idx, uid in enumerate(uids): 100 | gts_affordance_dict[uid] = loader.dataset.eval_labels[uid]['norm_contacts'] 101 | preds_affordance_dict[uid] = contact_points[idx] 102 | 103 | time_meters.add_loss_value("batch_time", time.time() - end) 104 | 105 | suffix = "({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s" \ 106 | .format(batch=batch_idx + 1, size=len(loader), 107 | data=time_meters.average_meters["data_time"].val, 108 | bt=time_meters.average_meters["batch_time"].avg) 109 | 110 | bar(suffix) 111 | end = time.time() 112 | 113 | if train: 114 | return loss_meters 115 | else: 116 | val_info = {} 117 | if phase == "traj": 118 | gts_traj = np.concatenate(gts_traj) 119 | preds_traj = np.concatenate(preds_traj) 120 | valids_traj = np.concatenate(valids_traj) 121 | 122 | ade, fde = evaluate_traj_stochastic(preds_traj, gts_traj, valids_traj) 123 | val_info.update({"traj_ade": ade, "traj_fde": fde}) 124 | 125 | if 'eval' in loader.dataset.partition and phase == "affordance": 126 | affordance_metrics = evaluate_affordance(preds_affordance_dict, 127 | gts_affordance_dict, 128 | n_pts=num_points, 129 | gaussian_sigma=gaussian_sigma, 130 | gaussian_k_ratio=gaussian_k_ratio) 131 | val_info.update(affordance_metrics) 132 | 133 | return val_info 134 | -------------------------------------------------------------------------------- /netscripts/epoch_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def progress_bar(msg=None): 5 | 6 | L = [] 7 | if msg: 8 | L.append(msg) 9 | 10 | msg = ''.join(L) 11 | sys.stdout.write(msg+'\n') 12 | 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | 33 | class AverageMeters: 34 | def __init__(self): 35 | super().__init__() 36 | self.average_meters = {} 37 | 38 | def add_loss_value(self, loss_name, loss_val, n=1): 39 | if loss_name not in self.average_meters: 40 | self.average_meters[loss_name] = AverageMeter() 41 | self.average_meters[loss_name].update(loss_val, n=n) -------------------------------------------------------------------------------- /netscripts/get_datasets.py: -------------------------------------------------------------------------------- 1 | from datasets.datasetopts import DatasetArgs 2 | from datasets.holoaders import EpicHODataset as HODataset, FeaturesHOLoader, get_dataloaders 3 | 4 | 5 | def get_dataset(args, base_path="./"): 6 | if args.evaluate: 7 | mode = "validation" 8 | else: 9 | mode = 'train' 10 | 11 | datasetargs = DatasetArgs(ek_version=args.ek_version, mode=mode, 12 | use_label_only=True, base_path=base_path, 13 | batch_size=args.batch_size, num_workers=args.workers) 14 | 15 | dls = get_dataloaders(datasetargs, HODataset, featuresloader=FeaturesHOLoader) 16 | return mode, dls 17 | -------------------------------------------------------------------------------- /netscripts/get_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks.traj_decoder import TrajCVAE 3 | from networks.affordance_decoder import AffordanceCVAE 4 | from networks.transformer import ObjectTransformer 5 | from networks.model import Model 6 | 7 | 8 | def get_network(args, num_frames_input=10, num_frames_output=4): 9 | hand_head = TrajCVAE(in_dim=2, hidden_dim=args.hidden_dim, 10 | latent_dim=args.latent_dim, condition_dim=args.embed_dim, 11 | coord_dim=args.coord_dim) 12 | obj_head = AffordanceCVAE(in_dim=2, hidden_dim=args.hidden_dim, 13 | latent_dim=args.latent_dim, condition_dim=args.embed_dim) 14 | net = ObjectTransformer(src_in_features=args.src_in_features, 15 | trg_in_features=args.trg_in_features, 16 | num_patches=args.num_patches, 17 | hand_head=hand_head, obj_head=obj_head, 18 | encoder_time_embed_type=args.encoder_time_embed_type, 19 | decoder_time_embed_type=args.decoder_time_embed_type, 20 | num_frames_input=num_frames_input, 21 | num_frames_output=num_frames_output, 22 | embed_dim=args.embed_dim, coord_dim=args.coord_dim, 23 | num_heads=args.num_heads, enc_depth=args.enc_depth, dec_depth=args.dec_depth) 24 | net = torch.nn.DataParallel(net) 25 | 26 | model = Model(net, lambda_obj=args.lambda_obj, lambda_traj=args.lambda_traj, 27 | lambda_obj_kl=args.lambda_obj_kl, lambda_traj_kl=args.lambda_traj_kl) 28 | return model 29 | -------------------------------------------------------------------------------- /netscripts/get_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Warmup(torch.optim.lr_scheduler._LRScheduler): 5 | def __init__( 6 | self, 7 | optimizer: torch.optim.Optimizer, 8 | scheduler: torch.optim.lr_scheduler._LRScheduler, 9 | init_lr_ratio: float = 0.0, 10 | num_epochs: int = 5, 11 | last_epoch: int = -1, 12 | iters_per_epoch: int = None, 13 | ): 14 | self.base_scheduler = scheduler 15 | self.warmup_iters = max(num_epochs * iters_per_epoch, 1) 16 | if self.warmup_iters > 1: 17 | self.init_lr_ratio = init_lr_ratio 18 | else: 19 | self.init_lr_ratio = 1.0 20 | super().__init__(optimizer, last_epoch) 21 | 22 | def get_lr(self): 23 | assert self.last_epoch < self.warmup_iters 24 | return [ 25 | el * (self.init_lr_ratio + (1 - self.init_lr_ratio) * 26 | (float(self.last_epoch) / self.warmup_iters)) 27 | for el in self.base_lrs 28 | ] 29 | 30 | def step(self, *args, **kwargs): 31 | if self.last_epoch < (self.warmup_iters - 1): 32 | super().step(*args, **kwargs) 33 | else: 34 | self.base_scheduler.step(*args, **kwargs) 35 | 36 | 37 | def get_optimizer(args, model, train_loader): 38 | assert train_loader is not None, "train_loader is None, " \ 39 | "warmup or cosine learning rate need number of iterations in dataloader" 40 | iters_per_epoch = len(train_loader) 41 | vae_params = [p for p_name, p in model.named_parameters() 42 | if ('vae' in p_name or 'head' in p_name) and p.requires_grad] 43 | other_params = [p for p_name, p in model.named_parameters() 44 | if ('vae' not in p_name and 'head' not in p_name) and p.requires_grad] 45 | 46 | if args.optimizer == "adam": 47 | optimizer = torch.optim.Adam([{'params': vae_params, 'weight_decay': 0.0}, {'params': other_params}], 48 | lr=args.lr, weight_decay=args.weight_decay) 49 | elif args.optimizer == "rms": 50 | optimizer = torch.optim.RMSprop([{'params': vae_params, 'weight_decay': 0.0}, {'params': other_params}], 51 | lr=args.lr, weight_decay=args.weight_decay) 52 | elif args.optimizer == "sgd": 53 | optimizer = torch.optim.SGD([{'params': vae_params, 'weight_decay': 0.0}, {'params': other_params}], 54 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 55 | elif args.optimizer == 'adamw': 56 | optimizer = torch.optim.AdamW([{'params': vae_params, 'weight_decay': 0.0}, {'params': other_params}], lr=args.lr) 57 | else: 58 | raise ValueError("unsupported optimizer type") 59 | 60 | for group in optimizer.param_groups: 61 | group["lr"] = args.lr 62 | group["initial_lr"] = args.lr 63 | 64 | if args.scheduler == "step": 65 | assert isinstance(args.lr_decay_step, int), "learning rate scheduler need integar lr_decay_step" 66 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma) 67 | elif args.scheduler == "multistep": 68 | if isinstance(args.lr_decay_step, list): 69 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma) 70 | else: 71 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.epochs // 2, gamma=0.1) 72 | elif args.scheduler == "cosine": 73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs*iters_per_epoch, 74 | last_epoch=-1, eta_min=0) 75 | else: 76 | raise ValueError("Unrecognized learning rate scheduler {}".format(args.scheduler)) 77 | 78 | main_scheduler = Warmup(optimizer, scheduler, init_lr_ratio=0., num_epochs=args.warmup_epochs, 79 | iters_per_epoch=iters_per_epoch) 80 | return optimizer, main_scheduler 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /netscripts/modelio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import traceback 4 | import warnings 5 | import torch 6 | 7 | 8 | def load_checkpoint(model, resume_path, strict=True, device=None): 9 | if os.path.isfile(resume_path): 10 | print("=> loading checkpoint '{}'".format(resume_path)) 11 | if device is not None: 12 | checkpoint = torch.load(resume_path, map_location=device) 13 | else: 14 | checkpoint = torch.load(resume_path) 15 | if "module" in list(checkpoint["state_dict"].keys())[0]: 16 | state_dict = checkpoint["state_dict"] 17 | else: 18 | state_dict = { 19 | "module.{}".format(key): item 20 | for key, item in checkpoint["state_dict"].items()} 21 | print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint["epoch"])) 22 | missing_states = set(model.state_dict().keys()) - set(state_dict.keys()) 23 | if len(missing_states) > 0: 24 | warnings.warn("Missing keys ! : {}".format(missing_states)) 25 | model.load_state_dict(state_dict, strict=strict) 26 | else: 27 | raise ValueError("=> no checkpoint found at '{}'".format(resume_path)) 28 | return checkpoint["epoch"] 29 | 30 | 31 | def save_checkpoint(state, checkpoint="checkpoint", filename="checkpoint.pth.tar"): 32 | filepath = os.path.join(checkpoint, filename) 33 | torch.save(state, filepath) 34 | -------------------------------------------------------------------------------- /networks/affordance_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from networks.decoder_modules import VAE 4 | 5 | 6 | class AffordanceCVAE(nn.Module): 7 | def __init__(self, in_dim, hidden_dim, latent_dim, condition_dim, coord_dim=None, 8 | pred_len=4, condition_traj=True, z_scale=2.0): 9 | super().__init__() 10 | self.latent_dim = latent_dim 11 | self.condition_traj = condition_traj 12 | self.z_scale = z_scale 13 | if self.condition_traj: 14 | if coord_dim is None: 15 | coord_dim = hidden_dim // 2 16 | self.coord_dim = coord_dim 17 | self.traj_to_feature = nn.Sequential( 18 | nn.Linear(2*(pred_len+1), coord_dim*(pred_len+1), bias=False), 19 | nn.ELU(inplace=True)) 20 | self.traj_context_fusion = nn.Sequential( 21 | nn.Linear(condition_dim+coord_dim*(pred_len+1), condition_dim, bias=False), 22 | nn.ELU(inplace=True)) 23 | 24 | self.cvae = VAE(in_dim=in_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, 25 | conditional=True, condition_dim=condition_dim) 26 | 27 | def forward(self, context, contact_point, hand_traj=None, return_pred=False): 28 | if self.condition_traj: 29 | assert hand_traj is not None 30 | batch_size = context.shape[0] 31 | hand_traj = hand_traj.reshape(batch_size, -1) 32 | traj_feat = self.traj_to_feature(hand_traj) 33 | fusion_feat = torch.cat([context, traj_feat], dim=1) 34 | condition_context = self.traj_context_fusion(fusion_feat) 35 | else: 36 | condition_context = context 37 | if not return_pred: 38 | recon_loss, KLD = self.cvae(contact_point, c=condition_context) 39 | return recon_loss, KLD 40 | else: 41 | pred_contact, recon_loss, KLD = self.cvae(contact_point, c=condition_context, return_pred=return_pred) 42 | return pred_contact, recon_loss, KLD 43 | 44 | def inference(self, context, hand_traj=None): 45 | if self.condition_traj: 46 | assert hand_traj is not None 47 | batch_size = context.shape[0] 48 | hand_traj = hand_traj.reshape(batch_size, -1) 49 | traj_feat = self.traj_to_feature(hand_traj) 50 | fusion_feat = torch.cat([context, traj_feat], dim=1) 51 | condition_context = self.traj_context_fusion(fusion_feat) 52 | else: 53 | condition_context = context 54 | z = self.z_scale * torch.randn([condition_context.shape[0], self.latent_dim], device=condition_context.device) 55 | recon_x = self.cvae.inference(z, c=condition_context) 56 | return recon_x -------------------------------------------------------------------------------- /networks/decoder_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VAE(nn.Module): 6 | 7 | def __init__(self, in_dim, hidden_dim, latent_dim, conditional=False, condition_dim=None): 8 | 9 | super().__init__() 10 | 11 | self.latent_dim = latent_dim 12 | self.conditional = conditional 13 | 14 | if self.conditional and condition_dim is not None: 15 | input_dim = in_dim + condition_dim 16 | dec_dim = latent_dim + condition_dim 17 | else: 18 | input_dim = in_dim 19 | dec_dim = latent_dim 20 | self.enc_MLP = nn.Sequential( 21 | nn.Linear(input_dim, hidden_dim), 22 | nn.ELU()) 23 | self.linear_means = nn.Linear(hidden_dim, latent_dim) 24 | self.linear_log_var = nn.Linear(hidden_dim, latent_dim) 25 | self.dec_MLP = nn.Sequential( 26 | nn.Linear(dec_dim, hidden_dim), 27 | nn.ELU(), 28 | nn.Linear(hidden_dim, in_dim)) 29 | 30 | def forward(self, x, c=None, return_pred=False): 31 | if self.conditional and c is not None: 32 | inp = torch.cat((x, c), dim=-1) 33 | else: 34 | inp = x 35 | h = self.enc_MLP(inp) 36 | mean = self.linear_means(h) 37 | log_var = self.linear_log_var(h) 38 | z = self.reparameterize(mean, log_var) 39 | if self.conditional and c is not None: 40 | z = torch.cat((z, c), dim=-1) 41 | recon_x = self.dec_MLP(z) 42 | recon_loss, KLD = self.loss_fn(recon_x, x, mean, log_var) 43 | if not return_pred: 44 | return recon_loss, KLD 45 | else: 46 | return recon_x, recon_loss, KLD 47 | 48 | def loss_fn(self, recon_x, x, mean, log_var): 49 | recon_loss = torch.sum((recon_x - x) ** 2, dim=1) 50 | KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1) 51 | return recon_loss, KLD 52 | 53 | def reparameterize(self, mu, log_var): 54 | std = torch.exp(0.5 * log_var) 55 | eps = torch.randn_like(std) 56 | return mu + eps * std 57 | 58 | def inference(self, z, c=None): 59 | if self.conditional and c is not None: 60 | z = torch.cat((z, c), dim=-1) 61 | recon_x = self.dec_MLP(z) 62 | return recon_x -------------------------------------------------------------------------------- /networks/embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | def __init__(self, d_model, max_len=5000): 9 | super(PositionalEncoding, self).__init__() 10 | pe = torch.zeros(max_len, d_model) 11 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 12 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 13 | pe[:, 0::2] = torch.sin(position * div_term) 14 | pe[:, 1::2] = torch.cos(position * div_term) 15 | pe = pe.unsqueeze(0) 16 | 17 | self.register_buffer('pe', pe) 18 | 19 | def forward(self, x): 20 | return x + self.pe[:, :x.shape[1]] 21 | 22 | 23 | class Encoder_PositionalEmbedding(nn.Module): 24 | def __init__(self, d_model, seq_len): 25 | super(Encoder_PositionalEmbedding, self).__init__() 26 | self.position_embedding = nn.Parameter(torch.zeros(1, seq_len, d_model)) 27 | 28 | def forward(self, x): 29 | B, T = x.shape[:2] 30 | if T != self.position_embedding.size(1): 31 | position_embedding = self.position_embedding.transpose(1, 2) 32 | new_position_embedding = F.interpolate(position_embedding, size=(T), mode='nearest') 33 | new_position_embedding = new_position_embedding.transpose(1, 2) 34 | x = x + new_position_embedding 35 | else: 36 | x = x + self.position_embedding 37 | return x 38 | 39 | 40 | class Decoder_PositionalEmbedding(nn.Module): 41 | def __init__(self, d_model, seq_len): 42 | super(Decoder_PositionalEmbedding, self).__init__() 43 | self.position_embedding = nn.Parameter(torch.zeros(1, seq_len, d_model)) 44 | 45 | def forward(self, x): 46 | x = x + self.position_embedding[:, :x.shape[1], :] 47 | return x 48 | 49 | -------------------------------------------------------------------------------- /networks/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.net_utils import DropPath, get_pad_mask 5 | from einops import rearrange 6 | 7 | 8 | class Mlp(nn.Module): 9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 10 | super().__init__() 11 | out_features = out_features or in_features 12 | hidden_features = hidden_features or in_features 13 | self.fc1 = nn.Linear(in_features, hidden_features) 14 | self.act = act_layer() 15 | self.fc2 = nn.Linear(hidden_features, out_features) 16 | self.drop = nn.Dropout(drop) 17 | 18 | def forward(self, x): 19 | x = self.fc1(x) 20 | x = self.act(x) 21 | x = self.drop(x) 22 | x = self.fc2(x) 23 | x = self.drop(x) 24 | return x 25 | 26 | 27 | class ScaledDotProductAttention(nn.Module): 28 | def __init__(self, temperature, attn_dropout=0.1): 29 | super().__init__() 30 | self.temperature = temperature 31 | self.dropout = nn.Dropout(attn_dropout) 32 | 33 | def forward(self, q, k, v, mask=None): 34 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 35 | if mask is not None: 36 | attn = attn.masked_fill(mask == 0, -1e9) 37 | attn = self.dropout(F.softmax(attn, dim=-1)) 38 | output = torch.matmul(attn, v) 39 | 40 | return output, attn 41 | 42 | 43 | class MultiHeadAttention(nn.Module): 44 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True): 45 | super().__init__() 46 | self.num_heads = num_heads 47 | head_dim = dim // num_heads 48 | self.with_qkv = with_qkv 49 | if self.with_qkv: 50 | self.proj_q = nn.Linear(dim, dim, bias=qkv_bias) 51 | self.proj_k = nn.Linear(dim, dim, bias=qkv_bias) 52 | self.proj_v = nn.Linear(dim, dim, bias=qkv_bias) 53 | 54 | self.proj = nn.Linear(dim, dim) 55 | self.proj_drop = nn.Dropout(proj_drop) 56 | self.attention = ScaledDotProductAttention(temperature=qk_scale or head_dim ** 0.5) 57 | self.attn_drop = nn.Dropout(attn_drop) 58 | 59 | def forward(self, q, k, v, mask=None): 60 | B, Nq, Nk, Nv, C = q.shape[0], q.shape[1], k.shape[1], v.shape[1], q.shape[2] 61 | if self.with_qkv: 62 | q = self.proj_q(q).reshape(B, Nq, self.num_heads, C // self.num_heads).transpose(1, 2) 63 | k = self.proj_k(k).reshape(B, Nk, self.num_heads, C // self.num_heads).transpose(1, 2) 64 | v = self.proj_v(v).reshape(B, Nv, self.num_heads, C // self.num_heads).transpose(1, 2) 65 | else: 66 | q = q.reshape(B, Nq, self.num_heads, C // self.num_heads).transpose(1, 2) 67 | k = k.reshape(B, Nk, self.num_heads, C // self.num_heads).transpose(1, 2) 68 | v = v.reshape(B, Nv, self.num_heads, C // self.num_heads).transpose(1, 2) 69 | if mask is not None: 70 | mask = mask.unsqueeze(1) 71 | 72 | x, attn = self.attention(q, k, v, mask=mask) 73 | x = x.transpose(1, 2).reshape(B, Nq, C) 74 | if self.with_qkv: 75 | x = self.proj(x) 76 | x = self.proj_drop(x) 77 | return x 78 | 79 | 80 | class EncoderBlock(nn.Module): 81 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 82 | drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm): 83 | super().__init__() 84 | self.norm1 = norm_layer(dim) 85 | self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 86 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 87 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 88 | self.norm2 = norm_layer(dim) 89 | mlp_hidden_dim = int(dim * mlp_ratio) 90 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 91 | 92 | def forward(self, x, B, T, N, mask=None): 93 | 94 | if mask is not None: 95 | src_mask = rearrange(mask, 'b n t -> b (n t)', b=B, n=N, t=T) 96 | src_mask = get_pad_mask(src_mask, 0) 97 | else: 98 | src_mask = None 99 | x2 = self.norm1(x) 100 | x = x + self.drop_path(self.attn(q=x2, k=x2, v=x2, mask=src_mask)) 101 | x = x + self.drop_path(self.mlp(self.norm2(x))) 102 | return x 103 | 104 | 105 | class DecoderBlock(nn.Module): 106 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 107 | drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm): 108 | super().__init__() 109 | self.norm1 = norm_layer(dim) 110 | self.self_attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 111 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 112 | 113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 114 | 115 | self.norm2 = norm_layer(dim) 116 | self.enc_dec_attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 117 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 118 | 119 | self.norm3 = nn.LayerNorm(dim) 120 | mlp_hidden_dim = int(dim * mlp_ratio) 121 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 122 | 123 | def forward(self, tgt, memory, memory_mask=None, trg_mask=None): 124 | tgt_2 = self.norm1(tgt) 125 | tgt = tgt + self.drop_path(self.self_attn(q=tgt_2, k=tgt_2, v=tgt_2, mask=trg_mask)) 126 | tgt = tgt + self.drop_path(self.enc_dec_attn(q=self.norm2(tgt), k=memory, v=memory, mask=memory_mask)) 127 | tgt = tgt + self.drop_path(self.mlp(self.norm2(tgt))) 128 | return tgt 129 | -------------------------------------------------------------------------------- /networks/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Model(nn.Module): 6 | 7 | def __init__(self, net, lambda_obj=None, lambda_traj=None, lambda_obj_kl=None, lambda_traj_kl=None): 8 | super(Model, self).__init__() 9 | self.net = net 10 | self.lambda_obj = lambda_obj 11 | self.lambda_obj_kl = lambda_obj_kl 12 | self.lambda_traj = lambda_traj 13 | self.lambda_traj_kl = lambda_traj_kl 14 | 15 | def forward(self, feat, bbox_feat, valid_mask, future_hands=None, contact_point=None, future_valid=None, 16 | num_samples=5, pred_len=4): 17 | if self.training: 18 | losses = {} 19 | total_loss = 0 20 | traj_loss, traj_kl_loss, obj_loss, obj_kl_loss = self.net(feat, bbox_feat, valid_mask, future_hands, 21 | contact_point, future_valid) 22 | 23 | if self.lambda_traj is not None and traj_loss is not None: 24 | traj_loss = self.lambda_traj * traj_loss.sum() 25 | total_loss += traj_loss 26 | losses['traj_loss'] = traj_loss.detach().cpu() 27 | else: 28 | losses['traj_loss'] = 0. 29 | 30 | if self.lambda_traj_kl is not None and traj_kl_loss is not None: 31 | traj_kl_loss = self.lambda_traj_kl * traj_kl_loss.sum() 32 | total_loss += traj_kl_loss 33 | losses['traj_kl_loss'] = traj_kl_loss.detach().cpu() 34 | else: 35 | losses['traj_kl_loss'] = 0. 36 | 37 | if self.lambda_obj is not None and obj_loss is not None: 38 | obj_loss = self.lambda_obj * obj_loss.sum() 39 | total_loss += obj_loss 40 | losses['obj_loss'] = obj_loss.detach().cpu() 41 | else: 42 | losses['obj_loss'] = 0. 43 | 44 | if self.lambda_obj_kl is not None and obj_kl_loss is not None: 45 | obj_kl_loss = self.lambda_obj_kl * obj_kl_loss.sum() 46 | total_loss += obj_kl_loss 47 | losses['obj_kl_loss'] = obj_kl_loss.detach().cpu() 48 | else: 49 | losses['obj_kl_loss'] = 0. 50 | 51 | if total_loss is not None: 52 | losses["total_loss"] = total_loss.detach().cpu() 53 | else: 54 | losses["total_loss"] = 0. 55 | return total_loss, losses 56 | else: 57 | future_hands_list = [] 58 | contact_points_list = [] 59 | for i in range(num_samples): 60 | future_hands, contact_point = self.net.module.inference(feat, bbox_feat, valid_mask, 61 | future_valid=future_valid, 62 | pred_len=pred_len) 63 | future_hands_list.append(future_hands) 64 | contact_points_list.append(contact_point) 65 | 66 | contact_points = torch.stack(contact_points_list, dim=0) 67 | 68 | assert len(contact_points.shape) == 3 69 | contact_points = contact_points.transpose(0, 1) 70 | 71 | future_hands_list = torch.stack(future_hands_list, dim=0) 72 | future_hands_list = future_hands_list.transpose(0, 1) 73 | return future_hands_list, contact_points 74 | -------------------------------------------------------------------------------- /networks/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import warnings 5 | 6 | 7 | def get_pad_mask(seq, pad_idx=0): 8 | if seq.dim() != 2: 9 | raise ValueError(" has to be a 2-dimensional tensor!") 10 | if not isinstance(pad_idx, int): 11 | raise TypeError(" has to be an int!") 12 | 13 | return (seq != pad_idx).unsqueeze(1) # equivalent (seq != pad_idx).unsqueeze(-2) 14 | 15 | 16 | def get_subsequent_mask(seq, diagonal=1): 17 | if seq.dim() < 2: 18 | raise ValueError(" has to be at least a 2-dimensional tensor!") 19 | 20 | seq_len = seq.size(1) 21 | mask = (1 - torch.triu(torch.ones((1, seq_len, seq_len), device=seq.device), diagonal=diagonal)).bool() 22 | return mask 23 | 24 | 25 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 26 | def norm_cdf(x): 27 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 28 | 29 | if (mean < a - 2 * std) or (mean > b + 2 * std): 30 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 31 | "The distribution of values may be incorrect.", 32 | stacklevel=2) 33 | 34 | with torch.no_grad(): 35 | l = norm_cdf((a - mean) / std) 36 | u = norm_cdf((b - mean) / std) 37 | 38 | tensor.uniform_(2 * l - 1, 2 * u - 1) 39 | tensor.erfinv_() 40 | 41 | tensor.mul_(std * math.sqrt(2.)) 42 | tensor.add_(mean) 43 | tensor.clamp_(min=a, max=b) 44 | return tensor 45 | 46 | 47 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 48 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 49 | 50 | 51 | def drop_path(x, drop_prob: float = 0., training: bool = False): 52 | if drop_prob == 0. or not training: 53 | return x 54 | keep_prob = 1 - drop_prob 55 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 56 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 57 | random_tensor.floor_() 58 | output = x.div(keep_prob) * random_tensor 59 | return output 60 | 61 | 62 | class DropPath(nn.Module): 63 | def __init__(self, drop_prob=None): 64 | super(DropPath, self).__init__() 65 | self.drop_prob = drop_prob 66 | 67 | def forward(self, x): 68 | return drop_path(x, self.drop_prob, self.training) 69 | 70 | 71 | def traj_affordance_dist(hand_traj, contact_point, future_valid=None, invalid_value=9): 72 | batch_size = contact_point.shape[0] 73 | expand_size = int(hand_traj.shape[0] / batch_size) 74 | contact_point = contact_point.unsqueeze(dim=1).expand(-1, expand_size, 2).reshape(-1, 2) # (B * 2 * Tf, 2) 75 | dist = torch.sum((hand_traj - contact_point) ** 2, dim=1).reshape(batch_size, -1) # (B, 2 * Tf) 76 | if future_valid is None: 77 | sorted_dist, sorted_idx = torch.sort(dist, dim=-1, descending=False) # from small to high 78 | return sorted_dist[:, 0] # (B, ) 79 | else: 80 | dist = dist.reshape(batch_size, 2, -1) # (B, 2, Tf) 81 | future_valid = future_valid > 0 82 | future_invalid = ~future_valid[:, :, None].expand(dist.shape) 83 | dist[future_invalid] = invalid_value # set invalid dist to be very large 84 | sorted_dist, sorted_idx = torch.sort(dist, dim=-1, descending=False) # from small to high 85 | selected_dist = sorted_dist[:, :, 0] # (B, 2) 86 | selected_dist, selected_idx = selected_dist.min(dim=1) # selected_dist, selected_idx (B, ) 87 | valid = torch.gather(future_valid, dim=1, index=selected_idx.unsqueeze(dim=1)).squeeze(dim=1) 88 | selected_dist = selected_dist * valid 89 | 90 | return selected_dist -------------------------------------------------------------------------------- /networks/traj_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import einops 4 | from networks.affordance_decoder import VAE 5 | 6 | 7 | class TrajCVAE(nn.Module): 8 | def __init__(self, in_dim, hidden_dim, latent_dim, condition_dim, coord_dim=None, 9 | condition_contact=False, z_scale=2.0): 10 | super().__init__() 11 | self.latent_dim = latent_dim 12 | self.condition_contact = condition_contact 13 | self.z_scale = z_scale 14 | if self.condition_contact: 15 | if coord_dim is None: 16 | coord_dim = hidden_dim // 2 17 | self.coord_dim = coord_dim 18 | self.contact_to_feature = nn.Sequential( 19 | nn.Linear(2, coord_dim, bias=False), 20 | nn.ELU(inplace=True)) 21 | self.contact_context_fusion = nn.Sequential( 22 | nn.Linear(condition_dim+coord_dim, condition_dim, bias=False), 23 | nn.ELU(inplace=True)) 24 | 25 | self.cvae = VAE(in_dim=in_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, 26 | conditional=True, condition_dim=condition_dim) 27 | 28 | def forward(self, context, target_hand, future_valid, contact_point=None, return_pred=False): 29 | batch_size = future_valid.shape[0] 30 | if self.condition_contact: 31 | assert contact_point is not None 32 | time_steps = int(context.shape[0] / batch_size / 2) 33 | contact_feat = self.contact_to_feature(contact_point) 34 | contact_feat = einops.repeat(contact_feat, 'm n -> m p q n', p=2, q=time_steps) 35 | contact_feat = contact_feat.reshape(-1, self.coord_dim) 36 | fusion_feat = torch.cat([context, contact_feat], dim=1) 37 | condition_context = self.contact_context_fusion(fusion_feat) 38 | else: 39 | condition_context = context 40 | if not return_pred: 41 | recon_loss, KLD = self.cvae(target_hand, c=condition_context) 42 | else: 43 | pred_hand, recon_loss, KLD = self.cvae(target_hand, c=condition_context, return_pred=return_pred) 44 | KLD = KLD.reshape(batch_size, 2, -1).sum(-1) 45 | KLD = (KLD * future_valid).sum(1) 46 | recon_loss = recon_loss.reshape(batch_size, 2, -1).sum(-1) 47 | traj_loss = (recon_loss * future_valid).sum(1) 48 | if not return_pred: 49 | return traj_loss, KLD 50 | else: 51 | return pred_hand, traj_loss, KLD 52 | 53 | def inference(self, context, contact_point=None): 54 | if self.condition_contact: 55 | assert contact_point is not None 56 | batch_size = contact_point.shape[0] 57 | time_steps = int(context.shape[0] / batch_size) 58 | contact_feat = self.contact_to_feature(contact_point) 59 | contact_feat = einops.repeat(contact_feat, 'm n -> m p n', p=time_steps) 60 | contact_feat = contact_feat.reshape(-1, self.coord_dim) 61 | fusion_feat = torch.cat([context, contact_feat], dim=1) 62 | condition_context = self.contact_context_fusion(fusion_feat) 63 | else: 64 | condition_context = context 65 | z = self.z_scale * torch.randn([context.shape[0], self.latent_dim], device=context.device) 66 | recon_x = self.cvae.inference(z, c=condition_context) 67 | return recon_x -------------------------------------------------------------------------------- /networks/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from networks.embedding import PositionalEncoding, Encoder_PositionalEmbedding, Decoder_PositionalEmbedding 5 | from networks.layer import EncoderBlock, DecoderBlock 6 | from networks.net_utils import trunc_normal_, get_pad_mask, get_subsequent_mask, traj_affordance_dist 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, num_patches=5, embed_dim=512, depth=6, num_heads=8, mlp_ratio=4., 11 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 12 | drop_path_rate=0.1, norm_layer=nn.LayerNorm, 13 | dropout=0., time_embed_type=None, num_frames=None): 14 | super().__init__() 15 | if time_embed_type is None or num_frames is None: 16 | time_embed_type = 'sin' 17 | self.time_embed_type = time_embed_type 18 | self.num_patches = num_patches # (hand, object global feature patches, default: 5) 19 | self.depth = depth 20 | self.dropout = nn.Dropout(dropout) 21 | self.num_features = self.embed_dim = embed_dim 22 | 23 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 24 | self.pos_drop = nn.Dropout(p=drop_rate) 25 | if not self.time_embed_type == "sin" and num_frames is not None: 26 | self.time_embed = Encoder_PositionalEmbedding(embed_dim, seq_len=num_frames) 27 | else: 28 | self.time_embed = PositionalEncoding(embed_dim) 29 | self.time_drop = nn.Dropout(p=drop_rate) 30 | 31 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)] 32 | self.encoder_blocks = nn.ModuleList([EncoderBlock( 33 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 34 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 35 | for i in range(self.depth)]) 36 | self.norm = norm_layer(embed_dim) 37 | trunc_normal_(self.pos_embed, std=.02) 38 | 39 | @torch.jit.ignore 40 | def no_weight_decay(self): 41 | return {'pos_embed', 'time_embed'} 42 | 43 | def forward(self, x, mask=None): 44 | B, T, N = x.shape[:3] 45 | 46 | x = rearrange(x, 'b t n m -> (b t) n m', b=B, t=T, n=N) 47 | x = x + self.pos_embed 48 | x = self.pos_drop(x) 49 | 50 | x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T) 51 | x = self.time_embed(x) 52 | x = self.time_drop(x) 53 | x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T) 54 | 55 | mask = mask.transpose(1, 2) 56 | for blk in self.encoder_blocks: 57 | x = blk(x, B, T, N, mask=mask) 58 | 59 | x = rearrange(x, 'b (n t) m -> b t n m', b=B, t=T, n=N) 60 | x = self.norm(x) 61 | return x 62 | 63 | 64 | class Decoder(nn.Module): 65 | def __init__(self, in_features, embed_dim=512, depth=6, num_heads=8, mlp_ratio=4., 66 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 67 | drop_path_rate=0.1, norm_layer=nn.LayerNorm, dropout=0., 68 | time_embed_type=None, num_frames=None): 69 | super().__init__() 70 | self.depth = depth 71 | self.dropout = nn.Dropout(dropout) 72 | self.num_features = self.embed_dim = embed_dim 73 | 74 | self.trg_embedding = nn.Linear(in_features, embed_dim) 75 | 76 | if time_embed_type is None or num_frames is None: 77 | time_embed_type = 'sin' 78 | self.time_embed_type = time_embed_type 79 | if not self.time_embed_type == "sin" and num_frames is not None: 80 | self.time_embed = Decoder_PositionalEmbedding(embed_dim, seq_len=num_frames) 81 | else: 82 | self.time_embed = PositionalEncoding(embed_dim) 83 | self.time_drop = nn.Dropout(p=drop_rate) 84 | 85 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)] 86 | self.decoder_blocks = nn.ModuleList([DecoderBlock( 87 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 88 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 89 | for i in range(self.depth)]) 90 | self.norm = norm_layer(embed_dim) 91 | 92 | @torch.jit.ignore 93 | def no_weight_decay(self): 94 | return {'pos_embed', 'time_embed'} 95 | 96 | def forward(self, trg, memory, memory_mask=None, trg_mask=None): 97 | trg = self.trg_embedding(trg) 98 | trg = self.time_embed(trg) 99 | trg = self.time_drop(trg) 100 | 101 | for blk in self.decoder_blocks: 102 | trg = blk(trg, memory, memory_mask=memory_mask, trg_mask=trg_mask) 103 | 104 | trg = self.norm(trg) 105 | return trg 106 | 107 | 108 | class ObjectTransformer(nn.Module): 109 | 110 | def __init__(self, src_in_features, trg_in_features, num_patches, 111 | hand_head, obj_head, 112 | embed_dim=512, coord_dim=64, num_heads=8, enc_depth=6, dec_depth=4, 113 | mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 114 | drop_path_rate=0.1, norm_layer=nn.LayerNorm, dropout=0., 115 | encoder_time_embed_type='sin', decoder_time_embed_type='sin', 116 | num_frames_input=None, num_frames_output=None): 117 | super().__init__() 118 | self.num_features = self.embed_dim = embed_dim 119 | self.coord_dim = coord_dim 120 | self.downproject = nn.Linear(src_in_features, embed_dim) 121 | 122 | self.bbox_to_feature = nn.Sequential( 123 | nn.Linear(4, self.coord_dim // 2), 124 | nn.ELU(inplace=True), 125 | nn.Linear(self.coord_dim // 2, self.coord_dim), 126 | nn.ELU() 127 | ) 128 | self.feat_fusion = nn.Sequential( 129 | nn.Linear(self.embed_dim + self.coord_dim, self.embed_dim), 130 | nn.ELU(inplace=True)) 131 | 132 | self.encoder = Encoder(num_patches=num_patches, 133 | embed_dim=embed_dim, depth=enc_depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 134 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 135 | drop_path_rate=drop_path_rate, norm_layer=norm_layer, dropout=dropout, 136 | time_embed_type=encoder_time_embed_type, num_frames=num_frames_input) 137 | 138 | self.decoder = Decoder(in_features=trg_in_features, embed_dim=embed_dim, 139 | depth=dec_depth, num_heads=num_heads, mlp_ratio=mlp_ratio, 140 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, 141 | attn_drop_rate=attn_drop_rate, 142 | drop_path_rate=drop_path_rate, norm_layer=norm_layer, dropout=dropout, 143 | time_embed_type=decoder_time_embed_type, num_frames=num_frames_output) 144 | 145 | self.hand_head = hand_head 146 | self.object_head = obj_head 147 | self.apply(self._init_weights) 148 | 149 | def _init_weights(self, m): 150 | if isinstance(m, nn.Linear): 151 | trunc_normal_(m.weight, std=.02) 152 | if isinstance(m, nn.Linear) and m.bias is not None: 153 | nn.init.constant_(m.bias, 0) 154 | elif isinstance(m, nn.LayerNorm): 155 | nn.init.constant_(m.bias, 0) 156 | nn.init.constant_(m.weight, 1.0) 157 | 158 | def encoder_input(self, feat, bbox_feat): 159 | B, T = feat.shape[0], feat.shape[2] 160 | feat = self.downproject(feat) 161 | bbox_feat = bbox_feat.view(-1, 4) 162 | bbox_feat = self.bbox_to_feature(bbox_feat) 163 | bbox_feat = bbox_feat.view(B, -1, T, self.coord_dim) 164 | ho_feat = feat[:, 1:, :, :] 165 | global_feat = feat[:, 0:1, :, :] 166 | feat = torch.cat((ho_feat, bbox_feat), dim=-1) 167 | feat = feat.view(-1, self.embed_dim + self.coord_dim) 168 | feat = self.feat_fusion(feat) 169 | feat = feat.view(B, -1, T, self.embed_dim) 170 | feat = torch.cat((global_feat, feat), dim=1) 171 | feat = feat.transpose(1, 2) 172 | return feat 173 | 174 | def forward(self, feat, bbox_feat, valid_mask, future_hands, contact_point, future_valid): 175 | # feat: (B, 5, T, src_in_features), global, hand & obj, T=10 176 | # bbox_feat: (B, 4, T, 4), hand & obj 177 | # valid_mask: (B, 4, T), hand & obj / (B, 5, T), hand & obj & global 178 | # future_hands: (B, 2, T, 2) right & left, T=5 (contain last observation frame) 179 | # contact_points: (B, 2) 180 | # future_valid: (B, 2), right & left traj valid 181 | # return: traj_loss: (B), obj_loss (B) 182 | if not valid_mask.shape[1] == feat.shape[1]: 183 | src_mask = torch.cat( 184 | (torch.ones_like(valid_mask[:, 0:1, :], dtype=valid_mask.dtype, device=valid_mask.device), 185 | valid_mask), dim=1).transpose(1, 2) 186 | else: 187 | src_mask = valid_mask.transpose(1, 2) 188 | feat = self.encoder_input(feat, bbox_feat) 189 | x = self.encoder(feat, mask=src_mask) 190 | 191 | memory = x[:, -1, :, :] 192 | memory_mask = get_pad_mask(src_mask[:, -1, :], pad_idx=0) 193 | 194 | future_rhand, future_lhand = future_hands[:, 0, :, :], future_hands[:, 1, :, :] 195 | 196 | rhand_input = future_rhand[:, :-1, :] 197 | lhand_input = future_lhand[:, :-1, :] 198 | trg_mask = torch.ones_like(rhand_input[:, :, 0]) 199 | trg_mask = get_subsequent_mask(trg_mask) 200 | 201 | x_rhand = self.decoder(rhand_input, memory, 202 | memory_mask=memory_mask, trg_mask=trg_mask) 203 | x_lhand = self.decoder(lhand_input, memory, 204 | memory_mask=memory_mask, trg_mask=trg_mask) 205 | 206 | x_hand = torch.cat((x_rhand, x_lhand), dim=1) 207 | x_hand = x_hand.reshape(-1, self.embed_dim) 208 | 209 | target_hand = future_hands[:, :, 1:, :] 210 | target_hand = target_hand.reshape(-1, target_hand.shape[-1]) 211 | 212 | pred_hand, traj_loss, traj_kl_loss = self.hand_head(x_hand, target_hand, future_valid, contact=None, 213 | return_pred=True) 214 | 215 | r_pred_contact, r_obj_loss, r_obj_kl_loss = self.object_head(memory[:, 0, :], contact_point, future_rhand, 216 | return_pred=True) 217 | l_pred_contact, l_obj_loss, l_obj_kl_loss = self.object_head(memory[:, 0, :], contact_point, future_lhand, 218 | return_pred=True) 219 | 220 | obj_loss = torch.stack([r_obj_loss, l_obj_loss], dim=1) 221 | obj_kl_loss = torch.stack([r_obj_kl_loss, l_obj_kl_loss], dim=1) 222 | 223 | obj_loss[~(future_valid > 0)] = 1e9 224 | selected_obj_loss, selected_idx = obj_loss.min(dim=1) 225 | selected_valid = torch.gather(future_valid, dim=1, index=selected_idx.unsqueeze(dim=1)).squeeze(dim=1) 226 | selected_obj_kl_loss = torch.gather(obj_kl_loss, dim=1, index=selected_idx.unsqueeze(dim=1)).squeeze(dim=1) 227 | obj_loss = selected_obj_loss * selected_valid 228 | obj_kl_loss = selected_obj_kl_loss * selected_valid 229 | 230 | return traj_loss, traj_kl_loss, obj_loss, obj_kl_loss 231 | 232 | def inference(self, feat, bbox_feat, valid_mask, future_valid=None, pred_len=4): 233 | # feat: (B, 5, T, src_in_features), hand & obj, T=10 234 | # bbox_feat: (B, 4, T, 4), hand & obj 235 | # valid_mask: (B, 4, T), hand & obj 236 | # future_valid: (B, 2), right & left traj valid 237 | # return: future_hand (B, 2, T, 2), not include last observation frame 238 | # return: contact_point (B, 2) 239 | B, T = feat.shape[0], feat.shape[2] 240 | if not valid_mask.shape[1] == feat.shape[1]: 241 | src_mask = torch.cat( 242 | (torch.ones_like(valid_mask[:, 0:1, :], dtype=valid_mask.dtype, device=valid_mask.device), 243 | valid_mask), dim=1).transpose(1, 2) 244 | else: 245 | src_mask = valid_mask.transpose(1, 2) 246 | feat = self.encoder_input(feat, bbox_feat) 247 | x = self.encoder(feat, mask=src_mask) 248 | 249 | memory = x[:, -1, :, :] 250 | memory_mask = get_pad_mask(src_mask[:, -1, :], pad_idx=0) 251 | 252 | observe_bbox = bbox_feat[:, :2, -1, :] 253 | observe_rhand, observe_lhand = observe_bbox[:, 0, :], observe_bbox[:, 1, :] 254 | future_rhand = (observe_rhand[:, :2] + observe_rhand[:, 2:]) / 2 255 | future_lhand = (observe_lhand[:, :2] + observe_lhand[:, 2:]) / 2 256 | future_rhand = future_rhand.unsqueeze(dim=1) 257 | future_lhand = future_lhand.unsqueeze(dim=1) 258 | 259 | pred_contact = None 260 | for i in range(pred_len): 261 | trg_mask = torch.ones_like(future_rhand[:, :, 0]) 262 | trg_mask = get_subsequent_mask(trg_mask) 263 | 264 | x_rhand = self.decoder(future_rhand, memory, 265 | memory_mask=memory_mask, trg_mask=trg_mask) 266 | x_hand = x_rhand.reshape(-1, self.embed_dim) 267 | loc_rhand = self.hand_head.inference(x_hand, pred_contact) 268 | loc_rhand = loc_rhand.reshape(B, -1, 2) 269 | pred_rhand = loc_rhand[:, -1:, :] 270 | future_rhand = torch.cat((future_rhand, pred_rhand), dim=1) 271 | 272 | for i in range(pred_len): 273 | trg_mask = torch.ones_like(future_lhand[:, :, 0]) 274 | trg_mask = get_subsequent_mask(trg_mask) 275 | 276 | x_lhand = self.decoder(future_lhand, memory, 277 | memory_mask=memory_mask, trg_mask=trg_mask) 278 | x_hand = x_lhand.reshape(-1, self.embed_dim) 279 | loc_lhand = self.hand_head.inference(x_hand, pred_contact) 280 | loc_lhand = loc_lhand.reshape(B, -1, 2) 281 | pred_lhand = loc_lhand[:, -1:, :] 282 | future_lhand = torch.cat((future_lhand, pred_lhand), dim=1) 283 | 284 | future_hands = torch.stack((future_rhand[:, 1:, :], future_lhand[:, 1:, :]), dim=1) 285 | 286 | r_pred_contact = self.object_head.inference(memory[:, 0, :], future_rhand) 287 | l_pred_contact = self.object_head.inference(memory[:, 0, :], future_lhand) 288 | pred_contact = torch.stack([r_pred_contact, l_pred_contact], dim=1) 289 | 290 | if future_valid is not None and torch.all(future_valid.sum(dim=1) >= 1): 291 | r_pred_contact_dist = traj_affordance_dist(future_hands.reshape(-1, 2), r_pred_contact, 292 | future_valid) 293 | l_pred_contact_dist = traj_affordance_dist(future_hands.reshape(-1, 2), l_pred_contact, 294 | future_valid) 295 | pred_contact_dist = torch.stack((r_pred_contact_dist, l_pred_contact_dist), dim=1) 296 | _, selected_idx = pred_contact_dist.min(dim=1) 297 | selected_idx = selected_idx.unsqueeze(dim=1).unsqueeze(dim=2).expand(pred_contact.shape[0], 1, 298 | pred_contact.shape[2]) 299 | pred_contact = torch.gather(pred_contact, dim=1, index=selected_idx).squeeze(dim=1) 300 | 301 | return future_hands, pred_contact 302 | -------------------------------------------------------------------------------- /options/expopts.py: -------------------------------------------------------------------------------- 1 | def add_exp_opts(parser): 2 | parser.add_argument("--resume", type=str, nargs="+", metavar="PATH", 3 | help="path to latest checkpoint (default: none)") 4 | parser.add_argument("--evaluate", dest="evaluate", action="store_true", 5 | help="evaluate model on validation set") 6 | parser.add_argument("--test_freq", type=int, default=10, 7 | help="testing frequency on evaluation dataset (set specific in traineval.py)") 8 | parser.add_argument("--snapshot", default=10, type=int, metavar="N", 9 | help="How often to take a snapshot of the model (0 = never)") 10 | parser.add_argument("--use_cuda", default=1, type=int, help="use GPU (default: True)") 11 | parser.add_argument('--ek_version', default="ek55", choices=["ek55", "ek100"], help="epic dataset version") 12 | parser.add_argument("--traj_only", action="store_true", help="evaluate traj on validation dataset") -------------------------------------------------------------------------------- /options/netsopts.py: -------------------------------------------------------------------------------- 1 | def add_nets_opts(parser): 2 | parser.add_argument('--src_in_features', type=int, default=1024, help='Network encoder input size') 3 | parser.add_argument('--trg_in_features', type=int, default=2, help='Network decoder input size') 4 | parser.add_argument('--num_patches', type=int, default=5, help='Number of classes') 5 | parser.add_argument('--num_classes', type=int, default=2513, help='Number of classes') 6 | 7 | parser.add_argument('--embed_dim', type=int, default=512, help='embedded dimension') 8 | parser.add_argument('--num_heads', type=int, default=8, help='num of heads in transformer') 9 | parser.add_argument('--enc_depth', type=int, default=6, help='transformer encoder depth') 10 | parser.add_argument('--dec_depth', type=int, default=4, help='transformer decoder depth') 11 | 12 | parser.add_argument('--coord_dim', type=int, default=64, help='coordinates feature dimension') 13 | parser.add_argument('--hidden_dim', type=int, default=512, help='stochastic modules hidden dimension') 14 | parser.add_argument('--latent_dim', type=int, default=256, help='stochastic modules latent dimension') 15 | 16 | parser.add_argument("--encoder_time_embed_type", default="sin", 17 | choices=["sin", "param"], help="transformer encoder time position embedding") 18 | parser.add_argument("--decoder_time_embed_type", default="sin", 19 | choices=["sin", "param"], help="transformer decoder time position embedding") 20 | 21 | parser.add_argument("--num_samples", default=20, type=int, help="get number of samples during inference, " 22 | "stochastic model multiple runs") 23 | parser.add_argument("--num_points", default=5, type=int, 24 | help="number of remaining contact points after farthest point " 25 | "sampling for evaluation affordance") 26 | parser.add_argument("--gaussian_sigma", default=3., type=float, 27 | help="predicted contact points gaussian kernel sigma") 28 | parser.add_argument("--gaussian_k_ratio", default=3., type=float, 29 | help="predicted contact points gaussian kernel size") 30 | 31 | # Options for loss supervision 32 | parser.add_argument("--lambda_obj", default=1e-1, type=float, help="Weight to supervise object affordance") 33 | parser.add_argument("--lambda_traj", default=1., type=float, help="Weight to supervise hand traj") 34 | parser.add_argument("--lambda_obj_kl", default=1e-3, type=float, help="Weight to supervise object affordance KLD") 35 | parser.add_argument("--lambda_traj_kl", default=1e-3, type=float, help="Weight to supervise hand traj KLD") 36 | 37 | 38 | def add_train_opts(parser): 39 | parser.add_argument("--manual_seed", default=0, type=int, help="manual seed") 40 | parser.add_argument("-j", "--workers", default=16, type=int, help="number of workers") 41 | parser.add_argument("--epochs", default=35, type=int, help="number epochs") 42 | parser.add_argument("--batch_size", default=128, type=int, help="batch size") 43 | 44 | parser.add_argument("--optimizer", default="adam", choices=["rms", "adam", "sgd", "adamw"]) 45 | parser.add_argument("--lr", "--learning-rate", default=1e-4, type=float, metavar="LR", help="initial learning rate") 46 | parser.add_argument("--momentum", default=0.9, type=float) 47 | 48 | parser.add_argument("--scheduler", default="cosine", choices=['cosine', 'step', 'multistep'], 49 | help="learning rate scheduler") 50 | parser.add_argument("--warmup_epochs", default=5, type=int, help="number of warmup epochs to run") 51 | parser.add_argument("--lr_decay_step", nargs="+", default=10, type=int, 52 | help="Epochs after which to decay learning rate") 53 | parser.add_argument( 54 | "--lr_decay_gamma", default=0.5, type=float, help="Factor by which to decay the learning rate") 55 | parser.add_argument("--weight_decay", default=1e-4, type=float) 56 | -------------------------------------------------------------------------------- /preprocess/affordance_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from preprocess.dataset_util import bbox_inter 4 | 5 | 6 | def skin_extract(image): 7 | def color_segmentation(): 8 | lower_HSV_values = np.array([0, 40, 0], dtype="uint8") 9 | upper_HSV_values = np.array([25, 255, 255], dtype="uint8") 10 | lower_YCbCr_values = np.array((0, 138, 67), dtype="uint8") 11 | upper_YCbCr_values = np.array((255, 173, 133), dtype="uint8") 12 | mask_YCbCr = cv2.inRange(YCbCr_image, lower_YCbCr_values, upper_YCbCr_values) 13 | mask_HSV = cv2.inRange(HSV_image, lower_HSV_values, upper_HSV_values) 14 | binary_mask_image = cv2.add(mask_HSV, mask_YCbCr) 15 | return binary_mask_image 16 | 17 | HSV_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 18 | YCbCr_image = cv2.cvtColor(image, cv2.COLOR_BGR2YCR_CB) 19 | binary_mask_image = color_segmentation() 20 | image_foreground = cv2.erode(binary_mask_image, None, iterations=3) 21 | dilated_binary_image = cv2.dilate(binary_mask_image, None, iterations=3) 22 | ret, image_background = cv2.threshold(dilated_binary_image, 1, 128, cv2.THRESH_BINARY) 23 | 24 | image_marker = cv2.add(image_foreground, image_background) 25 | image_marker32 = np.int32(image_marker) 26 | cv2.watershed(image, image_marker32) 27 | m = cv2.convertScaleAbs(image_marker32) 28 | ret, image_mask = cv2.threshold(m, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 29 | kernel = np.ones((20, 20), np.uint8) 30 | image_mask = cv2.morphologyEx(image_mask, cv2.MORPH_CLOSE, kernel) 31 | return image_mask 32 | 33 | 34 | def farthest_sampling(pcd, n_samples, init_pcd=None): 35 | def compute_distance(a, b): 36 | return np.linalg.norm(a - b, ord=2, axis=2) 37 | 38 | n_pts, dim = pcd.shape[0], pcd.shape[1] 39 | selected_pts_expanded = np.zeros(shape=(n_samples, 1, dim)) 40 | remaining_pts = np.copy(pcd) 41 | 42 | if init_pcd is None: 43 | if n_pts > 1: 44 | start_idx = np.random.randint(low=0, high=n_pts - 1) 45 | else: 46 | start_idx = 0 47 | selected_pts_expanded[0] = remaining_pts[start_idx] 48 | n_selected_pts = 1 49 | else: 50 | num_points = min(init_pcd.shape[0], n_samples) 51 | selected_pts_expanded[:num_points] = init_pcd[:num_points, None, :] 52 | n_selected_pts = num_points 53 | 54 | for _ in range(1, n_samples): 55 | if n_selected_pts < n_samples: 56 | dist_pts_to_selected = compute_distance(remaining_pts, selected_pts_expanded[:n_selected_pts]).T 57 | dist_pts_to_selected_min = np.min(dist_pts_to_selected, axis=1, keepdims=True) 58 | res_selected_idx = np.argmax(dist_pts_to_selected_min) 59 | selected_pts_expanded[n_selected_pts] = remaining_pts[res_selected_idx] 60 | n_selected_pts += 1 61 | 62 | selected_pts = np.squeeze(selected_pts_expanded, axis=1) 63 | return selected_pts 64 | 65 | 66 | def compute_heatmap(points, image_size, k_ratio=3.0): 67 | points = np.asarray(points) 68 | heatmap = np.zeros((image_size[0], image_size[1]), dtype=np.float32) 69 | n_points = points.shape[0] 70 | for i in range(n_points): 71 | x = points[i, 0] 72 | y = points[i, 1] 73 | col = int(x) 74 | row = int(y) 75 | try: 76 | heatmap[col, row] += 1.0 77 | except: 78 | col = min(max(col, 0), image_size[0] - 1) 79 | row = min(max(row, 0), image_size[1] - 1) 80 | heatmap[col, row] += 1.0 81 | k_size = int(np.sqrt(image_size[0] * image_size[1]) / k_ratio) 82 | if k_size % 2 == 0: 83 | k_size += 1 84 | heatmap = cv2.GaussianBlur(heatmap, (k_size, k_size), 0) 85 | if heatmap.max() > 0: 86 | heatmap /= heatmap.max() 87 | heatmap = heatmap.transpose() 88 | return heatmap 89 | 90 | 91 | def select_points_bbox(bbox, points, tolerance=2): 92 | x1, y1, x2, y2 = bbox 93 | ind_x = np.logical_and(points[:, 0] > x1-tolerance, points[:, 0] < x2+tolerance) 94 | ind_y = np.logical_and(points[:, 1] > y1-tolerance, points[:, 1] < y2+tolerance) 95 | ind = np.logical_and(ind_x, ind_y) 96 | indices = np.where(ind == True)[0] 97 | return points[indices] 98 | 99 | 100 | def find_contour_points(mask): 101 | _, contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 102 | if len(contours) != 0: 103 | c = max(contours, key=cv2.contourArea) 104 | c = c.squeeze(axis=1) 105 | return c 106 | else: 107 | return None 108 | 109 | 110 | def get_points_homo(select_points, homography, active_obj_traj, obj_bboxs_traj): 111 | # active_obj_traj: active obj traj in last observation frame 112 | # obj_bboxs_traj: active obj bbox traj in last observation frame 113 | select_points_homo = np.concatenate((select_points, np.ones((select_points.shape[0], 1), dtype=np.float32)), axis=1) 114 | select_points_homo = np.dot(select_points_homo, homography.T) 115 | select_points_homo = select_points_homo[:, :2] / select_points_homo[:, None, 2] 116 | 117 | obj_point_last_observe = np.array(active_obj_traj[0]) 118 | obj_point_future_start = np.array(active_obj_traj[-1]) 119 | 120 | future2last_trans = obj_point_last_observe - obj_point_future_start 121 | select_points_homo = select_points_homo + future2last_trans 122 | 123 | fill_indices = [idx for idx, points in enumerate(obj_bboxs_traj) if points is not None] 124 | contour_last_observe = obj_bboxs_traj[fill_indices[0]] 125 | contour_future_homo = obj_bboxs_traj[fill_indices[-1]] + future2last_trans 126 | contour_last_observe = contour_last_observe[:, None, :].astype(np.int) 127 | contour_future_homo = contour_future_homo[:, None, :].astype(np.int) 128 | filtered_points = [] 129 | for point in select_points_homo: 130 | if cv2.pointPolygonTest(contour_last_observe, (point[0], point[1]), False) >= 0 \ 131 | or cv2.pointPolygonTest(contour_future_homo, (point[0], point[1]), False) >= 0: 132 | filtered_points.append(point) 133 | filtered_points = np.array(filtered_points) 134 | return filtered_points 135 | 136 | 137 | def compute_affordance(frame, active_hand, active_obj, num_points=5, num_sampling=20): 138 | skin_mask = skin_extract(frame) 139 | hand_bbox = np.array(active_hand.bbox.coords_int).reshape(-1) 140 | obj_bbox = np.array(active_obj.bbox.coords_int).reshape(-1) 141 | obj_center = active_obj.bbox.center 142 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox) 143 | if not iou > 0: 144 | return None 145 | x1, y1, x2, y2 = hand_bbox 146 | hand_mask = np.zeros_like(skin_mask, dtype=np.uint8) 147 | hand_mask[y1:y2, x1:x2] = 255 148 | hand_mask = cv2.bitwise_and(skin_mask, hand_mask) 149 | select_points, init_points = None, None 150 | contact_points = find_contour_points(hand_mask) 151 | 152 | if contact_points is not None and contact_points.shape[0] > 0: 153 | contact_points = select_points_bbox((xA, yA, xB, yB), contact_points) 154 | if contact_points.shape[0] >= num_points: 155 | if contact_points.shape[0] > num_sampling: 156 | contact_points = farthest_sampling(contact_points, n_samples=num_sampling) 157 | distance = np.linalg.norm(contact_points - obj_center, ord=2, axis=1) 158 | indices = np.argsort(distance)[:num_points] 159 | select_points = contact_points[indices] 160 | elif contact_points.shape[0] > 0: 161 | print("no enough boundary points detected, sampling points in interaction region") 162 | init_points = contact_points 163 | else: 164 | print("no boundary points detected, use farthest point sampling") 165 | else: 166 | print("no boundary points detected, use farthest point sampling") 167 | if select_points is None: 168 | ho_mask = np.zeros_like(skin_mask, dtype=np.uint8) 169 | ho_mask[yA:yB, xA:xB] = 255 170 | ho_mask = cv2.bitwise_and(skin_mask, ho_mask) 171 | points = np.array(np.where(ho_mask[yA:yB, xA:xB] > 0)).T 172 | if points.shape[0] == 0: 173 | xx, yy = np.meshgrid(np.arange(xB - xA), np.arange(yB - yA)) 174 | xx += xA 175 | yy += yA 176 | points = np.vstack([xx.reshape(-1), yy.reshape(-1)]).T 177 | else: 178 | points = points[:, [1, 0]] 179 | points[:, 0] += xA 180 | points[:, 1] += yA 181 | if not points.shape[0] > 0: 182 | return None 183 | contact_points = farthest_sampling(points, n_samples=min(num_sampling, points.shape[0]), init_pcd=init_points) 184 | distance = np.linalg.norm(contact_points - obj_center, ord=2, axis=1) 185 | indices = np.argsort(distance)[:num_points] 186 | select_points = contact_points[indices] 187 | return select_points 188 | 189 | 190 | def compute_obj_affordance(frame, annot, active_obj, active_obj_idx, homography, 191 | active_obj_traj, obj_bboxs_traj, 192 | num_points=5, num_sampling=20, 193 | hand_threshold=0.1, obj_threshold=0.1): 194 | affordance_info = {} 195 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 196 | hand_threshold=hand_threshold) 197 | select_points = None 198 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 199 | if object_idx == active_obj_idx: 200 | active_hand = annot.hands[hand_idx] 201 | affordance_info[active_hand.side.name] = np.array(active_hand.bbox.coords_int).reshape(-1) 202 | cmap_points = compute_affordance(frame, active_hand, active_obj, num_points=num_points, num_sampling=num_sampling) 203 | if select_points is None and (cmap_points is not None and cmap_points.shape[0] > 0): 204 | select_points = cmap_points 205 | elif select_points is not None and (cmap_points is not None and cmap_points.shape[0] > 0): 206 | select_points = np.concatenate((select_points, cmap_points), axis=0) 207 | if select_points is None: 208 | print("affordance contact points filtered out") 209 | return None 210 | select_points_homo = get_points_homo(select_points, homography, active_obj_traj, obj_bboxs_traj) 211 | if len(select_points_homo) == 0: 212 | print("affordance contact points filtered out") 213 | return None 214 | else: 215 | affordance_info["select_points"] = select_points 216 | affordance_info["select_points_homo"] = select_points_homo 217 | 218 | obj_bbox = np.array(active_obj.bbox.coords_int).reshape(-1) 219 | affordance_info["obj_bbox"] = obj_bbox 220 | return affordance_info -------------------------------------------------------------------------------- /preprocess/dataset_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from preprocess.ho_types import FrameDetections, HandDetection, HandSide, HandState, ObjectDetection 5 | 6 | 7 | def sample_action_anticipation_frames(frame_start, t_buffer=1, fps=4.0, fps_init=60.0): 8 | time_start = (frame_start - 1) / fps_init 9 | num_frames = int(np.floor(t_buffer * fps)) 10 | times = (np.arange(num_frames + 1) - num_frames) / fps + time_start 11 | times = np.clip(times, 0, np.inf) 12 | times = times.astype(np.float32) 13 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1 14 | if frames_idxs.max() >= 1: 15 | frames_idxs[frames_idxs < 1] = frames_idxs[frames_idxs >= 1].min() 16 | return list(frames_idxs) 17 | 18 | 19 | def load_ho_annot(video_detections, frame_index, imgW=456, imgH=256): 20 | annot = video_detections[frame_index-1] # frame_index start from 1 21 | assert annot.frame_number == frame_index, "wrong frame index" 22 | annot.scale(width_factor=imgW, height_factor=imgH) 23 | return annot 24 | 25 | 26 | def load_img(frames_path, frame_index): 27 | frame = cv2.imread(os.path.join(frames_path, "frame_{:010d}.jpg".format(frame_index))) 28 | return frame 29 | 30 | 31 | def get_mask(frame, annot, hand_threshold=0.1, obj_threshold=0.1): 32 | msk_img = np.ones((frame.shape[:2]), dtype=frame.dtype) 33 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 34 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 35 | for hand in hands: 36 | (x1, y1), (x2, y2) = hand.bbox.coords_int 37 | msk_img[y1:y2, x1:x2] = 0 38 | 39 | if len(objs) > 0: 40 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 41 | hand_threshold=hand_threshold) 42 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 43 | hand = annot.hands[hand_idx] 44 | object = annot.objects[object_idx] 45 | if not hand.state.value == HandState.STATIONARY_OBJECT.value: 46 | (x1, y1), (x2, y2) = object.bbox.coords_int 47 | msk_img[y1:y2, x1:x2] = 0 48 | return msk_img 49 | 50 | 51 | def bbox_inter(boxA, boxB): 52 | xA = max(boxA[0], boxB[0]) 53 | yA = max(boxA[1], boxB[1]) 54 | xB = min(boxA[2], boxB[2]) 55 | yB = min(boxA[3], boxB[3]) 56 | 57 | interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0)) 58 | if interArea == 0: 59 | return xA, yA, xB, yB, 0 60 | 61 | boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1])) 62 | boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1])) 63 | iou = interArea / float(boxAArea + boxBArea - interArea) 64 | return xA, yA, xB, yB, iou 65 | 66 | 67 | def compute_iou(boxA, boxB): 68 | boxA = np.array(boxA).reshape(-1) 69 | boxB = np.array(boxB).reshape(-1) 70 | xA = max(boxA[0], boxB[0]) 71 | yA = max(boxA[1], boxB[1]) 72 | xB = min(boxA[2], boxB[2]) 73 | yB = min(boxA[3], boxB[3]) 74 | interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0)) 75 | if interArea == 0: 76 | return 0 77 | boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1])) 78 | boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1])) 79 | iou = interArea / float(boxAArea + boxBArea - interArea) 80 | return iou 81 | 82 | 83 | def points_in_bbox(point, bbox): 84 | (x1, y1), (x2, y2) = bbox 85 | (x, y) = point 86 | return (x1 <= x <= x2) and (y1 <= y <= y2) 87 | 88 | 89 | def valid_point(point, imgW=456, imgH=256): 90 | if point is None: 91 | return False 92 | else: 93 | x, y = point 94 | return (0 <= x < imgW) and (0 <=y < imgH) 95 | 96 | 97 | def valid_traj(traj, imgW=456, imgH=256): 98 | if len(traj) > 0: 99 | num_outlier = np.sum([not valid_point(point, imgW=imgW, imgH=imgH) 100 | for point in traj if point is not None]) 101 | valid_ratio = np.sum([valid_point(point, imgW=imgW, imgH=imgH) for point in traj[1:]]) / len(traj[1:]) 102 | valid_last = valid_point(traj[-1], imgW=imgW, imgH=imgH) 103 | if num_outlier > 1 or valid_ratio < 0.5 or not valid_last: 104 | traj = [] 105 | return traj 106 | 107 | 108 | def get_valid_traj(traj, imgW=456, imgH=256): 109 | try: 110 | traj[traj < 0] = traj[traj >= 0].min() 111 | except: 112 | traj[traj < 0] = 0 113 | try: 114 | traj[:, 0][traj[:, 0] >= imgW] = imgW - 1 115 | except: 116 | traj[:, 0][traj[:, 0] >= imgW] = imgW - 1 117 | try: 118 | traj[:, 1][traj[:, 1] >= imgH] = imgH - 1 119 | except: 120 | traj[:, 1][traj[:, 1] >= imgH] = imgH - 1 121 | return traj 122 | 123 | 124 | def fetch_data(frames_path, video_detections, frames_idxs, hand_threshold=0.1, obj_threshold=0.1): 125 | tolerance = frames_idxs[1] - frames_idxs[0] # extend future act frame by tolerance to find ho interaction 126 | frames = [] 127 | annots = [] 128 | 129 | miss_hand = 0 130 | for frame_idx in frames_idxs[:-1]: 131 | frame = load_img(frames_path, frame_idx) 132 | annot = load_ho_annot(video_detections, frame_idx) 133 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 134 | if len(hands) == 0: 135 | miss_hand += 1 136 | frames.append(frame) 137 | annots.append(annot) 138 | if miss_hand == len(frames_idxs[:-1]): 139 | return None 140 | frame_idx = frames_idxs[-1] 141 | frames_idxs = frames_idxs[:-1] 142 | 143 | hand_sides = [] 144 | idx = 0 145 | flag = False 146 | while idx < tolerance: 147 | annot = load_ho_annot(video_detections, frame_idx) 148 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 149 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 150 | if len(hands) > 0 and len(objs) > 0: # at least one hand is contact with obj 151 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 152 | hand_threshold=hand_threshold) 153 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 154 | hand_bbox = np.array(annot.hands[hand_idx].bbox.coords).reshape(-1) 155 | obj_bbox = np.array(annot.objects[object_idx].bbox.coords).reshape(-1) 156 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox) 157 | contact_state = annot.hands[hand_idx].state.value 158 | if iou > 0 and (contact_state == HandState.STATIONARY_OBJECT.value or 159 | contact_state == HandState.PORTABLE_OBJECT.value): 160 | hand_side = annot.hands[hand_idx].side.name 161 | hand_sides.append(hand_side) 162 | flag = True 163 | if flag: 164 | break 165 | else: 166 | idx += 1 167 | frame_idx += 1 168 | else: 169 | idx += 1 170 | frame_idx += 1 171 | if flag: 172 | frames_idxs.append(frame_idx) 173 | frames.append(load_img(frames_path, frame_idx)) 174 | annots.append(annot) 175 | return frames_idxs, frames, annots, list(set(hand_sides)) # remove redundant hand sides 176 | else: 177 | return None 178 | 179 | 180 | def save_video_info(save_path, video_index, frames_idxs, homography_stack, contacts, 181 | hand_trajs, obj_trajs, affordance_info): 182 | import pickle 183 | video_info = {"frame_indices": frames_idxs, 184 | "homography": homography_stack, 185 | "contact": contacts} 186 | video_info.update({"hand_trajs": hand_trajs}) 187 | video_info.update({"obj_trajs": obj_trajs}) 188 | video_info.update({"affordance": affordance_info}) 189 | with open(os.path.join(save_path, "label_{}.pkl".format(video_index)), 'wb') as f: 190 | pickle.dump(video_info, f) 191 | 192 | 193 | def load_video_info(save_path, video_index): 194 | import pickle 195 | with open(os.path.join(save_path, "label_{}.pkl".format(video_index)), 'rb') as f: 196 | video_info = pickle.load(f) 197 | return video_info 198 | -------------------------------------------------------------------------------- /preprocess/ho_types.py: -------------------------------------------------------------------------------- 1 | """The core set of types that represent hand-object detections""" 2 | 3 | from enum import Enum, unique 4 | from itertools import chain 5 | from typing import Dict, Iterator, List, Tuple, cast 6 | 7 | import numpy as np 8 | from dataclasses import dataclass 9 | import preprocess.types_pb2 as pb 10 | 11 | __all__ = [ 12 | "HandSide", 13 | "HandState", 14 | "FloatVector", 15 | "BBox", 16 | "HandDetection", 17 | "ObjectDetection", 18 | "FrameDetections", 19 | ] 20 | 21 | 22 | @unique 23 | class HandSide(Enum): 24 | LEFT = 0 25 | RIGHT = 1 26 | 27 | 28 | @unique 29 | class HandState(Enum): 30 | """An enum describing the different states a hand can be in: 31 | - No contact: The hand isn't touching anything 32 | - Self contact: The hand is touching itself 33 | - Another person: The hand is touching another person 34 | - Portable object: The hand is in contact with a portable object 35 | - Stationary object: The hand is in contact with an immovable/stationary object""" 36 | 37 | NO_CONTACT = 0 38 | SELF_CONTACT = 1 39 | ANOTHER_PERSON = 2 40 | PORTABLE_OBJECT = 3 41 | STATIONARY_OBJECT = 4 42 | 43 | 44 | @dataclass 45 | class FloatVector: 46 | """A floating-point 2D vector representation""" 47 | x: np.float32 48 | y: np.float32 49 | 50 | def to_protobuf(self) -> pb.FloatVector: 51 | vector = pb.FloatVector() 52 | vector.x = self.x 53 | vector.y = self.y 54 | assert vector.IsInitialized() 55 | return vector 56 | 57 | @staticmethod 58 | def from_protobuf(vector: pb.FloatVector) -> "FloatVector": 59 | return FloatVector(x=vector.x, y=vector.y) 60 | 61 | def __add__(self, other: "FloatVector") -> "FloatVector": 62 | return FloatVector(x=self.x + other.x, y=self.y + other.y) 63 | 64 | def __mul__(self, scaler: float) -> "FloatVector": 65 | return FloatVector(x=self.x * scaler, y=self.y * scaler) 66 | 67 | def __iter__(self) -> Iterator[float]: 68 | yield from (self.x, self.y) 69 | 70 | @property 71 | def coord(self) -> Tuple[float, float]: 72 | """Return coordinates as a tuple""" 73 | return (self.x, self.y) 74 | 75 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 76 | """Scale x component by ``width_factor`` and y component by ``height_factor``""" 77 | self.x *= width_factor 78 | self.y *= height_factor 79 | 80 | 81 | @dataclass 82 | class BBox: 83 | left: float 84 | top: float 85 | right: float 86 | bottom: float 87 | 88 | def to_protobuf(self) -> pb.BBox: 89 | bbox = pb.BBox() 90 | bbox.left = self.left 91 | bbox.top = self.top 92 | bbox.right = self.right 93 | bbox.bottom = self.bottom 94 | assert bbox.IsInitialized() 95 | return bbox 96 | 97 | @staticmethod 98 | def from_protobuf(bbox: pb.BBox) -> "BBox": 99 | return BBox( 100 | left=bbox.left, 101 | top=bbox.top, 102 | right=bbox.right, 103 | bottom=bbox.bottom, 104 | ) 105 | 106 | @property 107 | def center(self) -> Tuple[float, float]: 108 | x = (self.left + self.right) / 2 109 | y = (self.top + self.bottom) / 2 110 | return x, y 111 | 112 | @property 113 | def center_int(self) -> Tuple[int, int]: 114 | """Get center position as a tuple of integers (rounded)""" 115 | x, y = self.center 116 | return (round(x), round(y)) 117 | 118 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 119 | self.left *= width_factor 120 | self.right *= width_factor 121 | self.top *= height_factor 122 | self.bottom *= height_factor 123 | 124 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 125 | x, y = self.center 126 | new_width = self.width * width_factor 127 | new_height = self.height * height_factor 128 | self.left = x - new_width / 2 129 | self.right = x + new_width / 2 130 | self.top = y - new_height / 2 131 | self.bottom = y + new_height / 2 132 | 133 | @property 134 | def coords(self) -> Tuple[Tuple[float, float], Tuple[float, float]]: 135 | return ( 136 | self.top_left, 137 | self.bottom_right, 138 | ) 139 | 140 | @property 141 | def coords_int(self) -> Tuple[Tuple[int, int], Tuple[int, int]]: 142 | return ( 143 | self.top_left_int, 144 | self.bottom_right_int, 145 | ) 146 | 147 | @property 148 | def width(self) -> float: 149 | return self.right - self.left 150 | 151 | @property 152 | def height(self) -> float: 153 | return self.bottom - self.top 154 | 155 | @property 156 | def top_left(self) -> Tuple[float, float]: 157 | return (self.left, self.top) 158 | 159 | @property 160 | def bottom_right(self) -> Tuple[float, float]: 161 | return (self.right, self.bottom) 162 | 163 | @property 164 | def top_left_int(self) -> Tuple[int, int]: 165 | return (round(self.left), round(self.top)) 166 | 167 | @property 168 | def bottom_right_int(self) -> Tuple[int, int]: 169 | return (round(self.right), round(self.bottom)) 170 | 171 | 172 | @dataclass 173 | class HandDetection: 174 | """Dataclass representing a hand detection, consisting of a bounding box, 175 | a score (representing the model's confidence this is a hand), the predicted state 176 | of the hand, whether this is a left/right hand, and a predicted offset to the 177 | interacted object if the hand is interacting.""" 178 | 179 | bbox: BBox 180 | score: np.float32 181 | state: HandState 182 | side: HandSide 183 | object_offset: FloatVector 184 | 185 | def to_protobuf(self) -> pb.HandDetection: 186 | detection = pb.HandDetection() 187 | detection.bbox.MergeFrom(self.bbox.to_protobuf()) 188 | detection.score = self.score 189 | detection.state = self.state.value 190 | detection.object_offset.MergeFrom(self.object_offset.to_protobuf()) 191 | detection.side = self.side.value 192 | assert detection.IsInitialized() 193 | return detection 194 | 195 | @staticmethod 196 | def from_protobuf(detection: pb.HandDetection) -> "HandDetection": 197 | return HandDetection( 198 | bbox=BBox.from_protobuf(detection.bbox), 199 | score=detection.score, 200 | state=HandState(detection.state), 201 | object_offset=FloatVector.from_protobuf(detection.object_offset), 202 | side=HandSide(detection.side), 203 | ) 204 | 205 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 206 | self.bbox.scale(width_factor=width_factor, height_factor=height_factor) 207 | self.object_offset.scale(width_factor=width_factor, height_factor=height_factor) 208 | 209 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 210 | self.bbox.center_scale(width_factor=width_factor, height_factor=height_factor) 211 | 212 | 213 | @dataclass 214 | class ObjectDetection: 215 | """Dataclass representing an object detection, consisting of a bounding box and a 216 | score (the model's confidence this is an object)""" 217 | 218 | bbox: BBox 219 | score: np.float32 220 | 221 | def to_protobuf(self) -> pb.ObjectDetection: 222 | detection = pb.ObjectDetection() 223 | detection.bbox.MergeFrom(self.bbox.to_protobuf()) 224 | detection.score = self.score 225 | assert detection.IsInitialized() 226 | return detection 227 | 228 | @staticmethod 229 | def from_protobuf(detection: pb.ObjectDetection) -> "ObjectDetection": 230 | return ObjectDetection( 231 | bbox=BBox.from_protobuf(detection.bbox), score=detection.score 232 | ) 233 | 234 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 235 | self.bbox.scale(width_factor=width_factor, height_factor=height_factor) 236 | 237 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 238 | self.bbox.center_scale(width_factor=width_factor, height_factor=height_factor) 239 | 240 | 241 | @dataclass 242 | class FrameDetections: 243 | """Dataclass representing hand-object detections for a frame of a video""" 244 | 245 | video_id: str 246 | frame_number: int 247 | objects: List[ObjectDetection] 248 | hands: List[HandDetection] 249 | 250 | def to_protobuf(self) -> pb.Detections: 251 | detections = pb.Detections() 252 | detections.video_id = self.video_id 253 | detections.frame_number = self.frame_number 254 | detections.hands.extend([hand.to_protobuf() for hand in self.hands]) 255 | detections.objects.extend([object.to_protobuf() for object in self.objects]) 256 | assert detections.IsInitialized() 257 | return detections 258 | 259 | @staticmethod 260 | def from_protobuf(detections: pb.Detections) -> "FrameDetections": 261 | return FrameDetections( 262 | video_id=detections.video_id, 263 | frame_number=detections.frame_number, 264 | hands=[HandDetection.from_protobuf(pb) for pb in detections.hands], 265 | objects=[ObjectDetection.from_protobuf(pb) for pb in detections.objects], 266 | ) 267 | 268 | @staticmethod 269 | def from_protobuf_str(pb_str: bytes) -> "FrameDetections": 270 | pb_detection = pb.Detections() 271 | pb_detection.MergeFromString(pb_str) 272 | return FrameDetections.from_protobuf(pb_detection) 273 | 274 | def get_hand_object_interactions( 275 | self, object_threshold: float = 0, hand_threshold: float = 0 276 | ) -> Dict[int, int]: 277 | """Match the hands to objects based on the hand offset vector that the model 278 | uses to predict the location of the interacted object. 279 | 280 | Args: 281 | object_threshold: Object score threshold above which to consider objects 282 | for matching 283 | hand_threshold: Hand score threshold above which to consider hands for 284 | matching. 285 | 286 | Returns: 287 | A dictionary mapping hand detections to objects by indices 288 | """ 289 | interactions = dict() 290 | object_idxs = [ 291 | i for i, obj in enumerate(self.objects) if obj.score >= object_threshold 292 | ] 293 | object_centers = np.array( 294 | [self.objects[object_id].bbox.center for object_id in object_idxs] 295 | ) 296 | for hand_idx, hand_detection in enumerate(self.hands): 297 | if ( 298 | hand_detection.state.value == HandState.NO_CONTACT.value 299 | or hand_detection.score <= hand_threshold 300 | ): 301 | continue 302 | estimated_object_position = ( 303 | np.array(hand_detection.bbox.center) + 304 | np.array(hand_detection.object_offset.coord) 305 | ) 306 | distances = ((object_centers - estimated_object_position) ** 2).sum( 307 | axis=-1) 308 | interactions[hand_idx] = object_idxs[cast(int, np.argmin(distances))] 309 | return interactions 310 | 311 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 312 | """ 313 | Scale the coordinates of all the hands/objects. x components are multiplied 314 | by the ``width_factor`` and y components by the ``height_factor`` 315 | """ 316 | for det in chain(self.hands, self.objects): 317 | det.scale(width_factor=width_factor, height_factor=height_factor) 318 | 319 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None: 320 | """ 321 | Scale all the hands/objects about their center points. 322 | """ 323 | for det in chain(self.hands, self.objects): 324 | det.center_scale(width_factor=width_factor, height_factor=height_factor) -------------------------------------------------------------------------------- /preprocess/obj_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from preprocess.dataset_util import bbox_inter, HandState, compute_iou, \ 3 | valid_traj, get_valid_traj, points_in_bbox 4 | from preprocess.traj_util import get_homo_point, get_homo_bbox_point 5 | 6 | 7 | def find_active_side(annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1): 8 | if len(hand_sides) == 1: 9 | return hand_sides[0] 10 | else: 11 | hand_counter = {"LEFT": 0, "RIGHT": 0} 12 | for annot in annots: 13 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 14 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 15 | if len(hands) > 0 and len(objs) > 0: 16 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 17 | hand_threshold=hand_threshold) 18 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 19 | hand_bbox = np.array(annot.hands[hand_idx].bbox.coords_int).reshape(-1) 20 | obj_bbox = np.array(annot.objects[object_idx].bbox.coords_int).reshape(-1) 21 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox) 22 | if iou > 0: 23 | hand_side = annot.hands[hand_idx].side.name 24 | if annot.hands[hand_idx].state.value == HandState.PORTABLE_OBJECT.value: 25 | hand_counter[hand_side] += 1 26 | elif annot.hands[hand_idx].state.value == HandState.STATIONARY_OBJECT.value: 27 | hand_counter[hand_side] += 0.5 28 | if hand_counter["LEFT"] == hand_counter["RIGHT"]: 29 | return "RIGHT" 30 | else: 31 | return max(hand_counter, key=hand_counter.get) 32 | 33 | 34 | def compute_contact(annots, hand_side, contact_state, hand_threshold=0.1): 35 | contacts = [] 36 | for annot in annots: 37 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold 38 | and hand.side.name == hand_side and hand.state.value == contact_state] 39 | if len(hands) > 0: 40 | contacts.append(1) 41 | else: 42 | contacts.append(0) 43 | contacts = np.array(contacts) 44 | padding_contacts = np.pad(contacts, [1, 1], 'edge') 45 | contacts = np.convolve(padding_contacts, [1, 1, 1], 'same') 46 | contacts = contacts[1:-1] / 3 47 | contacts = contacts > 0.5 48 | indices = np.diff(contacts) != 0 49 | if indices.sum() == 0: 50 | return contacts 51 | else: 52 | split = np.where(indices)[0] + 1 53 | contacts_idx = split[-1] 54 | contacts[:contacts_idx] = False 55 | return contacts 56 | 57 | 58 | def find_active_obj_side(annot, hand_side, return_hand=False, return_idx=False, hand_threshold=0.1, obj_threshold=0.1): 59 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 60 | if len(objs) == 0: 61 | return None 62 | else: 63 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold, 64 | hand_threshold=hand_threshold) 65 | for hand_idx, object_idx in hand_object_idx_correspondences.items(): 66 | if annot.hands[hand_idx].side.name == hand_side: 67 | if return_hand and return_idx: 68 | return annot.objects[object_idx], object_idx, annot.hands[hand_idx], hand_idx 69 | elif return_hand: 70 | return annot.objects[object_idx], annot.hands[hand_idx] 71 | elif return_idx: 72 | return annot.objects[object_idx], object_idx 73 | else: 74 | return annot.objects[object_idx] 75 | return None 76 | 77 | 78 | def find_active_obj_iou(objs, bbox): 79 | max_iou = 0 80 | active_obj = None 81 | for obj in objs: 82 | iou = compute_iou(obj.bbox.coords, bbox) 83 | if iou > max_iou: 84 | max_iou = iou 85 | active_obj = obj 86 | return active_obj, max_iou 87 | 88 | 89 | def traj_compute(annots, hand_sides, homography_stack, hand_threshold=0.1, obj_threshold=0.1): 90 | annot = annots[-1] 91 | obj_traj = [] 92 | obj_centers = [] 93 | obj_bboxs =[] 94 | obj_bboxs_traj = [] 95 | active_hand_side = find_active_side(annots, hand_sides, hand_threshold=hand_threshold, 96 | obj_threshold=obj_threshold) 97 | active_obj, active_object_idx, active_hand, active_hand_idx = find_active_obj_side(annot, 98 | hand_side=active_hand_side, 99 | return_hand=True, return_idx=True, 100 | hand_threshold=hand_threshold, 101 | obj_threshold=obj_threshold) 102 | contact_state = active_hand.state.value 103 | contacts = compute_contact(annots, active_hand_side, contact_state, 104 | hand_threshold=hand_threshold) 105 | obj_center = active_obj.bbox.center 106 | obj_centers.append(obj_center) 107 | obj_point = get_homo_point(obj_center, homography_stack[-1]) 108 | obj_bbox = active_obj.bbox.coords 109 | obj_traj.append(obj_point) 110 | obj_bboxs.append(obj_bbox) 111 | 112 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[-1]) 113 | obj_bboxs_traj.append(obj_points2d) 114 | 115 | for idx in np.arange(len(annots)-2, -1, -1): 116 | annot = annots[idx] 117 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold] 118 | contact = contacts[idx] 119 | if not contact: 120 | obj_centers.append(None) 121 | obj_traj.append(None) 122 | obj_bboxs_traj.append(None) 123 | else: 124 | if len(objs) >= 2: 125 | target_obj, max_iou = find_active_obj_iou(objs, obj_bboxs[-1]) 126 | if target_obj is None: 127 | target_obj = find_active_obj_side(annot, hand_side=active_hand_side, 128 | hand_threshold=hand_threshold, 129 | obj_threshold=obj_threshold) 130 | if target_obj is None: 131 | obj_centers.append(None) 132 | obj_traj.append(None) 133 | obj_bboxs_traj.append(None) 134 | else: 135 | obj_center = target_obj.bbox.center 136 | obj_centers.append(obj_center) 137 | obj_point = get_homo_point(obj_center, homography_stack[idx]) 138 | obj_bbox = target_obj.bbox.coords 139 | obj_traj.append(obj_point) 140 | obj_bboxs.append(obj_bbox) 141 | 142 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[idx]) 143 | obj_bboxs_traj.append(obj_points2d) 144 | 145 | elif len(objs) > 0: 146 | target_obj = find_active_obj_side(annot, hand_side=active_hand_side, 147 | hand_threshold=hand_threshold, 148 | obj_threshold=obj_threshold) 149 | if target_obj is None: 150 | obj_centers.append(None) 151 | obj_traj.append(None) 152 | obj_bboxs_traj.append(None) 153 | else: 154 | obj_center = target_obj.bbox.center 155 | obj_centers.append(obj_center) 156 | obj_point = get_homo_point(obj_center, homography_stack[idx]) 157 | obj_bbox = target_obj.bbox.coords 158 | obj_traj.append(obj_point) 159 | obj_bboxs.append(obj_bbox) 160 | 161 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[idx]) 162 | obj_bboxs_traj.append(obj_points2d) 163 | else: 164 | obj_centers.append(None) 165 | obj_traj.append(None) 166 | obj_bboxs_traj.append(None) 167 | obj_bboxs.reverse() 168 | obj_traj.reverse() 169 | obj_centers.reverse() 170 | obj_bboxs_traj.reverse() 171 | return obj_traj, obj_centers, obj_bboxs, contacts, active_obj, active_object_idx, obj_bboxs_traj 172 | 173 | 174 | def traj_filter(obj_traj, obj_centers, obj_bbox, contacts, homography_stack, contact_ratio=0.4): 175 | assert len(obj_traj) == len(obj_centers), "traj length and center length not equal" 176 | assert len(obj_centers) == len(homography_stack), "center length and homography length not equal" 177 | homo_last2first = homography_stack[-1] 178 | homo_first2last = np.linalg.inv(homo_last2first) 179 | obj_points = [] 180 | obj_inside, obj_detect = [], [] 181 | for idx, obj_center in enumerate(obj_centers): 182 | if obj_center is not None: 183 | homo_current2first = homography_stack[idx] 184 | homo_current2last = homo_current2first.dot(homo_first2last) 185 | obj_point = get_homo_point(obj_center, homo_current2last) 186 | obj_points.append(obj_point) 187 | obj_inside.append(points_in_bbox(obj_point, obj_bbox)) 188 | obj_detect.append(True) 189 | else: 190 | obj_detect.append(False) 191 | obj_inside = np.array(obj_inside) 192 | obj_detect = np.array(obj_detect) 193 | contacts = np.bitwise_and(obj_detect, contacts) 194 | if np.sum(obj_inside) == len(obj_inside) and np.sum(contacts) / len(contacts) < contact_ratio: 195 | obj_traj = np.tile(obj_traj[-1], (len(obj_traj), 1)) 196 | return obj_traj, contacts 197 | 198 | 199 | def traj_completion(traj, imgW=456, imgH=256): 200 | fill_indices = [idx for idx, point in enumerate(traj) if point is not None] 201 | full_traj = traj.copy() 202 | if len(fill_indices) == 1: 203 | point = traj[fill_indices[0]] 204 | full_traj = np.array([point] * len(traj), dtype=np.float32) 205 | else: 206 | contact_time = fill_indices[0] 207 | if contact_time > 0: 208 | full_traj[:contact_time] = [traj[contact_time]] * contact_time 209 | for previous_idx, current_idx in zip(fill_indices[:-1], fill_indices[1:]): 210 | start_point, end_point = traj[previous_idx], traj[current_idx] 211 | time_expand = current_idx - previous_idx 212 | for idx in range(previous_idx+1, current_idx): 213 | full_traj[idx] = (idx-previous_idx) / time_expand * end_point + (current_idx-idx) / time_expand * start_point 214 | full_traj = np.array(full_traj, dtype=np.float32) 215 | full_traj = get_valid_traj(full_traj, imgW=imgW, imgH=imgH) 216 | return full_traj, fill_indices 217 | 218 | 219 | def compute_obj_traj(frames, annots, hand_sides, homography_stack, hand_threshold=0.1, obj_threshold=0.1, 220 | contact_ratio=0.4): 221 | imgH, imgW = frames[0].shape[:2] 222 | obj_traj, obj_centers, obj_bboxs, contacts, active_obj, active_object_idx, obj_bboxs_traj = traj_compute(annots, hand_sides, homography_stack, 223 | hand_threshold=hand_threshold, obj_threshold=obj_threshold) 224 | obj_traj, contacts = traj_filter(obj_traj, obj_centers, obj_bboxs[-1], contacts, homography_stack, 225 | contact_ratio=contact_ratio) 226 | obj_traj = valid_traj(obj_traj, imgW=imgW, imgH=imgH) 227 | if len(obj_traj) == 0: 228 | print("object traj filtered out") 229 | return None 230 | else: 231 | complete_traj, fill_indices = traj_completion(obj_traj, imgW=imgW, imgH=imgH) 232 | obj_trajs = {"traj": complete_traj, "fill_indices": fill_indices, "centers": obj_centers} 233 | return contacts, obj_trajs, active_obj, active_object_idx, obj_bboxs_traj -------------------------------------------------------------------------------- /preprocess/traj_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from preprocess.dataset_util import get_mask, valid_traj 4 | 5 | 6 | def match_keypoints(kpsA, kpsB, featuresA, featuresB, ratio=0.7, reprojThresh=4.0): 7 | matcher = cv2.DescriptorMatcher_create("BruteForce") 8 | rawMatches = matcher.knnMatch(featuresA, featuresB, 2) 9 | matches = [] 10 | 11 | for m in rawMatches: 12 | if len(m) == 2 and m[0].distance < m[1].distance * ratio: 13 | matches.append((m[0])) 14 | 15 | if len(matches) > 4: 16 | ptsA = np.float32([kpsA[m.queryIdx].pt for m in matches]) 17 | ptsB = np.float32([kpsB[m.trainIdx].pt for m in matches]) 18 | (H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, reprojThresh) 19 | matchesMask = status.ravel().tolist() 20 | return matches, H, matchesMask 21 | return None 22 | 23 | 24 | def get_pair_homography(frame_1, frame_2, annot_1, annot_2, hand_threshold=0.1, obj_threshold=0.1): 25 | flag = True 26 | descriptor = cv2.xfeatures2d.SURF_create() 27 | msk_img_1 = get_mask(frame_1, annot_1, hand_threshold=hand_threshold, obj_threshold=obj_threshold) 28 | msk_img_2 = get_mask(frame_2, annot_2, hand_threshold=hand_threshold, obj_threshold=obj_threshold) 29 | (kpsA, featuresA) = descriptor.detectAndCompute(frame_1, mask=msk_img_1) 30 | (kpsB, featuresB) = descriptor.detectAndCompute(frame_2, mask=msk_img_2) 31 | matches, matchesMask = None, None 32 | try: 33 | (matches, H_BA, matchesMask) = match_keypoints(kpsB, kpsA, featuresB, featuresA) 34 | except Exception: 35 | print("compute homography failed!") 36 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3) 37 | flag = False 38 | 39 | NoneType = type(None) 40 | if type(H_BA) == NoneType: 41 | print("compute homography failed!") 42 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3) 43 | flag = False 44 | try: 45 | np.linalg.inv(H_BA) 46 | except Exception: 47 | print("compute homography failed!") 48 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3) 49 | flag = False 50 | return matches, H_BA, matchesMask, flag 51 | 52 | 53 | def get_homo_point(point, homography): 54 | cx, cy = point 55 | center = np.array((cx, cy, 1.0), dtype=np.float32) 56 | x, y, z = np.dot(homography, center) 57 | x, y = x / z, y / z 58 | point = np.array((x, y), dtype=np.float32) 59 | return point 60 | 61 | 62 | def get_homo_bbox_point(bbox, homography): 63 | x1, y1, x2, y2 = np.array(bbox).reshape(-1) 64 | points = np.array([[x1, y1], [x2, y1], [x1, y2], [x2, y2]], dtype=np.float32) 65 | points_homo = np.concatenate((points, np.ones((4, 1), dtype=np.float32)), axis=1) 66 | points_coord = np.dot(points_homo, homography.T) 67 | points_coord2d = points_coord[:, :2] / points_coord[:, None, 2] 68 | return points_coord2d 69 | 70 | 71 | def get_hand_center(annot, hand_threshold=0.1): 72 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold] 73 | hands_center= {} 74 | hands_score = {} 75 | for hand in hands: 76 | side = hand.side.name 77 | score = hand.score 78 | if side not in hands_center or score > hands_score[side]: 79 | hands_center[side] = hand.bbox.center 80 | hands_score[side] = score 81 | return hands_center 82 | 83 | 84 | def get_hand_point(hands_center, homography, side): 85 | point, homo_point = None, None 86 | if side in hands_center: 87 | point = hands_center[side] 88 | homo_point = get_homo_point(point, homography) 89 | return point, homo_point 90 | 91 | 92 | def traj_compute(frames, annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1): 93 | imgH, imgW = frames[0].shape[:2] 94 | left_traj, right_traj = [], [] 95 | left_centers , right_centers= [], [] 96 | homography_stack = [np.eye(3)] 97 | for idx in range(1, len(frames)): 98 | matches, H_BA, matchesMask, flag = get_pair_homography(frames[idx - 1], frames[idx], 99 | annots[idx - 1], annots[idx], 100 | hand_threshold=hand_threshold, 101 | obj_threshold=obj_threshold) 102 | if not flag: 103 | return None 104 | else: 105 | homography_stack.append(np.dot(homography_stack[-1], H_BA)) 106 | for idx in range(len(frames)): 107 | hands_center = get_hand_center(annots[idx], hand_threshold=hand_threshold) 108 | if "LEFT" in hand_sides: 109 | left_center, left_point = get_hand_point(hands_center, homography_stack[idx], "LEFT") 110 | left_centers.append(left_center) 111 | left_traj.append(left_point) 112 | if "RIGHT" in hand_sides: 113 | right_center, right_point = get_hand_point(hands_center, homography_stack[idx], "RIGHT") 114 | right_centers.append(right_center) 115 | right_traj.append(right_point) 116 | 117 | left_traj = valid_traj(left_traj, imgW=imgW, imgH=imgH) 118 | right_traj = valid_traj(right_traj, imgW=imgW, imgH=imgH) 119 | return left_traj, left_centers, right_traj, right_centers, homography_stack 120 | 121 | 122 | def traj_completion(traj, side, imgW=456, imgH=256): 123 | from scipy.interpolate import CubicHermiteSpline 124 | 125 | def get_valid_traj(traj, imgW, imgH): 126 | traj[traj < 0] = traj[traj >= 0].min() 127 | traj[:, 0][traj[:, 0] > 1.5 * imgW] = 1.5 * imgW 128 | traj[:, 1][traj[:, 1] > 1.5 * imgH] = 1.5 * imgH 129 | return traj 130 | 131 | def spline_interpolation(axis): 132 | fill_times = np.array(fill_indices, dtype=np.float32) 133 | fill_traj = np.array([traj[idx][axis] for idx in fill_indices], dtype=np.float32) 134 | dt = fill_times[2:] - fill_times[:-2] 135 | dt = np.hstack([fill_times[1] - fill_times[0], dt, fill_times[-1] - fill_times[-2]]) 136 | dx = fill_traj[2:] - fill_traj[:-2] 137 | dx = np.hstack([fill_traj[1] - fill_traj[0], dx, fill_traj[-1] - fill_traj[-2]]) 138 | dxdt = dx / dt 139 | curve = CubicHermiteSpline(fill_times, fill_traj, dxdt) 140 | full_traj = curve(np.arange(len(traj), dtype=np.float32)) 141 | return full_traj, curve 142 | 143 | fill_indices = [idx for idx, point in enumerate(traj) if point is not None] 144 | if 0 not in fill_indices: 145 | if side == "LEFT": 146 | traj[0] = np.array((0.25*imgW, 1.5*imgH), dtype=np.float32) 147 | else: 148 | traj[0] = np.array((0.75*imgW, 1.5*imgH), dtype=np.float32) 149 | fill_indices = np.insert(fill_indices, 0, 0).tolist() 150 | fill_indices.sort() 151 | full_traj_x, curve_x = spline_interpolation(axis=0) 152 | full_traj_y, curve_y = spline_interpolation(axis=1) 153 | full_traj = np.stack([full_traj_x, full_traj_y], axis=1) 154 | full_traj = get_valid_traj(full_traj, imgW=imgW, imgH=imgH) 155 | curve = [curve_x, curve_y] 156 | return full_traj, fill_indices, curve 157 | 158 | 159 | def compute_hand_traj(frames, annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1): 160 | imgH, imgW = frames[0].shape[:2] 161 | results = traj_compute(frames, annots, hand_sides, 162 | hand_threshold=hand_threshold, obj_threshold=obj_threshold) 163 | if results is None: 164 | print("compute homography failed") 165 | return None 166 | else: 167 | left_traj, left_centers, right_traj, right_centers, homography_stack = results 168 | if len(left_traj) == 0 and len(right_traj) == 0: 169 | print("compute traj failed") 170 | return None 171 | hand_trajs = {} 172 | if len(left_traj) == 0: 173 | print("left traj filtered out") 174 | else: 175 | left_complete_traj, left_fill_indices, left_curve = traj_completion(left_traj, side="LEFT", 176 | imgW=imgW, imgH=imgH) 177 | hand_trajs["LEFT"] = {"traj": left_complete_traj, "fill_indices": left_fill_indices, 178 | "fit_curve": left_curve, "centers": left_centers} 179 | if len(right_traj) == 0: 180 | print("right traj filtered out") 181 | else: 182 | right_complete_traj, right_fill_indices, right_curve = traj_completion(right_traj, side="RIGHT", 183 | imgW=imgW, imgH=imgH) 184 | hand_trajs["RIGHT"] = {"traj": right_complete_traj, "fill_indices": right_fill_indices, 185 | "fit_curve": right_curve, "centers": right_centers} 186 | return homography_stack, hand_trajs 187 | -------------------------------------------------------------------------------- /preprocess/types_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: types.proto 4 | 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | DESCRIPTOR = _descriptor.FileDescriptor( 15 | name='types.proto', 16 | package='model.detections', 17 | syntax='proto3', 18 | serialized_options=None, 19 | create_key=_descriptor._internal_create_key, 20 | serialized_pb=b'\n\x0btypes.proto\x12\x10model.detections\"#\n\x0b\x46loatVector\x12\t\n\x01x\x18\x01 \x01(\x02\x12\t\n\x01y\x18\x02 \x01(\x02\"@\n\x04\x42\x42ox\x12\x0c\n\x04left\x18\x01 \x01(\x02\x12\x0b\n\x03top\x18\x02 \x01(\x02\x12\r\n\x05right\x18\x03 \x01(\x02\x12\x0e\n\x06\x62ottom\x18\x04 \x01(\x02\"\xfc\x02\n\rHandDetection\x12$\n\x04\x62\x62ox\x18\x01 \x01(\x0b\x32\x16.model.detections.BBox\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x38\n\x05state\x18\x03 \x01(\x0e\x32).model.detections.HandDetection.HandState\x12\x34\n\robject_offset\x18\x04 \x01(\x0b\x32\x1d.model.detections.FloatVector\x12\x36\n\x04side\x18\x05 \x01(\x0e\x32(.model.detections.HandDetection.HandSide\"m\n\tHandState\x12\x0e\n\nNO_CONTACT\x10\x00\x12\x10\n\x0cSELF_CONTACT\x10\x01\x12\x12\n\x0e\x41NOTHER_PERSON\x10\x02\x12\x13\n\x0fPORTABLE_OBJECT\x10\x03\x12\x15\n\x11STATIONARY_OBJECT\x10\x04\"\x1f\n\x08HandSide\x12\x08\n\x04LEFT\x10\x00\x12\t\n\x05RIGHT\x10\x01\"F\n\x0fObjectDetection\x12$\n\x04\x62\x62ox\x18\x01 \x01(\x0b\x32\x16.model.detections.BBox\x12\r\n\x05score\x18\x02 \x01(\x02\"\x98\x01\n\nDetections\x12\x10\n\x08video_id\x18\x01 \x01(\t\x12\x14\n\x0c\x66rame_number\x18\x02 \x01(\x05\x12.\n\x05hands\x18\x03 \x03(\x0b\x32\x1f.model.detections.HandDetection\x12\x32\n\x07objects\x18\x04 \x03(\x0b\x32!.model.detections.ObjectDetectionb\x06proto3' 21 | ) 22 | 23 | 24 | 25 | _HANDDETECTION_HANDSTATE = _descriptor.EnumDescriptor( 26 | name='HandState', 27 | full_name='model.detections.HandDetection.HandState', 28 | filename=None, 29 | file=DESCRIPTOR, 30 | create_key=_descriptor._internal_create_key, 31 | values=[ 32 | _descriptor.EnumValueDescriptor( 33 | name='NO_CONTACT', index=0, number=0, 34 | serialized_options=None, 35 | type=None, 36 | create_key=_descriptor._internal_create_key), 37 | _descriptor.EnumValueDescriptor( 38 | name='SELF_CONTACT', index=1, number=1, 39 | serialized_options=None, 40 | type=None, 41 | create_key=_descriptor._internal_create_key), 42 | _descriptor.EnumValueDescriptor( 43 | name='ANOTHER_PERSON', index=2, number=2, 44 | serialized_options=None, 45 | type=None, 46 | create_key=_descriptor._internal_create_key), 47 | _descriptor.EnumValueDescriptor( 48 | name='PORTABLE_OBJECT', index=3, number=3, 49 | serialized_options=None, 50 | type=None, 51 | create_key=_descriptor._internal_create_key), 52 | _descriptor.EnumValueDescriptor( 53 | name='STATIONARY_OBJECT', index=4, number=4, 54 | serialized_options=None, 55 | type=None, 56 | create_key=_descriptor._internal_create_key), 57 | ], 58 | containing_type=None, 59 | serialized_options=None, 60 | serialized_start=375, 61 | serialized_end=484, 62 | ) 63 | _sym_db.RegisterEnumDescriptor(_HANDDETECTION_HANDSTATE) 64 | 65 | _HANDDETECTION_HANDSIDE = _descriptor.EnumDescriptor( 66 | name='HandSide', 67 | full_name='model.detections.HandDetection.HandSide', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | create_key=_descriptor._internal_create_key, 71 | values=[ 72 | _descriptor.EnumValueDescriptor( 73 | name='LEFT', index=0, number=0, 74 | serialized_options=None, 75 | type=None, 76 | create_key=_descriptor._internal_create_key), 77 | _descriptor.EnumValueDescriptor( 78 | name='RIGHT', index=1, number=1, 79 | serialized_options=None, 80 | type=None, 81 | create_key=_descriptor._internal_create_key), 82 | ], 83 | containing_type=None, 84 | serialized_options=None, 85 | serialized_start=486, 86 | serialized_end=517, 87 | ) 88 | _sym_db.RegisterEnumDescriptor(_HANDDETECTION_HANDSIDE) 89 | 90 | 91 | _FLOATVECTOR = _descriptor.Descriptor( 92 | name='FloatVector', 93 | full_name='model.detections.FloatVector', 94 | filename=None, 95 | file=DESCRIPTOR, 96 | containing_type=None, 97 | create_key=_descriptor._internal_create_key, 98 | fields=[ 99 | _descriptor.FieldDescriptor( 100 | name='x', full_name='model.detections.FloatVector.x', index=0, 101 | number=1, type=2, cpp_type=6, label=1, 102 | has_default_value=False, default_value=float(0), 103 | message_type=None, enum_type=None, containing_type=None, 104 | is_extension=False, extension_scope=None, 105 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 106 | _descriptor.FieldDescriptor( 107 | name='y', full_name='model.detections.FloatVector.y', index=1, 108 | number=2, type=2, cpp_type=6, label=1, 109 | has_default_value=False, default_value=float(0), 110 | message_type=None, enum_type=None, containing_type=None, 111 | is_extension=False, extension_scope=None, 112 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 113 | ], 114 | extensions=[ 115 | ], 116 | nested_types=[], 117 | enum_types=[ 118 | ], 119 | serialized_options=None, 120 | is_extendable=False, 121 | syntax='proto3', 122 | extension_ranges=[], 123 | oneofs=[ 124 | ], 125 | serialized_start=33, 126 | serialized_end=68, 127 | ) 128 | 129 | 130 | _BBOX = _descriptor.Descriptor( 131 | name='BBox', 132 | full_name='model.detections.BBox', 133 | filename=None, 134 | file=DESCRIPTOR, 135 | containing_type=None, 136 | create_key=_descriptor._internal_create_key, 137 | fields=[ 138 | _descriptor.FieldDescriptor( 139 | name='left', full_name='model.detections.BBox.left', index=0, 140 | number=1, type=2, cpp_type=6, label=1, 141 | has_default_value=False, default_value=float(0), 142 | message_type=None, enum_type=None, containing_type=None, 143 | is_extension=False, extension_scope=None, 144 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 145 | _descriptor.FieldDescriptor( 146 | name='top', full_name='model.detections.BBox.top', index=1, 147 | number=2, type=2, cpp_type=6, label=1, 148 | has_default_value=False, default_value=float(0), 149 | message_type=None, enum_type=None, containing_type=None, 150 | is_extension=False, extension_scope=None, 151 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 152 | _descriptor.FieldDescriptor( 153 | name='right', full_name='model.detections.BBox.right', index=2, 154 | number=3, type=2, cpp_type=6, label=1, 155 | has_default_value=False, default_value=float(0), 156 | message_type=None, enum_type=None, containing_type=None, 157 | is_extension=False, extension_scope=None, 158 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 159 | _descriptor.FieldDescriptor( 160 | name='bottom', full_name='model.detections.BBox.bottom', index=3, 161 | number=4, type=2, cpp_type=6, label=1, 162 | has_default_value=False, default_value=float(0), 163 | message_type=None, enum_type=None, containing_type=None, 164 | is_extension=False, extension_scope=None, 165 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 166 | ], 167 | extensions=[ 168 | ], 169 | nested_types=[], 170 | enum_types=[ 171 | ], 172 | serialized_options=None, 173 | is_extendable=False, 174 | syntax='proto3', 175 | extension_ranges=[], 176 | oneofs=[ 177 | ], 178 | serialized_start=70, 179 | serialized_end=134, 180 | ) 181 | 182 | 183 | _HANDDETECTION = _descriptor.Descriptor( 184 | name='HandDetection', 185 | full_name='model.detections.HandDetection', 186 | filename=None, 187 | file=DESCRIPTOR, 188 | containing_type=None, 189 | create_key=_descriptor._internal_create_key, 190 | fields=[ 191 | _descriptor.FieldDescriptor( 192 | name='bbox', full_name='model.detections.HandDetection.bbox', index=0, 193 | number=1, type=11, cpp_type=10, label=1, 194 | has_default_value=False, default_value=None, 195 | message_type=None, enum_type=None, containing_type=None, 196 | is_extension=False, extension_scope=None, 197 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 198 | _descriptor.FieldDescriptor( 199 | name='score', full_name='model.detections.HandDetection.score', index=1, 200 | number=2, type=2, cpp_type=6, label=1, 201 | has_default_value=False, default_value=float(0), 202 | message_type=None, enum_type=None, containing_type=None, 203 | is_extension=False, extension_scope=None, 204 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 205 | _descriptor.FieldDescriptor( 206 | name='state', full_name='model.detections.HandDetection.state', index=2, 207 | number=3, type=14, cpp_type=8, label=1, 208 | has_default_value=False, default_value=0, 209 | message_type=None, enum_type=None, containing_type=None, 210 | is_extension=False, extension_scope=None, 211 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 212 | _descriptor.FieldDescriptor( 213 | name='object_offset', full_name='model.detections.HandDetection.object_offset', index=3, 214 | number=4, type=11, cpp_type=10, label=1, 215 | has_default_value=False, default_value=None, 216 | message_type=None, enum_type=None, containing_type=None, 217 | is_extension=False, extension_scope=None, 218 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 219 | _descriptor.FieldDescriptor( 220 | name='side', full_name='model.detections.HandDetection.side', index=4, 221 | number=5, type=14, cpp_type=8, label=1, 222 | has_default_value=False, default_value=0, 223 | message_type=None, enum_type=None, containing_type=None, 224 | is_extension=False, extension_scope=None, 225 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 226 | ], 227 | extensions=[ 228 | ], 229 | nested_types=[], 230 | enum_types=[ 231 | _HANDDETECTION_HANDSTATE, 232 | _HANDDETECTION_HANDSIDE, 233 | ], 234 | serialized_options=None, 235 | is_extendable=False, 236 | syntax='proto3', 237 | extension_ranges=[], 238 | oneofs=[ 239 | ], 240 | serialized_start=137, 241 | serialized_end=517, 242 | ) 243 | 244 | 245 | _OBJECTDETECTION = _descriptor.Descriptor( 246 | name='ObjectDetection', 247 | full_name='model.detections.ObjectDetection', 248 | filename=None, 249 | file=DESCRIPTOR, 250 | containing_type=None, 251 | create_key=_descriptor._internal_create_key, 252 | fields=[ 253 | _descriptor.FieldDescriptor( 254 | name='bbox', full_name='model.detections.ObjectDetection.bbox', index=0, 255 | number=1, type=11, cpp_type=10, label=1, 256 | has_default_value=False, default_value=None, 257 | message_type=None, enum_type=None, containing_type=None, 258 | is_extension=False, extension_scope=None, 259 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 260 | _descriptor.FieldDescriptor( 261 | name='score', full_name='model.detections.ObjectDetection.score', index=1, 262 | number=2, type=2, cpp_type=6, label=1, 263 | has_default_value=False, default_value=float(0), 264 | message_type=None, enum_type=None, containing_type=None, 265 | is_extension=False, extension_scope=None, 266 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 267 | ], 268 | extensions=[ 269 | ], 270 | nested_types=[], 271 | enum_types=[ 272 | ], 273 | serialized_options=None, 274 | is_extendable=False, 275 | syntax='proto3', 276 | extension_ranges=[], 277 | oneofs=[ 278 | ], 279 | serialized_start=519, 280 | serialized_end=589, 281 | ) 282 | 283 | 284 | _DETECTIONS = _descriptor.Descriptor( 285 | name='Detections', 286 | full_name='model.detections.Detections', 287 | filename=None, 288 | file=DESCRIPTOR, 289 | containing_type=None, 290 | create_key=_descriptor._internal_create_key, 291 | fields=[ 292 | _descriptor.FieldDescriptor( 293 | name='video_id', full_name='model.detections.Detections.video_id', index=0, 294 | number=1, type=9, cpp_type=9, label=1, 295 | has_default_value=False, default_value=b"".decode('utf-8'), 296 | message_type=None, enum_type=None, containing_type=None, 297 | is_extension=False, extension_scope=None, 298 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 299 | _descriptor.FieldDescriptor( 300 | name='frame_number', full_name='model.detections.Detections.frame_number', index=1, 301 | number=2, type=5, cpp_type=1, label=1, 302 | has_default_value=False, default_value=0, 303 | message_type=None, enum_type=None, containing_type=None, 304 | is_extension=False, extension_scope=None, 305 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 306 | _descriptor.FieldDescriptor( 307 | name='hands', full_name='model.detections.Detections.hands', index=2, 308 | number=3, type=11, cpp_type=10, label=3, 309 | has_default_value=False, default_value=[], 310 | message_type=None, enum_type=None, containing_type=None, 311 | is_extension=False, extension_scope=None, 312 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 313 | _descriptor.FieldDescriptor( 314 | name='objects', full_name='model.detections.Detections.objects', index=3, 315 | number=4, type=11, cpp_type=10, label=3, 316 | has_default_value=False, default_value=[], 317 | message_type=None, enum_type=None, containing_type=None, 318 | is_extension=False, extension_scope=None, 319 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 320 | ], 321 | extensions=[ 322 | ], 323 | nested_types=[], 324 | enum_types=[ 325 | ], 326 | serialized_options=None, 327 | is_extendable=False, 328 | syntax='proto3', 329 | extension_ranges=[], 330 | oneofs=[ 331 | ], 332 | serialized_start=592, 333 | serialized_end=744, 334 | ) 335 | 336 | _HANDDETECTION.fields_by_name['bbox'].message_type = _BBOX 337 | _HANDDETECTION.fields_by_name['state'].enum_type = _HANDDETECTION_HANDSTATE 338 | _HANDDETECTION.fields_by_name['object_offset'].message_type = _FLOATVECTOR 339 | _HANDDETECTION.fields_by_name['side'].enum_type = _HANDDETECTION_HANDSIDE 340 | _HANDDETECTION_HANDSTATE.containing_type = _HANDDETECTION 341 | _HANDDETECTION_HANDSIDE.containing_type = _HANDDETECTION 342 | _OBJECTDETECTION.fields_by_name['bbox'].message_type = _BBOX 343 | _DETECTIONS.fields_by_name['hands'].message_type = _HANDDETECTION 344 | _DETECTIONS.fields_by_name['objects'].message_type = _OBJECTDETECTION 345 | DESCRIPTOR.message_types_by_name['FloatVector'] = _FLOATVECTOR 346 | DESCRIPTOR.message_types_by_name['BBox'] = _BBOX 347 | DESCRIPTOR.message_types_by_name['HandDetection'] = _HANDDETECTION 348 | DESCRIPTOR.message_types_by_name['ObjectDetection'] = _OBJECTDETECTION 349 | DESCRIPTOR.message_types_by_name['Detections'] = _DETECTIONS 350 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 351 | 352 | FloatVector = _reflection.GeneratedProtocolMessageType('FloatVector', (_message.Message,), { 353 | 'DESCRIPTOR' : _FLOATVECTOR, 354 | '__module__' : 'types_pb2' 355 | # @@protoc_insertion_point(class_scope:model.detections.FloatVector) 356 | }) 357 | _sym_db.RegisterMessage(FloatVector) 358 | 359 | BBox = _reflection.GeneratedProtocolMessageType('BBox', (_message.Message,), { 360 | 'DESCRIPTOR' : _BBOX, 361 | '__module__' : 'types_pb2' 362 | # @@protoc_insertion_point(class_scope:model.detections.BBox) 363 | }) 364 | _sym_db.RegisterMessage(BBox) 365 | 366 | HandDetection = _reflection.GeneratedProtocolMessageType('HandDetection', (_message.Message,), { 367 | 'DESCRIPTOR' : _HANDDETECTION, 368 | '__module__' : 'types_pb2' 369 | # @@protoc_insertion_point(class_scope:model.detections.HandDetection) 370 | }) 371 | _sym_db.RegisterMessage(HandDetection) 372 | 373 | ObjectDetection = _reflection.GeneratedProtocolMessageType('ObjectDetection', (_message.Message,), { 374 | 'DESCRIPTOR' : _OBJECTDETECTION, 375 | '__module__' : 'types_pb2' 376 | # @@protoc_insertion_point(class_scope:model.detections.ObjectDetection) 377 | }) 378 | _sym_db.RegisterMessage(ObjectDetection) 379 | 380 | Detections = _reflection.GeneratedProtocolMessageType('Detections', (_message.Message,), { 381 | 'DESCRIPTOR' : _DETECTIONS, 382 | '__module__' : 'types_pb2' 383 | # @@protoc_insertion_point(class_scope:model.detections.Detections) 384 | }) 385 | _sym_db.RegisterMessage(Detections) 386 | 387 | 388 | # @@protoc_insertion_point(module_scope) 389 | -------------------------------------------------------------------------------- /preprocess/vis_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | from preprocess.affordance_util import compute_heatmap 5 | 6 | hand_rgb = {"LEFT": (0, 90, 181), "RIGHT": (220, 50, 32)} 7 | object_rgb = (255, 194, 10) 8 | 9 | 10 | def vis_traj(frame_vis, traj, fill_indices=None, side=None, circle_radis=4, circle_thickness=3, line_thickness=2, style='line', gap=5): 11 | for idx in range(len(traj)): 12 | x, y = traj[idx] 13 | if fill_indices is not None and idx in fill_indices: 14 | thickness = -1 15 | else: 16 | thickness = -1 17 | color = hand_rgb[side][::-1] if side is not None else (0, 255, 255) 18 | frame_vis = cv2.circle(frame_vis, (int(round(x)), int(round(y))), radius=circle_radis, color=color, 19 | thickness=thickness) 20 | if idx > 0: 21 | pt1 = (int(round(traj[idx-1][0])), int(round(traj[idx-1][1]))) 22 | pt2 = (int(round(traj[idx][0])), int(round(traj[idx][1]))) 23 | dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** .5 24 | pts = [] 25 | for i in np.arange(0, dist, gap): 26 | r = i / dist 27 | x = int((pt1[0] * (1 - r) + pt2[0] * r) + .5) 28 | y = int((pt1[1] * (1 - r) + pt2[1] * r) + .5) 29 | p = (x, y) 30 | pts.append(p) 31 | if style == 'dotted': 32 | for p in pts: 33 | cv2.circle(frame_vis, p, circle_thickness, color, -1) 34 | else: 35 | if len(pts) > 0: 36 | s = pts[0] 37 | e = pts[0] 38 | i = 0 39 | for p in pts: 40 | s = e 41 | e = p 42 | if i % 2 == 1: 43 | cv2.line(frame_vis, s, e, color, line_thickness) 44 | i += 1 45 | return frame_vis 46 | 47 | 48 | def vis_hand_traj(frames, hand_trajs): 49 | frame_vis = frames[0].copy() 50 | for side in hand_trajs: 51 | meta = hand_trajs[side] 52 | traj, fill_indices = meta["traj"], meta["fill_indices"] 53 | frame_vis = vis_traj(frame_vis, traj, fill_indices, side) 54 | return frame_vis 55 | 56 | 57 | def vis_affordance(frame, affordance_info): 58 | select_points = affordance_info["select_points_homo"] 59 | hmap = compute_heatmap(select_points, (frame.shape[1], frame.shape[0])) 60 | hmap = (hmap * 255).astype(np.uint8) 61 | hmap = cv2.applyColorMap(hmap, colormap=cv2.COLORMAP_JET) 62 | for idx in range((len(select_points))): 63 | point = select_points[idx].astype(np.int) 64 | frame_vis = cv2.circle(frame, (point[0], point[1]), radius=2, color=(255, 0, 255), 65 | thickness=-1) 66 | overlay = (0.7 * frame + 0.3 * hmap).astype(np.uint8) 67 | return overlay 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | torch 3 | lmdbdict 4 | tqdm 5 | scikit-learn 6 | einops -------------------------------------------------------------------------------- /traineval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.parallel 8 | import torch.optim 9 | 10 | from netscripts.get_datasets import get_dataset 11 | from netscripts.get_network import get_network 12 | from netscripts.get_optimizer import get_optimizer 13 | from netscripts import modelio 14 | from netscripts.epoch_feat import epoch_pass 15 | from options import netsopts, expopts 16 | from datasets.datasetopts import DatasetArgs 17 | 18 | 19 | def main(args): 20 | 21 | # Initialize randoms seeds 22 | torch.cuda.manual_seed_all(args.manual_seed) 23 | torch.manual_seed(args.manual_seed) 24 | np.random.seed(args.manual_seed) 25 | random.seed(args.manual_seed) 26 | 27 | datasetargs = DatasetArgs(ek_version=args.ek_version) 28 | num_frames_input = int(datasetargs.fps * datasetargs.t_buffer) 29 | num_frames_output = int(datasetargs.fps * datasetargs.t_ant) 30 | model = get_network(args, num_frames_input=num_frames_input, 31 | num_frames_output=num_frames_output) 32 | 33 | if args.use_cuda and torch.cuda.is_available(): 34 | print("Using {} GPUs !".format(torch.cuda.device_count())) 35 | model.cuda() 36 | 37 | start_epoch = 0 38 | if args.resume is not None: 39 | device = torch.device('cuda') if torch.cuda.is_available() and args.use_cuda else torch.device('cpu') 40 | start_epoch = modelio.load_checkpoint(model, resume_path=args.resume[0], strict=False, device=device) 41 | print("Loaded checkpoint from epoch {}, starting from there".format(start_epoch)) 42 | 43 | _, dls = get_dataset(args, base_path="./") 44 | 45 | if args.evaluate: 46 | args.epochs = start_epoch + 1 47 | traj_val_loader = None 48 | else: 49 | train_loader = dls['train'] 50 | traj_val_loader = dls['validation'] 51 | print("training dataset size: {}".format(len(train_loader.dataset))) 52 | optimizer, scheduler = get_optimizer(args, model=model, train_loader=train_loader) 53 | 54 | if not args.traj_only: 55 | val_loader = dls['eval'] 56 | else: 57 | traj_val_loader = val_loader = dls['validation'] 58 | print("evaluation dataset size: {}".format(len(val_loader.dataset))) 59 | 60 | for epoch in range(start_epoch, args.epochs): 61 | if not args.evaluate: 62 | print("Using lr {}".format(optimizer.param_groups[0]["lr"])) 63 | epoch_pass( 64 | loader=train_loader, 65 | model=model, 66 | phase='train', 67 | optimizer=optimizer, 68 | epoch=epoch, 69 | train=True, 70 | use_cuda=args.use_cuda, 71 | scheduler=scheduler) 72 | 73 | if args.evaluate or (epoch + 1) % args.test_freq == 0: 74 | with torch.no_grad(): 75 | if not args.traj_only: 76 | epoch_pass( 77 | loader=val_loader, 78 | model=model, 79 | epoch=epoch, 80 | phase='affordance', 81 | optimizer=None, 82 | train=False, 83 | use_cuda=args.use_cuda, 84 | num_samples=args.num_samples, 85 | num_points=args.num_points) 86 | else: 87 | with torch.no_grad(): 88 | epoch_pass( 89 | loader=traj_val_loader, 90 | model=model, 91 | epoch=epoch, 92 | phase='traj', 93 | optimizer=None, 94 | train=False, 95 | use_cuda=args.use_cuda, 96 | num_samples=args.num_samples) 97 | 98 | if not args.evaluate: 99 | if (epoch + 1 - args.warmup_epochs) % args.snapshot == 0: 100 | print(f"save epoch {epoch+1} checkpoint to {os.path.join(args.host_folder,args.exp_id)}") 101 | modelio.save_checkpoint( 102 | { 103 | "epoch": epoch + 1, 104 | "network": args.network, 105 | "state_dict": model.state_dict(), 106 | "optimizer": optimizer.state_dict(), 107 | }, 108 | checkpoint=os.path.join(args.host_folder, args.exp_id), 109 | filename = f"checkpoint_{epoch+1}.pth.tar") 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser(description="HOI Forecasting") 114 | netsopts.add_nets_opts(parser) 115 | netsopts.add_train_opts(parser) 116 | expopts.add_exp_opts(parser) 117 | args = parser.parse_args() 118 | 119 | if args.use_cuda and torch.cuda.is_available(): 120 | num_gpus = torch.cuda.device_count() 121 | args.batch_size = args.batch_size * num_gpus 122 | args.lr = args.lr * num_gpus 123 | 124 | if args.traj_only: assert args.evaluate, "evaluate trajectory on validation set must set --evaluate" 125 | main(args) 126 | print("All done !") --------------------------------------------------------------------------------