├── LICENSE
├── README.md
├── assets
├── EPIC-KITCHENS
│ └── P01
│ │ └── rgb_frames
│ │ └── P01_01
│ │ ├── frame_0000016827.jpg
│ │ ├── frame_0000016833.jpg
│ │ ├── frame_0000016839.jpg
│ │ ├── frame_0000016845.jpg
│ │ ├── frame_0000016851.jpg
│ │ ├── frame_0000016857.jpg
│ │ ├── frame_0000016863.jpg
│ │ ├── frame_0000016869.jpg
│ │ ├── frame_0000016875.jpg
│ │ ├── frame_0000016881.jpg
│ │ └── frame_0000016887.jpg
├── demo_gen.jpg
└── teaser.gif
├── datasets
├── dataloaders.py
├── dataset_utils.py
├── datasetopts.py
├── ho_utils.py
├── holoaders.py
└── input_loaders.py
├── demo_gen.py
├── environment.yaml
├── evaluation
├── affordance_eval.py
└── traj_eval.py
├── netscripts
├── epoch_feat.py
├── epoch_utils.py
├── get_datasets.py
├── get_network.py
├── get_optimizer.py
└── modelio.py
├── networks
├── affordance_decoder.py
├── decoder_modules.py
├── embedding.py
├── layer.py
├── model.py
├── net_utils.py
├── traj_decoder.py
└── transformer.py
├── options
├── expopts.py
└── netsopts.py
├── preprocess
├── affordance_util.py
├── dataset_util.py
├── ho_types.py
├── obj_util.py
├── traj_util.py
├── types_pb2.py
└── vis_util.py
├── requirements.txt
└── traineval.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 stevenlsw
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HOI-Forecast
2 |
3 | **Joint Hand Motion and Interaction Hotspots Prediction from Egocentric Videos (CVPR 2022)**
4 |
5 |
6 |
7 | #### [[Project Page]](https://stevenlsw.github.io/hoi-forecast/) [[Paper]](https://arxiv.org/abs/2204.01696) [[Training Data]](https://drive.google.com/drive/folders/1llDYFwn2gGQLpcWy6YScp3ej7A3LIPFc)
8 |
9 | Given observation frames of the past, we predict future hand trajectories (green and red lines) and object interaction hotspots (heatmaps) in egocentric view. We genearte training data **automatically** and use this data to train an Object-Centric Transformer (OCT) model for prediction.
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Installation
17 | - Clone this repository:
18 | ```Shell
19 | git clone https://github.com/stevenlsw/hoi-forecast
20 | cd hoi-forecast
21 | ```
22 | - Python 3.6 Environment:
23 | ```Shell
24 | conda env create -f environment.yaml
25 | conda activate fhoi
26 | ```
27 |
28 | ## Quick training data generation
29 | Official Epic-Kitchens Dataset looks the same as `assets/EPIC-KITCHENS`, rgb frames needed for the demo has been pre-downloaded in `assets/EPIC-KITCHENS/P01/rgb_frames/P01_01`.
30 |
31 | - Download Epic-Kitchens 55 Dataset [annotations](https://raw.githubusercontent.com/epic-kitchens/epic-kitchens-55-annotations/master/EPIC_train_action_labels.csv) and save in `assets` folder
32 |
33 | - Download hand-object detections below
34 | ```Shell
35 | link=https://data.bris.ac.uk/datasets/3l8eci2oqgst92n14w2yqi5ytu/hand-objects/P01/P01_01.pkl
36 | wget -P assets/EPIC-KITCHENS/P01/hand-objects $link
37 | ```
38 | - Run `python demo_gen.py` and results [png, pkl] are stored in `figs`, you should visualize the result
39 |
40 |
41 |

42 |
43 |
44 | - For more generated training labels, please visit [google drive](https://drive.google.com/drive/folders/1lNOSXXKbiqYJqC1hp1CIaIeM8u6UVvcg) and run `python example.py`.
45 |
46 |
47 | ## Evaluation on EK100
48 | We maunally collect the hand trajectories and interaction hotspots for evaluation. We pre-extract the input videos features.
49 |
50 | - Download the processed [files](https://drive.google.com/file/d/1s7qpBa-JjjuGk7v_aiuU2lvRjgP6fi9C/view?usp=sharing) (include collected labels, pre-extracted features, and dataset partitions, 600 MB) and **unzipped**. You will get the stucture like:
51 | ```
52 | hoi-forecast
53 | |-- data
54 | | |-- ek100
55 | | | |-- ek100_eval_labels.pkl
56 | | | |-- video_info.json
57 | | | |-- labels
58 | | | | |-- label_303.pkl
59 | | | | |-- ...
60 | | | |-- feats
61 | | | | |-- data.lmdb (RGB)
62 | |-- common
63 | | |-- epic-kitchens-55-annotations
64 | | |-- epic-kitchens-100-annotations
65 | | |-- rulstm
66 | ```
67 |
68 | - Download [pretrained models](https://drive.google.com/file/d/16IkQ4hOQk2_Klhd806J-46hLN-OGokxa/view?usp=sharing) on EK100 and the stored model path is refered as `$resume`.
69 |
70 | - Install PyTorch and dependencies by the following command:
71 | ```Shell
72 | pip install -r requirements.txt
73 | ```
74 |
75 | - Evaluate future hand trajectory
76 | ```Shell
77 | python traineval.py --evaluate --ek_version=ek100 --resume={path to the model} --traj_only
78 | ```
79 |
80 | - Evaluate future interaction hotspots
81 | ```Shell
82 | python traineval.py --evaluate --ek_version=ek100 --resume={path to the model}
83 | ```
84 |
85 | - Results should like:
86 |
87 |
88 |
89 |
90 | Hand Trajectory |
91 | Interaction Hotspots |
92 |
93 |
94 | ADE ↓ |
95 | FDE ↓ |
96 | SIM ↑ |
97 | AUC-J ↑ |
98 | NSS ↑ |
99 |
100 |
101 |
102 |
103 | 0.12 |
104 | 0.11 |
105 | 0.19 |
106 | 0.69 |
107 | 0.72 |
108 |
109 |
110 |
111 |
112 |
113 | ## Training
114 | - Extract per-frame features of training set similar to [RULSTM](https://github.com/fpv-iplab/rulstm) and store them in `data/ek100/feats/ek100.lmdb`, the key-value pair likes
115 | ```python
116 | fname = 'P01/rgb_frames/P01_01/frame_0000000720.jpg'
117 | env[fname.encode()] = result_dict # extracted feature results
118 | ```
119 |
120 | - Start training
121 | ```
122 | python traineval.py --ek_version=ek100
123 | ```
124 |
125 | ## Citation
126 | ```latex
127 | @inproceedings{liu2022joint,
128 | title={Joint Hand Motion and Interaction Hotspots Prediction from Egocentric Videos},
129 | author={Liu, Shaowei and Tripathi, Subarna and Majumdar, Somdeb and Wang, Xiaolong},
130 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
131 | year={2022}
132 | }
133 | ```
134 |
135 | ## Acknowledges
136 | We thank:
137 | * [epic-kitchens](https://github.com/epic-kitchens/epic-kitchens-100-hand-object-bboxes) for hand-object detections on Epic-Kitchens dataset
138 | * [rulstm](https://github.com/fpv-iplab/rulstm) for features extraction and action anticipation
139 | * [epic-kitchens-dataset-pytorch](https://github.com/guglielmocamporese/epic-kitchens-dataset-pytorch) for
140 | epic-kitchens dataloader
141 | * [Miao Liu](https://aptx4869lm.github.io/) for help with prior work
142 |
143 |
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016827.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016827.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016833.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016833.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016839.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016839.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016845.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016845.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016851.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016851.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016857.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016857.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016863.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016863.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016869.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016869.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016875.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016875.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016881.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016881.jpg
--------------------------------------------------------------------------------
/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016887.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/EPIC-KITCHENS/P01/rgb_frames/P01_01/frame_0000016887.jpg
--------------------------------------------------------------------------------
/assets/demo_gen.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/demo_gen.jpg
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/hoi-forecast/fbc85a17cda21c29974994abcdb055e943f98dcc/assets/teaser.gif
--------------------------------------------------------------------------------
/datasets/dataloaders.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from tqdm import tqdm
3 | import json
4 | import numpy as np
5 |
6 | from datasets.dataset_utils import get_ek55_annotation, get_ek100_annotation
7 | from datasets.input_loaders import get_loaders
8 |
9 |
10 | class EpicAction(object):
11 | def __init__(self, uid, participant_id, video_id, verb, verb_class,
12 | noun, noun_class, all_nouns, all_noun_classes, start_frame,
13 | stop_frame, start_time, stop_time, ori_fps, partition, action, action_class):
14 | self.uid = uid
15 | self.participant_id = participant_id
16 | self.video_id = video_id
17 | self.verb = verb
18 | self.verb_class = verb_class
19 | self.noun = noun
20 | self.noun_class = noun_class
21 | self.all_nouns = all_nouns
22 | self.all_noun_classes = all_noun_classes
23 | self.start_frame = start_frame
24 | self.stop_frame = stop_frame
25 | self.start_time = start_time
26 | self.stop_time = stop_time
27 | self.ori_fps = ori_fps
28 | self.partition = partition
29 | self.action = action
30 | self.action_class = action_class
31 |
32 | self.duration = self.stop_time - self.start_time
33 |
34 | def __repr__(self):
35 | return json.dumps(self.__dict__, indent=4)
36 |
37 | def set_previous_actions(self, actions):
38 | self.actions_prev = actions
39 |
40 |
41 | class EpicVideo(object):
42 | def __init__(self, df_video, ori_fps, partition, t_ant=None):
43 | self.df = df_video
44 | self.ori_fps = ori_fps
45 | self.partition = partition
46 | self.t_ant = t_ant
47 |
48 | self.actions, self.actions_invalid = self._get_actions()
49 | self.duration = max([a.stop_time for a in self.actions])
50 |
51 | def _get_actions(self):
52 | actions = []
53 | _actions_all = []
54 | actions_invalid = []
55 | for _, row in self.df.iterrows():
56 | action_args = {
57 | 'uid': row.uid,
58 | 'participant_id': row.participant_id,
59 | 'video_id': row.video_id,
60 | 'verb': row.verb if 'test' not in self.partition else None,
61 | 'verb_class': row.verb_class if 'test' not in self.partition else None,
62 | 'noun': row.noun if 'test' not in self.partition else None,
63 | 'noun_class': row.noun_class if 'test' not in self.partition else None,
64 | 'all_nouns': row.all_nouns if 'test' not in self.partition else None,
65 | 'all_noun_classes': row.all_noun_classes if 'test' not in self.partition else None,
66 | 'start_frame': row.start_frame,
67 | 'stop_frame': row.stop_time,
68 | 'start_time': row.start_time,
69 | 'stop_time': row.stop_time,
70 | 'ori_fps': self.ori_fps,
71 | 'partition': self.partition,
72 | 'action': row.action if 'test' not in self.partition else None,
73 | 'action_class': row.action_class if 'test' not in self.partition else None,
74 | }
75 | action = EpicAction(**action_args)
76 | action.set_previous_actions([aa for aa in _actions_all])
77 | assert self.t_ant is not None
78 | assert self.t_ant > 0.0
79 | if action.start_time - self.t_ant >= 0:
80 | actions += [action]
81 | else:
82 | actions_invalid += [action]
83 | _actions_all += [action]
84 | return actions, actions_invalid
85 |
86 |
87 | class EpicDataset(Dataset):
88 | def __init__(self, df, partition, ori_fps=60.0, fps=4.0, loader=None, t_ant=None, transform=None,
89 | num_actions_prev=None, label_path=None, eval_label_path=None,
90 | annot_path=None, rulstm_annot_path=None, ek_version=None):
91 | super().__init__()
92 | self.partition = partition
93 | self.ori_fps = ori_fps
94 | self.fps = fps
95 | self.df = df
96 | self.loader = loader
97 | self.t_ant = t_ant
98 | self.transform = transform
99 | self.num_actions_prev = num_actions_prev
100 |
101 | self.videos = self._get_videos()
102 | self.actions, self.actions_invalid = self._get_actions()
103 |
104 | def _get_videos(self):
105 | video_ids = sorted(list(set(self.df['video_id'].values.tolist())))
106 | videos = []
107 | pbar = tqdm(desc=f'Loading {self.partition} samples', total=len(self.df))
108 | for video_id in video_ids:
109 | video_args = {
110 | 'df_video': self.df[self.df['video_id'] == video_id].copy(),
111 | 'ori_fps': self.ori_fps,
112 | 'partition': self.partition,
113 | 't_ant': self.t_ant
114 | }
115 | video = EpicVideo(**video_args)
116 | videos += [video]
117 | pbar.update(len(video.actions))
118 | pbar.close()
119 | return videos
120 |
121 | def _get_actions(self):
122 | actions = []
123 | actions_invalid = []
124 | for video in self.videos:
125 | actions += video.actions
126 | actions_invalid += video.actions_invalid
127 | return actions, actions_invalid
128 |
129 | def __len__(self):
130 | return len(self.actions)
131 |
132 | def __getitem__(self, idx):
133 | a = self.actions[idx]
134 | sample = {'uid': a.uid}
135 |
136 | inputs = self.loader(a)
137 | sample.update(inputs)
138 |
139 | if 'test' not in self.partition:
140 | sample['verb_class'] = a.verb_class
141 | sample['noun_class'] = a.noun_class
142 | sample['action_class'] = a.action_class
143 |
144 | actions_prev = [-1] + [aa.action_class for aa in a.actions_prev]
145 | actions_prev = actions_prev[-self.num_actions_prev:]
146 | if len(actions_prev) < self.num_actions_prev:
147 | actions_prev = actions_prev[0:1] * (self.num_actions_prev - len(actions_prev)) + actions_prev
148 | actions_prev = np.array(actions_prev, dtype=np.int64)
149 | sample['action_class_prev'] = actions_prev
150 | return sample
151 |
152 |
153 | def get_datasets(args, epic_ds=None, featuresloader=None):
154 | loaders = get_loaders(args, featuresloader=featuresloader)
155 |
156 | annotation_args = {
157 | 'annot_path': args.annot_path,
158 | 'label_path': args.label_path,
159 | 'eval_label_path': args.eval_label_path,
160 | 'rulstm_annot_path': args.rulstm_annot_path,
161 | 'validation_ratio': args.validation_ratio,
162 | 'use_rulstm_splits': args.use_rulstm_splits,
163 | 'use_label_only': args.use_label_only
164 | }
165 |
166 | if args.ek_version == 'ek55':
167 | dfs = {
168 | 'train': get_ek55_annotation(partition='train', **annotation_args),
169 | 'validation': get_ek55_annotation(partition='validation', **annotation_args),
170 | 'eval': get_ek55_annotation(partition='eval', **annotation_args),
171 | 'test_s1': get_ek55_annotation(partition='test_s1', **annotation_args),
172 | 'test_s2': get_ek55_annotation(partition='test_s2', **annotation_args),
173 | }
174 | elif args.ek_version == 'ek100':
175 | dfs = {
176 | 'train': get_ek100_annotation(partition='train', **annotation_args),
177 | 'validation': get_ek100_annotation(partition='validation', **annotation_args),
178 | 'eval': get_ek100_annotation(partition='eval', **annotation_args),
179 | 'test': get_ek100_annotation(partition='test', **annotation_args),
180 | }
181 | else:
182 | raise Exception(f'Error. EPIC-Kitchens Version "{args.ek_version}" not supported.')
183 |
184 | ds_args = {
185 | 'label_path': args.label_path[args.ek_version],
186 | 'eval_label_path': args.eval_label_path[args.ek_version],
187 | 'annot_path': args.annot_path,
188 | 'rulstm_annot_path': args.rulstm_annot_path[args.ek_version],
189 | 'ek_version': args.ek_version,
190 | 'ori_fps': args.ori_fps,
191 | 'fps': args.fps,
192 | 't_ant': args.t_ant,
193 | 'num_actions_prev': args.num_actions_prev if args.task in ['anticipation'] else None,
194 | }
195 |
196 | if epic_ds is None:
197 | epic_ds = EpicDataset
198 |
199 | if args.mode in ['train', 'training']:
200 | dss = {
201 | 'train': epic_ds(df=dfs['train'], partition='train', loader=loaders['train'], **ds_args),
202 | 'validation': epic_ds(df=dfs['validation'], partition='validation', loader=loaders['validation'],
203 | **ds_args),
204 | 'eval': epic_ds(df=dfs['eval'], partition='eval', loader=loaders['validation'], **ds_args),
205 | }
206 | elif args.mode in ['validation', 'validating', 'validate']:
207 | dss = {
208 | 'validation': epic_ds(df=dfs['validation'], partition='validation',
209 | loader=loaders['validation'], **ds_args),
210 | 'eval': epic_ds(df=dfs['eval'], partition='eval', loader=loaders['validation'], **ds_args),
211 | }
212 | elif args.mode in ['test', 'testing']:
213 |
214 | if args.ek_version == "ek55":
215 | dss = {
216 | 'test_s1': epic_ds(df=dfs['test_s1'], partition='test_s1', loader=loaders['test'], **ds_args),
217 | 'test_s2': epic_ds(df=dfs['test_s2'], partition='test_s2', loader=loaders['test'], **ds_args),
218 | }
219 | elif args.ek_version == "ek100":
220 | dss = {
221 | 'test': epic_ds(df=dfs['test'], partition='test', loader=loaders['test'], **ds_args),
222 | }
223 | else:
224 | raise Exception(f'Error. Mode "{args.mode}" not supported.')
225 |
226 | return dss
227 |
228 |
229 | def get_dataloaders(args, epic_ds=None, featuresloader=None):
230 | dss = get_datasets(args, epic_ds=epic_ds,featuresloader=featuresloader)
231 | dl_args = {
232 | 'batch_size': args.batch_size,
233 | 'pin_memory': True,
234 | 'num_workers': args.num_workers,
235 | }
236 | if args.mode in ['train', 'training']:
237 | dls = {
238 | 'train': DataLoader(dss['train'], shuffle=False, **dl_args),
239 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args),
240 | }
241 | elif args.mode in ['validate', 'validation', 'validating']:
242 | dls = {
243 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args),
244 | }
245 | elif args.mode == 'test':
246 | if args.ek_version == "ek55":
247 | dls = {
248 | 'test_s1': DataLoader(dss['test_s1'], shuffle=False, **dl_args),
249 | 'test_s2': DataLoader(dss['test_s2'], shuffle=False, **dl_args),
250 | }
251 | elif args.ek_version == "ek100":
252 | dls = {
253 | 'test': DataLoader(dss['test'], shuffle=False, **dl_args),
254 | }
255 | else:
256 | raise Exception(f'Error. Mode "{args.mode}" not supported.')
257 | return dls
258 |
--------------------------------------------------------------------------------
/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pandas as pd
4 | import pickle
5 | from sklearn.model_selection import train_test_split
6 | import numpy as np
7 |
8 |
9 | def timestr2sec(t_str):
10 | hh, mm, ss = [float(x) for x in t_str.split(':')]
11 | t_sec = hh * 3600.0 + mm * 60.0 + ss
12 | return t_sec
13 |
14 |
15 | def read_rulstm_splits(rulstm_annotation_path):
16 | header = ['uid', 'video_id', 'start_frame', 'stop_frame', 'verb_class', 'noun_class', 'action_class']
17 | df_train = pd.read_csv(os.path.join(rulstm_annotation_path, 'training.csv'), names=header)
18 | df_validation = pd.read_csv(os.path.join(rulstm_annotation_path, 'validation.csv'), names=header)
19 | return df_train, df_validation
20 |
21 |
22 | def str2list(s, out_type=None):
23 | """
24 | Convert a string "[i1, i2, ...]" of items into a list [i1, i2, ...] of items.
25 | """
26 | s = s.replace('[', '').replace(']', '')
27 | s = s.replace('\'', '')
28 | s = s.split(', ')
29 | if out_type is not None:
30 | s = [out_type(ss) for ss in s]
31 | return s
32 |
33 |
34 | def split_train_val(df, validation_ratio=0.2, use_rulstm_splits=False,
35 | rulstm_annotation_path=None, label_info_path=None,
36 | use_label_only=True):
37 | if label_info_path is not None and use_label_only:
38 | with open(label_info_path, 'r') as f:
39 | uids_label = json.load(f)
40 | df = df.loc[df['uid'].isin(uids_label)]
41 | if use_rulstm_splits:
42 | assert rulstm_annotation_path is not None
43 | df_train_rulstm, df_validation_rulstm = read_rulstm_splits(rulstm_annotation_path)
44 | uids_train = df_train_rulstm['uid'].values.tolist()
45 | uids_validation = df_validation_rulstm['uid'].values.tolist()
46 | df_train = df.loc[df['uid'].isin(uids_train)]
47 | df_validation = df.loc[df['uid'].isin(uids_validation)]
48 | else:
49 | if validation_ratio == 0.0:
50 | df_train = df
51 | df_validation = pd.DataFrame(columns=df.columns)
52 | elif validation_ratio == 1.0:
53 | df_train = pd.DataFrame(columns=df.columns)
54 | df_validation = df
55 | elif 0.0 < validation_ratio < 1.0:
56 | df_train, df_validation = train_test_split(df, test_size=validation_ratio,
57 | random_state=3577,
58 | shuffle=True, stratify=df['participant_id'])
59 | else:
60 | raise Exception(f'Error. Validation "{validation_ratio}" not supported.')
61 | return df_train, df_validation
62 |
63 |
64 | def create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, ek_version, out_path='actions.csv', use_rulstm_splits=True):
65 | if use_rulstm_splits:
66 | if ek_version == 'ek55':
67 | df_actions = pd.read_csv(os.path.join(rulstm_annot_path['ek55'], 'actions.csv'))
68 | elif ek_version == 'ek100':
69 | df_actions = pd.read_csv(os.path.join(rulstm_annot_path['ek100'], 'actions.csv'))
70 | df_actions['action'] = df_actions.action.map(lambda x: x.replace(' ', '_'))
71 |
72 | df_actions['verb_class'] = df_actions.verb
73 | df_actions['noun_class'] = df_actions.noun
74 | df_actions['verb'] = df_actions.action.map(lambda x: x.split('_')[0])
75 | df_actions['noun'] = df_actions.action.map(lambda x: x.split('_')[1])
76 | df_actions['action'] = df_actions.action
77 | df_actions['action_class'] = df_actions.id
78 | del df_actions['id']
79 |
80 | else:
81 | if ek_version == 'ek55':
82 | df_train = get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path=None, partition='train',
83 | use_label_only=False, raw=True)
84 | df_validation = get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition='validation',
85 | use_label_only=False, raw=True)
86 | df = pd.concat([df_train, df_validation])
87 | df.sort_values(by=['uid'], inplace=True)
88 |
89 | elif ek_version == 'ek100':
90 | df_train = get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path=None, partition='train',
91 | use_label_only=False, raw=True)
92 | df_validation = get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition='validation',
93 | use_label_only=False, raw=True)
94 | df = pd.concat([df_train, df_validation])
95 | df.sort_values(by=['narration_id'], inplace=True)
96 |
97 | noun_classes = df.noun_class.values
98 | nouns = df.noun.values
99 | verb_classes = df.verb_class.values
100 | verbs = df.verb.values
101 |
102 | actions_combinations = [f'{v}_{n}' for v, n in zip(verb_classes, noun_classes)]
103 | actions = [f'{v}_{n}' for v, n in zip(verbs, nouns)]
104 |
105 | df_actions = {'verb_class': [], 'noun_class': [], 'verb': [], 'noun': [], 'action': []}
106 | vn_combinations = []
107 | for i, a in enumerate(actions_combinations):
108 | if a in vn_combinations:
109 | continue
110 |
111 | v, n = a.split('_')
112 | v = int(v)
113 | n = int(n)
114 | df_actions['verb_class'] += [v]
115 | df_actions['noun_class'] += [n]
116 | df_actions['action'] += [actions[i]]
117 | df_actions['verb'] += [verbs[i]]
118 | df_actions['noun'] += [nouns[i]]
119 | vn_combinations += [a]
120 | df_actions = pd.DataFrame(df_actions)
121 | df_actions.sort_values(by=['verb_class', 'noun_class'], inplace=True)
122 | df_actions['action_class'] = range(len(df_actions))
123 |
124 | df_actions.to_csv(out_path, index=False)
125 | print(f'Saved file at "{out_path}".')
126 |
127 |
128 | def get_ek55_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition, validation_ratio=0.2,
129 | use_rulstm_splits=False, use_label_only=True, raw=False):
130 | if partition in ['train', 'validation']:
131 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_train_action_labels.csv')
132 | label_info_path = os.path.join(label_path['ek55'], "video_info.json")
133 | df = pd.read_csv(csv_path)
134 | df_train, df_validation = split_train_val(df, validation_ratio=validation_ratio,
135 | use_rulstm_splits=use_rulstm_splits,
136 | rulstm_annotation_path=rulstm_annot_path['ek55'],
137 | label_info_path=label_info_path,
138 | use_label_only=use_label_only)
139 |
140 | df = df_train if partition == 'train' else df_validation
141 | if not use_rulstm_splits:
142 | df.sort_values(by=['uid'], inplace=True)
143 |
144 | elif partition in ['eval', 'evaluation']:
145 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_train_action_labels.csv')
146 | df = pd.read_csv(csv_path)
147 | with open(eval_label_path['ek55'], 'rb') as f:
148 | eval_labels = pickle.load(f)
149 | eval_uids = eval_labels.keys()
150 | df = df.loc[df['uid'].isin(eval_uids)]
151 |
152 | elif partition == 'test_s1':
153 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_test_s1_timestamps.csv')
154 | df = pd.read_csv(csv_path)
155 |
156 | elif partition == 'test_s2':
157 | csv_path = os.path.join(annot_path['ek55'], 'EPIC_test_s2_timestamps.csv')
158 | df = pd.read_csv(csv_path)
159 | else:
160 | raise Exception(f'Error. Partition "{partition}" not supported.')
161 |
162 | if raw:
163 | return df
164 |
165 | actions_df_path = os.path.join(annot_path['ek55'], 'actions.csv')
166 | if not os.path.exists(actions_df_path):
167 | create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, 'ek55', out_path=actions_df_path, use_rulstm_splits=True)
168 | df_actions = pd.read_csv(actions_df_path)
169 |
170 | df['start_time'] = df['start_timestamp'].map(lambda t: timestr2sec(t))
171 | df['stop_time'] = df['stop_timestamp'].map(lambda t: timestr2sec(t))
172 | if 'test' not in partition:
173 | action_classes = []
174 | actions = []
175 | for _, row in df.iterrows():
176 | v, n = row.verb_class, row.noun_class
177 | df_a_sub = df_actions[(df_actions['verb_class'] == v) & (df_actions['noun_class'] == n)]
178 | a_cl = df_a_sub['action_class'].values
179 | a = df_a_sub['action'].values
180 | if len(a_cl) > 1:
181 | print(a_cl)
182 | action_classes += [a_cl[0]]
183 | actions += [a[0]]
184 | df['action_class'] = action_classes
185 | df['action'] = actions
186 | df['all_nouns'] = df['all_nouns'].map(lambda x: str2list(x))
187 | df['all_noun_classes'] = df['all_noun_classes'].map(lambda x: str2list(x, out_type=int))
188 |
189 | return df
190 |
191 |
192 | def get_ek100_annotation(annot_path, rulstm_annot_path, label_path, eval_label_path, partition, validation_ratio=0.2,
193 | use_rulstm_splits=False, use_label_only=True, raw=False):
194 | if partition in 'train':
195 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv'))
196 | uids = np.arange(len(df))
197 |
198 | elif partition in 'validation':
199 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv'))
200 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv'))
201 | uids = np.arange(len(df)) + len(df_train)
202 |
203 | elif partition in 'evaluation':
204 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv'))
205 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv'))
206 | uids = np.arange(len(df)) + len(df_train)
207 | df['uid'] = uids
208 | with open(eval_label_path['ek100'], 'rb') as f:
209 | eval_labels = pickle.load(f)
210 | eval_uids = eval_labels.keys()
211 | df = df.loc[df['uid'].isin(eval_uids)]
212 |
213 | elif partition == 'test':
214 | df_train = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_train.csv'))
215 | df_validation = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_validation.csv'))
216 | df = pd.read_csv(os.path.join(annot_path['ek100'], 'EPIC_100_test_timestamps.csv'))
217 | uids = np.arange(len(df)) + len(df_train) + len(df_validation)
218 |
219 | else:
220 | raise Exception(f'Error. Partition "{partition}" not supported.')
221 | if raw:
222 | return df
223 |
224 | actions_df_path = os.path.join(annot_path['ek100'], 'actions.csv')
225 | if not os.path.exists(actions_df_path):
226 | create_actions_df(annot_path, rulstm_annot_path, label_path, eval_label_path, 'ek100', actions_df_path)
227 | df_actions = pd.read_csv(actions_df_path)
228 |
229 | df['start_time'] = df['start_timestamp'].map(lambda t: timestr2sec(t))
230 | df['stop_time'] = df['stop_timestamp'].map(lambda t: timestr2sec(t))
231 | if not 'uid' in df:
232 | df['uid'] = uids
233 |
234 | if use_label_only:
235 | label_info_path = os.path.join(label_path['ek100'], "video_info.json")
236 | with open(label_info_path, 'r') as f:
237 | uids_label = json.load(f)
238 | df = df.loc[df['uid'].isin(uids_label)]
239 |
240 | if 'test' not in partition:
241 | action_classes = []
242 | actions = []
243 | for _, row in df.iterrows():
244 | v, n = row.verb_class, row.noun_class
245 | df_a_sub = df_actions[(df_actions['verb_class'] == v) & (df_actions['noun_class'] == n)]
246 | a_cl = df_a_sub['action_class'].values
247 | a = df_a_sub['action'].values
248 | if len(a_cl) > 1:
249 | print(a_cl)
250 | action_classes += [a_cl[0]]
251 | actions += [a[0]]
252 | df['action_class'] = action_classes
253 | df['action'] = actions
254 | df['all_nouns'] = df['all_nouns'].map(lambda x: str2list(x))
255 | df['all_noun_classes'] = df['all_noun_classes'].map(lambda x: str2list(x, out_type=int))
256 | return df
--------------------------------------------------------------------------------
/datasets/datasetopts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 |
5 | class DatasetArgs(object):
6 | def __init__(self, ek_version='ek55', mode="train", use_label_only=True,
7 | base_path="./", batch_size=32, num_workers=0, modalities=['feat'],
8 | fps=4, t_buffer=2.5):
9 |
10 | self.features_paths = {
11 | 'ek55': os.path.join(base_path, 'data/ek55/feats'),
12 | 'ek100': os.path.join(base_path, 'data/ek100/feats')}
13 |
14 | # generated data labels
15 | self.label_path = {
16 | 'ek55': os.path.join(base_path, 'data/ek55'),
17 | 'ek100': os.path.join(base_path, 'data/ek100')}
18 |
19 | # amazon-annotated eval labels
20 | self.eval_label_path = {
21 | 'ek55': os.path.join(base_path, 'data/ek55/ek55_eval_labels.pkl'),
22 | 'ek100': os.path.join(base_path, 'data/ek100/ek100_eval_labels.pkl')
23 | }
24 |
25 | self.annot_path = {
26 | 'ek55': os.path.join(base_path, 'common/epic-kitchens-55-annotations'),
27 | 'ek100': os.path.join(base_path, 'common/epic-kitchens-100-annotations')}
28 |
29 | self.rulstm_annot_path = {
30 | 'ek55': os.path.join(base_path, 'common/rulstm/RULSTM/data/ek55'),
31 | 'ek100': os.path.join(base_path, 'common/rulstm/RULSTM/data/ek100')}
32 |
33 | self.pretrained_backbone_path = {
34 | 'ek55': os.path.join(base_path, 'common/rulstm/FEATEXT/models/ek55', 'TSN-rgb.pth.tar'),
35 | 'ek100': os.path.join(base_path, 'common/rulstm/FEATEXT/models/ek100', 'TSN-rgb-ek100.pth.tar'),
36 | }
37 |
38 | # default settings, no need changes
39 | if fps is None:
40 | self.fps = 4
41 | else:
42 | self.fps = fps
43 |
44 | if t_buffer is None:
45 | self.t_buffer = 2.5
46 | else:
47 | self.t_buffer = t_buffer
48 |
49 | self.ori_fps = 60.0
50 | self.t_ant = 1.0
51 |
52 | self.validation_ratio = 0.2
53 | self.use_rulstm_splits = True
54 |
55 | # only preprocess uids that have corresponding labels, in "video_info.json"
56 | self.use_label_only = use_label_only
57 |
58 | self.task = 'anticipation'
59 | self.num_actions_prev = 1
60 |
61 | self.batch_size = batch_size
62 | self.num_workers = num_workers
63 |
64 | self.modalities = modalities
65 | self.ek_version = ek_version # 'ek55' or 'ek100'
66 | self.mode = mode # 'train'
67 |
68 | def add_attr(self, attr_name, attr_value):
69 | setattr(self, attr_name, attr_value)
70 |
71 | def has_attr(self, attr_name):
72 | return hasattr(self, attr_name)
73 |
74 | def __repr__(self):
75 | return 'Input Args: ' + json.dumps(self.__dict__, indent=4)
76 |
77 |
--------------------------------------------------------------------------------
/datasets/ho_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 |
5 |
6 | def load_video_info(label_path, video_index):
7 | with open(os.path.join(label_path, "label_{}.pkl".format(video_index)), 'rb') as f:
8 | video_info = pickle.load(f)
9 | return video_info
10 |
11 |
12 | def sample_hand_traj(meta, fps, t_ant, shape=(456, 256)):
13 | width, height = shape
14 | traj = meta["traj"]
15 | ori_fps = int((len(traj) - 1) / t_ant)
16 | gap = int(ori_fps // fps)
17 | stop_idx = len(traj)
18 | indices = [0] + list(range(gap, stop_idx, gap))
19 | hand_traj = []
20 | for idx in indices:
21 | x, y = traj[idx]
22 | x, y, = x / width, y / height
23 | hand_traj.append(np.array([x, y], dtype=np.float32))
24 | hand_traj = np.array(hand_traj, dtype=np.float32)
25 | return hand_traj, indices
26 |
27 |
28 | def process_video_info(video_info, fps=4, t_ant=1.0, shape=(456, 256)):
29 | frames_idxs = video_info["frame_indices"]
30 | hand_trajs = video_info["hand_trajs"]
31 | obj_affordance = video_info['affordance']['select_points_homo']
32 | num_points = obj_affordance.shape[0]
33 | select_idx = np.random.choice(num_points, 1, replace=False)
34 | contact_point = obj_affordance[select_idx]
35 | cx, cy = contact_point[0]
36 | width, height = shape
37 | cx, cy = cx / width, cy/ height
38 | contact_point = np.array([cx, cy], dtype=np.float32)
39 |
40 | valid_mask = []
41 | if "RIGHT" in hand_trajs:
42 | meta = hand_trajs["RIGHT"]
43 | rhand_traj, indices = sample_hand_traj(meta, fps, t_ant, shape)
44 | valid_mask.append(1)
45 | else:
46 | length = int(fps * t_ant + 1)
47 | rhand_traj = np.repeat(np.array([[0.75, 1.5]], dtype=np.float32), length, axis=0)
48 | valid_mask.append(0)
49 |
50 | if "LEFT" in hand_trajs:
51 | meta = hand_trajs["LEFT"]
52 | lhand_traj, indices = sample_hand_traj(meta, fps, t_ant, shape)
53 | valid_mask.append(1)
54 | else:
55 | length = int(fps * t_ant + 1)
56 | lhand_traj = np.repeat(np.array([[0.25, 1.5]], dtype=np.float32), length, axis=0)
57 | valid_mask.append(0)
58 |
59 | future_hands = np.stack((rhand_traj, lhand_traj), axis=0)
60 | future_valid = np.array(valid_mask, dtype=np.int)
61 |
62 | last_frame_index = frames_idxs[0]
63 | return future_hands, contact_point, future_valid, last_frame_index
64 |
65 |
66 | def process_eval_video_info(video_info, fps=4, t_ant=1.0):
67 | valid_mask = []
68 | if "RIGHT" in video_info:
69 | rhand_traj = video_info["RIGHT"]
70 | assert rhand_traj.shape[0] == int(fps * t_ant + 1)
71 | valid_mask.append(1)
72 | else:
73 | rhand_traj = np.repeat(np.array([[0.75, 1.5]], dtype=np.float32), int(fps * t_ant + 1), axis=0)
74 | valid_mask.append(0)
75 |
76 | if "LEFT" in video_info:
77 | lhand_traj = video_info['LEFT']
78 | assert lhand_traj.shape[0] == int(fps * t_ant + 1)
79 | valid_mask.append(1)
80 | else:
81 | lhand_traj = np.repeat(np.array([[0.25, 1.5]], dtype=np.float32), int(fps * t_ant + 1), axis=0)
82 | valid_mask.append(0)
83 |
84 | future_hands = np.stack((rhand_traj, lhand_traj), axis=0)
85 | future_valid = np.array(valid_mask, dtype=np.int)
86 | return future_hands, future_valid
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/datasets/holoaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 | from lmdbdict import lmdbdict
5 | from torch.utils.data import DataLoader
6 |
7 | from datasets.dataloaders import EpicDataset, get_datasets
8 | from datasets.ho_utils import load_video_info, process_video_info, process_eval_video_info
9 |
10 |
11 | class FeaturesHOLoader(object):
12 | def __init__(self, sampler, feature_base_path, fps=4.0, input_name='rgb',
13 | frame_tmpl='frame_{:010d}.jpg',
14 | transform_feat=None, transform_video=None,
15 | t_observe=2.5):
16 |
17 | self.feature_base_path = feature_base_path
18 | self.lmdb_path = os.path.join(self.feature_base_path, "data.lmdb")
19 | self.env = lmdbdict(self.lmdb_path, 'r')
20 | self.fps = fps
21 | self.input_name = input_name
22 | self.frame_tmpl = frame_tmpl
23 | self.transform_feat = transform_feat
24 | self.transform_video = transform_video
25 | self.sampler = sampler
26 | self.t_observe = t_observe
27 | self.num_observe = int(t_observe * self.fps)
28 |
29 | def __call__(self, action):
30 | times, frames_idxs = self.sampler(action)
31 | assert self.num_observe <= len(frames_idxs), \
32 | "num of observation exceed the limit of {}, set smaller t_observe, current is {}".format(len(frames_idxs), self.t_observe)
33 | frames_names = [self.frame_tmpl.format(i) for i in frames_idxs]
34 | start_frame_idx = len(frames_idxs) - self.num_observe
35 | frames_names = frames_names[start_frame_idx:]
36 |
37 | full_names = []
38 | global_feats, global_masks = [], []
39 | rhand_feats, rhand_masks, rhand_bboxs = [], [], []
40 | lhand_feats, lhand_masks, lhand_bboxs = [], [], []
41 | robj_feats, robj_masks, robj_bboxs = [], [], []
42 | lobj_feats, lobj_masks, lobj_bboxs = [], [], []
43 |
44 | for f_name in frames_names:
45 | # full_name: e.g. 'P24/rgb_frames/P24_05/frame_0000075700.jpg'
46 | full_name = os.path.join(action.participant_id, "rgb_frames", action.video_id, f_name)
47 | full_names.append(full_name)
48 | # f_dict: 'GLOBAL_FEAT',
49 | # 'HAND_RIGHT_FEAT', 'HAND_RIGHT_BBOX', 'OBJECT_RIGHT_FEAT', 'OBJECT_RIGHT_BBOX',
50 | # 'HAND_LEFT_FEAT', 'HAND_LEFT_BBOX', 'OBJECT_LEFT_FEAT', 'OBJECT_LEFT_BBOX']
51 | key_enc = full_name.strip().encode('utf-8')
52 | if key_enc not in self.env:
53 | raise KeyError("invalid key {}, check lmdb file in {}".format(full_name.strip(), self.lmdb_path))
54 | f_dict = self.env[key_enc]
55 |
56 | global_feat = f_dict['GLOBAL_FEAT']
57 | global_masks.append(1)
58 | global_feats.append(global_feat)
59 |
60 | if 'HAND_RIGHT_FEAT' in f_dict:
61 | rhand_feat = f_dict['HAND_RIGHT_FEAT']
62 | else:
63 | rhand_feat = np.zeros_like(global_feat, dtype=np.float32)
64 | rhand_feats.append(rhand_feat)
65 |
66 | if 'HAND_LEFT_FEAT' in f_dict:
67 | lhand_feat = f_dict['HAND_LEFT_FEAT']
68 | else:
69 | lhand_feat = np.zeros_like(global_feat, dtype=np.float32)
70 | lhand_feats.append(lhand_feat)
71 |
72 | if 'OBJECT_RIGHT_FEAT' in f_dict:
73 | robj_feat = f_dict['OBJECT_RIGHT_FEAT']
74 | else:
75 | robj_feat = np.zeros_like(global_feat, dtype=np.float32)
76 | robj_feats.append(robj_feat)
77 |
78 | if 'OBJECT_LEFT_FEAT' in f_dict:
79 | lobj_feat = f_dict['OBJECT_LEFT_FEAT']
80 | else:
81 | lobj_feat = np.zeros_like(global_feat, dtype=np.float32)
82 | lobj_feats.append(lobj_feat)
83 |
84 | if 'HAND_RIGHT_BBOX' in f_dict:
85 | rhand_bbox = f_dict['HAND_RIGHT_BBOX']
86 | rhand_masks.append(1)
87 | else:
88 | cx, cy = (0.75, 1.5)
89 | sx, sy = (0.1, 0.1)
90 | rhand_bbox = np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2])
91 | rhand_masks.append(0)
92 | rhand_bboxs.append(rhand_bbox)
93 |
94 | if 'HAND_LEFT_BBOX' in f_dict:
95 | lhand_bbox = f_dict['HAND_LEFT_BBOX']
96 | lhand_masks.append(1)
97 | else:
98 | cx, cy = (0.25, 1.5)
99 | sx, sy = (0.1, 0.1)
100 | lhand_bbox = np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2])
101 | lhand_masks.append(0)
102 | lhand_bboxs.append(lhand_bbox)
103 |
104 | if 'OBJECT_RIGHT_BBOX' in f_dict:
105 | robj_bbox = f_dict['OBJECT_RIGHT_BBOX']
106 | robj_masks.append(1)
107 | else:
108 | robj_bbox = np.array([0.0, 0.0, 1.0, 1.0])
109 | robj_masks.append(0)
110 | robj_bboxs.append(robj_bbox)
111 |
112 | if 'OBJECT_LEFT_BBOX' in f_dict:
113 | lobj_bbox = f_dict['OBJECT_LEFT_BBOX']
114 | lobj_masks.append(1)
115 | else:
116 | lobj_bbox = np.array([0.0, 0.0, 1.0, 1.0])
117 | lobj_masks.append(0)
118 | lobj_bboxs.append(lobj_bbox)
119 |
120 | global_feats = np.stack(global_feats, axis=0)
121 | rhand_feats = np.stack(rhand_feats, axis=0)
122 | lhand_feats = np.stack(lhand_feats, axis=0)
123 | robj_feats = np.stack(robj_feats, axis=0)
124 | lobj_feats = np.stack(lobj_feats, axis=0)
125 |
126 | feats = np.stack((global_feats, rhand_feats, lhand_feats, robj_feats, lobj_feats), axis=0)
127 |
128 | rhand_bboxs = np.stack(rhand_bboxs, axis=0)
129 | lhand_bboxs = np.stack(lhand_bboxs, axis=0)
130 | robj_bboxs = np.stack(robj_bboxs, axis=0)
131 | lobj_bboxs = np.stack(lobj_bboxs, axis=0)
132 |
133 | bbox_feats = np.stack((rhand_bboxs, lhand_bboxs, robj_bboxs, lobj_bboxs), axis=0)
134 |
135 | global_masks = np.stack(global_masks, axis=0)
136 | rhand_masks = np.stack(rhand_masks, axis=0)
137 | lhand_masks = np.stack(lhand_masks, axis=0)
138 | robj_masks = np.stack(robj_masks, axis=0)
139 | lobj_masks = np.stack(lobj_masks, axis=0)
140 |
141 | valid_masks = np.stack((global_masks, rhand_masks, lhand_masks, robj_masks, lobj_masks), axis=0)
142 |
143 | out = {"name": full_names, "feat": feats, "bbox_feat": bbox_feats, "valid_mask": valid_masks, 'times': times,
144 | 'start_time': action.start_time, 'frames_idxs': frames_idxs}
145 | return out
146 |
147 |
148 | class EpicHODataset(EpicDataset):
149 | def __init__(self, df, partition, ori_fps=60.0, fps=4.0, loader=None, t_ant=None, transform=None,
150 | num_actions_prev=None, label_path=None, eval_label_path=None,
151 | annot_path=None, rulstm_annot_path=None, ek_version=None):
152 | super().__init__(df=df, partition=partition, ori_fps=ori_fps, fps=fps,
153 | loader=loader, t_ant=t_ant, transform=transform,
154 | num_actions_prev=num_actions_prev)
155 | self.ek_version = ek_version
156 | self.rulstm_annot_path = rulstm_annot_path
157 | self.annot_path = annot_path
158 | self.shape = (456, 256)
159 | self.discarded_labels, self.discarded_ids = self._get_discarded()
160 |
161 | if 'eval' not in self.partition:
162 | self.label_dir = os.path.join(label_path, "labels")
163 | else:
164 | with open(eval_label_path, 'rb') as f:
165 | self.eval_labels = pickle.load(f)
166 |
167 | def _load_eval_labels(self, uid):
168 | video_info = self.eval_labels[uid]
169 | future_hands, future_valid = process_eval_video_info(video_info, fps=self.fps, t_ant=self.t_ant)
170 | return future_hands, future_valid
171 |
172 | def _load_labels(self, uid):
173 | if os.path.exists(os.path.join(self.label_dir, "label_{}.pkl".format(uid))):
174 | label_valid = True
175 | video_info = load_video_info(self.label_dir, uid)
176 | future_hands, contact_point, future_valid, last_frame_index = process_video_info(video_info, fps=self.fps,
177 | t_ant=self.t_ant,
178 | shape=self.shape)
179 | else:
180 | label_valid = False
181 | length = int(self.fps * self.t_ant + 1)
182 | future_hands = np.zeros((2, length, 2), dtype=np.float32)
183 | contact_point = np.zeros((2,), dtype=np.float32)
184 | future_valid = np.array([0, 0], dtype=np.int)
185 | last_frame_index = None
186 |
187 | return future_hands, contact_point, future_valid, last_frame_index, label_valid
188 |
189 | def _get_discarded(self):
190 | discarded_ids = []
191 | discarded_labels = []
192 | if 'train' not in self.partition:
193 | label_type = ['verb', 'noun', 'action']
194 | else:
195 | label_type = 'action'
196 | if 'test' in self.partition:
197 | challenge = True
198 | else:
199 | challenge = False
200 |
201 | for action in self.actions_invalid:
202 | discarded_ids.append(action.uid)
203 | if isinstance(label_type, list):
204 | if challenge:
205 | discarded_labels.append(-1)
206 | else:
207 | verb, noun, action_class = action.verb_class, action.noun_class, action.action_class
208 | label = np.array([verb, noun, action_class], dtype=np.int)
209 | discarded_labels.append(label)
210 | else:
211 | if challenge:
212 | discarded_labels.append(-1)
213 | else:
214 | action_class = action.action_class
215 | discarded_labels.append(action_class)
216 | return discarded_labels, discarded_ids
217 |
218 | def __getitem__(self, idx):
219 | a = self.actions[idx]
220 | sample = {'uid': a.uid}
221 |
222 | inputs = self.loader(a)
223 | sample.update(inputs)
224 |
225 | if 'eval' not in self.partition:
226 | future_hands, contact_point, future_valid, last_frame_index, label_valid = self._load_labels(a.uid)
227 |
228 | sample['future_hands'] = future_hands
229 | sample['contact_point'] = contact_point
230 | sample['future_valid'] = future_valid
231 | sample['label_valid'] = label_valid
232 |
233 | if "frames_idxs" in sample and last_frame_index is not None:
234 | assert last_frame_index == sample["frames_idxs"][-1], \
235 | "dataloader video clip {} last observation frame mismatch, " \
236 | "index load from history s {} while load from future is{}!".format(a.uid, sample["frames_idxs"][-1],
237 | last_frame_index)
238 |
239 | else:
240 | future_hands, future_valid = self._load_eval_labels(a.uid)
241 | sample['future_hands'] = future_hands
242 | sample['future_valid'] = future_valid
243 | sample['label_valid'] = True
244 | # 'contact_point' load in evaluate func
245 |
246 | if 'test' not in self.partition:
247 | sample['verb_class'] = a.verb_class
248 | sample['noun_class'] = a.noun_class
249 | sample['action_class'] = a.action_class
250 | sample['label'] = np.array([a.verb_class, a.noun_class, a.action_class], dtype=np.int)
251 | return sample
252 |
253 |
254 | def get_dataloaders(args, epic_ds=None, featuresloader=None):
255 | dss = get_datasets(args, epic_ds=epic_ds, featuresloader=featuresloader)
256 |
257 | dl_args = {
258 | 'batch_size': args.batch_size,
259 | 'pin_memory': True,
260 | 'num_workers': args.num_workers,
261 | 'drop_last': False
262 | }
263 | if args.mode in ['train', 'training']:
264 | dls = {
265 | 'train': DataLoader(dss['train'], shuffle=True, **dl_args),
266 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args),
267 | 'eval': DataLoader(dss['eval'], shuffle=False, **dl_args)
268 | }
269 | elif args.mode in ['validate', 'validation', 'validating']:
270 | dls = {
271 | 'validation': DataLoader(dss['validation'], shuffle=False, **dl_args),
272 | 'eval': DataLoader(dss['eval'], shuffle=False, **dl_args)
273 | }
274 | elif args.mode == 'test':
275 | if args.ek_version == "ek55":
276 | dls = {
277 | 'test_s1': DataLoader(dss['test_s1'], shuffle=False, **dl_args),
278 | 'test_s2': DataLoader(dss['test_s2'], shuffle=False, **dl_args),
279 | }
280 | elif args.ek_version == "ek100":
281 | dls = {
282 | 'test': DataLoader(dss['test'], shuffle=False, **dl_args),
283 | }
284 | else:
285 | raise Exception(f'Error. Mode "{args.mode}" not supported.')
286 | return dls
287 |
--------------------------------------------------------------------------------
/datasets/input_loaders.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from torchvision import transforms
4 | import torch
5 | import os
6 | import lmdb
7 |
8 |
9 | class ActionAnticipationSampler(object):
10 | def __init__(self, t_buffer, t_ant=1.0, fps=4.0, ori_fps=60.0):
11 | self.t_buffer = t_buffer
12 | self.t_ant = t_ant
13 | self.fps = fps
14 | self.ori_fps = ori_fps
15 |
16 | def __call__(self, action):
17 | times, frames_idxs = sample_history_frames(action.start_frame, self.t_buffer,
18 | self.t_ant, fps=self.fps,
19 | fps_init=self.ori_fps)
20 | return times, frames_idxs
21 |
22 |
23 | def get_sampler(args):
24 | sampler = ActionAnticipationSampler(t_buffer=args.t_buffer, t_ant=args.t_ant,
25 | fps=args.fps, ori_fps=args.ori_fps)
26 | return sampler
27 |
28 |
29 | def sample_history_frames(frame_start, t_buffer=2.5, t_ant=1.0, fps=4.0, fps_init=60.0):
30 | time_start = (frame_start - 1) / fps_init
31 | num_frames = int(np.floor(t_buffer * fps))
32 | time_ant = time_start - t_ant
33 | times = (np.arange(1, num_frames + 1) - num_frames) / fps + time_ant
34 | times = np.clip(times, 0, np.inf)
35 | times = times.astype(np.float32)
36 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1
37 | times = (frames_idxs - 1) / fps_init
38 | return times, frames_idxs
39 |
40 |
41 | def sample_future_frames(frame_start, t_buffer=1, fps=4.0, fps_init=60.0):
42 | time_start = (frame_start - 1) / fps_init
43 | num_frames = int(np.floor(t_buffer * fps))
44 | times = (np.arange(num_frames + 1) - num_frames) / fps + time_start
45 | times = np.clip(times, 0, np.inf)
46 | times = times.astype(np.float32)
47 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1
48 | if frames_idxs.max() >= 1:
49 | frames_idxs[frames_idxs < 1] = frames_idxs[frames_idxs >= 1].min()
50 | return list(frames_idxs)
51 |
52 |
53 | class FeaturesLoader(object):
54 | def __init__(self, sampler, feature_base_path, fps, input_name='rgb',
55 | frame_tmpl='frame_{:010d}.jpg', transform_feat=None,
56 | transform_video=None):
57 | self.feature_base_path = feature_base_path
58 | self.env = lmdb.open(os.path.join(self.feature_base_path, input_name), readonly=True, lock=False)
59 | self.fps = fps
60 | self.input_name = input_name
61 | self.frame_tmpl = frame_tmpl
62 | self.transform_feat = transform_feat
63 | self.transform_video = transform_video
64 | self.sampler = sampler
65 |
66 | def __call__(self, action):
67 | times, frames_idxs = self.sampler(action)
68 | frames_names = [self.frame_tmpl.format(action.video_id, i) for i in frames_idxs]
69 | feats = []
70 | with self.env.begin() as env:
71 | for f_name in frames_names:
72 | feat = env.get(f_name.strip().encode('utf-8'))
73 | if feat is None:
74 | print(f_name)
75 | feat = np.frombuffer(feat, 'float32')
76 |
77 | if self.transform_feat is not None:
78 | feat = self.transform_feat(feat)
79 | feats += [feat]
80 |
81 | if self.transform_video is not None:
82 | feats = self.transform_video(feats)
83 | out = {self.input_name: feats}
84 | out['times'] = times
85 | out['start_time'] = action.start_time
86 | out['frames_idxs'] = frames_idxs
87 | return out
88 |
89 |
90 | class PipeLoaders(object):
91 | def __init__(self, loader_list):
92 | self.loader_list = loader_list
93 |
94 | def __call__(self, action):
95 | out = {}
96 | for loader in self.loader_list:
97 | out.update(loader(action))
98 | return out
99 |
100 |
101 | def get_features_loader(args, featuresloader=None):
102 | sampler = get_sampler(args)
103 | feat_in_modalities = list({'feat'}.intersection(args.modalities))
104 | transform_feat = lambda x: torch.tensor(x.copy())
105 | transform_video = lambda x: torch.stack(x, 0)
106 | loader_args = {
107 | 'feature_base_path': args.features_paths[args.ek_version],
108 | 'fps': args.fps,
109 | 'frame_tmpl': 'frame_{:010d}.jpg',
110 | 'transform_feat': transform_feat,
111 | 'transform_video': transform_video,
112 | 'sampler': sampler}
113 | if featuresloader is None:
114 | featuresloader = FeaturesLoader
115 | feat_loader_list = []
116 | for modality in feat_in_modalities:
117 | feat_loader = featuresloader(input_name=modality, **loader_args)
118 | feat_loader_list += [feat_loader]
119 | feat_loaders = {
120 | 'train': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None,
121 | 'validation': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None,
122 | 'test': PipeLoaders(feat_loader_list) if len(feat_loader_list) else None,
123 | }
124 | return feat_loaders
125 |
126 |
127 | def get_loaders(args, featuresloader=None):
128 | loaders = {
129 | 'train': [],
130 | 'validation': [],
131 | 'test': [],
132 | }
133 |
134 | if 'feat' in args.modalities:
135 | feat_loaders = get_features_loader(args, featuresloader=featuresloader)
136 | for k, l in feat_loaders.items():
137 | if l is not None:
138 | loaders[k] += [l]
139 |
140 | for k, l in loaders.items():
141 | loaders[k] = PipeLoaders(l)
142 | return loaders
143 |
--------------------------------------------------------------------------------
/demo_gen.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import argparse
4 | import pandas as pd
5 | import cv2
6 |
7 | from preprocess.traj_util import compute_hand_traj
8 | from preprocess.dataset_util import FrameDetections, sample_action_anticipation_frames, fetch_data, save_video_info
9 | from preprocess.obj_util import compute_obj_traj
10 | from preprocess.affordance_util import compute_obj_affordance
11 | from preprocess.vis_util import vis_affordance, vis_hand_traj
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--label_path', default="assets/EPIC_train_action_labels.csv", type=str, help="dataset annotation")
17 | parser.add_argument('--dataset_path', default="assets/EPIC-KITCHENS", type=str, help='dataset root')
18 | parser.add_argument('--save_path', default="./figs", type=str, help="generated results save path")
19 | parser.add_argument('--fps', default=10, type=int, help="sample frames per second")
20 | parser.add_argument('--hand_threshold', default=0.1, type=float, help="hand detection threshold")
21 | parser.add_argument('--obj_threshold', default=0.1, type=float, help="object detection threshold")
22 | parser.add_argument('--contact_ratio', default=0.4, type=float, help="active obj contact frames ratio")
23 | parser.add_argument('--num_sampling', default=20, type=int, help="sampling points for affordance")
24 | parser.add_argument('--num_points', default=5, type=int, help="selected points for affordance")
25 |
26 | args = parser.parse_args()
27 | os.makedirs(args.save_path, exist_ok=True)
28 | save_path = args.save_path
29 |
30 | uid = 52
31 |
32 | annotations = pd.read_csv(args.label_path)
33 | participant_id = annotations.loc[annotations['uid'] == uid].participant_id.item()
34 | video_id = annotations.loc[annotations['uid'] == uid].video_id.item()
35 | frames_path = os.path.join(args.dataset_path, participant_id, "rgb_frames", video_id)
36 | ho_path = os.path.join(args.dataset_path, participant_id, "hand-objects", "{}.pkl".format(video_id))
37 | start_act_frame = annotations.loc[annotations['uid'] == uid].start_frame.item()
38 | frames_idxs = sample_action_anticipation_frames(start_act_frame, fps=args.fps)
39 |
40 | with open(ho_path, "rb") as f:
41 | video_detections = [FrameDetections.from_protobuf_str(s) for s in pickle.load(f)]
42 | results = fetch_data(frames_path, video_detections, frames_idxs)
43 | if results is None:
44 | print("data fetch failed")
45 | else:
46 | frames_idxs, frames, annots, hand_sides = results
47 |
48 | results_hand = compute_hand_traj(frames, annots, hand_sides, hand_threshold=args.hand_threshold,
49 | obj_threshold=args.obj_threshold)
50 | if results_hand is None:
51 | print("compute traj failed") # homography fails or not enough points
52 | else:
53 | homography_stack, hand_trajs = results_hand
54 | results_obj = compute_obj_traj(frames, annots, hand_sides, homography_stack,
55 | hand_threshold=args.hand_threshold,
56 | obj_threshold=args.obj_threshold,
57 | contact_ratio=args.contact_ratio)
58 | if results_obj is None:
59 | print("compute obj traj failed")
60 | else:
61 | contacts, obj_trajs, active_obj, active_object_idx, obj_bboxs_traj = results_obj
62 | frame, homography = frames[-1], homography_stack[-1]
63 | affordance_info = compute_obj_affordance(frame, annots[-1], active_obj, active_object_idx, homography,
64 | active_obj_traj=obj_trajs['traj'], obj_bboxs_traj=obj_bboxs_traj,
65 | num_points=args.num_points, num_sampling=args.num_sampling)
66 | if affordance_info is not None:
67 | img_vis = vis_hand_traj(frames, hand_trajs)
68 | img_vis = vis_affordance(img_vis, affordance_info)
69 | img = cv2.hconcat([img_vis, frames[-1]])
70 | cv2.imwrite(os.path.join(save_path, "demo_{}.jpg".format(uid)), img)
71 | save_video_info(save_path, uid, frames_idxs, homography_stack, contacts, hand_trajs, obj_trajs, affordance_info)
72 | print(f"result stored at {save_path}")
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: fhoi
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.6
6 | - pip
7 | - pip:
8 | - pandas
9 | - numpy
10 | - matplotlib
11 | - dataclasses
12 | - protobuf
13 | - Scipy
14 | - opencv-contrib-python==3.4.2.16
--------------------------------------------------------------------------------
/evaluation/affordance_eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import cv2
4 | from sklearn.cluster import KMeans
5 | from joblib import Parallel, delayed
6 |
7 |
8 | def farthest_sampling(pcd, n_samples):
9 | def compute_distance(a, b):
10 | return np.linalg.norm(a - b, ord=2, axis=2)
11 |
12 | n_pts, dim = pcd.shape[0], pcd.shape[1]
13 | selected_pts_expanded = np.zeros(shape=(n_samples, 1, dim))
14 | remaining_pts = np.copy(pcd)
15 |
16 | if n_pts > 1:
17 | start_idx = np.random.randint(low=0, high=n_pts - 1)
18 | else:
19 | start_idx = 0
20 | selected_pts_expanded[0] = remaining_pts[start_idx]
21 | n_selected_pts = 1
22 |
23 | for _ in range(1, n_samples):
24 | if n_selected_pts < n_samples:
25 | dist_pts_to_selected = compute_distance(remaining_pts, selected_pts_expanded[:n_selected_pts]).T
26 | dist_pts_to_selected_min = np.min(dist_pts_to_selected, axis=1, keepdims=True)
27 | res_selected_idx = np.argmax(dist_pts_to_selected_min)
28 | selected_pts_expanded[n_selected_pts] = remaining_pts[res_selected_idx]
29 | n_selected_pts += 1
30 |
31 | selected_pts = np.squeeze(selected_pts_expanded, axis=1)
32 | return selected_pts
33 |
34 |
35 | def makeGaussian(size, fwhm=3., center=None):
36 | x = np.arange(0, size, 1, float)
37 | y = x[:, np.newaxis]
38 | if center is None:
39 | x0 = y0 = size // 2
40 | else:
41 | x0 = center[0]
42 | y0 = center[1]
43 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / fwhm ** 2)
44 |
45 |
46 | def compute_heatmap(normalized_points, image_size, k_ratio=3.0, transpose=True,
47 | fps=False, kmeans=False, n_pts=5, gaussian_sigma=0.):
48 | normalized_points = np.asarray(normalized_points)
49 | heatmap = np.zeros((image_size[0], image_size[1]), dtype=np.float32)
50 | n_points = normalized_points.shape[0]
51 | if n_points > n_pts and kmeans:
52 | kmeans = KMeans(n_clusters=n_pts, random_state=0).fit(normalized_points)
53 | normalized_points = kmeans.cluster_centers_
54 | elif n_points > n_pts and fps:
55 | normalized_points = farthest_sampling(normalized_points, n_samples=n_pts)
56 | n_points = normalized_points.shape[0]
57 | for i in range(n_points):
58 | x = normalized_points[i, 0] * image_size[0]
59 | y = normalized_points[i, 1] * image_size[1]
60 | col = int(x)
61 | row = int(y)
62 | try:
63 | heatmap[col, row] += 1.0
64 | except:
65 | col = min(max(col, 0), image_size[0] - 1)
66 | row = min(max(row, 0), image_size[1] - 1)
67 | heatmap[col, row] += 1.0
68 | k_size = int(np.sqrt(image_size[0] * image_size[1]) / k_ratio)
69 | if k_size % 2 == 0:
70 | k_size += 1
71 | heatmap = cv2.GaussianBlur(heatmap, (k_size, k_size), gaussian_sigma)
72 | if heatmap.max() > 0:
73 | heatmap /= heatmap.max()
74 | if transpose:
75 | heatmap = heatmap.transpose()
76 | return heatmap
77 |
78 |
79 | def SIM(map1, map2, eps=1e-12):
80 | map1, map2 = map1 / (map1.sum() + eps), map2 / (map2.sum() + eps)
81 | intersection = np.minimum(map1, map2)
82 | return np.sum(intersection)
83 |
84 |
85 | def AUC_Judd(saliency_map, fixation_map, jitter=True):
86 | saliency_map = np.array(saliency_map, copy=False)
87 | fixation_map = np.array(fixation_map, copy=False) > 0.5
88 | if not np.any(fixation_map):
89 | return np.nan
90 | if saliency_map.shape != fixation_map.shape:
91 | saliency_map = cv2.resize(saliency_map, fixation_map.shape, interpolation=cv2.INTER_AREA)
92 | if jitter:
93 | saliency_map += np.random.rand(*saliency_map.shape) * 1e-7
94 | saliency_map = (saliency_map - np.min(saliency_map)) / (np.max(saliency_map) - np.min(saliency_map) + 1e-12)
95 |
96 | S = saliency_map.ravel()
97 | F = fixation_map.ravel()
98 | S_fix = S[F]
99 | n_fix = len(S_fix)
100 | n_pixels = len(S)
101 | thresholds = sorted(S_fix, reverse=True)
102 | tp = np.zeros(len(thresholds) + 2)
103 | fp = np.zeros(len(thresholds) + 2)
104 | tp[0] = 0;
105 | tp[-1] = 1
106 | fp[0] = 0;
107 | fp[-1] = 1
108 | for k, thresh in enumerate(thresholds):
109 | above_th = np.sum(S >= thresh)
110 | tp[k + 1] = (k + 1) / float(n_fix)
111 | fp[k + 1] = (above_th - (k + 1)) / float(n_pixels - n_fix)
112 | return np.trapz(tp, fp)
113 |
114 |
115 | def NSS(saliency_map, fixation_map):
116 | MAP = (saliency_map - saliency_map.mean()) / (saliency_map.std())
117 | mask = fixation_map.astype(np.bool)
118 | score = MAP[mask].mean()
119 | return score
120 |
121 |
122 | def compute_score(pred, gt, valid_thresh=0.001):
123 | if torch.is_tensor(pred):
124 | pred = pred.numpy()
125 | if torch.is_tensor(gt):
126 | gt = gt.numpy()
127 |
128 | pred = pred / (pred.max() + 1e-12)
129 |
130 | all_thresh = np.linspace(0.001, 1.0, 41)
131 | tp = np.zeros((all_thresh.shape[0],))
132 | fp = np.zeros((all_thresh.shape[0],))
133 | fn = np.zeros((all_thresh.shape[0],))
134 | tn = np.zeros((all_thresh.shape[0],))
135 | valid_gt = gt > valid_thresh
136 | for idx, thresh in enumerate(all_thresh):
137 | mask = (pred >= thresh)
138 | tp[idx] += np.sum(np.logical_and(mask == 1, valid_gt == 1))
139 | tn[idx] += np.sum(np.logical_and(mask == 0, valid_gt == 0))
140 | fp[idx] += np.sum(np.logical_and(mask == 1, valid_gt == 0))
141 | fn[idx] += np.sum(np.logical_and(mask == 0, valid_gt == 1))
142 |
143 | scores = {}
144 | gt_real = np.array(gt)
145 | if gt_real.sum() == 0:
146 | gt_real = np.ones(gt_real.shape) / np.product(gt_real.shape)
147 |
148 | score = SIM(pred, gt_real)
149 | scores['SIM'] = score if not np.isnan(score) else None
150 |
151 | gt_binary = np.array(gt)
152 | gt_binary = (gt_binary / gt_binary.max() + 1e-12) if gt_binary.max() > 0 else gt_binary
153 | gt_binary = np.where(gt_binary > 0.5, 1, 0)
154 | score = AUC_Judd(pred, gt_binary)
155 | scores['AUC-J'] = score if not np.isnan(score) else None
156 |
157 | score = NSS(pred, gt_binary)
158 | scores['NSS'] = score if not np.isnan(score) else None
159 |
160 | return dict(scores), tp, tn, fp, fn
161 |
162 |
163 | def evaluate_affordance(preds_dict, gts_dict, val_log=None,
164 | sz=32, fps=False, kmeans=False, n_pts=5,
165 | gaussian_sigma=3., gaussian_k_ratio=3.):
166 | scores = []
167 | all_thresh = np.linspace(0.001, 1.0, 41)
168 | tp = np.zeros((all_thresh.shape[0],))
169 | fp = np.zeros((all_thresh.shape[0],))
170 | fn = np.zeros((all_thresh.shape[0],))
171 | tn = np.zeros((all_thresh.shape[0],))
172 |
173 | pred_hmaps = Parallel(n_jobs=16, verbose=0)(delayed(compute_heatmap)(norm_contacts, (sz, sz),
174 | fps=fps, kmeans=kmeans, n_pts=n_pts,
175 | gaussian_sigma=gaussian_sigma,
176 | k_ratio=gaussian_k_ratio)
177 | for (uid, norm_contacts) in preds_dict.items())
178 | gt_hmaps = Parallel(n_jobs=16, verbose=0)(delayed(compute_heatmap)(norm_contacts, (sz, sz),
179 | fps=fps, n_pts=n_pts,
180 | gaussian_sigma=0,
181 | k_ratio=3.)
182 | for (uid, norm_contacts) in gts_dict.items())
183 |
184 | for (pred_hmap, gt_hmap) in zip(pred_hmaps, gt_hmaps):
185 | score, ctp, ctn, cfp, cfn = compute_score(pred_hmap, gt_hmap)
186 | scores.append(score)
187 | tp = tp + ctp
188 | tn = tn + ctn
189 | fp = fp + cfp
190 | fn = fn + cfn
191 |
192 | write_out = []
193 | metrics = {}
194 | for key in ['SIM', 'AUC-J', 'NSS']:
195 | key_score = [s[key] for s in scores if s[key] is not None]
196 | mean, stderr = np.mean(key_score), np.std(key_score) / (np.sqrt(len(key_score)))
197 | log_str = '%s: %.3f ± %.3f (%d/%d)' % (key, mean, stderr, len(key_score), len(gts_dict))
198 | write_out.append(log_str)
199 | metrics[key] = mean
200 | write_out = '\n'.join(write_out)
201 | print(write_out)
202 | if val_log is not None:
203 | with open(val_log, "a") as log_file:
204 | log_file.write(write_out + '\n')
205 |
206 | write_out = []
207 | prec = tp / (tp + fp + 1e-6)
208 | recall = tp / (tp + fn + 1e-6)
209 | f1 = 2 * prec * recall / (prec + recall + 1e-6)
210 | idx = np.argmax(f1)
211 | prec_score = prec[idx]
212 | f1_score = f1[idx]
213 | recall_score = recall[idx]
214 |
215 | log_str = 'Precision: {:.3f}'.format(prec_score)
216 | write_out.append(log_str)
217 | log_str = 'Recall: {:0.4f}'.format(recall_score)
218 | write_out.append(log_str)
219 | log_str = 'F1 Score: {:0.4f}'.format(f1_score)
220 | write_out.append(log_str)
221 | write_out = '\n'.join(write_out)
222 | print(write_out)
223 | if val_log is not None:
224 | with open(val_log, "a") as log_file:
225 | log_file.write(write_out + '\n')
226 |
227 | return metrics
--------------------------------------------------------------------------------
/evaluation/traj_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def compute_ade(pred_traj, gt_traj, valid_traj=None, reduction=True):
6 | valid_loc = (gt_traj[:, :, :, 0] >= 0) & (gt_traj[:, :, :, 1] >= 0) \
7 | & (gt_traj[:, :, :, 0] < 1) & (gt_traj[:, :, :, 1] < 1)
8 |
9 | error = gt_traj - pred_traj
10 | error = error * valid_loc[:, :, :, None]
11 |
12 | if torch.is_tensor(error):
13 | if valid_traj is None:
14 | valid_traj = torch.ones(pred_traj.shape[0], pred_traj.shape[1])
15 | error = error ** 2
16 | ade = torch.sqrt(error.sum(dim=3)).mean(dim=2) * valid_traj
17 | if reduction:
18 | ade = ade.sum() / valid_traj.sum()
19 | valid_traj = valid_traj.sum()
20 | else:
21 | if valid_traj is None:
22 | valid_traj = np.ones((pred_traj.shape[0], pred_traj.shape[1]), dtype=int)
23 | error = np.linalg.norm(error, axis=3)
24 | ade = error.mean(axis=2) * valid_traj
25 | if reduction:
26 | ade = ade.sum() / valid_traj.sum()
27 | valid_traj = valid_traj.sum()
28 |
29 | return ade, valid_traj
30 |
31 |
32 | def compute_fde(pred_traj, gt_traj, valid_traj=None, reduction=True):
33 | pred_last = pred_traj[:, :, -1, :]
34 | gt_last = gt_traj[:, :, -1, :]
35 |
36 | valid_loc = (gt_last[:, :, 0] >= 0) & (gt_last[:, :, 1] >= 0) \
37 | & (gt_last[:, :, 0] < 1) & (gt_last[:, :, 1] < 1)
38 |
39 | error = gt_last - pred_last
40 | error = error * valid_loc[:, :, None]
41 |
42 | if torch.is_tensor(error):
43 | if valid_traj is None:
44 | valid_traj = torch.ones(pred_traj.shape[0], pred_traj.shape[1])
45 | error = error ** 2
46 | fde = torch.sqrt(error.sum(dim=2)) * valid_traj
47 | if reduction:
48 | fde = fde.sum() / valid_traj.sum()
49 | valid_traj = valid_traj.sum()
50 | else:
51 | if valid_traj is None:
52 | valid_traj = np.ones((pred_traj.shape[0], pred_traj.shape[1]), dtype=int)
53 | error = np.linalg.norm(error, axis=2)
54 | fde = error * valid_traj
55 | if reduction:
56 | fde = fde.sum() / valid_traj.sum()
57 | valid_traj = valid_traj.sum()
58 |
59 | return fde, valid_traj
60 |
61 |
62 | def evaluate_traj_stochastic(preds, gts, valids):
63 | len_dataset, num_samples, num_obj = preds.shape[0], preds.shape[1], preds.shape[2]
64 | ade_list, fde_list = [], []
65 | for idx in range(num_samples):
66 | ade, _ = compute_fde(preds[:, idx, :, :, :], gts, valids, reduction=False)
67 | ade_list.append(ade)
68 | fde, _ = compute_ade(preds[:, idx, :, :, :], gts, valids, reduction=False)
69 | fde_list.append(fde)
70 |
71 | if torch.is_tensor(preds):
72 | ade_list = torch.stack(ade_list, dim=0)
73 | fde_list = torch.stack(fde_list, dim=0)
74 |
75 | ade_err_min, _ = torch.min(ade_list, dim=0)
76 | ade_err_min = ade_err_min * valids
77 | fde_err_min, _ = torch.min(fde_list, dim=0)
78 | fde_err_min = fde_err_min * valids
79 |
80 | ade_err_mean = torch.mean(ade_list, dim=0)
81 | ade_err_mean = ade_err_mean * valids
82 | fde_err_mean = torch.mean(fde_list, dim=0)
83 | fde_err_mean = fde_err_mean * valids
84 |
85 | ade_err_std = torch.std(ade_list, dim=0) * np.sqrt((ade_list.shape[0] - 1.) / ade_list.shape[0])
86 | ade_err_std = ade_err_std * valids
87 | fde_err_std = torch.std(fde_list, dim=0) * np.sqrt((fde_list.shape[0] - 1.) / fde_list.shape[0])
88 | fde_err_std = fde_err_std * valids
89 |
90 | else:
91 | ade_list = np.array(ade_list, dtype=np.float32)
92 | fde_list = np.array(fde_list, dtype=np.float32)
93 |
94 | ade_err_min = ade_list.min(axis=0) * valids
95 | fde_err_min = fde_list.min(axis=0) * valids
96 |
97 | ade_err_mean = ade_list.mean(axis=0) * valids
98 | fde_err_mean = fde_list.mean(axis=0) * valids
99 |
100 | ade_err_std = ade_list.std(axis=0) * valids
101 | fde_err_std = fde_list.std(axis=0) * valids
102 |
103 | ade_mean = ade_err_mean.sum() / valids.sum()
104 | fde_mean = fde_err_mean.sum() / valids.sum()
105 |
106 | ade_std = ade_err_std.sum() / valids.sum()
107 | fde_std = fde_err_std.sum() / valids.sum()
108 | ade_mean_info = 'ADE: %.3f ± %.3f (%d/%d)' % (ade_mean, ade_std, valids.sum(), len_dataset * num_obj)
109 | fde_mean_info = "FDE: %.3f ± %.3f (%d/%d)" % (fde_mean, fde_std, valids.sum(), len_dataset * num_obj)
110 |
111 | ade_min = ade_err_min.sum() / valids.sum()
112 | fde_min = fde_err_min.sum() / valids.sum()
113 | ade_min_info = 'min ADE: %.3f (%d/%d)' % (ade_min, valids.sum(), len_dataset * num_obj)
114 | fde_min_info = "min FDE: %.3f (%d/%d)" % (fde_min, valids.sum(), len_dataset * num_obj)
115 |
116 | print(ade_min_info)
117 | print(fde_min_info)
118 | print(ade_mean_info)
119 | print(fde_mean_info)
120 |
121 | return ade_mean, fde_mean
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/netscripts/epoch_feat.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 |
5 | from netscripts.epoch_utils import progress_bar as bar, AverageMeters
6 | from evaluation.traj_eval import evaluate_traj_stochastic
7 | from evaluation.affordance_eval import evaluate_affordance
8 |
9 |
10 | def epoch_pass(loader, model, epoch, phase, optimizer=None,
11 | train=True, use_cuda=False,
12 | num_samples=5, pred_len=4, num_points=5, gaussian_sigma=3., gaussian_k_ratio=3.,
13 | scheduler=None):
14 | time_meters = AverageMeters()
15 |
16 | if train:
17 | print(f"{phase} epoch: {epoch + 1}")
18 | loss_meters = AverageMeters()
19 | model.train()
20 | else:
21 | print(f"evaluate epoch {epoch}")
22 | preds_traj, gts_traj, valids_traj = [], [], []
23 | gts_affordance_dict, preds_affordance_dict = {}, {}
24 | model.eval()
25 |
26 | if use_cuda and torch.cuda.is_available():
27 | device = torch.device('cuda')
28 | else:
29 | device = torch.device("cpu")
30 | end = time.time()
31 | for batch_idx, sample in enumerate(loader):
32 | if train:
33 | input = sample['feat'].float().to(device)
34 | bbox_feat = sample['bbox_feat'].float().to(device)
35 | valid_mask = sample['valid_mask'].float().to(device)
36 | future_hands = sample['future_hands'].float().to(device)
37 | contact_point = sample['contact_point'].float().to(device)
38 | future_valid = sample['future_valid'].float().to(device)
39 | time_meters.add_loss_value("data_time", time.time() - end)
40 | model_loss, model_losses = model(input, bbox_feat=bbox_feat,
41 | valid_mask=valid_mask, future_hands=future_hands,
42 | contact_point=contact_point, future_valid=future_valid)
43 |
44 | optimizer.zero_grad()
45 | model_loss.backward()
46 | optimizer.step()
47 |
48 | for key, val in model_losses.items():
49 | if val is not None:
50 | loss_meters.add_loss_value(key, val)
51 |
52 | time_meters.add_loss_value("batch_time", time.time() - end)
53 |
54 | suffix = "({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s " \
55 | "| Hand Traj Loss: {traj_loss:.3f} " \
56 | "| Hand Traj KL Loss: {traj_kl_loss:.3f} " \
57 | "| Object Affordance Loss: {obj_loss:.3f} " \
58 | "| Object Affordance KL Loss: {obj_kl_loss:.3f} " \
59 | "| Total Loss: {total_loss:.3f} ".format(batch=batch_idx + 1, size=len(loader),
60 | data=time_meters.average_meters["data_time"].val,
61 | bt=time_meters.average_meters["batch_time"].avg,
62 | traj_loss=loss_meters.average_meters["traj_loss"].avg,
63 | traj_kl_loss=loss_meters.average_meters[
64 | "traj_kl_loss"].avg,
65 | obj_loss=loss_meters.average_meters["obj_loss"].avg,
66 | obj_kl_loss=loss_meters.average_meters["obj_kl_loss"].avg,
67 | total_loss=loss_meters.average_meters[
68 | "total_loss"].avg)
69 | bar(suffix)
70 | end = time.time()
71 | if scheduler is not None:
72 | scheduler.step()
73 | else:
74 | input = sample['feat'].float().to(device)
75 | bbox_feat = sample['bbox_feat'].float().to(device)
76 | valid_mask = sample['valid_mask'].float().to(device)
77 | future_valid = sample['future_valid'].float().to(device)
78 |
79 | time_meters.add_loss_value("data_time", time.time() - end)
80 |
81 | with torch.no_grad():
82 | pred_future_hands, contact_points = model(input, bbox_feat, valid_mask,
83 | num_samples=num_samples,
84 | future_valid=future_valid,
85 | pred_len=pred_len)
86 |
87 | uids = sample['uid'].numpy()
88 | future_hands = sample['future_hands'][:, :, 1:, :].float().numpy()
89 | future_valid = sample['future_valid'].float().numpy()
90 |
91 | gts_traj.append(future_hands)
92 | valids_traj.append(future_valid)
93 |
94 | pred_future_hands = pred_future_hands.cpu().numpy()
95 | preds_traj.append(pred_future_hands)
96 |
97 | if 'eval' in loader.dataset.partition:
98 | contact_points = contact_points.cpu().numpy()
99 | for idx, uid in enumerate(uids):
100 | gts_affordance_dict[uid] = loader.dataset.eval_labels[uid]['norm_contacts']
101 | preds_affordance_dict[uid] = contact_points[idx]
102 |
103 | time_meters.add_loss_value("batch_time", time.time() - end)
104 |
105 | suffix = "({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s" \
106 | .format(batch=batch_idx + 1, size=len(loader),
107 | data=time_meters.average_meters["data_time"].val,
108 | bt=time_meters.average_meters["batch_time"].avg)
109 |
110 | bar(suffix)
111 | end = time.time()
112 |
113 | if train:
114 | return loss_meters
115 | else:
116 | val_info = {}
117 | if phase == "traj":
118 | gts_traj = np.concatenate(gts_traj)
119 | preds_traj = np.concatenate(preds_traj)
120 | valids_traj = np.concatenate(valids_traj)
121 |
122 | ade, fde = evaluate_traj_stochastic(preds_traj, gts_traj, valids_traj)
123 | val_info.update({"traj_ade": ade, "traj_fde": fde})
124 |
125 | if 'eval' in loader.dataset.partition and phase == "affordance":
126 | affordance_metrics = evaluate_affordance(preds_affordance_dict,
127 | gts_affordance_dict,
128 | n_pts=num_points,
129 | gaussian_sigma=gaussian_sigma,
130 | gaussian_k_ratio=gaussian_k_ratio)
131 | val_info.update(affordance_metrics)
132 |
133 | return val_info
134 |
--------------------------------------------------------------------------------
/netscripts/epoch_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 |
4 | def progress_bar(msg=None):
5 |
6 | L = []
7 | if msg:
8 | L.append(msg)
9 |
10 | msg = ''.join(L)
11 | sys.stdout.write(msg+'\n')
12 |
13 |
14 | class AverageMeter(object):
15 | """Computes and stores the average and current value"""
16 |
17 | def __init__(self):
18 | self.reset()
19 |
20 | def reset(self):
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.count = 0
25 |
26 | def update(self, val, n=1):
27 | self.val = val
28 | self.sum += val * n
29 | self.count += n
30 | self.avg = self.sum / self.count
31 |
32 |
33 | class AverageMeters:
34 | def __init__(self):
35 | super().__init__()
36 | self.average_meters = {}
37 |
38 | def add_loss_value(self, loss_name, loss_val, n=1):
39 | if loss_name not in self.average_meters:
40 | self.average_meters[loss_name] = AverageMeter()
41 | self.average_meters[loss_name].update(loss_val, n=n)
--------------------------------------------------------------------------------
/netscripts/get_datasets.py:
--------------------------------------------------------------------------------
1 | from datasets.datasetopts import DatasetArgs
2 | from datasets.holoaders import EpicHODataset as HODataset, FeaturesHOLoader, get_dataloaders
3 |
4 |
5 | def get_dataset(args, base_path="./"):
6 | if args.evaluate:
7 | mode = "validation"
8 | else:
9 | mode = 'train'
10 |
11 | datasetargs = DatasetArgs(ek_version=args.ek_version, mode=mode,
12 | use_label_only=True, base_path=base_path,
13 | batch_size=args.batch_size, num_workers=args.workers)
14 |
15 | dls = get_dataloaders(datasetargs, HODataset, featuresloader=FeaturesHOLoader)
16 | return mode, dls
17 |
--------------------------------------------------------------------------------
/netscripts/get_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from networks.traj_decoder import TrajCVAE
3 | from networks.affordance_decoder import AffordanceCVAE
4 | from networks.transformer import ObjectTransformer
5 | from networks.model import Model
6 |
7 |
8 | def get_network(args, num_frames_input=10, num_frames_output=4):
9 | hand_head = TrajCVAE(in_dim=2, hidden_dim=args.hidden_dim,
10 | latent_dim=args.latent_dim, condition_dim=args.embed_dim,
11 | coord_dim=args.coord_dim)
12 | obj_head = AffordanceCVAE(in_dim=2, hidden_dim=args.hidden_dim,
13 | latent_dim=args.latent_dim, condition_dim=args.embed_dim)
14 | net = ObjectTransformer(src_in_features=args.src_in_features,
15 | trg_in_features=args.trg_in_features,
16 | num_patches=args.num_patches,
17 | hand_head=hand_head, obj_head=obj_head,
18 | encoder_time_embed_type=args.encoder_time_embed_type,
19 | decoder_time_embed_type=args.decoder_time_embed_type,
20 | num_frames_input=num_frames_input,
21 | num_frames_output=num_frames_output,
22 | embed_dim=args.embed_dim, coord_dim=args.coord_dim,
23 | num_heads=args.num_heads, enc_depth=args.enc_depth, dec_depth=args.dec_depth)
24 | net = torch.nn.DataParallel(net)
25 |
26 | model = Model(net, lambda_obj=args.lambda_obj, lambda_traj=args.lambda_traj,
27 | lambda_obj_kl=args.lambda_obj_kl, lambda_traj_kl=args.lambda_traj_kl)
28 | return model
29 |
--------------------------------------------------------------------------------
/netscripts/get_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Warmup(torch.optim.lr_scheduler._LRScheduler):
5 | def __init__(
6 | self,
7 | optimizer: torch.optim.Optimizer,
8 | scheduler: torch.optim.lr_scheduler._LRScheduler,
9 | init_lr_ratio: float = 0.0,
10 | num_epochs: int = 5,
11 | last_epoch: int = -1,
12 | iters_per_epoch: int = None,
13 | ):
14 | self.base_scheduler = scheduler
15 | self.warmup_iters = max(num_epochs * iters_per_epoch, 1)
16 | if self.warmup_iters > 1:
17 | self.init_lr_ratio = init_lr_ratio
18 | else:
19 | self.init_lr_ratio = 1.0
20 | super().__init__(optimizer, last_epoch)
21 |
22 | def get_lr(self):
23 | assert self.last_epoch < self.warmup_iters
24 | return [
25 | el * (self.init_lr_ratio + (1 - self.init_lr_ratio) *
26 | (float(self.last_epoch) / self.warmup_iters))
27 | for el in self.base_lrs
28 | ]
29 |
30 | def step(self, *args, **kwargs):
31 | if self.last_epoch < (self.warmup_iters - 1):
32 | super().step(*args, **kwargs)
33 | else:
34 | self.base_scheduler.step(*args, **kwargs)
35 |
36 |
37 | def get_optimizer(args, model, train_loader):
38 | assert train_loader is not None, "train_loader is None, " \
39 | "warmup or cosine learning rate need number of iterations in dataloader"
40 | iters_per_epoch = len(train_loader)
41 | vae_params = [p for p_name, p in model.named_parameters()
42 | if ('vae' in p_name or 'head' in p_name) and p.requires_grad]
43 | other_params = [p for p_name, p in model.named_parameters()
44 | if ('vae' not in p_name and 'head' not in p_name) and p.requires_grad]
45 |
46 | if args.optimizer == "adam":
47 | optimizer = torch.optim.Adam([{'params': vae_params, 'weight_decay': 0.0}, {'params': other_params}],
48 | lr=args.lr, weight_decay=args.weight_decay)
49 | elif args.optimizer == "rms":
50 | optimizer = torch.optim.RMSprop([{'params': vae_params, 'weight_decay': 0.0}, {'params': other_params}],
51 | lr=args.lr, weight_decay=args.weight_decay)
52 | elif args.optimizer == "sgd":
53 | optimizer = torch.optim.SGD([{'params': vae_params, 'weight_decay': 0.0}, {'params': other_params}],
54 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
55 | elif args.optimizer == 'adamw':
56 | optimizer = torch.optim.AdamW([{'params': vae_params, 'weight_decay': 0.0}, {'params': other_params}], lr=args.lr)
57 | else:
58 | raise ValueError("unsupported optimizer type")
59 |
60 | for group in optimizer.param_groups:
61 | group["lr"] = args.lr
62 | group["initial_lr"] = args.lr
63 |
64 | if args.scheduler == "step":
65 | assert isinstance(args.lr_decay_step, int), "learning rate scheduler need integar lr_decay_step"
66 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma)
67 | elif args.scheduler == "multistep":
68 | if isinstance(args.lr_decay_step, list):
69 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_step, gamma=args.lr_decay_gamma)
70 | else:
71 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.epochs // 2, gamma=0.1)
72 | elif args.scheduler == "cosine":
73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs*iters_per_epoch,
74 | last_epoch=-1, eta_min=0)
75 | else:
76 | raise ValueError("Unrecognized learning rate scheduler {}".format(args.scheduler))
77 |
78 | main_scheduler = Warmup(optimizer, scheduler, init_lr_ratio=0., num_epochs=args.warmup_epochs,
79 | iters_per_epoch=iters_per_epoch)
80 | return optimizer, main_scheduler
81 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/netscripts/modelio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import traceback
4 | import warnings
5 | import torch
6 |
7 |
8 | def load_checkpoint(model, resume_path, strict=True, device=None):
9 | if os.path.isfile(resume_path):
10 | print("=> loading checkpoint '{}'".format(resume_path))
11 | if device is not None:
12 | checkpoint = torch.load(resume_path, map_location=device)
13 | else:
14 | checkpoint = torch.load(resume_path)
15 | if "module" in list(checkpoint["state_dict"].keys())[0]:
16 | state_dict = checkpoint["state_dict"]
17 | else:
18 | state_dict = {
19 | "module.{}".format(key): item
20 | for key, item in checkpoint["state_dict"].items()}
21 | print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint["epoch"]))
22 | missing_states = set(model.state_dict().keys()) - set(state_dict.keys())
23 | if len(missing_states) > 0:
24 | warnings.warn("Missing keys ! : {}".format(missing_states))
25 | model.load_state_dict(state_dict, strict=strict)
26 | else:
27 | raise ValueError("=> no checkpoint found at '{}'".format(resume_path))
28 | return checkpoint["epoch"]
29 |
30 |
31 | def save_checkpoint(state, checkpoint="checkpoint", filename="checkpoint.pth.tar"):
32 | filepath = os.path.join(checkpoint, filename)
33 | torch.save(state, filepath)
34 |
--------------------------------------------------------------------------------
/networks/affordance_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from networks.decoder_modules import VAE
4 |
5 |
6 | class AffordanceCVAE(nn.Module):
7 | def __init__(self, in_dim, hidden_dim, latent_dim, condition_dim, coord_dim=None,
8 | pred_len=4, condition_traj=True, z_scale=2.0):
9 | super().__init__()
10 | self.latent_dim = latent_dim
11 | self.condition_traj = condition_traj
12 | self.z_scale = z_scale
13 | if self.condition_traj:
14 | if coord_dim is None:
15 | coord_dim = hidden_dim // 2
16 | self.coord_dim = coord_dim
17 | self.traj_to_feature = nn.Sequential(
18 | nn.Linear(2*(pred_len+1), coord_dim*(pred_len+1), bias=False),
19 | nn.ELU(inplace=True))
20 | self.traj_context_fusion = nn.Sequential(
21 | nn.Linear(condition_dim+coord_dim*(pred_len+1), condition_dim, bias=False),
22 | nn.ELU(inplace=True))
23 |
24 | self.cvae = VAE(in_dim=in_dim, hidden_dim=hidden_dim, latent_dim=latent_dim,
25 | conditional=True, condition_dim=condition_dim)
26 |
27 | def forward(self, context, contact_point, hand_traj=None, return_pred=False):
28 | if self.condition_traj:
29 | assert hand_traj is not None
30 | batch_size = context.shape[0]
31 | hand_traj = hand_traj.reshape(batch_size, -1)
32 | traj_feat = self.traj_to_feature(hand_traj)
33 | fusion_feat = torch.cat([context, traj_feat], dim=1)
34 | condition_context = self.traj_context_fusion(fusion_feat)
35 | else:
36 | condition_context = context
37 | if not return_pred:
38 | recon_loss, KLD = self.cvae(contact_point, c=condition_context)
39 | return recon_loss, KLD
40 | else:
41 | pred_contact, recon_loss, KLD = self.cvae(contact_point, c=condition_context, return_pred=return_pred)
42 | return pred_contact, recon_loss, KLD
43 |
44 | def inference(self, context, hand_traj=None):
45 | if self.condition_traj:
46 | assert hand_traj is not None
47 | batch_size = context.shape[0]
48 | hand_traj = hand_traj.reshape(batch_size, -1)
49 | traj_feat = self.traj_to_feature(hand_traj)
50 | fusion_feat = torch.cat([context, traj_feat], dim=1)
51 | condition_context = self.traj_context_fusion(fusion_feat)
52 | else:
53 | condition_context = context
54 | z = self.z_scale * torch.randn([condition_context.shape[0], self.latent_dim], device=condition_context.device)
55 | recon_x = self.cvae.inference(z, c=condition_context)
56 | return recon_x
--------------------------------------------------------------------------------
/networks/decoder_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class VAE(nn.Module):
6 |
7 | def __init__(self, in_dim, hidden_dim, latent_dim, conditional=False, condition_dim=None):
8 |
9 | super().__init__()
10 |
11 | self.latent_dim = latent_dim
12 | self.conditional = conditional
13 |
14 | if self.conditional and condition_dim is not None:
15 | input_dim = in_dim + condition_dim
16 | dec_dim = latent_dim + condition_dim
17 | else:
18 | input_dim = in_dim
19 | dec_dim = latent_dim
20 | self.enc_MLP = nn.Sequential(
21 | nn.Linear(input_dim, hidden_dim),
22 | nn.ELU())
23 | self.linear_means = nn.Linear(hidden_dim, latent_dim)
24 | self.linear_log_var = nn.Linear(hidden_dim, latent_dim)
25 | self.dec_MLP = nn.Sequential(
26 | nn.Linear(dec_dim, hidden_dim),
27 | nn.ELU(),
28 | nn.Linear(hidden_dim, in_dim))
29 |
30 | def forward(self, x, c=None, return_pred=False):
31 | if self.conditional and c is not None:
32 | inp = torch.cat((x, c), dim=-1)
33 | else:
34 | inp = x
35 | h = self.enc_MLP(inp)
36 | mean = self.linear_means(h)
37 | log_var = self.linear_log_var(h)
38 | z = self.reparameterize(mean, log_var)
39 | if self.conditional and c is not None:
40 | z = torch.cat((z, c), dim=-1)
41 | recon_x = self.dec_MLP(z)
42 | recon_loss, KLD = self.loss_fn(recon_x, x, mean, log_var)
43 | if not return_pred:
44 | return recon_loss, KLD
45 | else:
46 | return recon_x, recon_loss, KLD
47 |
48 | def loss_fn(self, recon_x, x, mean, log_var):
49 | recon_loss = torch.sum((recon_x - x) ** 2, dim=1)
50 | KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1)
51 | return recon_loss, KLD
52 |
53 | def reparameterize(self, mu, log_var):
54 | std = torch.exp(0.5 * log_var)
55 | eps = torch.randn_like(std)
56 | return mu + eps * std
57 |
58 | def inference(self, z, c=None):
59 | if self.conditional and c is not None:
60 | z = torch.cat((z, c), dim=-1)
61 | recon_x = self.dec_MLP(z)
62 | return recon_x
--------------------------------------------------------------------------------
/networks/embedding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class PositionalEncoding(nn.Module):
8 | def __init__(self, d_model, max_len=5000):
9 | super(PositionalEncoding, self).__init__()
10 | pe = torch.zeros(max_len, d_model)
11 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
12 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
13 | pe[:, 0::2] = torch.sin(position * div_term)
14 | pe[:, 1::2] = torch.cos(position * div_term)
15 | pe = pe.unsqueeze(0)
16 |
17 | self.register_buffer('pe', pe)
18 |
19 | def forward(self, x):
20 | return x + self.pe[:, :x.shape[1]]
21 |
22 |
23 | class Encoder_PositionalEmbedding(nn.Module):
24 | def __init__(self, d_model, seq_len):
25 | super(Encoder_PositionalEmbedding, self).__init__()
26 | self.position_embedding = nn.Parameter(torch.zeros(1, seq_len, d_model))
27 |
28 | def forward(self, x):
29 | B, T = x.shape[:2]
30 | if T != self.position_embedding.size(1):
31 | position_embedding = self.position_embedding.transpose(1, 2)
32 | new_position_embedding = F.interpolate(position_embedding, size=(T), mode='nearest')
33 | new_position_embedding = new_position_embedding.transpose(1, 2)
34 | x = x + new_position_embedding
35 | else:
36 | x = x + self.position_embedding
37 | return x
38 |
39 |
40 | class Decoder_PositionalEmbedding(nn.Module):
41 | def __init__(self, d_model, seq_len):
42 | super(Decoder_PositionalEmbedding, self).__init__()
43 | self.position_embedding = nn.Parameter(torch.zeros(1, seq_len, d_model))
44 |
45 | def forward(self, x):
46 | x = x + self.position_embedding[:, :x.shape[1], :]
47 | return x
48 |
49 |
--------------------------------------------------------------------------------
/networks/layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from networks.net_utils import DropPath, get_pad_mask
5 | from einops import rearrange
6 |
7 |
8 | class Mlp(nn.Module):
9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
10 | super().__init__()
11 | out_features = out_features or in_features
12 | hidden_features = hidden_features or in_features
13 | self.fc1 = nn.Linear(in_features, hidden_features)
14 | self.act = act_layer()
15 | self.fc2 = nn.Linear(hidden_features, out_features)
16 | self.drop = nn.Dropout(drop)
17 |
18 | def forward(self, x):
19 | x = self.fc1(x)
20 | x = self.act(x)
21 | x = self.drop(x)
22 | x = self.fc2(x)
23 | x = self.drop(x)
24 | return x
25 |
26 |
27 | class ScaledDotProductAttention(nn.Module):
28 | def __init__(self, temperature, attn_dropout=0.1):
29 | super().__init__()
30 | self.temperature = temperature
31 | self.dropout = nn.Dropout(attn_dropout)
32 |
33 | def forward(self, q, k, v, mask=None):
34 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
35 | if mask is not None:
36 | attn = attn.masked_fill(mask == 0, -1e9)
37 | attn = self.dropout(F.softmax(attn, dim=-1))
38 | output = torch.matmul(attn, v)
39 |
40 | return output, attn
41 |
42 |
43 | class MultiHeadAttention(nn.Module):
44 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True):
45 | super().__init__()
46 | self.num_heads = num_heads
47 | head_dim = dim // num_heads
48 | self.with_qkv = with_qkv
49 | if self.with_qkv:
50 | self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
51 | self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
52 | self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)
53 |
54 | self.proj = nn.Linear(dim, dim)
55 | self.proj_drop = nn.Dropout(proj_drop)
56 | self.attention = ScaledDotProductAttention(temperature=qk_scale or head_dim ** 0.5)
57 | self.attn_drop = nn.Dropout(attn_drop)
58 |
59 | def forward(self, q, k, v, mask=None):
60 | B, Nq, Nk, Nv, C = q.shape[0], q.shape[1], k.shape[1], v.shape[1], q.shape[2]
61 | if self.with_qkv:
62 | q = self.proj_q(q).reshape(B, Nq, self.num_heads, C // self.num_heads).transpose(1, 2)
63 | k = self.proj_k(k).reshape(B, Nk, self.num_heads, C // self.num_heads).transpose(1, 2)
64 | v = self.proj_v(v).reshape(B, Nv, self.num_heads, C // self.num_heads).transpose(1, 2)
65 | else:
66 | q = q.reshape(B, Nq, self.num_heads, C // self.num_heads).transpose(1, 2)
67 | k = k.reshape(B, Nk, self.num_heads, C // self.num_heads).transpose(1, 2)
68 | v = v.reshape(B, Nv, self.num_heads, C // self.num_heads).transpose(1, 2)
69 | if mask is not None:
70 | mask = mask.unsqueeze(1)
71 |
72 | x, attn = self.attention(q, k, v, mask=mask)
73 | x = x.transpose(1, 2).reshape(B, Nq, C)
74 | if self.with_qkv:
75 | x = self.proj(x)
76 | x = self.proj_drop(x)
77 | return x
78 |
79 |
80 | class EncoderBlock(nn.Module):
81 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
82 | drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
83 | super().__init__()
84 | self.norm1 = norm_layer(dim)
85 | self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
86 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
87 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
88 | self.norm2 = norm_layer(dim)
89 | mlp_hidden_dim = int(dim * mlp_ratio)
90 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
91 |
92 | def forward(self, x, B, T, N, mask=None):
93 |
94 | if mask is not None:
95 | src_mask = rearrange(mask, 'b n t -> b (n t)', b=B, n=N, t=T)
96 | src_mask = get_pad_mask(src_mask, 0)
97 | else:
98 | src_mask = None
99 | x2 = self.norm1(x)
100 | x = x + self.drop_path(self.attn(q=x2, k=x2, v=x2, mask=src_mask))
101 | x = x + self.drop_path(self.mlp(self.norm2(x)))
102 | return x
103 |
104 |
105 | class DecoderBlock(nn.Module):
106 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
107 | drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
108 | super().__init__()
109 | self.norm1 = norm_layer(dim)
110 | self.self_attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
111 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
112 |
113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114 |
115 | self.norm2 = norm_layer(dim)
116 | self.enc_dec_attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
117 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
118 |
119 | self.norm3 = nn.LayerNorm(dim)
120 | mlp_hidden_dim = int(dim * mlp_ratio)
121 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
122 |
123 | def forward(self, tgt, memory, memory_mask=None, trg_mask=None):
124 | tgt_2 = self.norm1(tgt)
125 | tgt = tgt + self.drop_path(self.self_attn(q=tgt_2, k=tgt_2, v=tgt_2, mask=trg_mask))
126 | tgt = tgt + self.drop_path(self.enc_dec_attn(q=self.norm2(tgt), k=memory, v=memory, mask=memory_mask))
127 | tgt = tgt + self.drop_path(self.mlp(self.norm2(tgt)))
128 | return tgt
129 |
--------------------------------------------------------------------------------
/networks/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Model(nn.Module):
6 |
7 | def __init__(self, net, lambda_obj=None, lambda_traj=None, lambda_obj_kl=None, lambda_traj_kl=None):
8 | super(Model, self).__init__()
9 | self.net = net
10 | self.lambda_obj = lambda_obj
11 | self.lambda_obj_kl = lambda_obj_kl
12 | self.lambda_traj = lambda_traj
13 | self.lambda_traj_kl = lambda_traj_kl
14 |
15 | def forward(self, feat, bbox_feat, valid_mask, future_hands=None, contact_point=None, future_valid=None,
16 | num_samples=5, pred_len=4):
17 | if self.training:
18 | losses = {}
19 | total_loss = 0
20 | traj_loss, traj_kl_loss, obj_loss, obj_kl_loss = self.net(feat, bbox_feat, valid_mask, future_hands,
21 | contact_point, future_valid)
22 |
23 | if self.lambda_traj is not None and traj_loss is not None:
24 | traj_loss = self.lambda_traj * traj_loss.sum()
25 | total_loss += traj_loss
26 | losses['traj_loss'] = traj_loss.detach().cpu()
27 | else:
28 | losses['traj_loss'] = 0.
29 |
30 | if self.lambda_traj_kl is not None and traj_kl_loss is not None:
31 | traj_kl_loss = self.lambda_traj_kl * traj_kl_loss.sum()
32 | total_loss += traj_kl_loss
33 | losses['traj_kl_loss'] = traj_kl_loss.detach().cpu()
34 | else:
35 | losses['traj_kl_loss'] = 0.
36 |
37 | if self.lambda_obj is not None and obj_loss is not None:
38 | obj_loss = self.lambda_obj * obj_loss.sum()
39 | total_loss += obj_loss
40 | losses['obj_loss'] = obj_loss.detach().cpu()
41 | else:
42 | losses['obj_loss'] = 0.
43 |
44 | if self.lambda_obj_kl is not None and obj_kl_loss is not None:
45 | obj_kl_loss = self.lambda_obj_kl * obj_kl_loss.sum()
46 | total_loss += obj_kl_loss
47 | losses['obj_kl_loss'] = obj_kl_loss.detach().cpu()
48 | else:
49 | losses['obj_kl_loss'] = 0.
50 |
51 | if total_loss is not None:
52 | losses["total_loss"] = total_loss.detach().cpu()
53 | else:
54 | losses["total_loss"] = 0.
55 | return total_loss, losses
56 | else:
57 | future_hands_list = []
58 | contact_points_list = []
59 | for i in range(num_samples):
60 | future_hands, contact_point = self.net.module.inference(feat, bbox_feat, valid_mask,
61 | future_valid=future_valid,
62 | pred_len=pred_len)
63 | future_hands_list.append(future_hands)
64 | contact_points_list.append(contact_point)
65 |
66 | contact_points = torch.stack(contact_points_list, dim=0)
67 |
68 | assert len(contact_points.shape) == 3
69 | contact_points = contact_points.transpose(0, 1)
70 |
71 | future_hands_list = torch.stack(future_hands_list, dim=0)
72 | future_hands_list = future_hands_list.transpose(0, 1)
73 | return future_hands_list, contact_points
74 |
--------------------------------------------------------------------------------
/networks/net_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import warnings
5 |
6 |
7 | def get_pad_mask(seq, pad_idx=0):
8 | if seq.dim() != 2:
9 | raise ValueError(" has to be a 2-dimensional tensor!")
10 | if not isinstance(pad_idx, int):
11 | raise TypeError(" has to be an int!")
12 |
13 | return (seq != pad_idx).unsqueeze(1) # equivalent (seq != pad_idx).unsqueeze(-2)
14 |
15 |
16 | def get_subsequent_mask(seq, diagonal=1):
17 | if seq.dim() < 2:
18 | raise ValueError(" has to be at least a 2-dimensional tensor!")
19 |
20 | seq_len = seq.size(1)
21 | mask = (1 - torch.triu(torch.ones((1, seq_len, seq_len), device=seq.device), diagonal=diagonal)).bool()
22 | return mask
23 |
24 |
25 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
26 | def norm_cdf(x):
27 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
28 |
29 | if (mean < a - 2 * std) or (mean > b + 2 * std):
30 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
31 | "The distribution of values may be incorrect.",
32 | stacklevel=2)
33 |
34 | with torch.no_grad():
35 | l = norm_cdf((a - mean) / std)
36 | u = norm_cdf((b - mean) / std)
37 |
38 | tensor.uniform_(2 * l - 1, 2 * u - 1)
39 | tensor.erfinv_()
40 |
41 | tensor.mul_(std * math.sqrt(2.))
42 | tensor.add_(mean)
43 | tensor.clamp_(min=a, max=b)
44 | return tensor
45 |
46 |
47 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
48 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
49 |
50 |
51 | def drop_path(x, drop_prob: float = 0., training: bool = False):
52 | if drop_prob == 0. or not training:
53 | return x
54 | keep_prob = 1 - drop_prob
55 | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
56 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
57 | random_tensor.floor_()
58 | output = x.div(keep_prob) * random_tensor
59 | return output
60 |
61 |
62 | class DropPath(nn.Module):
63 | def __init__(self, drop_prob=None):
64 | super(DropPath, self).__init__()
65 | self.drop_prob = drop_prob
66 |
67 | def forward(self, x):
68 | return drop_path(x, self.drop_prob, self.training)
69 |
70 |
71 | def traj_affordance_dist(hand_traj, contact_point, future_valid=None, invalid_value=9):
72 | batch_size = contact_point.shape[0]
73 | expand_size = int(hand_traj.shape[0] / batch_size)
74 | contact_point = contact_point.unsqueeze(dim=1).expand(-1, expand_size, 2).reshape(-1, 2) # (B * 2 * Tf, 2)
75 | dist = torch.sum((hand_traj - contact_point) ** 2, dim=1).reshape(batch_size, -1) # (B, 2 * Tf)
76 | if future_valid is None:
77 | sorted_dist, sorted_idx = torch.sort(dist, dim=-1, descending=False) # from small to high
78 | return sorted_dist[:, 0] # (B, )
79 | else:
80 | dist = dist.reshape(batch_size, 2, -1) # (B, 2, Tf)
81 | future_valid = future_valid > 0
82 | future_invalid = ~future_valid[:, :, None].expand(dist.shape)
83 | dist[future_invalid] = invalid_value # set invalid dist to be very large
84 | sorted_dist, sorted_idx = torch.sort(dist, dim=-1, descending=False) # from small to high
85 | selected_dist = sorted_dist[:, :, 0] # (B, 2)
86 | selected_dist, selected_idx = selected_dist.min(dim=1) # selected_dist, selected_idx (B, )
87 | valid = torch.gather(future_valid, dim=1, index=selected_idx.unsqueeze(dim=1)).squeeze(dim=1)
88 | selected_dist = selected_dist * valid
89 |
90 | return selected_dist
--------------------------------------------------------------------------------
/networks/traj_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import einops
4 | from networks.affordance_decoder import VAE
5 |
6 |
7 | class TrajCVAE(nn.Module):
8 | def __init__(self, in_dim, hidden_dim, latent_dim, condition_dim, coord_dim=None,
9 | condition_contact=False, z_scale=2.0):
10 | super().__init__()
11 | self.latent_dim = latent_dim
12 | self.condition_contact = condition_contact
13 | self.z_scale = z_scale
14 | if self.condition_contact:
15 | if coord_dim is None:
16 | coord_dim = hidden_dim // 2
17 | self.coord_dim = coord_dim
18 | self.contact_to_feature = nn.Sequential(
19 | nn.Linear(2, coord_dim, bias=False),
20 | nn.ELU(inplace=True))
21 | self.contact_context_fusion = nn.Sequential(
22 | nn.Linear(condition_dim+coord_dim, condition_dim, bias=False),
23 | nn.ELU(inplace=True))
24 |
25 | self.cvae = VAE(in_dim=in_dim, hidden_dim=hidden_dim, latent_dim=latent_dim,
26 | conditional=True, condition_dim=condition_dim)
27 |
28 | def forward(self, context, target_hand, future_valid, contact_point=None, return_pred=False):
29 | batch_size = future_valid.shape[0]
30 | if self.condition_contact:
31 | assert contact_point is not None
32 | time_steps = int(context.shape[0] / batch_size / 2)
33 | contact_feat = self.contact_to_feature(contact_point)
34 | contact_feat = einops.repeat(contact_feat, 'm n -> m p q n', p=2, q=time_steps)
35 | contact_feat = contact_feat.reshape(-1, self.coord_dim)
36 | fusion_feat = torch.cat([context, contact_feat], dim=1)
37 | condition_context = self.contact_context_fusion(fusion_feat)
38 | else:
39 | condition_context = context
40 | if not return_pred:
41 | recon_loss, KLD = self.cvae(target_hand, c=condition_context)
42 | else:
43 | pred_hand, recon_loss, KLD = self.cvae(target_hand, c=condition_context, return_pred=return_pred)
44 | KLD = KLD.reshape(batch_size, 2, -1).sum(-1)
45 | KLD = (KLD * future_valid).sum(1)
46 | recon_loss = recon_loss.reshape(batch_size, 2, -1).sum(-1)
47 | traj_loss = (recon_loss * future_valid).sum(1)
48 | if not return_pred:
49 | return traj_loss, KLD
50 | else:
51 | return pred_hand, traj_loss, KLD
52 |
53 | def inference(self, context, contact_point=None):
54 | if self.condition_contact:
55 | assert contact_point is not None
56 | batch_size = contact_point.shape[0]
57 | time_steps = int(context.shape[0] / batch_size)
58 | contact_feat = self.contact_to_feature(contact_point)
59 | contact_feat = einops.repeat(contact_feat, 'm n -> m p n', p=time_steps)
60 | contact_feat = contact_feat.reshape(-1, self.coord_dim)
61 | fusion_feat = torch.cat([context, contact_feat], dim=1)
62 | condition_context = self.contact_context_fusion(fusion_feat)
63 | else:
64 | condition_context = context
65 | z = self.z_scale * torch.randn([context.shape[0], self.latent_dim], device=context.device)
66 | recon_x = self.cvae.inference(z, c=condition_context)
67 | return recon_x
--------------------------------------------------------------------------------
/networks/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from networks.embedding import PositionalEncoding, Encoder_PositionalEmbedding, Decoder_PositionalEmbedding
5 | from networks.layer import EncoderBlock, DecoderBlock
6 | from networks.net_utils import trunc_normal_, get_pad_mask, get_subsequent_mask, traj_affordance_dist
7 |
8 |
9 | class Encoder(nn.Module):
10 | def __init__(self, num_patches=5, embed_dim=512, depth=6, num_heads=8, mlp_ratio=4.,
11 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
12 | drop_path_rate=0.1, norm_layer=nn.LayerNorm,
13 | dropout=0., time_embed_type=None, num_frames=None):
14 | super().__init__()
15 | if time_embed_type is None or num_frames is None:
16 | time_embed_type = 'sin'
17 | self.time_embed_type = time_embed_type
18 | self.num_patches = num_patches # (hand, object global feature patches, default: 5)
19 | self.depth = depth
20 | self.dropout = nn.Dropout(dropout)
21 | self.num_features = self.embed_dim = embed_dim
22 |
23 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
24 | self.pos_drop = nn.Dropout(p=drop_rate)
25 | if not self.time_embed_type == "sin" and num_frames is not None:
26 | self.time_embed = Encoder_PositionalEmbedding(embed_dim, seq_len=num_frames)
27 | else:
28 | self.time_embed = PositionalEncoding(embed_dim)
29 | self.time_drop = nn.Dropout(p=drop_rate)
30 |
31 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)]
32 | self.encoder_blocks = nn.ModuleList([EncoderBlock(
33 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
34 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
35 | for i in range(self.depth)])
36 | self.norm = norm_layer(embed_dim)
37 | trunc_normal_(self.pos_embed, std=.02)
38 |
39 | @torch.jit.ignore
40 | def no_weight_decay(self):
41 | return {'pos_embed', 'time_embed'}
42 |
43 | def forward(self, x, mask=None):
44 | B, T, N = x.shape[:3]
45 |
46 | x = rearrange(x, 'b t n m -> (b t) n m', b=B, t=T, n=N)
47 | x = x + self.pos_embed
48 | x = self.pos_drop(x)
49 |
50 | x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
51 | x = self.time_embed(x)
52 | x = self.time_drop(x)
53 | x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)
54 |
55 | mask = mask.transpose(1, 2)
56 | for blk in self.encoder_blocks:
57 | x = blk(x, B, T, N, mask=mask)
58 |
59 | x = rearrange(x, 'b (n t) m -> b t n m', b=B, t=T, n=N)
60 | x = self.norm(x)
61 | return x
62 |
63 |
64 | class Decoder(nn.Module):
65 | def __init__(self, in_features, embed_dim=512, depth=6, num_heads=8, mlp_ratio=4.,
66 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
67 | drop_path_rate=0.1, norm_layer=nn.LayerNorm, dropout=0.,
68 | time_embed_type=None, num_frames=None):
69 | super().__init__()
70 | self.depth = depth
71 | self.dropout = nn.Dropout(dropout)
72 | self.num_features = self.embed_dim = embed_dim
73 |
74 | self.trg_embedding = nn.Linear(in_features, embed_dim)
75 |
76 | if time_embed_type is None or num_frames is None:
77 | time_embed_type = 'sin'
78 | self.time_embed_type = time_embed_type
79 | if not self.time_embed_type == "sin" and num_frames is not None:
80 | self.time_embed = Decoder_PositionalEmbedding(embed_dim, seq_len=num_frames)
81 | else:
82 | self.time_embed = PositionalEncoding(embed_dim)
83 | self.time_drop = nn.Dropout(p=drop_rate)
84 |
85 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)]
86 | self.decoder_blocks = nn.ModuleList([DecoderBlock(
87 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
88 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
89 | for i in range(self.depth)])
90 | self.norm = norm_layer(embed_dim)
91 |
92 | @torch.jit.ignore
93 | def no_weight_decay(self):
94 | return {'pos_embed', 'time_embed'}
95 |
96 | def forward(self, trg, memory, memory_mask=None, trg_mask=None):
97 | trg = self.trg_embedding(trg)
98 | trg = self.time_embed(trg)
99 | trg = self.time_drop(trg)
100 |
101 | for blk in self.decoder_blocks:
102 | trg = blk(trg, memory, memory_mask=memory_mask, trg_mask=trg_mask)
103 |
104 | trg = self.norm(trg)
105 | return trg
106 |
107 |
108 | class ObjectTransformer(nn.Module):
109 |
110 | def __init__(self, src_in_features, trg_in_features, num_patches,
111 | hand_head, obj_head,
112 | embed_dim=512, coord_dim=64, num_heads=8, enc_depth=6, dec_depth=4,
113 | mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
114 | drop_path_rate=0.1, norm_layer=nn.LayerNorm, dropout=0.,
115 | encoder_time_embed_type='sin', decoder_time_embed_type='sin',
116 | num_frames_input=None, num_frames_output=None):
117 | super().__init__()
118 | self.num_features = self.embed_dim = embed_dim
119 | self.coord_dim = coord_dim
120 | self.downproject = nn.Linear(src_in_features, embed_dim)
121 |
122 | self.bbox_to_feature = nn.Sequential(
123 | nn.Linear(4, self.coord_dim // 2),
124 | nn.ELU(inplace=True),
125 | nn.Linear(self.coord_dim // 2, self.coord_dim),
126 | nn.ELU()
127 | )
128 | self.feat_fusion = nn.Sequential(
129 | nn.Linear(self.embed_dim + self.coord_dim, self.embed_dim),
130 | nn.ELU(inplace=True))
131 |
132 | self.encoder = Encoder(num_patches=num_patches,
133 | embed_dim=embed_dim, depth=enc_depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
134 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
135 | drop_path_rate=drop_path_rate, norm_layer=norm_layer, dropout=dropout,
136 | time_embed_type=encoder_time_embed_type, num_frames=num_frames_input)
137 |
138 | self.decoder = Decoder(in_features=trg_in_features, embed_dim=embed_dim,
139 | depth=dec_depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
140 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate,
141 | attn_drop_rate=attn_drop_rate,
142 | drop_path_rate=drop_path_rate, norm_layer=norm_layer, dropout=dropout,
143 | time_embed_type=decoder_time_embed_type, num_frames=num_frames_output)
144 |
145 | self.hand_head = hand_head
146 | self.object_head = obj_head
147 | self.apply(self._init_weights)
148 |
149 | def _init_weights(self, m):
150 | if isinstance(m, nn.Linear):
151 | trunc_normal_(m.weight, std=.02)
152 | if isinstance(m, nn.Linear) and m.bias is not None:
153 | nn.init.constant_(m.bias, 0)
154 | elif isinstance(m, nn.LayerNorm):
155 | nn.init.constant_(m.bias, 0)
156 | nn.init.constant_(m.weight, 1.0)
157 |
158 | def encoder_input(self, feat, bbox_feat):
159 | B, T = feat.shape[0], feat.shape[2]
160 | feat = self.downproject(feat)
161 | bbox_feat = bbox_feat.view(-1, 4)
162 | bbox_feat = self.bbox_to_feature(bbox_feat)
163 | bbox_feat = bbox_feat.view(B, -1, T, self.coord_dim)
164 | ho_feat = feat[:, 1:, :, :]
165 | global_feat = feat[:, 0:1, :, :]
166 | feat = torch.cat((ho_feat, bbox_feat), dim=-1)
167 | feat = feat.view(-1, self.embed_dim + self.coord_dim)
168 | feat = self.feat_fusion(feat)
169 | feat = feat.view(B, -1, T, self.embed_dim)
170 | feat = torch.cat((global_feat, feat), dim=1)
171 | feat = feat.transpose(1, 2)
172 | return feat
173 |
174 | def forward(self, feat, bbox_feat, valid_mask, future_hands, contact_point, future_valid):
175 | # feat: (B, 5, T, src_in_features), global, hand & obj, T=10
176 | # bbox_feat: (B, 4, T, 4), hand & obj
177 | # valid_mask: (B, 4, T), hand & obj / (B, 5, T), hand & obj & global
178 | # future_hands: (B, 2, T, 2) right & left, T=5 (contain last observation frame)
179 | # contact_points: (B, 2)
180 | # future_valid: (B, 2), right & left traj valid
181 | # return: traj_loss: (B), obj_loss (B)
182 | if not valid_mask.shape[1] == feat.shape[1]:
183 | src_mask = torch.cat(
184 | (torch.ones_like(valid_mask[:, 0:1, :], dtype=valid_mask.dtype, device=valid_mask.device),
185 | valid_mask), dim=1).transpose(1, 2)
186 | else:
187 | src_mask = valid_mask.transpose(1, 2)
188 | feat = self.encoder_input(feat, bbox_feat)
189 | x = self.encoder(feat, mask=src_mask)
190 |
191 | memory = x[:, -1, :, :]
192 | memory_mask = get_pad_mask(src_mask[:, -1, :], pad_idx=0)
193 |
194 | future_rhand, future_lhand = future_hands[:, 0, :, :], future_hands[:, 1, :, :]
195 |
196 | rhand_input = future_rhand[:, :-1, :]
197 | lhand_input = future_lhand[:, :-1, :]
198 | trg_mask = torch.ones_like(rhand_input[:, :, 0])
199 | trg_mask = get_subsequent_mask(trg_mask)
200 |
201 | x_rhand = self.decoder(rhand_input, memory,
202 | memory_mask=memory_mask, trg_mask=trg_mask)
203 | x_lhand = self.decoder(lhand_input, memory,
204 | memory_mask=memory_mask, trg_mask=trg_mask)
205 |
206 | x_hand = torch.cat((x_rhand, x_lhand), dim=1)
207 | x_hand = x_hand.reshape(-1, self.embed_dim)
208 |
209 | target_hand = future_hands[:, :, 1:, :]
210 | target_hand = target_hand.reshape(-1, target_hand.shape[-1])
211 |
212 | pred_hand, traj_loss, traj_kl_loss = self.hand_head(x_hand, target_hand, future_valid, contact=None,
213 | return_pred=True)
214 |
215 | r_pred_contact, r_obj_loss, r_obj_kl_loss = self.object_head(memory[:, 0, :], contact_point, future_rhand,
216 | return_pred=True)
217 | l_pred_contact, l_obj_loss, l_obj_kl_loss = self.object_head(memory[:, 0, :], contact_point, future_lhand,
218 | return_pred=True)
219 |
220 | obj_loss = torch.stack([r_obj_loss, l_obj_loss], dim=1)
221 | obj_kl_loss = torch.stack([r_obj_kl_loss, l_obj_kl_loss], dim=1)
222 |
223 | obj_loss[~(future_valid > 0)] = 1e9
224 | selected_obj_loss, selected_idx = obj_loss.min(dim=1)
225 | selected_valid = torch.gather(future_valid, dim=1, index=selected_idx.unsqueeze(dim=1)).squeeze(dim=1)
226 | selected_obj_kl_loss = torch.gather(obj_kl_loss, dim=1, index=selected_idx.unsqueeze(dim=1)).squeeze(dim=1)
227 | obj_loss = selected_obj_loss * selected_valid
228 | obj_kl_loss = selected_obj_kl_loss * selected_valid
229 |
230 | return traj_loss, traj_kl_loss, obj_loss, obj_kl_loss
231 |
232 | def inference(self, feat, bbox_feat, valid_mask, future_valid=None, pred_len=4):
233 | # feat: (B, 5, T, src_in_features), hand & obj, T=10
234 | # bbox_feat: (B, 4, T, 4), hand & obj
235 | # valid_mask: (B, 4, T), hand & obj
236 | # future_valid: (B, 2), right & left traj valid
237 | # return: future_hand (B, 2, T, 2), not include last observation frame
238 | # return: contact_point (B, 2)
239 | B, T = feat.shape[0], feat.shape[2]
240 | if not valid_mask.shape[1] == feat.shape[1]:
241 | src_mask = torch.cat(
242 | (torch.ones_like(valid_mask[:, 0:1, :], dtype=valid_mask.dtype, device=valid_mask.device),
243 | valid_mask), dim=1).transpose(1, 2)
244 | else:
245 | src_mask = valid_mask.transpose(1, 2)
246 | feat = self.encoder_input(feat, bbox_feat)
247 | x = self.encoder(feat, mask=src_mask)
248 |
249 | memory = x[:, -1, :, :]
250 | memory_mask = get_pad_mask(src_mask[:, -1, :], pad_idx=0)
251 |
252 | observe_bbox = bbox_feat[:, :2, -1, :]
253 | observe_rhand, observe_lhand = observe_bbox[:, 0, :], observe_bbox[:, 1, :]
254 | future_rhand = (observe_rhand[:, :2] + observe_rhand[:, 2:]) / 2
255 | future_lhand = (observe_lhand[:, :2] + observe_lhand[:, 2:]) / 2
256 | future_rhand = future_rhand.unsqueeze(dim=1)
257 | future_lhand = future_lhand.unsqueeze(dim=1)
258 |
259 | pred_contact = None
260 | for i in range(pred_len):
261 | trg_mask = torch.ones_like(future_rhand[:, :, 0])
262 | trg_mask = get_subsequent_mask(trg_mask)
263 |
264 | x_rhand = self.decoder(future_rhand, memory,
265 | memory_mask=memory_mask, trg_mask=trg_mask)
266 | x_hand = x_rhand.reshape(-1, self.embed_dim)
267 | loc_rhand = self.hand_head.inference(x_hand, pred_contact)
268 | loc_rhand = loc_rhand.reshape(B, -1, 2)
269 | pred_rhand = loc_rhand[:, -1:, :]
270 | future_rhand = torch.cat((future_rhand, pred_rhand), dim=1)
271 |
272 | for i in range(pred_len):
273 | trg_mask = torch.ones_like(future_lhand[:, :, 0])
274 | trg_mask = get_subsequent_mask(trg_mask)
275 |
276 | x_lhand = self.decoder(future_lhand, memory,
277 | memory_mask=memory_mask, trg_mask=trg_mask)
278 | x_hand = x_lhand.reshape(-1, self.embed_dim)
279 | loc_lhand = self.hand_head.inference(x_hand, pred_contact)
280 | loc_lhand = loc_lhand.reshape(B, -1, 2)
281 | pred_lhand = loc_lhand[:, -1:, :]
282 | future_lhand = torch.cat((future_lhand, pred_lhand), dim=1)
283 |
284 | future_hands = torch.stack((future_rhand[:, 1:, :], future_lhand[:, 1:, :]), dim=1)
285 |
286 | r_pred_contact = self.object_head.inference(memory[:, 0, :], future_rhand)
287 | l_pred_contact = self.object_head.inference(memory[:, 0, :], future_lhand)
288 | pred_contact = torch.stack([r_pred_contact, l_pred_contact], dim=1)
289 |
290 | if future_valid is not None and torch.all(future_valid.sum(dim=1) >= 1):
291 | r_pred_contact_dist = traj_affordance_dist(future_hands.reshape(-1, 2), r_pred_contact,
292 | future_valid)
293 | l_pred_contact_dist = traj_affordance_dist(future_hands.reshape(-1, 2), l_pred_contact,
294 | future_valid)
295 | pred_contact_dist = torch.stack((r_pred_contact_dist, l_pred_contact_dist), dim=1)
296 | _, selected_idx = pred_contact_dist.min(dim=1)
297 | selected_idx = selected_idx.unsqueeze(dim=1).unsqueeze(dim=2).expand(pred_contact.shape[0], 1,
298 | pred_contact.shape[2])
299 | pred_contact = torch.gather(pred_contact, dim=1, index=selected_idx).squeeze(dim=1)
300 |
301 | return future_hands, pred_contact
302 |
--------------------------------------------------------------------------------
/options/expopts.py:
--------------------------------------------------------------------------------
1 | def add_exp_opts(parser):
2 | parser.add_argument("--resume", type=str, nargs="+", metavar="PATH",
3 | help="path to latest checkpoint (default: none)")
4 | parser.add_argument("--evaluate", dest="evaluate", action="store_true",
5 | help="evaluate model on validation set")
6 | parser.add_argument("--test_freq", type=int, default=10,
7 | help="testing frequency on evaluation dataset (set specific in traineval.py)")
8 | parser.add_argument("--snapshot", default=10, type=int, metavar="N",
9 | help="How often to take a snapshot of the model (0 = never)")
10 | parser.add_argument("--use_cuda", default=1, type=int, help="use GPU (default: True)")
11 | parser.add_argument('--ek_version', default="ek55", choices=["ek55", "ek100"], help="epic dataset version")
12 | parser.add_argument("--traj_only", action="store_true", help="evaluate traj on validation dataset")
--------------------------------------------------------------------------------
/options/netsopts.py:
--------------------------------------------------------------------------------
1 | def add_nets_opts(parser):
2 | parser.add_argument('--src_in_features', type=int, default=1024, help='Network encoder input size')
3 | parser.add_argument('--trg_in_features', type=int, default=2, help='Network decoder input size')
4 | parser.add_argument('--num_patches', type=int, default=5, help='Number of classes')
5 | parser.add_argument('--num_classes', type=int, default=2513, help='Number of classes')
6 |
7 | parser.add_argument('--embed_dim', type=int, default=512, help='embedded dimension')
8 | parser.add_argument('--num_heads', type=int, default=8, help='num of heads in transformer')
9 | parser.add_argument('--enc_depth', type=int, default=6, help='transformer encoder depth')
10 | parser.add_argument('--dec_depth', type=int, default=4, help='transformer decoder depth')
11 |
12 | parser.add_argument('--coord_dim', type=int, default=64, help='coordinates feature dimension')
13 | parser.add_argument('--hidden_dim', type=int, default=512, help='stochastic modules hidden dimension')
14 | parser.add_argument('--latent_dim', type=int, default=256, help='stochastic modules latent dimension')
15 |
16 | parser.add_argument("--encoder_time_embed_type", default="sin",
17 | choices=["sin", "param"], help="transformer encoder time position embedding")
18 | parser.add_argument("--decoder_time_embed_type", default="sin",
19 | choices=["sin", "param"], help="transformer decoder time position embedding")
20 |
21 | parser.add_argument("--num_samples", default=20, type=int, help="get number of samples during inference, "
22 | "stochastic model multiple runs")
23 | parser.add_argument("--num_points", default=5, type=int,
24 | help="number of remaining contact points after farthest point "
25 | "sampling for evaluation affordance")
26 | parser.add_argument("--gaussian_sigma", default=3., type=float,
27 | help="predicted contact points gaussian kernel sigma")
28 | parser.add_argument("--gaussian_k_ratio", default=3., type=float,
29 | help="predicted contact points gaussian kernel size")
30 |
31 | # Options for loss supervision
32 | parser.add_argument("--lambda_obj", default=1e-1, type=float, help="Weight to supervise object affordance")
33 | parser.add_argument("--lambda_traj", default=1., type=float, help="Weight to supervise hand traj")
34 | parser.add_argument("--lambda_obj_kl", default=1e-3, type=float, help="Weight to supervise object affordance KLD")
35 | parser.add_argument("--lambda_traj_kl", default=1e-3, type=float, help="Weight to supervise hand traj KLD")
36 |
37 |
38 | def add_train_opts(parser):
39 | parser.add_argument("--manual_seed", default=0, type=int, help="manual seed")
40 | parser.add_argument("-j", "--workers", default=16, type=int, help="number of workers")
41 | parser.add_argument("--epochs", default=35, type=int, help="number epochs")
42 | parser.add_argument("--batch_size", default=128, type=int, help="batch size")
43 |
44 | parser.add_argument("--optimizer", default="adam", choices=["rms", "adam", "sgd", "adamw"])
45 | parser.add_argument("--lr", "--learning-rate", default=1e-4, type=float, metavar="LR", help="initial learning rate")
46 | parser.add_argument("--momentum", default=0.9, type=float)
47 |
48 | parser.add_argument("--scheduler", default="cosine", choices=['cosine', 'step', 'multistep'],
49 | help="learning rate scheduler")
50 | parser.add_argument("--warmup_epochs", default=5, type=int, help="number of warmup epochs to run")
51 | parser.add_argument("--lr_decay_step", nargs="+", default=10, type=int,
52 | help="Epochs after which to decay learning rate")
53 | parser.add_argument(
54 | "--lr_decay_gamma", default=0.5, type=float, help="Factor by which to decay the learning rate")
55 | parser.add_argument("--weight_decay", default=1e-4, type=float)
56 |
--------------------------------------------------------------------------------
/preprocess/affordance_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from preprocess.dataset_util import bbox_inter
4 |
5 |
6 | def skin_extract(image):
7 | def color_segmentation():
8 | lower_HSV_values = np.array([0, 40, 0], dtype="uint8")
9 | upper_HSV_values = np.array([25, 255, 255], dtype="uint8")
10 | lower_YCbCr_values = np.array((0, 138, 67), dtype="uint8")
11 | upper_YCbCr_values = np.array((255, 173, 133), dtype="uint8")
12 | mask_YCbCr = cv2.inRange(YCbCr_image, lower_YCbCr_values, upper_YCbCr_values)
13 | mask_HSV = cv2.inRange(HSV_image, lower_HSV_values, upper_HSV_values)
14 | binary_mask_image = cv2.add(mask_HSV, mask_YCbCr)
15 | return binary_mask_image
16 |
17 | HSV_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
18 | YCbCr_image = cv2.cvtColor(image, cv2.COLOR_BGR2YCR_CB)
19 | binary_mask_image = color_segmentation()
20 | image_foreground = cv2.erode(binary_mask_image, None, iterations=3)
21 | dilated_binary_image = cv2.dilate(binary_mask_image, None, iterations=3)
22 | ret, image_background = cv2.threshold(dilated_binary_image, 1, 128, cv2.THRESH_BINARY)
23 |
24 | image_marker = cv2.add(image_foreground, image_background)
25 | image_marker32 = np.int32(image_marker)
26 | cv2.watershed(image, image_marker32)
27 | m = cv2.convertScaleAbs(image_marker32)
28 | ret, image_mask = cv2.threshold(m, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
29 | kernel = np.ones((20, 20), np.uint8)
30 | image_mask = cv2.morphologyEx(image_mask, cv2.MORPH_CLOSE, kernel)
31 | return image_mask
32 |
33 |
34 | def farthest_sampling(pcd, n_samples, init_pcd=None):
35 | def compute_distance(a, b):
36 | return np.linalg.norm(a - b, ord=2, axis=2)
37 |
38 | n_pts, dim = pcd.shape[0], pcd.shape[1]
39 | selected_pts_expanded = np.zeros(shape=(n_samples, 1, dim))
40 | remaining_pts = np.copy(pcd)
41 |
42 | if init_pcd is None:
43 | if n_pts > 1:
44 | start_idx = np.random.randint(low=0, high=n_pts - 1)
45 | else:
46 | start_idx = 0
47 | selected_pts_expanded[0] = remaining_pts[start_idx]
48 | n_selected_pts = 1
49 | else:
50 | num_points = min(init_pcd.shape[0], n_samples)
51 | selected_pts_expanded[:num_points] = init_pcd[:num_points, None, :]
52 | n_selected_pts = num_points
53 |
54 | for _ in range(1, n_samples):
55 | if n_selected_pts < n_samples:
56 | dist_pts_to_selected = compute_distance(remaining_pts, selected_pts_expanded[:n_selected_pts]).T
57 | dist_pts_to_selected_min = np.min(dist_pts_to_selected, axis=1, keepdims=True)
58 | res_selected_idx = np.argmax(dist_pts_to_selected_min)
59 | selected_pts_expanded[n_selected_pts] = remaining_pts[res_selected_idx]
60 | n_selected_pts += 1
61 |
62 | selected_pts = np.squeeze(selected_pts_expanded, axis=1)
63 | return selected_pts
64 |
65 |
66 | def compute_heatmap(points, image_size, k_ratio=3.0):
67 | points = np.asarray(points)
68 | heatmap = np.zeros((image_size[0], image_size[1]), dtype=np.float32)
69 | n_points = points.shape[0]
70 | for i in range(n_points):
71 | x = points[i, 0]
72 | y = points[i, 1]
73 | col = int(x)
74 | row = int(y)
75 | try:
76 | heatmap[col, row] += 1.0
77 | except:
78 | col = min(max(col, 0), image_size[0] - 1)
79 | row = min(max(row, 0), image_size[1] - 1)
80 | heatmap[col, row] += 1.0
81 | k_size = int(np.sqrt(image_size[0] * image_size[1]) / k_ratio)
82 | if k_size % 2 == 0:
83 | k_size += 1
84 | heatmap = cv2.GaussianBlur(heatmap, (k_size, k_size), 0)
85 | if heatmap.max() > 0:
86 | heatmap /= heatmap.max()
87 | heatmap = heatmap.transpose()
88 | return heatmap
89 |
90 |
91 | def select_points_bbox(bbox, points, tolerance=2):
92 | x1, y1, x2, y2 = bbox
93 | ind_x = np.logical_and(points[:, 0] > x1-tolerance, points[:, 0] < x2+tolerance)
94 | ind_y = np.logical_and(points[:, 1] > y1-tolerance, points[:, 1] < y2+tolerance)
95 | ind = np.logical_and(ind_x, ind_y)
96 | indices = np.where(ind == True)[0]
97 | return points[indices]
98 |
99 |
100 | def find_contour_points(mask):
101 | _, contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
102 | if len(contours) != 0:
103 | c = max(contours, key=cv2.contourArea)
104 | c = c.squeeze(axis=1)
105 | return c
106 | else:
107 | return None
108 |
109 |
110 | def get_points_homo(select_points, homography, active_obj_traj, obj_bboxs_traj):
111 | # active_obj_traj: active obj traj in last observation frame
112 | # obj_bboxs_traj: active obj bbox traj in last observation frame
113 | select_points_homo = np.concatenate((select_points, np.ones((select_points.shape[0], 1), dtype=np.float32)), axis=1)
114 | select_points_homo = np.dot(select_points_homo, homography.T)
115 | select_points_homo = select_points_homo[:, :2] / select_points_homo[:, None, 2]
116 |
117 | obj_point_last_observe = np.array(active_obj_traj[0])
118 | obj_point_future_start = np.array(active_obj_traj[-1])
119 |
120 | future2last_trans = obj_point_last_observe - obj_point_future_start
121 | select_points_homo = select_points_homo + future2last_trans
122 |
123 | fill_indices = [idx for idx, points in enumerate(obj_bboxs_traj) if points is not None]
124 | contour_last_observe = obj_bboxs_traj[fill_indices[0]]
125 | contour_future_homo = obj_bboxs_traj[fill_indices[-1]] + future2last_trans
126 | contour_last_observe = contour_last_observe[:, None, :].astype(np.int)
127 | contour_future_homo = contour_future_homo[:, None, :].astype(np.int)
128 | filtered_points = []
129 | for point in select_points_homo:
130 | if cv2.pointPolygonTest(contour_last_observe, (point[0], point[1]), False) >= 0 \
131 | or cv2.pointPolygonTest(contour_future_homo, (point[0], point[1]), False) >= 0:
132 | filtered_points.append(point)
133 | filtered_points = np.array(filtered_points)
134 | return filtered_points
135 |
136 |
137 | def compute_affordance(frame, active_hand, active_obj, num_points=5, num_sampling=20):
138 | skin_mask = skin_extract(frame)
139 | hand_bbox = np.array(active_hand.bbox.coords_int).reshape(-1)
140 | obj_bbox = np.array(active_obj.bbox.coords_int).reshape(-1)
141 | obj_center = active_obj.bbox.center
142 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox)
143 | if not iou > 0:
144 | return None
145 | x1, y1, x2, y2 = hand_bbox
146 | hand_mask = np.zeros_like(skin_mask, dtype=np.uint8)
147 | hand_mask[y1:y2, x1:x2] = 255
148 | hand_mask = cv2.bitwise_and(skin_mask, hand_mask)
149 | select_points, init_points = None, None
150 | contact_points = find_contour_points(hand_mask)
151 |
152 | if contact_points is not None and contact_points.shape[0] > 0:
153 | contact_points = select_points_bbox((xA, yA, xB, yB), contact_points)
154 | if contact_points.shape[0] >= num_points:
155 | if contact_points.shape[0] > num_sampling:
156 | contact_points = farthest_sampling(contact_points, n_samples=num_sampling)
157 | distance = np.linalg.norm(contact_points - obj_center, ord=2, axis=1)
158 | indices = np.argsort(distance)[:num_points]
159 | select_points = contact_points[indices]
160 | elif contact_points.shape[0] > 0:
161 | print("no enough boundary points detected, sampling points in interaction region")
162 | init_points = contact_points
163 | else:
164 | print("no boundary points detected, use farthest point sampling")
165 | else:
166 | print("no boundary points detected, use farthest point sampling")
167 | if select_points is None:
168 | ho_mask = np.zeros_like(skin_mask, dtype=np.uint8)
169 | ho_mask[yA:yB, xA:xB] = 255
170 | ho_mask = cv2.bitwise_and(skin_mask, ho_mask)
171 | points = np.array(np.where(ho_mask[yA:yB, xA:xB] > 0)).T
172 | if points.shape[0] == 0:
173 | xx, yy = np.meshgrid(np.arange(xB - xA), np.arange(yB - yA))
174 | xx += xA
175 | yy += yA
176 | points = np.vstack([xx.reshape(-1), yy.reshape(-1)]).T
177 | else:
178 | points = points[:, [1, 0]]
179 | points[:, 0] += xA
180 | points[:, 1] += yA
181 | if not points.shape[0] > 0:
182 | return None
183 | contact_points = farthest_sampling(points, n_samples=min(num_sampling, points.shape[0]), init_pcd=init_points)
184 | distance = np.linalg.norm(contact_points - obj_center, ord=2, axis=1)
185 | indices = np.argsort(distance)[:num_points]
186 | select_points = contact_points[indices]
187 | return select_points
188 |
189 |
190 | def compute_obj_affordance(frame, annot, active_obj, active_obj_idx, homography,
191 | active_obj_traj, obj_bboxs_traj,
192 | num_points=5, num_sampling=20,
193 | hand_threshold=0.1, obj_threshold=0.1):
194 | affordance_info = {}
195 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold,
196 | hand_threshold=hand_threshold)
197 | select_points = None
198 | for hand_idx, object_idx in hand_object_idx_correspondences.items():
199 | if object_idx == active_obj_idx:
200 | active_hand = annot.hands[hand_idx]
201 | affordance_info[active_hand.side.name] = np.array(active_hand.bbox.coords_int).reshape(-1)
202 | cmap_points = compute_affordance(frame, active_hand, active_obj, num_points=num_points, num_sampling=num_sampling)
203 | if select_points is None and (cmap_points is not None and cmap_points.shape[0] > 0):
204 | select_points = cmap_points
205 | elif select_points is not None and (cmap_points is not None and cmap_points.shape[0] > 0):
206 | select_points = np.concatenate((select_points, cmap_points), axis=0)
207 | if select_points is None:
208 | print("affordance contact points filtered out")
209 | return None
210 | select_points_homo = get_points_homo(select_points, homography, active_obj_traj, obj_bboxs_traj)
211 | if len(select_points_homo) == 0:
212 | print("affordance contact points filtered out")
213 | return None
214 | else:
215 | affordance_info["select_points"] = select_points
216 | affordance_info["select_points_homo"] = select_points_homo
217 |
218 | obj_bbox = np.array(active_obj.bbox.coords_int).reshape(-1)
219 | affordance_info["obj_bbox"] = obj_bbox
220 | return affordance_info
--------------------------------------------------------------------------------
/preprocess/dataset_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from preprocess.ho_types import FrameDetections, HandDetection, HandSide, HandState, ObjectDetection
5 |
6 |
7 | def sample_action_anticipation_frames(frame_start, t_buffer=1, fps=4.0, fps_init=60.0):
8 | time_start = (frame_start - 1) / fps_init
9 | num_frames = int(np.floor(t_buffer * fps))
10 | times = (np.arange(num_frames + 1) - num_frames) / fps + time_start
11 | times = np.clip(times, 0, np.inf)
12 | times = times.astype(np.float32)
13 | frames_idxs = np.floor(times * fps_init).astype(np.int32) + 1
14 | if frames_idxs.max() >= 1:
15 | frames_idxs[frames_idxs < 1] = frames_idxs[frames_idxs >= 1].min()
16 | return list(frames_idxs)
17 |
18 |
19 | def load_ho_annot(video_detections, frame_index, imgW=456, imgH=256):
20 | annot = video_detections[frame_index-1] # frame_index start from 1
21 | assert annot.frame_number == frame_index, "wrong frame index"
22 | annot.scale(width_factor=imgW, height_factor=imgH)
23 | return annot
24 |
25 |
26 | def load_img(frames_path, frame_index):
27 | frame = cv2.imread(os.path.join(frames_path, "frame_{:010d}.jpg".format(frame_index)))
28 | return frame
29 |
30 |
31 | def get_mask(frame, annot, hand_threshold=0.1, obj_threshold=0.1):
32 | msk_img = np.ones((frame.shape[:2]), dtype=frame.dtype)
33 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold]
34 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold]
35 | for hand in hands:
36 | (x1, y1), (x2, y2) = hand.bbox.coords_int
37 | msk_img[y1:y2, x1:x2] = 0
38 |
39 | if len(objs) > 0:
40 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold,
41 | hand_threshold=hand_threshold)
42 | for hand_idx, object_idx in hand_object_idx_correspondences.items():
43 | hand = annot.hands[hand_idx]
44 | object = annot.objects[object_idx]
45 | if not hand.state.value == HandState.STATIONARY_OBJECT.value:
46 | (x1, y1), (x2, y2) = object.bbox.coords_int
47 | msk_img[y1:y2, x1:x2] = 0
48 | return msk_img
49 |
50 |
51 | def bbox_inter(boxA, boxB):
52 | xA = max(boxA[0], boxB[0])
53 | yA = max(boxA[1], boxB[1])
54 | xB = min(boxA[2], boxB[2])
55 | yB = min(boxA[3], boxB[3])
56 |
57 | interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
58 | if interArea == 0:
59 | return xA, yA, xB, yB, 0
60 |
61 | boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
62 | boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
63 | iou = interArea / float(boxAArea + boxBArea - interArea)
64 | return xA, yA, xB, yB, iou
65 |
66 |
67 | def compute_iou(boxA, boxB):
68 | boxA = np.array(boxA).reshape(-1)
69 | boxB = np.array(boxB).reshape(-1)
70 | xA = max(boxA[0], boxB[0])
71 | yA = max(boxA[1], boxB[1])
72 | xB = min(boxA[2], boxB[2])
73 | yB = min(boxA[3], boxB[3])
74 | interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
75 | if interArea == 0:
76 | return 0
77 | boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
78 | boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
79 | iou = interArea / float(boxAArea + boxBArea - interArea)
80 | return iou
81 |
82 |
83 | def points_in_bbox(point, bbox):
84 | (x1, y1), (x2, y2) = bbox
85 | (x, y) = point
86 | return (x1 <= x <= x2) and (y1 <= y <= y2)
87 |
88 |
89 | def valid_point(point, imgW=456, imgH=256):
90 | if point is None:
91 | return False
92 | else:
93 | x, y = point
94 | return (0 <= x < imgW) and (0 <=y < imgH)
95 |
96 |
97 | def valid_traj(traj, imgW=456, imgH=256):
98 | if len(traj) > 0:
99 | num_outlier = np.sum([not valid_point(point, imgW=imgW, imgH=imgH)
100 | for point in traj if point is not None])
101 | valid_ratio = np.sum([valid_point(point, imgW=imgW, imgH=imgH) for point in traj[1:]]) / len(traj[1:])
102 | valid_last = valid_point(traj[-1], imgW=imgW, imgH=imgH)
103 | if num_outlier > 1 or valid_ratio < 0.5 or not valid_last:
104 | traj = []
105 | return traj
106 |
107 |
108 | def get_valid_traj(traj, imgW=456, imgH=256):
109 | try:
110 | traj[traj < 0] = traj[traj >= 0].min()
111 | except:
112 | traj[traj < 0] = 0
113 | try:
114 | traj[:, 0][traj[:, 0] >= imgW] = imgW - 1
115 | except:
116 | traj[:, 0][traj[:, 0] >= imgW] = imgW - 1
117 | try:
118 | traj[:, 1][traj[:, 1] >= imgH] = imgH - 1
119 | except:
120 | traj[:, 1][traj[:, 1] >= imgH] = imgH - 1
121 | return traj
122 |
123 |
124 | def fetch_data(frames_path, video_detections, frames_idxs, hand_threshold=0.1, obj_threshold=0.1):
125 | tolerance = frames_idxs[1] - frames_idxs[0] # extend future act frame by tolerance to find ho interaction
126 | frames = []
127 | annots = []
128 |
129 | miss_hand = 0
130 | for frame_idx in frames_idxs[:-1]:
131 | frame = load_img(frames_path, frame_idx)
132 | annot = load_ho_annot(video_detections, frame_idx)
133 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold]
134 | if len(hands) == 0:
135 | miss_hand += 1
136 | frames.append(frame)
137 | annots.append(annot)
138 | if miss_hand == len(frames_idxs[:-1]):
139 | return None
140 | frame_idx = frames_idxs[-1]
141 | frames_idxs = frames_idxs[:-1]
142 |
143 | hand_sides = []
144 | idx = 0
145 | flag = False
146 | while idx < tolerance:
147 | annot = load_ho_annot(video_detections, frame_idx)
148 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold]
149 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold]
150 | if len(hands) > 0 and len(objs) > 0: # at least one hand is contact with obj
151 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold,
152 | hand_threshold=hand_threshold)
153 | for hand_idx, object_idx in hand_object_idx_correspondences.items():
154 | hand_bbox = np.array(annot.hands[hand_idx].bbox.coords).reshape(-1)
155 | obj_bbox = np.array(annot.objects[object_idx].bbox.coords).reshape(-1)
156 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox)
157 | contact_state = annot.hands[hand_idx].state.value
158 | if iou > 0 and (contact_state == HandState.STATIONARY_OBJECT.value or
159 | contact_state == HandState.PORTABLE_OBJECT.value):
160 | hand_side = annot.hands[hand_idx].side.name
161 | hand_sides.append(hand_side)
162 | flag = True
163 | if flag:
164 | break
165 | else:
166 | idx += 1
167 | frame_idx += 1
168 | else:
169 | idx += 1
170 | frame_idx += 1
171 | if flag:
172 | frames_idxs.append(frame_idx)
173 | frames.append(load_img(frames_path, frame_idx))
174 | annots.append(annot)
175 | return frames_idxs, frames, annots, list(set(hand_sides)) # remove redundant hand sides
176 | else:
177 | return None
178 |
179 |
180 | def save_video_info(save_path, video_index, frames_idxs, homography_stack, contacts,
181 | hand_trajs, obj_trajs, affordance_info):
182 | import pickle
183 | video_info = {"frame_indices": frames_idxs,
184 | "homography": homography_stack,
185 | "contact": contacts}
186 | video_info.update({"hand_trajs": hand_trajs})
187 | video_info.update({"obj_trajs": obj_trajs})
188 | video_info.update({"affordance": affordance_info})
189 | with open(os.path.join(save_path, "label_{}.pkl".format(video_index)), 'wb') as f:
190 | pickle.dump(video_info, f)
191 |
192 |
193 | def load_video_info(save_path, video_index):
194 | import pickle
195 | with open(os.path.join(save_path, "label_{}.pkl".format(video_index)), 'rb') as f:
196 | video_info = pickle.load(f)
197 | return video_info
198 |
--------------------------------------------------------------------------------
/preprocess/ho_types.py:
--------------------------------------------------------------------------------
1 | """The core set of types that represent hand-object detections"""
2 |
3 | from enum import Enum, unique
4 | from itertools import chain
5 | from typing import Dict, Iterator, List, Tuple, cast
6 |
7 | import numpy as np
8 | from dataclasses import dataclass
9 | import preprocess.types_pb2 as pb
10 |
11 | __all__ = [
12 | "HandSide",
13 | "HandState",
14 | "FloatVector",
15 | "BBox",
16 | "HandDetection",
17 | "ObjectDetection",
18 | "FrameDetections",
19 | ]
20 |
21 |
22 | @unique
23 | class HandSide(Enum):
24 | LEFT = 0
25 | RIGHT = 1
26 |
27 |
28 | @unique
29 | class HandState(Enum):
30 | """An enum describing the different states a hand can be in:
31 | - No contact: The hand isn't touching anything
32 | - Self contact: The hand is touching itself
33 | - Another person: The hand is touching another person
34 | - Portable object: The hand is in contact with a portable object
35 | - Stationary object: The hand is in contact with an immovable/stationary object"""
36 |
37 | NO_CONTACT = 0
38 | SELF_CONTACT = 1
39 | ANOTHER_PERSON = 2
40 | PORTABLE_OBJECT = 3
41 | STATIONARY_OBJECT = 4
42 |
43 |
44 | @dataclass
45 | class FloatVector:
46 | """A floating-point 2D vector representation"""
47 | x: np.float32
48 | y: np.float32
49 |
50 | def to_protobuf(self) -> pb.FloatVector:
51 | vector = pb.FloatVector()
52 | vector.x = self.x
53 | vector.y = self.y
54 | assert vector.IsInitialized()
55 | return vector
56 |
57 | @staticmethod
58 | def from_protobuf(vector: pb.FloatVector) -> "FloatVector":
59 | return FloatVector(x=vector.x, y=vector.y)
60 |
61 | def __add__(self, other: "FloatVector") -> "FloatVector":
62 | return FloatVector(x=self.x + other.x, y=self.y + other.y)
63 |
64 | def __mul__(self, scaler: float) -> "FloatVector":
65 | return FloatVector(x=self.x * scaler, y=self.y * scaler)
66 |
67 | def __iter__(self) -> Iterator[float]:
68 | yield from (self.x, self.y)
69 |
70 | @property
71 | def coord(self) -> Tuple[float, float]:
72 | """Return coordinates as a tuple"""
73 | return (self.x, self.y)
74 |
75 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
76 | """Scale x component by ``width_factor`` and y component by ``height_factor``"""
77 | self.x *= width_factor
78 | self.y *= height_factor
79 |
80 |
81 | @dataclass
82 | class BBox:
83 | left: float
84 | top: float
85 | right: float
86 | bottom: float
87 |
88 | def to_protobuf(self) -> pb.BBox:
89 | bbox = pb.BBox()
90 | bbox.left = self.left
91 | bbox.top = self.top
92 | bbox.right = self.right
93 | bbox.bottom = self.bottom
94 | assert bbox.IsInitialized()
95 | return bbox
96 |
97 | @staticmethod
98 | def from_protobuf(bbox: pb.BBox) -> "BBox":
99 | return BBox(
100 | left=bbox.left,
101 | top=bbox.top,
102 | right=bbox.right,
103 | bottom=bbox.bottom,
104 | )
105 |
106 | @property
107 | def center(self) -> Tuple[float, float]:
108 | x = (self.left + self.right) / 2
109 | y = (self.top + self.bottom) / 2
110 | return x, y
111 |
112 | @property
113 | def center_int(self) -> Tuple[int, int]:
114 | """Get center position as a tuple of integers (rounded)"""
115 | x, y = self.center
116 | return (round(x), round(y))
117 |
118 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
119 | self.left *= width_factor
120 | self.right *= width_factor
121 | self.top *= height_factor
122 | self.bottom *= height_factor
123 |
124 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
125 | x, y = self.center
126 | new_width = self.width * width_factor
127 | new_height = self.height * height_factor
128 | self.left = x - new_width / 2
129 | self.right = x + new_width / 2
130 | self.top = y - new_height / 2
131 | self.bottom = y + new_height / 2
132 |
133 | @property
134 | def coords(self) -> Tuple[Tuple[float, float], Tuple[float, float]]:
135 | return (
136 | self.top_left,
137 | self.bottom_right,
138 | )
139 |
140 | @property
141 | def coords_int(self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
142 | return (
143 | self.top_left_int,
144 | self.bottom_right_int,
145 | )
146 |
147 | @property
148 | def width(self) -> float:
149 | return self.right - self.left
150 |
151 | @property
152 | def height(self) -> float:
153 | return self.bottom - self.top
154 |
155 | @property
156 | def top_left(self) -> Tuple[float, float]:
157 | return (self.left, self.top)
158 |
159 | @property
160 | def bottom_right(self) -> Tuple[float, float]:
161 | return (self.right, self.bottom)
162 |
163 | @property
164 | def top_left_int(self) -> Tuple[int, int]:
165 | return (round(self.left), round(self.top))
166 |
167 | @property
168 | def bottom_right_int(self) -> Tuple[int, int]:
169 | return (round(self.right), round(self.bottom))
170 |
171 |
172 | @dataclass
173 | class HandDetection:
174 | """Dataclass representing a hand detection, consisting of a bounding box,
175 | a score (representing the model's confidence this is a hand), the predicted state
176 | of the hand, whether this is a left/right hand, and a predicted offset to the
177 | interacted object if the hand is interacting."""
178 |
179 | bbox: BBox
180 | score: np.float32
181 | state: HandState
182 | side: HandSide
183 | object_offset: FloatVector
184 |
185 | def to_protobuf(self) -> pb.HandDetection:
186 | detection = pb.HandDetection()
187 | detection.bbox.MergeFrom(self.bbox.to_protobuf())
188 | detection.score = self.score
189 | detection.state = self.state.value
190 | detection.object_offset.MergeFrom(self.object_offset.to_protobuf())
191 | detection.side = self.side.value
192 | assert detection.IsInitialized()
193 | return detection
194 |
195 | @staticmethod
196 | def from_protobuf(detection: pb.HandDetection) -> "HandDetection":
197 | return HandDetection(
198 | bbox=BBox.from_protobuf(detection.bbox),
199 | score=detection.score,
200 | state=HandState(detection.state),
201 | object_offset=FloatVector.from_protobuf(detection.object_offset),
202 | side=HandSide(detection.side),
203 | )
204 |
205 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
206 | self.bbox.scale(width_factor=width_factor, height_factor=height_factor)
207 | self.object_offset.scale(width_factor=width_factor, height_factor=height_factor)
208 |
209 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
210 | self.bbox.center_scale(width_factor=width_factor, height_factor=height_factor)
211 |
212 |
213 | @dataclass
214 | class ObjectDetection:
215 | """Dataclass representing an object detection, consisting of a bounding box and a
216 | score (the model's confidence this is an object)"""
217 |
218 | bbox: BBox
219 | score: np.float32
220 |
221 | def to_protobuf(self) -> pb.ObjectDetection:
222 | detection = pb.ObjectDetection()
223 | detection.bbox.MergeFrom(self.bbox.to_protobuf())
224 | detection.score = self.score
225 | assert detection.IsInitialized()
226 | return detection
227 |
228 | @staticmethod
229 | def from_protobuf(detection: pb.ObjectDetection) -> "ObjectDetection":
230 | return ObjectDetection(
231 | bbox=BBox.from_protobuf(detection.bbox), score=detection.score
232 | )
233 |
234 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
235 | self.bbox.scale(width_factor=width_factor, height_factor=height_factor)
236 |
237 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
238 | self.bbox.center_scale(width_factor=width_factor, height_factor=height_factor)
239 |
240 |
241 | @dataclass
242 | class FrameDetections:
243 | """Dataclass representing hand-object detections for a frame of a video"""
244 |
245 | video_id: str
246 | frame_number: int
247 | objects: List[ObjectDetection]
248 | hands: List[HandDetection]
249 |
250 | def to_protobuf(self) -> pb.Detections:
251 | detections = pb.Detections()
252 | detections.video_id = self.video_id
253 | detections.frame_number = self.frame_number
254 | detections.hands.extend([hand.to_protobuf() for hand in self.hands])
255 | detections.objects.extend([object.to_protobuf() for object in self.objects])
256 | assert detections.IsInitialized()
257 | return detections
258 |
259 | @staticmethod
260 | def from_protobuf(detections: pb.Detections) -> "FrameDetections":
261 | return FrameDetections(
262 | video_id=detections.video_id,
263 | frame_number=detections.frame_number,
264 | hands=[HandDetection.from_protobuf(pb) for pb in detections.hands],
265 | objects=[ObjectDetection.from_protobuf(pb) for pb in detections.objects],
266 | )
267 |
268 | @staticmethod
269 | def from_protobuf_str(pb_str: bytes) -> "FrameDetections":
270 | pb_detection = pb.Detections()
271 | pb_detection.MergeFromString(pb_str)
272 | return FrameDetections.from_protobuf(pb_detection)
273 |
274 | def get_hand_object_interactions(
275 | self, object_threshold: float = 0, hand_threshold: float = 0
276 | ) -> Dict[int, int]:
277 | """Match the hands to objects based on the hand offset vector that the model
278 | uses to predict the location of the interacted object.
279 |
280 | Args:
281 | object_threshold: Object score threshold above which to consider objects
282 | for matching
283 | hand_threshold: Hand score threshold above which to consider hands for
284 | matching.
285 |
286 | Returns:
287 | A dictionary mapping hand detections to objects by indices
288 | """
289 | interactions = dict()
290 | object_idxs = [
291 | i for i, obj in enumerate(self.objects) if obj.score >= object_threshold
292 | ]
293 | object_centers = np.array(
294 | [self.objects[object_id].bbox.center for object_id in object_idxs]
295 | )
296 | for hand_idx, hand_detection in enumerate(self.hands):
297 | if (
298 | hand_detection.state.value == HandState.NO_CONTACT.value
299 | or hand_detection.score <= hand_threshold
300 | ):
301 | continue
302 | estimated_object_position = (
303 | np.array(hand_detection.bbox.center) +
304 | np.array(hand_detection.object_offset.coord)
305 | )
306 | distances = ((object_centers - estimated_object_position) ** 2).sum(
307 | axis=-1)
308 | interactions[hand_idx] = object_idxs[cast(int, np.argmin(distances))]
309 | return interactions
310 |
311 | def scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
312 | """
313 | Scale the coordinates of all the hands/objects. x components are multiplied
314 | by the ``width_factor`` and y components by the ``height_factor``
315 | """
316 | for det in chain(self.hands, self.objects):
317 | det.scale(width_factor=width_factor, height_factor=height_factor)
318 |
319 | def center_scale(self, width_factor: float = 1, height_factor: float = 1) -> None:
320 | """
321 | Scale all the hands/objects about their center points.
322 | """
323 | for det in chain(self.hands, self.objects):
324 | det.center_scale(width_factor=width_factor, height_factor=height_factor)
--------------------------------------------------------------------------------
/preprocess/obj_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from preprocess.dataset_util import bbox_inter, HandState, compute_iou, \
3 | valid_traj, get_valid_traj, points_in_bbox
4 | from preprocess.traj_util import get_homo_point, get_homo_bbox_point
5 |
6 |
7 | def find_active_side(annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1):
8 | if len(hand_sides) == 1:
9 | return hand_sides[0]
10 | else:
11 | hand_counter = {"LEFT": 0, "RIGHT": 0}
12 | for annot in annots:
13 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold]
14 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold]
15 | if len(hands) > 0 and len(objs) > 0:
16 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold,
17 | hand_threshold=hand_threshold)
18 | for hand_idx, object_idx in hand_object_idx_correspondences.items():
19 | hand_bbox = np.array(annot.hands[hand_idx].bbox.coords_int).reshape(-1)
20 | obj_bbox = np.array(annot.objects[object_idx].bbox.coords_int).reshape(-1)
21 | xA, yA, xB, yB, iou = bbox_inter(hand_bbox, obj_bbox)
22 | if iou > 0:
23 | hand_side = annot.hands[hand_idx].side.name
24 | if annot.hands[hand_idx].state.value == HandState.PORTABLE_OBJECT.value:
25 | hand_counter[hand_side] += 1
26 | elif annot.hands[hand_idx].state.value == HandState.STATIONARY_OBJECT.value:
27 | hand_counter[hand_side] += 0.5
28 | if hand_counter["LEFT"] == hand_counter["RIGHT"]:
29 | return "RIGHT"
30 | else:
31 | return max(hand_counter, key=hand_counter.get)
32 |
33 |
34 | def compute_contact(annots, hand_side, contact_state, hand_threshold=0.1):
35 | contacts = []
36 | for annot in annots:
37 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold
38 | and hand.side.name == hand_side and hand.state.value == contact_state]
39 | if len(hands) > 0:
40 | contacts.append(1)
41 | else:
42 | contacts.append(0)
43 | contacts = np.array(contacts)
44 | padding_contacts = np.pad(contacts, [1, 1], 'edge')
45 | contacts = np.convolve(padding_contacts, [1, 1, 1], 'same')
46 | contacts = contacts[1:-1] / 3
47 | contacts = contacts > 0.5
48 | indices = np.diff(contacts) != 0
49 | if indices.sum() == 0:
50 | return contacts
51 | else:
52 | split = np.where(indices)[0] + 1
53 | contacts_idx = split[-1]
54 | contacts[:contacts_idx] = False
55 | return contacts
56 |
57 |
58 | def find_active_obj_side(annot, hand_side, return_hand=False, return_idx=False, hand_threshold=0.1, obj_threshold=0.1):
59 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold]
60 | if len(objs) == 0:
61 | return None
62 | else:
63 | hand_object_idx_correspondences = annot.get_hand_object_interactions(object_threshold=obj_threshold,
64 | hand_threshold=hand_threshold)
65 | for hand_idx, object_idx in hand_object_idx_correspondences.items():
66 | if annot.hands[hand_idx].side.name == hand_side:
67 | if return_hand and return_idx:
68 | return annot.objects[object_idx], object_idx, annot.hands[hand_idx], hand_idx
69 | elif return_hand:
70 | return annot.objects[object_idx], annot.hands[hand_idx]
71 | elif return_idx:
72 | return annot.objects[object_idx], object_idx
73 | else:
74 | return annot.objects[object_idx]
75 | return None
76 |
77 |
78 | def find_active_obj_iou(objs, bbox):
79 | max_iou = 0
80 | active_obj = None
81 | for obj in objs:
82 | iou = compute_iou(obj.bbox.coords, bbox)
83 | if iou > max_iou:
84 | max_iou = iou
85 | active_obj = obj
86 | return active_obj, max_iou
87 |
88 |
89 | def traj_compute(annots, hand_sides, homography_stack, hand_threshold=0.1, obj_threshold=0.1):
90 | annot = annots[-1]
91 | obj_traj = []
92 | obj_centers = []
93 | obj_bboxs =[]
94 | obj_bboxs_traj = []
95 | active_hand_side = find_active_side(annots, hand_sides, hand_threshold=hand_threshold,
96 | obj_threshold=obj_threshold)
97 | active_obj, active_object_idx, active_hand, active_hand_idx = find_active_obj_side(annot,
98 | hand_side=active_hand_side,
99 | return_hand=True, return_idx=True,
100 | hand_threshold=hand_threshold,
101 | obj_threshold=obj_threshold)
102 | contact_state = active_hand.state.value
103 | contacts = compute_contact(annots, active_hand_side, contact_state,
104 | hand_threshold=hand_threshold)
105 | obj_center = active_obj.bbox.center
106 | obj_centers.append(obj_center)
107 | obj_point = get_homo_point(obj_center, homography_stack[-1])
108 | obj_bbox = active_obj.bbox.coords
109 | obj_traj.append(obj_point)
110 | obj_bboxs.append(obj_bbox)
111 |
112 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[-1])
113 | obj_bboxs_traj.append(obj_points2d)
114 |
115 | for idx in np.arange(len(annots)-2, -1, -1):
116 | annot = annots[idx]
117 | objs = [obj for obj in annot.objects if obj.score >= obj_threshold]
118 | contact = contacts[idx]
119 | if not contact:
120 | obj_centers.append(None)
121 | obj_traj.append(None)
122 | obj_bboxs_traj.append(None)
123 | else:
124 | if len(objs) >= 2:
125 | target_obj, max_iou = find_active_obj_iou(objs, obj_bboxs[-1])
126 | if target_obj is None:
127 | target_obj = find_active_obj_side(annot, hand_side=active_hand_side,
128 | hand_threshold=hand_threshold,
129 | obj_threshold=obj_threshold)
130 | if target_obj is None:
131 | obj_centers.append(None)
132 | obj_traj.append(None)
133 | obj_bboxs_traj.append(None)
134 | else:
135 | obj_center = target_obj.bbox.center
136 | obj_centers.append(obj_center)
137 | obj_point = get_homo_point(obj_center, homography_stack[idx])
138 | obj_bbox = target_obj.bbox.coords
139 | obj_traj.append(obj_point)
140 | obj_bboxs.append(obj_bbox)
141 |
142 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[idx])
143 | obj_bboxs_traj.append(obj_points2d)
144 |
145 | elif len(objs) > 0:
146 | target_obj = find_active_obj_side(annot, hand_side=active_hand_side,
147 | hand_threshold=hand_threshold,
148 | obj_threshold=obj_threshold)
149 | if target_obj is None:
150 | obj_centers.append(None)
151 | obj_traj.append(None)
152 | obj_bboxs_traj.append(None)
153 | else:
154 | obj_center = target_obj.bbox.center
155 | obj_centers.append(obj_center)
156 | obj_point = get_homo_point(obj_center, homography_stack[idx])
157 | obj_bbox = target_obj.bbox.coords
158 | obj_traj.append(obj_point)
159 | obj_bboxs.append(obj_bbox)
160 |
161 | obj_points2d = get_homo_bbox_point(obj_bbox, homography_stack[idx])
162 | obj_bboxs_traj.append(obj_points2d)
163 | else:
164 | obj_centers.append(None)
165 | obj_traj.append(None)
166 | obj_bboxs_traj.append(None)
167 | obj_bboxs.reverse()
168 | obj_traj.reverse()
169 | obj_centers.reverse()
170 | obj_bboxs_traj.reverse()
171 | return obj_traj, obj_centers, obj_bboxs, contacts, active_obj, active_object_idx, obj_bboxs_traj
172 |
173 |
174 | def traj_filter(obj_traj, obj_centers, obj_bbox, contacts, homography_stack, contact_ratio=0.4):
175 | assert len(obj_traj) == len(obj_centers), "traj length and center length not equal"
176 | assert len(obj_centers) == len(homography_stack), "center length and homography length not equal"
177 | homo_last2first = homography_stack[-1]
178 | homo_first2last = np.linalg.inv(homo_last2first)
179 | obj_points = []
180 | obj_inside, obj_detect = [], []
181 | for idx, obj_center in enumerate(obj_centers):
182 | if obj_center is not None:
183 | homo_current2first = homography_stack[idx]
184 | homo_current2last = homo_current2first.dot(homo_first2last)
185 | obj_point = get_homo_point(obj_center, homo_current2last)
186 | obj_points.append(obj_point)
187 | obj_inside.append(points_in_bbox(obj_point, obj_bbox))
188 | obj_detect.append(True)
189 | else:
190 | obj_detect.append(False)
191 | obj_inside = np.array(obj_inside)
192 | obj_detect = np.array(obj_detect)
193 | contacts = np.bitwise_and(obj_detect, contacts)
194 | if np.sum(obj_inside) == len(obj_inside) and np.sum(contacts) / len(contacts) < contact_ratio:
195 | obj_traj = np.tile(obj_traj[-1], (len(obj_traj), 1))
196 | return obj_traj, contacts
197 |
198 |
199 | def traj_completion(traj, imgW=456, imgH=256):
200 | fill_indices = [idx for idx, point in enumerate(traj) if point is not None]
201 | full_traj = traj.copy()
202 | if len(fill_indices) == 1:
203 | point = traj[fill_indices[0]]
204 | full_traj = np.array([point] * len(traj), dtype=np.float32)
205 | else:
206 | contact_time = fill_indices[0]
207 | if contact_time > 0:
208 | full_traj[:contact_time] = [traj[contact_time]] * contact_time
209 | for previous_idx, current_idx in zip(fill_indices[:-1], fill_indices[1:]):
210 | start_point, end_point = traj[previous_idx], traj[current_idx]
211 | time_expand = current_idx - previous_idx
212 | for idx in range(previous_idx+1, current_idx):
213 | full_traj[idx] = (idx-previous_idx) / time_expand * end_point + (current_idx-idx) / time_expand * start_point
214 | full_traj = np.array(full_traj, dtype=np.float32)
215 | full_traj = get_valid_traj(full_traj, imgW=imgW, imgH=imgH)
216 | return full_traj, fill_indices
217 |
218 |
219 | def compute_obj_traj(frames, annots, hand_sides, homography_stack, hand_threshold=0.1, obj_threshold=0.1,
220 | contact_ratio=0.4):
221 | imgH, imgW = frames[0].shape[:2]
222 | obj_traj, obj_centers, obj_bboxs, contacts, active_obj, active_object_idx, obj_bboxs_traj = traj_compute(annots, hand_sides, homography_stack,
223 | hand_threshold=hand_threshold, obj_threshold=obj_threshold)
224 | obj_traj, contacts = traj_filter(obj_traj, obj_centers, obj_bboxs[-1], contacts, homography_stack,
225 | contact_ratio=contact_ratio)
226 | obj_traj = valid_traj(obj_traj, imgW=imgW, imgH=imgH)
227 | if len(obj_traj) == 0:
228 | print("object traj filtered out")
229 | return None
230 | else:
231 | complete_traj, fill_indices = traj_completion(obj_traj, imgW=imgW, imgH=imgH)
232 | obj_trajs = {"traj": complete_traj, "fill_indices": fill_indices, "centers": obj_centers}
233 | return contacts, obj_trajs, active_obj, active_object_idx, obj_bboxs_traj
--------------------------------------------------------------------------------
/preprocess/traj_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from preprocess.dataset_util import get_mask, valid_traj
4 |
5 |
6 | def match_keypoints(kpsA, kpsB, featuresA, featuresB, ratio=0.7, reprojThresh=4.0):
7 | matcher = cv2.DescriptorMatcher_create("BruteForce")
8 | rawMatches = matcher.knnMatch(featuresA, featuresB, 2)
9 | matches = []
10 |
11 | for m in rawMatches:
12 | if len(m) == 2 and m[0].distance < m[1].distance * ratio:
13 | matches.append((m[0]))
14 |
15 | if len(matches) > 4:
16 | ptsA = np.float32([kpsA[m.queryIdx].pt for m in matches])
17 | ptsB = np.float32([kpsB[m.trainIdx].pt for m in matches])
18 | (H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, reprojThresh)
19 | matchesMask = status.ravel().tolist()
20 | return matches, H, matchesMask
21 | return None
22 |
23 |
24 | def get_pair_homography(frame_1, frame_2, annot_1, annot_2, hand_threshold=0.1, obj_threshold=0.1):
25 | flag = True
26 | descriptor = cv2.xfeatures2d.SURF_create()
27 | msk_img_1 = get_mask(frame_1, annot_1, hand_threshold=hand_threshold, obj_threshold=obj_threshold)
28 | msk_img_2 = get_mask(frame_2, annot_2, hand_threshold=hand_threshold, obj_threshold=obj_threshold)
29 | (kpsA, featuresA) = descriptor.detectAndCompute(frame_1, mask=msk_img_1)
30 | (kpsB, featuresB) = descriptor.detectAndCompute(frame_2, mask=msk_img_2)
31 | matches, matchesMask = None, None
32 | try:
33 | (matches, H_BA, matchesMask) = match_keypoints(kpsB, kpsA, featuresB, featuresA)
34 | except Exception:
35 | print("compute homography failed!")
36 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3)
37 | flag = False
38 |
39 | NoneType = type(None)
40 | if type(H_BA) == NoneType:
41 | print("compute homography failed!")
42 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3)
43 | flag = False
44 | try:
45 | np.linalg.inv(H_BA)
46 | except Exception:
47 | print("compute homography failed!")
48 | H_BA = np.array([1.0, 0, 0, 0, 1.0, 0, 0, 0, 1.0]).reshape(3, 3)
49 | flag = False
50 | return matches, H_BA, matchesMask, flag
51 |
52 |
53 | def get_homo_point(point, homography):
54 | cx, cy = point
55 | center = np.array((cx, cy, 1.0), dtype=np.float32)
56 | x, y, z = np.dot(homography, center)
57 | x, y = x / z, y / z
58 | point = np.array((x, y), dtype=np.float32)
59 | return point
60 |
61 |
62 | def get_homo_bbox_point(bbox, homography):
63 | x1, y1, x2, y2 = np.array(bbox).reshape(-1)
64 | points = np.array([[x1, y1], [x2, y1], [x1, y2], [x2, y2]], dtype=np.float32)
65 | points_homo = np.concatenate((points, np.ones((4, 1), dtype=np.float32)), axis=1)
66 | points_coord = np.dot(points_homo, homography.T)
67 | points_coord2d = points_coord[:, :2] / points_coord[:, None, 2]
68 | return points_coord2d
69 |
70 |
71 | def get_hand_center(annot, hand_threshold=0.1):
72 | hands = [hand for hand in annot.hands if hand.score >= hand_threshold]
73 | hands_center= {}
74 | hands_score = {}
75 | for hand in hands:
76 | side = hand.side.name
77 | score = hand.score
78 | if side not in hands_center or score > hands_score[side]:
79 | hands_center[side] = hand.bbox.center
80 | hands_score[side] = score
81 | return hands_center
82 |
83 |
84 | def get_hand_point(hands_center, homography, side):
85 | point, homo_point = None, None
86 | if side in hands_center:
87 | point = hands_center[side]
88 | homo_point = get_homo_point(point, homography)
89 | return point, homo_point
90 |
91 |
92 | def traj_compute(frames, annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1):
93 | imgH, imgW = frames[0].shape[:2]
94 | left_traj, right_traj = [], []
95 | left_centers , right_centers= [], []
96 | homography_stack = [np.eye(3)]
97 | for idx in range(1, len(frames)):
98 | matches, H_BA, matchesMask, flag = get_pair_homography(frames[idx - 1], frames[idx],
99 | annots[idx - 1], annots[idx],
100 | hand_threshold=hand_threshold,
101 | obj_threshold=obj_threshold)
102 | if not flag:
103 | return None
104 | else:
105 | homography_stack.append(np.dot(homography_stack[-1], H_BA))
106 | for idx in range(len(frames)):
107 | hands_center = get_hand_center(annots[idx], hand_threshold=hand_threshold)
108 | if "LEFT" in hand_sides:
109 | left_center, left_point = get_hand_point(hands_center, homography_stack[idx], "LEFT")
110 | left_centers.append(left_center)
111 | left_traj.append(left_point)
112 | if "RIGHT" in hand_sides:
113 | right_center, right_point = get_hand_point(hands_center, homography_stack[idx], "RIGHT")
114 | right_centers.append(right_center)
115 | right_traj.append(right_point)
116 |
117 | left_traj = valid_traj(left_traj, imgW=imgW, imgH=imgH)
118 | right_traj = valid_traj(right_traj, imgW=imgW, imgH=imgH)
119 | return left_traj, left_centers, right_traj, right_centers, homography_stack
120 |
121 |
122 | def traj_completion(traj, side, imgW=456, imgH=256):
123 | from scipy.interpolate import CubicHermiteSpline
124 |
125 | def get_valid_traj(traj, imgW, imgH):
126 | traj[traj < 0] = traj[traj >= 0].min()
127 | traj[:, 0][traj[:, 0] > 1.5 * imgW] = 1.5 * imgW
128 | traj[:, 1][traj[:, 1] > 1.5 * imgH] = 1.5 * imgH
129 | return traj
130 |
131 | def spline_interpolation(axis):
132 | fill_times = np.array(fill_indices, dtype=np.float32)
133 | fill_traj = np.array([traj[idx][axis] for idx in fill_indices], dtype=np.float32)
134 | dt = fill_times[2:] - fill_times[:-2]
135 | dt = np.hstack([fill_times[1] - fill_times[0], dt, fill_times[-1] - fill_times[-2]])
136 | dx = fill_traj[2:] - fill_traj[:-2]
137 | dx = np.hstack([fill_traj[1] - fill_traj[0], dx, fill_traj[-1] - fill_traj[-2]])
138 | dxdt = dx / dt
139 | curve = CubicHermiteSpline(fill_times, fill_traj, dxdt)
140 | full_traj = curve(np.arange(len(traj), dtype=np.float32))
141 | return full_traj, curve
142 |
143 | fill_indices = [idx for idx, point in enumerate(traj) if point is not None]
144 | if 0 not in fill_indices:
145 | if side == "LEFT":
146 | traj[0] = np.array((0.25*imgW, 1.5*imgH), dtype=np.float32)
147 | else:
148 | traj[0] = np.array((0.75*imgW, 1.5*imgH), dtype=np.float32)
149 | fill_indices = np.insert(fill_indices, 0, 0).tolist()
150 | fill_indices.sort()
151 | full_traj_x, curve_x = spline_interpolation(axis=0)
152 | full_traj_y, curve_y = spline_interpolation(axis=1)
153 | full_traj = np.stack([full_traj_x, full_traj_y], axis=1)
154 | full_traj = get_valid_traj(full_traj, imgW=imgW, imgH=imgH)
155 | curve = [curve_x, curve_y]
156 | return full_traj, fill_indices, curve
157 |
158 |
159 | def compute_hand_traj(frames, annots, hand_sides, hand_threshold=0.1, obj_threshold=0.1):
160 | imgH, imgW = frames[0].shape[:2]
161 | results = traj_compute(frames, annots, hand_sides,
162 | hand_threshold=hand_threshold, obj_threshold=obj_threshold)
163 | if results is None:
164 | print("compute homography failed")
165 | return None
166 | else:
167 | left_traj, left_centers, right_traj, right_centers, homography_stack = results
168 | if len(left_traj) == 0 and len(right_traj) == 0:
169 | print("compute traj failed")
170 | return None
171 | hand_trajs = {}
172 | if len(left_traj) == 0:
173 | print("left traj filtered out")
174 | else:
175 | left_complete_traj, left_fill_indices, left_curve = traj_completion(left_traj, side="LEFT",
176 | imgW=imgW, imgH=imgH)
177 | hand_trajs["LEFT"] = {"traj": left_complete_traj, "fill_indices": left_fill_indices,
178 | "fit_curve": left_curve, "centers": left_centers}
179 | if len(right_traj) == 0:
180 | print("right traj filtered out")
181 | else:
182 | right_complete_traj, right_fill_indices, right_curve = traj_completion(right_traj, side="RIGHT",
183 | imgW=imgW, imgH=imgH)
184 | hand_trajs["RIGHT"] = {"traj": right_complete_traj, "fill_indices": right_fill_indices,
185 | "fit_curve": right_curve, "centers": right_centers}
186 | return homography_stack, hand_trajs
187 |
--------------------------------------------------------------------------------
/preprocess/types_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: types.proto
4 |
5 | from google.protobuf import descriptor as _descriptor
6 | from google.protobuf import message as _message
7 | from google.protobuf import reflection as _reflection
8 | from google.protobuf import symbol_database as _symbol_database
9 | # @@protoc_insertion_point(imports)
10 |
11 | _sym_db = _symbol_database.Default()
12 |
13 |
14 | DESCRIPTOR = _descriptor.FileDescriptor(
15 | name='types.proto',
16 | package='model.detections',
17 | syntax='proto3',
18 | serialized_options=None,
19 | create_key=_descriptor._internal_create_key,
20 | serialized_pb=b'\n\x0btypes.proto\x12\x10model.detections\"#\n\x0b\x46loatVector\x12\t\n\x01x\x18\x01 \x01(\x02\x12\t\n\x01y\x18\x02 \x01(\x02\"@\n\x04\x42\x42ox\x12\x0c\n\x04left\x18\x01 \x01(\x02\x12\x0b\n\x03top\x18\x02 \x01(\x02\x12\r\n\x05right\x18\x03 \x01(\x02\x12\x0e\n\x06\x62ottom\x18\x04 \x01(\x02\"\xfc\x02\n\rHandDetection\x12$\n\x04\x62\x62ox\x18\x01 \x01(\x0b\x32\x16.model.detections.BBox\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x38\n\x05state\x18\x03 \x01(\x0e\x32).model.detections.HandDetection.HandState\x12\x34\n\robject_offset\x18\x04 \x01(\x0b\x32\x1d.model.detections.FloatVector\x12\x36\n\x04side\x18\x05 \x01(\x0e\x32(.model.detections.HandDetection.HandSide\"m\n\tHandState\x12\x0e\n\nNO_CONTACT\x10\x00\x12\x10\n\x0cSELF_CONTACT\x10\x01\x12\x12\n\x0e\x41NOTHER_PERSON\x10\x02\x12\x13\n\x0fPORTABLE_OBJECT\x10\x03\x12\x15\n\x11STATIONARY_OBJECT\x10\x04\"\x1f\n\x08HandSide\x12\x08\n\x04LEFT\x10\x00\x12\t\n\x05RIGHT\x10\x01\"F\n\x0fObjectDetection\x12$\n\x04\x62\x62ox\x18\x01 \x01(\x0b\x32\x16.model.detections.BBox\x12\r\n\x05score\x18\x02 \x01(\x02\"\x98\x01\n\nDetections\x12\x10\n\x08video_id\x18\x01 \x01(\t\x12\x14\n\x0c\x66rame_number\x18\x02 \x01(\x05\x12.\n\x05hands\x18\x03 \x03(\x0b\x32\x1f.model.detections.HandDetection\x12\x32\n\x07objects\x18\x04 \x03(\x0b\x32!.model.detections.ObjectDetectionb\x06proto3'
21 | )
22 |
23 |
24 |
25 | _HANDDETECTION_HANDSTATE = _descriptor.EnumDescriptor(
26 | name='HandState',
27 | full_name='model.detections.HandDetection.HandState',
28 | filename=None,
29 | file=DESCRIPTOR,
30 | create_key=_descriptor._internal_create_key,
31 | values=[
32 | _descriptor.EnumValueDescriptor(
33 | name='NO_CONTACT', index=0, number=0,
34 | serialized_options=None,
35 | type=None,
36 | create_key=_descriptor._internal_create_key),
37 | _descriptor.EnumValueDescriptor(
38 | name='SELF_CONTACT', index=1, number=1,
39 | serialized_options=None,
40 | type=None,
41 | create_key=_descriptor._internal_create_key),
42 | _descriptor.EnumValueDescriptor(
43 | name='ANOTHER_PERSON', index=2, number=2,
44 | serialized_options=None,
45 | type=None,
46 | create_key=_descriptor._internal_create_key),
47 | _descriptor.EnumValueDescriptor(
48 | name='PORTABLE_OBJECT', index=3, number=3,
49 | serialized_options=None,
50 | type=None,
51 | create_key=_descriptor._internal_create_key),
52 | _descriptor.EnumValueDescriptor(
53 | name='STATIONARY_OBJECT', index=4, number=4,
54 | serialized_options=None,
55 | type=None,
56 | create_key=_descriptor._internal_create_key),
57 | ],
58 | containing_type=None,
59 | serialized_options=None,
60 | serialized_start=375,
61 | serialized_end=484,
62 | )
63 | _sym_db.RegisterEnumDescriptor(_HANDDETECTION_HANDSTATE)
64 |
65 | _HANDDETECTION_HANDSIDE = _descriptor.EnumDescriptor(
66 | name='HandSide',
67 | full_name='model.detections.HandDetection.HandSide',
68 | filename=None,
69 | file=DESCRIPTOR,
70 | create_key=_descriptor._internal_create_key,
71 | values=[
72 | _descriptor.EnumValueDescriptor(
73 | name='LEFT', index=0, number=0,
74 | serialized_options=None,
75 | type=None,
76 | create_key=_descriptor._internal_create_key),
77 | _descriptor.EnumValueDescriptor(
78 | name='RIGHT', index=1, number=1,
79 | serialized_options=None,
80 | type=None,
81 | create_key=_descriptor._internal_create_key),
82 | ],
83 | containing_type=None,
84 | serialized_options=None,
85 | serialized_start=486,
86 | serialized_end=517,
87 | )
88 | _sym_db.RegisterEnumDescriptor(_HANDDETECTION_HANDSIDE)
89 |
90 |
91 | _FLOATVECTOR = _descriptor.Descriptor(
92 | name='FloatVector',
93 | full_name='model.detections.FloatVector',
94 | filename=None,
95 | file=DESCRIPTOR,
96 | containing_type=None,
97 | create_key=_descriptor._internal_create_key,
98 | fields=[
99 | _descriptor.FieldDescriptor(
100 | name='x', full_name='model.detections.FloatVector.x', index=0,
101 | number=1, type=2, cpp_type=6, label=1,
102 | has_default_value=False, default_value=float(0),
103 | message_type=None, enum_type=None, containing_type=None,
104 | is_extension=False, extension_scope=None,
105 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
106 | _descriptor.FieldDescriptor(
107 | name='y', full_name='model.detections.FloatVector.y', index=1,
108 | number=2, type=2, cpp_type=6, label=1,
109 | has_default_value=False, default_value=float(0),
110 | message_type=None, enum_type=None, containing_type=None,
111 | is_extension=False, extension_scope=None,
112 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
113 | ],
114 | extensions=[
115 | ],
116 | nested_types=[],
117 | enum_types=[
118 | ],
119 | serialized_options=None,
120 | is_extendable=False,
121 | syntax='proto3',
122 | extension_ranges=[],
123 | oneofs=[
124 | ],
125 | serialized_start=33,
126 | serialized_end=68,
127 | )
128 |
129 |
130 | _BBOX = _descriptor.Descriptor(
131 | name='BBox',
132 | full_name='model.detections.BBox',
133 | filename=None,
134 | file=DESCRIPTOR,
135 | containing_type=None,
136 | create_key=_descriptor._internal_create_key,
137 | fields=[
138 | _descriptor.FieldDescriptor(
139 | name='left', full_name='model.detections.BBox.left', index=0,
140 | number=1, type=2, cpp_type=6, label=1,
141 | has_default_value=False, default_value=float(0),
142 | message_type=None, enum_type=None, containing_type=None,
143 | is_extension=False, extension_scope=None,
144 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
145 | _descriptor.FieldDescriptor(
146 | name='top', full_name='model.detections.BBox.top', index=1,
147 | number=2, type=2, cpp_type=6, label=1,
148 | has_default_value=False, default_value=float(0),
149 | message_type=None, enum_type=None, containing_type=None,
150 | is_extension=False, extension_scope=None,
151 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
152 | _descriptor.FieldDescriptor(
153 | name='right', full_name='model.detections.BBox.right', index=2,
154 | number=3, type=2, cpp_type=6, label=1,
155 | has_default_value=False, default_value=float(0),
156 | message_type=None, enum_type=None, containing_type=None,
157 | is_extension=False, extension_scope=None,
158 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
159 | _descriptor.FieldDescriptor(
160 | name='bottom', full_name='model.detections.BBox.bottom', index=3,
161 | number=4, type=2, cpp_type=6, label=1,
162 | has_default_value=False, default_value=float(0),
163 | message_type=None, enum_type=None, containing_type=None,
164 | is_extension=False, extension_scope=None,
165 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
166 | ],
167 | extensions=[
168 | ],
169 | nested_types=[],
170 | enum_types=[
171 | ],
172 | serialized_options=None,
173 | is_extendable=False,
174 | syntax='proto3',
175 | extension_ranges=[],
176 | oneofs=[
177 | ],
178 | serialized_start=70,
179 | serialized_end=134,
180 | )
181 |
182 |
183 | _HANDDETECTION = _descriptor.Descriptor(
184 | name='HandDetection',
185 | full_name='model.detections.HandDetection',
186 | filename=None,
187 | file=DESCRIPTOR,
188 | containing_type=None,
189 | create_key=_descriptor._internal_create_key,
190 | fields=[
191 | _descriptor.FieldDescriptor(
192 | name='bbox', full_name='model.detections.HandDetection.bbox', index=0,
193 | number=1, type=11, cpp_type=10, label=1,
194 | has_default_value=False, default_value=None,
195 | message_type=None, enum_type=None, containing_type=None,
196 | is_extension=False, extension_scope=None,
197 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
198 | _descriptor.FieldDescriptor(
199 | name='score', full_name='model.detections.HandDetection.score', index=1,
200 | number=2, type=2, cpp_type=6, label=1,
201 | has_default_value=False, default_value=float(0),
202 | message_type=None, enum_type=None, containing_type=None,
203 | is_extension=False, extension_scope=None,
204 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
205 | _descriptor.FieldDescriptor(
206 | name='state', full_name='model.detections.HandDetection.state', index=2,
207 | number=3, type=14, cpp_type=8, label=1,
208 | has_default_value=False, default_value=0,
209 | message_type=None, enum_type=None, containing_type=None,
210 | is_extension=False, extension_scope=None,
211 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
212 | _descriptor.FieldDescriptor(
213 | name='object_offset', full_name='model.detections.HandDetection.object_offset', index=3,
214 | number=4, type=11, cpp_type=10, label=1,
215 | has_default_value=False, default_value=None,
216 | message_type=None, enum_type=None, containing_type=None,
217 | is_extension=False, extension_scope=None,
218 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
219 | _descriptor.FieldDescriptor(
220 | name='side', full_name='model.detections.HandDetection.side', index=4,
221 | number=5, type=14, cpp_type=8, label=1,
222 | has_default_value=False, default_value=0,
223 | message_type=None, enum_type=None, containing_type=None,
224 | is_extension=False, extension_scope=None,
225 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
226 | ],
227 | extensions=[
228 | ],
229 | nested_types=[],
230 | enum_types=[
231 | _HANDDETECTION_HANDSTATE,
232 | _HANDDETECTION_HANDSIDE,
233 | ],
234 | serialized_options=None,
235 | is_extendable=False,
236 | syntax='proto3',
237 | extension_ranges=[],
238 | oneofs=[
239 | ],
240 | serialized_start=137,
241 | serialized_end=517,
242 | )
243 |
244 |
245 | _OBJECTDETECTION = _descriptor.Descriptor(
246 | name='ObjectDetection',
247 | full_name='model.detections.ObjectDetection',
248 | filename=None,
249 | file=DESCRIPTOR,
250 | containing_type=None,
251 | create_key=_descriptor._internal_create_key,
252 | fields=[
253 | _descriptor.FieldDescriptor(
254 | name='bbox', full_name='model.detections.ObjectDetection.bbox', index=0,
255 | number=1, type=11, cpp_type=10, label=1,
256 | has_default_value=False, default_value=None,
257 | message_type=None, enum_type=None, containing_type=None,
258 | is_extension=False, extension_scope=None,
259 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
260 | _descriptor.FieldDescriptor(
261 | name='score', full_name='model.detections.ObjectDetection.score', index=1,
262 | number=2, type=2, cpp_type=6, label=1,
263 | has_default_value=False, default_value=float(0),
264 | message_type=None, enum_type=None, containing_type=None,
265 | is_extension=False, extension_scope=None,
266 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
267 | ],
268 | extensions=[
269 | ],
270 | nested_types=[],
271 | enum_types=[
272 | ],
273 | serialized_options=None,
274 | is_extendable=False,
275 | syntax='proto3',
276 | extension_ranges=[],
277 | oneofs=[
278 | ],
279 | serialized_start=519,
280 | serialized_end=589,
281 | )
282 |
283 |
284 | _DETECTIONS = _descriptor.Descriptor(
285 | name='Detections',
286 | full_name='model.detections.Detections',
287 | filename=None,
288 | file=DESCRIPTOR,
289 | containing_type=None,
290 | create_key=_descriptor._internal_create_key,
291 | fields=[
292 | _descriptor.FieldDescriptor(
293 | name='video_id', full_name='model.detections.Detections.video_id', index=0,
294 | number=1, type=9, cpp_type=9, label=1,
295 | has_default_value=False, default_value=b"".decode('utf-8'),
296 | message_type=None, enum_type=None, containing_type=None,
297 | is_extension=False, extension_scope=None,
298 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
299 | _descriptor.FieldDescriptor(
300 | name='frame_number', full_name='model.detections.Detections.frame_number', index=1,
301 | number=2, type=5, cpp_type=1, label=1,
302 | has_default_value=False, default_value=0,
303 | message_type=None, enum_type=None, containing_type=None,
304 | is_extension=False, extension_scope=None,
305 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
306 | _descriptor.FieldDescriptor(
307 | name='hands', full_name='model.detections.Detections.hands', index=2,
308 | number=3, type=11, cpp_type=10, label=3,
309 | has_default_value=False, default_value=[],
310 | message_type=None, enum_type=None, containing_type=None,
311 | is_extension=False, extension_scope=None,
312 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
313 | _descriptor.FieldDescriptor(
314 | name='objects', full_name='model.detections.Detections.objects', index=3,
315 | number=4, type=11, cpp_type=10, label=3,
316 | has_default_value=False, default_value=[],
317 | message_type=None, enum_type=None, containing_type=None,
318 | is_extension=False, extension_scope=None,
319 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
320 | ],
321 | extensions=[
322 | ],
323 | nested_types=[],
324 | enum_types=[
325 | ],
326 | serialized_options=None,
327 | is_extendable=False,
328 | syntax='proto3',
329 | extension_ranges=[],
330 | oneofs=[
331 | ],
332 | serialized_start=592,
333 | serialized_end=744,
334 | )
335 |
336 | _HANDDETECTION.fields_by_name['bbox'].message_type = _BBOX
337 | _HANDDETECTION.fields_by_name['state'].enum_type = _HANDDETECTION_HANDSTATE
338 | _HANDDETECTION.fields_by_name['object_offset'].message_type = _FLOATVECTOR
339 | _HANDDETECTION.fields_by_name['side'].enum_type = _HANDDETECTION_HANDSIDE
340 | _HANDDETECTION_HANDSTATE.containing_type = _HANDDETECTION
341 | _HANDDETECTION_HANDSIDE.containing_type = _HANDDETECTION
342 | _OBJECTDETECTION.fields_by_name['bbox'].message_type = _BBOX
343 | _DETECTIONS.fields_by_name['hands'].message_type = _HANDDETECTION
344 | _DETECTIONS.fields_by_name['objects'].message_type = _OBJECTDETECTION
345 | DESCRIPTOR.message_types_by_name['FloatVector'] = _FLOATVECTOR
346 | DESCRIPTOR.message_types_by_name['BBox'] = _BBOX
347 | DESCRIPTOR.message_types_by_name['HandDetection'] = _HANDDETECTION
348 | DESCRIPTOR.message_types_by_name['ObjectDetection'] = _OBJECTDETECTION
349 | DESCRIPTOR.message_types_by_name['Detections'] = _DETECTIONS
350 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
351 |
352 | FloatVector = _reflection.GeneratedProtocolMessageType('FloatVector', (_message.Message,), {
353 | 'DESCRIPTOR' : _FLOATVECTOR,
354 | '__module__' : 'types_pb2'
355 | # @@protoc_insertion_point(class_scope:model.detections.FloatVector)
356 | })
357 | _sym_db.RegisterMessage(FloatVector)
358 |
359 | BBox = _reflection.GeneratedProtocolMessageType('BBox', (_message.Message,), {
360 | 'DESCRIPTOR' : _BBOX,
361 | '__module__' : 'types_pb2'
362 | # @@protoc_insertion_point(class_scope:model.detections.BBox)
363 | })
364 | _sym_db.RegisterMessage(BBox)
365 |
366 | HandDetection = _reflection.GeneratedProtocolMessageType('HandDetection', (_message.Message,), {
367 | 'DESCRIPTOR' : _HANDDETECTION,
368 | '__module__' : 'types_pb2'
369 | # @@protoc_insertion_point(class_scope:model.detections.HandDetection)
370 | })
371 | _sym_db.RegisterMessage(HandDetection)
372 |
373 | ObjectDetection = _reflection.GeneratedProtocolMessageType('ObjectDetection', (_message.Message,), {
374 | 'DESCRIPTOR' : _OBJECTDETECTION,
375 | '__module__' : 'types_pb2'
376 | # @@protoc_insertion_point(class_scope:model.detections.ObjectDetection)
377 | })
378 | _sym_db.RegisterMessage(ObjectDetection)
379 |
380 | Detections = _reflection.GeneratedProtocolMessageType('Detections', (_message.Message,), {
381 | 'DESCRIPTOR' : _DETECTIONS,
382 | '__module__' : 'types_pb2'
383 | # @@protoc_insertion_point(class_scope:model.detections.Detections)
384 | })
385 | _sym_db.RegisterMessage(Detections)
386 |
387 |
388 | # @@protoc_insertion_point(module_scope)
389 |
--------------------------------------------------------------------------------
/preprocess/vis_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import matplotlib.pyplot as plt
4 | from preprocess.affordance_util import compute_heatmap
5 |
6 | hand_rgb = {"LEFT": (0, 90, 181), "RIGHT": (220, 50, 32)}
7 | object_rgb = (255, 194, 10)
8 |
9 |
10 | def vis_traj(frame_vis, traj, fill_indices=None, side=None, circle_radis=4, circle_thickness=3, line_thickness=2, style='line', gap=5):
11 | for idx in range(len(traj)):
12 | x, y = traj[idx]
13 | if fill_indices is not None and idx in fill_indices:
14 | thickness = -1
15 | else:
16 | thickness = -1
17 | color = hand_rgb[side][::-1] if side is not None else (0, 255, 255)
18 | frame_vis = cv2.circle(frame_vis, (int(round(x)), int(round(y))), radius=circle_radis, color=color,
19 | thickness=thickness)
20 | if idx > 0:
21 | pt1 = (int(round(traj[idx-1][0])), int(round(traj[idx-1][1])))
22 | pt2 = (int(round(traj[idx][0])), int(round(traj[idx][1])))
23 | dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** .5
24 | pts = []
25 | for i in np.arange(0, dist, gap):
26 | r = i / dist
27 | x = int((pt1[0] * (1 - r) + pt2[0] * r) + .5)
28 | y = int((pt1[1] * (1 - r) + pt2[1] * r) + .5)
29 | p = (x, y)
30 | pts.append(p)
31 | if style == 'dotted':
32 | for p in pts:
33 | cv2.circle(frame_vis, p, circle_thickness, color, -1)
34 | else:
35 | if len(pts) > 0:
36 | s = pts[0]
37 | e = pts[0]
38 | i = 0
39 | for p in pts:
40 | s = e
41 | e = p
42 | if i % 2 == 1:
43 | cv2.line(frame_vis, s, e, color, line_thickness)
44 | i += 1
45 | return frame_vis
46 |
47 |
48 | def vis_hand_traj(frames, hand_trajs):
49 | frame_vis = frames[0].copy()
50 | for side in hand_trajs:
51 | meta = hand_trajs[side]
52 | traj, fill_indices = meta["traj"], meta["fill_indices"]
53 | frame_vis = vis_traj(frame_vis, traj, fill_indices, side)
54 | return frame_vis
55 |
56 |
57 | def vis_affordance(frame, affordance_info):
58 | select_points = affordance_info["select_points_homo"]
59 | hmap = compute_heatmap(select_points, (frame.shape[1], frame.shape[0]))
60 | hmap = (hmap * 255).astype(np.uint8)
61 | hmap = cv2.applyColorMap(hmap, colormap=cv2.COLORMAP_JET)
62 | for idx in range((len(select_points))):
63 | point = select_points[idx].astype(np.int)
64 | frame_vis = cv2.circle(frame, (point[0], point[1]), radius=2, color=(255, 0, 255),
65 | thickness=-1)
66 | overlay = (0.7 * frame + 0.3 * hmap).astype(np.uint8)
67 | return overlay
68 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchvision
2 | torch
3 | lmdbdict
4 | tqdm
5 | scikit-learn
6 | einops
--------------------------------------------------------------------------------
/traineval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.parallel
8 | import torch.optim
9 |
10 | from netscripts.get_datasets import get_dataset
11 | from netscripts.get_network import get_network
12 | from netscripts.get_optimizer import get_optimizer
13 | from netscripts import modelio
14 | from netscripts.epoch_feat import epoch_pass
15 | from options import netsopts, expopts
16 | from datasets.datasetopts import DatasetArgs
17 |
18 |
19 | def main(args):
20 |
21 | # Initialize randoms seeds
22 | torch.cuda.manual_seed_all(args.manual_seed)
23 | torch.manual_seed(args.manual_seed)
24 | np.random.seed(args.manual_seed)
25 | random.seed(args.manual_seed)
26 |
27 | datasetargs = DatasetArgs(ek_version=args.ek_version)
28 | num_frames_input = int(datasetargs.fps * datasetargs.t_buffer)
29 | num_frames_output = int(datasetargs.fps * datasetargs.t_ant)
30 | model = get_network(args, num_frames_input=num_frames_input,
31 | num_frames_output=num_frames_output)
32 |
33 | if args.use_cuda and torch.cuda.is_available():
34 | print("Using {} GPUs !".format(torch.cuda.device_count()))
35 | model.cuda()
36 |
37 | start_epoch = 0
38 | if args.resume is not None:
39 | device = torch.device('cuda') if torch.cuda.is_available() and args.use_cuda else torch.device('cpu')
40 | start_epoch = modelio.load_checkpoint(model, resume_path=args.resume[0], strict=False, device=device)
41 | print("Loaded checkpoint from epoch {}, starting from there".format(start_epoch))
42 |
43 | _, dls = get_dataset(args, base_path="./")
44 |
45 | if args.evaluate:
46 | args.epochs = start_epoch + 1
47 | traj_val_loader = None
48 | else:
49 | train_loader = dls['train']
50 | traj_val_loader = dls['validation']
51 | print("training dataset size: {}".format(len(train_loader.dataset)))
52 | optimizer, scheduler = get_optimizer(args, model=model, train_loader=train_loader)
53 |
54 | if not args.traj_only:
55 | val_loader = dls['eval']
56 | else:
57 | traj_val_loader = val_loader = dls['validation']
58 | print("evaluation dataset size: {}".format(len(val_loader.dataset)))
59 |
60 | for epoch in range(start_epoch, args.epochs):
61 | if not args.evaluate:
62 | print("Using lr {}".format(optimizer.param_groups[0]["lr"]))
63 | epoch_pass(
64 | loader=train_loader,
65 | model=model,
66 | phase='train',
67 | optimizer=optimizer,
68 | epoch=epoch,
69 | train=True,
70 | use_cuda=args.use_cuda,
71 | scheduler=scheduler)
72 |
73 | if args.evaluate or (epoch + 1) % args.test_freq == 0:
74 | with torch.no_grad():
75 | if not args.traj_only:
76 | epoch_pass(
77 | loader=val_loader,
78 | model=model,
79 | epoch=epoch,
80 | phase='affordance',
81 | optimizer=None,
82 | train=False,
83 | use_cuda=args.use_cuda,
84 | num_samples=args.num_samples,
85 | num_points=args.num_points)
86 | else:
87 | with torch.no_grad():
88 | epoch_pass(
89 | loader=traj_val_loader,
90 | model=model,
91 | epoch=epoch,
92 | phase='traj',
93 | optimizer=None,
94 | train=False,
95 | use_cuda=args.use_cuda,
96 | num_samples=args.num_samples)
97 |
98 | if not args.evaluate:
99 | if (epoch + 1 - args.warmup_epochs) % args.snapshot == 0:
100 | print(f"save epoch {epoch+1} checkpoint to {os.path.join(args.host_folder,args.exp_id)}")
101 | modelio.save_checkpoint(
102 | {
103 | "epoch": epoch + 1,
104 | "network": args.network,
105 | "state_dict": model.state_dict(),
106 | "optimizer": optimizer.state_dict(),
107 | },
108 | checkpoint=os.path.join(args.host_folder, args.exp_id),
109 | filename = f"checkpoint_{epoch+1}.pth.tar")
110 |
111 |
112 | if __name__ == "__main__":
113 | parser = argparse.ArgumentParser(description="HOI Forecasting")
114 | netsopts.add_nets_opts(parser)
115 | netsopts.add_train_opts(parser)
116 | expopts.add_exp_opts(parser)
117 | args = parser.parse_args()
118 |
119 | if args.use_cuda and torch.cuda.is_available():
120 | num_gpus = torch.cuda.device_count()
121 | args.batch_size = args.batch_size * num_gpus
122 | args.lr = args.lr * num_gpus
123 |
124 | if args.traj_only: assert args.evaluate, "evaluate trajectory on validation set must set --evaluate"
125 | main(args)
126 | print("All done !")
--------------------------------------------------------------------------------